Upload eval_zero_shot_cpu.py
Browse files- eval_zero_shot_cpu.py +13 -36
eval_zero_shot_cpu.py
CHANGED
|
@@ -1,24 +1,7 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
Zero-shot eval of Gemma 4 E2B (2B) on CPU
|
| 4 |
-
Use this to test the smallest Gemma 4 before any mobile conversion.
|
| 5 |
-
|
| 6 |
-
REQUIREMENTS:
|
| 7 |
-
pip install transformers datasets torch huggingface_hub
|
| 8 |
-
|
| 9 |
-
USAGE:
|
| 10 |
-
# Quick test (20 samples, ~2-3 min on laptop CPU)
|
| 11 |
-
python eval_zero_shot_cpu.py --limit 20
|
| 12 |
-
|
| 13 |
-
# Full test split (~400 samples, ~30-45 min on CPU)
|
| 14 |
-
python eval_zero_shot_cpu.py --limit -1
|
| 15 |
-
|
| 16 |
-
MODEL SIZE:
|
| 17 |
-
gemma-4-E2B-it = 2B params
|
| 18 |
-
FP32 on CPU RAM: ~8 GB peak
|
| 19 |
-
Use --dtype fp16 to halve RAM to ~4 GB if your CPU supports it.
|
| 20 |
"""
|
| 21 |
-
|
| 22 |
import argparse
|
| 23 |
import json
|
| 24 |
import time
|
|
@@ -44,27 +27,22 @@ def parse_args():
|
|
| 44 |
p.add_argument("--model", default="google/gemma-4-E2B-it")
|
| 45 |
p.add_argument("--dataset", default="BothBosu/scam-dialogue")
|
| 46 |
p.add_argument("--split", default="test")
|
| 47 |
-
p.add_argument("--limit", type=int, default=
|
| 48 |
-
|
| 49 |
-
p.add_argument("--
|
| 50 |
-
help="fp16 = half RAM (~4 GB), fp32 = ~8 GB")
|
| 51 |
-
p.add_argument("--out", default="results_zero_shot_cpu.json")
|
| 52 |
return p.parse_args()
|
| 53 |
|
| 54 |
|
| 55 |
-
def
|
| 56 |
torch_dtype = torch.float16 if dtype == "fp16" else torch.float32
|
| 57 |
-
print(f"Loading {model_id}
|
| 58 |
-
|
| 59 |
-
print(f" If OOM: close browser tabs, reduce --limit, or use fp16.\n")
|
| 60 |
-
|
| 61 |
model = Gemma4ForConditionalGeneration.from_pretrained(
|
| 62 |
model_id,
|
| 63 |
torch_dtype=torch_dtype,
|
| 64 |
-
device_map=
|
| 65 |
low_cpu_mem_usage=True,
|
| 66 |
)
|
| 67 |
-
model = model.to("cpu")
|
| 68 |
processor = AutoProcessor.from_pretrained(model_id)
|
| 69 |
model.eval()
|
| 70 |
return model, processor
|
|
@@ -80,7 +58,7 @@ def classify(model, processor, transcript: str) -> str:
|
|
| 80 |
messages, tokenize=True, return_dict=True,
|
| 81 |
return_tensors="pt", add_generation_prompt=True,
|
| 82 |
)
|
| 83 |
-
inputs = {k: v.to(
|
| 84 |
|
| 85 |
gen_ids = model.generate(
|
| 86 |
**inputs, max_new_tokens=5, do_sample=False,
|
|
@@ -115,7 +93,7 @@ def compute_metrics(items):
|
|
| 115 |
|
| 116 |
def main():
|
| 117 |
args = parse_args()
|
| 118 |
-
model, processor =
|
| 119 |
ds = load_dataset(args.dataset, split=args.split)
|
| 120 |
n = len(ds) if args.limit < 0 else min(args.limit, len(ds))
|
| 121 |
|
|
@@ -138,10 +116,10 @@ def main():
|
|
| 138 |
metrics["throughput"] = n / elapsed
|
| 139 |
|
| 140 |
print("\n" + "=" * 60)
|
| 141 |
-
print("
|
| 142 |
print("=" * 60)
|
| 143 |
print(f"Model : {args.model} (2B params)")
|
| 144 |
-
print(f"Device :
|
| 145 |
print(f"Samples : {n}")
|
| 146 |
print(f"Time : {elapsed:.1f}s ({metrics['throughput']:.2f} ex/s)")
|
| 147 |
print(f"Accuracy : {metrics['accuracy']:.2%}")
|
|
@@ -158,8 +136,7 @@ def main():
|
|
| 158 |
|
| 159 |
acc, f1 = metrics["accuracy"], metrics["f1_scam"]
|
| 160 |
if acc >= 0.90 and f1 >= 0.85:
|
| 161 |
-
print("\n✅ PASS — 2B base model is accurate enough.")
|
| 162 |
-
print(" Next: convert to LiteRT 4-bit for phone (~1.5 GB RAM).")
|
| 163 |
elif acc >= 0.75 and f1 >= 0.70:
|
| 164 |
print("\n⚠️ MARGINAL — Fine-tune with Unsloth, then LiteRT-convert.")
|
| 165 |
else:
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
Zero-shot eval of Gemma 4 E2B (2B) on CPU/GPU.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
|
|
|
| 5 |
import argparse
|
| 6 |
import json
|
| 7 |
import time
|
|
|
|
| 27 |
p.add_argument("--model", default="google/gemma-4-E2B-it")
|
| 28 |
p.add_argument("--dataset", default="BothBosu/scam-dialogue")
|
| 29 |
p.add_argument("--split", default="test")
|
| 30 |
+
p.add_argument("--limit", type=int, default=50)
|
| 31 |
+
p.add_argument("--dtype", default="fp16", choices=["fp16","fp32"])
|
| 32 |
+
p.add_argument("--out", default="results_zero_shot.json")
|
|
|
|
|
|
|
| 33 |
return p.parse_args()
|
| 34 |
|
| 35 |
|
| 36 |
+
def load_model(model_id: str, dtype: str):
|
| 37 |
torch_dtype = torch.float16 if dtype == "fp16" else torch.float32
|
| 38 |
+
print(f"Loading {model_id} (dtype={dtype}) …")
|
| 39 |
+
|
|
|
|
|
|
|
| 40 |
model = Gemma4ForConditionalGeneration.from_pretrained(
|
| 41 |
model_id,
|
| 42 |
torch_dtype=torch_dtype,
|
| 43 |
+
device_map="auto",
|
| 44 |
low_cpu_mem_usage=True,
|
| 45 |
)
|
|
|
|
| 46 |
processor = AutoProcessor.from_pretrained(model_id)
|
| 47 |
model.eval()
|
| 48 |
return model, processor
|
|
|
|
| 58 |
messages, tokenize=True, return_dict=True,
|
| 59 |
return_tensors="pt", add_generation_prompt=True,
|
| 60 |
)
|
| 61 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 62 |
|
| 63 |
gen_ids = model.generate(
|
| 64 |
**inputs, max_new_tokens=5, do_sample=False,
|
|
|
|
| 93 |
|
| 94 |
def main():
|
| 95 |
args = parse_args()
|
| 96 |
+
model, processor = load_model(args.model, args.dtype)
|
| 97 |
ds = load_dataset(args.dataset, split=args.split)
|
| 98 |
n = len(ds) if args.limit < 0 else min(args.limit, len(ds))
|
| 99 |
|
|
|
|
| 116 |
metrics["throughput"] = n / elapsed
|
| 117 |
|
| 118 |
print("\n" + "=" * 60)
|
| 119 |
+
print("ZERO-SHOT EVALUATION REPORT")
|
| 120 |
print("=" * 60)
|
| 121 |
print(f"Model : {args.model} (2B params)")
|
| 122 |
+
print(f"Device : {next(model.parameters()).device}")
|
| 123 |
print(f"Samples : {n}")
|
| 124 |
print(f"Time : {elapsed:.1f}s ({metrics['throughput']:.2f} ex/s)")
|
| 125 |
print(f"Accuracy : {metrics['accuracy']:.2%}")
|
|
|
|
| 136 |
|
| 137 |
acc, f1 = metrics["accuracy"], metrics["f1_scam"]
|
| 138 |
if acc >= 0.90 and f1 >= 0.85:
|
| 139 |
+
print("\n✅ PASS — 2B base model is accurate enough for phone deployment.")
|
|
|
|
| 140 |
elif acc >= 0.75 and f1 >= 0.70:
|
| 141 |
print("\n⚠️ MARGINAL — Fine-tune with Unsloth, then LiteRT-convert.")
|
| 142 |
else:
|