`;
node.querySelector(".kin-node-delete").addEventListener("click", (event) => {
event.stopPropagation();
deleteNode(link.id);
});
const input = node.querySelector("input");
input.addEventListener("focus", () => {
state.selectedNode = link.id;
state.selectedEdge = null;
state.pendingParent = null;
state.connectMode = false;
setStatus(`Link ${link.id} selected`);
updatePromptStatus();
});
input.addEventListener("input", () => {
link.name = input.value;
syncTree();
updatePromptStatus();
});
node.addEventListener("click", (event) => {
if (event.target.tagName === "INPUT" || event.target.tagName === "BUTTON") {
return;
}
handleNodeClick(link.id);
});
node.addEventListener("pointerdown", (event) => {
if (event.target.tagName === "INPUT" || event.target.tagName === "BUTTON") {
return;
}
state.dragging = {
id: link.id,
startX: event.clientX,
startY: event.clientY,
nodeX: link.x,
nodeY: link.y,
moved: false
};
node.classList.add("dragging");
node.setPointerCapture(event.pointerId);
});
node.addEventListener("pointermove", (event) => {
if (!state.dragging || state.dragging.id !== link.id) {
return;
}
const dx = event.clientX - state.dragging.startX;
const dy = event.clientY - state.dragging.startY;
if (Math.abs(dx) + Math.abs(dy) > 3) {
state.dragging.moved = true;
}
const { width: nodeWidth, height: nodeHeight } = currentNodeSize();
const maxX = Math.max(0, canvas.clientWidth - nodeWidth - 10);
const maxY = Math.max(0, canvas.clientHeight - nodeHeight - 10);
link.x = Math.min(maxX, Math.max(0, state.dragging.nodeX + dx));
link.y = Math.min(maxY, Math.max(0, state.dragging.nodeY + dy));
node.style.left = `${link.x}px`;
node.style.top = `${link.y}px`;
renderEdges();
syncTree();
});
node.addEventListener("pointerup", () => {
if (state.dragging && state.dragging.moved) {
state.suppressClick = true;
window.setTimeout(() => {
state.suppressClick = false;
}, 0);
}
state.dragging = null;
node.classList.remove("dragging");
});
nodeLayer.appendChild(node);
});
syncTree();
updatePromptStatus();
}
function escapeHtml(value) {
return String(value)
.replaceAll("&", "&")
.replaceAll('"', """)
.replaceAll("<", "<")
.replaceAll(">", ">");
}
function handleNodeClick(nodeId) {
if (state.suppressClick) {
return;
}
state.selectedEdge = null;
if (!state.connectMode) {
state.selectedNode = nodeId;
state.pendingParent = null;
setStatus(`Link ${nodeId} selected`);
updatePromptStatus();
render();
return;
}
if (state.pendingParent === null) {
state.pendingParent = nodeId;
state.selectedNode = nodeId;
setStatus(`Parent Link ${nodeId} selected. Click a child link.`);
render();
return;
}
const parent = state.pendingParent;
const child = nodeId;
const result = addJoint(parent, child, "revolute");
state.pendingParent = null;
state.selectedNode = null;
state.connectMode = false;
setStatus(result);
render();
}
function addNode() {
const id = state.links.length;
const position = defaultNodePosition(id);
state.links.push({
...position,
id,
name: id === 0 ? "base" : `link_${id}`,
color: palette[id % palette.length]
});
state.selectedNode = id;
state.selectedEdge = null;
setStatus(`Added Link ${id}`);
updatePromptStatus();
render();
}
function deleteNode(nodeId) {
const oldLinks = state.links.filter((link) => link.id !== nodeId);
const idMap = new Map();
oldLinks.forEach((link, index) => {
idMap.set(link.id, index);
link.id = index;
link.color = palette[index % palette.length];
});
state.links = oldLinks;
const remappedPrompts = {};
Object.entries(state.prompts).forEach(([rawLinkId, prompt]) => {
const nextId = idMap.get(Number(rawLinkId));
if (nextId !== undefined) {
remappedPrompts[nextId] = prompt;
}
});
state.prompts = remappedPrompts;
state.joints = state.joints
.filter((joint) => joint.parent !== nodeId && joint.child !== nodeId)
.map((joint) => ({
parent: idMap.get(joint.parent),
child: idMap.get(joint.child),
type: joint.type
}))
.filter((joint) => joint.parent !== undefined && joint.child !== undefined);
state.selectedNode = null;
state.selectedEdge = null;
state.pendingParent = null;
setStatus(`Deleted Link ${nodeId}`);
syncPrompts();
updatePromptStatus();
render();
}
function addJoint(parent, child, type) {
if (parent === child) {
return "A joint cannot connect a link to itself.";
}
if (state.joints.some((joint) => joint.parent === parent && joint.child === child)) {
return "That joint already exists.";
}
if (state.joints.some((joint) => joint.child === child)) {
return `Link ${child} already has a parent.`;
}
if (wouldCreateCycle(parent, child)) {
return "That joint would create a cycle.";
}
state.joints.push({ parent, child, type });
return `Added ${type} joint: Link ${parent} to Link ${child}`;
}
function wouldCreateCycle(parent, child) {
const stack = [child];
const seen = new Set();
while (stack.length > 0) {
const current = stack.pop();
if (current === parent) {
return true;
}
if (seen.has(current)) {
continue;
}
seen.add(current);
state.joints
.filter((joint) => joint.parent === current)
.forEach((joint) => stack.push(joint.child));
}
return false;
}
addNodeButton.addEventListener("click", addNode);
addJointButton.addEventListener("click", () => {
state.connectMode = !state.connectMode;
state.pendingParent = null;
state.selectedEdge = null;
setStatus(state.connectMode ? "Creating a revolute joint. Click a parent link, then a child link." : "Joint creation cancelled.");
render();
});
deleteButton.addEventListener("click", () => {
if (state.selectedNode !== null) {
deleteNode(state.selectedNode);
return;
}
if (state.selectedEdge !== null) {
deleteJoint(state.selectedEdge);
return;
}
setStatus("Select a link or joint to delete.");
});
resetButton.addEventListener("click", () => {
loadTree(defaultTree);
setStatus("Reset to the default tree.");
});
if (promptCanvas) {
promptCanvas.addEventListener("pointerdown", (event) => {
state.promptDrag = {
startX: event.clientX,
startY: event.clientY,
yaw: state.promptCamera.yaw,
pitch: state.promptCamera.pitch,
moved: false
};
promptCanvas.setPointerCapture(event.pointerId);
});
promptCanvas.addEventListener("pointermove", (event) => {
if (!state.promptDrag) {
return;
}
const dx = event.clientX - state.promptDrag.startX;
const dy = event.clientY - state.promptDrag.startY;
if (Math.abs(dx) + Math.abs(dy) > 3) {
state.promptDrag.moved = true;
}
if (state.promptDrag.moved) {
state.promptCamera.yaw = state.promptDrag.yaw - dx * 0.01;
state.promptCamera.pitch = state.promptDrag.pitch + dy * 0.01;
renderPromptMesh();
}
});
promptCanvas.addEventListener("pointerup", (event) => {
const drag = state.promptDrag;
state.promptDrag = null;
if (drag && !drag.moved) {
pickPromptPoint(event);
}
});
promptCanvas.addEventListener("wheel", (event) => {
event.preventDefault();
const factor = Math.exp(event.deltaY * 0.0012);
const minDistance = Math.max(1e-5, state.promptCamera.radius * 0.35);
const maxDistance = Math.max(1, state.promptCamera.radius * 8);
state.promptCamera.distance = Math.max(
minDistance,
Math.min(maxDistance, state.promptCamera.distance * factor)
);
renderPromptMesh();
}, { passive: false });
}
if (clearPromptButton) {
clearPromptButton.addEventListener("click", () => {
if (state.selectedNode === null) {
updatePromptStatus("Select a link before clearing its prompt.");
return;
}
delete state.prompts[state.selectedNode];
syncPrompts();
updatePromptStatus();
render();
renderPromptMesh();
});
}
if (clearAllPromptsButton) {
clearAllPromptsButton.addEventListener("click", () => {
state.prompts = {};
syncPrompts();
updatePromptStatus("Cleared all point prompts.");
render();
renderPromptMesh();
});
}
if (resetPromptViewButton) {
resetPromptViewButton.addEventListener("click", resetPromptCamera);
}
if (promptMeshBox) {
promptMeshBox.addEventListener("input", loadPromptMeshFromBox);
promptMeshBox.addEventListener("change", loadPromptMeshFromBox);
}
syncBox.addEventListener("input", loadExternalTreeFromBox);
syncBox.addEventListener("change", loadExternalTreeFromBox);
if (promptSyncBox) {
promptSyncBox.addEventListener("input", loadExternalPromptsFromBox);
promptSyncBox.addEventListener("change", loadExternalPromptsFromBox);
}
canvas.addEventListener("click", (event) => {
if (event.target === canvas || event.target === edgeLayer || event.target === nodeLayer) {
state.selectedNode = null;
state.selectedEdge = null;
state.pendingParent = null;
updatePromptStatus();
render();
}
});
window.addEventListener("resize", scheduleResponsiveRelayout);
if (typeof ResizeObserver !== "undefined") {
const resizeObserver = new ResizeObserver(scheduleResponsiveRelayout);
resizeObserver.observe(canvas);
}
window.setInterval(() => {
loadPromptMeshFromBox();
loadExternalTreeFromBox();
loadExternalPromptsFromBox();
}, 900);
loadTree(defaultTree);
loadPromptMeshFromBox();
loadExternalTreeFromBox();
loadExternalPromptsFromBox();
renderPromptMesh();
syncPrompts();
scheduleResponsiveRelayout();
window.setTimeout(scheduleResponsiveRelayout, 50);
window.setTimeout(scheduleResponsiveRelayout, 250);
setStatus("Drag links to arrange the tree.");
}
waitForEditor();
}
"""
)
def _extract_gradio_path(value: Any) -> Path | None:
if value is None:
return None
if isinstance(value, dict):
raw_path = value.get("path") or value.get("name")
else:
raw_path = value
if raw_path is None:
return None
return Path(str(raw_path)).expanduser().resolve()
def _tree_to_pretty_json(tree: dict[str, Any]) -> str:
return json.dumps(tree, indent=2)
def _load_json_object(raw_value: str) -> dict[str, Any]:
try:
parsed = json.loads(raw_value)
except json.JSONDecodeError as exc:
raise ValueError(f"Kinematic tree must be valid JSON: {exc}") from exc
if not isinstance(parsed, dict):
raise ValueError("Kinematic tree JSON must be an object.")
return parsed
def _link_identifier(link: Any, fallback_id: int) -> tuple[int, str]:
if isinstance(link, str):
return fallback_id, link.strip()
if isinstance(link, dict):
raw_id = link.get("id", fallback_id)
raw_name = link.get("name", f"link_{raw_id}")
return int(raw_id), str(raw_name).strip()
raise ValueError(f"Links must be strings or objects, got {type(link).__name__}.")
def _resolve_link_ref(raw_value: Any, *, name_to_id: dict[str, int], num_links: int) -> int:
if isinstance(raw_value, str):
stripped = raw_value.strip()
if stripped in name_to_id:
return int(name_to_id[stripped])
if stripped.lstrip("+-").isdigit():
raw_value = int(stripped)
else:
raise ValueError(f"Unknown link reference {raw_value!r}.")
link_id = int(raw_value)
if not 0 <= link_id < num_links:
raise ValueError(f"Link ID {link_id} is outside [0, {num_links - 1}].")
return link_id
def _joint_type_from_record(record: dict[str, Any]) -> str:
raw_type = (
record.get("type")
or record.get("joint_type")
or record.get("motion_type")
)
if raw_type is None:
if bool(record.get("is_revolute", False)):
raw_type = "revolute"
elif bool(record.get("is_prismatic", False)):
raw_type = "prismatic"
joint_type = str(raw_type or "").strip().lower()
if joint_type not in {"revolute", "prismatic"}:
raise ValueError(
"Each joint must specify type 'revolute' or 'prismatic'."
)
return joint_type
def parse_kinematic_tree(raw_value: str) -> tuple[list[str], list[tuple[int, int, str]]]:
payload = _load_json_object(raw_value)
raw_links = payload.get("links", payload.get("link_names"))
if not isinstance(raw_links, list) or not raw_links:
raise ValueError("Kinematic tree must contain a non-empty 'links' list.")
link_records = [_link_identifier(link, idx) for idx, link in enumerate(raw_links)]
ids = [link_id for link_id, _ in link_records]
if sorted(ids) != list(range(len(ids))):
raise ValueError("Link IDs must be dense integers starting at 0.")
link_names_by_id = {link_id: name for link_id, name in link_records}
link_names = [link_names_by_id[idx] for idx in range(len(link_records))]
if any(not name for name in link_names):
raise ValueError("Link names must be non-empty.")
name_to_id = {name: idx for idx, name in enumerate(link_names)}
raw_joints = payload.get("joints", [])
if not isinstance(raw_joints, list):
raise ValueError("'joints' must be a list.")
joint_specs: list[tuple[int, int, str]] = []
for joint in raw_joints:
if isinstance(joint, (list, tuple)) and len(joint) == 3:
parent_ref, child_ref, joint_type_ref = joint
joint_record = {
"parent": parent_ref,
"child": child_ref,
"type": joint_type_ref,
}
elif isinstance(joint, dict):
joint_record = joint
else:
raise ValueError(
"Each joint must be an object or [parent, child, type] list."
)
parent_ref = joint_record.get(
"parent",
joint_record.get("parent_id", joint_record.get("parent_link_id")),
)
child_ref = joint_record.get(
"child",
joint_record.get("child_id", joint_record.get("child_link_id")),
)
if parent_ref is None or child_ref is None:
raise ValueError("Each joint must define parent and child links.")
parent_id = _resolve_link_ref(
parent_ref,
name_to_id=name_to_id,
num_links=len(link_names),
)
child_id = _resolve_link_ref(
child_ref,
name_to_id=name_to_id,
num_links=len(link_names),
)
joint_specs.append((parent_id, child_id, _joint_type_from_record(joint_record)))
build_joint_tensors(len(link_names), joint_specs)
return link_names, joint_specs
def _compact_json(value: dict[str, Any]) -> str:
return json.dumps(value, separators=(",", ":"))
def _prompt_mesh_payload(mesh: Any) -> str:
vertices = np.asarray(mesh.vertices, dtype=np.float32)
faces = np.asarray(mesh.faces, dtype=np.int64)
if vertices.ndim != 2 or vertices.shape[1] != 3 or len(vertices) == 0:
raise ValueError("Prompt picker requires a mesh with 3D vertices.")
if faces.ndim != 2 or faces.shape[1] != 3 or len(faces) == 0:
raise ValueError("Prompt picker requires a triangular mesh.")
face_count = int(faces.shape[0])
used_vertex_ids = np.unique(faces.reshape(-1))
remap = np.full((vertices.shape[0],), -1, dtype=np.int64)
remap[used_vertex_ids] = np.arange(len(used_vertex_ids), dtype=np.int64)
compact_vertices = vertices[used_vertex_ids]
compact_faces = remap[faces].astype(np.int32, copy=False)
triangles = compact_vertices[compact_faces]
normals = np.cross(
triangles[:, 1] - triangles[:, 0],
triangles[:, 2] - triangles[:, 0],
).astype(np.float32, copy=False)
normal_lengths = np.linalg.norm(normals, axis=1, keepdims=True)
normals = np.divide(
normals,
np.maximum(normal_lengths, np.float32(1e-8)),
out=np.zeros_like(normals),
)
bbox_min = vertices.min(axis=0)
bbox_max = vertices.max(axis=0)
center = ((bbox_min + bbox_max) * 0.5).astype(np.float32, copy=False)
radius = float(np.linalg.norm((bbox_max - bbox_min) * 0.5))
if radius <= 0.0:
radius = 1.0
payload = {
"vertices": np.round(compact_vertices, 6).tolist(),
"faces": compact_faces.tolist(),
"normals": np.round(normals, 6).tolist(),
"center": np.round(center, 6).tolist(),
"radius": radius,
"source_faces": face_count,
"display_faces": int(compact_faces.shape[0]),
"sampled": False,
}
return _compact_json(payload)
def _up_dir_slug(up_dir: str) -> str:
return up_dir.replace("+", "pos").replace("-", "neg")
def _upright_rendering_placeholder_path() -> str:
from PIL import Image, ImageDraw, ImageFont
placeholder_path = OUTPUT_ROOT / "_ui" / "upright_orientation_rendering.png"
if not placeholder_path.exists():
placeholder_path.parent.mkdir(parents=True, exist_ok=True)
image = Image.new("RGB", (768, 384), (248, 250, 252))
draw = ImageDraw.Draw(image)
try:
font = ImageFont.truetype("DejaVuSans.ttf", 38)
small_font = ImageFont.truetype("DejaVuSans.ttf", 20)
except Exception:
font = ImageFont.load_default()
small_font = ImageFont.load_default()
title = "Rendering"
subtitle = "Preparing six upright orientation previews..."
title_box = draw.textbbox((0, 0), title, font=font)
subtitle_box = draw.textbbox((0, 0), subtitle, font=small_font)
draw.text(
((768 - (title_box[2] - title_box[0])) / 2, 150),
title,
fill=(17, 24, 39),
font=font,
)
draw.text(
((768 - (subtitle_box[2] - subtitle_box[0])) / 2, 204),
subtitle,
fill=(75, 85, 99),
font=small_font,
)
image.save(placeholder_path)
return str(placeholder_path)
def _upright_rendering_preview_paths() -> list[str]:
placeholder_path = _upright_rendering_placeholder_path()
return [placeholder_path for _ in UP_DIR_CHOICES]
def _upright_preview_paths(gallery_items: list[tuple[str, str]]) -> list[str | None]:
paths: list[str | None] = [path for path, _caption in gallery_items]
if len(paths) < len(UP_DIR_CHOICES):
paths.extend([None] * (len(UP_DIR_CHOICES) - len(paths)))
return paths[: len(UP_DIR_CHOICES)]
def _mesh_file_sha256(mesh_path: Path) -> str:
digest = hashlib.sha256()
with mesh_path.open("rb") as file:
for chunk in iter(lambda: file.read(1024 * 1024), b""):
digest.update(chunk)
return digest.hexdigest()
def _mesh_cache_root() -> Path:
return OUTPUT_ROOT / "mesh_cache"
def _mesh_cache_dir(mesh_hash: str) -> Path:
return _mesh_cache_root() / str(mesh_hash)
def _timestamped_mesh_output_dir(output_root: Path, mesh_hash: str, suffix: str = "") -> Path:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
name_parts = [str(mesh_hash)]
if suffix:
name_parts.append(str(suffix))
name_parts.append(timestamp)
return output_root / "_".join(name_parts)
def _copy_original_input_mesh(mesh_path: Path, output_dir: Path) -> Path:
suffix = mesh_path.suffix.lower() or ".mesh"
output_path = output_dir / f"input_mesh_original{suffix}"
output_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(mesh_path, output_path)
return output_path
def _upright_preview_cache_dir(mesh_hash: str) -> Path:
return _mesh_cache_dir(mesh_hash) / "upright_previews"
def _cached_upright_preview_items(mesh_hash: str) -> list[tuple[str, str]] | None:
preview_dir = _upright_preview_cache_dir(mesh_hash)
gallery_items: list[tuple[str, str]] = []
for up_dir in UP_DIR_CHOICES:
output_path = preview_dir / f"up_{_up_dir_slug(up_dir)}.png"
if not output_path.exists():
return None
gallery_items.append((str(output_path), f"{up_dir} up"))
return gallery_items
def _auto_kinematics_cache_dir(mesh_hash: str) -> Path:
return _mesh_cache_dir(mesh_hash) / "auto_kinematics"
def _auto_kinematics_cache_complete_path(cache_dir: Path) -> Path:
return cache_dir / "cache_complete.json"
def _cached_auto_kinematics(
cache_dir: Path,
) -> tuple[str, str] | None:
complete_path = _auto_kinematics_cache_complete_path(cache_dir)
if not complete_path.exists():
return None
complete_payload = json.loads(complete_path.read_text(encoding="utf-8"))
if int(complete_payload.get("version", 0)) != AUTO_KINEMATICS_CACHE_VERSION:
return None
tree_path = cache_dir / "demo_kinematic_tree.json"
prompt_path = cache_dir / "demo_point_prompts.json"
if not tree_path.exists() or not prompt_path.exists():
return None
return (
tree_path.read_text(encoding="utf-8").strip(),
prompt_path.read_text(encoding="utf-8").strip(),
)
def _store_auto_kinematics_cache(source_dir: Path, cache_dir: Path) -> None:
temp_dir = cache_dir.parent / f".{cache_dir.name}.tmp_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}"
if temp_dir.exists():
shutil.rmtree(temp_dir)
temp_dir.parent.mkdir(parents=True, exist_ok=True)
shutil.copytree(source_dir, temp_dir)
tree_path = temp_dir / "demo_kinematic_tree.json"
prompt_path = temp_dir / "demo_point_prompts.json"
renders_dir = temp_dir / "renders"
render_paths = sorted(renders_dir.glob("view_*.png")) if renders_dir.exists() else []
missing_items: list[str] = []
if not tree_path.exists():
missing_items.append(str(tree_path.name))
if not prompt_path.exists():
missing_items.append(str(prompt_path.name))
if len(render_paths) == 0:
missing_items.append("renders/view_*.png")
if missing_items:
shutil.rmtree(temp_dir)
raise FileNotFoundError(
"Auto-kinematics output is incomplete; missing "
+ ", ".join(missing_items)
)
_auto_kinematics_cache_complete_path(temp_dir).write_text(
json.dumps(
{
"version": AUTO_KINEMATICS_CACHE_VERSION,
"created_at": datetime.now().isoformat(timespec="seconds"),
"render_count": len(render_paths),
},
indent=2,
)
+ "\n",
encoding="utf-8",
)
if cache_dir.exists():
shutil.rmtree(cache_dir)
temp_dir.rename(cache_dir)
def _preview_face_colors(mesh: Any) -> np.ndarray:
faces = np.asarray(mesh.faces, dtype=np.int64)
fallback = np.tile(np.asarray([[178, 190, 205, 255]], dtype=np.uint8), (len(faces), 1))
visual = getattr(mesh, "visual", None)
if visual is None:
return fallback
uv = getattr(visual, "uv", None)
material = getattr(visual, "material", None)
texture = None
if material is not None:
texture = getattr(material, "baseColorTexture", None) or getattr(material, "image", None)
if uv is not None and texture is not None:
try:
tex = np.asarray(texture.convert("RGBA"), dtype=np.float32)
uv_array = np.asarray(uv, dtype=np.float32)
if uv_array.ndim == 2 and uv_array.shape[0] >= int(faces.max()) + 1:
face_uv = uv_array[faces].mean(axis=1)
face_uv = np.clip(face_uv, 0.0, 1.0)
height, width = tex.shape[:2]
x = np.rint(face_uv[:, 0] * (width - 1)).astype(np.int64)
y = np.rint((1.0 - face_uv[:, 1]) * (height - 1)).astype(np.int64)
colors = tex[y, x]
base_factor = getattr(material, "baseColorFactor", None)
if base_factor is not None:
factor = np.asarray(base_factor, dtype=np.float32).reshape(-1)
if factor.size >= 3:
colors[:, :3] *= factor[:3]
if factor.size >= 4:
colors[:, 3] *= factor[3]
return np.clip(colors, 0, 255).astype(np.uint8)
except Exception:
traceback.print_exc()
for attr_name in ("face_colors", "vertex_colors"):
try:
color_array = np.asarray(getattr(visual, attr_name))
except Exception:
continue
if color_array.ndim != 2 or color_array.shape[0] == 0:
continue
colors = color_array.astype(np.float32, copy=False)
if colors.max(initial=0.0) <= 1.0:
colors = colors * 255.0
if colors.shape[1] == 3:
colors = np.concatenate(
[colors, np.full((colors.shape[0], 1), 255.0, dtype=np.float32)],
axis=1,
)
if attr_name == "face_colors" and colors.shape[0] == len(faces):
return np.clip(colors[:, :4], 0, 255).astype(np.uint8)
if attr_name == "vertex_colors" and colors.shape[0] >= int(faces.max()) + 1:
return np.clip(colors[faces].mean(axis=1)[:, :4], 0, 255).astype(np.uint8)
return fallback
def _camera_basis_for_preview(
*,
pitch_deg: float = 58.0,
azimuth_deg: float = 225.0,
camera_distance: float = 2.25,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
pitch = np.deg2rad(float(pitch_deg))
azimuth = np.deg2rad(float(azimuth_deg))
horizontal_distance = float(camera_distance) * np.sin(pitch)
camera = np.asarray(
[
horizontal_distance * np.sin(azimuth),
horizontal_distance * np.cos(azimuth),
float(camera_distance) * np.cos(pitch),
],
dtype=np.float32,
)
forward = -camera
forward /= max(float(np.linalg.norm(forward)), 1e-8)
world_up = np.asarray([0.0, 0.0, 1.0], dtype=np.float32)
right = np.cross(forward, world_up)
right /= max(float(np.linalg.norm(right)), 1e-8)
up = np.cross(right, forward)
up /= max(float(np.linalg.norm(up)), 1e-8)
return camera, right.astype(np.float32), up.astype(np.float32), forward.astype(np.float32)
def _render_up_direction_preview_software(
mesh: Any,
*,
up_dir: str,
face_colors: np.ndarray,
output_path: Path,
resolution: int,
) -> None:
from PIL import Image, ImageDraw
reoriented_mesh, _ = reorient_mesh_to_z_up(mesh, up_dir)
normalized_mesh, _, _ = normalize_mesh(reoriented_mesh)
vertices = np.asarray(normalized_mesh.vertices, dtype=np.float32)
faces = np.asarray(normalized_mesh.faces, dtype=np.int64)
if vertices.size == 0 or faces.size == 0:
raise ValueError("Cannot render an empty mesh preview.")
camera, right, up, forward = _camera_basis_for_preview()
view_x = vertices @ right
view_y = vertices @ up
view_z = (vertices - camera) @ forward
xy = np.stack([view_x, view_y], axis=1)
xy_min = xy.min(axis=0)
xy_max = xy.max(axis=0)
extent = np.maximum(xy_max - xy_min, 1e-6)
scale = float(resolution) * 0.84 / float(np.max(extent))
center = (xy_min + xy_max) * 0.5
screen = np.empty((vertices.shape[0], 2), dtype=np.float32)
screen[:, 0] = (view_x - center[0]) * scale + float(resolution) * 0.5
screen[:, 1] = float(resolution) * 0.5 - (view_y - center[1]) * scale
tri = vertices[faces]
normals = np.cross(tri[:, 1] - tri[:, 0], tri[:, 2] - tri[:, 0])
normal_lengths = np.linalg.norm(normals, axis=1, keepdims=True)
normals = np.divide(normals, np.maximum(normal_lengths, 1e-8), out=np.zeros_like(normals))
light_dir = np.asarray([-0.35, -0.45, 0.82], dtype=np.float32)
light_dir /= np.linalg.norm(light_dir)
shade = 0.64 + 0.36 * np.maximum(normals @ light_dir, 0.0)
colors = face_colors[: len(faces)].astype(np.float32, copy=False)
if colors.shape[1] == 3:
alpha = np.ones((colors.shape[0], 1), dtype=np.float32) * 255.0
colors = np.concatenate([colors, alpha], axis=1)
rgb = np.clip(colors[:, :3] * shade[:, None], 0, 255)
alpha = np.clip(colors[:, 3:4] / 255.0, 0.0, 1.0)
background = np.asarray([248.0, 250.0, 252.0], dtype=np.float32)
rgb = rgb * alpha + background[None, :] * (1.0 - alpha)
image = Image.new("RGB", (int(resolution), int(resolution)), tuple(background.astype(np.uint8)))
draw = ImageDraw.Draw(image)
screen_faces = screen[faces]
face_depth = view_z[faces].mean(axis=1)
x0 = screen_faces[:, 0, 0]
y0 = screen_faces[:, 0, 1]
x1 = screen_faces[:, 1, 0]
y1 = screen_faces[:, 1, 1]
x2 = screen_faces[:, 2, 0]
y2 = screen_faces[:, 2, 1]
area = np.abs((x1 - x0) * (y2 - y0) - (x2 - x0) * (y1 - y0))
image_limit = float(resolution + 2)
valid = (
(screen_faces[:, :, 0].max(axis=1) >= -2.0)
& (screen_faces[:, :, 0].min(axis=1) <= image_limit)
& (screen_faces[:, :, 1].max(axis=1) >= -2.0)
& (screen_faces[:, :, 1].min(axis=1) <= image_limit)
& (area >= float(os.environ.get("UPRIGHT_PREVIEW_MIN_TRIANGLE_AREA", "0.25")))
)
if os.environ.get("UPRIGHT_PREVIEW_CULL_BACKFACES", "1").strip().lower() not in {"0", "false", "no"}:
valid &= (normals @ forward) < 0.03
valid_face_ids = np.flatnonzero(valid)
order = valid_face_ids[np.argsort(face_depth[valid_face_ids])[::-1]]
for face_id in order:
pts = screen_faces[face_id]
x0, y0 = pts[0]
x1, y1 = pts[1]
x2, y2 = pts[2]
fill = tuple(np.rint(rgb[face_id]).astype(np.uint8).tolist())
draw.polygon(
[(float(x0), float(y0)), (float(x1), float(y1)), (float(x2), float(y2))],
fill=fill,
)
output_path.parent.mkdir(parents=True, exist_ok=True)
image.save(output_path)
def _render_up_direction_previews_software(
*,
mesh: Any,
output_dir: Path,
) -> list[tuple[str, str]]:
resolution = int(os.environ.get("UPRIGHT_PREVIEW_RESOLUTION", "288"))
output_dir.mkdir(parents=True, exist_ok=True)
face_colors = _preview_face_colors(mesh)
gallery_items: list[tuple[str, str]] = []
for up_dir in UP_DIR_CHOICES:
output_path = output_dir / f"up_{_up_dir_slug(up_dir)}.png"
_render_up_direction_preview_software(
mesh,
up_dir=up_dir,
face_colors=face_colors,
output_path=output_path,
resolution=resolution,
)
gallery_items.append((str(output_path), f"{up_dir} up"))
return gallery_items
def _render_up_direction_previews(
*,
mesh_path: Path,
mesh: Any,
mesh_hash: str | None = None,
) -> list[tuple[str, str]]:
if mesh_hash is not None:
cached_items = _cached_upright_preview_items(mesh_hash)
if cached_items is not None:
return cached_items
preview_dir = _upright_preview_cache_dir(mesh_hash)
else:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
preview_dir = OUTPUT_ROOT / "up_direction_previews" / f"{mesh_path.stem}_{timestamp}"
return _render_up_direction_previews_software(
mesh=mesh,
output_dir=preview_dir,
)
def _load_point_prompt_payload(raw_value: str | None) -> dict[str, Any]:
if raw_value is None or not str(raw_value).strip():
return {"prompts": []}
try:
payload = json.loads(str(raw_value))
except json.JSONDecodeError as exc:
raise ValueError(f"Point prompt JSON must be valid JSON: {exc}") from exc
if isinstance(payload, list):
return {"prompts": payload}
if not isinstance(payload, dict):
raise ValueError("Point prompt JSON must be an object or a list.")
return payload
def _parse_point_prompt_arrays(
raw_value: str | None,
*,
num_links: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray] | None:
payload = _load_point_prompt_payload(raw_value)
raw_prompts = payload.get("prompts", [])
if raw_prompts is None:
raw_prompts = []
if not isinstance(raw_prompts, list):
raise ValueError("'prompts' must be a list.")
points = np.zeros((num_links, 3), dtype=np.float32)
normals = np.zeros((num_links, 3), dtype=np.float32)
has_prompt = np.zeros((num_links,), dtype=np.bool_)
for prompt in raw_prompts:
if not isinstance(prompt, dict):
raise ValueError("Each point prompt must be an object.")
link_id = int(prompt.get("link_id", prompt.get("link", -1)))
if not 0 <= link_id < num_links:
raise ValueError(f"Point prompt link_id {link_id} is outside [0, {num_links - 1}].")
point = np.asarray(prompt.get("point"), dtype=np.float32)
normal = np.asarray(prompt.get("normal"), dtype=np.float32)
if point.shape != (3,):
raise ValueError(f"Point prompt for link {link_id} must have a 3D point.")
if normal.shape != (3,):
raise ValueError(f"Point prompt for link {link_id} must have a 3D normal.")
normal_norm = float(np.linalg.norm(normal))
if normal_norm <= 1e-8:
raise ValueError(f"Point prompt for link {link_id} has a zero normal.")
points[link_id] = point
normals[link_id] = normal / np.float32(normal_norm)
has_prompt[link_id] = True
if not bool(has_prompt.any()):
return None
return points, normals, has_prompt
def _normalize_point_prompt_arrays(
*,
points: np.ndarray,
normals: np.ndarray,
mesh_geometry: Any,
) -> tuple[np.ndarray, np.ndarray]:
# Point prompt JSON stores raw upload-space coordinates; inference uses the
# same upright transform as the input mesh before normalization.
rotation = np.asarray(mesh_geometry.up_dir_rotation, dtype=np.float32)
rotated_points = np.asarray(points, dtype=np.float32) @ rotation.T
rotated_normals = np.asarray(normals, dtype=np.float32) @ rotation.T
normal_lengths = np.linalg.norm(rotated_normals, axis=1, keepdims=True)
rotated_normals = np.divide(
rotated_normals,
np.maximum(normal_lengths, np.float32(1e-8)),
out=np.zeros_like(rotated_normals),
)
normalized_points = (
(rotated_points - np.asarray(mesh_geometry.center, dtype=np.float32))
* np.float32(mesh_geometry.scale)
)
return (
normalized_points.astype(np.float32, copy=False),
rotated_normals.astype(np.float32, copy=False),
)
def _duplicate_link_prompt_warning(
link_names: list[str],
point_prompt_arrays: tuple[np.ndarray, np.ndarray, np.ndarray] | None,
) -> str | None:
has_prompt = (
np.zeros((len(link_names),), dtype=np.bool_)
if point_prompt_arrays is None
else np.asarray(point_prompt_arrays[2], dtype=np.bool_)
)
name_to_link_ids: dict[str, list[int]] = {}
for link_id, link_name in enumerate(link_names):
normalized_name = str(link_name).strip()
name_to_link_ids.setdefault(normalized_name, []).append(int(link_id))
ambiguous_groups: list[str] = []
for link_name, link_ids in name_to_link_ids.items():
if len(link_ids) <= 1:
continue
missing_prompt_ids = [link_id for link_id in link_ids if not bool(has_prompt[link_id])]
if missing_prompt_ids:
ambiguous_groups.append(
f"{link_name!r} links {link_ids} missing prompts for {missing_prompt_ids}"
)
if not ambiguous_groups:
return None
return (
"Duplicate link names need point prompts to disambiguate them. "
"Rename the duplicate links or add point prompts for every link in each duplicate-name group: "
+ "; ".join(ambiguous_groups)
)
def _auto_tree_from_parsed_response(parsed_response: dict[str, Any]) -> dict[str, Any]:
return {
"links": [
{
"id": int(link["link_id"]),
"name": str(link["name"]).strip() or f"link_{int(link['link_id'])}",
}
for link in parsed_response["links"]
],
"joints": [
{
"parent": int(joint["parent_link_id"]),
"child": int(joint["child_link_id"]),
"type": str(joint["joint_type"]).strip().lower(),
}
for joint in parsed_response["joints"]
],
}
def _point_prompt_json_from_normalized_prompts(
*,
normalized_points: np.ndarray,
normalized_normals: np.ndarray,
mesh_geometry: Any,
) -> str:
# Auto-kinematics lifting produces normalized model-space prompts. Store
# them back in raw upload-space coordinates so cached prompts are reusable
# across later upright-orientation choices.
rotation = np.asarray(mesh_geometry.up_dir_rotation, dtype=np.float32)
points = np.asarray(normalized_points, dtype=np.float32)
normals = np.asarray(normalized_normals, dtype=np.float32)
model_points = points / np.float32(mesh_geometry.scale) + np.asarray(
mesh_geometry.center,
dtype=np.float32,
)
raw_points = model_points @ rotation
raw_normals = normals @ rotation
normal_lengths = np.linalg.norm(raw_normals, axis=1, keepdims=True)
raw_normals = np.divide(
raw_normals,
np.maximum(normal_lengths, np.float32(1e-8)),
out=np.zeros_like(raw_normals),
)
prompts = [
{
"link_id": int(link_id),
"point": np.round(raw_points[link_id], 6).astype(float).tolist(),
"normal": np.round(raw_normals[link_id], 6).astype(float).tolist(),
}
for link_id in range(int(points.shape[0]))
]
return json.dumps({"prompts": prompts}, indent=2)
def _zip_directory(directory: Path) -> Path:
zip_path = directory.with_suffix(".zip")
if zip_path.exists():
zip_path.unlink()
with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zip_file:
for path in sorted(directory.rglob("*")):
if path.is_file():
zip_file.write(path, path.relative_to(directory))
return zip_path
def _to_cpu_payload(value: Any) -> Any:
if isinstance(value, torch.Tensor):
return value.detach().cpu()
if isinstance(value, dict):
return {key: _to_cpu_payload(item) for key, item in value.items()}
if isinstance(value, list):
return [_to_cpu_payload(item) for item in value]
if isinstance(value, tuple):
return tuple(_to_cpu_payload(item) for item in value)
return value
def _to_device_payload(value: Any, device: torch.device) -> Any:
if isinstance(value, torch.Tensor):
return value.to(device)
if isinstance(value, dict):
return {key: _to_device_payload(item, device) for key, item in value.items()}
if isinstance(value, list):
return [_to_device_payload(item, device) for item in value]
if isinstance(value, tuple):
return tuple(_to_device_payload(item, device) for item in value)
return value
def _ensure_instruct_checkpoint(checkpoint_path: Path) -> Path:
if checkpoint_path.exists():
return checkpoint_path
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
downloaded_path = hf_hub_download(
repo_id=CHECKPOINT_REPO_ID,
filename=CHECKPOINT_REPO_FILENAME,
local_dir=str(checkpoint_path.parent),
local_dir_use_symlinks=False,
token=os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN"),
)
downloaded_path = Path(downloaded_path)
if downloaded_path != checkpoint_path:
shutil.copy2(downloaded_path, checkpoint_path)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Could not download checkpoint to {checkpoint_path}")
return checkpoint_path
def _hf_token() -> str | None:
return os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
def _prefetch_clip_text_assets(config: dict[str, Any]) -> None:
model_config = config.get("model", {})
if not isinstance(model_config, dict) or not bool(model_config.get("use_text_conditioning", True)):
return
clip_model_name = str(model_config.get("clip_model_name", "openai/clip-vit-large-patch14"))
print(f"Prefetching CLIP text assets: {clip_model_name}")
snapshot_download(
repo_id=clip_model_name,
cache_dir=os.environ.get("HF_HOME") or None,
allow_patterns=[
"config.json",
"tokenizer_config.json",
"vocab.json",
"merges.txt",
"tokenizer.json",
"special_tokens_map.json",
"model.safetensors",
],
token=_hf_token(),
)
def _prefetch_partfield_assets(config: dict[str, Any]) -> None:
model_config = config.get("model", {})
if not isinstance(model_config, dict):
return
needs_partfield = any(
bool(model_config.get(key, False))
for key in (
"use_pretrained_features_shape",
"use_pretrained_features_query",
"use_pretrained_features_point_prompt",
)
)
if not needs_partfield:
return
print("Prefetching PartField checkpoint assets")
ensure_partfield_assets_downloaded()
def _prefetch_startup_assets(config: dict[str, Any]) -> None:
if os.environ.get("INSTRUCT_PARTICULATE_PREFETCH_ASSETS", "1").strip().lower() in {
"0",
"false",
"no",
}:
print("Skipping startup asset prefetch because INSTRUCT_PARTICULATE_PREFETCH_ASSETS is disabled")
return
_prefetch_partfield_assets(config)
_prefetch_clip_text_assets(config)
def _preload_weights_enabled() -> bool:
return os.environ.get("INSTRUCT_PARTICULATE_PRELOAD_WEIGHTS", "1").strip().lower() not in {
"0",
"false",
"no",
}
def _spaces_gpu(fn):
duration = max(1, min(int(os.environ.get("SPACES_GPU_DURATION", "20")), 20))
return spaces.GPU(duration=duration)(fn)
def _mesh_face_warning_update_from_face_count(face_count: int):
if face_count <= HIGH_FACE_COUNT_WARNING_THRESHOLD:
return ""
return (
'
'
"Large mesh warning: "
f"this mesh has {face_count:,} faces. "
"That is too dense for this Space and can exhaust CPU resources. "
"Please upload a simplified mesh with fewer than "
f"{HIGH_FACE_COUNT_WARNING_THRESHOLD:,} faces for the best results; "
"this upload will still proceed.
Drag colored link rectangles to arrange the tree. Add links with Add Link. To add a joint, press Add Joint, then click the parent link followed by the child link. New joints start as revolute; use each joint's dropdown to change it to prismatic. Link names are editable inside each rectangle.
Upload one object mesh and provide or extract a kinematic tree. The demo predicts part segmentation and joint parameters, then exports animated GLB and URDF assets.
How to use the demo:
Upload a mesh: use the Input Mesh panel or select an example. You can also generate an object with Hunyuan3D, using the China site or the international site, then bring the mesh here to make it interactive. For efficient processing, select Hunyuan3D's 50k-face generation option.
Choose the upright orientation: click the preview where the object is upright. This orientation is used for both auto-kinematic extraction and inference.
Define the kinematic tree: edit links and joints manually, or click Extract Kinematic Structure to infer a starting tree and point prompts.
Add optional point prompts: select a link in the Kinematic Tree Editor, then click the mesh in the Point Prompt Picker to mark a representative point for that link.
Run inference: keep Connected Component Postprocessing on when the mesh has clean connected components, then click Run Inference.
Review and export: inspect the point query visualization, articulated model, and predicted parts and axes, then export the URDF when the output is ready.
Meshes above 100,000 faces may be slow or fail on the Space CPU; simplified meshes are recommended for reliable runs.
"""
)
best_practice_banner()
mesh_face_warning = gr.HTML(
value="",
visible=True,
elem_id="mesh-face-warning",
elem_classes=["mesh-face-warning"],
container=False,
)
loaded_mesh_path = gr.State(None)
loaded_mesh_hash = gr.State(None)
selected_up_dir = gr.Textbox(
value="",
label="Selected Upright Direction",
elem_id="selected_up_dir",
elem_classes=["kinematic-json-sync"],
)
latest_output_dir = gr.State(None)
inference_payload = gr.State(None)
example_mesh_index = gr.Textbox(
value="",
label="Example Mesh Index",
elem_id="example_mesh_index",
elem_classes=["kinematic-json-sync"],
)
with gr.Row(equal_height=True, elem_classes=["demo-row", "demo-top-row"]):
with gr.Column(scale=1, min_width=300, elem_classes=["demo-panel", "mesh-panel"]):
input_mesh = gr.Model3D(label="Input Mesh", interactive=True, height=300)
examples = _example_meshes()
if examples:
with gr.Column(elem_classes=["mesh-examples"]):
gr.Markdown("Examples")
gr.HTML(
_example_mesh_grid_html(examples),
container=False,
padding=False,
)
with gr.Column(scale=1, min_width=300, elem_classes=["demo-panel", "kin-panel"]):
extract_button = gr.Button("Extract Kinematic Structure")
auto_status = gr.Textbox(
label="Kinematic Extraction Status",
interactive=False,
elem_classes=["kin-extraction-status"],
)
gr.HTML(
_kinematic_tree_editor_html(),
elem_id="kinematic_tree_editor_html",
elem_classes=["kinematic-editor-host"],
min_height=500,
container=False,
padding=False,
)
with gr.Column(scale=1, min_width=300, elem_classes=["demo-panel", "prompt-panel"]):
gr.HTML(
_point_prompt_picker_html(),
elem_classes=["point-prompt-picker-host"],
container=False,
padding=False,
)
num_query_points = gr.State(51200)
per_face_queries = gr.State(3)
query_batch_size = gr.State(51200)
animation_frames = gr.State(50)
connectivity = gr.State(True)
confidence_temperature = gr.State(1.0)
with gr.Column(elem_classes=["inference-params-panel", "inference-params-static"]):
strict = gr.Checkbox(
label="Connected Component Postprocessing",
value=True,
elem_classes=["toggle-switch"],
)
gr.Markdown(
"Enable this when the mesh has clean connected components. "
"Each component will stay intact and be assigned to one predicted part; "
"a part may still merge multiple components.",
elem_classes=["inference-params-help"],
)
with gr.Row(elem_classes=["prompt-run-row"]):
run_button = gr.Button("Run Inference", variant="primary")
kinematic_tree = gr.Textbox(
label="Kinematic Tree JSON Sync",
value=_tree_to_pretty_json(DEFAULT_KINEMATIC_TREE),
lines=8,
elem_id="kinematic_tree_json",
elem_classes=["kinematic-json-sync"],
)
point_prompt_mesh_data = gr.Textbox(
label="Point Prompt Mesh Data",
value="",
lines=1,
elem_id="point_prompt_mesh_data",
elem_classes=["kinematic-json-sync"],
)
point_prompts = gr.Textbox(
label="Point Prompt JSON Sync",
value='{"prompts":[]}',
lines=4,
elem_id="point_prompt_json",
elem_classes=["kinematic-json-sync"],
)
with gr.Row(equal_height=True, elem_classes=["demo-row", "demo-bottom-row"]):
with gr.Column(scale=1, min_width=300, elem_classes=["demo-panel", "orientation-panel"]):
gr.HTML(
"""
Upright Orientation
After uploading or selecting a mesh, click the preview where the object is upright. That choice is used for both kinematic extraction and inference.