diff --git "a/app.py" "b/app.py"
--- "a/app.py"
+++ "b/app.py"
@@ -1,9 +1,17 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
import gc
import os
import shutil
import sys
import time
from datetime import datetime
+from pathlib import Path
+from collections import defaultdict
+from typing import List, Dict, Tuple
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
@@ -12,11 +20,12 @@ import gradio as gr
import numpy as np
import spaces
import torch
+import trimesh
from PIL import Image
from pillow_heif import register_heif_opener
+from sklearn.cluster import DBSCAN
register_heif_opener()
-
sys.path.append("mapanything/")
from mapanything.utils.geometry import depthmap_to_world_frame, points_to_normals
@@ -27,10 +36,14 @@ from mapanything.utils.hf_utils.css_and_html import (
get_gradio_theme,
)
from mapanything.utils.hf_utils.hf_helpers import initialize_mapanything_model
-from mapanything.utils.hf_utils.viz import predictions_to_glb
+from mapanything.utils.hf_utils.visual_util import predictions_to_glb
from mapanything.utils.image import load_images, rgb
+# ============================================================================
+# Global Configuration
+# ============================================================================
+
# MapAnything Configuration
high_level_config = {
"path": "configs/train.yaml",
@@ -51,13 +64,616 @@ high_level_config = {
"resolution": 518,
}
-# Initialize model - this will be done on GPU when needed
+# GroundingDINO Configuration
+GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny"
+GROUNDING_DINO_BOX_THRESHOLD = 0.25
+GROUNDING_DINO_TEXT_THRESHOLD = 0.2
+
+# SAM Configuration
+SAM_MODEL_ID = "facebook/sam-vit-huge"
+
+DEFAULT_TEXT_PROMPT = "window . table . sofa . tv . book . door"
+
+COMMON_OBJECTS_PROMPT = (
+ "person . face . hand . "
+ "chair . sofa . couch . bed . table . desk . cabinet . shelf . drawer . "
+ "door . window . wall . floor . ceiling . curtain . "
+ "tv . monitor . screen . computer . laptop . keyboard . mouse . "
+ "phone . tablet . remote . "
+ "lamp . light . chandelier . "
+ "book . magazine . paper . pen . pencil . "
+ "bottle . cup . glass . mug . plate . bowl . fork . knife . spoon . "
+ "vase . plant . flower . pot . "
+ "clock . picture . frame . mirror . "
+ "pillow . cushion . blanket . towel . "
+ "bag . backpack . suitcase . "
+ "box . basket . container . "
+ "shoe . hat . coat . "
+ "toy . ball . "
+ "car . bicycle . motorcycle . bus . truck . "
+ "tree . grass . sky . cloud . sun . "
+ "dog . cat . bird . "
+ "building . house . bridge . road . street . "
+ "sign . pole . bench"
+)
+
+# DBSCAN Clustering Configuration
+DBSCAN_EPS_CONFIG = {
+ 'sofa': 1.5,
+ 'bed': 1.5,
+ 'couch': 1.5,
+ 'desk': 0.8,
+ 'table': 0.8,
+ 'chair': 0.6,
+ 'cabinet': 0.8,
+ 'window': 0.5,
+ 'door': 0.6,
+ 'tv': 0.6,
+ 'default': 1.0
+}
+
+DBSCAN_MIN_SAMPLES = 1
+ENABLE_VISUAL_FEATURES = False
+
+# Segmentation Quality Control
+MIN_DETECTION_CONFIDENCE = 0.35
+MIN_MASK_AREA = 100
+
+# Matching score calculation configuration
+MATCH_3D_DISTANCE_THRESHOLD = 2.5
+
+# Global model variables
model = None
+grounding_dino_model = None
+grounding_dino_processor = None
+sam_predictor = None
+
+
+# ============================================================================
+# Segmentation Model Loading
+# ============================================================================
+
+def load_grounding_dino_model(device):
+ global grounding_dino_model, grounding_dino_processor
+ if grounding_dino_model is not None:
+ print("GroundingDINO already loaded")
+ return
+ try:
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
+ print(f"Loading GroundingDINO from HuggingFace: {GROUNDING_DINO_MODEL_ID}")
+ grounding_dino_processor = AutoProcessor.from_pretrained(GROUNDING_DINO_MODEL_ID)
+ grounding_dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(
+ GROUNDING_DINO_MODEL_ID
+ ).to(device).eval()
+ print("GroundingDINO loaded successfully")
+ except Exception as e:
+ print(f"Failed to load GroundingDINO: {e}")
+ import traceback
+ traceback.print_exc()
+
+def load_sam_model(device):
+ global sam_predictor
+ if sam_predictor is not None:
+ print("SAM already loaded")
+ return
+ try:
+ from transformers import SamModel, SamProcessor
+ print(f"Loading SAM from HuggingFace: {SAM_MODEL_ID}")
+ sam_model = SamModel.from_pretrained(SAM_MODEL_ID).to(device).eval()
+ sam_processor = SamProcessor.from_pretrained(SAM_MODEL_ID)
+ sam_predictor = {'model': sam_model, 'processor': sam_processor}
+ print("SAM loaded successfully")
+ except Exception as e:
+ print(f"Failed to load SAM: {e}")
+ print("SAM functionality will be disabled, bounding boxes will be used as masks instead")
+ import traceback
+ traceback.print_exc()
+
+
+# ============================================================================
+# Segmentation Functions
+# ============================================================================
+
+def generate_distinct_colors(n):
+ import colorsys
+ if n == 0:
+ return []
+ colors = []
+ for i in range(n):
+ hue = i / max(n, 1)
+ rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
+ rgb_color = tuple(int(c * 255) for c in rgb)
+ colors.append(rgb_color)
+ return colors
+
+def run_grounding_dino_detection(image_np, text_prompt, device):
+ if grounding_dino_model is None or grounding_dino_processor is None:
+ print("GroundingDINO is not loaded")
+ return []
+ try:
+ print(f"GroundingDINO Detection: {text_prompt}")
+ if image_np.dtype == np.uint8:
+ pil_image = Image.fromarray(image_np)
+ else:
+ pil_image = Image.fromarray((image_np * 255).astype(np.uint8))
+
+ inputs = grounding_dino_processor(images=pil_image, text=text_prompt, return_tensors="pt")
+ inputs = {k: v.to(device) for k, v in inputs.items()}
+
+ with torch.no_grad():
+ outputs = grounding_dino_model(**inputs)
+
+ results = grounding_dino_processor.post_process_grounded_object_detection(
+ outputs,
+ inputs["input_ids"],
+ threshold=GROUNDING_DINO_BOX_THRESHOLD,
+ text_threshold=GROUNDING_DINO_TEXT_THRESHOLD,
+ target_sizes=[pil_image.size[::-1]]
+ )[0]
+
+ detections = []
+ boxes = results["boxes"].cpu().numpy()
+ scores = results["scores"].cpu().numpy()
+ labels = results["labels"]
+
+ print(f"Detected {len(boxes)} objects")
+ for box, score, label in zip(boxes, scores, labels):
+ detection = {
+ 'bbox': box.tolist(),
+ 'label': label,
+ 'confidence': float(score)
+ }
+ detections.append(detection)
+ print(f" - {label}: {score:.2f}")
+ return detections
+ except Exception as e:
+ print(f"GroundingDINO detection failed: {e}")
+ import traceback
+ traceback.print_exc()
+ return []
+
+def run_sam_refinement(image_np, boxes):
+ if sam_predictor is None:
+ print("SAM is not loaded, using bbox as mask")
+ masks = []
+ h, w = image_np.shape[:2]
+ for box in boxes:
+ x1, y1, x2, y2 = map(int, box)
+ mask = np.zeros((h, w), dtype=bool)
+ mask[y1:y2, x1:x2] = True
+ masks.append(mask)
+ return masks
+ try:
+ print(f"SAM accurate segmentation for {len(boxes)} regions...")
+ from PIL import Image
+ sam_model = sam_predictor['model']
+ sam_processor = sam_predictor['processor']
+ device = sam_model.device
+
+ if image_np.dtype == np.uint8:
+ pil_image = Image.fromarray(image_np)
+ else:
+ pil_image = Image.fromarray((image_np * 255).astype(np.uint8))
+
+ masks = []
+ for box in boxes:
+ x1, y1, x2, y2 = map(int, box)
+ input_boxes = [[[x1, y1, x2, y2]]]
+
+ inputs = sam_processor(pil_image, input_boxes=input_boxes, return_tensors="pt")
+ inputs = {k: v.to(device) for k, v in inputs.items()}
+
+ with torch.no_grad():
+ outputs = sam_model(**inputs)
+
+ pred_masks = sam_processor.image_processor.post_process_masks(
+ outputs.pred_masks.cpu(),
+ inputs["original_sizes"].cpu(),
+ inputs["reshaped_input_sizes"].cpu()
+ )[0][0][0]
+
+ masks.append(pred_masks.numpy() > 0.5)
+ print("SAM segmentation completed")
+ return masks
+ except Exception as e:
+ print(f"SAM segmentation failed: {e}")
+ import traceback
+ traceback.print_exc()
+ masks = []
+ h, w = image_np.shape[:2]
+ for box in boxes:
+ x1, y1, x2, y2 = map(int, box)
+ mask = np.zeros((h, w), dtype=bool)
+ mask[y1:y2, x1:x2] = True
+ masks.append(mask)
+ return masks
+
+def normalize_label(label):
+ label = label.strip().lower()
+ priority_labels = ['sofa', 'bed', 'table', 'desk', 'chair', 'cabinet', 'window', 'door']
+ for priority in priority_labels:
+ if priority in label:
+ return priority
+ first_word = label.split()[0] if label else label
+ if first_word.endswith('s') and len(first_word) > 1:
+ singular = first_word[:-1]
+ if first_word.endswith('sses'):
+ singular = first_word[:-2]
+ elif first_word.endswith('ies'):
+ singular = first_word[:-3] + 'y'
+ elif first_word.endswith('ves'):
+ singular = first_word[:-3] + 'f'
+ return singular
+ return first_word
+
+def labels_match(label1, label2):
+ return normalize_label(label1) == normalize_label(label2)
+
+def compute_object_3d_center(points, mask):
+ masked_points = points[mask]
+ if len(masked_points) == 0:
+ return None
+ return np.median(masked_points, axis=0)
+
+def compute_3d_bbox_iou(center1, size1, center2, size2):
+ try:
+ min1 = center1 - size1 / 2
+ max1 = center1 + size1 / 2
+ min2 = center2 - size2 / 2
+ max2 = center2 + size2 / 2
+
+ inter_min = np.maximum(min1, min2)
+ inter_max = np.minimum(max1, max2)
+ inter_size = np.maximum(0, inter_max - inter_min)
+ inter_volume = np.prod(inter_size)
+
+ volume1 = np.prod(size1)
+ volume2 = np.prod(size2)
+ union_volume = volume1 + volume2 - inter_volume
+
+ if union_volume == 0:
+ return 0.0
+ return inter_volume / union_volume
+ except:
+ return 0.0
+
+def compute_2d_mask_iou(mask1, mask2):
+ try:
+ intersection = np.logical_and(mask1, mask2).sum()
+ union = np.logical_or(mask1, mask2).sum()
+ if union == 0:
+ return 0.0
+ return intersection / union
+ except:
+ return 0.0
+
+def extract_visual_features(image, mask, encoder):
+ try:
+ coords = np.argwhere(mask)
+ if len(coords) == 0:
+ return None
+ y_min, x_min = coords.min(axis=0)
+ y_max, x_max = coords.max(axis=0)
+
+ if y_max <= y_min or x_max <= x_min:
+ return None
+
+ cropped = image[y_min:y_max+1, x_min:x_max+1]
+ if cropped.dtype == np.float32 or cropped.dtype == np.float64:
+ if cropped.max() <= 1.0:
+ cropped = (cropped * 255).astype(np.uint8)
+ else:
+ cropped = cropped.astype(np.uint8)
+
+ from PIL import Image
+ import torchvision.transforms as T
+
+ pil_img = Image.fromarray(cropped)
+ pil_img = pil_img.resize((224, 224), Image.BILINEAR)
+
+ transform = T.Compose([
+ T.ToTensor(),
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ])
+
+ try:
+ device = next(encoder.parameters()).device
+ except:
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+ img_tensor = transform(pil_img).unsqueeze(0).to(device)
+
+ with torch.no_grad():
+ if hasattr(encoder, 'forward_features'):
+ features = encoder.forward_features(img_tensor)
+ else:
+ features = encoder(img_tensor)
+
+ if not isinstance(features, torch.Tensor):
+ if isinstance(features, dict):
+ features = features.get('x', features.get('last_hidden_state', None))
+ if features is None: return None
+ elif hasattr(features, 'data'):
+ features = features.data
+ else:
+ return None
+
+ if len(features.shape) == 4:
+ features = features.mean(dim=[2, 3])
+ elif len(features.shape) == 3:
+ features = features.mean(dim=1)
+
+ features = features / (features.norm(dim=1, keepdim=True) + 1e-8)
+ return features.cpu().numpy()[0]
+ except Exception as e:
+ import traceback
+ print(f"Feature extraction failed: {type(e).__name__}: {e}")
+ return None
+
+def compute_feature_similarity(feat1, feat2):
+ if feat1 is None or feat2 is None:
+ return 0.0
+ try:
+ return np.dot(feat1, feat2)
+ except:
+ return 0.0
+
+def compute_adaptive_eps(centers, base_eps):
+ if len(centers) <= 1:
+ return base_eps
+ from scipy.spatial.distance import pdist
+ distances = pdist(centers)
+ if len(distances) == 0:
+ return base_eps
+ median_dist = np.median(distances)
+ if median_dist > base_eps * 2:
+ adaptive_eps = min(median_dist * 0.6, base_eps * 2.5)
+ elif median_dist > base_eps:
+ adaptive_eps = median_dist * 0.5
+ else:
+ adaptive_eps = base_eps
+ return adaptive_eps
+
+def match_objects_across_views(all_view_detections):
+ print("\nMatching objects using adaptive DBSCAN clustering...")
+ objects_by_label = defaultdict(list)
+
+ for view_idx, detections in enumerate(all_view_detections):
+ for det_idx, det in enumerate(detections):
+ if det.get('center_3d') is None:
+ continue
+ norm_label = normalize_label(det['label'])
+ objects_by_label[norm_label].append({
+ 'view_idx': view_idx,
+ 'det_idx': det_idx,
+ 'label': det['label'],
+ 'norm_label': norm_label,
+ 'center_3d': det['center_3d'],
+ 'confidence': det['confidence'],
+ 'bbox_3d': det.get('bbox_3d'),
+ })
+
+ if len(objects_by_label) == 0:
+ return {}, []
+
+ object_id_map = defaultdict(dict)
+ unique_objects = []
+ next_global_id = 0
+
+ for norm_label, objects in objects_by_label.items():
+ print(f"Processing {norm_label}: {len(objects)} detections")
+ if len(objects) == 1:
+ obj = objects[0]
+ unique_objects.append({
+ 'global_id': next_global_id,
+ 'label': obj['label'],
+ 'views': [(obj['view_idx'], obj['det_idx'])],
+ 'center_3d': obj['center_3d'],
+ })
+ object_id_map[obj['view_idx']][obj['det_idx']] = next_global_id
+ next_global_id += 1
+ continue
+
+ centers = np.array([obj['center_3d'] for obj in objects])
+ base_eps = DBSCAN_EPS_CONFIG.get(norm_label, DBSCAN_EPS_CONFIG.get('default', 1.0))
+ eps = compute_adaptive_eps(centers, base_eps)
+
+ clustering = DBSCAN(eps=eps, min_samples=DBSCAN_MIN_SAMPLES, metric='euclidean')
+ cluster_labels = clustering.fit_predict(centers)
+
+ for cluster_id in set(cluster_labels):
+ if cluster_id == -1:
+ for i, label in enumerate(cluster_labels):
+ if label == -1:
+ obj = objects[i]
+ unique_objects.append({
+ 'global_id': next_global_id,
+ 'label': obj['label'],
+ 'views': [(obj['view_idx'], obj['det_idx'])],
+ 'center_3d': obj['center_3d'],
+ })
+ object_id_map[obj['view_idx']][obj['det_idx']] = next_global_id
+ next_global_id += 1
+ else:
+ cluster_objects = [objects[i] for i, label in enumerate(cluster_labels) if label == cluster_id]
+ total_conf = sum(o['confidence'] for o in cluster_objects)
+ weighted_center = sum(o['center_3d'] * o['confidence'] for o in cluster_objects) / total_conf
+
+ unique_objects.append({
+ 'global_id': next_global_id,
+ 'label': cluster_objects[0]['label'],
+ 'views': [(o['view_idx'], o['det_idx']) for o in cluster_objects],
+ 'center_3d': weighted_center,
+ })
+ for obj in cluster_objects:
+ object_id_map[obj['view_idx']][obj['det_idx']] = next_global_id
+ next_global_id += 1
+
+ return object_id_map, unique_objects
+
+def create_multi_view_segmented_mesh(processed_data, all_view_detections, all_view_masks,
+ object_id_map, unique_objects, target_dir, use_sam=True):
+ try:
+ print("\nGenerating multi-view segmented mesh...")
+ unique_normalized_labels = sorted(set(normalize_label(obj['label']) for obj in unique_objects))
+ label_colors = {}
+ colors = generate_distinct_colors(len(unique_normalized_labels))
+
+ for i, norm_label in enumerate(unique_normalized_labels):
+ label_colors[norm_label] = colors[i]
+
+ for obj in unique_objects:
+ norm_label = normalize_label(obj['label'])
+ obj['color'] = label_colors[norm_label]
+ obj['normalized_label'] = norm_label
+
+ import utils3d
+ all_meshes = []
+
+ for view_idx in range(len(processed_data)):
+ view_data = processed_data[view_idx]
+ image = view_data["image"]
+ points3d = view_data["points3d"]
+ mask = view_data.get("mask")
+ normal = view_data.get("normal")
+
+ detections = all_view_detections[view_idx]
+ masks = all_view_masks[view_idx]
+
+ if len(detections) == 0:
+ continue
+
+ if image.dtype != np.uint8:
+ if image.max() <= 1.0:
+ image = (image * 255).astype(np.uint8)
+ else:
+ image = image.astype(np.uint8)
+
+ colored_image = image.copy()
+ confidence_map = np.zeros((image.shape[0], image.shape[1]), dtype=np.float32)
+
+ detections_info = []
+ for det_idx, (det, seg_mask) in enumerate(zip(detections, masks)):
+ if det['confidence'] < MIN_DETECTION_CONFIDENCE:
+ continue
+ mask_area = seg_mask.sum()
+ if mask_area < MIN_MASK_AREA:
+ continue
+ global_id = object_id_map[view_idx].get(det_idx)
+ if global_id is None:
+ continue
+ unique_obj = next((obj for obj in unique_objects if obj['global_id'] == global_id), None)
+ if unique_obj is None:
+ continue
+ detections_info.append({
+ 'mask': seg_mask,
+ 'color': unique_obj['color'],
+ 'confidence': det['confidence'],
+ 'label': det['label'],
+ 'area': mask_area
+ })
+
+ detections_info.sort(key=lambda x: x['confidence'])
+ for info in detections_info:
+ seg_mask = info['mask']
+ color = info['color']
+ conf = info['confidence']
+ update_mask = seg_mask & (conf > confidence_map)
+ colored_image[update_mask] = color
+ confidence_map[update_mask] = conf
+
+ height, width = image.shape[:2]
+ if normal is None:
+ faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh(
+ points3d,
+ colored_image.astype(np.float32) / 255,
+ utils3d.numpy.image_uv(width=width, height=height),
+ mask=mask if mask is not None else np.ones((height, width), dtype=bool),
+ tri=True
+ )
+ vertex_normals = None
+ else:
+ faces, vertices, vertex_colors, vertex_uvs, vertex_normals = utils3d.numpy.image_mesh(
+ points3d,
+ colored_image.astype(np.float32) / 255,
+ utils3d.numpy.image_uv(width=width, height=height),
+ normal,
+ mask=mask if mask is not None else np.ones((height, width), dtype=bool),
+ tri=True
+ )
+
+ vertices = vertices * np.array([1, -1, -1], dtype=np.float32)
+ if vertex_normals is not None:
+ vertex_normals = vertex_normals * np.array([1, -1, -1], dtype=np.float32)
+
+ view_mesh = trimesh.Trimesh(
+ vertices=vertices,
+ faces=faces,
+ vertex_normals=vertex_normals,
+ vertex_colors=(vertex_colors * 255).astype(np.uint8),
+ process=False
+ )
+ all_meshes.append(view_mesh)
+
+ if len(all_meshes) == 0:
+ return None
+
+ combined_mesh = trimesh.util.concatenate(all_meshes)
+ glb_path = os.path.join(target_dir, 'multi_view_segmented_mesh.glb')
+ combined_mesh.export(glb_path)
+ return glb_path
+ except Exception as e:
+ print(f"Failed to generate multi-view mesh: {e}")
+ import traceback
+ traceback.print_exc()
+ return None
+
+
+# ============================================================================
+# Core Model Inference & View Processing
+# ============================================================================
+
+def process_predictions_for_visualization(
+ predictions, views, high_level_config, filter_black_bg=False, filter_white_bg=False
+):
+ processed_data = {}
+ for view_idx, view in enumerate(views):
+ image = rgb(view["img"], norm_type=high_level_config["data_norm_type"])
+ pred_pts3d = predictions["world_points"][view_idx]
+
+ view_data = {
+ "image": image[0],
+ "points3d": pred_pts3d,
+ "depth": None,
+ "normal": None,
+ "mask": None,
+ }
+
+ mask = predictions["final_mask"][view_idx].copy()
+
+ if filter_black_bg:
+ view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
+ black_bg_mask = view_colors.sum(axis=2) >= 16
+ mask = mask & black_bg_mask
+
+ if filter_white_bg:
+ view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
+ white_bg_mask = ~(
+ (view_colors[:, :, 0] > 240)
+ & (view_colors[:, :, 1] > 240)
+ & (view_colors[:, :, 2] > 240)
+ )
+ mask = mask & white_bg_mask
+
+ view_data["mask"] = mask
+ view_data["depth"] = predictions["depth"][view_idx].squeeze()
+ normals, _ = points_to_normals(pred_pts3d, mask=view_data["mask"])
+ view_data["normal"] = normals
+ processed_data[view_idx] = view_data
+
+ return processed_data
-# -------------------------------------------------------------------------
-# 1) Core model inference
-# -------------------------------------------------------------------------
@spaces.GPU(duration=120)
def run_model(
target_dir,
@@ -65,29 +681,31 @@ def run_model(
mask_edges=True,
filter_black_bg=False,
filter_white_bg=False,
+ enable_segmentation=False,
+ text_prompt=DEFAULT_TEXT_PROMPT,
+ use_sam=True,
):
- """
- Run the MapAnything model on images in the 'target_dir/images' folder and return predictions.
- """
- global model
- import torch # Ensure torch is available in function scope
+ global model, grounding_dino_model, sam_predictor
+ import torch
print(f"Processing images from {target_dir}")
-
- # Device check
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
- # Initialize model if not already done
if model is None:
+ print("Loading MapAnything from HuggingFace...")
model = initialize_mapanything_model(high_level_config, device)
-
+ print("MapAnything loaded successfully")
else:
model = model.to(device)
model.eval()
- # Load images using MapAnything's load_images function
+ if enable_segmentation:
+ load_grounding_dino_model(device)
+ if use_sam:
+ load_sam_model(device)
+
print("Loading images...")
image_folder_path = os.path.join(target_dir, "images")
views = load_images(image_folder_path)
@@ -96,19 +714,12 @@ def run_model(
if len(views) == 0:
raise ValueError("No images found. Check your upload.")
- # Run model inference
print("Running inference...")
- # apply_mask: Whether to apply the non-ambiguous mask to the output. Defaults to True.
- # mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True.
- # Use checkbox values - mask_edges is set to True by default since there's no UI control for it
outputs = model.infer(
views, apply_mask=apply_mask, mask_edges=True, memory_efficient_inference=False
)
- # Convert predictions to format expected by visualization
predictions = {}
-
- # Initialize lists for the required keys
extrinsic_list = []
intrinsic_list = []
world_points_list = []
@@ -116,250 +727,295 @@ def run_model(
images_list = []
final_mask_list = []
- # Loop through the outputs
for pred in outputs:
- # Extract data from predictions
- depthmap_torch = pred["depth_z"][0].squeeze(-1) # (H, W)
- intrinsics_torch = pred["intrinsics"][0] # (3, 3)
- camera_pose_torch = pred["camera_poses"][0] # (4, 4)
-
- # Compute new pts3d using depth, intrinsics, and camera pose
+ depthmap_torch = pred["depth_z"][0].squeeze(-1)
+ intrinsics_torch = pred["intrinsics"][0]
+ camera_pose_torch = pred["camera_poses"][0]
+
pts3d_computed, valid_mask = depthmap_to_world_frame(
depthmap_torch, intrinsics_torch, camera_pose_torch
)
-
- # Convert to numpy arrays for visualization
- # Check if mask key exists in pred, if not, fill with boolean trues in the size of depthmap_torch
+
if "mask" in pred:
mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool)
else:
- # Fill with boolean trues in the size of depthmap_torch
mask = np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool)
- # Combine with valid depth mask
mask = mask & valid_mask.cpu().numpy()
-
image = pred["img_no_norm"][0].cpu().numpy()
- # Append to lists
extrinsic_list.append(camera_pose_torch.cpu().numpy())
intrinsic_list.append(intrinsics_torch.cpu().numpy())
world_points_list.append(pts3d_computed.cpu().numpy())
depth_maps_list.append(depthmap_torch.cpu().numpy())
- images_list.append(image) # Add image to list
- final_mask_list.append(mask) # Add final_mask to list
+ images_list.append(image)
+ final_mask_list.append(mask)
- # Convert lists to numpy arrays with required shapes
- # extrinsic: (S, 3, 4) - batch of camera extrinsic matrices
predictions["extrinsic"] = np.stack(extrinsic_list, axis=0)
-
- # intrinsic: (S, 3, 3) - batch of camera intrinsic matrices
predictions["intrinsic"] = np.stack(intrinsic_list, axis=0)
-
- # world_points: (S, H, W, 3) - batch of 3D world points
predictions["world_points"] = np.stack(world_points_list, axis=0)
-
- # depth: (S, H, W, 1) or (S, H, W) - batch of depth maps
+
depth_maps = np.stack(depth_maps_list, axis=0)
- # Add channel dimension if needed to match (S, H, W, 1) format
if len(depth_maps.shape) == 3:
depth_maps = depth_maps[..., np.newaxis]
-
predictions["depth"] = depth_maps
-
- # images: (S, H, W, 3) - batch of input images
predictions["images"] = np.stack(images_list, axis=0)
-
- # final_mask: (S, H, W) - batch of final masks for filtering
predictions["final_mask"] = np.stack(final_mask_list, axis=0)
- # Process data for visualization tabs (depth, normal, measure)
processed_data = process_predictions_for_visualization(
predictions, views, high_level_config, filter_black_bg, filter_white_bg
)
- # Clean up
+ segmented_glb = None
+ if enable_segmentation and grounding_dino_model is not None:
+ print("\nStarting multi-view segmentation...")
+ all_view_detections = []
+ all_view_masks = []
+
+ for view_idx, ref_image in enumerate(images_list):
+ if ref_image.dtype != np.uint8:
+ ref_image_np = (ref_image * 255).astype(np.uint8)
+ else:
+ ref_image_np = ref_image
+
+ detections = run_grounding_dino_detection(ref_image_np, text_prompt, device)
+ if len(detections) > 0:
+ boxes = [d['bbox'] for d in detections]
+ masks = run_sam_refinement(ref_image_np, boxes) if use_sam else []
+ points3d = world_points_list[view_idx]
+ encoder = model.encoder if hasattr(model, 'encoder') else None
+
+ for det_idx, (det, mask) in enumerate(zip(detections, masks)):
+ center_3d = compute_object_3d_center(points3d, mask)
+ det['center_3d'] = center_3d
+ if center_3d is not None:
+ masked_points = points3d[mask]
+ if len(masked_points) > 0:
+ bbox_min = masked_points.min(axis=0)
+ bbox_max = masked_points.max(axis=0)
+ det['bbox_3d'] = {
+ 'center': center_3d, 'size': bbox_max - bbox_min,
+ 'min': bbox_min, 'max': bbox_max
+ }
+ det['mask_2d'] = mask
+ if ENABLE_VISUAL_FEATURES and encoder is not None:
+ det['visual_feature'] = extract_visual_features(ref_image, mask, encoder)
+ else:
+ det['visual_feature'] = None
+
+ all_view_detections.append(detections)
+ all_view_masks.append(masks)
+ else:
+ all_view_detections.append([])
+ all_view_masks.append([])
+
+ if any(len(dets) > 0 for dets in all_view_detections):
+ object_id_map, unique_objects = match_objects_across_views(all_view_detections)
+ segmented_glb = create_multi_view_segmented_mesh(
+ processed_data, all_view_detections, all_view_masks,
+ object_id_map, unique_objects, target_dir, use_sam
+ )
+
torch.cuda.empty_cache()
+ return predictions, processed_data, segmented_glb
- return predictions, processed_data
+# ============================================================================
+# View Navigation and Visualization Helpers
+# ============================================================================
-def update_view_selectors(processed_data):
- """Update view selector dropdowns based on available views"""
- if processed_data is None or len(processed_data) == 0:
- choices = ["View 1"]
- else:
- num_views = len(processed_data)
- choices = [f"View {i + 1}" for i in range(num_views)]
+def colorize_depth(depth_map, mask=None):
+ if depth_map is None: return None
+ depth_normalized = depth_map.copy()
+ valid_mask = depth_normalized > 0
+ if mask is not None:
+ valid_mask = valid_mask & mask
- return (
- gr.Dropdown(choices=choices, value=choices[0]), # depth_view_selector
- gr.Dropdown(choices=choices, value=choices[0]), # normal_view_selector
- gr.Dropdown(choices=choices, value=choices[0]), # measure_view_selector
- )
+ if valid_mask.sum() > 0:
+ valid_depths = depth_normalized[valid_mask]
+ p5 = np.percentile(valid_depths, 5)
+ p95 = np.percentile(valid_depths, 95)
+ depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5)
+ import matplotlib.pyplot as plt
+ colormap = plt.cm.turbo_r
+ colored = colormap(depth_normalized)
+ colored = (colored[:, :, :3] * 255).astype(np.uint8)
+ colored[~valid_mask] = [255, 255, 255]
+ return colored
-def get_view_data_by_index(processed_data, view_index):
- """Get view data by index, handling bounds"""
- if processed_data is None or len(processed_data) == 0:
- return None
+def colorize_normal(normal_map, mask=None):
+ if normal_map is None: return None
+ normal_vis = normal_map.copy()
+ if mask is not None:
+ invalid_mask = ~mask
+ normal_vis[invalid_mask] = [0, 0, 0]
+ normal_vis = (normal_vis + 1.0) / 2.0
+ normal_vis = (normal_vis * 255).astype(np.uint8)
+ return normal_vis
+def get_view_data_by_index(processed_data, view_index):
+ if processed_data is None or len(processed_data) == 0: return None
view_keys = list(processed_data.keys())
- if view_index < 0 or view_index >= len(view_keys):
- view_index = 0
-
+ if view_index < 0 or view_index >= len(view_keys): view_index = 0
return processed_data[view_keys[view_index]]
-
def update_depth_view(processed_data, view_index):
- """Update depth view for a specific view index"""
view_data = get_view_data_by_index(processed_data, view_index)
- if view_data is None or view_data["depth"] is None:
- return None
-
+ if view_data is None or view_data["depth"] is None: return None
return colorize_depth(view_data["depth"], mask=view_data.get("mask"))
-
def update_normal_view(processed_data, view_index):
- """Update normal view for a specific view index"""
view_data = get_view_data_by_index(processed_data, view_index)
- if view_data is None or view_data["normal"] is None:
- return None
-
+ if view_data is None or view_data["normal"] is None: return None
return colorize_normal(view_data["normal"], mask=view_data.get("mask"))
-
def update_measure_view(processed_data, view_index):
- """Update measure view for a specific view index with mask overlay"""
view_data = get_view_data_by_index(processed_data, view_index)
- if view_data is None:
- return None, [] # image, measure_points
-
- # Get the base image
+ if view_data is None: return None, []
image = view_data["image"].copy()
-
- # Ensure image is in uint8 format
if image.dtype != np.uint8:
- if image.max() <= 1.0:
- image = (image * 255).astype(np.uint8)
- else:
- image = image.astype(np.uint8)
+ if image.max() <= 1.0: image = (image * 255).astype(np.uint8)
+ else: image = image.astype(np.uint8)
- # Apply mask overlay if mask is available
if view_data["mask"] is not None:
- mask = view_data["mask"]
-
- # Create light grey overlay for masked areas
- # Masked areas (False values) will be overlaid with light grey
- invalid_mask = ~mask # Areas where mask is False
-
+ invalid_mask = ~view_data["mask"]
if invalid_mask.any():
- # Create a light grey overlay (RGB: 192, 192, 192)
overlay_color = np.array([255, 220, 220], dtype=np.uint8)
-
- # Apply overlay with some transparency
- alpha = 0.5 # Transparency level
- for c in range(3): # RGB channels
+ alpha = 0.5
+ for c in range(3):
image[:, :, c] = np.where(
invalid_mask,
(1 - alpha) * image[:, :, c] + alpha * overlay_color[c],
image[:, :, c],
).astype(np.uint8)
-
return image, []
-
-def navigate_depth_view(processed_data, current_selector_value, direction):
- """Navigate depth view (direction: -1 for previous, +1 for next)"""
+def update_view_selectors(processed_data):
if processed_data is None or len(processed_data) == 0:
- return "View 1", None
-
- # Parse current view number
- try:
- current_view = int(current_selector_value.split()[1]) - 1
- except:
- current_view = 0
-
- num_views = len(processed_data)
- new_view = (current_view + direction) % num_views
-
- new_selector_value = f"View {new_view + 1}"
- depth_vis = update_depth_view(processed_data, new_view)
-
- return new_selector_value, depth_vis
+ choices = ["View 1"]
+ else:
+ choices = [f"View {i + 1}" for i in range(len(processed_data))]
+ return (
+ gr.Dropdown(choices=choices, value=choices[0]),
+ gr.Dropdown(choices=choices, value=choices[0]),
+ gr.Dropdown(choices=choices, value=choices[0]),
+ )
+def navigate_depth_view(processed_data, current_selector_value, direction):
+ if processed_data is None or len(processed_data) == 0: return "View 1", None
+ try: current_view = int(current_selector_value.split()[1]) - 1
+ except: current_view = 0
+ new_view = (current_view + direction) % len(processed_data)
+ return f"View {new_view + 1}", update_depth_view(processed_data, new_view)
def navigate_normal_view(processed_data, current_selector_value, direction):
- """Navigate normal view (direction: -1 for previous, +1 for next)"""
- if processed_data is None or len(processed_data) == 0:
- return "View 1", None
-
- # Parse current view number
- try:
- current_view = int(current_selector_value.split()[1]) - 1
- except:
- current_view = 0
-
- num_views = len(processed_data)
- new_view = (current_view + direction) % num_views
-
- new_selector_value = f"View {new_view + 1}"
- normal_vis = update_normal_view(processed_data, new_view)
-
- return new_selector_value, normal_vis
-
+ if processed_data is None or len(processed_data) == 0: return "View 1", None
+ try: current_view = int(current_selector_value.split()[1]) - 1
+ except: current_view = 0
+ new_view = (current_view + direction) % len(processed_data)
+ return f"View {new_view + 1}", update_normal_view(processed_data, new_view)
def navigate_measure_view(processed_data, current_selector_value, direction):
- """Navigate measure view (direction: -1 for previous, +1 for next)"""
- if processed_data is None or len(processed_data) == 0:
- return "View 1", None, []
-
- # Parse current view number
- try:
- current_view = int(current_selector_value.split()[1]) - 1
- except:
- current_view = 0
+ if processed_data is None or len(processed_data) == 0: return "View 1", None, []
+ try: current_view = int(current_selector_value.split()[1]) - 1
+ except: current_view = 0
+ new_view = (current_view + direction) % len(processed_data)
+ measure_image, measure_points = update_measure_view(processed_data, new_view)
+ return f"View {new_view + 1}", measure_image, measure_points
- num_views = len(processed_data)
- new_view = (current_view + direction) % num_views
+def populate_visualization_tabs(processed_data):
+ if processed_data is None or len(processed_data) == 0: return None, None, None, []
+ return (
+ update_depth_view(processed_data, 0),
+ update_normal_view(processed_data, 0),
+ update_measure_view(processed_data, 0)[0],
+ []
+ )
- new_selector_value = f"View {new_view + 1}"
- measure_image, measure_points = update_measure_view(processed_data, new_view)
+def measure(processed_data, measure_points, current_view_selector, event: gr.SelectData):
+ try:
+ if processed_data is None or len(processed_data) == 0: return None, [], "No data available"
+ try: current_view_index = int(current_view_selector.split()[1]) - 1
+ except: current_view_index = 0
+ if current_view_index < 0 or current_view_index >= len(processed_data): current_view_index = 0
+
+ view_keys = list(processed_data.keys())
+ current_view = processed_data[view_keys[current_view_index]]
+ if current_view is None: return None, [], "No view data available"
+
+ point2d = event.index[0], event.index[1]
+
+ if (current_view["mask"] is not None and
+ 0 <= point2d[1] < current_view["mask"].shape[0] and
+ 0 <= point2d[0] < current_view["mask"].shape[1]):
+ if not current_view["mask"][point2d[1], point2d[0]]:
+ masked_image, _ = update_measure_view(processed_data, current_view_index)
+ return masked_image, measure_points, 'Cannot measure on masked areas (shown in grey)'
- return new_selector_value, measure_image, measure_points
+ measure_points.append(point2d)
+ image, _ = update_measure_view(processed_data, current_view_index)
+ if image is None: return None, [], "No image available"
+
+ image = image.copy()
+ points3d = current_view["points3d"]
+
+ for p in measure_points:
+ if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]:
+ image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2)
+ depth_text = ""
+ for i, p in enumerate(measure_points):
+ if (current_view["depth"] is not None and
+ 0 <= p[1] < current_view["depth"].shape[0] and
+ 0 <= p[0] < current_view["depth"].shape[1]):
+ d = current_view["depth"][p[1], p[0]]
+ depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n"
+ else:
+ if (points3d is not None and 0 <= p[1] < points3d.shape[0] and 0 <= p[0] < points3d.shape[1]):
+ z = points3d[p[1], p[0], 2]
+ depth_text += f"- **P{i + 1} Z-coord: {z:.2f}m.**\n"
-def populate_visualization_tabs(processed_data):
- """Populate the depth, normal, and measure tabs with processed data"""
- if processed_data is None or len(processed_data) == 0:
- return None, None, None, []
+ if len(measure_points) == 2:
+ point1, point2 = measure_points
+ if (0 <= point1[0] < image.shape[1] and 0 <= point1[1] < image.shape[0] and
+ 0 <= point2[0] < image.shape[1] and 0 <= point2[1] < image.shape[0]):
+ image = cv2.line(image, point1, point2, color=(255, 0, 0), thickness=2)
+
+ distance_text = "- **Distance: Unable to compute**"
+ if (points3d is not None and 0 <= point1[1] < points3d.shape[0] and 0 <= point1[0] < points3d.shape[1] and
+ 0 <= point2[1] < points3d.shape[0] and 0 <= point2[0] < points3d.shape[1]):
+ try:
+ p1_3d = points3d[point1[1], point1[0]]
+ p2_3d = points3d[point2[1], point2[0]]
+ distance = np.linalg.norm(p1_3d - p2_3d)
+ distance_text = f"- **Distance: {distance:.2f}m**"
+ except:
+ pass
+
+ measure_points = []
+ text = depth_text + distance_text
+ return [image, measure_points, text]
+ else:
+ return [image, measure_points, depth_text]
- # Use update functions to ensure confidence filtering is applied from the start
- depth_vis = update_depth_view(processed_data, 0)
- normal_vis = update_normal_view(processed_data, 0)
- measure_img, _ = update_measure_view(processed_data, 0)
+ except Exception as e:
+ return None, [], f"Measure function error: {e}"
- return depth_vis, normal_vis, measure_img, []
+# ============================================================================
+# Core Functions for Upload and Interface
+# ============================================================================
-# -------------------------------------------------------------------------
-# 2) Handle uploaded video/images --> produce target_dir + images
-# -------------------------------------------------------------------------
def handle_uploads(unified_upload, s_time_interval=1.0):
- """
- Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
- images or extracted frames from video into it. Return (target_dir, image_paths).
- """
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
- # Create a unique folder name
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
target_dir = f"input_images_{timestamp}"
target_dir_images = os.path.join(target_dir, "images")
- # Clean up if somehow that folder already exists
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
os.makedirs(target_dir)
@@ -367,7 +1023,6 @@ def handle_uploads(unified_upload, s_time_interval=1.0):
image_paths = []
- # --- Handle uploaded files (both images and videos) ---
if unified_upload is not None:
for file_data in unified_upload:
if isinstance(file_data, dict) and "name" in file_data:
@@ -376,118 +1031,57 @@ def handle_uploads(unified_upload, s_time_interval=1.0):
file_path = str(file_data)
file_ext = os.path.splitext(file_path)[1].lower()
-
- # Check if it's a video file
- video_extensions = [
- ".mp4",
- ".avi",
- ".mov",
- ".mkv",
- ".wmv",
- ".flv",
- ".webm",
- ".m4v",
- ".3gp",
- ]
+ video_extensions = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"]
+
if file_ext in video_extensions:
- # Handle as video
vs = cv2.VideoCapture(file_path)
fps = vs.get(cv2.CAP_PROP_FPS)
- frame_interval = int(fps * s_time_interval) # frames per interval
-
+ frame_interval = int(fps * s_time_interval)
count = 0
video_frame_num = 0
while True:
gotit, frame = vs.read()
- if not gotit:
- break
+ if not gotit: break
count += 1
if count % frame_interval == 0:
- # Use original filename as prefix for frames
base_name = os.path.splitext(os.path.basename(file_path))[0]
- image_path = os.path.join(
- target_dir_images, f"{base_name}_{video_frame_num:06}.png"
- )
+ image_path = os.path.join(target_dir_images, f"{base_name}_{video_frame_num:06}.png")
cv2.imwrite(image_path, frame)
image_paths.append(image_path)
video_frame_num += 1
vs.release()
- print(
- f"Extracted {video_frame_num} frames from video: {os.path.basename(file_path)}"
- )
-
else:
- # Handle as image
- # Check if the file is a HEIC image
if file_ext in [".heic", ".heif"]:
- # Convert HEIC to JPEG for better gallery compatibility
try:
with Image.open(file_path) as img:
- # Convert to RGB if necessary (HEIC can have different color modes)
- if img.mode not in ("RGB", "L"):
- img = img.convert("RGB")
-
- # Create JPEG filename
+ if img.mode not in ("RGB", "L"): img = img.convert("RGB")
base_name = os.path.splitext(os.path.basename(file_path))[0]
- dst_path = os.path.join(
- target_dir_images, f"{base_name}.jpg"
- )
-
- # Save as JPEG with high quality
+ dst_path = os.path.join(target_dir_images, f"{base_name}.jpg")
img.save(dst_path, "JPEG", quality=95)
image_paths.append(dst_path)
- print(
- f"Converted HEIC to JPEG: {os.path.basename(file_path)} -> {os.path.basename(dst_path)}"
- )
except Exception as e:
- print(f"Error converting HEIC file {file_path}: {e}")
- # Fall back to copying as is
- dst_path = os.path.join(
- target_dir_images, os.path.basename(file_path)
- )
+ dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
shutil.copy(file_path, dst_path)
image_paths.append(dst_path)
else:
- # Regular image files - copy as is
- dst_path = os.path.join(
- target_dir_images, os.path.basename(file_path)
- )
+ dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
shutil.copy(file_path, dst_path)
image_paths.append(dst_path)
- # Sort final images for gallery
image_paths = sorted(image_paths)
-
- end_time = time.time()
- print(
- f"Files processed to {target_dir_images}; took {end_time - start_time:.3f} seconds"
- )
return target_dir, image_paths
-
-# -------------------------------------------------------------------------
-# 3) Update gallery on upload
-# -------------------------------------------------------------------------
-def update_gallery_on_upload(input_video, input_images, s_time_interval=1.0):
- """
- Whenever user uploads or changes files, immediately handle them
- and show in the gallery. Return (target_dir, image_paths).
- If nothing is uploaded, returns "None" and empty list.
- """
- if not input_video and not input_images:
+def update_gallery_on_upload(input_files, s_time_interval=1.0):
+ if not input_files:
return None, None, None, None
- target_dir, image_paths = handle_uploads(input_video, input_images, s_time_interval)
+ target_dir, image_paths = handle_uploads(input_files, s_time_interval)
return (
None,
target_dir,
image_paths,
- "Upload complete. Click 'Reconstruct' to begin 3D processing.",
+ "Upload complete. Click 'Start Reconstruction' to begin 3D processing.",
)
-
-# -------------------------------------------------------------------------
-# 4) Reconstruction: uses the target_dir plus any viz parameters
-# -------------------------------------------------------------------------
@spaces.GPU(duration=120)
def gradio_demo(
target_dir,
@@ -497,384 +1091,80 @@ def gradio_demo(
filter_white_bg=False,
apply_mask=True,
show_mesh=True,
+ enable_segmentation=False,
+ text_prompt=DEFAULT_TEXT_PROMPT,
+ use_sam=True,
):
- """
- Perform reconstruction using the already-created target_dir/images.
- """
if not os.path.isdir(target_dir) or target_dir == "None":
- return None, "No valid target directory found. Please upload first.", None, None
+ return None, None, "No valid target directory found. Please upload first.", None, None, None, None, None, "", None, None, None
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
- # Prepare frame_filter dropdown
target_dir_images = os.path.join(target_dir, "images")
- all_files = (
- sorted(os.listdir(target_dir_images))
- if os.path.isdir(target_dir_images)
- else []
- )
- all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
- frame_filter_choices = ["All"] + all_files
+ all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
+ all_files_display = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
+ frame_filter_choices = ["All"] + all_files_display
- print("Running MapAnything model...")
with torch.no_grad():
- predictions, processed_data = run_model(target_dir, apply_mask)
+ predictions, processed_data, segmented_glb = run_model(
+ target_dir, apply_mask, True, filter_black_bg, filter_white_bg,
+ enable_segmentation, text_prompt, use_sam
+ )
- # Save predictions
prediction_save_path = os.path.join(target_dir, "predictions.npz")
np.savez(prediction_save_path, **predictions)
- # Handle None frame_filter
- if frame_filter is None:
- frame_filter = "All"
+ if frame_filter is None: frame_filter = "All"
- # Build a GLB file name
glbfile = os.path.join(
target_dir,
f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb",
)
- # Convert predictions to GLB
glbscene = predictions_to_glb(
predictions,
filter_by_frames=frame_filter,
show_cam=show_cam,
mask_black_bg=filter_black_bg,
mask_white_bg=filter_white_bg,
- as_mesh=show_mesh, # Use the show_mesh parameter
+ as_mesh=show_mesh,
)
glbscene.export(file_obj=glbfile)
- # Cleanup
del predictions
gc.collect()
torch.cuda.empty_cache()
- end_time = time.time()
- print(f"Total time: {end_time - start_time:.2f} seconds")
- log_msg = (
- f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
- )
+ log_msg = f"Reconstruction successful ({len(all_files)} frames)"
- # Populate visualization tabs with processed data
- depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs(
- processed_data
- )
+ depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs(processed_data)
+ depth_selector, normal_selector, measure_selector = update_view_selectors(processed_data)
- # Update view selectors based on available views
- depth_selector, normal_selector, measure_selector = update_view_selectors(
- processed_data
- )
+ output_reconstruction = segmented_glb if (enable_segmentation and segmented_glb is not None) else glbfile
return (
- glbfile,
+ glbfile, # Raw 3D
+ output_reconstruction, # 3D View (Segmented if applicable)
log_msg,
gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
processed_data,
depth_vis,
normal_vis,
measure_img,
- "", # measure_text (empty initially)
+ "",
depth_selector,
normal_selector,
measure_selector,
)
-
-# -------------------------------------------------------------------------
-# 5) Helper functions for UI resets + re-visualization
-# -------------------------------------------------------------------------
-def colorize_depth(depth_map, mask=None):
- """Convert depth map to colorized visualization with optional mask"""
- if depth_map is None:
- return None
-
- # Normalize depth to 0-1 range
- depth_normalized = depth_map.copy()
- valid_mask = depth_normalized > 0
-
- # Apply additional mask if provided (for background filtering)
- if mask is not None:
- valid_mask = valid_mask & mask
-
- if valid_mask.sum() > 0:
- valid_depths = depth_normalized[valid_mask]
- p5 = np.percentile(valid_depths, 5)
- p95 = np.percentile(valid_depths, 95)
-
- depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5)
-
- # Apply colormap
- import matplotlib.pyplot as plt
-
- colormap = plt.cm.turbo_r
- colored = colormap(depth_normalized)
- colored = (colored[:, :, :3] * 255).astype(np.uint8)
-
- # Set invalid pixels to white
- colored[~valid_mask] = [255, 255, 255]
-
- return colored
-
-
-def colorize_normal(normal_map, mask=None):
- """Convert normal map to colorized visualization with optional mask"""
- if normal_map is None:
- return None
-
- # Create a copy for modification
- normal_vis = normal_map.copy()
-
- # Apply mask if provided (set masked areas to [0, 0, 0] which becomes grey after normalization)
- if mask is not None:
- invalid_mask = ~mask
- normal_vis[invalid_mask] = [0, 0, 0] # Set invalid areas to zero
-
- # Normalize normals to [0, 1] range for visualization
- normal_vis = (normal_vis + 1.0) / 2.0
- normal_vis = (normal_vis * 255).astype(np.uint8)
-
- return normal_vis
-
-
-def process_predictions_for_visualization(
- predictions, views, high_level_config, filter_black_bg=False, filter_white_bg=False
-):
- """Extract depth, normal, and 3D points from predictions for visualization"""
- processed_data = {}
-
- # Process each view
- for view_idx, view in enumerate(views):
- # Get image
- image = rgb(view["img"], norm_type=high_level_config["data_norm_type"])
-
- # Get predicted points
- pred_pts3d = predictions["world_points"][view_idx]
-
- # Initialize data for this view
- view_data = {
- "image": image[0],
- "points3d": pred_pts3d,
- "depth": None,
- "normal": None,
- "mask": None,
- }
-
- # Start with the final mask from predictions
- mask = predictions["final_mask"][view_idx].copy()
-
- # Apply black background filtering if enabled
- if filter_black_bg:
- # Get the image colors (ensure they're in 0-255 range)
- view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
- # Filter out black background pixels (sum of RGB < 16)
- black_bg_mask = view_colors.sum(axis=2) >= 16
- mask = mask & black_bg_mask
-
- # Apply white background filtering if enabled
- if filter_white_bg:
- # Get the image colors (ensure they're in 0-255 range)
- view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
- # Filter out white background pixels (all RGB > 240)
- white_bg_mask = ~(
- (view_colors[:, :, 0] > 240)
- & (view_colors[:, :, 1] > 240)
- & (view_colors[:, :, 2] > 240)
- )
- mask = mask & white_bg_mask
-
- view_data["mask"] = mask
- view_data["depth"] = predictions["depth"][view_idx].squeeze()
-
- normals, _ = points_to_normals(pred_pts3d, mask=view_data["mask"])
- view_data["normal"] = normals
-
- processed_data[view_idx] = view_data
-
- return processed_data
-
-
-def reset_measure(processed_data):
- """Reset measure points"""
- if processed_data is None or len(processed_data) == 0:
- return None, [], ""
-
- # Return the first view image
- first_view = list(processed_data.values())[0]
- return first_view["image"], [], ""
-
-
-def measure(
- processed_data, measure_points, current_view_selector, event: gr.SelectData
-):
- """Handle measurement on images"""
- try:
- print(f"Measure function called with selector: {current_view_selector}")
-
- if processed_data is None or len(processed_data) == 0:
- return None, [], "No data available"
-
- # Use the currently selected view instead of always using the first view
- try:
- current_view_index = int(current_view_selector.split()[1]) - 1
- except:
- current_view_index = 0
-
- print(f"Using view index: {current_view_index}")
-
- # Get view data safely
- if current_view_index < 0 or current_view_index >= len(processed_data):
- current_view_index = 0
-
- view_keys = list(processed_data.keys())
- current_view = processed_data[view_keys[current_view_index]]
-
- if current_view is None:
- return None, [], "No view data available"
-
- point2d = event.index[0], event.index[1]
- print(f"Clicked point: {point2d}")
-
- # Check if the clicked point is in a masked area (prevent interaction)
- if (
- current_view["mask"] is not None
- and 0 <= point2d[1] < current_view["mask"].shape[0]
- and 0 <= point2d[0] < current_view["mask"].shape[1]
- ):
- # Check if the point is in a masked (invalid) area
- if not current_view["mask"][point2d[1], point2d[0]]:
- print(f"Clicked point {point2d} is in masked area, ignoring click")
- # Always return image with mask overlay
- masked_image, _ = update_measure_view(
- processed_data, current_view_index
- )
- return (
- masked_image,
- measure_points,
- 'Cannot measure on masked areas (shown in grey)',
- )
-
- measure_points.append(point2d)
-
- # Get image with mask overlay and ensure it's valid
- image, _ = update_measure_view(processed_data, current_view_index)
- if image is None:
- return None, [], "No image available"
-
- image = image.copy()
- points3d = current_view["points3d"]
-
- # Ensure image is in uint8 format for proper cv2 operations
- try:
- if image.dtype != np.uint8:
- if image.max() <= 1.0:
- # Image is in [0, 1] range, convert to [0, 255]
- image = (image * 255).astype(np.uint8)
- else:
- # Image is already in [0, 255] range
- image = image.astype(np.uint8)
- except Exception as e:
- print(f"Image conversion error: {e}")
- return None, [], f"Image conversion error: {e}"
-
- # Draw circles for points
- try:
- for p in measure_points:
- if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]:
- image = cv2.circle(
- image, p, radius=5, color=(255, 0, 0), thickness=2
- )
- except Exception as e:
- print(f"Drawing error: {e}")
- return None, [], f"Drawing error: {e}"
-
- depth_text = ""
- try:
- for i, p in enumerate(measure_points):
- if (
- current_view["depth"] is not None
- and 0 <= p[1] < current_view["depth"].shape[0]
- and 0 <= p[0] < current_view["depth"].shape[1]
- ):
- d = current_view["depth"][p[1], p[0]]
- depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n"
- else:
- # Use Z coordinate of 3D points if depth not available
- if (
- points3d is not None
- and 0 <= p[1] < points3d.shape[0]
- and 0 <= p[0] < points3d.shape[1]
- ):
- z = points3d[p[1], p[0], 2]
- depth_text += f"- **P{i + 1} Z-coord: {z:.2f}m.**\n"
- except Exception as e:
- print(f"Depth text error: {e}")
- depth_text = f"Error computing depth: {e}\n"
-
- if len(measure_points) == 2:
- try:
- point1, point2 = measure_points
- # Draw line
- if (
- 0 <= point1[0] < image.shape[1]
- and 0 <= point1[1] < image.shape[0]
- and 0 <= point2[0] < image.shape[1]
- and 0 <= point2[1] < image.shape[0]
- ):
- image = cv2.line(
- image, point1, point2, color=(255, 0, 0), thickness=2
- )
-
- # Compute 3D distance
- distance_text = "- **Distance: Unable to compute**"
- if (
- points3d is not None
- and 0 <= point1[1] < points3d.shape[0]
- and 0 <= point1[0] < points3d.shape[1]
- and 0 <= point2[1] < points3d.shape[0]
- and 0 <= point2[0] < points3d.shape[1]
- ):
- try:
- p1_3d = points3d[point1[1], point1[0]]
- p2_3d = points3d[point2[1], point2[0]]
- distance = np.linalg.norm(p1_3d - p2_3d)
- distance_text = f"- **Distance: {distance:.2f}m**"
- except Exception as e:
- print(f"Distance computation error: {e}")
- distance_text = f"- **Distance computation error: {e}**"
-
- measure_points = []
- text = depth_text + distance_text
- print(f"Measurement complete: {text}")
- return [image, measure_points, text]
- except Exception as e:
- print(f"Final measurement error: {e}")
- return None, [], f"Measurement error: {e}"
- else:
- print(f"Single point measurement: {depth_text}")
- return [image, measure_points, depth_text]
-
- except Exception as e:
- print(f"Overall measure function error: {e}")
- return None, [], f"Measure function error: {e}"
-
-
def clear_fields():
- """
- Clears the 3D viewer, the stored target_dir, and empties the gallery.
- """
- return None
-
+ return None, None
def update_log():
- """
- Display a quick log message while waiting.
- """
return "Loading and Reconstructing..."
-
def update_visualization(
target_dir,
frame_filter,
@@ -884,30 +1174,12 @@ def update_visualization(
filter_white_bg=False,
show_mesh=True,
):
- """
- Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
- and return it for the 3D viewer. If is_example == "True", skip.
- """
-
- # If it's an example click, skip as requested
- if is_example == "True":
- return (
- gr.update(),
- "No reconstruction available. Please click the Reconstruct button first.",
- )
-
- if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
- return (
- gr.update(),
- "No reconstruction available. Please click the Reconstruct button first.",
- )
+ if is_example == "True" or not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
+ return gr.update(), gr.update(), "No reconstruction available. Please run reconstruction first."
predictions_path = os.path.join(target_dir, "predictions.npz")
if not os.path.exists(predictions_path):
- return (
- gr.update(),
- f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.",
- )
+ return gr.update(), gr.update(), "No reconstruction available. Please run reconstruction first."
loaded = np.load(predictions_path, allow_pickle=True)
predictions = {key: loaded[key] for key in loaded.keys()}
@@ -928,11 +1200,7 @@ def update_visualization(
)
glbscene.export(file_obj=glbfile)
- return (
- glbfile,
- "Visualization updated.",
- )
-
+ return glbfile, gr.update(), "Visualization updated."
def update_all_views_on_filter_change(
target_dir,
@@ -943,11 +1211,6 @@ def update_all_views_on_filter_change(
normal_view_selector,
measure_view_selector,
):
- """
- Update all individual view tabs when background filtering checkboxes change.
- This regenerates the processed data with new filtering and updates all views.
- """
- # Check if we have a valid target directory and predictions
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
return processed_data, None, None, None, []
@@ -956,722 +1219,288 @@ def update_all_views_on_filter_change(
return processed_data, None, None, None, []
try:
- # Load the original predictions and views
loaded = np.load(predictions_path, allow_pickle=True)
predictions = {key: loaded[key] for key in loaded.keys()}
- # Load images using MapAnything's load_images function
image_folder_path = os.path.join(target_dir, "images")
views = load_images(image_folder_path)
- # Regenerate processed data with new filtering settings
new_processed_data = process_predictions_for_visualization(
predictions, views, high_level_config, filter_black_bg, filter_white_bg
)
- # Get current view indices
- try:
- depth_view_idx = (
- int(depth_view_selector.split()[1]) - 1 if depth_view_selector else 0
- )
- except:
- depth_view_idx = 0
+ try: depth_view_idx = int(depth_view_selector.split()[1]) - 1 if depth_view_selector else 0
+ except: depth_view_idx = 0
- try:
- normal_view_idx = (
- int(normal_view_selector.split()[1]) - 1 if normal_view_selector else 0
- )
- except:
- normal_view_idx = 0
+ try: normal_view_idx = int(normal_view_selector.split()[1]) - 1 if normal_view_selector else 0
+ except: normal_view_idx = 0
- try:
- measure_view_idx = (
- int(measure_view_selector.split()[1]) - 1
- if measure_view_selector
- else 0
- )
- except:
- measure_view_idx = 0
+ try: measure_view_idx = int(measure_view_selector.split()[1]) - 1 if measure_view_selector else 0
+ except: measure_view_idx = 0
- # Update all views with new filtered data
depth_vis = update_depth_view(new_processed_data, depth_view_idx)
normal_vis = update_normal_view(new_processed_data, normal_view_idx)
measure_img, _ = update_measure_view(new_processed_data, measure_view_idx)
return new_processed_data, depth_vis, normal_vis, measure_img, []
-
except Exception as e:
- print(f"Error updating views on filter change: {e}")
return processed_data, None, None, None, []
-
-# -------------------------------------------------------------------------
-# Example scene functions
-# -------------------------------------------------------------------------
+# Example Scenes
def get_scene_info(examples_dir):
- """Get information about scenes in the examples directory"""
import glob
-
scenes = []
- if not os.path.exists(examples_dir):
- return scenes
-
+ if not os.path.exists(examples_dir): return scenes
for scene_folder in sorted(os.listdir(examples_dir)):
scene_path = os.path.join(examples_dir, scene_folder)
if os.path.isdir(scene_path):
- # Find all image files in the scene folder
image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"]
image_files = []
for ext in image_extensions:
image_files.extend(glob.glob(os.path.join(scene_path, ext)))
image_files.extend(glob.glob(os.path.join(scene_path, ext.upper())))
-
if image_files:
- # Sort images and get the first one for thumbnail
image_files = sorted(image_files)
- first_image = image_files[0]
- num_images = len(image_files)
-
- scenes.append(
- {
- "name": scene_folder,
- "path": scene_path,
- "thumbnail": first_image,
- "num_images": num_images,
- "image_files": image_files,
- }
- )
-
+ scenes.append({
+ "name": scene_folder,
+ "path": scene_path,
+ "thumbnail": image_files[0],
+ "num_images": len(image_files),
+ "image_files": image_files,
+ })
return scenes
-
def load_example_scene(scene_name, examples_dir="examples"):
- """Load a scene from examples directory"""
scenes = get_scene_info(examples_dir)
-
- # Find the selected scene
- selected_scene = None
- for scene in scenes:
- if scene["name"] == scene_name:
- selected_scene = scene
- break
-
- if selected_scene is None:
- return None, None, None, "Scene not found"
-
- # Create file-like objects for the unified upload system
- # Convert image file paths to the format expected by unified_upload
- file_objects = []
- for image_path in selected_scene["image_files"]:
- file_objects.append(image_path)
-
- # Create target directory and copy images using the unified upload system
+ selected_scene = next((s for s in scenes if s["name"] == scene_name), None)
+ if selected_scene is None: return None, None, None, "Scene not found"
+
+ file_objects = [image_path for image_path in selected_scene["image_files"]]
target_dir, image_paths = handle_uploads(file_objects, 1.0)
-
+
return (
- None, # Clear reconstruction output
- target_dir, # Set target directory
- image_paths, # Set gallery
- f"Loaded scene '{scene_name}' with {selected_scene['num_images']} images. Click 'Reconstruct' to begin 3D processing.",
+ None,
+ target_dir,
+ image_paths,
+ f"Loaded scene '{scene_name}'. Click 'Start Reconstruction' to begin.",
)
-# -------------------------------------------------------------------------
-# 6) Build Gradio UI
-# -------------------------------------------------------------------------
+# ============================================================================
+# Gradio UI Construction
+# ============================================================================
+
theme = get_gradio_theme()
+CUSTOM_CSS = GRADIO_CSS + """
+.gradio-container { max-width: 100% !important; }
+.gallery-container { max-height: 350px !important; overflow-y: auto !important; }
+.file-preview { max-height: 200px !important; overflow-y: auto !important; }
+.video-container { max-height: 300px !important; }
+.textbox-container { max-height: 100px !important; }
+.tab-content { min-height: 550px !important; }
+"""
+
with gr.Blocks() as demo:
- # State variables for the tabbed interface
is_example = gr.Textbox(label="is_example", visible=False, value="None")
- num_images = gr.Textbox(label="num_images", visible=False, value="None")
processed_data_state = gr.State(value=None)
measure_points_state = gr.State(value=[])
- current_view_index = gr.State(value=0) # Track current view index for navigation
-
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
- with gr.Row():
- with gr.Column(scale=2):
- # Unified upload component for both videos and images
+ with gr.Row(equal_height=False):
+ # Input Area
+ with gr.Column(scale=1, min_width=300):
+ gr.Markdown("### Input")
+
unified_upload = gr.File(
file_count="multiple",
label="Upload Video or Images",
interactive=True,
file_types=["image", "video"],
)
+
with gr.Row():
s_time_interval = gr.Slider(
- minimum=0.1,
- maximum=5.0,
- value=1.0,
- step=0.1,
+ minimum=0.1, maximum=5.0, value=1.0, step=0.1,
label="Video sample time interval (take a sample every x sec.)",
- interactive=True,
- visible=True,
- scale=3,
- )
- resample_btn = gr.Button(
- "Resample Video",
- visible=False,
- variant="secondary",
- scale=1,
+ interactive=True, scale=3
)
+ resample_btn = gr.Button("Resample Video", visible=False, variant="secondary", scale=1)
image_gallery = gr.Gallery(
- label="Preview",
- columns=4,
- height="300px",
- show_download_button=True,
- object_fit="contain",
- preview=True,
- )
-
- clear_uploads_btn = gr.ClearButton(
- [unified_upload, image_gallery],
- value="Clear Uploads",
- variant="secondary",
- size="sm",
+ label="Preview", columns=4, height="300px",
+ show_download_button=True, object_fit="contain", preview=True
)
-
- with gr.Column(scale=4):
- with gr.Column():
- gr.Markdown(
- "**Metric 3D Reconstruction (Point Cloud and Camera Poses)**"
- )
- log_output = gr.Markdown(
- "Please upload a video or images, then click Reconstruct.",
- elem_classes=["custom-log"],
- )
-
- # Add tabbed interface similar to MoGe
- with gr.Tabs():
- with gr.Tab("3D View"):
- reconstruction_output = gr.Model3D(
- height=520,
- zoom_speed=0.5,
- pan_speed=0.5,
- clear_color=[0.0, 0.0, 0.0, 0.0],
- key="persistent_3d_viewer",
- elem_id="reconstruction_3d_viewer",
- )
- with gr.Tab("Depth"):
- with gr.Row(elem_classes=["navigation-row"]):
- prev_depth_btn = gr.Button("◀ Previous", size="sm", scale=1)
- depth_view_selector = gr.Dropdown(
- choices=["View 1"],
- value="View 1",
- label="Select View",
- scale=2,
- interactive=True,
- allow_custom_value=True,
- )
- next_depth_btn = gr.Button("Next ▶", size="sm", scale=1)
- depth_map = gr.Image(
- type="numpy",
- label="Colorized Depth Map",
- format="png",
- interactive=False,
- )
- with gr.Tab("Normal"):
- with gr.Row(elem_classes=["navigation-row"]):
- prev_normal_btn = gr.Button(
- "◀ Previous", size="sm", scale=1
- )
- normal_view_selector = gr.Dropdown(
- choices=["View 1"],
- value="View 1",
- label="Select View",
- scale=2,
- interactive=True,
- allow_custom_value=True,
- )
- next_normal_btn = gr.Button("Next ▶", size="sm", scale=1)
- normal_map = gr.Image(
- type="numpy",
- label="Normal Map",
- format="png",
- interactive=False,
- )
- with gr.Tab("Measure"):
- gr.Markdown(MEASURE_INSTRUCTIONS_HTML)
- with gr.Row(elem_classes=["navigation-row"]):
- prev_measure_btn = gr.Button(
- "◀ Previous", size="sm", scale=1
- )
- measure_view_selector = gr.Dropdown(
- choices=["View 1"],
- value="View 1",
- label="Select View",
- scale=2,
- interactive=True,
- allow_custom_value=True,
- )
- next_measure_btn = gr.Button("Next ▶", size="sm", scale=1)
- measure_image = gr.Image(
- type="numpy",
- show_label=False,
- format="webp",
- interactive=False,
- sources=[],
- )
- gr.Markdown(
- "**Note:** Light-grey areas indicate regions with no depth information where measurements cannot be taken."
- )
- measure_text = gr.Markdown("")
-
+
with gr.Row():
- submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
- clear_btn = gr.ClearButton(
- [
- unified_upload,
- reconstruction_output,
- log_output,
- target_dir_output,
- image_gallery,
- ],
- scale=1,
- )
+ submit_btn = gr.Button("Start Reconstruction", variant="primary", scale=2)
+ clear_btn = gr.ClearButton([unified_upload, image_gallery, target_dir_output], value="Clear", scale=1)
+
+ # Output Area
+ with gr.Column(scale=2, min_width=600):
+ gr.Markdown("### Output")
+
+ with gr.Tabs():
+ with gr.Tab("Raw 3D"):
+ raw_3d_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5, clear_color=[0.0, 0.0, 0.0, 0.0])
+ with gr.Tab("3D View"):
+ reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5, clear_color=[0.0, 0.0, 0.0, 0.0])
+ with gr.Tab("Depth"):
+ with gr.Row(elem_classes=["navigation-row"]):
+ prev_depth_btn = gr.Button("Previous", size="sm", scale=1)
+ depth_view_selector = gr.Dropdown(choices=["View 1"], value="View 1", label="Select View", scale=2, interactive=True, allow_custom_value=True)
+ next_depth_btn = gr.Button("Next", size="sm", scale=1)
+ depth_map = gr.Image(type="numpy", label="Colorized Depth Map", format="png", interactive=False)
+ with gr.Tab("Normal"):
+ with gr.Row(elem_classes=["navigation-row"]):
+ prev_normal_btn = gr.Button("Previous", size="sm", scale=1)
+ normal_view_selector = gr.Dropdown(choices=["View 1"], value="View 1", label="Select View", scale=2, interactive=True, allow_custom_value=True)
+ next_normal_btn = gr.Button("Next", size="sm", scale=1)
+ normal_map = gr.Image(type="numpy", label="Normal Map", format="png", interactive=False)
+ with gr.Tab("Measure"):
+ gr.Markdown(MEASURE_INSTRUCTIONS_HTML)
+ with gr.Row(elem_classes=["navigation-row"]):
+ prev_measure_btn = gr.Button("Previous", size="sm", scale=1)
+ measure_view_selector = gr.Dropdown(choices=["View 1"], value="View 1", label="Select View", scale=2, interactive=True, allow_custom_value=True)
+ next_measure_btn = gr.Button("Next", size="sm", scale=1)
+ measure_image = gr.Image(type="numpy", show_label=False, format="webp", interactive=False, sources=[])
+ gr.Markdown("**Note:** Light-grey areas indicate regions with no depth information where measurements cannot be taken.")
+ measure_text = gr.Markdown("")
+
+ log_output = gr.Textbox(
+ value="Please upload images or video, then click 'Start Reconstruction'",
+ label="Status Information", interactive=False, lines=1, max_lines=1
+ )
- with gr.Row():
- frame_filter = gr.Dropdown(
- choices=["All"], value="All", label="Show Points from Frame"
+ with gr.Accordion("Advanced Options", open=False):
+ with gr.Row(equal_height=False):
+ with gr.Column(scale=1, min_width=300):
+ gr.Markdown("#### Visualization Parameters")
+ frame_filter = gr.Dropdown(choices=["All"], value="All", label="Show Points from Frame")
+ show_cam = gr.Checkbox(label="Show Camera", value=True)
+ show_mesh = gr.Checkbox(label="Show Mesh", value=True)
+ filter_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
+ filter_white_bg = gr.Checkbox(label="Filter White Background", value=False)
+
+ with gr.Column(scale=1, min_width=300):
+ gr.Markdown("#### Reconstruction Parameters")
+ apply_mask_checkbox = gr.Checkbox(label="Apply mask for predicted ambiguous depth classes & edges", value=True)
+
+ gr.Markdown("#### Segmentation Parameters")
+ enable_segmentation = gr.Checkbox(label="Enable Semantic Segmentation", value=False)
+ use_sam_checkbox = gr.Checkbox(label="Use SAM for Accurate Segmentation", value=True)
+ text_prompt = gr.Textbox(
+ value=DEFAULT_TEXT_PROMPT,
+ label="Detect Objects (separated by .)",
+ placeholder="e.g. chair . table . sofa", lines=2, max_lines=2
)
- with gr.Column():
- gr.Markdown("### Pointcloud Options: (live updates)")
- show_cam = gr.Checkbox(label="Show Camera", value=True)
- show_mesh = gr.Checkbox(label="Show Mesh", value=True)
- filter_black_bg = gr.Checkbox(
- label="Filter Black Background", value=False
- )
- filter_white_bg = gr.Checkbox(
- label="Filter White Background", value=False
- )
- gr.Markdown("### Reconstruction Options: (updated on next run)")
- apply_mask_checkbox = gr.Checkbox(
- label="Apply mask for predicted ambiguous depth classes & edges",
- value=True,
- )
- # ---------------------- Example Scenes Section ----------------------
- gr.Markdown("## Example Scenes (lists all scenes in the examples folder)")
- gr.Markdown("Click any thumbnail to load the scene for reconstruction.")
-
- # Get scene information
- scenes = get_scene_info("examples")
-
- # Create thumbnail grid (4 columns, N rows)
- if scenes:
- for i in range(0, len(scenes), 4): # Process 4 scenes per row
- with gr.Row():
- for j in range(4):
- scene_idx = i + j
- if scene_idx < len(scenes):
- scene = scenes[scene_idx]
- with gr.Column(scale=1, elem_classes=["clickable-thumbnail"]):
- # Clickable thumbnail
- scene_img = gr.Image(
- value=scene["thumbnail"],
- height=150,
- interactive=False,
- show_label=False,
- elem_id=f"scene_thumb_{scene['name']}",
- sources=[],
- )
-
- # Scene name and image count as text below thumbnail
- gr.Markdown(
- f"**{scene['name']}** \n {scene['num_images']} images",
- elem_classes=["scene-info"],
- )
-
- # Connect thumbnail click to load scene
- scene_img.select(
- fn=lambda name=scene["name"]: load_example_scene(name),
- outputs=[
- reconstruction_output,
- target_dir_output,
- image_gallery,
- log_output,
- ],
- )
- else:
- # Empty column to maintain grid structure
- with gr.Column(scale=1):
- pass
-
- # -------------------------------------------------------------------------
- # "Reconstruct" button logic:
- # - Clear fields
- # - Update log
- # - gradio_demo(...) with the existing target_dir
- # - Then set is_example = "False"
- # -------------------------------------------------------------------------
- submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then(
- fn=update_log, inputs=[], outputs=[log_output]
- ).then(
- fn=gradio_demo,
- inputs=[
- target_dir_output,
- frame_filter,
- show_cam,
- filter_black_bg,
- filter_white_bg,
- apply_mask_checkbox,
- show_mesh,
- ],
- outputs=[
- reconstruction_output,
- log_output,
- frame_filter,
- processed_data_state,
- depth_map,
- normal_map,
- measure_image,
- measure_text,
- depth_view_selector,
- normal_view_selector,
- measure_view_selector,
- ],
- ).then(
- fn=lambda: "False",
- inputs=[],
- outputs=[is_example], # set is_example to "False"
- )
-
- # -------------------------------------------------------------------------
- # Real-time Visualization Updates
- # -------------------------------------------------------------------------
- frame_filter.change(
- update_visualization,
- [
- target_dir_output,
- frame_filter,
- show_cam,
- is_example,
- filter_black_bg,
- filter_white_bg,
- show_mesh,
- ],
- [reconstruction_output, log_output],
- )
- show_cam.change(
- update_visualization,
- [
- target_dir_output,
- frame_filter,
- show_cam,
- is_example,
- ],
- [reconstruction_output, log_output],
- )
- filter_black_bg.change(
- update_visualization,
- [
- target_dir_output,
- frame_filter,
- show_cam,
- is_example,
- filter_black_bg,
- filter_white_bg,
- ],
- [reconstruction_output, log_output],
- ).then(
- fn=update_all_views_on_filter_change,
- inputs=[
- target_dir_output,
- filter_black_bg,
- filter_white_bg,
- processed_data_state,
- depth_view_selector,
- normal_view_selector,
- measure_view_selector,
- ],
- outputs=[
- processed_data_state,
- depth_map,
- normal_map,
- measure_image,
- measure_points_state,
- ],
- )
- filter_white_bg.change(
- update_visualization,
- [
- target_dir_output,
- frame_filter,
- show_cam,
- is_example,
- filter_black_bg,
- filter_white_bg,
- show_mesh,
- ],
- [reconstruction_output, log_output],
- ).then(
- fn=update_all_views_on_filter_change,
- inputs=[
- target_dir_output,
- filter_black_bg,
- filter_white_bg,
- processed_data_state,
- depth_view_selector,
- normal_view_selector,
- measure_view_selector,
- ],
- outputs=[
- processed_data_state,
- depth_map,
- normal_map,
- measure_image,
- measure_points_state,
- ],
- )
-
- show_mesh.change(
- update_visualization,
- [
- target_dir_output,
- frame_filter,
- show_cam,
- is_example,
- filter_black_bg,
- filter_white_bg,
- show_mesh,
- ],
- [reconstruction_output, log_output],
- )
-
- # -------------------------------------------------------------------------
- # Auto-update gallery whenever user uploads or changes their files
- # -------------------------------------------------------------------------
- def update_gallery_on_unified_upload(files, interval):
- if not files:
- return None, None, None
- target_dir, image_paths = handle_uploads(files, interval)
- return (
- target_dir,
- image_paths,
- "Upload complete. Click 'Reconstruct' to begin 3D processing.",
- )
+ with gr.Row():
+ detect_all_btn = gr.Button("Detect All", size="sm")
+ restore_default_btn = gr.Button("Default", size="sm")
+
+ with gr.Accordion("Example Scenes", open=False):
+ scenes = get_scene_info("examples")
+ if scenes:
+ for i in range(0, len(scenes), 4):
+ with gr.Row(equal_height=True):
+ for j in range(4):
+ scene_idx = i + j
+ if scene_idx < len(scenes):
+ scene = scenes[scene_idx]
+ with gr.Column(scale=1, min_width=150):
+ scene_img = gr.Image(value=scene["thumbnail"], height=150, interactive=False, show_label=False, sources=[], container=False)
+ gr.Markdown(f"**{scene['name']}** ({scene['num_images']} images)", elem_classes=["text-center"])
+ scene_img.select(
+ fn=lambda name=scene["name"]: load_example_scene(name),
+ outputs=[reconstruction_output, target_dir_output, image_gallery, log_output]
+ )
+
+ # Connect UI Components
+ detect_all_btn.click(fn=lambda: COMMON_OBJECTS_PROMPT, outputs=[text_prompt])
+ restore_default_btn.click(fn=lambda: DEFAULT_TEXT_PROMPT, outputs=[text_prompt])
def show_resample_button(files):
- """Show the resample button only if there are uploaded files containing videos"""
- if not files:
- return gr.update(visible=False)
-
- # Check if any uploaded files are videos
- video_extensions = [
- ".mp4",
- ".avi",
- ".mov",
- ".mkv",
- ".wmv",
- ".flv",
- ".webm",
- ".m4v",
- ".3gp",
- ]
- has_video = False
-
- for file_data in files:
- if isinstance(file_data, dict) and "name" in file_data:
- file_path = file_data["name"]
- else:
- file_path = str(file_data)
-
- file_ext = os.path.splitext(file_path)[1].lower()
- if file_ext in video_extensions:
- has_video = True
- break
-
+ if not files: return gr.update(visible=False)
+ video_extensions = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"]
+ has_video = any(os.path.splitext(str(f["name"] if isinstance(f, dict) else f))[1].lower() in video_extensions for f in files)
return gr.update(visible=has_video)
- def hide_resample_button():
- """Hide the resample button after use"""
- return gr.update(visible=False)
-
- def resample_video_with_new_interval(files, new_interval, current_target_dir):
- """Resample video with new slider value"""
- if not files:
- return (
- current_target_dir,
- None,
- "No files to resample.",
- gr.update(visible=False),
- )
-
- # Check if we have videos to resample
- video_extensions = [
- ".mp4",
- ".avi",
- ".mov",
- ".mkv",
- ".wmv",
- ".flv",
- ".webm",
- ".m4v",
- ".3gp",
- ]
- has_video = any(
- os.path.splitext(
- str(file_data["name"] if isinstance(file_data, dict) else file_data)
- )[1].lower()
- in video_extensions
- for file_data in files
- )
-
- if not has_video:
- return (
- current_target_dir,
- None,
- "No videos found to resample.",
- gr.update(visible=False),
- )
-
- # Clean up old target directory if it exists
- if (
- current_target_dir
- and current_target_dir != "None"
- and os.path.exists(current_target_dir)
- ):
- shutil.rmtree(current_target_dir)
-
- # Process files with new interval
- target_dir, image_paths = handle_uploads(files, new_interval)
-
- return (
- target_dir,
- image_paths,
- f"Video resampled with {new_interval}s interval. Click 'Reconstruct' to begin 3D processing.",
- gr.update(visible=False),
- )
-
unified_upload.change(
- fn=update_gallery_on_unified_upload,
+ fn=update_gallery_on_upload,
inputs=[unified_upload, s_time_interval],
- outputs=[target_dir_output, image_gallery, log_output],
+ outputs=[raw_3d_output, target_dir_output, image_gallery, log_output]
).then(
- fn=show_resample_button,
- inputs=[unified_upload],
- outputs=[resample_btn],
+ fn=show_resample_button, inputs=[unified_upload], outputs=[resample_btn]
)
+ s_time_interval.change(fn=show_resample_button, inputs=[unified_upload], outputs=[resample_btn])
- # Show resample button when slider changes (only if files are uploaded)
- s_time_interval.change(
- fn=show_resample_button,
- inputs=[unified_upload],
- outputs=[resample_btn],
- )
+ def resample_video_with_new_interval(files, new_interval, current_target_dir):
+ if not files: return current_target_dir, None, "No files.", gr.update(visible=False)
+ if current_target_dir and current_target_dir != "None" and os.path.exists(current_target_dir): shutil.rmtree(current_target_dir)
+ target_dir, image_paths = handle_uploads(files, new_interval)
+ return target_dir, image_paths, f"Video resampled with {new_interval}s interval.", gr.update(visible=False)
- # Handle resample button click
resample_btn.click(
fn=resample_video_with_new_interval,
inputs=[unified_upload, s_time_interval, target_dir_output],
- outputs=[target_dir_output, image_gallery, log_output, resample_btn],
+ outputs=[target_dir_output, image_gallery, log_output, resample_btn]
)
- # -------------------------------------------------------------------------
- # Measure tab functionality
- # -------------------------------------------------------------------------
- measure_image.select(
- fn=measure,
- inputs=[processed_data_state, measure_points_state, measure_view_selector],
- outputs=[measure_image, measure_points_state, measure_text],
- )
-
- # -------------------------------------------------------------------------
- # Navigation functionality for Depth, Normal, and Measure tabs
- # -------------------------------------------------------------------------
-
- # Depth tab navigation
- prev_depth_btn.click(
- fn=lambda processed_data, current_selector: navigate_depth_view(
- processed_data, current_selector, -1
- ),
- inputs=[processed_data_state, depth_view_selector],
- outputs=[depth_view_selector, depth_map],
- )
-
- next_depth_btn.click(
- fn=lambda processed_data, current_selector: navigate_depth_view(
- processed_data, current_selector, 1
- ),
- inputs=[processed_data_state, depth_view_selector],
- outputs=[depth_view_selector, depth_map],
+ submit_btn.click(
+ fn=clear_fields,
+ outputs=[raw_3d_output, reconstruction_output]
+ ).then(
+ fn=update_log, outputs=[log_output]
+ ).then(
+ fn=gradio_demo,
+ inputs=[
+ target_dir_output, frame_filter, show_cam, filter_black_bg, filter_white_bg,
+ apply_mask_checkbox, show_mesh, enable_segmentation, text_prompt, use_sam_checkbox
+ ],
+ outputs=[
+ raw_3d_output, reconstruction_output, log_output, frame_filter,
+ processed_data_state, depth_map, normal_map, measure_image, measure_text,
+ depth_view_selector, normal_view_selector, measure_view_selector
+ ]
+ ).then(
+ fn=lambda: "False", outputs=[is_example]
)
- depth_view_selector.change(
- fn=lambda processed_data, selector_value: (
- update_depth_view(
- processed_data,
- int(selector_value.split()[1]) - 1,
- )
- if selector_value
- else None
- ),
- inputs=[processed_data_state, depth_view_selector],
- outputs=[depth_map],
- )
+ clear_btn.add([raw_3d_output, reconstruction_output, log_output])
- # Normal tab navigation
- prev_normal_btn.click(
- fn=lambda processed_data, current_selector: navigate_normal_view(
- processed_data, current_selector, -1
- ),
- inputs=[processed_data_state, normal_view_selector],
- outputs=[normal_view_selector, normal_map],
- )
+ for comp in [frame_filter, show_cam, show_mesh]:
+ comp.change(
+ fn=update_visualization,
+ inputs=[target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh],
+ outputs=[raw_3d_output, reconstruction_output, log_output]
+ )
- next_normal_btn.click(
- fn=lambda processed_data, current_selector: navigate_normal_view(
- processed_data, current_selector, 1
- ),
- inputs=[processed_data_state, normal_view_selector],
- outputs=[normal_view_selector, normal_map],
- )
+ for comp in [filter_black_bg, filter_white_bg]:
+ comp.change(
+ fn=update_visualization,
+ inputs=[target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh],
+ outputs=[raw_3d_output, reconstruction_output, log_output]
+ ).then(
+ fn=update_all_views_on_filter_change,
+ inputs=[target_dir_output, filter_black_bg, filter_white_bg, processed_data_state, depth_view_selector, normal_view_selector, measure_view_selector],
+ outputs=[processed_data_state, depth_map, normal_map, measure_image, measure_points_state]
+ )
- normal_view_selector.change(
- fn=lambda processed_data, selector_value: (
- update_normal_view(
- processed_data,
- int(selector_value.split()[1]) - 1,
- )
- if selector_value
- else None
- ),
- inputs=[processed_data_state, normal_view_selector],
- outputs=[normal_map],
- )
+ measure_image.select(fn=measure, inputs=[processed_data_state, measure_points_state, measure_view_selector], outputs=[measure_image, measure_points_state, measure_text])
- # Measure tab navigation
- prev_measure_btn.click(
- fn=lambda processed_data, current_selector: navigate_measure_view(
- processed_data, current_selector, -1
- ),
- inputs=[processed_data_state, measure_view_selector],
- outputs=[measure_view_selector, measure_image, measure_points_state],
- )
+ prev_depth_btn.click(fn=lambda d, s: navigate_depth_view(d, s, -1), inputs=[processed_data_state, depth_view_selector], outputs=[depth_view_selector, depth_map])
+ next_depth_btn.click(fn=lambda d, s: navigate_depth_view(d, s, 1), inputs=[processed_data_state, depth_view_selector], outputs=[depth_view_selector, depth_map])
+ depth_view_selector.change(fn=lambda d, v: update_depth_view(d, int(v.split()[1])-1) if v else None, inputs=[processed_data_state, depth_view_selector], outputs=[depth_map])
- next_measure_btn.click(
- fn=lambda processed_data, current_selector: navigate_measure_view(
- processed_data, current_selector, 1
- ),
- inputs=[processed_data_state, measure_view_selector],
- outputs=[measure_view_selector, measure_image, measure_points_state],
- )
+ prev_normal_btn.click(fn=lambda d, s: navigate_normal_view(d, s, -1), inputs=[processed_data_state, normal_view_selector], outputs=[normal_view_selector, normal_map])
+ next_normal_btn.click(fn=lambda d, s: navigate_normal_view(d, s, 1), inputs=[processed_data_state, normal_view_selector], outputs=[normal_view_selector, normal_map])
+ normal_view_selector.change(fn=lambda d, v: update_normal_view(d, int(v.split()[1])-1) if v else None, inputs=[processed_data_state, normal_view_selector], outputs=[normal_map])
- measure_view_selector.change(
- fn=lambda processed_data, selector_value: (
- update_measure_view(processed_data, int(selector_value.split()[1]) - 1)
- if selector_value
- else (None, [])
- ),
- inputs=[processed_data_state, measure_view_selector],
- outputs=[measure_image, measure_points_state],
- )
+ prev_measure_btn.click(fn=lambda d, s: navigate_measure_view(d, s, -1), inputs=[processed_data_state, measure_view_selector], outputs=[measure_view_selector, measure_image, measure_points_state])
+ next_measure_btn.click(fn=lambda d, s: navigate_measure_view(d, s, 1), inputs=[processed_data_state, measure_view_selector], outputs=[measure_view_selector, measure_image, measure_points_state])
+ measure_view_selector.change(fn=lambda d, v: update_measure_view(d, int(v.split()[1])-1) if v else (None, []), inputs=[processed_data_state, measure_view_selector], outputs=[measure_image, measure_points_state])
- # -------------------------------------------------------------------------
- # Acknowledgement section
- # -------------------------------------------------------------------------
gr.HTML(get_acknowledgements_html())
- demo.queue(max_size=20).launch(theme=theme, css=GRADIO_CSS, show_error=True, share=True, ssr_mode=False)
\ No newline at end of file
+if __name__ == "__main__":
+ demo.queue(max_size=20).launch(theme=theme, css=CUSTOM_CSS, show_error=True, share=True, ssr_mode=False)
\ No newline at end of file