"""VisionCoder OpenEnv Environment — multi-step, session-aware.""" from __future__ import annotations import base64 import io import uuid from dataclasses import dataclass, field from typing import Dict, Optional from PIL import Image from openenv.models import Action, Observation, RenderRequest, RenderResponse, State from openenv.dataset import load_websight_dataset from openenv.server.rewards.color_rewards import color_reward from openenv.server.rewards.format_rewards import format_reward from openenv.server.rewards.position_rewards import position_reward from openenv.server.rewards.ssim_reward import ssim_reward from openenv.server.rewards.structural_rewards import structural_similarity_reward from openenv.server.rewards.text_block_rewards import text_block_reward from openenv.server.rewards.validity_rewards import html_validity_reward from openenv.server.rewards import extract_html from openenv.server.rewards.visual_rewards import _render_html, clip_visual_reward DEFAULT_MAX_STEPS = 5 DEFAULT_LOW_RES = (320, 240) DEFAULT_FULL_RES = (640, 480) REWARD_WEIGHTS = { "format": 0.5, # was 1.0 — saturates to 1.0 after early training; reduce weight "validity": 0.5, # was 1.0 — saturates quickly; reduce weight "structural": 0.5, # unchanged — inflated by inline-style refs "text_block": 3.0, # unchanged — most discriminative, blank/wrong layout → 0 "position": 1.0, # unchanged "color": 1.5, # was 1.0 — increased for near-perfect sensitivity "clip": 2.5, # was 2.0 — most continuous signal at top, increase "ssim": 1.5, # new — pixel-level SSIM, fills variance gap in 0.7-0.97 zone } _WEIGHT_SUM = sum(REWARD_WEIGHTS.values()) # 11.0 LOW_RES = DEFAULT_LOW_RES # module-level alias kept for external imports FULL_RES = DEFAULT_FULL_RES DIFFICULTY_PROMPTS = { "easy": ( "You are a UI-to-code assistant. Given a screenshot of a simple website, " "generate complete HTML with inline CSS. Output only raw HTML." ), "medium": ( "You are a UI-to-code assistant. Given a screenshot of a website with navigation " "and multiple sections, generate complete HTML with inline CSS. Output only raw HTML." ), "hard": ( "You are a UI-to-code assistant. Given a screenshot of a complex website with forms, " "tables, and rich layout, generate complete HTML with inline CSS. Output only raw HTML." ), } _DEFAULT_PROMPT = DIFFICULTY_PROMPTS["medium"] def _image_to_b64(image: Image.Image, size: Optional[tuple] = None) -> str: if size is not None: image = image.resize(size, Image.LANCZOS) buf = io.BytesIO() image.save(buf, format="PNG") return base64.b64encode(buf.getvalue()).decode() @dataclass class _Session: episode_id: str session_id: str difficulty: str sample: dict ref_image: Image.Image max_steps: int step_count: int = 0 sample_index: int = 0 class VisionCoderEnvironment: """Multi-step, session-aware OpenEnv environment for screenshot-to-HTML generation. Each reset() creates an independent session identified by session_id. step() accepts session_id in the Action and allows up to max_steps turns per episode before returning done=True. step() returns render_low and render_full (base64 PNG) alongside the reward so the Developer agent can inspect its render without an extra /render call. Args: max_steps: Default max developer turns per episode (overridable per reset). low_res: Resolution for the low-res preview returned to the Developer. full_res: Resolution for reward computation and Critic renders. max_samples: Max dataset samples to load per difficulty. """ def __init__( self, max_steps: int = DEFAULT_MAX_STEPS, low_res: tuple = DEFAULT_LOW_RES, full_res: tuple = DEFAULT_FULL_RES, max_samples: int = 2000, ): self._default_max_steps = max_steps self._low_res = low_res self._full_res = full_res self._max_samples = max_samples self._datasets: Dict[str, list] = {} self._dataset_indices: Dict[str, int] = {"easy": 0, "medium": 0, "hard": 0, "mixed": 0} self._sessions: Dict[str, _Session] = {} self._last_session_id: Optional[str] = None # backward-compat fallback # ------------------------------------------------------------------ # Dataset helpers # ------------------------------------------------------------------ def _get_dataset(self, difficulty: str) -> list: key = difficulty if difficulty in ("easy", "medium", "hard") else "mixed" if key not in self._datasets: self._datasets[key] = load_websight_dataset( max_samples=self._max_samples, difficulty=key if key != "mixed" else None, ) return self._datasets[key] # ------------------------------------------------------------------ # OpenEnv interface # ------------------------------------------------------------------ def reset(self, difficulty: str = "mixed", max_steps: Optional[int] = None) -> Observation: """Start a new episode. Returns session_id and the reference screenshot. Args: difficulty: Task difficulty — easy | medium | hard | mixed. max_steps: Override max turns for this episode; uses env default when None. """ episode_max_steps = max_steps if max_steps is not None else self._default_max_steps dataset = self._get_dataset(difficulty) key = difficulty if difficulty in ("easy", "medium", "hard") else "mixed" idx = self._dataset_indices[key] sample = dataset[idx] self._dataset_indices[key] = (idx + 1) % len(dataset) session_id = str(uuid.uuid4()) episode_id = str(uuid.uuid4()) ref_image = _render_html(sample["solution"]) if ref_image is None: ref_image = Image.new("RGB", self._full_res, color=(255, 255, 255)) session = _Session( episode_id=episode_id, session_id=session_id, difficulty=difficulty, sample={**sample, "image": ref_image}, ref_image=ref_image, max_steps=episode_max_steps, sample_index=idx, ) self._sessions[session_id] = session self._last_session_id = session_id return Observation( done=False, session_id=session_id, screenshot_b64=_image_to_b64(ref_image), prompt=DIFFICULTY_PROMPTS.get(difficulty, _DEFAULT_PROMPT), metadata={ "episode_id": episode_id, "session_id": session_id, "sample_index": idx, "difficulty": difficulty, "max_steps": episode_max_steps, "low_res": list(self._low_res), "full_res": list(self._full_res), }, ) def step(self, action: Action) -> Observation: """Score submitted HTML and return reward + rendered images. Uses action.session_id to look up the episode. Falls back to the most recently created session when session_id is omitted (single-client compat). Returns done=True when step_count reaches MAX_STEPS. """ session_id = action.session_id or self._last_session_id if session_id is None or session_id not in self._sessions: raise RuntimeError("No active session. Call reset() first.") session = self._sessions[session_id] session.step_count += 1 done = session.step_count >= session.max_steps completions = [[{"content": action.html}]] images = [session.ref_image] solutions = [session.sample["solution"]] fmt = format_reward(completions)[0] val = html_validity_reward(completions)[0] struct = structural_similarity_reward(completions, solution=solutions)[0] tb = text_block_reward(completions, solution=solutions)[0] pos = position_reward(completions, solution=solutions)[0] ref_w, ref_h = session.ref_image.size pred_render = _render_html(extract_html(action.html), width=ref_w, height=ref_h) if pred_render is None: pred_render = Image.new("RGB", (ref_w, ref_h), color=(255, 255, 255)) pred_renders = [pred_render] col = color_reward(completions, image=images, pred_image=pred_renders)[0] clip = clip_visual_reward(completions, image=images, pred_image=pred_renders)[0] ssim = ssim_reward(completions, image=images, pred_image=pred_renders)[0] raw_total = ( REWARD_WEIGHTS["format"] * fmt + REWARD_WEIGHTS["validity"] * val + REWARD_WEIGHTS["structural"] * struct + REWARD_WEIGHTS["text_block"] * tb + REWARD_WEIGHTS["position"] * pos + REWARD_WEIGHTS["color"] * col + REWARD_WEIGHTS["clip"] * clip + REWARD_WEIGHTS["ssim"] * ssim ) total = raw_total / _WEIGHT_SUM return Observation( done=done, reward=total, session_id=session_id, render_low=_image_to_b64(pred_render, size=self._low_res), render_full=_image_to_b64(pred_render, size=self._full_res), metadata={ "episode_id": session.episode_id, "session_id": session_id, "step_count": session.step_count, "difficulty": session.difficulty, "max_steps": session.max_steps, "rewards": { "format": fmt, "validity": val, "structural": struct, "text_block": tb, "position": pos, "color": col, "clip": clip, "ssim": ssim, "total": total, }, }, ) def render(self, request: RenderRequest) -> RenderResponse: """Render HTML to images without computing rewards. Used by the Developer agent's render() tool call to self-check mid-generation without consuming an episode step. """ image = _render_html(extract_html(request.html)) if image is None: image = Image.new("RGB", self._full_res, color=(255, 255, 255)) return RenderResponse( image_b64=_image_to_b64(image), image_low_b64=_image_to_b64(image, size=LOW_RES), ) @property def state(self) -> State: """Return metadata for the most recently created session.""" if self._last_session_id and self._last_session_id in self._sessions: s = self._sessions[self._last_session_id] return State( episode_id=s.episode_id, session_id=s.session_id, step_count=s.step_count, sample_index=s.sample_index, max_steps=s.max_steps, ) return State()