tienti0000's picture
fixed local model path
a6c5bcf
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np
@st.cache_resource
def load_model():
model = AutoModelForSequenceClassification.from_pretrained("arxiv_model",
trust_remote_code=True,
local_files_only=True)
tokenizer = AutoTokenizer.from_pretrained("arxiv_model", local_files_only=True)
return model, tokenizer
model, tokenizer = load_model()
id2label = model.config.id2label
st.title("🔬 ArXiv Article Classifier")
st.markdown("Введите **название** и (по желанию) **аннотацию** статьи. Сервис предскажет вероятные темы!")
title_input = st.text_input("Название статьи")
abstract_input = st.text_area("Аннотация (необязательно)")
if st.button("Классифицировать") and title_input:
text = title_input.strip()
if abstract_input.strip():
text += " " + abstract_input.strip()
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=1).numpy()[0]
sorted_indices = np.argsort(probs)[::-1]
# top_labels = [(id2label[str(i)], probs[i]) for i in sorted_indices]
top_labels = [(id2label[i], probs[i]) for i in sorted_indices]
cumulative = 0.0
top95 = []
for label, prob in top_labels:
top95.append((label, prob))
cumulative += prob
if cumulative >= 0.95:
break
st.markdown(f"### 🎯 Основная тема: `{top_labels[0][0]}` ({top_labels[0][1]*100:.2f}%)")
st.markdown("### 📋 Категории (до 95% суммарной вероятности):")
for label, prob in top95:
st.write(f"- `{label}`: {prob*100:.2f}%")
else:
st.markdown("_Введите название статьи и нажмите кнопку выше_")