ThreeGen / scene.py
bolajiev
Sprint 5: hierarchical GroupNode with layout engine
cf56b0c
Raw
History Blame Contribute Delete
27.2 kB
"""
Scene DSL: the constrained format the small model must emit.
The model never writes Three.js directly. It emits this JSON, which we then
validate + clamp + repair here, and compile to Three.js in compiler.py.
That separation is what keeps the live preview from ever breaking.
"""
from __future__ import annotations
import json
import logging
import re
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field, field_validator
log = logging.getLogger(__name__)
SHAPES = {
"box", "sphere", "cylinder", "cone", "torus", "torusKnot", "plane",
"tetrahedron", "icosahedron", "dodecahedron", "octahedron",
"capsule", "ring", "circle", "tube", "roundedBox",
}
EXTRUDE_SHAPES = {"star", "heart", "hexagon", "badge", "shield"}
MATERIALS = {"standard", "basic", "phong", "wireframe"}
PRESET_NAMES = {"gold", "chrome", "glass", "neon", "matte", "plastic"}
LIGHT_TYPES = {"ambient", "directional", "point"}
ANIM_TYPES = {"none", "rotate", "float", "orbit"}
HEX = re.compile(r"^#[0-9a-fA-F]{6}$")
# ---- Color normalisation (Fix 1) ----
_SYNONYMS: Dict[str, str] = {
"electric blue": "#7df9ff", "electricblue": "#7df9ff",
"neon green": "#39ff14", "neongreen": "#39ff14",
"neon": "#39ff14",
"neon blue": "#4d4dff", "neonblue": "#4d4dff",
"neon red": "#ff3131", "neonred": "#ff3131",
"neon pink": "#ff6ec7", "neonpink": "#ff6ec7",
"neon yellow": "#ffff00", "neonyellow": "#ffff00",
"neon orange": "#ff6600", "neonorange": "#ff6600",
}
_CSS_COLORS: frozenset = frozenset({
"aliceblue", "antiquewhite", "aqua", "aquamarine", "azure", "beige",
"bisque", "black", "blanchedalmond", "blue", "blueviolet", "brown",
"burlywood", "cadetblue", "chartreuse", "chocolate", "coral",
"cornflowerblue", "cornsilk", "crimson", "cyan", "darkblue", "darkcyan",
"darkgoldenrod", "darkgray", "darkgreen", "darkgrey", "darkkhaki",
"darkmagenta", "darkolivegreen", "darkorange", "darkorchid", "darkred",
"darksalmon", "darkseagreen", "darkslateblue", "darkslategray",
"darkslategrey", "darkturquoise", "darkviolet", "deeppink", "deepskyblue",
"dimgray", "dimgrey", "dodgerblue", "firebrick", "floralwhite",
"forestgreen", "fuchsia", "gainsboro", "ghostwhite", "gold", "goldenrod",
"gray", "green", "greenyellow", "grey", "honeydew", "hotpink",
"indianred", "indigo", "ivory", "khaki", "lavender", "lavenderblush",
"lawngreen", "lemonchiffon", "lightblue", "lightcoral", "lightcyan",
"lightgoldenrodyellow", "lightgray", "lightgreen", "lightgrey",
"lightpink", "lightsalmon", "lightseagreen", "lightskyblue",
"lightslategray", "lightslategrey", "lightsteelblue", "lightyellow",
"lime", "limegreen", "linen", "magenta", "maroon", "mediumaquamarine",
"mediumblue", "mediumorchid", "mediumpurple", "mediumseagreen",
"mediumslateblue", "mediumspringgreen", "mediumturquoise",
"mediumvioletred", "midnightblue", "mintcream", "mistyrose", "moccasin",
"navajowhite", "navy", "oldlace", "olive", "olivedrab", "orange",
"orangered", "orchid", "palegoldenrod", "palegreen", "paleturquoise",
"palevioletred", "papayawhip", "peachpuff", "peru", "pink", "plum",
"powderblue", "purple", "red", "rosybrown", "royalblue", "saddlebrown",
"salmon", "sandybrown", "seagreen", "seashell", "sienna", "silver",
"skyblue", "slateblue", "slategray", "slategrey", "snow", "springgreen",
"steelblue", "tan", "teal", "thistle", "tomato", "turquoise", "violet",
"wheat", "white", "whitesmoke", "yellow", "yellowgreen",
})
_HEX_RE = re.compile(r"^#[0-9a-fA-F]{3,8}$")
_RGB_HSL_RE = re.compile(
r"^(rgb|hsl)a?\(\s*[\d.]+%?\s*,\s*[\d.]+%?\s*,\s*[\d.]+%?\s*(?:,\s*[\d.]+)?\s*\)$"
)
def _sanitize_color(v: str, default: str = "#888888") -> str:
"""Accept hex, rgb/hsl(), CSS/X11 names, and synonym map. Reject anything else."""
v = str(v).strip()
lo = v.lower()
if lo in _SYNONYMS:
return _SYNONYMS[lo]
collapsed = lo.replace(" ", "").replace("-", "")
if collapsed in _SYNONYMS:
return _SYNONYMS[collapsed]
if _HEX_RE.match(v):
return v
if _RGB_HSL_RE.match(lo):
return lo
if collapsed in _CSS_COLORS:
return collapsed
log.warning("Unknown color %r, using default %s", v, default)
return default
def _clamp(v: float, lo: float, hi: float) -> float:
return max(lo, min(hi, v))
def _shape_extent(shape: str, params: Dict[str, float]) -> tuple:
"""Return (width, height, depth) bounding box for layout size computations."""
def p(k, d): return float(params.get(k, d))
if shape == "box":
return (p("width", 1.0), p("height", 1.0), p("depth", 1.0))
if shape == "sphere":
d = p("radius", 0.6) * 2
return (d, d, d)
if shape == "cylinder":
r = max(p("radiusTop", 0.5), p("radiusBottom", 0.5)) * 2
return (r, p("height", 1.0), r)
if shape == "cone":
return (p("radius", 0.5) * 2, p("height", 1.0), p("radius", 0.5) * 2)
if shape in ("torus", "torusKnot"):
r = (p("radius", 0.5) + p("tube", 0.2)) * 2
return (r, r, r)
if shape in ("tetrahedron", "icosahedron", "dodecahedron", "octahedron"):
d = p("radius", 0.6) * 2
return (d, d, d)
if shape == "plane":
return (p("width", 5.0), 0.01, p("height", 5.0))
if shape == "capsule":
d = p("radius", 0.4) * 2
return (d, p("length", 1.0) + d, d)
if shape in ("ring", "circle"):
r = p("outerRadius", p("radius", 0.6)) * 2
return (r, 0.01, r)
if shape == "tube":
return (1.0, 1.5, 1.0)
if shape == "roundedBox":
return (p("width", 1.0), p("height", 1.0), p("depth", 1.0))
return (1.0, 1.0, 1.0)
class Obj(BaseModel):
shape: str = "box"
position: List[float] = Field(default_factory=lambda: [0.0, 0.0, 0.0])
rotation: List[float] = Field(default_factory=lambda: [0.0, 0.0, 0.0])
scale: List[float] = Field(default_factory=lambda: [1.0, 1.0, 1.0])
color: str = "#88ccff"
material: str = "standard"
preset: Optional[str] = None
metalness: float = 0.3
roughness: float = 0.4
emissive: str = "#000000"
params: Dict[str, float] = Field(default_factory=dict)
@field_validator("shape")
@classmethod
def _shape(cls, v: Any) -> str:
v = str(v)
return v if v in SHAPES else "box"
@field_validator("material")
@classmethod
def _material(cls, v: Any) -> str:
v = str(v)
return v if v in MATERIALS else "standard"
@field_validator("preset")
@classmethod
def _preset_field(cls, v: Any) -> Optional[str]:
if v is None:
return None
v = str(v).lower().strip()
return v if v in PRESET_NAMES else None
@field_validator("position", "rotation", "scale")
@classmethod
def _vec3(cls, v: Any, info) -> List[float]:
fill = 1.0 if info.field_name == "scale" else 0.0
out: List[float] = []
try:
for x in list(v)[:3]:
out.append(float(x))
except Exception:
out = []
while len(out) < 3:
out.append(fill)
return out
@field_validator("color", "emissive")
@classmethod
def _hex(cls, v: Any, info) -> str:
default = "#88ccff" if info.field_name == "color" else "#000000"
return _sanitize_color(str(v), default)
@field_validator("metalness", "roughness")
@classmethod
def _unit(cls, v: Any) -> float:
try:
return _clamp(float(v), 0.0, 1.0)
except Exception:
return 0.4
@field_validator("params")
@classmethod
def _params(cls, v: Any) -> Dict[str, float]:
clean: Dict[str, float] = {}
if isinstance(v, dict):
for k, val in v.items():
try:
clean[str(k)] = _clamp(float(val), -50.0, 50.0)
except Exception:
continue
return clean
class Light(BaseModel):
type: str = "directional"
color: str = "#ffffff"
intensity: float = 1.0
position: List[float] = Field(default_factory=lambda: [5.0, 8.0, 6.0])
@field_validator("type")
@classmethod
def _type(cls, v: Any) -> str:
v = str(v)
return v if v in LIGHT_TYPES else "directional"
@field_validator("color")
@classmethod
def _hex(cls, v: Any) -> str:
return _sanitize_color(str(v), "#ffffff")
@field_validator("intensity")
@classmethod
def _intensity(cls, v: Any) -> float:
try:
return _clamp(float(v), 0.0, 10.0)
except Exception:
return 1.0
@field_validator("position")
@classmethod
def _vec3(cls, v: Any) -> List[float]:
out: List[float] = []
try:
for x in list(v)[:3]:
out.append(float(x))
except Exception:
out = []
while len(out) < 3:
out.append(5.0)
return out
class Animation(BaseModel):
type: str = "rotate"
speed: float = 1.0
axis: str = "y"
@field_validator("type")
@classmethod
def _type(cls, v: Any) -> str:
v = str(v)
return v if v in ANIM_TYPES else "rotate"
@field_validator("speed")
@classmethod
def _speed(cls, v: Any) -> float:
try:
return _clamp(float(v), 0.0, 5.0)
except Exception:
return 1.0
@field_validator("axis")
@classmethod
def _axis(cls, v: Any) -> str:
v = str(v)
return v if v in {"x", "y", "z"} else "y"
def _vec3_field(v: Any, default: float = 0.0) -> List[float]:
out: List[float] = []
try:
for x in list(v)[:3]:
out.append(float(x))
except Exception:
out = []
while len(out) < 3:
out.append(default)
return out
class LayoutStack(BaseModel):
"""Stack children along an axis, centering the total extent at the node's position."""
type: Literal["stack"] = "stack"
axis: str = "y"
gap: float = 0.05
position: List[float] = Field(default_factory=lambda: [0.0, 0.0, 0.0])
children: List[Any] = Field(default_factory=list)
@field_validator("axis")
@classmethod
def _axis(cls, v: Any) -> str:
return str(v) if str(v) in {"x", "y", "z"} else "y"
@field_validator("gap")
@classmethod
def _gap(cls, v: Any) -> float:
try: return max(0.0, float(v))
except Exception: return 0.05
@field_validator("position", mode="before")
@classmethod
def _pos(cls, v: Any) -> List[float]:
return _vec3_field(v)
@field_validator("children", mode="before")
@classmethod
def _children(cls, v: Any) -> List[Any]:
return [_parse_scene_item(c) for c in v] if isinstance(v, list) else []
class LayoutRow(BaseModel):
"""Lay out children in a row along the x-axis."""
type: Literal["row"] = "row"
gap: float = 0.3
position: List[float] = Field(default_factory=lambda: [0.0, 0.0, 0.0])
children: List[Any] = Field(default_factory=list)
@field_validator("gap")
@classmethod
def _gap(cls, v: Any) -> float:
try: return max(0.0, float(v))
except Exception: return 0.3
@field_validator("position", mode="before")
@classmethod
def _pos(cls, v: Any) -> List[float]:
return _vec3_field(v)
@field_validator("children", mode="before")
@classmethod
def _children(cls, v: Any) -> List[Any]:
return [_parse_scene_item(c) for c in v] if isinstance(v, list) else []
class LayoutGrid(BaseModel):
"""Lay out children in a grid on the x-z plane."""
type: Literal["grid"] = "grid"
cols: int = 2
gap_x: float = 0.3
gap_z: float = 0.3
position: List[float] = Field(default_factory=lambda: [0.0, 0.0, 0.0])
children: List[Any] = Field(default_factory=list)
@field_validator("cols")
@classmethod
def _cols(cls, v: Any) -> int:
try: return max(1, int(v))
except Exception: return 2
@field_validator("position", mode="before")
@classmethod
def _pos(cls, v: Any) -> List[float]:
return _vec3_field(v)
@field_validator("children", mode="before")
@classmethod
def _children(cls, v: Any) -> List[Any]:
return [_parse_scene_item(c) for c in v] if isinstance(v, list) else []
class ExtrudeNode(BaseModel):
"""A 2-D shape path extruded into 3-D with a bevel."""
type: Literal["extrude"] = "extrude"
shape: str = "badge"
depth: float = 0.2
bevel: bool = True
color: str = "#88ccff"
material: str = "standard"
preset: Optional[str] = None
metalness: float = 0.3
roughness: float = 0.4
emissive: str = "#000000"
position: List[float] = Field(default_factory=lambda: [0.0, 0.0, 0.0])
rotation: List[float] = Field(default_factory=lambda: [0.0, 0.0, 0.0])
scale: List[float] = Field(default_factory=lambda: [1.0, 1.0, 1.0])
@field_validator("shape")
@classmethod
def _shape(cls, v: Any) -> str:
v = str(v).lower()
return v if v in EXTRUDE_SHAPES else "badge"
@field_validator("depth")
@classmethod
def _depth(cls, v: Any) -> float:
try:
return _clamp(float(v), 0.02, 2.0)
except Exception:
return 0.2
@field_validator("color", "emissive")
@classmethod
def _hex(cls, v: Any, info) -> str:
default = "#88ccff" if info.field_name == "color" else "#000000"
return _sanitize_color(str(v), default)
@field_validator("material")
@classmethod
def _material(cls, v: Any) -> str:
v = str(v)
return v if v in MATERIALS else "standard"
@field_validator("preset")
@classmethod
def _preset_field(cls, v: Any) -> Optional[str]:
if v is None:
return None
v = str(v).lower().strip()
return v if v in PRESET_NAMES else None
@field_validator("metalness", "roughness")
@classmethod
def _unit(cls, v: Any) -> float:
try:
return _clamp(float(v), 0.0, 1.0)
except Exception:
return 0.4
@field_validator("position", "rotation", "scale")
@classmethod
def _vec3(cls, v: Any, info) -> List[float]:
fill = 1.0 if info.field_name == "scale" else 0.0
out: List[float] = []
try:
for x in list(v)[:3]:
out.append(float(x))
except Exception:
out = []
while len(out) < 3:
out.append(fill)
return out
class Text3DNode(BaseModel):
"""3-D text rendered via Three.js TextGeometry + FontLoader (Latin chars only)."""
type: Literal["text3d"] = "text3d"
text: str = "TEXT"
size: float = 0.6
depth: float = 0.2
bevel: bool = True
color: str = "#88ccff"
material: str = "standard"
preset: Optional[str] = None
metalness: float = 0.3
roughness: float = 0.4
emissive: str = "#000000"
position: List[float] = Field(default_factory=lambda: [0.0, 0.0, 0.0])
rotation: List[float] = Field(default_factory=lambda: [0.0, 0.0, 0.0])
scale: List[float] = Field(default_factory=lambda: [1.0, 1.0, 1.0])
@field_validator("text")
@classmethod
def _text(cls, v: Any) -> str:
v = "".join(c for c in str(v).strip() if c.isprintable() and ord(c) < 128)[:24]
return v or "TEXT"
@field_validator("size")
@classmethod
def _size(cls, v: Any) -> float:
try:
return _clamp(float(v), 0.1, 4.0)
except Exception:
return 0.6
@field_validator("depth")
@classmethod
def _depth(cls, v: Any) -> float:
try:
return _clamp(float(v), 0.05, 1.0)
except Exception:
return 0.2
@field_validator("color", "emissive")
@classmethod
def _hex(cls, v: Any, info) -> str:
default = "#88ccff" if info.field_name == "color" else "#000000"
return _sanitize_color(str(v), default)
@field_validator("material")
@classmethod
def _material(cls, v: Any) -> str:
v = str(v)
return v if v in MATERIALS else "standard"
@field_validator("preset")
@classmethod
def _preset_field(cls, v: Any) -> Optional[str]:
if v is None:
return None
v = str(v).lower().strip()
return v if v in PRESET_NAMES else None
@field_validator("metalness", "roughness")
@classmethod
def _unit(cls, v: Any) -> float:
try:
return _clamp(float(v), 0.0, 1.0)
except Exception:
return 0.4
@field_validator("position", "rotation", "scale")
@classmethod
def _vec3(cls, v: Any, info) -> List[float]:
fill = 1.0 if info.field_name == "scale" else 0.0
out: List[float] = []
try:
for x in list(v)[:3]:
out.append(float(x))
except Exception:
out = []
while len(out) < 3:
out.append(fill)
return out
LAYOUT_TYPES = {"none", "row", "column", "stack", "grid"}
class GroupNode(BaseModel):
"""A group of child nodes with optional layout and group-level transform."""
type: Literal["group"] = "group"
layout: str = "none"
gap: float = 0.2
cols: int = 3
position: List[float] = Field(default_factory=lambda: [0.0, 0.0, 0.0])
rotation: List[float] = Field(default_factory=lambda: [0.0, 0.0, 0.0])
scale: List[float] = Field(default_factory=lambda: [1.0, 1.0, 1.0])
children: List[Any] = Field(default_factory=list)
@field_validator("layout")
@classmethod
def _layout(cls, v: Any) -> str:
v = str(v).lower().strip()
return v if v in LAYOUT_TYPES else "none"
@field_validator("gap")
@classmethod
def _gap(cls, v: Any) -> float:
try:
return _clamp(float(v), 0.0, 20.0)
except Exception:
return 0.2
@field_validator("cols")
@classmethod
def _cols(cls, v: Any) -> int:
try:
return max(1, min(int(v), 20))
except Exception:
return 3
@field_validator("position", "rotation", "scale")
@classmethod
def _vec3(cls, v: Any, info) -> List[float]:
fill = 1.0 if info.field_name == "scale" else 0.0
out: List[float] = []
try:
for x in list(v)[:3]:
out.append(float(x))
except Exception:
out = []
while len(out) < 3:
out.append(fill)
return out
@field_validator("children", mode="before")
@classmethod
def _children(cls, v: Any) -> List[Any]:
return [_parse_scene_item(c) for c in v] if isinstance(v, list) else []
def _parse_scene_item(v: Any) -> Any:
"""Parse a dict (or existing model) into the correct scene node type."""
if isinstance(v, (Obj, LayoutStack, LayoutRow, LayoutGrid, ExtrudeNode, Text3DNode, GroupNode)):
return v
if not isinstance(v, dict):
return Obj()
t = v.get("type", "")
try:
if t == "group":
return GroupNode(**v)
if t == "stack":
return LayoutStack(**v)
if t == "row":
return LayoutRow(**v)
if t == "grid":
return LayoutGrid(**v)
if t == "extrude":
return ExtrudeNode(**v)
if t == "text3d":
return Text3DNode(**v)
return Obj(**v)
except Exception:
return Obj()
# ---- Composite templates (deterministic Python expansions) ----
def _template_burger(params: Dict[str, Any]) -> List[Obj]:
bun = params.get("color_bun", "#c8a96e")
patty = params.get("color_patty", "#5a3a1a")
lettuce = params.get("color_lettuce", "#3a8a3a")
return [
Obj(shape="sphere", color=bun, position=[0, 0.65, 0], params={"radius": 0.45}, roughness=0.7),
Obj(shape="cylinder", color=lettuce, position=[0, 0.28, 0], params={"radiusTop": 0.52, "radiusBottom": 0.52, "height": 0.1}, roughness=0.9),
Obj(shape="cylinder", color=patty, position=[0, 0.12, 0], params={"radiusTop": 0.5, "radiusBottom": 0.5, "height": 0.18}, roughness=0.8),
Obj(shape="cylinder", color=bun, position=[0, -0.18, 0], params={"radiusTop": 0.52, "radiusBottom": 0.55, "height": 0.32}, roughness=0.7),
]
def _template_snowman(params: Dict[str, Any]) -> List[Obj]:
body = params.get("color_body", "#e8e8e8")
hat = params.get("color_hat", "#1a1a1a")
return [
Obj(shape="sphere", color=body, position=[0, -0.55, 0], params={"radius": 0.5}, roughness=0.9),
Obj(shape="sphere", color=body, position=[0, 0.2, 0], params={"radius": 0.35}, roughness=0.9),
Obj(shape="sphere", color=body, position=[0, 0.82, 0], params={"radius": 0.24}, roughness=0.9),
Obj(shape="cylinder", color=hat, position=[0, 1.18, 0], params={"radiusTop": 0.16, "radiusBottom": 0.26, "height": 0.34}),
]
def _template_tree(params: Dict[str, Any]) -> List[Obj]:
leaves = params.get("color_leaves", "#2e8b57")
trunk = params.get("color_trunk", "#8b5a2b")
return [
Obj(shape="cylinder", color=trunk, position=[0, -0.6, 0], params={"radiusTop": 0.12, "radiusBottom": 0.15, "height": 0.8}, roughness=0.9),
Obj(shape="cone", color=leaves, position=[0, 0.2, 0], params={"radius": 0.7, "height": 1.0}, roughness=0.8),
Obj(shape="cone", color=leaves, position=[0, 0.72, 0], params={"radius": 0.55, "height": 0.8}, roughness=0.8),
Obj(shape="cone", color=leaves, position=[0, 1.14, 0], params={"radius": 0.4, "height": 0.65}, roughness=0.8),
]
def _template_nested_spheres(params: Dict[str, Any]) -> List[Obj]:
inner = params.get("color_inner", "red")
outer = params.get("color_outer", "blue")
return [
Obj(shape="sphere", color=outer, material="wireframe",
position=[0, 0, 0], params={"radius": 0.8}),
Obj(shape="sphere", color=inner, material="wireframe",
position=[0, 0, 0], params={"radius": 0.45}),
]
# Per-shape: safe text face width (world units after normalization to 1.5)
# and a small y nudge so text sits in the visual body, not the tapered parts.
_BADGE_SHAPE_PARAMS: Dict[str, Dict[str, float]] = {
"star": {"safe_w": 0.85, "text_y": 0.0},
"heart": {"safe_w": 0.65, "text_y": 0.2},
"hexagon": {"safe_w": 1.1, "text_y": 0.0},
"badge": {"safe_w": 1.2, "text_y": 0.0},
"shield": {"safe_w": 0.9, "text_y": 0.1},
}
_CHAR_W = 0.65 # helvetiker average char width per size=1.0 unit
def _template_badge_with_text(params: Dict[str, Any]) -> List[Any]:
"""Deterministic badge+text: model supplies shape/text/colors; compiler sets layout."""
shape = str(params.get("shape", "star")).lower()
if shape not in EXTRUDE_SHAPES:
shape = "star"
text_raw = str(params.get("text", "TEXT"))
text = "".join(c for c in text_raw if c.isprintable() and ord(c) < 128)[:24].strip() or "TEXT"
color_badge = _sanitize_color(str(params.get("color_badge", "#3a6bc4")), "#3a6bc4")
color_text = _sanitize_color(str(params.get("color_text", "#ffffff")), "#ffffff")
badge_metal = _clamp(float(params.get("metalness", 0.5)), 0.0, 1.0)
badge_rough = _clamp(float(params.get("roughness", 0.25)), 0.0, 1.0)
preset_badge = str(params.get("preset_badge", "")) or None
if preset_badge not in PRESET_NAMES:
preset_badge = None
preset_text = str(params.get("preset_text", "")) or None
if preset_text not in PRESET_NAMES:
preset_text = None
sp = _BADGE_SHAPE_PARAMS.get(shape, {"safe_w": 1.0, "text_y": 0.0})
# Scale text so it fills ~80 % of the badge face width
text_size = round(
_clamp(sp["safe_w"] * 0.8 / (max(1, len(text)) * _CHAR_W), 0.12, 0.55), 3
)
return [
ExtrudeNode(
shape=shape, depth=0.15, bevel=True,
color=color_badge, preset=preset_badge,
metalness=badge_metal, roughness=badge_rough,
emissive="#000000", position=[0.0, 0.0, 0.0],
),
Text3DNode(
text=text, size=text_size, depth=0.06, bevel=True,
color=color_text, preset=preset_text,
metalness=0.1, roughness=0.4, emissive="#000000",
# z=0.15 always clears the badge front face (badge front ≈ 0.05–0.07 after scale)
position=[0.0, sp["text_y"], 0.15],
),
]
TEMPLATES: Dict[str, Any] = {
"burger": _template_burger,
"snowman": _template_snowman,
"tree": _template_tree,
"nested_spheres": _template_nested_spheres,
"badge_with_text": _template_badge_with_text,
}
class Scene(BaseModel):
background: str = "#0b0e14"
objects: List[Union[Obj, LayoutStack, LayoutRow, LayoutGrid, ExtrudeNode, Text3DNode, GroupNode]] = Field(default_factory=list)
lights: List[Light] = Field(default_factory=list)
animation: Animation = Field(default_factory=Animation)
@field_validator("background")
@classmethod
def _bg(cls, v: Any) -> str:
return _sanitize_color(str(v), "#0b0e14")
@field_validator("objects", mode="before")
@classmethod
def _objects(cls, v: Any) -> List[Any]:
if not isinstance(v, list):
return []
return [_parse_scene_item(item) for item in v]
def extract_json(text: str) -> Optional[dict]:
"""Pull the first balanced JSON object out of a raw model response."""
if not text:
return None
text = text.strip()
text = re.sub(r"^```(?:json)?", "", text).strip()
text = re.sub(r"```$", "", text).strip()
start = text.find("{")
end = text.rfind("}")
if start == -1 or end == -1 or end <= start:
return None
try:
return json.loads(text[start:end + 1])
except Exception:
return None
def build_scene(data: Optional[dict]) -> Scene:
"""Validate/repair into a Scene, guaranteeing something renderable."""
if not isinstance(data, dict):
log.warning("build_scene: no valid dict, using empty scene")
data = {}
# Expand named templates into flat object lists before schema validation
tmpl = data.get("template")
if isinstance(tmpl, dict) and not data.get("objects"):
name = tmpl.get("name", "")
if name in TEMPLATES:
log.info("Expanding template: %s", name)
data = {k: v for k, v in data.items() if k != "template"}
data["objects"] = [o.model_dump() for o in TEMPLATES[name](tmpl)]
try:
scene = Scene(**data)
except Exception as e:
log.warning("Scene validation failed (%s), falling back to default", type(e).__name__)
scene = Scene()
if not scene.objects:
log.warning("build_scene: no objects, inserting default box")
scene.objects = [Obj()]
if not scene.lights:
log.warning("build_scene: no lights, inserting defaults")
scene.lights = [
Light(type="ambient", intensity=0.5),
Light(type="directional", intensity=1.3, position=[5, 8, 6]),
]
return scene