File size: 5,493 Bytes
0fc3f8b
 
8bdc6fb
0fc3f8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bdc6fb
 
 
0fc3f8b
 
 
8bdc6fb
0fc3f8b
8bdc6fb
 
0fc3f8b
 
 
8bdc6fb
0fc3f8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bdc6fb
0fc3f8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bdc6fb
0fc3f8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bdc6fb
0fc3f8b
 
8bdc6fb
0fc3f8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bdc6fb
0fc3f8b
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#!/usr/bin/env python3
"""
Zero-shot eval of Gemma 4 E2B (2B) on CPU/GPU.
"""
import argparse
import json
import time
from pathlib import Path

import torch
from datasets import load_dataset
from transformers import AutoProcessor, Gemma4ForConditionalGeneration

SYS = ("You are a phone scam detection expert. "
       "Your job is to read a call transcript and decide if it is a scam.")

USER_TEMPLATE = (
    "Read this phone call transcript and classify it:\n\n"
    "{transcript}\n\n"
    "Answer with exactly ONE of these two words: SCAM or LEGITIMATE. "
    "Do not explain."
)


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--model", default="google/gemma-4-E2B-it")
    p.add_argument("--dataset", default="BothBosu/scam-dialogue")
    p.add_argument("--split", default="test")
    p.add_argument("--limit", type=int, default=50)
    p.add_argument("--dtype", default="fp16", choices=["fp16","fp32"])
    p.add_argument("--out", default="results_zero_shot.json")
    return p.parse_args()


def load_model(model_id: str, dtype: str):
    torch_dtype = torch.float16 if dtype == "fp16" else torch.float32
    print(f"Loading {model_id} (dtype={dtype}) …")
    
    model = Gemma4ForConditionalGeneration.from_pretrained(
        model_id,
        torch_dtype=torch_dtype,
        device_map="auto",
        low_cpu_mem_usage=True,
    )
    processor = AutoProcessor.from_pretrained(model_id)
    model.eval()
    return model, processor


@torch.inference_mode()
def classify(model, processor, transcript: str) -> str:
    messages = [
        {"role": "system", "content": [{"type": "text", "text": SYS}]},
        {"role": "user",   "content": [{"type": "text", "text": USER_TEMPLATE.format(transcript=transcript)}]},
    ]
    inputs = processor.apply_chat_template(
        messages, tokenize=True, return_dict=True,
        return_tensors="pt", add_generation_prompt=True,
    )
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    gen_ids = model.generate(
        **inputs, max_new_tokens=5, do_sample=False,
        pad_token_id=processor.tokenizer.pad_token_id,
    )
    new_ids = gen_ids[:, inputs["input_ids"].shape[-1]:]
    return processor.batch_decode(new_ids, skip_special_tokens=True)[0].strip().upper()


def normalize(pred_raw: str) -> str:
    if "SCAM" in pred_raw:
        return "SCAM"
    if any(w in pred_raw for w in ["LEGIT", "NOT", "SAFE", "NO", "NORMAL"]):
        return "LEGITIMATE"
    return pred_raw


def compute_metrics(items):
    total = len(items)
    tp = sum(1 for it in items if it["pred"] == "SCAM" and it["gold"] == "SCAM")
    fp = sum(1 for it in items if it["pred"] == "SCAM" and it["gold"] == "LEGITIMATE")
    fn = sum(1 for it in items if it["pred"] == "LEGITIMATE" and it["gold"] == "SCAM")
    tn = sum(1 for it in items if it["pred"] == "LEGITIMATE" and it["gold"] == "LEGITIMATE")
    accuracy = (tp + tn) / total
    precision = tp / (tp + fp) if (tp + fp) else 0
    recall = tp / (tp + fn) if (tp + fn) else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0
    return {"total": total, "accuracy": accuracy, "precision_scam": precision,
            "recall_scam": recall, "f1_scam": f1,
            "confusion": {"TP": tp, "FP": fp, "FN": fn, "TN": tn}}


def main():
    args = parse_args()
    model, processor = load_model(args.model, args.dtype)
    ds = load_dataset(args.dataset, split=args.split)
    n = len(ds) if args.limit < 0 else min(args.limit, len(ds))

    items = []
    t0 = time.time()
    for i in range(n):
        row = ds[i]
        gold = "SCAM" if row["label"] == 1 else "LEGITIMATE"
        pred_raw = classify(model, processor, row["dialogue"])
        pred = normalize(pred_raw)
        correct = pred == gold
        items.append({"index": i, "gold": gold, "pred_raw": pred_raw,
                      "pred": pred, "correct": correct})
        mark = "✓" if correct else "✗"
        print(f"[{i+1:3}/{n}] gold={gold:11} pred='{pred_raw:15}' → {pred:11} {mark}")

    elapsed = time.time() - t0
    metrics = compute_metrics(items)
    metrics["time_sec"] = elapsed
    metrics["throughput"] = n / elapsed

    print("\n" + "=" * 60)
    print("ZERO-SHOT EVALUATION REPORT")
    print("=" * 60)
    print(f"Model      : {args.model} (2B params)")
    print(f"Device     : {next(model.parameters()).device}")
    print(f"Samples    : {n}")
    print(f"Time       : {elapsed:.1f}s ({metrics['throughput']:.2f} ex/s)")
    print(f"Accuracy   : {metrics['accuracy']:.2%}")
    print(f"Precision  : {metrics['precision_scam']:.2%}")
    print(f"Recall     : {metrics['recall_scam']:.2%}")
    print(f"F1 (SCAM)  : {metrics['f1_scam']:.2%}")
    print(f"Confusion  : TP={metrics['confusion']['TP']} FP={metrics['confusion']['FP']} "
          f"FN={metrics['confusion']['FN']} TN={metrics['confusion']['TN']}")
    print("=" * 60)

    out = {"args": vars(args), "metrics": metrics, "items": items}
    Path(args.out).write_text(json.dumps(out, indent=2))
    print(f"Saved → {args.out}")

    acc, f1 = metrics["accuracy"], metrics["f1_scam"]
    if acc >= 0.90 and f1 >= 0.85:
        print("\n✅ PASS — 2B base model is accurate enough for phone deployment.")
    elif acc >= 0.75 and f1 >= 0.70:
        print("\n⚠️ MARGINAL — Fine-tune with Unsloth, then LiteRT-convert.")
    else:
        print("\n❌ FAIL — Fine-tune REQUIRED before mobile deployment.")


if __name__ == "__main__":
    main()