arxiv-classifier / src /streamlit_app.py
whytimmy's picture
Update src/streamlit_app.py
84aba52 verified
import streamlit as st
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import json
import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODEL_DIR = os.path.join(BASE_DIR, "arxiv_dir")
st.set_page_config(
page_title="Arxiv Classifier",
page_icon="🚀",
layout="wide",
initial_sidebar_state="collapsed"
)
st.title("Arxiv Classifier")
@st.cache_resource
def load_model():
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
with open(os.path.join(MODEL_DIR, "id2tag.json")) as f:
id2tag = {int(k): v for k, v in json.load(f).items()}
with open(os.path.join(MODEL_DIR, "tag2name.json")) as f:
tag2name = json.load(f)
model.eval()
return model, tokenizer, id2tag, tag2name
model, tokenizer, id2tag, tag2name = load_model()
def predict_top95(title, summary=None):
text = title
if summary:
text += " [SEP] " + summary
tokens = tokenizer(text, truncation=True, max_length=512, return_tensors="pt")
tokens = {k: v.to(model.device) for k, v in tokens.items()}
model.eval()
with torch.no_grad():
logits = model(**tokens).logits
probs = torch.softmax(logits, dim=-1)[0]
sorted_probs, sorted_idx = probs.sort(descending=True)
cumsum = sorted_probs.cumsum(dim=-1)
mask = (cumsum - sorted_probs) < 0.95
results = []
for prob, idx in zip(sorted_probs[mask], sorted_idx[mask]):
results.append((id2tag[idx.item()], prob.item()))
return results
def colored_bar(prob):
if prob > 0.4:
color = "#2ecc71"
elif prob > 0.15:
color = "#f39c12"
else:
color = "#e74c3c"
st.markdown(f"""
<div style="background:#e0e0e0;border-radius:4px;height:10px;margin:4px 0">
<div style="background:{color};width:{prob*100:.1f}%;height:100%;border-radius:4px"></div>
</div>
""", unsafe_allow_html=True)
col1, col2 = st.columns([1, 1])
with col1:
name = st.text_input("Название статьи")
abstract = st.text_area("Abstract", height=200)
clicked = st.button("Классифицировать")
with col2:
if clicked:
if not name and not abstract:
st.warning("Введите название или abstract")
else:
results = predict_top95(name, abstract if abstract else None)
st.markdown(f"### Результаты — {len(results)} {'класс' if len(results) == 1 else 'класса' if len(results) < 5 else 'классов'}")
visible = results[:10]
hidden = results[10:]
for tag, prob in visible:
label = tag2name.get(tag, tag)
st.markdown(f"**{label}** `{tag}`")
colored_bar(prob)
st.caption(f"{prob:.1%}")
if hidden:
with st.expander(f"Показать ещё {len(hidden)}"):
for tag, prob in hidden:
label = tag2name.get(tag, tag)
st.markdown(f"**{label}** `{tag}`")
colored_bar(prob)
st.caption(f"{prob:.1%}")