diff --git "a/app.py" "b/app.py" --- "a/app.py" +++ "b/app.py" @@ -16,12 +16,12 @@ import torch from PIL import Image from pillow_heif import register_heif_opener +# --- Rerun Imports --- import rerun as rr try: import rerun.blueprint as rrb except ImportError: rrb = None - from gradio_rerun import Rerun register_heif_opener() @@ -29,13 +29,22 @@ register_heif_opener() sys.path.append("mapanything/") from mapanything.utils.geometry import depthmap_to_world_frame, points_to_normals +from mapanything.utils.hf_utils.css_and_html import ( + GRADIO_CSS, + MEASURE_INSTRUCTIONS_HTML, + get_acknowledgements_html, + get_description_html, + get_gradio_theme, + get_header_html, +) from mapanything.utils.hf_utils.hf_helpers import initialize_mapanything_model from mapanything.utils.hf_utils.viz import predictions_to_glb from mapanything.utils.image import load_images, rgb +# MapAnything Configuration high_level_config = { "path": "configs/train.yaml", - "hf_model_name": "facebook/map-anything-v1", + "hf_model_name": "facebook/map-anything-v1", # -- facebook/map-anything "model_str": "mapanything", "config_overrides": [ "machine=aws", @@ -52,550 +61,37 @@ high_level_config = { "resolution": 518, } +# Initialize model - this will be done on GPU when needed model = None -TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') -os.makedirs(TMP_DIR, exist_ok=True) - -MEASURE_INSTRUCTIONS_HTML = """ -**How to measure:** Click two points on the image to measure the real-world 3D distance between them. -""" - -# ───────────────────────────────────────────── -# CSS — Dark industrial / sci-fi aesthetic -# ───────────────────────────────────────────── -CUSTOM_CSS = """ -@import url('https://fonts.googleapis.com/css2?family=Syne:wght@400;600;700;800&family=JetBrains+Mono:wght@300;400;500&display=swap'); - -/* ── Root tokens ── */ -:root { - --bg-void: #080c10; - --bg-base: #0d1117; - --bg-surface: #111822; - --bg-raised: #172030; - --bg-hover: #1e2d42; - --border: #1f3048; - --border-bright: #2a4060; - --accent: #00c8b4; - --accent-dim: #007a6e; - --accent-glow: rgba(0,200,180,0.18); - --accent2: #3b8bff; - --text-primary: #e8f0f8; - --text-secondary: #8fa8c0; - --text-dim: #4a6480; - --danger: #ff4d6d; - --success: #00c8b4; - --font-display: 'Syne', sans-serif; - --font-mono: 'JetBrains Mono', monospace; - --radius: 10px; - --radius-lg: 16px; - --transition: all 0.2s cubic-bezier(0.4,0,0.2,1); -} - -/* ── Global reset ── */ -*, *::before, *::after { box-sizing: border-box; } - -body, .gradio-container { - background: var(--bg-void) !important; - color: var(--text-primary) !important; - font-family: var(--font-mono) !important; - min-height: 100vh; -} - -.gradio-container { - max-width: 1400px !important; - margin: 0 auto !important; - padding: 0 !important; -} - -/* ── Hero Header ── */ -#hero-header { - background: linear-gradient(135deg, #080c10 0%, #0a1520 40%, #0d1f30 100%); - border-bottom: 1px solid var(--border-bright); - padding: 48px 56px 40px; - position: relative; - overflow: hidden; - margin-bottom: 0; -} -#hero-header::before { - content: ''; - position: absolute; - inset: 0; - background: - radial-gradient(ellipse 60% 80% at 80% 50%, rgba(0,200,180,0.07) 0%, transparent 60%), - radial-gradient(ellipse 40% 60% at 20% 20%, rgba(59,139,255,0.06) 0%, transparent 50%); - pointer-events: none; -} -#hero-header::after { - content: ''; - position: absolute; - bottom: 0; left: 56px; right: 56px; - height: 1px; - background: linear-gradient(90deg, transparent, var(--accent), var(--accent2), transparent); - opacity: 0.5; -} - -#hero-title { - font-family: var(--font-display) !important; - font-size: 3rem !important; - font-weight: 800 !important; - letter-spacing: -0.02em !important; - line-height: 1 !important; - margin: 0 0 6px !important; - background: linear-gradient(135deg, #e8f0f8 30%, var(--accent) 70%, var(--accent2) 100%); - -webkit-background-clip: text; - -webkit-text-fill-color: transparent; - background-clip: text; -} -#hero-title p, #hero-title h1, #hero-title h2, #hero-title h3 { - font-family: var(--font-display) !important; - font-size: 3rem !important; - font-weight: 800 !important; - background: linear-gradient(135deg, #e8f0f8 30%, var(--accent) 70%, var(--accent2) 100%); - -webkit-background-clip: text; - -webkit-text-fill-color: transparent; - background-clip: text; - margin: 0 !important; -} - -#hero-sub p { - font-family: var(--font-mono) !important; - font-size: 0.8rem !important; - font-weight: 400 !important; - color: var(--text-secondary) !important; - letter-spacing: 0.12em !important; - text-transform: uppercase !important; - margin: 0 !important; -} - -.hero-badge { - display: inline-flex; - align-items: center; - gap: 6px; - background: rgba(0,200,180,0.1); - border: 1px solid var(--accent-dim); - border-radius: 99px; - padding: 4px 12px; - font-size: 0.7rem; - font-family: var(--font-mono); - color: var(--accent); - letter-spacing: 0.08em; - text-transform: uppercase; - margin-bottom: 16px; -} -.hero-badge::before { - content: ''; - width: 6px; height: 6px; - border-radius: 50%; - background: var(--accent); - box-shadow: 0 0 8px var(--accent); - animation: pulse-dot 2s ease-in-out infinite; -} -@keyframes pulse-dot { - 0%, 100% { opacity: 1; transform: scale(1); } - 50% { opacity: 0.4; transform: scale(0.7); } -} - -/* ── Main layout panels ── */ -#main-body { - display: grid; - grid-template-columns: 380px 1fr; - gap: 0; - min-height: calc(100vh - 160px); - background: var(--bg-void); -} - -#left-panel { - background: var(--bg-base); - border-right: 1px solid var(--border); - padding: 28px 24px; - display: flex; - flex-direction: column; - gap: 20px; - overflow-y: auto; -} - -#right-panel { - background: var(--bg-void); - padding: 28px 28px; - display: flex; - flex-direction: column; - gap: 16px; -} - -/* ── Section labels ── */ -.section-label { - font-family: var(--font-mono) !important; - font-size: 0.65rem !important; - font-weight: 500 !important; - color: var(--text-dim) !important; - letter-spacing: 0.15em !important; - text-transform: uppercase !important; - margin-bottom: 8px !important; - padding-bottom: 6px !important; - border-bottom: 1px solid var(--border) !important; -} -.section-label p { margin: 0 !important; color: inherit !important; font: inherit !important; } - -/* ── Gradio component overrides ── */ -.gradio-container .gr-block, -.gradio-container .gr-box, -.gradio-container .gr-form, -.gradio-container .block { - background: transparent !important; - border: none !important; -} - -/* Upload */ -.gradio-container .upload-container, -.gradio-container [data-testid="file-upload"] { - background: var(--bg-surface) !important; - border: 1px dashed var(--border-bright) !important; - border-radius: var(--radius) !important; - transition: var(--transition) !important; -} -.gradio-container [data-testid="file-upload"]:hover { - border-color: var(--accent) !important; - background: var(--bg-raised) !important; -} - -/* Gallery */ -.gradio-container .gallery { - background: var(--bg-surface) !important; - border: 1px solid var(--border) !important; - border-radius: var(--radius) !important; - overflow: hidden !important; -} -.gradio-container .gallery-item { - border: 1px solid var(--border) !important; - border-radius: 6px !important; - overflow: hidden !important; -} - -/* Textbox / inputs */ -.gradio-container input, -.gradio-container textarea, -.gradio-container select { - background: var(--bg-raised) !important; - border: 1px solid var(--border) !important; - color: var(--text-primary) !important; - border-radius: 6px !important; - font-family: var(--font-mono) !important; - font-size: 0.82rem !important; -} -.gradio-container input:focus, -.gradio-container textarea:focus { - border-color: var(--accent) !important; - outline: none !important; - box-shadow: 0 0 0 3px var(--accent-glow) !important; -} - -/* Labels */ -.gradio-container label span, -.gradio-container .label-wrap span { - color: var(--text-secondary) !important; - font-family: var(--font-mono) !important; - font-size: 0.75rem !important; - font-weight: 500 !important; - letter-spacing: 0.04em !important; -} - -/* Sliders */ -.gradio-container input[type=range] { - accent-color: var(--accent) !important; - background: transparent !important; - border: none !important; -} -.gradio-container .range-slider { - background: var(--bg-raised) !important; - border: 1px solid var(--border) !important; - border-radius: var(--radius) !important; - padding: 10px 14px !important; -} - -/* Checkboxes */ -.gradio-container input[type=checkbox] { - accent-color: var(--accent) !important; - width: 14px !important; - height: 14px !important; -} -.gradio-container .checkbox-group, -.gradio-container .checkbox { - background: transparent !important; - border: none !important; -} - -/* Dropdowns */ -.gradio-container .dropdown, -.gradio-container select { - background: var(--bg-raised) !important; - border: 1px solid var(--border) !important; - border-radius: 6px !important; - color: var(--text-primary) !important; - font-family: var(--font-mono) !important; -} - -/* ── Buttons ── */ -.gradio-container button { - font-family: var(--font-mono) !important; - font-size: 0.78rem !important; - font-weight: 500 !important; - letter-spacing: 0.06em !important; - border-radius: var(--radius) !important; - transition: var(--transition) !important; - cursor: pointer !important; -} - -/* Primary — Reconstruct */ -.gradio-container button.primary, -.gradio-container button[variant="primary"], -#reconstruct-btn button { - background: linear-gradient(135deg, var(--accent) 0%, #009e8c 100%) !important; - border: none !important; - color: #060d12 !important; - font-weight: 700 !important; - padding: 12px 28px !important; - box-shadow: 0 4px 20px rgba(0,200,180,0.3), 0 0 0 0 var(--accent-glow) !important; - text-transform: uppercase !important; - letter-spacing: 0.1em !important; -} -.gradio-container button.primary:hover, -#reconstruct-btn button:hover { - transform: translateY(-1px) !important; - box-shadow: 0 6px 28px rgba(0,200,180,0.45) !important; -} - -/* Secondary */ -.gradio-container button.secondary, -.gradio-container button[variant="secondary"] { - background: var(--bg-raised) !important; - border: 1px solid var(--border-bright) !important; - color: var(--text-secondary) !important; -} -.gradio-container button.secondary:hover { - background: var(--bg-hover) !important; - border-color: var(--accent-dim) !important; - color: var(--text-primary) !important; -} - -/* Nav buttons (◀ ▶) */ -.gradio-container button[size="sm"] { - background: var(--bg-surface) !important; - border: 1px solid var(--border) !important; - color: var(--text-secondary) !important; - padding: 6px 12px !important; - font-size: 0.72rem !important; -} -.gradio-container button[size="sm"]:hover { - background: var(--bg-hover) !important; - border-color: var(--accent) !important; - color: var(--accent) !important; -} - -/* ── Tabs ── */ -.gradio-container .tabs { - background: transparent !important; - border: none !important; -} -.gradio-container .tab-nav { - background: var(--bg-surface) !important; - border: 1px solid var(--border) !important; - border-radius: var(--radius) var(--radius) 0 0 !important; - padding: 4px !important; - gap: 2px !important; - display: flex !important; -} -.gradio-container .tab-nav button { - background: transparent !important; - border: none !important; - color: var(--text-dim) !important; - border-radius: 6px !important; - padding: 8px 16px !important; - font-size: 0.72rem !important; - letter-spacing: 0.08em !important; - text-transform: uppercase !important; -} -.gradio-container .tab-nav button.selected { - background: var(--bg-raised) !important; - color: var(--accent) !important; - border-bottom: 2px solid var(--accent) !important; -} -.gradio-container .tabitem { - background: var(--bg-surface) !important; - border: 1px solid var(--border) !important; - border-top: none !important; - border-radius: 0 0 var(--radius) var(--radius) !important; - padding: 16px !important; -} - -/* ── Log output ── */ -#log-output { - background: var(--bg-surface) !important; - border: 1px solid var(--border) !important; - border-left: 3px solid var(--accent) !important; - border-radius: var(--radius) !important; - padding: 12px 16px !important; - font-family: var(--font-mono) !important; - font-size: 0.75rem !important; - color: var(--text-secondary) !important; - min-height: 38px !important; -} -#log-output p { margin: 0 !important; color: var(--text-secondary) !important; font-family: var(--font-mono) !important; font-size: 0.75rem !important; } - -/* ── 3D viewer ── */ -.rerun-viewer, -[data-testid="rerun"] { - border-radius: var(--radius) !important; - border: 1px solid var(--border-bright) !important; - overflow: hidden !important; - background: var(--bg-base) !important; -} - -/* ── Images ── */ -.gradio-container .image-container img, -.gradio-container .image-preview img { - border-radius: 6px !important; -} - -/* ── Accordion ── */ -.gradio-container .accordion { - background: var(--bg-surface) !important; - border: 1px solid var(--border) !important; - border-radius: var(--radius) !important; - overflow: hidden !important; -} -.gradio-container .accordion-header { - background: var(--bg-raised) !important; - padding: 10px 14px !important; - font-family: var(--font-mono) !important; - font-size: 0.75rem !important; - color: var(--text-secondary) !important; - cursor: pointer !important; -} -.gradio-container .accordion-header:hover { - color: var(--text-primary) !important; -} - -/* ── Status indicator row ── */ -.status-row { - display: flex; - align-items: center; - gap: 10px; -} -.status-dot { - width: 8px; height: 8px; - border-radius: 50%; - background: var(--accent); - box-shadow: 0 0 8px var(--accent); - flex-shrink: 0; -} - -/* ── Markdown ── */ -.gradio-container .markdown p, -.gradio-container .prose p { - color: var(--text-secondary) !important; - font-family: var(--font-mono) !important; - font-size: 0.78rem !important; - line-height: 1.6 !important; -} -.gradio-container .markdown strong, -.gradio-container .prose strong { - color: var(--text-primary) !important; -} -.gradio-container .markdown h3, -.gradio-container .prose h3 { - color: var(--accent) !important; - font-family: var(--font-display) !important; - font-size: 0.85rem !important; - font-weight: 700 !important; - letter-spacing: 0.05em !important; - text-transform: uppercase !important; - margin: 16px 0 8px !important; -} - -/* ── Divider ── */ -.divider { - height: 1px; - background: linear-gradient(90deg, transparent, var(--border-bright), transparent); - margin: 4px 0; -} - -/* ── Options group ── */ -.options-group { - background: var(--bg-surface); - border: 1px solid var(--border); - border-radius: var(--radius); - padding: 14px 16px; -} -.options-title { - font-family: var(--font-mono) !important; - font-size: 0.65rem !important; - font-weight: 500 !important; - color: var(--accent) !important; - letter-spacing: 0.15em !important; - text-transform: uppercase !important; - margin-bottom: 10px !important; - padding-bottom: 6px !important; - border-bottom: 1px solid var(--border) !important; -} -.options-title p { margin: 0 !important; color: inherit !important; font: inherit !important; } - -/* ── Scrollbar ── */ -::-webkit-scrollbar { width: 6px; height: 6px; } -::-webkit-scrollbar-track { background: var(--bg-base); } -::-webkit-scrollbar-thumb { background: var(--border-bright); border-radius: 3px; } -::-webkit-scrollbar-thumb:hover { background: var(--accent-dim); } - -/* ── Example thumbnails ── */ -.examples-grid { - display: grid; - grid-template-columns: repeat(auto-fill, minmax(140px, 1fr)); - gap: 12px; - margin-top: 12px; -} -.clickable-thumbnail { - cursor: pointer; - transition: var(--transition); -} -.clickable-thumbnail:hover { - transform: translateY(-2px); -} -.scene-info p { - color: var(--text-secondary) !important; - font-size: 0.7rem !important; - font-family: var(--font-mono) !important; - text-align: center !important; - margin: 4px 0 0 !important; -} - -/* ── Responsive ── */ -@media (max-width: 900px) { - #main-body { grid-template-columns: 1fr !important; } - #hero-header { padding: 32px 24px 28px; } - #hero-title p, #hero-title h1 { font-size: 2rem !important; } -} -""" -# ───────────────────────────────────────────── -# Rerun helper -# ───────────────────────────────────────────── -def predictions_to_rrd(predictions, glbfile, target_dir, frame_filter="All", show_cam=True): +# ------------------------------------------------------------------------- +# Rerun Helper Function +# ------------------------------------------------------------------------- +def create_rerun_recording(glb_path, output_dir): + """ + Takes a generated GLB file, wraps it in a Rerun recording (.rrd), + and returns the path to the .rrd file for the UI to consume. + """ run_id = str(uuid.uuid4()) - timestamp = datetime.now().strftime("%Y-%m-%dT%H%M%S") - rrd_path = os.path.join(target_dir, f"mapanything_{timestamp}.rrd") - + + # Robustly handle different Rerun SDK versions rec = None if hasattr(rr, "new_recording"): - rec = rr.new_recording(application_id="MapAnything-3D-Viewer", recording_id=run_id) + rec = rr.new_recording(application_id="MapAnything-3D", recording_id=run_id) elif hasattr(rr, "RecordingStream"): - rec = rr.RecordingStream(application_id="MapAnything-3D-Viewer", recording_id=run_id) + rec = rr.RecordingStream(application_id="MapAnything-3D", recording_id=run_id) else: - rr.init("MapAnything-3D-Viewer", recording_id=run_id, spawn=False) + rr.init("MapAnything-3D", recording_id=run_id, spawn=False) rec = rr - + + # Clear previous states rec.log("world", rr.Clear(recursive=True), static=True) + + # Set coordinates rec.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, static=True) + # Add optional axes helpers try: rec.log("world/axes/x", rr.Arrows3D(vectors=[[0.5, 0, 0]], colors=[[255, 0, 0]]), static=True) rec.log("world/axes/y", rr.Arrows3D(vectors=[[0, 0.5, 0]], colors=[[0, 255, 0]]), static=True) @@ -603,688 +99,1231 @@ def predictions_to_rrd(predictions, glbfile, target_dir, frame_filter="All", sho except Exception: pass - rec.log("world/model", rr.Asset3D(path=glbfile), static=True) - - if show_cam and "extrinsic" in predictions and "intrinsic" in predictions: - try: - extrinsics = predictions["extrinsic"] - intrinsics = predictions["intrinsic"] - for i, (ext, intr) in enumerate(zip(extrinsics, intrinsics)): - translation = ext[:3, 3] - rotation_mat = ext[:3, :3] - rec.log(f"world/cameras/cam_{i:03d}", rr.Transform3D(translation=translation, mat3x3=rotation_mat), static=True) - fx, fy = intr[0, 0], intr[1, 1] - cx, cy = intr[0, 2], intr[1, 2] - h, w = (predictions["images"][i].shape[:2] if "images" in predictions and i < len(predictions["images"]) else (518, 518)) - rec.log(f"world/cameras/cam_{i:03d}/image", rr.Pinhole(focal_length=[fx, fy], principal_point=[cx, cy], width=w, height=h), static=True) - if "images" in predictions and i < len(predictions["images"]): - img = predictions["images"][i] - if img.dtype != np.uint8: - img = (np.clip(img, 0, 1) * 255).astype(np.uint8) - rec.log(f"world/cameras/cam_{i:03d}/image/rgb", rr.Image(img), static=True) - except Exception as e: - print(f"Camera logging failed (non-fatal): {e}") - - if "world_points" in predictions and "images" in predictions: - try: - world_points = predictions["world_points"] - images = predictions["images"] - final_mask = predictions.get("final_mask") - all_points, all_colors = [], [] - for i in range(len(world_points)): - pts = world_points[i] - img = images[i] - mask = final_mask[i].astype(bool) if final_mask is not None else np.ones(pts.shape[:2], dtype=bool) - pts_flat = pts[mask] - img_flat = img[mask] - if img_flat.dtype != np.uint8: - img_flat = (np.clip(img_flat, 0, 1) * 255).astype(np.uint8) - all_points.append(pts_flat) - all_colors.append(img_flat) - if all_points: - all_points = np.concatenate(all_points, axis=0) - all_colors = np.concatenate(all_colors, axis=0) - if len(all_points) > 500_000: - idx = np.random.choice(len(all_points), 500_000, replace=False) - all_points = all_points[idx] - all_colors = all_colors[idx] - rec.log("world/point_cloud", rr.Points3D(positions=all_points, colors=all_colors, radii=0.002), static=True) - except Exception as e: - print(f"Point cloud logging failed (non-fatal): {e}") - + # Log the 3D Model + rec.log("world/scene", rr.Asset3D(path=glb_path), static=True) + + # Blueprint for clean layout if rrb is not None: try: - blueprint = rrb.Blueprint(rrb.Spatial3DView(origin="/world", name="3D View"), collapse_panels=True) + blueprint = rrb.Blueprint( + rrb.Spatial3DView( + origin="/world", + name="3D View", + ), + collapse_panels=True, + ) rec.send_blueprint(blueprint) except Exception as e: print(f"Blueprint creation failed (non-fatal): {e}") + # Save the recording to the target directory + rrd_path = os.path.join(output_dir, f'scene_{run_id}.rrd') rec.save(rrd_path) + return rrd_path -# ───────────────────────────────────────────── -# Core model inference -# ───────────────────────────────────────────── +# ------------------------------------------------------------------------- +# 1) Core model inference +# ------------------------------------------------------------------------- @spaces.GPU(duration=120) -def run_model(target_dir, apply_mask=True, mask_edges=True, filter_black_bg=False, filter_white_bg=False): +def run_model( + target_dir, + apply_mask=True, + mask_edges=True, + filter_black_bg=False, + filter_white_bg=False, +): + """ + Run the MapAnything model on images in the 'target_dir/images' folder and return predictions. + """ global model - import torch + import torch # Ensure torch is available in function scope print(f"Processing images from {target_dir}") - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # Device check + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + + # Initialize model if not already done if model is None: model = initialize_mapanything_model(high_level_config, device) + else: model = model.to(device) + model.eval() + # Load images using MapAnything's load_images function + print("Loading images...") image_folder_path = os.path.join(target_dir, "images") views = load_images(image_folder_path) + print(f"Loaded {len(views)} images") if len(views) == 0: raise ValueError("No images found. Check your upload.") - outputs = model.infer(views, apply_mask=apply_mask, mask_edges=True, memory_efficient_inference=False) + # Run model inference + print("Running inference...") + # apply_mask: Whether to apply the non-ambiguous mask to the output. Defaults to True. + # mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True. + # Use checkbox values - mask_edges is set to True by default since there's no UI control for it + outputs = model.infer( + views, apply_mask=apply_mask, mask_edges=True, memory_efficient_inference=False + ) + # Convert predictions to format expected by visualization predictions = {} - extrinsic_list, intrinsic_list, world_points_list = [], [], [] - depth_maps_list, images_list, final_mask_list = [], [], [] - for pred in outputs: - depthmap_torch = pred["depth_z"][0].squeeze(-1) - intrinsics_torch = pred["intrinsics"][0] - camera_pose_torch = pred["camera_poses"][0] - pts3d_computed, valid_mask = depthmap_to_world_frame(depthmap_torch, intrinsics_torch, camera_pose_torch) + # Initialize lists for the required keys + extrinsic_list = [] + intrinsic_list = [] + world_points_list = [] + depth_maps_list = [] + images_list = [] + final_mask_list = [] + # Loop through the outputs + for pred in outputs: + # Extract data from predictions + depthmap_torch = pred["depth_z"][0].squeeze(-1) # (H, W) + intrinsics_torch = pred["intrinsics"][0] # (3, 3) + camera_pose_torch = pred["camera_poses"][0] # (4, 4) + + # Compute new pts3d using depth, intrinsics, and camera pose + pts3d_computed, valid_mask = depthmap_to_world_frame( + depthmap_torch, intrinsics_torch, camera_pose_torch + ) + + # Convert to numpy arrays for visualization + # Check if mask key exists in pred, if not, fill with boolean trues in the size of depthmap_torch if "mask" in pred: mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool) else: + # Fill with boolean trues in the size of depthmap_torch mask = np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool) + + # Combine with valid depth mask mask = mask & valid_mask.cpu().numpy() + image = pred["img_no_norm"][0].cpu().numpy() + + # Append to lists extrinsic_list.append(camera_pose_torch.cpu().numpy()) intrinsic_list.append(intrinsics_torch.cpu().numpy()) world_points_list.append(pts3d_computed.cpu().numpy()) depth_maps_list.append(depthmap_torch.cpu().numpy()) - images_list.append(pred["img_no_norm"][0].cpu().numpy()) - final_mask_list.append(mask) + images_list.append(image) # Add image to list + final_mask_list.append(mask) # Add final_mask to list + # Convert lists to numpy arrays with required shapes + # extrinsic: (S, 3, 4) - batch of camera extrinsic matrices predictions["extrinsic"] = np.stack(extrinsic_list, axis=0) + + # intrinsic: (S, 3, 3) - batch of camera intrinsic matrices predictions["intrinsic"] = np.stack(intrinsic_list, axis=0) + + # world_points: (S, H, W, 3) - batch of 3D world points predictions["world_points"] = np.stack(world_points_list, axis=0) + + # depth: (S, H, W, 1) or (S, H, W) - batch of depth maps depth_maps = np.stack(depth_maps_list, axis=0) + # Add channel dimension if needed to match (S, H, W, 1) format if len(depth_maps.shape) == 3: depth_maps = depth_maps[..., np.newaxis] + predictions["depth"] = depth_maps + + # images: (S, H, W, 3) - batch of input images predictions["images"] = np.stack(images_list, axis=0) + + # final_mask: (S, H, W) - batch of final masks for filtering predictions["final_mask"] = np.stack(final_mask_list, axis=0) - processed_data = process_predictions_for_visualization(predictions, views, high_level_config, filter_black_bg, filter_white_bg) + # Process data for visualization tabs (depth, normal, measure) + processed_data = process_predictions_for_visualization( + predictions, views, high_level_config, filter_black_bg, filter_white_bg + ) + + # Clean up torch.cuda.empty_cache() + return predictions, processed_data def update_view_selectors(processed_data): - choices = [f"View {i + 1}" for i in range(len(processed_data))] if processed_data else ["View 1"] + """Update view selector dropdowns based on available views""" + if processed_data is None or len(processed_data) == 0: + choices = ["View 1"] + else: + num_views = len(processed_data) + choices = [f"View {i + 1}" for i in range(num_views)] + return ( - gr.Dropdown(choices=choices, value=choices[0]), - gr.Dropdown(choices=choices, value=choices[0]), - gr.Dropdown(choices=choices, value=choices[0]), + gr.Dropdown(choices=choices, value=choices[0]), # depth_view_selector + gr.Dropdown(choices=choices, value=choices[0]), # normal_view_selector + gr.Dropdown(choices=choices, value=choices[0]), # measure_view_selector ) def get_view_data_by_index(processed_data, view_index): - if not processed_data: + """Get view data by index, handling bounds""" + if processed_data is None or len(processed_data) == 0: return None + view_keys = list(processed_data.keys()) - view_index = max(0, min(view_index, len(view_keys) - 1)) + if view_index < 0 or view_index >= len(view_keys): + view_index = 0 + return processed_data[view_keys[view_index]] def update_depth_view(processed_data, view_index): + """Update depth view for a specific view index""" view_data = get_view_data_by_index(processed_data, view_index) if view_data is None or view_data["depth"] is None: return None + return colorize_depth(view_data["depth"], mask=view_data.get("mask")) def update_normal_view(processed_data, view_index): + """Update normal view for a specific view index""" view_data = get_view_data_by_index(processed_data, view_index) if view_data is None or view_data["normal"] is None: return None + return colorize_normal(view_data["normal"], mask=view_data.get("mask")) def update_measure_view(processed_data, view_index): + """Update measure view for a specific view index with mask overlay""" view_data = get_view_data_by_index(processed_data, view_index) if view_data is None: - return None, [] + return None, [] # image, measure_points + + # Get the base image image = view_data["image"].copy() + + # Ensure image is in uint8 format if image.dtype != np.uint8: - image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8) + if image.max() <= 1.0: + image = (image * 255).astype(np.uint8) + else: + image = image.astype(np.uint8) + + # Apply mask overlay if mask is available if view_data["mask"] is not None: - invalid_mask = ~view_data["mask"] + mask = view_data["mask"] + + # Create light grey overlay for masked areas + # Masked areas (False values) will be overlaid with light grey + invalid_mask = ~mask # Areas where mask is False + if invalid_mask.any(): + # Create a light grey overlay (RGB: 192, 192, 192) overlay_color = np.array([255, 220, 220], dtype=np.uint8) - alpha = 0.5 - for c in range(3): - image[:, :, c] = np.where(invalid_mask, (1 - alpha) * image[:, :, c] + alpha * overlay_color[c], image[:, :, c]).astype(np.uint8) + + # Apply overlay with some transparency + alpha = 0.5 # Transparency level + for c in range(3): # RGB channels + image[:, :, c] = np.where( + invalid_mask, + (1 - alpha) * image[:, :, c] + alpha * overlay_color[c], + image[:, :, c], + ).astype(np.uint8) + return image, [] def navigate_depth_view(processed_data, current_selector_value, direction): - if not processed_data: + """Navigate depth view (direction: -1 for previous, +1 for next)""" + if processed_data is None or len(processed_data) == 0: return "View 1", None + + # Parse current view number try: current_view = int(current_selector_value.split()[1]) - 1 except: current_view = 0 - new_view = (current_view + direction) % len(processed_data) - return f"View {new_view + 1}", update_depth_view(processed_data, new_view) + + num_views = len(processed_data) + new_view = (current_view + direction) % num_views + + new_selector_value = f"View {new_view + 1}" + depth_vis = update_depth_view(processed_data, new_view) + + return new_selector_value, depth_vis def navigate_normal_view(processed_data, current_selector_value, direction): - if not processed_data: + """Navigate normal view (direction: -1 for previous, +1 for next)""" + if processed_data is None or len(processed_data) == 0: return "View 1", None + + # Parse current view number try: current_view = int(current_selector_value.split()[1]) - 1 except: current_view = 0 - new_view = (current_view + direction) % len(processed_data) - return f"View {new_view + 1}", update_normal_view(processed_data, new_view) + + num_views = len(processed_data) + new_view = (current_view + direction) % num_views + + new_selector_value = f"View {new_view + 1}" + normal_vis = update_normal_view(processed_data, new_view) + + return new_selector_value, normal_vis def navigate_measure_view(processed_data, current_selector_value, direction): - if not processed_data: + """Navigate measure view (direction: -1 for previous, +1 for next)""" + if processed_data is None or len(processed_data) == 0: return "View 1", None, [] + + # Parse current view number try: current_view = int(current_selector_value.split()[1]) - 1 except: current_view = 0 - new_view = (current_view + direction) % len(processed_data) - img, pts = update_measure_view(processed_data, new_view) - return f"View {new_view + 1}", img, pts + + num_views = len(processed_data) + new_view = (current_view + direction) % num_views + + new_selector_value = f"View {new_view + 1}" + measure_image, measure_points = update_measure_view(processed_data, new_view) + + return new_selector_value, measure_image, measure_points def populate_visualization_tabs(processed_data): - if not processed_data: + """Populate the depth, normal, and measure tabs with processed data""" + if processed_data is None or len(processed_data) == 0: return None, None, None, [] - return ( - update_depth_view(processed_data, 0), - update_normal_view(processed_data, 0), - update_measure_view(processed_data, 0)[0], - [], - ) + + # Use update functions to ensure confidence filtering is applied from the start + depth_vis = update_depth_view(processed_data, 0) + normal_vis = update_normal_view(processed_data, 0) + measure_img, _ = update_measure_view(processed_data, 0) + + return depth_vis, normal_vis, measure_img, [] -# ───────────────────────────────────────────── -# File handling -# ───────────────────────────────────────────── +# ------------------------------------------------------------------------- +# 2) Handle uploaded video/images --> produce target_dir + images +# ------------------------------------------------------------------------- def handle_uploads(unified_upload, s_time_interval=1.0): + """ + Create a new 'target_dir' + 'images' subfolder, and place user-uploaded + images or extracted frames from video into it. Return (target_dir, image_paths). + """ start_time = time.time() gc.collect() torch.cuda.empty_cache() + # Create a unique folder name timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") target_dir = f"input_images_{timestamp}" target_dir_images = os.path.join(target_dir, "images") + # Clean up if somehow that folder already exists if os.path.exists(target_dir): shutil.rmtree(target_dir) - os.makedirs(target_dir_images, exist_ok=True) + os.makedirs(target_dir) + os.makedirs(target_dir_images) image_paths = [] + + # --- Handle uploaded files (both images and videos) --- if unified_upload is not None: for file_data in unified_upload: - file_path = file_data["name"] if isinstance(file_data, dict) and "name" in file_data else str(file_data) + if isinstance(file_data, dict) and "name" in file_data: + file_path = file_data["name"] + else: + file_path = str(file_data) + file_ext = os.path.splitext(file_path)[1].lower() - video_extensions = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"] + # Check if it's a video file + video_extensions = [ + ".mp4", + ".avi", + ".mov", + ".mkv", + ".wmv", + ".flv", + ".webm", + ".m4v", + ".3gp", + ] if file_ext in video_extensions: + # Handle as video vs = cv2.VideoCapture(file_path) fps = vs.get(cv2.CAP_PROP_FPS) - frame_interval = int(fps * s_time_interval) - count = video_frame_num = 0 + frame_interval = int(fps * s_time_interval) # frames per interval + + count = 0 + video_frame_num = 0 while True: gotit, frame = vs.read() if not gotit: break count += 1 if count % frame_interval == 0: + # Use original filename as prefix for frames base_name = os.path.splitext(os.path.basename(file_path))[0] - image_path = os.path.join(target_dir_images, f"{base_name}_{video_frame_num:06}.png") + image_path = os.path.join( + target_dir_images, f"{base_name}_{video_frame_num:06}.png" + ) cv2.imwrite(image_path, frame) image_paths.append(image_path) video_frame_num += 1 vs.release() - elif file_ext in [".heic", ".heif"]: - try: - with Image.open(file_path) as img: - if img.mode not in ("RGB", "L"): - img = img.convert("RGB") - base_name = os.path.splitext(os.path.basename(file_path))[0] - dst_path = os.path.join(target_dir_images, f"{base_name}.jpg") - img.save(dst_path, "JPEG", quality=95) + print( + f"Extracted {video_frame_num} frames from video: {os.path.basename(file_path)}" + ) + + else: + # Handle as image + # Check if the file is a HEIC image + if file_ext in [".heic", ".heif"]: + # Convert HEIC to JPEG for better gallery compatibility + try: + with Image.open(file_path) as img: + # Convert to RGB if necessary (HEIC can have different color modes) + if img.mode not in ("RGB", "L"): + img = img.convert("RGB") + + # Create JPEG filename + base_name = os.path.splitext(os.path.basename(file_path))[0] + dst_path = os.path.join( + target_dir_images, f"{base_name}.jpg" + ) + + # Save as JPEG with high quality + img.save(dst_path, "JPEG", quality=95) + image_paths.append(dst_path) + print( + f"Converted HEIC to JPEG: {os.path.basename(file_path)} -> {os.path.basename(dst_path)}" + ) + except Exception as e: + print(f"Error converting HEIC file {file_path}: {e}") + # Fall back to copying as is + dst_path = os.path.join( + target_dir_images, os.path.basename(file_path) + ) + shutil.copy(file_path, dst_path) image_paths.append(dst_path) - except Exception as e: - print(f"HEIC error: {e}") - dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) + else: + # Regular image files - copy as is + dst_path = os.path.join( + target_dir_images, os.path.basename(file_path) + ) shutil.copy(file_path, dst_path) image_paths.append(dst_path) - else: - dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) - shutil.copy(file_path, dst_path) - image_paths.append(dst_path) + # Sort final images for gallery image_paths = sorted(image_paths) - print(f"Files processed; took {time.time() - start_time:.3f}s") + + end_time = time.time() + print( + f"Files processed to {target_dir_images}; took {end_time - start_time:.3f} seconds" + ) return target_dir, image_paths -# ───────────────────────────────────────────── -# Reconstruction -# ───────────────────────────────────────────── +# ------------------------------------------------------------------------- +# 3) Update gallery on upload +# ------------------------------------------------------------------------- +def update_gallery_on_upload(input_video, input_images, s_time_interval=1.0): + """ + Whenever user uploads or changes files, immediately handle them + and show in the gallery. Return (target_dir, image_paths). + If nothing is uploaded, returns "None" and empty list. + """ + if not input_video and not input_images: + return None, None, None, None + target_dir, image_paths = handle_uploads(input_video, input_images, s_time_interval) + return ( + None, + target_dir, + image_paths, + "Upload complete. Click 'Reconstruct' to begin 3D processing.", + ) + + +# ------------------------------------------------------------------------- +# 4) Reconstruction: uses the target_dir plus any viz parameters +# ------------------------------------------------------------------------- @spaces.GPU(duration=120) -def gradio_demo(target_dir, frame_filter="All", show_cam=True, filter_black_bg=False, filter_white_bg=False, apply_mask=True, show_mesh=True): +def gradio_demo( + target_dir, + frame_filter="All", + show_cam=True, + filter_black_bg=False, + filter_white_bg=False, + apply_mask=True, + show_mesh=True, +): + """ + Perform reconstruction using the already-created target_dir/images. + """ if not os.path.isdir(target_dir) or target_dir == "None": - return None, "⚠ No valid target directory found. Please upload first.", None, None, None, None, None, "", None, None, None + return None, "No valid target directory found. Please upload first.", None, None start_time = time.time() gc.collect() torch.cuda.empty_cache() + # Prepare frame_filter dropdown target_dir_images = os.path.join(target_dir, "images") - all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else [] - all_files_labeled = [f"{i}: {filename}" for i, filename in enumerate(all_files)] - frame_filter_choices = ["All"] + all_files_labeled + all_files = ( + sorted(os.listdir(target_dir_images)) + if os.path.isdir(target_dir_images) + else [] + ) + all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)] + frame_filter_choices = ["All"] + all_files + print("Running MapAnything model...") with torch.no_grad(): predictions, processed_data = run_model(target_dir, apply_mask) - np.savez(os.path.join(target_dir, "predictions.npz"), **predictions) + # Save predictions + prediction_save_path = os.path.join(target_dir, "predictions.npz") + np.savez(prediction_save_path, **predictions) + # Handle None frame_filter if frame_filter is None: frame_filter = "All" + # Build a GLB file name glbfile = os.path.join( target_dir, f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb", ) - glbscene = predictions_to_glb(predictions, filter_by_frames=frame_filter, show_cam=show_cam, mask_black_bg=filter_black_bg, mask_white_bg=filter_white_bg, as_mesh=show_mesh) - glbscene.export(file_obj=glbfile) - rrd_path = predictions_to_rrd(predictions, glbfile, target_dir, frame_filter, show_cam) + # Convert predictions to GLB + glbscene = predictions_to_glb( + predictions, + filter_by_frames=frame_filter, + show_cam=show_cam, + mask_black_bg=filter_black_bg, + mask_white_bg=filter_white_bg, + as_mesh=show_mesh, # Use the show_mesh parameter + ) + glbscene.export(file_obj=glbfile) + + # --------------------------------------------------------- + # Generate the Rerun recording using the new helper + # --------------------------------------------------------- + rrd_path = create_rerun_recording(glbfile, target_dir) + # Cleanup del predictions gc.collect() torch.cuda.empty_cache() - elapsed = time.time() - start_time - log_msg = f"✓ Reconstruction complete — {len(all_files)} frames processed in {elapsed:.1f}s" + end_time = time.time() + print(f"Total time: {end_time - start_time:.2f} seconds") + log_msg = ( + f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization." + ) + + # Populate visualization tabs with processed data + depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs( + processed_data + ) - depth_vis, normal_vis, measure_img, _ = populate_visualization_tabs(processed_data) - depth_selector, normal_selector, measure_selector = update_view_selectors(processed_data) + # Update view selectors based on available views + depth_selector, normal_selector, measure_selector = update_view_selectors( + processed_data + ) return ( - rrd_path, log_msg, + rrd_path, # Return the Rerun recording path instead of glbfile + log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True), processed_data, - depth_vis, normal_vis, measure_img, "", - depth_selector, normal_selector, measure_selector, + depth_vis, + normal_vis, + measure_img, + "", # measure_text (empty initially) + depth_selector, + normal_selector, + measure_selector, ) -# ───────────────────────────────────────────── -# Visualization helpers -# ───────────────────────────────────────────── +# ------------------------------------------------------------------------- +# 5) Helper functions for UI resets + re-visualization +# ------------------------------------------------------------------------- def colorize_depth(depth_map, mask=None): + """Convert depth map to colorized visualization with optional mask""" if depth_map is None: return None + + # Normalize depth to 0-1 range depth_normalized = depth_map.copy() valid_mask = depth_normalized > 0 + + # Apply additional mask if provided (for background filtering) if mask is not None: valid_mask = valid_mask & mask + if valid_mask.sum() > 0: valid_depths = depth_normalized[valid_mask] - p5, p95 = np.percentile(valid_depths, 5), np.percentile(valid_depths, 95) + p5 = np.percentile(valid_depths, 5) + p95 = np.percentile(valid_depths, 95) + depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5) + + # Apply colormap import matplotlib.pyplot as plt - colored = (plt.cm.turbo_r(depth_normalized)[:, :, :3] * 255).astype(np.uint8) + + colormap = plt.cm.turbo_r + colored = colormap(depth_normalized) + colored = (colored[:, :, :3] * 255).astype(np.uint8) + + # Set invalid pixels to white colored[~valid_mask] = [255, 255, 255] + return colored def colorize_normal(normal_map, mask=None): + """Convert normal map to colorized visualization with optional mask""" if normal_map is None: return None + + # Create a copy for modification normal_vis = normal_map.copy() + + # Apply mask if provided (set masked areas to [0, 0, 0] which becomes grey after normalization) if mask is not None: - normal_vis[~mask] = [0, 0, 0] - return ((normal_vis + 1.0) / 2.0 * 255).astype(np.uint8) + invalid_mask = ~mask + normal_vis[invalid_mask] = [0, 0, 0] # Set invalid areas to zero + + # Normalize normals to [0, 1] range for visualization + normal_vis = (normal_vis + 1.0) / 2.0 + normal_vis = (normal_vis * 255).astype(np.uint8) + return normal_vis -def process_predictions_for_visualization(predictions, views, high_level_config, filter_black_bg=False, filter_white_bg=False): + +def process_predictions_for_visualization( + predictions, views, high_level_config, filter_black_bg=False, filter_white_bg=False +): + """Extract depth, normal, and 3D points from predictions for visualization""" processed_data = {} + + # Process each view for view_idx, view in enumerate(views): + # Get image image = rgb(view["img"], norm_type=high_level_config["data_norm_type"]) + + # Get predicted points pred_pts3d = predictions["world_points"][view_idx] - view_data = {"image": image[0], "points3d": pred_pts3d, "depth": None, "normal": None, "mask": None} + + # Initialize data for this view + view_data = { + "image": image[0], + "points3d": pred_pts3d, + "depth": None, + "normal": None, + "mask": None, + } + + # Start with the final mask from predictions mask = predictions["final_mask"][view_idx].copy() + + # Apply black background filtering if enabled if filter_black_bg: + # Get the image colors (ensure they're in 0-255 range) view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0] - mask = mask & (view_colors.sum(axis=2) >= 16) + # Filter out black background pixels (sum of RGB < 16) + black_bg_mask = view_colors.sum(axis=2) >= 16 + mask = mask & black_bg_mask + + # Apply white background filtering if enabled if filter_white_bg: + # Get the image colors (ensure they're in 0-255 range) view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0] - mask = mask & ~((view_colors[:, :, 0] > 240) & (view_colors[:, :, 1] > 240) & (view_colors[:, :, 2] > 240)) + # Filter out white background pixels (all RGB > 240) + white_bg_mask = ~( + (view_colors[:, :, 0] > 240) + & (view_colors[:, :, 1] > 240) + & (view_colors[:, :, 2] > 240) + ) + mask = mask & white_bg_mask + view_data["mask"] = mask view_data["depth"] = predictions["depth"][view_idx].squeeze() - view_data["normal"], _ = points_to_normals(pred_pts3d, mask=mask) + + normals, _ = points_to_normals(pred_pts3d, mask=view_data["mask"]) + view_data["normal"] = normals + processed_data[view_idx] = view_data + return processed_data -def measure(processed_data, measure_points, current_view_selector, event: gr.SelectData): +def reset_measure(processed_data): + """Reset measure points""" + if processed_data is None or len(processed_data) == 0: + return None, [], "" + + # Return the first view image + first_view = list(processed_data.values())[0] + return first_view["image"], [], "" + + +def measure( + processed_data, measure_points, current_view_selector, event: gr.SelectData +): + """Handle measurement on images""" try: - if not processed_data: + print(f"Measure function called with selector: {current_view_selector}") + + if processed_data is None or len(processed_data) == 0: return None, [], "No data available" + + # Use the currently selected view instead of always using the first view try: current_view_index = int(current_view_selector.split()[1]) - 1 except: current_view_index = 0 - current_view_index = max(0, min(current_view_index, len(processed_data) - 1)) + + print(f"Using view index: {current_view_index}") + + # Get view data safely + if current_view_index < 0 or current_view_index >= len(processed_data): + current_view_index = 0 + view_keys = list(processed_data.keys()) current_view = processed_data[view_keys[current_view_index]] + if current_view is None: return None, [], "No view data available" point2d = event.index[0], event.index[1] - if current_view["mask"] is not None and 0 <= point2d[1] < current_view["mask"].shape[0] and 0 <= point2d[0] < current_view["mask"].shape[1]: + print(f"Clicked point: {point2d}") + + # Check if the clicked point is in a masked area (prevent interaction) + if ( + current_view["mask"] is not None + and 0 <= point2d[1] < current_view["mask"].shape[0] + and 0 <= point2d[0] < current_view["mask"].shape[1] + ): + # Check if the point is in a masked (invalid) area if not current_view["mask"][point2d[1], point2d[0]]: - masked_image, _ = update_measure_view(processed_data, current_view_index) - return masked_image, measure_points, 'Cannot measure on masked areas' + print(f"Clicked point {point2d} is in masked area, ignoring click") + # Always return image with mask overlay + masked_image, _ = update_measure_view( + processed_data, current_view_index + ) + return ( + masked_image, + measure_points, + 'Cannot measure on masked areas (shown in grey)', + ) measure_points.append(point2d) + + # Get image with mask overlay and ensure it's valid image, _ = update_measure_view(processed_data, current_view_index) if image is None: return None, [], "No image available" + image = image.copy() - if image.dtype != np.uint8: - image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8) points3d = current_view["points3d"] - for p in measure_points: - if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]: - cv2.circle(image, p, radius=5, color=(0, 200, 180), thickness=2) + # Ensure image is in uint8 format for proper cv2 operations + try: + if image.dtype != np.uint8: + if image.max() <= 1.0: + # Image is in [0, 1] range, convert to [0, 255] + image = (image * 255).astype(np.uint8) + else: + # Image is already in [0, 255] range + image = image.astype(np.uint8) + except Exception as e: + print(f"Image conversion error: {e}") + return None, [], f"Image conversion error: {e}" + + # Draw circles for points + try: + for p in measure_points: + if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]: + image = cv2.circle( + image, p, radius=5, color=(255, 0, 0), thickness=2 + ) + except Exception as e: + print(f"Drawing error: {e}") + return None, [], f"Drawing error: {e}" depth_text = "" - for i, p in enumerate(measure_points): - if current_view["depth"] is not None and 0 <= p[1] < current_view["depth"].shape[0] and 0 <= p[0] < current_view["depth"].shape[1]: - depth_text += f"- **P{i+1} depth: {current_view['depth'][p[1], p[0]]:.2f}m**\n" - elif points3d is not None and 0 <= p[1] < points3d.shape[0] and 0 <= p[0] < points3d.shape[1]: - depth_text += f"- **P{i+1} Z-coord: {points3d[p[1], p[0], 2]:.2f}m**\n" + try: + for i, p in enumerate(measure_points): + if ( + current_view["depth"] is not None + and 0 <= p[1] < current_view["depth"].shape[0] + and 0 <= p[0] < current_view["depth"].shape[1] + ): + d = current_view["depth"][p[1], p[0]] + depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n" + else: + # Use Z coordinate of 3D points if depth not available + if ( + points3d is not None + and 0 <= p[1] < points3d.shape[0] + and 0 <= p[0] < points3d.shape[1] + ): + z = points3d[p[1], p[0], 2] + depth_text += f"- **P{i + 1} Z-coord: {z:.2f}m.**\n" + except Exception as e: + print(f"Depth text error: {e}") + depth_text = f"Error computing depth: {e}\n" if len(measure_points) == 2: - point1, point2 = measure_points - if 0 <= point1[0] < image.shape[1] and 0 <= point1[1] < image.shape[0] and 0 <= point2[0] < image.shape[1] and 0 <= point2[1] < image.shape[0]: - cv2.line(image, point1, point2, color=(0, 200, 180), thickness=2) - distance_text = "- **Distance: Unable to compute**" - if points3d is not None and 0 <= point1[1] < points3d.shape[0] and 0 <= point1[0] < points3d.shape[1] and 0 <= point2[1] < points3d.shape[0] and 0 <= point2[0] < points3d.shape[1]: - try: - distance = np.linalg.norm(points3d[point1[1], point1[0]] - points3d[point2[1], point2[0]]) - distance_text = f"- **Distance: {distance:.2f}m**" - except Exception as e: - distance_text = f"- **Distance error: {e}**" - return image, [], depth_text + distance_text - return image, measure_points, depth_text + try: + point1, point2 = measure_points + # Draw line + if ( + 0 <= point1[0] < image.shape[1] + and 0 <= point1[1] < image.shape[0] + and 0 <= point2[0] < image.shape[1] + and 0 <= point2[1] < image.shape[0] + ): + image = cv2.line( + image, point1, point2, color=(255, 0, 0), thickness=2 + ) + + # Compute 3D distance + distance_text = "- **Distance: Unable to compute**" + if ( + points3d is not None + and 0 <= point1[1] < points3d.shape[0] + and 0 <= point1[0] < points3d.shape[1] + and 0 <= point2[1] < points3d.shape[0] + and 0 <= point2[0] < points3d.shape[1] + ): + try: + p1_3d = points3d[point1[1], point1[0]] + p2_3d = points3d[point2[1], point2[0]] + distance = np.linalg.norm(p1_3d - p2_3d) + distance_text = f"- **Distance: {distance:.2f}m**" + except Exception as e: + print(f"Distance computation error: {e}") + distance_text = f"- **Distance computation error: {e}**" + + measure_points = [] + text = depth_text + distance_text + print(f"Measurement complete: {text}") + return [image, measure_points, text] + except Exception as e: + print(f"Final measurement error: {e}") + return None, [], f"Measurement error: {e}" + else: + print(f"Single point measurement: {depth_text}") + return [image, measure_points, depth_text] + except Exception as e: - print(f"Measure error: {e}") - return None, [], f"Error: {e}" + print(f"Overall measure function error: {e}") + return None, [], f"Measure function error: {e}" def clear_fields(): + """ + Clears the 3D viewer, the stored target_dir, and empties the gallery. + """ return None def update_log(): - return "⟳ Initialising reconstruction pipeline..." - - -def update_visualization(target_dir, frame_filter, show_cam, is_example, filter_black_bg=False, filter_white_bg=False, show_mesh=True): + """ + Display a quick log message while waiting. + """ + return "Loading and Reconstructing..." + + +def update_visualization( + target_dir, + frame_filter, + show_cam, + is_example, + filter_black_bg=False, + filter_white_bg=False, + show_mesh=True, +): + """ + Reload saved predictions from npz, create (or reuse) the GLB for new parameters, + wrap it in a Rerun recording (.rrd), and return it for the Rerun viewer. + """ + + # If it's an example click, skip as requested if is_example == "True": - return gr.update(), "No reconstruction available. Please click Reconstruct first." + return ( + gr.update(), + "No reconstruction available. Please click the Reconstruct button first.", + ) + if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): - return gr.update(), "No reconstruction available. Please upload first." + return ( + gr.update(), + "No reconstruction available. Please click the Reconstruct button first.", + ) + predictions_path = os.path.join(target_dir, "predictions.npz") if not os.path.exists(predictions_path): - return gr.update(), "No reconstruction available. Please run Reconstruct first." + return ( + gr.update(), + f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.", + ) loaded = np.load(predictions_path, allow_pickle=True) predictions = {key: loaded[key] for key in loaded.keys()} + glbfile = os.path.join( target_dir, f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb", ) + if not os.path.exists(glbfile): - glbscene = predictions_to_glb(predictions, filter_by_frames=frame_filter, show_cam=show_cam, mask_black_bg=filter_black_bg, mask_white_bg=filter_white_bg, as_mesh=show_mesh) + glbscene = predictions_to_glb( + predictions, + filter_by_frames=frame_filter, + show_cam=show_cam, + mask_black_bg=filter_black_bg, + mask_white_bg=filter_white_bg, + as_mesh=show_mesh, + ) glbscene.export(file_obj=glbfile) - rrd_path = predictions_to_rrd(predictions, glbfile, target_dir, frame_filter, show_cam) - return rrd_path, "✓ Visualization updated." + + # Generate the Rerun recording using the helper + rrd_path = create_rerun_recording(glbfile, target_dir) + + return ( + rrd_path, # Was glbfile + "Visualization updated.", + ) -def update_all_views_on_filter_change(target_dir, filter_black_bg, filter_white_bg, processed_data, depth_view_selector, normal_view_selector, measure_view_selector): +def update_all_views_on_filter_change( + target_dir, + filter_black_bg, + filter_white_bg, + processed_data, + depth_view_selector, + normal_view_selector, + measure_view_selector, +): + """ + Update all individual view tabs when background filtering checkboxes change. + This regenerates the processed data with new filtering and updates all views. + """ + # Check if we have a valid target directory and predictions if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): return processed_data, None, None, None, [] + predictions_path = os.path.join(target_dir, "predictions.npz") if not os.path.exists(predictions_path): return processed_data, None, None, None, [] + try: + # Load the original predictions and views loaded = np.load(predictions_path, allow_pickle=True) predictions = {key: loaded[key] for key in loaded.keys()} - views = load_images(os.path.join(target_dir, "images")) - new_processed_data = process_predictions_for_visualization(predictions, views, high_level_config, filter_black_bg, filter_white_bg) - def parse_idx(sel, default=0): - try: - return int(sel.split()[1]) - 1 - except: - return default + # Load images using MapAnything's load_images function + image_folder_path = os.path.join(target_dir, "images") + views = load_images(image_folder_path) + + # Regenerate processed data with new filtering settings + new_processed_data = process_predictions_for_visualization( + predictions, views, high_level_config, filter_black_bg, filter_white_bg + ) + + # Get current view indices + try: + depth_view_idx = ( + int(depth_view_selector.split()[1]) - 1 if depth_view_selector else 0 + ) + except: + depth_view_idx = 0 + + try: + normal_view_idx = ( + int(normal_view_selector.split()[1]) - 1 if normal_view_selector else 0 + ) + except: + normal_view_idx = 0 + + try: + measure_view_idx = ( + int(measure_view_selector.split()[1]) - 1 + if measure_view_selector + else 0 + ) + except: + measure_view_idx = 0 + + # Update all views with new filtered data + depth_vis = update_depth_view(new_processed_data, depth_view_idx) + normal_vis = update_normal_view(new_processed_data, normal_view_idx) + measure_img, _ = update_measure_view(new_processed_data, measure_view_idx) - depth_vis = update_depth_view(new_processed_data, parse_idx(depth_view_selector)) - normal_vis = update_normal_view(new_processed_data, parse_idx(normal_view_selector)) - measure_img, _ = update_measure_view(new_processed_data, parse_idx(measure_view_selector)) return new_processed_data, depth_vis, normal_vis, measure_img, [] + except Exception as e: print(f"Error updating views on filter change: {e}") return processed_data, None, None, None, [] +# ------------------------------------------------------------------------- +# Example scene functions +# ------------------------------------------------------------------------- def get_scene_info(examples_dir): + """Get information about scenes in the examples directory""" import glob + scenes = [] if not os.path.exists(examples_dir): return scenes + for scene_folder in sorted(os.listdir(examples_dir)): scene_path = os.path.join(examples_dir, scene_folder) if os.path.isdir(scene_path): + # Find all image files in the scene folder + image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"] image_files = [] - for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"]: + for ext in image_extensions: image_files.extend(glob.glob(os.path.join(scene_path, ext))) image_files.extend(glob.glob(os.path.join(scene_path, ext.upper()))) + if image_files: + # Sort images and get the first one for thumbnail image_files = sorted(image_files) - scenes.append({"name": scene_folder, "path": scene_path, "thumbnail": image_files[0], "num_images": len(image_files), "image_files": image_files}) + first_image = image_files[0] + num_images = len(image_files) + + scenes.append( + { + "name": scene_folder, + "path": scene_path, + "thumbnail": first_image, + "num_images": num_images, + "image_files": image_files, + } + ) + return scenes def load_example_scene(scene_name, examples_dir="examples"): + """Load a scene from examples directory""" scenes = get_scene_info(examples_dir) - selected_scene = next((s for s in scenes if s["name"] == scene_name), None) + + # Find the selected scene + selected_scene = None + for scene in scenes: + if scene["name"] == scene_name: + selected_scene = scene + break + if selected_scene is None: return None, None, None, "Scene not found" - target_dir, image_paths = handle_uploads(selected_scene["image_files"], 1.0) - return None, target_dir, image_paths, f"✓ Loaded '{scene_name}' — {selected_scene['num_images']} images. Click Reconstruct to begin." - - -# ───────────────────────────────────────────── -# Gradio UI -# ───────────────────────────────────────────── -with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Base()) as demo: - - # ── State ── - is_example = gr.Textbox(visible=False, value="None") - num_images = gr.Textbox(visible=False, value="None") - processed_data_state = gr.State(value=None) - measure_points_state = gr.State(value=[]) - current_view_index = gr.State(value=0) - target_dir_output = gr.Textbox(visible=False, value="None") - - # ── Hero Header ── - with gr.Row(elem_id="hero-header"): - with gr.Column(): - gr.HTML('
facebook / map-anything-v1
') - gr.Markdown("# Map-Anything-v1", elem_id="hero-title") - gr.Markdown("Metric 3D reconstruction · Point cloud & camera poses", elem_id="hero-sub") - - # ── Main body ── - with gr.Row(elem_id="main-body"): - - # ── LEFT PANEL ── - with gr.Column(elem_id="left-panel", scale=1): - - gr.Markdown("**Input**", elem_classes=["section-label"]) - - unified_upload = gr.File( - file_count="multiple", - label="Upload Video or Images", - interactive=True, - file_types=["image", "video"], - ) - s_time_interval = gr.Slider( - minimum=0.1, maximum=5.0, value=1.0, step=0.1, - label="Video sample interval (sec)", - interactive=True, - ) - resample_btn = gr.Button("↺ Resample Video", visible=False, variant="secondary", size="sm") - - image_gallery = gr.Gallery( - label="Preview", - columns=3, - height="200px", - object_fit="contain", - preview=True, - ) + # Create file-like objects for the unified upload system + # Convert image file paths to the format expected by unified_upload + file_objects = [] + for image_path in selected_scene["image_files"]: + file_objects.append(image_path) - gr.ClearButton( - [unified_upload, image_gallery], - value="Clear Uploads", - variant="secondary", - size="sm", - ) + # Create target directory and copy images using the unified upload system + target_dir, image_paths = handle_uploads(file_objects, 1.0) - gr.HTML('
') + return ( + None, # Clear reconstruction output + target_dir, # Set target directory + image_paths, # Set gallery + f"Loaded scene '{scene_name}' with {selected_scene['num_images']} images. Click 'Reconstruct' to begin 3D processing.", + ) - # ── Pointcloud options ── - with gr.Group(elem_classes=["options-group"]): - gr.Markdown("**Live Options**", elem_classes=["options-title"]) - with gr.Row(): - show_cam = gr.Checkbox(label="Cameras", value=True) - show_mesh = gr.Checkbox(label="Mesh", value=True) - with gr.Row(): - filter_black_bg = gr.Checkbox(label="Filter Black BG", value=False) - filter_white_bg = gr.Checkbox(label="Filter White BG", value=False) - frame_filter = gr.Dropdown(choices=["All"], value="All", label="Show Points from Frame") +# ------------------------------------------------------------------------- +# 6) Build Gradio UI +# ------------------------------------------------------------------------- +theme = get_gradio_theme() + +with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo: + # State variables + is_example = gr.Textbox(label="is_example", visible=False, value="None") + num_images = gr.Textbox(label="num_images", visible=False, value="None") + processed_data_state = gr.State(value=None) + measure_points_state = gr.State(value=[]) + current_view_index = gr.State(value=0) + target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None") + + # --- Header Area --- + with gr.Column(elem_id="header-container"): + gr.Markdown( + "
" + "

🗺️ Map-Anything-v1

" + "

Metric 3D Reconstruction (Point Cloud and Camera Poses)

" + "
" + ) + gr.Markdown("---") + + # --- Main App Layout --- + with gr.Row(): + + # LEFT COLUMN (Sidebar / Controls) + with gr.Column(scale=1, min_width=350): + + with gr.Group(): + gr.Markdown("### 📁 1. Input Media") + unified_upload = gr.File( + file_count="multiple", + label="Upload Video or Images", + interactive=True, + file_types=["image", "video"], + ) + with gr.Row(): + s_time_interval = gr.Slider( + minimum=0.1, + maximum=5.0, + value=1.0, + step=0.1, + label="Video sample interval (sec)", + interactive=True, + visible=True, + ) + resample_btn = gr.Button("Resample", visible=False, variant="secondary") + + image_gallery = gr.Gallery( + label="Preview", + columns=4, + height="200px", + object_fit="contain", + preview=True, + ) + clear_uploads_btn = gr.ClearButton( + [unified_upload, image_gallery], + value="Clear Uploads", + variant="secondary", + size="sm", + ) - # ── Reconstruction options ── - with gr.Group(elem_classes=["options-group"]): - gr.Markdown("**Reconstruction Options**", elem_classes=["options-title"]) + with gr.Group(): + gr.Markdown("### ⚙️ 2. Reconstruction Settings") apply_mask_checkbox = gr.Checkbox( - label="Apply ambiguity mask & edge mask", + label="Apply mask (depth classes & edges)", value=True, ) + + with gr.Row(): + submit_btn = gr.Button("🚀 Reconstruct", variant="primary", scale=2) + clear_btn = gr.ClearButton( + [ + unified_upload, + target_dir_output, + image_gallery, + ], + value="Clear All", + scale=1, + ) - gr.HTML('
') - - # ── Action buttons ── - with gr.Row(elem_id="reconstruct-btn"): - submit_btn = gr.Button("⬡ Reconstruct", variant="primary", scale=2) - gr.ClearButton( - [unified_upload, target_dir_output, image_gallery], - value="✕", - variant="secondary", - scale=1, + with gr.Accordion("🎨 Visualization Options", open=True): + gr.Markdown("*Note: Updates automatically applied to viewer.*") + frame_filter = gr.Dropdown( + choices=["All"], value="All", label="Show Points from Frame" ) + show_cam = gr.Checkbox(label="Show Camera Paths", value=True) + show_mesh = gr.Checkbox(label="Show 3D Mesh", value=True) + filter_black_bg = gr.Checkbox(label="Filter Black Background", value=False) + filter_white_bg = gr.Checkbox(label="Filter White Background", value=False) - # ── RIGHT PANEL ── - with gr.Column(elem_id="right-panel", scale=3): - - # Status bar - log_output = gr.Markdown( - "Upload images or a video, then click **Reconstruct**.", - elem_id="log-output", - ) - # Tabbed viewer + # RIGHT COLUMN (Main Viewer Area) + with gr.Column(scale=2, min_width=600): + log_output = gr.Markdown("Status: **Ready**. Please upload media or select an example scene below.", elem_classes=["custom-log"]) + with gr.Tabs(): - with gr.Tab("⬡ 3D View"): + with gr.Tab("3D View"): reconstruction_output = Rerun( label="Rerun 3D Viewer", - height=540, + height=600, ) - - with gr.Tab("⬡ Depth"): - with gr.Row(): - prev_depth_btn = gr.Button("◀", size="sm", scale=1) + with gr.Tab("Depth"): + with gr.Row(elem_classes=["navigation-row"]): + prev_depth_btn = gr.Button("◀ Previous", size="sm", scale=1) depth_view_selector = gr.Dropdown( - choices=["View 1"], value="View 1", label="View", - scale=4, interactive=True, allow_custom_value=True, + choices=["View 1"], + value="View 1", + label="Select View", + scale=2, + interactive=True, + allow_custom_value=True, ) - next_depth_btn = gr.Button("▶", size="sm", scale=1) - depth_map = gr.Image(type="numpy", label="Depth Map", format="png", interactive=False) - - with gr.Tab("⬡ Normal"): - with gr.Row(): - prev_normal_btn = gr.Button("◀", size="sm", scale=1) + next_depth_btn = gr.Button("Next ▶", size="sm", scale=1) + depth_map = gr.Image( + type="numpy", + label="Colorized Depth Map", + format="png", + interactive=False, + ) + with gr.Tab("Normal"): + with gr.Row(elem_classes=["navigation-row"]): + prev_normal_btn = gr.Button("◀ Previous", size="sm", scale=1) normal_view_selector = gr.Dropdown( - choices=["View 1"], value="View 1", label="View", - scale=4, interactive=True, allow_custom_value=True, + choices=["View 1"], + value="View 1", + label="Select View", + scale=2, + interactive=True, + allow_custom_value=True, ) - next_normal_btn = gr.Button("▶", size="sm", scale=1) - normal_map = gr.Image(type="numpy", label="Normal Map", format="png", interactive=False) - - with gr.Tab("⬡ Measure"): + next_normal_btn = gr.Button("Next ▶", size="sm", scale=1) + normal_map = gr.Image( + type="numpy", + label="Normal Map", + format="png", + interactive=False, + ) + with gr.Tab("Measure"): gr.Markdown(MEASURE_INSTRUCTIONS_HTML) - with gr.Row(): - prev_measure_btn = gr.Button("◀", size="sm", scale=1) + with gr.Row(elem_classes=["navigation-row"]): + prev_measure_btn = gr.Button("◀ Previous", size="sm", scale=1) measure_view_selector = gr.Dropdown( - choices=["View 1"], value="View 1", label="View", - scale=4, interactive=True, allow_custom_value=True, + choices=["View 1"], + value="View 1", + label="Select View", + scale=2, + interactive=True, + allow_custom_value=True, ) - next_measure_btn = gr.Button("▶", size="sm", scale=1) - measure_image = gr.Image(type="numpy", show_label=False, format="webp", interactive=False, sources=[]) - gr.Markdown("*Light-grey areas = no depth data — measurements unavailable there.*") + next_measure_btn = gr.Button("Next ▶", size="sm", scale=1) + measure_image = gr.Image( + type="numpy", + show_label=False, + format="webp", + interactive=False, + sources=[], + ) + gr.Markdown("**Note:** Light-grey areas indicate regions with no depth information where measurements cannot be taken.") measure_text = gr.Markdown("") - # ── Example Scenes ── - gr.HTML(""" -
-
- Example Scenes -
-
- Click a thumbnail to load the scene -
-
- """) + # --- Footer Area (Example Scenes) --- + gr.Markdown("---") + gr.Markdown("## 🌟 Example Scenes") + gr.Markdown("Click any thumbnail below to load a sample dataset for reconstruction.") scenes = get_scene_info("examples") + if scenes: - for i in range(0, len(scenes), 4): + for i in range(0, len(scenes), 4): with gr.Row(): for j in range(4): scene_idx = i + j @@ -1293,91 +1332,199 @@ with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Base()) as demo: with gr.Column(scale=1, elem_classes=["clickable-thumbnail"]): scene_img = gr.Image( value=scene["thumbnail"], - height=140, + height=150, interactive=False, show_label=False, elem_id=f"scene_thumb_{scene['name']}", sources=[], ) - gr.Markdown(f"**{scene['name']}**\n{scene['num_images']} images", elem_classes=["scene-info"]) + gr.Markdown( + f"**{scene['name']}** \n {scene['num_images']} images", + elem_classes=["scene-info"], + ) + # Clicking an example bypasses the manual process and loads everything automatically scene_img.select( fn=lambda name=scene["name"]: load_example_scene(name), - outputs=[reconstruction_output, target_dir_output, image_gallery, log_output], + outputs=[ + reconstruction_output, # To clear old view + target_dir_output, + image_gallery, + log_output, + ], ) else: with gr.Column(scale=1): pass - # ───────────────────────────────────────────── - # Event wiring - # ───────────────────────────────────────────── + # ========================================================================= + # Event Bindings & Logic + # ========================================================================= + submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then( fn=update_log, inputs=[], outputs=[log_output] ).then( fn=gradio_demo, - inputs=[target_dir_output, frame_filter, show_cam, filter_black_bg, filter_white_bg, apply_mask_checkbox, show_mesh], - outputs=[reconstruction_output, log_output, frame_filter, processed_data_state, depth_map, normal_map, measure_image, measure_text, depth_view_selector, normal_view_selector, measure_view_selector], - ).then(fn=lambda: "False", inputs=[], outputs=[is_example]) - - frame_filter.change(update_visualization, [target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh], [reconstruction_output, log_output]) - show_cam.change(update_visualization, [target_dir_output, frame_filter, show_cam, is_example], [reconstruction_output, log_output]) + inputs=[ + target_dir_output, + frame_filter, + show_cam, + filter_black_bg, + filter_white_bg, + apply_mask_checkbox, + show_mesh, + ], + outputs=[ + reconstruction_output, + log_output, + frame_filter, + processed_data_state, + depth_map, + normal_map, + measure_image, + measure_text, + depth_view_selector, + normal_view_selector, + measure_view_selector, + ], + ).then( + fn=lambda: "False", + inputs=[], + outputs=[is_example], + ) + # Real-time Visualization Updates + frame_filter.change( + update_visualization, + [target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh], + [reconstruction_output, log_output], + ) + show_cam.change( + update_visualization, + [target_dir_output, frame_filter, show_cam, is_example], + [reconstruction_output, log_output], + ) filter_black_bg.change( - update_visualization, [target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg], [reconstruction_output, log_output] - ).then(update_all_views_on_filter_change, [target_dir_output, filter_black_bg, filter_white_bg, processed_data_state, depth_view_selector, normal_view_selector, measure_view_selector], [processed_data_state, depth_map, normal_map, measure_image, measure_points_state]) - + update_visualization, + [target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg], + [reconstruction_output, log_output], + ).then( + fn=update_all_views_on_filter_change, + inputs=[target_dir_output, filter_black_bg, filter_white_bg, processed_data_state, depth_view_selector, normal_view_selector, measure_view_selector], + outputs=[processed_data_state, depth_map, normal_map, measure_image, measure_points_state], + ) filter_white_bg.change( - update_visualization, [target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh], [reconstruction_output, log_output] - ).then(update_all_views_on_filter_change, [target_dir_output, filter_black_bg, filter_white_bg, processed_data_state, depth_view_selector, normal_view_selector, measure_view_selector], [processed_data_state, depth_map, normal_map, measure_image, measure_points_state]) - - show_mesh.change(update_visualization, [target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh], [reconstruction_output, log_output]) + update_visualization, + [target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh], + [reconstruction_output, log_output], + ).then( + fn=update_all_views_on_filter_change, + inputs=[target_dir_output, filter_black_bg, filter_white_bg, processed_data_state, depth_view_selector, normal_view_selector, measure_view_selector], + outputs=[processed_data_state, depth_map, normal_map, measure_image, measure_points_state], + ) + show_mesh.change( + update_visualization, + [target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh], + [reconstruction_output, log_output], + ) + # Auto-update gallery on upload def update_gallery_on_unified_upload(files, interval): if not files: - return None, None, None + return None, None, "Ready for upload." target_dir, image_paths = handle_uploads(files, interval) - return target_dir, image_paths, "✓ Upload complete. Click Reconstruct to begin." + return target_dir, image_paths, "Upload complete. Click '🚀 Reconstruct' to begin 3D processing." def show_resample_button(files): - if not files: - return gr.update(visible=False) - video_extensions = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"] - has_video = any(os.path.splitext(str(f["name"] if isinstance(f, dict) else f))[1].lower() in video_extensions for f in files) + if not files: return gr.update(visible=False) + video_exts = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"] + has_video = False + for f_data in files: + f_path = str(f_data["name"] if isinstance(f_data, dict) else f_data) + if os.path.splitext(f_path)[1].lower() in video_exts: + has_video = True + break return gr.update(visible=has_video) def resample_video_with_new_interval(files, new_interval, current_target_dir): - if not files: - return current_target_dir, None, "No files to resample.", gr.update(visible=False) - video_extensions = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"] - has_video = any(os.path.splitext(str(f["name"] if isinstance(f, dict) else f))[1].lower() in video_extensions for f in files) - if not has_video: - return current_target_dir, None, "No videos to resample.", gr.update(visible=False) + if not files: return current_target_dir, None, "No files to resample.", gr.update(visible=False) + video_exts = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"] + has_video = any(os.path.splitext(str(f["name"] if isinstance(f, dict) else f))[1].lower() in video_exts for f in files) + + if not has_video: return current_target_dir, None, "No videos found.", gr.update(visible=False) + if current_target_dir and current_target_dir != "None" and os.path.exists(current_target_dir): shutil.rmtree(current_target_dir) + target_dir, image_paths = handle_uploads(files, new_interval) - return target_dir, image_paths, f"✓ Resampled at {new_interval}s interval. Click Reconstruct.", gr.update(visible=False) + return target_dir, image_paths, f"Video resampled ({new_interval}s interval). Click '🚀 Reconstruct'.", gr.update(visible=False) unified_upload.change( fn=update_gallery_on_unified_upload, inputs=[unified_upload, s_time_interval], outputs=[target_dir_output, image_gallery, log_output], - ).then(fn=show_resample_button, inputs=[unified_upload], outputs=[resample_btn]) + ).then( + fn=show_resample_button, + inputs=[unified_upload], + outputs=[resample_btn], + ) - s_time_interval.change(fn=show_resample_button, inputs=[unified_upload], outputs=[resample_btn]) - resample_btn.click(fn=resample_video_with_new_interval, inputs=[unified_upload, s_time_interval, target_dir_output], outputs=[target_dir_output, image_gallery, log_output, resample_btn]) + s_time_interval.change( + fn=show_resample_button, + inputs=[unified_upload], + outputs=[resample_btn], + ) + + resample_btn.click( + fn=resample_video_with_new_interval, + inputs=[unified_upload, s_time_interval, target_dir_output], + outputs=[target_dir_output, image_gallery, log_output, resample_btn], + ) - measure_image.select(fn=measure, inputs=[processed_data_state, measure_points_state, measure_view_selector], outputs=[measure_image, measure_points_state, measure_text]) + # Measure Interactions + measure_image.select( + fn=measure, + inputs=[processed_data_state, measure_points_state, measure_view_selector], + outputs=[measure_image, measure_points_state, measure_text], + ) - prev_depth_btn.click(fn=lambda pd, sel: navigate_depth_view(pd, sel, -1), inputs=[processed_data_state, depth_view_selector], outputs=[depth_view_selector, depth_map]) - next_depth_btn.click(fn=lambda pd, sel: navigate_depth_view(pd, sel, 1), inputs=[processed_data_state, depth_view_selector], outputs=[depth_view_selector, depth_map]) - depth_view_selector.change(fn=lambda pd, sel: update_depth_view(pd, int(sel.split()[1]) - 1) if sel else None, inputs=[processed_data_state, depth_view_selector], outputs=[depth_map]) + # Tab Navigations + prev_depth_btn.click( + fn=lambda d, s: navigate_depth_view(d, s, -1), + inputs=[processed_data_state, depth_view_selector], outputs=[depth_view_selector, depth_map], + ) + next_depth_btn.click( + fn=lambda d, s: navigate_depth_view(d, s, 1), + inputs=[processed_data_state, depth_view_selector], outputs=[depth_view_selector, depth_map], + ) + depth_view_selector.change( + fn=lambda d, s: update_depth_view(d, int(s.split()[1]) - 1) if s else None, + inputs=[processed_data_state, depth_view_selector], outputs=[depth_map], + ) - prev_normal_btn.click(fn=lambda pd, sel: navigate_normal_view(pd, sel, -1), inputs=[processed_data_state, normal_view_selector], outputs=[normal_view_selector, normal_map]) - next_normal_btn.click(fn=lambda pd, sel: navigate_normal_view(pd, sel, 1), inputs=[processed_data_state, normal_view_selector], outputs=[normal_view_selector, normal_map]) - normal_view_selector.change(fn=lambda pd, sel: update_normal_view(pd, int(sel.split()[1]) - 1) if sel else None, inputs=[processed_data_state, normal_view_selector], outputs=[normal_map]) + prev_normal_btn.click( + fn=lambda d, s: navigate_normal_view(d, s, -1), + inputs=[processed_data_state, normal_view_selector], outputs=[normal_view_selector, normal_map], + ) + next_normal_btn.click( + fn=lambda d, s: navigate_normal_view(d, s, 1), + inputs=[processed_data_state, normal_view_selector], outputs=[normal_view_selector, normal_map], + ) + normal_view_selector.change( + fn=lambda d, s: update_normal_view(d, int(s.split()[1]) - 1) if s else None, + inputs=[processed_data_state, normal_view_selector], outputs=[normal_map], + ) - prev_measure_btn.click(fn=lambda pd, sel: navigate_measure_view(pd, sel, -1), inputs=[processed_data_state, measure_view_selector], outputs=[measure_view_selector, measure_image, measure_points_state]) - next_measure_btn.click(fn=lambda pd, sel: navigate_measure_view(pd, sel, 1), inputs=[processed_data_state, measure_view_selector], outputs=[measure_view_selector, measure_image, measure_points_state]) - measure_view_selector.change(fn=lambda pd, sel: update_measure_view(pd, int(sel.split()[1]) - 1) if sel else (None, []), inputs=[processed_data_state, measure_view_selector], outputs=[measure_image, measure_points_state]) + prev_measure_btn.click( + fn=lambda d, s: navigate_measure_view(d, s, -1), + inputs=[processed_data_state, measure_view_selector], outputs=[measure_view_selector, measure_image, measure_points_state], + ) + next_measure_btn.click( + fn=lambda d, s: navigate_measure_view(d, s, 1), + inputs=[processed_data_state, measure_view_selector], outputs=[measure_view_selector, measure_image, measure_points_state], + ) + measure_view_selector.change( + fn=lambda d, s: update_measure_view(d, int(s.split()[1]) - 1) if s else (None, []), + inputs=[processed_data_state, measure_view_selector], outputs=[measure_image, measure_points_state], + ) - demo.queue(max_size=20).launch(css=CUSTOM_CSS, show_error=True, share=True, ssr_mode=False) \ No newline at end of file + demo.queue(max_size=20).launch(theme=theme, css=GRADIO_CSS, show_error=True, share=True, ssr_mode=False) \ No newline at end of file