diff --git "a/app.py" "b/app.py" --- "a/app.py" +++ "b/app.py" @@ -1,9 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + import gc import os import shutil import sys import time from datetime import datetime +from pathlib import Path +from collections import defaultdict +from typing import List, Dict, Tuple os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" @@ -12,11 +20,12 @@ import gradio as gr import numpy as np import spaces import torch +import trimesh from PIL import Image from pillow_heif import register_heif_opener +from sklearn.cluster import DBSCAN register_heif_opener() - sys.path.append("mapanything/") from mapanything.utils.geometry import depthmap_to_world_frame, points_to_normals @@ -27,10 +36,14 @@ from mapanything.utils.hf_utils.css_and_html import ( get_gradio_theme, ) from mapanything.utils.hf_utils.hf_helpers import initialize_mapanything_model -from mapanything.utils.hf_utils.viz import predictions_to_glb +from mapanything.utils.hf_utils.visual_util import predictions_to_glb from mapanything.utils.image import load_images, rgb +# ============================================================================ +# Global Configuration +# ============================================================================ + # MapAnything Configuration high_level_config = { "path": "configs/train.yaml", @@ -51,13 +64,616 @@ high_level_config = { "resolution": 518, } -# Initialize model - this will be done on GPU when needed +# GroundingDINO Configuration +GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny" +GROUNDING_DINO_BOX_THRESHOLD = 0.25 +GROUNDING_DINO_TEXT_THRESHOLD = 0.2 + +# SAM Configuration +SAM_MODEL_ID = "facebook/sam-vit-huge" + +DEFAULT_TEXT_PROMPT = "window . table . sofa . tv . book . door" + +COMMON_OBJECTS_PROMPT = ( + "person . face . hand . " + "chair . sofa . couch . bed . table . desk . cabinet . shelf . drawer . " + "door . window . wall . floor . ceiling . curtain . " + "tv . monitor . screen . computer . laptop . keyboard . mouse . " + "phone . tablet . remote . " + "lamp . light . chandelier . " + "book . magazine . paper . pen . pencil . " + "bottle . cup . glass . mug . plate . bowl . fork . knife . spoon . " + "vase . plant . flower . pot . " + "clock . picture . frame . mirror . " + "pillow . cushion . blanket . towel . " + "bag . backpack . suitcase . " + "box . basket . container . " + "shoe . hat . coat . " + "toy . ball . " + "car . bicycle . motorcycle . bus . truck . " + "tree . grass . sky . cloud . sun . " + "dog . cat . bird . " + "building . house . bridge . road . street . " + "sign . pole . bench" +) + +# DBSCAN Clustering Configuration +DBSCAN_EPS_CONFIG = { + 'sofa': 1.5, + 'bed': 1.5, + 'couch': 1.5, + 'desk': 0.8, + 'table': 0.8, + 'chair': 0.6, + 'cabinet': 0.8, + 'window': 0.5, + 'door': 0.6, + 'tv': 0.6, + 'default': 1.0 +} + +DBSCAN_MIN_SAMPLES = 1 +ENABLE_VISUAL_FEATURES = False + +# Segmentation Quality Control +MIN_DETECTION_CONFIDENCE = 0.35 +MIN_MASK_AREA = 100 + +# Matching score calculation configuration +MATCH_3D_DISTANCE_THRESHOLD = 2.5 + +# Global model variables model = None +grounding_dino_model = None +grounding_dino_processor = None +sam_predictor = None + + +# ============================================================================ +# Segmentation Model Loading +# ============================================================================ + +def load_grounding_dino_model(device): + global grounding_dino_model, grounding_dino_processor + if grounding_dino_model is not None: + print("GroundingDINO already loaded") + return + try: + from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection + print(f"Loading GroundingDINO from HuggingFace: {GROUNDING_DINO_MODEL_ID}") + grounding_dino_processor = AutoProcessor.from_pretrained(GROUNDING_DINO_MODEL_ID) + grounding_dino_model = AutoModelForZeroShotObjectDetection.from_pretrained( + GROUNDING_DINO_MODEL_ID + ).to(device).eval() + print("GroundingDINO loaded successfully") + except Exception as e: + print(f"Failed to load GroundingDINO: {e}") + import traceback + traceback.print_exc() + +def load_sam_model(device): + global sam_predictor + if sam_predictor is not None: + print("SAM already loaded") + return + try: + from transformers import SamModel, SamProcessor + print(f"Loading SAM from HuggingFace: {SAM_MODEL_ID}") + sam_model = SamModel.from_pretrained(SAM_MODEL_ID).to(device).eval() + sam_processor = SamProcessor.from_pretrained(SAM_MODEL_ID) + sam_predictor = {'model': sam_model, 'processor': sam_processor} + print("SAM loaded successfully") + except Exception as e: + print(f"Failed to load SAM: {e}") + print("SAM functionality will be disabled, bounding boxes will be used as masks instead") + import traceback + traceback.print_exc() + + +# ============================================================================ +# Segmentation Functions +# ============================================================================ + +def generate_distinct_colors(n): + import colorsys + if n == 0: + return [] + colors = [] + for i in range(n): + hue = i / max(n, 1) + rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.95) + rgb_color = tuple(int(c * 255) for c in rgb) + colors.append(rgb_color) + return colors + +def run_grounding_dino_detection(image_np, text_prompt, device): + if grounding_dino_model is None or grounding_dino_processor is None: + print("GroundingDINO is not loaded") + return [] + try: + print(f"GroundingDINO Detection: {text_prompt}") + if image_np.dtype == np.uint8: + pil_image = Image.fromarray(image_np) + else: + pil_image = Image.fromarray((image_np * 255).astype(np.uint8)) + + inputs = grounding_dino_processor(images=pil_image, text=text_prompt, return_tensors="pt") + inputs = {k: v.to(device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = grounding_dino_model(**inputs) + + results = grounding_dino_processor.post_process_grounded_object_detection( + outputs, + inputs["input_ids"], + threshold=GROUNDING_DINO_BOX_THRESHOLD, + text_threshold=GROUNDING_DINO_TEXT_THRESHOLD, + target_sizes=[pil_image.size[::-1]] + )[0] + + detections = [] + boxes = results["boxes"].cpu().numpy() + scores = results["scores"].cpu().numpy() + labels = results["labels"] + + print(f"Detected {len(boxes)} objects") + for box, score, label in zip(boxes, scores, labels): + detection = { + 'bbox': box.tolist(), + 'label': label, + 'confidence': float(score) + } + detections.append(detection) + print(f" - {label}: {score:.2f}") + return detections + except Exception as e: + print(f"GroundingDINO detection failed: {e}") + import traceback + traceback.print_exc() + return [] + +def run_sam_refinement(image_np, boxes): + if sam_predictor is None: + print("SAM is not loaded, using bbox as mask") + masks = [] + h, w = image_np.shape[:2] + for box in boxes: + x1, y1, x2, y2 = map(int, box) + mask = np.zeros((h, w), dtype=bool) + mask[y1:y2, x1:x2] = True + masks.append(mask) + return masks + try: + print(f"SAM accurate segmentation for {len(boxes)} regions...") + from PIL import Image + sam_model = sam_predictor['model'] + sam_processor = sam_predictor['processor'] + device = sam_model.device + + if image_np.dtype == np.uint8: + pil_image = Image.fromarray(image_np) + else: + pil_image = Image.fromarray((image_np * 255).astype(np.uint8)) + + masks = [] + for box in boxes: + x1, y1, x2, y2 = map(int, box) + input_boxes = [[[x1, y1, x2, y2]]] + + inputs = sam_processor(pil_image, input_boxes=input_boxes, return_tensors="pt") + inputs = {k: v.to(device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = sam_model(**inputs) + + pred_masks = sam_processor.image_processor.post_process_masks( + outputs.pred_masks.cpu(), + inputs["original_sizes"].cpu(), + inputs["reshaped_input_sizes"].cpu() + )[0][0][0] + + masks.append(pred_masks.numpy() > 0.5) + print("SAM segmentation completed") + return masks + except Exception as e: + print(f"SAM segmentation failed: {e}") + import traceback + traceback.print_exc() + masks = [] + h, w = image_np.shape[:2] + for box in boxes: + x1, y1, x2, y2 = map(int, box) + mask = np.zeros((h, w), dtype=bool) + mask[y1:y2, x1:x2] = True + masks.append(mask) + return masks + +def normalize_label(label): + label = label.strip().lower() + priority_labels = ['sofa', 'bed', 'table', 'desk', 'chair', 'cabinet', 'window', 'door'] + for priority in priority_labels: + if priority in label: + return priority + first_word = label.split()[0] if label else label + if first_word.endswith('s') and len(first_word) > 1: + singular = first_word[:-1] + if first_word.endswith('sses'): + singular = first_word[:-2] + elif first_word.endswith('ies'): + singular = first_word[:-3] + 'y' + elif first_word.endswith('ves'): + singular = first_word[:-3] + 'f' + return singular + return first_word + +def labels_match(label1, label2): + return normalize_label(label1) == normalize_label(label2) + +def compute_object_3d_center(points, mask): + masked_points = points[mask] + if len(masked_points) == 0: + return None + return np.median(masked_points, axis=0) + +def compute_3d_bbox_iou(center1, size1, center2, size2): + try: + min1 = center1 - size1 / 2 + max1 = center1 + size1 / 2 + min2 = center2 - size2 / 2 + max2 = center2 + size2 / 2 + + inter_min = np.maximum(min1, min2) + inter_max = np.minimum(max1, max2) + inter_size = np.maximum(0, inter_max - inter_min) + inter_volume = np.prod(inter_size) + + volume1 = np.prod(size1) + volume2 = np.prod(size2) + union_volume = volume1 + volume2 - inter_volume + + if union_volume == 0: + return 0.0 + return inter_volume / union_volume + except: + return 0.0 + +def compute_2d_mask_iou(mask1, mask2): + try: + intersection = np.logical_and(mask1, mask2).sum() + union = np.logical_or(mask1, mask2).sum() + if union == 0: + return 0.0 + return intersection / union + except: + return 0.0 + +def extract_visual_features(image, mask, encoder): + try: + coords = np.argwhere(mask) + if len(coords) == 0: + return None + y_min, x_min = coords.min(axis=0) + y_max, x_max = coords.max(axis=0) + + if y_max <= y_min or x_max <= x_min: + return None + + cropped = image[y_min:y_max+1, x_min:x_max+1] + if cropped.dtype == np.float32 or cropped.dtype == np.float64: + if cropped.max() <= 1.0: + cropped = (cropped * 255).astype(np.uint8) + else: + cropped = cropped.astype(np.uint8) + + from PIL import Image + import torchvision.transforms as T + + pil_img = Image.fromarray(cropped) + pil_img = pil_img.resize((224, 224), Image.BILINEAR) + + transform = T.Compose([ + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + try: + device = next(encoder.parameters()).device + except: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + img_tensor = transform(pil_img).unsqueeze(0).to(device) + + with torch.no_grad(): + if hasattr(encoder, 'forward_features'): + features = encoder.forward_features(img_tensor) + else: + features = encoder(img_tensor) + + if not isinstance(features, torch.Tensor): + if isinstance(features, dict): + features = features.get('x', features.get('last_hidden_state', None)) + if features is None: return None + elif hasattr(features, 'data'): + features = features.data + else: + return None + + if len(features.shape) == 4: + features = features.mean(dim=[2, 3]) + elif len(features.shape) == 3: + features = features.mean(dim=1) + + features = features / (features.norm(dim=1, keepdim=True) + 1e-8) + return features.cpu().numpy()[0] + except Exception as e: + import traceback + print(f"Feature extraction failed: {type(e).__name__}: {e}") + return None + +def compute_feature_similarity(feat1, feat2): + if feat1 is None or feat2 is None: + return 0.0 + try: + return np.dot(feat1, feat2) + except: + return 0.0 + +def compute_adaptive_eps(centers, base_eps): + if len(centers) <= 1: + return base_eps + from scipy.spatial.distance import pdist + distances = pdist(centers) + if len(distances) == 0: + return base_eps + median_dist = np.median(distances) + if median_dist > base_eps * 2: + adaptive_eps = min(median_dist * 0.6, base_eps * 2.5) + elif median_dist > base_eps: + adaptive_eps = median_dist * 0.5 + else: + adaptive_eps = base_eps + return adaptive_eps + +def match_objects_across_views(all_view_detections): + print("\nMatching objects using adaptive DBSCAN clustering...") + objects_by_label = defaultdict(list) + + for view_idx, detections in enumerate(all_view_detections): + for det_idx, det in enumerate(detections): + if det.get('center_3d') is None: + continue + norm_label = normalize_label(det['label']) + objects_by_label[norm_label].append({ + 'view_idx': view_idx, + 'det_idx': det_idx, + 'label': det['label'], + 'norm_label': norm_label, + 'center_3d': det['center_3d'], + 'confidence': det['confidence'], + 'bbox_3d': det.get('bbox_3d'), + }) + + if len(objects_by_label) == 0: + return {}, [] + + object_id_map = defaultdict(dict) + unique_objects = [] + next_global_id = 0 + + for norm_label, objects in objects_by_label.items(): + print(f"Processing {norm_label}: {len(objects)} detections") + if len(objects) == 1: + obj = objects[0] + unique_objects.append({ + 'global_id': next_global_id, + 'label': obj['label'], + 'views': [(obj['view_idx'], obj['det_idx'])], + 'center_3d': obj['center_3d'], + }) + object_id_map[obj['view_idx']][obj['det_idx']] = next_global_id + next_global_id += 1 + continue + + centers = np.array([obj['center_3d'] for obj in objects]) + base_eps = DBSCAN_EPS_CONFIG.get(norm_label, DBSCAN_EPS_CONFIG.get('default', 1.0)) + eps = compute_adaptive_eps(centers, base_eps) + + clustering = DBSCAN(eps=eps, min_samples=DBSCAN_MIN_SAMPLES, metric='euclidean') + cluster_labels = clustering.fit_predict(centers) + + for cluster_id in set(cluster_labels): + if cluster_id == -1: + for i, label in enumerate(cluster_labels): + if label == -1: + obj = objects[i] + unique_objects.append({ + 'global_id': next_global_id, + 'label': obj['label'], + 'views': [(obj['view_idx'], obj['det_idx'])], + 'center_3d': obj['center_3d'], + }) + object_id_map[obj['view_idx']][obj['det_idx']] = next_global_id + next_global_id += 1 + else: + cluster_objects = [objects[i] for i, label in enumerate(cluster_labels) if label == cluster_id] + total_conf = sum(o['confidence'] for o in cluster_objects) + weighted_center = sum(o['center_3d'] * o['confidence'] for o in cluster_objects) / total_conf + + unique_objects.append({ + 'global_id': next_global_id, + 'label': cluster_objects[0]['label'], + 'views': [(o['view_idx'], o['det_idx']) for o in cluster_objects], + 'center_3d': weighted_center, + }) + for obj in cluster_objects: + object_id_map[obj['view_idx']][obj['det_idx']] = next_global_id + next_global_id += 1 + + return object_id_map, unique_objects + +def create_multi_view_segmented_mesh(processed_data, all_view_detections, all_view_masks, + object_id_map, unique_objects, target_dir, use_sam=True): + try: + print("\nGenerating multi-view segmented mesh...") + unique_normalized_labels = sorted(set(normalize_label(obj['label']) for obj in unique_objects)) + label_colors = {} + colors = generate_distinct_colors(len(unique_normalized_labels)) + + for i, norm_label in enumerate(unique_normalized_labels): + label_colors[norm_label] = colors[i] + + for obj in unique_objects: + norm_label = normalize_label(obj['label']) + obj['color'] = label_colors[norm_label] + obj['normalized_label'] = norm_label + + import utils3d + all_meshes = [] + + for view_idx in range(len(processed_data)): + view_data = processed_data[view_idx] + image = view_data["image"] + points3d = view_data["points3d"] + mask = view_data.get("mask") + normal = view_data.get("normal") + + detections = all_view_detections[view_idx] + masks = all_view_masks[view_idx] + + if len(detections) == 0: + continue + + if image.dtype != np.uint8: + if image.max() <= 1.0: + image = (image * 255).astype(np.uint8) + else: + image = image.astype(np.uint8) + + colored_image = image.copy() + confidence_map = np.zeros((image.shape[0], image.shape[1]), dtype=np.float32) + + detections_info = [] + for det_idx, (det, seg_mask) in enumerate(zip(detections, masks)): + if det['confidence'] < MIN_DETECTION_CONFIDENCE: + continue + mask_area = seg_mask.sum() + if mask_area < MIN_MASK_AREA: + continue + global_id = object_id_map[view_idx].get(det_idx) + if global_id is None: + continue + unique_obj = next((obj for obj in unique_objects if obj['global_id'] == global_id), None) + if unique_obj is None: + continue + detections_info.append({ + 'mask': seg_mask, + 'color': unique_obj['color'], + 'confidence': det['confidence'], + 'label': det['label'], + 'area': mask_area + }) + + detections_info.sort(key=lambda x: x['confidence']) + for info in detections_info: + seg_mask = info['mask'] + color = info['color'] + conf = info['confidence'] + update_mask = seg_mask & (conf > confidence_map) + colored_image[update_mask] = color + confidence_map[update_mask] = conf + + height, width = image.shape[:2] + if normal is None: + faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh( + points3d, + colored_image.astype(np.float32) / 255, + utils3d.numpy.image_uv(width=width, height=height), + mask=mask if mask is not None else np.ones((height, width), dtype=bool), + tri=True + ) + vertex_normals = None + else: + faces, vertices, vertex_colors, vertex_uvs, vertex_normals = utils3d.numpy.image_mesh( + points3d, + colored_image.astype(np.float32) / 255, + utils3d.numpy.image_uv(width=width, height=height), + normal, + mask=mask if mask is not None else np.ones((height, width), dtype=bool), + tri=True + ) + + vertices = vertices * np.array([1, -1, -1], dtype=np.float32) + if vertex_normals is not None: + vertex_normals = vertex_normals * np.array([1, -1, -1], dtype=np.float32) + + view_mesh = trimesh.Trimesh( + vertices=vertices, + faces=faces, + vertex_normals=vertex_normals, + vertex_colors=(vertex_colors * 255).astype(np.uint8), + process=False + ) + all_meshes.append(view_mesh) + + if len(all_meshes) == 0: + return None + + combined_mesh = trimesh.util.concatenate(all_meshes) + glb_path = os.path.join(target_dir, 'multi_view_segmented_mesh.glb') + combined_mesh.export(glb_path) + return glb_path + except Exception as e: + print(f"Failed to generate multi-view mesh: {e}") + import traceback + traceback.print_exc() + return None + + +# ============================================================================ +# Core Model Inference & View Processing +# ============================================================================ + +def process_predictions_for_visualization( + predictions, views, high_level_config, filter_black_bg=False, filter_white_bg=False +): + processed_data = {} + for view_idx, view in enumerate(views): + image = rgb(view["img"], norm_type=high_level_config["data_norm_type"]) + pred_pts3d = predictions["world_points"][view_idx] + + view_data = { + "image": image[0], + "points3d": pred_pts3d, + "depth": None, + "normal": None, + "mask": None, + } + + mask = predictions["final_mask"][view_idx].copy() + + if filter_black_bg: + view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0] + black_bg_mask = view_colors.sum(axis=2) >= 16 + mask = mask & black_bg_mask + + if filter_white_bg: + view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0] + white_bg_mask = ~( + (view_colors[:, :, 0] > 240) + & (view_colors[:, :, 1] > 240) + & (view_colors[:, :, 2] > 240) + ) + mask = mask & white_bg_mask + + view_data["mask"] = mask + view_data["depth"] = predictions["depth"][view_idx].squeeze() + normals, _ = points_to_normals(pred_pts3d, mask=view_data["mask"]) + view_data["normal"] = normals + processed_data[view_idx] = view_data + + return processed_data -# ------------------------------------------------------------------------- -# 1) Core model inference -# ------------------------------------------------------------------------- @spaces.GPU(duration=120) def run_model( target_dir, @@ -65,29 +681,31 @@ def run_model( mask_edges=True, filter_black_bg=False, filter_white_bg=False, + enable_segmentation=False, + text_prompt=DEFAULT_TEXT_PROMPT, + use_sam=True, ): - """ - Run the MapAnything model on images in the 'target_dir/images' folder and return predictions. - """ - global model - import torch # Ensure torch is available in function scope + global model, grounding_dino_model, sam_predictor + import torch print(f"Processing images from {target_dir}") - - # Device check device = "cuda" if torch.cuda.is_available() else "cpu" device = torch.device(device) - # Initialize model if not already done if model is None: + print("Loading MapAnything from HuggingFace...") model = initialize_mapanything_model(high_level_config, device) - + print("MapAnything loaded successfully") else: model = model.to(device) model.eval() - # Load images using MapAnything's load_images function + if enable_segmentation: + load_grounding_dino_model(device) + if use_sam: + load_sam_model(device) + print("Loading images...") image_folder_path = os.path.join(target_dir, "images") views = load_images(image_folder_path) @@ -96,19 +714,12 @@ def run_model( if len(views) == 0: raise ValueError("No images found. Check your upload.") - # Run model inference print("Running inference...") - # apply_mask: Whether to apply the non-ambiguous mask to the output. Defaults to True. - # mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True. - # Use checkbox values - mask_edges is set to True by default since there's no UI control for it outputs = model.infer( views, apply_mask=apply_mask, mask_edges=True, memory_efficient_inference=False ) - # Convert predictions to format expected by visualization predictions = {} - - # Initialize lists for the required keys extrinsic_list = [] intrinsic_list = [] world_points_list = [] @@ -116,250 +727,295 @@ def run_model( images_list = [] final_mask_list = [] - # Loop through the outputs for pred in outputs: - # Extract data from predictions - depthmap_torch = pred["depth_z"][0].squeeze(-1) # (H, W) - intrinsics_torch = pred["intrinsics"][0] # (3, 3) - camera_pose_torch = pred["camera_poses"][0] # (4, 4) - - # Compute new pts3d using depth, intrinsics, and camera pose + depthmap_torch = pred["depth_z"][0].squeeze(-1) + intrinsics_torch = pred["intrinsics"][0] + camera_pose_torch = pred["camera_poses"][0] + pts3d_computed, valid_mask = depthmap_to_world_frame( depthmap_torch, intrinsics_torch, camera_pose_torch ) - - # Convert to numpy arrays for visualization - # Check if mask key exists in pred, if not, fill with boolean trues in the size of depthmap_torch + if "mask" in pred: mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool) else: - # Fill with boolean trues in the size of depthmap_torch mask = np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool) - # Combine with valid depth mask mask = mask & valid_mask.cpu().numpy() - image = pred["img_no_norm"][0].cpu().numpy() - # Append to lists extrinsic_list.append(camera_pose_torch.cpu().numpy()) intrinsic_list.append(intrinsics_torch.cpu().numpy()) world_points_list.append(pts3d_computed.cpu().numpy()) depth_maps_list.append(depthmap_torch.cpu().numpy()) - images_list.append(image) # Add image to list - final_mask_list.append(mask) # Add final_mask to list + images_list.append(image) + final_mask_list.append(mask) - # Convert lists to numpy arrays with required shapes - # extrinsic: (S, 3, 4) - batch of camera extrinsic matrices predictions["extrinsic"] = np.stack(extrinsic_list, axis=0) - - # intrinsic: (S, 3, 3) - batch of camera intrinsic matrices predictions["intrinsic"] = np.stack(intrinsic_list, axis=0) - - # world_points: (S, H, W, 3) - batch of 3D world points predictions["world_points"] = np.stack(world_points_list, axis=0) - - # depth: (S, H, W, 1) or (S, H, W) - batch of depth maps + depth_maps = np.stack(depth_maps_list, axis=0) - # Add channel dimension if needed to match (S, H, W, 1) format if len(depth_maps.shape) == 3: depth_maps = depth_maps[..., np.newaxis] - predictions["depth"] = depth_maps - - # images: (S, H, W, 3) - batch of input images predictions["images"] = np.stack(images_list, axis=0) - - # final_mask: (S, H, W) - batch of final masks for filtering predictions["final_mask"] = np.stack(final_mask_list, axis=0) - # Process data for visualization tabs (depth, normal, measure) processed_data = process_predictions_for_visualization( predictions, views, high_level_config, filter_black_bg, filter_white_bg ) - # Clean up + segmented_glb = None + if enable_segmentation and grounding_dino_model is not None: + print("\nStarting multi-view segmentation...") + all_view_detections = [] + all_view_masks = [] + + for view_idx, ref_image in enumerate(images_list): + if ref_image.dtype != np.uint8: + ref_image_np = (ref_image * 255).astype(np.uint8) + else: + ref_image_np = ref_image + + detections = run_grounding_dino_detection(ref_image_np, text_prompt, device) + if len(detections) > 0: + boxes = [d['bbox'] for d in detections] + masks = run_sam_refinement(ref_image_np, boxes) if use_sam else [] + points3d = world_points_list[view_idx] + encoder = model.encoder if hasattr(model, 'encoder') else None + + for det_idx, (det, mask) in enumerate(zip(detections, masks)): + center_3d = compute_object_3d_center(points3d, mask) + det['center_3d'] = center_3d + if center_3d is not None: + masked_points = points3d[mask] + if len(masked_points) > 0: + bbox_min = masked_points.min(axis=0) + bbox_max = masked_points.max(axis=0) + det['bbox_3d'] = { + 'center': center_3d, 'size': bbox_max - bbox_min, + 'min': bbox_min, 'max': bbox_max + } + det['mask_2d'] = mask + if ENABLE_VISUAL_FEATURES and encoder is not None: + det['visual_feature'] = extract_visual_features(ref_image, mask, encoder) + else: + det['visual_feature'] = None + + all_view_detections.append(detections) + all_view_masks.append(masks) + else: + all_view_detections.append([]) + all_view_masks.append([]) + + if any(len(dets) > 0 for dets in all_view_detections): + object_id_map, unique_objects = match_objects_across_views(all_view_detections) + segmented_glb = create_multi_view_segmented_mesh( + processed_data, all_view_detections, all_view_masks, + object_id_map, unique_objects, target_dir, use_sam + ) + torch.cuda.empty_cache() + return predictions, processed_data, segmented_glb - return predictions, processed_data +# ============================================================================ +# View Navigation and Visualization Helpers +# ============================================================================ -def update_view_selectors(processed_data): - """Update view selector dropdowns based on available views""" - if processed_data is None or len(processed_data) == 0: - choices = ["View 1"] - else: - num_views = len(processed_data) - choices = [f"View {i + 1}" for i in range(num_views)] +def colorize_depth(depth_map, mask=None): + if depth_map is None: return None + depth_normalized = depth_map.copy() + valid_mask = depth_normalized > 0 + if mask is not None: + valid_mask = valid_mask & mask - return ( - gr.Dropdown(choices=choices, value=choices[0]), # depth_view_selector - gr.Dropdown(choices=choices, value=choices[0]), # normal_view_selector - gr.Dropdown(choices=choices, value=choices[0]), # measure_view_selector - ) + if valid_mask.sum() > 0: + valid_depths = depth_normalized[valid_mask] + p5 = np.percentile(valid_depths, 5) + p95 = np.percentile(valid_depths, 95) + depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5) + import matplotlib.pyplot as plt + colormap = plt.cm.turbo_r + colored = colormap(depth_normalized) + colored = (colored[:, :, :3] * 255).astype(np.uint8) + colored[~valid_mask] = [255, 255, 255] + return colored -def get_view_data_by_index(processed_data, view_index): - """Get view data by index, handling bounds""" - if processed_data is None or len(processed_data) == 0: - return None +def colorize_normal(normal_map, mask=None): + if normal_map is None: return None + normal_vis = normal_map.copy() + if mask is not None: + invalid_mask = ~mask + normal_vis[invalid_mask] = [0, 0, 0] + normal_vis = (normal_vis + 1.0) / 2.0 + normal_vis = (normal_vis * 255).astype(np.uint8) + return normal_vis +def get_view_data_by_index(processed_data, view_index): + if processed_data is None or len(processed_data) == 0: return None view_keys = list(processed_data.keys()) - if view_index < 0 or view_index >= len(view_keys): - view_index = 0 - + if view_index < 0 or view_index >= len(view_keys): view_index = 0 return processed_data[view_keys[view_index]] - def update_depth_view(processed_data, view_index): - """Update depth view for a specific view index""" view_data = get_view_data_by_index(processed_data, view_index) - if view_data is None or view_data["depth"] is None: - return None - + if view_data is None or view_data["depth"] is None: return None return colorize_depth(view_data["depth"], mask=view_data.get("mask")) - def update_normal_view(processed_data, view_index): - """Update normal view for a specific view index""" view_data = get_view_data_by_index(processed_data, view_index) - if view_data is None or view_data["normal"] is None: - return None - + if view_data is None or view_data["normal"] is None: return None return colorize_normal(view_data["normal"], mask=view_data.get("mask")) - def update_measure_view(processed_data, view_index): - """Update measure view for a specific view index with mask overlay""" view_data = get_view_data_by_index(processed_data, view_index) - if view_data is None: - return None, [] # image, measure_points - - # Get the base image + if view_data is None: return None, [] image = view_data["image"].copy() - - # Ensure image is in uint8 format if image.dtype != np.uint8: - if image.max() <= 1.0: - image = (image * 255).astype(np.uint8) - else: - image = image.astype(np.uint8) + if image.max() <= 1.0: image = (image * 255).astype(np.uint8) + else: image = image.astype(np.uint8) - # Apply mask overlay if mask is available if view_data["mask"] is not None: - mask = view_data["mask"] - - # Create light grey overlay for masked areas - # Masked areas (False values) will be overlaid with light grey - invalid_mask = ~mask # Areas where mask is False - + invalid_mask = ~view_data["mask"] if invalid_mask.any(): - # Create a light grey overlay (RGB: 192, 192, 192) overlay_color = np.array([255, 220, 220], dtype=np.uint8) - - # Apply overlay with some transparency - alpha = 0.5 # Transparency level - for c in range(3): # RGB channels + alpha = 0.5 + for c in range(3): image[:, :, c] = np.where( invalid_mask, (1 - alpha) * image[:, :, c] + alpha * overlay_color[c], image[:, :, c], ).astype(np.uint8) - return image, [] - -def navigate_depth_view(processed_data, current_selector_value, direction): - """Navigate depth view (direction: -1 for previous, +1 for next)""" +def update_view_selectors(processed_data): if processed_data is None or len(processed_data) == 0: - return "View 1", None - - # Parse current view number - try: - current_view = int(current_selector_value.split()[1]) - 1 - except: - current_view = 0 - - num_views = len(processed_data) - new_view = (current_view + direction) % num_views - - new_selector_value = f"View {new_view + 1}" - depth_vis = update_depth_view(processed_data, new_view) - - return new_selector_value, depth_vis + choices = ["View 1"] + else: + choices = [f"View {i + 1}" for i in range(len(processed_data))] + return ( + gr.Dropdown(choices=choices, value=choices[0]), + gr.Dropdown(choices=choices, value=choices[0]), + gr.Dropdown(choices=choices, value=choices[0]), + ) +def navigate_depth_view(processed_data, current_selector_value, direction): + if processed_data is None or len(processed_data) == 0: return "View 1", None + try: current_view = int(current_selector_value.split()[1]) - 1 + except: current_view = 0 + new_view = (current_view + direction) % len(processed_data) + return f"View {new_view + 1}", update_depth_view(processed_data, new_view) def navigate_normal_view(processed_data, current_selector_value, direction): - """Navigate normal view (direction: -1 for previous, +1 for next)""" - if processed_data is None or len(processed_data) == 0: - return "View 1", None - - # Parse current view number - try: - current_view = int(current_selector_value.split()[1]) - 1 - except: - current_view = 0 - - num_views = len(processed_data) - new_view = (current_view + direction) % num_views - - new_selector_value = f"View {new_view + 1}" - normal_vis = update_normal_view(processed_data, new_view) - - return new_selector_value, normal_vis - + if processed_data is None or len(processed_data) == 0: return "View 1", None + try: current_view = int(current_selector_value.split()[1]) - 1 + except: current_view = 0 + new_view = (current_view + direction) % len(processed_data) + return f"View {new_view + 1}", update_normal_view(processed_data, new_view) def navigate_measure_view(processed_data, current_selector_value, direction): - """Navigate measure view (direction: -1 for previous, +1 for next)""" - if processed_data is None or len(processed_data) == 0: - return "View 1", None, [] - - # Parse current view number - try: - current_view = int(current_selector_value.split()[1]) - 1 - except: - current_view = 0 + if processed_data is None or len(processed_data) == 0: return "View 1", None, [] + try: current_view = int(current_selector_value.split()[1]) - 1 + except: current_view = 0 + new_view = (current_view + direction) % len(processed_data) + measure_image, measure_points = update_measure_view(processed_data, new_view) + return f"View {new_view + 1}", measure_image, measure_points - num_views = len(processed_data) - new_view = (current_view + direction) % num_views +def populate_visualization_tabs(processed_data): + if processed_data is None or len(processed_data) == 0: return None, None, None, [] + return ( + update_depth_view(processed_data, 0), + update_normal_view(processed_data, 0), + update_measure_view(processed_data, 0)[0], + [] + ) - new_selector_value = f"View {new_view + 1}" - measure_image, measure_points = update_measure_view(processed_data, new_view) +def measure(processed_data, measure_points, current_view_selector, event: gr.SelectData): + try: + if processed_data is None or len(processed_data) == 0: return None, [], "No data available" + try: current_view_index = int(current_view_selector.split()[1]) - 1 + except: current_view_index = 0 + if current_view_index < 0 or current_view_index >= len(processed_data): current_view_index = 0 + + view_keys = list(processed_data.keys()) + current_view = processed_data[view_keys[current_view_index]] + if current_view is None: return None, [], "No view data available" + + point2d = event.index[0], event.index[1] + + if (current_view["mask"] is not None and + 0 <= point2d[1] < current_view["mask"].shape[0] and + 0 <= point2d[0] < current_view["mask"].shape[1]): + if not current_view["mask"][point2d[1], point2d[0]]: + masked_image, _ = update_measure_view(processed_data, current_view_index) + return masked_image, measure_points, 'Cannot measure on masked areas (shown in grey)' - return new_selector_value, measure_image, measure_points + measure_points.append(point2d) + image, _ = update_measure_view(processed_data, current_view_index) + if image is None: return None, [], "No image available" + + image = image.copy() + points3d = current_view["points3d"] + + for p in measure_points: + if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]: + image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2) + depth_text = "" + for i, p in enumerate(measure_points): + if (current_view["depth"] is not None and + 0 <= p[1] < current_view["depth"].shape[0] and + 0 <= p[0] < current_view["depth"].shape[1]): + d = current_view["depth"][p[1], p[0]] + depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n" + else: + if (points3d is not None and 0 <= p[1] < points3d.shape[0] and 0 <= p[0] < points3d.shape[1]): + z = points3d[p[1], p[0], 2] + depth_text += f"- **P{i + 1} Z-coord: {z:.2f}m.**\n" -def populate_visualization_tabs(processed_data): - """Populate the depth, normal, and measure tabs with processed data""" - if processed_data is None or len(processed_data) == 0: - return None, None, None, [] + if len(measure_points) == 2: + point1, point2 = measure_points + if (0 <= point1[0] < image.shape[1] and 0 <= point1[1] < image.shape[0] and + 0 <= point2[0] < image.shape[1] and 0 <= point2[1] < image.shape[0]): + image = cv2.line(image, point1, point2, color=(255, 0, 0), thickness=2) + + distance_text = "- **Distance: Unable to compute**" + if (points3d is not None and 0 <= point1[1] < points3d.shape[0] and 0 <= point1[0] < points3d.shape[1] and + 0 <= point2[1] < points3d.shape[0] and 0 <= point2[0] < points3d.shape[1]): + try: + p1_3d = points3d[point1[1], point1[0]] + p2_3d = points3d[point2[1], point2[0]] + distance = np.linalg.norm(p1_3d - p2_3d) + distance_text = f"- **Distance: {distance:.2f}m**" + except: + pass + + measure_points = [] + text = depth_text + distance_text + return [image, measure_points, text] + else: + return [image, measure_points, depth_text] - # Use update functions to ensure confidence filtering is applied from the start - depth_vis = update_depth_view(processed_data, 0) - normal_vis = update_normal_view(processed_data, 0) - measure_img, _ = update_measure_view(processed_data, 0) + except Exception as e: + return None, [], f"Measure function error: {e}" - return depth_vis, normal_vis, measure_img, [] +# ============================================================================ +# Core Functions for Upload and Interface +# ============================================================================ -# ------------------------------------------------------------------------- -# 2) Handle uploaded video/images --> produce target_dir + images -# ------------------------------------------------------------------------- def handle_uploads(unified_upload, s_time_interval=1.0): - """ - Create a new 'target_dir' + 'images' subfolder, and place user-uploaded - images or extracted frames from video into it. Return (target_dir, image_paths). - """ start_time = time.time() gc.collect() torch.cuda.empty_cache() - # Create a unique folder name timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") target_dir = f"input_images_{timestamp}" target_dir_images = os.path.join(target_dir, "images") - # Clean up if somehow that folder already exists if os.path.exists(target_dir): shutil.rmtree(target_dir) os.makedirs(target_dir) @@ -367,7 +1023,6 @@ def handle_uploads(unified_upload, s_time_interval=1.0): image_paths = [] - # --- Handle uploaded files (both images and videos) --- if unified_upload is not None: for file_data in unified_upload: if isinstance(file_data, dict) and "name" in file_data: @@ -376,118 +1031,57 @@ def handle_uploads(unified_upload, s_time_interval=1.0): file_path = str(file_data) file_ext = os.path.splitext(file_path)[1].lower() - - # Check if it's a video file - video_extensions = [ - ".mp4", - ".avi", - ".mov", - ".mkv", - ".wmv", - ".flv", - ".webm", - ".m4v", - ".3gp", - ] + video_extensions = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"] + if file_ext in video_extensions: - # Handle as video vs = cv2.VideoCapture(file_path) fps = vs.get(cv2.CAP_PROP_FPS) - frame_interval = int(fps * s_time_interval) # frames per interval - + frame_interval = int(fps * s_time_interval) count = 0 video_frame_num = 0 while True: gotit, frame = vs.read() - if not gotit: - break + if not gotit: break count += 1 if count % frame_interval == 0: - # Use original filename as prefix for frames base_name = os.path.splitext(os.path.basename(file_path))[0] - image_path = os.path.join( - target_dir_images, f"{base_name}_{video_frame_num:06}.png" - ) + image_path = os.path.join(target_dir_images, f"{base_name}_{video_frame_num:06}.png") cv2.imwrite(image_path, frame) image_paths.append(image_path) video_frame_num += 1 vs.release() - print( - f"Extracted {video_frame_num} frames from video: {os.path.basename(file_path)}" - ) - else: - # Handle as image - # Check if the file is a HEIC image if file_ext in [".heic", ".heif"]: - # Convert HEIC to JPEG for better gallery compatibility try: with Image.open(file_path) as img: - # Convert to RGB if necessary (HEIC can have different color modes) - if img.mode not in ("RGB", "L"): - img = img.convert("RGB") - - # Create JPEG filename + if img.mode not in ("RGB", "L"): img = img.convert("RGB") base_name = os.path.splitext(os.path.basename(file_path))[0] - dst_path = os.path.join( - target_dir_images, f"{base_name}.jpg" - ) - - # Save as JPEG with high quality + dst_path = os.path.join(target_dir_images, f"{base_name}.jpg") img.save(dst_path, "JPEG", quality=95) image_paths.append(dst_path) - print( - f"Converted HEIC to JPEG: {os.path.basename(file_path)} -> {os.path.basename(dst_path)}" - ) except Exception as e: - print(f"Error converting HEIC file {file_path}: {e}") - # Fall back to copying as is - dst_path = os.path.join( - target_dir_images, os.path.basename(file_path) - ) + dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) shutil.copy(file_path, dst_path) image_paths.append(dst_path) else: - # Regular image files - copy as is - dst_path = os.path.join( - target_dir_images, os.path.basename(file_path) - ) + dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) shutil.copy(file_path, dst_path) image_paths.append(dst_path) - # Sort final images for gallery image_paths = sorted(image_paths) - - end_time = time.time() - print( - f"Files processed to {target_dir_images}; took {end_time - start_time:.3f} seconds" - ) return target_dir, image_paths - -# ------------------------------------------------------------------------- -# 3) Update gallery on upload -# ------------------------------------------------------------------------- -def update_gallery_on_upload(input_video, input_images, s_time_interval=1.0): - """ - Whenever user uploads or changes files, immediately handle them - and show in the gallery. Return (target_dir, image_paths). - If nothing is uploaded, returns "None" and empty list. - """ - if not input_video and not input_images: +def update_gallery_on_upload(input_files, s_time_interval=1.0): + if not input_files: return None, None, None, None - target_dir, image_paths = handle_uploads(input_video, input_images, s_time_interval) + target_dir, image_paths = handle_uploads(input_files, s_time_interval) return ( None, target_dir, image_paths, - "Upload complete. Click 'Reconstruct' to begin 3D processing.", + "Upload complete. Click 'Start Reconstruction' to begin 3D processing.", ) - -# ------------------------------------------------------------------------- -# 4) Reconstruction: uses the target_dir plus any viz parameters -# ------------------------------------------------------------------------- @spaces.GPU(duration=120) def gradio_demo( target_dir, @@ -497,384 +1091,80 @@ def gradio_demo( filter_white_bg=False, apply_mask=True, show_mesh=True, + enable_segmentation=False, + text_prompt=DEFAULT_TEXT_PROMPT, + use_sam=True, ): - """ - Perform reconstruction using the already-created target_dir/images. - """ if not os.path.isdir(target_dir) or target_dir == "None": - return None, "No valid target directory found. Please upload first.", None, None + return None, None, "No valid target directory found. Please upload first.", None, None, None, None, None, "", None, None, None start_time = time.time() gc.collect() torch.cuda.empty_cache() - # Prepare frame_filter dropdown target_dir_images = os.path.join(target_dir, "images") - all_files = ( - sorted(os.listdir(target_dir_images)) - if os.path.isdir(target_dir_images) - else [] - ) - all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)] - frame_filter_choices = ["All"] + all_files + all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else [] + all_files_display = [f"{i}: {filename}" for i, filename in enumerate(all_files)] + frame_filter_choices = ["All"] + all_files_display - print("Running MapAnything model...") with torch.no_grad(): - predictions, processed_data = run_model(target_dir, apply_mask) + predictions, processed_data, segmented_glb = run_model( + target_dir, apply_mask, True, filter_black_bg, filter_white_bg, + enable_segmentation, text_prompt, use_sam + ) - # Save predictions prediction_save_path = os.path.join(target_dir, "predictions.npz") np.savez(prediction_save_path, **predictions) - # Handle None frame_filter - if frame_filter is None: - frame_filter = "All" + if frame_filter is None: frame_filter = "All" - # Build a GLB file name glbfile = os.path.join( target_dir, f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb", ) - # Convert predictions to GLB glbscene = predictions_to_glb( predictions, filter_by_frames=frame_filter, show_cam=show_cam, mask_black_bg=filter_black_bg, mask_white_bg=filter_white_bg, - as_mesh=show_mesh, # Use the show_mesh parameter + as_mesh=show_mesh, ) glbscene.export(file_obj=glbfile) - # Cleanup del predictions gc.collect() torch.cuda.empty_cache() - end_time = time.time() - print(f"Total time: {end_time - start_time:.2f} seconds") - log_msg = ( - f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization." - ) + log_msg = f"Reconstruction successful ({len(all_files)} frames)" - # Populate visualization tabs with processed data - depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs( - processed_data - ) + depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs(processed_data) + depth_selector, normal_selector, measure_selector = update_view_selectors(processed_data) - # Update view selectors based on available views - depth_selector, normal_selector, measure_selector = update_view_selectors( - processed_data - ) + output_reconstruction = segmented_glb if (enable_segmentation and segmented_glb is not None) else glbfile return ( - glbfile, + glbfile, # Raw 3D + output_reconstruction, # 3D View (Segmented if applicable) log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True), processed_data, depth_vis, normal_vis, measure_img, - "", # measure_text (empty initially) + "", depth_selector, normal_selector, measure_selector, ) - -# ------------------------------------------------------------------------- -# 5) Helper functions for UI resets + re-visualization -# ------------------------------------------------------------------------- -def colorize_depth(depth_map, mask=None): - """Convert depth map to colorized visualization with optional mask""" - if depth_map is None: - return None - - # Normalize depth to 0-1 range - depth_normalized = depth_map.copy() - valid_mask = depth_normalized > 0 - - # Apply additional mask if provided (for background filtering) - if mask is not None: - valid_mask = valid_mask & mask - - if valid_mask.sum() > 0: - valid_depths = depth_normalized[valid_mask] - p5 = np.percentile(valid_depths, 5) - p95 = np.percentile(valid_depths, 95) - - depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5) - - # Apply colormap - import matplotlib.pyplot as plt - - colormap = plt.cm.turbo_r - colored = colormap(depth_normalized) - colored = (colored[:, :, :3] * 255).astype(np.uint8) - - # Set invalid pixels to white - colored[~valid_mask] = [255, 255, 255] - - return colored - - -def colorize_normal(normal_map, mask=None): - """Convert normal map to colorized visualization with optional mask""" - if normal_map is None: - return None - - # Create a copy for modification - normal_vis = normal_map.copy() - - # Apply mask if provided (set masked areas to [0, 0, 0] which becomes grey after normalization) - if mask is not None: - invalid_mask = ~mask - normal_vis[invalid_mask] = [0, 0, 0] # Set invalid areas to zero - - # Normalize normals to [0, 1] range for visualization - normal_vis = (normal_vis + 1.0) / 2.0 - normal_vis = (normal_vis * 255).astype(np.uint8) - - return normal_vis - - -def process_predictions_for_visualization( - predictions, views, high_level_config, filter_black_bg=False, filter_white_bg=False -): - """Extract depth, normal, and 3D points from predictions for visualization""" - processed_data = {} - - # Process each view - for view_idx, view in enumerate(views): - # Get image - image = rgb(view["img"], norm_type=high_level_config["data_norm_type"]) - - # Get predicted points - pred_pts3d = predictions["world_points"][view_idx] - - # Initialize data for this view - view_data = { - "image": image[0], - "points3d": pred_pts3d, - "depth": None, - "normal": None, - "mask": None, - } - - # Start with the final mask from predictions - mask = predictions["final_mask"][view_idx].copy() - - # Apply black background filtering if enabled - if filter_black_bg: - # Get the image colors (ensure they're in 0-255 range) - view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0] - # Filter out black background pixels (sum of RGB < 16) - black_bg_mask = view_colors.sum(axis=2) >= 16 - mask = mask & black_bg_mask - - # Apply white background filtering if enabled - if filter_white_bg: - # Get the image colors (ensure they're in 0-255 range) - view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0] - # Filter out white background pixels (all RGB > 240) - white_bg_mask = ~( - (view_colors[:, :, 0] > 240) - & (view_colors[:, :, 1] > 240) - & (view_colors[:, :, 2] > 240) - ) - mask = mask & white_bg_mask - - view_data["mask"] = mask - view_data["depth"] = predictions["depth"][view_idx].squeeze() - - normals, _ = points_to_normals(pred_pts3d, mask=view_data["mask"]) - view_data["normal"] = normals - - processed_data[view_idx] = view_data - - return processed_data - - -def reset_measure(processed_data): - """Reset measure points""" - if processed_data is None or len(processed_data) == 0: - return None, [], "" - - # Return the first view image - first_view = list(processed_data.values())[0] - return first_view["image"], [], "" - - -def measure( - processed_data, measure_points, current_view_selector, event: gr.SelectData -): - """Handle measurement on images""" - try: - print(f"Measure function called with selector: {current_view_selector}") - - if processed_data is None or len(processed_data) == 0: - return None, [], "No data available" - - # Use the currently selected view instead of always using the first view - try: - current_view_index = int(current_view_selector.split()[1]) - 1 - except: - current_view_index = 0 - - print(f"Using view index: {current_view_index}") - - # Get view data safely - if current_view_index < 0 or current_view_index >= len(processed_data): - current_view_index = 0 - - view_keys = list(processed_data.keys()) - current_view = processed_data[view_keys[current_view_index]] - - if current_view is None: - return None, [], "No view data available" - - point2d = event.index[0], event.index[1] - print(f"Clicked point: {point2d}") - - # Check if the clicked point is in a masked area (prevent interaction) - if ( - current_view["mask"] is not None - and 0 <= point2d[1] < current_view["mask"].shape[0] - and 0 <= point2d[0] < current_view["mask"].shape[1] - ): - # Check if the point is in a masked (invalid) area - if not current_view["mask"][point2d[1], point2d[0]]: - print(f"Clicked point {point2d} is in masked area, ignoring click") - # Always return image with mask overlay - masked_image, _ = update_measure_view( - processed_data, current_view_index - ) - return ( - masked_image, - measure_points, - 'Cannot measure on masked areas (shown in grey)', - ) - - measure_points.append(point2d) - - # Get image with mask overlay and ensure it's valid - image, _ = update_measure_view(processed_data, current_view_index) - if image is None: - return None, [], "No image available" - - image = image.copy() - points3d = current_view["points3d"] - - # Ensure image is in uint8 format for proper cv2 operations - try: - if image.dtype != np.uint8: - if image.max() <= 1.0: - # Image is in [0, 1] range, convert to [0, 255] - image = (image * 255).astype(np.uint8) - else: - # Image is already in [0, 255] range - image = image.astype(np.uint8) - except Exception as e: - print(f"Image conversion error: {e}") - return None, [], f"Image conversion error: {e}" - - # Draw circles for points - try: - for p in measure_points: - if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]: - image = cv2.circle( - image, p, radius=5, color=(255, 0, 0), thickness=2 - ) - except Exception as e: - print(f"Drawing error: {e}") - return None, [], f"Drawing error: {e}" - - depth_text = "" - try: - for i, p in enumerate(measure_points): - if ( - current_view["depth"] is not None - and 0 <= p[1] < current_view["depth"].shape[0] - and 0 <= p[0] < current_view["depth"].shape[1] - ): - d = current_view["depth"][p[1], p[0]] - depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n" - else: - # Use Z coordinate of 3D points if depth not available - if ( - points3d is not None - and 0 <= p[1] < points3d.shape[0] - and 0 <= p[0] < points3d.shape[1] - ): - z = points3d[p[1], p[0], 2] - depth_text += f"- **P{i + 1} Z-coord: {z:.2f}m.**\n" - except Exception as e: - print(f"Depth text error: {e}") - depth_text = f"Error computing depth: {e}\n" - - if len(measure_points) == 2: - try: - point1, point2 = measure_points - # Draw line - if ( - 0 <= point1[0] < image.shape[1] - and 0 <= point1[1] < image.shape[0] - and 0 <= point2[0] < image.shape[1] - and 0 <= point2[1] < image.shape[0] - ): - image = cv2.line( - image, point1, point2, color=(255, 0, 0), thickness=2 - ) - - # Compute 3D distance - distance_text = "- **Distance: Unable to compute**" - if ( - points3d is not None - and 0 <= point1[1] < points3d.shape[0] - and 0 <= point1[0] < points3d.shape[1] - and 0 <= point2[1] < points3d.shape[0] - and 0 <= point2[0] < points3d.shape[1] - ): - try: - p1_3d = points3d[point1[1], point1[0]] - p2_3d = points3d[point2[1], point2[0]] - distance = np.linalg.norm(p1_3d - p2_3d) - distance_text = f"- **Distance: {distance:.2f}m**" - except Exception as e: - print(f"Distance computation error: {e}") - distance_text = f"- **Distance computation error: {e}**" - - measure_points = [] - text = depth_text + distance_text - print(f"Measurement complete: {text}") - return [image, measure_points, text] - except Exception as e: - print(f"Final measurement error: {e}") - return None, [], f"Measurement error: {e}" - else: - print(f"Single point measurement: {depth_text}") - return [image, measure_points, depth_text] - - except Exception as e: - print(f"Overall measure function error: {e}") - return None, [], f"Measure function error: {e}" - - def clear_fields(): - """ - Clears the 3D viewer, the stored target_dir, and empties the gallery. - """ - return None - + return None, None def update_log(): - """ - Display a quick log message while waiting. - """ return "Loading and Reconstructing..." - def update_visualization( target_dir, frame_filter, @@ -884,30 +1174,12 @@ def update_visualization( filter_white_bg=False, show_mesh=True, ): - """ - Reload saved predictions from npz, create (or reuse) the GLB for new parameters, - and return it for the 3D viewer. If is_example == "True", skip. - """ - - # If it's an example click, skip as requested - if is_example == "True": - return ( - gr.update(), - "No reconstruction available. Please click the Reconstruct button first.", - ) - - if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): - return ( - gr.update(), - "No reconstruction available. Please click the Reconstruct button first.", - ) + if is_example == "True" or not target_dir or target_dir == "None" or not os.path.isdir(target_dir): + return gr.update(), gr.update(), "No reconstruction available. Please run reconstruction first." predictions_path = os.path.join(target_dir, "predictions.npz") if not os.path.exists(predictions_path): - return ( - gr.update(), - f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.", - ) + return gr.update(), gr.update(), "No reconstruction available. Please run reconstruction first." loaded = np.load(predictions_path, allow_pickle=True) predictions = {key: loaded[key] for key in loaded.keys()} @@ -928,11 +1200,7 @@ def update_visualization( ) glbscene.export(file_obj=glbfile) - return ( - glbfile, - "Visualization updated.", - ) - + return glbfile, gr.update(), "Visualization updated." def update_all_views_on_filter_change( target_dir, @@ -943,11 +1211,6 @@ def update_all_views_on_filter_change( normal_view_selector, measure_view_selector, ): - """ - Update all individual view tabs when background filtering checkboxes change. - This regenerates the processed data with new filtering and updates all views. - """ - # Check if we have a valid target directory and predictions if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): return processed_data, None, None, None, [] @@ -956,722 +1219,288 @@ def update_all_views_on_filter_change( return processed_data, None, None, None, [] try: - # Load the original predictions and views loaded = np.load(predictions_path, allow_pickle=True) predictions = {key: loaded[key] for key in loaded.keys()} - # Load images using MapAnything's load_images function image_folder_path = os.path.join(target_dir, "images") views = load_images(image_folder_path) - # Regenerate processed data with new filtering settings new_processed_data = process_predictions_for_visualization( predictions, views, high_level_config, filter_black_bg, filter_white_bg ) - # Get current view indices - try: - depth_view_idx = ( - int(depth_view_selector.split()[1]) - 1 if depth_view_selector else 0 - ) - except: - depth_view_idx = 0 + try: depth_view_idx = int(depth_view_selector.split()[1]) - 1 if depth_view_selector else 0 + except: depth_view_idx = 0 - try: - normal_view_idx = ( - int(normal_view_selector.split()[1]) - 1 if normal_view_selector else 0 - ) - except: - normal_view_idx = 0 + try: normal_view_idx = int(normal_view_selector.split()[1]) - 1 if normal_view_selector else 0 + except: normal_view_idx = 0 - try: - measure_view_idx = ( - int(measure_view_selector.split()[1]) - 1 - if measure_view_selector - else 0 - ) - except: - measure_view_idx = 0 + try: measure_view_idx = int(measure_view_selector.split()[1]) - 1 if measure_view_selector else 0 + except: measure_view_idx = 0 - # Update all views with new filtered data depth_vis = update_depth_view(new_processed_data, depth_view_idx) normal_vis = update_normal_view(new_processed_data, normal_view_idx) measure_img, _ = update_measure_view(new_processed_data, measure_view_idx) return new_processed_data, depth_vis, normal_vis, measure_img, [] - except Exception as e: - print(f"Error updating views on filter change: {e}") return processed_data, None, None, None, [] - -# ------------------------------------------------------------------------- -# Example scene functions -# ------------------------------------------------------------------------- +# Example Scenes def get_scene_info(examples_dir): - """Get information about scenes in the examples directory""" import glob - scenes = [] - if not os.path.exists(examples_dir): - return scenes - + if not os.path.exists(examples_dir): return scenes for scene_folder in sorted(os.listdir(examples_dir)): scene_path = os.path.join(examples_dir, scene_folder) if os.path.isdir(scene_path): - # Find all image files in the scene folder image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"] image_files = [] for ext in image_extensions: image_files.extend(glob.glob(os.path.join(scene_path, ext))) image_files.extend(glob.glob(os.path.join(scene_path, ext.upper()))) - if image_files: - # Sort images and get the first one for thumbnail image_files = sorted(image_files) - first_image = image_files[0] - num_images = len(image_files) - - scenes.append( - { - "name": scene_folder, - "path": scene_path, - "thumbnail": first_image, - "num_images": num_images, - "image_files": image_files, - } - ) - + scenes.append({ + "name": scene_folder, + "path": scene_path, + "thumbnail": image_files[0], + "num_images": len(image_files), + "image_files": image_files, + }) return scenes - def load_example_scene(scene_name, examples_dir="examples"): - """Load a scene from examples directory""" scenes = get_scene_info(examples_dir) - - # Find the selected scene - selected_scene = None - for scene in scenes: - if scene["name"] == scene_name: - selected_scene = scene - break - - if selected_scene is None: - return None, None, None, "Scene not found" - - # Create file-like objects for the unified upload system - # Convert image file paths to the format expected by unified_upload - file_objects = [] - for image_path in selected_scene["image_files"]: - file_objects.append(image_path) - - # Create target directory and copy images using the unified upload system + selected_scene = next((s for s in scenes if s["name"] == scene_name), None) + if selected_scene is None: return None, None, None, "Scene not found" + + file_objects = [image_path for image_path in selected_scene["image_files"]] target_dir, image_paths = handle_uploads(file_objects, 1.0) - + return ( - None, # Clear reconstruction output - target_dir, # Set target directory - image_paths, # Set gallery - f"Loaded scene '{scene_name}' with {selected_scene['num_images']} images. Click 'Reconstruct' to begin 3D processing.", + None, + target_dir, + image_paths, + f"Loaded scene '{scene_name}'. Click 'Start Reconstruction' to begin.", ) -# ------------------------------------------------------------------------- -# 6) Build Gradio UI -# ------------------------------------------------------------------------- +# ============================================================================ +# Gradio UI Construction +# ============================================================================ + theme = get_gradio_theme() +CUSTOM_CSS = GRADIO_CSS + """ +.gradio-container { max-width: 100% !important; } +.gallery-container { max-height: 350px !important; overflow-y: auto !important; } +.file-preview { max-height: 200px !important; overflow-y: auto !important; } +.video-container { max-height: 300px !important; } +.textbox-container { max-height: 100px !important; } +.tab-content { min-height: 550px !important; } +""" + with gr.Blocks() as demo: - # State variables for the tabbed interface is_example = gr.Textbox(label="is_example", visible=False, value="None") - num_images = gr.Textbox(label="num_images", visible=False, value="None") processed_data_state = gr.State(value=None) measure_points_state = gr.State(value=[]) - current_view_index = gr.State(value=0) # Track current view index for navigation - target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None") - with gr.Row(): - with gr.Column(scale=2): - # Unified upload component for both videos and images + with gr.Row(equal_height=False): + # Input Area + with gr.Column(scale=1, min_width=300): + gr.Markdown("### Input") + unified_upload = gr.File( file_count="multiple", label="Upload Video or Images", interactive=True, file_types=["image", "video"], ) + with gr.Row(): s_time_interval = gr.Slider( - minimum=0.1, - maximum=5.0, - value=1.0, - step=0.1, + minimum=0.1, maximum=5.0, value=1.0, step=0.1, label="Video sample time interval (take a sample every x sec.)", - interactive=True, - visible=True, - scale=3, - ) - resample_btn = gr.Button( - "Resample Video", - visible=False, - variant="secondary", - scale=1, + interactive=True, scale=3 ) + resample_btn = gr.Button("Resample Video", visible=False, variant="secondary", scale=1) image_gallery = gr.Gallery( - label="Preview", - columns=4, - height="300px", - show_download_button=True, - object_fit="contain", - preview=True, - ) - - clear_uploads_btn = gr.ClearButton( - [unified_upload, image_gallery], - value="Clear Uploads", - variant="secondary", - size="sm", + label="Preview", columns=4, height="300px", + show_download_button=True, object_fit="contain", preview=True ) - - with gr.Column(scale=4): - with gr.Column(): - gr.Markdown( - "**Metric 3D Reconstruction (Point Cloud and Camera Poses)**" - ) - log_output = gr.Markdown( - "Please upload a video or images, then click Reconstruct.", - elem_classes=["custom-log"], - ) - - # Add tabbed interface similar to MoGe - with gr.Tabs(): - with gr.Tab("3D View"): - reconstruction_output = gr.Model3D( - height=520, - zoom_speed=0.5, - pan_speed=0.5, - clear_color=[0.0, 0.0, 0.0, 0.0], - key="persistent_3d_viewer", - elem_id="reconstruction_3d_viewer", - ) - with gr.Tab("Depth"): - with gr.Row(elem_classes=["navigation-row"]): - prev_depth_btn = gr.Button("◀ Previous", size="sm", scale=1) - depth_view_selector = gr.Dropdown( - choices=["View 1"], - value="View 1", - label="Select View", - scale=2, - interactive=True, - allow_custom_value=True, - ) - next_depth_btn = gr.Button("Next ▶", size="sm", scale=1) - depth_map = gr.Image( - type="numpy", - label="Colorized Depth Map", - format="png", - interactive=False, - ) - with gr.Tab("Normal"): - with gr.Row(elem_classes=["navigation-row"]): - prev_normal_btn = gr.Button( - "◀ Previous", size="sm", scale=1 - ) - normal_view_selector = gr.Dropdown( - choices=["View 1"], - value="View 1", - label="Select View", - scale=2, - interactive=True, - allow_custom_value=True, - ) - next_normal_btn = gr.Button("Next ▶", size="sm", scale=1) - normal_map = gr.Image( - type="numpy", - label="Normal Map", - format="png", - interactive=False, - ) - with gr.Tab("Measure"): - gr.Markdown(MEASURE_INSTRUCTIONS_HTML) - with gr.Row(elem_classes=["navigation-row"]): - prev_measure_btn = gr.Button( - "◀ Previous", size="sm", scale=1 - ) - measure_view_selector = gr.Dropdown( - choices=["View 1"], - value="View 1", - label="Select View", - scale=2, - interactive=True, - allow_custom_value=True, - ) - next_measure_btn = gr.Button("Next ▶", size="sm", scale=1) - measure_image = gr.Image( - type="numpy", - show_label=False, - format="webp", - interactive=False, - sources=[], - ) - gr.Markdown( - "**Note:** Light-grey areas indicate regions with no depth information where measurements cannot be taken." - ) - measure_text = gr.Markdown("") - + with gr.Row(): - submit_btn = gr.Button("Reconstruct", scale=1, variant="primary") - clear_btn = gr.ClearButton( - [ - unified_upload, - reconstruction_output, - log_output, - target_dir_output, - image_gallery, - ], - scale=1, - ) + submit_btn = gr.Button("Start Reconstruction", variant="primary", scale=2) + clear_btn = gr.ClearButton([unified_upload, image_gallery, target_dir_output], value="Clear", scale=1) + + # Output Area + with gr.Column(scale=2, min_width=600): + gr.Markdown("### Output") + + with gr.Tabs(): + with gr.Tab("Raw 3D"): + raw_3d_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5, clear_color=[0.0, 0.0, 0.0, 0.0]) + with gr.Tab("3D View"): + reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5, clear_color=[0.0, 0.0, 0.0, 0.0]) + with gr.Tab("Depth"): + with gr.Row(elem_classes=["navigation-row"]): + prev_depth_btn = gr.Button("Previous", size="sm", scale=1) + depth_view_selector = gr.Dropdown(choices=["View 1"], value="View 1", label="Select View", scale=2, interactive=True, allow_custom_value=True) + next_depth_btn = gr.Button("Next", size="sm", scale=1) + depth_map = gr.Image(type="numpy", label="Colorized Depth Map", format="png", interactive=False) + with gr.Tab("Normal"): + with gr.Row(elem_classes=["navigation-row"]): + prev_normal_btn = gr.Button("Previous", size="sm", scale=1) + normal_view_selector = gr.Dropdown(choices=["View 1"], value="View 1", label="Select View", scale=2, interactive=True, allow_custom_value=True) + next_normal_btn = gr.Button("Next", size="sm", scale=1) + normal_map = gr.Image(type="numpy", label="Normal Map", format="png", interactive=False) + with gr.Tab("Measure"): + gr.Markdown(MEASURE_INSTRUCTIONS_HTML) + with gr.Row(elem_classes=["navigation-row"]): + prev_measure_btn = gr.Button("Previous", size="sm", scale=1) + measure_view_selector = gr.Dropdown(choices=["View 1"], value="View 1", label="Select View", scale=2, interactive=True, allow_custom_value=True) + next_measure_btn = gr.Button("Next", size="sm", scale=1) + measure_image = gr.Image(type="numpy", show_label=False, format="webp", interactive=False, sources=[]) + gr.Markdown("**Note:** Light-grey areas indicate regions with no depth information where measurements cannot be taken.") + measure_text = gr.Markdown("") + + log_output = gr.Textbox( + value="Please upload images or video, then click 'Start Reconstruction'", + label="Status Information", interactive=False, lines=1, max_lines=1 + ) - with gr.Row(): - frame_filter = gr.Dropdown( - choices=["All"], value="All", label="Show Points from Frame" + with gr.Accordion("Advanced Options", open=False): + with gr.Row(equal_height=False): + with gr.Column(scale=1, min_width=300): + gr.Markdown("#### Visualization Parameters") + frame_filter = gr.Dropdown(choices=["All"], value="All", label="Show Points from Frame") + show_cam = gr.Checkbox(label="Show Camera", value=True) + show_mesh = gr.Checkbox(label="Show Mesh", value=True) + filter_black_bg = gr.Checkbox(label="Filter Black Background", value=False) + filter_white_bg = gr.Checkbox(label="Filter White Background", value=False) + + with gr.Column(scale=1, min_width=300): + gr.Markdown("#### Reconstruction Parameters") + apply_mask_checkbox = gr.Checkbox(label="Apply mask for predicted ambiguous depth classes & edges", value=True) + + gr.Markdown("#### Segmentation Parameters") + enable_segmentation = gr.Checkbox(label="Enable Semantic Segmentation", value=False) + use_sam_checkbox = gr.Checkbox(label="Use SAM for Accurate Segmentation", value=True) + text_prompt = gr.Textbox( + value=DEFAULT_TEXT_PROMPT, + label="Detect Objects (separated by .)", + placeholder="e.g. chair . table . sofa", lines=2, max_lines=2 ) - with gr.Column(): - gr.Markdown("### Pointcloud Options: (live updates)") - show_cam = gr.Checkbox(label="Show Camera", value=True) - show_mesh = gr.Checkbox(label="Show Mesh", value=True) - filter_black_bg = gr.Checkbox( - label="Filter Black Background", value=False - ) - filter_white_bg = gr.Checkbox( - label="Filter White Background", value=False - ) - gr.Markdown("### Reconstruction Options: (updated on next run)") - apply_mask_checkbox = gr.Checkbox( - label="Apply mask for predicted ambiguous depth classes & edges", - value=True, - ) - # ---------------------- Example Scenes Section ---------------------- - gr.Markdown("## Example Scenes (lists all scenes in the examples folder)") - gr.Markdown("Click any thumbnail to load the scene for reconstruction.") - - # Get scene information - scenes = get_scene_info("examples") - - # Create thumbnail grid (4 columns, N rows) - if scenes: - for i in range(0, len(scenes), 4): # Process 4 scenes per row - with gr.Row(): - for j in range(4): - scene_idx = i + j - if scene_idx < len(scenes): - scene = scenes[scene_idx] - with gr.Column(scale=1, elem_classes=["clickable-thumbnail"]): - # Clickable thumbnail - scene_img = gr.Image( - value=scene["thumbnail"], - height=150, - interactive=False, - show_label=False, - elem_id=f"scene_thumb_{scene['name']}", - sources=[], - ) - - # Scene name and image count as text below thumbnail - gr.Markdown( - f"**{scene['name']}** \n {scene['num_images']} images", - elem_classes=["scene-info"], - ) - - # Connect thumbnail click to load scene - scene_img.select( - fn=lambda name=scene["name"]: load_example_scene(name), - outputs=[ - reconstruction_output, - target_dir_output, - image_gallery, - log_output, - ], - ) - else: - # Empty column to maintain grid structure - with gr.Column(scale=1): - pass - - # ------------------------------------------------------------------------- - # "Reconstruct" button logic: - # - Clear fields - # - Update log - # - gradio_demo(...) with the existing target_dir - # - Then set is_example = "False" - # ------------------------------------------------------------------------- - submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then( - fn=update_log, inputs=[], outputs=[log_output] - ).then( - fn=gradio_demo, - inputs=[ - target_dir_output, - frame_filter, - show_cam, - filter_black_bg, - filter_white_bg, - apply_mask_checkbox, - show_mesh, - ], - outputs=[ - reconstruction_output, - log_output, - frame_filter, - processed_data_state, - depth_map, - normal_map, - measure_image, - measure_text, - depth_view_selector, - normal_view_selector, - measure_view_selector, - ], - ).then( - fn=lambda: "False", - inputs=[], - outputs=[is_example], # set is_example to "False" - ) - - # ------------------------------------------------------------------------- - # Real-time Visualization Updates - # ------------------------------------------------------------------------- - frame_filter.change( - update_visualization, - [ - target_dir_output, - frame_filter, - show_cam, - is_example, - filter_black_bg, - filter_white_bg, - show_mesh, - ], - [reconstruction_output, log_output], - ) - show_cam.change( - update_visualization, - [ - target_dir_output, - frame_filter, - show_cam, - is_example, - ], - [reconstruction_output, log_output], - ) - filter_black_bg.change( - update_visualization, - [ - target_dir_output, - frame_filter, - show_cam, - is_example, - filter_black_bg, - filter_white_bg, - ], - [reconstruction_output, log_output], - ).then( - fn=update_all_views_on_filter_change, - inputs=[ - target_dir_output, - filter_black_bg, - filter_white_bg, - processed_data_state, - depth_view_selector, - normal_view_selector, - measure_view_selector, - ], - outputs=[ - processed_data_state, - depth_map, - normal_map, - measure_image, - measure_points_state, - ], - ) - filter_white_bg.change( - update_visualization, - [ - target_dir_output, - frame_filter, - show_cam, - is_example, - filter_black_bg, - filter_white_bg, - show_mesh, - ], - [reconstruction_output, log_output], - ).then( - fn=update_all_views_on_filter_change, - inputs=[ - target_dir_output, - filter_black_bg, - filter_white_bg, - processed_data_state, - depth_view_selector, - normal_view_selector, - measure_view_selector, - ], - outputs=[ - processed_data_state, - depth_map, - normal_map, - measure_image, - measure_points_state, - ], - ) - - show_mesh.change( - update_visualization, - [ - target_dir_output, - frame_filter, - show_cam, - is_example, - filter_black_bg, - filter_white_bg, - show_mesh, - ], - [reconstruction_output, log_output], - ) - - # ------------------------------------------------------------------------- - # Auto-update gallery whenever user uploads or changes their files - # ------------------------------------------------------------------------- - def update_gallery_on_unified_upload(files, interval): - if not files: - return None, None, None - target_dir, image_paths = handle_uploads(files, interval) - return ( - target_dir, - image_paths, - "Upload complete. Click 'Reconstruct' to begin 3D processing.", - ) + with gr.Row(): + detect_all_btn = gr.Button("Detect All", size="sm") + restore_default_btn = gr.Button("Default", size="sm") + + with gr.Accordion("Example Scenes", open=False): + scenes = get_scene_info("examples") + if scenes: + for i in range(0, len(scenes), 4): + with gr.Row(equal_height=True): + for j in range(4): + scene_idx = i + j + if scene_idx < len(scenes): + scene = scenes[scene_idx] + with gr.Column(scale=1, min_width=150): + scene_img = gr.Image(value=scene["thumbnail"], height=150, interactive=False, show_label=False, sources=[], container=False) + gr.Markdown(f"**{scene['name']}** ({scene['num_images']} images)", elem_classes=["text-center"]) + scene_img.select( + fn=lambda name=scene["name"]: load_example_scene(name), + outputs=[reconstruction_output, target_dir_output, image_gallery, log_output] + ) + + # Connect UI Components + detect_all_btn.click(fn=lambda: COMMON_OBJECTS_PROMPT, outputs=[text_prompt]) + restore_default_btn.click(fn=lambda: DEFAULT_TEXT_PROMPT, outputs=[text_prompt]) def show_resample_button(files): - """Show the resample button only if there are uploaded files containing videos""" - if not files: - return gr.update(visible=False) - - # Check if any uploaded files are videos - video_extensions = [ - ".mp4", - ".avi", - ".mov", - ".mkv", - ".wmv", - ".flv", - ".webm", - ".m4v", - ".3gp", - ] - has_video = False - - for file_data in files: - if isinstance(file_data, dict) and "name" in file_data: - file_path = file_data["name"] - else: - file_path = str(file_data) - - file_ext = os.path.splitext(file_path)[1].lower() - if file_ext in video_extensions: - has_video = True - break - + if not files: return gr.update(visible=False) + video_extensions = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"] + has_video = any(os.path.splitext(str(f["name"] if isinstance(f, dict) else f))[1].lower() in video_extensions for f in files) return gr.update(visible=has_video) - def hide_resample_button(): - """Hide the resample button after use""" - return gr.update(visible=False) - - def resample_video_with_new_interval(files, new_interval, current_target_dir): - """Resample video with new slider value""" - if not files: - return ( - current_target_dir, - None, - "No files to resample.", - gr.update(visible=False), - ) - - # Check if we have videos to resample - video_extensions = [ - ".mp4", - ".avi", - ".mov", - ".mkv", - ".wmv", - ".flv", - ".webm", - ".m4v", - ".3gp", - ] - has_video = any( - os.path.splitext( - str(file_data["name"] if isinstance(file_data, dict) else file_data) - )[1].lower() - in video_extensions - for file_data in files - ) - - if not has_video: - return ( - current_target_dir, - None, - "No videos found to resample.", - gr.update(visible=False), - ) - - # Clean up old target directory if it exists - if ( - current_target_dir - and current_target_dir != "None" - and os.path.exists(current_target_dir) - ): - shutil.rmtree(current_target_dir) - - # Process files with new interval - target_dir, image_paths = handle_uploads(files, new_interval) - - return ( - target_dir, - image_paths, - f"Video resampled with {new_interval}s interval. Click 'Reconstruct' to begin 3D processing.", - gr.update(visible=False), - ) - unified_upload.change( - fn=update_gallery_on_unified_upload, + fn=update_gallery_on_upload, inputs=[unified_upload, s_time_interval], - outputs=[target_dir_output, image_gallery, log_output], + outputs=[raw_3d_output, target_dir_output, image_gallery, log_output] ).then( - fn=show_resample_button, - inputs=[unified_upload], - outputs=[resample_btn], + fn=show_resample_button, inputs=[unified_upload], outputs=[resample_btn] ) + s_time_interval.change(fn=show_resample_button, inputs=[unified_upload], outputs=[resample_btn]) - # Show resample button when slider changes (only if files are uploaded) - s_time_interval.change( - fn=show_resample_button, - inputs=[unified_upload], - outputs=[resample_btn], - ) + def resample_video_with_new_interval(files, new_interval, current_target_dir): + if not files: return current_target_dir, None, "No files.", gr.update(visible=False) + if current_target_dir and current_target_dir != "None" and os.path.exists(current_target_dir): shutil.rmtree(current_target_dir) + target_dir, image_paths = handle_uploads(files, new_interval) + return target_dir, image_paths, f"Video resampled with {new_interval}s interval.", gr.update(visible=False) - # Handle resample button click resample_btn.click( fn=resample_video_with_new_interval, inputs=[unified_upload, s_time_interval, target_dir_output], - outputs=[target_dir_output, image_gallery, log_output, resample_btn], + outputs=[target_dir_output, image_gallery, log_output, resample_btn] ) - # ------------------------------------------------------------------------- - # Measure tab functionality - # ------------------------------------------------------------------------- - measure_image.select( - fn=measure, - inputs=[processed_data_state, measure_points_state, measure_view_selector], - outputs=[measure_image, measure_points_state, measure_text], - ) - - # ------------------------------------------------------------------------- - # Navigation functionality for Depth, Normal, and Measure tabs - # ------------------------------------------------------------------------- - - # Depth tab navigation - prev_depth_btn.click( - fn=lambda processed_data, current_selector: navigate_depth_view( - processed_data, current_selector, -1 - ), - inputs=[processed_data_state, depth_view_selector], - outputs=[depth_view_selector, depth_map], - ) - - next_depth_btn.click( - fn=lambda processed_data, current_selector: navigate_depth_view( - processed_data, current_selector, 1 - ), - inputs=[processed_data_state, depth_view_selector], - outputs=[depth_view_selector, depth_map], + submit_btn.click( + fn=clear_fields, + outputs=[raw_3d_output, reconstruction_output] + ).then( + fn=update_log, outputs=[log_output] + ).then( + fn=gradio_demo, + inputs=[ + target_dir_output, frame_filter, show_cam, filter_black_bg, filter_white_bg, + apply_mask_checkbox, show_mesh, enable_segmentation, text_prompt, use_sam_checkbox + ], + outputs=[ + raw_3d_output, reconstruction_output, log_output, frame_filter, + processed_data_state, depth_map, normal_map, measure_image, measure_text, + depth_view_selector, normal_view_selector, measure_view_selector + ] + ).then( + fn=lambda: "False", outputs=[is_example] ) - depth_view_selector.change( - fn=lambda processed_data, selector_value: ( - update_depth_view( - processed_data, - int(selector_value.split()[1]) - 1, - ) - if selector_value - else None - ), - inputs=[processed_data_state, depth_view_selector], - outputs=[depth_map], - ) + clear_btn.add([raw_3d_output, reconstruction_output, log_output]) - # Normal tab navigation - prev_normal_btn.click( - fn=lambda processed_data, current_selector: navigate_normal_view( - processed_data, current_selector, -1 - ), - inputs=[processed_data_state, normal_view_selector], - outputs=[normal_view_selector, normal_map], - ) + for comp in [frame_filter, show_cam, show_mesh]: + comp.change( + fn=update_visualization, + inputs=[target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh], + outputs=[raw_3d_output, reconstruction_output, log_output] + ) - next_normal_btn.click( - fn=lambda processed_data, current_selector: navigate_normal_view( - processed_data, current_selector, 1 - ), - inputs=[processed_data_state, normal_view_selector], - outputs=[normal_view_selector, normal_map], - ) + for comp in [filter_black_bg, filter_white_bg]: + comp.change( + fn=update_visualization, + inputs=[target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh], + outputs=[raw_3d_output, reconstruction_output, log_output] + ).then( + fn=update_all_views_on_filter_change, + inputs=[target_dir_output, filter_black_bg, filter_white_bg, processed_data_state, depth_view_selector, normal_view_selector, measure_view_selector], + outputs=[processed_data_state, depth_map, normal_map, measure_image, measure_points_state] + ) - normal_view_selector.change( - fn=lambda processed_data, selector_value: ( - update_normal_view( - processed_data, - int(selector_value.split()[1]) - 1, - ) - if selector_value - else None - ), - inputs=[processed_data_state, normal_view_selector], - outputs=[normal_map], - ) + measure_image.select(fn=measure, inputs=[processed_data_state, measure_points_state, measure_view_selector], outputs=[measure_image, measure_points_state, measure_text]) - # Measure tab navigation - prev_measure_btn.click( - fn=lambda processed_data, current_selector: navigate_measure_view( - processed_data, current_selector, -1 - ), - inputs=[processed_data_state, measure_view_selector], - outputs=[measure_view_selector, measure_image, measure_points_state], - ) + prev_depth_btn.click(fn=lambda d, s: navigate_depth_view(d, s, -1), inputs=[processed_data_state, depth_view_selector], outputs=[depth_view_selector, depth_map]) + next_depth_btn.click(fn=lambda d, s: navigate_depth_view(d, s, 1), inputs=[processed_data_state, depth_view_selector], outputs=[depth_view_selector, depth_map]) + depth_view_selector.change(fn=lambda d, v: update_depth_view(d, int(v.split()[1])-1) if v else None, inputs=[processed_data_state, depth_view_selector], outputs=[depth_map]) - next_measure_btn.click( - fn=lambda processed_data, current_selector: navigate_measure_view( - processed_data, current_selector, 1 - ), - inputs=[processed_data_state, measure_view_selector], - outputs=[measure_view_selector, measure_image, measure_points_state], - ) + prev_normal_btn.click(fn=lambda d, s: navigate_normal_view(d, s, -1), inputs=[processed_data_state, normal_view_selector], outputs=[normal_view_selector, normal_map]) + next_normal_btn.click(fn=lambda d, s: navigate_normal_view(d, s, 1), inputs=[processed_data_state, normal_view_selector], outputs=[normal_view_selector, normal_map]) + normal_view_selector.change(fn=lambda d, v: update_normal_view(d, int(v.split()[1])-1) if v else None, inputs=[processed_data_state, normal_view_selector], outputs=[normal_map]) - measure_view_selector.change( - fn=lambda processed_data, selector_value: ( - update_measure_view(processed_data, int(selector_value.split()[1]) - 1) - if selector_value - else (None, []) - ), - inputs=[processed_data_state, measure_view_selector], - outputs=[measure_image, measure_points_state], - ) + prev_measure_btn.click(fn=lambda d, s: navigate_measure_view(d, s, -1), inputs=[processed_data_state, measure_view_selector], outputs=[measure_view_selector, measure_image, measure_points_state]) + next_measure_btn.click(fn=lambda d, s: navigate_measure_view(d, s, 1), inputs=[processed_data_state, measure_view_selector], outputs=[measure_view_selector, measure_image, measure_points_state]) + measure_view_selector.change(fn=lambda d, v: update_measure_view(d, int(v.split()[1])-1) if v else (None, []), inputs=[processed_data_state, measure_view_selector], outputs=[measure_image, measure_points_state]) - # ------------------------------------------------------------------------- - # Acknowledgement section - # ------------------------------------------------------------------------- gr.HTML(get_acknowledgements_html()) - demo.queue(max_size=20).launch(theme=theme, css=GRADIO_CSS, show_error=True, share=True, ssr_mode=False) \ No newline at end of file +if __name__ == "__main__": + demo.queue(max_size=20).launch(theme=theme, css=CUSTOM_CSS, show_error=True, share=True, ssr_mode=False) \ No newline at end of file