"""Visual reward: CLIP image-embedding cosine similarity after rendering HTML. Uses openai/clip-vit-base-patch32 on CPU — no GPU required, stays well within the 16 GB HF Spaces memory limit (~1.5 GB total for model + inference). Falls back to PIL pixel-diff if CLIP fails to load. """ from __future__ import annotations import io import logging from typing import Optional from PIL import Image from openenv.server.rewards import extract_html logger = logging.getLogger(__name__) import os as _os _CLIP_MODEL_NAME = _os.path.expanduser("~/models/clip-vit-base-patch32") if _os.path.isdir(_os.path.expanduser("~/models/clip-vit-base-patch32")) else "openai/clip-vit-base-patch32" _clip_model = None _clip_processor = None def _get_clip(): """Lazy singleton — loads CLIP once, reuses across calls.""" global _clip_model, _clip_processor if _clip_model is None: from transformers import CLIPModel, CLIPProcessor logger.info("Loading CLIP model %s …", _CLIP_MODEL_NAME) _clip_model = CLIPModel.from_pretrained(_CLIP_MODEL_NAME) _clip_model.eval() _clip_processor = CLIPProcessor.from_pretrained(_CLIP_MODEL_NAME) logger.info("CLIP model loaded.") return _clip_model, _clip_processor def _render_html(html: str, width: int = 640, height: int = 480) -> Optional[Image.Image]: """Render HTML to a PIL Image using Playwright headless Chromium. Uses full_page=True so the complete page is captured (no viewport cropping). The viewport width is fixed; height auto-expands to fit content. """ try: from playwright.sync_api import sync_playwright with sync_playwright() as p: browser = p.chromium.launch(args=["--no-sandbox", "--disable-dev-shm-usage"]) page = browser.new_page(viewport={"width": width, "height": height}) page.set_content(html, wait_until="networkidle") png_bytes = page.screenshot(full_page=True) browser.close() return Image.open(io.BytesIO(png_bytes)).convert("RGB") except Exception as exc: logger.warning("HTML rendering failed: %s", exc) return None _CLIP_RENORM_THRESHOLD = 0.65 # raw cosine similarity ≤ this → score 0; 1.0 → 1.0 # Renormalisation makes the metric stricter: only pages visually similar to reference # score meaningfully. Blank pages (raw ~0.45) and unstyled pages (raw ~0.75) get pushed # toward 0, while near-perfect matches (raw ~1.0) remain high. def _clip_similarity(img_a: Image.Image, img_b: Image.Image) -> float: """Compute CLIP image-embedding cosine similarity, renormalised to [0, 1].""" import torch model, processor = _get_clip() inputs = processor(images=[img_a, img_b], return_tensors="pt") with torch.no_grad(): out = model.get_image_features(**inputs) # transformers v5 returns a dataclass; v4 returns a plain tensor features = out.pooler_output if hasattr(out, "pooler_output") else out features = features / features.norm(dim=-1, keepdim=True) raw = (features[0] @ features[1]).item() # Renormalise: threshold → 0, 1.0 → 1.0 scale = 1.0 - _CLIP_RENORM_THRESHOLD return float(max(0.0, min(1.0, (raw - _CLIP_RENORM_THRESHOLD) / scale))) def _pil_similarity(img_a: Image.Image, img_b: Image.Image, size: tuple = (128, 128)) -> float: """Fallback: pixel-wise similarity in [0, 1].""" a = img_a.resize(size).convert("RGB") b = img_b.resize(size).convert("RGB") pa = list(a.getdata()) pb = list(b.getdata()) total_diff = sum( abs(int(ra) - int(rb)) + abs(int(ga) - int(gb)) + abs(int(ba) - int(bb)) for (ra, ga, ba), (rb, gb, bb) in zip(pa, pb) ) max_diff = size[0] * size[1] * 3 * 255 return 1.0 - total_diff / max_diff def clip_visual_reward( completions: list[list[dict]], image: Optional[list[Image.Image]] = None, pred_image: Optional[list[Optional[Image.Image]]] = None, ) -> list[float]: """Score visual similarity between rendered HTML and reference screenshot. Renders each completion's HTML with Playwright (unless pred_image is provided), then computes CLIP cosine similarity against the reference. Falls back to PIL pixel-diff if CLIP is unavailable. Returns 0.5 if rendering fails. Args: completions: List of completion message lists. image: List of reference PIL Images (one per completion). pred_image: Optional pre-rendered prediction images (skips rendering when provided — avoids duplicate Playwright launches). Returns: List of float scores in [0.0, 1.0]. """ # Determine similarity function — prefer CLIP, fall back to pixel-diff try: _get_clip() sim_fn = _clip_similarity except Exception as exc: logger.warning("CLIP unavailable, falling back to pixel-diff: %s", exc) sim_fn = _pil_similarity results = [] for i, completion in enumerate(completions): content = completion[0]["content"] html = extract_html(content) ref_image = image[i] if image and i < len(image) else None # Use pre-rendered image if supplied, otherwise render now if pred_image is not None and i < len(pred_image): rendered = pred_image[i] else: rendered = _render_html(html) if rendered is None or ref_image is None: results.append(0.5) continue try: score = sim_fn(rendered, ref_image.convert("RGB")) except Exception as exc: logger.warning("Similarity scoring failed: %s", exc) score = 0.5 results.append(score) return results