YSDA_ML / app.py
Kirill
fix app.py
378fec2
import streamlit as st
from transformers import AutoTokenizer, DistilBertForSequenceClassification, DistilBertConfig
import torch
from torch.nn.functional import softmax
base_model_name = 'distilbert-base-uncased'
@st.cache_data
def load_tags_info():
id_to_description = {}
with open('tags.txt', 'r') as file:
i = 0
for line in file:
description = line[:-1]
id_to_description[i] = description
i += 1
return id_to_description
id_to_description = load_tags_info()
@st.cache_resource
def load_model():
config = DistilBertConfig.from_json_file('./config.json')
model = DistilBertForSequenceClassification(config)
state_dict = torch.load('./pytorch_model.bin', map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
return model
def load_tokenizer():
return AutoTokenizer.from_pretrained('distilbert-base-uncased')
def top_xx(preds, xx=95):
tops = torch.argsort(preds, 1, descending=True)
total = 0
index = 0
result = []
while total < xx / 100:
next_id = tops[0, index].item()
total += preds[0, next_id]
index += 1
result.append(id_to_description[next_id])
return result
model = load_model()
tokenizer = load_tokenizer()
temperature = 1
st.title('ArXivTager')
st.caption('Напишите тему (Title) и параграф из статьи (Abstract). Поля должны быть ЗАПОЛНЕНЫ текстом на АНГЛИЙСКОМ языке для корректной классификации.')
with st.form("ArXivTager"):
title = st.text_area(label='Title', height=100)
abstract = st.text_area(label='Abstract (optional)', height=200)
st.caption('ВЫВОД: набор тем в порядке уменьшения вероятностей.')
submitted = st.form_submit_button("Get tags")
if submitted:
if title == '':
st.markdown("Нужно хоть что-то написать")
else:
prompt = 'Title: ' + title + ' Abstract: ' + abstract
tokens = tokenizer(prompt, truncation=True, padding='max_length', return_tensors='pt')['input_ids']
preds = softmax(model(tokens.reshape(1, -1)).logits / temperature, dim=1)
tags = top_xx(preds)
other_tags = []
st.header('Inferred tags:')
for i, tag_data in enumerate(tags):
st.markdown('* ' + tag_data)