narcolepticchicken commited on
Commit
e83f963
·
verified ·
1 Parent(s): 317b409

Upload jobs/run_real_llm_standalone_v4.py

Browse files
Files changed (1) hide show
  1. jobs/run_real_llm_standalone_v4.py +579 -0
jobs/run_real_llm_standalone_v4.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Self-contained GPU job for real LLM code benchmark — V4.
3
+ Key fixes:
4
+ 1. Robust code extraction with markdown stripping and AST validation
5
+ 2. Temperature=0.0 for deterministic generation
6
+ 3. Try both chat-templated and raw prompts
7
+ 4. Debug logging of raw generated outputs
8
+ 5. Proper body replacement (extract just body, not full function)
9
+ 6. Use evalplus/humanevalplus for stronger hidden tests
10
+ """
11
+ import json
12
+ import os
13
+ import random
14
+ import re
15
+ import subprocess
16
+ import sys
17
+ import tempfile
18
+ import time
19
+ import ast
20
+ from dataclasses import dataclass, field
21
+ from enum import Enum
22
+ from pathlib import Path
23
+ from typing import Any, Dict, List, Optional, Tuple
24
+
25
+ from datasets import load_dataset
26
+ from transformers import AutoModelForCausalLM, AutoTokenizer
27
+ import torch
28
+
29
+
30
+ # --- ORACLE (INLINE) ---
31
+
32
+ @dataclass
33
+ class OracleResult:
34
+ raw_score: float
35
+ cost_adjusted_score: float
36
+ confidence: float
37
+ evidence: Dict[str, Any]
38
+ reason: str
39
+ failure_tags: List[str] = field(default_factory=list)
40
+ reward_value: float = 0.0
41
+
42
+
43
+ class ImpactOracle:
44
+ def __init__(self, compute_penalty_rate=0.0001, gaming_penalty=2.0):
45
+ self.compute_penalty_rate = compute_penalty_rate
46
+ self.gaming_penalty = gaming_penalty
47
+
48
+ def score(self, mode, action, context, result, agent_id=""):
49
+ correctness = result.get("correctness", 0.0)
50
+ compute_cost = result.get("compute_cost", 0.0)
51
+ public_pass = result.get("public_pass", correctness)
52
+ hidden_pass = result.get("hidden_tests_pass", correctness)
53
+ failure_tags = []
54
+ if public_pass and not hidden_pass:
55
+ failure_tags.append("gaming_hidden_tests")
56
+ raw = correctness * 1.0 - compute_cost * self.compute_penalty_rate
57
+ if "gaming_hidden_tests" in failure_tags:
58
+ raw -= self.gaming_penalty
59
+ cost_adj = raw - compute_cost * self.compute_penalty_rate
60
+ return OracleResult(
61
+ raw_score=raw,
62
+ cost_adjusted_score=cost_adj,
63
+ confidence=result.get("confidence", correctness),
64
+ evidence={"correctness": correctness},
65
+ reason=f"correctness={correctness:.2f}, cost={compute_cost}",
66
+ failure_tags=failure_tags,
67
+ reward_value=cost_adj,
68
+ )
69
+
70
+
71
+ # --- LEDGER (INLINE) ---
72
+
73
+ @dataclass
74
+ class LedgerEntry:
75
+ agent_id: str
76
+ task_id: str
77
+ action_id: str
78
+ earned_credit: float
79
+ spent_credit: float
80
+ decayed_credit: float
81
+ remaining_credit: float
82
+ reason: str
83
+ oracle_score: float
84
+ compute_cost: float
85
+ timestamp: float
86
+ capability_scope: str = "global"
87
+
88
+
89
+ class CreditLedger:
90
+ def __init__(self, decay_lambda=0.05):
91
+ self.entries = []
92
+ self.balances = {}
93
+ self.decay_lambda = decay_lambda
94
+
95
+ def earn(self, agent_id, task_id, action_id, amount, oracle_score, compute_cost, reason, capability_scope="global"):
96
+ now = time.time()
97
+ self._apply_decay(agent_id, now, capability_scope)
98
+ current = self._get(agent_id, capability_scope)
99
+ new_bal = current + amount
100
+ self.entries.append(LedgerEntry(agent_id, task_id, action_id, amount, 0.0, 0.0, new_bal, reason, oracle_score, compute_cost, now, capability_scope))
101
+ self._set(agent_id, capability_scope, new_bal)
102
+
103
+ def spend(self, agent_id, task_id, action_id, amount, capability_scope="global", reason="spend"):
104
+ now = time.time()
105
+ self._apply_decay(agent_id, now, capability_scope)
106
+ current = self._get(agent_id, capability_scope)
107
+ if current < amount:
108
+ return False
109
+ new_bal = current - amount
110
+ self.entries.append(LedgerEntry(agent_id, task_id, action_id, 0.0, amount, 0.0, new_bal, reason, 0.0, 0.0, now, capability_scope))
111
+ self._set(agent_id, capability_scope, new_bal)
112
+ return True
113
+
114
+ def balance(self, agent_id, capability_scope="global"):
115
+ now = time.time()
116
+ self._apply_decay(agent_id, now, capability_scope)
117
+ return self._get(agent_id, capability_scope)
118
+
119
+ def _get(self, agent_id, cap):
120
+ return self.balances.get(agent_id, {}).get(cap, 0.0)
121
+
122
+ def _set(self, agent_id, cap, val):
123
+ if agent_id not in self.balances:
124
+ self.balances[agent_id] = {}
125
+ self.balances[agent_id][cap] = val
126
+
127
+ def _apply_decay(self, agent_id, now, cap):
128
+ current = self._get(agent_id, cap)
129
+ if current <= 0:
130
+ return
131
+ decayed = current * (1 - self.decay_lambda)
132
+ if decayed < current:
133
+ self.entries.append(LedgerEntry(agent_id, "decay", "decay", 0.0, 0.0, current - decayed, decayed, "credit_decay", 0.0, 0.0, now, cap))
134
+ self._set(agent_id, cap, decayed)
135
+
136
+
137
+ # --- BROKER (INLINE) ---
138
+
139
+ class Decision(Enum):
140
+ ALLOW = "allow"
141
+ DENY = "deny"
142
+ REQUIRE_APPROVAL = "require_approval"
143
+ DOWNGRADE = "downgrade"
144
+ ESCALATE = "escalate"
145
+ ASK_JUSTIFICATION = "ask_justification"
146
+
147
+
148
+ @dataclass
149
+ class ResourceDecision:
150
+ decision: Decision
151
+ reason: str
152
+ capability: str
153
+ downgrade_to: Optional[str] = None
154
+
155
+
156
+ class ResourceBroker:
157
+ RESOURCE_RISK = {
158
+ "model_call": "medium", "retrieval_call": "low", "verifier_call": "medium",
159
+ "debate_turn": "low", "file_write": "high", "shell_execute": "high",
160
+ "memory_write": "medium", "human_escalation": "high", "larger_model": "medium"
161
+ }
162
+ DEFAULT_THRESHOLDS = {"low": 0.5, "medium": 2.0, "high": 5.0}
163
+
164
+ def __init__(self, thresholds=None, urgency_boost=0.5):
165
+ self.thresholds = thresholds or self.DEFAULT_THRESHOLDS.copy()
166
+ self.urgency_boost = urgency_boost
167
+ self.denial_history = {}
168
+
169
+ def request(self, capability, agent_id, credit_balance, task_state=None, risk_score=0.0, gaming_flags=None):
170
+ task_state = task_state or {}
171
+ gaming_flags = gaming_flags or []
172
+ risk_class = self.RESOURCE_RISK.get(capability, "medium")
173
+ threshold = self.thresholds.get(risk_class, 2.0)
174
+ urgency = task_state.get("urgency", 0.0)
175
+ adjusted = max(0.1, threshold - urgency * self.urgency_boost)
176
+ if gaming_flags:
177
+ return ResourceDecision(Decision.DENY, f"Gaming: {gaming_flags}", capability)
178
+ if risk_class == "high" and risk_score > 0.7:
179
+ return ResourceDecision(Decision.REQUIRE_APPROVAL, f"High risk {risk_score:.2f}", capability)
180
+ if credit_balance >= adjusted:
181
+ return ResourceDecision(Decision.ALLOW, f"Balance {credit_balance:.2f} >= {adjusted:.2f}", capability)
182
+ if credit_balance >= adjusted * 0.5:
183
+ if risk_class == "medium":
184
+ return ResourceDecision(Decision.DOWNGRADE, f"Downgrading from {capability}", capability, "retrieval_call")
185
+ return ResourceDecision(Decision.ASK_JUSTIFICATION, f"Justification required", capability)
186
+ denials = self.denial_history.get(agent_id, 0)
187
+ if denials > 3:
188
+ return ResourceDecision(Decision.ESCALATE, f"Denied {denials} times", capability)
189
+ self.denial_history[agent_id] = denials + 1
190
+ return ResourceDecision(Decision.DENY, f"Balance {credit_balance:.2f} < {adjusted:.2f}", capability)
191
+
192
+
193
+ # --- HELPERS ---
194
+
195
+ def strip_markdown_fences(text: str) -> str:
196
+ """Remove markdown code fences."""
197
+ text = text.strip()
198
+ if text.startswith("```"):
199
+ lines = text.splitlines()
200
+ if lines[0].startswith("```"):
201
+ lines = lines[1:]
202
+ if lines and lines[-1].strip() == "```":
203
+ lines = lines[:-1]
204
+ text = "\n".join(lines)
205
+ return text.strip()
206
+
207
+
208
+ def extract_body_or_full(code: str, entry_point: str) -> str:
209
+ """
210
+ Try to extract just the function body from generated code.
211
+ If the model outputs the full function (def + docstring + body),
212
+ extract only the lines after the closing docstring.
213
+ If it outputs just body, return as-is.
214
+ """
215
+ code = strip_markdown_fences(code)
216
+
217
+ # If no def line at all, assume it's just the body
218
+ if not re.search(rf'\bdef\s+{re.escape(entry_point)}\b', code):
219
+ return code
220
+
221
+ # Find the function definition
222
+ # Match: def func_name(...) followed by optional docstring, then body
223
+ pattern = rf'(\bdef\s+{re.escape(entry_point)}\s*\([^)]*\)[^:]*:)(.*?)(?=\n(?:\S|$))'
224
+ match = re.search(pattern, code, re.DOTALL)
225
+ if not match:
226
+ # Fallback: everything from def to end
227
+ pattern2 = rf'\bdef\s+{re.escape(entry_point)}\s*\([^)]*\)[^:]*:.*'
228
+ match = re.search(pattern2, code, re.DOTALL)
229
+ if match:
230
+ return match.group(0)
231
+ return code
232
+
233
+ func_text = match.group(0)
234
+ # Try to extract just the body after the docstring
235
+ # Look for """ ... """ or ''' ... ''' after the def line
236
+ docstring_match = re.search(r'(?m)^(\s*)(""".*?"""|\'\'\'.*?\'\'\')\s*\n', func_text, re.DOTALL)
237
+ if docstring_match:
238
+ body_start = docstring_match.end()
239
+ body = func_text[body_start:]
240
+ return body.strip()
241
+ else:
242
+ # No docstring, body starts after the colon
243
+ colon_pos = func_text.find(':')
244
+ if colon_pos != -1:
245
+ body = func_text[colon_pos+1:]
246
+ return body.strip()
247
+ return func_text
248
+
249
+
250
+ def validate_python(code: str) -> tuple:
251
+ """Validate Python syntax with AST. Returns (ok, error)."""
252
+ try:
253
+ ast.parse(code)
254
+ return True, ""
255
+ except SyntaxError as e:
256
+ return False, str(e)
257
+ except Exception as e:
258
+ return False, str(e)
259
+
260
+
261
+ def extract_function_body(code: str, entry_point: str) -> str:
262
+ """Robust extraction: try multiple strategies."""
263
+ strategies = [
264
+ lambda c: extract_body_or_full(c, entry_point),
265
+ lambda c: strip_markdown_fences(c),
266
+ lambda c: c,
267
+ ]
268
+ for i, strat in enumerate(strategies):
269
+ candidate = strat(code)
270
+ ok, err = validate_python(candidate)
271
+ if ok:
272
+ return candidate
273
+ print(f" [extract strat {i}] failed: {err[:80]}")
274
+ # Last resort: return raw
275
+ return code
276
+
277
+
278
+ def run_tests(code: str, test_code: str, timeout: int = 10):
279
+ full = code + "\n\n" + test_code + "\n\ncheck()\n"
280
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
281
+ f.write(full)
282
+ tmp = f.name
283
+ try:
284
+ result = subprocess.run(['python', tmp], capture_output=True, text=True, timeout=timeout)
285
+ passed = result.returncode == 0
286
+ error = result.stderr if not passed else ""
287
+ except subprocess.TimeoutExpired:
288
+ passed = False
289
+ error = "Timeout"
290
+ except Exception as e:
291
+ passed = False
292
+ error = str(e)
293
+ finally:
294
+ os.unlink(tmp)
295
+ return passed, error
296
+
297
+
298
+ QWEN_SYSTEM = "You are an expert Python programmer. Complete the function. Output ONLY the function body or the complete function, no markdown, no explanations."
299
+
300
+
301
+ def wrap_prompt_chat(humaneval_prompt: str, tok) -> str:
302
+ messages = [
303
+ {"role": "system", "content": QWEN_SYSTEM},
304
+ {"role": "user", "content": humaneval_prompt.strip()},
305
+ ]
306
+ if hasattr(tok, "apply_chat_template") and tok.chat_template:
307
+ return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
308
+ else:
309
+ sys_msg = f"system\n{QWEN_SYSTEM}\n"
310
+ usr_msg = f"user\n{humaneval_prompt.strip()}\n"
311
+ asst_msg = "assistant\n"
312
+ return sys_msg + usr_msg + asst_msg
313
+
314
+
315
+ # --- BENCHMARK ---
316
+
317
+ class RealLLMBenchmarkV4:
318
+ def __init__(self, model_name="Qwen/Qwen2.5-Coder-0.5B-Instruct", n_problems=10, seed=42, use_chat_template=True):
319
+ self.model_name = model_name
320
+ self.n_problems = n_problems
321
+ self.seed = seed
322
+ self.oracle = ImpactOracle()
323
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
324
+ self.use_chat_template = use_chat_template
325
+ print(f"Using device: {self.device}, chat_template={use_chat_template}")
326
+
327
+ def load_problems(self):
328
+ ds = load_dataset("evalplus/humanevalplus", split="test")
329
+ problems = []
330
+ for i, item in enumerate(ds):
331
+ if i >= self.n_problems:
332
+ break
333
+ problems.append({
334
+ "task_id": item["task_id"],
335
+ "prompt": item["prompt"],
336
+ "test": item["test"],
337
+ "entry_point": item["entry_point"],
338
+ })
339
+ return problems
340
+
341
+ def load_model(self):
342
+ print(f"Loading {self.model_name}...")
343
+ tok = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
344
+ model = AutoModelForCausalLM.from_pretrained(
345
+ self.model_name,
346
+ trust_remote_code=True,
347
+ torch_dtype=torch.bfloat16 if self.device == "cuda" else torch.float32,
348
+ device_map="auto" if self.device == "cuda" else None,
349
+ )
350
+ if self.device == "cpu":
351
+ model = model.to("cpu").float()
352
+ print(f"Model loaded. Chat template present: {bool(tok.chat_template)}")
353
+ return model, tok
354
+
355
+ def generate(self, model, tok, prompt_raw: str, max_new_tokens: int = 256, temperature: float = 0.0):
356
+ if self.use_chat_template:
357
+ chat_prompt = wrap_prompt_chat(prompt_raw, tok)
358
+ else:
359
+ chat_prompt = prompt_raw
360
+
361
+ inputs = tok(chat_prompt, return_tensors="pt").to(model.device)
362
+ with torch.no_grad():
363
+ outputs = model.generate(
364
+ **inputs,
365
+ max_new_tokens=max_new_tokens,
366
+ temperature=temperature,
367
+ do_sample=(temperature > 0),
368
+ pad_token_id=tok.eos_token_id,
369
+ )
370
+ gen = tok.decode(outputs[0], skip_special_tokens=True)
371
+ prompt_decoded = tok.decode(inputs.input_ids[0], skip_special_tokens=True)
372
+ code = gen[len(prompt_decoded):].strip()
373
+ return code
374
+
375
+ def _evaluate_one(self, problem, model, tok, max_new_tokens=256):
376
+ """Generate and evaluate one problem. Returns (passed, tokens, raw_output, error)."""
377
+ raw = self.generate(model, tok, problem["prompt"], max_new_tokens=max_new_tokens)
378
+ tokens = len(tok.encode(raw))
379
+
380
+ # Try multiple extraction strategies
381
+ candidates = [
382
+ extract_body_or_full(raw, problem["entry_point"]),
383
+ strip_markdown_fences(raw),
384
+ raw,
385
+ ]
386
+
387
+ best_code = None
388
+ best_err = ""
389
+ for cand in candidates:
390
+ full = problem["prompt"] + cand
391
+ ok, err = validate_python(full)
392
+ if ok:
393
+ best_code = cand
394
+ break
395
+ best_err = err
396
+
397
+ if best_code is None:
398
+ print(f" [AST FAIL] All candidates invalid. Last error: {best_err[:80]}")
399
+ best_code = raw
400
+
401
+ full = problem["prompt"] + best_code
402
+ passed, error = run_tests(full, problem["test"])
403
+
404
+ return passed, tokens, raw, error
405
+
406
+ def run_baseline(self, problems, model, tok):
407
+ results = []
408
+ total_compute = 0
409
+ for problem in problems:
410
+ passed, tokens, raw, error = self._evaluate_one(problem, model, tok, max_new_tokens=256)
411
+ total_compute += tokens
412
+ results.append({
413
+ "task_id": problem["task_id"],
414
+ "passed": passed,
415
+ "tokens": tokens,
416
+ "raw_output": raw[:200],
417
+ "error": error[:200],
418
+ })
419
+ print(f" {problem['task_id']}: passed={passed}, tokens={tokens}, raw={raw[:60]!r}")
420
+ if not passed:
421
+ print(f" error={error[:100]!r}")
422
+ return {
423
+ "accuracy": sum(1 for r in results if r["passed"]) / len(results),
424
+ "total_compute": total_compute,
425
+ "mean_tokens": total_compute / len(problems),
426
+ "results": results,
427
+ }
428
+
429
+ def run_occ(self, problems, model, tok):
430
+ ledger = CreditLedger(decay_lambda=0.02)
431
+ broker = ResourceBroker()
432
+ ledger.earn("code_agent", "seed", "seed", 25.0, 0.0, 0.0, "initial", "model_call")
433
+ results = []
434
+ total_compute = 0
435
+
436
+ for problem in problems:
437
+ budget_remaining = 2000
438
+ attempts = 0
439
+ passed = False
440
+ best_code = ""
441
+ best_score = -999
442
+
443
+ while budget_remaining > 100 and attempts < 3 and not passed:
444
+ attempts += 1
445
+ balance = ledger.balance("code_agent", "model_call")
446
+ dec = broker.request("model_call", "code_agent", balance,
447
+ task_state={"attempts": attempts, "budget_remaining": budget_remaining})
448
+ if dec.decision == Decision.DENY:
449
+ break
450
+
451
+ # OCC: shorter / lower temp on first attempt
452
+ temp = 0.0
453
+ max_tok = 128 if attempts == 1 else 256
454
+ code_raw = self.generate(model, tok, problem["prompt"], max_new_tokens=max_tok, temperature=temp)
455
+ tokens = len(tok.encode(code_raw))
456
+ budget_remaining -= tokens
457
+ total_compute += tokens
458
+
459
+ # Extract
460
+ candidates = [
461
+ extract_body_or_full(code_raw, problem["entry_point"]),
462
+ strip_markdown_fences(code_raw),
463
+ code_raw,
464
+ ]
465
+ func = None
466
+ for cand in candidates:
467
+ full = problem["prompt"] + cand
468
+ ok, _ = validate_python(full)
469
+ if ok:
470
+ func = cand
471
+ break
472
+ if func is None:
473
+ func = code_raw
474
+
475
+ full = problem["prompt"] + func
476
+ passed_now, error = run_tests(full, problem["test"])
477
+ score = 1.0 if passed_now else 0.0
478
+
479
+ oracle_res = self.oracle.score(
480
+ mode="code",
481
+ action={"attempt": attempts},
482
+ context={},
483
+ result={"correctness": score, "pass_at_k": score, "regression": False,
484
+ "compute_cost": tokens, "public_pass": passed_now, "hidden_tests_pass": passed_now},
485
+ agent_id="code_agent",
486
+ )
487
+
488
+ if oracle_res.raw_score > best_score:
489
+ best_score = oracle_res.raw_score
490
+ best_code = code_raw
491
+ passed = passed_now
492
+
493
+ if passed:
494
+ ledger.earn("code_agent", problem["task_id"], f"att_{attempts}", 5.0, oracle_res.raw_score, tokens, "pass", "model_call")
495
+ else:
496
+ ledger.spend("code_agent", problem["task_id"], f"att_{attempts}", 1.0, "model_call", reason="fail")
497
+
498
+ if attempts >= 2 and not passed:
499
+ break
500
+
501
+ results.append({"task_id": problem["task_id"], "passed": passed, "attempts": attempts,
502
+ "tokens_used": 2000 - budget_remaining, "best_score": best_score,
503
+ "raw_best": best_code[:200]})
504
+ print(f" {problem['task_id']}: passed={passed}, attempts={attempts}, raw={best_code[:60]!r}")
505
+
506
+ return {
507
+ "accuracy": sum(1 for r in results if r["passed"]) / len(results),
508
+ "total_compute": total_compute,
509
+ "mean_tokens": total_compute / len(problems),
510
+ "mean_attempts": sum(r["attempts"] for r in results) / len(results),
511
+ "results": results,
512
+ }
513
+
514
+ def run_all(self):
515
+ problems = self.load_problems()
516
+ print(f"Loaded {len(problems)} problems")
517
+ model, tok = self.load_model()
518
+
519
+ print("\n--- Baseline (chat template) ---")
520
+ self.use_chat_template = True
521
+ baseline_chat = self.run_baseline(problems, model, tok)
522
+ print(f"Baseline chat: accuracy={baseline_chat['accuracy']:.3f}, compute={baseline_chat['total_compute']}")
523
+
524
+ print("\n--- Baseline (raw prompt) ---")
525
+ self.use_chat_template = False
526
+ baseline_raw = self.run_baseline(problems, model, tok)
527
+ print(f"Baseline raw: accuracy={baseline_raw['accuracy']:.3f}, compute={baseline_raw['total_compute']}")
528
+
529
+ # Use whichever baseline is better for OCC comparison
530
+ best_baseline = baseline_chat if baseline_chat["accuracy"] >= baseline_raw["accuracy"] else baseline_raw
531
+ best_mode = "chat" if baseline_chat["accuracy"] >= baseline_raw["accuracy"] else "raw"
532
+ self.use_chat_template = (best_mode == "chat")
533
+
534
+ print(f"\n--- OCC (using {best_mode}) ---")
535
+ occ = self.run_occ(problems, model, tok)
536
+ print(f"OCC: accuracy={occ['accuracy']:.3f}, compute={occ['total_compute']}")
537
+
538
+ comp = {
539
+ "baseline_accuracy": best_baseline["accuracy"],
540
+ "occ_accuracy": occ["accuracy"],
541
+ "baseline_compute": best_baseline["total_compute"],
542
+ "occ_compute": occ["total_compute"],
543
+ "compute_reduction": 1.0 - (occ["total_compute"] / max(best_baseline["total_compute"], 1)),
544
+ "accuracy_delta": occ["accuracy"] - best_baseline["accuracy"],
545
+ }
546
+
547
+ return {
548
+ "baseline_chat": baseline_chat,
549
+ "baseline_raw": baseline_raw,
550
+ "baseline_best": best_baseline,
551
+ "occ_budget": occ,
552
+ "comparison": comp,
553
+ }
554
+
555
+
556
+ def main():
557
+ bench = RealLLMBenchmarkV4(n_problems=10, seed=42)
558
+ results = bench.run_all()
559
+
560
+ print("\n" + "=" * 60)
561
+ print("REAL LLM CODE BENCHMARK (V4)")
562
+ print("=" * 60)
563
+ comp = results["comparison"]
564
+ print(f"Baseline accuracy: {comp['baseline_accuracy']:.3f}")
565
+ print(f"OCC accuracy: {comp['occ_accuracy']:.3f}")
566
+ print(f"Baseline compute: {comp['baseline_compute']}")
567
+ print(f"OCC compute: {comp['occ_compute']}")
568
+ print(f"Compute reduction: {comp['compute_reduction']:.1%}")
569
+ print(f"Accuracy delta: {comp['accuracy_delta']:+.3f}")
570
+
571
+ out_dir = Path("/app/occ/reports")
572
+ out_dir.mkdir(parents=True, exist_ok=True)
573
+ with open(out_dir / "benchmark_code_real_llm_v4_results.json", "w") as f:
574
+ json.dump(results, f, indent=2, default=str)
575
+ print(f"\nSaved to {out_dir / 'benchmark_code_real_llm_v4_results.json'}")
576
+
577
+
578
+ if __name__ == "__main__":
579
+ main()