InteriorFusion / src /interiorfusion /models /scene_assembly.py
stevee00's picture
Upload src/interiorfusion/models/scene_assembly.py
dcb20f6 verified
"""Phase 4: Scene Assembly Module.
Optimizes room layout, resolves collisions, normalizes scale,
and builds the editable scene graph representation.
"""
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class SceneAssemblyModule(nn.Module):
"""Assemble individual objects into a coherent room scene."""
def __init__(
self,
device: str = "cuda",
dtype: torch.dtype = torch.float16,
):
super().__init__()
self.device = device
self.dtype = dtype
# Furniture dimension priors (meters) for scale normalization
self.furniture_priors = {
"sofa": {"width": 2.0, "depth": 0.9, "height": 0.8},
"chair": {"width": 0.5, "depth": 0.5, "height": 0.9},
"table": {"width": 1.2, "depth": 0.8, "height": 0.75},
"coffee_table": {"width": 1.0, "depth": 0.6, "height": 0.45},
"bed": {"width": 2.0, "depth": 1.5, "height": 0.5},
"desk": {"width": 1.4, "depth": 0.7, "height": 0.75},
"bookshelf": {"width": 1.0, "depth": 0.3, "height": 2.0},
"lamp": {"width": 0.3, "depth": 0.3, "height": 1.5},
"wardrobe": {"width": 1.5, "depth": 0.6, "height": 2.1},
"tv_stand": {"width": 1.2, "depth": 0.4, "height": 0.5},
"rug": {"width": 2.0, "depth": 1.5, "height": 0.02},
"plant": {"width": 0.3, "depth": 0.3, "height": 1.0},
"furniture": {"width": 0.8, "depth": 0.8, "height": 0.8}, # default
}
def assemble(
self,
room_shell_mesh: "trimesh.Trimesh", # type: ignore
object_meshes: List["trimesh.Trimesh"], # type: ignore
room_layout: Dict,
detected_objects: Dict,
depth_map: np.ndarray,
) -> Dict:
"""
Assemble room scene from individual components.
Steps:
1. Place objects at detected positions
2. Normalize scales using furniture priors
3. Ensure objects rest on floor
4. Resolve collisions
5. Build scene graph
6. Merge into unified mesh
"""
# Step 1: Initial placement from detected positions
placed_objects = self._place_objects(
object_meshes, detected_objects, room_layout
)
# Step 2: Scale normalization
normalized_objects = self._normalize_scales(
placed_objects, detected_objects, depth_map
)
# Step 3: Gravity constraint (objects on floor)
grounded_objects = self._apply_gravity(
normalized_objects, room_layout
)
# Step 4: Collision detection and resolution
resolved_objects = self._resolve_collisions(
grounded_objects, room_layout
)
# Step 5: Build scene graph
scene_graph = self._build_scene_graph(
resolved_objects, room_layout, detected_objects
)
# Step 6: Merge into unified mesh
scene_mesh = self._merge_scene(
room_shell_mesh, resolved_objects
)
return {
"scene_mesh": scene_mesh,
"object_meshes": resolved_objects,
"scene_graph": scene_graph,
}
def _place_objects(
self,
object_meshes: List["trimesh.Trimesh"], # type: ignore
detected_objects: Dict,
room_layout: Dict,
) -> List["trimesh.Trimesh"]: # type: ignore
"""Place objects at their detected positions in 3D space."""
placed = []
floor_height = room_layout.get("floor", {}).get("height", 0.0)
for i, mesh in enumerate(object_meshes):
if i in detected_objects:
obj_info = detected_objects[i]
bbox = obj_info.get("bbox", [0, 0, 100, 100])
depth_range = obj_info.get("depth_range", [1.0, 3.0])
# Compute 3D position from bbox center + depth
# Simple approximation: center of bbox at mean depth
img_h, img_w = depth_map.shape if 'depth_map' in locals() else (512, 512)
x1, y1, x2, y2 = bbox
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
mean_depth = np.mean(depth_range)
# Convert image coordinates to 3D
# Assume camera at origin, looking down -z
fx = fy = max(img_w, img_h)
cx_cam = img_w / 2
cy_cam = img_h / 2
x_3d = (cx - cx_cam) * mean_depth / fx
z_3d = mean_depth # depth is z in camera frame
# Position mesh
mesh_copy = mesh.copy()
# Center mesh
centroid = mesh_copy.centroid if hasattr(mesh_copy, 'centroid') else mesh_copy.bounds.mean(axis=0)
mesh_copy.apply_translation([-centroid[0], 0, -centroid[2]])
# Move to detected position
mesh_copy.apply_translation([x_3d, floor_height, z_3d])
placed.append(mesh_copy)
else:
placed.append(mesh.copy())
return placed
def _normalize_scales(
self,
object_meshes: List["trimesh.Trimesh"], # type: ignore
detected_objects: Dict,
depth_map: np.ndarray,
) -> List["trimesh.Trimesh"]: # type: ignore
"""Normalize object scales using furniture priors and depth."""
normalized = []
for i, mesh in enumerate(object_meshes):
mesh_copy = mesh.copy()
# Get class name
class_name = "furniture"
if i in detected_objects:
class_name = detected_objects[i].get("class_name", "furniture")
# Get prior dimensions
prior = self.furniture_priors.get(
class_name, self.furniture_priors["furniture"]
)
# Compute current dimensions
bounds = mesh_copy.bounds
current_dims = bounds[1] - bounds[0]
# Compute scale factors
# Use largest dimension for scale reference
max_current = max(current_dims)
max_prior = max(prior["width"], prior["depth"], prior["height"])
if max_current > 0.001: # Avoid division by zero
scale_factor = max_prior / max_current
# Apply non-uniform scaling to match prior
target_scale = np.array([
prior["width"] / max(current_dims[0], 0.001),
prior["height"] / max(current_dims[1], 0.001),
prior["depth"] / max(current_dims[2], 0.001),
])
# Clamp scale to reasonable range
scale_factor = np.clip(scale_factor, 0.1, 3.0)
target_scale = np.clip(target_scale, 0.1, 3.0)
# Use uniform scale for stability
mesh_copy.apply_scale(scale_factor)
normalized.append(mesh_copy)
return normalized
def _apply_gravity(
self,
object_meshes: List["trimesh.Trimesh"], # type: ignore
room_layout: Dict,
) -> List["trimesh.Trimesh"]: # type: ignore
"""Ensure all objects rest on the floor."""
floor_height = room_layout.get("floor", {}).get("height", 0.0)
grounded = []
for mesh in object_meshes:
mesh_copy = mesh.copy()
# Find lowest point
if len(mesh_copy.vertices) > 0:
min_y = mesh_copy.vertices[:, 1].min()
# Move so lowest point is at floor height
delta_y = floor_height - min_y
mesh_copy.apply_translation([0, delta_y, 0])
grounded.append(mesh_copy)
return grounded
def _resolve_collisions(
self,
object_meshes: List["trimesh.Trimesh"], # type: ignore
room_layout: Dict,
) -> List["trimesh.Trimesh"]: # type: ignore
"""Detect and resolve inter-object collisions."""
resolved = list(object_meshes)
max_iterations = 50
for iteration in range(max_iterations):
collisions_found = False
for i in range(len(resolved)):
for j in range(i + 1, len(resolved)):
try:
# Check collision
collision = resolved[i].collision_manager
is_collision = False # Placeholder
# Simple bounding box collision test
b1 = resolved[i].bounds
b2 = resolved[j].bounds
overlap = (
b1[0][0] < b2[1][0] and b1[1][0] > b2[0][0] and
b1[0][1] < b2[1][1] and b1[1][1] > b2[0][1] and
b1[0][2] < b2[1][2] and b1[1][2] > b2[0][2]
)
if overlap:
collisions_found = True
# Push apart along smallest overlap axis
overlaps = [
min(b1[1][0] - b2[0][0], b2[1][0] - b1[0][0]),
min(b1[1][1] - b2[0][1], b2[1][1] - b1[0][1]),
min(b1[1][2] - b2[0][2], b2[1][2] - b1[0][2]),
]
min_axis = np.argmin(overlaps)
push_dir = np.zeros(3)
push_dir[min_axis] = 1.0
# Push in opposite directions
push_dist = overlaps[min_axis] * 0.5 + 0.05
center_i = resolved[i].bounds.mean(axis=0)
center_j = resolved[j].bounds.mean(axis=0)
if center_i[min_axis] < center_j[min_axis]:
resolved[i].apply_translation(-push_dir * push_dist)
resolved[j].apply_translation(push_dir * push_dist)
else:
resolved[i].apply_translation(push_dir * push_dist)
resolved[j].apply_translation(-push_dir * push_dist)
except Exception:
pass
if not collisions_found:
break
return resolved
def _build_scene_graph(
self,
object_meshes: List["trimesh.Trimesh"], # type: ignore
room_layout: Dict,
detected_objects: Dict,
) -> Dict:
"""Build editable scene graph from assembled objects."""
nodes = []
edges = []
# Room shell node
nodes.append({
"id": "room_shell",
"type": "room",
"label": "room",
"bbox": None,
})
# Object nodes
for i, mesh in enumerate(object_meshes):
class_name = "furniture"
if i in detected_objects:
class_name = detected_objects[i].get("class_name", "furniture")
center = mesh.bounds.mean(axis=0)
dims = mesh.bounds[1] - mesh.bounds[0]
nodes.append({
"id": i,
"type": "object",
"label": class_name,
"position": center.tolist(),
"dimensions": dims.tolist(),
"mesh_index": i,
})
# Edge: object is IN room
edges.append({
"from": i,
"to": "room_shell",
"relation": "in",
})
# Infer spatial relationships between objects
for i in range(len(object_meshes)):
for j in range(i + 1, len(object_meshes)):
center_i = object_meshes[i].bounds.mean(axis=0)
center_j = object_meshes[j].bounds.mean(axis=0)
dist = np.linalg.norm(center_i - center_j)
# Proximity threshold
if dist < 2.0:
# Determine relationship
if abs(center_i[1] - center_j[1]) < 0.1:
relation = "next_to"
elif center_i[1] > center_j[1] + 0.2:
relation = "on"
else:
relation = "near"
edges.append({
"from": i,
"to": j,
"relation": relation,
"distance": float(dist),
})
return {
"nodes": nodes,
"edges": edges,
}
def _merge_scene(
self,
room_shell_mesh: "trimesh.Trimesh", # type: ignore
object_meshes: List["trimesh.Trimesh"], # type: ignore
) -> "trimesh.Trimesh": # type: ignore
"""Merge room shell and objects into unified scene mesh."""
import trimesh
meshes = [room_shell_mesh] + list(object_meshes)
# Filter out empty meshes
valid_meshes = [m for m in meshes if hasattr(m, 'vertices') and len(m.vertices) > 0]
if not valid_meshes:
return trimesh.Trimesh()
try:
scene_mesh = trimesh.util.concatenate(valid_meshes)
except Exception:
# Fallback: add meshes one by one
scene_mesh = valid_meshes[0]
for m in valid_meshes[1:]:
try:
scene_mesh += m
except Exception:
pass
return scene_mesh
def reassemble_with_textures(
self,
room_shell_mesh: "trimesh.Trimesh", # type: ignore
textured_objects: List["trimesh.Trimesh"], # type: ignore
scene_graph: Dict,
) -> "trimesh.Trimesh": # type: ignore
"""Re-assemble scene with textured objects."""
return self._merge_scene(room_shell_mesh, textured_objects)