""" predict.py ────────── Run inference on new text with any trained classifier. Usage ───── # Single prediction — traditional models python predict.py --model lr --text "Federal Reserve cuts rates to near zero" python predict.py --model svm --text "Ronaldo scores hat-trick in Champions League" # Single prediction — transformer (FP32) python predict.py --model transformer --text "NASA launches James Webb successor" # Single prediction — INT8 quantized transformer (fast CPU inference) python predict.py --model transformer_quantized --text "Federal Reserve raises rates" python predict.py --model transformer_quantized --checkpoint roberta-base --text "..." # Interactive loop python predict.py --model transformer --interactive python predict.py --model transformer_quantized --interactive python predict.py --model lr --interactive """ import argparse import logging import sys from typing import Dict, Optional import numpy as np import torch from config import CFG import traditional_model as tm import transformer_model as trm # Suppress INFO logs during interactive prediction logging.basicConfig(level=logging.WARNING) # ── Prediction functions ────────────────────────────────────────────────────── def predict_traditional(text: str, model_name: str) -> Dict: """Run a single prediction with a saved sklearn pipeline.""" pipeline = tm.load_model(model_name) pred_id = int(pipeline.predict([text])[0]) result: Dict = { "text": text, "label_id": pred_id, "label": CFG.label_names[pred_id], } # Logistic Regression supports predict_proba; LinearSVC does not clf = pipeline.named_steps[model_name] if hasattr(clf, "predict_proba"): probs = pipeline.predict_proba([text])[0] result["probabilities"] = { CFG.label_names[i]: round(float(p), 4) for i, p in enumerate(probs) } return result def predict_transformer( text: str, model=None, tokenizer=None, ) -> Dict: """Run a single prediction with a fine-tuned transformer (FP32, MPS/CPU).""" if model is None or tokenizer is None: model, tokenizer = trm.load_model() encoding = tokenizer( text, truncation=True, max_length=CFG.max_length, return_tensors="pt", ) with torch.no_grad(): logits = model(**encoding).logits[0] probs = torch.softmax(logits, dim=-1).numpy() pred_id = int(np.argmax(probs)) return { "text": text, "label_id": pred_id, "label": CFG.label_names[pred_id], "probabilities": { CFG.label_names[i]: round(float(p), 4) for i, p in enumerate(probs) }, } def predict_transformer_quantized( text: str, model=None, tokenizer=None, is_quantized: bool = True, ) -> Dict: """Run inference with the INT8 quantized model (CPU only).""" if model is None or tokenizer is None: raise ValueError("Pass a pre-loaded model and tokenizer.") encoding = tokenizer( text, truncation=True, max_length=CFG.max_length, return_tensors="pt", ) # INT8 quantized kernels only run on CPU encoding = {k: v.to("cpu") for k, v in encoding.items()} with torch.inference_mode(): logits = model(**encoding).logits[0] probs = torch.softmax(logits, dim=-1).numpy() pred_id = int(np.argmax(probs)) label = "[INT8] " if is_quantized else "[FP32] " return { "text": text, "label_id": pred_id, "label": CFG.label_names[pred_id], "model_type": label.strip(), "probabilities": { CFG.label_names[i]: round(float(p), 4) for i, p in enumerate(probs) }, } # ── Display ─────────────────────────────────────────────────────────────────── def display_result(result: Dict) -> None: """Print a formatted prediction result to the terminal.""" snippet = result["text"] if len(snippet) > 90: snippet = snippet[:90] + "…" model_tag = f" [{result['model_type']}]" if "model_type" in result else "" print(f"\n Input : {snippet}") print(f" Label : [{result['label_id']}] {result['label']}{model_tag}") if "probabilities" in result: print(" Scores :") sorted_probs = sorted( result["probabilities"].items(), key=lambda x: x[1], reverse=True, ) for label, prob in sorted_probs: bar = "█" * round(prob * 28) blank = " " * (28 - round(prob * 28)) print(f" {label:<12} [{bar}{blank}] {prob:.4f}") print() # ── CLI ─────────────────────────────────────────────────────────────────────── def build_parser() -> argparse.ArgumentParser: p = argparse.ArgumentParser( description="Document Classifier — Inference", formatter_class=argparse.RawDescriptionHelpFormatter, ) p.add_argument( "--model", default="transformer", help="Which saved model to load: lr, svm, transformer, transformer_quantized, or a specific variant like distilbert_quantized (default: transformer)", ) p.add_argument( "--checkpoint", type=str, default="distilbert-base-uncased", help="HuggingFace checkpoint name for transformer/transformer_quantized " "(default: distilbert-base-uncased)", ) p.add_argument( "--text", type=str, default=None, help="Single text string to classify", ) p.add_argument( "--interactive", action="store_true", help="Enter an interactive prediction loop", ) return p def main() -> None: args = build_parser().parse_args() # Normalise model selection and automatically extract checkpoint if specified model_lower = args.model.lower() is_quant = "quant" in model_lower or "int8" in model_lower if "distilbert" in model_lower: args.checkpoint = "distilbert-base-uncased" elif "roberta" in model_lower: args.checkpoint = "roberta-base" elif "bert" in model_lower: args.checkpoint = "bert-base-uncased" if is_quant: args.model = "transformer_quantized" elif args.model not in ["lr", "svm"]: args.model = "transformer" # Pre-load the model once (avoids reloading on every prediction in loops) cached_model = None cached_tokenizer = None cached_quantized = False if args.model == "transformer": print(f" Loading transformer model ({args.checkpoint}) …") cached_model, cached_tokenizer = trm.load_model(args.checkpoint) print(" Model ready.\n") elif args.model == "transformer_quantized": print(f" Loading quantized model ({args.checkpoint}) …") cached_model, cached_tokenizer, cached_quantized = trm.load_quantized_model( args.checkpoint ) tag = "INT8 quantized" if cached_quantized else "FP32 (INT8 not found, fell back)" print(f" Model ready — {tag}.\n") def _predict(text: str) -> Dict: if args.model == "transformer": return predict_transformer(text, cached_model, cached_tokenizer) if args.model == "transformer_quantized": return predict_transformer_quantized( text, cached_model, cached_tokenizer, is_quantized=cached_quantized ) return predict_traditional(text, args.model) if args.interactive: print(" Document Classifier — Interactive Mode") print(" Type text and press Enter. Type 'q' or 'quit' to exit.\n") while True: try: text = input(" >> ").strip() except (KeyboardInterrupt, EOFError): print("\n Exiting.") break if not text: continue if text.lower() in {"q", "quit", "exit"}: print(" Goodbye.") break display_result(_predict(text)) elif args.text: display_result(_predict(args.text)) else: print(" Error: provide --text or use --interactive.\n") build_parser().print_help() sys.exit(1) if __name__ == "__main__": main()