| |
| """ |
| Benchmark GPT 5.4 against the shader environment via WebSocket. |
| |
| Connects to a running shader environment server and runs a multi-turn |
| agent loop where GPT 5.4 tries to reproduce each reference image in GLSL. |
| |
| Usage: |
| # Start the server first: |
| # uvicorn server.app:app --host 0.0.0.0 --port 8000 |
| # OR: docker run -p 8000:8000 shader |
| |
| python envs/shader/benchmark.py # run 3 episodes |
| python envs/shader/benchmark.py --turns 5 # cap turns |
| python envs/shader/benchmark.py --url ws://localhost:8001/ws # custom server |
| """ |
|
|
| import argparse |
| import asyncio |
| import base64 |
| import json |
| import os |
| import time |
| from pathlib import Path |
|
|
| import websockets |
| from openai import OpenAI |
|
|
| |
| |
| |
| OUTPUT_DIR = Path(__file__).resolve().parent / "benchmark_output" |
|
|
| |
| |
| |
| _client_kwargs = {"api_key": os.environ["OPENAI_API_KEY"]} |
| if os.environ.get("OPENAI_BASE_URL"): |
| _client_kwargs["base_url"] = os.environ["OPENAI_BASE_URL"] |
|
|
| CLIENT = OpenAI(**_client_kwargs) |
| MODEL = "gpt-5.4" |
|
|
| INSTRUCTIONS = """\ |
| You are a GLSL shader expert. Your task is to write a Shadertoy-dialect \ |
| GLSL fragment shader that reproduces the given reference image as closely \ |
| as possible. |
| |
| Rules: |
| - Write a `void mainImage(out vec4 fragColor, in vec2 fragCoord)` function. |
| - You may use standard Shadertoy uniforms: iResolution, iTime, iTimeDelta, \ |
| iFrame, iMouse, iDate, iSampleRate. |
| - Do NOT include #version, precision, or #extension directives. |
| - Output ONLY the raw GLSL code — no markdown fencing, no explanation. |
| |
| The rendered output is compared to the reference via SSIM (structural \ |
| similarity). Target: SSIM >= 0.99.""" |
|
|
|
|
| |
| |
| |
|
|
| def strip_fences(text: str) -> str: |
| """Remove markdown code fences if present.""" |
| text = text.strip() |
| if text.startswith("```"): |
| lines = text.split("\n") |
| end = len(lines) - 1 |
| while end > 0 and lines[end].strip() != "```": |
| end -= 1 |
| if end > 0: |
| return "\n".join(lines[1:end]) |
| return "\n".join(lines[1:]) |
| return text |
|
|
|
|
| def extract_text(response) -> str: |
| """Pull text from a Responses API response object.""" |
| for item in response.output: |
| if item.type == "message": |
| for block in item.content: |
| if block.type == "output_text": |
| return block.text |
| return "" |
|
|
|
|
| def save_b64_png(b64: str, path: Path): |
| """Save a base64-encoded PNG string to a file.""" |
| path.write_bytes(base64.b64decode(b64)) |
|
|
|
|
| |
| |
| |
|
|
| async def ws_send(ws, msg_type: str, data: dict) -> dict: |
| """Send a message and return the response data.""" |
| await ws.send(json.dumps({"type": msg_type, "data": data})) |
| resp = json.loads(await ws.recv()) |
| if resp.get("type") == "error": |
| raise RuntimeError(f"Server error: {resp.get('data', {})}") |
| return resp["data"] |
|
|
|
|
| |
| |
| |
|
|
| async def run_episode(ws, seed: int, episode_dir: Path, max_turns: int) -> dict: |
| """Run one episode via WebSocket. Returns result dict.""" |
| |
| data = await ws_send(ws, "reset", {"seed": seed}) |
| obs = data["observation"] |
| task = obs["task"] |
| ref_b64 = obs["reference_png"] |
| remaining = obs["remaining"] |
|
|
| print(f" task: {task}, budget: {remaining}") |
|
|
| |
| episode_dir.mkdir(parents=True, exist_ok=True) |
| save_b64_png(ref_b64, episode_dir / "reference.png") |
|
|
| |
| conversation = [ |
| { |
| "role": "user", |
| "content": [ |
| { |
| "type": "input_text", |
| "text": "Write a GLSL shader that reproduces this reference image exactly.", |
| }, |
| { |
| "type": "input_image", |
| "image_url": f"data:image/png;base64,{ref_b64}", |
| }, |
| ], |
| } |
| ] |
|
|
| results = [] |
|
|
| for turn in range(1, max_turns + 1): |
| print(f" turn {turn}/{max_turns} ...", end=" ", flush=True) |
|
|
| |
| t0 = time.time() |
| resp = CLIENT.responses.create( |
| model=MODEL, |
| instructions=INSTRUCTIONS, |
| input=conversation, |
| max_output_tokens=8192, |
| temperature=0.2, |
| ) |
| api_s = time.time() - t0 |
|
|
| raw = extract_text(resp) |
| code = strip_fences(raw) |
|
|
| |
| data = await ws_send(ws, "step", {"code": code}) |
| obs = data["observation"] |
| reward = data["reward"] |
| done = data["done"] |
|
|
| compiled = obs["compiled"] |
| rendered = obs["rendered"] |
| ssim = obs["ssim"] |
| errors = obs["errors"] |
|
|
| |
| if obs.get("agent_png"): |
| save_b64_png(obs["agent_png"], episode_dir / f"turn_{turn}.png") |
|
|
| turn_data = { |
| "turn": turn, |
| "ssim": round(ssim, 6), |
| "reward": reward, |
| "compiled": compiled, |
| "rendered": rendered, |
| "errors": errors, |
| "api_seconds": round(api_s, 1), |
| "code_len": len(code), |
| } |
| results.append(turn_data) |
|
|
| if not compiled: |
| status = "COMPILE_FAIL" |
| elif not rendered: |
| status = "RENDER_FAIL" |
| else: |
| status = f"ssim={ssim:.4f}" |
| print(f"{status} reward={reward} ({api_s:.1f}s)") |
|
|
| if done: |
| if ssim >= 0.99: |
| print(f" => SOLVED on turn {turn}") |
| else: |
| print(f" => budget exhausted") |
| break |
|
|
| |
| conversation.append({"role": "assistant", "content": code}) |
|
|
| feedback_parts = [] |
| if not compiled: |
| feedback_parts.append( |
| "Compilation FAILED.\nErrors:\n" + "\n".join(errors) |
| ) |
| elif not rendered: |
| feedback_parts.append( |
| "Render FAILED.\nErrors:\n" + "\n".join(errors) |
| ) |
| else: |
| feedback_parts.append(f"SSIM: {ssim:.4f} (need >= 0.99).") |
| feedback_parts.append( |
| "Below is your current render vs the reference. " |
| "Fix the differences. Output ONLY raw GLSL code." |
| ) |
|
|
| feedback_content = [ |
| {"type": "input_text", "text": "\n".join(feedback_parts)} |
| ] |
|
|
| if obs.get("agent_png"): |
| feedback_content.append( |
| {"type": "input_image", "image_url": f"data:image/png;base64,{obs['agent_png']}"} |
| ) |
|
|
| conversation.append({"role": "user", "content": feedback_content}) |
|
|
| return { |
| "task": task, |
| "seed": seed, |
| "turns": results, |
| "best_ssim": max(r["ssim"] for r in results), |
| "solved": any(r["ssim"] >= 0.99 for r in results), |
| } |
|
|
|
|
| |
| |
| |
|
|
| async def run(url: str, seeds: list[int], max_turns: int): |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
| all_results = {} |
|
|
| async with websockets.connect(url) as ws: |
| for i, seed in enumerate(seeds): |
| label = f"episode_{i+1}" |
| print(f"\n{'='*60}") |
| print(f" [{label.upper()}] seed={seed}") |
| print(f"{'='*60}") |
|
|
| result = await run_episode( |
| ws, seed, OUTPUT_DIR / label, max_turns, |
| ) |
| all_results[label] = result |
|
|
| |
| print(f"\n{'='*60}") |
| print(" SUMMARY") |
| print(f"{'='*60}") |
| for label, data in all_results.items(): |
| best = data["best_ssim"] |
| solved = "YES" if data["solved"] else "no" |
| turns_used = len(data["turns"]) |
| print(f" {label}: task={data['task']} best_ssim={best:.4f} " |
| f"solved={solved} turns={turns_used}") |
|
|
| out_path = OUTPUT_DIR / "results.json" |
| with open(out_path, "w") as f: |
| json.dump(all_results, f, indent=2) |
| print(f"\nResults saved to {out_path}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Benchmark GPT 5.4 on shader env") |
| parser.add_argument("--url", default="ws://localhost:8000/ws", |
| help="WebSocket URL of the shader environment") |
| parser.add_argument("--turns", type=int, default=10, help="Max turns per episode") |
| parser.add_argument("--seeds", type=int, nargs="+", default=[1, 2, 3], |
| help="Seeds for reproducible task selection (one episode per seed)") |
| args = parser.parse_args() |
|
|
| asyncio.run(run(args.url, args.seeds, args.turns)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|