File size: 7,122 Bytes
2be4558 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
#!/usr/bin/env python3
"""
Inference script for the Academic Paper Classifier.
Loads a fine-tuned DistilBERT model and predicts the arxiv category for a
given paper abstract. Returns the predicted category along with per-class
confidence scores.
Usage examples:
# Use a local model directory
python inference.py --model_path ./model --abstract "We propose a novel ..."
# Use a HuggingFace Hub model
python inference.py --model_path gr8monk3ys/paper-classifier-model \
--abstract "We propose a novel ..."
# Interactive mode (reads from stdin)
python inference.py --model_path ./model
Author: Lorenzo Scaturchio (gr8monk3ys)
License: MIT
"""
import argparse
import json
import logging
import sys
from pathlib import Path
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Classifier wrapper
# ---------------------------------------------------------------------------
class PaperClassifier:
"""Thin wrapper around a fine-tuned sequence-classification model.
Parameters
----------
model_path : str
Path to a local model directory **or** a HuggingFace Hub model id.
device : str | None
Target device (``"cpu"``, ``"cuda"``, ``"mps"``). If *None* the best
available device is selected automatically.
"""
def __init__(self, model_path: str, device: str | None = None) -> None:
if device is None:
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
self.device = torch.device(device)
logger.info("Loading tokenizer from: %s", model_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
logger.info("Loading model from: %s", model_path)
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
self.model.to(self.device)
# Read label mapping stored in the model config
self.id2label: dict[int, str] = self.model.config.id2label
logger.info("Labels: %s", list(self.id2label.values()))
@torch.no_grad()
def predict(self, abstract: str, top_k: int | None = None) -> dict:
"""Classify a single paper abstract.
Parameters
----------
abstract : str
The paper abstract to classify.
top_k : int | None
If given, only the *top_k* categories (by confidence) are returned
in ``scores``. Pass *None* to return all categories.
Returns
-------
dict
``{"label": str, "confidence": float, "scores": {label: prob}}``
"""
self.model.eval()
inputs = self.tokenizer(
abstract,
return_tensors="pt",
truncation=True,
padding=True,
max_length=512,
).to(self.device)
logits = self.model(**inputs).logits
probs = torch.softmax(logits, dim=-1).squeeze(0).cpu().numpy()
sorted_indices = probs.argsort()[::-1]
if top_k is not None:
sorted_indices = sorted_indices[:top_k]
scores = {
self.id2label[int(idx)]: float(probs[idx]) for idx in sorted_indices
}
best_idx = int(probs.argmax())
return {
"label": self.id2label[best_idx],
"confidence": float(probs[best_idx]),
"scores": scores,
}
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Classify an academic paper abstract into an arxiv category."
)
parser.add_argument(
"--model_path",
type=str,
default="./model",
help="Path to the fine-tuned model directory or HF Hub id (default: %(default)s).",
)
parser.add_argument(
"--abstract",
type=str,
default=None,
help="Paper abstract text. If omitted, the script enters interactive mode.",
)
parser.add_argument(
"--top_k",
type=int,
default=None,
help="Only show the top-k predictions (default: show all).",
)
parser.add_argument(
"--device",
type=str,
default=None,
choices=["cpu", "cuda", "mps"],
help="Device to run inference on (default: auto-detect).",
)
parser.add_argument(
"--json",
action="store_true",
default=False,
dest="output_json",
help="Output raw JSON instead of human-readable text.",
)
return parser.parse_args()
def _print_result(result: dict, output_json: bool) -> None:
"""Pretty-print or JSON-dump a prediction result."""
if output_json:
print(json.dumps(result, indent=2))
return
print(f"\n Predicted category : {result['label']}")
print(f" Confidence : {result['confidence']:.4f}")
print(" ---------------------------------")
for label, score in result["scores"].items():
bar = "#" * int(score * 40)
print(f" {label:<10s} {score:6.4f} {bar}")
print()
def main() -> None:
args = parse_args()
classifier = PaperClassifier(model_path=args.model_path, device=args.device)
if args.abstract is not None:
result = classifier.predict(args.abstract, top_k=args.top_k)
_print_result(result, args.output_json)
return
# Interactive mode
print("Academic Paper Classifier - Interactive Mode")
print("Enter a paper abstract (or 'quit' to exit).")
print("For multi-line input, end with an empty line.\n")
while True:
try:
lines: list[str] = []
prompt = "abstract> " if sys.stdin.isatty() else ""
while True:
line = input(prompt)
if line.strip().lower() == "quit":
logger.info("Exiting.")
return
if line == "" and lines:
break
lines.append(line)
prompt = "... " if sys.stdin.isatty() else ""
abstract = " ".join(lines).strip()
if not abstract:
continue
result = classifier.predict(abstract, top_k=args.top_k)
_print_result(result, args.output_json)
except (EOFError, KeyboardInterrupt):
print()
logger.info("Exiting.")
return
if __name__ == "__main__":
main()
|