InteriorFusion / src /interiorfusion /models /reconstruction_3d.py
stevee00's picture
Upload src/interiorfusion/models/reconstruction_3d.py
2033370 verified
"""Phase 3: 3D Reconstruction Module.
Reconstructs:
- Room shell (walls, floor, ceiling) as planar meshes
- Per-object 3D meshes using TRELLIS.2 or native InteriorFusion-L
- Scene-level Gaussian Splatting representation
"""
import os
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
class Reconstruction3DModule(nn.Module):
"""Reconstruct 3D geometry from multi-view images."""
def __init__(
self,
model_size: str = "L",
device: str = "cuda",
dtype: torch.dtype = torch.float16,
cache_dir: Optional[str] = None,
):
super().__init__()
self.model_size = model_size
self.device = device
self.dtype = dtype
self.cache_dir = cache_dir
# Lazy load reconstruction models
self._trellis_model = None
self._native_model = None
def reconstruct_room_shell(
self,
room_shell_views: Dict[str, Image.Image],
room_layout: Dict,
depth_map: np.ndarray,
) -> "trimesh.Trimesh": # type: ignore
"""
Reconstruct room shell (walls, floor, ceiling) as planar meshes.
Uses detected layout planes from scene understanding to create
watertight room geometry.
"""
try:
import trimesh
except ImportError:
print("Warning: trimesh not available, using numpy fallback")
return None
meshes = []
# Floor mesh
floor = room_layout.get("floor", {})
if floor:
floor_mesh = self._create_floor_mesh(floor, room_layout)
if floor_mesh is not None:
meshes.append(floor_mesh)
# Ceiling mesh
ceiling = room_layout.get("ceiling", {})
if ceiling:
ceiling_mesh = self._create_ceiling_mesh(ceiling, room_layout)
if ceiling_mesh is not None:
meshes.append(ceiling_mesh)
# Wall meshes
walls = room_layout.get("walls", [])
for wall in walls:
wall_mesh = self._create_wall_mesh(wall, room_layout)
if wall_mesh is not None:
meshes.append(wall_mesh)
# Combine all meshes
if meshes:
try:
room_shell = trimesh.util.concatenate(meshes)
except Exception:
room_shell = meshes[0]
for m in meshes[1:]:
room_shell += m
return room_shell
# Fallback: create simple box room
return self._create_fallback_room(room_layout)
def _create_floor_mesh(self, floor: Dict, room_layout: Dict) -> Optional["trimesh.Trimesh"]: # type: ignore
"""Create floor plane mesh."""
try:
import trimesh
except ImportError:
return None
dims = room_layout.get("dimensions", {})
width = dims.get("width", 5.0)
depth = dims.get("depth", 5.0)
height = floor.get("height", 0.0)
# Create rectangular floor
vertices = np.array([
[-width/2, height, -depth/2],
[width/2, height, -depth/2],
[width/2, height, depth/2],
[-width/2, height, depth/2],
])
faces = np.array([
[0, 1, 2],
[0, 2, 3],
])
mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
# Add UV coordinates for texture mapping
uvs = np.array([
[0, 0],
[1, 0],
[1, 1],
[0, 1],
])
mesh.visual = trimesh.visual.TextureVisuals(uv=uvs)
return mesh
def _create_ceiling_mesh(self, ceiling: Dict, room_layout: Dict) -> Optional["trimesh.Trimesh"]: # type: ignore
"""Create ceiling plane mesh."""
try:
import trimesh
except ImportError:
return None
dims = room_layout.get("dimensions", {})
width = dims.get("width", 5.0)
depth = dims.get("depth", 5.0)
height = ceiling.get("height", 2.7)
vertices = np.array([
[-width/2, height, -depth/2],
[width/2, height, -depth/2],
[width/2, height, depth/2],
[-width/2, height, depth/2],
])
# Ceiling faces point downward
faces = np.array([
[0, 2, 1],
[0, 3, 2],
])
mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
return mesh
def _create_wall_mesh(self, wall: Dict, room_layout: Dict) -> Optional["trimesh.Trimesh"]: # type: ignore
"""Create wall plane mesh."""
try:
import trimesh
except ImportError:
return None
dims = room_layout.get("dimensions", {})
width = dims.get("width", 5.0)
depth = dims.get("depth", 5.0)
height = dims.get("height", 2.7)
normal = np.array(wall.get("normal", [0, 0, 1]))
position = wall.get("position", 0.0)
direction = wall.get("direction", "back")
# Create wall based on direction
if direction in ["back", "front"]:
# Wall perpendicular to z-axis
z = position if direction == "front" else -position
vertices = np.array([
[-width/2, 0, z],
[width/2, 0, z],
[width/2, height, z],
[-width/2, height, z],
])
else: # left or right
# Wall perpendicular to x-axis
x = position if direction == "right" else -position
vertices = np.array([
[x, 0, -depth/2],
[x, 0, depth/2],
[x, height, depth/2],
[x, height, -depth/2],
])
# Determine face orientation based on normal
if normal[2] > 0.5 or normal[0] > 0.5:
faces = np.array([[0, 1, 2], [0, 2, 3]])
else:
faces = np.array([[0, 2, 1], [0, 3, 2]])
mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
return mesh
def _create_fallback_room(self, room_layout: Dict) -> "trimesh.Trimesh": # type: ignore
"""Create a simple box room as fallback."""
import trimesh
dims = room_layout.get("dimensions", {})
width = dims.get("width", 5.0)
depth = dims.get("depth", 5.0)
height = dims.get("height", 2.7)
# Create box with interior
box = trimesh.creation.box(extents=[width, height, depth])
box.apply_translation([0, height/2, 0])
return box
def reconstruct_object(
self,
multiviews: List[Image.Image],
room_layout: Optional[Dict] = None,
depth_map: Optional[np.ndarray] = None,
object_info: Optional[Dict] = None,
) -> Tuple["trimesh.Trimesh", Optional[torch.Tensor]]: # type: ignore
"""
Reconstruct a single furniture object from multi-view images.
Uses TRELLIS.2 for high-quality object reconstruction,
or falls back to simple point cloud reconstruction.
Returns:
(mesh, gaussian_cloud)
"""
# Try TRELLIS.2 if available
mesh = self._try_trellis_reconstruction(multiviews)
if mesh is not None:
return mesh, None
# Fallback: simple reconstruction from depth
return self._fallback_object_reconstruction(multiviews, depth_map, object_info)
def _try_trellis_reconstruction(
self,
multiviews: List[Image.Image],
) -> Optional["trimesh.Trimesh"]: # type: ignore
"""Try to use TRELLIS.2 for object reconstruction."""
try:
# Attempt to import and use TRELLIS
# In production: from trellis import TRELLISPipeline
# For now, placeholder
return None
except ImportError:
return None
def _fallback_object_reconstruction(
self,
multiviews: List[Image.Image],
depth_map: Optional[np.ndarray] = None,
object_info: Optional[Dict] = None,
) -> Tuple["trimesh.Trimesh", Optional[torch.Tensor]]: # type: ignore
"""Simple reconstruction from first multi-view image and depth."""
import trimesh
if depth_map is not None and object_info is not None:
bbox = object_info.get("bbox", [0, 0, 100, 100])
x1, y1, x2, y2 = bbox
# Extract depth region for this object
obj_depth = depth_map[y1:y2, x1:x2]
# Create point cloud from depth
H, W = obj_depth.shape
fx = fy = max(W, H)
cx, cy = W / 2, H / 2
u, v = np.meshgrid(np.arange(W), np.arange(H))
z = obj_depth
x = (u - cx) * z / fx
y = (v - cy) * z / fy
points = np.stack([x, y, z], axis=-1).reshape(-1, 3)
# Remove invalid points
valid = points[:, 2] > 0.1
points = points[valid]
if len(points) > 100:
# Create convex hull as simple mesh
try:
mesh = trimesh.convex.hull_points(points)
return mesh, None
except Exception:
pass
# If hull fails, return point cloud as mesh
if len(points) > 0:
mesh = trimesh.PointCloud(points)
return mesh, None
# Ultimate fallback: small cube
mesh = trimesh.creation.box(extents=[0.5, 0.5, 0.5])
return mesh, None
def build_scene_gaussians(
self,
room_shell_mesh: "trimesh.Trimesh", # type: ignore
object_gaussians: List[Optional[torch.Tensor]],
object_meshes: List["trimesh.Trimesh"], # type: ignore
) -> torch.Tensor:
"""
Build a unified Gaussian Splatting representation for the entire scene.
Converts meshes to Gaussian primitives for fast rendering.
"""
gaussians = []
# Convert room shell mesh to Gaussians
try:
if hasattr(room_shell_mesh, 'vertices') and len(room_shell_mesh.vertices) > 0:
room_gaussians = self._mesh_to_gaussians(room_shell_mesh)
gaussians.append(room_gaussians)
except Exception as e:
print(f"Warning: could not convert room shell to Gaussians: {e}")
# Add per-object Gaussians
for obj_gauss in object_gaussians:
if obj_gauss is not None:
gaussians.append(obj_gauss)
if gaussians:
return torch.cat(gaussians, dim=0)
# Fallback: return empty tensor
return torch.zeros(0, 14, device=self.device)
def _mesh_to_gaussians(
self,
mesh: "trimesh.Trimesh", # type: ignore
num_gaussians_per_face: int = 4,
) -> torch.Tensor:
"""
Convert a mesh to 3D Gaussian primitives.
Each face spawns multiple Gaussians with:
- Position: near face centroid
- Scale: based on face area
- Rotation: aligned with face normal
- Opacity: ~0.9
- Color: from vertex colors or white
"""
if len(mesh.faces) == 0:
return torch.zeros(0, 14, device=self.device)
vertices = torch.tensor(mesh.vertices, dtype=torch.float32, device=self.device)
faces = torch.tensor(mesh.faces, dtype=torch.long, device=self.device)
num_faces = len(faces)
total_gaussians = num_faces * num_gaussians_per_face
# Get face data
v0 = vertices[faces[:, 0]]
v1 = vertices[faces[:, 1]]
v2 = vertices[faces[:, 2]]
# Face centroids
centroids = (v0 + v1 + v2) / 3.0
# Face normals
edges1 = v1 - v0
edges2 = v2 - v0
normals = torch.cross(edges1, edges2, dim=-1)
normals = F.normalize(normals, dim=-1)
# Face areas
areas = 0.5 * torch.norm(normals, dim=-1)
# Build Gaussians
# Gaussian parameters: [x, y, z, scale_x, scale_y, scale_z,
# rot_qx, rot_qy, rot_qz, rot_qw, r, g, b, opacity]
gaussians = []
for i in range(num_gaussians_per_face):
# Offset from centroid
offset = torch.randn_like(centroids) * 0.01
positions = centroids + offset
# Scale based on area
scales = torch.stack([
torch.sqrt(areas) * 0.1 + 0.001,
torch.sqrt(areas) * 0.1 + 0.001,
torch.sqrt(areas) * 0.05 + 0.001,
], dim=-1)
# Rotation from normal
# Simple: identity-ish rotation aligned with normal
rot_identity = torch.tensor([0.0, 0.0, 0.0, 1.0], device=self.device)
rotations = rot_identity.unsqueeze(0).expand(num_faces, -1)
# Color: white default
colors = torch.ones(num_faces, 3, device=self.device) * 0.8
# Opacity
opacity = torch.ones(num_faces, 1, device=self.device) * 0.9
gaussians.append(torch.cat([
positions, scales, rotations, colors, opacity
], dim=-1))
return torch.cat(gaussians, dim=0)