Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| st.markdown("# Классификатор статей") | |
| st.markdown(""" | |
| <img width=700px src='https://evergreens.com.ua/assets/images/articles/Kudelya-images/ml/article1-img5%20(1).jpg'> | |
| <em>Источник: [evergreens.com.ua](https://evergreens.com.ua/ru/articles/classical-machine-learning.html)</em> | |
| """, unsafe_allow_html=True) | |
| st.markdown("#### Привет!") | |
| st.markdown("""Это классификатор статей. Вы можете подать на вход название статьи, её аннотацию (abstract) или и то и другое и | |
| получить наиболее вероятные тематики данной статьи. По введённым данным нейросеть предскажет топ-95% тематик. | |
| Разбиение на тематики происходит согласно [arXiv Category Taxonomy](https://arxiv.org/category_taxonomy) в несколько более упрощённом формате (без подтематик и без некоторых не самых частых тематик). | |
| Ввод данных осуществляется в поля ниже. | |
| - Для того, чтобы сохранить введённую информацию в поле, нажмите Ctrl+Enter. | |
| - После того, как Вы заполните хотя бы одно поле, модель сделает предсказание. После заполнения второго поля, предсказание обновится. | |
| """) | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # В целом, и без кеширования работало быстро, но с кешированием действительно быстрее | |
| def cached_model(): | |
| model = AutoModelForSequenceClassification.from_pretrained('pretrained_model') | |
| return model | |
| tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased") | |
| model = cached_model() | |
| name = st.text_area("Введите название статьи:") | |
| abstract = st.text_area("Введите аннотацию статьи:") | |
| text = name + '. ' + abstract | |
| tokens = tokenizer.encode(text) | |
| res = model(torch.as_tensor([tokens], device=device))[0] | |
| def top95(res): | |
| probs = torch.softmax(res, dim=-1).squeeze().cpu() | |
| probs = [(prob.item(), i) for i, prob in enumerate(probs)] | |
| probs = sorted(probs, reverse=True) | |
| result_probs = [] | |
| result_indexes = [] | |
| sum_prob = 0 | |
| for prob, i in probs: | |
| if sum_prob < 0.95: | |
| result_probs.append(prob) | |
| result_indexes.append(i) | |
| sum_prob += prob | |
| return result_probs, result_indexes | |
| class_names = ['cs', 'stat', 'math', 'q-bio', 'physics'] | |
| full_names = ['Computer Science', 'Statistics', 'Mathematics', 'Quantitative Biology', 'Physics'] | |
| answer = [f'- {100*prob:.2f}% - {class_names[i]} ({full_names[i]}) \n\n' for prob, i in zip(*top95(res))] | |
| st.markdown(f""" | |
| Давайте сверимся :) | |
| _Ваше название:_ {name if name else '{не задано}'} | |
| _Ваша аннотация:_ {abstract if abstract else '{не задано}'} | |
| """) | |
| st.markdown('**Предсказания тематики Вашей статьи (топ-95%):**\n\n' + ''.join(answer)) | |
| # st.markdown(f'Тематика Вашей статьи: {class_names[index]}') | |
| st.markdown(""" | |
| --- | |
| Автор: _Негодин Владислав_ | |
| tg: _@kol060k_ | |
| """) | |