Spaces:
Running on Zero
Running on Zero
| import gc | |
| import os | |
| import shutil | |
| import sys | |
| import time | |
| import uuid | |
| from datetime import datetime | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| from PIL import Image | |
| from pillow_heif import register_heif_opener | |
| import rerun as rr | |
| try: | |
| import rerun.blueprint as rrb | |
| except ImportError: | |
| rrb = None | |
| from gradio_rerun import Rerun | |
| register_heif_opener() | |
| sys.path.append("mapanything/") | |
| from mapanything.utils.geometry import depthmap_to_world_frame, points_to_normals | |
| from mapanything.utils.hf_utils.css_and_html import ( | |
| GRADIO_CSS, | |
| MEASURE_INSTRUCTIONS_HTML, | |
| get_acknowledgements_html, | |
| get_description_html, | |
| get_gradio_theme, | |
| get_header_html, | |
| ) | |
| from mapanything.utils.hf_utils.hf_helpers import initialize_mapanything_model | |
| from mapanything.utils.hf_utils.viz import predictions_to_glb | |
| from mapanything.utils.image import load_images, rgb | |
| from typing import Iterable | |
| from gradio.themes import Soft | |
| from gradio.themes.utils import colors, fonts, sizes | |
| # ββ Steel-Blue palette ββββββββββββββββββββββββββββββββββββββββββββββ | |
| colors.steel_blue = colors.Color( | |
| name="steel_blue", | |
| c50="#EBF3F8", | |
| c100="#D3E5F0", | |
| c200="#A8CCE1", | |
| c300="#7DB3D2", | |
| c400="#529AC3", | |
| c500="#4682B4", | |
| c600="#3E72A0", | |
| c700="#36638C", | |
| c800="#2E5378", | |
| c900="#264364", | |
| c950="#1E3450", | |
| ) | |
| class SteelBlueTheme(Soft): | |
| def __init__( | |
| self, | |
| *, | |
| primary_hue: colors.Color | str = colors.gray, | |
| secondary_hue: colors.Color | str = colors.steel_blue, | |
| neutral_hue: colors.Color | str = colors.slate, | |
| text_size: sizes.Size | str = sizes.text_lg, | |
| font: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("Outfit"), "Arial", "sans-serif", | |
| ), | |
| font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", | |
| ), | |
| ): | |
| super().__init__( | |
| primary_hue=primary_hue, | |
| secondary_hue=secondary_hue, | |
| neutral_hue=neutral_hue, | |
| text_size=text_size, | |
| font=font, | |
| font_mono=font_mono, | |
| ) | |
| super().set( | |
| background_fill_primary="*primary_50", | |
| background_fill_primary_dark="*primary_900", | |
| body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", | |
| body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)", | |
| button_primary_text_color="white", | |
| button_primary_text_color_hover="white", | |
| button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", | |
| button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", | |
| button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)", | |
| button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)", | |
| button_secondary_text_color="black", | |
| button_secondary_text_color_hover="white", | |
| button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)", | |
| button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)", | |
| button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)", | |
| button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)", | |
| slider_color="*secondary_500", | |
| slider_color_dark="*secondary_600", | |
| block_title_text_weight="600", | |
| block_border_width="3px", | |
| block_shadow="*shadow_drop_lg", | |
| button_primary_shadow="*shadow_drop_lg", | |
| button_large_padding="11px", | |
| color_accent_soft="*primary_100", | |
| block_label_background_fill="*primary_200", | |
| ) | |
| steel_blue_theme = SteelBlueTheme() | |
| SVG_CUBE = '<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor"><path stroke-linecap="round" stroke-linejoin="round" d="m21 7.5-9-5.25L3 7.5m18 0-9 5.25m9-5.25v9l-9 5.25M3 7.5l9 5.25M3 7.5v9l9 5.25m0-9v9"/></svg>' | |
| SVG_CHIP = '<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor"><path stroke-linecap="round" stroke-linejoin="round" d="M8.25 3v1.5M4.5 8.25H3m18 0h-1.5M4.5 12H3m18 0h-1.5m-15 3.75H3m18 0h-1.5M8.25 19.5V21M12 3v1.5m0 15V21m3.75-18v1.5m0 15V21m-9-1.5h10.5a2.25 2.25 0 0 0 2.25-2.25V6.75a2.25 2.25 0 0 0-2.25-2.25H6.75A2.25 2.25 0 0 0 4.5 6.75v10.5a2.25 2.25 0 0 0 2.25 2.25Z"/></svg>' | |
| def html_header(): | |
| return f""" | |
| <div class="app-header"> | |
| <div class="header-content"> | |
| <div class="header-icon-wrap">{SVG_CUBE}</div> | |
| <div class="header-text"> | |
| <h1>Map-Anything — v1</h1> | |
| <div class="header-meta"> | |
| <span class="meta-badge">{SVG_CHIP} facebook/map-anything-v1</span> | |
| <span class="meta-sep"></span> | |
| <span class="meta-cap">3D Reconstruction</span> | |
| <span class="meta-sep"></span> | |
| <span class="meta-cap">Depth Estimation</span> | |
| <span class="meta-sep"></span> | |
| <span class="meta-cap">Normal Maps</span> | |
| <span class="meta-sep"></span> | |
| <span class="meta-cap">Measurements</span> | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| high_level_config = { | |
| "path": "configs/train.yaml", | |
| "hf_model_name": "facebook/map-anything-v1", | |
| "model_str": "mapanything", | |
| "config_overrides": [ | |
| "machine=aws", | |
| "model=mapanything", | |
| "model/task=images_only", | |
| "model.encoder.uses_torch_hub=false", | |
| ], | |
| "checkpoint_name": "model.safetensors", | |
| "config_name": "config.json", | |
| "trained_with_amp": True, | |
| "trained_with_amp_dtype": "bf16", | |
| "data_norm_type": "dinov2", | |
| "patch_size": 14, | |
| "resolution": 518, | |
| } | |
| model = None | |
| TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') | |
| os.makedirs(TMP_DIR, exist_ok=True) | |
| CUSTOM_CSS = (GRADIO_CSS or "") + r""" | |
| @import url('https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;500;600;700;800&family=IBM+Plex+Mono:wght@400;500;600&display=swap'); | |
| body, .gradio-container { font-family: 'Outfit', sans-serif !important; } | |
| footer { display: none !important; } | |
| /* ββ App Header ββ */ | |
| .app-header { | |
| background: linear-gradient(135deg, #1E3450 0%, #264364 30%, #3E72A0 70%, #4682B4 100%); | |
| border-radius: 16px; | |
| padding: 32px 40px; | |
| margin-bottom: 24px; | |
| position: relative; | |
| overflow: hidden; | |
| box-shadow: 0 8px 32px rgba(30, 52, 80, 0.35); | |
| } | |
| .app-header::before { | |
| content: ''; | |
| position: absolute; | |
| top: -50%; | |
| right: -20%; | |
| width: 400px; | |
| height: 400px; | |
| background: radial-gradient(circle, rgba(255, 255, 255, 0.06) 0%, transparent 70%); | |
| border-radius: 50%; | |
| } | |
| .app-header::after { | |
| content: ''; | |
| position: absolute; | |
| bottom: -30%; | |
| left: -10%; | |
| width: 300px; | |
| height: 300px; | |
| background: radial-gradient(circle, rgba(70, 130, 180, 0.15) 0%, transparent 70%); | |
| border-radius: 50%; | |
| } | |
| .header-content { | |
| display: flex; | |
| align-items: center; | |
| gap: 24px; | |
| position: relative; | |
| z-index: 1; | |
| } | |
| .header-icon-wrap { | |
| width: 64px; | |
| height: 64px; | |
| background: rgba(255, 255, 255, 0.12); | |
| border-radius: 16px; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| flex-shrink: 0; | |
| backdrop-filter: blur(8px); | |
| border: 1px solid rgba(255, 255, 255, 0.15); | |
| } | |
| /* ββ Force header SVGs white in ALL modes ββ */ | |
| .header-icon-wrap svg, | |
| .app-header svg { | |
| width: 36px; | |
| height: 36px; | |
| color: #ffffff !important; | |
| stroke: #ffffff !important; | |
| } | |
| .meta-badge svg { | |
| width: 14px !important; | |
| height: 14px !important; | |
| color: #ffffff !important; | |
| stroke: #ffffff !important; | |
| } | |
| .header-text h1 { | |
| font-family: 'Outfit', sans-serif; | |
| font-size: 2rem; | |
| font-weight: 700; | |
| color: #fff !important; | |
| margin: 0 0 8px 0; | |
| letter-spacing: -0.02em; | |
| line-height: 1.2; | |
| } | |
| .header-meta { | |
| display: flex; | |
| align-items: center; | |
| gap: 12px; | |
| flex-wrap: wrap; | |
| } | |
| .meta-badge { | |
| display: inline-flex; | |
| align-items: center; | |
| gap: 6px; | |
| background: rgba(255, 255, 255, 0.12); | |
| color: rgba(255, 255, 255, 0.9) !important; | |
| padding: 4px 12px; | |
| border-radius: 20px; | |
| font-family: 'IBM Plex Mono', monospace; | |
| font-size: 0.8rem; | |
| font-weight: 500; | |
| border: 1px solid rgba(255, 255, 255, 0.1); | |
| backdrop-filter: blur(4px); | |
| } | |
| .meta-sep { | |
| width: 4px; | |
| height: 4px; | |
| background: rgba(255, 255, 255, 0.35); | |
| border-radius: 50%; | |
| flex-shrink: 0; | |
| } | |
| .meta-cap { | |
| color: rgba(255, 255, 255, 0.65) !important; | |
| font-size: 0.85rem; | |
| font-weight: 400; | |
| } | |
| /* ββ Page shell ββ */ | |
| #app-shell { | |
| max-width: 1400px; | |
| margin: 0 auto; | |
| padding: 0 16px 40px; | |
| } | |
| /* ββ Two-panel layout ββ */ | |
| #left-panel { min-width: 320px; max-width: 380px; } | |
| #right-panel { flex: 1; min-width: 0; } | |
| /* ββ Section labels ββ */ | |
| .section-label { | |
| font-size: 0.7rem !important; | |
| font-weight: 600 !important; | |
| letter-spacing: 0.08em !important; | |
| text-transform: uppercase !important; | |
| opacity: 0.5 !important; | |
| margin-bottom: 6px !important; | |
| margin-top: 16px !important; | |
| display: block !important; | |
| } | |
| /* ββ Upload zone ββ */ | |
| #upload-zone .wrap { | |
| border-radius: 10px !important; | |
| min-height: 110px !important; | |
| } | |
| /* ββ Gallery ββ */ | |
| #preview-gallery { border-radius: 10px; overflow: hidden; } | |
| /* ββ Action buttons ββ */ | |
| #btn-reconstruct { | |
| width: 100% !important; | |
| font-size: 0.95rem !important; | |
| font-weight: 600 !important; | |
| padding: 12px !important; | |
| border-radius: 8px !important; | |
| } | |
| /* ββ Buttons ββ */ | |
| .primary { | |
| border-radius: 10px !important; | |
| font-weight: 600 !important; | |
| letter-spacing: 0.02em !important; | |
| transition: all 0.25s ease !important; | |
| font-family: 'Outfit', sans-serif !important; | |
| } | |
| .primary:hover { | |
| transform: translateY(-2px) !important; | |
| box-shadow: 0 6px 20px rgba(70, 130, 180, 0.35) !important; | |
| } | |
| .primary:active { transform: translateY(0) !important; } | |
| /* ββ Log strip ββ */ | |
| #log-strip { | |
| font-size: 0.82rem !important; | |
| padding: 8px 12px !important; | |
| border-radius: 6px !important; | |
| border: 1px solid var(--border-color-primary) !important; | |
| background: var(--background-fill-secondary) !important; | |
| min-height: 36px !important; | |
| } | |
| /* ββ Viewer tabs ββ */ | |
| #viewer-tabs .tab-nav button { | |
| font-size: 0.8rem !important; | |
| font-weight: 500 !important; | |
| padding: 6px 14px !important; | |
| } | |
| #viewer-tabs > .tabitem { padding: 0 !important; } | |
| /* ββ Tab transitions ββ */ | |
| .gradio-tabitem { animation: tabFadeIn 0.35s ease-out; } | |
| @keyframes tabFadeIn { | |
| from { opacity: 0; transform: translateY(6px); } | |
| to { opacity: 1; transform: translateY(0); } | |
| } | |
| /* ββ Navigation rows inside tabs ββ */ | |
| .nav-row { align-items: center !important; gap: 6px !important; margin-bottom: 8px !important; } | |
| .nav-row button { min-width: 80px !important; } | |
| /* ββ Options panel ββ */ | |
| #options-panel { | |
| border: 1px solid var(--border-color-primary); | |
| border-radius: 10px; | |
| padding: 16px; | |
| margin-top: 12px; | |
| } | |
| #options-panel .gr-markdown h3 { | |
| font-size: 0.72rem !important; | |
| font-weight: 600 !important; | |
| letter-spacing: 0.07em !important; | |
| text-transform: uppercase !important; | |
| opacity: 0.5 !important; | |
| margin: 14px 0 6px !important; | |
| } | |
| #options-panel .gr-markdown h3:first-child { margin-top: 0 !important; } | |
| /* ββ Frame filter ββ */ | |
| #frame-filter { margin-top: 12px; } | |
| /* ββ Examples section ββ */ | |
| #examples-section { | |
| margin-top: 36px; | |
| padding-top: 24px; | |
| border-top: 1px solid var(--border-color-primary); | |
| } | |
| #examples-section h2 { | |
| font-size: 1.1rem !important; | |
| font-weight: 600 !important; | |
| margin-bottom: 4px !important; | |
| } | |
| #examples-section .scene-caption { | |
| font-size: 0.75rem !important; | |
| text-align: center !important; | |
| opacity: 0.65 !important; | |
| margin-top: 4px !important; | |
| } | |
| .scene-thumb img { border-radius: 8px; transition: opacity .15s; } | |
| .scene-thumb img:hover { opacity: .85; } | |
| /* ββ Measure note ββ */ | |
| .measure-note { | |
| font-size: 0.78rem !important; | |
| opacity: 0.6 !important; | |
| margin-top: 6px !important; | |
| } | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 960px; | |
| } | |
| /* ββ Accordion ββ */ | |
| .gradio-accordion { | |
| border-radius: 10px !important; | |
| border: 1px solid rgba(70, 130, 180, 0.2) !important; | |
| } | |
| .gradio-accordion > .label-wrap { border-radius: 10px !important; } | |
| /* ββ Labels ββ */ | |
| label { | |
| font-weight: 600 !important; | |
| font-family: 'Outfit', sans-serif !important; | |
| } | |
| /* ββ Slider ββ */ | |
| .gradio-slider input[type="range"] { accent-color: #4682B4 !important; } | |
| /* ββ Scrollbar ββ */ | |
| ::-webkit-scrollbar { width: 8px; height: 8px; } | |
| ::-webkit-scrollbar-track { background: rgba(70, 130, 180, 0.06); border-radius: 4px; } | |
| ::-webkit-scrollbar-thumb { | |
| background: linear-gradient(135deg, #4682B4, #3E72A0); | |
| border-radius: 4px; | |
| } | |
| ::-webkit-scrollbar-thumb:hover { | |
| background: linear-gradient(135deg, #3E72A0, #2E5378); | |
| } | |
| /* ββ Dark-mode overrides for header (keep text/SVG white) ββ */ | |
| @media (prefers-color-scheme: dark) { | |
| .app-header { | |
| background: linear-gradient(135deg, #1E3450 0%, #264364 30%, #3E72A0 70%, #4682B4 100%); | |
| } | |
| .header-text h1 { color: #fff !important; } | |
| .header-icon-wrap svg, | |
| .app-header svg, | |
| .meta-badge svg { | |
| color: #ffffff !important; | |
| stroke: #ffffff !important; | |
| } | |
| .meta-badge { color: rgba(255, 255, 255, 0.9) !important; } | |
| .meta-cap { color: rgba(255, 255, 255, 0.65) !important; } | |
| } | |
| /* Also handle Gradio's own .dark class */ | |
| .dark .header-text h1 { color: #fff !important; } | |
| .dark .header-icon-wrap svg, | |
| .dark .app-header svg, | |
| .dark .meta-badge svg { | |
| color: #ffffff !important; | |
| stroke: #ffffff !important; | |
| } | |
| .dark .meta-badge { color: rgba(255, 255, 255, 0.9) !important; } | |
| .dark .meta-cap { color: rgba(255, 255, 255, 0.65) !important; } | |
| /* ββ Responsive ββ */ | |
| @media (max-width: 768px) { | |
| .app-header { padding: 20px 24px; } | |
| .header-text h1 { font-size: 1.5rem; } | |
| .header-content { | |
| flex-direction: column; | |
| align-items: flex-start; | |
| gap: 16px; | |
| } | |
| .header-meta { gap: 8px; } | |
| } | |
| """ | |
| def predictions_to_rrd(predictions, glbfile, target_dir, frame_filter="All", show_cam=True): | |
| run_id = str(uuid.uuid4()) | |
| timestamp = datetime.now().strftime("%Y-%m-%dT%H%M%S") | |
| rrd_path = os.path.join(target_dir, f"mapanything_{timestamp}.rrd") | |
| rec = None | |
| if hasattr(rr, "new_recording"): | |
| rec = rr.new_recording(application_id="MapAnything-3D-Viewer", recording_id=run_id) | |
| elif hasattr(rr, "RecordingStream"): | |
| rec = rr.RecordingStream(application_id="MapAnything-3D-Viewer", recording_id=run_id) | |
| else: | |
| rr.init("MapAnything-3D-Viewer", recording_id=run_id, spawn=False) | |
| rec = rr | |
| rec.log("world", rr.Clear(recursive=True), static=True) | |
| rec.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, static=True) | |
| try: | |
| rec.log("world/axes/x", rr.Arrows3D(vectors=[[0.5, 0, 0]], colors=[[255, 0, 0]]), static=True) | |
| rec.log("world/axes/y", rr.Arrows3D(vectors=[[0, 0.5, 0]], colors=[[0, 255, 0]]), static=True) | |
| rec.log("world/axes/z", rr.Arrows3D(vectors=[[0, 0, 0.5]], colors=[[0, 0, 255]]), static=True) | |
| except Exception: | |
| pass | |
| rec.log("world/model", rr.Asset3D(path=glbfile), static=True) | |
| if show_cam and "extrinsic" in predictions and "intrinsic" in predictions: | |
| try: | |
| extrinsics = predictions["extrinsic"] | |
| intrinsics = predictions["intrinsic"] | |
| for i, (ext, intr) in enumerate(zip(extrinsics, intrinsics)): | |
| translation = ext[:3, 3] | |
| rotation_mat = ext[:3, :3] | |
| rec.log( | |
| f"world/cameras/cam_{i:03d}", | |
| rr.Transform3D(translation=translation, mat3x3=rotation_mat), | |
| static=True, | |
| ) | |
| fx, fy = intr[0, 0], intr[1, 1] | |
| cx, cy = intr[0, 2], intr[1, 2] | |
| if "images" in predictions and i < len(predictions["images"]): | |
| h, w = predictions["images"][i].shape[:2] | |
| else: | |
| h, w = 518, 518 | |
| rec.log( | |
| f"world/cameras/cam_{i:03d}/image", | |
| rr.Pinhole(focal_length=[fx, fy], principal_point=[cx, cy], width=w, height=h), | |
| static=True, | |
| ) | |
| if "images" in predictions and i < len(predictions["images"]): | |
| img = predictions["images"][i] | |
| if img.dtype != np.uint8: | |
| img = (np.clip(img, 0, 1) * 255).astype(np.uint8) | |
| rec.log(f"world/cameras/cam_{i:03d}/image/rgb", rr.Image(img), static=True) | |
| except Exception as e: | |
| print(f"Camera logging failed (non-fatal): {e}") | |
| if "world_points" in predictions and "images" in predictions: | |
| try: | |
| world_points = predictions["world_points"] | |
| images = predictions["images"] | |
| final_mask = predictions.get("final_mask") | |
| all_points, all_colors = [], [] | |
| for i in range(len(world_points)): | |
| pts = world_points[i] | |
| img = images[i] | |
| mask = final_mask[i].astype(bool) if final_mask is not None else np.ones(pts.shape[:2], dtype=bool) | |
| pts_flat = pts[mask] | |
| img_flat = img[mask] | |
| if img_flat.dtype != np.uint8: | |
| img_flat = (np.clip(img_flat, 0, 1) * 255).astype(np.uint8) | |
| all_points.append(pts_flat) | |
| all_colors.append(img_flat) | |
| if all_points: | |
| all_points = np.concatenate(all_points, axis=0) | |
| all_colors = np.concatenate(all_colors, axis=0) | |
| max_pts = 500_000 | |
| if len(all_points) > max_pts: | |
| idx = np.random.choice(len(all_points), max_pts, replace=False) | |
| all_points = all_points[idx] | |
| all_colors = all_colors[idx] | |
| rec.log("world/point_cloud", rr.Points3D(positions=all_points, colors=all_colors, radii=0.002), static=True) | |
| except Exception as e: | |
| print(f"Point cloud logging failed (non-fatal): {e}") | |
| if rrb is not None: | |
| try: | |
| blueprint = rrb.Blueprint( | |
| rrb.Spatial3DView(origin="/world", name="3D View"), | |
| collapse_panels=True, | |
| ) | |
| rec.send_blueprint(blueprint) | |
| except Exception as e: | |
| print(f"Blueprint creation failed (non-fatal): {e}") | |
| rec.save(rrd_path) | |
| return rrd_path | |
| def run_model(target_dir, apply_mask=True, mask_edges=True, filter_black_bg=False, filter_white_bg=False): | |
| global model | |
| import torch | |
| print(f"Processing images from {target_dir}") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if model is None: | |
| model = initialize_mapanything_model(high_level_config, device) | |
| else: | |
| model = model.to(device) | |
| model.eval() | |
| print("Loading images...") | |
| image_folder_path = os.path.join(target_dir, "images") | |
| views = load_images(image_folder_path) | |
| print(f"Loaded {len(views)} images") | |
| if len(views) == 0: | |
| raise ValueError("No images found. Check your upload.") | |
| print("Running inference...") | |
| outputs = model.infer(views, apply_mask=apply_mask, mask_edges=True, memory_efficient_inference=False) | |
| predictions = {} | |
| extrinsic_list, intrinsic_list, world_points_list = [], [], [] | |
| depth_maps_list, images_list, final_mask_list = [], [], [] | |
| for pred in outputs: | |
| depthmap_torch = pred["depth_z"][0].squeeze(-1) | |
| intrinsics_torch = pred["intrinsics"][0] | |
| camera_pose_torch = pred["camera_poses"][0] | |
| pts3d_computed, valid_mask = depthmap_to_world_frame(depthmap_torch, intrinsics_torch, camera_pose_torch) | |
| mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool) if "mask" in pred else np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool) | |
| mask = mask & valid_mask.cpu().numpy() | |
| image = pred["img_no_norm"][0].cpu().numpy() | |
| extrinsic_list.append(camera_pose_torch.cpu().numpy()) | |
| intrinsic_list.append(intrinsics_torch.cpu().numpy()) | |
| world_points_list.append(pts3d_computed.cpu().numpy()) | |
| depth_maps_list.append(depthmap_torch.cpu().numpy()) | |
| images_list.append(image) | |
| final_mask_list.append(mask) | |
| predictions["extrinsic"] = np.stack(extrinsic_list, axis=0) | |
| predictions["intrinsic"] = np.stack(intrinsic_list, axis=0) | |
| predictions["world_points"] = np.stack(world_points_list, axis=0) | |
| depth_maps = np.stack(depth_maps_list, axis=0) | |
| if len(depth_maps.shape) == 3: | |
| depth_maps = depth_maps[..., np.newaxis] | |
| predictions["depth"] = depth_maps | |
| predictions["images"] = np.stack(images_list, axis=0) | |
| predictions["final_mask"] = np.stack(final_mask_list, axis=0) | |
| processed_data = process_predictions_for_visualization(predictions, views, high_level_config, filter_black_bg, filter_white_bg) | |
| torch.cuda.empty_cache() | |
| return predictions, processed_data | |
| def update_view_selectors(processed_data): | |
| choices = [f"View {i + 1}" for i in range(len(processed_data))] if processed_data else ["View 1"] | |
| return ( | |
| gr.Dropdown(choices=choices, value=choices[0]), | |
| gr.Dropdown(choices=choices, value=choices[0]), | |
| gr.Dropdown(choices=choices, value=choices[0]), | |
| ) | |
| def get_view_data_by_index(processed_data, view_index): | |
| if not processed_data: | |
| return None | |
| view_keys = list(processed_data.keys()) | |
| view_index = max(0, min(view_index, len(view_keys) - 1)) | |
| return processed_data[view_keys[view_index]] | |
| def update_depth_view(processed_data, view_index): | |
| view_data = get_view_data_by_index(processed_data, view_index) | |
| if view_data is None or view_data["depth"] is None: | |
| return None | |
| return colorize_depth(view_data["depth"], mask=view_data.get("mask")) | |
| def update_normal_view(processed_data, view_index): | |
| view_data = get_view_data_by_index(processed_data, view_index) | |
| if view_data is None or view_data["normal"] is None: | |
| return None | |
| return colorize_normal(view_data["normal"], mask=view_data.get("mask")) | |
| def update_measure_view(processed_data, view_index): | |
| view_data = get_view_data_by_index(processed_data, view_index) | |
| if view_data is None: | |
| return None, [] | |
| image = view_data["image"].copy() | |
| if image.dtype != np.uint8: | |
| image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8) | |
| if view_data["mask"] is not None: | |
| invalid_mask = ~view_data["mask"] | |
| if invalid_mask.any(): | |
| overlay_color = np.array([255, 220, 220], dtype=np.uint8) | |
| alpha = 0.5 | |
| for c in range(3): | |
| image[:, :, c] = np.where( | |
| invalid_mask, | |
| (1 - alpha) * image[:, :, c] + alpha * overlay_color[c], | |
| image[:, :, c], | |
| ).astype(np.uint8) | |
| return image, [] | |
| def navigate_depth_view(processed_data, current_selector_value, direction): | |
| if not processed_data: | |
| return "View 1", None | |
| try: | |
| current_view = int(current_selector_value.split()[1]) - 1 | |
| except Exception: | |
| current_view = 0 | |
| new_view = (current_view + direction) % len(processed_data) | |
| return f"View {new_view + 1}", update_depth_view(processed_data, new_view) | |
| def navigate_normal_view(processed_data, current_selector_value, direction): | |
| if not processed_data: | |
| return "View 1", None | |
| try: | |
| current_view = int(current_selector_value.split()[1]) - 1 | |
| except Exception: | |
| current_view = 0 | |
| new_view = (current_view + direction) % len(processed_data) | |
| return f"View {new_view + 1}", update_normal_view(processed_data, new_view) | |
| def navigate_measure_view(processed_data, current_selector_value, direction): | |
| if not processed_data: | |
| return "View 1", None, [] | |
| try: | |
| current_view = int(current_selector_value.split()[1]) - 1 | |
| except Exception: | |
| current_view = 0 | |
| new_view = (current_view + direction) % len(processed_data) | |
| measure_image, measure_points = update_measure_view(processed_data, new_view) | |
| return f"View {new_view + 1}", measure_image, measure_points | |
| def populate_visualization_tabs(processed_data): | |
| if not processed_data: | |
| return None, None, None, [] | |
| return ( | |
| update_depth_view(processed_data, 0), | |
| update_normal_view(processed_data, 0), | |
| update_measure_view(processed_data, 0)[0], | |
| [], | |
| ) | |
| def handle_uploads(unified_upload, s_time_interval=1.0): | |
| start_time = time.time() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") | |
| target_dir = f"input_images_{timestamp}" | |
| target_dir_images = os.path.join(target_dir, "images") | |
| if os.path.exists(target_dir): | |
| shutil.rmtree(target_dir) | |
| os.makedirs(target_dir_images) | |
| image_paths = [] | |
| video_extensions = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"] | |
| if unified_upload is not None: | |
| for file_data in unified_upload: | |
| file_path = file_data["name"] if isinstance(file_data, dict) and "name" in file_data else str(file_data) | |
| file_ext = os.path.splitext(file_path)[1].lower() | |
| if file_ext in video_extensions: | |
| vs = cv2.VideoCapture(file_path) | |
| fps = vs.get(cv2.CAP_PROP_FPS) | |
| frame_interval = int(fps * s_time_interval) | |
| count, video_frame_num = 0, 0 | |
| while True: | |
| gotit, frame = vs.read() | |
| if not gotit: | |
| break | |
| count += 1 | |
| if count % frame_interval == 0: | |
| base_name = os.path.splitext(os.path.basename(file_path))[0] | |
| image_path = os.path.join(target_dir_images, f"{base_name}_{video_frame_num:06}.png") | |
| cv2.imwrite(image_path, frame) | |
| image_paths.append(image_path) | |
| video_frame_num += 1 | |
| vs.release() | |
| print(f"Extracted {video_frame_num} frames from: {os.path.basename(file_path)}") | |
| elif file_ext in [".heic", ".heif"]: | |
| try: | |
| with Image.open(file_path) as img: | |
| if img.mode not in ("RGB", "L"): | |
| img = img.convert("RGB") | |
| base_name = os.path.splitext(os.path.basename(file_path))[0] | |
| dst_path = os.path.join(target_dir_images, f"{base_name}.jpg") | |
| img.save(dst_path, "JPEG", quality=95) | |
| image_paths.append(dst_path) | |
| except Exception as e: | |
| print(f"Error converting HEIC {file_path}: {e}") | |
| dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) | |
| shutil.copy(file_path, dst_path) | |
| image_paths.append(dst_path) | |
| else: | |
| dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) | |
| shutil.copy(file_path, dst_path) | |
| image_paths.append(dst_path) | |
| image_paths = sorted(image_paths) | |
| print(f"Files processed to {target_dir_images}; took {time.time() - start_time:.3f}s") | |
| return target_dir, image_paths | |
| def gradio_demo(target_dir, frame_filter="All", show_cam=True, filter_black_bg=False, filter_white_bg=False, apply_mask=True, show_mesh=True): | |
| if not os.path.isdir(target_dir) or target_dir == "None": | |
| return None, "No valid target directory found. Please upload first.", None, None | |
| start_time = time.time() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| target_dir_images = os.path.join(target_dir, "images") | |
| all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else [] | |
| all_files_labeled = [f"{i}: {filename}" for i, filename in enumerate(all_files)] | |
| frame_filter_choices = ["All"] + all_files_labeled | |
| print("Running MapAnything model...") | |
| with torch.no_grad(): | |
| predictions, processed_data = run_model(target_dir, apply_mask) | |
| np.savez(os.path.join(target_dir, "predictions.npz"), **predictions) | |
| if frame_filter is None: | |
| frame_filter = "All" | |
| glbfile = os.path.join( | |
| target_dir, | |
| f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb", | |
| ) | |
| glbscene = predictions_to_glb( | |
| predictions, | |
| filter_by_frames=frame_filter, | |
| show_cam=show_cam, | |
| mask_black_bg=filter_black_bg, | |
| mask_white_bg=filter_white_bg, | |
| as_mesh=show_mesh, | |
| ) | |
| glbscene.export(file_obj=glbfile) | |
| rrd_path = predictions_to_rrd(predictions, glbfile, target_dir, frame_filter, show_cam) | |
| del predictions | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| print(f"Total time: {time.time() - start_time:.2f}s") | |
| log_msg = f"β Reconstruction complete β {len(all_files)} frames processed." | |
| depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs(processed_data) | |
| depth_selector, normal_selector, measure_selector = update_view_selectors(processed_data) | |
| return ( | |
| rrd_path, log_msg, | |
| gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True), | |
| processed_data, depth_vis, normal_vis, measure_img, "", | |
| depth_selector, normal_selector, measure_selector, | |
| ) | |
| def colorize_depth(depth_map, mask=None): | |
| if depth_map is None: | |
| return None | |
| depth_normalized = depth_map.copy() | |
| valid_mask = depth_normalized > 0 | |
| if mask is not None: | |
| valid_mask = valid_mask & mask | |
| if valid_mask.sum() > 0: | |
| valid_depths = depth_normalized[valid_mask] | |
| p5, p95 = np.percentile(valid_depths, 5), np.percentile(valid_depths, 95) | |
| depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5) | |
| import matplotlib.pyplot as plt | |
| colored = (plt.cm.turbo_r(depth_normalized)[:, :, :3] * 255).astype(np.uint8) | |
| colored[~valid_mask] = [255, 255, 255] | |
| return colored | |
| def colorize_normal(normal_map, mask=None): | |
| if normal_map is None: | |
| return None | |
| normal_vis = normal_map.copy() | |
| if mask is not None: | |
| normal_vis[~mask] = [0, 0, 0] | |
| return ((normal_vis + 1.0) / 2.0 * 255).astype(np.uint8) | |
| def process_predictions_for_visualization(predictions, views, high_level_config, filter_black_bg=False, filter_white_bg=False): | |
| processed_data = {} | |
| for view_idx, view in enumerate(views): | |
| image = rgb(view["img"], norm_type=high_level_config["data_norm_type"]) | |
| pred_pts3d = predictions["world_points"][view_idx] | |
| mask = predictions["final_mask"][view_idx].copy() | |
| if filter_black_bg: | |
| view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0] | |
| mask = mask & (view_colors.sum(axis=2) >= 16) | |
| if filter_white_bg: | |
| view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0] | |
| mask = mask & ~( | |
| (view_colors[:, :, 0] > 240) | |
| & (view_colors[:, :, 1] > 240) | |
| & (view_colors[:, :, 2] > 240) | |
| ) | |
| normals, _ = points_to_normals(pred_pts3d, mask=mask) | |
| processed_data[view_idx] = { | |
| "image": image[0], | |
| "points3d": pred_pts3d, | |
| "depth": predictions["depth"][view_idx].squeeze(), | |
| "normal": normals, | |
| "mask": mask, | |
| } | |
| return processed_data | |
| def measure(processed_data, measure_points, current_view_selector, event: gr.SelectData): | |
| try: | |
| if not processed_data: | |
| return None, [], "No data available" | |
| try: | |
| current_view_index = int(current_view_selector.split()[1]) - 1 | |
| except Exception: | |
| current_view_index = 0 | |
| current_view_index = max(0, min(current_view_index, len(processed_data) - 1)) | |
| current_view = processed_data[list(processed_data.keys())[current_view_index]] | |
| if current_view is None: | |
| return None, [], "No view data available" | |
| point2d = event.index[0], event.index[1] | |
| if ( | |
| current_view["mask"] is not None | |
| and 0 <= point2d[1] < current_view["mask"].shape[0] | |
| and 0 <= point2d[0] < current_view["mask"].shape[1] | |
| ): | |
| if not current_view["mask"][point2d[1], point2d[0]]: | |
| masked_image, _ = update_measure_view(processed_data, current_view_index) | |
| return masked_image, measure_points, '<span style="color: red; font-weight: bold;">Cannot measure on masked areas</span>' | |
| measure_points.append(point2d) | |
| image, _ = update_measure_view(processed_data, current_view_index) | |
| if image is None: | |
| return None, [], "No image available" | |
| image = image.copy() | |
| if image.dtype != np.uint8: | |
| image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8) | |
| points3d = current_view["points3d"] | |
| for p in measure_points: | |
| if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]: | |
| image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2) | |
| depth_text = "" | |
| for i, p in enumerate(measure_points): | |
| if ( | |
| current_view["depth"] is not None | |
| and 0 <= p[1] < current_view["depth"].shape[0] | |
| and 0 <= p[0] < current_view["depth"].shape[1] | |
| ): | |
| depth_text += f"- **P{i + 1} depth: {current_view['depth'][p[1], p[0]]:.2f}m**\n" | |
| elif ( | |
| points3d is not None | |
| and 0 <= p[1] < points3d.shape[0] | |
| and 0 <= p[0] < points3d.shape[1] | |
| ): | |
| depth_text += f"- **P{i + 1} Z-coord: {points3d[p[1], p[0], 2]:.2f}m**\n" | |
| if len(measure_points) == 2: | |
| point1, point2 = measure_points | |
| if all( | |
| 0 <= point1[0] < image.shape[1] | |
| and 0 <= point1[1] < image.shape[0] | |
| and 0 <= point2[0] < image.shape[1] | |
| and 0 <= point2[1] < image.shape[0] | |
| for _ in [1] | |
| ): | |
| image = cv2.line(image, point1, point2, color=(255, 0, 0), thickness=2) | |
| distance_text = "- **Distance: Unable to compute**" | |
| if points3d is not None and all( | |
| 0 <= p[1] < points3d.shape[0] and 0 <= p[0] < points3d.shape[1] | |
| for p in [point1, point2] | |
| ): | |
| try: | |
| distance = np.linalg.norm( | |
| points3d[point1[1], point1[0]] - points3d[point2[1], point2[0]] | |
| ) | |
| distance_text = f"- **Distance: {distance:.2f}m**" | |
| except Exception as e: | |
| distance_text = f"- **Distance error: {e}**" | |
| return [image, [], depth_text + distance_text] | |
| return [image, measure_points, depth_text] | |
| except Exception as e: | |
| print(f"Measure error: {e}") | |
| return None, [], f"Error: {e}" | |
| def clear_fields(): | |
| return None | |
| def update_log(): | |
| return "β³ Loading and reconstructingβ¦" | |
| def update_visualization( | |
| target_dir, frame_filter, show_cam, is_example, | |
| filter_black_bg=False, filter_white_bg=False, show_mesh=True, | |
| ): | |
| if is_example == "True": | |
| return gr.update(), "No reconstruction available. Please click Reconstruct first." | |
| if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): | |
| return gr.update(), "No reconstruction available. Please upload first." | |
| predictions_path = os.path.join(target_dir, "predictions.npz") | |
| if not os.path.exists(predictions_path): | |
| return gr.update(), "No reconstruction found. Please run Reconstruct first." | |
| loaded = np.load(predictions_path, allow_pickle=True) | |
| predictions = {key: loaded[key] for key in loaded.keys()} | |
| glbfile = os.path.join( | |
| target_dir, | |
| f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb", | |
| ) | |
| if not os.path.exists(glbfile): | |
| glbscene = predictions_to_glb( | |
| predictions, | |
| filter_by_frames=frame_filter, | |
| show_cam=show_cam, | |
| mask_black_bg=filter_black_bg, | |
| mask_white_bg=filter_white_bg, | |
| as_mesh=show_mesh, | |
| ) | |
| glbscene.export(file_obj=glbfile) | |
| rrd_path = predictions_to_rrd(predictions, glbfile, target_dir, frame_filter, show_cam) | |
| return rrd_path, "Visualization updated." | |
| def update_all_views_on_filter_change( | |
| target_dir, filter_black_bg, filter_white_bg, processed_data, | |
| depth_view_selector, normal_view_selector, measure_view_selector, | |
| ): | |
| if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): | |
| return processed_data, None, None, None, [] | |
| predictions_path = os.path.join(target_dir, "predictions.npz") | |
| if not os.path.exists(predictions_path): | |
| return processed_data, None, None, None, [] | |
| try: | |
| loaded = np.load(predictions_path, allow_pickle=True) | |
| predictions = {key: loaded[key] for key in loaded.keys()} | |
| views = load_images(os.path.join(target_dir, "images")) | |
| new_processed_data = process_predictions_for_visualization( | |
| predictions, views, high_level_config, filter_black_bg, filter_white_bg, | |
| ) | |
| def safe_idx(sel): | |
| try: | |
| return int(sel.split()[1]) - 1 | |
| except Exception: | |
| return 0 | |
| depth_vis = update_depth_view(new_processed_data, safe_idx(depth_view_selector)) | |
| normal_vis = update_normal_view(new_processed_data, safe_idx(normal_view_selector)) | |
| measure_img, _ = update_measure_view(new_processed_data, safe_idx(measure_view_selector)) | |
| return new_processed_data, depth_vis, normal_vis, measure_img, [] | |
| except Exception as e: | |
| print(f"Filter change error: {e}") | |
| return processed_data, None, None, None, [] | |
| def get_scene_info(examples_dir): | |
| import glob | |
| scenes = [] | |
| if not os.path.exists(examples_dir): | |
| return scenes | |
| for scene_folder in sorted(os.listdir(examples_dir)): | |
| scene_path = os.path.join(examples_dir, scene_folder) | |
| if os.path.isdir(scene_path): | |
| image_files = [] | |
| for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"]: | |
| image_files.extend(glob.glob(os.path.join(scene_path, ext))) | |
| image_files.extend(glob.glob(os.path.join(scene_path, ext.upper()))) | |
| if image_files: | |
| image_files = sorted(image_files) | |
| scenes.append({ | |
| "name": scene_folder, | |
| "path": scene_path, | |
| "thumbnail": image_files[0], | |
| "num_images": len(image_files), | |
| "image_files": image_files, | |
| }) | |
| return scenes | |
| def load_example_scene(scene_name, examples_dir="examples"): | |
| scenes = get_scene_info(examples_dir) | |
| selected_scene = next((s for s in scenes if s["name"] == scene_name), None) | |
| if selected_scene is None: | |
| return None, None, None, "Scene not found" | |
| target_dir, image_paths = handle_uploads(selected_scene["image_files"], 1.0) | |
| return None, target_dir, image_paths, f"Loaded '{scene_name}' β {selected_scene['num_images']} images. Click Reconstruct." | |
| with gr.Blocks() as demo: | |
| is_example = gr.Textbox(visible=False, value="None") | |
| num_images = gr.Textbox(visible=False, value="None") | |
| processed_data_state = gr.State(value=None) | |
| measure_points_state = gr.State(value=[]) | |
| target_dir_output = gr.Textbox(visible=False, value="None") | |
| with gr.Column(elem_id="app-shell"): | |
| # ββ New styled header ββ | |
| gr.HTML(html_header()) | |
| with gr.Row(equal_height=False): | |
| # ββ Left Panel ββ | |
| with gr.Column(elem_id="left-panel", scale=0): | |
| unified_upload = gr.File( | |
| file_count="multiple", | |
| label="Upload Images/Videos", | |
| file_types=["image", "video"], | |
| height="150", | |
| ) | |
| with gr.Row(): | |
| s_time_interval = gr.Slider( | |
| minimum=0.1, maximum=5.0, value=1.0, step=0.1, | |
| label="Video interval (sec)", | |
| scale=3, | |
| ) | |
| resample_btn = gr.Button("Resample", visible=False, variant="secondary", scale=1) | |
| image_gallery = gr.Gallery( | |
| columns=2, | |
| height="150", | |
| ) | |
| gr.ClearButton( | |
| [unified_upload, image_gallery], | |
| value="Clear uploads", | |
| variant="secondary", | |
| size="sm", | |
| ) | |
| submit_btn = gr.Button("Reconstruct", variant="primary") | |
| with gr.Accordion("Options", open=False): | |
| gr.Markdown("### Point Cloud") | |
| show_cam = gr.Checkbox(label="Show cameras", value=True) | |
| show_mesh = gr.Checkbox(label="Show mesh", value=True) | |
| filter_black_bg = gr.Checkbox(label="Filter black background", value=False) | |
| filter_white_bg = gr.Checkbox(label="Filter white background", value=False) | |
| gr.Markdown("### Reconstruction (next run)") | |
| apply_mask_checkbox = gr.Checkbox( | |
| label="Apply ambiguous-depth mask & edges", value=True, | |
| ) | |
| # ββ Right Panel ββ | |
| with gr.Column(elem_id="right-panel", scale=1): | |
| log_output = gr.Markdown( | |
| "Upload a video or images, then click **Reconstruct**.", | |
| elem_id="log-strip", | |
| ) | |
| with gr.Tabs(elem_id="viewer-tabs"): | |
| with gr.Tab("3D View"): | |
| reconstruction_output = Rerun( | |
| label="Rerun 3D Viewer", | |
| height=672, | |
| ) | |
| with gr.Tab("Depth"): | |
| with gr.Row(elem_classes=["nav-row"]): | |
| prev_depth_btn = gr.Button("β Prev", size="sm", scale=1) | |
| depth_view_selector = gr.Dropdown( | |
| choices=["View 1"], value="View 1", | |
| label="View", scale=3, interactive=True, | |
| allow_custom_value=True, show_label=False, | |
| ) | |
| next_depth_btn = gr.Button("Next βΆ", size="sm", scale=1) | |
| depth_map = gr.Image( | |
| type="numpy", label="Depth Map", | |
| format="png", interactive=False, | |
| ) | |
| with gr.Tab("Normal"): | |
| with gr.Row(elem_classes=["nav-row"]): | |
| prev_normal_btn = gr.Button("β Prev", size="sm", scale=1) | |
| normal_view_selector = gr.Dropdown( | |
| choices=["View 1"], value="View 1", | |
| label="View", scale=3, interactive=True, | |
| allow_custom_value=True, show_label=False, | |
| ) | |
| next_normal_btn = gr.Button("Next βΆ", size="sm", scale=1) | |
| normal_map = gr.Image( | |
| type="numpy", label="Normal Map", | |
| format="png", interactive=False, | |
| ) | |
| with gr.Tab("Measure"): | |
| gr.Markdown(MEASURE_INSTRUCTIONS_HTML) | |
| with gr.Row(elem_classes=["nav-row"]): | |
| prev_measure_btn = gr.Button("β Prev", size="sm", scale=1) | |
| measure_view_selector = gr.Dropdown( | |
| choices=["View 1"], value="View 1", | |
| label="View", scale=3, interactive=True, | |
| allow_custom_value=True, show_label=False, | |
| ) | |
| next_measure_btn = gr.Button("Next βΆ", size="sm", scale=1) | |
| measure_image = gr.Image( | |
| type="numpy", show_label=False, | |
| format="webp", interactive=False, sources=[], | |
| ) | |
| gr.Markdown( | |
| "Light-grey areas have no depth β measurements cannot be placed there.", | |
| elem_classes=["measure-note"], | |
| ) | |
| measure_text = gr.Markdown("") | |
| with gr.Column(): | |
| frame_filter = gr.Dropdown( | |
| choices=["All"], value="All", label="Filter by Frame", | |
| show_label=True, | |
| ) | |
| with gr.Column(elem_id="examples-section"): | |
| gr.Markdown("## Example Scenes") | |
| gr.Markdown("Click a thumbnail to load the scene, then press **Reconstruct**.") | |
| scenes = get_scene_info("examples") | |
| if scenes: | |
| for i in range(0, len(scenes), 4): | |
| with gr.Row(): | |
| for j in range(4): | |
| idx = i + j | |
| if idx < len(scenes): | |
| scene = scenes[idx] | |
| with gr.Column(scale=1, min_width=140, elem_classes=["scene-thumb"]): | |
| scene_img = gr.Image( | |
| value=scene["thumbnail"], | |
| height=130, | |
| interactive=False, | |
| show_label=False, | |
| sources=[], | |
| ) | |
| gr.Markdown( | |
| f"**{scene['name']}** \n{scene['num_images']} imgs", | |
| elem_classes=["scene-caption"], | |
| ) | |
| scene_img.select( | |
| fn=lambda name=scene["name"]: load_example_scene(name), | |
| outputs=[reconstruction_output, target_dir_output, image_gallery, log_output], | |
| ) | |
| else: | |
| with gr.Column(scale=1, min_width=140): | |
| pass | |
| submit_btn.click( | |
| fn=clear_fields, inputs=[], outputs=[reconstruction_output], | |
| ).then( | |
| fn=update_log, inputs=[], outputs=[log_output], | |
| ).then( | |
| fn=gradio_demo, | |
| inputs=[target_dir_output, frame_filter, show_cam, filter_black_bg, filter_white_bg, apply_mask_checkbox, show_mesh], | |
| outputs=[ | |
| reconstruction_output, log_output, frame_filter, processed_data_state, | |
| depth_map, normal_map, measure_image, measure_text, | |
| depth_view_selector, normal_view_selector, measure_view_selector, | |
| ], | |
| ).then(fn=lambda: "False", inputs=[], outputs=[is_example]) | |
| for trigger_inputs, trigger in [ | |
| ([target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh], frame_filter.change), | |
| ([target_dir_output, frame_filter, show_cam, is_example], show_cam.change), | |
| ([target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh], show_mesh.change), | |
| ]: | |
| trigger(update_visualization, trigger_inputs, [reconstruction_output, log_output]) | |
| filter_black_bg.change( | |
| update_visualization, | |
| [target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg], | |
| [reconstruction_output, log_output], | |
| ).then( | |
| update_all_views_on_filter_change, | |
| [target_dir_output, filter_black_bg, filter_white_bg, processed_data_state, depth_view_selector, normal_view_selector, measure_view_selector], | |
| [processed_data_state, depth_map, normal_map, measure_image, measure_points_state], | |
| ) | |
| filter_white_bg.change( | |
| update_visualization, | |
| [target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh], | |
| [reconstruction_output, log_output], | |
| ).then( | |
| update_all_views_on_filter_change, | |
| [target_dir_output, filter_black_bg, filter_white_bg, processed_data_state, depth_view_selector, normal_view_selector, measure_view_selector], | |
| [processed_data_state, depth_map, normal_map, measure_image, measure_points_state], | |
| ) | |
| def update_gallery_on_unified_upload(files, interval): | |
| if not files: | |
| return None, None, None | |
| target_dir, image_paths = handle_uploads(files, interval) | |
| return target_dir, image_paths, "Upload complete. Click **Reconstruct** to begin." | |
| def show_resample_button(files): | |
| if not files: | |
| return gr.update(visible=False) | |
| video_exts = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"] | |
| has_video = any( | |
| os.path.splitext(str(f["name"] if isinstance(f, dict) else f))[1].lower() in video_exts | |
| for f in files | |
| ) | |
| return gr.update(visible=has_video) | |
| def resample_video_with_new_interval(files, new_interval, current_target_dir): | |
| if not files: | |
| return current_target_dir, None, "No files to resample.", gr.update(visible=False) | |
| video_exts = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"] | |
| if not any( | |
| os.path.splitext(str(f["name"] if isinstance(f, dict) else f))[1].lower() in video_exts | |
| for f in files | |
| ): | |
| return current_target_dir, None, "No videos found.", gr.update(visible=False) | |
| if current_target_dir and current_target_dir != "None" and os.path.exists(current_target_dir): | |
| shutil.rmtree(current_target_dir) | |
| target_dir, image_paths = handle_uploads(files, new_interval) | |
| return target_dir, image_paths, f"Resampled at {new_interval}s. Click **Reconstruct**.", gr.update(visible=False) | |
| unified_upload.change( | |
| fn=update_gallery_on_unified_upload, | |
| inputs=[unified_upload, s_time_interval], | |
| outputs=[target_dir_output, image_gallery, log_output], | |
| ).then(fn=show_resample_button, inputs=[unified_upload], outputs=[resample_btn]) | |
| s_time_interval.change(fn=show_resample_button, inputs=[unified_upload], outputs=[resample_btn]) | |
| resample_btn.click( | |
| fn=resample_video_with_new_interval, | |
| inputs=[unified_upload, s_time_interval, target_dir_output], | |
| outputs=[target_dir_output, image_gallery, log_output, resample_btn], | |
| ) | |
| measure_image.select( | |
| fn=measure, | |
| inputs=[processed_data_state, measure_points_state, measure_view_selector], | |
| outputs=[measure_image, measure_points_state, measure_text], | |
| ) | |
| prev_depth_btn.click( | |
| fn=lambda pd, sel: navigate_depth_view(pd, sel, -1), | |
| inputs=[processed_data_state, depth_view_selector], | |
| outputs=[depth_view_selector, depth_map], | |
| ) | |
| next_depth_btn.click( | |
| fn=lambda pd, sel: navigate_depth_view(pd, sel, 1), | |
| inputs=[processed_data_state, depth_view_selector], | |
| outputs=[depth_view_selector, depth_map], | |
| ) | |
| depth_view_selector.change( | |
| fn=lambda pd, sel: update_depth_view(pd, int(sel.split()[1]) - 1) if sel else None, | |
| inputs=[processed_data_state, depth_view_selector], | |
| outputs=[depth_map], | |
| ) | |
| prev_normal_btn.click( | |
| fn=lambda pd, sel: navigate_normal_view(pd, sel, -1), | |
| inputs=[processed_data_state, normal_view_selector], | |
| outputs=[normal_view_selector, normal_map], | |
| ) | |
| next_normal_btn.click( | |
| fn=lambda pd, sel: navigate_normal_view(pd, sel, 1), | |
| inputs=[processed_data_state, normal_view_selector], | |
| outputs=[normal_view_selector, normal_map], | |
| ) | |
| normal_view_selector.change( | |
| fn=lambda pd, sel: update_normal_view(pd, int(sel.split()[1]) - 1) if sel else None, | |
| inputs=[processed_data_state, normal_view_selector], | |
| outputs=[normal_map], | |
| ) | |
| prev_measure_btn.click( | |
| fn=lambda pd, sel: navigate_measure_view(pd, sel, -1), | |
| inputs=[processed_data_state, measure_view_selector], | |
| outputs=[measure_view_selector, measure_image, measure_points_state], | |
| ) | |
| next_measure_btn.click( | |
| fn=lambda pd, sel: navigate_measure_view(pd, sel, 1), | |
| inputs=[processed_data_state, measure_view_selector], | |
| outputs=[measure_view_selector, measure_image, measure_points_state], | |
| ) | |
| measure_view_selector.change( | |
| fn=lambda pd, sel: update_measure_view(pd, int(sel.split()[1]) - 1) if sel else (None, []), | |
| inputs=[processed_data_state, measure_view_selector], | |
| outputs=[measure_image, measure_points_state], | |
| ) | |
| demo.queue(max_size=50).launch(css=CUSTOM_CSS, theme=steel_blue_theme, show_error=True, share=True, ssr_mode=False, mcp_server=True) |