#!/usr/bin/env python3 """ Zero-shot eval of Gemma 4 E2B (2B) on CPU/GPU. """ import argparse import json import time from pathlib import Path import torch from datasets import load_dataset from transformers import AutoProcessor, Gemma4ForConditionalGeneration SYS = ("You are a phone scam detection expert. " "Your job is to read a call transcript and decide if it is a scam.") USER_TEMPLATE = ( "Read this phone call transcript and classify it:\n\n" "{transcript}\n\n" "Answer with exactly ONE of these two words: SCAM or LEGITIMATE. " "Do not explain." ) def parse_args(): p = argparse.ArgumentParser() p.add_argument("--model", default="google/gemma-4-E2B-it") p.add_argument("--dataset", default="BothBosu/scam-dialogue") p.add_argument("--split", default="test") p.add_argument("--limit", type=int, default=50) p.add_argument("--dtype", default="fp16", choices=["fp16","fp32"]) p.add_argument("--out", default="results_zero_shot.json") return p.parse_args() def load_model(model_id: str, dtype: str): torch_dtype = torch.float16 if dtype == "fp16" else torch.float32 print(f"Loading {model_id} (dtype={dtype}) …") model = Gemma4ForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch_dtype, device_map="auto", low_cpu_mem_usage=True, ) processor = AutoProcessor.from_pretrained(model_id) model.eval() return model, processor @torch.inference_mode() def classify(model, processor, transcript: str) -> str: messages = [ {"role": "system", "content": [{"type": "text", "text": SYS}]}, {"role": "user", "content": [{"type": "text", "text": USER_TEMPLATE.format(transcript=transcript)}]}, ] inputs = processor.apply_chat_template( messages, tokenize=True, return_dict=True, return_tensors="pt", add_generation_prompt=True, ) inputs = {k: v.to(model.device) for k, v in inputs.items()} gen_ids = model.generate( **inputs, max_new_tokens=5, do_sample=False, pad_token_id=processor.tokenizer.pad_token_id, ) new_ids = gen_ids[:, inputs["input_ids"].shape[-1]:] return processor.batch_decode(new_ids, skip_special_tokens=True)[0].strip().upper() def normalize(pred_raw: str) -> str: if "SCAM" in pred_raw: return "SCAM" if any(w in pred_raw for w in ["LEGIT", "NOT", "SAFE", "NO", "NORMAL"]): return "LEGITIMATE" return pred_raw def compute_metrics(items): total = len(items) tp = sum(1 for it in items if it["pred"] == "SCAM" and it["gold"] == "SCAM") fp = sum(1 for it in items if it["pred"] == "SCAM" and it["gold"] == "LEGITIMATE") fn = sum(1 for it in items if it["pred"] == "LEGITIMATE" and it["gold"] == "SCAM") tn = sum(1 for it in items if it["pred"] == "LEGITIMATE" and it["gold"] == "LEGITIMATE") accuracy = (tp + tn) / total precision = tp / (tp + fp) if (tp + fp) else 0 recall = tp / (tp + fn) if (tp + fn) else 0 f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0 return {"total": total, "accuracy": accuracy, "precision_scam": precision, "recall_scam": recall, "f1_scam": f1, "confusion": {"TP": tp, "FP": fp, "FN": fn, "TN": tn}} def main(): args = parse_args() model, processor = load_model(args.model, args.dtype) ds = load_dataset(args.dataset, split=args.split) n = len(ds) if args.limit < 0 else min(args.limit, len(ds)) items = [] t0 = time.time() for i in range(n): row = ds[i] gold = "SCAM" if row["label"] == 1 else "LEGITIMATE" pred_raw = classify(model, processor, row["dialogue"]) pred = normalize(pred_raw) correct = pred == gold items.append({"index": i, "gold": gold, "pred_raw": pred_raw, "pred": pred, "correct": correct}) mark = "✓" if correct else "✗" print(f"[{i+1:3}/{n}] gold={gold:11} pred='{pred_raw:15}' → {pred:11} {mark}") elapsed = time.time() - t0 metrics = compute_metrics(items) metrics["time_sec"] = elapsed metrics["throughput"] = n / elapsed print("\n" + "=" * 60) print("ZERO-SHOT EVALUATION REPORT") print("=" * 60) print(f"Model : {args.model} (2B params)") print(f"Device : {next(model.parameters()).device}") print(f"Samples : {n}") print(f"Time : {elapsed:.1f}s ({metrics['throughput']:.2f} ex/s)") print(f"Accuracy : {metrics['accuracy']:.2%}") print(f"Precision : {metrics['precision_scam']:.2%}") print(f"Recall : {metrics['recall_scam']:.2%}") print(f"F1 (SCAM) : {metrics['f1_scam']:.2%}") print(f"Confusion : TP={metrics['confusion']['TP']} FP={metrics['confusion']['FP']} " f"FN={metrics['confusion']['FN']} TN={metrics['confusion']['TN']}") print("=" * 60) out = {"args": vars(args), "metrics": metrics, "items": items} Path(args.out).write_text(json.dumps(out, indent=2)) print(f"Saved → {args.out}") acc, f1 = metrics["accuracy"], metrics["f1_scam"] if acc >= 0.90 and f1 >= 0.85: print("\n✅ PASS — 2B base model is accurate enough for phone deployment.") elif acc >= 0.75 and f1 >= 0.70: print("\n⚠️ MARGINAL — Fine-tune with Unsloth, then LiteRT-convert.") else: print("\n❌ FAIL — Fine-tune REQUIRED before mobile deployment.") if __name__ == "__main__": main()