File size: 5,243 Bytes
67f71c2 923d2e5 67f71c2 4189d2a 67f71c2 4189d2a 67f71c2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | """
Shader environment implementation.
Reference-image-conditioned shader generation with SSIM reward.
"""
import random
from base64 import b64encode
from io import BytesIO
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
from PIL import Image
try:
from ..models import ShaderAction, ShaderObservation
from ..tasks import Task, load as load_tasks, CURATED, CURATED_BY_NAME
from ..reward import ssim
from ..harness import render as harness_render
except ImportError:
from models import ShaderAction, ShaderObservation
from tasks import Task, load as load_tasks, CURATED, CURATED_BY_NAME
from reward import ssim
from harness import render as harness_render
class ShaderEnvironment(Environment):
"""
OpenEnv environment for shader generation against a reference image.
Each episode picks a task (shader with known ground-truth code), renders
the reference frame, then challenges the agent to reproduce it via GLSL.
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
_tasks: list[Task] | None = None
def __init__(self, budget: int = 10, seed: int | None = None):
self._budget = budget
self._rng = random.Random(seed)
if ShaderEnvironment._tasks is None:
try:
ShaderEnvironment._tasks = load_tasks()
except FileNotFoundError:
# Corpus not available — use curated tasks only
ShaderEnvironment._tasks = list(CURATED)
self._state = State(episode_id=None, step_count=0)
# Episode state
self._task: Task | None = None
self._ref: bytes | None = None
self._remaining: int = 0
def reset(self, seed: int | None = None, episode_id: str | None = None,
task: str | None = None, **kwargs) -> ShaderObservation:
"""Start a new episode.
Args:
seed: RNG seed for reproducible task selection.
episode_id: Optional episode identifier.
task: If given, select a curated task by name (e.g. "gradient",
"rings", "mandelbrot"). Otherwise picks randomly from the
full corpus.
"""
if seed is not None:
self._rng = random.Random(seed)
self._state = State(
episode_id=episode_id or str(uuid4())[:8],
step_count=0,
)
self._remaining = self._budget
# Select task(s) to try
if task and task in CURATED_BY_NAME:
candidates = [CURATED_BY_NAME[task]]
else:
indices = list(range(len(self._tasks)))
self._rng.shuffle(indices)
candidates = [self._tasks[i] for i in indices]
for t in candidates:
result = harness_render(
t.code,
resolution=t.resolution,
time=t.time,
)
if result.compiled and result.rendered and result.frame:
self._task = t
self._ref = result.frame
return ShaderObservation(
task=t.name,
remaining=self._remaining,
reference_png=self._encode(
result.frame, result.width, result.height,
),
done=False,
reward=None,
)
raise RuntimeError("no task rendered successfully")
def step(self, action: ShaderAction) -> ShaderObservation:
"""Render agent's GLSL, compute SSIM vs reference, return observation."""
if self._task is None:
raise RuntimeError("call reset() before step()")
self._state.step_count += 1
self._remaining -= 1
result = harness_render(
action.code,
resolution=self._task.resolution,
time=self._task.time,
)
# Compute reward in (0, 1) exclusive — validator rejects 0.0 and 1.0
if result.compiled and result.rendered and result.frame:
score = ssim(
self._ref, result.frame,
result.width, result.height,
)
png = self._encode(
result.frame, result.width, result.height,
)
else:
score = 0.0
png = ""
score = min(max(score, 0.01), 0.99)
done = self._remaining <= 0 or score >= 0.99
return ShaderObservation(
task=self._task.name,
remaining=self._remaining,
reference_png="",
compiled=result.compiled,
rendered=result.rendered,
errors=result.errors,
agent_png=png,
ssim=score,
done=done,
reward=score,
)
@property
def state(self) -> State:
return self._state
@staticmethod
def _encode(rgba: bytes, width: int, height: int) -> str:
"""Convert raw RGBA bytes to a base64-encoded PNG string."""
img = Image.frombytes("RGBA", (width, height), rgba)
buf = BytesIO()
img.save(buf, format="PNG")
return b64encode(buf.getvalue()).decode("ascii")
|