import streamlit as st from transformers import AutoTokenizer, DistilBertForSequenceClassification, DistilBertConfig import torch from torch.nn.functional import softmax base_model_name = 'distilbert-base-uncased' @st.cache_data def load_tags_info(): id_to_description = {} with open('tags.txt', 'r') as file: i = 0 for line in file: description = line[:-1] id_to_description[i] = description i += 1 return id_to_description id_to_description = load_tags_info() @st.cache_resource def load_model(): config = DistilBertConfig.from_json_file('./config.json') model = DistilBertForSequenceClassification(config) state_dict = torch.load('./pytorch_model.bin', map_location=torch.device('cpu')) model.load_state_dict(state_dict) return model def load_tokenizer(): return AutoTokenizer.from_pretrained('distilbert-base-uncased') def top_xx(preds, xx=95): tops = torch.argsort(preds, 1, descending=True) total = 0 index = 0 result = [] while total < xx / 100: next_id = tops[0, index].item() total += preds[0, next_id] index += 1 result.append(id_to_description[next_id]) return result model = load_model() tokenizer = load_tokenizer() temperature = 1 st.title('ArXivTager') st.caption('Напишите тему (Title) и параграф из статьи (Abstract). Поля должны быть ЗАПОЛНЕНЫ текстом на АНГЛИЙСКОМ языке для корректной классификации.') with st.form("ArXivTager"): title = st.text_area(label='Title', height=100) abstract = st.text_area(label='Abstract (optional)', height=200) st.caption('ВЫВОД: набор тем в порядке уменьшения вероятностей.') submitted = st.form_submit_button("Get tags") if submitted: if title == '': st.markdown("Нужно хоть что-то написать") else: prompt = 'Title: ' + title + ' Abstract: ' + abstract tokens = tokenizer(prompt, truncation=True, padding='max_length', return_tensors='pt')['input_ids'] preds = softmax(model(tokens.reshape(1, -1)).logits / temperature, dim=1) tags = top_xx(preds) other_tags = [] st.header('Inferred tags:') for i, tag_data in enumerate(tags): st.markdown('* ' + tag_data)