"""Phase 4: Scene Assembly Module. Optimizes room layout, resolves collisions, normalizes scale, and builds the editable scene graph representation. """ from typing import Dict, List, Optional, Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class SceneAssemblyModule(nn.Module): """Assemble individual objects into a coherent room scene.""" def __init__( self, device: str = "cuda", dtype: torch.dtype = torch.float16, ): super().__init__() self.device = device self.dtype = dtype # Furniture dimension priors (meters) for scale normalization self.furniture_priors = { "sofa": {"width": 2.0, "depth": 0.9, "height": 0.8}, "chair": {"width": 0.5, "depth": 0.5, "height": 0.9}, "table": {"width": 1.2, "depth": 0.8, "height": 0.75}, "coffee_table": {"width": 1.0, "depth": 0.6, "height": 0.45}, "bed": {"width": 2.0, "depth": 1.5, "height": 0.5}, "desk": {"width": 1.4, "depth": 0.7, "height": 0.75}, "bookshelf": {"width": 1.0, "depth": 0.3, "height": 2.0}, "lamp": {"width": 0.3, "depth": 0.3, "height": 1.5}, "wardrobe": {"width": 1.5, "depth": 0.6, "height": 2.1}, "tv_stand": {"width": 1.2, "depth": 0.4, "height": 0.5}, "rug": {"width": 2.0, "depth": 1.5, "height": 0.02}, "plant": {"width": 0.3, "depth": 0.3, "height": 1.0}, "furniture": {"width": 0.8, "depth": 0.8, "height": 0.8}, # default } def assemble( self, room_shell_mesh: "trimesh.Trimesh", # type: ignore object_meshes: List["trimesh.Trimesh"], # type: ignore room_layout: Dict, detected_objects: Dict, depth_map: np.ndarray, ) -> Dict: """ Assemble room scene from individual components. Steps: 1. Place objects at detected positions 2. Normalize scales using furniture priors 3. Ensure objects rest on floor 4. Resolve collisions 5. Build scene graph 6. Merge into unified mesh """ # Step 1: Initial placement from detected positions placed_objects = self._place_objects( object_meshes, detected_objects, room_layout ) # Step 2: Scale normalization normalized_objects = self._normalize_scales( placed_objects, detected_objects, depth_map ) # Step 3: Gravity constraint (objects on floor) grounded_objects = self._apply_gravity( normalized_objects, room_layout ) # Step 4: Collision detection and resolution resolved_objects = self._resolve_collisions( grounded_objects, room_layout ) # Step 5: Build scene graph scene_graph = self._build_scene_graph( resolved_objects, room_layout, detected_objects ) # Step 6: Merge into unified mesh scene_mesh = self._merge_scene( room_shell_mesh, resolved_objects ) return { "scene_mesh": scene_mesh, "object_meshes": resolved_objects, "scene_graph": scene_graph, } def _place_objects( self, object_meshes: List["trimesh.Trimesh"], # type: ignore detected_objects: Dict, room_layout: Dict, ) -> List["trimesh.Trimesh"]: # type: ignore """Place objects at their detected positions in 3D space.""" placed = [] floor_height = room_layout.get("floor", {}).get("height", 0.0) for i, mesh in enumerate(object_meshes): if i in detected_objects: obj_info = detected_objects[i] bbox = obj_info.get("bbox", [0, 0, 100, 100]) depth_range = obj_info.get("depth_range", [1.0, 3.0]) # Compute 3D position from bbox center + depth # Simple approximation: center of bbox at mean depth img_h, img_w = depth_map.shape if 'depth_map' in locals() else (512, 512) x1, y1, x2, y2 = bbox cx = (x1 + x2) / 2 cy = (y1 + y2) / 2 mean_depth = np.mean(depth_range) # Convert image coordinates to 3D # Assume camera at origin, looking down -z fx = fy = max(img_w, img_h) cx_cam = img_w / 2 cy_cam = img_h / 2 x_3d = (cx - cx_cam) * mean_depth / fx z_3d = mean_depth # depth is z in camera frame # Position mesh mesh_copy = mesh.copy() # Center mesh centroid = mesh_copy.centroid if hasattr(mesh_copy, 'centroid') else mesh_copy.bounds.mean(axis=0) mesh_copy.apply_translation([-centroid[0], 0, -centroid[2]]) # Move to detected position mesh_copy.apply_translation([x_3d, floor_height, z_3d]) placed.append(mesh_copy) else: placed.append(mesh.copy()) return placed def _normalize_scales( self, object_meshes: List["trimesh.Trimesh"], # type: ignore detected_objects: Dict, depth_map: np.ndarray, ) -> List["trimesh.Trimesh"]: # type: ignore """Normalize object scales using furniture priors and depth.""" normalized = [] for i, mesh in enumerate(object_meshes): mesh_copy = mesh.copy() # Get class name class_name = "furniture" if i in detected_objects: class_name = detected_objects[i].get("class_name", "furniture") # Get prior dimensions prior = self.furniture_priors.get( class_name, self.furniture_priors["furniture"] ) # Compute current dimensions bounds = mesh_copy.bounds current_dims = bounds[1] - bounds[0] # Compute scale factors # Use largest dimension for scale reference max_current = max(current_dims) max_prior = max(prior["width"], prior["depth"], prior["height"]) if max_current > 0.001: # Avoid division by zero scale_factor = max_prior / max_current # Apply non-uniform scaling to match prior target_scale = np.array([ prior["width"] / max(current_dims[0], 0.001), prior["height"] / max(current_dims[1], 0.001), prior["depth"] / max(current_dims[2], 0.001), ]) # Clamp scale to reasonable range scale_factor = np.clip(scale_factor, 0.1, 3.0) target_scale = np.clip(target_scale, 0.1, 3.0) # Use uniform scale for stability mesh_copy.apply_scale(scale_factor) normalized.append(mesh_copy) return normalized def _apply_gravity( self, object_meshes: List["trimesh.Trimesh"], # type: ignore room_layout: Dict, ) -> List["trimesh.Trimesh"]: # type: ignore """Ensure all objects rest on the floor.""" floor_height = room_layout.get("floor", {}).get("height", 0.0) grounded = [] for mesh in object_meshes: mesh_copy = mesh.copy() # Find lowest point if len(mesh_copy.vertices) > 0: min_y = mesh_copy.vertices[:, 1].min() # Move so lowest point is at floor height delta_y = floor_height - min_y mesh_copy.apply_translation([0, delta_y, 0]) grounded.append(mesh_copy) return grounded def _resolve_collisions( self, object_meshes: List["trimesh.Trimesh"], # type: ignore room_layout: Dict, ) -> List["trimesh.Trimesh"]: # type: ignore """Detect and resolve inter-object collisions.""" resolved = list(object_meshes) max_iterations = 50 for iteration in range(max_iterations): collisions_found = False for i in range(len(resolved)): for j in range(i + 1, len(resolved)): try: # Check collision collision = resolved[i].collision_manager is_collision = False # Placeholder # Simple bounding box collision test b1 = resolved[i].bounds b2 = resolved[j].bounds overlap = ( b1[0][0] < b2[1][0] and b1[1][0] > b2[0][0] and b1[0][1] < b2[1][1] and b1[1][1] > b2[0][1] and b1[0][2] < b2[1][2] and b1[1][2] > b2[0][2] ) if overlap: collisions_found = True # Push apart along smallest overlap axis overlaps = [ min(b1[1][0] - b2[0][0], b2[1][0] - b1[0][0]), min(b1[1][1] - b2[0][1], b2[1][1] - b1[0][1]), min(b1[1][2] - b2[0][2], b2[1][2] - b1[0][2]), ] min_axis = np.argmin(overlaps) push_dir = np.zeros(3) push_dir[min_axis] = 1.0 # Push in opposite directions push_dist = overlaps[min_axis] * 0.5 + 0.05 center_i = resolved[i].bounds.mean(axis=0) center_j = resolved[j].bounds.mean(axis=0) if center_i[min_axis] < center_j[min_axis]: resolved[i].apply_translation(-push_dir * push_dist) resolved[j].apply_translation(push_dir * push_dist) else: resolved[i].apply_translation(push_dir * push_dist) resolved[j].apply_translation(-push_dir * push_dist) except Exception: pass if not collisions_found: break return resolved def _build_scene_graph( self, object_meshes: List["trimesh.Trimesh"], # type: ignore room_layout: Dict, detected_objects: Dict, ) -> Dict: """Build editable scene graph from assembled objects.""" nodes = [] edges = [] # Room shell node nodes.append({ "id": "room_shell", "type": "room", "label": "room", "bbox": None, }) # Object nodes for i, mesh in enumerate(object_meshes): class_name = "furniture" if i in detected_objects: class_name = detected_objects[i].get("class_name", "furniture") center = mesh.bounds.mean(axis=0) dims = mesh.bounds[1] - mesh.bounds[0] nodes.append({ "id": i, "type": "object", "label": class_name, "position": center.tolist(), "dimensions": dims.tolist(), "mesh_index": i, }) # Edge: object is IN room edges.append({ "from": i, "to": "room_shell", "relation": "in", }) # Infer spatial relationships between objects for i in range(len(object_meshes)): for j in range(i + 1, len(object_meshes)): center_i = object_meshes[i].bounds.mean(axis=0) center_j = object_meshes[j].bounds.mean(axis=0) dist = np.linalg.norm(center_i - center_j) # Proximity threshold if dist < 2.0: # Determine relationship if abs(center_i[1] - center_j[1]) < 0.1: relation = "next_to" elif center_i[1] > center_j[1] + 0.2: relation = "on" else: relation = "near" edges.append({ "from": i, "to": j, "relation": relation, "distance": float(dist), }) return { "nodes": nodes, "edges": edges, } def _merge_scene( self, room_shell_mesh: "trimesh.Trimesh", # type: ignore object_meshes: List["trimesh.Trimesh"], # type: ignore ) -> "trimesh.Trimesh": # type: ignore """Merge room shell and objects into unified scene mesh.""" import trimesh meshes = [room_shell_mesh] + list(object_meshes) # Filter out empty meshes valid_meshes = [m for m in meshes if hasattr(m, 'vertices') and len(m.vertices) > 0] if not valid_meshes: return trimesh.Trimesh() try: scene_mesh = trimesh.util.concatenate(valid_meshes) except Exception: # Fallback: add meshes one by one scene_mesh = valid_meshes[0] for m in valid_meshes[1:]: try: scene_mesh += m except Exception: pass return scene_mesh def reassemble_with_textures( self, room_shell_mesh: "trimesh.Trimesh", # type: ignore textured_objects: List["trimesh.Trimesh"], # type: ignore scene_graph: Dict, ) -> "trimesh.Trimesh": # type: ignore """Re-assemble scene with textured objects.""" return self._merge_scene(room_shell_mesh, textured_objects)