HW4-Service / src /streamlit_app.py
kilinkarov's picture
Update src/streamlit_app.py
d6de007 verified
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from pathlib import Path
MODEL_PATH = str(Path(__file__).resolve().parent.parent / "model")
TOKENIZER_NAME = "oracat/bert-paper-classifier-arxiv"
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
model.eval()
return tokenizer, model
st.set_page_config(page_title="Arxiv Paper Classifier", layout="centered")
st.title("Arxiv Paper Classifier")
st.markdown("Classify an academic paper into an arxiv category by its title and/or abstract.")
title = st.text_input("Paper title (optional)", placeholder="e.g. Attention Is All You Need")
abstract = st.text_area("Abstract (optional)", placeholder="Paste the paper abstract here...", height=200)
if st.button("Classify", type="primary"):
if not title and not abstract:
st.error("Please enter at least a title or an abstract.")
else:
tokenizer, model = load_model()
if title and abstract:
inputs = tokenizer(title, abstract, truncation=True, max_length=256, return_tensors="pt")
elif title:
inputs = tokenizer(title, truncation=True, max_length=256, return_tensors="pt")
else:
inputs = tokenizer(abstract, truncation=True, max_length=256, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
probs = F.softmax(logits, dim=-1).squeeze()
sorted_indices = torch.argsort(probs, descending=True)
id2label = model.config.id2label
cumulative = 0.0
results = []
for idx in sorted_indices:
idx = idx.item()
prob = probs[idx].item()
cumulative += prob
results.append((id2label[idx], prob))
if cumulative >= 0.95:
break
st.subheader("Predicted categories (top-95%)")
for label, prob in results:
col1, col2 = st.columns([3, 7])
with col1:
st.markdown(f"**{label}**")
with col2:
st.progress(prob, text=f"{prob:.1%}")