Spaces:
Sleeping
Sleeping
| """arXiv Topic Classifier — Streamlit web UI. | |
| Fine-tuned DistilBERT predicts the top-level arxiv category for a paper given | |
| its title and (optionally) abstract. The UI shows topics whose cumulative | |
| probability covers >=95%, sorted by descending confidence. | |
| The model is loaded from one of: | |
| 1. HF Hub repo specified in env var ARXIV_MODEL_REPO (e.g. "user/arxiv-clf") | |
| 2. Local directory ./model/ (produced by train.ipynb) | |
| Set ARXIV_MODEL_REPO before launching to use a hosted model on HF Spaces. | |
| Device selection is automatic: MPS (Apple Silicon) → CUDA → CPU. On HF Spaces | |
| free tier this falls back to CPU. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import List, Tuple | |
| # Allow rare ops without an MPS kernel to fall back to CPU instead of crashing. | |
| # Must be set before torch is imported. | |
| os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1") | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| import streamlit as st | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| # --------------------------------------------------------------------------- | |
| # Configuration | |
| # --------------------------------------------------------------------------- | |
| DEFAULT_LOCAL_MODEL_DIR = Path(__file__).parent / "model" | |
| HF_REPO_ENV_VAR = "ARXIV_MODEL_REPO" | |
| MAX_LENGTH_FALLBACK = 256 | |
| TOP_P_DEFAULT = 0.95 | |
| def _select_device() -> torch.device: | |
| """Pick the best available device: MPS → CUDA → CPU. | |
| On Apple Silicon (M1/M2/M3) MPS gives a major speedup over CPU. On HF | |
| Spaces free tier neither MPS nor CUDA is available so we fall back to | |
| CPU automatically. | |
| """ | |
| if torch.backends.mps.is_available() and torch.backends.mps.is_built(): | |
| return torch.device("mps") | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| return torch.device("cpu") | |
| DEVICE = _select_device() | |
| # Human-readable names for arxiv top-level categories. Used as a fallback if | |
| # label_meta.json does not provide pretty_names for some label. | |
| PRETTY_NAMES_FALLBACK = { | |
| "astro-ph": "Astrophysics", | |
| "cond-mat": "Condensed Matter Physics", | |
| "cs": "Computer Science", | |
| "econ": "Economics", | |
| "eess": "Electrical Engineering & Systems", | |
| "gr-qc": "General Relativity & Quantum Cosmology", | |
| "hep-ex": "High Energy Physics — Experiment", | |
| "hep-lat": "High Energy Physics — Lattice", | |
| "hep-ph": "High Energy Physics — Phenomenology", | |
| "hep-th": "High Energy Physics — Theory", | |
| "math": "Mathematics", | |
| "math-ph": "Mathematical Physics", | |
| "nlin": "Nonlinear Sciences", | |
| "nucl-ex": "Nuclear Physics — Experiment", | |
| "nucl-th": "Nuclear Physics — Theory", | |
| "physics": "Physics (general)", | |
| "q-bio": "Quantitative Biology", | |
| "q-fin": "Quantitative Finance", | |
| "quant-ph": "Quantum Physics", | |
| "stat": "Statistics", | |
| } | |
| EXAMPLES = [ | |
| { | |
| "name": "Transformers paper (cs)", | |
| "title": "Attention Is All You Need", | |
| "abstract": ( | |
| "The dominant sequence transduction models are based on complex " | |
| "recurrent or convolutional neural networks that include an " | |
| "encoder and a decoder. We propose a new simple network " | |
| "architecture, the Transformer, based solely on attention " | |
| "mechanisms, dispensing with recurrence and convolutions entirely." | |
| ), | |
| }, | |
| { | |
| "name": "Algebraic geometry (math)", | |
| "title": "On the Hodge conjecture for products of certain K3 surfaces", | |
| "abstract": ( | |
| "We prove the Hodge conjecture for self-products of certain K3 " | |
| "surfaces using transcendental cycles, motivic methods, and " | |
| "Kuga-Satake constructions." | |
| ), | |
| }, | |
| { | |
| "name": "TeV gamma astronomy (astro-ph)", | |
| "title": "Observation of TeV gamma rays from blazar Mrk 421 with VERITAS", | |
| "abstract": ( | |
| "We report observations of very high energy gamma-ray emission " | |
| "from the blazar Markarian 421 conducted with the VERITAS array " | |
| "of imaging atmospheric Cherenkov telescopes." | |
| ), | |
| }, | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # Model loading (cached so we only do it once per session) | |
| # --------------------------------------------------------------------------- | |
| def load_model_and_tokenizer(): | |
| """Load model + tokenizer + label metadata. | |
| Returns a dict with keys: model, tokenizer, id2label, pretty_names, | |
| max_length, source. | |
| """ | |
| repo = os.environ.get(HF_REPO_ENV_VAR, "").strip() | |
| source: str | |
| label_meta_path: Path | None = None | |
| if repo: | |
| # Hub: download model files. label_meta.json is fetched separately if | |
| # present (HF auto-downloads only the model files via from_pretrained). | |
| source = f"HF Hub: {repo}" | |
| tokenizer = AutoTokenizer.from_pretrained(repo) | |
| model = AutoModelForSequenceClassification.from_pretrained(repo) | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| label_meta_path = Path( | |
| hf_hub_download(repo_id=repo, filename="label_meta.json") | |
| ) | |
| except Exception: | |
| label_meta_path = None | |
| elif DEFAULT_LOCAL_MODEL_DIR.exists(): | |
| source = f"local dir: {DEFAULT_LOCAL_MODEL_DIR}" | |
| tokenizer = AutoTokenizer.from_pretrained(DEFAULT_LOCAL_MODEL_DIR) | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| DEFAULT_LOCAL_MODEL_DIR | |
| ) | |
| candidate = DEFAULT_LOCAL_MODEL_DIR / "label_meta.json" | |
| label_meta_path = candidate if candidate.exists() else None | |
| else: | |
| raise FileNotFoundError( | |
| "No model found. Either set environment variable " | |
| f"{HF_REPO_ENV_VAR} to a HuggingFace repo id, or place a trained " | |
| f"model in {DEFAULT_LOCAL_MODEL_DIR} (run train.ipynb)." | |
| ) | |
| model.to(DEVICE) | |
| model.eval() | |
| # Resolve labels: prefer label_meta.json, fall back to model config. | |
| if label_meta_path is not None: | |
| meta = json.loads(label_meta_path.read_text()) | |
| id2label = {int(k): v for k, v in meta["id2label"].items()} | |
| pretty_names = meta.get("pretty_names", {}) | |
| max_length = int(meta.get("max_length", MAX_LENGTH_FALLBACK)) | |
| else: | |
| id2label = {int(k): v for k, v in model.config.id2label.items()} | |
| pretty_names = {} | |
| max_length = MAX_LENGTH_FALLBACK | |
| # Fill in any missing pretty names with our fallback table. | |
| pretty_names = { | |
| lab: pretty_names.get(lab, PRETTY_NAMES_FALLBACK.get(lab, lab)) | |
| for lab in id2label.values() | |
| } | |
| return { | |
| "model": model, | |
| "tokenizer": tokenizer, | |
| "id2label": id2label, | |
| "pretty_names": pretty_names, | |
| "max_length": max_length, | |
| "source": source, | |
| "device": str(DEVICE), | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Inference | |
| # --------------------------------------------------------------------------- | |
| def build_input_text(title: str, abstract: str) -> str: | |
| """Combine title and (optional) abstract into one input string. | |
| The training notebook used `title + ". " + abstract`. We replicate that | |
| so inference matches the training distribution. If the abstract is empty | |
| we fall back to title-only — the model still works, just with less | |
| context. | |
| """ | |
| title = title.strip() | |
| abstract = abstract.strip() | |
| if abstract: | |
| return f"{title}. {abstract}" | |
| return title | |
| def predict( | |
| title: str, abstract: str, top_p: float = TOP_P_DEFAULT | |
| ) -> List[Tuple[str, str, float]]: | |
| """Return [(label, pretty_name, prob)] covering top-p of the mass.""" | |
| bundle = load_model_and_tokenizer() | |
| model = bundle["model"] | |
| tokenizer = bundle["tokenizer"] | |
| id2label = bundle["id2label"] | |
| pretty_names = bundle["pretty_names"] | |
| max_length = bundle["max_length"] | |
| text = build_input_text(title, abstract) | |
| enc = tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=max_length, | |
| return_tensors="pt", | |
| ) | |
| enc = {k: v.to(DEVICE) for k, v in enc.items()} | |
| logits = model(**enc).logits[0] | |
| probs = F.softmax(logits, dim=-1).cpu().numpy() | |
| order = probs.argsort()[::-1] | |
| cumulative = 0.0 | |
| out: list[tuple[str, str, float]] = [] | |
| for idx in order: | |
| label = id2label[int(idx)] | |
| pretty = pretty_names.get(label, label) | |
| prob = float(probs[idx]) | |
| out.append((label, pretty, prob)) | |
| cumulative += prob | |
| if cumulative >= top_p: | |
| break | |
| return out | |
| # --------------------------------------------------------------------------- | |
| # UI | |
| # --------------------------------------------------------------------------- | |
| def render_results(results: List[Tuple[str, str, float]]) -> None: | |
| st.subheader("Predicted topics") | |
| st.caption( | |
| "Showing the smallest set of topics whose total probability is at least " | |
| f"{int(TOP_P_DEFAULT * 100)}%." | |
| ) | |
| top_label, top_pretty, top_prob = results[0] | |
| st.success(f"**Best guess:** {top_pretty} · `{top_label}` · {top_prob:.1%}") | |
| for label, pretty, prob in results: | |
| col1, col2 = st.columns([3, 1]) | |
| with col1: | |
| st.markdown(f"**{pretty}** · `{label}`") | |
| st.progress(min(max(prob, 0.0), 1.0)) | |
| with col2: | |
| st.metric(label=" ", value=f"{prob:.1%}", label_visibility="collapsed") | |
| def main() -> None: | |
| st.set_page_config( | |
| page_title="arXiv Topic Classifier", | |
| page_icon=":bookmark_tabs:", | |
| layout="centered", | |
| ) | |
| st.title("arXiv Topic Classifier") | |
| st.markdown( | |
| "Paste a paper's **title** (and optionally its **abstract**) — the " | |
| "model will tell you which arXiv categories it most likely belongs to. " | |
| "Powered by a fine-tuned DistilBERT." | |
| ) | |
| # Try to load model up front so config issues surface immediately. | |
| try: | |
| bundle = load_model_and_tokenizer() | |
| except Exception as exc: | |
| st.error( | |
| "Failed to load the classification model.\n\n" | |
| f"**Reason:** {exc}\n\n" | |
| "If you are running locally, train a model with `train.ipynb` " | |
| "and re-launch. On HuggingFace Spaces, set the secret " | |
| f"`{HF_REPO_ENV_VAR}` to a model repo id." | |
| ) | |
| st.stop() | |
| with st.sidebar: | |
| st.header("About") | |
| st.markdown( | |
| "**Task.** Classify an arXiv paper into one of " | |
| f"{len(bundle['id2label'])} top-level categories.\n\n" | |
| "**Model.** Fine-tuned `distilbert-base-uncased`.\n\n" | |
| "**Input.** Title is required; abstract is optional but helps a lot." | |
| ) | |
| st.caption(f"Model source: {bundle['source']}") | |
| st.caption(f"Inference device: `{bundle['device']}`") | |
| st.header("Try an example") | |
| for ex in EXAMPLES: | |
| if st.button(ex["name"], use_container_width=True): | |
| st.session_state["title_input"] = ex["title"] | |
| st.session_state["abstract_input"] = ex["abstract"] | |
| st.session_state.setdefault("title_input", "") | |
| st.session_state.setdefault("abstract_input", "") | |
| title = st.text_input( | |
| "Paper title", | |
| key="title_input", | |
| placeholder="e.g. Attention Is All You Need", | |
| help="Required. The full paper title.", | |
| ) | |
| abstract = st.text_area( | |
| "Abstract (optional)", | |
| key="abstract_input", | |
| height=180, | |
| placeholder=( | |
| "Paste the paper abstract here. If you leave this empty the model " | |
| "will classify by title only — predictions will be less confident." | |
| ), | |
| help="Optional but strongly recommended.", | |
| ) | |
| classify_clicked = st.button("Classify", type="primary", use_container_width=True) | |
| if not classify_clicked: | |
| return | |
| # Input validation | |
| if not title.strip(): | |
| st.warning( | |
| "Please enter at least a paper title. The abstract is optional, " | |
| "but the title is required." | |
| ) | |
| return | |
| if len(title.strip()) < 5: | |
| st.warning( | |
| "That title looks unusually short. Please paste the full paper " | |
| "title for a meaningful prediction." | |
| ) | |
| return | |
| # Inference | |
| try: | |
| with st.spinner("Classifying…"): | |
| results = predict(title, abstract) | |
| except Exception as exc: | |
| st.error( | |
| "Something went wrong while running the model.\n\n" | |
| f"**Details:** {exc}\n\n" | |
| "Try shortening the abstract or refresh the page." | |
| ) | |
| return | |
| if not results: | |
| st.error("The model returned no predictions. This should not happen — " | |
| "please report it.") | |
| return | |
| render_results(results) | |
| if not abstract.strip(): | |
| st.info( | |
| "Tip: paste the abstract for a noticeably more confident " | |
| "prediction." | |
| ) | |
| if __name__ == "__main__": | |
| main() | |