arxiv_article_classifier / src /streamlit_app.py
aaaleksandrasimonova's picture
Update src/streamlit_app.py
147f1b7 verified
import streamlit as st
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
@st.cache_resource
def load_model():
path_model = "aaaleksandrasimonova/arxiv_article_model1"
model = AutoModelForSequenceClassification.from_pretrained(path_model)
tokenizer = AutoTokenizer.from_pretrained(path_model)
return model, tokenizer
model, tokenizer = load_model()
model.eval()
id2label = model.config.id2label
def predict_proba(title, abstract):
inputs = tokenizer(title,
abstract,
return_tensors="pt",
truncation=True,
padding=True,
max_length=256)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs).logits.cpu()
probs = torch.softmax(outputs, dim=1).numpy()[0]
return probs
def top_95(probs):
sorted_indices = np.argsort(probs)[::-1]
total = 0
result = []
for idx in sorted_indices:
total += probs[idx]
result.append((id2label[idx], float(probs[idx])))
if total >= 0.95:
break
return result
class_name = {
"cs": "Computer Science",
"econ": "Economics",
"eess": "Electrical Engineering and Systems Science",
"math": "Mathematics",
"physics": "Physics",
"q-bio": "Quantitative Biology",
"q-fin": "Quantitative Finance",
"stat": "Statistics",
}
tags_str = "\n\n ".join([f"✅ {key} - {value}" for key, value in class_name.items()])
st.title("📚 Arxiv Article Classifier")
st.markdown(
"Введите название статьи и abstrcat (можно оставить пустым).\n\nПолучите пресказание тематики статьи (top-95%)"
)
st.markdown(
f"Тематики:\n\n {tags_str}"
)
title = st.text_input("📝 Title статьи")
abstract = st.text_area("📄 Abstract статьи (можно оставить пустым)")
if st.button("🔍 Классифицировать"):
if title.strip() == "":
st.error("Пожалуйста, введите название статьи.")
else:
probs = predict_proba(title, abstract)
results = top_95(probs)
if len(results) >= 6:
st.error("Возможно, ваша статья не соответствует ни одной из доступных категорий.")
st.subheader("📊 Результат:")
for tag, prob in results:
st.write(f"**{tag}. {class_name[tag]}** — {prob:.2%}")
total_prob = sum([p for _, p in results])
st.caption(f"Суммарная вероятность: {total_prob:.2%}")