from __future__ import annotations import argparse import copy import json import random import re import sys from pathlib import Path from typing import Any, Dict, Iterable, List, Sequence, Tuple ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(ROOT)) try: from models import DesignGymAction from server.DesignGym_environment import DesignGymEnvironment except Exception: from DesignGym.models import DesignGymAction from DesignGym.server.DesignGym_environment import DesignGymEnvironment ID_RE = re.compile(r"([A-Za-z0-9_]+)@\(") TASKS = [ "poster_basic_v1", "editorial_cover_v1", "dense_flyer_v1", ] SYSTEM_PROMPT = ( "You are a long-horizon spatial layout design agent. " "You receive a design brief, current phase, layout state, metrics, and feedback. " "Output exactly one valid minified JSON action object and nothing else." ) def compact_action(action: DesignGymAction) -> str: """Return concise JSON action for SFT target.""" data = action.model_dump(exclude_none=True) # Drop empty/default fields to avoid teaching noisy long JSON. cleaned: Dict[str, Any] = {} for key, value in data.items(): if key == "element_ids" and value == []: continue if key in {"dx", "dy", "dw", "dh", "strength"} and float(value) == 0.0: continue if key == "grid" and int(value) == 0: continue if key == "anchor" and value == "center" and data.get("action_type") != "resize": continue cleaned[key] = value return json.dumps(cleaned, sort_keys=True, separators=(",", ":")) def ids_in_layout(obs: Any) -> List[str]: return ID_RE.findall(getattr(obs, "layout_summary", "") or "") def has_id(obs: Any, element_id: str) -> bool: return element_id in ids_in_layout(obs) def task_kind(task_id: str) -> str: if "editorial" in task_id: return "editorial" if "dense" in task_id: return "dense" return "poster" def make_template_actions(task_id: str) -> List[DesignGymAction]: kind = task_kind(task_id) if kind == "poster": return [ DesignGymAction(action_type="apply_template", template_id="hero"), DesignGymAction(action_type="apply_template", template_id="split"), ] if kind == "editorial": return [ DesignGymAction(action_type="apply_template", template_id="editorial"), DesignGymAction(action_type="apply_template", template_id="grid"), ] return [ DesignGymAction(action_type="apply_template", template_id="grid"), DesignGymAction(action_type="apply_template", template_id="hero"), ] def candidate_actions(obs: Any, recent_actions: Sequence[str]) -> List[DesignGymAction]: task_id = getattr(obs, "task_id", "") kind = task_kind(task_id) phase = getattr(obs, "phase", "refinement") worst = set(getattr(obs, "worst_metrics", []) or []) brief = getattr(obs, "brief", {}) or {} required_regions = brief.get("required_regions", {}) or {} instruction = float(getattr(obs, "instruction_score", 0.0) or 0.0) actions: List[DesignGymAction] = [] # Structure phase: make global layout choice early. if int(getattr(obs, "step_count", 0) or 0) == 0 or phase == "structure": actions.extend(make_template_actions(task_id)) # Placement: satisfy brief-required regions. priority = [ "cta", "price_badge", "hero_image", "masthead", "title", "subtitle", "headline_1", "headline_2", "headline_3", "details", "sponsor_strip", "logo", ] if instruction < 0.85 or phase in {"placement", "structure"}: for element_id in priority: if element_id in required_regions and has_id(obs, element_id): actions.append( DesignGymAction( action_type="anchor_to_region", element_id=element_id, region_id=str(required_regions[element_id]), mode="center", ) ) # Occupancy / text fit repair. if "occupancy" in worst or "text_fit" in worst: for element_id, dw, dh in [ ("hero_image", 0.03, 0.02), ("details", 0.02, 0.02), ("image_left", 0.02, 0.02), ("image_right", 0.02, 0.02), ("subtitle", 0.02, 0.01), ("headline_2", 0.02, 0.01), ]: if has_id(obs, element_id): actions.append( DesignGymAction( action_type="resize", element_id=element_id, dw=dw, dh=dh, anchor="center", ) ) # Hierarchy repair. if "hierarchy" in worst or phase in {"refinement", "polish"}: for element_id in [ "title", "headline_1", "masthead", "cta", "price_badge", "details", ]: if has_id(obs, element_id): actions.append( DesignGymAction( action_type="promote", element_id=element_id, strength=0.04, ) ) # Alignment repair. if "alignment" in worst or phase in {"placement", "polish"}: # Safe alignment candidates are useful even when alignment is not the worst metric. # This improves SFT coverage for long-horizon refinement behavior. if kind == "poster": ids = [x for x in ["title", "subtitle"] if has_id(obs, x)] if len(ids) >= 2: actions.append( DesignGymAction( action_type="align", element_ids=ids, axis="x", mode="left", ) ) elif kind == "editorial": ids = [x for x in ["masthead", "headline_1", "headline_2"] if has_id(obs, x)] if len(ids) >= 2: actions.append( DesignGymAction( action_type="align", element_ids=ids, axis="x", mode="left", ) ) else: ids = [x for x in ["caption_1", "caption_2"] if has_id(obs, x)] if len(ids) >= 2: actions.append( DesignGymAction( action_type="align", element_ids=ids, axis="y", mode="top", ) ) # Spacing / reading order repair. if "spacing" in worst or "reading_order" in worst or phase == "refinement": if kind == "poster": actions.append(DesignGymAction(action_type="reflow_group", group_id="headline", pattern="stack")) elif kind == "editorial": actions.append(DesignGymAction(action_type="reflow_group", group_id="stories", pattern="stack")) else: actions.append(DesignGymAction(action_type="reflow_group", group_id="support", pattern="row")) # Small polish moves. # Small local movement candidates. # These are important for teaching fine-grained spatial correction. if phase in {"refinement", "polish"} or "balance" in worst or "spacing" in worst: for element_id in [ "hero_image", "title", "subtitle", "cta", "masthead", "headline_1", "headline_2", "details", "price_badge", ]: if has_id(obs, element_id): actions.append( DesignGymAction( action_type="move", element_id=element_id, dx=0.01, dy=-0.01, ) ) actions.append( DesignGymAction( action_type="move", element_id=element_id, dx=-0.01, dy=0.01, ) ) break # Finalize only when plausibly ready. score = float(getattr(obs, "current_score", 0.0) or 0.0) step_count = int(getattr(obs, "step_count", 0) or 0) max_steps = int(getattr(obs, "max_steps", 1) or 1) if step_count >= int(0.75 * max_steps) and score >= 0.72 and instruction >= 0.60: actions.append(DesignGymAction(action_type="finalize")) # Deduplicate. dedup: List[DesignGymAction] = [] seen = set() for action in actions: key = compact_action(action) if key in seen: continue seen.add(key) dedup.append(action) # Avoid repeating identical action if possible. filtered = [a for a in dedup if compact_action(a) not in set(recent_actions[-2:])] return filtered or dedup or [DesignGymAction(action_type="finalize")] def evaluate_candidate(env: DesignGymEnvironment, action: DesignGymAction) -> float: """Score a candidate by simulating it on a copied environment.""" try: tmp = copy.deepcopy(env) obs = tmp.step(action) state = tmp.state if getattr(obs, "last_action_error", None): return -10.0 reward = float(getattr(state, "last_reward", 0.0) or 0.0) layout_score = float(getattr(state, "current_score", 0.0) or 0.0) instruction = float(getattr(state, "instruction_score", 0.0) or 0.0) phase_score = float(getattr(state, "phase_score", 0.0) or 0.0) # Reward is primary. Small tie-breaks prefer better final state. return reward + 0.05 * instruction + 0.03 * layout_score + 0.02 * phase_score except Exception: return -10.0 def preferred_action_type_for_example(episode_idx: int, local_step: int, obs: Any) -> str | None: phase = getattr(obs, "phase", "") bucket = (episode_idx * 31 + local_step * 17) % 100 # Force some safe alignment examples. if phase in {"placement", "refinement", "polish"} and bucket < 12: return "align" # Force some fine-grained movement examples. if phase in {"refinement", "polish"} and 12 <= bucket < 20: return "move" return None def choose_expert_action( env: DesignGymEnvironment, obs: Any, preferred_action_type: str | None = None, ) -> DesignGymAction: recent = list(getattr(env.state, "action_history", []) or []) candidates = candidate_actions(obs, recent) scored = [(a, evaluate_candidate(env, a)) for a in candidates] scored.sort(key=lambda x: x[1], reverse=True) ranked = [a for a, _ in scored] # If this example is scheduled for action diversity, choose the best candidate # of that type, but only if it is not terrible. if preferred_action_type: preferred = [ (a, score) for a, score in scored if a.action_type == preferred_action_type and score > -1.0 ] if preferred: preferred.sort(key=lambda x: x[1], reverse=True) return preferred[0][0] # Normal expert choice with small top-k diversity. if len(ranked) > 1: rng_key = ( f"{getattr(env.state, 'seed', 0)}:" f"{getattr(obs, 'task_id', '')}:" f"{getattr(obs, 'step_count', 0)}:" f"{len(recent)}" ) rng = random.Random(rng_key) if rng.random() < 0.12: top_k = ranked[: min(4, len(ranked))] return rng.choice(top_k) return ranked[0] def prompt_from_obs(obs: Any) -> str: payload = { "task_id": getattr(obs, "task_id", ""), "step_count": getattr(obs, "step_count", 0), "max_steps": getattr(obs, "max_steps", 0), "phase": getattr(obs, "phase", ""), "allowed_actions": getattr(obs, "allowed_actions", []), "current_score": round(float(getattr(obs, "current_score", 0.0) or 0.0), 4), "best_score_so_far": round(float(getattr(obs, "best_score_so_far", 0.0) or 0.0), 4), "instruction_score": round(float(getattr(obs, "instruction_score", 0.0) or 0.0), 4), "phase_score": round(float(getattr(obs, "phase_score", 0.0) or 0.0), 4), "brief": getattr(obs, "brief", {}) or {}, "metrics": getattr(obs, "metrics", {}) or {}, "metric_deltas": getattr(obs, "metric_deltas", {}) or {}, "worst_metrics": getattr(obs, "worst_metrics", []) or [], "focus_elements": getattr(obs, "focus_elements", []) or [], "critic_feedback": getattr(obs, "critic_feedback", []) or [], "layout_summary": getattr(obs, "layout_summary", "") or "", } action_schema = { "action_type": "apply_template | anchor_to_region | resize | move | align | distribute | promote | reflow_group | finalize", "optional_fields": { "template_id": "hero | split | editorial | grid | draft", "element_id": "single element id", "element_ids": "list of element ids", "region_id": "semantic target region", "group_id": "semantic group id", "pattern": "stack | row", "axis": "x | y", "mode": "left | center | top", "dx_dy_dw_dh": "small normalized floats", "strength": "small float for promote", }, } return ( "Choose the next best layout edit.\n" "Output only one valid minified JSON action object.\n\n" f"STATE:\n{json.dumps(payload, sort_keys=True)}\n\n" f"ACTION_SCHEMA:\n{json.dumps(action_schema, sort_keys=True)}" ) def make_example(obs: Any, action: DesignGymAction, metadata: Dict[str, Any]) -> Dict[str, Any]: return { "messages": [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": prompt_from_obs(obs)}, {"role": "assistant", "content": compact_action(action)}, ], **metadata, } def generate_examples( *, episodes: int, seed: int, max_steps_override: int | None, tasks: Sequence[str], ) -> List[Dict[str, Any]]: rng = random.Random(seed) examples: List[Dict[str, Any]] = [] for episode_idx in range(episodes): task_id = tasks[episode_idx % len(tasks)] episode_seed = seed + episode_idx env = DesignGymEnvironment() obs = env.reset(task_id=task_id, seed=episode_seed) max_steps = max_steps_override or int(getattr(obs, "max_steps", 8) or 8) for local_step in range(max_steps): if bool(getattr(obs, "done", False)) or bool(getattr(env.state, "done", False)): break preferred_type = preferred_action_type_for_example(episode_idx, local_step, obs) action = choose_expert_action(env, obs, preferred_action_type=preferred_type) before_obs = obs obs = env.step(action) metadata = { "task_id": task_id, "episode_seed": episode_seed, "episode_index": episode_idx, "step_index": local_step, "expert_action": compact_action(action), "reward_after": round(float(getattr(env.state, "last_reward", 0.0) or 0.0), 6), "score_after": round(float(getattr(env.state, "current_score", 0.0) or 0.0), 6), "instruction_score_after": round(float(getattr(env.state, "instruction_score", 0.0) or 0.0), 6), "phase_after": getattr(env.state, "phase", ""), } examples.append(make_example(before_obs, action, metadata)) rng.shuffle(examples) return examples def write_jsonl(path: Path, rows: Iterable[Dict[str, Any]]) -> int: path.parent.mkdir(parents=True, exist_ok=True) count = 0 with path.open("w", encoding="utf-8") as f: for row in rows: f.write(json.dumps(row, ensure_ascii=False) + "\n") count += 1 return count def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--episodes", type=int, default=300) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--max-steps", type=int, default=None) parser.add_argument("--tasks", nargs="*", default=TASKS) parser.add_argument("--out", type=str, default="data/sft/designgym2_sft_train.jsonl") parser.add_argument("--eval-out", type=str, default="data/sft/designgym2_sft_eval.jsonl") parser.add_argument("--eval-ratio", type=float, default=0.10) args = parser.parse_args() examples = generate_examples( episodes=args.episodes, seed=args.seed, max_steps_override=args.max_steps, tasks=args.tasks, ) split_idx = int(len(examples) * (1.0 - args.eval_ratio)) train_rows = examples[:split_idx] eval_rows = examples[split_idx:] train_count = write_jsonl(Path(args.out), train_rows) eval_count = write_jsonl(Path(args.eval_out), eval_rows) print(f"[OK] generated total={len(examples)} train={train_count} eval={eval_count}") print(f"[TRAIN] {args.out}") print(f"[EVAL] {args.eval_out}") if examples: print("[SAMPLE]") print(json.dumps(examples[0], indent=2)[:3000]) if __name__ == "__main__": main()