Spaces:
Running
Running
| """Gradio Space app β arXiv CS sub-field classifier. | |
| Loads the trained classifier + label encoder + config from the HF model repo | |
| and exposes a simple title+abstract -> top-k labels interface. | |
| Supports two model flavours via ``config.json["feature_kind"]``: | |
| * ``embedding`` β uses a SentenceTransformer named in config | |
| * ``tfidf`` β uses the bundled vectoriser inside classifier.joblib | |
| """ | |
| from __future__ import annotations | |
| import os | |
| # Windows-only workaround for the MKL/OpenMP DLL clash between conda numpy | |
| # and pip torch: import torch first, allow duplicate OpenMP runtimes. | |
| # (No-op on Linux / HF Space.) | |
| os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE") | |
| try: | |
| import torch # noqa: F401 | |
| except Exception: | |
| pass | |
| import json | |
| import logging | |
| from pathlib import Path | |
| import gradio as gr | |
| import joblib | |
| from huggingface_hub import hf_hub_download | |
| MODEL_REPO = "Tiansve/arxiv-subfield-linear-probe" | |
| MIN_ABSTRACT_WORDS = 20 | |
| logging.basicConfig(level=logging.INFO) | |
| log = logging.getLogger("space") | |
| # ---- Load artefacts ------------------------------------------------------- # | |
| config_path = hf_hub_download(MODEL_REPO, "config.json") | |
| with open(config_path, "r", encoding="utf-8") as f: | |
| cfg = json.load(f) | |
| clf = joblib.load(hf_hub_download(MODEL_REPO, "classifier.joblib")) | |
| le = joblib.load(hf_hub_download(MODEL_REPO, "label_encoder.joblib")) | |
| FEATURE_KIND = cfg["feature_kind"] | |
| EMBEDDER_NAME = cfg.get("embedder") | |
| CLASS_NAMES = list(le.classes_) | |
| embedder = None | |
| vectorizer = None | |
| classifier = clf # rebound below for tfidf | |
| if FEATURE_KIND == "embedding": | |
| from sentence_transformers import SentenceTransformer | |
| log.info("Loading SentenceTransformer: %s", EMBEDDER_NAME) | |
| embedder = SentenceTransformer(EMBEDDER_NAME) | |
| elif FEATURE_KIND == "tfidf": | |
| # classifier.joblib for tfidf is a dict {"vectorizer", "classifier"} | |
| vectorizer = clf["vectorizer"] | |
| classifier = clf["classifier"] | |
| else: | |
| raise ValueError(f"Unknown feature_kind: {FEATURE_KIND!r}") | |
| # ---- Inference ------------------------------------------------------------ # | |
| def _featurise(text: str): | |
| if FEATURE_KIND == "embedding": | |
| return embedder.encode([text], normalize_embeddings=True) | |
| return vectorizer.transform([text]) | |
| def classify(title: str, abstract: str): | |
| title = (title or "").strip() | |
| abstract = (abstract or "").strip() | |
| if not abstract or len(abstract.split()) < MIN_ABSTRACT_WORDS: | |
| return {f"(paste at least {MIN_ABSTRACT_WORDS} words of abstract)": 1.0} | |
| text = (title + ". " + abstract).strip(". ").strip() | |
| X = _featurise(text) | |
| probs = classifier.predict_proba(X)[0] | |
| return {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))} | |
| # ---- Examples ------------------------------------------------------------- # | |
| def _load_examples(): | |
| path = Path(__file__).parent / "examples.json" | |
| if not path.exists(): | |
| return [] | |
| data = json.loads(path.read_text(encoding="utf-8")) | |
| # Gradio Examples expects a list of [input1, input2, ...] matching the inputs order | |
| return [[row["title"], row["abstract"]] for row in data] | |
| # ---- UI ------------------------------------------------------------------- # | |
| description = ( | |
| f"Paste an arXiv-style abstract and the model predicts which CS sub-field it belongs to.\n\n" | |
| f"Trained on ~4.1k recent abstracts across the six categories " | |
| f"({', '.join(CLASS_NAMES)}). " | |
| f"Current model: **{cfg['config']}** " | |
| f"(test macro-F1 = {cfg['test_macro_f1']:.3f}). " | |
| f"Predictions are noisy near class boundaries by design β see the linked report for analysis." | |
| ) | |
| demo = gr.Interface( | |
| fn=classify, | |
| inputs=[ | |
| gr.Textbox(label="Title", placeholder="Paper title (optional)"), | |
| gr.Textbox(label="Abstract", lines=10, placeholder="Paste the abstract here..."), | |
| ], | |
| outputs=gr.Label(num_top_classes=4, label="Predicted sub-field"), | |
| title="arXiv CS Sub-field Classifier", | |
| description=description, | |
| examples=_load_examples(), | |
| cache_examples=False, | |
| allow_flagging="never", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |