Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch | |
| from transformers import DistilBertForSequenceClassification, AutoTokenizer | |
| from torch.nn import Softmax | |
| import numpy as np | |
| # Настройка страницы | |
| st.set_page_config( | |
| page_title="arXiv Classifier", | |
| page_icon="📚", | |
| layout="centered" | |
| ) | |
| def load_model(): | |
| model_path = "./best_model" | |
| model = DistilBertForSequenceClassification.from_pretrained(model_path) | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| model.eval() | |
| return model, tokenizer | |
| def predict(text, model, tokenizer, threshold=0.95): | |
| """Предсказание с накоплением вероятностей до 95%""" | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, | |
| padding=True | |
| ) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = Softmax(dim=1)(outputs.logits).squeeze().numpy() | |
| indices = np.argsort(probs)[::-1] | |
| cumulative = 0 | |
| results = [] | |
| for idx in indices: | |
| prob = probs[idx] | |
| cumulative += prob | |
| category = model.config.id2label[idx] | |
| category_names = { | |
| 'cs.AI': '🤖 Искусственный интеллект', | |
| 'cs.CL': '💬 Обработка естественного языка', | |
| 'cs.CV': '👁️ Компьютерное зрение', | |
| 'physics': '⚛️ Физика', | |
| 'math': '📐 Математика', | |
| 'q-bio': '🧬 Биология' | |
| } | |
| display_name = category_names.get(category, category) | |
| results.append((display_name, prob, category)) | |
| if cumulative >= threshold: | |
| break | |
| return results | |
| st.title("📚 arXiv Статья Классификатор") | |
| st.markdown(""" | |
| Определяет тематику научной статьи по **названию** и **аннотации**. | |
| Модель обучена на 18,000+ статей из arXiv.org. | |
| """) | |
| st.subheader("Введите данные статьи") | |
| title = st.text_input("📌 Название статьи *", placeholder="Например: Attention is All You Need") | |
| abstract = st.text_area( | |
| "📄 Аннотация (необязательно)", | |
| placeholder="Введите аннотацию статьи здесь...", | |
| height=150 | |
| ) | |
| if title.strip() == "": | |
| st.warning("⚠️ Пожалуйста, введите название статьи") | |
| st.stop() | |
| if abstract.strip(): | |
| full_text = title + " [SEP] " + abstract | |
| else: | |
| full_text = title | |
| if st.button("🔍 Определить тематику", type="primary"): | |
| with st.spinner("Анализирую статью..."): | |
| try: | |
| model, tokenizer = load_model() | |
| predictions = predict(full_text, model, tokenizer) | |
| st.subheader("📊 Результаты классификации") | |
| for display_name, prob, cat in predictions: | |
| st.markdown(f"**{display_name}**") | |
| st.progress(float(prob), text=f"{prob*100:.1f}%") | |
| if len(predictions) == 1: | |
| st.success(f"✅ Статья однозначно относится к категории **{predictions[0][0]}**") | |
| else: | |
| st.info(f"📌 Статья может относиться к нескольким областям (топ-{len(predictions)} категорий, суммарная вероятность > 95%)") | |
| except Exception as e: | |
| st.error(f"❌ Ошибка: {str(e)}") | |
| st.markdown("Попробуйте ввести другой текст или проверьте подключение.") | |
| st.markdown("---") | |
| st.caption("Built with DistilBERT | Trained on arXiv papers | Deployed on Hugging Face Spaces") |