import streamlit as st import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForSequenceClassification from pathlib import Path MODEL_PATH = str(Path(__file__).resolve().parent.parent / "model") TOKENIZER_NAME = "oracat/bert-paper-classifier-arxiv" @st.cache_resource def load_model(): tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) model.eval() return tokenizer, model st.set_page_config(page_title="Arxiv Paper Classifier", layout="centered") st.title("Arxiv Paper Classifier") st.markdown("Classify an academic paper into an arxiv category by its title and/or abstract.") title = st.text_input("Paper title (optional)", placeholder="e.g. Attention Is All You Need") abstract = st.text_area("Abstract (optional)", placeholder="Paste the paper abstract here...", height=200) if st.button("Classify", type="primary"): if not title and not abstract: st.error("Please enter at least a title or an abstract.") else: tokenizer, model = load_model() if title and abstract: inputs = tokenizer(title, abstract, truncation=True, max_length=256, return_tensors="pt") elif title: inputs = tokenizer(title, truncation=True, max_length=256, return_tensors="pt") else: inputs = tokenizer(abstract, truncation=True, max_length=256, return_tensors="pt") with torch.no_grad(): logits = model(**inputs).logits probs = F.softmax(logits, dim=-1).squeeze() sorted_indices = torch.argsort(probs, descending=True) id2label = model.config.id2label cumulative = 0.0 results = [] for idx in sorted_indices: idx = idx.item() prob = probs[idx].item() cumulative += prob results.append((id2label[idx], prob)) if cumulative >= 0.95: break st.subheader("Predicted categories (top-95%)") for label, prob in results: col1, col2 = st.columns([3, 7]) with col1: st.markdown(f"**{label}**") with col2: st.progress(prob, text=f"{prob:.1%}")