File size: 5,493 Bytes
0fc3f8b 8bdc6fb 0fc3f8b 8bdc6fb 0fc3f8b 8bdc6fb 0fc3f8b 8bdc6fb 0fc3f8b 8bdc6fb 0fc3f8b 8bdc6fb 0fc3f8b 8bdc6fb 0fc3f8b 8bdc6fb 0fc3f8b 8bdc6fb 0fc3f8b 8bdc6fb 0fc3f8b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | #!/usr/bin/env python3
"""
Zero-shot eval of Gemma 4 E2B (2B) on CPU/GPU.
"""
import argparse
import json
import time
from pathlib import Path
import torch
from datasets import load_dataset
from transformers import AutoProcessor, Gemma4ForConditionalGeneration
SYS = ("You are a phone scam detection expert. "
"Your job is to read a call transcript and decide if it is a scam.")
USER_TEMPLATE = (
"Read this phone call transcript and classify it:\n\n"
"{transcript}\n\n"
"Answer with exactly ONE of these two words: SCAM or LEGITIMATE. "
"Do not explain."
)
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--model", default="google/gemma-4-E2B-it")
p.add_argument("--dataset", default="BothBosu/scam-dialogue")
p.add_argument("--split", default="test")
p.add_argument("--limit", type=int, default=50)
p.add_argument("--dtype", default="fp16", choices=["fp16","fp32"])
p.add_argument("--out", default="results_zero_shot.json")
return p.parse_args()
def load_model(model_id: str, dtype: str):
torch_dtype = torch.float16 if dtype == "fp16" else torch.float32
print(f"Loading {model_id} (dtype={dtype}) …")
model = Gemma4ForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch_dtype,
device_map="auto",
low_cpu_mem_usage=True,
)
processor = AutoProcessor.from_pretrained(model_id)
model.eval()
return model, processor
@torch.inference_mode()
def classify(model, processor, transcript: str) -> str:
messages = [
{"role": "system", "content": [{"type": "text", "text": SYS}]},
{"role": "user", "content": [{"type": "text", "text": USER_TEMPLATE.format(transcript=transcript)}]},
]
inputs = processor.apply_chat_template(
messages, tokenize=True, return_dict=True,
return_tensors="pt", add_generation_prompt=True,
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
gen_ids = model.generate(
**inputs, max_new_tokens=5, do_sample=False,
pad_token_id=processor.tokenizer.pad_token_id,
)
new_ids = gen_ids[:, inputs["input_ids"].shape[-1]:]
return processor.batch_decode(new_ids, skip_special_tokens=True)[0].strip().upper()
def normalize(pred_raw: str) -> str:
if "SCAM" in pred_raw:
return "SCAM"
if any(w in pred_raw for w in ["LEGIT", "NOT", "SAFE", "NO", "NORMAL"]):
return "LEGITIMATE"
return pred_raw
def compute_metrics(items):
total = len(items)
tp = sum(1 for it in items if it["pred"] == "SCAM" and it["gold"] == "SCAM")
fp = sum(1 for it in items if it["pred"] == "SCAM" and it["gold"] == "LEGITIMATE")
fn = sum(1 for it in items if it["pred"] == "LEGITIMATE" and it["gold"] == "SCAM")
tn = sum(1 for it in items if it["pred"] == "LEGITIMATE" and it["gold"] == "LEGITIMATE")
accuracy = (tp + tn) / total
precision = tp / (tp + fp) if (tp + fp) else 0
recall = tp / (tp + fn) if (tp + fn) else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0
return {"total": total, "accuracy": accuracy, "precision_scam": precision,
"recall_scam": recall, "f1_scam": f1,
"confusion": {"TP": tp, "FP": fp, "FN": fn, "TN": tn}}
def main():
args = parse_args()
model, processor = load_model(args.model, args.dtype)
ds = load_dataset(args.dataset, split=args.split)
n = len(ds) if args.limit < 0 else min(args.limit, len(ds))
items = []
t0 = time.time()
for i in range(n):
row = ds[i]
gold = "SCAM" if row["label"] == 1 else "LEGITIMATE"
pred_raw = classify(model, processor, row["dialogue"])
pred = normalize(pred_raw)
correct = pred == gold
items.append({"index": i, "gold": gold, "pred_raw": pred_raw,
"pred": pred, "correct": correct})
mark = "✓" if correct else "✗"
print(f"[{i+1:3}/{n}] gold={gold:11} pred='{pred_raw:15}' → {pred:11} {mark}")
elapsed = time.time() - t0
metrics = compute_metrics(items)
metrics["time_sec"] = elapsed
metrics["throughput"] = n / elapsed
print("\n" + "=" * 60)
print("ZERO-SHOT EVALUATION REPORT")
print("=" * 60)
print(f"Model : {args.model} (2B params)")
print(f"Device : {next(model.parameters()).device}")
print(f"Samples : {n}")
print(f"Time : {elapsed:.1f}s ({metrics['throughput']:.2f} ex/s)")
print(f"Accuracy : {metrics['accuracy']:.2%}")
print(f"Precision : {metrics['precision_scam']:.2%}")
print(f"Recall : {metrics['recall_scam']:.2%}")
print(f"F1 (SCAM) : {metrics['f1_scam']:.2%}")
print(f"Confusion : TP={metrics['confusion']['TP']} FP={metrics['confusion']['FP']} "
f"FN={metrics['confusion']['FN']} TN={metrics['confusion']['TN']}")
print("=" * 60)
out = {"args": vars(args), "metrics": metrics, "items": items}
Path(args.out).write_text(json.dumps(out, indent=2))
print(f"Saved → {args.out}")
acc, f1 = metrics["accuracy"], metrics["f1_scam"]
if acc >= 0.90 and f1 >= 0.85:
print("\n✅ PASS — 2B base model is accurate enough for phone deployment.")
elif acc >= 0.75 and f1 >= 0.70:
print("\n⚠️ MARGINAL — Fine-tune with Unsloth, then LiteRT-convert.")
else:
print("\n❌ FAIL — Fine-tune REQUIRED before mobile deployment.")
if __name__ == "__main__":
main()
|