Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import numpy as np | |
| import transformers | |
| from transformers import BertTokenizer, BertModel | |
| codes = ['cs.AI', | |
| 'cs.CL', | |
| 'cs.CV', | |
| 'cs.NE', | |
| 'stat.ML', | |
| 'cs.LG', | |
| 'stat.AP', | |
| 'cs.RO', | |
| 'math.OC', | |
| 'cs.IR', | |
| 'stat.ME', | |
| 'cs.DC', | |
| 'q-bio.NC', | |
| 'cs.CR', | |
| 'cs.HC', | |
| 'cs.SD', | |
| 'cs.CY', | |
| 'cs.IT', | |
| 'math.IT', | |
| 'cs.SI', | |
| 'cs.LO', | |
| 'cs.DS'] | |
| human_readible = {'cs.AI': 'Artificial Intelligence', | |
| 'cs.CL': 'Computation and Language', | |
| 'cs.CV': 'Computer Vision and Pattern Recognition', | |
| 'cs.NE': 'Neural and Evolutionary Computing', | |
| 'stat.ML': 'Statistic aspects of machine learning', | |
| 'cs.LG': 'Machine Learning', | |
| 'stat.AP': 'Statistical Applications', | |
| 'cs.RO': 'Robotics', | |
| 'math.OC': 'Optimization and Control', | |
| 'cs.IR': 'Information Retrieval', | |
| 'stat.ME': 'Statistics Methodology', | |
| 'cs.DC': 'Distributed, Parallel, and Cluster Computing', | |
| 'q-bio.NC': 'Neurons and Cognition', | |
| 'cs.CR': 'Cryptography and Security', | |
| 'cs.HC': 'Human-Computer Interaction', | |
| 'cs.SD': 'Sound', | |
| 'cs.CY': 'Computers and Society', | |
| 'cs.IT': 'Information Theory', | |
| 'math.IT': 'Theoretical Information Theory', | |
| 'cs.SI': 'Social and Information Networks', | |
| 'cs.LO': 'Logic in Computer Science', | |
| 'cs.DS': 'Data Structures and Algorithms'} | |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| class BERTClass(torch.nn.Module): | |
| def __init__(self): | |
| super(BERTClass, self).__init__() | |
| self.bert_model = BertModel.from_pretrained('bert-base-uncased', return_dict=True) | |
| self.dropout = torch.nn.Dropout(0.3) | |
| self.linear = torch.nn.Linear(768, 22) | |
| def forward(self, input_ids, attn_mask, token_type_ids): | |
| output = self.bert_model( | |
| input_ids, | |
| attention_mask=attn_mask, | |
| token_type_ids=token_type_ids | |
| ) | |
| output_dropout = self.dropout(output.pooler_output) | |
| output = self.linear(output_dropout) | |
| return output | |
| def load_model(): | |
| model = torch.load('model_saved', map_location=torch.device('cpu')) | |
| model.eval() | |
| return model | |
| model = load_model() | |
| def process(example): | |
| encodings = tokenizer.encode_plus( | |
| example, | |
| None, | |
| add_special_tokens=True, | |
| max_length=256, | |
| padding='max_length', | |
| return_token_type_ids=True, | |
| truncation=True, | |
| return_attention_mask=True, | |
| return_tensors='pt' | |
| ) | |
| model.eval() | |
| with torch.no_grad(): | |
| input_ids = encodings['input_ids'].to(device, dtype=torch.long) | |
| attention_mask = encodings['attention_mask'].to(device, dtype=torch.long) | |
| token_type_ids = encodings['token_type_ids'].to(device, dtype=torch.long) | |
| output = model(input_ids, attention_mask, token_type_ids) | |
| # print(output) | |
| # final_output = torch.nn.functional.softmax(torch.exp(output), dim=1).cpu().detach().numpy().tolist() | |
| final_output = torch.nn.functional.softmax(output, dim=1).cpu().detach().numpy().tolist() | |
| # print(final_output[0]) | |
| ind = np.argsort(final_output[0]).tolist() | |
| ind.reverse() | |
| top95 = 0 | |
| top95_topics = [] | |
| top95_probs = [] | |
| topics = codes | |
| for i in ind: | |
| if top95 <= 0.95: | |
| top95 += final_output[0][i] | |
| top95_topics.append(human_readible[topics[i]]) | |
| top95_probs.append(final_output[0][i]) | |
| return top95_topics | |
| text = st.text_area("Введите в текстовое поле аннотацию статьи по computer science" | |
| " чтобы определить возможные ее темы.") | |
| if st.button("Анализировать текст"): | |
| if len(text) == 0: | |
| st.markdown("Вы пока ничего не написали") | |
| else: | |
| topics = process(text) | |
| st.markdown("Скорее всего ваша статья одна из следующих в списке:\n") | |
| for topic in topics: | |
| st.markdown(topic + '\n') | |
| #video = open("video1.mp4", 'rb') | |
| #video_data = video.read() | |
| #st.video(video_data) |