import streamlit as st st.markdown("# Классификатор статей") st.markdown(""" Источник: [evergreens.com.ua](https://evergreens.com.ua/ru/articles/classical-machine-learning.html) """, unsafe_allow_html=True) st.markdown("#### Привет!") st.markdown("""Это классификатор статей. Вы можете подать на вход название статьи, её аннотацию (abstract) или и то и другое и получить наиболее вероятные тематики данной статьи. По введённым данным нейросеть предскажет топ-95% тематик. Разбиение на тематики происходит согласно [arXiv Category Taxonomy](https://arxiv.org/category_taxonomy) в несколько более упрощённом формате (без подтематик и без некоторых не самых частых тематик). Ввод данных осуществляется в поля ниже. - Для того, чтобы сохранить введённую информацию в поле, нажмите Ctrl+Enter. - После того, как Вы заполните хотя бы одно поле, модель сделает предсказание. После заполнения второго поля, предсказание обновится. """) import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # В целом, и без кеширования работало быстро, но с кешированием действительно быстрее @st.cache def cached_model(): model = AutoModelForSequenceClassification.from_pretrained('pretrained_model') return model tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased") model = cached_model() name = st.text_area("Введите название статьи:") abstract = st.text_area("Введите аннотацию статьи:") text = name + '. ' + abstract tokens = tokenizer.encode(text) res = model(torch.as_tensor([tokens], device=device))[0] def top95(res): probs = torch.softmax(res, dim=-1).squeeze().cpu() probs = [(prob.item(), i) for i, prob in enumerate(probs)] probs = sorted(probs, reverse=True) result_probs = [] result_indexes = [] sum_prob = 0 for prob, i in probs: if sum_prob < 0.95: result_probs.append(prob) result_indexes.append(i) sum_prob += prob return result_probs, result_indexes class_names = ['cs', 'stat', 'math', 'q-bio', 'physics'] full_names = ['Computer Science', 'Statistics', 'Mathematics', 'Quantitative Biology', 'Physics'] answer = [f'- {100*prob:.2f}% - {class_names[i]} ({full_names[i]}) \n\n' for prob, i in zip(*top95(res))] st.markdown(f""" Давайте сверимся :) _Ваше название:_ {name if name else '{не задано}'} _Ваша аннотация:_ {abstract if abstract else '{не задано}'} """) st.markdown('**Предсказания тематики Вашей статьи (топ-95%):**\n\n' + ''.join(answer)) # st.markdown(f'Тематика Вашей статьи: {class_names[index]}') st.markdown(""" --- Автор: _Негодин Владислав_ tg: _@kol060k_ """)