import gc import os import shutil import sys import time import uuid from datetime import datetime os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import cv2 import gradio as gr import numpy as np import spaces import torch from PIL import Image from pillow_heif import register_heif_opener import rerun as rr try: import rerun.blueprint as rrb except ImportError: rrb = None from gradio_rerun import Rerun register_heif_opener() sys.path.append("mapanything/") from mapanything.utils.geometry import depthmap_to_world_frame, points_to_normals from mapanything.utils.hf_utils.css_and_html import ( GRADIO_CSS, MEASURE_INSTRUCTIONS_HTML, get_acknowledgements_html, get_description_html, get_gradio_theme, get_header_html, ) from mapanything.utils.hf_utils.hf_helpers import initialize_mapanything_model from mapanything.utils.hf_utils.viz import predictions_to_glb from mapanything.utils.image import load_images, rgb from typing import Iterable from gradio.themes import Soft from gradio.themes.utils import colors, fonts, sizes # ── Steel-Blue palette ────────────────────────────────────────────── colors.steel_blue = colors.Color( name="steel_blue", c50="#EBF3F8", c100="#D3E5F0", c200="#A8CCE1", c300="#7DB3D2", c400="#529AC3", c500="#4682B4", c600="#3E72A0", c700="#36638C", c800="#2E5378", c900="#264364", c950="#1E3450", ) class SteelBlueTheme(Soft): def __init__( self, *, primary_hue: colors.Color | str = colors.gray, secondary_hue: colors.Color | str = colors.steel_blue, neutral_hue: colors.Color | str = colors.slate, text_size: sizes.Size | str = sizes.text_lg, font: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("Outfit"), "Arial", "sans-serif", ), font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", ), ): super().__init__( primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue, text_size=text_size, font=font, font_mono=font_mono, ) super().set( background_fill_primary="*primary_50", background_fill_primary_dark="*primary_900", body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)", button_primary_text_color="white", button_primary_text_color_hover="white", button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)", button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)", button_secondary_text_color="black", button_secondary_text_color_hover="white", button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)", button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)", button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)", button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)", slider_color="*secondary_500", slider_color_dark="*secondary_600", block_title_text_weight="600", block_border_width="3px", block_shadow="*shadow_drop_lg", button_primary_shadow="*shadow_drop_lg", button_large_padding="11px", color_accent_soft="*primary_100", block_label_background_fill="*primary_200", ) steel_blue_theme = SteelBlueTheme() SVG_CUBE = '' SVG_CHIP = '' def html_header(): return f"""
{SVG_CUBE}

Map-Anything — v1

{SVG_CHIP} facebook/map-anything-v1 3D Reconstruction Depth Estimation Normal Maps Measurements
""" high_level_config = { "path": "configs/train.yaml", "hf_model_name": "facebook/map-anything-v1", "model_str": "mapanything", "config_overrides": [ "machine=aws", "model=mapanything", "model/task=images_only", "model.encoder.uses_torch_hub=false", ], "checkpoint_name": "model.safetensors", "config_name": "config.json", "trained_with_amp": True, "trained_with_amp_dtype": "bf16", "data_norm_type": "dinov2", "patch_size": 14, "resolution": 518, } model = None TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') os.makedirs(TMP_DIR, exist_ok=True) CUSTOM_CSS = (GRADIO_CSS or "") + r""" @import url('https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;500;600;700;800&family=IBM+Plex+Mono:wght@400;500;600&display=swap'); body, .gradio-container { font-family: 'Outfit', sans-serif !important; } footer { display: none !important; } /* ── App Header ── */ .app-header { background: linear-gradient(135deg, #1E3450 0%, #264364 30%, #3E72A0 70%, #4682B4 100%); border-radius: 16px; padding: 32px 40px; margin-bottom: 24px; position: relative; overflow: hidden; box-shadow: 0 8px 32px rgba(30, 52, 80, 0.35); } .app-header::before { content: ''; position: absolute; top: -50%; right: -20%; width: 400px; height: 400px; background: radial-gradient(circle, rgba(255, 255, 255, 0.06) 0%, transparent 70%); border-radius: 50%; } .app-header::after { content: ''; position: absolute; bottom: -30%; left: -10%; width: 300px; height: 300px; background: radial-gradient(circle, rgba(70, 130, 180, 0.15) 0%, transparent 70%); border-radius: 50%; } .header-content { display: flex; align-items: center; gap: 24px; position: relative; z-index: 1; } .header-icon-wrap { width: 64px; height: 64px; background: rgba(255, 255, 255, 0.12); border-radius: 16px; display: flex; align-items: center; justify-content: center; flex-shrink: 0; backdrop-filter: blur(8px); border: 1px solid rgba(255, 255, 255, 0.15); } /* ── Force header SVGs white in ALL modes ── */ .header-icon-wrap svg, .app-header svg { width: 36px; height: 36px; color: #ffffff !important; stroke: #ffffff !important; } .meta-badge svg { width: 14px !important; height: 14px !important; color: #ffffff !important; stroke: #ffffff !important; } .header-text h1 { font-family: 'Outfit', sans-serif; font-size: 2rem; font-weight: 700; color: #fff !important; margin: 0 0 8px 0; letter-spacing: -0.02em; line-height: 1.2; } .header-meta { display: flex; align-items: center; gap: 12px; flex-wrap: wrap; } .meta-badge { display: inline-flex; align-items: center; gap: 6px; background: rgba(255, 255, 255, 0.12); color: rgba(255, 255, 255, 0.9) !important; padding: 4px 12px; border-radius: 20px; font-family: 'IBM Plex Mono', monospace; font-size: 0.8rem; font-weight: 500; border: 1px solid rgba(255, 255, 255, 0.1); backdrop-filter: blur(4px); } .meta-sep { width: 4px; height: 4px; background: rgba(255, 255, 255, 0.35); border-radius: 50%; flex-shrink: 0; } .meta-cap { color: rgba(255, 255, 255, 0.65) !important; font-size: 0.85rem; font-weight: 400; } /* ── Page shell ── */ #app-shell { max-width: 1400px; margin: 0 auto; padding: 0 16px 40px; } /* ── Two-panel layout ── */ #left-panel { min-width: 320px; max-width: 380px; } #right-panel { flex: 1; min-width: 0; } /* ── Section labels ── */ .section-label { font-size: 0.7rem !important; font-weight: 600 !important; letter-spacing: 0.08em !important; text-transform: uppercase !important; opacity: 0.5 !important; margin-bottom: 6px !important; margin-top: 16px !important; display: block !important; } /* ── Upload zone ── */ #upload-zone .wrap { border-radius: 10px !important; min-height: 110px !important; } /* ── Gallery ── */ #preview-gallery { border-radius: 10px; overflow: hidden; } /* ── Action buttons ── */ #btn-reconstruct { width: 100% !important; font-size: 0.95rem !important; font-weight: 600 !important; padding: 12px !important; border-radius: 8px !important; } /* ── Buttons ── */ .primary { border-radius: 10px !important; font-weight: 600 !important; letter-spacing: 0.02em !important; transition: all 0.25s ease !important; font-family: 'Outfit', sans-serif !important; } .primary:hover { transform: translateY(-2px) !important; box-shadow: 0 6px 20px rgba(70, 130, 180, 0.35) !important; } .primary:active { transform: translateY(0) !important; } /* ── Log strip ── */ #log-strip { font-size: 0.82rem !important; padding: 8px 12px !important; border-radius: 6px !important; border: 1px solid var(--border-color-primary) !important; background: var(--background-fill-secondary) !important; min-height: 36px !important; } /* ── Viewer tabs ── */ #viewer-tabs .tab-nav button { font-size: 0.8rem !important; font-weight: 500 !important; padding: 6px 14px !important; } #viewer-tabs > .tabitem { padding: 0 !important; } /* ── Tab transitions ── */ .gradio-tabitem { animation: tabFadeIn 0.35s ease-out; } @keyframes tabFadeIn { from { opacity: 0; transform: translateY(6px); } to { opacity: 1; transform: translateY(0); } } /* ── Navigation rows inside tabs ── */ .nav-row { align-items: center !important; gap: 6px !important; margin-bottom: 8px !important; } .nav-row button { min-width: 80px !important; } /* ── Options panel ── */ #options-panel { border: 1px solid var(--border-color-primary); border-radius: 10px; padding: 16px; margin-top: 12px; } #options-panel .gr-markdown h3 { font-size: 0.72rem !important; font-weight: 600 !important; letter-spacing: 0.07em !important; text-transform: uppercase !important; opacity: 0.5 !important; margin: 14px 0 6px !important; } #options-panel .gr-markdown h3:first-child { margin-top: 0 !important; } /* ── Frame filter ── */ #frame-filter { margin-top: 12px; } /* ── Examples section ── */ #examples-section { margin-top: 36px; padding-top: 24px; border-top: 1px solid var(--border-color-primary); } #examples-section h2 { font-size: 1.1rem !important; font-weight: 600 !important; margin-bottom: 4px !important; } #examples-section .scene-caption { font-size: 0.75rem !important; text-align: center !important; opacity: 0.65 !important; margin-top: 4px !important; } .scene-thumb img { border-radius: 8px; transition: opacity .15s; } .scene-thumb img:hover { opacity: .85; } /* ── Measure note ── */ .measure-note { font-size: 0.78rem !important; opacity: 0.6 !important; margin-top: 6px !important; } #col-container { margin: 0 auto; max-width: 960px; } /* ── Accordion ── */ .gradio-accordion { border-radius: 10px !important; border: 1px solid rgba(70, 130, 180, 0.2) !important; } .gradio-accordion > .label-wrap { border-radius: 10px !important; } /* ── Labels ── */ label { font-weight: 600 !important; font-family: 'Outfit', sans-serif !important; } /* ── Slider ── */ .gradio-slider input[type="range"] { accent-color: #4682B4 !important; } /* ── Scrollbar ── */ ::-webkit-scrollbar { width: 8px; height: 8px; } ::-webkit-scrollbar-track { background: rgba(70, 130, 180, 0.06); border-radius: 4px; } ::-webkit-scrollbar-thumb { background: linear-gradient(135deg, #4682B4, #3E72A0); border-radius: 4px; } ::-webkit-scrollbar-thumb:hover { background: linear-gradient(135deg, #3E72A0, #2E5378); } /* ── Dark-mode overrides for header (keep text/SVG white) ── */ @media (prefers-color-scheme: dark) { .app-header { background: linear-gradient(135deg, #1E3450 0%, #264364 30%, #3E72A0 70%, #4682B4 100%); } .header-text h1 { color: #fff !important; } .header-icon-wrap svg, .app-header svg, .meta-badge svg { color: #ffffff !important; stroke: #ffffff !important; } .meta-badge { color: rgba(255, 255, 255, 0.9) !important; } .meta-cap { color: rgba(255, 255, 255, 0.65) !important; } } /* Also handle Gradio's own .dark class */ .dark .header-text h1 { color: #fff !important; } .dark .header-icon-wrap svg, .dark .app-header svg, .dark .meta-badge svg { color: #ffffff !important; stroke: #ffffff !important; } .dark .meta-badge { color: rgba(255, 255, 255, 0.9) !important; } .dark .meta-cap { color: rgba(255, 255, 255, 0.65) !important; } /* ── Responsive ── */ @media (max-width: 768px) { .app-header { padding: 20px 24px; } .header-text h1 { font-size: 1.5rem; } .header-content { flex-direction: column; align-items: flex-start; gap: 16px; } .header-meta { gap: 8px; } } """ def predictions_to_rrd(predictions, glbfile, target_dir, frame_filter="All", show_cam=True): run_id = str(uuid.uuid4()) timestamp = datetime.now().strftime("%Y-%m-%dT%H%M%S") rrd_path = os.path.join(target_dir, f"mapanything_{timestamp}.rrd") rec = None if hasattr(rr, "new_recording"): rec = rr.new_recording(application_id="MapAnything-3D-Viewer", recording_id=run_id) elif hasattr(rr, "RecordingStream"): rec = rr.RecordingStream(application_id="MapAnything-3D-Viewer", recording_id=run_id) else: rr.init("MapAnything-3D-Viewer", recording_id=run_id, spawn=False) rec = rr rec.log("world", rr.Clear(recursive=True), static=True) rec.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, static=True) try: rec.log("world/axes/x", rr.Arrows3D(vectors=[[0.5, 0, 0]], colors=[[255, 0, 0]]), static=True) rec.log("world/axes/y", rr.Arrows3D(vectors=[[0, 0.5, 0]], colors=[[0, 255, 0]]), static=True) rec.log("world/axes/z", rr.Arrows3D(vectors=[[0, 0, 0.5]], colors=[[0, 0, 255]]), static=True) except Exception: pass rec.log("world/model", rr.Asset3D(path=glbfile), static=True) if show_cam and "extrinsic" in predictions and "intrinsic" in predictions: try: extrinsics = predictions["extrinsic"] intrinsics = predictions["intrinsic"] for i, (ext, intr) in enumerate(zip(extrinsics, intrinsics)): translation = ext[:3, 3] rotation_mat = ext[:3, :3] rec.log( f"world/cameras/cam_{i:03d}", rr.Transform3D(translation=translation, mat3x3=rotation_mat), static=True, ) fx, fy = intr[0, 0], intr[1, 1] cx, cy = intr[0, 2], intr[1, 2] if "images" in predictions and i < len(predictions["images"]): h, w = predictions["images"][i].shape[:2] else: h, w = 518, 518 rec.log( f"world/cameras/cam_{i:03d}/image", rr.Pinhole(focal_length=[fx, fy], principal_point=[cx, cy], width=w, height=h), static=True, ) if "images" in predictions and i < len(predictions["images"]): img = predictions["images"][i] if img.dtype != np.uint8: img = (np.clip(img, 0, 1) * 255).astype(np.uint8) rec.log(f"world/cameras/cam_{i:03d}/image/rgb", rr.Image(img), static=True) except Exception as e: print(f"Camera logging failed (non-fatal): {e}") if "world_points" in predictions and "images" in predictions: try: world_points = predictions["world_points"] images = predictions["images"] final_mask = predictions.get("final_mask") all_points, all_colors = [], [] for i in range(len(world_points)): pts = world_points[i] img = images[i] mask = final_mask[i].astype(bool) if final_mask is not None else np.ones(pts.shape[:2], dtype=bool) pts_flat = pts[mask] img_flat = img[mask] if img_flat.dtype != np.uint8: img_flat = (np.clip(img_flat, 0, 1) * 255).astype(np.uint8) all_points.append(pts_flat) all_colors.append(img_flat) if all_points: all_points = np.concatenate(all_points, axis=0) all_colors = np.concatenate(all_colors, axis=0) max_pts = 500_000 if len(all_points) > max_pts: idx = np.random.choice(len(all_points), max_pts, replace=False) all_points = all_points[idx] all_colors = all_colors[idx] rec.log("world/point_cloud", rr.Points3D(positions=all_points, colors=all_colors, radii=0.002), static=True) except Exception as e: print(f"Point cloud logging failed (non-fatal): {e}") if rrb is not None: try: blueprint = rrb.Blueprint( rrb.Spatial3DView(origin="/world", name="3D View"), collapse_panels=True, ) rec.send_blueprint(blueprint) except Exception as e: print(f"Blueprint creation failed (non-fatal): {e}") rec.save(rrd_path) return rrd_path @spaces.GPU(duration=120) def run_model(target_dir, apply_mask=True, mask_edges=True, filter_black_bg=False, filter_white_bg=False): global model import torch print(f"Processing images from {target_dir}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if model is None: model = initialize_mapanything_model(high_level_config, device) else: model = model.to(device) model.eval() print("Loading images...") image_folder_path = os.path.join(target_dir, "images") views = load_images(image_folder_path) print(f"Loaded {len(views)} images") if len(views) == 0: raise ValueError("No images found. Check your upload.") print("Running inference...") outputs = model.infer(views, apply_mask=apply_mask, mask_edges=True, memory_efficient_inference=False) predictions = {} extrinsic_list, intrinsic_list, world_points_list = [], [], [] depth_maps_list, images_list, final_mask_list = [], [], [] for pred in outputs: depthmap_torch = pred["depth_z"][0].squeeze(-1) intrinsics_torch = pred["intrinsics"][0] camera_pose_torch = pred["camera_poses"][0] pts3d_computed, valid_mask = depthmap_to_world_frame(depthmap_torch, intrinsics_torch, camera_pose_torch) mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool) if "mask" in pred else np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool) mask = mask & valid_mask.cpu().numpy() image = pred["img_no_norm"][0].cpu().numpy() extrinsic_list.append(camera_pose_torch.cpu().numpy()) intrinsic_list.append(intrinsics_torch.cpu().numpy()) world_points_list.append(pts3d_computed.cpu().numpy()) depth_maps_list.append(depthmap_torch.cpu().numpy()) images_list.append(image) final_mask_list.append(mask) predictions["extrinsic"] = np.stack(extrinsic_list, axis=0) predictions["intrinsic"] = np.stack(intrinsic_list, axis=0) predictions["world_points"] = np.stack(world_points_list, axis=0) depth_maps = np.stack(depth_maps_list, axis=0) if len(depth_maps.shape) == 3: depth_maps = depth_maps[..., np.newaxis] predictions["depth"] = depth_maps predictions["images"] = np.stack(images_list, axis=0) predictions["final_mask"] = np.stack(final_mask_list, axis=0) processed_data = process_predictions_for_visualization(predictions, views, high_level_config, filter_black_bg, filter_white_bg) torch.cuda.empty_cache() return predictions, processed_data def update_view_selectors(processed_data): choices = [f"View {i + 1}" for i in range(len(processed_data))] if processed_data else ["View 1"] return ( gr.Dropdown(choices=choices, value=choices[0]), gr.Dropdown(choices=choices, value=choices[0]), gr.Dropdown(choices=choices, value=choices[0]), ) def get_view_data_by_index(processed_data, view_index): if not processed_data: return None view_keys = list(processed_data.keys()) view_index = max(0, min(view_index, len(view_keys) - 1)) return processed_data[view_keys[view_index]] def update_depth_view(processed_data, view_index): view_data = get_view_data_by_index(processed_data, view_index) if view_data is None or view_data["depth"] is None: return None return colorize_depth(view_data["depth"], mask=view_data.get("mask")) def update_normal_view(processed_data, view_index): view_data = get_view_data_by_index(processed_data, view_index) if view_data is None or view_data["normal"] is None: return None return colorize_normal(view_data["normal"], mask=view_data.get("mask")) def update_measure_view(processed_data, view_index): view_data = get_view_data_by_index(processed_data, view_index) if view_data is None: return None, [] image = view_data["image"].copy() if image.dtype != np.uint8: image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8) if view_data["mask"] is not None: invalid_mask = ~view_data["mask"] if invalid_mask.any(): overlay_color = np.array([255, 220, 220], dtype=np.uint8) alpha = 0.5 for c in range(3): image[:, :, c] = np.where( invalid_mask, (1 - alpha) * image[:, :, c] + alpha * overlay_color[c], image[:, :, c], ).astype(np.uint8) return image, [] def navigate_depth_view(processed_data, current_selector_value, direction): if not processed_data: return "View 1", None try: current_view = int(current_selector_value.split()[1]) - 1 except Exception: current_view = 0 new_view = (current_view + direction) % len(processed_data) return f"View {new_view + 1}", update_depth_view(processed_data, new_view) def navigate_normal_view(processed_data, current_selector_value, direction): if not processed_data: return "View 1", None try: current_view = int(current_selector_value.split()[1]) - 1 except Exception: current_view = 0 new_view = (current_view + direction) % len(processed_data) return f"View {new_view + 1}", update_normal_view(processed_data, new_view) def navigate_measure_view(processed_data, current_selector_value, direction): if not processed_data: return "View 1", None, [] try: current_view = int(current_selector_value.split()[1]) - 1 except Exception: current_view = 0 new_view = (current_view + direction) % len(processed_data) measure_image, measure_points = update_measure_view(processed_data, new_view) return f"View {new_view + 1}", measure_image, measure_points def populate_visualization_tabs(processed_data): if not processed_data: return None, None, None, [] return ( update_depth_view(processed_data, 0), update_normal_view(processed_data, 0), update_measure_view(processed_data, 0)[0], [], ) def handle_uploads(unified_upload, s_time_interval=1.0): start_time = time.time() gc.collect() torch.cuda.empty_cache() timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") target_dir = f"input_images_{timestamp}" target_dir_images = os.path.join(target_dir, "images") if os.path.exists(target_dir): shutil.rmtree(target_dir) os.makedirs(target_dir_images) image_paths = [] video_extensions = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"] if unified_upload is not None: for file_data in unified_upload: file_path = file_data["name"] if isinstance(file_data, dict) and "name" in file_data else str(file_data) file_ext = os.path.splitext(file_path)[1].lower() if file_ext in video_extensions: vs = cv2.VideoCapture(file_path) fps = vs.get(cv2.CAP_PROP_FPS) frame_interval = int(fps * s_time_interval) count, video_frame_num = 0, 0 while True: gotit, frame = vs.read() if not gotit: break count += 1 if count % frame_interval == 0: base_name = os.path.splitext(os.path.basename(file_path))[0] image_path = os.path.join(target_dir_images, f"{base_name}_{video_frame_num:06}.png") cv2.imwrite(image_path, frame) image_paths.append(image_path) video_frame_num += 1 vs.release() print(f"Extracted {video_frame_num} frames from: {os.path.basename(file_path)}") elif file_ext in [".heic", ".heif"]: try: with Image.open(file_path) as img: if img.mode not in ("RGB", "L"): img = img.convert("RGB") base_name = os.path.splitext(os.path.basename(file_path))[0] dst_path = os.path.join(target_dir_images, f"{base_name}.jpg") img.save(dst_path, "JPEG", quality=95) image_paths.append(dst_path) except Exception as e: print(f"Error converting HEIC {file_path}: {e}") dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) shutil.copy(file_path, dst_path) image_paths.append(dst_path) else: dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) shutil.copy(file_path, dst_path) image_paths.append(dst_path) image_paths = sorted(image_paths) print(f"Files processed to {target_dir_images}; took {time.time() - start_time:.3f}s") return target_dir, image_paths @spaces.GPU(duration=120) def gradio_demo(target_dir, frame_filter="All", show_cam=True, filter_black_bg=False, filter_white_bg=False, apply_mask=True, show_mesh=True): if not os.path.isdir(target_dir) or target_dir == "None": return None, "No valid target directory found. Please upload first.", None, None start_time = time.time() gc.collect() torch.cuda.empty_cache() target_dir_images = os.path.join(target_dir, "images") all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else [] all_files_labeled = [f"{i}: {filename}" for i, filename in enumerate(all_files)] frame_filter_choices = ["All"] + all_files_labeled print("Running MapAnything model...") with torch.no_grad(): predictions, processed_data = run_model(target_dir, apply_mask) np.savez(os.path.join(target_dir, "predictions.npz"), **predictions) if frame_filter is None: frame_filter = "All" glbfile = os.path.join( target_dir, f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb", ) glbscene = predictions_to_glb( predictions, filter_by_frames=frame_filter, show_cam=show_cam, mask_black_bg=filter_black_bg, mask_white_bg=filter_white_bg, as_mesh=show_mesh, ) glbscene.export(file_obj=glbfile) rrd_path = predictions_to_rrd(predictions, glbfile, target_dir, frame_filter, show_cam) del predictions gc.collect() torch.cuda.empty_cache() print(f"Total time: {time.time() - start_time:.2f}s") log_msg = f"✅ Reconstruction complete — {len(all_files)} frames processed." depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs(processed_data) depth_selector, normal_selector, measure_selector = update_view_selectors(processed_data) return ( rrd_path, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True), processed_data, depth_vis, normal_vis, measure_img, "", depth_selector, normal_selector, measure_selector, ) def colorize_depth(depth_map, mask=None): if depth_map is None: return None depth_normalized = depth_map.copy() valid_mask = depth_normalized > 0 if mask is not None: valid_mask = valid_mask & mask if valid_mask.sum() > 0: valid_depths = depth_normalized[valid_mask] p5, p95 = np.percentile(valid_depths, 5), np.percentile(valid_depths, 95) depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5) import matplotlib.pyplot as plt colored = (plt.cm.turbo_r(depth_normalized)[:, :, :3] * 255).astype(np.uint8) colored[~valid_mask] = [255, 255, 255] return colored def colorize_normal(normal_map, mask=None): if normal_map is None: return None normal_vis = normal_map.copy() if mask is not None: normal_vis[~mask] = [0, 0, 0] return ((normal_vis + 1.0) / 2.0 * 255).astype(np.uint8) def process_predictions_for_visualization(predictions, views, high_level_config, filter_black_bg=False, filter_white_bg=False): processed_data = {} for view_idx, view in enumerate(views): image = rgb(view["img"], norm_type=high_level_config["data_norm_type"]) pred_pts3d = predictions["world_points"][view_idx] mask = predictions["final_mask"][view_idx].copy() if filter_black_bg: view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0] mask = mask & (view_colors.sum(axis=2) >= 16) if filter_white_bg: view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0] mask = mask & ~( (view_colors[:, :, 0] > 240) & (view_colors[:, :, 1] > 240) & (view_colors[:, :, 2] > 240) ) normals, _ = points_to_normals(pred_pts3d, mask=mask) processed_data[view_idx] = { "image": image[0], "points3d": pred_pts3d, "depth": predictions["depth"][view_idx].squeeze(), "normal": normals, "mask": mask, } return processed_data def measure(processed_data, measure_points, current_view_selector, event: gr.SelectData): try: if not processed_data: return None, [], "No data available" try: current_view_index = int(current_view_selector.split()[1]) - 1 except Exception: current_view_index = 0 current_view_index = max(0, min(current_view_index, len(processed_data) - 1)) current_view = processed_data[list(processed_data.keys())[current_view_index]] if current_view is None: return None, [], "No view data available" point2d = event.index[0], event.index[1] if ( current_view["mask"] is not None and 0 <= point2d[1] < current_view["mask"].shape[0] and 0 <= point2d[0] < current_view["mask"].shape[1] ): if not current_view["mask"][point2d[1], point2d[0]]: masked_image, _ = update_measure_view(processed_data, current_view_index) return masked_image, measure_points, 'Cannot measure on masked areas' measure_points.append(point2d) image, _ = update_measure_view(processed_data, current_view_index) if image is None: return None, [], "No image available" image = image.copy() if image.dtype != np.uint8: image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8) points3d = current_view["points3d"] for p in measure_points: if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]: image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2) depth_text = "" for i, p in enumerate(measure_points): if ( current_view["depth"] is not None and 0 <= p[1] < current_view["depth"].shape[0] and 0 <= p[0] < current_view["depth"].shape[1] ): depth_text += f"- **P{i + 1} depth: {current_view['depth'][p[1], p[0]]:.2f}m**\n" elif ( points3d is not None and 0 <= p[1] < points3d.shape[0] and 0 <= p[0] < points3d.shape[1] ): depth_text += f"- **P{i + 1} Z-coord: {points3d[p[1], p[0], 2]:.2f}m**\n" if len(measure_points) == 2: point1, point2 = measure_points if all( 0 <= point1[0] < image.shape[1] and 0 <= point1[1] < image.shape[0] and 0 <= point2[0] < image.shape[1] and 0 <= point2[1] < image.shape[0] for _ in [1] ): image = cv2.line(image, point1, point2, color=(255, 0, 0), thickness=2) distance_text = "- **Distance: Unable to compute**" if points3d is not None and all( 0 <= p[1] < points3d.shape[0] and 0 <= p[0] < points3d.shape[1] for p in [point1, point2] ): try: distance = np.linalg.norm( points3d[point1[1], point1[0]] - points3d[point2[1], point2[0]] ) distance_text = f"- **Distance: {distance:.2f}m**" except Exception as e: distance_text = f"- **Distance error: {e}**" return [image, [], depth_text + distance_text] return [image, measure_points, depth_text] except Exception as e: print(f"Measure error: {e}") return None, [], f"Error: {e}" def clear_fields(): return None def update_log(): return "⏳ Loading and reconstructing…" def update_visualization( target_dir, frame_filter, show_cam, is_example, filter_black_bg=False, filter_white_bg=False, show_mesh=True, ): if is_example == "True": return gr.update(), "No reconstruction available. Please click Reconstruct first." if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): return gr.update(), "No reconstruction available. Please upload first." predictions_path = os.path.join(target_dir, "predictions.npz") if not os.path.exists(predictions_path): return gr.update(), "No reconstruction found. Please run Reconstruct first." loaded = np.load(predictions_path, allow_pickle=True) predictions = {key: loaded[key] for key in loaded.keys()} glbfile = os.path.join( target_dir, f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb", ) if not os.path.exists(glbfile): glbscene = predictions_to_glb( predictions, filter_by_frames=frame_filter, show_cam=show_cam, mask_black_bg=filter_black_bg, mask_white_bg=filter_white_bg, as_mesh=show_mesh, ) glbscene.export(file_obj=glbfile) rrd_path = predictions_to_rrd(predictions, glbfile, target_dir, frame_filter, show_cam) return rrd_path, "Visualization updated." def update_all_views_on_filter_change( target_dir, filter_black_bg, filter_white_bg, processed_data, depth_view_selector, normal_view_selector, measure_view_selector, ): if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): return processed_data, None, None, None, [] predictions_path = os.path.join(target_dir, "predictions.npz") if not os.path.exists(predictions_path): return processed_data, None, None, None, [] try: loaded = np.load(predictions_path, allow_pickle=True) predictions = {key: loaded[key] for key in loaded.keys()} views = load_images(os.path.join(target_dir, "images")) new_processed_data = process_predictions_for_visualization( predictions, views, high_level_config, filter_black_bg, filter_white_bg, ) def safe_idx(sel): try: return int(sel.split()[1]) - 1 except Exception: return 0 depth_vis = update_depth_view(new_processed_data, safe_idx(depth_view_selector)) normal_vis = update_normal_view(new_processed_data, safe_idx(normal_view_selector)) measure_img, _ = update_measure_view(new_processed_data, safe_idx(measure_view_selector)) return new_processed_data, depth_vis, normal_vis, measure_img, [] except Exception as e: print(f"Filter change error: {e}") return processed_data, None, None, None, [] def get_scene_info(examples_dir): import glob scenes = [] if not os.path.exists(examples_dir): return scenes for scene_folder in sorted(os.listdir(examples_dir)): scene_path = os.path.join(examples_dir, scene_folder) if os.path.isdir(scene_path): image_files = [] for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"]: image_files.extend(glob.glob(os.path.join(scene_path, ext))) image_files.extend(glob.glob(os.path.join(scene_path, ext.upper()))) if image_files: image_files = sorted(image_files) scenes.append({ "name": scene_folder, "path": scene_path, "thumbnail": image_files[0], "num_images": len(image_files), "image_files": image_files, }) return scenes def load_example_scene(scene_name, examples_dir="examples"): scenes = get_scene_info(examples_dir) selected_scene = next((s for s in scenes if s["name"] == scene_name), None) if selected_scene is None: return None, None, None, "Scene not found" target_dir, image_paths = handle_uploads(selected_scene["image_files"], 1.0) return None, target_dir, image_paths, f"Loaded '{scene_name}' — {selected_scene['num_images']} images. Click Reconstruct." with gr.Blocks() as demo: is_example = gr.Textbox(visible=False, value="None") num_images = gr.Textbox(visible=False, value="None") processed_data_state = gr.State(value=None) measure_points_state = gr.State(value=[]) target_dir_output = gr.Textbox(visible=False, value="None") with gr.Column(elem_id="app-shell"): # ── New styled header ── gr.HTML(html_header()) with gr.Row(equal_height=False): # ── Left Panel ── with gr.Column(elem_id="left-panel", scale=0): unified_upload = gr.File( file_count="multiple", label="Upload Images/Videos", file_types=["image", "video"], height="150", ) with gr.Row(): s_time_interval = gr.Slider( minimum=0.1, maximum=5.0, value=1.0, step=0.1, label="Video interval (sec)", scale=3, ) resample_btn = gr.Button("Resample", visible=False, variant="secondary", scale=1) image_gallery = gr.Gallery( columns=2, height="150", ) gr.ClearButton( [unified_upload, image_gallery], value="Clear uploads", variant="secondary", size="sm", ) submit_btn = gr.Button("Reconstruct", variant="primary") with gr.Accordion("Options", open=False): gr.Markdown("### Point Cloud") show_cam = gr.Checkbox(label="Show cameras", value=True) show_mesh = gr.Checkbox(label="Show mesh", value=True) filter_black_bg = gr.Checkbox(label="Filter black background", value=False) filter_white_bg = gr.Checkbox(label="Filter white background", value=False) gr.Markdown("### Reconstruction (next run)") apply_mask_checkbox = gr.Checkbox( label="Apply ambiguous-depth mask & edges", value=True, ) # ── Right Panel ── with gr.Column(elem_id="right-panel", scale=1): log_output = gr.Markdown( "Upload a video or images, then click **Reconstruct**.", elem_id="log-strip", ) with gr.Tabs(elem_id="viewer-tabs"): with gr.Tab("3D View"): reconstruction_output = Rerun( label="Rerun 3D Viewer", height=672, ) with gr.Tab("Depth"): with gr.Row(elem_classes=["nav-row"]): prev_depth_btn = gr.Button("◀ Prev", size="sm", scale=1) depth_view_selector = gr.Dropdown( choices=["View 1"], value="View 1", label="View", scale=3, interactive=True, allow_custom_value=True, show_label=False, ) next_depth_btn = gr.Button("Next ▶", size="sm", scale=1) depth_map = gr.Image( type="numpy", label="Depth Map", format="png", interactive=False, ) with gr.Tab("Normal"): with gr.Row(elem_classes=["nav-row"]): prev_normal_btn = gr.Button("◀ Prev", size="sm", scale=1) normal_view_selector = gr.Dropdown( choices=["View 1"], value="View 1", label="View", scale=3, interactive=True, allow_custom_value=True, show_label=False, ) next_normal_btn = gr.Button("Next ▶", size="sm", scale=1) normal_map = gr.Image( type="numpy", label="Normal Map", format="png", interactive=False, ) with gr.Tab("Measure"): gr.Markdown(MEASURE_INSTRUCTIONS_HTML) with gr.Row(elem_classes=["nav-row"]): prev_measure_btn = gr.Button("◀ Prev", size="sm", scale=1) measure_view_selector = gr.Dropdown( choices=["View 1"], value="View 1", label="View", scale=3, interactive=True, allow_custom_value=True, show_label=False, ) next_measure_btn = gr.Button("Next ▶", size="sm", scale=1) measure_image = gr.Image( type="numpy", show_label=False, format="webp", interactive=False, sources=[], ) gr.Markdown( "Light-grey areas have no depth — measurements cannot be placed there.", elem_classes=["measure-note"], ) measure_text = gr.Markdown("") with gr.Column(): frame_filter = gr.Dropdown( choices=["All"], value="All", label="Filter by Frame", show_label=True, ) with gr.Column(elem_id="examples-section"): gr.Markdown("## Example Scenes") gr.Markdown("Click a thumbnail to load the scene, then press **Reconstruct**.") scenes = get_scene_info("examples") if scenes: for i in range(0, len(scenes), 4): with gr.Row(): for j in range(4): idx = i + j if idx < len(scenes): scene = scenes[idx] with gr.Column(scale=1, min_width=140, elem_classes=["scene-thumb"]): scene_img = gr.Image( value=scene["thumbnail"], height=130, interactive=False, show_label=False, sources=[], ) gr.Markdown( f"**{scene['name']}** \n{scene['num_images']} imgs", elem_classes=["scene-caption"], ) scene_img.select( fn=lambda name=scene["name"]: load_example_scene(name), outputs=[reconstruction_output, target_dir_output, image_gallery, log_output], ) else: with gr.Column(scale=1, min_width=140): pass submit_btn.click( fn=clear_fields, inputs=[], outputs=[reconstruction_output], ).then( fn=update_log, inputs=[], outputs=[log_output], ).then( fn=gradio_demo, inputs=[target_dir_output, frame_filter, show_cam, filter_black_bg, filter_white_bg, apply_mask_checkbox, show_mesh], outputs=[ reconstruction_output, log_output, frame_filter, processed_data_state, depth_map, normal_map, measure_image, measure_text, depth_view_selector, normal_view_selector, measure_view_selector, ], ).then(fn=lambda: "False", inputs=[], outputs=[is_example]) for trigger_inputs, trigger in [ ([target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh], frame_filter.change), ([target_dir_output, frame_filter, show_cam, is_example], show_cam.change), ([target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh], show_mesh.change), ]: trigger(update_visualization, trigger_inputs, [reconstruction_output, log_output]) filter_black_bg.change( update_visualization, [target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg], [reconstruction_output, log_output], ).then( update_all_views_on_filter_change, [target_dir_output, filter_black_bg, filter_white_bg, processed_data_state, depth_view_selector, normal_view_selector, measure_view_selector], [processed_data_state, depth_map, normal_map, measure_image, measure_points_state], ) filter_white_bg.change( update_visualization, [target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh], [reconstruction_output, log_output], ).then( update_all_views_on_filter_change, [target_dir_output, filter_black_bg, filter_white_bg, processed_data_state, depth_view_selector, normal_view_selector, measure_view_selector], [processed_data_state, depth_map, normal_map, measure_image, measure_points_state], ) def update_gallery_on_unified_upload(files, interval): if not files: return None, None, None target_dir, image_paths = handle_uploads(files, interval) return target_dir, image_paths, "Upload complete. Click **Reconstruct** to begin." def show_resample_button(files): if not files: return gr.update(visible=False) video_exts = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"] has_video = any( os.path.splitext(str(f["name"] if isinstance(f, dict) else f))[1].lower() in video_exts for f in files ) return gr.update(visible=has_video) def resample_video_with_new_interval(files, new_interval, current_target_dir): if not files: return current_target_dir, None, "No files to resample.", gr.update(visible=False) video_exts = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"] if not any( os.path.splitext(str(f["name"] if isinstance(f, dict) else f))[1].lower() in video_exts for f in files ): return current_target_dir, None, "No videos found.", gr.update(visible=False) if current_target_dir and current_target_dir != "None" and os.path.exists(current_target_dir): shutil.rmtree(current_target_dir) target_dir, image_paths = handle_uploads(files, new_interval) return target_dir, image_paths, f"Resampled at {new_interval}s. Click **Reconstruct**.", gr.update(visible=False) unified_upload.change( fn=update_gallery_on_unified_upload, inputs=[unified_upload, s_time_interval], outputs=[target_dir_output, image_gallery, log_output], ).then(fn=show_resample_button, inputs=[unified_upload], outputs=[resample_btn]) s_time_interval.change(fn=show_resample_button, inputs=[unified_upload], outputs=[resample_btn]) resample_btn.click( fn=resample_video_with_new_interval, inputs=[unified_upload, s_time_interval, target_dir_output], outputs=[target_dir_output, image_gallery, log_output, resample_btn], ) measure_image.select( fn=measure, inputs=[processed_data_state, measure_points_state, measure_view_selector], outputs=[measure_image, measure_points_state, measure_text], ) prev_depth_btn.click( fn=lambda pd, sel: navigate_depth_view(pd, sel, -1), inputs=[processed_data_state, depth_view_selector], outputs=[depth_view_selector, depth_map], ) next_depth_btn.click( fn=lambda pd, sel: navigate_depth_view(pd, sel, 1), inputs=[processed_data_state, depth_view_selector], outputs=[depth_view_selector, depth_map], ) depth_view_selector.change( fn=lambda pd, sel: update_depth_view(pd, int(sel.split()[1]) - 1) if sel else None, inputs=[processed_data_state, depth_view_selector], outputs=[depth_map], ) prev_normal_btn.click( fn=lambda pd, sel: navigate_normal_view(pd, sel, -1), inputs=[processed_data_state, normal_view_selector], outputs=[normal_view_selector, normal_map], ) next_normal_btn.click( fn=lambda pd, sel: navigate_normal_view(pd, sel, 1), inputs=[processed_data_state, normal_view_selector], outputs=[normal_view_selector, normal_map], ) normal_view_selector.change( fn=lambda pd, sel: update_normal_view(pd, int(sel.split()[1]) - 1) if sel else None, inputs=[processed_data_state, normal_view_selector], outputs=[normal_map], ) prev_measure_btn.click( fn=lambda pd, sel: navigate_measure_view(pd, sel, -1), inputs=[processed_data_state, measure_view_selector], outputs=[measure_view_selector, measure_image, measure_points_state], ) next_measure_btn.click( fn=lambda pd, sel: navigate_measure_view(pd, sel, 1), inputs=[processed_data_state, measure_view_selector], outputs=[measure_view_selector, measure_image, measure_points_state], ) measure_view_selector.change( fn=lambda pd, sel: update_measure_view(pd, int(sel.split()[1]) - 1) if sel else (None, []), inputs=[processed_data_state, measure_view_selector], outputs=[measure_image, measure_points_state], ) demo.queue(max_size=50).launch(css=CUSTOM_CSS, theme=steel_blue_theme, show_error=True, share=True, ssr_mode=False, mcp_server=True)