Map-Anything-v1 / app.py
prithivMLmods's picture
update app [theme]
a78ad30 verified
import gc
import os
import shutil
import sys
import time
import uuid
from datetime import datetime
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import cv2
import gradio as gr
import numpy as np
import spaces
import torch
from PIL import Image
from pillow_heif import register_heif_opener
import rerun as rr
try:
import rerun.blueprint as rrb
except ImportError:
rrb = None
from gradio_rerun import Rerun
register_heif_opener()
sys.path.append("mapanything/")
from mapanything.utils.geometry import depthmap_to_world_frame, points_to_normals
from mapanything.utils.hf_utils.css_and_html import (
GRADIO_CSS,
MEASURE_INSTRUCTIONS_HTML,
get_acknowledgements_html,
get_description_html,
get_gradio_theme,
get_header_html,
)
from mapanything.utils.hf_utils.hf_helpers import initialize_mapanything_model
from mapanything.utils.hf_utils.viz import predictions_to_glb
from mapanything.utils.image import load_images, rgb
from typing import Iterable
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
# ── Steel-Blue palette ──────────────────────────────────────────────
colors.steel_blue = colors.Color(
name="steel_blue",
c50="#EBF3F8",
c100="#D3E5F0",
c200="#A8CCE1",
c300="#7DB3D2",
c400="#529AC3",
c500="#4682B4",
c600="#3E72A0",
c700="#36638C",
c800="#2E5378",
c900="#264364",
c950="#1E3450",
)
class SteelBlueTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.gray,
secondary_hue: colors.Color | str = colors.steel_blue,
neutral_hue: colors.Color | str = colors.slate,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
),
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
text_size=text_size,
font=font,
font_mono=font_mono,
)
super().set(
background_fill_primary="*primary_50",
background_fill_primary_dark="*primary_900",
body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
button_primary_text_color="white",
button_primary_text_color_hover="white",
button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)",
button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)",
button_secondary_text_color="black",
button_secondary_text_color_hover="white",
button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
slider_color="*secondary_500",
slider_color_dark="*secondary_600",
block_title_text_weight="600",
block_border_width="3px",
block_shadow="*shadow_drop_lg",
button_primary_shadow="*shadow_drop_lg",
button_large_padding="11px",
color_accent_soft="*primary_100",
block_label_background_fill="*primary_200",
)
steel_blue_theme = SteelBlueTheme()
SVG_CUBE = '<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor"><path stroke-linecap="round" stroke-linejoin="round" d="m21 7.5-9-5.25L3 7.5m18 0-9 5.25m9-5.25v9l-9 5.25M3 7.5l9 5.25M3 7.5v9l9 5.25m0-9v9"/></svg>'
SVG_CHIP = '<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor"><path stroke-linecap="round" stroke-linejoin="round" d="M8.25 3v1.5M4.5 8.25H3m18 0h-1.5M4.5 12H3m18 0h-1.5m-15 3.75H3m18 0h-1.5M8.25 19.5V21M12 3v1.5m0 15V21m3.75-18v1.5m0 15V21m-9-1.5h10.5a2.25 2.25 0 0 0 2.25-2.25V6.75a2.25 2.25 0 0 0-2.25-2.25H6.75A2.25 2.25 0 0 0 4.5 6.75v10.5a2.25 2.25 0 0 0 2.25 2.25Z"/></svg>'
def html_header():
return f"""
<div class="app-header">
<div class="header-content">
<div class="header-icon-wrap">{SVG_CUBE}</div>
<div class="header-text">
<h1>Map-Anything &mdash; v1</h1>
<div class="header-meta">
<span class="meta-badge">{SVG_CHIP} facebook/map-anything-v1</span>
<span class="meta-sep"></span>
<span class="meta-cap">3D Reconstruction</span>
<span class="meta-sep"></span>
<span class="meta-cap">Depth Estimation</span>
<span class="meta-sep"></span>
<span class="meta-cap">Normal Maps</span>
<span class="meta-sep"></span>
<span class="meta-cap">Measurements</span>
</div>
</div>
</div>
</div>
"""
high_level_config = {
"path": "configs/train.yaml",
"hf_model_name": "facebook/map-anything-v1",
"model_str": "mapanything",
"config_overrides": [
"machine=aws",
"model=mapanything",
"model/task=images_only",
"model.encoder.uses_torch_hub=false",
],
"checkpoint_name": "model.safetensors",
"config_name": "config.json",
"trained_with_amp": True,
"trained_with_amp_dtype": "bf16",
"data_norm_type": "dinov2",
"patch_size": 14,
"resolution": 518,
}
model = None
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
os.makedirs(TMP_DIR, exist_ok=True)
CUSTOM_CSS = (GRADIO_CSS or "") + r"""
@import url('https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;500;600;700;800&family=IBM+Plex+Mono:wght@400;500;600&display=swap');
body, .gradio-container { font-family: 'Outfit', sans-serif !important; }
footer { display: none !important; }
/* ── App Header ── */
.app-header {
background: linear-gradient(135deg, #1E3450 0%, #264364 30%, #3E72A0 70%, #4682B4 100%);
border-radius: 16px;
padding: 32px 40px;
margin-bottom: 24px;
position: relative;
overflow: hidden;
box-shadow: 0 8px 32px rgba(30, 52, 80, 0.35);
}
.app-header::before {
content: '';
position: absolute;
top: -50%;
right: -20%;
width: 400px;
height: 400px;
background: radial-gradient(circle, rgba(255, 255, 255, 0.06) 0%, transparent 70%);
border-radius: 50%;
}
.app-header::after {
content: '';
position: absolute;
bottom: -30%;
left: -10%;
width: 300px;
height: 300px;
background: radial-gradient(circle, rgba(70, 130, 180, 0.15) 0%, transparent 70%);
border-radius: 50%;
}
.header-content {
display: flex;
align-items: center;
gap: 24px;
position: relative;
z-index: 1;
}
.header-icon-wrap {
width: 64px;
height: 64px;
background: rgba(255, 255, 255, 0.12);
border-radius: 16px;
display: flex;
align-items: center;
justify-content: center;
flex-shrink: 0;
backdrop-filter: blur(8px);
border: 1px solid rgba(255, 255, 255, 0.15);
}
/* ── Force header SVGs white in ALL modes ── */
.header-icon-wrap svg,
.app-header svg {
width: 36px;
height: 36px;
color: #ffffff !important;
stroke: #ffffff !important;
}
.meta-badge svg {
width: 14px !important;
height: 14px !important;
color: #ffffff !important;
stroke: #ffffff !important;
}
.header-text h1 {
font-family: 'Outfit', sans-serif;
font-size: 2rem;
font-weight: 700;
color: #fff !important;
margin: 0 0 8px 0;
letter-spacing: -0.02em;
line-height: 1.2;
}
.header-meta {
display: flex;
align-items: center;
gap: 12px;
flex-wrap: wrap;
}
.meta-badge {
display: inline-flex;
align-items: center;
gap: 6px;
background: rgba(255, 255, 255, 0.12);
color: rgba(255, 255, 255, 0.9) !important;
padding: 4px 12px;
border-radius: 20px;
font-family: 'IBM Plex Mono', monospace;
font-size: 0.8rem;
font-weight: 500;
border: 1px solid rgba(255, 255, 255, 0.1);
backdrop-filter: blur(4px);
}
.meta-sep {
width: 4px;
height: 4px;
background: rgba(255, 255, 255, 0.35);
border-radius: 50%;
flex-shrink: 0;
}
.meta-cap {
color: rgba(255, 255, 255, 0.65) !important;
font-size: 0.85rem;
font-weight: 400;
}
/* ── Page shell ── */
#app-shell {
max-width: 1400px;
margin: 0 auto;
padding: 0 16px 40px;
}
/* ── Two-panel layout ── */
#left-panel { min-width: 320px; max-width: 380px; }
#right-panel { flex: 1; min-width: 0; }
/* ── Section labels ── */
.section-label {
font-size: 0.7rem !important;
font-weight: 600 !important;
letter-spacing: 0.08em !important;
text-transform: uppercase !important;
opacity: 0.5 !important;
margin-bottom: 6px !important;
margin-top: 16px !important;
display: block !important;
}
/* ── Upload zone ── */
#upload-zone .wrap {
border-radius: 10px !important;
min-height: 110px !important;
}
/* ── Gallery ── */
#preview-gallery { border-radius: 10px; overflow: hidden; }
/* ── Action buttons ── */
#btn-reconstruct {
width: 100% !important;
font-size: 0.95rem !important;
font-weight: 600 !important;
padding: 12px !important;
border-radius: 8px !important;
}
/* ── Buttons ── */
.primary {
border-radius: 10px !important;
font-weight: 600 !important;
letter-spacing: 0.02em !important;
transition: all 0.25s ease !important;
font-family: 'Outfit', sans-serif !important;
}
.primary:hover {
transform: translateY(-2px) !important;
box-shadow: 0 6px 20px rgba(70, 130, 180, 0.35) !important;
}
.primary:active { transform: translateY(0) !important; }
/* ── Log strip ── */
#log-strip {
font-size: 0.82rem !important;
padding: 8px 12px !important;
border-radius: 6px !important;
border: 1px solid var(--border-color-primary) !important;
background: var(--background-fill-secondary) !important;
min-height: 36px !important;
}
/* ── Viewer tabs ── */
#viewer-tabs .tab-nav button {
font-size: 0.8rem !important;
font-weight: 500 !important;
padding: 6px 14px !important;
}
#viewer-tabs > .tabitem { padding: 0 !important; }
/* ── Tab transitions ── */
.gradio-tabitem { animation: tabFadeIn 0.35s ease-out; }
@keyframes tabFadeIn {
from { opacity: 0; transform: translateY(6px); }
to { opacity: 1; transform: translateY(0); }
}
/* ── Navigation rows inside tabs ── */
.nav-row { align-items: center !important; gap: 6px !important; margin-bottom: 8px !important; }
.nav-row button { min-width: 80px !important; }
/* ── Options panel ── */
#options-panel {
border: 1px solid var(--border-color-primary);
border-radius: 10px;
padding: 16px;
margin-top: 12px;
}
#options-panel .gr-markdown h3 {
font-size: 0.72rem !important;
font-weight: 600 !important;
letter-spacing: 0.07em !important;
text-transform: uppercase !important;
opacity: 0.5 !important;
margin: 14px 0 6px !important;
}
#options-panel .gr-markdown h3:first-child { margin-top: 0 !important; }
/* ── Frame filter ── */
#frame-filter { margin-top: 12px; }
/* ── Examples section ── */
#examples-section {
margin-top: 36px;
padding-top: 24px;
border-top: 1px solid var(--border-color-primary);
}
#examples-section h2 {
font-size: 1.1rem !important;
font-weight: 600 !important;
margin-bottom: 4px !important;
}
#examples-section .scene-caption {
font-size: 0.75rem !important;
text-align: center !important;
opacity: 0.65 !important;
margin-top: 4px !important;
}
.scene-thumb img { border-radius: 8px; transition: opacity .15s; }
.scene-thumb img:hover { opacity: .85; }
/* ── Measure note ── */
.measure-note {
font-size: 0.78rem !important;
opacity: 0.6 !important;
margin-top: 6px !important;
}
#col-container {
margin: 0 auto;
max-width: 960px;
}
/* ── Accordion ── */
.gradio-accordion {
border-radius: 10px !important;
border: 1px solid rgba(70, 130, 180, 0.2) !important;
}
.gradio-accordion > .label-wrap { border-radius: 10px !important; }
/* ── Labels ── */
label {
font-weight: 600 !important;
font-family: 'Outfit', sans-serif !important;
}
/* ── Slider ── */
.gradio-slider input[type="range"] { accent-color: #4682B4 !important; }
/* ── Scrollbar ── */
::-webkit-scrollbar { width: 8px; height: 8px; }
::-webkit-scrollbar-track { background: rgba(70, 130, 180, 0.06); border-radius: 4px; }
::-webkit-scrollbar-thumb {
background: linear-gradient(135deg, #4682B4, #3E72A0);
border-radius: 4px;
}
::-webkit-scrollbar-thumb:hover {
background: linear-gradient(135deg, #3E72A0, #2E5378);
}
/* ── Dark-mode overrides for header (keep text/SVG white) ── */
@media (prefers-color-scheme: dark) {
.app-header {
background: linear-gradient(135deg, #1E3450 0%, #264364 30%, #3E72A0 70%, #4682B4 100%);
}
.header-text h1 { color: #fff !important; }
.header-icon-wrap svg,
.app-header svg,
.meta-badge svg {
color: #ffffff !important;
stroke: #ffffff !important;
}
.meta-badge { color: rgba(255, 255, 255, 0.9) !important; }
.meta-cap { color: rgba(255, 255, 255, 0.65) !important; }
}
/* Also handle Gradio's own .dark class */
.dark .header-text h1 { color: #fff !important; }
.dark .header-icon-wrap svg,
.dark .app-header svg,
.dark .meta-badge svg {
color: #ffffff !important;
stroke: #ffffff !important;
}
.dark .meta-badge { color: rgba(255, 255, 255, 0.9) !important; }
.dark .meta-cap { color: rgba(255, 255, 255, 0.65) !important; }
/* ── Responsive ── */
@media (max-width: 768px) {
.app-header { padding: 20px 24px; }
.header-text h1 { font-size: 1.5rem; }
.header-content {
flex-direction: column;
align-items: flex-start;
gap: 16px;
}
.header-meta { gap: 8px; }
}
"""
def predictions_to_rrd(predictions, glbfile, target_dir, frame_filter="All", show_cam=True):
run_id = str(uuid.uuid4())
timestamp = datetime.now().strftime("%Y-%m-%dT%H%M%S")
rrd_path = os.path.join(target_dir, f"mapanything_{timestamp}.rrd")
rec = None
if hasattr(rr, "new_recording"):
rec = rr.new_recording(application_id="MapAnything-3D-Viewer", recording_id=run_id)
elif hasattr(rr, "RecordingStream"):
rec = rr.RecordingStream(application_id="MapAnything-3D-Viewer", recording_id=run_id)
else:
rr.init("MapAnything-3D-Viewer", recording_id=run_id, spawn=False)
rec = rr
rec.log("world", rr.Clear(recursive=True), static=True)
rec.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, static=True)
try:
rec.log("world/axes/x", rr.Arrows3D(vectors=[[0.5, 0, 0]], colors=[[255, 0, 0]]), static=True)
rec.log("world/axes/y", rr.Arrows3D(vectors=[[0, 0.5, 0]], colors=[[0, 255, 0]]), static=True)
rec.log("world/axes/z", rr.Arrows3D(vectors=[[0, 0, 0.5]], colors=[[0, 0, 255]]), static=True)
except Exception:
pass
rec.log("world/model", rr.Asset3D(path=glbfile), static=True)
if show_cam and "extrinsic" in predictions and "intrinsic" in predictions:
try:
extrinsics = predictions["extrinsic"]
intrinsics = predictions["intrinsic"]
for i, (ext, intr) in enumerate(zip(extrinsics, intrinsics)):
translation = ext[:3, 3]
rotation_mat = ext[:3, :3]
rec.log(
f"world/cameras/cam_{i:03d}",
rr.Transform3D(translation=translation, mat3x3=rotation_mat),
static=True,
)
fx, fy = intr[0, 0], intr[1, 1]
cx, cy = intr[0, 2], intr[1, 2]
if "images" in predictions and i < len(predictions["images"]):
h, w = predictions["images"][i].shape[:2]
else:
h, w = 518, 518
rec.log(
f"world/cameras/cam_{i:03d}/image",
rr.Pinhole(focal_length=[fx, fy], principal_point=[cx, cy], width=w, height=h),
static=True,
)
if "images" in predictions and i < len(predictions["images"]):
img = predictions["images"][i]
if img.dtype != np.uint8:
img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
rec.log(f"world/cameras/cam_{i:03d}/image/rgb", rr.Image(img), static=True)
except Exception as e:
print(f"Camera logging failed (non-fatal): {e}")
if "world_points" in predictions and "images" in predictions:
try:
world_points = predictions["world_points"]
images = predictions["images"]
final_mask = predictions.get("final_mask")
all_points, all_colors = [], []
for i in range(len(world_points)):
pts = world_points[i]
img = images[i]
mask = final_mask[i].astype(bool) if final_mask is not None else np.ones(pts.shape[:2], dtype=bool)
pts_flat = pts[mask]
img_flat = img[mask]
if img_flat.dtype != np.uint8:
img_flat = (np.clip(img_flat, 0, 1) * 255).astype(np.uint8)
all_points.append(pts_flat)
all_colors.append(img_flat)
if all_points:
all_points = np.concatenate(all_points, axis=0)
all_colors = np.concatenate(all_colors, axis=0)
max_pts = 500_000
if len(all_points) > max_pts:
idx = np.random.choice(len(all_points), max_pts, replace=False)
all_points = all_points[idx]
all_colors = all_colors[idx]
rec.log("world/point_cloud", rr.Points3D(positions=all_points, colors=all_colors, radii=0.002), static=True)
except Exception as e:
print(f"Point cloud logging failed (non-fatal): {e}")
if rrb is not None:
try:
blueprint = rrb.Blueprint(
rrb.Spatial3DView(origin="/world", name="3D View"),
collapse_panels=True,
)
rec.send_blueprint(blueprint)
except Exception as e:
print(f"Blueprint creation failed (non-fatal): {e}")
rec.save(rrd_path)
return rrd_path
@spaces.GPU(duration=120)
def run_model(target_dir, apply_mask=True, mask_edges=True, filter_black_bg=False, filter_white_bg=False):
global model
import torch
print(f"Processing images from {target_dir}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if model is None:
model = initialize_mapanything_model(high_level_config, device)
else:
model = model.to(device)
model.eval()
print("Loading images...")
image_folder_path = os.path.join(target_dir, "images")
views = load_images(image_folder_path)
print(f"Loaded {len(views)} images")
if len(views) == 0:
raise ValueError("No images found. Check your upload.")
print("Running inference...")
outputs = model.infer(views, apply_mask=apply_mask, mask_edges=True, memory_efficient_inference=False)
predictions = {}
extrinsic_list, intrinsic_list, world_points_list = [], [], []
depth_maps_list, images_list, final_mask_list = [], [], []
for pred in outputs:
depthmap_torch = pred["depth_z"][0].squeeze(-1)
intrinsics_torch = pred["intrinsics"][0]
camera_pose_torch = pred["camera_poses"][0]
pts3d_computed, valid_mask = depthmap_to_world_frame(depthmap_torch, intrinsics_torch, camera_pose_torch)
mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool) if "mask" in pred else np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool)
mask = mask & valid_mask.cpu().numpy()
image = pred["img_no_norm"][0].cpu().numpy()
extrinsic_list.append(camera_pose_torch.cpu().numpy())
intrinsic_list.append(intrinsics_torch.cpu().numpy())
world_points_list.append(pts3d_computed.cpu().numpy())
depth_maps_list.append(depthmap_torch.cpu().numpy())
images_list.append(image)
final_mask_list.append(mask)
predictions["extrinsic"] = np.stack(extrinsic_list, axis=0)
predictions["intrinsic"] = np.stack(intrinsic_list, axis=0)
predictions["world_points"] = np.stack(world_points_list, axis=0)
depth_maps = np.stack(depth_maps_list, axis=0)
if len(depth_maps.shape) == 3:
depth_maps = depth_maps[..., np.newaxis]
predictions["depth"] = depth_maps
predictions["images"] = np.stack(images_list, axis=0)
predictions["final_mask"] = np.stack(final_mask_list, axis=0)
processed_data = process_predictions_for_visualization(predictions, views, high_level_config, filter_black_bg, filter_white_bg)
torch.cuda.empty_cache()
return predictions, processed_data
def update_view_selectors(processed_data):
choices = [f"View {i + 1}" for i in range(len(processed_data))] if processed_data else ["View 1"]
return (
gr.Dropdown(choices=choices, value=choices[0]),
gr.Dropdown(choices=choices, value=choices[0]),
gr.Dropdown(choices=choices, value=choices[0]),
)
def get_view_data_by_index(processed_data, view_index):
if not processed_data:
return None
view_keys = list(processed_data.keys())
view_index = max(0, min(view_index, len(view_keys) - 1))
return processed_data[view_keys[view_index]]
def update_depth_view(processed_data, view_index):
view_data = get_view_data_by_index(processed_data, view_index)
if view_data is None or view_data["depth"] is None:
return None
return colorize_depth(view_data["depth"], mask=view_data.get("mask"))
def update_normal_view(processed_data, view_index):
view_data = get_view_data_by_index(processed_data, view_index)
if view_data is None or view_data["normal"] is None:
return None
return colorize_normal(view_data["normal"], mask=view_data.get("mask"))
def update_measure_view(processed_data, view_index):
view_data = get_view_data_by_index(processed_data, view_index)
if view_data is None:
return None, []
image = view_data["image"].copy()
if image.dtype != np.uint8:
image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8)
if view_data["mask"] is not None:
invalid_mask = ~view_data["mask"]
if invalid_mask.any():
overlay_color = np.array([255, 220, 220], dtype=np.uint8)
alpha = 0.5
for c in range(3):
image[:, :, c] = np.where(
invalid_mask,
(1 - alpha) * image[:, :, c] + alpha * overlay_color[c],
image[:, :, c],
).astype(np.uint8)
return image, []
def navigate_depth_view(processed_data, current_selector_value, direction):
if not processed_data:
return "View 1", None
try:
current_view = int(current_selector_value.split()[1]) - 1
except Exception:
current_view = 0
new_view = (current_view + direction) % len(processed_data)
return f"View {new_view + 1}", update_depth_view(processed_data, new_view)
def navigate_normal_view(processed_data, current_selector_value, direction):
if not processed_data:
return "View 1", None
try:
current_view = int(current_selector_value.split()[1]) - 1
except Exception:
current_view = 0
new_view = (current_view + direction) % len(processed_data)
return f"View {new_view + 1}", update_normal_view(processed_data, new_view)
def navigate_measure_view(processed_data, current_selector_value, direction):
if not processed_data:
return "View 1", None, []
try:
current_view = int(current_selector_value.split()[1]) - 1
except Exception:
current_view = 0
new_view = (current_view + direction) % len(processed_data)
measure_image, measure_points = update_measure_view(processed_data, new_view)
return f"View {new_view + 1}", measure_image, measure_points
def populate_visualization_tabs(processed_data):
if not processed_data:
return None, None, None, []
return (
update_depth_view(processed_data, 0),
update_normal_view(processed_data, 0),
update_measure_view(processed_data, 0)[0],
[],
)
def handle_uploads(unified_upload, s_time_interval=1.0):
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
target_dir = f"input_images_{timestamp}"
target_dir_images = os.path.join(target_dir, "images")
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
os.makedirs(target_dir_images)
image_paths = []
video_extensions = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"]
if unified_upload is not None:
for file_data in unified_upload:
file_path = file_data["name"] if isinstance(file_data, dict) and "name" in file_data else str(file_data)
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext in video_extensions:
vs = cv2.VideoCapture(file_path)
fps = vs.get(cv2.CAP_PROP_FPS)
frame_interval = int(fps * s_time_interval)
count, video_frame_num = 0, 0
while True:
gotit, frame = vs.read()
if not gotit:
break
count += 1
if count % frame_interval == 0:
base_name = os.path.splitext(os.path.basename(file_path))[0]
image_path = os.path.join(target_dir_images, f"{base_name}_{video_frame_num:06}.png")
cv2.imwrite(image_path, frame)
image_paths.append(image_path)
video_frame_num += 1
vs.release()
print(f"Extracted {video_frame_num} frames from: {os.path.basename(file_path)}")
elif file_ext in [".heic", ".heif"]:
try:
with Image.open(file_path) as img:
if img.mode not in ("RGB", "L"):
img = img.convert("RGB")
base_name = os.path.splitext(os.path.basename(file_path))[0]
dst_path = os.path.join(target_dir_images, f"{base_name}.jpg")
img.save(dst_path, "JPEG", quality=95)
image_paths.append(dst_path)
except Exception as e:
print(f"Error converting HEIC {file_path}: {e}")
dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
shutil.copy(file_path, dst_path)
image_paths.append(dst_path)
else:
dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
shutil.copy(file_path, dst_path)
image_paths.append(dst_path)
image_paths = sorted(image_paths)
print(f"Files processed to {target_dir_images}; took {time.time() - start_time:.3f}s")
return target_dir, image_paths
@spaces.GPU(duration=120)
def gradio_demo(target_dir, frame_filter="All", show_cam=True, filter_black_bg=False, filter_white_bg=False, apply_mask=True, show_mesh=True):
if not os.path.isdir(target_dir) or target_dir == "None":
return None, "No valid target directory found. Please upload first.", None, None
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
target_dir_images = os.path.join(target_dir, "images")
all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
all_files_labeled = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
frame_filter_choices = ["All"] + all_files_labeled
print("Running MapAnything model...")
with torch.no_grad():
predictions, processed_data = run_model(target_dir, apply_mask)
np.savez(os.path.join(target_dir, "predictions.npz"), **predictions)
if frame_filter is None:
frame_filter = "All"
glbfile = os.path.join(
target_dir,
f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb",
)
glbscene = predictions_to_glb(
predictions,
filter_by_frames=frame_filter,
show_cam=show_cam,
mask_black_bg=filter_black_bg,
mask_white_bg=filter_white_bg,
as_mesh=show_mesh,
)
glbscene.export(file_obj=glbfile)
rrd_path = predictions_to_rrd(predictions, glbfile, target_dir, frame_filter, show_cam)
del predictions
gc.collect()
torch.cuda.empty_cache()
print(f"Total time: {time.time() - start_time:.2f}s")
log_msg = f"βœ… Reconstruction complete β€” {len(all_files)} frames processed."
depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs(processed_data)
depth_selector, normal_selector, measure_selector = update_view_selectors(processed_data)
return (
rrd_path, log_msg,
gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
processed_data, depth_vis, normal_vis, measure_img, "",
depth_selector, normal_selector, measure_selector,
)
def colorize_depth(depth_map, mask=None):
if depth_map is None:
return None
depth_normalized = depth_map.copy()
valid_mask = depth_normalized > 0
if mask is not None:
valid_mask = valid_mask & mask
if valid_mask.sum() > 0:
valid_depths = depth_normalized[valid_mask]
p5, p95 = np.percentile(valid_depths, 5), np.percentile(valid_depths, 95)
depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5)
import matplotlib.pyplot as plt
colored = (plt.cm.turbo_r(depth_normalized)[:, :, :3] * 255).astype(np.uint8)
colored[~valid_mask] = [255, 255, 255]
return colored
def colorize_normal(normal_map, mask=None):
if normal_map is None:
return None
normal_vis = normal_map.copy()
if mask is not None:
normal_vis[~mask] = [0, 0, 0]
return ((normal_vis + 1.0) / 2.0 * 255).astype(np.uint8)
def process_predictions_for_visualization(predictions, views, high_level_config, filter_black_bg=False, filter_white_bg=False):
processed_data = {}
for view_idx, view in enumerate(views):
image = rgb(view["img"], norm_type=high_level_config["data_norm_type"])
pred_pts3d = predictions["world_points"][view_idx]
mask = predictions["final_mask"][view_idx].copy()
if filter_black_bg:
view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
mask = mask & (view_colors.sum(axis=2) >= 16)
if filter_white_bg:
view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
mask = mask & ~(
(view_colors[:, :, 0] > 240)
& (view_colors[:, :, 1] > 240)
& (view_colors[:, :, 2] > 240)
)
normals, _ = points_to_normals(pred_pts3d, mask=mask)
processed_data[view_idx] = {
"image": image[0],
"points3d": pred_pts3d,
"depth": predictions["depth"][view_idx].squeeze(),
"normal": normals,
"mask": mask,
}
return processed_data
def measure(processed_data, measure_points, current_view_selector, event: gr.SelectData):
try:
if not processed_data:
return None, [], "No data available"
try:
current_view_index = int(current_view_selector.split()[1]) - 1
except Exception:
current_view_index = 0
current_view_index = max(0, min(current_view_index, len(processed_data) - 1))
current_view = processed_data[list(processed_data.keys())[current_view_index]]
if current_view is None:
return None, [], "No view data available"
point2d = event.index[0], event.index[1]
if (
current_view["mask"] is not None
and 0 <= point2d[1] < current_view["mask"].shape[0]
and 0 <= point2d[0] < current_view["mask"].shape[1]
):
if not current_view["mask"][point2d[1], point2d[0]]:
masked_image, _ = update_measure_view(processed_data, current_view_index)
return masked_image, measure_points, '<span style="color: red; font-weight: bold;">Cannot measure on masked areas</span>'
measure_points.append(point2d)
image, _ = update_measure_view(processed_data, current_view_index)
if image is None:
return None, [], "No image available"
image = image.copy()
if image.dtype != np.uint8:
image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8)
points3d = current_view["points3d"]
for p in measure_points:
if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]:
image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2)
depth_text = ""
for i, p in enumerate(measure_points):
if (
current_view["depth"] is not None
and 0 <= p[1] < current_view["depth"].shape[0]
and 0 <= p[0] < current_view["depth"].shape[1]
):
depth_text += f"- **P{i + 1} depth: {current_view['depth'][p[1], p[0]]:.2f}m**\n"
elif (
points3d is not None
and 0 <= p[1] < points3d.shape[0]
and 0 <= p[0] < points3d.shape[1]
):
depth_text += f"- **P{i + 1} Z-coord: {points3d[p[1], p[0], 2]:.2f}m**\n"
if len(measure_points) == 2:
point1, point2 = measure_points
if all(
0 <= point1[0] < image.shape[1]
and 0 <= point1[1] < image.shape[0]
and 0 <= point2[0] < image.shape[1]
and 0 <= point2[1] < image.shape[0]
for _ in [1]
):
image = cv2.line(image, point1, point2, color=(255, 0, 0), thickness=2)
distance_text = "- **Distance: Unable to compute**"
if points3d is not None and all(
0 <= p[1] < points3d.shape[0] and 0 <= p[0] < points3d.shape[1]
for p in [point1, point2]
):
try:
distance = np.linalg.norm(
points3d[point1[1], point1[0]] - points3d[point2[1], point2[0]]
)
distance_text = f"- **Distance: {distance:.2f}m**"
except Exception as e:
distance_text = f"- **Distance error: {e}**"
return [image, [], depth_text + distance_text]
return [image, measure_points, depth_text]
except Exception as e:
print(f"Measure error: {e}")
return None, [], f"Error: {e}"
def clear_fields():
return None
def update_log():
return "⏳ Loading and reconstructing…"
def update_visualization(
target_dir, frame_filter, show_cam, is_example,
filter_black_bg=False, filter_white_bg=False, show_mesh=True,
):
if is_example == "True":
return gr.update(), "No reconstruction available. Please click Reconstruct first."
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
return gr.update(), "No reconstruction available. Please upload first."
predictions_path = os.path.join(target_dir, "predictions.npz")
if not os.path.exists(predictions_path):
return gr.update(), "No reconstruction found. Please run Reconstruct first."
loaded = np.load(predictions_path, allow_pickle=True)
predictions = {key: loaded[key] for key in loaded.keys()}
glbfile = os.path.join(
target_dir,
f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb",
)
if not os.path.exists(glbfile):
glbscene = predictions_to_glb(
predictions,
filter_by_frames=frame_filter,
show_cam=show_cam,
mask_black_bg=filter_black_bg,
mask_white_bg=filter_white_bg,
as_mesh=show_mesh,
)
glbscene.export(file_obj=glbfile)
rrd_path = predictions_to_rrd(predictions, glbfile, target_dir, frame_filter, show_cam)
return rrd_path, "Visualization updated."
def update_all_views_on_filter_change(
target_dir, filter_black_bg, filter_white_bg, processed_data,
depth_view_selector, normal_view_selector, measure_view_selector,
):
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
return processed_data, None, None, None, []
predictions_path = os.path.join(target_dir, "predictions.npz")
if not os.path.exists(predictions_path):
return processed_data, None, None, None, []
try:
loaded = np.load(predictions_path, allow_pickle=True)
predictions = {key: loaded[key] for key in loaded.keys()}
views = load_images(os.path.join(target_dir, "images"))
new_processed_data = process_predictions_for_visualization(
predictions, views, high_level_config, filter_black_bg, filter_white_bg,
)
def safe_idx(sel):
try:
return int(sel.split()[1]) - 1
except Exception:
return 0
depth_vis = update_depth_view(new_processed_data, safe_idx(depth_view_selector))
normal_vis = update_normal_view(new_processed_data, safe_idx(normal_view_selector))
measure_img, _ = update_measure_view(new_processed_data, safe_idx(measure_view_selector))
return new_processed_data, depth_vis, normal_vis, measure_img, []
except Exception as e:
print(f"Filter change error: {e}")
return processed_data, None, None, None, []
def get_scene_info(examples_dir):
import glob
scenes = []
if not os.path.exists(examples_dir):
return scenes
for scene_folder in sorted(os.listdir(examples_dir)):
scene_path = os.path.join(examples_dir, scene_folder)
if os.path.isdir(scene_path):
image_files = []
for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"]:
image_files.extend(glob.glob(os.path.join(scene_path, ext)))
image_files.extend(glob.glob(os.path.join(scene_path, ext.upper())))
if image_files:
image_files = sorted(image_files)
scenes.append({
"name": scene_folder,
"path": scene_path,
"thumbnail": image_files[0],
"num_images": len(image_files),
"image_files": image_files,
})
return scenes
def load_example_scene(scene_name, examples_dir="examples"):
scenes = get_scene_info(examples_dir)
selected_scene = next((s for s in scenes if s["name"] == scene_name), None)
if selected_scene is None:
return None, None, None, "Scene not found"
target_dir, image_paths = handle_uploads(selected_scene["image_files"], 1.0)
return None, target_dir, image_paths, f"Loaded '{scene_name}' β€” {selected_scene['num_images']} images. Click Reconstruct."
with gr.Blocks() as demo:
is_example = gr.Textbox(visible=False, value="None")
num_images = gr.Textbox(visible=False, value="None")
processed_data_state = gr.State(value=None)
measure_points_state = gr.State(value=[])
target_dir_output = gr.Textbox(visible=False, value="None")
with gr.Column(elem_id="app-shell"):
# ── New styled header ──
gr.HTML(html_header())
with gr.Row(equal_height=False):
# ── Left Panel ──
with gr.Column(elem_id="left-panel", scale=0):
unified_upload = gr.File(
file_count="multiple",
label="Upload Images/Videos",
file_types=["image", "video"],
height="150",
)
with gr.Row():
s_time_interval = gr.Slider(
minimum=0.1, maximum=5.0, value=1.0, step=0.1,
label="Video interval (sec)",
scale=3,
)
resample_btn = gr.Button("Resample", visible=False, variant="secondary", scale=1)
image_gallery = gr.Gallery(
columns=2,
height="150",
)
gr.ClearButton(
[unified_upload, image_gallery],
value="Clear uploads",
variant="secondary",
size="sm",
)
submit_btn = gr.Button("Reconstruct", variant="primary")
with gr.Accordion("Options", open=False):
gr.Markdown("### Point Cloud")
show_cam = gr.Checkbox(label="Show cameras", value=True)
show_mesh = gr.Checkbox(label="Show mesh", value=True)
filter_black_bg = gr.Checkbox(label="Filter black background", value=False)
filter_white_bg = gr.Checkbox(label="Filter white background", value=False)
gr.Markdown("### Reconstruction (next run)")
apply_mask_checkbox = gr.Checkbox(
label="Apply ambiguous-depth mask & edges", value=True,
)
# ── Right Panel ──
with gr.Column(elem_id="right-panel", scale=1):
log_output = gr.Markdown(
"Upload a video or images, then click **Reconstruct**.",
elem_id="log-strip",
)
with gr.Tabs(elem_id="viewer-tabs"):
with gr.Tab("3D View"):
reconstruction_output = Rerun(
label="Rerun 3D Viewer",
height=672,
)
with gr.Tab("Depth"):
with gr.Row(elem_classes=["nav-row"]):
prev_depth_btn = gr.Button("β—€ Prev", size="sm", scale=1)
depth_view_selector = gr.Dropdown(
choices=["View 1"], value="View 1",
label="View", scale=3, interactive=True,
allow_custom_value=True, show_label=False,
)
next_depth_btn = gr.Button("Next β–Ά", size="sm", scale=1)
depth_map = gr.Image(
type="numpy", label="Depth Map",
format="png", interactive=False,
)
with gr.Tab("Normal"):
with gr.Row(elem_classes=["nav-row"]):
prev_normal_btn = gr.Button("β—€ Prev", size="sm", scale=1)
normal_view_selector = gr.Dropdown(
choices=["View 1"], value="View 1",
label="View", scale=3, interactive=True,
allow_custom_value=True, show_label=False,
)
next_normal_btn = gr.Button("Next β–Ά", size="sm", scale=1)
normal_map = gr.Image(
type="numpy", label="Normal Map",
format="png", interactive=False,
)
with gr.Tab("Measure"):
gr.Markdown(MEASURE_INSTRUCTIONS_HTML)
with gr.Row(elem_classes=["nav-row"]):
prev_measure_btn = gr.Button("β—€ Prev", size="sm", scale=1)
measure_view_selector = gr.Dropdown(
choices=["View 1"], value="View 1",
label="View", scale=3, interactive=True,
allow_custom_value=True, show_label=False,
)
next_measure_btn = gr.Button("Next β–Ά", size="sm", scale=1)
measure_image = gr.Image(
type="numpy", show_label=False,
format="webp", interactive=False, sources=[],
)
gr.Markdown(
"Light-grey areas have no depth β€” measurements cannot be placed there.",
elem_classes=["measure-note"],
)
measure_text = gr.Markdown("")
with gr.Column():
frame_filter = gr.Dropdown(
choices=["All"], value="All", label="Filter by Frame",
show_label=True,
)
with gr.Column(elem_id="examples-section"):
gr.Markdown("## Example Scenes")
gr.Markdown("Click a thumbnail to load the scene, then press **Reconstruct**.")
scenes = get_scene_info("examples")
if scenes:
for i in range(0, len(scenes), 4):
with gr.Row():
for j in range(4):
idx = i + j
if idx < len(scenes):
scene = scenes[idx]
with gr.Column(scale=1, min_width=140, elem_classes=["scene-thumb"]):
scene_img = gr.Image(
value=scene["thumbnail"],
height=130,
interactive=False,
show_label=False,
sources=[],
)
gr.Markdown(
f"**{scene['name']}** \n{scene['num_images']} imgs",
elem_classes=["scene-caption"],
)
scene_img.select(
fn=lambda name=scene["name"]: load_example_scene(name),
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
)
else:
with gr.Column(scale=1, min_width=140):
pass
submit_btn.click(
fn=clear_fields, inputs=[], outputs=[reconstruction_output],
).then(
fn=update_log, inputs=[], outputs=[log_output],
).then(
fn=gradio_demo,
inputs=[target_dir_output, frame_filter, show_cam, filter_black_bg, filter_white_bg, apply_mask_checkbox, show_mesh],
outputs=[
reconstruction_output, log_output, frame_filter, processed_data_state,
depth_map, normal_map, measure_image, measure_text,
depth_view_selector, normal_view_selector, measure_view_selector,
],
).then(fn=lambda: "False", inputs=[], outputs=[is_example])
for trigger_inputs, trigger in [
([target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh], frame_filter.change),
([target_dir_output, frame_filter, show_cam, is_example], show_cam.change),
([target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh], show_mesh.change),
]:
trigger(update_visualization, trigger_inputs, [reconstruction_output, log_output])
filter_black_bg.change(
update_visualization,
[target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg],
[reconstruction_output, log_output],
).then(
update_all_views_on_filter_change,
[target_dir_output, filter_black_bg, filter_white_bg, processed_data_state, depth_view_selector, normal_view_selector, measure_view_selector],
[processed_data_state, depth_map, normal_map, measure_image, measure_points_state],
)
filter_white_bg.change(
update_visualization,
[target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh],
[reconstruction_output, log_output],
).then(
update_all_views_on_filter_change,
[target_dir_output, filter_black_bg, filter_white_bg, processed_data_state, depth_view_selector, normal_view_selector, measure_view_selector],
[processed_data_state, depth_map, normal_map, measure_image, measure_points_state],
)
def update_gallery_on_unified_upload(files, interval):
if not files:
return None, None, None
target_dir, image_paths = handle_uploads(files, interval)
return target_dir, image_paths, "Upload complete. Click **Reconstruct** to begin."
def show_resample_button(files):
if not files:
return gr.update(visible=False)
video_exts = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"]
has_video = any(
os.path.splitext(str(f["name"] if isinstance(f, dict) else f))[1].lower() in video_exts
for f in files
)
return gr.update(visible=has_video)
def resample_video_with_new_interval(files, new_interval, current_target_dir):
if not files:
return current_target_dir, None, "No files to resample.", gr.update(visible=False)
video_exts = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"]
if not any(
os.path.splitext(str(f["name"] if isinstance(f, dict) else f))[1].lower() in video_exts
for f in files
):
return current_target_dir, None, "No videos found.", gr.update(visible=False)
if current_target_dir and current_target_dir != "None" and os.path.exists(current_target_dir):
shutil.rmtree(current_target_dir)
target_dir, image_paths = handle_uploads(files, new_interval)
return target_dir, image_paths, f"Resampled at {new_interval}s. Click **Reconstruct**.", gr.update(visible=False)
unified_upload.change(
fn=update_gallery_on_unified_upload,
inputs=[unified_upload, s_time_interval],
outputs=[target_dir_output, image_gallery, log_output],
).then(fn=show_resample_button, inputs=[unified_upload], outputs=[resample_btn])
s_time_interval.change(fn=show_resample_button, inputs=[unified_upload], outputs=[resample_btn])
resample_btn.click(
fn=resample_video_with_new_interval,
inputs=[unified_upload, s_time_interval, target_dir_output],
outputs=[target_dir_output, image_gallery, log_output, resample_btn],
)
measure_image.select(
fn=measure,
inputs=[processed_data_state, measure_points_state, measure_view_selector],
outputs=[measure_image, measure_points_state, measure_text],
)
prev_depth_btn.click(
fn=lambda pd, sel: navigate_depth_view(pd, sel, -1),
inputs=[processed_data_state, depth_view_selector],
outputs=[depth_view_selector, depth_map],
)
next_depth_btn.click(
fn=lambda pd, sel: navigate_depth_view(pd, sel, 1),
inputs=[processed_data_state, depth_view_selector],
outputs=[depth_view_selector, depth_map],
)
depth_view_selector.change(
fn=lambda pd, sel: update_depth_view(pd, int(sel.split()[1]) - 1) if sel else None,
inputs=[processed_data_state, depth_view_selector],
outputs=[depth_map],
)
prev_normal_btn.click(
fn=lambda pd, sel: navigate_normal_view(pd, sel, -1),
inputs=[processed_data_state, normal_view_selector],
outputs=[normal_view_selector, normal_map],
)
next_normal_btn.click(
fn=lambda pd, sel: navigate_normal_view(pd, sel, 1),
inputs=[processed_data_state, normal_view_selector],
outputs=[normal_view_selector, normal_map],
)
normal_view_selector.change(
fn=lambda pd, sel: update_normal_view(pd, int(sel.split()[1]) - 1) if sel else None,
inputs=[processed_data_state, normal_view_selector],
outputs=[normal_map],
)
prev_measure_btn.click(
fn=lambda pd, sel: navigate_measure_view(pd, sel, -1),
inputs=[processed_data_state, measure_view_selector],
outputs=[measure_view_selector, measure_image, measure_points_state],
)
next_measure_btn.click(
fn=lambda pd, sel: navigate_measure_view(pd, sel, 1),
inputs=[processed_data_state, measure_view_selector],
outputs=[measure_view_selector, measure_image, measure_points_state],
)
measure_view_selector.change(
fn=lambda pd, sel: update_measure_view(pd, int(sel.split()[1]) - 1) if sel else (None, []),
inputs=[processed_data_state, measure_view_selector],
outputs=[measure_image, measure_points_state],
)
demo.queue(max_size=50).launch(css=CUSTOM_CSS, theme=steel_blue_theme, show_error=True, share=True, ssr_mode=False, mcp_server=True)