s23deepak commited on
Commit
8bdc6fb
·
verified ·
1 Parent(s): edf8f8f

Upload eval_zero_shot_cpu.py

Browse files
Files changed (1) hide show
  1. eval_zero_shot_cpu.py +13 -36
eval_zero_shot_cpu.py CHANGED
@@ -1,24 +1,7 @@
1
  #!/usr/bin/env python3
2
  """
3
- Zero-shot eval of Gemma 4 E2B (2B) on CPU — no GPU needed.
4
- Use this to test the smallest Gemma 4 before any mobile conversion.
5
-
6
- REQUIREMENTS:
7
- pip install transformers datasets torch huggingface_hub
8
-
9
- USAGE:
10
- # Quick test (20 samples, ~2-3 min on laptop CPU)
11
- python eval_zero_shot_cpu.py --limit 20
12
-
13
- # Full test split (~400 samples, ~30-45 min on CPU)
14
- python eval_zero_shot_cpu.py --limit -1
15
-
16
- MODEL SIZE:
17
- gemma-4-E2B-it = 2B params
18
- FP32 on CPU RAM: ~8 GB peak
19
- Use --dtype fp16 to halve RAM to ~4 GB if your CPU supports it.
20
  """
21
-
22
  import argparse
23
  import json
24
  import time
@@ -44,27 +27,22 @@ def parse_args():
44
  p.add_argument("--model", default="google/gemma-4-E2B-it")
45
  p.add_argument("--dataset", default="BothBosu/scam-dialogue")
46
  p.add_argument("--split", default="test")
47
- p.add_argument("--limit", type=int, default=20,
48
- help="Max rows (-1 = all). Default 20 for quick CPU test.")
49
- p.add_argument("--dtype", default="fp32", choices=["fp16","fp32"],
50
- help="fp16 = half RAM (~4 GB), fp32 = ~8 GB")
51
- p.add_argument("--out", default="results_zero_shot_cpu.json")
52
  return p.parse_args()
53
 
54
 
55
- def load_model_cpu(model_id: str, dtype: str):
56
  torch_dtype = torch.float16 if dtype == "fp16" else torch.float32
57
- print(f"Loading {model_id} on CPU (dtype={dtype}) …")
58
- print(f" Expected RAM: ~{4 if dtype == 'fp16' else 8} GB")
59
- print(f" If OOM: close browser tabs, reduce --limit, or use fp16.\n")
60
-
61
  model = Gemma4ForConditionalGeneration.from_pretrained(
62
  model_id,
63
  torch_dtype=torch_dtype,
64
- device_map=None, # force CPU
65
  low_cpu_mem_usage=True,
66
  )
67
- model = model.to("cpu")
68
  processor = AutoProcessor.from_pretrained(model_id)
69
  model.eval()
70
  return model, processor
@@ -80,7 +58,7 @@ def classify(model, processor, transcript: str) -> str:
80
  messages, tokenize=True, return_dict=True,
81
  return_tensors="pt", add_generation_prompt=True,
82
  )
83
- inputs = {k: v.to("cpu") for k, v in inputs.items()}
84
 
85
  gen_ids = model.generate(
86
  **inputs, max_new_tokens=5, do_sample=False,
@@ -115,7 +93,7 @@ def compute_metrics(items):
115
 
116
  def main():
117
  args = parse_args()
118
- model, processor = load_model_cpu(args.model, args.dtype)
119
  ds = load_dataset(args.dataset, split=args.split)
120
  n = len(ds) if args.limit < 0 else min(args.limit, len(ds))
121
 
@@ -138,10 +116,10 @@ def main():
138
  metrics["throughput"] = n / elapsed
139
 
140
  print("\n" + "=" * 60)
141
- print("CPU ZERO-SHOT REPORT")
142
  print("=" * 60)
143
  print(f"Model : {args.model} (2B params)")
144
- print(f"Device : CPU")
145
  print(f"Samples : {n}")
146
  print(f"Time : {elapsed:.1f}s ({metrics['throughput']:.2f} ex/s)")
147
  print(f"Accuracy : {metrics['accuracy']:.2%}")
@@ -158,8 +136,7 @@ def main():
158
 
159
  acc, f1 = metrics["accuracy"], metrics["f1_scam"]
160
  if acc >= 0.90 and f1 >= 0.85:
161
- print("\n✅ PASS — 2B base model is accurate enough.")
162
- print(" Next: convert to LiteRT 4-bit for phone (~1.5 GB RAM).")
163
  elif acc >= 0.75 and f1 >= 0.70:
164
  print("\n⚠️ MARGINAL — Fine-tune with Unsloth, then LiteRT-convert.")
165
  else:
 
1
  #!/usr/bin/env python3
2
  """
3
+ Zero-shot eval of Gemma 4 E2B (2B) on CPU/GPU.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  """
 
5
  import argparse
6
  import json
7
  import time
 
27
  p.add_argument("--model", default="google/gemma-4-E2B-it")
28
  p.add_argument("--dataset", default="BothBosu/scam-dialogue")
29
  p.add_argument("--split", default="test")
30
+ p.add_argument("--limit", type=int, default=50)
31
+ p.add_argument("--dtype", default="fp16", choices=["fp16","fp32"])
32
+ p.add_argument("--out", default="results_zero_shot.json")
 
 
33
  return p.parse_args()
34
 
35
 
36
+ def load_model(model_id: str, dtype: str):
37
  torch_dtype = torch.float16 if dtype == "fp16" else torch.float32
38
+ print(f"Loading {model_id} (dtype={dtype}) …")
39
+
 
 
40
  model = Gemma4ForConditionalGeneration.from_pretrained(
41
  model_id,
42
  torch_dtype=torch_dtype,
43
+ device_map="auto",
44
  low_cpu_mem_usage=True,
45
  )
 
46
  processor = AutoProcessor.from_pretrained(model_id)
47
  model.eval()
48
  return model, processor
 
58
  messages, tokenize=True, return_dict=True,
59
  return_tensors="pt", add_generation_prompt=True,
60
  )
61
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
62
 
63
  gen_ids = model.generate(
64
  **inputs, max_new_tokens=5, do_sample=False,
 
93
 
94
  def main():
95
  args = parse_args()
96
+ model, processor = load_model(args.model, args.dtype)
97
  ds = load_dataset(args.dataset, split=args.split)
98
  n = len(ds) if args.limit < 0 else min(args.limit, len(ds))
99
 
 
116
  metrics["throughput"] = n / elapsed
117
 
118
  print("\n" + "=" * 60)
119
+ print("ZERO-SHOT EVALUATION REPORT")
120
  print("=" * 60)
121
  print(f"Model : {args.model} (2B params)")
122
+ print(f"Device : {next(model.parameters()).device}")
123
  print(f"Samples : {n}")
124
  print(f"Time : {elapsed:.1f}s ({metrics['throughput']:.2f} ex/s)")
125
  print(f"Accuracy : {metrics['accuracy']:.2%}")
 
136
 
137
  acc, f1 = metrics["accuracy"], metrics["f1_scam"]
138
  if acc >= 0.90 and f1 >= 0.85:
139
+ print("\n✅ PASS — 2B base model is accurate enough for phone deployment.")
 
140
  elif acc >= 0.75 and f1 >= 0.70:
141
  print("\n⚠️ MARGINAL — Fine-tune with Unsloth, then LiteRT-convert.")
142
  else: