| |
| """ |
| Zero-shot eval of Gemma 4 E2B (2B) on CPU/GPU. |
| """ |
| import argparse |
| import json |
| import time |
| from pathlib import Path |
|
|
| import torch |
| from datasets import load_dataset |
| from transformers import AutoProcessor, Gemma4ForConditionalGeneration |
|
|
| SYS = ("You are a phone scam detection expert. " |
| "Your job is to read a call transcript and decide if it is a scam.") |
|
|
| USER_TEMPLATE = ( |
| "Read this phone call transcript and classify it:\n\n" |
| "{transcript}\n\n" |
| "Answer with exactly ONE of these two words: SCAM or LEGITIMATE. " |
| "Do not explain." |
| ) |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--model", default="google/gemma-4-E2B-it") |
| p.add_argument("--dataset", default="BothBosu/scam-dialogue") |
| p.add_argument("--split", default="test") |
| p.add_argument("--limit", type=int, default=50) |
| p.add_argument("--dtype", default="fp16", choices=["fp16","fp32"]) |
| p.add_argument("--out", default="results_zero_shot.json") |
| return p.parse_args() |
|
|
|
|
| def load_model(model_id: str, dtype: str): |
| torch_dtype = torch.float16 if dtype == "fp16" else torch.float32 |
| print(f"Loading {model_id} (dtype={dtype}) …") |
| |
| model = Gemma4ForConditionalGeneration.from_pretrained( |
| model_id, |
| torch_dtype=torch_dtype, |
| device_map="auto", |
| low_cpu_mem_usage=True, |
| ) |
| processor = AutoProcessor.from_pretrained(model_id) |
| model.eval() |
| return model, processor |
|
|
|
|
| @torch.inference_mode() |
| def classify(model, processor, transcript: str) -> str: |
| messages = [ |
| {"role": "system", "content": [{"type": "text", "text": SYS}]}, |
| {"role": "user", "content": [{"type": "text", "text": USER_TEMPLATE.format(transcript=transcript)}]}, |
| ] |
| inputs = processor.apply_chat_template( |
| messages, tokenize=True, return_dict=True, |
| return_tensors="pt", add_generation_prompt=True, |
| ) |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
| gen_ids = model.generate( |
| **inputs, max_new_tokens=5, do_sample=False, |
| pad_token_id=processor.tokenizer.pad_token_id, |
| ) |
| new_ids = gen_ids[:, inputs["input_ids"].shape[-1]:] |
| return processor.batch_decode(new_ids, skip_special_tokens=True)[0].strip().upper() |
|
|
|
|
| def normalize(pred_raw: str) -> str: |
| if "SCAM" in pred_raw: |
| return "SCAM" |
| if any(w in pred_raw for w in ["LEGIT", "NOT", "SAFE", "NO", "NORMAL"]): |
| return "LEGITIMATE" |
| return pred_raw |
|
|
|
|
| def compute_metrics(items): |
| total = len(items) |
| tp = sum(1 for it in items if it["pred"] == "SCAM" and it["gold"] == "SCAM") |
| fp = sum(1 for it in items if it["pred"] == "SCAM" and it["gold"] == "LEGITIMATE") |
| fn = sum(1 for it in items if it["pred"] == "LEGITIMATE" and it["gold"] == "SCAM") |
| tn = sum(1 for it in items if it["pred"] == "LEGITIMATE" and it["gold"] == "LEGITIMATE") |
| accuracy = (tp + tn) / total |
| precision = tp / (tp + fp) if (tp + fp) else 0 |
| recall = tp / (tp + fn) if (tp + fn) else 0 |
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0 |
| return {"total": total, "accuracy": accuracy, "precision_scam": precision, |
| "recall_scam": recall, "f1_scam": f1, |
| "confusion": {"TP": tp, "FP": fp, "FN": fn, "TN": tn}} |
|
|
|
|
| def main(): |
| args = parse_args() |
| model, processor = load_model(args.model, args.dtype) |
| ds = load_dataset(args.dataset, split=args.split) |
| n = len(ds) if args.limit < 0 else min(args.limit, len(ds)) |
|
|
| items = [] |
| t0 = time.time() |
| for i in range(n): |
| row = ds[i] |
| gold = "SCAM" if row["label"] == 1 else "LEGITIMATE" |
| pred_raw = classify(model, processor, row["dialogue"]) |
| pred = normalize(pred_raw) |
| correct = pred == gold |
| items.append({"index": i, "gold": gold, "pred_raw": pred_raw, |
| "pred": pred, "correct": correct}) |
| mark = "✓" if correct else "✗" |
| print(f"[{i+1:3}/{n}] gold={gold:11} pred='{pred_raw:15}' → {pred:11} {mark}") |
|
|
| elapsed = time.time() - t0 |
| metrics = compute_metrics(items) |
| metrics["time_sec"] = elapsed |
| metrics["throughput"] = n / elapsed |
|
|
| print("\n" + "=" * 60) |
| print("ZERO-SHOT EVALUATION REPORT") |
| print("=" * 60) |
| print(f"Model : {args.model} (2B params)") |
| print(f"Device : {next(model.parameters()).device}") |
| print(f"Samples : {n}") |
| print(f"Time : {elapsed:.1f}s ({metrics['throughput']:.2f} ex/s)") |
| print(f"Accuracy : {metrics['accuracy']:.2%}") |
| print(f"Precision : {metrics['precision_scam']:.2%}") |
| print(f"Recall : {metrics['recall_scam']:.2%}") |
| print(f"F1 (SCAM) : {metrics['f1_scam']:.2%}") |
| print(f"Confusion : TP={metrics['confusion']['TP']} FP={metrics['confusion']['FP']} " |
| f"FN={metrics['confusion']['FN']} TN={metrics['confusion']['TN']}") |
| print("=" * 60) |
|
|
| out = {"args": vars(args), "metrics": metrics, "items": items} |
| Path(args.out).write_text(json.dumps(out, indent=2)) |
| print(f"Saved → {args.out}") |
|
|
| acc, f1 = metrics["accuracy"], metrics["f1_scam"] |
| if acc >= 0.90 and f1 >= 0.85: |
| print("\n✅ PASS — 2B base model is accurate enough for phone deployment.") |
| elif acc >= 0.75 and f1 >= 0.70: |
| print("\n⚠️ MARGINAL — Fine-tune with Unsloth, then LiteRT-convert.") |
| else: |
| print("\n❌ FAIL — Fine-tune REQUIRED before mobile deployment.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|