s23deepak commited on
Commit
0fc3f8b
·
verified ·
1 Parent(s): 0f23a85

Upload eval_zero_shot_cpu.py

Browse files
Files changed (1) hide show
  1. eval_zero_shot_cpu.py +170 -0
eval_zero_shot_cpu.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Zero-shot eval of Gemma 4 E2B (2B) on CPU — no GPU needed.
4
+ Use this to test the smallest Gemma 4 before any mobile conversion.
5
+
6
+ REQUIREMENTS:
7
+ pip install transformers datasets torch huggingface_hub
8
+
9
+ USAGE:
10
+ # Quick test (20 samples, ~2-3 min on laptop CPU)
11
+ python eval_zero_shot_cpu.py --limit 20
12
+
13
+ # Full test split (~400 samples, ~30-45 min on CPU)
14
+ python eval_zero_shot_cpu.py --limit -1
15
+
16
+ MODEL SIZE:
17
+ gemma-4-E2B-it = 2B params
18
+ FP32 on CPU RAM: ~8 GB peak
19
+ Use --dtype fp16 to halve RAM to ~4 GB if your CPU supports it.
20
+ """
21
+
22
+ import argparse
23
+ import json
24
+ import time
25
+ from pathlib import Path
26
+
27
+ import torch
28
+ from datasets import load_dataset
29
+ from transformers import AutoProcessor, Gemma4ForConditionalGeneration
30
+
31
+ SYS = ("You are a phone scam detection expert. "
32
+ "Your job is to read a call transcript and decide if it is a scam.")
33
+
34
+ USER_TEMPLATE = (
35
+ "Read this phone call transcript and classify it:\n\n"
36
+ "{transcript}\n\n"
37
+ "Answer with exactly ONE of these two words: SCAM or LEGITIMATE. "
38
+ "Do not explain."
39
+ )
40
+
41
+
42
+ def parse_args():
43
+ p = argparse.ArgumentParser()
44
+ p.add_argument("--model", default="google/gemma-4-E2B-it")
45
+ p.add_argument("--dataset", default="BothBosu/scam-dialogue")
46
+ p.add_argument("--split", default="test")
47
+ p.add_argument("--limit", type=int, default=20,
48
+ help="Max rows (-1 = all). Default 20 for quick CPU test.")
49
+ p.add_argument("--dtype", default="fp32", choices=["fp16","fp32"],
50
+ help="fp16 = half RAM (~4 GB), fp32 = ~8 GB")
51
+ p.add_argument("--out", default="results_zero_shot_cpu.json")
52
+ return p.parse_args()
53
+
54
+
55
+ def load_model_cpu(model_id: str, dtype: str):
56
+ torch_dtype = torch.float16 if dtype == "fp16" else torch.float32
57
+ print(f"Loading {model_id} on CPU (dtype={dtype}) …")
58
+ print(f" Expected RAM: ~{4 if dtype == 'fp16' else 8} GB")
59
+ print(f" If OOM: close browser tabs, reduce --limit, or use fp16.\n")
60
+
61
+ model = Gemma4ForConditionalGeneration.from_pretrained(
62
+ model_id,
63
+ torch_dtype=torch_dtype,
64
+ device_map=None, # force CPU
65
+ low_cpu_mem_usage=True,
66
+ )
67
+ model = model.to("cpu")
68
+ processor = AutoProcessor.from_pretrained(model_id)
69
+ model.eval()
70
+ return model, processor
71
+
72
+
73
+ @torch.inference_mode()
74
+ def classify(model, processor, transcript: str) -> str:
75
+ messages = [
76
+ {"role": "system", "content": [{"type": "text", "text": SYS}]},
77
+ {"role": "user", "content": [{"type": "text", "text": USER_TEMPLATE.format(transcript=transcript)}]},
78
+ ]
79
+ inputs = processor.apply_chat_template(
80
+ messages, tokenize=True, return_dict=True,
81
+ return_tensors="pt", add_generation_prompt=True,
82
+ )
83
+ inputs = {k: v.to("cpu") for k, v in inputs.items()}
84
+
85
+ gen_ids = model.generate(
86
+ **inputs, max_new_tokens=5, do_sample=False,
87
+ pad_token_id=processor.tokenizer.pad_token_id,
88
+ )
89
+ new_ids = gen_ids[:, inputs["input_ids"].shape[-1]:]
90
+ return processor.batch_decode(new_ids, skip_special_tokens=True)[0].strip().upper()
91
+
92
+
93
+ def normalize(pred_raw: str) -> str:
94
+ if "SCAM" in pred_raw:
95
+ return "SCAM"
96
+ if any(w in pred_raw for w in ["LEGIT", "NOT", "SAFE", "NO", "NORMAL"]):
97
+ return "LEGITIMATE"
98
+ return pred_raw
99
+
100
+
101
+ def compute_metrics(items):
102
+ total = len(items)
103
+ tp = sum(1 for it in items if it["pred"] == "SCAM" and it["gold"] == "SCAM")
104
+ fp = sum(1 for it in items if it["pred"] == "SCAM" and it["gold"] == "LEGITIMATE")
105
+ fn = sum(1 for it in items if it["pred"] == "LEGITIMATE" and it["gold"] == "SCAM")
106
+ tn = sum(1 for it in items if it["pred"] == "LEGITIMATE" and it["gold"] == "LEGITIMATE")
107
+ accuracy = (tp + tn) / total
108
+ precision = tp / (tp + fp) if (tp + fp) else 0
109
+ recall = tp / (tp + fn) if (tp + fn) else 0
110
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0
111
+ return {"total": total, "accuracy": accuracy, "precision_scam": precision,
112
+ "recall_scam": recall, "f1_scam": f1,
113
+ "confusion": {"TP": tp, "FP": fp, "FN": fn, "TN": tn}}
114
+
115
+
116
+ def main():
117
+ args = parse_args()
118
+ model, processor = load_model_cpu(args.model, args.dtype)
119
+ ds = load_dataset(args.dataset, split=args.split)
120
+ n = len(ds) if args.limit < 0 else min(args.limit, len(ds))
121
+
122
+ items = []
123
+ t0 = time.time()
124
+ for i in range(n):
125
+ row = ds[i]
126
+ gold = "SCAM" if row["label"] == 1 else "LEGITIMATE"
127
+ pred_raw = classify(model, processor, row["dialogue"])
128
+ pred = normalize(pred_raw)
129
+ correct = pred == gold
130
+ items.append({"index": i, "gold": gold, "pred_raw": pred_raw,
131
+ "pred": pred, "correct": correct})
132
+ mark = "✓" if correct else "✗"
133
+ print(f"[{i+1:3}/{n}] gold={gold:11} pred='{pred_raw:15}' → {pred:11} {mark}")
134
+
135
+ elapsed = time.time() - t0
136
+ metrics = compute_metrics(items)
137
+ metrics["time_sec"] = elapsed
138
+ metrics["throughput"] = n / elapsed
139
+
140
+ print("\n" + "=" * 60)
141
+ print("CPU ZERO-SHOT REPORT")
142
+ print("=" * 60)
143
+ print(f"Model : {args.model} (2B params)")
144
+ print(f"Device : CPU")
145
+ print(f"Samples : {n}")
146
+ print(f"Time : {elapsed:.1f}s ({metrics['throughput']:.2f} ex/s)")
147
+ print(f"Accuracy : {metrics['accuracy']:.2%}")
148
+ print(f"Precision : {metrics['precision_scam']:.2%}")
149
+ print(f"Recall : {metrics['recall_scam']:.2%}")
150
+ print(f"F1 (SCAM) : {metrics['f1_scam']:.2%}")
151
+ print(f"Confusion : TP={metrics['confusion']['TP']} FP={metrics['confusion']['FP']} "
152
+ f"FN={metrics['confusion']['FN']} TN={metrics['confusion']['TN']}")
153
+ print("=" * 60)
154
+
155
+ out = {"args": vars(args), "metrics": metrics, "items": items}
156
+ Path(args.out).write_text(json.dumps(out, indent=2))
157
+ print(f"Saved → {args.out}")
158
+
159
+ acc, f1 = metrics["accuracy"], metrics["f1_scam"]
160
+ if acc >= 0.90 and f1 >= 0.85:
161
+ print("\n✅ PASS — 2B base model is accurate enough.")
162
+ print(" Next: convert to LiteRT 4-bit for phone (~1.5 GB RAM).")
163
+ elif acc >= 0.75 and f1 >= 0.70:
164
+ print("\n⚠️ MARGINAL — Fine-tune with Unsloth, then LiteRT-convert.")
165
+ else:
166
+ print("\n❌ FAIL — Fine-tune REQUIRED before mobile deployment.")
167
+
168
+
169
+ if __name__ == "__main__":
170
+ main()