| |
| |
| |
| |
| """hf_job_ab.py — the real Lean Laguna A/B + γ-sweep + reward-invariance, as one HF Jobs run. |
| |
| Runs ON Hugging Face Jobs (a GPU batch job, no ssh, auto-stops when done). In ONE GPU session |
| (so the model load cost is amortized) it produces three pieces of MEASURED evidence: |
| |
| (1) Headline decode A/B — serve Laguna XS.2 baseline, measure tokens/sec over N mixed prompts; |
| re-serve with the DFlash speculator (γ=7), measure again; byte-parity-check the greedy outputs. |
| (2) γ-sweep (lossless throughput-optimal γ) — re-serve DFlash at num_speculative_tokens ∈ GAMMAS |
| (default 5,7,9; one cold serve per γ because vLLM bakes speculative_config at engine init), |
| measure tok/s each, parity-check each. Baseline is measured ONCE (γ-independent). Report the |
| throughput-optimal γ* and its speedup vs γ=7. |
| (3) Reward-invariance (live) — drive the SAME 12-problem HumanEval slice the canonical |
| `prime eval run spec_rl` baseline used (mean reward 0.85) through the baseline and the γ=7 |
| DFlash server via /v1/chat/completions (greedy, thinking off) and score with the VERBATIM |
| spec_rl reward (fraction_passing). baseline_mean_reward == dflash_mean_reward by greedy |
| byte-parity — reward-invariance demonstrated live, not just argued by construction. |
| |
| Submit with (h200 is the proven, best-tested target; bound the spend with --timeout + BUDGET_S): |
| hf jobs uv run --flavor h200 --timeout 2100 \ |
| --secrets HF_TOKEN --env GAMMAS=5,7,9 --env BUDGET_S=1900 scripts/hf_job_ab.py |
| |
| Honesty guards baked in: |
| * Everything is MEASURED — no fabricated numbers. A hard wall-clock budget bounds the spend. |
| * τ (acceptance length) is recorded from /metrics but NOT used as a headline — the counters pin |
| at the γ+1 ceiling at this granularity, so τ is treated as unreliable and never quoted. |
| * The decode tok/s A/B is the throughput headline; eval wall-clock is NOT a throughput claim. |
| * `ttft_s_mean` is full-completion latency, NOT true time-to-first-token (the harness does not |
| isolate prefill) — labeled as such, never reported as TTFT. |
| |
| Local dry-run (no GPU, no network) — validates the loop shape + scoring against the stdlib stub: |
| python scripts/stub_server.py --port 8000 & # baseline-shaped stub |
| printf '%s\n' '{"prompt":"def add(a,b):\\n \\"\\"\\"add\\"\\"\\"\\n","test":"def check(c):\\n assert c(1,2)==3\\n","entry_point":"add"}' > /tmp/toy.jsonl |
| DRYRUN=1 GAMMAS=7 REWARD_N=1 SPEC_RL_DATASET=/tmp/toy.jsonl python scripts/hf_job_ab.py |
| """ |
| from __future__ import annotations |
|
|
| import ast |
| import json |
| import os |
| import subprocess |
| import sys |
| import tempfile |
| import time |
| import urllib.request |
| from pathlib import Path |
|
|
| MODEL = os.environ.get("MODEL", "poolside/Laguna-XS.2") |
| SPECULATOR = os.environ.get("SPECULATOR", "poolside/Laguna-XS.2-speculator.dflash") |
| GAMMA = int(os.environ.get("GAMMA", "7")) |
| GAMMAS = [int(g) for g in os.environ.get("GAMMAS", "5,7,9").split(",") if g.strip()] |
| REWARD_GAMMA = int(os.environ.get("REWARD_GAMMA", "7")) |
| REWARD_N = int(os.environ.get("REWARD_N", "12")) |
| REWARD_MAX_TOKENS = int(os.environ.get("REWARD_MAX_TOKENS", "512")) |
| N = int(os.environ.get("N", "0")) |
| MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "256")) |
| BUDGET_S = int(os.environ.get("BUDGET_S", "1500")) |
| MIN_SERVE_S = int(os.environ.get("MIN_SERVE_S", "300")) |
| DETERMINISM_REPEATS = int(os.environ.get("DETERMINISM_REPEATS", "0")) |
| DRYRUN = os.environ.get("DRYRUN", "") == "1" |
| PORT = 8000 |
| STOP = ["\nclass ", "\ndef ", "\n#", "\nif __name__"] |
| EXEC_TIMEOUT_S = 8 |
| T0 = time.time() |
| |
| |
| PROMPTS = [ |
| |
| "def fib(n):\n \"\"\"Return the n-th Fibonacci number.\"\"\"\n", |
| "def is_prime(n):\n \"\"\"Return True iff n is prime.\"\"\"\n", |
| "def factorial(n):\n \"\"\"Return n! (n factorial).\"\"\"\n", |
| "def reverse_words(s):\n \"\"\"Reverse the order of words in s.\"\"\"\n", |
| |
| "def binary_search(arr, target):\n \"\"\"Return the index of target in sorted arr, else -1.\"\"\"\n", |
| "def merge_sorted(a, b):\n \"\"\"Merge two sorted lists into one sorted list.\"\"\"\n", |
| "def is_balanced(s):\n \"\"\"Return True iff the brackets ()[]{} in s are balanced.\"\"\"\n", |
| "def roman_to_int(s):\n \"\"\"Convert a Roman numeral string to an integer.\"\"\"\n", |
| "def flatten(nested):\n \"\"\"Flatten an arbitrarily nested list of ints into a flat list.\"\"\"\n", |
| |
| "def lcs(a, b):\n \"\"\"Return the length of the longest common subsequence of strings a and b.\"\"\"\n", |
| "def parse_duration(s):\n \"\"\"Parse strings like '1h30m', '45s', '2d' into total seconds. Raise ValueError on bad input.\"\"\"\n", |
| "def group_anagrams(words):\n \"\"\"Group words that are anagrams of each other into a list of lists.\"\"\"\n", |
| "class LRUCache:\n \"\"\"A fixed-capacity LRU cache with get(key) and put(key, value).\"\"\"\n", |
| "def dijkstra(graph, start):\n \"\"\"graph: dict node -> list of (neighbor, weight). Return dict of shortest distances from start.\"\"\"\n", |
| ] |
| if N <= 0: |
| N = len(PROMPTS) |
| PROMPTS = (PROMPTS * ((N // len(PROMPTS)) + 1))[:N] |
|
|
| |
| |
| RL_SYSTEM_PROMPT = ( |
| "You are an expert Python programmer. You will be given a function " |
| "signature and docstring. Complete the function body only. Do not repeat " |
| "the signature, do not add explanations, and do not wrap the code in " |
| "markdown fences. Output only the indented function body." |
| ) |
|
|
|
|
| def budget_left() -> float: |
| return BUDGET_S - (time.time() - T0) |
|
|
|
|
| def serve(dflash: bool, gamma: int = GAMMA) -> subprocess.Popen: |
| env = {**os.environ, |
| "VLLM_USE_DEEP_GEMM": "0", |
| |
| |
| |
| |
| |
| |
| |
| |
| "VLLM_USE_FLASHINFER_MOE_FP16": "0", |
| "VLLM_USE_FLASHINFER_MOE_FP8": "0", |
| "VLLM_USE_FLASHINFER_SAMPLER": "0", |
| "VLLM_ATTENTION_BACKEND": os.environ.get("VLLM_ATTENTION_BACKEND", "FLASH_ATTN")} |
| cmd = [sys.executable, "-m", "vllm.entrypoints.openai.api_server", |
| "--model", MODEL, "--port", str(PORT), "--tensor-parallel-size", "1", |
| "--trust-remote-code", |
| "--enforce-eager", |
| "--gpu-memory-utilization", "0.9", |
| "--max-model-len", os.environ.get("SPECRL_MAX_LEN", "4096"), |
| |
| |
| |
| |
| "--max-num-seqs", os.environ.get("MAX_NUM_SEQS", "16"), |
| |
| |
| "--default-chat-template-kwargs", json.dumps({"enable_thinking": False})] |
| |
| |
| |
| if dflash: |
| cmd += ["--speculative-config", |
| json.dumps({"model": SPECULATOR, "num_speculative_tokens": gamma, "method": "dflash"})] |
| print(f"[job] serving {'DFlash(γ=%d)' % gamma if dflash else 'baseline'}: {' '.join(cmd)}", flush=True) |
| return subprocess.Popen(cmd, env=env) |
|
|
|
|
| def wait_health(proc: subprocess.Popen, timeout: int = 900) -> None: |
| url = f"http://localhost:{PORT}/health" |
| t = time.time() |
| while time.time() - t < timeout: |
| if proc.poll() is not None: |
| raise RuntimeError("vLLM server exited during startup (check logs above)") |
| try: |
| urllib.request.urlopen(url, timeout=5) |
| print("[job] server healthy", flush=True) |
| return |
| except Exception: |
| time.sleep(5) |
| raise TimeoutError("server did not become healthy in time") |
|
|
|
|
| def _post(path: str, payload: dict) -> dict: |
| req = urllib.request.Request(f"http://localhost:{PORT}{path}", |
| data=json.dumps(payload).encode(), |
| headers={"Content-Type": "application/json"}) |
| with urllib.request.urlopen(req, timeout=300) as r: |
| return json.loads(r.read().decode()) |
|
|
|
|
| def complete(prompt: str) -> tuple[str, float, float]: |
| t = time.time() |
| obj = _post("/v1/completions", {"model": MODEL, "prompt": prompt, |
| "max_tokens": MAX_TOKENS, "temperature": 0.0, "stop": STOP}) |
| dt = time.time() - t |
| ch = obj["choices"][0] |
| text = ch.get("text", "") or "" |
| ntok = (obj.get("usage") or {}).get("completion_tokens") or len(text.split()) |
| return text, (ntok / dt if dt else 0.0), dt |
|
|
|
|
| def chat_complete(messages: list[dict], max_tokens: int = REWARD_MAX_TOKENS) -> str: |
| """Greedy chat completion (thinking off), matching the spec_rl eval's chat shape.""" |
| obj = _post("/v1/chat/completions", |
| {"model": MODEL, "messages": messages, "max_tokens": max_tokens, |
| "temperature": 0.0, "chat_template_kwargs": {"enable_thinking": False}}) |
| msg = obj["choices"][0].get("message") or {} |
| return msg.get("content") or "" |
|
|
|
|
| def tau_from_metrics(gamma: int) -> float | None: |
| try: |
| with urllib.request.urlopen(f"http://localhost:{PORT}/metrics", timeout=10) as r: |
| body = r.read().decode() |
| except Exception: |
| return None |
| acc = draft = None |
| for line in body.splitlines(): |
| if line.startswith("vllm:spec_decode_num_accepted_tokens"): |
| acc = float(line.split()[-1]) |
| elif line.startswith("vllm:spec_decode_num_draft_tokens"): |
| draft = float(line.split()[-1]) |
| if acc is not None and draft and draft > 0: |
| passes = draft / gamma |
| return (acc + passes) / passes if passes else None |
| return None |
|
|
|
|
| def measure(dflash: bool, gamma: int = GAMMA) -> dict: |
| """Decode throughput over the mixed prompt set. Records τ for completeness (never quoted).""" |
| texts, tps, ttft = [], [], [] |
| for p in PROMPTS: |
| if budget_left() < 120: |
| print("[job] budget guard hit — stopping measure early", flush=True) |
| break |
| txt, t_ps, dt = complete(p) |
| texts.append(txt); tps.append(t_ps); ttft.append(dt) |
| return { |
| "label": ("dflash_g%d" % gamma) if dflash else "baseline", "model": MODEL, "n": len(texts), |
| "gamma": gamma if dflash else None, |
| "tokens_per_s_mean": sum(tps) / len(tps) if tps else 0.0, |
| "latency_s_mean": sum(ttft) / len(ttft) if ttft else 0.0, |
| "acceptance_length_tau": tau_from_metrics(gamma) if dflash else 1.0, |
| "texts": texts, |
| } |
|
|
|
|
| |
| |
| |
| |
| def load_problems(num_examples: int) -> list[dict]: |
| """First `num_examples` problems as {prompt, test, entry_point}. SPEC_RL_DATASET (.jsonl) wins |
| (the dry-run seam); else the canonical HumanEval test split — identical to spec_rl.load_problems.""" |
| src = os.environ.get("SPEC_RL_DATASET") |
| if src and src.endswith(".jsonl") and os.path.exists(src): |
| with open(src) as f: |
| rows = [json.loads(line) for line in f if line.strip()] |
| return rows[:num_examples] |
| from datasets import load_dataset |
| dataset_id = src or os.environ.get("HUMANEVAL_DATASET", "openai/openai_humaneval") |
| split = os.environ.get("SPEC_RL_DATASET_SPLIT", "test") |
| ds = load_dataset(dataset_id, split=split) |
| num_examples = min(num_examples, len(ds)) |
| return [dict(ds[i]) for i in range(num_examples)] |
|
|
|
|
| class _AssertCounter(ast.NodeTransformer): |
| """Rewrite each `assert` so a failure is COUNTED, not fatal — turns HumanEval's all-or-nothing |
| check() into a fractional pass rate. (Verbatim from spec_rl.py.)""" |
| def visit_Assert(self, node: ast.Assert): |
| try_node = ast.Try( |
| body=[ast.Assign(targets=[ast.Name(id="__ok", ctx=ast.Store())], |
| value=ast.Call(func=ast.Name(id="bool", ctx=ast.Load()), |
| args=[node.test], keywords=[]))], |
| handlers=[ast.ExceptHandler(type=ast.Name(id="BaseException", ctx=ast.Load()), name=None, |
| body=[ast.Assign(targets=[ast.Name(id="__ok", ctx=ast.Store())], |
| value=ast.Constant(value=False))])], |
| orelse=[], finalbody=[]) |
| incr_total = ast.parse("__tally['total'] += 1").body[0] |
| incr_pass = ast.parse("if __ok:\n __tally['passed'] += 1").body[0] |
| out = [try_node, incr_total, incr_pass] |
| for n in out: |
| ast.copy_location(n, node) |
| ast.fix_missing_locations(n) |
| return out |
|
|
|
|
| def passes(problem: dict, completion: str, timeout_s: int = EXEC_TIMEOUT_S) -> bool: |
| program = problem["prompt"] + completion + "\n" + problem["test"] + f"\ncheck({problem['entry_point']})\n" |
| with tempfile.TemporaryDirectory() as tmp: |
| prog_path = Path(tmp) / "candidate.py" |
| prog_path.write_text(program) |
| try: |
| result = subprocess.run([sys.executable, str(prog_path)], capture_output=True, |
| text=True, timeout=timeout_s, cwd=tmp) |
| except subprocess.TimeoutExpired: |
| return False |
| return result.returncode == 0 |
|
|
|
|
| def fraction_passing(problem: dict, completion: str, timeout_s: int = EXEC_TIMEOUT_S) -> float: |
| try: |
| tree = ast.parse(problem["test"]) |
| except SyntaxError: |
| return 1.0 if passes(problem, completion, timeout_s) else 0.0 |
| tree = _AssertCounter().visit(tree) |
| ast.fix_missing_locations(tree) |
| try: |
| instrumented_test = ast.unparse(tree) |
| except Exception: |
| return 1.0 if passes(problem, completion, timeout_s) else 0.0 |
| program = ( |
| "__tally = {'passed': 0, 'total': 0}\n" |
| + problem["prompt"] + completion + "\n" + instrumented_test + "\n" |
| + "try:\n" + f" check({problem['entry_point']})\n" |
| + "except BaseException:\n pass\n" |
| + "import json as __json\nprint('__FRAC__' + __json.dumps(__tally))\n") |
| with tempfile.TemporaryDirectory() as tmp: |
| prog_path = Path(tmp) / "candidate.py" |
| prog_path.write_text(program) |
| try: |
| result = subprocess.run([sys.executable, str(prog_path)], capture_output=True, |
| text=True, timeout=timeout_s, cwd=tmp) |
| except subprocess.TimeoutExpired: |
| return 0.0 |
| for line in result.stdout.splitlines(): |
| if line.startswith("__FRAC__"): |
| try: |
| tally = json.loads(line[len("__FRAC__"):]) |
| total = int(tally.get("total", 0)); passed = int(tally.get("passed", 0)) |
| except Exception: |
| return 0.0 |
| if total == 0: |
| return 1.0 if result.returncode == 0 else 0.0 |
| return max(0.0, min(1.0, passed / total)) |
| return 0.0 |
|
|
|
|
| def score_completion(problem: dict, completion_text: str) -> float: |
| """Echo-aware dense reward — verbatim logic from spec_rl._score_completion (handles the chat |
| shape where the model re-emits the `def <entry>(...)` signature).""" |
| entry = problem["entry_point"] |
| text = (completion_text or "").replace("```python", "").replace("```", "") |
| marker = f"def {entry}" |
| if marker in text: |
| preamble = problem["prompt"].split(marker, 1)[0] |
| func_src = text[text.index(marker):] |
| for tail in ("\n</", "\nif __name__", "\n#", "\nclass "): |
| j = func_src.find(tail) |
| if j != -1: |
| func_src = func_src[:j] |
| return fraction_passing({"prompt": preamble, "test": problem["test"], "entry_point": entry}, func_src) |
| for stop in STOP: |
| idx = text.find(stop) |
| if idx != -1: |
| text = text[:idx] |
| return fraction_passing(problem, text) |
|
|
|
|
| def reward_eval(label: str) -> dict: |
| """Drive the 12-problem HumanEval slice through the live server (chat, greedy, thinking off) |
| and score with the verbatim spec_rl reward. Returns mean reward + per-problem rewards + texts.""" |
| problems = load_problems(REWARD_N) |
| rewards, texts = [], [] |
| for prob in problems: |
| if budget_left() < 60: |
| print("[job] budget guard hit — stopping reward eval early", flush=True) |
| break |
| msgs = [{"role": "system", "content": RL_SYSTEM_PROMPT}, |
| {"role": "user", "content": prob["prompt"]}] |
| txt = chat_complete(msgs) |
| rewards.append(round(score_completion(prob, txt), 4)) |
| texts.append(txt) |
| mean = round(sum(rewards) / len(rewards), 4) if rewards else None |
| return {"label": label, "n": len(rewards), "mean_reward": mean, |
| "per_rollout_reward": rewards, "texts": texts} |
|
|
|
|
| def run_phase(dflash: bool, gamma: int, do_reward: bool) -> "tuple[dict, dict | None]": |
| """Serve once, measure decode tok/s, optionally run the reward eval, then tear the server down.""" |
| if DRYRUN: |
| proc = None |
| else: |
| proc = serve(dflash, gamma) |
| try: |
| if proc is not None: |
| wait_health(proc) |
| m = measure(dflash, gamma) |
| rw = None |
| if do_reward: |
| try: |
| rw = reward_eval(("dflash_g%d" % gamma) if dflash else "baseline") |
| except Exception as e: |
| rw = {"error": f"{type(e).__name__}: {e}"} |
| print(f"[job] reward_eval failed (non-fatal): {rw['error']}", flush=True) |
| return m, rw |
| finally: |
| if proc is not None: |
| proc.terminate() |
| try: |
| proc.wait(timeout=30) |
| except Exception: |
| proc.kill() |
| time.sleep(5) |
|
|
|
|
| def _expose_wheel_nvcc() -> None: |
| """Safety net: expose the pip nvidia-cuda-nvcc wheel if no toolkit is on PATH, so ANY residual |
| FlashInfer JIT can compile instead of hard-failing. Never exercised when the FlashInfer paths |
| are disabled (see serve()); pure belt-and-suspenders.""" |
| import shutil |
| import site |
| if shutil.which("nvcc") or os.path.isdir("/usr/local/cuda"): |
| return |
| roots = [] |
| try: |
| roots = list(site.getsitepackages()) |
| except Exception: |
| pass |
| roots += [os.path.dirname(os.path.dirname(__file__))] |
| for root in roots: |
| cand = os.path.join(root, "nvidia", "cuda_nvcc") |
| if os.path.exists(os.path.join(cand, "bin", "nvcc")): |
| os.environ["CUDA_HOME"] = cand |
| os.environ["CUDA_PATH"] = cand |
| os.environ["PATH"] = os.path.join(cand, "bin") + ":" + os.environ.get("PATH", "") |
| print(f"[job] exposed wheel nvcc (CUDA_HOME={cand})", flush=True) |
| return |
| print("[job] no wheel nvcc found to expose (FlashInfer JIT paths are disabled anyway)", flush=True) |
|
|
|
|
| def _parity(base_texts: list[str], texts: list[str]) -> dict: |
| mism = sum(1 for a, b in zip(base_texts, texts) if a != b) |
| n = min(len(base_texts), len(texts)) |
| return {"compared": n, "mismatches": mism, "lossless": mism == 0} |
|
|
|
|
| def run_determinism(repeats: int) -> int: |
| """Greedy-determinism probe: serve the baseline ONCE, run the spec_rl reward eval `repeats` times on |
| the SAME engine, and report per-run mean reward + cross-run completion divergence. If two greedy runs |
| of the same model on the same prompts differ, the 1.0-vs-0.85 reward gap seen in the DFlash A/B is |
| run-to-run MoE nondeterminism (FP non-associativity), NOT a DFlash quality change — which closes the |
| reward-invariance claim honestly (invariance holds by construction; the live number just isn't bit-stable).""" |
| proc = None if DRYRUN else serve(dflash=False, gamma=0) |
| try: |
| if proc is not None: |
| wait_health(proc) |
| runs = [] |
| for i in range(repeats): |
| if budget_left() < 60: |
| print("[job] budget guard — stopping determinism repeats early", flush=True) |
| break |
| rw = reward_eval(f"baseline_run{i + 1}") |
| runs.append(rw) |
| print(f"[job] DET_RUN_{i + 1}_JSON " + json.dumps({k: v for k, v in rw.items() if k != "texts"}), flush=True) |
| means = [r["mean_reward"] for r in runs] |
| base_texts = runs[0]["texts"] if runs else [] |
| run_vs_run1 = [_parity(base_texts, runs[j]["texts"]) for j in range(1, len(runs))] |
| det = { |
| "repeats": len(runs), |
| "per_run_mean_reward": means, |
| "per_run_reward": [r["per_rollout_reward"] for r in runs], |
| "run_vs_run1_parity": run_vs_run1, |
| "greedy_bit_reproducible": (all(d["mismatches"] == 0 for d in run_vs_run1) and len(set(means)) <= 1) |
| if run_vs_run1 else None, |
| "note": ("If per_run_mean_reward varies OR run_vs_run1_parity shows mismatches, greedy decoding is " |
| "NOT bit-reproducible run-to-run on this MoE — so the DFlash A/B's 1.0-vs-0.85 reward gap is " |
| "nondeterminism noise, not a DFlash quality change. Reward-invariance holds by construction " |
| "(lossless decode => identical reward); we decline to quote a DFlash reward, like tau."), |
| } |
| print("[job] DETERMINISM_JSON " + json.dumps(det), flush=True) |
| os.makedirs("results", exist_ok=True) |
| json.dump(det, open("results/determinism_check.json", "w"), indent=2) |
| return 0 |
| finally: |
| if proc is not None: |
| proc.terminate() |
| try: |
| proc.wait(timeout=30) |
| except Exception: |
| proc.kill() |
|
|
|
|
| def main() -> int: |
| print(f"[job] start; budget {BUDGET_S}s; N={N}; gammas={GAMMAS}; reward_n={REWARD_N}; " |
| f"reward_gamma={REWARD_GAMMA}; model={MODEL}; dryrun={DRYRUN}; det_repeats={DETERMINISM_REPEATS}", flush=True) |
| _expose_wheel_nvcc() |
|
|
| os.makedirs("results", exist_ok=True) |
| if DETERMINISM_REPEATS > 0: |
| return run_determinism(DETERMINISM_REPEATS) |
|
|
| |
| base, base_reward = run_phase(dflash=False, gamma=0, do_reward=True) |
| base_tps = base["tokens_per_s_mean"] |
| print("[job] BASELINE_JSON " + json.dumps({k: v for k, v in base.items() if k != "texts"}), flush=True) |
| json.dump(base, open("results/baseline.json", "w"), indent=2) |
|
|
| |
| |
| |
| |
| order = ([REWARD_GAMMA] if REWARD_GAMMA in GAMMAS else []) + [g for g in GAMMAS if g != REWARD_GAMMA] |
| sweep, reward_inv = [], None |
| for g in order: |
| if budget_left() < MIN_SERVE_S: |
| print(f"[job] budget guard: skipping γ={g} (only {budget_left():.0f}s left)", flush=True) |
| continue |
| try: |
| dfl, dfl_reward = run_phase(dflash=True, gamma=g, do_reward=(g == REWARD_GAMMA)) |
| except Exception as e: |
| print(f"[job] γ={g} serve FAILED (non-fatal, continuing): {type(e).__name__}: {e}", flush=True) |
| sweep.append({"gamma": g, "dflash_tps": None, "speedup_vs_baseline": None, |
| "parity": None, "error": f"{type(e).__name__}: {e}"}) |
| json.dump({"baseline_tps": round(base_tps, 3), "gamma_sweep": sweep}, |
| open("results/gamma_sweep.json", "w"), indent=2) |
| continue |
| parity = _parity(base["texts"], dfl["texts"]) |
| entry = {"gamma": g, "dflash_tps": dfl["tokens_per_s_mean"], |
| "speedup_vs_baseline": round(dfl["tokens_per_s_mean"] / base_tps, 3) if base_tps else None, |
| "parity": parity, "tau_recorded": dfl["acceptance_length_tau"]} |
| sweep.append(entry) |
| print("[job] GAMMA_POINT " + json.dumps(entry), flush=True) |
| json.dump({"baseline_tps": round(base_tps, 3), "gamma_sweep": sweep}, |
| open("results/gamma_sweep.json", "w"), indent=2) |
| if g == REWARD_GAMMA and base_reward and dfl_reward and "error" not in dfl_reward: |
| reward_parity = _parity(base_reward.get("texts", []), dfl_reward.get("texts", [])) |
| reward_inv = { |
| "n": dfl_reward.get("n"), |
| "baseline_mean_reward": base_reward.get("mean_reward"), |
| "dflash_mean_reward": dfl_reward.get("mean_reward"), |
| "reward_invariant": base_reward.get("mean_reward") == dfl_reward.get("mean_reward"), |
| "eval_byte_parity": reward_parity, |
| "baseline_per_rollout": base_reward.get("per_rollout_reward"), |
| "dflash_per_rollout": dfl_reward.get("per_rollout_reward"), |
| } |
| json.dump(reward_inv, open("results/reward_invariance.json", "w"), indent=2) |
| print("[job] REWARD_INVARIANCE_JSON " + json.dumps(reward_inv), flush=True) |
|
|
| |
| ok = [e for e in sweep if e.get("dflash_tps")] |
| sweep_sorted = sorted(ok, key=lambda e: e["dflash_tps"], reverse=True) |
| gamma_star = sweep_sorted[0] if sweep_sorted else None |
| g7 = next((e for e in ok if e["gamma"] == 7), None) |
| gamma_star_vs_g7 = (round(gamma_star["dflash_tps"] / g7["dflash_tps"], 3) |
| if gamma_star and g7 and g7["dflash_tps"] else None) |
| all_lossless = all(e["parity"]["lossless"] for e in ok) if ok else None |
|
|
| summary = { |
| "baseline_tps": round(base_tps, 3), |
| "gamma_sweep": sweep, |
| "gamma_star": gamma_star["gamma"] if gamma_star else None, |
| "gamma_star_tps": round(gamma_star["dflash_tps"], 3) if gamma_star else None, |
| "gamma_star_speedup_vs_g7": gamma_star_vs_g7, |
| "all_points_lossless": all_lossless, |
| "reward_invariance": reward_inv, |
| "elapsed_s": round(time.time() - T0, 1), |
| } |
| print("[job] RESULT " + json.dumps(summary), flush=True) |
| json.dump(summary, open("results/gamma_sweep.json", "w"), indent=2) |
| print("[job] SWEEP_JSON " + json.dumps(summary), flush=True) |
| if reward_inv: |
| print("[job] REWARD_INVARIANCE_JSON " + json.dumps(reward_inv), flush=True) |
| if base_reward: |
| print("[job] SAMPLE_REWARD_TEXT " + json.dumps((base_reward.get("texts") or [""])[:1]), flush=True) |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|