YSDA_ML2_HW_NLP / app.py
vdnegodin's picture
Update app.py
3579d9f
import streamlit as st
st.markdown("# Классификатор статей")
st.markdown("""
<img width=700px src='https://evergreens.com.ua/assets/images/articles/Kudelya-images/ml/article1-img5%20(1).jpg'>
<em>Источник: [evergreens.com.ua](https://evergreens.com.ua/ru/articles/classical-machine-learning.html)</em>
""", unsafe_allow_html=True)
st.markdown("#### Привет!")
st.markdown("""Это классификатор статей. Вы можете подать на вход название статьи, её аннотацию (abstract) или и то и другое и
получить наиболее вероятные тематики данной статьи. По введённым данным нейросеть предскажет топ-95% тематик.
Разбиение на тематики происходит согласно [arXiv Category Taxonomy](https://arxiv.org/category_taxonomy) в несколько более упрощённом формате (без подтематик и без некоторых не самых частых тематик).
Ввод данных осуществляется в поля ниже.
- Для того, чтобы сохранить введённую информацию в поле, нажмите Ctrl+Enter.
- После того, как Вы заполните хотя бы одно поле, модель сделает предсказание. После заполнения второго поля, предсказание обновится.
""")
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# В целом, и без кеширования работало быстро, но с кешированием действительно быстрее
@st.cache
def cached_model():
model = AutoModelForSequenceClassification.from_pretrained('pretrained_model')
return model
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")
model = cached_model()
name = st.text_area("Введите название статьи:")
abstract = st.text_area("Введите аннотацию статьи:")
text = name + '. ' + abstract
tokens = tokenizer.encode(text)
res = model(torch.as_tensor([tokens], device=device))[0]
def top95(res):
probs = torch.softmax(res, dim=-1).squeeze().cpu()
probs = [(prob.item(), i) for i, prob in enumerate(probs)]
probs = sorted(probs, reverse=True)
result_probs = []
result_indexes = []
sum_prob = 0
for prob, i in probs:
if sum_prob < 0.95:
result_probs.append(prob)
result_indexes.append(i)
sum_prob += prob
return result_probs, result_indexes
class_names = ['cs', 'stat', 'math', 'q-bio', 'physics']
full_names = ['Computer Science', 'Statistics', 'Mathematics', 'Quantitative Biology', 'Physics']
answer = [f'- {100*prob:.2f}% - {class_names[i]} ({full_names[i]}) \n\n' for prob, i in zip(*top95(res))]
st.markdown(f"""
Давайте сверимся :)
_Ваше название:_ {name if name else '{не задано}'}
_Ваша аннотация:_ {abstract if abstract else '{не задано}'}
""")
st.markdown('**Предсказания тематики Вашей статьи (топ-95%):**\n\n' + ''.join(answer))
# st.markdown(f'Тематика Вашей статьи: {class_names[index]}')
st.markdown("""
---
Автор: _Негодин Владислав_
tg: _@kol060k_
""")