File size: 13,291 Bytes
e40613d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
"""arXiv Topic Classifier — Streamlit web UI.

Fine-tuned DistilBERT predicts the top-level arxiv category for a paper given
its title and (optionally) abstract. The UI shows topics whose cumulative
probability covers >=95%, sorted by descending confidence.

The model is loaded from one of:
  1. HF Hub repo specified in env var ARXIV_MODEL_REPO (e.g. "user/arxiv-clf")
  2. Local directory ./model/ (produced by train.ipynb)

Set ARXIV_MODEL_REPO before launching to use a hosted model on HF Spaces.

Device selection is automatic: MPS (Apple Silicon) → CUDA → CPU. On HF Spaces
free tier this falls back to CPU.
"""

from __future__ import annotations

import json
import os
from pathlib import Path
from typing import List, Tuple

# Allow rare ops without an MPS kernel to fall back to CPU instead of crashing.
# Must be set before torch is imported.
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

import streamlit as st
import torch
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------

DEFAULT_LOCAL_MODEL_DIR = Path(__file__).parent / "model"
HF_REPO_ENV_VAR = "ARXIV_MODEL_REPO"
MAX_LENGTH_FALLBACK = 256
TOP_P_DEFAULT = 0.95


def _select_device() -> torch.device:
    """Pick the best available device: MPS → CUDA → CPU.

    On Apple Silicon (M1/M2/M3) MPS gives a major speedup over CPU. On HF
    Spaces free tier neither MPS nor CUDA is available so we fall back to
    CPU automatically.
    """
    if torch.backends.mps.is_available() and torch.backends.mps.is_built():
        return torch.device("mps")
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")


DEVICE = _select_device()

# Human-readable names for arxiv top-level categories. Used as a fallback if
# label_meta.json does not provide pretty_names for some label.
PRETTY_NAMES_FALLBACK = {
    "astro-ph": "Astrophysics",
    "cond-mat": "Condensed Matter Physics",
    "cs": "Computer Science",
    "econ": "Economics",
    "eess": "Electrical Engineering & Systems",
    "gr-qc": "General Relativity & Quantum Cosmology",
    "hep-ex": "High Energy Physics — Experiment",
    "hep-lat": "High Energy Physics — Lattice",
    "hep-ph": "High Energy Physics — Phenomenology",
    "hep-th": "High Energy Physics — Theory",
    "math": "Mathematics",
    "math-ph": "Mathematical Physics",
    "nlin": "Nonlinear Sciences",
    "nucl-ex": "Nuclear Physics — Experiment",
    "nucl-th": "Nuclear Physics — Theory",
    "physics": "Physics (general)",
    "q-bio": "Quantitative Biology",
    "q-fin": "Quantitative Finance",
    "quant-ph": "Quantum Physics",
    "stat": "Statistics",
}

EXAMPLES = [
    {
        "name": "Transformers paper (cs)",
        "title": "Attention Is All You Need",
        "abstract": (
            "The dominant sequence transduction models are based on complex "
            "recurrent or convolutional neural networks that include an "
            "encoder and a decoder. We propose a new simple network "
            "architecture, the Transformer, based solely on attention "
            "mechanisms, dispensing with recurrence and convolutions entirely."
        ),
    },
    {
        "name": "Algebraic geometry (math)",
        "title": "On the Hodge conjecture for products of certain K3 surfaces",
        "abstract": (
            "We prove the Hodge conjecture for self-products of certain K3 "
            "surfaces using transcendental cycles, motivic methods, and "
            "Kuga-Satake constructions."
        ),
    },
    {
        "name": "TeV gamma astronomy (astro-ph)",
        "title": "Observation of TeV gamma rays from blazar Mrk 421 with VERITAS",
        "abstract": (
            "We report observations of very high energy gamma-ray emission "
            "from the blazar Markarian 421 conducted with the VERITAS array "
            "of imaging atmospheric Cherenkov telescopes."
        ),
    },
]


# ---------------------------------------------------------------------------
# Model loading (cached so we only do it once per session)
# ---------------------------------------------------------------------------


@st.cache_resource(show_spinner="Loading model… (only happens once)")
def load_model_and_tokenizer():
    """Load model + tokenizer + label metadata.

    Returns a dict with keys: model, tokenizer, id2label, pretty_names,
    max_length, source.
    """
    repo = os.environ.get(HF_REPO_ENV_VAR, "").strip()
    source: str
    label_meta_path: Path | None = None

    if repo:
        # Hub: download model files. label_meta.json is fetched separately if
        # present (HF auto-downloads only the model files via from_pretrained).
        source = f"HF Hub: {repo}"
        tokenizer = AutoTokenizer.from_pretrained(repo)
        model = AutoModelForSequenceClassification.from_pretrained(repo)
        try:
            from huggingface_hub import hf_hub_download

            label_meta_path = Path(
                hf_hub_download(repo_id=repo, filename="label_meta.json")
            )
        except Exception:
            label_meta_path = None
    elif DEFAULT_LOCAL_MODEL_DIR.exists():
        source = f"local dir: {DEFAULT_LOCAL_MODEL_DIR}"
        tokenizer = AutoTokenizer.from_pretrained(DEFAULT_LOCAL_MODEL_DIR)
        model = AutoModelForSequenceClassification.from_pretrained(
            DEFAULT_LOCAL_MODEL_DIR
        )
        candidate = DEFAULT_LOCAL_MODEL_DIR / "label_meta.json"
        label_meta_path = candidate if candidate.exists() else None
    else:
        raise FileNotFoundError(
            "No model found. Either set environment variable "
            f"{HF_REPO_ENV_VAR} to a HuggingFace repo id, or place a trained "
            f"model in {DEFAULT_LOCAL_MODEL_DIR} (run train.ipynb)."
        )

    model.to(DEVICE)
    model.eval()

    # Resolve labels: prefer label_meta.json, fall back to model config.
    if label_meta_path is not None:
        meta = json.loads(label_meta_path.read_text())
        id2label = {int(k): v for k, v in meta["id2label"].items()}
        pretty_names = meta.get("pretty_names", {})
        max_length = int(meta.get("max_length", MAX_LENGTH_FALLBACK))
    else:
        id2label = {int(k): v for k, v in model.config.id2label.items()}
        pretty_names = {}
        max_length = MAX_LENGTH_FALLBACK

    # Fill in any missing pretty names with our fallback table.
    pretty_names = {
        lab: pretty_names.get(lab, PRETTY_NAMES_FALLBACK.get(lab, lab))
        for lab in id2label.values()
    }

    return {
        "model": model,
        "tokenizer": tokenizer,
        "id2label": id2label,
        "pretty_names": pretty_names,
        "max_length": max_length,
        "source": source,
        "device": str(DEVICE),
    }


# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------


def build_input_text(title: str, abstract: str) -> str:
    """Combine title and (optional) abstract into one input string.

    The training notebook used `title + ". " + abstract`. We replicate that
    so inference matches the training distribution. If the abstract is empty
    we fall back to title-only — the model still works, just with less
    context.
    """
    title = title.strip()
    abstract = abstract.strip()
    if abstract:
        return f"{title}. {abstract}"
    return title


@torch.inference_mode()
def predict(
    title: str, abstract: str, top_p: float = TOP_P_DEFAULT
) -> List[Tuple[str, str, float]]:
    """Return [(label, pretty_name, prob)] covering top-p of the mass."""
    bundle = load_model_and_tokenizer()
    model = bundle["model"]
    tokenizer = bundle["tokenizer"]
    id2label = bundle["id2label"]
    pretty_names = bundle["pretty_names"]
    max_length = bundle["max_length"]

    text = build_input_text(title, abstract)
    enc = tokenizer(
        text,
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
    )
    enc = {k: v.to(DEVICE) for k, v in enc.items()}
    logits = model(**enc).logits[0]
    probs = F.softmax(logits, dim=-1).cpu().numpy()

    order = probs.argsort()[::-1]
    cumulative = 0.0
    out: list[tuple[str, str, float]] = []
    for idx in order:
        label = id2label[int(idx)]
        pretty = pretty_names.get(label, label)
        prob = float(probs[idx])
        out.append((label, pretty, prob))
        cumulative += prob
        if cumulative >= top_p:
            break
    return out


# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------


def render_results(results: List[Tuple[str, str, float]]) -> None:
    st.subheader("Predicted topics")
    st.caption(
        "Showing the smallest set of topics whose total probability is at least "
        f"{int(TOP_P_DEFAULT * 100)}%."
    )

    top_label, top_pretty, top_prob = results[0]
    st.success(f"**Best guess:** {top_pretty}  ·  `{top_label}`  ·  {top_prob:.1%}")

    for label, pretty, prob in results:
        col1, col2 = st.columns([3, 1])
        with col1:
            st.markdown(f"**{pretty}**  ·  `{label}`")
            st.progress(min(max(prob, 0.0), 1.0))
        with col2:
            st.metric(label=" ", value=f"{prob:.1%}", label_visibility="collapsed")


def main() -> None:
    st.set_page_config(
        page_title="arXiv Topic Classifier",
        page_icon=":bookmark_tabs:",
        layout="centered",
    )

    st.title("arXiv Topic Classifier")
    st.markdown(
        "Paste a paper's **title** (and optionally its **abstract**) — the "
        "model will tell you which arXiv categories it most likely belongs to. "
        "Powered by a fine-tuned DistilBERT."
    )

    # Try to load model up front so config issues surface immediately.
    try:
        bundle = load_model_and_tokenizer()
    except Exception as exc:
        st.error(
            "Failed to load the classification model.\n\n"
            f"**Reason:** {exc}\n\n"
            "If you are running locally, train a model with `train.ipynb` "
            "and re-launch. On HuggingFace Spaces, set the secret "
            f"`{HF_REPO_ENV_VAR}` to a model repo id."
        )
        st.stop()

    with st.sidebar:
        st.header("About")
        st.markdown(
            "**Task.** Classify an arXiv paper into one of "
            f"{len(bundle['id2label'])} top-level categories.\n\n"
            "**Model.** Fine-tuned `distilbert-base-uncased`.\n\n"
            "**Input.** Title is required; abstract is optional but helps a lot."
        )
        st.caption(f"Model source: {bundle['source']}")
        st.caption(f"Inference device: `{bundle['device']}`")

        st.header("Try an example")
        for ex in EXAMPLES:
            if st.button(ex["name"], use_container_width=True):
                st.session_state["title_input"] = ex["title"]
                st.session_state["abstract_input"] = ex["abstract"]

    st.session_state.setdefault("title_input", "")
    st.session_state.setdefault("abstract_input", "")

    title = st.text_input(
        "Paper title",
        key="title_input",
        placeholder="e.g. Attention Is All You Need",
        help="Required. The full paper title.",
    )
    abstract = st.text_area(
        "Abstract (optional)",
        key="abstract_input",
        height=180,
        placeholder=(
            "Paste the paper abstract here. If you leave this empty the model "
            "will classify by title only — predictions will be less confident."
        ),
        help="Optional but strongly recommended.",
    )

    classify_clicked = st.button("Classify", type="primary", use_container_width=True)

    if not classify_clicked:
        return

    # Input validation
    if not title.strip():
        st.warning(
            "Please enter at least a paper title. The abstract is optional, "
            "but the title is required."
        )
        return
    if len(title.strip()) < 5:
        st.warning(
            "That title looks unusually short. Please paste the full paper "
            "title for a meaningful prediction."
        )
        return

    # Inference
    try:
        with st.spinner("Classifying…"):
            results = predict(title, abstract)
    except Exception as exc:
        st.error(
            "Something went wrong while running the model.\n\n"
            f"**Details:** {exc}\n\n"
            "Try shortening the abstract or refresh the page."
        )
        return

    if not results:
        st.error("The model returned no predictions. This should not happen — "
                 "please report it.")
        return

    render_results(results)

    if not abstract.strip():
        st.info(
            "Tip: paste the abstract for a noticeably more confident "
            "prediction."
        )


if __name__ == "__main__":
    main()