html-gym-env / server /html_design_agent_environment.py
AhanR's picture
Upload folder using huggingface_hub
f5cd640 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Html Design Agent Environment Implementation.
Renders HTML with a headless Playwright browser and scores it on four
dimensions from AGENTS.md:
R = 0.25*R_branding + 0.25*R_spacing + 0.25*R_a11y + 0.25*R_composition
Three tasks of increasing difficulty:
- level1_accessibility : (easy) add missing alt/aria-label/label attributes
- level2_spacing : (medium) fix off-grid spacing to the 8pt system
- level3_contrast : (hard) fix colours, spacing AND accessibility together
"""
from __future__ import annotations
import json
import math
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
try:
from ..models import HtmlDesignAgentAction, HtmlDesignAgentObservation
except ImportError:
from models import HtmlDesignAgentAction, HtmlDesignAgentObservation
# ---------------------------------------------------------------------------
# Brand design tokens (single source of truth for all tasks)
# ---------------------------------------------------------------------------
DESIGN_TOKENS: Dict[str, Any] = {
"palette": {
"primary": "#1A1A2E", # dark navy – headings, body text
"accent": "#E94560", # brand red – CTAs, highlights
"white": "#FFFFFF", # page background, reversed text
"surface": "#F0F0F5", # card backgrounds
"muted": "#646478", # secondary / caption text
},
"palette_rgb": {
"primary": (26, 26, 46),
"accent": (233, 69, 96),
"white": (255, 255, 255),
"surface": (240, 240, 245),
"muted": (100, 100, 120),
},
"fonts": ["Inter", "Roboto", "Open Sans", "system-ui", "sans-serif", "-apple-system"],
"spacing_scale": [0, 4, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 96, 128],
"min_contrast_ratio": 4.5,
"max_color_delta_e": 2.0,
}
# ---------------------------------------------------------------------------
# Task definitions
# ---------------------------------------------------------------------------
@dataclass
class TaskDefinition:
task_id: str
description: str
difficulty: str # "easy" | "medium" | "hard"
broken_html: str
done_threshold: float # episode ends when total reward >= this
TASKS: Dict[str, TaskDefinition] = {
# ------------------------------------------------------------------
# Level 1 – level1_accessibility (EASY)
# Only structural HTML attributes are missing: alt, aria-label, labels.
# Colours and spacing are already correct — agent only needs to add
# missing attributes, no design knowledge required.
# done_threshold is lenient (0.80) because the scorer also considers
# composition/branding which are already correct.
# ------------------------------------------------------------------
"level1_accessibility": TaskDefinition(
task_id="level1_accessibility",
description=(
"Easy — fix accessibility only: add alt text to images, "
"aria-label to icon buttons, and <label> to form inputs. "
"Colours and spacing are already correct."
),
difficulty="easy",
done_threshold=0.80,
broken_html="""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<style>
*, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
body { font-family: Inter, sans-serif; padding: 32px; background: #ffffff; color: #1A1A2E; }
h1 { font-size: 1.75rem; margin-bottom: 24px; color: #1A1A2E; }
.form-group{ margin-bottom: 16px; }
label { display: block; font-size: 0.875rem; margin-bottom: 8px; color: #646478; }
input { display: block; width: 100%; padding: 8px 16px;
border: 1px solid #cccccc; border-radius: 4px;
font-size: 1rem; color: #1A1A2E; }
.btn { background: #E94560; color: #ffffff; padding: 12px 24px;
border: none; border-radius: 4px; font-size: 1rem; cursor: pointer; }
.icon-btn { background: #1A1A2E; color: #ffffff; padding: 8px 12px;
border: none; border-radius: 4px; cursor: pointer; margin-left: 8px; }
</style>
</head>
<body>
<!-- VIOLATION: img missing alt attribute -->
<img src="data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAAC0lEQVQI12NgAAIABQAABjE+ibYAAAAASUVORK5CYII="
style="display:block;width:100%;height:120px;border-radius:8px;margin-bottom:24px;" />
<h1>Contact Us</h1>
<form>
<!-- VIOLATION: inputs missing associated <label for="…"> -->
<div class="form-group">
<input id="name" type="text" placeholder="Your name" />
</div>
<div class="form-group">
<input id="email" type="email" placeholder="Email address" />
</div>
<div class="form-group">
<input id="subject" type="text" placeholder="Subject" />
</div>
<button class="btn" type="submit">Send Message</button>
<!-- VIOLATION: icon buttons missing aria-label -->
<button class="icon-btn" type="button">&#x2715;</button>
<button class="icon-btn" type="button">&#x2191;</button>
</form>
</body>
</html>""",
),
# ------------------------------------------------------------------
# Level 2 – level2_spacing (MEDIUM)
# Colours are correct (brand palette used throughout) but every
# spacing value is off-grid (13 px, 5 px, 7 px, 10 px, 22 px …).
# Agent must understand the 8pt grid system and fix all padding /
# margin / gap values while keeping colours and accessibility intact.
# ------------------------------------------------------------------
"level2_spacing": TaskDefinition(
task_id="level2_spacing",
description=(
"Medium — fix spacing only: all padding/margin/gap values must "
"be multiples of 8 (0,8,16,24,32,48,64…). "
"Colours and accessibility are already correct."
),
difficulty="medium",
done_threshold=0.85,
broken_html="""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<style>
*, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
/* VIOLATIONS: every spacing value below is off the 8pt grid */
body { font-family: Inter, sans-serif; padding: 13px; background: #ffffff; color: #1A1A2E; }
.header { padding: 10px 15px; background: #1A1A2E; color: #ffffff; margin-bottom: 5px; }
.nav { display: flex; gap: 7px; padding: 6px 15px; background: #F0F0F5; }
.nav a { color: #1A1A2E; text-decoration: none; padding: 5px 10px; }
.hero { padding: 30px 15px; text-align: center; }
.hero h1{ color: #1A1A2E; font-size: 2rem; margin-bottom: 10px; }
.hero p { color: #646478; margin-bottom: 22px; }
.cta { background: #E94560; color: #ffffff; padding: 13px 27px;
border: none; border-radius: 4px; font-size: 1rem; cursor: pointer; }
.cards { display: flex; gap: 10px; padding: 0 15px; margin-top: 18px; }
.card { background: #F0F0F5; padding: 18px; border-radius: 6px; flex: 1; color: #1A1A2E; }
</style>
</head>
<body>
<div class="header"><strong>BrandName</strong></div>
<nav class="nav" aria-label="Main navigation">
<a href="#">Home</a><a href="#">About</a>
<a href="#">Services</a><a href="#">Contact</a>
</nav>
<section class="hero">
<h1>Build Faster, Ship Smarter</h1>
<p>The platform trusted by 10,000+ developers worldwide.</p>
<button class="cta">Start Free Trial</button>
</section>
<div class="cards">
<div class="card">Fast Performance</div>
<div class="card">99.9% Uptime</div>
<div class="card">24/7 Support</div>
</div>
</body>
</html>""",
),
# ------------------------------------------------------------------
# Level 3 – level3_contrast (HARD)
# Everything is broken simultaneously: low-contrast colours (#CCC on
# white), off-grid spacing, AND missing accessibility attributes.
# The agent must fix all three dimensions at once — requiring colour
# knowledge (WCAG 4.5:1, brand palette), grid alignment (8pt), and
# structural HTML fixes (alt, aria-label, labels).
# ------------------------------------------------------------------
"level3_contrast": TaskDefinition(
task_id="level3_contrast",
description=(
"Hard — fix everything: low-contrast colours (use brand palette), "
"off-grid spacing (multiples of 8), and missing accessibility "
"attributes (alt, aria-label, labels). All four reward dimensions active."
),
difficulty="hard",
done_threshold=0.88,
broken_html="""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<style>
*, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
/* COLOUR VIOLATIONS: near-white text on white background */
body { background: #ffffff; font-family: Inter, sans-serif; padding: 13px; }
.header { padding: 10px 15px; background: #dddddd; color: #eeeeee; margin-bottom: 5px; }
h1 { color: #cccccc; font-size: 2rem; margin-bottom: 10px; }
/* SPACING VIOLATIONS: 13px, 5px, 7px, 22px are off-grid */
.hero { padding: 30px 13px; text-align: center; }
.hero p { color: #dddddd; margin-bottom: 22px; }
.cta { background: #eeeeee; color: #cccccc; padding: 13px 27px;
border: none; border-radius: 4px; font-size: 1rem; cursor: pointer; }
.cards { display: flex; gap: 7px; padding: 0 13px; margin-top: 18px; }
.card { background: #f5f5f5; padding: 18px; border-radius: 6px;
flex: 1; color: #cccccc; }
.card-title { font-size: 1rem; margin-bottom: 5px; color: #bbbbbb; }
/* FORM */
.form-section { padding: 22px 13px; }
.form-group { margin-bottom: 10px; }
input { display: block; width: 100%; padding: 9px 14px;
border: 1px solid #eeeeee; border-radius: 4px;
font-size: 1rem; color: #cccccc; }
.icon-btn { background: #dddddd; color: #eeeeee; padding: 9px 14px;
border: none; border-radius: 4px; cursor: pointer; margin-left: 5px; }
</style>
</head>
<body>
<!-- VIOLATION: img missing alt -->
<img src="data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAAC0lEQVQI12NgAAIABQAABjE+ibYAAAAASUVORK5CYII="
style="display:block;width:100%;height:120px;border-radius:8px;margin-bottom:22px;" />
<div class="header"><strong>BrandName</strong></div>
<section class="hero">
<h1>Build Faster, Ship Smarter</h1>
<p>The platform trusted by 10,000+ developers worldwide.</p>
<!-- VIOLATION: low-contrast CTA -->
<button class="cta">Start Free Trial</button>
</section>
<div class="cards">
<div class="card"><p class="card-title">Fast Performance</p><p>Sub-10ms latency.</p></div>
<div class="card"><p class="card-title">99.9% Uptime</p><p>Always available.</p></div>
<div class="card"><p class="card-title">24/7 Support</p><p>We've got you covered.</p></div>
</div>
<!-- VIOLATION: inputs missing labels, icon buttons missing aria-label -->
<div class="form-section">
<div class="form-group"><input id="name" type="text" placeholder="Your name" /></div>
<div class="form-group"><input id="email" type="email" placeholder="Email address" /></div>
<button class="cta" type="submit">Send Message</button>
<button class="icon-btn" type="button">&#x2715;</button>
<button class="icon-btn" type="button">&#x2191;</button>
</div>
</body>
</html>""",
),
}
# ---------------------------------------------------------------------------
# Colour-science helpers
# ---------------------------------------------------------------------------
def _srgb_to_linear(c: int) -> float:
v = c / 255.0
return v / 12.92 if v <= 0.04045 else ((v + 0.055) / 1.055) ** 2.4
def _relative_luminance(r: int, g: int, b: int) -> float:
return (
0.2126 * _srgb_to_linear(r)
+ 0.7152 * _srgb_to_linear(g)
+ 0.0722 * _srgb_to_linear(b)
)
def _wcag_contrast(rgb1: Tuple[int, int, int], rgb2: Tuple[int, int, int]) -> float:
l1 = _relative_luminance(*rgb1)
l2 = _relative_luminance(*rgb2)
lighter, darker = max(l1, l2), min(l1, l2)
return (lighter + 0.05) / (darker + 0.05)
def _xyz_to_lab(x: float, y: float, z: float) -> Tuple[float, float, float]:
x, y, z = x / 0.95047, y / 1.00000, z / 1.08883
def f(t: float) -> float:
return t ** (1 / 3) if t > 0.008856 else 7.787 * t + 16 / 116
return 116 * f(y) - 16, 500 * (f(x) - f(y)), 200 * (f(y) - f(z))
def _rgb_to_lab(r: int, g: int, b: int) -> Tuple[float, float, float]:
lr, lg, lb = _srgb_to_linear(r), _srgb_to_linear(g), _srgb_to_linear(b)
x = lr * 0.4124564 + lg * 0.3575761 + lb * 0.1804375
y = lr * 0.2126729 + lg * 0.7151522 + lb * 0.0721750
z = lr * 0.0193339 + lg * 0.1191920 + lb * 0.9503041
return _xyz_to_lab(x, y, z)
def _delta_e(rgb1: Tuple[int, int, int], rgb2: Tuple[int, int, int]) -> float:
L1, a1, b1 = _rgb_to_lab(*rgb1)
L2, a2, b2 = _rgb_to_lab(*rgb2)
return math.sqrt((L1 - L2) ** 2 + (a1 - a2) ** 2 + (b1 - b2) ** 2)
def _closest_brand_delta_e(rgb: Tuple[int, int, int]) -> float:
palette = DESIGN_TOKENS["palette_rgb"]
return min(_delta_e(rgb, brand_rgb) for brand_rgb in palette.values())
# ---------------------------------------------------------------------------
# JavaScript snippets executed inside the headless page
# ---------------------------------------------------------------------------
_JS_GET_TEXT_COLORS = """
() => {
function parseRgb(s) {
const m = s.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);
return m ? [+m[1], +m[2], +m[3]] : null;
}
const sel = 'h1,h2,h3,h4,h5,h6,p,span,a,button,label,li,td,th,caption,small,strong,em';
const result = [];
document.querySelectorAll(sel).forEach(el => {
const s = window.getComputedStyle(el);
const fg = parseRgb(s.color);
const bg = parseRgb(s.backgroundColor);
if (fg && bg) result.push({fg, bg, tag: el.tagName});
});
return result;
}
"""
_JS_GET_SPACING_VIOLATIONS = """
() => {
const GRID = 8;
const violations = [];
let total = 0;
document.querySelectorAll('*').forEach(el => {
const s = window.getComputedStyle(el);
const props = [
'paddingTop','paddingRight','paddingBottom','paddingLeft',
'marginTop','marginRight','marginBottom','marginLeft',
];
if (s.display === 'flex' || s.display === 'grid') {
props.push('gap','rowGap','columnGap');
}
props.forEach(p => {
const v = parseFloat(s[p]);
if (!isNaN(v) && v > 0) {
total++;
if (v % GRID !== 0) violations.push({tag: el.tagName, prop: p, value: v});
}
});
});
return {violations, total};
}
"""
_JS_GET_A11Y_ISSUES = """
() => {
const issues = [];
// Images without alt attribute
document.querySelectorAll('img').forEach(img => {
if (!img.hasAttribute('alt'))
issues.push('img missing alt attribute');
});
// Buttons without accessible name
document.querySelectorAll('button').forEach(btn => {
const name = (btn.textContent || '').trim()
|| btn.getAttribute('aria-label')
|| btn.getAttribute('title');
if (!name) issues.push('button missing accessible name (add text or aria-label)');
});
// Inputs without label
const inputSel = 'input:not([type=hidden]):not([type=submit]):not([type=button]):not([type=reset])';
document.querySelectorAll(inputSel).forEach(input => {
const id = input.id;
const hasLabel = id && document.querySelector('label[for="' + id + '"]');
const hasAria = input.getAttribute('aria-label') || input.getAttribute('aria-labelledby');
if (!hasLabel && !hasAria)
issues.push('input[type=' + (input.type || 'text') + '] missing associated label');
});
return issues;
}
"""
_JS_GET_COMPOSITION = """
() => {
const vw = window.innerWidth, vh = window.innerHeight;
const mx = vw / 2, my = vh / 2;
let lw = 0, rw = 0, tw = 0, bw = 0;
document.querySelectorAll('*').forEach(el => {
const r = el.getBoundingClientRect();
if (r.width === 0 || r.height === 0) return;
const s = window.getComputedStyle(el);
if (s.display === 'none' || s.visibility === 'hidden' || s.opacity === '0') return;
const area = r.width * r.height;
const cx = r.left + r.width / 2;
const cy = r.top + r.height / 2;
if (cx < mx) lw += area; else rw += area;
if (cy < my) tw += area; else bw += area;
});
return {lw, rw, tw, bw};
}
"""
_JS_DOM_SUMMARY = """
() => {
function walk(el, depth) {
if (depth > 3) return null;
const children = Array.from(el.children)
.slice(0, 6)
.map(c => walk(c, depth + 1))
.filter(Boolean);
return {
tag: el.tagName.toLowerCase(),
cls: Array.from(el.classList).slice(0, 3).join(' '),
children,
};
}
return JSON.stringify(walk(document.body, 0));
}
"""
# ---------------------------------------------------------------------------
# Playwright evaluator
# ---------------------------------------------------------------------------
class DesignEvaluator:
"""
Wraps a headless Chromium browser for HTML scoring.
One instance is shared for the lifetime of the environment.
"""
def __init__(self) -> None:
from playwright.sync_api import sync_playwright # lazy import
self._pw = sync_playwright().start()
self._browser = self._pw.chromium.launch(headless=True)
# ------------------------------------------------------------------
def evaluate(
self,
html: str,
task: TaskDefinition,
) -> Tuple[float, Dict[str, float], List[str], str]:
"""
Render *html* and return (total_reward, breakdown, violations, dom_summary).
"""
page = self._browser.new_page(viewport={"width": 1280, "height": 720})
try:
page.set_content(html, wait_until="domcontentloaded")
r_branding, brand_violations = self._score_branding(page)
r_spacing, spacing_violations = self._score_spacing(page)
r_a11y, a11y_violations = self._score_a11y(page)
r_composition, comp_violations = self._score_composition(page)
breakdown = {
"branding": round(r_branding, 3),
"spacing": round(r_spacing, 3),
"a11y": round(r_a11y, 3),
"composition": round(r_composition, 3),
}
total = (r_branding + r_spacing + r_a11y + r_composition) / 4.0
violations = brand_violations + spacing_violations + a11y_violations + comp_violations
try:
dom_summary = page.evaluate(_JS_DOM_SUMMARY) or ""
except Exception:
dom_summary = ""
return round(total, 4), breakdown, violations, dom_summary
finally:
page.close()
# ------------------------------------------------------------------
def _score_branding(self, page: Any) -> Tuple[float, List[str]]:
"""
R_branding – two sub-checks:
1. WCAG contrast ratio >= 4.5:1 for text pairs.
2. Foreground/background colours are within ΔE* 2.0 of brand palette.
Each failing pair contributes a proportional penalty.
"""
violations: List[str] = []
try:
pairs = page.evaluate(_JS_GET_TEXT_COLORS)
except Exception:
return 1.0, []
if not pairs:
return 1.0, []
total = len(pairs)
passed = 0
max_delta = DESIGN_TOKENS["max_color_delta_e"]
min_contrast = DESIGN_TOKENS["min_contrast_ratio"]
for item in pairs:
fg = tuple(item["fg"])
bg = tuple(item["bg"])
tag = item.get("tag", "?")
contrast = _wcag_contrast(fg, bg)
fg_delta = _closest_brand_delta_e(fg)
bg_delta = _closest_brand_delta_e(bg)
ok_contrast = contrast >= min_contrast
ok_fg = fg_delta <= max_delta
ok_bg = bg_delta <= max_delta
if ok_contrast and ok_fg and ok_bg:
passed += 1
else:
if not ok_contrast:
violations.append(
f"<{tag}> contrast {contrast:.1f}:1 < {min_contrast}:1 "
f"(fg={fg}, bg={bg})"
)
if not ok_fg:
violations.append(
f"<{tag}> text colour {fg} not in brand palette (ΔE*={fg_delta:.1f})"
)
score = passed / total if total else 1.0
return score, violations
# ------------------------------------------------------------------
def _score_spacing(self, page: Any) -> Tuple[float, List[str]]:
"""
R_spacing – all padding/margin/gap values must be multiples of 8.
+1.0 for 100 % compliance; -0.1 per off-grid value (clamped to 0).
"""
violations: List[str] = []
try:
result = page.evaluate(_JS_GET_SPACING_VIOLATIONS)
except Exception:
return 1.0, []
total = result.get("total", 0)
bad_items = result.get("violations", [])
if total == 0:
return 1.0, []
for item in bad_items[:10]: # cap reported violations
violations.append(
f"<{item['tag']}> {item['prop']}: {item['value']}px is not a multiple of 8"
)
score = max(0.0, 1.0 - (len(bad_items) / total))
return score, violations
# ------------------------------------------------------------------
def _score_a11y(self, page: Any) -> Tuple[float, List[str]]:
"""
R_a11y – checks:
• All <img> have alt attributes.
• All <button> have accessible names.
• All visible <input> have associated <label> or aria-label.
Score = 1 - (violations / checks_performed).
"""
try:
issues = page.evaluate(_JS_GET_A11Y_ISSUES)
except Exception:
return 1.0, []
# Count total checkable elements
try:
img_count = page.evaluate("() => document.querySelectorAll('img').length")
btn_count = page.evaluate("() => document.querySelectorAll('button').length")
inp_count = page.evaluate(
"() => document.querySelectorAll("
"'input:not([type=hidden]):not([type=submit]):not([type=button]):not([type=reset])'"
").length"
)
total_checks = img_count + btn_count + inp_count
except Exception:
total_checks = max(len(issues), 1)
if total_checks == 0:
return 1.0, []
score = max(0.0, 1.0 - len(issues) / total_checks)
return score, list(issues)
# ------------------------------------------------------------------
def _score_composition(self, page: Any) -> Tuple[float, List[str]]:
"""
R_composition – visual weight must not exceed 60/40 between
left/right halves or top/bottom halves (unless hero layout).
"""
violations: List[str] = []
try:
weights = page.evaluate(_JS_GET_COMPOSITION)
except Exception:
return 1.0, []
lw, rw = weights.get("lw", 0), weights.get("rw", 0)
tw, bw = weights.get("tw", 0), weights.get("bw", 0)
def _balance_score(a: float, b: float) -> float:
total = a + b
if total == 0:
return 1.0
ratio = max(a, b) / total # 0.5 = perfect, 1.0 = all one side
if ratio <= 0.60:
return 1.0
# linearly penalise from 0.60 → 1.0
return max(0.0, 1.0 - (ratio - 0.60) / 0.40)
lr_score = _balance_score(lw, rw)
tb_score = _balance_score(tw, bw)
if lr_score < 1.0:
violations.append(
f"Left/right imbalance: {lw/(lw+rw)*100:.0f}% / {rw/(lw+rw)*100:.0f}% "
f"(max 60/40 allowed)"
)
if tb_score < 1.0:
violations.append(
f"Top/bottom imbalance: {tw/(tw+bw)*100:.0f}% / {bw/(tw+bw)*100:.0f}% "
f"(max 60/40 allowed)"
)
score = (lr_score + tb_score) / 2.0
return score, violations
# ------------------------------------------------------------------
def close(self) -> None:
try:
self._browser.close()
self._pw.stop()
except Exception:
pass
# ---------------------------------------------------------------------------
# Environment
# ---------------------------------------------------------------------------
class HtmlDesignAgentEnvironment(Environment):
"""
RL environment for HTML design quality.
On reset() the agent receives the task's broken HTML as the initial
observation. On step() it submits improved HTML and gets a scored
observation back. Episode ends when R >= done_threshold or after
MAX_STEPS steps.
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
MAX_STEPS: int = 20
def __init__(self) -> None:
task_id = os.getenv("HTML_DESIGN_AGENT_TASK", "level1_accessibility")
if task_id not in TASKS:
task_id = "level1_accessibility"
self._task: TaskDefinition = TASKS[task_id]
self._state = State(episode_id=str(uuid4()), step_count=0)
self._evaluator: Optional[DesignEvaluator] = None
# ------------------------------------------------------------------
def _get_evaluator(self) -> DesignEvaluator:
if self._evaluator is None:
self._evaluator = DesignEvaluator()
return self._evaluator
# ------------------------------------------------------------------
def reset(self) -> HtmlDesignAgentObservation:
# Re-read task in case env var changed between episodes
task_id = os.getenv("HTML_DESIGN_AGENT_TASK", "level1_accessibility")
if task_id not in TASKS:
task_id = "level1_accessibility"
self._task = TASKS[task_id]
self._state = State(episode_id=str(uuid4()), step_count=0)
return HtmlDesignAgentObservation(
task_id=self._task.task_id,
current_html=self._task.broken_html,
design_tokens=_serialisable_tokens(),
reward_breakdown={},
violations=[],
dom_summary="",
step_count=0,
done=False,
reward=0.0,
)
# ------------------------------------------------------------------
def step(self, action: HtmlDesignAgentAction) -> HtmlDesignAgentObservation: # type: ignore[override]
self._state.step_count += 1
step = self._state.step_count
evaluator = self._get_evaluator()
reward, breakdown, violations, dom_summary = evaluator.evaluate(
action.html, self._task
)
done = reward >= self._task.done_threshold or step >= self.MAX_STEPS
return HtmlDesignAgentObservation(
task_id=self._task.task_id,
current_html=action.html,
design_tokens=_serialisable_tokens(),
reward_breakdown=breakdown,
violations=violations,
dom_summary=dom_summary,
step_count=step,
done=done,
reward=reward,
metadata={
"step": step,
"task_id": self._task.task_id,
"breakdown": breakdown,
},
)
# ------------------------------------------------------------------
@property
def state(self) -> State:
return self._state
# ------------------------------------------------------------------
def __del__(self) -> None:
if self._evaluator is not None:
try:
self._evaluator.close()
except Exception:
pass
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _serialisable_tokens() -> Dict[str, Any]:
"""Return design tokens without the rgb tuples (not JSON-serialisable)."""
tokens = {k: v for k, v in DESIGN_TOKENS.items() if k != "palette_rgb"}
return tokens