import streamlit as st from transformers import DistilBertTokenizer, DistilBertForSequenceClassification import torch import json model = DistilBertForSequenceClassification.from_pretrained('./arxiv_classifier') tokenizer = DistilBertTokenizer.from_pretrained('./arxiv_classifier') with open('./arxiv_classifier/index_to_category.json', 'r', encoding='utf-8') as f: index_to_category = json.load(f) def predict(title, summary): inputs = tokenizer(title + " " + summary, return_tensors="pt", padding=True, truncation=True) outputs = model(**inputs) predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) return predictions st.set_page_config(page_title="ArXiv Article Classifier", layout="centered") st.title("ArXiv Article Classifier") st.write("Введите заголовок и аннотацию статьи, чтобы получить предсказание категории.") title = st.text_input("Заголовок статьи", placeholder="Введите заголовок статьи здесь") summary = st.text_area("Аннотация статьи", placeholder="Введите аннотацию статьи здесь") # Кнопка для классификации if st.button("Классифицировать"): if title.strip() == "" and summary.strip() == "": st.error("Пожалуйста, введите заголовок или аннотацию статьи.") else: with st.spinner("Классификация..."): predictions = predict(title, summary) sorted_indices = torch.argsort(predictions[0], descending=True) cumulative_probability = 0.0 st.subheader("Результаты классификации:") for idx in sorted_indices: probability = predictions[0][idx].item() cumulative_probability += probability category_name = index_to_category.get(str(idx.item()), "Unknown") st.write(f"Категория {category_name}: {probability:.2f}") if cumulative_probability >= 0.95: break st.markdown( """ """, unsafe_allow_html=True )