nexa-classify-api / predict.py
Prototype6239's picture
Upload folder using huggingface_hub
a229747 verified
Raw
History Blame Contribute Delete
8.83 kB
"""
predict.py
──────────
Run inference on new text with any trained classifier.
Usage
─────
# Single prediction β€” traditional models
python predict.py --model lr --text "Federal Reserve cuts rates to near zero"
python predict.py --model svm --text "Ronaldo scores hat-trick in Champions League"
# Single prediction β€” transformer (FP32)
python predict.py --model transformer --text "NASA launches James Webb successor"
# Single prediction β€” INT8 quantized transformer (fast CPU inference)
python predict.py --model transformer_quantized --text "Federal Reserve raises rates"
python predict.py --model transformer_quantized --checkpoint roberta-base --text "..."
# Interactive loop
python predict.py --model transformer --interactive
python predict.py --model transformer_quantized --interactive
python predict.py --model lr --interactive
"""
import argparse
import logging
import sys
from typing import Dict, Optional
import numpy as np
import torch
from config import CFG
import traditional_model as tm
import transformer_model as trm
# Suppress INFO logs during interactive prediction
logging.basicConfig(level=logging.WARNING)
# ── Prediction functions ──────────────────────────────────────────────────────
def predict_traditional(text: str, model_name: str) -> Dict:
"""Run a single prediction with a saved sklearn pipeline."""
pipeline = tm.load_model(model_name)
pred_id = int(pipeline.predict([text])[0])
result: Dict = {
"text": text,
"label_id": pred_id,
"label": CFG.label_names[pred_id],
}
# Logistic Regression supports predict_proba; LinearSVC does not
clf = pipeline.named_steps[model_name]
if hasattr(clf, "predict_proba"):
probs = pipeline.predict_proba([text])[0]
result["probabilities"] = {
CFG.label_names[i]: round(float(p), 4)
for i, p in enumerate(probs)
}
return result
def predict_transformer(
text: str,
model=None,
tokenizer=None,
) -> Dict:
"""Run a single prediction with a fine-tuned transformer (FP32, MPS/CPU)."""
if model is None or tokenizer is None:
model, tokenizer = trm.load_model()
encoding = tokenizer(
text,
truncation=True,
max_length=CFG.max_length,
return_tensors="pt",
)
with torch.no_grad():
logits = model(**encoding).logits[0]
probs = torch.softmax(logits, dim=-1).numpy()
pred_id = int(np.argmax(probs))
return {
"text": text,
"label_id": pred_id,
"label": CFG.label_names[pred_id],
"probabilities": {
CFG.label_names[i]: round(float(p), 4)
for i, p in enumerate(probs)
},
}
def predict_transformer_quantized(
text: str,
model=None,
tokenizer=None,
is_quantized: bool = True,
) -> Dict:
"""Run inference with the INT8 quantized model (CPU only)."""
if model is None or tokenizer is None:
raise ValueError("Pass a pre-loaded model and tokenizer.")
encoding = tokenizer(
text,
truncation=True,
max_length=CFG.max_length,
return_tensors="pt",
)
# INT8 quantized kernels only run on CPU
encoding = {k: v.to("cpu") for k, v in encoding.items()}
with torch.inference_mode():
logits = model(**encoding).logits[0]
probs = torch.softmax(logits, dim=-1).numpy()
pred_id = int(np.argmax(probs))
label = "[INT8] " if is_quantized else "[FP32] "
return {
"text": text,
"label_id": pred_id,
"label": CFG.label_names[pred_id],
"model_type": label.strip(),
"probabilities": {
CFG.label_names[i]: round(float(p), 4)
for i, p in enumerate(probs)
},
}
# ── Display ───────────────────────────────────────────────────────────────────
def display_result(result: Dict) -> None:
"""Print a formatted prediction result to the terminal."""
snippet = result["text"]
if len(snippet) > 90:
snippet = snippet[:90] + "…"
model_tag = f" [{result['model_type']}]" if "model_type" in result else ""
print(f"\n Input : {snippet}")
print(f" Label : [{result['label_id']}] {result['label']}{model_tag}")
if "probabilities" in result:
print(" Scores :")
sorted_probs = sorted(
result["probabilities"].items(),
key=lambda x: x[1],
reverse=True,
)
for label, prob in sorted_probs:
bar = "β–ˆ" * round(prob * 28)
blank = " " * (28 - round(prob * 28))
print(f" {label:<12} [{bar}{blank}] {prob:.4f}")
print()
# ── CLI ───────────────────────────────────────────────────────────────────────
def build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
description="Document Classifier β€” Inference",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
p.add_argument(
"--model",
default="transformer",
help="Which saved model to load: lr, svm, transformer, transformer_quantized, or a specific variant like distilbert_quantized (default: transformer)",
)
p.add_argument(
"--checkpoint",
type=str,
default="distilbert-base-uncased",
help="HuggingFace checkpoint name for transformer/transformer_quantized "
"(default: distilbert-base-uncased)",
)
p.add_argument(
"--text",
type=str,
default=None,
help="Single text string to classify",
)
p.add_argument(
"--interactive",
action="store_true",
help="Enter an interactive prediction loop",
)
return p
def main() -> None:
args = build_parser().parse_args()
# Normalise model selection and automatically extract checkpoint if specified
model_lower = args.model.lower()
is_quant = "quant" in model_lower or "int8" in model_lower
if "distilbert" in model_lower:
args.checkpoint = "distilbert-base-uncased"
elif "roberta" in model_lower:
args.checkpoint = "roberta-base"
elif "bert" in model_lower:
args.checkpoint = "bert-base-uncased"
if is_quant:
args.model = "transformer_quantized"
elif args.model not in ["lr", "svm"]:
args.model = "transformer"
# Pre-load the model once (avoids reloading on every prediction in loops)
cached_model = None
cached_tokenizer = None
cached_quantized = False
if args.model == "transformer":
print(f" Loading transformer model ({args.checkpoint}) …")
cached_model, cached_tokenizer = trm.load_model(args.checkpoint)
print(" Model ready.\n")
elif args.model == "transformer_quantized":
print(f" Loading quantized model ({args.checkpoint}) …")
cached_model, cached_tokenizer, cached_quantized = trm.load_quantized_model(
args.checkpoint
)
tag = "INT8 quantized" if cached_quantized else "FP32 (INT8 not found, fell back)"
print(f" Model ready β€” {tag}.\n")
def _predict(text: str) -> Dict:
if args.model == "transformer":
return predict_transformer(text, cached_model, cached_tokenizer)
if args.model == "transformer_quantized":
return predict_transformer_quantized(
text, cached_model, cached_tokenizer, is_quantized=cached_quantized
)
return predict_traditional(text, args.model)
if args.interactive:
print(" Document Classifier β€” Interactive Mode")
print(" Type text and press Enter. Type 'q' or 'quit' to exit.\n")
while True:
try:
text = input(" >> ").strip()
except (KeyboardInterrupt, EOFError):
print("\n Exiting.")
break
if not text:
continue
if text.lower() in {"q", "quit", "exit"}:
print(" Goodbye.")
break
display_result(_predict(text))
elif args.text:
display_result(_predict(args.text))
else:
print(" Error: provide --text <string> or use --interactive.\n")
build_parser().print_help()
sys.exit(1)
if __name__ == "__main__":
main()