File size: 8,827 Bytes
a229747
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
"""
predict.py
──────────
Run inference on new text with any trained classifier.

Usage
─────
    # Single prediction β€” traditional models
    python predict.py --model lr  --text "Federal Reserve cuts rates to near zero"
    python predict.py --model svm --text "Ronaldo scores hat-trick in Champions League"

    # Single prediction β€” transformer (FP32)
    python predict.py --model transformer --text "NASA launches James Webb successor"

    # Single prediction β€” INT8 quantized transformer (fast CPU inference)
    python predict.py --model transformer_quantized --text "Federal Reserve raises rates"
    python predict.py --model transformer_quantized --checkpoint roberta-base --text "..."

    # Interactive loop
    python predict.py --model transformer --interactive
    python predict.py --model transformer_quantized --interactive
    python predict.py --model lr --interactive
"""
import argparse
import logging
import sys
from typing import Dict, Optional

import numpy as np
import torch

from config import CFG
import traditional_model as tm
import transformer_model as trm

# Suppress INFO logs during interactive prediction
logging.basicConfig(level=logging.WARNING)


# ── Prediction functions ──────────────────────────────────────────────────────

def predict_traditional(text: str, model_name: str) -> Dict:
    """Run a single prediction with a saved sklearn pipeline."""
    pipeline = tm.load_model(model_name)
    pred_id  = int(pipeline.predict([text])[0])

    result: Dict = {
        "text":     text,
        "label_id": pred_id,
        "label":    CFG.label_names[pred_id],
    }

    # Logistic Regression supports predict_proba; LinearSVC does not
    clf = pipeline.named_steps[model_name]
    if hasattr(clf, "predict_proba"):
        probs = pipeline.predict_proba([text])[0]
        result["probabilities"] = {
            CFG.label_names[i]: round(float(p), 4)
            for i, p in enumerate(probs)
        }

    return result


def predict_transformer(
    text: str,
    model=None,
    tokenizer=None,
) -> Dict:
    """Run a single prediction with a fine-tuned transformer (FP32, MPS/CPU)."""
    if model is None or tokenizer is None:
        model, tokenizer = trm.load_model()

    encoding = tokenizer(
        text,
        truncation=True,
        max_length=CFG.max_length,
        return_tensors="pt",
    )

    with torch.no_grad():
        logits = model(**encoding).logits[0]

    probs   = torch.softmax(logits, dim=-1).numpy()
    pred_id = int(np.argmax(probs))

    return {
        "text":     text,
        "label_id": pred_id,
        "label":    CFG.label_names[pred_id],
        "probabilities": {
            CFG.label_names[i]: round(float(p), 4)
            for i, p in enumerate(probs)
        },
    }


def predict_transformer_quantized(
    text: str,
    model=None,
    tokenizer=None,
    is_quantized: bool = True,
) -> Dict:
    """Run inference with the INT8 quantized model (CPU only)."""
    if model is None or tokenizer is None:
        raise ValueError("Pass a pre-loaded model and tokenizer.")

    encoding = tokenizer(
        text,
        truncation=True,
        max_length=CFG.max_length,
        return_tensors="pt",
    )
    # INT8 quantized kernels only run on CPU
    encoding = {k: v.to("cpu") for k, v in encoding.items()}

    with torch.inference_mode():
        logits = model(**encoding).logits[0]

    probs   = torch.softmax(logits, dim=-1).numpy()
    pred_id = int(np.argmax(probs))

    label = "[INT8] " if is_quantized else "[FP32] "
    return {
        "text":        text,
        "label_id":   pred_id,
        "label":      CFG.label_names[pred_id],
        "model_type": label.strip(),
        "probabilities": {
            CFG.label_names[i]: round(float(p), 4)
            for i, p in enumerate(probs)
        },
    }


# ── Display ───────────────────────────────────────────────────────────────────

def display_result(result: Dict) -> None:
    """Print a formatted prediction result to the terminal."""
    snippet = result["text"]
    if len(snippet) > 90:
        snippet = snippet[:90] + "…"

    model_tag = f"  [{result['model_type']}]" if "model_type" in result else ""
    print(f"\n  Input  : {snippet}")
    print(f"  Label  : [{result['label_id']}]  {result['label']}{model_tag}")

    if "probabilities" in result:
        print("  Scores :")
        sorted_probs = sorted(
            result["probabilities"].items(),
            key=lambda x: x[1],
            reverse=True,
        )
        for label, prob in sorted_probs:
            bar   = "β–ˆ" * round(prob * 28)
            blank = " " * (28 - round(prob * 28))
            print(f"    {label:<12}  [{bar}{blank}]  {prob:.4f}")
    print()


# ── CLI ───────────────────────────────────────────────────────────────────────

def build_parser() -> argparse.ArgumentParser:
    p = argparse.ArgumentParser(
        description="Document Classifier β€” Inference",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    p.add_argument(
        "--model",
        default="transformer",
        help="Which saved model to load: lr, svm, transformer, transformer_quantized, or a specific variant like distilbert_quantized (default: transformer)",
    )
    p.add_argument(
        "--checkpoint",
        type=str,
        default="distilbert-base-uncased",
        help="HuggingFace checkpoint name for transformer/transformer_quantized "
             "(default: distilbert-base-uncased)",
    )
    p.add_argument(
        "--text",
        type=str,
        default=None,
        help="Single text string to classify",
    )
    p.add_argument(
        "--interactive",
        action="store_true",
        help="Enter an interactive prediction loop",
    )
    return p


def main() -> None:
    args = build_parser().parse_args()

    # Normalise model selection and automatically extract checkpoint if specified
    model_lower = args.model.lower()
    is_quant = "quant" in model_lower or "int8" in model_lower
    
    if "distilbert" in model_lower:
        args.checkpoint = "distilbert-base-uncased"
    elif "roberta" in model_lower:
        args.checkpoint = "roberta-base"
    elif "bert" in model_lower:
        args.checkpoint = "bert-base-uncased"
        
    if is_quant:
        args.model = "transformer_quantized"
    elif args.model not in ["lr", "svm"]:
        args.model = "transformer"

    # Pre-load the model once (avoids reloading on every prediction in loops)
    cached_model      = None
    cached_tokenizer  = None
    cached_quantized  = False

    if args.model == "transformer":
        print(f"  Loading transformer model ({args.checkpoint}) …")
        cached_model, cached_tokenizer = trm.load_model(args.checkpoint)
        print("  Model ready.\n")

    elif args.model == "transformer_quantized":
        print(f"  Loading quantized model ({args.checkpoint}) …")
        cached_model, cached_tokenizer, cached_quantized = trm.load_quantized_model(
            args.checkpoint
        )
        tag = "INT8 quantized" if cached_quantized else "FP32 (INT8 not found, fell back)"
        print(f"  Model ready β€” {tag}.\n")

    def _predict(text: str) -> Dict:
        if args.model == "transformer":
            return predict_transformer(text, cached_model, cached_tokenizer)
        if args.model == "transformer_quantized":
            return predict_transformer_quantized(
                text, cached_model, cached_tokenizer, is_quantized=cached_quantized
            )
        return predict_traditional(text, args.model)

    if args.interactive:
        print("  Document Classifier β€” Interactive Mode")
        print("  Type text and press Enter.  Type 'q' or 'quit' to exit.\n")
        while True:
            try:
                text = input("  >> ").strip()
            except (KeyboardInterrupt, EOFError):
                print("\n  Exiting.")
                break
            if not text:
                continue
            if text.lower() in {"q", "quit", "exit"}:
                print("  Goodbye.")
                break
            display_result(_predict(text))

    elif args.text:
        display_result(_predict(args.text))

    else:
        print("  Error: provide --text <string> or use --interactive.\n")
        build_parser().print_help()
        sys.exit(1)


if __name__ == "__main__":
    main()