arxiv-classifier / src /streamlit_app.py
sofia-ol's picture
Update src/streamlit_app.py
75ad80e verified
Raw
History Blame Contribute Delete
3.96 kB
import streamlit as st
import torch
from transformers import DistilBertForSequenceClassification, AutoTokenizer
from torch.nn import Softmax
import numpy as np
# Настройка страницы
st.set_page_config(
page_title="arXiv Classifier",
page_icon="📚",
layout="centered"
)
@st.cache_resource
def load_model():
model_path = "./best_model"
model = DistilBertForSequenceClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model.eval()
return model, tokenizer
def predict(text, model, tokenizer, threshold=0.95):
"""Предсказание с накоплением вероятностей до 95%"""
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
with torch.no_grad():
outputs = model(**inputs)
probs = Softmax(dim=1)(outputs.logits).squeeze().numpy()
indices = np.argsort(probs)[::-1]
cumulative = 0
results = []
for idx in indices:
prob = probs[idx]
cumulative += prob
category = model.config.id2label[idx]
category_names = {
'cs.AI': '🤖 Искусственный интеллект',
'cs.CL': '💬 Обработка естественного языка',
'cs.CV': '👁️ Компьютерное зрение',
'physics': '⚛️ Физика',
'math': '📐 Математика',
'q-bio': '🧬 Биология'
}
display_name = category_names.get(category, category)
results.append((display_name, prob, category))
if cumulative >= threshold:
break
return results
st.title("📚 arXiv Статья Классификатор")
st.markdown("""
Определяет тематику научной статьи по **названию** и **аннотации**.
Модель обучена на 18,000+ статей из arXiv.org.
""")
st.subheader("Введите данные статьи")
title = st.text_input("📌 Название статьи *", placeholder="Например: Attention is All You Need")
abstract = st.text_area(
"📄 Аннотация (необязательно)",
placeholder="Введите аннотацию статьи здесь...",
height=150
)
if title.strip() == "":
st.warning("⚠️ Пожалуйста, введите название статьи")
st.stop()
if abstract.strip():
full_text = title + " [SEP] " + abstract
else:
full_text = title
if st.button("🔍 Определить тематику", type="primary"):
with st.spinner("Анализирую статью..."):
try:
model, tokenizer = load_model()
predictions = predict(full_text, model, tokenizer)
st.subheader("📊 Результаты классификации")
for display_name, prob, cat in predictions:
st.markdown(f"**{display_name}**")
st.progress(float(prob), text=f"{prob*100:.1f}%")
if len(predictions) == 1:
st.success(f"✅ Статья однозначно относится к категории **{predictions[0][0]}**")
else:
st.info(f"📌 Статья может относиться к нескольким областям (топ-{len(predictions)} категорий, суммарная вероятность > 95%)")
except Exception as e:
st.error(f"❌ Ошибка: {str(e)}")
st.markdown("Попробуйте ввести другой текст или проверьте подключение.")
st.markdown("---")
st.caption("Built with DistilBERT | Trained on arXiv papers | Deployed on Hugging Face Spaces")