File size: 11,241 Bytes
cf6c0e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae3dc3c
 
 
cf6c0e0
 
 
 
 
 
 
 
 
 
 
 
 
ae3dc3c
 
cf6c0e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae3dc3c
cf6c0e0
 
 
 
 
 
 
 
ae3dc3c
cf6c0e0
 
 
 
ae3dc3c
 
 
 
 
 
cf6c0e0
 
ae3dc3c
 
 
 
 
 
 
 
 
 
cf6c0e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae3dc3c
 
 
 
 
 
 
 
 
cf6c0e0
 
 
 
 
 
 
 
 
 
 
 
ae3dc3c
cf6c0e0
 
 
 
 
 
 
ae3dc3c
cf6c0e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae3dc3c
 
 
cf6c0e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae3dc3c
cf6c0e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae3dc3c
 
cf6c0e0
 
 
 
 
ae3dc3c
cf6c0e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae3dc3c
cf6c0e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae3dc3c
cf6c0e0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
"""VisionCoder OpenEnv Environment — multi-step, session-aware."""
from __future__ import annotations

import base64
import io
import uuid
from dataclasses import dataclass, field
from typing import Dict, Optional

from PIL import Image

from openenv.models import Action, Observation, RenderRequest, RenderResponse, State
from openenv.dataset import load_websight_dataset
from openenv.server.rewards.color_rewards import color_reward
from openenv.server.rewards.format_rewards import format_reward
from openenv.server.rewards.position_rewards import position_reward
from openenv.server.rewards.ssim_reward import ssim_reward
from openenv.server.rewards.structural_rewards import structural_similarity_reward
from openenv.server.rewards.text_block_rewards import text_block_reward
from openenv.server.rewards.validity_rewards import html_validity_reward
from openenv.server.rewards import extract_html
from openenv.server.rewards.visual_rewards import _render_html, clip_visual_reward

DEFAULT_MAX_STEPS = 5
DEFAULT_LOW_RES  = (320, 240)
DEFAULT_FULL_RES = (640, 480)

REWARD_WEIGHTS = {
    "format":     0.5,   # was 1.0 — saturates to 1.0 after early training; reduce weight
    "validity":   0.5,   # was 1.0 — saturates quickly; reduce weight
    "structural": 0.5,   # unchanged — inflated by inline-style refs
    "text_block": 3.0,   # unchanged — most discriminative, blank/wrong layout → 0
    "position":   1.0,   # unchanged
    "color":      1.5,   # was 1.0 — increased for near-perfect sensitivity
    "clip":       2.5,   # was 2.0 — most continuous signal at top, increase
    "ssim":       1.5,   # new — pixel-level SSIM, fills variance gap in 0.7-0.97 zone
}
_WEIGHT_SUM = sum(REWARD_WEIGHTS.values())  # 11.0

LOW_RES  = DEFAULT_LOW_RES   # module-level alias kept for external imports
FULL_RES = DEFAULT_FULL_RES

DIFFICULTY_PROMPTS = {
    "easy": (
        "You are a UI-to-code assistant. Given a screenshot of a simple website, "
        "generate complete HTML with inline CSS. Output only raw HTML."
    ),
    "medium": (
        "You are a UI-to-code assistant. Given a screenshot of a website with navigation "
        "and multiple sections, generate complete HTML with inline CSS. Output only raw HTML."
    ),
    "hard": (
        "You are a UI-to-code assistant. Given a screenshot of a complex website with forms, "
        "tables, and rich layout, generate complete HTML with inline CSS. Output only raw HTML."
    ),
}
_DEFAULT_PROMPT = DIFFICULTY_PROMPTS["medium"]


def _image_to_b64(image: Image.Image, size: Optional[tuple] = None) -> str:
    if size is not None:
        image = image.resize(size, Image.LANCZOS)
    buf = io.BytesIO()
    image.save(buf, format="PNG")
    return base64.b64encode(buf.getvalue()).decode()


@dataclass
class _Session:
    episode_id: str
    session_id: str
    difficulty: str
    sample: dict
    ref_image: Image.Image
    max_steps: int
    step_count: int = 0
    sample_index: int = 0


class VisionCoderEnvironment:
    """Multi-step, session-aware OpenEnv environment for screenshot-to-HTML generation.

    Each reset() creates an independent session identified by session_id.
    step() accepts session_id in the Action and allows up to max_steps turns
    per episode before returning done=True.

    step() returns render_low and render_full (base64 PNG) alongside the reward
    so the Developer agent can inspect its render without an extra /render call.

    Args:
        max_steps:   Default max developer turns per episode (overridable per reset).
        low_res:     Resolution for the low-res preview returned to the Developer.
        full_res:    Resolution for reward computation and Critic renders.
        max_samples: Max dataset samples to load per difficulty.
    """

    def __init__(
        self,
        max_steps: int = DEFAULT_MAX_STEPS,
        low_res: tuple = DEFAULT_LOW_RES,
        full_res: tuple = DEFAULT_FULL_RES,
        max_samples: int = 2000,
    ):
        self._default_max_steps = max_steps
        self._low_res = low_res
        self._full_res = full_res
        self._max_samples = max_samples
        self._datasets: Dict[str, list] = {}
        self._dataset_indices: Dict[str, int] = {"easy": 0, "medium": 0, "hard": 0, "mixed": 0}
        self._sessions: Dict[str, _Session] = {}
        self._last_session_id: Optional[str] = None  # backward-compat fallback

    # ------------------------------------------------------------------
    # Dataset helpers
    # ------------------------------------------------------------------

    def _get_dataset(self, difficulty: str) -> list:
        key = difficulty if difficulty in ("easy", "medium", "hard") else "mixed"
        if key not in self._datasets:
            self._datasets[key] = load_websight_dataset(
                max_samples=self._max_samples,
                difficulty=key if key != "mixed" else None,
            )
        return self._datasets[key]

    # ------------------------------------------------------------------
    # OpenEnv interface
    # ------------------------------------------------------------------

    def reset(self, difficulty: str = "mixed", max_steps: Optional[int] = None) -> Observation:
        """Start a new episode. Returns session_id and the reference screenshot.

        Args:
            difficulty: Task difficulty — easy | medium | hard | mixed.
            max_steps:  Override max turns for this episode; uses env default when None.
        """
        episode_max_steps = max_steps if max_steps is not None else self._default_max_steps

        dataset = self._get_dataset(difficulty)
        key = difficulty if difficulty in ("easy", "medium", "hard") else "mixed"

        idx = self._dataset_indices[key]
        sample = dataset[idx]
        self._dataset_indices[key] = (idx + 1) % len(dataset)

        session_id = str(uuid.uuid4())
        episode_id = str(uuid.uuid4())

        ref_image = _render_html(sample["solution"])
        if ref_image is None:
            ref_image = Image.new("RGB", self._full_res, color=(255, 255, 255))

        session = _Session(
            episode_id=episode_id,
            session_id=session_id,
            difficulty=difficulty,
            sample={**sample, "image": ref_image},
            ref_image=ref_image,
            max_steps=episode_max_steps,
            sample_index=idx,
        )
        self._sessions[session_id] = session
        self._last_session_id = session_id

        return Observation(
            done=False,
            session_id=session_id,
            screenshot_b64=_image_to_b64(ref_image),
            prompt=DIFFICULTY_PROMPTS.get(difficulty, _DEFAULT_PROMPT),
            metadata={
                "episode_id": episode_id,
                "session_id": session_id,
                "sample_index": idx,
                "difficulty": difficulty,
                "max_steps": episode_max_steps,
                "low_res": list(self._low_res),
                "full_res": list(self._full_res),
            },
        )

    def step(self, action: Action) -> Observation:
        """Score submitted HTML and return reward + rendered images.

        Uses action.session_id to look up the episode. Falls back to the most
        recently created session when session_id is omitted (single-client compat).

        Returns done=True when step_count reaches MAX_STEPS.
        """
        session_id = action.session_id or self._last_session_id
        if session_id is None or session_id not in self._sessions:
            raise RuntimeError("No active session. Call reset() first.")

        session = self._sessions[session_id]
        session.step_count += 1
        done = session.step_count >= session.max_steps

        completions = [[{"content": action.html}]]
        images = [session.ref_image]
        solutions = [session.sample["solution"]]

        fmt   = format_reward(completions)[0]
        val   = html_validity_reward(completions)[0]
        struct = structural_similarity_reward(completions, solution=solutions)[0]
        tb    = text_block_reward(completions, solution=solutions)[0]
        pos   = position_reward(completions, solution=solutions)[0]

        ref_w, ref_h = session.ref_image.size
        pred_render = _render_html(extract_html(action.html), width=ref_w, height=ref_h)
        if pred_render is None:
            pred_render = Image.new("RGB", (ref_w, ref_h), color=(255, 255, 255))
        pred_renders = [pred_render]

        col  = color_reward(completions, image=images, pred_image=pred_renders)[0]
        clip = clip_visual_reward(completions, image=images, pred_image=pred_renders)[0]
        ssim = ssim_reward(completions, image=images, pred_image=pred_renders)[0]

        raw_total = (
            REWARD_WEIGHTS["format"]     * fmt
            + REWARD_WEIGHTS["validity"]   * val
            + REWARD_WEIGHTS["structural"] * struct
            + REWARD_WEIGHTS["text_block"] * tb
            + REWARD_WEIGHTS["position"]   * pos
            + REWARD_WEIGHTS["color"]      * col
            + REWARD_WEIGHTS["clip"]       * clip
            + REWARD_WEIGHTS["ssim"]       * ssim
        )
        total = raw_total / _WEIGHT_SUM

        return Observation(
            done=done,
            reward=total,
            session_id=session_id,
            render_low=_image_to_b64(pred_render, size=self._low_res),
            render_full=_image_to_b64(pred_render, size=self._full_res),
            metadata={
                "episode_id": session.episode_id,
                "session_id": session_id,
                "step_count": session.step_count,
                "difficulty": session.difficulty,
                "max_steps": session.max_steps,
                "rewards": {
                    "format": fmt,
                    "validity": val,
                    "structural": struct,
                    "text_block": tb,
                    "position": pos,
                    "color": col,
                    "clip": clip,
                    "ssim": ssim,
                    "total": total,
                },
            },
        )

    def render(self, request: RenderRequest) -> RenderResponse:
        """Render HTML to images without computing rewards.

        Used by the Developer agent's render() tool call to self-check
        mid-generation without consuming an episode step.
        """
        image = _render_html(extract_html(request.html))
        if image is None:
            image = Image.new("RGB", self._full_res, color=(255, 255, 255))
        return RenderResponse(
            image_b64=_image_to_b64(image),
            image_low_b64=_image_to_b64(image, size=LOW_RES),
        )

    @property
    def state(self) -> State:
        """Return metadata for the most recently created session."""
        if self._last_session_id and self._last_session_id in self._sessions:
            s = self._sessions[self._last_session_id]
            return State(
                episode_id=s.episode_id,
                session_id=s.session_id,
                step_count=s.step_count,
                sample_index=s.sample_index,
                max_steps=s.max_steps,
            )
        return State()