#!/usr/bin/env python3 """ Benchmark GPT 5.4 against the shader environment via WebSocket. Connects to a running shader environment server and runs a multi-turn agent loop where GPT 5.4 tries to reproduce each reference image in GLSL. Usage: # Start the server first: # uvicorn server.app:app --host 0.0.0.0 --port 8000 # OR: docker run -p 8000:8000 shader python envs/shader/benchmark.py # run 3 episodes python envs/shader/benchmark.py --turns 5 # cap turns python envs/shader/benchmark.py --url ws://localhost:8001/ws # custom server """ import argparse import asyncio import base64 import json import os import time from pathlib import Path import websockets from openai import OpenAI # --------------------------------------------------------------------------- # Paths # --------------------------------------------------------------------------- OUTPUT_DIR = Path(__file__).resolve().parent / "benchmark_output" # --------------------------------------------------------------------------- # OpenAI client (Responses API) # --------------------------------------------------------------------------- _client_kwargs = {"api_key": os.environ["OPENAI_API_KEY"]} if os.environ.get("OPENAI_BASE_URL"): _client_kwargs["base_url"] = os.environ["OPENAI_BASE_URL"] CLIENT = OpenAI(**_client_kwargs) MODEL = "gpt-5.4" INSTRUCTIONS = """\ You are a GLSL shader expert. Your task is to write a Shadertoy-dialect \ GLSL fragment shader that reproduces the given reference image as closely \ as possible. Rules: - Write a `void mainImage(out vec4 fragColor, in vec2 fragCoord)` function. - You may use standard Shadertoy uniforms: iResolution, iTime, iTimeDelta, \ iFrame, iMouse, iDate, iSampleRate. - Do NOT include #version, precision, or #extension directives. - Output ONLY the raw GLSL code — no markdown fencing, no explanation. The rendered output is compared to the reference via SSIM (structural \ similarity). Target: SSIM >= 0.99.""" # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def strip_fences(text: str) -> str: """Remove markdown code fences if present.""" text = text.strip() if text.startswith("```"): lines = text.split("\n") end = len(lines) - 1 while end > 0 and lines[end].strip() != "```": end -= 1 if end > 0: return "\n".join(lines[1:end]) return "\n".join(lines[1:]) return text def extract_text(response) -> str: """Pull text from a Responses API response object.""" for item in response.output: if item.type == "message": for block in item.content: if block.type == "output_text": return block.text return "" def save_b64_png(b64: str, path: Path): """Save a base64-encoded PNG string to a file.""" path.write_bytes(base64.b64decode(b64)) # --------------------------------------------------------------------------- # Server communication # --------------------------------------------------------------------------- async def ws_send(ws, msg_type: str, data: dict) -> dict: """Send a message and return the response data.""" await ws.send(json.dumps({"type": msg_type, "data": data})) resp = json.loads(await ws.recv()) if resp.get("type") == "error": raise RuntimeError(f"Server error: {resp.get('data', {})}") return resp["data"] # --------------------------------------------------------------------------- # Agent loop # --------------------------------------------------------------------------- async def run_episode(ws, seed: int, episode_dir: Path, max_turns: int) -> dict: """Run one episode via WebSocket. Returns result dict.""" # Reset data = await ws_send(ws, "reset", {"seed": seed}) obs = data["observation"] task = obs["task"] ref_b64 = obs["reference_png"] remaining = obs["remaining"] print(f" task: {task}, budget: {remaining}") # Save reference image episode_dir.mkdir(parents=True, exist_ok=True) save_b64_png(ref_b64, episode_dir / "reference.png") # Initial conversation with reference image conversation = [ { "role": "user", "content": [ { "type": "input_text", "text": "Write a GLSL shader that reproduces this reference image exactly.", }, { "type": "input_image", "image_url": f"data:image/png;base64,{ref_b64}", }, ], } ] results = [] for turn in range(1, max_turns + 1): print(f" turn {turn}/{max_turns} ...", end=" ", flush=True) # Call GPT t0 = time.time() resp = CLIENT.responses.create( model=MODEL, instructions=INSTRUCTIONS, input=conversation, max_output_tokens=8192, temperature=0.2, ) api_s = time.time() - t0 raw = extract_text(resp) code = strip_fences(raw) # Step the environment data = await ws_send(ws, "step", {"code": code}) obs = data["observation"] reward = data["reward"] done = data["done"] compiled = obs["compiled"] rendered = obs["rendered"] ssim = obs["ssim"] errors = obs["errors"] # Save agent render if available if obs.get("agent_png"): save_b64_png(obs["agent_png"], episode_dir / f"turn_{turn}.png") turn_data = { "turn": turn, "ssim": round(ssim, 6), "reward": reward, "compiled": compiled, "rendered": rendered, "errors": errors, "api_seconds": round(api_s, 1), "code_len": len(code), } results.append(turn_data) if not compiled: status = "COMPILE_FAIL" elif not rendered: status = "RENDER_FAIL" else: status = f"ssim={ssim:.4f}" print(f"{status} reward={reward} ({api_s:.1f}s)") if done: if ssim >= 0.99: print(f" => SOLVED on turn {turn}") else: print(f" => budget exhausted") break # Feedback for next turn conversation.append({"role": "assistant", "content": code}) feedback_parts = [] if not compiled: feedback_parts.append( "Compilation FAILED.\nErrors:\n" + "\n".join(errors) ) elif not rendered: feedback_parts.append( "Render FAILED.\nErrors:\n" + "\n".join(errors) ) else: feedback_parts.append(f"SSIM: {ssim:.4f} (need >= 0.99).") feedback_parts.append( "Below is your current render vs the reference. " "Fix the differences. Output ONLY raw GLSL code." ) feedback_content = [ {"type": "input_text", "text": "\n".join(feedback_parts)} ] if obs.get("agent_png"): feedback_content.append( {"type": "input_image", "image_url": f"data:image/png;base64,{obs['agent_png']}"} ) conversation.append({"role": "user", "content": feedback_content}) return { "task": task, "seed": seed, "turns": results, "best_ssim": max(r["ssim"] for r in results), "solved": any(r["ssim"] >= 0.99 for r in results), } # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- async def run(url: str, seeds: list[int], max_turns: int): OUTPUT_DIR.mkdir(parents=True, exist_ok=True) all_results = {} async with websockets.connect(url) as ws: for i, seed in enumerate(seeds): label = f"episode_{i+1}" print(f"\n{'='*60}") print(f" [{label.upper()}] seed={seed}") print(f"{'='*60}") result = await run_episode( ws, seed, OUTPUT_DIR / label, max_turns, ) all_results[label] = result # Summary print(f"\n{'='*60}") print(" SUMMARY") print(f"{'='*60}") for label, data in all_results.items(): best = data["best_ssim"] solved = "YES" if data["solved"] else "no" turns_used = len(data["turns"]) print(f" {label}: task={data['task']} best_ssim={best:.4f} " f"solved={solved} turns={turns_used}") out_path = OUTPUT_DIR / "results.json" with open(out_path, "w") as f: json.dump(all_results, f, indent=2) print(f"\nResults saved to {out_path}") def main(): parser = argparse.ArgumentParser(description="Benchmark GPT 5.4 on shader env") parser.add_argument("--url", default="ws://localhost:8000/ws", help="WebSocket URL of the shader environment") parser.add_argument("--turns", type=int, default=10, help="Max turns per episode") parser.add_argument("--seeds", type=int, nargs="+", default=[1, 2, 3], help="Seeds for reproducible task selection (one episode per seed)") args = parser.parse_args() asyncio.run(run(args.url, args.seeds, args.turns)) if __name__ == "__main__": main()