YakovPodlesnov's picture
Update app.py
1b2eba5 verified
import streamlit as st
import torch
import joblib
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from huggingface_hub import hf_hub_download
@st.cache_resource
def load_model():
repo_id = "YakovPodlesnov/best-article-classifier"
model = AutoModelForSequenceClassification.from_pretrained(repo_id)
tokenizer = AutoTokenizer.from_pretrained(repo_id)
label_encoder_path = hf_hub_download(
repo_id=repo_id,
filename="label_encoder.joblib"
)
label_encoder = joblib.load(label_encoder_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
return model, tokenizer, label_encoder, device
def predict(text, model, tokenizer, label_encoder, device):
inputs = tokenizer(
text,
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy()[0]
sorted_indices = np.argsort(probs)[::-1]
cumulative = 0.0
results = []
for idx in sorted_indices:
if cumulative >= 0.95:
break
label = label_encoder.inverse_transform([idx])[0]
prob = float(probs[idx])
results.append((label, prob))
cumulative += prob
return results
st.set_page_config(page_title="arXiv Classifier", layout="wide")
st.title("Классификатор научных статей")
st.markdown("Определяем тематику статьи по заголовку и аннотации (arXiv)")
with st.spinner("Инициализация модели..."):
try:
model, tokenizer, label_encoder, device = load_model()
except Exception as e:
st.error(f"Ошибка загрузки модели: {str(e)}")
st.stop()
col1, col2 = st.columns([1, 1])
with col1:
title = st.text_input("Название статьи:",
placeholder="Attention Is All You Need",
help="Введите полное название статьи")
with col2:
abstract = st.text_area("Аннотация:",
height=150,
placeholder="We propose a new simple network architecture...",
help="Введите текст аннотации (необязательно)")
if st.button("Определить категории", type="primary"):
if not title:
st.error("Пожалуйста, введите название статьи")
else:
text = title + (" " + abstract if abstract else "")
with st.spinner("Анализируем текст..."):
try:
predictions = predict(text, model, tokenizer, label_encoder, device)
except Exception as e:
st.error(f"Ошибка предсказания: {str(e)}")
st.stop()
st.subheader(" Результаты классификации")
st.markdown("""
Расшифровку названий, например для cs, можно посмотреть здесь: https://arxiv.org/archive/cs
""")
if not predictions:
st.warning("Не удалось определить категории")
else:
cols = st.columns([2, 3])
with cols[0]:
st.markdown("**Топ-категории:**")
for label, prob in predictions:
st.markdown(f"▸ {label} ({prob:.1%})")
with cols[1]:
labels = [p[0] for p in predictions]
probs = [p[1] for p in predictions]
st.bar_chart(
dict(zip(labels, probs)),
use_container_width=True,
color="#FF4B4B"
)
st.markdown("---")
st.markdown("""
### Как это работает?
1. Введите **название** научной статьи
2. При необходимости добавьте **аннотацию**
3. Нажмите кнопку для классификации статьи
4. Получите предсказанные категории arXiv
**Модель: scibert_scivocab_uncased**
""")
if __name__ == "__main__":
pass