map-anything / app.py
jcompanion
Fix: Let Gradio auto-name predictions endpoint
6f5085a
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# conda activate hf3.10
import json
import gc
import os
import shutil
import sys
import time
from datetime import datetime
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import cv2
import gradio as gr
import numpy as np
import spaces
import torch
from PIL import Image
from pillow_heif import register_heif_opener
register_heif_opener()
sys.path.append("mapanything/")
from mapanything.utils.geometry import depthmap_to_world_frame, points_to_normals
from mapanything.utils.hf_utils.css_and_html import (
GRADIO_CSS,
MEASURE_INSTRUCTIONS_HTML,
get_acknowledgements_html,
get_description_html,
get_gradio_theme,
get_header_html,
)
from mapanything.utils.hf_utils.hf_helpers import initialize_mapanything_model
from mapanything.utils.hf_utils.visual_util import predictions_to_glb
from mapanything.utils.image import load_images, rgb
def get_logo_base64():
"""Convert WAI logo to base64 for embedding in HTML"""
import base64
logo_path = "examples/WAI-Logo/wai_logo.png"
try:
with open(logo_path, "rb") as img_file:
img_data = img_file.read()
base64_str = base64.b64encode(img_data).decode()
return f"data:image/png;base64,{base64_str}"
except FileNotFoundError:
return None
def get_camera_poses_json(predictions):
"""Convert camera poses to JSON-serializable format
NOTE: In Map Anything, 'extrinsic' is actually a camera-to-world transform,
NOT world-to-camera as is common in OpenCV. This is confirmed by geometry.py
line 98 which uses: pts3d_world = einsum(camera_pose, pts3d_cam_homo, ...)
"""
cameras = []
if "extrinsic" in predictions and "intrinsic" in predictions:
extrinsic = predictions["extrinsic"] # Shape: (S, 4, 4) - camera-to-world!
intrinsic = predictions["intrinsic"] # Shape: (S, 3, 3)
for i in range(len(extrinsic)):
ext = extrinsic[i]
intr = intrinsic[i]
# Extract camera position (translation part of camera-to-world)
# Since extrinsic IS camera-to-world, position is just the translation column
cam_pos = ext[:3, 3].tolist()
# Extract camera forward direction (Z axis in world coords)
# The rotation part of camera-to-world transforms camera Z to world
cam_forward = ext[:3, 2].tolist() # Third column of rotation
# Calculate FOV from intrinsics
fx = float(intr[0, 0])
fy = float(intr[1, 1])
cx = float(intr[0, 2])
cy = float(intr[1, 2])
# Assume image width/height from principal point (cx, cy are typically at center)
width = int(cx * 2)
height = int(cy * 2)
fov_x = float(2 * np.arctan(width / (2 * fx)) * 180 / np.pi)
fov_y = float(2 * np.arctan(height / (2 * fy)) * 180 / np.pi)
cameras.append({
"index": i,
"cameraName": f"cam_{i}",
"extrinsic": ext.tolist(), # camera-to-world transform
"intrinsic": intr.tolist(),
"position": cam_pos,
"forward": cam_forward,
"fov": fov_y, # Vertical FOV
"fov_x": fov_x,
"width": width,
"height": height,
"fx": fx,
"fy": fy,
"cx": cx,
"cy": cy,
"note": "extrinsic is camera-to-world (not world-to-camera)"
})
return json.dumps(cameras)
# MapAnything Configuration
high_level_config = {
"path": "configs/train.yaml",
"hf_model_name": "facebook/map-anything",
"model_str": "mapanything",
"config_overrides": [
"machine=aws",
"model=mapanything",
"model/task=images_only",
"model.encoder.uses_torch_hub=false",
],
"checkpoint_name": "model.safetensors",
"config_name": "config.json",
"trained_with_amp": True,
"trained_with_amp_dtype": "bf16",
"data_norm_type": "dinov2",
"patch_size": 14,
"resolution": 518,
}
# Initialize model - this will be done on GPU when needed
model = None
# -------------------------------------------------------------------------
# 1) Core model inference
# -------------------------------------------------------------------------
@spaces.GPU(duration=120)
def run_model(
target_dir,
apply_mask=True,
mask_edges=True,
filter_black_bg=False,
filter_white_bg=False,
):
"""
Run the MapAnything model on images in the 'target_dir/images' folder and return predictions.
"""
global model
import torch # Ensure torch is available in function scope
print(f"Processing images from {target_dir}")
# Device check
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
# Initialize model if not already done
if model is None:
model = initialize_mapanything_model(high_level_config, device)
else:
model = model.to(device)
model.eval()
# Load images using MapAnything's load_images function
print("Loading images...")
image_folder_path = os.path.join(target_dir, "images")
views = load_images(image_folder_path)
print(f"Loaded {len(views)} images")
if len(views) == 0:
raise ValueError("No images found. Check your upload.")
# Run model inference
print("Running inference...")
# apply_mask: Whether to apply the non-ambiguous mask to the output. Defaults to True.
# mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True.
# Use checkbox values - mask_edges is set to True by default since there's no UI control for it
outputs = model.infer(
views, apply_mask=apply_mask, mask_edges=True, memory_efficient_inference=False
)
# Convert predictions to format expected by visualization
predictions = {}
# Initialize lists for the required keys
extrinsic_list = []
intrinsic_list = []
world_points_list = []
depth_maps_list = []
images_list = []
final_mask_list = []
# Loop through the outputs
for pred in outputs:
# Extract data from predictions
depthmap_torch = pred["depth_z"][0].squeeze(-1) # (H, W)
intrinsics_torch = pred["intrinsics"][0] # (3, 3)
camera_pose_torch = pred["camera_poses"][0] # (4, 4)
# Compute new pts3d using depth, intrinsics, and camera pose
pts3d_computed, valid_mask = depthmap_to_world_frame(
depthmap_torch, intrinsics_torch, camera_pose_torch
)
# Convert to numpy arrays for visualization
# Check if mask key exists in pred, if not, fill with boolean trues in the size of depthmap_torch
if "mask" in pred:
mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool)
else:
# Fill with boolean trues in the size of depthmap_torch
mask = np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool)
# Combine with valid depth mask
mask = mask & valid_mask.cpu().numpy()
image = pred["img_no_norm"][0].cpu().numpy()
# Append to lists
extrinsic_list.append(camera_pose_torch.cpu().numpy())
intrinsic_list.append(intrinsics_torch.cpu().numpy())
world_points_list.append(pts3d_computed.cpu().numpy())
depth_maps_list.append(depthmap_torch.cpu().numpy())
images_list.append(image) # Add image to list
final_mask_list.append(mask) # Add final_mask to list
# Convert lists to numpy arrays with required shapes
# extrinsic: (S, 3, 4) - batch of camera extrinsic matrices
predictions["extrinsic"] = np.stack(extrinsic_list, axis=0)
# intrinsic: (S, 3, 3) - batch of camera intrinsic matrices
predictions["intrinsic"] = np.stack(intrinsic_list, axis=0)
# world_points: (S, H, W, 3) - batch of 3D world points
predictions["world_points"] = np.stack(world_points_list, axis=0)
# depth: (S, H, W, 1) or (S, H, W) - batch of depth maps
depth_maps = np.stack(depth_maps_list, axis=0)
# Add channel dimension if needed to match (S, H, W, 1) format
if len(depth_maps.shape) == 3:
depth_maps = depth_maps[..., np.newaxis]
predictions["depth"] = depth_maps
# images: (S, H, W, 3) - batch of input images
predictions["images"] = np.stack(images_list, axis=0)
# final_mask: (S, H, W) - batch of final masks for filtering
predictions["final_mask"] = np.stack(final_mask_list, axis=0)
# Process data for visualization tabs (depth, normal, measure)
processed_data = process_predictions_for_visualization(
predictions, views, high_level_config, filter_black_bg, filter_white_bg
)
# Clean up
torch.cuda.empty_cache()
return predictions, processed_data
def update_view_selectors(processed_data):
"""Update view selector dropdowns based on available views"""
if processed_data is None or len(processed_data) == 0:
choices = ["View 1"]
else:
num_views = len(processed_data)
choices = [f"View {i + 1}" for i in range(num_views)]
return (
gr.Dropdown(choices=choices, value=choices[0]), # depth_view_selector
gr.Dropdown(choices=choices, value=choices[0]), # normal_view_selector
gr.Dropdown(choices=choices, value=choices[0]), # measure_view_selector
)
def get_view_data_by_index(processed_data, view_index):
"""Get view data by index, handling bounds"""
if processed_data is None or len(processed_data) == 0:
return None
view_keys = list(processed_data.keys())
if view_index < 0 or view_index >= len(view_keys):
view_index = 0
return processed_data[view_keys[view_index]]
def update_depth_view(processed_data, view_index):
"""Update depth view for a specific view index"""
view_data = get_view_data_by_index(processed_data, view_index)
if view_data is None or view_data["depth"] is None:
return None
return colorize_depth(view_data["depth"], mask=view_data.get("mask"))
def update_normal_view(processed_data, view_index):
"""Update normal view for a specific view index"""
view_data = get_view_data_by_index(processed_data, view_index)
if view_data is None or view_data["normal"] is None:
return None
return colorize_normal(view_data["normal"], mask=view_data.get("mask"))
def update_measure_view(processed_data, view_index):
"""Update measure view for a specific view index with mask overlay"""
view_data = get_view_data_by_index(processed_data, view_index)
if view_data is None:
return None, [] # image, measure_points
# Get the base image
image = view_data["image"].copy()
# Ensure image is in uint8 format
if image.dtype != np.uint8:
if image.max() <= 1.0:
image = (image * 255).astype(np.uint8)
else:
image = image.astype(np.uint8)
# Apply mask overlay if mask is available
if view_data["mask"] is not None:
mask = view_data["mask"]
# Create light grey overlay for masked areas
# Masked areas (False values) will be overlaid with light grey
invalid_mask = ~mask # Areas where mask is False
if invalid_mask.any():
# Create a light grey overlay (RGB: 192, 192, 192)
overlay_color = np.array([255, 220, 220], dtype=np.uint8)
# Apply overlay with some transparency
alpha = 0.5 # Transparency level
for c in range(3): # RGB channels
image[:, :, c] = np.where(
invalid_mask,
(1 - alpha) * image[:, :, c] + alpha * overlay_color[c],
image[:, :, c],
).astype(np.uint8)
return image, []
def navigate_depth_view(processed_data, current_selector_value, direction):
"""Navigate depth view (direction: -1 for previous, +1 for next)"""
if processed_data is None or len(processed_data) == 0:
return "View 1", None
# Parse current view number
try:
current_view = int(current_selector_value.split()[1]) - 1
except:
current_view = 0
num_views = len(processed_data)
new_view = (current_view + direction) % num_views
new_selector_value = f"View {new_view + 1}"
depth_vis = update_depth_view(processed_data, new_view)
return new_selector_value, depth_vis
def navigate_normal_view(processed_data, current_selector_value, direction):
"""Navigate normal view (direction: -1 for previous, +1 for next)"""
if processed_data is None or len(processed_data) == 0:
return "View 1", None
# Parse current view number
try:
current_view = int(current_selector_value.split()[1]) - 1
except:
current_view = 0
num_views = len(processed_data)
new_view = (current_view + direction) % num_views
new_selector_value = f"View {new_view + 1}"
normal_vis = update_normal_view(processed_data, new_view)
return new_selector_value, normal_vis
def navigate_measure_view(processed_data, current_selector_value, direction):
"""Navigate measure view (direction: -1 for previous, +1 for next)"""
if processed_data is None or len(processed_data) == 0:
return "View 1", None, []
# Parse current view number
try:
current_view = int(current_selector_value.split()[1]) - 1
except:
current_view = 0
num_views = len(processed_data)
new_view = (current_view + direction) % num_views
new_selector_value = f"View {new_view + 1}"
measure_image, measure_points = update_measure_view(processed_data, new_view)
return new_selector_value, measure_image, measure_points
def populate_visualization_tabs(processed_data):
"""Populate the depth, normal, and measure tabs with processed data"""
if processed_data is None or len(processed_data) == 0:
return None, None, None, []
# Use update functions to ensure confidence filtering is applied from the start
depth_vis = update_depth_view(processed_data, 0)
normal_vis = update_normal_view(processed_data, 0)
measure_img, _ = update_measure_view(processed_data, 0)
return depth_vis, normal_vis, measure_img, []
# -------------------------------------------------------------------------
# 2) Handle uploaded video/images --> produce target_dir + images
# -------------------------------------------------------------------------
def handle_uploads(unified_upload, s_time_interval=1.0):
"""
Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
images or extracted frames from video into it. Return (target_dir, image_paths).
"""
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
# Create a unique folder name
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
target_dir = f"input_images_{timestamp}"
target_dir_images = os.path.join(target_dir, "images")
# Clean up if somehow that folder already exists
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
os.makedirs(target_dir)
os.makedirs(target_dir_images)
image_paths = []
# --- Handle uploaded files (both images and videos) ---
if unified_upload is not None:
for file_data in unified_upload:
if isinstance(file_data, dict) and "name" in file_data:
file_path = file_data["name"]
else:
file_path = str(file_data)
file_ext = os.path.splitext(file_path)[1].lower()
# Check if it's a video file
video_extensions = [
".mp4",
".avi",
".mov",
".mkv",
".wmv",
".flv",
".webm",
".m4v",
".3gp",
]
if file_ext in video_extensions:
# Handle as video
vs = cv2.VideoCapture(file_path)
fps = vs.get(cv2.CAP_PROP_FPS)
frame_interval = int(fps * s_time_interval) # frames per interval
count = 0
video_frame_num = 0
while True:
gotit, frame = vs.read()
if not gotit:
break
count += 1
if count % frame_interval == 0:
# Use original filename as prefix for frames
base_name = os.path.splitext(os.path.basename(file_path))[0]
image_path = os.path.join(
target_dir_images, f"{base_name}_{video_frame_num:06}.png"
)
cv2.imwrite(image_path, frame)
image_paths.append(image_path)
video_frame_num += 1
vs.release()
print(
f"Extracted {video_frame_num} frames from video: {os.path.basename(file_path)}"
)
else:
# Handle as image
# Check if the file is a HEIC image
if file_ext in [".heic", ".heif"]:
# Convert HEIC to JPEG for better gallery compatibility
try:
with Image.open(file_path) as img:
# Convert to RGB if necessary (HEIC can have different color modes)
if img.mode not in ("RGB", "L"):
img = img.convert("RGB")
# Create JPEG filename
base_name = os.path.splitext(os.path.basename(file_path))[0]
dst_path = os.path.join(
target_dir_images, f"{base_name}.jpg"
)
# Save as JPEG with high quality
img.save(dst_path, "JPEG", quality=95)
image_paths.append(dst_path)
print(
f"Converted HEIC to JPEG: {os.path.basename(file_path)} -> {os.path.basename(dst_path)}"
)
except Exception as e:
print(f"Error converting HEIC file {file_path}: {e}")
# Fall back to copying as is
dst_path = os.path.join(
target_dir_images, os.path.basename(file_path)
)
shutil.copy(file_path, dst_path)
image_paths.append(dst_path)
else:
# Regular image files - copy as is
dst_path = os.path.join(
target_dir_images, os.path.basename(file_path)
)
shutil.copy(file_path, dst_path)
image_paths.append(dst_path)
# Sort final images for gallery
image_paths = sorted(image_paths)
end_time = time.time()
print(
f"Files processed to {target_dir_images}; took {end_time - start_time:.3f} seconds"
)
return target_dir, image_paths
# -------------------------------------------------------------------------
# 3) Update gallery on upload
# -------------------------------------------------------------------------
def update_gallery_on_upload(input_video, input_images, s_time_interval=1.0):
"""
Whenever user uploads or changes files, immediately handle them
and show in the gallery. Return (target_dir, image_paths).
If nothing is uploaded, returns "None" and empty list.
"""
if not input_video and not input_images:
return None, None, None, None
target_dir, image_paths = handle_uploads(input_video, input_images, s_time_interval)
return (
None,
target_dir,
image_paths,
"Upload complete. Click 'Reconstruct' to begin 3D processing.",
)
# -------------------------------------------------------------------------
# 4) Reconstruction: uses the target_dir plus any viz parameters
# -------------------------------------------------------------------------
@spaces.GPU(duration=120)
def gradio_demo(
target_dir,
frame_filter="All",
show_cam=True,
filter_black_bg=False,
filter_white_bg=False,
apply_mask=True,
show_mesh=True,
):
"""
Perform reconstruction using the already-created target_dir/images.
"""
if not os.path.isdir(target_dir) or target_dir == "None":
return None, "No valid target directory found. Please upload first.", None, None
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
# Prepare frame_filter dropdown
target_dir_images = os.path.join(target_dir, "images")
all_files = (
sorted(os.listdir(target_dir_images))
if os.path.isdir(target_dir_images)
else []
)
all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
frame_filter_choices = ["All"] + all_files
print("Running MapAnything model...")
with torch.no_grad():
predictions, processed_data = run_model(target_dir, apply_mask)
# Save predictions
prediction_save_path = os.path.join(target_dir, "predictions.npz")
np.savez(prediction_save_path, **predictions)
# Handle None frame_filter
if frame_filter is None:
frame_filter = "All"
# Build a GLB file name
glbfile = os.path.join(
target_dir,
f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb",
)
# Convert predictions to GLB
glbscene = predictions_to_glb(
predictions,
filter_by_frames=frame_filter,
show_cam=show_cam,
mask_black_bg=filter_black_bg,
mask_white_bg=filter_white_bg,
as_mesh=show_mesh, # Use the show_mesh parameter
)
glbscene.export(file_obj=glbfile)
# Get camera poses JSON before cleanup deletes predictions - ADDED JC
camera_poses_json = get_camera_poses_json(predictions)
# Cleanup
del predictions
gc.collect()
torch.cuda.empty_cache()
end_time = time.time()
print(f"Total time: {end_time - start_time:.2f} seconds")
log_msg = (
f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
)
# Populate visualization tabs with processed data
depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs(
processed_data
)
# Update view selectors based on available views
depth_selector, normal_selector, measure_selector = update_view_selectors(
processed_data
)
return (
glbfile,
log_msg,
gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
processed_data,
depth_vis,
normal_vis,
measure_img,
"", # measure_text (empty initially)
depth_selector,
normal_selector,
measure_selector,
camera_poses_json # ADDED JC
)
def get_predictions_file(target_dir):
"""
Return the predictions.npz file path for API download.
This allows external clients to download the full reconstruction data
including world_points, depth, extrinsics, and intrinsics.
Args:
target_dir: Directory containing predictions.npz from reconstruction
Returns:
Path to predictions.npz file, or None if not found
"""
if not target_dir or not os.path.isdir(target_dir):
return None
predictions_path = os.path.join(target_dir, "predictions.npz")
if os.path.exists(predictions_path):
return predictions_path
return None
# -------------------------------------------------------------------------
# 5) Helper functions for UI resets + re-visualization
# -------------------------------------------------------------------------
def colorize_depth(depth_map, mask=None):
"""Convert depth map to colorized visualization with optional mask"""
if depth_map is None:
return None
# Normalize depth to 0-1 range
depth_normalized = depth_map.copy()
valid_mask = depth_normalized > 0
# Apply additional mask if provided (for background filtering)
if mask is not None:
valid_mask = valid_mask & mask
if valid_mask.sum() > 0:
valid_depths = depth_normalized[valid_mask]
p5 = np.percentile(valid_depths, 5)
p95 = np.percentile(valid_depths, 95)
depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5)
# Apply colormap
import matplotlib.pyplot as plt
colormap = plt.cm.turbo_r
colored = colormap(depth_normalized)
colored = (colored[:, :, :3] * 255).astype(np.uint8)
# Set invalid pixels to white
colored[~valid_mask] = [255, 255, 255]
return colored
def colorize_normal(normal_map, mask=None):
"""Convert normal map to colorized visualization with optional mask"""
if normal_map is None:
return None
# Create a copy for modification
normal_vis = normal_map.copy()
# Apply mask if provided (set masked areas to [0, 0, 0] which becomes grey after normalization)
if mask is not None:
invalid_mask = ~mask
normal_vis[invalid_mask] = [0, 0, 0] # Set invalid areas to zero
# Normalize normals to [0, 1] range for visualization
normal_vis = (normal_vis + 1.0) / 2.0
normal_vis = (normal_vis * 255).astype(np.uint8)
return normal_vis
def process_predictions_for_visualization(
predictions, views, high_level_config, filter_black_bg=False, filter_white_bg=False
):
"""Extract depth, normal, and 3D points from predictions for visualization"""
processed_data = {}
# Process each view
for view_idx, view in enumerate(views):
# Get image
image = rgb(view["img"], norm_type=high_level_config["data_norm_type"])
# Get predicted points
pred_pts3d = predictions["world_points"][view_idx]
# Initialize data for this view
view_data = {
"image": image[0],
"points3d": pred_pts3d,
"depth": None,
"normal": None,
"mask": None,
}
# Start with the final mask from predictions
mask = predictions["final_mask"][view_idx].copy()
# Apply black background filtering if enabled
if filter_black_bg:
# Get the image colors (ensure they're in 0-255 range)
view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
# Filter out black background pixels (sum of RGB < 16)
black_bg_mask = view_colors.sum(axis=2) >= 16
mask = mask & black_bg_mask
# Apply white background filtering if enabled
if filter_white_bg:
# Get the image colors (ensure they're in 0-255 range)
view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
# Filter out white background pixels (all RGB > 240)
white_bg_mask = ~(
(view_colors[:, :, 0] > 240)
& (view_colors[:, :, 1] > 240)
& (view_colors[:, :, 2] > 240)
)
mask = mask & white_bg_mask
view_data["mask"] = mask
view_data["depth"] = predictions["depth"][view_idx].squeeze()
normals, _ = points_to_normals(pred_pts3d, mask=view_data["mask"])
view_data["normal"] = normals
processed_data[view_idx] = view_data
return processed_data
def reset_measure(processed_data):
"""Reset measure points"""
if processed_data is None or len(processed_data) == 0:
return None, [], ""
# Return the first view image
first_view = list(processed_data.values())[0]
return first_view["image"], [], ""
def measure(
processed_data, measure_points, current_view_selector, event: gr.SelectData
):
"""Handle measurement on images"""
try:
print(f"Measure function called with selector: {current_view_selector}")
if processed_data is None or len(processed_data) == 0:
return None, [], "No data available"
# Use the currently selected view instead of always using the first view
try:
current_view_index = int(current_view_selector.split()[1]) - 1
except:
current_view_index = 0
print(f"Using view index: {current_view_index}")
# Get view data safely
if current_view_index < 0 or current_view_index >= len(processed_data):
current_view_index = 0
view_keys = list(processed_data.keys())
current_view = processed_data[view_keys[current_view_index]]
if current_view is None:
return None, [], "No view data available"
point2d = event.index[0], event.index[1]
print(f"Clicked point: {point2d}")
# Check if the clicked point is in a masked area (prevent interaction)
if (
current_view["mask"] is not None
and 0 <= point2d[1] < current_view["mask"].shape[0]
and 0 <= point2d[0] < current_view["mask"].shape[1]
):
# Check if the point is in a masked (invalid) area
if not current_view["mask"][point2d[1], point2d[0]]:
print(f"Clicked point {point2d} is in masked area, ignoring click")
# Always return image with mask overlay
masked_image, _ = update_measure_view(
processed_data, current_view_index
)
return (
masked_image,
measure_points,
'<span style="color: red; font-weight: bold;">Cannot measure on masked areas (shown in grey)</span>',
)
measure_points.append(point2d)
# Get image with mask overlay and ensure it's valid
image, _ = update_measure_view(processed_data, current_view_index)
if image is None:
return None, [], "No image available"
image = image.copy()
points3d = current_view["points3d"]
# Ensure image is in uint8 format for proper cv2 operations
try:
if image.dtype != np.uint8:
if image.max() <= 1.0:
# Image is in [0, 1] range, convert to [0, 255]
image = (image * 255).astype(np.uint8)
else:
# Image is already in [0, 255] range
image = image.astype(np.uint8)
except Exception as e:
print(f"Image conversion error: {e}")
return None, [], f"Image conversion error: {e}"
# Draw circles for points
try:
for p in measure_points:
if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]:
image = cv2.circle(
image, p, radius=5, color=(255, 0, 0), thickness=2
)
except Exception as e:
print(f"Drawing error: {e}")
return None, [], f"Drawing error: {e}"
depth_text = ""
try:
for i, p in enumerate(measure_points):
if (
current_view["depth"] is not None
and 0 <= p[1] < current_view["depth"].shape[0]
and 0 <= p[0] < current_view["depth"].shape[1]
):
d = current_view["depth"][p[1], p[0]]
depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n"
else:
# Use Z coordinate of 3D points if depth not available
if (
points3d is not None
and 0 <= p[1] < points3d.shape[0]
and 0 <= p[0] < points3d.shape[1]
):
z = points3d[p[1], p[0], 2]
depth_text += f"- **P{i + 1} Z-coord: {z:.2f}m.**\n"
except Exception as e:
print(f"Depth text error: {e}")
depth_text = f"Error computing depth: {e}\n"
if len(measure_points) == 2:
try:
point1, point2 = measure_points
# Draw line
if (
0 <= point1[0] < image.shape[1]
and 0 <= point1[1] < image.shape[0]
and 0 <= point2[0] < image.shape[1]
and 0 <= point2[1] < image.shape[0]
):
image = cv2.line(
image, point1, point2, color=(255, 0, 0), thickness=2
)
# Compute 3D distance
distance_text = "- **Distance: Unable to compute**"
if (
points3d is not None
and 0 <= point1[1] < points3d.shape[0]
and 0 <= point1[0] < points3d.shape[1]
and 0 <= point2[1] < points3d.shape[0]
and 0 <= point2[0] < points3d.shape[1]
):
try:
p1_3d = points3d[point1[1], point1[0]]
p2_3d = points3d[point2[1], point2[0]]
distance = np.linalg.norm(p1_3d - p2_3d)
distance_text = f"- **Distance: {distance:.2f}m**"
except Exception as e:
print(f"Distance computation error: {e}")
distance_text = f"- **Distance computation error: {e}**"
measure_points = []
text = depth_text + distance_text
print(f"Measurement complete: {text}")
return [image, measure_points, text]
except Exception as e:
print(f"Final measurement error: {e}")
return None, [], f"Measurement error: {e}"
else:
print(f"Single point measurement: {depth_text}")
return [image, measure_points, depth_text]
except Exception as e:
print(f"Overall measure function error: {e}")
return None, [], f"Measure function error: {e}"
def clear_fields():
"""
Clears the 3D viewer, the stored target_dir, and empties the gallery.
"""
return None
def update_log():
"""
Display a quick log message while waiting.
"""
return "Loading and Reconstructing..."
def update_visualization(
target_dir,
frame_filter,
show_cam,
is_example,
filter_black_bg=False,
filter_white_bg=False,
show_mesh=True,
):
"""
Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
and return it for the 3D viewer. If is_example == "True", skip.
"""
# If it's an example click, skip as requested
if is_example == "True":
return (
gr.update(),
"No reconstruction available. Please click the Reconstruct button first.",
)
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
return (
gr.update(),
"No reconstruction available. Please click the Reconstruct button first.",
)
predictions_path = os.path.join(target_dir, "predictions.npz")
if not os.path.exists(predictions_path):
return (
gr.update(),
f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.",
)
loaded = np.load(predictions_path, allow_pickle=True)
predictions = {key: loaded[key] for key in loaded.keys()}
glbfile = os.path.join(
target_dir,
f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb",
)
if not os.path.exists(glbfile):
glbscene = predictions_to_glb(
predictions,
filter_by_frames=frame_filter,
show_cam=show_cam,
mask_black_bg=filter_black_bg,
mask_white_bg=filter_white_bg,
as_mesh=show_mesh,
)
glbscene.export(file_obj=glbfile)
return (
glbfile,
"Visualization updated.",
)
def update_all_views_on_filter_change(
target_dir,
filter_black_bg,
filter_white_bg,
processed_data,
depth_view_selector,
normal_view_selector,
measure_view_selector,
):
"""
Update all individual view tabs when background filtering checkboxes change.
This regenerates the processed data with new filtering and updates all views.
"""
# Check if we have a valid target directory and predictions
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
return processed_data, None, None, None, []
predictions_path = os.path.join(target_dir, "predictions.npz")
if not os.path.exists(predictions_path):
return processed_data, None, None, None, []
try:
# Load the original predictions and views
loaded = np.load(predictions_path, allow_pickle=True)
predictions = {key: loaded[key] for key in loaded.keys()}
# Load images using MapAnything's load_images function
image_folder_path = os.path.join(target_dir, "images")
views = load_images(image_folder_path)
# Regenerate processed data with new filtering settings
new_processed_data = process_predictions_for_visualization(
predictions, views, high_level_config, filter_black_bg, filter_white_bg
)
# Get current view indices
try:
depth_view_idx = (
int(depth_view_selector.split()[1]) - 1 if depth_view_selector else 0
)
except:
depth_view_idx = 0
try:
normal_view_idx = (
int(normal_view_selector.split()[1]) - 1 if normal_view_selector else 0
)
except:
normal_view_idx = 0
try:
measure_view_idx = (
int(measure_view_selector.split()[1]) - 1
if measure_view_selector
else 0
)
except:
measure_view_idx = 0
# Update all views with new filtered data
depth_vis = update_depth_view(new_processed_data, depth_view_idx)
normal_vis = update_normal_view(new_processed_data, normal_view_idx)
measure_img, _ = update_measure_view(new_processed_data, measure_view_idx)
return new_processed_data, depth_vis, normal_vis, measure_img, []
except Exception as e:
print(f"Error updating views on filter change: {e}")
return processed_data, None, None, None, []
# -------------------------------------------------------------------------
# Example scene functions
# -------------------------------------------------------------------------
def get_scene_info(examples_dir):
"""Get information about scenes in the examples directory"""
import glob
scenes = []
if not os.path.exists(examples_dir):
return scenes
for scene_folder in sorted(os.listdir(examples_dir)):
scene_path = os.path.join(examples_dir, scene_folder)
if os.path.isdir(scene_path):
# Find all image files in the scene folder
image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"]
image_files = []
for ext in image_extensions:
image_files.extend(glob.glob(os.path.join(scene_path, ext)))
image_files.extend(glob.glob(os.path.join(scene_path, ext.upper())))
if image_files:
# Sort images and get the first one for thumbnail
image_files = sorted(image_files)
first_image = image_files[0]
num_images = len(image_files)
scenes.append(
{
"name": scene_folder,
"path": scene_path,
"thumbnail": first_image,
"num_images": num_images,
"image_files": image_files,
}
)
return scenes
def load_example_scene(scene_name, examples_dir="examples"):
"""Load a scene from examples directory"""
scenes = get_scene_info(examples_dir)
# Find the selected scene
selected_scene = None
for scene in scenes:
if scene["name"] == scene_name:
selected_scene = scene
break
if selected_scene is None:
return None, None, None, "Scene not found"
# Create file-like objects for the unified upload system
# Convert image file paths to the format expected by unified_upload
file_objects = []
for image_path in selected_scene["image_files"]:
file_objects.append(image_path)
# Create target directory and copy images using the unified upload system
target_dir, image_paths = handle_uploads(file_objects, 1.0)
return (
None, # Clear reconstruction output
target_dir, # Set target directory
image_paths, # Set gallery
f"Loaded scene '{scene_name}' with {selected_scene['num_images']} images. Click 'Reconstruct' to begin 3D processing.",
)
# -------------------------------------------------------------------------
# 6) Mask Projection Functions (for multi-angle virtual staging)
# -------------------------------------------------------------------------
def project_mask_to_cameras(
mask_image: np.ndarray,
source_camera_index: int,
target_dir: str,
dilate_radius: int = 5,
flip_y: bool = False
) -> dict:
"""
Project a furniture mask from one camera view to all other camera views.
This uses the raw depth data and camera matrices from the reconstruction
to accurately project masked pixels through 3D space.
Args:
mask_image: Binary mask (H, W) where 255 = furniture, 0 = background
source_camera_index: Index of the camera that captured the staged image
target_dir: Directory containing predictions.npz from reconstruction
dilate_radius: Radius for dilating projected masks to fill gaps
Returns:
dict with:
- 'projected_masks': {camera_index: base64_png_mask}
- 'source_camera': source camera index
- 'num_points_projected': number of 3D points
- 'debug_info': debugging information
"""
import base64
from io import BytesIO
from PIL import Image
# Load predictions
predictions_path = os.path.join(target_dir, "predictions.npz")
if not os.path.exists(predictions_path):
raise ValueError(f"No predictions found at {predictions_path}")
loaded = np.load(predictions_path, allow_pickle=True)
predictions = {key: loaded[key] for key in loaded.keys()}
depth_maps = predictions["depth"] # (S, H, W, 1) or (S, H, W)
extrinsics = predictions["extrinsic"] # (S, 4, 4)
intrinsics = predictions["intrinsic"] # (S, 3, 3)
num_cameras = len(extrinsics)
if source_camera_index < 0 or source_camera_index >= num_cameras:
raise ValueError(f"Invalid source camera index: {source_camera_index}")
# Get source camera data
source_depth = depth_maps[source_camera_index].squeeze() # (H, W)
source_intrinsic = intrinsics[source_camera_index] # (3, 3)
source_extrinsic = extrinsics[source_camera_index] # (4, 4)
H, W = source_depth.shape
# Resize mask to match depth map size if needed
if mask_image.shape[:2] != (H, W):
mask_resized = cv2.resize(mask_image, (W, H), interpolation=cv2.INTER_NEAREST)
else:
mask_resized = mask_image
# Ensure binary mask
if mask_resized.ndim == 3:
mask_resized = cv2.cvtColor(mask_resized, cv2.COLOR_RGB2GRAY)
mask_binary = (mask_resized > 127).astype(np.uint8)
# Get masked pixel coordinates
ys, xs = np.where(mask_binary > 0)
depths = source_depth[ys, xs]
# Filter out invalid depths
valid_mask = depths > 0.01 # Minimum 1cm depth
ys = ys[valid_mask]
xs = xs[valid_mask]
depths = depths[valid_mask]
if len(depths) == 0:
return {
'projected_masks': {},
'source_camera': source_camera_index,
'num_points_projected': 0,
'debug_info': 'No valid depth values in masked region'
}
print(f"[Project Mask] Projecting {len(depths)} points from camera {source_camera_index}")
# Debug: Print camera info
for cam_idx in range(num_cameras):
ext = extrinsics[cam_idx]
# Camera position is -R^T @ t (inverse of extrinsic)
R = ext[:3, :3]
t = ext[:3, 3]
cam_pos = -R.T @ t
# Camera forward direction (Z axis of camera in world coords)
cam_forward = R.T @ np.array([0, 0, 1])
print(f"[Project Mask] Camera {cam_idx}: pos=[{cam_pos[0]:.2f}, {cam_pos[1]:.2f}, {cam_pos[2]:.2f}], forward=[{cam_forward[0]:.2f}, {cam_forward[1]:.2f}, {cam_forward[2]:.2f}]")
# Unproject to 3D camera coordinates
fx = source_intrinsic[0, 0]
fy = source_intrinsic[1, 1]
cx = source_intrinsic[0, 2]
cy = source_intrinsic[1, 2]
# Camera coordinates (OpenCV convention: X-right, Y-down, Z-forward)
X_cam = (xs - cx) * depths / fx
Y_cam = (ys - cy) * depths / fy
Z_cam = depths
# Stack as homogeneous coordinates (N, 4)
points_cam = np.stack([X_cam, Y_cam, Z_cam, np.ones_like(X_cam)], axis=1)
# Transform to world coordinates
# NOTE: In Map Anything, "extrinsic" is actually camera-to-world (not world-to-camera!)
# See geometry.py line 98: pts3d_world = einsum(camera_pose, pts3d_cam_homo, ...)
# So we use it directly, NOT its inverse
cam_to_world = source_extrinsic # Already camera-to-world!
points_world = (cam_to_world @ points_cam.T).T[:, :3] # (N, 3)
print(f"[Project Mask] World points range: X=[{points_world[:,0].min():.2f}, {points_world[:,0].max():.2f}], "
f"Y=[{points_world[:,1].min():.2f}, {points_world[:,1].max():.2f}], "
f"Z=[{points_world[:,2].min():.2f}, {points_world[:,2].max():.2f}]")
# Project to each target camera
projected_masks = {}
debug_info = []
for target_idx in range(num_cameras):
if target_idx == source_camera_index:
# Skip source camera (or you could project to itself for verification)
continue
target_intrinsic = intrinsics[target_idx]
target_extrinsic = extrinsics[target_idx] # This is camera-to-world!
# Transform world points to target camera coordinates
# Since extrinsic is camera-to-world, we need its inverse for world-to-camera
world_to_target_cam = np.linalg.inv(target_extrinsic)
points_world_homo = np.hstack([points_world, np.ones((len(points_world), 1))])
points_target_cam = (world_to_target_cam @ points_world_homo.T).T[:, :3] # (N, 3)
# Debug: Print Z range before filtering
print(f"[Project Mask] Camera {target_idx}: Z range in camera coords = [{points_target_cam[:, 2].min():.3f}, {points_target_cam[:, 2].max():.3f}]")
# Filter points behind camera (Z <= 0) and very close points (cause projection blowup)
MIN_Z = 0.1 # Increased from 0.01 to avoid extreme projections
in_front = points_target_cam[:, 2] > MIN_Z
points_target_cam = points_target_cam[in_front]
if len(points_target_cam) == 0:
debug_info.append(f"Camera {target_idx}: All points behind camera")
continue
# Project to image plane
target_fx = target_intrinsic[0, 0]
target_fy = target_intrinsic[1, 1]
target_cx = target_intrinsic[0, 2]
target_cy = target_intrinsic[1, 2]
proj_x = (points_target_cam[:, 0] / points_target_cam[:, 2]) * target_fx + target_cx
proj_y = (points_target_cam[:, 1] / points_target_cam[:, 2]) * target_fy + target_cy
# Create output mask
target_H, target_W = depth_maps[target_idx].squeeze().shape
# Flip Y if requested (for coordinate system mismatch)
if flip_y:
proj_y = target_H - proj_y
# Debug: Print projection stats
print(f"[Project Mask] Camera {target_idx}: proj_x range=[{proj_x.min():.1f}, {proj_x.max():.1f}], proj_y range=[{proj_y.min():.1f}, {proj_y.max():.1f}], image size={target_W}x{target_H}")
target_mask = np.zeros((target_H, target_W), dtype=np.uint8)
# Count in-bounds points
in_bounds = (proj_x >= 0) & (proj_x < target_W) & (proj_y >= 0) & (proj_y < target_H)
proj_x_valid = proj_x[in_bounds].astype(np.int32)
proj_y_valid = proj_y[in_bounds].astype(np.int32)
# Plot points on mask
target_mask[proj_y_valid, proj_x_valid] = 255
# Dilate to fill gaps (point cloud is sparse)
if dilate_radius > 0:
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilate_radius * 2 + 1, dilate_radius * 2 + 1))
target_mask = cv2.dilate(target_mask, kernel)
# Optional: erode slightly to tighten edges
erode_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilate_radius, dilate_radius))
target_mask = cv2.erode(target_mask, erode_kernel)
# Fill holes
contours, _ = cv2.findContours(target_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
target_mask_filled = np.zeros_like(target_mask)
cv2.drawContours(target_mask_filled, contours, -1, 255, -1) # -1 fills the contours
# Convert to base64 PNG
pil_img = Image.fromarray(target_mask_filled)
buffer = BytesIO()
pil_img.save(buffer, format='PNG')
base64_mask = base64.b64encode(buffer.getvalue()).decode('utf-8')
projected_masks[target_idx] = base64_mask
coverage = (target_mask_filled > 0).sum() / (target_H * target_W) * 100
debug_info.append(f"Camera {target_idx}: {in_bounds.sum()}/{len(proj_x)} points in-bounds, {coverage:.1f}% coverage")
print(f"[Project Mask] Camera {target_idx}: {in_bounds.sum()} points projected, {coverage:.1f}% coverage")
return {
'projected_masks': projected_masks,
'source_camera': source_camera_index,
'num_points_projected': len(depths),
'debug_info': '\n'.join(debug_info)
}
def project_mask_depth_aware(
mask_image: np.ndarray,
source_camera_index: int,
target_dir: str,
dilate_radius: int = 7,
depth_threshold: float = 0.5,
flip_y: bool = False
) -> dict:
"""
Enhanced projection that uses target camera's depth to intelligently expand masks.
Only dilates in regions where depth is similar to projected furniture depth.
Args:
mask_image: Binary mask (H, W) where 255 = furniture, 0 = background
source_camera_index: Index of the camera that captured the staged image
target_dir: Directory containing predictions.npz from reconstruction
dilate_radius: Radius for dilating projected masks to fill gaps
depth_threshold: Maximum depth difference (meters) for mask expansion
flip_y: Flip Y axis if projection appears upside down
Returns:
dict with projected_masks, source_camera, num_points_projected, debug_info
"""
import base64
from io import BytesIO
from PIL import Image
# Load predictions
predictions_path = os.path.join(target_dir, "predictions.npz")
if not os.path.exists(predictions_path):
raise ValueError(f"No predictions found at {predictions_path}")
loaded = np.load(predictions_path, allow_pickle=True)
predictions = {key: loaded[key] for key in loaded.keys()}
depth_maps = predictions["depth"] # (S, H, W, 1) or (S, H, W)
extrinsics = predictions["extrinsic"] # (S, 4, 4)
intrinsics = predictions["intrinsic"] # (S, 3, 3)
num_cameras = len(extrinsics)
if source_camera_index < 0 or source_camera_index >= num_cameras:
raise ValueError(f"Invalid source camera index: {source_camera_index}")
# Get source camera data
source_depth = depth_maps[source_camera_index].squeeze() # (H, W)
source_intrinsic = intrinsics[source_camera_index] # (3, 3)
source_extrinsic = extrinsics[source_camera_index] # (4, 4)
H, W = source_depth.shape
# Resize mask to match depth map size if needed
if mask_image.shape[:2] != (H, W):
mask_resized = cv2.resize(mask_image, (W, H), interpolation=cv2.INTER_NEAREST)
else:
mask_resized = mask_image
# Ensure binary mask
if mask_resized.ndim == 3:
mask_resized = cv2.cvtColor(mask_resized, cv2.COLOR_RGB2GRAY)
mask_binary = (mask_resized > 127).astype(np.uint8)
# Get masked pixel coordinates
ys, xs = np.where(mask_binary > 0)
depths = source_depth[ys, xs]
# Filter out invalid depths
valid_mask = depths > 0.01 # Minimum 1cm depth
ys = ys[valid_mask]
xs = xs[valid_mask]
depths = depths[valid_mask]
if len(depths) == 0:
return {
'projected_masks': {},
'source_camera': source_camera_index,
'num_points_projected': 0,
'debug_info': 'No valid depth values in masked region'
}
print(f"[Depth-Aware Projection] Projecting {len(depths)} points from camera {source_camera_index}")
# Unproject to 3D camera coordinates
fx = source_intrinsic[0, 0]
fy = source_intrinsic[1, 1]
cx = source_intrinsic[0, 2]
cy = source_intrinsic[1, 2]
X_cam = (xs - cx) * depths / fx
Y_cam = (ys - cy) * depths / fy
Z_cam = depths
points_cam = np.stack([X_cam, Y_cam, Z_cam, np.ones_like(X_cam)], axis=1)
# Transform to world coordinates (extrinsic is camera-to-world)
cam_to_world = source_extrinsic
points_world = (cam_to_world @ points_cam.T).T[:, :3]
# Project to each target camera
projected_masks = {}
debug_info = []
for target_idx in range(num_cameras):
if target_idx == source_camera_index:
continue
target_intrinsic = intrinsics[target_idx]
target_extrinsic = extrinsics[target_idx]
target_depth = depth_maps[target_idx].squeeze()
# Transform world points to target camera coordinates
world_to_target_cam = np.linalg.inv(target_extrinsic)
points_world_homo = np.hstack([points_world, np.ones((len(points_world), 1))])
points_target_cam = (world_to_target_cam @ points_world_homo.T).T[:, :3]
# Filter points behind camera
MIN_Z = 0.1
in_front = points_target_cam[:, 2] > MIN_Z
points_target_cam_filtered = points_target_cam[in_front]
if len(points_target_cam_filtered) == 0:
debug_info.append(f"Camera {target_idx}: All points behind camera")
continue
# Project to image plane
target_fx = target_intrinsic[0, 0]
target_fy = target_intrinsic[1, 1]
target_cx = target_intrinsic[0, 2]
target_cy = target_intrinsic[1, 2]
proj_x = (points_target_cam_filtered[:, 0] / points_target_cam_filtered[:, 2]) * target_fx + target_cx
proj_y = (points_target_cam_filtered[:, 1] / points_target_cam_filtered[:, 2]) * target_fy + target_cy
target_H, target_W = target_depth.shape
if flip_y:
proj_y = target_H - proj_y
# Create sparse mask from projected points
target_mask = np.zeros((target_H, target_W), dtype=np.uint8)
in_bounds = (proj_x >= 0) & (proj_x < target_W) & (proj_y >= 0) & (proj_y < target_H)
proj_x_valid = proj_x[in_bounds].astype(np.int32)
proj_y_valid = proj_y[in_bounds].astype(np.int32)
target_mask[proj_y_valid, proj_x_valid] = 255
# DEPTH-AWARE DILATION
if dilate_radius > 0 and len(proj_x_valid) > 0:
# Get depth values at projected furniture points
furniture_depths = target_depth[proj_y_valid, proj_x_valid]
valid_furniture_depths = furniture_depths[furniture_depths > 0]
if len(valid_furniture_depths) > 0:
# Use median depth as reference (robust to outliers)
median_furniture_depth = np.median(valid_furniture_depths)
# Create depth-consistent mask: only pixels with similar depth
depth_diff = np.abs(target_depth - median_furniture_depth)
depth_consistent = (depth_diff < depth_threshold) & (target_depth > 0)
# Standard dilation
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilate_radius * 2 + 1, dilate_radius * 2 + 1))
dilated = cv2.dilate(target_mask, kernel)
# Constrain dilation to depth-consistent regions
target_mask = (dilated > 0) & depth_consistent
target_mask = (target_mask * 255).astype(np.uint8)
print(f"[Depth-Aware] Camera {target_idx}: median_depth={median_furniture_depth:.2f}m, "
f"depth_consistent_pixels={depth_consistent.sum()}, threshold={depth_threshold}m")
else:
# Fallback to standard dilation if no valid depths
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilate_radius * 2 + 1, dilate_radius * 2 + 1))
target_mask = cv2.dilate(target_mask, kernel)
# Fill contours
contours, _ = cv2.findContours(target_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
target_mask_filled = np.zeros_like(target_mask)
cv2.drawContours(target_mask_filled, contours, -1, 255, -1)
# Convert to base64 PNG
pil_img = Image.fromarray(target_mask_filled)
buffer = BytesIO()
pil_img.save(buffer, format='PNG')
base64_mask = base64.b64encode(buffer.getvalue()).decode('utf-8')
projected_masks[target_idx] = base64_mask
coverage = (target_mask_filled > 0).sum() / (target_H * target_W) * 100
debug_info.append(f"Camera {target_idx}: {in_bounds.sum()}/{len(proj_x)} points in-bounds, {coverage:.1f}% coverage (depth-aware)")
print(f"[Depth-Aware Projection] Camera {target_idx}: {in_bounds.sum()} points projected, {coverage:.1f}% coverage")
return {
'projected_masks': projected_masks,
'source_camera': source_camera_index,
'num_points_projected': len(depths),
'debug_info': '\n'.join(debug_info)
}
def gradio_project_mask(mask_image, source_camera_index, target_dir, dilate_radius, flip_y=False, show_overlay=True):
"""Gradio wrapper for project_mask_to_cameras"""
if mask_image is None:
return None, None, None, "Please upload a mask image"
if target_dir is None or target_dir == "None" or not os.path.isdir(target_dir):
return None, None, None, "Please run reconstruction first"
try:
# Convert to numpy if needed
if hasattr(mask_image, 'numpy'):
mask_np = mask_image.numpy()
else:
mask_np = np.array(mask_image)
result = project_mask_to_cameras(
mask_np,
int(source_camera_index),
target_dir,
int(dilate_radius),
flip_y=flip_y
)
# Load original images for overlay
predictions_path = os.path.join(target_dir, "predictions.npz")
loaded = np.load(predictions_path, allow_pickle=True)
original_images = loaded["images"] # (S, H, W, 3) or (S, 3, H, W)
# Convert first projected mask to displayable image
display_masks = []
display_overlays = []
for cam_idx, base64_mask in result['projected_masks'].items():
import base64
from io import BytesIO
from PIL import Image
mask_bytes = base64.b64decode(base64_mask)
mask_img = Image.open(BytesIO(mask_bytes))
mask_array = np.array(mask_img)
display_masks.append((mask_array, f"Camera {cam_idx}"))
# Create overlay with original image
cam_idx_int = int(cam_idx)
if cam_idx_int < len(original_images):
orig_img = original_images[cam_idx_int]
# Handle different image formats
if orig_img.ndim == 3 and orig_img.shape[0] == 3: # CHW format
orig_img = np.transpose(orig_img, (1, 2, 0))
# Convert to uint8 if needed
if orig_img.max() <= 1.0:
orig_img = (orig_img * 255).astype(np.uint8)
else:
orig_img = orig_img.astype(np.uint8)
# Resize mask to match original image if needed
if mask_array.shape[:2] != orig_img.shape[:2]:
mask_resized = cv2.resize(mask_array, (orig_img.shape[1], orig_img.shape[0]), interpolation=cv2.INTER_NEAREST)
else:
mask_resized = mask_array
# Create red overlay where mask is white
overlay = orig_img.copy()
mask_bool = mask_resized > 127
# Blend: 50% original + 50% red where mask is active
overlay[mask_bool, 0] = np.clip(overlay[mask_bool, 0] * 0.5 + 255 * 0.5, 0, 255).astype(np.uint8) # Red
overlay[mask_bool, 1] = (overlay[mask_bool, 1] * 0.5).astype(np.uint8) # Reduce green
overlay[mask_bool, 2] = (overlay[mask_bool, 2] * 0.5).astype(np.uint8) # Reduce blue
display_overlays.append((overlay, f"Camera {cam_idx} Overlay"))
status = f"Projected {result['num_points_projected']} points\n{result['debug_info']}"
# Return gallery of masks, overlays, and status
return display_masks, display_overlays, json.dumps(result['projected_masks']), status
except Exception as e:
import traceback
return None, None, None, f"Error: {str(e)}\n{traceback.format_exc()}"
def gradio_project_mask_enhanced(mask_image, source_camera_index, target_dir, dilate_radius, depth_threshold, flip_y=False):
"""Gradio wrapper for depth-aware projection"""
if mask_image is None:
return None, None, "Please upload a mask image"
if target_dir is None or target_dir == "None" or not os.path.isdir(target_dir):
return None, None, "Please run reconstruction first"
try:
# Convert to numpy if needed
if hasattr(mask_image, 'numpy'):
mask_np = mask_image.numpy()
else:
mask_np = np.array(mask_image)
result = project_mask_depth_aware(
mask_np,
int(source_camera_index),
target_dir,
int(dilate_radius),
float(depth_threshold),
flip_y=flip_y
)
# Load original images for overlay
predictions_path = os.path.join(target_dir, "predictions.npz")
loaded = np.load(predictions_path, allow_pickle=True)
original_images = loaded["images"]
display_masks = []
display_overlays = []
for cam_idx, base64_mask in result['projected_masks'].items():
import base64
from io import BytesIO
from PIL import Image
mask_bytes = base64.b64decode(base64_mask)
mask_img = Image.open(BytesIO(mask_bytes))
mask_array = np.array(mask_img)
display_masks.append((mask_array, f"Camera {cam_idx}"))
# Create overlay
cam_idx_int = int(cam_idx)
if cam_idx_int < len(original_images):
orig_img = original_images[cam_idx_int]
if orig_img.ndim == 3 and orig_img.shape[0] == 3:
orig_img = np.transpose(orig_img, (1, 2, 0))
if orig_img.max() <= 1.0:
orig_img = (orig_img * 255).astype(np.uint8)
else:
orig_img = orig_img.astype(np.uint8)
if mask_array.shape[:2] != orig_img.shape[:2]:
mask_resized = cv2.resize(mask_array, (orig_img.shape[1], orig_img.shape[0]), interpolation=cv2.INTER_NEAREST)
else:
mask_resized = mask_array
overlay = orig_img.copy()
mask_bool = mask_resized > 127
overlay[mask_bool, 0] = np.clip(overlay[mask_bool, 0] * 0.5 + 255 * 0.5, 0, 255).astype(np.uint8)
overlay[mask_bool, 1] = (overlay[mask_bool, 1] * 0.5).astype(np.uint8)
overlay[mask_bool, 2] = (overlay[mask_bool, 2] * 0.5).astype(np.uint8)
display_overlays.append((overlay, f"Camera {cam_idx} Overlay"))
status = f"Projected {result['num_points_projected']} points (depth-aware)\n{result['debug_info']}"
return display_masks, display_overlays, status
except Exception as e:
import traceback
return None, None, f"Error: {str(e)}\n{traceback.format_exc()}"
# -------------------------------------------------------------------------
# 7) Build Gradio UI
# -------------------------------------------------------------------------
theme = get_gradio_theme()
with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo:
# State variables for the tabbed interface
is_example = gr.Textbox(label="is_example", visible=False, value="None")
num_images = gr.Textbox(label="num_images", visible=False, value="None")
processed_data_state = gr.State(value=None)
measure_points_state = gr.State(value=[])
current_view_index = gr.State(value=0) # Track current view index for navigation
gr.HTML(get_header_html(get_logo_base64()))
gr.HTML(get_description_html())
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
# Hidden output for camera poses (for API access)
camera_poses_output = gr.Textbox(
label="Camera Poses JSON",
visible=False,
)
with gr.Row():
with gr.Column(scale=2):
# Unified upload component for both videos and images
unified_upload = gr.File(
file_count="multiple",
label="Upload Video or Images",
interactive=True,
file_types=["image", "video"],
)
with gr.Row():
s_time_interval = gr.Slider(
minimum=0.1,
maximum=5.0,
value=1.0,
step=0.1,
label="Video sample time interval (take a sample every x sec.)",
interactive=True,
visible=True,
scale=3,
)
resample_btn = gr.Button(
"Resample Video",
visible=False,
variant="secondary",
scale=1,
)
image_gallery = gr.Gallery(
label="Preview",
columns=4,
height="300px",
show_download_button=True,
object_fit="contain",
preview=True,
)
clear_uploads_btn = gr.ClearButton(
[unified_upload, image_gallery],
value="Clear Uploads",
variant="secondary",
size="sm",
)
with gr.Column(scale=4):
with gr.Column():
gr.Markdown(
"**Metric 3D Reconstruction (Point Cloud and Camera Poses)**"
)
log_output = gr.Markdown(
"Please upload a video or images, then click Reconstruct.",
elem_classes=["custom-log"],
)
# Add tabbed interface similar to MoGe
with gr.Tabs():
with gr.Tab("3D View"):
reconstruction_output = gr.Model3D(
height=520,
zoom_speed=0.5,
pan_speed=0.5,
clear_color=[0.0, 0.0, 0.0, 0.0],
key="persistent_3d_viewer",
elem_id="reconstruction_3d_viewer",
)
with gr.Tab("Depth"):
with gr.Row(elem_classes=["navigation-row"]):
prev_depth_btn = gr.Button("◀ Previous", size="sm", scale=1)
depth_view_selector = gr.Dropdown(
choices=["View 1"],
value="View 1",
label="Select View",
scale=2,
interactive=True,
allow_custom_value=True,
)
next_depth_btn = gr.Button("Next ▶", size="sm", scale=1)
depth_map = gr.Image(
type="numpy",
label="Colorized Depth Map",
format="png",
interactive=False,
)
with gr.Tab("Normal"):
with gr.Row(elem_classes=["navigation-row"]):
prev_normal_btn = gr.Button(
"◀ Previous", size="sm", scale=1
)
normal_view_selector = gr.Dropdown(
choices=["View 1"],
value="View 1",
label="Select View",
scale=2,
interactive=True,
allow_custom_value=True,
)
next_normal_btn = gr.Button("Next ▶", size="sm", scale=1)
normal_map = gr.Image(
type="numpy",
label="Normal Map",
format="png",
interactive=False,
)
with gr.Tab("Measure"):
gr.Markdown(MEASURE_INSTRUCTIONS_HTML)
with gr.Row(elem_classes=["navigation-row"]):
prev_measure_btn = gr.Button(
"◀ Previous", size="sm", scale=1
)
measure_view_selector = gr.Dropdown(
choices=["View 1"],
value="View 1",
label="Select View",
scale=2,
interactive=True,
allow_custom_value=True,
)
next_measure_btn = gr.Button("Next ▶", size="sm", scale=1)
measure_image = gr.Image(
type="numpy",
show_label=False,
format="webp",
interactive=False,
sources=[],
)
gr.Markdown(
"**Note:** Light-grey areas indicate regions with no depth information where measurements cannot be taken."
)
measure_text = gr.Markdown("")
with gr.Tab("Project Mask"):
gr.Markdown("""
### Multi-Angle Virtual Staging
Upload a furniture mask from one camera angle to project it to other views.
**Steps:**
1. Run reconstruction with your images (1 staged + 2 empty)
2. Upload the furniture mask (white/255 = furniture, black/0 = background)
3. Select which camera captured the staged image
4. Click "Project Mask" to generate masks for other cameras
""")
with gr.Row():
mask_upload = gr.Image(
type="numpy",
label="Furniture Mask (binary: white=furniture)",
sources=["upload"],
)
with gr.Column():
source_camera_dropdown = gr.Dropdown(
choices=["0", "1", "2"],
value="0",
label="Source Camera Index (which view has furniture)",
)
dilate_slider = gr.Slider(
minimum=0,
maximum=20,
value=5,
step=1,
label="Dilation Radius (fills gaps in projection)",
)
flip_y_checkbox = gr.Checkbox(
label="Flip Y axis (try if projection is upside down)",
value=False,
)
project_btn = gr.Button("Project Mask", variant="primary")
projected_gallery = gr.Gallery(
label="Projected Masks",
columns=3,
height="250px",
object_fit="contain",
)
overlay_gallery = gr.Gallery(
label="Masks Overlaid on Original Images",
columns=3,
height="250px",
object_fit="contain",
)
projected_json = gr.Textbox(
label="Projected Masks JSON (base64)",
visible=False,
)
project_status = gr.Markdown("")
# Connect button
project_btn.click(
fn=gradio_project_mask,
inputs=[mask_upload, source_camera_dropdown, target_dir_output, dilate_slider, flip_y_checkbox],
outputs=[projected_gallery, overlay_gallery, projected_json, project_status],
)
with gr.Tab("Enhanced Projection"):
gr.Markdown("""
### Depth-Aware Projection
Improved mask projection that uses depth information to prevent bleeding onto background.
Only expands masks in regions with similar depth to the furniture.
""")
with gr.Row():
mask_upload_enhanced = gr.Image(
type="numpy",
label="Furniture Mask (binary: white=furniture)",
sources=["upload"],
)
with gr.Column():
source_camera_enhanced = gr.Dropdown(
choices=["0", "1", "2"],
value="0",
label="Source Camera Index (which view has furniture)",
)
dilate_slider_enhanced = gr.Slider(
minimum=0,
maximum=20,
value=7,
step=1,
label="Dilation Radius",
)
depth_threshold_slider = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.5,
step=0.1,
label="Depth Threshold (meters) - how much depth variation to allow",
)
flip_y_enhanced = gr.Checkbox(
label="Flip Y axis (try if projection is upside down)",
value=False,
)
project_btn_enhanced = gr.Button("Project (Depth-Aware)", variant="primary")
enhanced_gallery = gr.Gallery(
label="Projected Masks",
columns=3,
height="250px",
object_fit="contain",
)
enhanced_overlay = gr.Gallery(
label="Masks Overlaid on Original Images",
columns=3,
height="250px",
object_fit="contain",
)
enhanced_status = gr.Markdown("")
project_btn_enhanced.click(
fn=gradio_project_mask_enhanced,
inputs=[mask_upload_enhanced, source_camera_enhanced, target_dir_output, dilate_slider_enhanced, depth_threshold_slider, flip_y_enhanced],
outputs=[enhanced_gallery, enhanced_overlay, enhanced_status],
)
with gr.Row():
submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
clear_btn = gr.ClearButton(
[
unified_upload,
reconstruction_output,
log_output,
target_dir_output,
image_gallery,
],
scale=1,
)
with gr.Row():
frame_filter = gr.Dropdown(
choices=["All"], value="All", label="Show Points from Frame"
)
with gr.Column():
gr.Markdown("### Pointcloud Options: (live updates)")
show_cam = gr.Checkbox(label="Show Camera", value=True)
show_mesh = gr.Checkbox(label="Show Mesh", value=True)
filter_black_bg = gr.Checkbox(
label="Filter Black Background", value=False
)
filter_white_bg = gr.Checkbox(
label="Filter White Background", value=False
)
gr.Markdown("### Reconstruction Options: (updated on next run)")
apply_mask_checkbox = gr.Checkbox(
label="Apply mask for predicted ambiguous depth classes & edges",
value=True,
)
# ---------------------- Example Scenes Section ----------------------
gr.Markdown("## Example Scenes (lists all scenes in the examples folder)")
gr.Markdown("Click any thumbnail to load the scene for reconstruction.")
# Get scene information
scenes = get_scene_info("examples")
# Create thumbnail grid (4 columns, N rows)
if scenes:
for i in range(0, len(scenes), 4): # Process 4 scenes per row
with gr.Row():
for j in range(4):
scene_idx = i + j
if scene_idx < len(scenes):
scene = scenes[scene_idx]
with gr.Column(scale=1, elem_classes=["clickable-thumbnail"]):
# Clickable thumbnail
scene_img = gr.Image(
value=scene["thumbnail"],
height=150,
interactive=False,
show_label=False,
elem_id=f"scene_thumb_{scene['name']}",
sources=[],
)
# Scene name and image count as text below thumbnail
gr.Markdown(
f"**{scene['name']}** \n {scene['num_images']} images",
elem_classes=["scene-info"],
)
# Connect thumbnail click to load scene
scene_img.select(
fn=lambda name=scene["name"]: load_example_scene(name),
outputs=[
reconstruction_output,
target_dir_output,
image_gallery,
log_output,
],
)
else:
# Empty column to maintain grid structure
with gr.Column(scale=1):
pass
# -------------------------------------------------------------------------
# "Reconstruct" button logic:
# - Clear fields
# - Update log
# - gradio_demo(...) with the existing target_dir
# - Then set is_example = "False"
# -------------------------------------------------------------------------
submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then(
fn=update_log, inputs=[], outputs=[log_output]
).then(
fn=gradio_demo,
inputs=[
target_dir_output,
frame_filter,
show_cam,
filter_black_bg,
filter_white_bg,
apply_mask_checkbox,
show_mesh,
],
outputs=[
reconstruction_output,
log_output,
frame_filter,
processed_data_state,
depth_map,
normal_map,
measure_image,
measure_text,
depth_view_selector,
normal_view_selector,
measure_view_selector,
camera_poses_output,
],
).then(
fn=lambda: "False",
inputs=[],
outputs=[is_example], # set is_example to "False"
)
# -------------------------------------------------------------------------
# Real-time Visualization Updates
# -------------------------------------------------------------------------
frame_filter.change(
update_visualization,
[
target_dir_output,
frame_filter,
show_cam,
is_example,
filter_black_bg,
filter_white_bg,
show_mesh,
],
[reconstruction_output, log_output],
)
show_cam.change(
update_visualization,
[
target_dir_output,
frame_filter,
show_cam,
is_example,
],
[reconstruction_output, log_output],
)
filter_black_bg.change(
update_visualization,
[
target_dir_output,
frame_filter,
show_cam,
is_example,
filter_black_bg,
filter_white_bg,
],
[reconstruction_output, log_output],
).then(
fn=update_all_views_on_filter_change,
inputs=[
target_dir_output,
filter_black_bg,
filter_white_bg,
processed_data_state,
depth_view_selector,
normal_view_selector,
measure_view_selector,
],
outputs=[
processed_data_state,
depth_map,
normal_map,
measure_image,
measure_points_state,
],
)
filter_white_bg.change(
update_visualization,
[
target_dir_output,
frame_filter,
show_cam,
is_example,
filter_black_bg,
filter_white_bg,
show_mesh,
],
[reconstruction_output, log_output],
).then(
fn=update_all_views_on_filter_change,
inputs=[
target_dir_output,
filter_black_bg,
filter_white_bg,
processed_data_state,
depth_view_selector,
normal_view_selector,
measure_view_selector,
],
outputs=[
processed_data_state,
depth_map,
normal_map,
measure_image,
measure_points_state,
],
)
show_mesh.change(
update_visualization,
[
target_dir_output,
frame_filter,
show_cam,
is_example,
filter_black_bg,
filter_white_bg,
show_mesh,
],
[reconstruction_output, log_output],
)
# -------------------------------------------------------------------------
# Auto-update gallery whenever user uploads or changes their files
# -------------------------------------------------------------------------
def update_gallery_on_unified_upload(files, interval):
if not files:
return None, None, None
target_dir, image_paths = handle_uploads(files, interval)
return (
target_dir,
image_paths,
"Upload complete. Click 'Reconstruct' to begin 3D processing.",
)
def show_resample_button(files):
"""Show the resample button only if there are uploaded files containing videos"""
if not files:
return gr.update(visible=False)
# Check if any uploaded files are videos
video_extensions = [
".mp4",
".avi",
".mov",
".mkv",
".wmv",
".flv",
".webm",
".m4v",
".3gp",
]
has_video = False
for file_data in files:
if isinstance(file_data, dict) and "name" in file_data:
file_path = file_data["name"]
else:
file_path = str(file_data)
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext in video_extensions:
has_video = True
break
return gr.update(visible=has_video)
def hide_resample_button():
"""Hide the resample button after use"""
return gr.update(visible=False)
def resample_video_with_new_interval(files, new_interval, current_target_dir):
"""Resample video with new slider value"""
if not files:
return (
current_target_dir,
None,
"No files to resample.",
gr.update(visible=False),
)
# Check if we have videos to resample
video_extensions = [
".mp4",
".avi",
".mov",
".mkv",
".wmv",
".flv",
".webm",
".m4v",
".3gp",
]
has_video = any(
os.path.splitext(
str(file_data["name"] if isinstance(file_data, dict) else file_data)
)[1].lower()
in video_extensions
for file_data in files
)
if not has_video:
return (
current_target_dir,
None,
"No videos found to resample.",
gr.update(visible=False),
)
# Clean up old target directory if it exists
if (
current_target_dir
and current_target_dir != "None"
and os.path.exists(current_target_dir)
):
shutil.rmtree(current_target_dir)
# Process files with new interval
target_dir, image_paths = handle_uploads(files, new_interval)
return (
target_dir,
image_paths,
f"Video resampled with {new_interval}s interval. Click 'Reconstruct' to begin 3D processing.",
gr.update(visible=False),
)
unified_upload.change(
fn=update_gallery_on_unified_upload,
inputs=[unified_upload, s_time_interval],
outputs=[target_dir_output, image_gallery, log_output],
).then(
fn=show_resample_button,
inputs=[unified_upload],
outputs=[resample_btn],
)
# Show resample button when slider changes (only if files are uploaded)
s_time_interval.change(
fn=show_resample_button,
inputs=[unified_upload],
outputs=[resample_btn],
)
# Handle resample button click
resample_btn.click(
fn=resample_video_with_new_interval,
inputs=[unified_upload, s_time_interval, target_dir_output],
outputs=[target_dir_output, image_gallery, log_output, resample_btn],
)
# -------------------------------------------------------------------------
# Measure tab functionality
# -------------------------------------------------------------------------
measure_image.select(
fn=measure,
inputs=[processed_data_state, measure_points_state, measure_view_selector],
outputs=[measure_image, measure_points_state, measure_text],
)
# -------------------------------------------------------------------------
# Navigation functionality for Depth, Normal, and Measure tabs
# -------------------------------------------------------------------------
# Depth tab navigation
prev_depth_btn.click(
fn=lambda processed_data, current_selector: navigate_depth_view(
processed_data, current_selector, -1
),
inputs=[processed_data_state, depth_view_selector],
outputs=[depth_view_selector, depth_map],
)
next_depth_btn.click(
fn=lambda processed_data, current_selector: navigate_depth_view(
processed_data, current_selector, 1
),
inputs=[processed_data_state, depth_view_selector],
outputs=[depth_view_selector, depth_map],
)
depth_view_selector.change(
fn=lambda processed_data, selector_value: (
update_depth_view(
processed_data,
int(selector_value.split()[1]) - 1,
)
if selector_value
else None
),
inputs=[processed_data_state, depth_view_selector],
outputs=[depth_map],
)
# Normal tab navigation
prev_normal_btn.click(
fn=lambda processed_data, current_selector: navigate_normal_view(
processed_data, current_selector, -1
),
inputs=[processed_data_state, normal_view_selector],
outputs=[normal_view_selector, normal_map],
)
next_normal_btn.click(
fn=lambda processed_data, current_selector: navigate_normal_view(
processed_data, current_selector, 1
),
inputs=[processed_data_state, normal_view_selector],
outputs=[normal_view_selector, normal_map],
)
normal_view_selector.change(
fn=lambda processed_data, selector_value: (
update_normal_view(
processed_data,
int(selector_value.split()[1]) - 1,
)
if selector_value
else None
),
inputs=[processed_data_state, normal_view_selector],
outputs=[normal_map],
)
# Measure tab navigation
prev_measure_btn.click(
fn=lambda processed_data, current_selector: navigate_measure_view(
processed_data, current_selector, -1
),
inputs=[processed_data_state, measure_view_selector],
outputs=[measure_view_selector, measure_image, measure_points_state],
)
next_measure_btn.click(
fn=lambda processed_data, current_selector: navigate_measure_view(
processed_data, current_selector, 1
),
inputs=[processed_data_state, measure_view_selector],
outputs=[measure_view_selector, measure_image, measure_points_state],
)
measure_view_selector.change(
fn=lambda processed_data, selector_value: (
update_measure_view(processed_data, int(selector_value.split()[1]) - 1)
if selector_value
else (None, [])
),
inputs=[processed_data_state, measure_view_selector],
outputs=[measure_image, measure_points_state],
)
# -------------------------------------------------------------------------
# API Endpoint for downloading predictions.npz
# -------------------------------------------------------------------------
# Expose get_predictions_file as an API endpoint
# Using render=True (default) for components - use elem_id for CSS hiding if needed
# The button must be visible for Gradio to register the API endpoint
predictions_dir_input = gr.Textbox(
label="Target Dir for Predictions",
value="",
elem_id="predictions_api_input",
scale=0, # Minimize space
)
predictions_file_output = gr.File(
label="Predictions NPZ File",
elem_id="predictions_api_output",
scale=0,
)
predictions_api_btn = gr.Button(
"Download Predictions NPZ (API)",
elem_id="predictions_api_btn",
size="sm",
scale=0,
)
predictions_api_btn.click(
fn=get_predictions_file,
inputs=[predictions_dir_input],
outputs=[predictions_file_output],
)
# -------------------------------------------------------------------------
# Acknowledgement section
# -------------------------------------------------------------------------
gr.HTML(get_acknowledgements_html())
demo.queue(max_size=20).launch(show_error=True, share=True, ssr_mode=False)