lean-laguna / scripts /hf_job_ab.py
art87able's picture
Upload folder using huggingface_hub
8cc969e verified
# /// script
# requires-python = ">=3.10"
# dependencies = ["vllm>=0.21", "huggingface_hub>=0.25", "datasets>=2.0"]
# ///
"""hf_job_ab.py — the real Lean Laguna A/B + γ-sweep + reward-invariance, as one HF Jobs run.
Runs ON Hugging Face Jobs (a GPU batch job, no ssh, auto-stops when done). In ONE GPU session
(so the model load cost is amortized) it produces three pieces of MEASURED evidence:
(1) Headline decode A/B — serve Laguna XS.2 baseline, measure tokens/sec over N mixed prompts;
re-serve with the DFlash speculator (γ=7), measure again; byte-parity-check the greedy outputs.
(2) γ-sweep (lossless throughput-optimal γ) — re-serve DFlash at num_speculative_tokens ∈ GAMMAS
(default 5,7,9; one cold serve per γ because vLLM bakes speculative_config at engine init),
measure tok/s each, parity-check each. Baseline is measured ONCE (γ-independent). Report the
throughput-optimal γ* and its speedup vs γ=7.
(3) Reward-invariance (live) — drive the SAME 12-problem HumanEval slice the canonical
`prime eval run spec_rl` baseline used (mean reward 0.85) through the baseline and the γ=7
DFlash server via /v1/chat/completions (greedy, thinking off) and score with the VERBATIM
spec_rl reward (fraction_passing). baseline_mean_reward == dflash_mean_reward by greedy
byte-parity — reward-invariance demonstrated live, not just argued by construction.
Submit with (h200 is the proven, best-tested target; bound the spend with --timeout + BUDGET_S):
hf jobs uv run --flavor h200 --timeout 2100 \
--secrets HF_TOKEN --env GAMMAS=5,7,9 --env BUDGET_S=1900 scripts/hf_job_ab.py
Honesty guards baked in:
* Everything is MEASURED — no fabricated numbers. A hard wall-clock budget bounds the spend.
* τ (acceptance length) is recorded from /metrics but NOT used as a headline — the counters pin
at the γ+1 ceiling at this granularity, so τ is treated as unreliable and never quoted.
* The decode tok/s A/B is the throughput headline; eval wall-clock is NOT a throughput claim.
* `ttft_s_mean` is full-completion latency, NOT true time-to-first-token (the harness does not
isolate prefill) — labeled as such, never reported as TTFT.
Local dry-run (no GPU, no network) — validates the loop shape + scoring against the stdlib stub:
python scripts/stub_server.py --port 8000 & # baseline-shaped stub
printf '%s\n' '{"prompt":"def add(a,b):\\n \\"\\"\\"add\\"\\"\\"\\n","test":"def check(c):\\n assert c(1,2)==3\\n","entry_point":"add"}' > /tmp/toy.jsonl
DRYRUN=1 GAMMAS=7 REWARD_N=1 SPEC_RL_DATASET=/tmp/toy.jsonl python scripts/hf_job_ab.py
"""
from __future__ import annotations
import ast
import json
import os
import subprocess
import sys
import tempfile
import time
import urllib.request
from pathlib import Path
MODEL = os.environ.get("MODEL", "poolside/Laguna-XS.2")
SPECULATOR = os.environ.get("SPECULATOR", "poolside/Laguna-XS.2-speculator.dflash")
GAMMA = int(os.environ.get("GAMMA", "7")) # default draft length (the card's value)
GAMMAS = [int(g) for g in os.environ.get("GAMMAS", "5,7,9").split(",") if g.strip()]
REWARD_GAMMA = int(os.environ.get("REWARD_GAMMA", "7")) # γ used for the live reward-invariance eval
REWARD_N = int(os.environ.get("REWARD_N", "12")) # HumanEval problems (matches the 0.85 baseline)
REWARD_MAX_TOKENS = int(os.environ.get("REWARD_MAX_TOKENS", "512"))
N = int(os.environ.get("N", "0")) # 0 => use the full curated prompt set
MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "256"))
BUDGET_S = int(os.environ.get("BUDGET_S", "1500")) # hard wall-clock cap (credit guard)
MIN_SERVE_S = int(os.environ.get("MIN_SERVE_S", "300")) # don't start a serve we can't finish
DETERMINISM_REPEATS = int(os.environ.get("DETERMINISM_REPEATS", "0")) # >0 => greedy-determinism probe mode
DRYRUN = os.environ.get("DRYRUN", "") == "1" # local stub mode: skip serving, just measure
PORT = 8000
STOP = ["\nclass ", "\ndef ", "\n#", "\nif __name__"]
EXEC_TIMEOUT_S = 8
T0 = time.time()
# A mixed-difficulty set so the throughput A/B is measured across EASY -> HARD, not just trivial
# canonical functions (which over-state the win by pinning acceptance at the γ+1 ceiling).
PROMPTS = [
# --- trivial canonical (high acceptance: the ceiling case) ---
"def fib(n):\n \"\"\"Return the n-th Fibonacci number.\"\"\"\n",
"def is_prime(n):\n \"\"\"Return True iff n is prime.\"\"\"\n",
"def factorial(n):\n \"\"\"Return n! (n factorial).\"\"\"\n",
"def reverse_words(s):\n \"\"\"Reverse the order of words in s.\"\"\"\n",
# --- medium ---
"def binary_search(arr, target):\n \"\"\"Return the index of target in sorted arr, else -1.\"\"\"\n",
"def merge_sorted(a, b):\n \"\"\"Merge two sorted lists into one sorted list.\"\"\"\n",
"def is_balanced(s):\n \"\"\"Return True iff the brackets ()[]{} in s are balanced.\"\"\"\n",
"def roman_to_int(s):\n \"\"\"Convert a Roman numeral string to an integer.\"\"\"\n",
"def flatten(nested):\n \"\"\"Flatten an arbitrarily nested list of ints into a flat list.\"\"\"\n",
# --- harder / branchy / rare-token (acceptance should drop here) ---
"def lcs(a, b):\n \"\"\"Return the length of the longest common subsequence of strings a and b.\"\"\"\n",
"def parse_duration(s):\n \"\"\"Parse strings like '1h30m', '45s', '2d' into total seconds. Raise ValueError on bad input.\"\"\"\n",
"def group_anagrams(words):\n \"\"\"Group words that are anagrams of each other into a list of lists.\"\"\"\n",
"class LRUCache:\n \"\"\"A fixed-capacity LRU cache with get(key) and put(key, value).\"\"\"\n",
"def dijkstra(graph, start):\n \"\"\"graph: dict node -> list of (neighbor, weight). Return dict of shortest distances from start.\"\"\"\n",
]
if N <= 0:
N = len(PROMPTS)
PROMPTS = (PROMPTS * ((N // len(PROMPTS)) + 1))[:N] # repeat only if a larger N is forced
# spec_rl's system prompt, verbatim, so the live reward eval sends the EXACT same instruction the
# canonical `prime eval run spec_rl` baseline used.
RL_SYSTEM_PROMPT = (
"You are an expert Python programmer. You will be given a function "
"signature and docstring. Complete the function body only. Do not repeat "
"the signature, do not add explanations, and do not wrap the code in "
"markdown fences. Output only the indented function body."
)
def budget_left() -> float:
return BUDGET_S - (time.time() - T0)
def serve(dflash: bool, gamma: int = GAMMA) -> subprocess.Popen:
env = {**os.environ,
"VLLM_USE_DEEP_GEMM": "0",
# Laguna is an UNQUANTIZED bf16 MoE. The slim uv image ships only pip CUDA *runtime*
# wheels — no nvcc/toolkit at /usr/local/cuda. vLLM/FlashInfer lazily JIT-compile
# several kernels on first use (inside profile_run), each needing nvcc, so each dies
# "Could not find nvcc". We disable EVERY FlashInfer JIT path and pin prebuilt
# alternatives:
# - MoE -> Triton fused-MoE (PTX via Triton). [verified: sm90+sm120 cutlass JIT crash]
# - sampler -> torch top-k/top-p (not FlashInfer). [verified: sampling JIT crash]
# - attention -> FLASH_ATTN (prebuilt flash-attn wheel, not FlashInfer JIT).
"VLLM_USE_FLASHINFER_MOE_FP16": "0",
"VLLM_USE_FLASHINFER_MOE_FP8": "0",
"VLLM_USE_FLASHINFER_SAMPLER": "0",
"VLLM_ATTENTION_BACKEND": os.environ.get("VLLM_ATTENTION_BACKEND", "FLASH_ATTN")}
cmd = [sys.executable, "-m", "vllm.entrypoints.openai.api_server",
"--model", MODEL, "--port", str(PORT), "--tensor-parallel-size", "1",
"--trust-remote-code", # Laguna's custom MoE arch needs it in vLLM
"--enforce-eager", # skip CUDA-graph capture: leaner + faster start; A/B ratio unaffected
"--gpu-memory-utilization", "0.9",
"--max-model-len", os.environ.get("SPECRL_MAX_LEN", "4096"),
# Cap concurrent sequences low: we issue sequential single requests, and DFlash's draft
# slots scale with max_num_seqs and compete with the scheduler's token budget. At the
# default seq count, γ=9 drove max_num_scheduled_tokens to 0 (serve refused to start);
# a low cap lets γ up to ~11 schedule. Single-stream A/B ratio is unaffected.
"--max-num-seqs", os.environ.get("MAX_NUM_SEQS", "16"),
# Laguna's chat template defaults enable_thinking false; pin it so the chat-route reward
# eval is non-thinking (matches the canonical hosted baseline run; greedy A/B stays clean).
"--default-chat-template-kwargs", json.dumps({"enable_thinking": False})]
# NOTE: base poolside/Laguna-XS.2 loads in bf16 at ~62 GiB (full MoE resident). It fits a
# 96GB-class GPU (rtx-pro-6000) with room for KV; h200 (141GB) is the safe, best-tested target.
# The earlier failures were NOT OOM — they were the nvcc/FlashInfer-JIT issue fixed above.
if dflash:
cmd += ["--speculative-config",
json.dumps({"model": SPECULATOR, "num_speculative_tokens": gamma, "method": "dflash"})]
print(f"[job] serving {'DFlash(γ=%d)' % gamma if dflash else 'baseline'}: {' '.join(cmd)}", flush=True)
return subprocess.Popen(cmd, env=env)
def wait_health(proc: subprocess.Popen, timeout: int = 900) -> None:
url = f"http://localhost:{PORT}/health"
t = time.time()
while time.time() - t < timeout:
if proc.poll() is not None:
raise RuntimeError("vLLM server exited during startup (check logs above)")
try:
urllib.request.urlopen(url, timeout=5)
print("[job] server healthy", flush=True)
return
except Exception:
time.sleep(5)
raise TimeoutError("server did not become healthy in time")
def _post(path: str, payload: dict) -> dict:
req = urllib.request.Request(f"http://localhost:{PORT}{path}",
data=json.dumps(payload).encode(),
headers={"Content-Type": "application/json"})
with urllib.request.urlopen(req, timeout=300) as r:
return json.loads(r.read().decode())
def complete(prompt: str) -> tuple[str, float, float]:
t = time.time()
obj = _post("/v1/completions", {"model": MODEL, "prompt": prompt,
"max_tokens": MAX_TOKENS, "temperature": 0.0, "stop": STOP})
dt = time.time() - t
ch = obj["choices"][0]
text = ch.get("text", "") or ""
ntok = (obj.get("usage") or {}).get("completion_tokens") or len(text.split())
return text, (ntok / dt if dt else 0.0), dt
def chat_complete(messages: list[dict], max_tokens: int = REWARD_MAX_TOKENS) -> str:
"""Greedy chat completion (thinking off), matching the spec_rl eval's chat shape."""
obj = _post("/v1/chat/completions",
{"model": MODEL, "messages": messages, "max_tokens": max_tokens,
"temperature": 0.0, "chat_template_kwargs": {"enable_thinking": False}})
msg = obj["choices"][0].get("message") or {}
return msg.get("content") or ""
def tau_from_metrics(gamma: int) -> float | None:
try:
with urllib.request.urlopen(f"http://localhost:{PORT}/metrics", timeout=10) as r:
body = r.read().decode()
except Exception:
return None
acc = draft = None
for line in body.splitlines():
if line.startswith("vllm:spec_decode_num_accepted_tokens"):
acc = float(line.split()[-1])
elif line.startswith("vllm:spec_decode_num_draft_tokens"):
draft = float(line.split()[-1])
if acc is not None and draft and draft > 0:
passes = draft / gamma
return (acc + passes) / passes if passes else None
return None
def measure(dflash: bool, gamma: int = GAMMA) -> dict:
"""Decode throughput over the mixed prompt set. Records τ for completeness (never quoted)."""
texts, tps, ttft = [], [], []
for p in PROMPTS:
if budget_left() < 120:
print("[job] budget guard hit — stopping measure early", flush=True)
break
txt, t_ps, dt = complete(p)
texts.append(txt); tps.append(t_ps); ttft.append(dt)
return {
"label": ("dflash_g%d" % gamma) if dflash else "baseline", "model": MODEL, "n": len(texts),
"gamma": gamma if dflash else None,
"tokens_per_s_mean": sum(tps) / len(tps) if tps else 0.0,
"latency_s_mean": sum(ttft) / len(ttft) if ttft else 0.0, # full-completion latency, NOT true TTFT
"acceptance_length_tau": tau_from_metrics(gamma) if dflash else 1.0, # recorded, NOT quoted
"texts": texts,
}
# --------------------------------------------------------------------------- #
# Reward core — copied VERBATIM from environments/spec_rl/spec_rl.py so the live
# reward number is computed by the identical scorer the canonical eval used.
# --------------------------------------------------------------------------- #
def load_problems(num_examples: int) -> list[dict]:
"""First `num_examples` problems as {prompt, test, entry_point}. SPEC_RL_DATASET (.jsonl) wins
(the dry-run seam); else the canonical HumanEval test split — identical to spec_rl.load_problems."""
src = os.environ.get("SPEC_RL_DATASET")
if src and src.endswith(".jsonl") and os.path.exists(src):
with open(src) as f:
rows = [json.loads(line) for line in f if line.strip()]
return rows[:num_examples]
from datasets import load_dataset
dataset_id = src or os.environ.get("HUMANEVAL_DATASET", "openai/openai_humaneval")
split = os.environ.get("SPEC_RL_DATASET_SPLIT", "test")
ds = load_dataset(dataset_id, split=split)
num_examples = min(num_examples, len(ds))
return [dict(ds[i]) for i in range(num_examples)]
class _AssertCounter(ast.NodeTransformer):
"""Rewrite each `assert` so a failure is COUNTED, not fatal — turns HumanEval's all-or-nothing
check() into a fractional pass rate. (Verbatim from spec_rl.py.)"""
def visit_Assert(self, node: ast.Assert):
try_node = ast.Try(
body=[ast.Assign(targets=[ast.Name(id="__ok", ctx=ast.Store())],
value=ast.Call(func=ast.Name(id="bool", ctx=ast.Load()),
args=[node.test], keywords=[]))],
handlers=[ast.ExceptHandler(type=ast.Name(id="BaseException", ctx=ast.Load()), name=None,
body=[ast.Assign(targets=[ast.Name(id="__ok", ctx=ast.Store())],
value=ast.Constant(value=False))])],
orelse=[], finalbody=[])
incr_total = ast.parse("__tally['total'] += 1").body[0]
incr_pass = ast.parse("if __ok:\n __tally['passed'] += 1").body[0]
out = [try_node, incr_total, incr_pass]
for n in out:
ast.copy_location(n, node)
ast.fix_missing_locations(n)
return out
def passes(problem: dict, completion: str, timeout_s: int = EXEC_TIMEOUT_S) -> bool:
program = problem["prompt"] + completion + "\n" + problem["test"] + f"\ncheck({problem['entry_point']})\n"
with tempfile.TemporaryDirectory() as tmp:
prog_path = Path(tmp) / "candidate.py"
prog_path.write_text(program)
try:
result = subprocess.run([sys.executable, str(prog_path)], capture_output=True,
text=True, timeout=timeout_s, cwd=tmp)
except subprocess.TimeoutExpired:
return False
return result.returncode == 0
def fraction_passing(problem: dict, completion: str, timeout_s: int = EXEC_TIMEOUT_S) -> float:
try:
tree = ast.parse(problem["test"])
except SyntaxError:
return 1.0 if passes(problem, completion, timeout_s) else 0.0
tree = _AssertCounter().visit(tree)
ast.fix_missing_locations(tree)
try:
instrumented_test = ast.unparse(tree)
except Exception:
return 1.0 if passes(problem, completion, timeout_s) else 0.0
program = (
"__tally = {'passed': 0, 'total': 0}\n"
+ problem["prompt"] + completion + "\n" + instrumented_test + "\n"
+ "try:\n" + f" check({problem['entry_point']})\n"
+ "except BaseException:\n pass\n"
+ "import json as __json\nprint('__FRAC__' + __json.dumps(__tally))\n")
with tempfile.TemporaryDirectory() as tmp:
prog_path = Path(tmp) / "candidate.py"
prog_path.write_text(program)
try:
result = subprocess.run([sys.executable, str(prog_path)], capture_output=True,
text=True, timeout=timeout_s, cwd=tmp)
except subprocess.TimeoutExpired:
return 0.0
for line in result.stdout.splitlines():
if line.startswith("__FRAC__"):
try:
tally = json.loads(line[len("__FRAC__"):])
total = int(tally.get("total", 0)); passed = int(tally.get("passed", 0))
except Exception:
return 0.0
if total == 0:
return 1.0 if result.returncode == 0 else 0.0
return max(0.0, min(1.0, passed / total))
return 0.0
def score_completion(problem: dict, completion_text: str) -> float:
"""Echo-aware dense reward — verbatim logic from spec_rl._score_completion (handles the chat
shape where the model re-emits the `def <entry>(...)` signature)."""
entry = problem["entry_point"]
text = (completion_text or "").replace("```python", "").replace("```", "")
marker = f"def {entry}"
if marker in text:
preamble = problem["prompt"].split(marker, 1)[0]
func_src = text[text.index(marker):]
for tail in ("\n</", "\nif __name__", "\n#", "\nclass "):
j = func_src.find(tail)
if j != -1:
func_src = func_src[:j]
return fraction_passing({"prompt": preamble, "test": problem["test"], "entry_point": entry}, func_src)
for stop in STOP:
idx = text.find(stop)
if idx != -1:
text = text[:idx]
return fraction_passing(problem, text)
def reward_eval(label: str) -> dict:
"""Drive the 12-problem HumanEval slice through the live server (chat, greedy, thinking off)
and score with the verbatim spec_rl reward. Returns mean reward + per-problem rewards + texts."""
problems = load_problems(REWARD_N)
rewards, texts = [], []
for prob in problems:
if budget_left() < 60:
print("[job] budget guard hit — stopping reward eval early", flush=True)
break
msgs = [{"role": "system", "content": RL_SYSTEM_PROMPT},
{"role": "user", "content": prob["prompt"]}]
txt = chat_complete(msgs)
rewards.append(round(score_completion(prob, txt), 4))
texts.append(txt)
mean = round(sum(rewards) / len(rewards), 4) if rewards else None
return {"label": label, "n": len(rewards), "mean_reward": mean,
"per_rollout_reward": rewards, "texts": texts}
def run_phase(dflash: bool, gamma: int, do_reward: bool) -> "tuple[dict, dict | None]":
"""Serve once, measure decode tok/s, optionally run the reward eval, then tear the server down."""
if DRYRUN:
proc = None
else:
proc = serve(dflash, gamma)
try:
if proc is not None:
wait_health(proc)
m = measure(dflash, gamma)
rw = None
if do_reward:
try:
rw = reward_eval(("dflash_g%d" % gamma) if dflash else "baseline")
except Exception as e: # never let the reward eval tank the sweep
rw = {"error": f"{type(e).__name__}: {e}"}
print(f"[job] reward_eval failed (non-fatal): {rw['error']}", flush=True)
return m, rw
finally:
if proc is not None:
proc.terminate()
try:
proc.wait(timeout=30)
except Exception:
proc.kill()
time.sleep(5)
def _expose_wheel_nvcc() -> None:
"""Safety net: expose the pip nvidia-cuda-nvcc wheel if no toolkit is on PATH, so ANY residual
FlashInfer JIT can compile instead of hard-failing. Never exercised when the FlashInfer paths
are disabled (see serve()); pure belt-and-suspenders."""
import shutil
import site
if shutil.which("nvcc") or os.path.isdir("/usr/local/cuda"):
return
roots = []
try:
roots = list(site.getsitepackages())
except Exception:
pass
roots += [os.path.dirname(os.path.dirname(__file__))]
for root in roots:
cand = os.path.join(root, "nvidia", "cuda_nvcc")
if os.path.exists(os.path.join(cand, "bin", "nvcc")):
os.environ["CUDA_HOME"] = cand
os.environ["CUDA_PATH"] = cand
os.environ["PATH"] = os.path.join(cand, "bin") + ":" + os.environ.get("PATH", "")
print(f"[job] exposed wheel nvcc (CUDA_HOME={cand})", flush=True)
return
print("[job] no wheel nvcc found to expose (FlashInfer JIT paths are disabled anyway)", flush=True)
def _parity(base_texts: list[str], texts: list[str]) -> dict:
mism = sum(1 for a, b in zip(base_texts, texts) if a != b)
n = min(len(base_texts), len(texts))
return {"compared": n, "mismatches": mism, "lossless": mism == 0}
def run_determinism(repeats: int) -> int:
"""Greedy-determinism probe: serve the baseline ONCE, run the spec_rl reward eval `repeats` times on
the SAME engine, and report per-run mean reward + cross-run completion divergence. If two greedy runs
of the same model on the same prompts differ, the 1.0-vs-0.85 reward gap seen in the DFlash A/B is
run-to-run MoE nondeterminism (FP non-associativity), NOT a DFlash quality change — which closes the
reward-invariance claim honestly (invariance holds by construction; the live number just isn't bit-stable)."""
proc = None if DRYRUN else serve(dflash=False, gamma=0)
try:
if proc is not None:
wait_health(proc)
runs = []
for i in range(repeats):
if budget_left() < 60:
print("[job] budget guard — stopping determinism repeats early", flush=True)
break
rw = reward_eval(f"baseline_run{i + 1}")
runs.append(rw)
print(f"[job] DET_RUN_{i + 1}_JSON " + json.dumps({k: v for k, v in rw.items() if k != "texts"}), flush=True)
means = [r["mean_reward"] for r in runs]
base_texts = runs[0]["texts"] if runs else []
run_vs_run1 = [_parity(base_texts, runs[j]["texts"]) for j in range(1, len(runs))]
det = {
"repeats": len(runs),
"per_run_mean_reward": means,
"per_run_reward": [r["per_rollout_reward"] for r in runs],
"run_vs_run1_parity": run_vs_run1,
"greedy_bit_reproducible": (all(d["mismatches"] == 0 for d in run_vs_run1) and len(set(means)) <= 1)
if run_vs_run1 else None,
"note": ("If per_run_mean_reward varies OR run_vs_run1_parity shows mismatches, greedy decoding is "
"NOT bit-reproducible run-to-run on this MoE — so the DFlash A/B's 1.0-vs-0.85 reward gap is "
"nondeterminism noise, not a DFlash quality change. Reward-invariance holds by construction "
"(lossless decode => identical reward); we decline to quote a DFlash reward, like tau."),
}
print("[job] DETERMINISM_JSON " + json.dumps(det), flush=True)
os.makedirs("results", exist_ok=True)
json.dump(det, open("results/determinism_check.json", "w"), indent=2)
return 0
finally:
if proc is not None:
proc.terminate()
try:
proc.wait(timeout=30)
except Exception:
proc.kill()
def main() -> int:
print(f"[job] start; budget {BUDGET_S}s; N={N}; gammas={GAMMAS}; reward_n={REWARD_N}; "
f"reward_gamma={REWARD_GAMMA}; model={MODEL}; dryrun={DRYRUN}; det_repeats={DETERMINISM_REPEATS}", flush=True)
_expose_wheel_nvcc()
os.makedirs("results", exist_ok=True)
if DETERMINISM_REPEATS > 0:
return run_determinism(DETERMINISM_REPEATS)
# 1) Baseline ONCE (γ-independent): decode tok/s + the baseline reward eval. Persist immediately.
base, base_reward = run_phase(dflash=False, gamma=0, do_reward=True)
base_tps = base["tokens_per_s_mean"]
print("[job] BASELINE_JSON " + json.dumps({k: v for k, v in base.items() if k != "texts"}), flush=True)
json.dump(base, open("results/baseline.json", "w"), indent=2)
# 2) γ-sweep. Process REWARD_GAMMA first so the headline parity + reward-invariance land early.
# DURABILITY: each γ is isolated in try/except, and we print + json.dump after EVERY phase —
# so a serve that refuses to start at a high γ (scheduler-budget config error) records a
# skipped point and the run CONTINUES, and a late crash can never erase earlier evidence.
order = ([REWARD_GAMMA] if REWARD_GAMMA in GAMMAS else []) + [g for g in GAMMAS if g != REWARD_GAMMA]
sweep, reward_inv = [], None
for g in order:
if budget_left() < MIN_SERVE_S:
print(f"[job] budget guard: skipping γ={g} (only {budget_left():.0f}s left)", flush=True)
continue
try:
dfl, dfl_reward = run_phase(dflash=True, gamma=g, do_reward=(g == REWARD_GAMMA))
except Exception as e:
print(f"[job] γ={g} serve FAILED (non-fatal, continuing): {type(e).__name__}: {e}", flush=True)
sweep.append({"gamma": g, "dflash_tps": None, "speedup_vs_baseline": None,
"parity": None, "error": f"{type(e).__name__}: {e}"})
json.dump({"baseline_tps": round(base_tps, 3), "gamma_sweep": sweep},
open("results/gamma_sweep.json", "w"), indent=2)
continue
parity = _parity(base["texts"], dfl["texts"])
entry = {"gamma": g, "dflash_tps": dfl["tokens_per_s_mean"],
"speedup_vs_baseline": round(dfl["tokens_per_s_mean"] / base_tps, 3) if base_tps else None,
"parity": parity, "tau_recorded": dfl["acceptance_length_tau"]}
sweep.append(entry)
print("[job] GAMMA_POINT " + json.dumps(entry), flush=True)
json.dump({"baseline_tps": round(base_tps, 3), "gamma_sweep": sweep},
open("results/gamma_sweep.json", "w"), indent=2) # persist after every point
if g == REWARD_GAMMA and base_reward and dfl_reward and "error" not in dfl_reward:
reward_parity = _parity(base_reward.get("texts", []), dfl_reward.get("texts", []))
reward_inv = {
"n": dfl_reward.get("n"),
"baseline_mean_reward": base_reward.get("mean_reward"),
"dflash_mean_reward": dfl_reward.get("mean_reward"),
"reward_invariant": base_reward.get("mean_reward") == dfl_reward.get("mean_reward"),
"eval_byte_parity": reward_parity,
"baseline_per_rollout": base_reward.get("per_rollout_reward"),
"dflash_per_rollout": dfl_reward.get("per_rollout_reward"),
}
json.dump(reward_inv, open("results/reward_invariance.json", "w"), indent=2) # persist NOW
print("[job] REWARD_INVARIANCE_JSON " + json.dumps(reward_inv), flush=True) # emit NOW
# Consolidated summary (gamma_star ignores any failed/None points).
ok = [e for e in sweep if e.get("dflash_tps")]
sweep_sorted = sorted(ok, key=lambda e: e["dflash_tps"], reverse=True)
gamma_star = sweep_sorted[0] if sweep_sorted else None
g7 = next((e for e in ok if e["gamma"] == 7), None)
gamma_star_vs_g7 = (round(gamma_star["dflash_tps"] / g7["dflash_tps"], 3)
if gamma_star and g7 and g7["dflash_tps"] else None)
all_lossless = all(e["parity"]["lossless"] for e in ok) if ok else None
summary = {
"baseline_tps": round(base_tps, 3),
"gamma_sweep": sweep,
"gamma_star": gamma_star["gamma"] if gamma_star else None,
"gamma_star_tps": round(gamma_star["dflash_tps"], 3) if gamma_star else None,
"gamma_star_speedup_vs_g7": gamma_star_vs_g7,
"all_points_lossless": all_lossless,
"reward_invariance": reward_inv,
"elapsed_s": round(time.time() - T0, 1),
}
print("[job] RESULT " + json.dumps(summary), flush=True)
json.dump(summary, open("results/gamma_sweep.json", "w"), indent=2)
print("[job] SWEEP_JSON " + json.dumps(summary), flush=True)
if reward_inv:
print("[job] REWARD_INVARIANCE_JSON " + json.dumps(reward_inv), flush=True)
if base_reward:
print("[job] SAMPLE_REWARD_TEXT " + json.dumps((base_reward.get("texts") or [""])[:1]), flush=True)
return 0
if __name__ == "__main__":
raise SystemExit(main())