Dmitry057's picture
Add app.py
e40613d verified
"""arXiv Topic Classifier — Streamlit web UI.
Fine-tuned DistilBERT predicts the top-level arxiv category for a paper given
its title and (optionally) abstract. The UI shows topics whose cumulative
probability covers >=95%, sorted by descending confidence.
The model is loaded from one of:
1. HF Hub repo specified in env var ARXIV_MODEL_REPO (e.g. "user/arxiv-clf")
2. Local directory ./model/ (produced by train.ipynb)
Set ARXIV_MODEL_REPO before launching to use a hosted model on HF Spaces.
Device selection is automatic: MPS (Apple Silicon) → CUDA → CPU. On HF Spaces
free tier this falls back to CPU.
"""
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import List, Tuple
# Allow rare ops without an MPS kernel to fall back to CPU instead of crashing.
# Must be set before torch is imported.
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
DEFAULT_LOCAL_MODEL_DIR = Path(__file__).parent / "model"
HF_REPO_ENV_VAR = "ARXIV_MODEL_REPO"
MAX_LENGTH_FALLBACK = 256
TOP_P_DEFAULT = 0.95
def _select_device() -> torch.device:
"""Pick the best available device: MPS → CUDA → CPU.
On Apple Silicon (M1/M2/M3) MPS gives a major speedup over CPU. On HF
Spaces free tier neither MPS nor CUDA is available so we fall back to
CPU automatically.
"""
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
return torch.device("mps")
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
DEVICE = _select_device()
# Human-readable names for arxiv top-level categories. Used as a fallback if
# label_meta.json does not provide pretty_names for some label.
PRETTY_NAMES_FALLBACK = {
"astro-ph": "Astrophysics",
"cond-mat": "Condensed Matter Physics",
"cs": "Computer Science",
"econ": "Economics",
"eess": "Electrical Engineering & Systems",
"gr-qc": "General Relativity & Quantum Cosmology",
"hep-ex": "High Energy Physics — Experiment",
"hep-lat": "High Energy Physics — Lattice",
"hep-ph": "High Energy Physics — Phenomenology",
"hep-th": "High Energy Physics — Theory",
"math": "Mathematics",
"math-ph": "Mathematical Physics",
"nlin": "Nonlinear Sciences",
"nucl-ex": "Nuclear Physics — Experiment",
"nucl-th": "Nuclear Physics — Theory",
"physics": "Physics (general)",
"q-bio": "Quantitative Biology",
"q-fin": "Quantitative Finance",
"quant-ph": "Quantum Physics",
"stat": "Statistics",
}
EXAMPLES = [
{
"name": "Transformers paper (cs)",
"title": "Attention Is All You Need",
"abstract": (
"The dominant sequence transduction models are based on complex "
"recurrent or convolutional neural networks that include an "
"encoder and a decoder. We propose a new simple network "
"architecture, the Transformer, based solely on attention "
"mechanisms, dispensing with recurrence and convolutions entirely."
),
},
{
"name": "Algebraic geometry (math)",
"title": "On the Hodge conjecture for products of certain K3 surfaces",
"abstract": (
"We prove the Hodge conjecture for self-products of certain K3 "
"surfaces using transcendental cycles, motivic methods, and "
"Kuga-Satake constructions."
),
},
{
"name": "TeV gamma astronomy (astro-ph)",
"title": "Observation of TeV gamma rays from blazar Mrk 421 with VERITAS",
"abstract": (
"We report observations of very high energy gamma-ray emission "
"from the blazar Markarian 421 conducted with the VERITAS array "
"of imaging atmospheric Cherenkov telescopes."
),
},
]
# ---------------------------------------------------------------------------
# Model loading (cached so we only do it once per session)
# ---------------------------------------------------------------------------
@st.cache_resource(show_spinner="Loading model… (only happens once)")
def load_model_and_tokenizer():
"""Load model + tokenizer + label metadata.
Returns a dict with keys: model, tokenizer, id2label, pretty_names,
max_length, source.
"""
repo = os.environ.get(HF_REPO_ENV_VAR, "").strip()
source: str
label_meta_path: Path | None = None
if repo:
# Hub: download model files. label_meta.json is fetched separately if
# present (HF auto-downloads only the model files via from_pretrained).
source = f"HF Hub: {repo}"
tokenizer = AutoTokenizer.from_pretrained(repo)
model = AutoModelForSequenceClassification.from_pretrained(repo)
try:
from huggingface_hub import hf_hub_download
label_meta_path = Path(
hf_hub_download(repo_id=repo, filename="label_meta.json")
)
except Exception:
label_meta_path = None
elif DEFAULT_LOCAL_MODEL_DIR.exists():
source = f"local dir: {DEFAULT_LOCAL_MODEL_DIR}"
tokenizer = AutoTokenizer.from_pretrained(DEFAULT_LOCAL_MODEL_DIR)
model = AutoModelForSequenceClassification.from_pretrained(
DEFAULT_LOCAL_MODEL_DIR
)
candidate = DEFAULT_LOCAL_MODEL_DIR / "label_meta.json"
label_meta_path = candidate if candidate.exists() else None
else:
raise FileNotFoundError(
"No model found. Either set environment variable "
f"{HF_REPO_ENV_VAR} to a HuggingFace repo id, or place a trained "
f"model in {DEFAULT_LOCAL_MODEL_DIR} (run train.ipynb)."
)
model.to(DEVICE)
model.eval()
# Resolve labels: prefer label_meta.json, fall back to model config.
if label_meta_path is not None:
meta = json.loads(label_meta_path.read_text())
id2label = {int(k): v for k, v in meta["id2label"].items()}
pretty_names = meta.get("pretty_names", {})
max_length = int(meta.get("max_length", MAX_LENGTH_FALLBACK))
else:
id2label = {int(k): v for k, v in model.config.id2label.items()}
pretty_names = {}
max_length = MAX_LENGTH_FALLBACK
# Fill in any missing pretty names with our fallback table.
pretty_names = {
lab: pretty_names.get(lab, PRETTY_NAMES_FALLBACK.get(lab, lab))
for lab in id2label.values()
}
return {
"model": model,
"tokenizer": tokenizer,
"id2label": id2label,
"pretty_names": pretty_names,
"max_length": max_length,
"source": source,
"device": str(DEVICE),
}
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
def build_input_text(title: str, abstract: str) -> str:
"""Combine title and (optional) abstract into one input string.
The training notebook used `title + ". " + abstract`. We replicate that
so inference matches the training distribution. If the abstract is empty
we fall back to title-only — the model still works, just with less
context.
"""
title = title.strip()
abstract = abstract.strip()
if abstract:
return f"{title}. {abstract}"
return title
@torch.inference_mode()
def predict(
title: str, abstract: str, top_p: float = TOP_P_DEFAULT
) -> List[Tuple[str, str, float]]:
"""Return [(label, pretty_name, prob)] covering top-p of the mass."""
bundle = load_model_and_tokenizer()
model = bundle["model"]
tokenizer = bundle["tokenizer"]
id2label = bundle["id2label"]
pretty_names = bundle["pretty_names"]
max_length = bundle["max_length"]
text = build_input_text(title, abstract)
enc = tokenizer(
text,
truncation=True,
max_length=max_length,
return_tensors="pt",
)
enc = {k: v.to(DEVICE) for k, v in enc.items()}
logits = model(**enc).logits[0]
probs = F.softmax(logits, dim=-1).cpu().numpy()
order = probs.argsort()[::-1]
cumulative = 0.0
out: list[tuple[str, str, float]] = []
for idx in order:
label = id2label[int(idx)]
pretty = pretty_names.get(label, label)
prob = float(probs[idx])
out.append((label, pretty, prob))
cumulative += prob
if cumulative >= top_p:
break
return out
# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
def render_results(results: List[Tuple[str, str, float]]) -> None:
st.subheader("Predicted topics")
st.caption(
"Showing the smallest set of topics whose total probability is at least "
f"{int(TOP_P_DEFAULT * 100)}%."
)
top_label, top_pretty, top_prob = results[0]
st.success(f"**Best guess:** {top_pretty} · `{top_label}` · {top_prob:.1%}")
for label, pretty, prob in results:
col1, col2 = st.columns([3, 1])
with col1:
st.markdown(f"**{pretty}** · `{label}`")
st.progress(min(max(prob, 0.0), 1.0))
with col2:
st.metric(label=" ", value=f"{prob:.1%}", label_visibility="collapsed")
def main() -> None:
st.set_page_config(
page_title="arXiv Topic Classifier",
page_icon=":bookmark_tabs:",
layout="centered",
)
st.title("arXiv Topic Classifier")
st.markdown(
"Paste a paper's **title** (and optionally its **abstract**) — the "
"model will tell you which arXiv categories it most likely belongs to. "
"Powered by a fine-tuned DistilBERT."
)
# Try to load model up front so config issues surface immediately.
try:
bundle = load_model_and_tokenizer()
except Exception as exc:
st.error(
"Failed to load the classification model.\n\n"
f"**Reason:** {exc}\n\n"
"If you are running locally, train a model with `train.ipynb` "
"and re-launch. On HuggingFace Spaces, set the secret "
f"`{HF_REPO_ENV_VAR}` to a model repo id."
)
st.stop()
with st.sidebar:
st.header("About")
st.markdown(
"**Task.** Classify an arXiv paper into one of "
f"{len(bundle['id2label'])} top-level categories.\n\n"
"**Model.** Fine-tuned `distilbert-base-uncased`.\n\n"
"**Input.** Title is required; abstract is optional but helps a lot."
)
st.caption(f"Model source: {bundle['source']}")
st.caption(f"Inference device: `{bundle['device']}`")
st.header("Try an example")
for ex in EXAMPLES:
if st.button(ex["name"], use_container_width=True):
st.session_state["title_input"] = ex["title"]
st.session_state["abstract_input"] = ex["abstract"]
st.session_state.setdefault("title_input", "")
st.session_state.setdefault("abstract_input", "")
title = st.text_input(
"Paper title",
key="title_input",
placeholder="e.g. Attention Is All You Need",
help="Required. The full paper title.",
)
abstract = st.text_area(
"Abstract (optional)",
key="abstract_input",
height=180,
placeholder=(
"Paste the paper abstract here. If you leave this empty the model "
"will classify by title only — predictions will be less confident."
),
help="Optional but strongly recommended.",
)
classify_clicked = st.button("Classify", type="primary", use_container_width=True)
if not classify_clicked:
return
# Input validation
if not title.strip():
st.warning(
"Please enter at least a paper title. The abstract is optional, "
"but the title is required."
)
return
if len(title.strip()) < 5:
st.warning(
"That title looks unusually short. Please paste the full paper "
"title for a meaningful prediction."
)
return
# Inference
try:
with st.spinner("Classifying…"):
results = predict(title, abstract)
except Exception as exc:
st.error(
"Something went wrong while running the model.\n\n"
f"**Details:** {exc}\n\n"
"Try shortening the abstract or refresh the page."
)
return
if not results:
st.error("The model returned no predictions. This should not happen — "
"please report it.")
return
render_results(results)
if not abstract.strip():
st.info(
"Tip: paste the abstract for a noticeably more confident "
"prediction."
)
if __name__ == "__main__":
main()