layoutenv / server /layout_environment.py
Ryz3n758's picture
Upload folder using huggingface_hub
a0360ae verified
Raw
History Blame Contribute Delete
24.7 kB
"""
Layout Environment Implementation.
An RL environment for iteratively refining UI poster layouts.
The agent receives a layout and must improve it using discrete actions
(MOVE, RESIZE, ALIGN, SNAP, NO_OP).
Perturbations are the responsibility of the caller (e.g. inference.py);
this environment is agnostic to how the initial layout was produced.
"""
from __future__ import annotations
import base64
import io
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Set
from uuid import uuid4
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from openenv.core.env_server.interfaces import Environment
try:
from ..models import ACTIONS, MAGNITUDES, LayoutAction, LayoutObservation, LayoutState
except (ImportError, ModuleNotFoundError):
from models import ACTIONS, MAGNITUDES, LayoutAction, LayoutObservation, LayoutState
from .metrics import (
_axis_value,
_to_ltrb,
compute_all_metrics,
quality_score,
)
CONTENT_AWARE_METRICS: Set[str] = {"occlusion"}
# Single baked-in training-free layout (normalised bboxes).
# Training code should load the full dataset and pass ``sample=`` into ``reset``.
DEFAULT_LAYOUT_SAMPLE: Dict[str, Any] = {
"id": 0,
"canvas_size": [3556, 2000],
"elements": [
{
"type": "Title",
"text": "Demo",
"bbox": [0.2, 0.15, 0.8, 0.25],
"font_size": 120.0,
},
{
"type": "Bodytext",
"text": "Stateless default episode",
"bbox": [0.15, 0.4, 0.85, 0.55],
"font_size": 90.0,
},
{
"type": "Website",
"text": "example.com",
"bbox": [0.35, 0.85, 0.65, 0.92],
"font_size": 48.0,
},
],
}
def _bbox_to_centre(bbox: List[float]) -> Dict[str, float]:
x1, y1, x2, y2 = bbox
return {
"cx": (x1 + x2) / 2,
"cy": (y1 + y2) / 2,
"w": x2 - x1,
"h": y2 - y1,
}
def _default_stats_from_sample(sample: Dict[str, Any]) -> Dict[str, Any]:
"""
Per-element-type Gaussian plausibility priors matching sample's ground truth.
Shared isotropic covariance (loose prior) so perturbed layouts still score smoothly.
"""
inv_cov = np.linalg.inv((0.1**2) * np.eye(5) + 1e-6 * np.eye(5))
out: Dict[str, Any] = {}
for elem in sample.get("elements", []):
etype = elem.get("type")
if not etype or etype in out:
continue
centre = _bbox_to_centre(elem["bbox"])
canvas_h = float(sample["canvas_size"][1])
fs_raw = float(elem.get("font_size", 0.0) or 0.0)
fs_norm = fs_raw / canvas_h if canvas_h > 0 else 0.0
mu = np.array(
[centre["cx"], centre["cy"], centre["w"], centre["h"], fs_norm],
dtype=np.float64,
)
out[etype] = {"mu": mu, "cov_inv": inv_cov}
return out
DEFAULT_STATS: Dict[str, Any] = _default_stats_from_sample(DEFAULT_LAYOUT_SAMPLE)
def _sample_to_elements(sample: Dict) -> List[Dict]:
"""Convert a dataset sample to the internal element list."""
canvas_w, canvas_h = sample["canvas_size"]
elements = []
for i, elem in enumerate(sample.get("elements", [])):
centre = _bbox_to_centre(elem["bbox"])
fs_raw = float(elem.get("font_size", 0.0) or 0.0)
fs_norm = fs_raw / canvas_h if canvas_h > 0 else 0.0
elements.append({
"id": i,
"type": elem.get("type", "unknown"),
"text": elem.get("text", ""),
"cx": centre["cx"],
"cy": centre["cy"],
"w": centre["w"],
"h": centre["h"],
"font_size": fs_norm,
})
return elements
# Action application
def _apply_action(
elements: List[Dict],
action: LayoutAction,
) -> None:
"""Mutate elements in-place according to action"""
eid = action.element_id
act = action.action
param = action.param
delta = MAGNITUDES.get(action.magnitude, MAGNITUDES["MEDIUM"])
elem = elements[eid]
if act == "MOVE":
if param == "UP":
elem["cy"] -= delta
elif param == "DOWN":
elem["cy"] += delta
elif param == "LEFT":
elem["cx"] -= delta
elif param == "RIGHT":
elem["cx"] += delta
elif act == "RESIZE":
if param == "WIDER":
elem["w"] += delta
elif param == "NARROWER":
elem["w"] -= delta
elif param == "TALLER":
elem["h"] += delta
elif param == "SHORTER":
elem["h"] -= delta
# Keep geometry valid for downstream metric computations.
elem["w"] = max(0.01, min(1.0, elem["w"]))
elem["h"] = max(0.01, min(1.0, elem["h"]))
elif act == "ALIGN":
_apply_align(elements, eid, param)
elif act == "SNAP":
grid = 0.05
elem["cx"] = round(elem["cx"] / grid) * grid
elem["cy"] = round(elem["cy"] / grid) * grid
_PARAM_TO_AXIS = {
"LEFT": "left", "RIGHT": "right", "CENTER_X": "cx",
"TOP": "top", "BOTTOM": "bottom", "CENTER_Y": "cy",
}
def _apply_align(
elements: List[Dict],
eid: int,
param: str,
threshold: float = 0.15,
) -> None:
"""Nearest-neighbour inter-element alignment with canvas fallback."""
target = elements[eid]
others = [e for e in elements if e["id"] != target["id"]]
axis = _PARAM_TO_AXIS.get(param, param.lower())
target_val = _axis_value(target, axis)
best_val: Optional[float] = None
best_dist = float("inf")
for other in others:
other_val = _axis_value(other, axis)
dist = abs(target_val - other_val)
if dist < best_dist:
best_dist = dist
best_val = other_val
if best_val is not None and best_dist < threshold:
snap_to = best_val
else:
canvas_anchors = {
"left": 0.0, "right": 1.0, "cx": 0.5,
"top": 0.0, "bottom": 1.0, "cy": 0.5,
}
snap_to = canvas_anchors[axis]
_set_axis_value(target, axis, snap_to)
def _set_axis_value(e: Dict, axis: str, val: float) -> None:
hw, hh = e["w"] / 2, e["h"] / 2
if axis == "left":
e["cx"] = val + hw
elif axis == "right":
e["cx"] = val - hw
elif axis == "cx":
e["cx"] = val
elif axis == "top":
e["cy"] = val + hh
elif axis == "bottom":
e["cy"] = val - hh
elif axis == "cy":
e["cy"] = val
# Round helpers
def _round_elements(elements: List[Dict], dp: int = 3) -> List[Dict]:
"""Return a copy with floats rounded for observation output."""
out = []
for e in elements:
out.append({
"id": e["id"],
"type": e["type"],
"cx": round(e["cx"], dp),
"cy": round(e["cy"], dp),
"w": round(e["w"], dp),
"h": round(e["h"], dp),
"font_size": round(e["font_size"], dp),
})
return out
def _resolve_media_path(dataset_json_path: str, relative_path: str) -> Path:
"""
Resolve e.g. images/0_bg.png relative to the dataset JSON directory.
This supports volume-mounted datasets when the server container can access
the dataset path.
"""
return Path(dataset_json_path).resolve().parent / relative_path
def _render_layout_on_background(
bg_path: str | Path | None,
elements: List[Dict],
bg_pil: Image.Image | None = None,
) -> Image.Image:
"""
Draw normalized layout (cx, cy, w, h in [0, 1]) on top of the background.
Filled rectangles use distinct colors; label = type and truncated text.
If bg_path is None or missing, a neutral placeholder canvas is used.
bg_pil allows passing an already-loaded PIL image (e.g. decoded from
base64) so the environment can work without filesystem access.
"""
if bg_pil is not None:
base = bg_pil.convert("RGBA")
w_px, h_px = base.size
elif bg_path is None:
w_px, h_px = 1024, 1024
base = Image.new("RGBA", (w_px, h_px), (245, 245, 245, 255))
else:
path = Path(bg_path)
if not path.is_file():
w_px, h_px = 1024, 1024
base = Image.new("RGBA", (w_px, h_px), (245, 245, 245, 255))
else:
with Image.open(path) as img:
base = img.convert("RGBA")
w_px, h_px = base.size
overlay = Image.new("RGBA", (w_px, h_px), (0, 0, 0, 0))
draw = ImageDraw.Draw(overlay)
palette = [
(255, 99, 71, 90),
(60, 179, 113, 90),
(65, 105, 225, 90),
(238, 130, 238, 90),
(255, 215, 0, 90),
(0, 206, 209, 90),
(255, 140, 0, 90),
(147, 112, 219, 90),
]
line_w = max(1, min(w_px, h_px) // 100)
for i, e in enumerate(elements):
cx, cy, ew, eh = (
float(e["cx"]),
float(e["cy"]),
float(e["w"]),
float(e["h"]),
)
x1 = int((cx - ew / 2) * w_px)
y1 = int((cy - eh / 2) * h_px)
x2 = int((cx + ew / 2) * w_px)
y2 = int((cy + eh / 2) * h_px)
x1 = max(0, min(x1, w_px - 1))
y1 = max(0, min(y1, h_px - 1))
x2 = max(0, min(x2, w_px - 1))
y2 = max(0, min(y2, h_px - 1))
if x2 <= x1:
x2 = min(w_px - 1, x1 + 1)
if y2 <= y1:
y2 = min(h_px - 1, y1 + 1)
fill = palette[i % len(palette)]
outline = (*fill[:3], 255)
draw.rectangle([x1, y1, x2, y2], fill=fill, outline=outline, width=line_w)
composed = Image.alpha_composite(base, overlay)
d2 = ImageDraw.Draw(composed)
font_size = max(8, min(w_px, h_px) // 18)
font: ImageFont.FreeTypeFont | ImageFont.ImageFont
try:
font = ImageFont.truetype(
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", font_size
)
except OSError:
try:
font = ImageFont.truetype("DejaVuSans.ttf", font_size)
except OSError:
font = ImageFont.load_default()
for i, e in enumerate(elements):
cx, cy, ew, eh = (
float(e["cx"]),
float(e["cy"]),
float(e["w"]),
float(e["h"]),
)
x1 = int((cx - ew / 2) * w_px)
y1 = int((cy - eh / 2) * h_px)
x2 = int((cx + ew / 2) * w_px)
y2 = int((cy + eh / 2) * h_px)
x1 = max(0, min(x1, w_px - 1))
y1 = max(0, min(y1, h_px - 1))
x2 = max(0, min(x2, w_px - 1))
y2 = max(0, min(y2, h_px - 1))
if x2 <= x1:
x2 = min(w_px - 1, x1 + 1)
if y2 <= y1:
y2 = min(h_px - 1, y1 + 1)
raw_text = str(e.get("text", "") or "").strip()
label = str(e.get("type", "unknown") or "unknown")
if raw_text:
label = f"{label}: {raw_text}"
if len(label) > 48:
label = label[:45] + "..."
tb = d2.textbbox((0, 0), label, font=font)
tw, th = tb[2] - tb[0], tb[3] - tb[1]
tx = x1 + max(2, (x2 - x1 - tw) // 2)
ty = y1 + max(2, (y2 - y1 - th) // 2)
d2.text(
(tx, ty),
label,
font=font,
fill=(255, 255, 255, 255),
stroke_width=max(1, line_w // 2),
stroke_fill=(0, 0, 0, 255),
)
if composed.mode == "RGBA":
rgb = Image.new("RGB", composed.size, (255, 255, 255))
rgb.paste(composed, mask=composed.split()[3])
return rgb
return composed.convert("RGB")
def _safe_load_saliency_array(path: Path) -> np.ndarray | None:
"""
Best-effort saliency loader. Returns None on missing/invalid content.
Expected input is a 2D float-compatible numpy array.
"""
if not path.is_file():
return None
try:
loaded = np.load(path)
except Exception:
return None
arr = np.asarray(loaded, dtype=np.float32)
if arr.ndim != 2 or arr.size == 0:
return None
arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
return np.clip(arr, 0.0, None)
def _pil_to_png_base64(img: Image.Image) -> str:
buf = io.BytesIO()
img.save(buf, format="PNG", optimize=True)
return base64.b64encode(buf.getvalue()).decode("ascii")
def _generate_text_feedback(
delta_q: float,
metrics: Dict[str, float],
elements: List[Dict],
) -> str:
"""
Produce a concise, actionable text hint from the current metrics.
The feedback tells the model (a) whether it improved, and (b) which
metric to target next with a concrete suggestion.
"""
parts: List[str] = []
if delta_q > 0.01:
parts.append(f"Quality improved by +{delta_q:.3f}. Keep going.")
elif delta_q < -0.01:
parts.append(f"Quality dropped by {delta_q:.3f}. Undo or try a different action.")
else:
parts.append("Negligible change. Try a different element or direction.")
overlap = metrics.get("overlap", 0.0)
boundary = metrics.get("boundary", 0.0)
alignment = metrics.get("alignment", 1.0)
spacing = metrics.get("spacing", 1.0)
occlusion = metrics.get("occlusion", 0.0)
penalties = {"overlap": overlap, "boundary": boundary, "occlusion": occlusion}
worst_penalty_name = max(penalties, key=penalties.get) # type: ignore[arg-type]
worst_penalty_val = penalties[worst_penalty_name]
rewards = {"alignment": alignment, "spacing": spacing}
worst_reward_name = min(rewards, key=rewards.get) # type: ignore[arg-type]
worst_reward_val = rewards[worst_reward_name]
if worst_penalty_val > 0.05:
if worst_penalty_name == "overlap":
parts.append(
f"Overlap is high ({overlap:.3f}). "
"MOVE overlapping elements apart or RESIZE them smaller."
)
elif worst_penalty_name == "boundary":
oob = [
e["id"] for e in elements
if _is_out_of_bounds(e)
]
if oob:
parts.append(
f"Boundary violation ({boundary:.3f}) on element(s) {oob}. "
"MOVE them inward or RESIZE them smaller."
)
else:
parts.append(
f"Boundary penalty ({boundary:.3f}). "
"Some elements may be near the edge; MOVE inward."
)
else:
parts.append(
f"Occlusion is high ({occlusion:.3f}). "
"MOVE/RESIZE elements away from high-saliency regions."
)
elif worst_reward_val < 0.5:
if worst_reward_name == "alignment":
parts.append(
f"Alignment is low ({alignment:.3f}). "
"Use ALIGN (CENTER_X, LEFT, etc.) to snap edges together."
)
else:
parts.append(
f"Spacing is uneven ({spacing:.3f}). "
"MOVE elements to equalise vertical/horizontal gaps."
)
return " ".join(parts)
def _is_out_of_bounds(e: Dict) -> bool:
hw, hh = e["w"] / 2, e["h"] / 2
l, t, r, b = e["cx"] - hw, e["cy"] - hh, e["cx"] + hw, e["cy"] + hh
return l < 0 or t < 0 or r > 1 or b > 1
# Environment
INVALID_ACTION_PENALTY = -0.5
STEP_PENALTY = -0.05
REWARD_SCALE = 10.0
TERMINAL_BONUS_SCALE = 5.0
TERMINAL_PENALTY = -1.0
# Align terminal shaping with the easiest grader delta threshold.
Q_DELTA_THRESHOLD = 0.15
def _normalize_visible_reward(raw_reward: float | int) -> float:
"""Normalize raw unbounded rewards into a stable visible range.
Using tanh to squash safely into roughly (0, 1) before final clamping.
"""
import math
# Scale such that a max bonus of ~5.0 maps near 0.9.
squashed = (math.tanh(float(raw_reward) / 5.0) + 1.0) / 2.0
return round(squashed, 4)
class LayoutEnvironment(Environment):
"""
An RL environment for layout refinement.
The caller is responsible for producing the initial layout (e.g. by
perturbing a ground-truth sample) and passing it via reset(sample=...).
Args:
max_steps: Maximum actions per episode.
weights: Optional metric weight overrides for Q.
stats: Plausibility metric config (e.g. loaded from *_stats.npy);
immutable for the lifetime of this env instance. If omitted,
DEFAULT_STATS (derived from DEFAULT_LAYOUT_SAMPLE) is used.
"""
# This environment stores episode-specific fields on the instance.
# Do not advertise shared-instance concurrent session safety.
SUPPORTS_CONCURRENT_SESSIONS: bool = False
def __init__(
self,
max_steps: int = 500,
weights: Optional[Dict[str, float]] = None,
stats: Optional[Dict[str, Any]] = None,
):
super().__init__()
self._state = LayoutState(episode_id=str(uuid4()), step_count=0)
self._max_steps = max_steps
self._weights = weights
self._stats: Dict[str, Any] = (
DEFAULT_STATS if stats is None else stats
)
self._mode: Literal["llm", "vlm"] = "llm"
self._text_feedback: bool = True
self._render_image_in_observation: bool = True
self._saliency_map: np.ndarray | None = None
def _active_content_metric_names(self) -> Set[str]:
# Global policy: content-aware metrics are VLM-only AND require a valid saliency map.
if self._mode == "vlm" and self._saliency_map is not None:
return CONTENT_AWARE_METRICS
return set()
def _build_observation(
self,
step_num: int,
done: bool,
reward: float | int,
metrics: Dict,
q: float,
) -> LayoutObservation:
image_path: Optional[str] = None
rendered_b64: Optional[str] = None
if self._mode == "vlm":
image_path = self._state.current_image_rel
if self._mode == "vlm" and self._render_image_in_observation:
resolved_bg_path: Optional[Path] = None
bg_img: Image.Image | None = None
inline_b64 = getattr(self._state, "_bg_image_base64", None)
if inline_b64:
with Image.open(io.BytesIO(base64.b64decode(inline_b64))) as decoded:
bg_img = decoded.convert("RGBA")
elif self._state.current_image_rel and self._state.dataset_json_path:
resolved_bg_path = _resolve_media_path(
self._state.dataset_json_path, self._state.current_image_rel
)
if resolved_bg_path.is_file():
with Image.open(resolved_bg_path) as loaded:
bg_img = loaded.convert("RGBA")
rendered = _render_layout_on_background(
resolved_bg_path, self._state.elements, bg_pil=bg_img
)
rendered_b64 = _pil_to_png_base64(rendered)
prev_q = self._state.previous_quality
delta_q = q - prev_q
feedback: Optional[str] = None
if self._text_feedback:
if step_num == 0:
feedback = "Episode started. Choose an element and action."
else:
feedback = _generate_text_feedback(delta_q, metrics, self._state.elements)
obs = LayoutObservation(
canvas={"width": 1.0, "height": 1.0},
elements=_round_elements(self._state.elements),
metrics=metrics,
step=step_num,
max_steps=self._max_steps,
quality_score=q,
initial_quality_score=self._state.initial_quality,
text_feedback=feedback,
reward=reward,
done=done,
image_path=image_path,
rendered_image_base64=rendered_b64,
)
return obs
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
*,
sample: Optional[Dict[str, Any]] = None,
dataset_json_path: Optional[str] = None,
background_image_base64: Optional[str] = None,
mode: Optional[Literal["llm", "vlm"]] = None,
text_feedback: Optional[bool] = None,
render_image_in_observation: Optional[bool] = None,
**kwargs: Any,
) -> LayoutObservation:
# Intentionally avoid touching module-global RNG state here.
# Seeding happens client-side for perturbation reproducibility.
if mode is not None:
self._mode = mode
if text_feedback is not None:
self._text_feedback = text_feedback
if render_image_in_observation is not None:
self._render_image_in_observation = render_image_in_observation
chosen = sample if sample is not None else DEFAULT_LAYOUT_SAMPLE
if self._mode == "vlm" and not chosen.get("image_path") and not background_image_base64:
raise ValueError(
"VLM mode requires sample['image_path'] or background_image_base64. "
"Pass a sample from your dataset on reset."
)
current_image_rel = (
chosen.get("image_path") if self._mode == "vlm" else None
)
current_saliency_rel = (
chosen.get("saliency_image_path") if self._mode == "vlm" else None
)
elements = _sample_to_elements(chosen)
self._state = LayoutState(
episode_id=episode_id if episode_id is not None else str(uuid4()),
step_count=0,
elements=elements,
previous_quality=0.0,
initial_quality=0.0,
current_image_rel=current_image_rel,
current_saliency_rel=current_saliency_rel,
dataset_json_path=dataset_json_path,
)
if background_image_base64:
self._state._bg_image_base64 = background_image_base64
self._saliency_map = None
if (
self._mode == "vlm"
and self._state.current_saliency_rel
and self._state.dataset_json_path
):
saliency_abs = _resolve_media_path(
self._state.dataset_json_path, self._state.current_saliency_rel
)
self._saliency_map = _safe_load_saliency_array(saliency_abs)
metrics = compute_all_metrics(
self._state.elements,
self._stats,
saliency_map=self._saliency_map,
content_metric_names=self._active_content_metric_names(),
)
q = quality_score(metrics, self._weights)
self._state.previous_quality = q
self._state.initial_quality = q
return self._build_observation(0, False, _normalize_visible_reward(0.0), metrics, q)
def step(self, action: LayoutAction) -> LayoutObservation: # type: ignore[override]
self._state.step_count += 1
step_num = self._state.step_count
valid = action.is_valid(len(self._state.elements))
if not valid:
metrics = compute_all_metrics(
self._state.elements,
self._stats,
saliency_map=self._saliency_map,
content_metric_names=self._active_content_metric_names(),
)
q = quality_score(metrics, self._weights)
done = step_num >= self._max_steps
reward = INVALID_ACTION_PENALTY + STEP_PENALTY
if done:
q_delta = q - self._state.initial_quality
reward += (
TERMINAL_BONUS_SCALE if q_delta >= Q_DELTA_THRESHOLD else TERMINAL_PENALTY
)
return self._build_observation(
step_num, done, _normalize_visible_reward(reward), metrics, q
)
is_noop = action.action == "NO_OP"
if not is_noop:
_apply_action(self._state.elements, action)
else:
pass
metrics = compute_all_metrics(
self._state.elements,
self._stats,
saliency_map=self._saliency_map,
content_metric_names=self._active_content_metric_names(),
)
q = quality_score(metrics, self._weights)
delta_q = q - self._state.previous_quality
done = is_noop or step_num >= self._max_steps
reward = REWARD_SCALE * delta_q + STEP_PENALTY
if done:
q_delta = q - self._state.initial_quality
reward += (
TERMINAL_BONUS_SCALE if q_delta >= Q_DELTA_THRESHOLD else TERMINAL_PENALTY
)
obs = self._build_observation(
step_num, done, _normalize_visible_reward(reward), metrics, q
)
self._state.previous_quality = q
return obs
@property
def state(self) -> LayoutState:
return self._state