| """Phase 3: 3D Reconstruction Module. |
| |
| Reconstructs: |
| - Room shell (walls, floor, ceiling) as planar meshes |
| - Per-object 3D meshes using TRELLIS.2 or native InteriorFusion-L |
| - Scene-level Gaussian Splatting representation |
| """ |
|
|
| import os |
| from typing import Dict, List, Optional, Tuple, Union |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from PIL import Image |
|
|
|
|
| class Reconstruction3DModule(nn.Module): |
| """Reconstruct 3D geometry from multi-view images.""" |
| |
| def __init__( |
| self, |
| model_size: str = "L", |
| device: str = "cuda", |
| dtype: torch.dtype = torch.float16, |
| cache_dir: Optional[str] = None, |
| ): |
| super().__init__() |
| self.model_size = model_size |
| self.device = device |
| self.dtype = dtype |
| self.cache_dir = cache_dir |
| |
| |
| self._trellis_model = None |
| self._native_model = None |
| |
| def reconstruct_room_shell( |
| self, |
| room_shell_views: Dict[str, Image.Image], |
| room_layout: Dict, |
| depth_map: np.ndarray, |
| ) -> "trimesh.Trimesh": |
| """ |
| Reconstruct room shell (walls, floor, ceiling) as planar meshes. |
| |
| Uses detected layout planes from scene understanding to create |
| watertight room geometry. |
| """ |
| try: |
| import trimesh |
| except ImportError: |
| print("Warning: trimesh not available, using numpy fallback") |
| return None |
| |
| meshes = [] |
| |
| |
| floor = room_layout.get("floor", {}) |
| if floor: |
| floor_mesh = self._create_floor_mesh(floor, room_layout) |
| if floor_mesh is not None: |
| meshes.append(floor_mesh) |
| |
| |
| ceiling = room_layout.get("ceiling", {}) |
| if ceiling: |
| ceiling_mesh = self._create_ceiling_mesh(ceiling, room_layout) |
| if ceiling_mesh is not None: |
| meshes.append(ceiling_mesh) |
| |
| |
| walls = room_layout.get("walls", []) |
| for wall in walls: |
| wall_mesh = self._create_wall_mesh(wall, room_layout) |
| if wall_mesh is not None: |
| meshes.append(wall_mesh) |
| |
| |
| if meshes: |
| try: |
| room_shell = trimesh.util.concatenate(meshes) |
| except Exception: |
| room_shell = meshes[0] |
| for m in meshes[1:]: |
| room_shell += m |
| return room_shell |
| |
| |
| return self._create_fallback_room(room_layout) |
| |
| def _create_floor_mesh(self, floor: Dict, room_layout: Dict) -> Optional["trimesh.Trimesh"]: |
| """Create floor plane mesh.""" |
| try: |
| import trimesh |
| except ImportError: |
| return None |
| |
| dims = room_layout.get("dimensions", {}) |
| width = dims.get("width", 5.0) |
| depth = dims.get("depth", 5.0) |
| height = floor.get("height", 0.0) |
| |
| |
| vertices = np.array([ |
| [-width/2, height, -depth/2], |
| [width/2, height, -depth/2], |
| [width/2, height, depth/2], |
| [-width/2, height, depth/2], |
| ]) |
| |
| faces = np.array([ |
| [0, 1, 2], |
| [0, 2, 3], |
| ]) |
| |
| mesh = trimesh.Trimesh(vertices=vertices, faces=faces) |
| |
| |
| uvs = np.array([ |
| [0, 0], |
| [1, 0], |
| [1, 1], |
| [0, 1], |
| ]) |
| mesh.visual = trimesh.visual.TextureVisuals(uv=uvs) |
| |
| return mesh |
| |
| def _create_ceiling_mesh(self, ceiling: Dict, room_layout: Dict) -> Optional["trimesh.Trimesh"]: |
| """Create ceiling plane mesh.""" |
| try: |
| import trimesh |
| except ImportError: |
| return None |
| |
| dims = room_layout.get("dimensions", {}) |
| width = dims.get("width", 5.0) |
| depth = dims.get("depth", 5.0) |
| height = ceiling.get("height", 2.7) |
| |
| vertices = np.array([ |
| [-width/2, height, -depth/2], |
| [width/2, height, -depth/2], |
| [width/2, height, depth/2], |
| [-width/2, height, depth/2], |
| ]) |
| |
| |
| faces = np.array([ |
| [0, 2, 1], |
| [0, 3, 2], |
| ]) |
| |
| mesh = trimesh.Trimesh(vertices=vertices, faces=faces) |
| return mesh |
| |
| def _create_wall_mesh(self, wall: Dict, room_layout: Dict) -> Optional["trimesh.Trimesh"]: |
| """Create wall plane mesh.""" |
| try: |
| import trimesh |
| except ImportError: |
| return None |
| |
| dims = room_layout.get("dimensions", {}) |
| width = dims.get("width", 5.0) |
| depth = dims.get("depth", 5.0) |
| height = dims.get("height", 2.7) |
| |
| normal = np.array(wall.get("normal", [0, 0, 1])) |
| position = wall.get("position", 0.0) |
| direction = wall.get("direction", "back") |
| |
| |
| if direction in ["back", "front"]: |
| |
| z = position if direction == "front" else -position |
| vertices = np.array([ |
| [-width/2, 0, z], |
| [width/2, 0, z], |
| [width/2, height, z], |
| [-width/2, height, z], |
| ]) |
| else: |
| |
| x = position if direction == "right" else -position |
| vertices = np.array([ |
| [x, 0, -depth/2], |
| [x, 0, depth/2], |
| [x, height, depth/2], |
| [x, height, -depth/2], |
| ]) |
| |
| |
| if normal[2] > 0.5 or normal[0] > 0.5: |
| faces = np.array([[0, 1, 2], [0, 2, 3]]) |
| else: |
| faces = np.array([[0, 2, 1], [0, 3, 2]]) |
| |
| mesh = trimesh.Trimesh(vertices=vertices, faces=faces) |
| return mesh |
| |
| def _create_fallback_room(self, room_layout: Dict) -> "trimesh.Trimesh": |
| """Create a simple box room as fallback.""" |
| import trimesh |
| |
| dims = room_layout.get("dimensions", {}) |
| width = dims.get("width", 5.0) |
| depth = dims.get("depth", 5.0) |
| height = dims.get("height", 2.7) |
| |
| |
| box = trimesh.creation.box(extents=[width, height, depth]) |
| box.apply_translation([0, height/2, 0]) |
| |
| return box |
| |
| def reconstruct_object( |
| self, |
| multiviews: List[Image.Image], |
| room_layout: Optional[Dict] = None, |
| depth_map: Optional[np.ndarray] = None, |
| object_info: Optional[Dict] = None, |
| ) -> Tuple["trimesh.Trimesh", Optional[torch.Tensor]]: |
| """ |
| Reconstruct a single furniture object from multi-view images. |
| |
| Uses TRELLIS.2 for high-quality object reconstruction, |
| or falls back to simple point cloud reconstruction. |
| |
| Returns: |
| (mesh, gaussian_cloud) |
| """ |
| |
| mesh = self._try_trellis_reconstruction(multiviews) |
| if mesh is not None: |
| return mesh, None |
| |
| |
| return self._fallback_object_reconstruction(multiviews, depth_map, object_info) |
| |
| def _try_trellis_reconstruction( |
| self, |
| multiviews: List[Image.Image], |
| ) -> Optional["trimesh.Trimesh"]: |
| """Try to use TRELLIS.2 for object reconstruction.""" |
| try: |
| |
| |
| |
| return None |
| except ImportError: |
| return None |
| |
| def _fallback_object_reconstruction( |
| self, |
| multiviews: List[Image.Image], |
| depth_map: Optional[np.ndarray] = None, |
| object_info: Optional[Dict] = None, |
| ) -> Tuple["trimesh.Trimesh", Optional[torch.Tensor]]: |
| """Simple reconstruction from first multi-view image and depth.""" |
| import trimesh |
| |
| if depth_map is not None and object_info is not None: |
| bbox = object_info.get("bbox", [0, 0, 100, 100]) |
| x1, y1, x2, y2 = bbox |
| |
| |
| obj_depth = depth_map[y1:y2, x1:x2] |
| |
| |
| H, W = obj_depth.shape |
| fx = fy = max(W, H) |
| cx, cy = W / 2, H / 2 |
| |
| u, v = np.meshgrid(np.arange(W), np.arange(H)) |
| z = obj_depth |
| x = (u - cx) * z / fx |
| y = (v - cy) * z / fy |
| |
| points = np.stack([x, y, z], axis=-1).reshape(-1, 3) |
| |
| |
| valid = points[:, 2] > 0.1 |
| points = points[valid] |
| |
| if len(points) > 100: |
| |
| try: |
| mesh = trimesh.convex.hull_points(points) |
| return mesh, None |
| except Exception: |
| pass |
| |
| |
| if len(points) > 0: |
| mesh = trimesh.PointCloud(points) |
| return mesh, None |
| |
| |
| mesh = trimesh.creation.box(extents=[0.5, 0.5, 0.5]) |
| return mesh, None |
| |
| def build_scene_gaussians( |
| self, |
| room_shell_mesh: "trimesh.Trimesh", |
| object_gaussians: List[Optional[torch.Tensor]], |
| object_meshes: List["trimesh.Trimesh"], |
| ) -> torch.Tensor: |
| """ |
| Build a unified Gaussian Splatting representation for the entire scene. |
| |
| Converts meshes to Gaussian primitives for fast rendering. |
| """ |
| gaussians = [] |
| |
| |
| try: |
| if hasattr(room_shell_mesh, 'vertices') and len(room_shell_mesh.vertices) > 0: |
| room_gaussians = self._mesh_to_gaussians(room_shell_mesh) |
| gaussians.append(room_gaussians) |
| except Exception as e: |
| print(f"Warning: could not convert room shell to Gaussians: {e}") |
| |
| |
| for obj_gauss in object_gaussians: |
| if obj_gauss is not None: |
| gaussians.append(obj_gauss) |
| |
| if gaussians: |
| return torch.cat(gaussians, dim=0) |
| |
| |
| return torch.zeros(0, 14, device=self.device) |
| |
| def _mesh_to_gaussians( |
| self, |
| mesh: "trimesh.Trimesh", |
| num_gaussians_per_face: int = 4, |
| ) -> torch.Tensor: |
| """ |
| Convert a mesh to 3D Gaussian primitives. |
| |
| Each face spawns multiple Gaussians with: |
| - Position: near face centroid |
| - Scale: based on face area |
| - Rotation: aligned with face normal |
| - Opacity: ~0.9 |
| - Color: from vertex colors or white |
| """ |
| if len(mesh.faces) == 0: |
| return torch.zeros(0, 14, device=self.device) |
| |
| vertices = torch.tensor(mesh.vertices, dtype=torch.float32, device=self.device) |
| faces = torch.tensor(mesh.faces, dtype=torch.long, device=self.device) |
| |
| num_faces = len(faces) |
| total_gaussians = num_faces * num_gaussians_per_face |
| |
| |
| v0 = vertices[faces[:, 0]] |
| v1 = vertices[faces[:, 1]] |
| v2 = vertices[faces[:, 2]] |
| |
| |
| centroids = (v0 + v1 + v2) / 3.0 |
| |
| |
| edges1 = v1 - v0 |
| edges2 = v2 - v0 |
| normals = torch.cross(edges1, edges2, dim=-1) |
| normals = F.normalize(normals, dim=-1) |
| |
| |
| areas = 0.5 * torch.norm(normals, dim=-1) |
| |
| |
| |
| |
| gaussians = [] |
| |
| for i in range(num_gaussians_per_face): |
| |
| offset = torch.randn_like(centroids) * 0.01 |
| positions = centroids + offset |
| |
| |
| scales = torch.stack([ |
| torch.sqrt(areas) * 0.1 + 0.001, |
| torch.sqrt(areas) * 0.1 + 0.001, |
| torch.sqrt(areas) * 0.05 + 0.001, |
| ], dim=-1) |
| |
| |
| |
| rot_identity = torch.tensor([0.0, 0.0, 0.0, 1.0], device=self.device) |
| rotations = rot_identity.unsqueeze(0).expand(num_faces, -1) |
| |
| |
| colors = torch.ones(num_faces, 3, device=self.device) * 0.8 |
| |
| |
| opacity = torch.ones(num_faces, 1, device=self.device) * 0.9 |
| |
| gaussians.append(torch.cat([ |
| positions, scales, rotations, colors, opacity |
| ], dim=-1)) |
| |
| return torch.cat(gaussians, dim=0) |
|
|