|
|
|
|
|
""" |
|
|
Inference script for the Academic Paper Classifier. |
|
|
|
|
|
Loads a fine-tuned DistilBERT model and predicts the arxiv category for a |
|
|
given paper abstract. Returns the predicted category along with per-class |
|
|
confidence scores. |
|
|
|
|
|
Usage examples: |
|
|
# Use a local model directory |
|
|
python inference.py --model_path ./model --abstract "We propose a novel ..." |
|
|
|
|
|
# Use a HuggingFace Hub model |
|
|
python inference.py --model_path gr8monk3ys/paper-classifier-model \ |
|
|
--abstract "We propose a novel ..." |
|
|
|
|
|
# Interactive mode (reads from stdin) |
|
|
python inference.py --model_path ./model |
|
|
|
|
|
Author: Lorenzo Scaturchio (gr8monk3ys) |
|
|
License: MIT |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import logging |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s", |
|
|
handlers=[logging.StreamHandler(sys.stdout)], |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PaperClassifier: |
|
|
"""Thin wrapper around a fine-tuned sequence-classification model. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
model_path : str |
|
|
Path to a local model directory **or** a HuggingFace Hub model id. |
|
|
device : str | None |
|
|
Target device (``"cpu"``, ``"cuda"``, ``"mps"``). If *None* the best |
|
|
available device is selected automatically. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_path: str, device: str | None = None) -> None: |
|
|
if device is None: |
|
|
if torch.cuda.is_available(): |
|
|
device = "cuda" |
|
|
elif torch.backends.mps.is_available(): |
|
|
device = "mps" |
|
|
else: |
|
|
device = "cpu" |
|
|
self.device = torch.device(device) |
|
|
|
|
|
logger.info("Loading tokenizer from: %s", model_path) |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
|
|
logger.info("Loading model from: %s", model_path) |
|
|
self.model = AutoModelForSequenceClassification.from_pretrained(model_path) |
|
|
self.model.to(self.device) |
|
|
|
|
|
|
|
|
self.id2label: dict[int, str] = self.model.config.id2label |
|
|
logger.info("Labels: %s", list(self.id2label.values())) |
|
|
|
|
|
@torch.no_grad() |
|
|
def predict(self, abstract: str, top_k: int | None = None) -> dict: |
|
|
"""Classify a single paper abstract. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
abstract : str |
|
|
The paper abstract to classify. |
|
|
top_k : int | None |
|
|
If given, only the *top_k* categories (by confidence) are returned |
|
|
in ``scores``. Pass *None* to return all categories. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
dict |
|
|
``{"label": str, "confidence": float, "scores": {label: prob}}`` |
|
|
""" |
|
|
self.model.eval() |
|
|
|
|
|
inputs = self.tokenizer( |
|
|
abstract, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
padding=True, |
|
|
max_length=512, |
|
|
).to(self.device) |
|
|
|
|
|
logits = self.model(**inputs).logits |
|
|
probs = torch.softmax(logits, dim=-1).squeeze(0).cpu().numpy() |
|
|
|
|
|
sorted_indices = probs.argsort()[::-1] |
|
|
if top_k is not None: |
|
|
sorted_indices = sorted_indices[:top_k] |
|
|
|
|
|
scores = { |
|
|
self.id2label[int(idx)]: float(probs[idx]) for idx in sorted_indices |
|
|
} |
|
|
|
|
|
best_idx = int(probs.argmax()) |
|
|
return { |
|
|
"label": self.id2label[best_idx], |
|
|
"confidence": float(probs[best_idx]), |
|
|
"scores": scores, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Classify an academic paper abstract into an arxiv category." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model_path", |
|
|
type=str, |
|
|
default="./model", |
|
|
help="Path to the fine-tuned model directory or HF Hub id (default: %(default)s).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--abstract", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Paper abstract text. If omitted, the script enters interactive mode.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--top_k", |
|
|
type=int, |
|
|
default=None, |
|
|
help="Only show the top-k predictions (default: show all).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--device", |
|
|
type=str, |
|
|
default=None, |
|
|
choices=["cpu", "cuda", "mps"], |
|
|
help="Device to run inference on (default: auto-detect).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--json", |
|
|
action="store_true", |
|
|
default=False, |
|
|
dest="output_json", |
|
|
help="Output raw JSON instead of human-readable text.", |
|
|
) |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def _print_result(result: dict, output_json: bool) -> None: |
|
|
"""Pretty-print or JSON-dump a prediction result.""" |
|
|
if output_json: |
|
|
print(json.dumps(result, indent=2)) |
|
|
return |
|
|
|
|
|
print(f"\n Predicted category : {result['label']}") |
|
|
print(f" Confidence : {result['confidence']:.4f}") |
|
|
print(" ---------------------------------") |
|
|
for label, score in result["scores"].items(): |
|
|
bar = "#" * int(score * 40) |
|
|
print(f" {label:<10s} {score:6.4f} {bar}") |
|
|
print() |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
args = parse_args() |
|
|
classifier = PaperClassifier(model_path=args.model_path, device=args.device) |
|
|
|
|
|
if args.abstract is not None: |
|
|
result = classifier.predict(args.abstract, top_k=args.top_k) |
|
|
_print_result(result, args.output_json) |
|
|
return |
|
|
|
|
|
|
|
|
print("Academic Paper Classifier - Interactive Mode") |
|
|
print("Enter a paper abstract (or 'quit' to exit).") |
|
|
print("For multi-line input, end with an empty line.\n") |
|
|
|
|
|
while True: |
|
|
try: |
|
|
lines: list[str] = [] |
|
|
prompt = "abstract> " if sys.stdin.isatty() else "" |
|
|
while True: |
|
|
line = input(prompt) |
|
|
if line.strip().lower() == "quit": |
|
|
logger.info("Exiting.") |
|
|
return |
|
|
if line == "" and lines: |
|
|
break |
|
|
lines.append(line) |
|
|
prompt = "... " if sys.stdin.isatty() else "" |
|
|
|
|
|
abstract = " ".join(lines).strip() |
|
|
if not abstract: |
|
|
continue |
|
|
|
|
|
result = classifier.predict(abstract, top_k=args.top_k) |
|
|
_print_result(result, args.output_json) |
|
|
|
|
|
except (EOFError, KeyboardInterrupt): |
|
|
print() |
|
|
logger.info("Exiting.") |
|
|
return |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|