Spaces:
Running
Running
| """ | |
| predict.py | |
| ββββββββββ | |
| Run inference on new text with any trained classifier. | |
| Usage | |
| βββββ | |
| # Single prediction β traditional models | |
| python predict.py --model lr --text "Federal Reserve cuts rates to near zero" | |
| python predict.py --model svm --text "Ronaldo scores hat-trick in Champions League" | |
| # Single prediction β transformer (FP32) | |
| python predict.py --model transformer --text "NASA launches James Webb successor" | |
| # Single prediction β INT8 quantized transformer (fast CPU inference) | |
| python predict.py --model transformer_quantized --text "Federal Reserve raises rates" | |
| python predict.py --model transformer_quantized --checkpoint roberta-base --text "..." | |
| # Interactive loop | |
| python predict.py --model transformer --interactive | |
| python predict.py --model transformer_quantized --interactive | |
| python predict.py --model lr --interactive | |
| """ | |
| import argparse | |
| import logging | |
| import sys | |
| from typing import Dict, Optional | |
| import numpy as np | |
| import torch | |
| from config import CFG | |
| import traditional_model as tm | |
| import transformer_model as trm | |
| # Suppress INFO logs during interactive prediction | |
| logging.basicConfig(level=logging.WARNING) | |
| # ββ Prediction functions ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict_traditional(text: str, model_name: str) -> Dict: | |
| """Run a single prediction with a saved sklearn pipeline.""" | |
| pipeline = tm.load_model(model_name) | |
| pred_id = int(pipeline.predict([text])[0]) | |
| result: Dict = { | |
| "text": text, | |
| "label_id": pred_id, | |
| "label": CFG.label_names[pred_id], | |
| } | |
| # Logistic Regression supports predict_proba; LinearSVC does not | |
| clf = pipeline.named_steps[model_name] | |
| if hasattr(clf, "predict_proba"): | |
| probs = pipeline.predict_proba([text])[0] | |
| result["probabilities"] = { | |
| CFG.label_names[i]: round(float(p), 4) | |
| for i, p in enumerate(probs) | |
| } | |
| return result | |
| def predict_transformer( | |
| text: str, | |
| model=None, | |
| tokenizer=None, | |
| ) -> Dict: | |
| """Run a single prediction with a fine-tuned transformer (FP32, MPS/CPU).""" | |
| if model is None or tokenizer is None: | |
| model, tokenizer = trm.load_model() | |
| encoding = tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=CFG.max_length, | |
| return_tensors="pt", | |
| ) | |
| with torch.no_grad(): | |
| logits = model(**encoding).logits[0] | |
| probs = torch.softmax(logits, dim=-1).numpy() | |
| pred_id = int(np.argmax(probs)) | |
| return { | |
| "text": text, | |
| "label_id": pred_id, | |
| "label": CFG.label_names[pred_id], | |
| "probabilities": { | |
| CFG.label_names[i]: round(float(p), 4) | |
| for i, p in enumerate(probs) | |
| }, | |
| } | |
| def predict_transformer_quantized( | |
| text: str, | |
| model=None, | |
| tokenizer=None, | |
| is_quantized: bool = True, | |
| ) -> Dict: | |
| """Run inference with the INT8 quantized model (CPU only).""" | |
| if model is None or tokenizer is None: | |
| raise ValueError("Pass a pre-loaded model and tokenizer.") | |
| encoding = tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=CFG.max_length, | |
| return_tensors="pt", | |
| ) | |
| # INT8 quantized kernels only run on CPU | |
| encoding = {k: v.to("cpu") for k, v in encoding.items()} | |
| with torch.inference_mode(): | |
| logits = model(**encoding).logits[0] | |
| probs = torch.softmax(logits, dim=-1).numpy() | |
| pred_id = int(np.argmax(probs)) | |
| label = "[INT8] " if is_quantized else "[FP32] " | |
| return { | |
| "text": text, | |
| "label_id": pred_id, | |
| "label": CFG.label_names[pred_id], | |
| "model_type": label.strip(), | |
| "probabilities": { | |
| CFG.label_names[i]: round(float(p), 4) | |
| for i, p in enumerate(probs) | |
| }, | |
| } | |
| # ββ Display βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def display_result(result: Dict) -> None: | |
| """Print a formatted prediction result to the terminal.""" | |
| snippet = result["text"] | |
| if len(snippet) > 90: | |
| snippet = snippet[:90] + "β¦" | |
| model_tag = f" [{result['model_type']}]" if "model_type" in result else "" | |
| print(f"\n Input : {snippet}") | |
| print(f" Label : [{result['label_id']}] {result['label']}{model_tag}") | |
| if "probabilities" in result: | |
| print(" Scores :") | |
| sorted_probs = sorted( | |
| result["probabilities"].items(), | |
| key=lambda x: x[1], | |
| reverse=True, | |
| ) | |
| for label, prob in sorted_probs: | |
| bar = "β" * round(prob * 28) | |
| blank = " " * (28 - round(prob * 28)) | |
| print(f" {label:<12} [{bar}{blank}] {prob:.4f}") | |
| print() | |
| # ββ CLI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_parser() -> argparse.ArgumentParser: | |
| p = argparse.ArgumentParser( | |
| description="Document Classifier β Inference", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| ) | |
| p.add_argument( | |
| "--model", | |
| default="transformer", | |
| help="Which saved model to load: lr, svm, transformer, transformer_quantized, or a specific variant like distilbert_quantized (default: transformer)", | |
| ) | |
| p.add_argument( | |
| "--checkpoint", | |
| type=str, | |
| default="distilbert-base-uncased", | |
| help="HuggingFace checkpoint name for transformer/transformer_quantized " | |
| "(default: distilbert-base-uncased)", | |
| ) | |
| p.add_argument( | |
| "--text", | |
| type=str, | |
| default=None, | |
| help="Single text string to classify", | |
| ) | |
| p.add_argument( | |
| "--interactive", | |
| action="store_true", | |
| help="Enter an interactive prediction loop", | |
| ) | |
| return p | |
| def main() -> None: | |
| args = build_parser().parse_args() | |
| # Normalise model selection and automatically extract checkpoint if specified | |
| model_lower = args.model.lower() | |
| is_quant = "quant" in model_lower or "int8" in model_lower | |
| if "distilbert" in model_lower: | |
| args.checkpoint = "distilbert-base-uncased" | |
| elif "roberta" in model_lower: | |
| args.checkpoint = "roberta-base" | |
| elif "bert" in model_lower: | |
| args.checkpoint = "bert-base-uncased" | |
| if is_quant: | |
| args.model = "transformer_quantized" | |
| elif args.model not in ["lr", "svm"]: | |
| args.model = "transformer" | |
| # Pre-load the model once (avoids reloading on every prediction in loops) | |
| cached_model = None | |
| cached_tokenizer = None | |
| cached_quantized = False | |
| if args.model == "transformer": | |
| print(f" Loading transformer model ({args.checkpoint}) β¦") | |
| cached_model, cached_tokenizer = trm.load_model(args.checkpoint) | |
| print(" Model ready.\n") | |
| elif args.model == "transformer_quantized": | |
| print(f" Loading quantized model ({args.checkpoint}) β¦") | |
| cached_model, cached_tokenizer, cached_quantized = trm.load_quantized_model( | |
| args.checkpoint | |
| ) | |
| tag = "INT8 quantized" if cached_quantized else "FP32 (INT8 not found, fell back)" | |
| print(f" Model ready β {tag}.\n") | |
| def _predict(text: str) -> Dict: | |
| if args.model == "transformer": | |
| return predict_transformer(text, cached_model, cached_tokenizer) | |
| if args.model == "transformer_quantized": | |
| return predict_transformer_quantized( | |
| text, cached_model, cached_tokenizer, is_quantized=cached_quantized | |
| ) | |
| return predict_traditional(text, args.model) | |
| if args.interactive: | |
| print(" Document Classifier β Interactive Mode") | |
| print(" Type text and press Enter. Type 'q' or 'quit' to exit.\n") | |
| while True: | |
| try: | |
| text = input(" >> ").strip() | |
| except (KeyboardInterrupt, EOFError): | |
| print("\n Exiting.") | |
| break | |
| if not text: | |
| continue | |
| if text.lower() in {"q", "quit", "exit"}: | |
| print(" Goodbye.") | |
| break | |
| display_result(_predict(text)) | |
| elif args.text: | |
| display_result(_predict(args.text)) | |
| else: | |
| print(" Error: provide --text <string> or use --interactive.\n") | |
| build_parser().print_help() | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() | |