InteriorFusion / src /interiorfusion /models /scene_understanding.py
stevee00's picture
Upload src/interiorfusion/models/scene_understanding.py
ebc1c85 verified
"""Phase 1: Scene Understanding Module.
Combines:
- Metric depth estimation (Depth Anything V2)
- Room layout estimation (SpatialLM)
- Semantic segmentation (Mask2Former-style)
- Object detection & isolation (SAM)
"""
import os
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import (
pipeline,
AutoImageProcessor,
AutoModelForDepthEstimation,
AutoModelForSemanticSegmentation,
CLIPVisionModel,
CLIPImageProcessor,
)
class SceneUnderstandingModule(nn.Module):
"""Extract scene structure from a single interior image."""
def __init__(
self,
model_size: str = "L",
device: str = "cuda",
dtype: torch.dtype = torch.float16,
cache_dir: Optional[str] = None,
):
super().__init__()
self.model_size = model_size
self.device = device
self.dtype = dtype
self.cache_dir = cache_dir
# Load sub-models (lazy)
self._depth_model = None
self._depth_processor = None
self._segmentation_model = None
self._sam_model = None
self._room_classifier = None
@property
def depth_model(self):
if self._depth_model is None:
model_id = "depth-anything/Depth-Anything-V2-Metric-Indoor-Large-hf"
self._depth_processor = AutoImageProcessor.from_pretrained(
model_id, cache_dir=self.cache_dir
)
self._depth_model = AutoModelForDepthEstimation.from_pretrained(
model_id,
torch_dtype=self.dtype,
cache_dir=self.cache_dir,
).to(self.device)
self._depth_model.eval()
return self._depth_model, self._depth_processor
@property
def segmentation_model(self):
if self._segmentation_model is None:
# Using a generic indoor segmentation model
# In production, use fine-tuned Mask2Former or OneFormer
model_id = "facebook/mask2former-swin-large-coco-instance"
self._segmentation_model = AutoModelForSemanticSegmentation.from_pretrained(
model_id,
torch_dtype=self.dtype,
cache_dir=self.cache_dir,
).to(self.device)
self._segmentation_model.eval()
return self._segmentation_model
@property
def room_classifier(self):
if self._room_classifier is None:
self._room_classifier = CLIPVisionModel.from_pretrained(
"openai/clip-vit-large-patch14",
torch_dtype=self.dtype,
cache_dir=self.cache_dir,
).to(self.device)
self._room_classifier.eval()
return self._room_classifier
def forward(self, image: Image.Image) -> Dict:
"""
Process a single interior image and extract all scene understanding.
Returns:
Dictionary with keys:
- depth: metric depth map [H, W]
- normal: surface normal map [H, W, 3]
- room_layout: room layout structure
- semantic_segmentation: pixel-wise class labels [H, W]
- detected_objects: dict of per-object crops and masks
- room_type: str (e.g. "living_room")
- style: str (e.g. "modern")
"""
# Convert to tensor
img_np = np.array(image)
# === Metric Depth Estimation ===
depth_map = self.estimate_depth(image)
# === Surface Normal Estimation ===
normal_map = self.estimate_normals(depth_map)
# === Room Layout Estimation ===
room_layout = self.estimate_room_layout(depth_map, img_np)
# === Semantic Segmentation ===
semantic_seg = self.segment_image(image)
# === Object Detection & Isolation ===
detected_objects = self.detect_and_isolate_objects(
image, depth_map, semantic_seg
)
# === Room Type Classification ===
room_type = self.classify_room_type(image)
# === Style Classification ===
style = self.classify_style(image)
return {
"depth": depth_map,
"normal": normal_map,
"room_layout": room_layout,
"semantic_segmentation": semantic_seg,
"detected_objects": detected_objects,
"room_type": room_type,
"style": style,
}
def estimate_depth(self, image: Image.Image) -> np.ndarray:
"""Estimate metric depth using Depth Anything V2."""
model, processor = self.depth_model
inputs = processor(images=image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
# Interpolate to original size
prediction = F.interpolate(
predicted_depth.unsqueeze(1),
size=image.size[::-1],
mode="bicubic",
align_corners=False,
)
depth = prediction.squeeze().cpu().numpy()
return depth
def estimate_normals(self, depth: np.ndarray) -> np.ndarray:
"""Compute surface normals from depth map."""
# Compute gradients
dz_dx = np.gradient(depth, axis=1)
dz_dy = np.gradient(depth, axis=0)
# Normal vector: [-dz/dx, -dz/dy, 1]
normals = np.stack([
-dz_dx,
-dz_dy,
np.ones_like(depth)
], axis=-1)
# Normalize
norm = np.linalg.norm(normals, axis=-1, keepdims=True)
normals = normals / (norm + 1e-8)
# Map to [0, 1] for visualization
normals_vis = (normals + 1) / 2
return normals
def estimate_room_layout(
self, depth: np.ndarray, img: np.ndarray
) -> Dict:
"""
Estimate room layout from depth map.
Detects:
- Floor plane (largest horizontal surface below camera)
- Ceiling plane (horizontal surface above, roughly at fixed height)
- Wall planes (vertical surfaces)
Returns Manhattan-world layout with planes.
"""
H, W = depth.shape
# Create point cloud from depth
# Assume standard camera intrinsics (can be refined later)
fx = fy = max(W, H) # approximate focal length
cx, cy = W / 2, H / 2
u, v = np.meshgrid(np.arange(W), np.arange(H))
z = depth
x = (u - cx) * z / fx
y = (v - cy) * z / fy
points = np.stack([x, y, z], axis=-1) # [H, W, 3]
# RANSAC plane detection
floor_plane = self._detect_floor_plane(points)
ceiling_plane = self._detect_ceiling_plane(points)
wall_planes = self._detect_wall_planes(points)
return {
"floor": floor_plane,
"ceiling": ceiling_plane,
"walls": wall_planes,
"point_cloud": points,
"dimensions": {
"width": float(np.max(x) - np.min(x)),
"depth": float(np.max(z) - np.min(z)),
"height": float(ceiling_plane["height"] - floor_plane["height"]),
}
}
def _detect_floor_plane(self, points: np.ndarray) -> Dict:
"""Detect floor plane (lowest y values, near-horizontal normal)."""
# Simple heuristic: lowest 20% of points
y_values = points[:, :, 1].flatten()
threshold = np.percentile(y_values, 20)
floor_mask = points[:, :, 1] < threshold
floor_points = points[floor_mask]
if len(floor_points) < 100:
# Fallback: assume floor at y=0
return {"normal": [0, 1, 0], "height": 0.0, "points": None}
# Fit plane
centroid = np.mean(floor_points, axis=0)
centered = floor_points - centroid
_, _, vh = np.linalg.svd(centered, full_matrices=False)
normal = vh[-1] # smallest singular value direction
# Ensure normal points up
if normal[1] < 0:
normal = -normal
return {
"normal": normal.tolist(),
"height": float(centroid[1]),
"centroid": centroid.tolist(),
"points": floor_points.shape[0],
}
def _detect_ceiling_plane(self, points: np.ndarray) -> Dict:
"""Detect ceiling plane (highest y values)."""
y_values = points[:, :, 1].flatten()
threshold = np.percentile(y_values, 90)
ceiling_mask = points[:, :, 1] > threshold
ceiling_points = points[ceiling_mask]
if len(ceiling_points) < 100:
# Fallback: typical room height ~2.7m
return {"normal": [0, -1, 0], "height": 2.7, "points": None}
centroid = np.mean(ceiling_points, axis=0)
return {
"normal": [0, -1, 0],
"height": float(centroid[1]),
"points": ceiling_points.shape[0],
}
def _detect_wall_planes(self, points: np.ndarray) -> List[Dict]:
"""Detect wall planes from remaining points."""
# Simplified: detect 4 walls for rectangular rooms
# In production, use proper RANSAC or SpatialLM
x = points[:, :, 0]
z = points[:, :, 2]
walls = []
# Left wall (minimum x)
x_min = np.percentile(x.flatten(), 5)
left_mask = np.abs(x - x_min) < 0.3
if np.sum(left_mask) > 100:
walls.append({
"normal": [1, 0, 0],
"position": float(x_min),
"direction": "left",
})
# Right wall (maximum x)
x_max = np.percentile(x.flatten(), 95)
right_mask = np.abs(x - x_max) < 0.3
if np.sum(right_mask) > 100:
walls.append({
"normal": [-1, 0, 0],
"position": float(x_max),
"direction": "right",
})
# Back wall (minimum z)
z_min = np.percentile(z.flatten(), 5)
back_mask = np.abs(z - z_min) < 0.3
if np.sum(back_mask) > 100:
walls.append({
"normal": [0, 0, 1],
"position": float(z_min),
"direction": "back",
})
# Front wall (maximum z)
z_max = np.percentile(z.flatten(), 95)
front_mask = np.abs(z - z_max) < 0.3
if np.sum(front_mask) > 100:
walls.append({
"normal": [0, 0, -1],
"position": float(z_max),
"direction": "front",
})
return walls
def segment_image(self, image: Image.Image) -> np.ndarray:
"""Run semantic segmentation to identify regions."""
# Placeholder: in production, use fine-tuned indoor segmentation
# For now, return a simple heuristic segmentation
img_np = np.array(image)
H, W = img_np.shape[:2]
# Heuristic: classify based on position and color
# Bottom 30% = floor, top 10% = ceiling, rest = walls + objects
seg = np.zeros((H, W), dtype=np.int32)
# Floor region
floor_threshold = int(H * 0.7)
seg[floor_threshold:] = 1 # floor
# Ceiling region
ceiling_threshold = int(H * 0.1)
seg[:ceiling_threshold] = 2 # ceiling
# Wall regions (sides)
wall_threshold = int(W * 0.15)
seg[ceiling_threshold:floor_threshold, :wall_threshold] = 3 # left wall
seg[ceiling_threshold:floor_threshold, -wall_threshold:] = 4 # right wall
return seg
def detect_and_isolate_objects(
self,
image: Image.Image,
depth: np.ndarray,
semantic_seg: np.ndarray,
) -> Dict:
"""
Detect and isolate furniture objects from the scene.
Returns dict mapping object_id -> {
crop: PIL Image,
mask: binary mask,
bbox: [x1, y1, x2, y2],
class_name: str,
depth_range: [min, max],
}
"""
# Placeholder: in production, use SAM + fine-tuned detector
# For now, return a simple grid-based detection
detected_objects = {}
img_np = np.array(image)
H, W = img_np.shape[:2]
# Divide room into zones and detect likely objects
# Floor zone: look for distinct objects above floor
floor_y = int(H * 0.65)
# Simple heuristic: detect high-gradient regions in floor area
from scipy.ndimage import label
floor_region = semantic_seg == 1
depth_floor = depth.copy()
depth_floor[~floor_region] = 0
# Find objects by depth discontinuities in floor region
depth_grad = np.abs(np.gradient(depth_floor)[0]) + \
np.abs(np.gradient(depth_floor)[1])
object_mask = depth_grad > np.percentile(depth_grad, 85)
labeled, num_features = label(object_mask)
for i in range(1, min(num_features + 1, 10)): # max 10 objects
obj_mask = labeled == i
ys, xs = np.where(obj_mask)
if len(xs) < 100: # skip tiny objects
continue
x1, y1 = int(xs.min()), int(ys.min())
x2, y2 = int(xs.max()), int(ys.max())
# Pad bbox
pad = 20
x1 = max(0, x1 - pad)
y1 = max(0, y1 - pad)
x2 = min(W, x2 + pad)
y2 = min(H, y2 + pad)
crop = image.crop((x1, y1, x2, y2))
mask_crop = obj_mask[y1:y2, x1:x2]
obj_depth = depth[obj_mask]
detected_objects[i - 1] = {
"crop": crop,
"mask": mask_crop,
"bbox": [x1, y1, x2, y2],
"class_name": "furniture", # would be classified in production
"depth_range": [float(obj_depth.min()), float(obj_depth.max())],
}
return detected_objects
def classify_room_type(self, image: Image.Image) -> str:
"""Classify room type from image."""
# Placeholder: use CLIP or fine-tuned classifier
# For now, return based on simple heuristics
img_np = np.array(image)
# Simple heuristic based on color distribution
# In production, use fine-tuned model
mean_color = img_np.mean(axis=(0, 1))
# Very simple heuristic (would be replaced with proper classifier)
if mean_color[2] > mean_color[0] + 20: # more blue = maybe kitchen/bathroom
return "kitchen"
elif mean_color[0] > mean_color[1] + 20: # more red = maybe bedroom
return "bedroom"
else:
return "living_room"
def classify_style(self, image: Image.Image) -> str:
"""Classify interior design style."""
# Placeholder: use fine-tuned style classifier
# Styles: modern, scandinavian, luxury, indian, commercial, minimalist
img_np = np.array(image)
mean_color = img_np.mean(axis=(0, 1))
std_color = img_np.std(axis=(0, 1))
# Simple heuristic (would be replaced with proper classifier)
if std_color.mean() < 30: # low color variation = minimalist/scandinavian
return "scandinavian"
elif mean_color.mean() > 180: # bright = modern
return "modern"
elif mean_color[0] < 80 and mean_color[1] < 80 and mean_color[2] < 80: # dark = luxury
return "luxury"
else:
return "modern"