grandgemma-eval / eval_zero_shot_cpu.py
s23deepak's picture
Upload eval_zero_shot_cpu.py
8bdc6fb verified
#!/usr/bin/env python3
"""
Zero-shot eval of Gemma 4 E2B (2B) on CPU/GPU.
"""
import argparse
import json
import time
from pathlib import Path
import torch
from datasets import load_dataset
from transformers import AutoProcessor, Gemma4ForConditionalGeneration
SYS = ("You are a phone scam detection expert. "
"Your job is to read a call transcript and decide if it is a scam.")
USER_TEMPLATE = (
"Read this phone call transcript and classify it:\n\n"
"{transcript}\n\n"
"Answer with exactly ONE of these two words: SCAM or LEGITIMATE. "
"Do not explain."
)
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--model", default="google/gemma-4-E2B-it")
p.add_argument("--dataset", default="BothBosu/scam-dialogue")
p.add_argument("--split", default="test")
p.add_argument("--limit", type=int, default=50)
p.add_argument("--dtype", default="fp16", choices=["fp16","fp32"])
p.add_argument("--out", default="results_zero_shot.json")
return p.parse_args()
def load_model(model_id: str, dtype: str):
torch_dtype = torch.float16 if dtype == "fp16" else torch.float32
print(f"Loading {model_id} (dtype={dtype}) …")
model = Gemma4ForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch_dtype,
device_map="auto",
low_cpu_mem_usage=True,
)
processor = AutoProcessor.from_pretrained(model_id)
model.eval()
return model, processor
@torch.inference_mode()
def classify(model, processor, transcript: str) -> str:
messages = [
{"role": "system", "content": [{"type": "text", "text": SYS}]},
{"role": "user", "content": [{"type": "text", "text": USER_TEMPLATE.format(transcript=transcript)}]},
]
inputs = processor.apply_chat_template(
messages, tokenize=True, return_dict=True,
return_tensors="pt", add_generation_prompt=True,
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
gen_ids = model.generate(
**inputs, max_new_tokens=5, do_sample=False,
pad_token_id=processor.tokenizer.pad_token_id,
)
new_ids = gen_ids[:, inputs["input_ids"].shape[-1]:]
return processor.batch_decode(new_ids, skip_special_tokens=True)[0].strip().upper()
def normalize(pred_raw: str) -> str:
if "SCAM" in pred_raw:
return "SCAM"
if any(w in pred_raw for w in ["LEGIT", "NOT", "SAFE", "NO", "NORMAL"]):
return "LEGITIMATE"
return pred_raw
def compute_metrics(items):
total = len(items)
tp = sum(1 for it in items if it["pred"] == "SCAM" and it["gold"] == "SCAM")
fp = sum(1 for it in items if it["pred"] == "SCAM" and it["gold"] == "LEGITIMATE")
fn = sum(1 for it in items if it["pred"] == "LEGITIMATE" and it["gold"] == "SCAM")
tn = sum(1 for it in items if it["pred"] == "LEGITIMATE" and it["gold"] == "LEGITIMATE")
accuracy = (tp + tn) / total
precision = tp / (tp + fp) if (tp + fp) else 0
recall = tp / (tp + fn) if (tp + fn) else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0
return {"total": total, "accuracy": accuracy, "precision_scam": precision,
"recall_scam": recall, "f1_scam": f1,
"confusion": {"TP": tp, "FP": fp, "FN": fn, "TN": tn}}
def main():
args = parse_args()
model, processor = load_model(args.model, args.dtype)
ds = load_dataset(args.dataset, split=args.split)
n = len(ds) if args.limit < 0 else min(args.limit, len(ds))
items = []
t0 = time.time()
for i in range(n):
row = ds[i]
gold = "SCAM" if row["label"] == 1 else "LEGITIMATE"
pred_raw = classify(model, processor, row["dialogue"])
pred = normalize(pred_raw)
correct = pred == gold
items.append({"index": i, "gold": gold, "pred_raw": pred_raw,
"pred": pred, "correct": correct})
mark = "✓" if correct else "✗"
print(f"[{i+1:3}/{n}] gold={gold:11} pred='{pred_raw:15}' → {pred:11} {mark}")
elapsed = time.time() - t0
metrics = compute_metrics(items)
metrics["time_sec"] = elapsed
metrics["throughput"] = n / elapsed
print("\n" + "=" * 60)
print("ZERO-SHOT EVALUATION REPORT")
print("=" * 60)
print(f"Model : {args.model} (2B params)")
print(f"Device : {next(model.parameters()).device}")
print(f"Samples : {n}")
print(f"Time : {elapsed:.1f}s ({metrics['throughput']:.2f} ex/s)")
print(f"Accuracy : {metrics['accuracy']:.2%}")
print(f"Precision : {metrics['precision_scam']:.2%}")
print(f"Recall : {metrics['recall_scam']:.2%}")
print(f"F1 (SCAM) : {metrics['f1_scam']:.2%}")
print(f"Confusion : TP={metrics['confusion']['TP']} FP={metrics['confusion']['FP']} "
f"FN={metrics['confusion']['FN']} TN={metrics['confusion']['TN']}")
print("=" * 60)
out = {"args": vars(args), "metrics": metrics, "items": items}
Path(args.out).write_text(json.dumps(out, indent=2))
print(f"Saved → {args.out}")
acc, f1 = metrics["accuracy"], metrics["f1_scam"]
if acc >= 0.90 and f1 >= 0.85:
print("\n✅ PASS — 2B base model is accurate enough for phone deployment.")
elif acc >= 0.75 and f1 >= 0.70:
print("\n⚠️ MARGINAL — Fine-tune with Unsloth, then LiteRT-convert.")
else:
print("\n❌ FAIL — Fine-tune REQUIRED before mobile deployment.")
if __name__ == "__main__":
main()