narcolepticchicken commited on
Commit
9731829
·
verified ·
1 Parent(s): 11812f7

Upload jobs/occ_humaneval_v2.py

Browse files
Files changed (1) hide show
  1. jobs/occ_humaneval_v2.py +107 -0
jobs/occ_humaneval_v2.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ OCC HumanEval — Direct completion-format evaluation with Qwen3-Coder-30B-A3B-Instruct.
4
+ Uses completion format (not chat), stop-token trimming, prompt+completion concatenation.
5
+ Two-pass: 128 tokens → test → 1024 tokens on failures. Saves real token counts.
6
+ """
7
+ import json, time, sys, torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
+ from datasets import load_dataset
10
+
11
+ MODEL = "Qwen/Qwen3-Coder-30B-A3B-Instruct"
12
+ SHORT, LONG = 128, 1024
13
+ STOP = ["\nclass", "\ndef", "\n#", "\nif __name__", "\nprint("]
14
+
15
+ def log(msg): print(f"[OCC] {msg}", flush=True)
16
+
17
+ log(f"Loading {MODEL}...")
18
+ tok = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
19
+ tok.pad_token = tok.eos_token
20
+ model = AutoModelForCausalLM.from_pretrained(MODEL, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto")
21
+ log(f"Loaded on {model.device}")
22
+
23
+ ds = load_dataset("openai/openai_humaneval", split="test")
24
+ problems = list(ds)
25
+ log(f"Loaded {len(problems)} problems")
26
+
27
+ def generate(prompt, mx):
28
+ inp = tok(prompt, return_tensors="pt", truncation=True, max_length=2048).to(model.device)
29
+ ilen = inp.input_ids.shape[1]
30
+ with torch.no_grad():
31
+ out = model.generate(**inp, max_new_tokens=mx, do_sample=False, pad_token_id=tok.eos_token_id)
32
+ gen = tok.decode(out[0][ilen:], skip_special_tokens=True)
33
+ for sw in STOP:
34
+ idx = gen.find(sw)
35
+ if idx != -1: gen = gen[:idx]
36
+ return gen
37
+
38
+ def clean_body(gen):
39
+ lines = gen.split("\n")
40
+ while lines and not lines[0].strip(): lines = lines[1:]
41
+ while lines and not lines[-1].strip(): lines = lines[:-1]
42
+ return "\n".join(lines)
43
+
44
+ def test_problem(prompt, gen, test, entry):
45
+ body = clean_body(gen)
46
+ if not body.strip(): return False, "empty"
47
+ code = prompt.rstrip() + "\n" + body + "\n\n" + test
48
+ try: compile(code, "<t>", "exec")
49
+ except SyntaxError as e: return False, f"SyntaxErr:{e}"
50
+ try:
51
+ ns = {}
52
+ exec(code, ns)
53
+ if "check" in ns: ns["check"](ns.get(entry))
54
+ return True, "OK"
55
+ except AssertionError: return False, "AssertErr"
56
+ except Exception as e: return False, f"{type(e).__name__}:{str(e)[:60]}"
57
+
58
+ def run_eval(probs, mx, label):
59
+ log(f"\n--- {label} ---")
60
+ ok, fail, tok_total = [], [], 0
61
+ for i, p in enumerate(probs):
62
+ t0 = time.time()
63
+ gen = generate(p["prompt"], mx)
64
+ elapsed = time.time() - t0
65
+ ntok = len(tok.encode(gen))
66
+ tok_total += ntok
67
+ passed, reason = test_problem(p["prompt"], gen, p["test"], p["entry_point"])
68
+ if passed:
69
+ ok.append(p["task_id"])
70
+ log(f" [{i+1:3d}] {p['task_id']}: PASS ({ntok}tok {elapsed:.1f}s)")
71
+ else:
72
+ fail.append(p["task_id"])
73
+ log(f" [{i+1:3d}] {p['task_id']}: FAIL ({ntok}tok) — {reason[:80]}")
74
+ acc = len(ok) / len(probs)
75
+ log(f"{label}: {len(ok)}/{len(probs)} ({acc:.4f}), {tok_total} tokens")
76
+ return ok, fail, tok_total, acc
77
+
78
+ # Pass 1
79
+ p1_ok, p1_fail, p1_tok, p1_acc = run_eval(problems, SHORT, "Pass1-SHORT")
80
+
81
+ # Pass 2
82
+ p2_new = 0; p2_tok = 0
83
+ if p1_fail:
84
+ fail_set = set(p1_fail)
85
+ fp = [p for p in problems if p["task_id"] in fail_set]
86
+ p2_ok, _, p2_tok, _ = run_eval(fp, LONG, "Pass2-LONG")
87
+ p2_new = len(p2_ok)
88
+ else:
89
+ log("\nPass2: SKIPPED (100% on pass 1)")
90
+
91
+ final = len(p1_ok) + p2_new
92
+ final_acc = final / len(problems)
93
+ tt = p1_tok + p2_tok
94
+ base = len(problems) * LONG
95
+ sav = (1 - tt/base)*100
96
+
97
+ log(f"\n{'='*60}")
98
+ log(f"Qwen3-Coder-30B-A3B-Instruct HumanEval OCC")
99
+ log(f"P1({SHORT}tok): {len(p1_ok)} passed, P2(+{LONG}tok): +{p2_new}")
100
+ log(f"OCC pass@1={final_acc:.4f} ({final}/{len(problems)}), {tt} tok")
101
+ log(f"Baseline={base} tok, Savings={sav:.1f}%")
102
+ log(f"{'='*60}")
103
+
104
+ res = {"model":MODEL,"pass1_ok":len(p1_ok),"pass2_new":p2_new,
105
+ "occ_pass1":final_acc,"occ_tokens":tt,"baseline_tokens":base,"savings_pct":sav}
106
+ json.dump(res, open("/app/occ_humaneval_results.json","w"), indent=2)
107
+ log("Saved to /app/occ_humaneval_results.json")