DesignGym / training /generate_sft_data.py
yashvyasop's picture
Upload folder using huggingface_hub
44c2d9e verified
from __future__ import annotations
import argparse
import copy
import json
import random
import re
import sys
from pathlib import Path
from typing import Any, Dict, Iterable, List, Sequence, Tuple
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
try:
from models import DesignGymAction
from server.DesignGym_environment import DesignGymEnvironment
except Exception:
from DesignGym.models import DesignGymAction
from DesignGym.server.DesignGym_environment import DesignGymEnvironment
ID_RE = re.compile(r"([A-Za-z0-9_]+)@\(")
TASKS = [
"poster_basic_v1",
"editorial_cover_v1",
"dense_flyer_v1",
]
SYSTEM_PROMPT = (
"You are a long-horizon spatial layout design agent. "
"You receive a design brief, current phase, layout state, metrics, and feedback. "
"Output exactly one valid minified JSON action object and nothing else."
)
def compact_action(action: DesignGymAction) -> str:
"""Return concise JSON action for SFT target."""
data = action.model_dump(exclude_none=True)
# Drop empty/default fields to avoid teaching noisy long JSON.
cleaned: Dict[str, Any] = {}
for key, value in data.items():
if key == "element_ids" and value == []:
continue
if key in {"dx", "dy", "dw", "dh", "strength"} and float(value) == 0.0:
continue
if key == "grid" and int(value) == 0:
continue
if key == "anchor" and value == "center" and data.get("action_type") != "resize":
continue
cleaned[key] = value
return json.dumps(cleaned, sort_keys=True, separators=(",", ":"))
def ids_in_layout(obs: Any) -> List[str]:
return ID_RE.findall(getattr(obs, "layout_summary", "") or "")
def has_id(obs: Any, element_id: str) -> bool:
return element_id in ids_in_layout(obs)
def task_kind(task_id: str) -> str:
if "editorial" in task_id:
return "editorial"
if "dense" in task_id:
return "dense"
return "poster"
def make_template_actions(task_id: str) -> List[DesignGymAction]:
kind = task_kind(task_id)
if kind == "poster":
return [
DesignGymAction(action_type="apply_template", template_id="hero"),
DesignGymAction(action_type="apply_template", template_id="split"),
]
if kind == "editorial":
return [
DesignGymAction(action_type="apply_template", template_id="editorial"),
DesignGymAction(action_type="apply_template", template_id="grid"),
]
return [
DesignGymAction(action_type="apply_template", template_id="grid"),
DesignGymAction(action_type="apply_template", template_id="hero"),
]
def candidate_actions(obs: Any, recent_actions: Sequence[str]) -> List[DesignGymAction]:
task_id = getattr(obs, "task_id", "")
kind = task_kind(task_id)
phase = getattr(obs, "phase", "refinement")
worst = set(getattr(obs, "worst_metrics", []) or [])
brief = getattr(obs, "brief", {}) or {}
required_regions = brief.get("required_regions", {}) or {}
instruction = float(getattr(obs, "instruction_score", 0.0) or 0.0)
actions: List[DesignGymAction] = []
# Structure phase: make global layout choice early.
if int(getattr(obs, "step_count", 0) or 0) == 0 or phase == "structure":
actions.extend(make_template_actions(task_id))
# Placement: satisfy brief-required regions.
priority = [
"cta",
"price_badge",
"hero_image",
"masthead",
"title",
"subtitle",
"headline_1",
"headline_2",
"headline_3",
"details",
"sponsor_strip",
"logo",
]
if instruction < 0.85 or phase in {"placement", "structure"}:
for element_id in priority:
if element_id in required_regions and has_id(obs, element_id):
actions.append(
DesignGymAction(
action_type="anchor_to_region",
element_id=element_id,
region_id=str(required_regions[element_id]),
mode="center",
)
)
# Occupancy / text fit repair.
if "occupancy" in worst or "text_fit" in worst:
for element_id, dw, dh in [
("hero_image", 0.03, 0.02),
("details", 0.02, 0.02),
("image_left", 0.02, 0.02),
("image_right", 0.02, 0.02),
("subtitle", 0.02, 0.01),
("headline_2", 0.02, 0.01),
]:
if has_id(obs, element_id):
actions.append(
DesignGymAction(
action_type="resize",
element_id=element_id,
dw=dw,
dh=dh,
anchor="center",
)
)
# Hierarchy repair.
if "hierarchy" in worst or phase in {"refinement", "polish"}:
for element_id in [
"title",
"headline_1",
"masthead",
"cta",
"price_badge",
"details",
]:
if has_id(obs, element_id):
actions.append(
DesignGymAction(
action_type="promote",
element_id=element_id,
strength=0.04,
)
)
# Alignment repair.
if "alignment" in worst or phase in {"placement", "polish"}:
# Safe alignment candidates are useful even when alignment is not the worst metric.
# This improves SFT coverage for long-horizon refinement behavior.
if kind == "poster":
ids = [x for x in ["title", "subtitle"] if has_id(obs, x)]
if len(ids) >= 2:
actions.append(
DesignGymAction(
action_type="align",
element_ids=ids,
axis="x",
mode="left",
)
)
elif kind == "editorial":
ids = [x for x in ["masthead", "headline_1", "headline_2"] if has_id(obs, x)]
if len(ids) >= 2:
actions.append(
DesignGymAction(
action_type="align",
element_ids=ids,
axis="x",
mode="left",
)
)
else:
ids = [x for x in ["caption_1", "caption_2"] if has_id(obs, x)]
if len(ids) >= 2:
actions.append(
DesignGymAction(
action_type="align",
element_ids=ids,
axis="y",
mode="top",
)
)
# Spacing / reading order repair.
if "spacing" in worst or "reading_order" in worst or phase == "refinement":
if kind == "poster":
actions.append(DesignGymAction(action_type="reflow_group", group_id="headline", pattern="stack"))
elif kind == "editorial":
actions.append(DesignGymAction(action_type="reflow_group", group_id="stories", pattern="stack"))
else:
actions.append(DesignGymAction(action_type="reflow_group", group_id="support", pattern="row"))
# Small polish moves.
# Small local movement candidates.
# These are important for teaching fine-grained spatial correction.
if phase in {"refinement", "polish"} or "balance" in worst or "spacing" in worst:
for element_id in [
"hero_image",
"title",
"subtitle",
"cta",
"masthead",
"headline_1",
"headline_2",
"details",
"price_badge",
]:
if has_id(obs, element_id):
actions.append(
DesignGymAction(
action_type="move",
element_id=element_id,
dx=0.01,
dy=-0.01,
)
)
actions.append(
DesignGymAction(
action_type="move",
element_id=element_id,
dx=-0.01,
dy=0.01,
)
)
break
# Finalize only when plausibly ready.
score = float(getattr(obs, "current_score", 0.0) or 0.0)
step_count = int(getattr(obs, "step_count", 0) or 0)
max_steps = int(getattr(obs, "max_steps", 1) or 1)
if step_count >= int(0.75 * max_steps) and score >= 0.72 and instruction >= 0.60:
actions.append(DesignGymAction(action_type="finalize"))
# Deduplicate.
dedup: List[DesignGymAction] = []
seen = set()
for action in actions:
key = compact_action(action)
if key in seen:
continue
seen.add(key)
dedup.append(action)
# Avoid repeating identical action if possible.
filtered = [a for a in dedup if compact_action(a) not in set(recent_actions[-2:])]
return filtered or dedup or [DesignGymAction(action_type="finalize")]
def evaluate_candidate(env: DesignGymEnvironment, action: DesignGymAction) -> float:
"""Score a candidate by simulating it on a copied environment."""
try:
tmp = copy.deepcopy(env)
obs = tmp.step(action)
state = tmp.state
if getattr(obs, "last_action_error", None):
return -10.0
reward = float(getattr(state, "last_reward", 0.0) or 0.0)
layout_score = float(getattr(state, "current_score", 0.0) or 0.0)
instruction = float(getattr(state, "instruction_score", 0.0) or 0.0)
phase_score = float(getattr(state, "phase_score", 0.0) or 0.0)
# Reward is primary. Small tie-breaks prefer better final state.
return reward + 0.05 * instruction + 0.03 * layout_score + 0.02 * phase_score
except Exception:
return -10.0
def preferred_action_type_for_example(episode_idx: int, local_step: int, obs: Any) -> str | None:
phase = getattr(obs, "phase", "")
bucket = (episode_idx * 31 + local_step * 17) % 100
# Force some safe alignment examples.
if phase in {"placement", "refinement", "polish"} and bucket < 12:
return "align"
# Force some fine-grained movement examples.
if phase in {"refinement", "polish"} and 12 <= bucket < 20:
return "move"
return None
def choose_expert_action(
env: DesignGymEnvironment,
obs: Any,
preferred_action_type: str | None = None,
) -> DesignGymAction:
recent = list(getattr(env.state, "action_history", []) or [])
candidates = candidate_actions(obs, recent)
scored = [(a, evaluate_candidate(env, a)) for a in candidates]
scored.sort(key=lambda x: x[1], reverse=True)
ranked = [a for a, _ in scored]
# If this example is scheduled for action diversity, choose the best candidate
# of that type, but only if it is not terrible.
if preferred_action_type:
preferred = [
(a, score)
for a, score in scored
if a.action_type == preferred_action_type and score > -1.0
]
if preferred:
preferred.sort(key=lambda x: x[1], reverse=True)
return preferred[0][0]
# Normal expert choice with small top-k diversity.
if len(ranked) > 1:
rng_key = (
f"{getattr(env.state, 'seed', 0)}:"
f"{getattr(obs, 'task_id', '')}:"
f"{getattr(obs, 'step_count', 0)}:"
f"{len(recent)}"
)
rng = random.Random(rng_key)
if rng.random() < 0.12:
top_k = ranked[: min(4, len(ranked))]
return rng.choice(top_k)
return ranked[0]
def prompt_from_obs(obs: Any) -> str:
payload = {
"task_id": getattr(obs, "task_id", ""),
"step_count": getattr(obs, "step_count", 0),
"max_steps": getattr(obs, "max_steps", 0),
"phase": getattr(obs, "phase", ""),
"allowed_actions": getattr(obs, "allowed_actions", []),
"current_score": round(float(getattr(obs, "current_score", 0.0) or 0.0), 4),
"best_score_so_far": round(float(getattr(obs, "best_score_so_far", 0.0) or 0.0), 4),
"instruction_score": round(float(getattr(obs, "instruction_score", 0.0) or 0.0), 4),
"phase_score": round(float(getattr(obs, "phase_score", 0.0) or 0.0), 4),
"brief": getattr(obs, "brief", {}) or {},
"metrics": getattr(obs, "metrics", {}) or {},
"metric_deltas": getattr(obs, "metric_deltas", {}) or {},
"worst_metrics": getattr(obs, "worst_metrics", []) or [],
"focus_elements": getattr(obs, "focus_elements", []) or [],
"critic_feedback": getattr(obs, "critic_feedback", []) or [],
"layout_summary": getattr(obs, "layout_summary", "") or "",
}
action_schema = {
"action_type": "apply_template | anchor_to_region | resize | move | align | distribute | promote | reflow_group | finalize",
"optional_fields": {
"template_id": "hero | split | editorial | grid | draft",
"element_id": "single element id",
"element_ids": "list of element ids",
"region_id": "semantic target region",
"group_id": "semantic group id",
"pattern": "stack | row",
"axis": "x | y",
"mode": "left | center | top",
"dx_dy_dw_dh": "small normalized floats",
"strength": "small float for promote",
},
}
return (
"Choose the next best layout edit.\n"
"Output only one valid minified JSON action object.\n\n"
f"STATE:\n{json.dumps(payload, sort_keys=True)}\n\n"
f"ACTION_SCHEMA:\n{json.dumps(action_schema, sort_keys=True)}"
)
def make_example(obs: Any, action: DesignGymAction, metadata: Dict[str, Any]) -> Dict[str, Any]:
return {
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt_from_obs(obs)},
{"role": "assistant", "content": compact_action(action)},
],
**metadata,
}
def generate_examples(
*,
episodes: int,
seed: int,
max_steps_override: int | None,
tasks: Sequence[str],
) -> List[Dict[str, Any]]:
rng = random.Random(seed)
examples: List[Dict[str, Any]] = []
for episode_idx in range(episodes):
task_id = tasks[episode_idx % len(tasks)]
episode_seed = seed + episode_idx
env = DesignGymEnvironment()
obs = env.reset(task_id=task_id, seed=episode_seed)
max_steps = max_steps_override or int(getattr(obs, "max_steps", 8) or 8)
for local_step in range(max_steps):
if bool(getattr(obs, "done", False)) or bool(getattr(env.state, "done", False)):
break
preferred_type = preferred_action_type_for_example(episode_idx, local_step, obs)
action = choose_expert_action(env, obs, preferred_action_type=preferred_type)
before_obs = obs
obs = env.step(action)
metadata = {
"task_id": task_id,
"episode_seed": episode_seed,
"episode_index": episode_idx,
"step_index": local_step,
"expert_action": compact_action(action),
"reward_after": round(float(getattr(env.state, "last_reward", 0.0) or 0.0), 6),
"score_after": round(float(getattr(env.state, "current_score", 0.0) or 0.0), 6),
"instruction_score_after": round(float(getattr(env.state, "instruction_score", 0.0) or 0.0), 6),
"phase_after": getattr(env.state, "phase", ""),
}
examples.append(make_example(before_obs, action, metadata))
rng.shuffle(examples)
return examples
def write_jsonl(path: Path, rows: Iterable[Dict[str, Any]]) -> int:
path.parent.mkdir(parents=True, exist_ok=True)
count = 0
with path.open("w", encoding="utf-8") as f:
for row in rows:
f.write(json.dumps(row, ensure_ascii=False) + "\n")
count += 1
return count
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--episodes", type=int, default=300)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--max-steps", type=int, default=None)
parser.add_argument("--tasks", nargs="*", default=TASKS)
parser.add_argument("--out", type=str, default="data/sft/designgym2_sft_train.jsonl")
parser.add_argument("--eval-out", type=str, default="data/sft/designgym2_sft_eval.jsonl")
parser.add_argument("--eval-ratio", type=float, default=0.10)
args = parser.parse_args()
examples = generate_examples(
episodes=args.episodes,
seed=args.seed,
max_steps_override=args.max_steps,
tasks=args.tasks,
)
split_idx = int(len(examples) * (1.0 - args.eval_ratio))
train_rows = examples[:split_idx]
eval_rows = examples[split_idx:]
train_count = write_jsonl(Path(args.out), train_rows)
eval_count = write_jsonl(Path(args.eval_out), eval_rows)
print(f"[OK] generated total={len(examples)} train={train_count} eval={eval_count}")
print(f"[TRAIN] {args.out}")
print(f"[EVAL] {args.eval_out}")
if examples:
print("[SAMPLE]")
print(json.dumps(examples[0], indent=2)[:3000])
if __name__ == "__main__":
main()