File size: 2,269 Bytes
2c26172 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import streamlit as st
from transformers import AutoTokenizer, DistilBertForSequenceClassification
import torch
from torch.nn.functional import softmax
base_model_name = 'distilbert-base-uncased'
@st.cache
def load_tags_info():
id_to_description = {}
with open('tags.txt', 'r') as file:
i = 0
for line in file:
description = line[:-1]
id_to_description[i] = description
i += 1
return id_to_description
id_to_description = load_tags_info()
@st.cache
def load_model():
return DistilBertForSequenceClassification.from_pretrained('./')
def load_tokenizer():
return AutoTokenizer.from_pretrained('distilbert-base-uncased')
def top_xx(preds, xx=95):
tops = torch.argsort(preds, 1, descending=True)
total = 0
index = 0
result = []
while total < xx / 100:
next_id = tops[0, index].item()
total += preds[0, next_id]
index += 1
result.append(id_to_description[next_id])
return result
model = load_model()
tokenizer = load_tokenizer()
temperature = 1
st.title('ArXivTager')
st.caption('Напишите тему (Title) и параграф из статьи (Abstract). Поля должны быть ЗАПОЛНЕНЫ для корректной классификации.')
with st.form("ArXivTager"):
title = st.text_area(label='Title', height=30)
abstract = st.text_area(label='Abstract (optional)', height=200)
st.caption('ВЫВОД: набор тем в порядке уменьшения вероятностей.')
submitted = st.form_submit_button("Get tags")
if submitted:
if title == '':
st.markdown("Нужно хоть что-то написать")
else:
prompt = 'Title: ' + title + ' Abstract: ' + abstract
tokens = tokenizer(prompt, truncation=True, padding='max_length', return_tensors='pt')['input_ids']
preds = softmax(model(tokens.reshape(1, -1)).logits / temperature, dim=1)
tags = top_xx(preds)
other_tags = []
st.header('Inferred tags:')
for i, tag_data in enumerate(tags):
st.markdown('* ' + tag_data) |