"""arXiv Topic Classifier — Streamlit web UI. Fine-tuned DistilBERT predicts the top-level arxiv category for a paper given its title and (optionally) abstract. The UI shows topics whose cumulative probability covers >=95%, sorted by descending confidence. The model is loaded from one of: 1. HF Hub repo specified in env var ARXIV_MODEL_REPO (e.g. "user/arxiv-clf") 2. Local directory ./model/ (produced by train.ipynb) Set ARXIV_MODEL_REPO before launching to use a hosted model on HF Spaces. Device selection is automatic: MPS (Apple Silicon) → CUDA → CPU. On HF Spaces free tier this falls back to CPU. """ from __future__ import annotations import json import os from pathlib import Path from typing import List, Tuple # Allow rare ops without an MPS kernel to fall back to CPU instead of crashing. # Must be set before torch is imported. os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1") os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") import streamlit as st import torch import torch.nn.functional as F from transformers import AutoModelForSequenceClassification, AutoTokenizer # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- DEFAULT_LOCAL_MODEL_DIR = Path(__file__).parent / "model" HF_REPO_ENV_VAR = "ARXIV_MODEL_REPO" MAX_LENGTH_FALLBACK = 256 TOP_P_DEFAULT = 0.95 def _select_device() -> torch.device: """Pick the best available device: MPS → CUDA → CPU. On Apple Silicon (M1/M2/M3) MPS gives a major speedup over CPU. On HF Spaces free tier neither MPS nor CUDA is available so we fall back to CPU automatically. """ if torch.backends.mps.is_available() and torch.backends.mps.is_built(): return torch.device("mps") if torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") DEVICE = _select_device() # Human-readable names for arxiv top-level categories. Used as a fallback if # label_meta.json does not provide pretty_names for some label. PRETTY_NAMES_FALLBACK = { "astro-ph": "Astrophysics", "cond-mat": "Condensed Matter Physics", "cs": "Computer Science", "econ": "Economics", "eess": "Electrical Engineering & Systems", "gr-qc": "General Relativity & Quantum Cosmology", "hep-ex": "High Energy Physics — Experiment", "hep-lat": "High Energy Physics — Lattice", "hep-ph": "High Energy Physics — Phenomenology", "hep-th": "High Energy Physics — Theory", "math": "Mathematics", "math-ph": "Mathematical Physics", "nlin": "Nonlinear Sciences", "nucl-ex": "Nuclear Physics — Experiment", "nucl-th": "Nuclear Physics — Theory", "physics": "Physics (general)", "q-bio": "Quantitative Biology", "q-fin": "Quantitative Finance", "quant-ph": "Quantum Physics", "stat": "Statistics", } EXAMPLES = [ { "name": "Transformers paper (cs)", "title": "Attention Is All You Need", "abstract": ( "The dominant sequence transduction models are based on complex " "recurrent or convolutional neural networks that include an " "encoder and a decoder. We propose a new simple network " "architecture, the Transformer, based solely on attention " "mechanisms, dispensing with recurrence and convolutions entirely." ), }, { "name": "Algebraic geometry (math)", "title": "On the Hodge conjecture for products of certain K3 surfaces", "abstract": ( "We prove the Hodge conjecture for self-products of certain K3 " "surfaces using transcendental cycles, motivic methods, and " "Kuga-Satake constructions." ), }, { "name": "TeV gamma astronomy (astro-ph)", "title": "Observation of TeV gamma rays from blazar Mrk 421 with VERITAS", "abstract": ( "We report observations of very high energy gamma-ray emission " "from the blazar Markarian 421 conducted with the VERITAS array " "of imaging atmospheric Cherenkov telescopes." ), }, ] # --------------------------------------------------------------------------- # Model loading (cached so we only do it once per session) # --------------------------------------------------------------------------- @st.cache_resource(show_spinner="Loading model… (only happens once)") def load_model_and_tokenizer(): """Load model + tokenizer + label metadata. Returns a dict with keys: model, tokenizer, id2label, pretty_names, max_length, source. """ repo = os.environ.get(HF_REPO_ENV_VAR, "").strip() source: str label_meta_path: Path | None = None if repo: # Hub: download model files. label_meta.json is fetched separately if # present (HF auto-downloads only the model files via from_pretrained). source = f"HF Hub: {repo}" tokenizer = AutoTokenizer.from_pretrained(repo) model = AutoModelForSequenceClassification.from_pretrained(repo) try: from huggingface_hub import hf_hub_download label_meta_path = Path( hf_hub_download(repo_id=repo, filename="label_meta.json") ) except Exception: label_meta_path = None elif DEFAULT_LOCAL_MODEL_DIR.exists(): source = f"local dir: {DEFAULT_LOCAL_MODEL_DIR}" tokenizer = AutoTokenizer.from_pretrained(DEFAULT_LOCAL_MODEL_DIR) model = AutoModelForSequenceClassification.from_pretrained( DEFAULT_LOCAL_MODEL_DIR ) candidate = DEFAULT_LOCAL_MODEL_DIR / "label_meta.json" label_meta_path = candidate if candidate.exists() else None else: raise FileNotFoundError( "No model found. Either set environment variable " f"{HF_REPO_ENV_VAR} to a HuggingFace repo id, or place a trained " f"model in {DEFAULT_LOCAL_MODEL_DIR} (run train.ipynb)." ) model.to(DEVICE) model.eval() # Resolve labels: prefer label_meta.json, fall back to model config. if label_meta_path is not None: meta = json.loads(label_meta_path.read_text()) id2label = {int(k): v for k, v in meta["id2label"].items()} pretty_names = meta.get("pretty_names", {}) max_length = int(meta.get("max_length", MAX_LENGTH_FALLBACK)) else: id2label = {int(k): v for k, v in model.config.id2label.items()} pretty_names = {} max_length = MAX_LENGTH_FALLBACK # Fill in any missing pretty names with our fallback table. pretty_names = { lab: pretty_names.get(lab, PRETTY_NAMES_FALLBACK.get(lab, lab)) for lab in id2label.values() } return { "model": model, "tokenizer": tokenizer, "id2label": id2label, "pretty_names": pretty_names, "max_length": max_length, "source": source, "device": str(DEVICE), } # --------------------------------------------------------------------------- # Inference # --------------------------------------------------------------------------- def build_input_text(title: str, abstract: str) -> str: """Combine title and (optional) abstract into one input string. The training notebook used `title + ". " + abstract`. We replicate that so inference matches the training distribution. If the abstract is empty we fall back to title-only — the model still works, just with less context. """ title = title.strip() abstract = abstract.strip() if abstract: return f"{title}. {abstract}" return title @torch.inference_mode() def predict( title: str, abstract: str, top_p: float = TOP_P_DEFAULT ) -> List[Tuple[str, str, float]]: """Return [(label, pretty_name, prob)] covering top-p of the mass.""" bundle = load_model_and_tokenizer() model = bundle["model"] tokenizer = bundle["tokenizer"] id2label = bundle["id2label"] pretty_names = bundle["pretty_names"] max_length = bundle["max_length"] text = build_input_text(title, abstract) enc = tokenizer( text, truncation=True, max_length=max_length, return_tensors="pt", ) enc = {k: v.to(DEVICE) for k, v in enc.items()} logits = model(**enc).logits[0] probs = F.softmax(logits, dim=-1).cpu().numpy() order = probs.argsort()[::-1] cumulative = 0.0 out: list[tuple[str, str, float]] = [] for idx in order: label = id2label[int(idx)] pretty = pretty_names.get(label, label) prob = float(probs[idx]) out.append((label, pretty, prob)) cumulative += prob if cumulative >= top_p: break return out # --------------------------------------------------------------------------- # UI # --------------------------------------------------------------------------- def render_results(results: List[Tuple[str, str, float]]) -> None: st.subheader("Predicted topics") st.caption( "Showing the smallest set of topics whose total probability is at least " f"{int(TOP_P_DEFAULT * 100)}%." ) top_label, top_pretty, top_prob = results[0] st.success(f"**Best guess:** {top_pretty} · `{top_label}` · {top_prob:.1%}") for label, pretty, prob in results: col1, col2 = st.columns([3, 1]) with col1: st.markdown(f"**{pretty}** · `{label}`") st.progress(min(max(prob, 0.0), 1.0)) with col2: st.metric(label=" ", value=f"{prob:.1%}", label_visibility="collapsed") def main() -> None: st.set_page_config( page_title="arXiv Topic Classifier", page_icon=":bookmark_tabs:", layout="centered", ) st.title("arXiv Topic Classifier") st.markdown( "Paste a paper's **title** (and optionally its **abstract**) — the " "model will tell you which arXiv categories it most likely belongs to. " "Powered by a fine-tuned DistilBERT." ) # Try to load model up front so config issues surface immediately. try: bundle = load_model_and_tokenizer() except Exception as exc: st.error( "Failed to load the classification model.\n\n" f"**Reason:** {exc}\n\n" "If you are running locally, train a model with `train.ipynb` " "and re-launch. On HuggingFace Spaces, set the secret " f"`{HF_REPO_ENV_VAR}` to a model repo id." ) st.stop() with st.sidebar: st.header("About") st.markdown( "**Task.** Classify an arXiv paper into one of " f"{len(bundle['id2label'])} top-level categories.\n\n" "**Model.** Fine-tuned `distilbert-base-uncased`.\n\n" "**Input.** Title is required; abstract is optional but helps a lot." ) st.caption(f"Model source: {bundle['source']}") st.caption(f"Inference device: `{bundle['device']}`") st.header("Try an example") for ex in EXAMPLES: if st.button(ex["name"], use_container_width=True): st.session_state["title_input"] = ex["title"] st.session_state["abstract_input"] = ex["abstract"] st.session_state.setdefault("title_input", "") st.session_state.setdefault("abstract_input", "") title = st.text_input( "Paper title", key="title_input", placeholder="e.g. Attention Is All You Need", help="Required. The full paper title.", ) abstract = st.text_area( "Abstract (optional)", key="abstract_input", height=180, placeholder=( "Paste the paper abstract here. If you leave this empty the model " "will classify by title only — predictions will be less confident." ), help="Optional but strongly recommended.", ) classify_clicked = st.button("Classify", type="primary", use_container_width=True) if not classify_clicked: return # Input validation if not title.strip(): st.warning( "Please enter at least a paper title. The abstract is optional, " "but the title is required." ) return if len(title.strip()) < 5: st.warning( "That title looks unusually short. Please paste the full paper " "title for a meaningful prediction." ) return # Inference try: with st.spinner("Classifying…"): results = predict(title, abstract) except Exception as exc: st.error( "Something went wrong while running the model.\n\n" f"**Details:** {exc}\n\n" "Try shortening the abstract or refresh the page." ) return if not results: st.error("The model returned no predictions. This should not happen — " "please report it.") return render_results(results) if not abstract.strip(): st.info( "Tip: paste the abstract for a noticeably more confident " "prediction." ) if __name__ == "__main__": main()