paper-classifier / inference.py
gr8monk3ys's picture
Upload folder using huggingface_hub
2be4558 verified
#!/usr/bin/env python3
"""
Inference script for the Academic Paper Classifier.
Loads a fine-tuned DistilBERT model and predicts the arxiv category for a
given paper abstract. Returns the predicted category along with per-class
confidence scores.
Usage examples:
# Use a local model directory
python inference.py --model_path ./model --abstract "We propose a novel ..."
# Use a HuggingFace Hub model
python inference.py --model_path gr8monk3ys/paper-classifier-model \
--abstract "We propose a novel ..."
# Interactive mode (reads from stdin)
python inference.py --model_path ./model
Author: Lorenzo Scaturchio (gr8monk3ys)
License: MIT
"""
import argparse
import json
import logging
import sys
from pathlib import Path
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Classifier wrapper
# ---------------------------------------------------------------------------
class PaperClassifier:
"""Thin wrapper around a fine-tuned sequence-classification model.
Parameters
----------
model_path : str
Path to a local model directory **or** a HuggingFace Hub model id.
device : str | None
Target device (``"cpu"``, ``"cuda"``, ``"mps"``). If *None* the best
available device is selected automatically.
"""
def __init__(self, model_path: str, device: str | None = None) -> None:
if device is None:
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
self.device = torch.device(device)
logger.info("Loading tokenizer from: %s", model_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
logger.info("Loading model from: %s", model_path)
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
self.model.to(self.device)
# Read label mapping stored in the model config
self.id2label: dict[int, str] = self.model.config.id2label
logger.info("Labels: %s", list(self.id2label.values()))
@torch.no_grad()
def predict(self, abstract: str, top_k: int | None = None) -> dict:
"""Classify a single paper abstract.
Parameters
----------
abstract : str
The paper abstract to classify.
top_k : int | None
If given, only the *top_k* categories (by confidence) are returned
in ``scores``. Pass *None* to return all categories.
Returns
-------
dict
``{"label": str, "confidence": float, "scores": {label: prob}}``
"""
self.model.eval()
inputs = self.tokenizer(
abstract,
return_tensors="pt",
truncation=True,
padding=True,
max_length=512,
).to(self.device)
logits = self.model(**inputs).logits
probs = torch.softmax(logits, dim=-1).squeeze(0).cpu().numpy()
sorted_indices = probs.argsort()[::-1]
if top_k is not None:
sorted_indices = sorted_indices[:top_k]
scores = {
self.id2label[int(idx)]: float(probs[idx]) for idx in sorted_indices
}
best_idx = int(probs.argmax())
return {
"label": self.id2label[best_idx],
"confidence": float(probs[best_idx]),
"scores": scores,
}
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Classify an academic paper abstract into an arxiv category."
)
parser.add_argument(
"--model_path",
type=str,
default="./model",
help="Path to the fine-tuned model directory or HF Hub id (default: %(default)s).",
)
parser.add_argument(
"--abstract",
type=str,
default=None,
help="Paper abstract text. If omitted, the script enters interactive mode.",
)
parser.add_argument(
"--top_k",
type=int,
default=None,
help="Only show the top-k predictions (default: show all).",
)
parser.add_argument(
"--device",
type=str,
default=None,
choices=["cpu", "cuda", "mps"],
help="Device to run inference on (default: auto-detect).",
)
parser.add_argument(
"--json",
action="store_true",
default=False,
dest="output_json",
help="Output raw JSON instead of human-readable text.",
)
return parser.parse_args()
def _print_result(result: dict, output_json: bool) -> None:
"""Pretty-print or JSON-dump a prediction result."""
if output_json:
print(json.dumps(result, indent=2))
return
print(f"\n Predicted category : {result['label']}")
print(f" Confidence : {result['confidence']:.4f}")
print(" ---------------------------------")
for label, score in result["scores"].items():
bar = "#" * int(score * 40)
print(f" {label:<10s} {score:6.4f} {bar}")
print()
def main() -> None:
args = parse_args()
classifier = PaperClassifier(model_path=args.model_path, device=args.device)
if args.abstract is not None:
result = classifier.predict(args.abstract, top_k=args.top_k)
_print_result(result, args.output_json)
return
# Interactive mode
print("Academic Paper Classifier - Interactive Mode")
print("Enter a paper abstract (or 'quit' to exit).")
print("For multi-line input, end with an empty line.\n")
while True:
try:
lines: list[str] = []
prompt = "abstract> " if sys.stdin.isatty() else ""
while True:
line = input(prompt)
if line.strip().lower() == "quit":
logger.info("Exiting.")
return
if line == "" and lines:
break
lines.append(line)
prompt = "... " if sys.stdin.isatty() else ""
abstract = " ".join(lines).strip()
if not abstract:
continue
result = classifier.predict(abstract, top_k=args.top_k)
_print_result(result, args.output_json)
except (EOFError, KeyboardInterrupt):
print()
logger.info("Exiting.")
return
if __name__ == "__main__":
main()