prithivMLmods commited on
Commit
ea340f5
·
verified ·
1 Parent(s): 6ef3d1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +552 -1010
app.py CHANGED
@@ -16,12 +16,12 @@ import torch
16
  from PIL import Image
17
  from pillow_heif import register_heif_opener
18
 
19
- # --- Rerun Imports ---
20
  import rerun as rr
21
  try:
22
  import rerun.blueprint as rrb
23
  except ImportError:
24
  rrb = None
 
25
  from gradio_rerun import Rerun
26
 
27
  register_heif_opener()
@@ -41,10 +41,9 @@ from mapanything.utils.hf_utils.hf_helpers import initialize_mapanything_model
41
  from mapanything.utils.hf_utils.viz import predictions_to_glb
42
  from mapanything.utils.image import load_images, rgb
43
 
44
- # MapAnything Configuration
45
  high_level_config = {
46
  "path": "configs/train.yaml",
47
- "hf_model_name": "facebook/map-anything-v1", # -- facebook/map-anything
48
  "model_str": "mapanything",
49
  "config_overrides": [
50
  "machine=aws",
@@ -61,37 +60,31 @@ high_level_config = {
61
  "resolution": 518,
62
  }
63
 
64
- # Initialize model - this will be done on GPU when needed
65
  model = None
 
 
66
 
67
 
68
  # -------------------------------------------------------------------------
69
- # Rerun Helper Function
70
  # -------------------------------------------------------------------------
71
- def create_rerun_recording(glb_path, output_dir):
72
- """
73
- Takes a generated GLB file, wraps it in a Rerun recording (.rrd),
74
- and returns the path to the .rrd file for the UI to consume.
75
- """
76
  run_id = str(uuid.uuid4())
77
-
78
- # Robustly handle different Rerun SDK versions
 
79
  rec = None
80
  if hasattr(rr, "new_recording"):
81
- rec = rr.new_recording(application_id="MapAnything-3D", recording_id=run_id)
82
  elif hasattr(rr, "RecordingStream"):
83
- rec = rr.RecordingStream(application_id="MapAnything-3D", recording_id=run_id)
84
  else:
85
- rr.init("MapAnything-3D", recording_id=run_id, spawn=False)
86
  rec = rr
87
-
88
- # Clear previous states
89
  rec.log("world", rr.Clear(recursive=True), static=True)
90
-
91
- # Set coordinates
92
  rec.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, static=True)
93
 
94
- # Add optional axes helpers
95
  try:
96
  rec.log("world/axes/x", rr.Arrows3D(vectors=[[0.5, 0, 0]], colors=[[255, 0, 0]]), static=True)
97
  rec.log("world/axes/y", rr.Arrows3D(vectors=[[0, 0.5, 0]], colors=[[0, 255, 0]]), static=True)
@@ -99,27 +92,78 @@ def create_rerun_recording(glb_path, output_dir):
99
  except Exception:
100
  pass
101
 
102
- # Log the 3D Model
103
- rec.log("world/scene", rr.Asset3D(path=glb_path), static=True)
104
-
105
- # Blueprint for clean layout
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  if rrb is not None:
107
  try:
108
  blueprint = rrb.Blueprint(
109
- rrb.Spatial3DView(
110
- origin="/world",
111
- name="3D View",
112
- ),
113
  collapse_panels=True,
114
  )
115
  rec.send_blueprint(blueprint)
116
  except Exception as e:
117
  print(f"Blueprint creation failed (non-fatal): {e}")
118
 
119
- # Save the recording to the target directory
120
- rrd_path = os.path.join(output_dir, f'scene_{run_id}.rrd')
121
  rec.save(rrd_path)
122
-
123
  return rrd_path
124
 
125
 
@@ -127,448 +171,221 @@ def create_rerun_recording(glb_path, output_dir):
127
  # 1) Core model inference
128
  # -------------------------------------------------------------------------
129
  @spaces.GPU(duration=120)
130
- def run_model(
131
- target_dir,
132
- apply_mask=True,
133
- mask_edges=True,
134
- filter_black_bg=False,
135
- filter_white_bg=False,
136
- ):
137
- """
138
- Run the MapAnything model on images in the 'target_dir/images' folder and return predictions.
139
- """
140
  global model
141
- import torch # Ensure torch is available in function scope
142
 
143
  print(f"Processing images from {target_dir}")
 
144
 
145
- # Device check
146
- device = "cuda" if torch.cuda.is_available() else "cpu"
147
- device = torch.device(device)
148
-
149
- # Initialize model if not already done
150
  if model is None:
151
  model = initialize_mapanything_model(high_level_config, device)
152
-
153
  else:
154
  model = model.to(device)
155
-
156
  model.eval()
157
 
158
- # Load images using MapAnything's load_images function
159
  print("Loading images...")
160
  image_folder_path = os.path.join(target_dir, "images")
161
  views = load_images(image_folder_path)
162
-
163
  print(f"Loaded {len(views)} images")
164
  if len(views) == 0:
165
  raise ValueError("No images found. Check your upload.")
166
 
167
- # Run model inference
168
  print("Running inference...")
169
- # apply_mask: Whether to apply the non-ambiguous mask to the output. Defaults to True.
170
- # mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True.
171
- # Use checkbox values - mask_edges is set to True by default since there's no UI control for it
172
- outputs = model.infer(
173
- views, apply_mask=apply_mask, mask_edges=True, memory_efficient_inference=False
174
- )
175
 
176
- # Convert predictions to format expected by visualization
177
  predictions = {}
 
 
178
 
179
- # Initialize lists for the required keys
180
- extrinsic_list = []
181
- intrinsic_list = []
182
- world_points_list = []
183
- depth_maps_list = []
184
- images_list = []
185
- final_mask_list = []
186
-
187
- # Loop through the outputs
188
  for pred in outputs:
189
- # Extract data from predictions
190
- depthmap_torch = pred["depth_z"][0].squeeze(-1) # (H, W)
191
- intrinsics_torch = pred["intrinsics"][0] # (3, 3)
192
- camera_pose_torch = pred["camera_poses"][0] # (4, 4)
193
-
194
- # Compute new pts3d using depth, intrinsics, and camera pose
195
- pts3d_computed, valid_mask = depthmap_to_world_frame(
196
- depthmap_torch, intrinsics_torch, camera_pose_torch
197
- )
198
-
199
- # Convert to numpy arrays for visualization
200
- # Check if mask key exists in pred, if not, fill with boolean trues in the size of depthmap_torch
201
- if "mask" in pred:
202
- mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool)
203
- else:
204
- # Fill with boolean trues in the size of depthmap_torch
205
- mask = np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool)
206
-
207
- # Combine with valid depth mask
208
  mask = mask & valid_mask.cpu().numpy()
209
-
210
  image = pred["img_no_norm"][0].cpu().numpy()
211
-
212
- # Append to lists
213
  extrinsic_list.append(camera_pose_torch.cpu().numpy())
214
  intrinsic_list.append(intrinsics_torch.cpu().numpy())
215
  world_points_list.append(pts3d_computed.cpu().numpy())
216
  depth_maps_list.append(depthmap_torch.cpu().numpy())
217
- images_list.append(image) # Add image to list
218
- final_mask_list.append(mask) # Add final_mask to list
219
 
220
- # Convert lists to numpy arrays with required shapes
221
- # extrinsic: (S, 3, 4) - batch of camera extrinsic matrices
222
  predictions["extrinsic"] = np.stack(extrinsic_list, axis=0)
223
-
224
- # intrinsic: (S, 3, 3) - batch of camera intrinsic matrices
225
  predictions["intrinsic"] = np.stack(intrinsic_list, axis=0)
226
-
227
- # world_points: (S, H, W, 3) - batch of 3D world points
228
  predictions["world_points"] = np.stack(world_points_list, axis=0)
229
-
230
- # depth: (S, H, W, 1) or (S, H, W) - batch of depth maps
231
  depth_maps = np.stack(depth_maps_list, axis=0)
232
- # Add channel dimension if needed to match (S, H, W, 1) format
233
  if len(depth_maps.shape) == 3:
234
  depth_maps = depth_maps[..., np.newaxis]
235
-
236
  predictions["depth"] = depth_maps
237
-
238
- # images: (S, H, W, 3) - batch of input images
239
  predictions["images"] = np.stack(images_list, axis=0)
240
-
241
- # final_mask: (S, H, W) - batch of final masks for filtering
242
  predictions["final_mask"] = np.stack(final_mask_list, axis=0)
243
 
244
- # Process data for visualization tabs (depth, normal, measure)
245
- processed_data = process_predictions_for_visualization(
246
- predictions, views, high_level_config, filter_black_bg, filter_white_bg
247
- )
248
-
249
- # Clean up
250
  torch.cuda.empty_cache()
251
-
252
  return predictions, processed_data
253
 
254
 
255
  def update_view_selectors(processed_data):
256
- """Update view selector dropdowns based on available views"""
257
- if processed_data is None or len(processed_data) == 0:
258
- choices = ["View 1"]
259
- else:
260
- num_views = len(processed_data)
261
- choices = [f"View {i + 1}" for i in range(num_views)]
262
-
263
  return (
264
- gr.Dropdown(choices=choices, value=choices[0]), # depth_view_selector
265
- gr.Dropdown(choices=choices, value=choices[0]), # normal_view_selector
266
- gr.Dropdown(choices=choices, value=choices[0]), # measure_view_selector
267
  )
268
 
269
 
270
  def get_view_data_by_index(processed_data, view_index):
271
- """Get view data by index, handling bounds"""
272
- if processed_data is None or len(processed_data) == 0:
273
  return None
274
-
275
  view_keys = list(processed_data.keys())
276
- if view_index < 0 or view_index >= len(view_keys):
277
- view_index = 0
278
-
279
  return processed_data[view_keys[view_index]]
280
 
281
 
282
  def update_depth_view(processed_data, view_index):
283
- """Update depth view for a specific view index"""
284
  view_data = get_view_data_by_index(processed_data, view_index)
285
  if view_data is None or view_data["depth"] is None:
286
  return None
287
-
288
  return colorize_depth(view_data["depth"], mask=view_data.get("mask"))
289
 
290
 
291
  def update_normal_view(processed_data, view_index):
292
- """Update normal view for a specific view index"""
293
  view_data = get_view_data_by_index(processed_data, view_index)
294
  if view_data is None or view_data["normal"] is None:
295
  return None
296
-
297
  return colorize_normal(view_data["normal"], mask=view_data.get("mask"))
298
 
299
 
300
  def update_measure_view(processed_data, view_index):
301
- """Update measure view for a specific view index with mask overlay"""
302
  view_data = get_view_data_by_index(processed_data, view_index)
303
  if view_data is None:
304
- return None, [] # image, measure_points
305
-
306
- # Get the base image
307
  image = view_data["image"].copy()
308
-
309
- # Ensure image is in uint8 format
310
  if image.dtype != np.uint8:
311
- if image.max() <= 1.0:
312
- image = (image * 255).astype(np.uint8)
313
- else:
314
- image = image.astype(np.uint8)
315
-
316
- # Apply mask overlay if mask is available
317
  if view_data["mask"] is not None:
318
- mask = view_data["mask"]
319
-
320
- # Create light grey overlay for masked areas
321
- # Masked areas (False values) will be overlaid with light grey
322
- invalid_mask = ~mask # Areas where mask is False
323
-
324
  if invalid_mask.any():
325
- # Create a light grey overlay (RGB: 192, 192, 192)
326
  overlay_color = np.array([255, 220, 220], dtype=np.uint8)
327
-
328
- # Apply overlay with some transparency
329
- alpha = 0.5 # Transparency level
330
- for c in range(3): # RGB channels
331
- image[:, :, c] = np.where(
332
- invalid_mask,
333
- (1 - alpha) * image[:, :, c] + alpha * overlay_color[c],
334
- image[:, :, c],
335
- ).astype(np.uint8)
336
-
337
  return image, []
338
 
339
 
340
  def navigate_depth_view(processed_data, current_selector_value, direction):
341
- """Navigate depth view (direction: -1 for previous, +1 for next)"""
342
- if processed_data is None or len(processed_data) == 0:
343
  return "View 1", None
344
-
345
- # Parse current view number
346
  try:
347
  current_view = int(current_selector_value.split()[1]) - 1
348
  except:
349
  current_view = 0
350
-
351
- num_views = len(processed_data)
352
- new_view = (current_view + direction) % num_views
353
-
354
- new_selector_value = f"View {new_view + 1}"
355
- depth_vis = update_depth_view(processed_data, new_view)
356
-
357
- return new_selector_value, depth_vis
358
 
359
 
360
  def navigate_normal_view(processed_data, current_selector_value, direction):
361
- """Navigate normal view (direction: -1 for previous, +1 for next)"""
362
- if processed_data is None or len(processed_data) == 0:
363
  return "View 1", None
364
-
365
- # Parse current view number
366
  try:
367
  current_view = int(current_selector_value.split()[1]) - 1
368
  except:
369
  current_view = 0
370
-
371
- num_views = len(processed_data)
372
- new_view = (current_view + direction) % num_views
373
-
374
- new_selector_value = f"View {new_view + 1}"
375
- normal_vis = update_normal_view(processed_data, new_view)
376
-
377
- return new_selector_value, normal_vis
378
 
379
 
380
  def navigate_measure_view(processed_data, current_selector_value, direction):
381
- """Navigate measure view (direction: -1 for previous, +1 for next)"""
382
- if processed_data is None or len(processed_data) == 0:
383
  return "View 1", None, []
384
-
385
- # Parse current view number
386
  try:
387
  current_view = int(current_selector_value.split()[1]) - 1
388
  except:
389
  current_view = 0
390
-
391
- num_views = len(processed_data)
392
- new_view = (current_view + direction) % num_views
393
-
394
- new_selector_value = f"View {new_view + 1}"
395
  measure_image, measure_points = update_measure_view(processed_data, new_view)
396
-
397
- return new_selector_value, measure_image, measure_points
398
 
399
 
400
  def populate_visualization_tabs(processed_data):
401
- """Populate the depth, normal, and measure tabs with processed data"""
402
- if processed_data is None or len(processed_data) == 0:
403
  return None, None, None, []
404
-
405
- # Use update functions to ensure confidence filtering is applied from the start
406
- depth_vis = update_depth_view(processed_data, 0)
407
- normal_vis = update_normal_view(processed_data, 0)
408
- measure_img, _ = update_measure_view(processed_data, 0)
409
-
410
- return depth_vis, normal_vis, measure_img, []
411
 
412
 
413
  # -------------------------------------------------------------------------
414
- # 2) Handle uploaded video/images --> produce target_dir + images
415
  # -------------------------------------------------------------------------
416
  def handle_uploads(unified_upload, s_time_interval=1.0):
417
- """
418
- Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
419
- images or extracted frames from video into it. Return (target_dir, image_paths).
420
- """
421
  start_time = time.time()
422
  gc.collect()
423
  torch.cuda.empty_cache()
424
 
425
- # Create a unique folder name
426
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
427
  target_dir = f"input_images_{timestamp}"
428
  target_dir_images = os.path.join(target_dir, "images")
429
-
430
- # Clean up if somehow that folder already exists
431
  if os.path.exists(target_dir):
432
  shutil.rmtree(target_dir)
433
- os.makedirs(target_dir)
434
  os.makedirs(target_dir_images)
435
 
436
  image_paths = []
 
437
 
438
- # --- Handle uploaded files (both images and videos) ---
439
  if unified_upload is not None:
440
  for file_data in unified_upload:
441
- if isinstance(file_data, dict) and "name" in file_data:
442
- file_path = file_data["name"]
443
- else:
444
- file_path = str(file_data)
445
-
446
  file_ext = os.path.splitext(file_path)[1].lower()
447
 
448
- # Check if it's a video file
449
- video_extensions = [
450
- ".mp4",
451
- ".avi",
452
- ".mov",
453
- ".mkv",
454
- ".wmv",
455
- ".flv",
456
- ".webm",
457
- ".m4v",
458
- ".3gp",
459
- ]
460
  if file_ext in video_extensions:
461
- # Handle as video
462
  vs = cv2.VideoCapture(file_path)
463
  fps = vs.get(cv2.CAP_PROP_FPS)
464
- frame_interval = int(fps * s_time_interval) # frames per interval
465
-
466
- count = 0
467
- video_frame_num = 0
468
  while True:
469
  gotit, frame = vs.read()
470
  if not gotit:
471
  break
472
  count += 1
473
  if count % frame_interval == 0:
474
- # Use original filename as prefix for frames
475
  base_name = os.path.splitext(os.path.basename(file_path))[0]
476
- image_path = os.path.join(
477
- target_dir_images, f"{base_name}_{video_frame_num:06}.png"
478
- )
479
  cv2.imwrite(image_path, frame)
480
  image_paths.append(image_path)
481
  video_frame_num += 1
482
  vs.release()
483
- print(
484
- f"Extracted {video_frame_num} frames from video: {os.path.basename(file_path)}"
485
- )
486
-
487
- else:
488
- # Handle as image
489
- # Check if the file is a HEIC image
490
- if file_ext in [".heic", ".heif"]:
491
- # Convert HEIC to JPEG for better gallery compatibility
492
- try:
493
- with Image.open(file_path) as img:
494
- # Convert to RGB if necessary (HEIC can have different color modes)
495
- if img.mode not in ("RGB", "L"):
496
- img = img.convert("RGB")
497
-
498
- # Create JPEG filename
499
- base_name = os.path.splitext(os.path.basename(file_path))[0]
500
- dst_path = os.path.join(
501
- target_dir_images, f"{base_name}.jpg"
502
- )
503
-
504
- # Save as JPEG with high quality
505
- img.save(dst_path, "JPEG", quality=95)
506
- image_paths.append(dst_path)
507
- print(
508
- f"Converted HEIC to JPEG: {os.path.basename(file_path)} -> {os.path.basename(dst_path)}"
509
- )
510
- except Exception as e:
511
- print(f"Error converting HEIC file {file_path}: {e}")
512
- # Fall back to copying as is
513
- dst_path = os.path.join(
514
- target_dir_images, os.path.basename(file_path)
515
- )
516
- shutil.copy(file_path, dst_path)
517
  image_paths.append(dst_path)
518
- else:
519
- # Regular image files - copy as is
520
- dst_path = os.path.join(
521
- target_dir_images, os.path.basename(file_path)
522
- )
523
  shutil.copy(file_path, dst_path)
524
  image_paths.append(dst_path)
 
 
 
 
525
 
526
- # Sort final images for gallery
527
  image_paths = sorted(image_paths)
528
-
529
- end_time = time.time()
530
- print(
531
- f"Files processed to {target_dir_images}; took {end_time - start_time:.3f} seconds"
532
- )
533
  return target_dir, image_paths
534
 
535
 
536
  # -------------------------------------------------------------------------
537
- # 3) Update gallery on upload
538
- # -------------------------------------------------------------------------
539
- def update_gallery_on_upload(input_video, input_images, s_time_interval=1.0):
540
- """
541
- Whenever user uploads or changes files, immediately handle them
542
- and show in the gallery. Return (target_dir, image_paths).
543
- If nothing is uploaded, returns "None" and empty list.
544
- """
545
- if not input_video and not input_images:
546
- return None, None, None, None
547
- target_dir, image_paths = handle_uploads(input_video, input_images, s_time_interval)
548
- return (
549
- None,
550
- target_dir,
551
- image_paths,
552
- "Upload complete. Click 'Reconstruct' to begin 3D processing.",
553
- )
554
-
555
-
556
- # -------------------------------------------------------------------------
557
- # 4) Reconstruction: uses the target_dir plus any viz parameters
558
  # -------------------------------------------------------------------------
559
  @spaces.GPU(duration=120)
560
- def gradio_demo(
561
- target_dir,
562
- frame_filter="All",
563
- show_cam=True,
564
- filter_black_bg=False,
565
- filter_white_bg=False,
566
- apply_mask=True,
567
- show_mesh=True,
568
- ):
569
- """
570
- Perform reconstruction using the already-created target_dir/images.
571
- """
572
  if not os.path.isdir(target_dir) or target_dir == "None":
573
  return None, "No valid target directory found. Please upload first.", None, None
574
 
@@ -576,411 +393,172 @@ def gradio_demo(
576
  gc.collect()
577
  torch.cuda.empty_cache()
578
 
579
- # Prepare frame_filter dropdown
580
  target_dir_images = os.path.join(target_dir, "images")
581
- all_files = (
582
- sorted(os.listdir(target_dir_images))
583
- if os.path.isdir(target_dir_images)
584
- else []
585
- )
586
- all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
587
- frame_filter_choices = ["All"] + all_files
588
 
589
  print("Running MapAnything model...")
590
  with torch.no_grad():
591
  predictions, processed_data = run_model(target_dir, apply_mask)
592
 
593
- # Save predictions
594
- prediction_save_path = os.path.join(target_dir, "predictions.npz")
595
- np.savez(prediction_save_path, **predictions)
596
 
597
- # Handle None frame_filter
598
  if frame_filter is None:
599
  frame_filter = "All"
600
 
601
- # Build a GLB file name
602
  glbfile = os.path.join(
603
  target_dir,
604
  f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb",
605
  )
606
-
607
- # Convert predictions to GLB
608
- glbscene = predictions_to_glb(
609
- predictions,
610
- filter_by_frames=frame_filter,
611
- show_cam=show_cam,
612
- mask_black_bg=filter_black_bg,
613
- mask_white_bg=filter_white_bg,
614
- as_mesh=show_mesh, # Use the show_mesh parameter
615
- )
616
  glbscene.export(file_obj=glbfile)
617
-
618
- # ---------------------------------------------------------
619
- # Generate the Rerun recording using the new helper
620
- # ---------------------------------------------------------
621
- rrd_path = create_rerun_recording(glbfile, target_dir)
622
 
623
- # Cleanup
 
624
  del predictions
625
  gc.collect()
626
  torch.cuda.empty_cache()
627
 
628
- end_time = time.time()
629
- print(f"Total time: {end_time - start_time:.2f} seconds")
630
- log_msg = (
631
- f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
632
- )
633
-
634
- # Populate visualization tabs with processed data
635
- depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs(
636
- processed_data
637
- )
638
 
639
- # Update view selectors based on available views
640
- depth_selector, normal_selector, measure_selector = update_view_selectors(
641
- processed_data
642
- )
643
 
644
  return (
645
- rrd_path, # Return the Rerun recording path instead of glbfile
646
- log_msg,
647
  gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
648
- processed_data,
649
- depth_vis,
650
- normal_vis,
651
- measure_img,
652
- "", # measure_text (empty initially)
653
- depth_selector,
654
- normal_selector,
655
- measure_selector,
656
  )
657
 
658
 
659
  # -------------------------------------------------------------------------
660
- # 5) Helper functions for UI resets + re-visualization
661
  # -------------------------------------------------------------------------
662
  def colorize_depth(depth_map, mask=None):
663
- """Convert depth map to colorized visualization with optional mask"""
664
  if depth_map is None:
665
  return None
666
-
667
- # Normalize depth to 0-1 range
668
  depth_normalized = depth_map.copy()
669
  valid_mask = depth_normalized > 0
670
-
671
- # Apply additional mask if provided (for background filtering)
672
  if mask is not None:
673
  valid_mask = valid_mask & mask
674
-
675
  if valid_mask.sum() > 0:
676
  valid_depths = depth_normalized[valid_mask]
677
- p5 = np.percentile(valid_depths, 5)
678
- p95 = np.percentile(valid_depths, 95)
679
-
680
  depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5)
681
-
682
- # Apply colormap
683
  import matplotlib.pyplot as plt
684
-
685
- colormap = plt.cm.turbo_r
686
- colored = colormap(depth_normalized)
687
- colored = (colored[:, :, :3] * 255).astype(np.uint8)
688
-
689
- # Set invalid pixels to white
690
  colored[~valid_mask] = [255, 255, 255]
691
-
692
  return colored
693
 
694
 
695
  def colorize_normal(normal_map, mask=None):
696
- """Convert normal map to colorized visualization with optional mask"""
697
  if normal_map is None:
698
  return None
699
-
700
- # Create a copy for modification
701
  normal_vis = normal_map.copy()
702
-
703
- # Apply mask if provided (set masked areas to [0, 0, 0] which becomes grey after normalization)
704
  if mask is not None:
705
- invalid_mask = ~mask
706
- normal_vis[invalid_mask] = [0, 0, 0] # Set invalid areas to zero
707
 
708
- # Normalize normals to [0, 1] range for visualization
709
- normal_vis = (normal_vis + 1.0) / 2.0
710
- normal_vis = (normal_vis * 255).astype(np.uint8)
711
 
712
- return normal_vis
713
-
714
-
715
- def process_predictions_for_visualization(
716
- predictions, views, high_level_config, filter_black_bg=False, filter_white_bg=False
717
- ):
718
- """Extract depth, normal, and 3D points from predictions for visualization"""
719
  processed_data = {}
720
-
721
- # Process each view
722
  for view_idx, view in enumerate(views):
723
- # Get image
724
  image = rgb(view["img"], norm_type=high_level_config["data_norm_type"])
725
-
726
- # Get predicted points
727
  pred_pts3d = predictions["world_points"][view_idx]
728
-
729
- # Initialize data for this view
730
- view_data = {
731
- "image": image[0],
732
- "points3d": pred_pts3d,
733
- "depth": None,
734
- "normal": None,
735
- "mask": None,
736
- }
737
-
738
- # Start with the final mask from predictions
739
  mask = predictions["final_mask"][view_idx].copy()
740
-
741
- # Apply black background filtering if enabled
742
  if filter_black_bg:
743
- # Get the image colors (ensure they're in 0-255 range)
744
  view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
745
- # Filter out black background pixels (sum of RGB < 16)
746
- black_bg_mask = view_colors.sum(axis=2) >= 16
747
- mask = mask & black_bg_mask
748
-
749
- # Apply white background filtering if enabled
750
  if filter_white_bg:
751
- # Get the image colors (ensure they're in 0-255 range)
752
  view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
753
- # Filter out white background pixels (all RGB > 240)
754
- white_bg_mask = ~(
755
- (view_colors[:, :, 0] > 240)
756
- & (view_colors[:, :, 1] > 240)
757
- & (view_colors[:, :, 2] > 240)
758
- )
759
- mask = mask & white_bg_mask
760
-
761
- view_data["mask"] = mask
762
- view_data["depth"] = predictions["depth"][view_idx].squeeze()
763
-
764
- normals, _ = points_to_normals(pred_pts3d, mask=view_data["mask"])
765
- view_data["normal"] = normals
766
-
767
- processed_data[view_idx] = view_data
768
-
769
  return processed_data
770
 
771
 
772
- def reset_measure(processed_data):
773
- """Reset measure points"""
774
- if processed_data is None or len(processed_data) == 0:
775
- return None, [], ""
776
-
777
- # Return the first view image
778
- first_view = list(processed_data.values())[0]
779
- return first_view["image"], [], ""
780
-
781
-
782
- def measure(
783
- processed_data, measure_points, current_view_selector, event: gr.SelectData
784
- ):
785
- """Handle measurement on images"""
786
  try:
787
- print(f"Measure function called with selector: {current_view_selector}")
788
-
789
- if processed_data is None or len(processed_data) == 0:
790
  return None, [], "No data available"
791
-
792
- # Use the currently selected view instead of always using the first view
793
  try:
794
  current_view_index = int(current_view_selector.split()[1]) - 1
795
  except:
796
  current_view_index = 0
797
-
798
- print(f"Using view index: {current_view_index}")
799
-
800
- # Get view data safely
801
- if current_view_index < 0 or current_view_index >= len(processed_data):
802
- current_view_index = 0
803
-
804
- view_keys = list(processed_data.keys())
805
- current_view = processed_data[view_keys[current_view_index]]
806
-
807
  if current_view is None:
808
  return None, [], "No view data available"
809
 
810
  point2d = event.index[0], event.index[1]
811
- print(f"Clicked point: {point2d}")
812
-
813
- # Check if the clicked point is in a masked area (prevent interaction)
814
- if (
815
- current_view["mask"] is not None
816
- and 0 <= point2d[1] < current_view["mask"].shape[0]
817
- and 0 <= point2d[0] < current_view["mask"].shape[1]
818
- ):
819
- # Check if the point is in a masked (invalid) area
820
  if not current_view["mask"][point2d[1], point2d[0]]:
821
- print(f"Clicked point {point2d} is in masked area, ignoring click")
822
- # Always return image with mask overlay
823
- masked_image, _ = update_measure_view(
824
- processed_data, current_view_index
825
- )
826
- return (
827
- masked_image,
828
- measure_points,
829
- '<span style="color: red; font-weight: bold;">Cannot measure on masked areas (shown in grey)</span>',
830
- )
831
 
832
  measure_points.append(point2d)
833
-
834
- # Get image with mask overlay and ensure it's valid
835
  image, _ = update_measure_view(processed_data, current_view_index)
836
  if image is None:
837
  return None, [], "No image available"
838
-
839
  image = image.copy()
 
 
840
  points3d = current_view["points3d"]
841
 
842
- # Ensure image is in uint8 format for proper cv2 operations
843
- try:
844
- if image.dtype != np.uint8:
845
- if image.max() <= 1.0:
846
- # Image is in [0, 1] range, convert to [0, 255]
847
- image = (image * 255).astype(np.uint8)
848
- else:
849
- # Image is already in [0, 255] range
850
- image = image.astype(np.uint8)
851
- except Exception as e:
852
- print(f"Image conversion error: {e}")
853
- return None, [], f"Image conversion error: {e}"
854
-
855
- # Draw circles for points
856
- try:
857
- for p in measure_points:
858
- if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]:
859
- image = cv2.circle(
860
- image, p, radius=5, color=(255, 0, 0), thickness=2
861
- )
862
- except Exception as e:
863
- print(f"Drawing error: {e}")
864
- return None, [], f"Drawing error: {e}"
865
 
866
  depth_text = ""
867
- try:
868
- for i, p in enumerate(measure_points):
869
- if (
870
- current_view["depth"] is not None
871
- and 0 <= p[1] < current_view["depth"].shape[0]
872
- and 0 <= p[0] < current_view["depth"].shape[1]
873
- ):
874
- d = current_view["depth"][p[1], p[0]]
875
- depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n"
876
- else:
877
- # Use Z coordinate of 3D points if depth not available
878
- if (
879
- points3d is not None
880
- and 0 <= p[1] < points3d.shape[0]
881
- and 0 <= p[0] < points3d.shape[1]
882
- ):
883
- z = points3d[p[1], p[0], 2]
884
- depth_text += f"- **P{i + 1} Z-coord: {z:.2f}m.**\n"
885
- except Exception as e:
886
- print(f"Depth text error: {e}")
887
- depth_text = f"Error computing depth: {e}\n"
888
 
889
  if len(measure_points) == 2:
890
- try:
891
- point1, point2 = measure_points
892
- # Draw line
893
- if (
894
- 0 <= point1[0] < image.shape[1]
895
- and 0 <= point1[1] < image.shape[0]
896
- and 0 <= point2[0] < image.shape[1]
897
- and 0 <= point2[1] < image.shape[0]
898
- ):
899
- image = cv2.line(
900
- image, point1, point2, color=(255, 0, 0), thickness=2
901
- )
902
-
903
- # Compute 3D distance
904
- distance_text = "- **Distance: Unable to compute**"
905
- if (
906
- points3d is not None
907
- and 0 <= point1[1] < points3d.shape[0]
908
- and 0 <= point1[0] < points3d.shape[1]
909
- and 0 <= point2[1] < points3d.shape[0]
910
- and 0 <= point2[0] < points3d.shape[1]
911
- ):
912
- try:
913
- p1_3d = points3d[point1[1], point1[0]]
914
- p2_3d = points3d[point2[1], point2[0]]
915
- distance = np.linalg.norm(p1_3d - p2_3d)
916
- distance_text = f"- **Distance: {distance:.2f}m**"
917
- except Exception as e:
918
- print(f"Distance computation error: {e}")
919
- distance_text = f"- **Distance computation error: {e}**"
920
-
921
- measure_points = []
922
- text = depth_text + distance_text
923
- print(f"Measurement complete: {text}")
924
- return [image, measure_points, text]
925
- except Exception as e:
926
- print(f"Final measurement error: {e}")
927
- return None, [], f"Measurement error: {e}"
928
- else:
929
- print(f"Single point measurement: {depth_text}")
930
- return [image, measure_points, depth_text]
931
-
932
  except Exception as e:
933
- print(f"Overall measure function error: {e}")
934
- return None, [], f"Measure function error: {e}"
935
 
936
 
937
  def clear_fields():
938
- """
939
- Clears the 3D viewer, the stored target_dir, and empties the gallery.
940
- """
941
  return None
942
 
943
 
944
  def update_log():
945
- """
946
- Display a quick log message while waiting.
947
- """
948
- return "Loading and Reconstructing..."
949
-
950
-
951
- def update_visualization(
952
- target_dir,
953
- frame_filter,
954
- show_cam,
955
- is_example,
956
- filter_black_bg=False,
957
- filter_white_bg=False,
958
- show_mesh=True,
959
- ):
960
- """
961
- Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
962
- wrap it in a Rerun recording (.rrd), and return it for the Rerun viewer.
963
- """
964
-
965
- # If it's an example click, skip as requested
966
- if is_example == "True":
967
- return (
968
- gr.update(),
969
- "No reconstruction available. Please click the Reconstruct button first.",
970
- )
971
 
972
- if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
973
- return (
974
- gr.update(),
975
- "No reconstruction available. Please click the Reconstruct button first.",
976
- )
977
 
 
 
 
 
 
978
  predictions_path = os.path.join(target_dir, "predictions.npz")
979
  if not os.path.exists(predictions_path):
980
- return (
981
- gr.update(),
982
- f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.",
983
- )
984
 
985
  loaded = np.load(predictions_path, allow_pickle=True)
986
  predictions = {key: loaded[key] for key in loaded.keys()}
@@ -989,95 +567,36 @@ def update_visualization(
989
  target_dir,
990
  f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb",
991
  )
992
-
993
  if not os.path.exists(glbfile):
994
- glbscene = predictions_to_glb(
995
- predictions,
996
- filter_by_frames=frame_filter,
997
- show_cam=show_cam,
998
- mask_black_bg=filter_black_bg,
999
- mask_white_bg=filter_white_bg,
1000
- as_mesh=show_mesh,
1001
- )
1002
  glbscene.export(file_obj=glbfile)
1003
-
1004
- # Generate the Rerun recording using the helper
1005
- rrd_path = create_rerun_recording(glbfile, target_dir)
1006
 
1007
- return (
1008
- rrd_path, # Was glbfile
1009
- "Visualization updated.",
1010
- )
1011
 
1012
 
1013
- def update_all_views_on_filter_change(
1014
- target_dir,
1015
- filter_black_bg,
1016
- filter_white_bg,
1017
- processed_data,
1018
- depth_view_selector,
1019
- normal_view_selector,
1020
- measure_view_selector,
1021
- ):
1022
- """
1023
- Update all individual view tabs when background filtering checkboxes change.
1024
- This regenerates the processed data with new filtering and updates all views.
1025
- """
1026
- # Check if we have a valid target directory and predictions
1027
  if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
1028
  return processed_data, None, None, None, []
1029
-
1030
  predictions_path = os.path.join(target_dir, "predictions.npz")
1031
  if not os.path.exists(predictions_path):
1032
  return processed_data, None, None, None, []
1033
-
1034
  try:
1035
- # Load the original predictions and views
1036
  loaded = np.load(predictions_path, allow_pickle=True)
1037
  predictions = {key: loaded[key] for key in loaded.keys()}
1038
-
1039
- # Load images using MapAnything's load_images function
1040
- image_folder_path = os.path.join(target_dir, "images")
1041
- views = load_images(image_folder_path)
1042
-
1043
- # Regenerate processed data with new filtering settings
1044
- new_processed_data = process_predictions_for_visualization(
1045
- predictions, views, high_level_config, filter_black_bg, filter_white_bg
1046
- )
1047
-
1048
- # Get current view indices
1049
- try:
1050
- depth_view_idx = (
1051
- int(depth_view_selector.split()[1]) - 1 if depth_view_selector else 0
1052
- )
1053
- except:
1054
- depth_view_idx = 0
1055
-
1056
- try:
1057
- normal_view_idx = (
1058
- int(normal_view_selector.split()[1]) - 1 if normal_view_selector else 0
1059
- )
1060
- except:
1061
- normal_view_idx = 0
1062
-
1063
- try:
1064
- measure_view_idx = (
1065
- int(measure_view_selector.split()[1]) - 1
1066
- if measure_view_selector
1067
- else 0
1068
- )
1069
- except:
1070
- measure_view_idx = 0
1071
-
1072
- # Update all views with new filtered data
1073
- depth_vis = update_depth_view(new_processed_data, depth_view_idx)
1074
- normal_vis = update_normal_view(new_processed_data, normal_view_idx)
1075
- measure_img, _ = update_measure_view(new_processed_data, measure_view_idx)
1076
-
1077
  return new_processed_data, depth_vis, normal_vis, measure_img, []
1078
-
1079
  except Exception as e:
1080
- print(f"Error updating views on filter change: {e}")
1081
  return processed_data, None, None, None, []
1082
 
1083
 
@@ -1085,446 +604,469 @@ def update_all_views_on_filter_change(
1085
  # Example scene functions
1086
  # -------------------------------------------------------------------------
1087
  def get_scene_info(examples_dir):
1088
- """Get information about scenes in the examples directory"""
1089
  import glob
1090
-
1091
  scenes = []
1092
  if not os.path.exists(examples_dir):
1093
  return scenes
1094
-
1095
  for scene_folder in sorted(os.listdir(examples_dir)):
1096
  scene_path = os.path.join(examples_dir, scene_folder)
1097
  if os.path.isdir(scene_path):
1098
- # Find all image files in the scene folder
1099
- image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"]
1100
  image_files = []
1101
- for ext in image_extensions:
1102
  image_files.extend(glob.glob(os.path.join(scene_path, ext)))
1103
  image_files.extend(glob.glob(os.path.join(scene_path, ext.upper())))
1104
-
1105
  if image_files:
1106
- # Sort images and get the first one for thumbnail
1107
  image_files = sorted(image_files)
1108
- first_image = image_files[0]
1109
- num_images = len(image_files)
1110
-
1111
- scenes.append(
1112
- {
1113
- "name": scene_folder,
1114
- "path": scene_path,
1115
- "thumbnail": first_image,
1116
- "num_images": num_images,
1117
- "image_files": image_files,
1118
- }
1119
- )
1120
-
1121
  return scenes
1122
 
1123
 
1124
  def load_example_scene(scene_name, examples_dir="examples"):
1125
- """Load a scene from examples directory"""
1126
  scenes = get_scene_info(examples_dir)
1127
-
1128
- # Find the selected scene
1129
- selected_scene = None
1130
- for scene in scenes:
1131
- if scene["name"] == scene_name:
1132
- selected_scene = scene
1133
- break
1134
-
1135
  if selected_scene is None:
1136
  return None, None, None, "Scene not found"
 
 
 
1137
 
1138
- # Create file-like objects for the unified upload system
1139
- # Convert image file paths to the format expected by unified_upload
1140
- file_objects = []
1141
- for image_path in selected_scene["image_files"]:
1142
- file_objects.append(image_path)
 
 
 
 
 
1143
 
1144
- # Create target directory and copy images using the unified upload system
1145
- target_dir, image_paths = handle_uploads(file_objects, 1.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1146
 
1147
- return (
1148
- None, # Clear reconstruction output
1149
- target_dir, # Set target directory
1150
- image_paths, # Set gallery
1151
- f"Loaded scene '{scene_name}' with {selected_scene['num_images']} images. Click 'Reconstruct' to begin 3D processing.",
1152
- )
 
 
 
 
 
 
 
 
 
1153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1154
 
1155
  # -------------------------------------------------------------------------
1156
  # 6) Build Gradio UI
1157
  # -------------------------------------------------------------------------
1158
- theme = get_gradio_theme()
1159
-
1160
- with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo:
1161
- # State variables
1162
- is_example = gr.Textbox(label="is_example", visible=False, value="None")
1163
- num_images = gr.Textbox(label="num_images", visible=False, value="None")
1164
- processed_data_state = gr.State(value=None)
1165
- measure_points_state = gr.State(value=[])
1166
- current_view_index = gr.State(value=0)
1167
- target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
1168
-
1169
- # --- Header Area ---
1170
- with gr.Column(elem_id="header-container"):
1171
- gr.Markdown(
1172
- "<div style='text-align: center; max-width: 800px; margin: 0 auto; padding-top: 10px;'>"
1173
- "<h1>🗺️ Map-Anything-v1</h1>"
1174
- "<h3 style='color: #666; font-weight: 400;'>Metric 3D Reconstruction (Point Cloud and Camera Poses)</h3>"
1175
- "</div>"
1176
- )
1177
- gr.Markdown("---")
1178
-
1179
- # --- Main App Layout ---
1180
- with gr.Row():
1181
-
1182
- # LEFT COLUMN (Sidebar / Controls)
1183
- with gr.Column(scale=1, min_width=350):
1184
-
1185
- with gr.Group():
1186
- gr.Markdown("### 📁 1. Input Media")
1187
  unified_upload = gr.File(
1188
  file_count="multiple",
1189
  label="Upload Video or Images",
1190
- interactive=True,
1191
  file_types=["image", "video"],
 
1192
  )
 
1193
  with gr.Row():
1194
  s_time_interval = gr.Slider(
1195
- minimum=0.1,
1196
- maximum=5.0,
1197
- value=1.0,
1198
- step=0.1,
1199
  label="Video sample interval (sec)",
1200
- interactive=True,
1201
- visible=True,
1202
  )
1203
- resample_btn = gr.Button("Resample", visible=False, variant="secondary")
1204
 
 
1205
  image_gallery = gr.Gallery(
1206
- label="Preview",
1207
- columns=4,
1208
- height="200px",
1209
- object_fit="contain",
1210
  preview=True,
 
 
1211
  )
1212
- clear_uploads_btn = gr.ClearButton(
 
1213
  [unified_upload, image_gallery],
1214
- value="Clear Uploads",
1215
  variant="secondary",
1216
  size="sm",
1217
  )
1218
 
1219
- with gr.Group():
1220
- gr.Markdown("### ⚙️ 2. Reconstruction Settings")
1221
- apply_mask_checkbox = gr.Checkbox(
1222
- label="Apply mask (depth classes & edges)",
1223
- value=True,
1224
- )
1225
-
1226
- with gr.Row():
1227
- submit_btn = gr.Button("🚀 Reconstruct", variant="primary", scale=2)
1228
- clear_btn = gr.ClearButton(
1229
- [
1230
- unified_upload,
1231
- target_dir_output,
1232
- image_gallery,
1233
- ],
1234
- value="Clear All",
1235
- scale=1,
1236
  )
1237
 
1238
- with gr.Accordion("🎨 Visualization Options", open=True):
1239
- gr.Markdown("*Note: Updates automatically applied to viewer.*")
1240
  frame_filter = gr.Dropdown(
1241
- choices=["All"], value="All", label="Show Points from Frame"
 
1242
  )
1243
- show_cam = gr.Checkbox(label="Show Camera Paths", value=True)
1244
- show_mesh = gr.Checkbox(label="Show 3D Mesh", value=True)
1245
- filter_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
1246
- filter_white_bg = gr.Checkbox(label="Filter White Background", value=False)
1247
-
1248
-
1249
- # RIGHT COLUMN (Main Viewer Area)
1250
- with gr.Column(scale=2, min_width=600):
1251
- log_output = gr.Markdown("Status: **Ready**. Please upload media or select an example scene below.", elem_classes=["custom-log"])
1252
-
1253
- with gr.Tabs():
1254
- with gr.Tab("3D View"):
1255
- reconstruction_output = Rerun(
1256
- label="Rerun 3D Viewer",
1257
- height=600,
1258
- )
1259
- with gr.Tab("Depth"):
1260
- with gr.Row(elem_classes=["navigation-row"]):
1261
- prev_depth_btn = gr.Button("◀ Previous", size="sm", scale=1)
1262
- depth_view_selector = gr.Dropdown(
1263
- choices=["View 1"],
1264
- value="View 1",
1265
- label="Select View",
1266
- scale=2,
1267
- interactive=True,
1268
- allow_custom_value=True,
1269
- )
1270
- next_depth_btn = gr.Button("Next ▶", size="sm", scale=1)
1271
- depth_map = gr.Image(
1272
- type="numpy",
1273
- label="Colorized Depth Map",
1274
- format="png",
1275
- interactive=False,
1276
- )
1277
- with gr.Tab("Normal"):
1278
- with gr.Row(elem_classes=["navigation-row"]):
1279
- prev_normal_btn = gr.Button("◀ Previous", size="sm", scale=1)
1280
- normal_view_selector = gr.Dropdown(
1281
- choices=["View 1"],
1282
- value="View 1",
1283
- label="Select View",
1284
- scale=2,
1285
- interactive=True,
1286
- allow_custom_value=True,
1287
- )
1288
- next_normal_btn = gr.Button("Next ▶", size="sm", scale=1)
1289
- normal_map = gr.Image(
1290
- type="numpy",
1291
- label="Normal Map",
1292
- format="png",
1293
- interactive=False,
1294
- )
1295
- with gr.Tab("Measure"):
1296
- gr.Markdown(MEASURE_INSTRUCTIONS_HTML)
1297
- with gr.Row(elem_classes=["navigation-row"]):
1298
- prev_measure_btn = gr.Button("◀ Previous", size="sm", scale=1)
1299
- measure_view_selector = gr.Dropdown(
1300
- choices=["View 1"],
1301
- value="View 1",
1302
- label="Select View",
1303
- scale=2,
1304
- interactive=True,
1305
- allow_custom_value=True,
1306
  )
1307
- next_measure_btn = gr.Button("Next ▶", size="sm", scale=1)
1308
- measure_image = gr.Image(
1309
- type="numpy",
1310
- show_label=False,
1311
- format="webp",
1312
- interactive=False,
1313
- sources=[],
1314
- )
1315
- gr.Markdown("**Note:** Light-grey areas indicate regions with no depth information where measurements cannot be taken.")
1316
- measure_text = gr.Markdown("")
1317
-
1318
- # --- Footer Area (Example Scenes) ---
1319
- gr.Markdown("---")
1320
- gr.Markdown("## 🌟 Example Scenes")
1321
- gr.Markdown("Click any thumbnail below to load a sample dataset for reconstruction.")
1322
-
1323
- scenes = get_scene_info("examples")
1324
-
1325
- if scenes:
1326
- for i in range(0, len(scenes), 4):
1327
- with gr.Row():
1328
- for j in range(4):
1329
- scene_idx = i + j
1330
- if scene_idx < len(scenes):
1331
- scene = scenes[scene_idx]
1332
- with gr.Column(scale=1, elem_classes=["clickable-thumbnail"]):
1333
- scene_img = gr.Image(
1334
- value=scene["thumbnail"],
1335
- height=150,
1336
- interactive=False,
1337
- show_label=False,
1338
- elem_id=f"scene_thumb_{scene['name']}",
1339
- sources=[],
1340
  )
1341
- gr.Markdown(
1342
- f"**{scene['name']}** \n {scene['num_images']} images",
1343
- elem_classes=["scene-info"],
 
 
 
 
 
 
 
 
 
 
 
1344
  )
1345
- # Clicking an example bypasses the manual process and loads everything automatically
1346
- scene_img.select(
1347
- fn=lambda name=scene["name"]: load_example_scene(name),
1348
- outputs=[
1349
- reconstruction_output, # To clear old view
1350
- target_dir_output,
1351
- image_gallery,
1352
- log_output,
1353
- ],
 
 
 
 
 
 
1354
  )
1355
- else:
1356
- with gr.Column(scale=1):
1357
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1358
 
1359
  # =========================================================================
1360
- # Event Bindings & Logic
1361
  # =========================================================================
1362
 
1363
- submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then(
 
 
 
1364
  fn=update_log, inputs=[], outputs=[log_output]
1365
  ).then(
1366
  fn=gradio_demo,
1367
- inputs=[
1368
- target_dir_output,
1369
- frame_filter,
1370
- show_cam,
1371
- filter_black_bg,
1372
- filter_white_bg,
1373
- apply_mask_checkbox,
1374
- show_mesh,
1375
- ],
1376
  outputs=[
1377
- reconstruction_output,
1378
- log_output,
1379
- frame_filter,
1380
- processed_data_state,
1381
- depth_map,
1382
- normal_map,
1383
- measure_image,
1384
- measure_text,
1385
- depth_view_selector,
1386
- normal_view_selector,
1387
- measure_view_selector,
1388
  ],
1389
- ).then(
1390
- fn=lambda: "False",
1391
- inputs=[],
1392
- outputs=[is_example],
1393
- )
 
 
 
 
1394
 
1395
- # Real-time Visualization Updates
1396
- frame_filter.change(
1397
- update_visualization,
1398
- [target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh],
1399
- [reconstruction_output, log_output],
1400
- )
1401
- show_cam.change(
1402
- update_visualization,
1403
- [target_dir_output, frame_filter, show_cam, is_example],
1404
- [reconstruction_output, log_output],
1405
- )
1406
  filter_black_bg.change(
1407
  update_visualization,
1408
  [target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg],
1409
  [reconstruction_output, log_output],
1410
  ).then(
1411
- fn=update_all_views_on_filter_change,
1412
- inputs=[target_dir_output, filter_black_bg, filter_white_bg, processed_data_state, depth_view_selector, normal_view_selector, measure_view_selector],
1413
- outputs=[processed_data_state, depth_map, normal_map, measure_image, measure_points_state],
1414
  )
1415
  filter_white_bg.change(
1416
  update_visualization,
1417
  [target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh],
1418
  [reconstruction_output, log_output],
1419
  ).then(
1420
- fn=update_all_views_on_filter_change,
1421
- inputs=[target_dir_output, filter_black_bg, filter_white_bg, processed_data_state, depth_view_selector, normal_view_selector, measure_view_selector],
1422
- outputs=[processed_data_state, depth_map, normal_map, measure_image, measure_points_state],
1423
- )
1424
- show_mesh.change(
1425
- update_visualization,
1426
- [target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh],
1427
- [reconstruction_output, log_output],
1428
  )
1429
 
1430
- # Auto-update gallery on upload
1431
  def update_gallery_on_unified_upload(files, interval):
1432
  if not files:
1433
- return None, None, "Ready for upload."
1434
  target_dir, image_paths = handle_uploads(files, interval)
1435
- return target_dir, image_paths, "Upload complete. Click '🚀 Reconstruct' to begin 3D processing."
1436
 
1437
  def show_resample_button(files):
1438
- if not files: return gr.update(visible=False)
 
1439
  video_exts = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"]
1440
- has_video = False
1441
- for f_data in files:
1442
- f_path = str(f_data["name"] if isinstance(f_data, dict) else f_data)
1443
- if os.path.splitext(f_path)[1].lower() in video_exts:
1444
- has_video = True
1445
- break
1446
  return gr.update(visible=has_video)
1447
 
1448
  def resample_video_with_new_interval(files, new_interval, current_target_dir):
1449
- if not files: return current_target_dir, None, "No files to resample.", gr.update(visible=False)
 
1450
  video_exts = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"]
1451
- has_video = any(os.path.splitext(str(f["name"] if isinstance(f, dict) else f))[1].lower() in video_exts for f in files)
1452
-
1453
- if not has_video: return current_target_dir, None, "No videos found.", gr.update(visible=False)
1454
-
1455
  if current_target_dir and current_target_dir != "None" and os.path.exists(current_target_dir):
1456
  shutil.rmtree(current_target_dir)
1457
-
1458
  target_dir, image_paths = handle_uploads(files, new_interval)
1459
- return target_dir, image_paths, f"Video resampled ({new_interval}s interval). Click '🚀 Reconstruct'.", gr.update(visible=False)
1460
 
1461
  unified_upload.change(
1462
  fn=update_gallery_on_unified_upload,
1463
  inputs=[unified_upload, s_time_interval],
1464
  outputs=[target_dir_output, image_gallery, log_output],
1465
- ).then(
1466
- fn=show_resample_button,
1467
- inputs=[unified_upload],
1468
- outputs=[resample_btn],
1469
- )
1470
-
1471
- s_time_interval.change(
1472
- fn=show_resample_button,
1473
- inputs=[unified_upload],
1474
- outputs=[resample_btn],
1475
- )
1476
 
 
1477
  resample_btn.click(
1478
  fn=resample_video_with_new_interval,
1479
  inputs=[unified_upload, s_time_interval, target_dir_output],
1480
  outputs=[target_dir_output, image_gallery, log_output, resample_btn],
1481
  )
1482
 
1483
- # Measure Interactions
1484
  measure_image.select(
1485
  fn=measure,
1486
  inputs=[processed_data_state, measure_points_state, measure_view_selector],
1487
  outputs=[measure_image, measure_points_state, measure_text],
1488
  )
1489
 
1490
- # Tab Navigations
1491
  prev_depth_btn.click(
1492
- fn=lambda d, s: navigate_depth_view(d, s, -1),
1493
  inputs=[processed_data_state, depth_view_selector], outputs=[depth_view_selector, depth_map],
1494
  )
1495
  next_depth_btn.click(
1496
- fn=lambda d, s: navigate_depth_view(d, s, 1),
1497
  inputs=[processed_data_state, depth_view_selector], outputs=[depth_view_selector, depth_map],
1498
  )
1499
  depth_view_selector.change(
1500
- fn=lambda d, s: update_depth_view(d, int(s.split()[1]) - 1) if s else None,
1501
  inputs=[processed_data_state, depth_view_selector], outputs=[depth_map],
1502
  )
1503
 
 
1504
  prev_normal_btn.click(
1505
- fn=lambda d, s: navigate_normal_view(d, s, -1),
1506
  inputs=[processed_data_state, normal_view_selector], outputs=[normal_view_selector, normal_map],
1507
  )
1508
  next_normal_btn.click(
1509
- fn=lambda d, s: navigate_normal_view(d, s, 1),
1510
  inputs=[processed_data_state, normal_view_selector], outputs=[normal_view_selector, normal_map],
1511
  )
1512
  normal_view_selector.change(
1513
- fn=lambda d, s: update_normal_view(d, int(s.split()[1]) - 1) if s else None,
1514
  inputs=[processed_data_state, normal_view_selector], outputs=[normal_map],
1515
  )
1516
 
 
1517
  prev_measure_btn.click(
1518
- fn=lambda d, s: navigate_measure_view(d, s, -1),
1519
- inputs=[processed_data_state, measure_view_selector], outputs=[measure_view_selector, measure_image, measure_points_state],
 
1520
  )
1521
  next_measure_btn.click(
1522
- fn=lambda d, s: navigate_measure_view(d, s, 1),
1523
- inputs=[processed_data_state, measure_view_selector], outputs=[measure_view_selector, measure_image, measure_points_state],
 
1524
  )
1525
  measure_view_selector.change(
1526
- fn=lambda d, s: update_measure_view(d, int(s.split()[1]) - 1) if s else (None, []),
1527
- inputs=[processed_data_state, measure_view_selector], outputs=[measure_image, measure_points_state],
 
1528
  )
1529
 
1530
- demo.queue(max_size=20).launch(theme=theme, css=GRADIO_CSS, show_error=True, share=True, ssr_mode=False)
 
16
  from PIL import Image
17
  from pillow_heif import register_heif_opener
18
 
 
19
  import rerun as rr
20
  try:
21
  import rerun.blueprint as rrb
22
  except ImportError:
23
  rrb = None
24
+
25
  from gradio_rerun import Rerun
26
 
27
  register_heif_opener()
 
41
  from mapanything.utils.hf_utils.viz import predictions_to_glb
42
  from mapanything.utils.image import load_images, rgb
43
 
 
44
  high_level_config = {
45
  "path": "configs/train.yaml",
46
+ "hf_model_name": "facebook/map-anything-v1",
47
  "model_str": "mapanything",
48
  "config_overrides": [
49
  "machine=aws",
 
60
  "resolution": 518,
61
  }
62
 
 
63
  model = None
64
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
65
+ os.makedirs(TMP_DIR, exist_ok=True)
66
 
67
 
68
  # -------------------------------------------------------------------------
69
+ # Rerun visualization helper
70
  # -------------------------------------------------------------------------
71
+ def predictions_to_rrd(predictions, glbfile, target_dir, frame_filter="All", show_cam=True):
 
 
 
 
72
  run_id = str(uuid.uuid4())
73
+ timestamp = datetime.now().strftime("%Y-%m-%dT%H%M%S")
74
+ rrd_path = os.path.join(target_dir, f"mapanything_{timestamp}.rrd")
75
+
76
  rec = None
77
  if hasattr(rr, "new_recording"):
78
+ rec = rr.new_recording(application_id="MapAnything-3D-Viewer", recording_id=run_id)
79
  elif hasattr(rr, "RecordingStream"):
80
+ rec = rr.RecordingStream(application_id="MapAnything-3D-Viewer", recording_id=run_id)
81
  else:
82
+ rr.init("MapAnything-3D-Viewer", recording_id=run_id, spawn=False)
83
  rec = rr
84
+
 
85
  rec.log("world", rr.Clear(recursive=True), static=True)
 
 
86
  rec.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, static=True)
87
 
 
88
  try:
89
  rec.log("world/axes/x", rr.Arrows3D(vectors=[[0.5, 0, 0]], colors=[[255, 0, 0]]), static=True)
90
  rec.log("world/axes/y", rr.Arrows3D(vectors=[[0, 0.5, 0]], colors=[[0, 255, 0]]), static=True)
 
92
  except Exception:
93
  pass
94
 
95
+ rec.log("world/model", rr.Asset3D(path=glbfile), static=True)
96
+
97
+ if show_cam and "extrinsic" in predictions and "intrinsic" in predictions:
98
+ try:
99
+ extrinsics = predictions["extrinsic"]
100
+ intrinsics = predictions["intrinsic"]
101
+ for i, (ext, intr) in enumerate(zip(extrinsics, intrinsics)):
102
+ translation = ext[:3, 3]
103
+ rotation_mat = ext[:3, :3]
104
+ rec.log(
105
+ f"world/cameras/cam_{i:03d}",
106
+ rr.Transform3D(translation=translation, mat3x3=rotation_mat),
107
+ static=True,
108
+ )
109
+ fx, fy = intr[0, 0], intr[1, 1]
110
+ cx, cy = intr[0, 2], intr[1, 2]
111
+ if "images" in predictions and i < len(predictions["images"]):
112
+ h, w = predictions["images"][i].shape[:2]
113
+ else:
114
+ h, w = 518, 518
115
+ rec.log(
116
+ f"world/cameras/cam_{i:03d}/image",
117
+ rr.Pinhole(focal_length=[fx, fy], principal_point=[cx, cy], width=w, height=h),
118
+ static=True,
119
+ )
120
+ if "images" in predictions and i < len(predictions["images"]):
121
+ img = predictions["images"][i]
122
+ if img.dtype != np.uint8:
123
+ img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
124
+ rec.log(f"world/cameras/cam_{i:03d}/image/rgb", rr.Image(img), static=True)
125
+ except Exception as e:
126
+ print(f"Camera logging failed (non-fatal): {e}")
127
+
128
+ if "world_points" in predictions and "images" in predictions:
129
+ try:
130
+ world_points = predictions["world_points"]
131
+ images = predictions["images"]
132
+ final_mask = predictions.get("final_mask")
133
+ all_points, all_colors = [], []
134
+ for i in range(len(world_points)):
135
+ pts = world_points[i]
136
+ img = images[i]
137
+ mask = final_mask[i].astype(bool) if final_mask is not None else np.ones(pts.shape[:2], dtype=bool)
138
+ pts_flat = pts[mask]
139
+ img_flat = img[mask]
140
+ if img_flat.dtype != np.uint8:
141
+ img_flat = (np.clip(img_flat, 0, 1) * 255).astype(np.uint8)
142
+ all_points.append(pts_flat)
143
+ all_colors.append(img_flat)
144
+ if all_points:
145
+ all_points = np.concatenate(all_points, axis=0)
146
+ all_colors = np.concatenate(all_colors, axis=0)
147
+ max_pts = 500_000
148
+ if len(all_points) > max_pts:
149
+ idx = np.random.choice(len(all_points), max_pts, replace=False)
150
+ all_points = all_points[idx]
151
+ all_colors = all_colors[idx]
152
+ rec.log("world/point_cloud", rr.Points3D(positions=all_points, colors=all_colors, radii=0.002), static=True)
153
+ except Exception as e:
154
+ print(f"Point cloud logging failed (non-fatal): {e}")
155
+
156
  if rrb is not None:
157
  try:
158
  blueprint = rrb.Blueprint(
159
+ rrb.Spatial3DView(origin="/world", name="3D View"),
 
 
 
160
  collapse_panels=True,
161
  )
162
  rec.send_blueprint(blueprint)
163
  except Exception as e:
164
  print(f"Blueprint creation failed (non-fatal): {e}")
165
 
 
 
166
  rec.save(rrd_path)
 
167
  return rrd_path
168
 
169
 
 
171
  # 1) Core model inference
172
  # -------------------------------------------------------------------------
173
  @spaces.GPU(duration=120)
174
+ def run_model(target_dir, apply_mask=True, mask_edges=True, filter_black_bg=False, filter_white_bg=False):
 
 
 
 
 
 
 
 
 
175
  global model
176
+ import torch
177
 
178
  print(f"Processing images from {target_dir}")
179
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
180
 
 
 
 
 
 
181
  if model is None:
182
  model = initialize_mapanything_model(high_level_config, device)
 
183
  else:
184
  model = model.to(device)
 
185
  model.eval()
186
 
 
187
  print("Loading images...")
188
  image_folder_path = os.path.join(target_dir, "images")
189
  views = load_images(image_folder_path)
 
190
  print(f"Loaded {len(views)} images")
191
  if len(views) == 0:
192
  raise ValueError("No images found. Check your upload.")
193
 
 
194
  print("Running inference...")
195
+ outputs = model.infer(views, apply_mask=apply_mask, mask_edges=True, memory_efficient_inference=False)
 
 
 
 
 
196
 
 
197
  predictions = {}
198
+ extrinsic_list, intrinsic_list, world_points_list = [], [], []
199
+ depth_maps_list, images_list, final_mask_list = [], [], []
200
 
 
 
 
 
 
 
 
 
 
201
  for pred in outputs:
202
+ depthmap_torch = pred["depth_z"][0].squeeze(-1)
203
+ intrinsics_torch = pred["intrinsics"][0]
204
+ camera_pose_torch = pred["camera_poses"][0]
205
+ pts3d_computed, valid_mask = depthmap_to_world_frame(depthmap_torch, intrinsics_torch, camera_pose_torch)
206
+ mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool) if "mask" in pred else np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  mask = mask & valid_mask.cpu().numpy()
 
208
  image = pred["img_no_norm"][0].cpu().numpy()
 
 
209
  extrinsic_list.append(camera_pose_torch.cpu().numpy())
210
  intrinsic_list.append(intrinsics_torch.cpu().numpy())
211
  world_points_list.append(pts3d_computed.cpu().numpy())
212
  depth_maps_list.append(depthmap_torch.cpu().numpy())
213
+ images_list.append(image)
214
+ final_mask_list.append(mask)
215
 
 
 
216
  predictions["extrinsic"] = np.stack(extrinsic_list, axis=0)
 
 
217
  predictions["intrinsic"] = np.stack(intrinsic_list, axis=0)
 
 
218
  predictions["world_points"] = np.stack(world_points_list, axis=0)
 
 
219
  depth_maps = np.stack(depth_maps_list, axis=0)
 
220
  if len(depth_maps.shape) == 3:
221
  depth_maps = depth_maps[..., np.newaxis]
 
222
  predictions["depth"] = depth_maps
 
 
223
  predictions["images"] = np.stack(images_list, axis=0)
 
 
224
  predictions["final_mask"] = np.stack(final_mask_list, axis=0)
225
 
226
+ processed_data = process_predictions_for_visualization(predictions, views, high_level_config, filter_black_bg, filter_white_bg)
 
 
 
 
 
227
  torch.cuda.empty_cache()
 
228
  return predictions, processed_data
229
 
230
 
231
  def update_view_selectors(processed_data):
232
+ choices = [f"View {i + 1}" for i in range(len(processed_data))] if processed_data else ["View 1"]
 
 
 
 
 
 
233
  return (
234
+ gr.Dropdown(choices=choices, value=choices[0]),
235
+ gr.Dropdown(choices=choices, value=choices[0]),
236
+ gr.Dropdown(choices=choices, value=choices[0]),
237
  )
238
 
239
 
240
  def get_view_data_by_index(processed_data, view_index):
241
+ if not processed_data:
 
242
  return None
 
243
  view_keys = list(processed_data.keys())
244
+ view_index = max(0, min(view_index, len(view_keys) - 1))
 
 
245
  return processed_data[view_keys[view_index]]
246
 
247
 
248
  def update_depth_view(processed_data, view_index):
 
249
  view_data = get_view_data_by_index(processed_data, view_index)
250
  if view_data is None or view_data["depth"] is None:
251
  return None
 
252
  return colorize_depth(view_data["depth"], mask=view_data.get("mask"))
253
 
254
 
255
  def update_normal_view(processed_data, view_index):
 
256
  view_data = get_view_data_by_index(processed_data, view_index)
257
  if view_data is None or view_data["normal"] is None:
258
  return None
 
259
  return colorize_normal(view_data["normal"], mask=view_data.get("mask"))
260
 
261
 
262
  def update_measure_view(processed_data, view_index):
 
263
  view_data = get_view_data_by_index(processed_data, view_index)
264
  if view_data is None:
265
+ return None, []
 
 
266
  image = view_data["image"].copy()
 
 
267
  if image.dtype != np.uint8:
268
+ image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8)
 
 
 
 
 
269
  if view_data["mask"] is not None:
270
+ invalid_mask = ~view_data["mask"]
 
 
 
 
 
271
  if invalid_mask.any():
 
272
  overlay_color = np.array([255, 220, 220], dtype=np.uint8)
273
+ alpha = 0.5
274
+ for c in range(3):
275
+ image[:, :, c] = np.where(invalid_mask, (1 - alpha) * image[:, :, c] + alpha * overlay_color[c], image[:, :, c]).astype(np.uint8)
 
 
 
 
 
 
 
276
  return image, []
277
 
278
 
279
  def navigate_depth_view(processed_data, current_selector_value, direction):
280
+ if not processed_data:
 
281
  return "View 1", None
 
 
282
  try:
283
  current_view = int(current_selector_value.split()[1]) - 1
284
  except:
285
  current_view = 0
286
+ new_view = (current_view + direction) % len(processed_data)
287
+ return f"View {new_view + 1}", update_depth_view(processed_data, new_view)
 
 
 
 
 
 
288
 
289
 
290
  def navigate_normal_view(processed_data, current_selector_value, direction):
291
+ if not processed_data:
 
292
  return "View 1", None
 
 
293
  try:
294
  current_view = int(current_selector_value.split()[1]) - 1
295
  except:
296
  current_view = 0
297
+ new_view = (current_view + direction) % len(processed_data)
298
+ return f"View {new_view + 1}", update_normal_view(processed_data, new_view)
 
 
 
 
 
 
299
 
300
 
301
  def navigate_measure_view(processed_data, current_selector_value, direction):
302
+ if not processed_data:
 
303
  return "View 1", None, []
 
 
304
  try:
305
  current_view = int(current_selector_value.split()[1]) - 1
306
  except:
307
  current_view = 0
308
+ new_view = (current_view + direction) % len(processed_data)
 
 
 
 
309
  measure_image, measure_points = update_measure_view(processed_data, new_view)
310
+ return f"View {new_view + 1}", measure_image, measure_points
 
311
 
312
 
313
  def populate_visualization_tabs(processed_data):
314
+ if not processed_data:
 
315
  return None, None, None, []
316
+ return update_depth_view(processed_data, 0), update_normal_view(processed_data, 0), update_measure_view(processed_data, 0)[0], []
 
 
 
 
 
 
317
 
318
 
319
  # -------------------------------------------------------------------------
320
+ # 2) Handle uploaded video/images
321
  # -------------------------------------------------------------------------
322
  def handle_uploads(unified_upload, s_time_interval=1.0):
 
 
 
 
323
  start_time = time.time()
324
  gc.collect()
325
  torch.cuda.empty_cache()
326
 
 
327
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
328
  target_dir = f"input_images_{timestamp}"
329
  target_dir_images = os.path.join(target_dir, "images")
 
 
330
  if os.path.exists(target_dir):
331
  shutil.rmtree(target_dir)
 
332
  os.makedirs(target_dir_images)
333
 
334
  image_paths = []
335
+ video_extensions = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"]
336
 
 
337
  if unified_upload is not None:
338
  for file_data in unified_upload:
339
+ file_path = file_data["name"] if isinstance(file_data, dict) and "name" in file_data else str(file_data)
 
 
 
 
340
  file_ext = os.path.splitext(file_path)[1].lower()
341
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  if file_ext in video_extensions:
 
343
  vs = cv2.VideoCapture(file_path)
344
  fps = vs.get(cv2.CAP_PROP_FPS)
345
+ frame_interval = int(fps * s_time_interval)
346
+ count, video_frame_num = 0, 0
 
 
347
  while True:
348
  gotit, frame = vs.read()
349
  if not gotit:
350
  break
351
  count += 1
352
  if count % frame_interval == 0:
 
353
  base_name = os.path.splitext(os.path.basename(file_path))[0]
354
+ image_path = os.path.join(target_dir_images, f"{base_name}_{video_frame_num:06}.png")
 
 
355
  cv2.imwrite(image_path, frame)
356
  image_paths.append(image_path)
357
  video_frame_num += 1
358
  vs.release()
359
+ print(f"Extracted {video_frame_num} frames from: {os.path.basename(file_path)}")
360
+ elif file_ext in [".heic", ".heif"]:
361
+ try:
362
+ with Image.open(file_path) as img:
363
+ if img.mode not in ("RGB", "L"):
364
+ img = img.convert("RGB")
365
+ base_name = os.path.splitext(os.path.basename(file_path))[0]
366
+ dst_path = os.path.join(target_dir_images, f"{base_name}.jpg")
367
+ img.save(dst_path, "JPEG", quality=95)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  image_paths.append(dst_path)
369
+ except Exception as e:
370
+ print(f"Error converting HEIC {file_path}: {e}")
371
+ dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
 
 
372
  shutil.copy(file_path, dst_path)
373
  image_paths.append(dst_path)
374
+ else:
375
+ dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
376
+ shutil.copy(file_path, dst_path)
377
+ image_paths.append(dst_path)
378
 
 
379
  image_paths = sorted(image_paths)
380
+ print(f"Files processed to {target_dir_images}; took {time.time() - start_time:.3f}s")
 
 
 
 
381
  return target_dir, image_paths
382
 
383
 
384
  # -------------------------------------------------------------------------
385
+ # 3) Reconstruction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  # -------------------------------------------------------------------------
387
  @spaces.GPU(duration=120)
388
+ def gradio_demo(target_dir, frame_filter="All", show_cam=True, filter_black_bg=False, filter_white_bg=False, apply_mask=True, show_mesh=True):
 
 
 
 
 
 
 
 
 
 
 
389
  if not os.path.isdir(target_dir) or target_dir == "None":
390
  return None, "No valid target directory found. Please upload first.", None, None
391
 
 
393
  gc.collect()
394
  torch.cuda.empty_cache()
395
 
 
396
  target_dir_images = os.path.join(target_dir, "images")
397
+ all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
398
+ all_files_labeled = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
399
+ frame_filter_choices = ["All"] + all_files_labeled
 
 
 
 
400
 
401
  print("Running MapAnything model...")
402
  with torch.no_grad():
403
  predictions, processed_data = run_model(target_dir, apply_mask)
404
 
405
+ np.savez(os.path.join(target_dir, "predictions.npz"), **predictions)
 
 
406
 
 
407
  if frame_filter is None:
408
  frame_filter = "All"
409
 
 
410
  glbfile = os.path.join(
411
  target_dir,
412
  f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb",
413
  )
414
+ glbscene = predictions_to_glb(predictions, filter_by_frames=frame_filter, show_cam=show_cam, mask_black_bg=filter_black_bg, mask_white_bg=filter_white_bg, as_mesh=show_mesh)
 
 
 
 
 
 
 
 
 
415
  glbscene.export(file_obj=glbfile)
 
 
 
 
 
416
 
417
+ rrd_path = predictions_to_rrd(predictions, glbfile, target_dir, frame_filter, show_cam)
418
+
419
  del predictions
420
  gc.collect()
421
  torch.cuda.empty_cache()
422
 
423
+ print(f"Total time: {time.time() - start_time:.2f}s")
424
+ log_msg = f" Reconstruction complete {len(all_files)} frames processed."
 
 
 
 
 
 
 
 
425
 
426
+ depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs(processed_data)
427
+ depth_selector, normal_selector, measure_selector = update_view_selectors(processed_data)
 
 
428
 
429
  return (
430
+ rrd_path, log_msg,
 
431
  gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
432
+ processed_data, depth_vis, normal_vis, measure_img, "",
433
+ depth_selector, normal_selector, measure_selector,
 
 
 
 
 
 
434
  )
435
 
436
 
437
  # -------------------------------------------------------------------------
438
+ # 4) Helper / visualization functions
439
  # -------------------------------------------------------------------------
440
  def colorize_depth(depth_map, mask=None):
 
441
  if depth_map is None:
442
  return None
 
 
443
  depth_normalized = depth_map.copy()
444
  valid_mask = depth_normalized > 0
 
 
445
  if mask is not None:
446
  valid_mask = valid_mask & mask
 
447
  if valid_mask.sum() > 0:
448
  valid_depths = depth_normalized[valid_mask]
449
+ p5, p95 = np.percentile(valid_depths, 5), np.percentile(valid_depths, 95)
 
 
450
  depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5)
 
 
451
  import matplotlib.pyplot as plt
452
+ colored = (plt.cm.turbo_r(depth_normalized)[:, :, :3] * 255).astype(np.uint8)
 
 
 
 
 
453
  colored[~valid_mask] = [255, 255, 255]
 
454
  return colored
455
 
456
 
457
  def colorize_normal(normal_map, mask=None):
 
458
  if normal_map is None:
459
  return None
 
 
460
  normal_vis = normal_map.copy()
 
 
461
  if mask is not None:
462
+ normal_vis[~mask] = [0, 0, 0]
463
+ return ((normal_vis + 1.0) / 2.0 * 255).astype(np.uint8)
464
 
 
 
 
465
 
466
+ def process_predictions_for_visualization(predictions, views, high_level_config, filter_black_bg=False, filter_white_bg=False):
 
 
 
 
 
 
467
  processed_data = {}
 
 
468
  for view_idx, view in enumerate(views):
 
469
  image = rgb(view["img"], norm_type=high_level_config["data_norm_type"])
 
 
470
  pred_pts3d = predictions["world_points"][view_idx]
 
 
 
 
 
 
 
 
 
 
 
471
  mask = predictions["final_mask"][view_idx].copy()
 
 
472
  if filter_black_bg:
 
473
  view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
474
+ mask = mask & (view_colors.sum(axis=2) >= 16)
 
 
 
 
475
  if filter_white_bg:
 
476
  view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
477
+ mask = mask & ~((view_colors[:, :, 0] > 240) & (view_colors[:, :, 1] > 240) & (view_colors[:, :, 2] > 240))
478
+ normals, _ = points_to_normals(pred_pts3d, mask=mask)
479
+ processed_data[view_idx] = {
480
+ "image": image[0],
481
+ "points3d": pred_pts3d,
482
+ "depth": predictions["depth"][view_idx].squeeze(),
483
+ "normal": normals,
484
+ "mask": mask,
485
+ }
 
 
 
 
 
 
 
486
  return processed_data
487
 
488
 
489
+ def measure(processed_data, measure_points, current_view_selector, event: gr.SelectData):
 
 
 
 
 
 
 
 
 
 
 
 
 
490
  try:
491
+ if not processed_data:
 
 
492
  return None, [], "No data available"
 
 
493
  try:
494
  current_view_index = int(current_view_selector.split()[1]) - 1
495
  except:
496
  current_view_index = 0
497
+ current_view_index = max(0, min(current_view_index, len(processed_data) - 1))
498
+ current_view = processed_data[list(processed_data.keys())[current_view_index]]
 
 
 
 
 
 
 
 
499
  if current_view is None:
500
  return None, [], "No view data available"
501
 
502
  point2d = event.index[0], event.index[1]
503
+ if current_view["mask"] is not None and 0 <= point2d[1] < current_view["mask"].shape[0] and 0 <= point2d[0] < current_view["mask"].shape[1]:
 
 
 
 
 
 
 
 
504
  if not current_view["mask"][point2d[1], point2d[0]]:
505
+ masked_image, _ = update_measure_view(processed_data, current_view_index)
506
+ return masked_image, measure_points, '<span style="color: red; font-weight: bold;">Cannot measure on masked areas</span>'
 
 
 
 
 
 
 
 
507
 
508
  measure_points.append(point2d)
 
 
509
  image, _ = update_measure_view(processed_data, current_view_index)
510
  if image is None:
511
  return None, [], "No image available"
 
512
  image = image.copy()
513
+ if image.dtype != np.uint8:
514
+ image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8)
515
  points3d = current_view["points3d"]
516
 
517
+ for p in measure_points:
518
+ if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]:
519
+ image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
 
521
  depth_text = ""
522
+ for i, p in enumerate(measure_points):
523
+ if current_view["depth"] is not None and 0 <= p[1] < current_view["depth"].shape[0] and 0 <= p[0] < current_view["depth"].shape[1]:
524
+ depth_text += f"- **P{i + 1} depth: {current_view['depth'][p[1], p[0]]:.2f}m**\n"
525
+ elif points3d is not None and 0 <= p[1] < points3d.shape[0] and 0 <= p[0] < points3d.shape[1]:
526
+ depth_text += f"- **P{i + 1} Z-coord: {points3d[p[1], p[0], 2]:.2f}m**\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
 
528
  if len(measure_points) == 2:
529
+ point1, point2 = measure_points
530
+ if all(0 <= point1[0] < image.shape[1] and 0 <= point1[1] < image.shape[0] and 0 <= point2[0] < image.shape[1] and 0 <= point2[1] < image.shape[0] for _ in [1]):
531
+ image = cv2.line(image, point1, point2, color=(255, 0, 0), thickness=2)
532
+ distance_text = "- **Distance: Unable to compute**"
533
+ if points3d is not None and all(0 <= p[1] < points3d.shape[0] and 0 <= p[0] < points3d.shape[1] for p in [point1, point2]):
534
+ try:
535
+ distance = np.linalg.norm(points3d[point1[1], point1[0]] - points3d[point2[1], point2[0]])
536
+ distance_text = f"- **Distance: {distance:.2f}m**"
537
+ except Exception as e:
538
+ distance_text = f"- **Distance error: {e}**"
539
+ return [image, [], depth_text + distance_text]
540
+ return [image, measure_points, depth_text]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
541
  except Exception as e:
542
+ print(f"Measure error: {e}")
543
+ return None, [], f"Error: {e}"
544
 
545
 
546
  def clear_fields():
 
 
 
547
  return None
548
 
549
 
550
  def update_log():
551
+ return "⏳ Loading and reconstructing…"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
552
 
 
 
 
 
 
553
 
554
+ def update_visualization(target_dir, frame_filter, show_cam, is_example, filter_black_bg=False, filter_white_bg=False, show_mesh=True):
555
+ if is_example == "True":
556
+ return gr.update(), "No reconstruction available. Please click Reconstruct first."
557
+ if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
558
+ return gr.update(), "No reconstruction available. Please upload first."
559
  predictions_path = os.path.join(target_dir, "predictions.npz")
560
  if not os.path.exists(predictions_path):
561
+ return gr.update(), "No reconstruction found. Please run Reconstruct first."
 
 
 
562
 
563
  loaded = np.load(predictions_path, allow_pickle=True)
564
  predictions = {key: loaded[key] for key in loaded.keys()}
 
567
  target_dir,
568
  f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb",
569
  )
 
570
  if not os.path.exists(glbfile):
571
+ glbscene = predictions_to_glb(predictions, filter_by_frames=frame_filter, show_cam=show_cam, mask_black_bg=filter_black_bg, mask_white_bg=filter_white_bg, as_mesh=show_mesh)
 
 
 
 
 
 
 
572
  glbscene.export(file_obj=glbfile)
 
 
 
573
 
574
+ rrd_path = predictions_to_rrd(predictions, glbfile, target_dir, frame_filter, show_cam)
575
+ return rrd_path, "Visualization updated."
 
 
576
 
577
 
578
+ def update_all_views_on_filter_change(target_dir, filter_black_bg, filter_white_bg, processed_data, depth_view_selector, normal_view_selector, measure_view_selector):
 
 
 
 
 
 
 
 
 
 
 
 
 
579
  if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
580
  return processed_data, None, None, None, []
 
581
  predictions_path = os.path.join(target_dir, "predictions.npz")
582
  if not os.path.exists(predictions_path):
583
  return processed_data, None, None, None, []
 
584
  try:
 
585
  loaded = np.load(predictions_path, allow_pickle=True)
586
  predictions = {key: loaded[key] for key in loaded.keys()}
587
+ views = load_images(os.path.join(target_dir, "images"))
588
+ new_processed_data = process_predictions_for_visualization(predictions, views, high_level_config, filter_black_bg, filter_white_bg)
589
+ def safe_idx(sel):
590
+ try:
591
+ return int(sel.split()[1]) - 1
592
+ except:
593
+ return 0
594
+ depth_vis = update_depth_view(new_processed_data, safe_idx(depth_view_selector))
595
+ normal_vis = update_normal_view(new_processed_data, safe_idx(normal_view_selector))
596
+ measure_img, _ = update_measure_view(new_processed_data, safe_idx(measure_view_selector))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
597
  return new_processed_data, depth_vis, normal_vis, measure_img, []
 
598
  except Exception as e:
599
+ print(f"Filter change error: {e}")
600
  return processed_data, None, None, None, []
601
 
602
 
 
604
  # Example scene functions
605
  # -------------------------------------------------------------------------
606
  def get_scene_info(examples_dir):
 
607
  import glob
 
608
  scenes = []
609
  if not os.path.exists(examples_dir):
610
  return scenes
 
611
  for scene_folder in sorted(os.listdir(examples_dir)):
612
  scene_path = os.path.join(examples_dir, scene_folder)
613
  if os.path.isdir(scene_path):
 
 
614
  image_files = []
615
+ for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"]:
616
  image_files.extend(glob.glob(os.path.join(scene_path, ext)))
617
  image_files.extend(glob.glob(os.path.join(scene_path, ext.upper())))
 
618
  if image_files:
 
619
  image_files = sorted(image_files)
620
+ scenes.append({"name": scene_folder, "path": scene_path, "thumbnail": image_files[0], "num_images": len(image_files), "image_files": image_files})
 
 
 
 
 
 
 
 
 
 
 
 
621
  return scenes
622
 
623
 
624
  def load_example_scene(scene_name, examples_dir="examples"):
 
625
  scenes = get_scene_info(examples_dir)
626
+ selected_scene = next((s for s in scenes if s["name"] == scene_name), None)
 
 
 
 
 
 
 
627
  if selected_scene is None:
628
  return None, None, None, "Scene not found"
629
+ target_dir, image_paths = handle_uploads(selected_scene["image_files"], 1.0)
630
+ return None, target_dir, image_paths, f"Loaded '{scene_name}' — {selected_scene['num_images']} images. Click Reconstruct."
631
+
632
 
633
+ # -------------------------------------------------------------------------
634
+ # CSS
635
+ # -------------------------------------------------------------------------
636
+ CUSTOM_CSS = (GRADIO_CSS or "") + """
637
+ /* ── Page shell ── */
638
+ #app-shell {
639
+ max-width: 1400px;
640
+ margin: 0 auto;
641
+ padding: 0 16px 40px;
642
+ }
643
 
644
+ /* ── Header ── */
645
+ #app-header {
646
+ padding: 28px 0 20px;
647
+ border-bottom: 1px solid var(--border-color-primary);
648
+ margin-bottom: 24px;
649
+ }
650
+ #app-header h1 {
651
+ font-size: 2rem !important;
652
+ font-weight: 700 !important;
653
+ margin: 0 0 4px !important;
654
+ line-height: 1.2 !important;
655
+ }
656
+ #app-header p {
657
+ margin: 0 !important;
658
+ opacity: 0.65;
659
+ font-size: 0.95rem !important;
660
+ }
661
 
662
+ /* ── Two-panel layout ── */
663
+ #left-panel { min-width: 320px; max-width: 380px; }
664
+ #right-panel { flex: 1; min-width: 0; }
665
+
666
+ /* ── Section labels ── */
667
+ .section-label {
668
+ font-size: 0.7rem !important;
669
+ font-weight: 600 !important;
670
+ letter-spacing: 0.08em !important;
671
+ text-transform: uppercase !important;
672
+ opacity: 0.5 !important;
673
+ margin-bottom: 6px !important;
674
+ margin-top: 16px !important;
675
+ display: block !important;
676
+ }
677
 
678
+ /* ── Upload zone ── */
679
+ #upload-zone .wrap {
680
+ border-radius: 10px !important;
681
+ min-height: 110px !important;
682
+ }
683
+
684
+ /* ── Gallery ── */
685
+ #preview-gallery { border-radius: 10px; overflow: hidden; }
686
+
687
+ /* ── Action buttons ── */
688
+ #btn-reconstruct {
689
+ width: 100% !important;
690
+ font-size: 0.95rem !important;
691
+ font-weight: 600 !important;
692
+ padding: 12px !important;
693
+ border-radius: 8px !important;
694
+ }
695
+
696
+ /* ── Log strip ── */
697
+ #log-strip {
698
+ font-size: 0.82rem !important;
699
+ padding: 8px 12px !important;
700
+ border-radius: 6px !important;
701
+ border: 1px solid var(--border-color-primary) !important;
702
+ background: var(--background-fill-secondary) !important;
703
+ min-height: 36px !important;
704
+ }
705
+
706
+ /* ── Viewer tabs ── */
707
+ #viewer-tabs .tab-nav button {
708
+ font-size: 0.8rem !important;
709
+ font-weight: 500 !important;
710
+ padding: 6px 14px !important;
711
+ }
712
+ #viewer-tabs > .tabitem { padding: 0 !important; }
713
+
714
+ /* ── Navigation rows inside tabs ── */
715
+ .nav-row { align-items: center !important; gap: 6px !important; margin-bottom: 8px !important; }
716
+ .nav-row button { min-width: 80px !important; }
717
+
718
+ /* ── Options panel ── */
719
+ #options-panel {
720
+ border: 1px solid var(--border-color-primary);
721
+ border-radius: 10px;
722
+ padding: 16px;
723
+ margin-top: 12px;
724
+ }
725
+ #options-panel .gr-markdown h3 {
726
+ font-size: 0.72rem !important;
727
+ font-weight: 600 !important;
728
+ letter-spacing: 0.07em !important;
729
+ text-transform: uppercase !important;
730
+ opacity: 0.5 !important;
731
+ margin: 14px 0 6px !important;
732
+ }
733
+ #options-panel .gr-markdown h3:first-child { margin-top: 0 !important; }
734
+
735
+ /* ── Frame filter ── */
736
+ #frame-filter { margin-top: 12px; }
737
+
738
+ /* ── Examples section ── */
739
+ #examples-section { margin-top: 36px; padding-top: 24px; border-top: 1px solid var(--border-color-primary); }
740
+ #examples-section h2 { font-size: 1.1rem !important; font-weight: 600 !important; margin-bottom: 4px !important; }
741
+ #examples-section .scene-caption {
742
+ font-size: 0.75rem !important;
743
+ text-align: center !important;
744
+ opacity: 0.65 !important;
745
+ margin-top: 4px !important;
746
+ }
747
+ .scene-thumb img { border-radius: 8px; transition: opacity .15s; }
748
+ .scene-thumb img:hover { opacity: .85; }
749
+
750
+ /* ── Measure note ── */
751
+ .measure-note { font-size: 0.78rem !important; opacity: 0.6 !important; margin-top: 6px !important; }
752
+ """
753
 
754
  # -------------------------------------------------------------------------
755
  # 6) Build Gradio UI
756
  # -------------------------------------------------------------------------
757
+ with gr.Blocks(css=CUSTOM_CSS) as demo:
758
+
759
+ # Hidden state
760
+ is_example = gr.Textbox(visible=False, value="None")
761
+ num_images = gr.Textbox(visible=False, value="None")
762
+ processed_data_state = gr.State(value=None)
763
+ measure_points_state = gr.State(value=[])
764
+ target_dir_output = gr.Textbox(visible=False, value="None")
765
+
766
+ # ── Header ──────────────────────────────────────────────────────────
767
+ with gr.Column(elem_id="app-shell"):
768
+ with gr.Column(elem_id="app-header"):
769
+ gr.Markdown("# **Map-Anything-v1**")
770
+ gr.Markdown("Metric 3D Reconstruction (Point Cloud and Camera Poses)")
771
+
772
+ # ── Main two-column layout ───────────────────────────────────────
773
+ with gr.Row(equal_height=False):
774
+
775
+ # ── LEFT PANEL ─────────────────────────────────────────────
776
+ with gr.Column(elem_id="left-panel", scale=0):
777
+
778
+ gr.Markdown('<span class="section-label">Input</span>')
 
 
 
 
 
 
 
779
  unified_upload = gr.File(
780
  file_count="multiple",
781
  label="Upload Video or Images",
 
782
  file_types=["image", "video"],
783
+ elem_id="upload-zone",
784
  )
785
+
786
  with gr.Row():
787
  s_time_interval = gr.Slider(
788
+ minimum=0.1, maximum=5.0, value=1.0, step=0.1,
 
 
 
789
  label="Video sample interval (sec)",
790
+ scale=3,
 
791
  )
792
+ resample_btn = gr.Button("Resample", visible=False, variant="secondary", scale=1)
793
 
794
+ gr.Markdown('<span class="section-label">Preview</span>')
795
  image_gallery = gr.Gallery(
796
+ label="",
797
+ columns=3,
798
+ height="220px",
799
+ object_fit="cover",
800
  preview=True,
801
+ elem_id="preview-gallery",
802
+ show_label=False,
803
  )
804
+
805
+ gr.ClearButton(
806
  [unified_upload, image_gallery],
807
+ value="Clear uploads",
808
  variant="secondary",
809
  size="sm",
810
  )
811
 
812
+ gr.Markdown('<span class="section-label">Run</span>')
813
+ submit_btn = gr.Button("Reconstruct", variant="primary", elem_id="btn-reconstruct")
814
+
815
+ # Options accordion
816
+ with gr.Accordion("Options", open=False, elem_id="options-panel"):
817
+ gr.Markdown("### Point Cloud")
818
+ show_cam = gr.Checkbox(label="Show cameras", value=True)
819
+ show_mesh = gr.Checkbox(label="Show mesh", value=True)
820
+ filter_black_bg = gr.Checkbox(label="Filter black background", value=False)
821
+ filter_white_bg = gr.Checkbox(label="Filter white background", value=False)
822
+ gr.Markdown("### Reconstruction (next run)")
823
+ apply_mask_checkbox = gr.Checkbox(
824
+ label="Apply ambiguous-depth mask & edges", value=True
 
 
 
 
825
  )
826
 
827
+ gr.Markdown('<span class="section-label">Filter by frame</span>')
 
828
  frame_filter = gr.Dropdown(
829
+ choices=["All"], value="All", label="",
830
+ elem_id="frame-filter", show_label=False,
831
  )
832
+
833
+ # ── RIGHT PANEL ────────────────────────────────────────────
834
+ with gr.Column(elem_id="right-panel", scale=1):
835
+
836
+ # Status log
837
+ log_output = gr.Markdown(
838
+ "Upload a video or images, then click **Reconstruct**.",
839
+ elem_id="log-strip",
840
+ )
841
+
842
+ # Viewer tabs
843
+ with gr.Tabs(elem_id="viewer-tabs"):
844
+
845
+ # 3-D View
846
+ with gr.Tab("3D View"):
847
+ reconstruction_output = Rerun(
848
+ label="Rerun 3D Viewer",
849
+ height=560,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
850
  )
851
+
852
+ # Depth
853
+ with gr.Tab("Depth"):
854
+ with gr.Row(elem_classes=["nav-row"]):
855
+ prev_depth_btn = gr.Button("◀ Prev", size="sm", scale=1)
856
+ depth_view_selector = gr.Dropdown(
857
+ choices=["View 1"], value="View 1",
858
+ label="View", scale=3, interactive=True,
859
+ allow_custom_value=True, show_label=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
860
  )
861
+ next_depth_btn = gr.Button("Next ▶", size="sm", scale=1)
862
+ depth_map = gr.Image(
863
+ type="numpy", label="Depth Map",
864
+ format="png", interactive=False,
865
+ )
866
+
867
+ # Normal
868
+ with gr.Tab("Normal"):
869
+ with gr.Row(elem_classes=["nav-row"]):
870
+ prev_normal_btn = gr.Button("◀ Prev", size="sm", scale=1)
871
+ normal_view_selector = gr.Dropdown(
872
+ choices=["View 1"], value="View 1",
873
+ label="View", scale=3, interactive=True,
874
+ allow_custom_value=True, show_label=False,
875
  )
876
+ next_normal_btn = gr.Button("Next ▶", size="sm", scale=1)
877
+ normal_map = gr.Image(
878
+ type="numpy", label="Normal Map",
879
+ format="png", interactive=False,
880
+ )
881
+
882
+ # Measure
883
+ with gr.Tab("Measure"):
884
+ gr.Markdown(MEASURE_INSTRUCTIONS_HTML)
885
+ with gr.Row(elem_classes=["nav-row"]):
886
+ prev_measure_btn = gr.Button("◀ Prev", size="sm", scale=1)
887
+ measure_view_selector = gr.Dropdown(
888
+ choices=["View 1"], value="View 1",
889
+ label="View", scale=3, interactive=True,
890
+ allow_custom_value=True, show_label=False,
891
  )
892
+ next_measure_btn = gr.Button("Next ▶", size="sm", scale=1)
893
+ measure_image = gr.Image(
894
+ type="numpy", show_label=False,
895
+ format="webp", interactive=False, sources=[],
896
+ )
897
+ gr.Markdown(
898
+ "Light-grey areas have no depth — measurements cannot be placed there.",
899
+ elem_classes=["measure-note"],
900
+ )
901
+ measure_text = gr.Markdown("")
902
+
903
+ # ── Examples ────────────────────────────────────────────────────
904
+ with gr.Column(elem_id="examples-section"):
905
+ gr.Markdown("## Example Scenes")
906
+ gr.Markdown("Click a thumbnail to load the scene, then press **Reconstruct**.")
907
+
908
+ scenes = get_scene_info("examples")
909
+ if scenes:
910
+ for i in range(0, len(scenes), 4):
911
+ with gr.Row():
912
+ for j in range(4):
913
+ idx = i + j
914
+ if idx < len(scenes):
915
+ scene = scenes[idx]
916
+ with gr.Column(scale=1, min_width=140, elem_classes=["scene-thumb"]):
917
+ scene_img = gr.Image(
918
+ value=scene["thumbnail"],
919
+ height=130,
920
+ interactive=False,
921
+ show_label=False,
922
+ sources=[],
923
+ )
924
+ gr.Markdown(
925
+ f"**{scene['name']}** \n{scene['num_images']} imgs",
926
+ elem_classes=["scene-caption"],
927
+ )
928
+ scene_img.select(
929
+ fn=lambda name=scene["name"]: load_example_scene(name),
930
+ outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
931
+ )
932
+ else:
933
+ with gr.Column(scale=1, min_width=140):
934
+ pass
935
 
936
  # =========================================================================
937
+ # Event wiring
938
  # =========================================================================
939
 
940
+ # Reconstruct button
941
+ submit_btn.click(
942
+ fn=clear_fields, inputs=[], outputs=[reconstruction_output]
943
+ ).then(
944
  fn=update_log, inputs=[], outputs=[log_output]
945
  ).then(
946
  fn=gradio_demo,
947
+ inputs=[target_dir_output, frame_filter, show_cam, filter_black_bg, filter_white_bg, apply_mask_checkbox, show_mesh],
 
 
 
 
 
 
 
 
948
  outputs=[
949
+ reconstruction_output, log_output, frame_filter, processed_data_state,
950
+ depth_map, normal_map, measure_image, measure_text,
951
+ depth_view_selector, normal_view_selector, measure_view_selector,
 
 
 
 
 
 
 
 
952
  ],
953
+ ).then(fn=lambda: "False", inputs=[], outputs=[is_example])
954
+
955
+ # Live visualization option updates
956
+ for trigger_inputs, trigger in [
957
+ ([target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh], frame_filter.change),
958
+ ([target_dir_output, frame_filter, show_cam, is_example], show_cam.change),
959
+ ([target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh], show_mesh.change),
960
+ ]:
961
+ trigger(update_visualization, trigger_inputs, [reconstruction_output, log_output])
962
 
 
 
 
 
 
 
 
 
 
 
 
963
  filter_black_bg.change(
964
  update_visualization,
965
  [target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg],
966
  [reconstruction_output, log_output],
967
  ).then(
968
+ update_all_views_on_filter_change,
969
+ [target_dir_output, filter_black_bg, filter_white_bg, processed_data_state, depth_view_selector, normal_view_selector, measure_view_selector],
970
+ [processed_data_state, depth_map, normal_map, measure_image, measure_points_state],
971
  )
972
  filter_white_bg.change(
973
  update_visualization,
974
  [target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh],
975
  [reconstruction_output, log_output],
976
  ).then(
977
+ update_all_views_on_filter_change,
978
+ [target_dir_output, filter_black_bg, filter_white_bg, processed_data_state, depth_view_selector, normal_view_selector, measure_view_selector],
979
+ [processed_data_state, depth_map, normal_map, measure_image, measure_points_state],
 
 
 
 
 
980
  )
981
 
982
+ # Upload handling
983
  def update_gallery_on_unified_upload(files, interval):
984
  if not files:
985
+ return None, None, None
986
  target_dir, image_paths = handle_uploads(files, interval)
987
+ return target_dir, image_paths, "Upload complete. Click **Reconstruct** to begin."
988
 
989
  def show_resample_button(files):
990
+ if not files:
991
+ return gr.update(visible=False)
992
  video_exts = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"]
993
+ has_video = any(os.path.splitext(str(f["name"] if isinstance(f, dict) else f))[1].lower() in video_exts for f in files)
 
 
 
 
 
994
  return gr.update(visible=has_video)
995
 
996
  def resample_video_with_new_interval(files, new_interval, current_target_dir):
997
+ if not files:
998
+ return current_target_dir, None, "No files to resample.", gr.update(visible=False)
999
  video_exts = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"]
1000
+ if not any(os.path.splitext(str(f["name"] if isinstance(f, dict) else f))[1].lower() in video_exts for f in files):
1001
+ return current_target_dir, None, "No videos found.", gr.update(visible=False)
 
 
1002
  if current_target_dir and current_target_dir != "None" and os.path.exists(current_target_dir):
1003
  shutil.rmtree(current_target_dir)
 
1004
  target_dir, image_paths = handle_uploads(files, new_interval)
1005
+ return target_dir, image_paths, f"Resampled at {new_interval}s. Click **Reconstruct**.", gr.update(visible=False)
1006
 
1007
  unified_upload.change(
1008
  fn=update_gallery_on_unified_upload,
1009
  inputs=[unified_upload, s_time_interval],
1010
  outputs=[target_dir_output, image_gallery, log_output],
1011
+ ).then(fn=show_resample_button, inputs=[unified_upload], outputs=[resample_btn])
 
 
 
 
 
 
 
 
 
 
1012
 
1013
+ s_time_interval.change(fn=show_resample_button, inputs=[unified_upload], outputs=[resample_btn])
1014
  resample_btn.click(
1015
  fn=resample_video_with_new_interval,
1016
  inputs=[unified_upload, s_time_interval, target_dir_output],
1017
  outputs=[target_dir_output, image_gallery, log_output, resample_btn],
1018
  )
1019
 
1020
+ # Measure tab
1021
  measure_image.select(
1022
  fn=measure,
1023
  inputs=[processed_data_state, measure_points_state, measure_view_selector],
1024
  outputs=[measure_image, measure_points_state, measure_text],
1025
  )
1026
 
1027
+ # Depth tab navigation
1028
  prev_depth_btn.click(
1029
+ fn=lambda pd, sel: navigate_depth_view(pd, sel, -1),
1030
  inputs=[processed_data_state, depth_view_selector], outputs=[depth_view_selector, depth_map],
1031
  )
1032
  next_depth_btn.click(
1033
+ fn=lambda pd, sel: navigate_depth_view(pd, sel, 1),
1034
  inputs=[processed_data_state, depth_view_selector], outputs=[depth_view_selector, depth_map],
1035
  )
1036
  depth_view_selector.change(
1037
+ fn=lambda pd, sel: update_depth_view(pd, int(sel.split()[1]) - 1) if sel else None,
1038
  inputs=[processed_data_state, depth_view_selector], outputs=[depth_map],
1039
  )
1040
 
1041
+ # Normal tab navigation
1042
  prev_normal_btn.click(
1043
+ fn=lambda pd, sel: navigate_normal_view(pd, sel, -1),
1044
  inputs=[processed_data_state, normal_view_selector], outputs=[normal_view_selector, normal_map],
1045
  )
1046
  next_normal_btn.click(
1047
+ fn=lambda pd, sel: navigate_normal_view(pd, sel, 1),
1048
  inputs=[processed_data_state, normal_view_selector], outputs=[normal_view_selector, normal_map],
1049
  )
1050
  normal_view_selector.change(
1051
+ fn=lambda pd, sel: update_normal_view(pd, int(sel.split()[1]) - 1) if sel else None,
1052
  inputs=[processed_data_state, normal_view_selector], outputs=[normal_map],
1053
  )
1054
 
1055
+ # Measure tab navigation
1056
  prev_measure_btn.click(
1057
+ fn=lambda pd, sel: navigate_measure_view(pd, sel, -1),
1058
+ inputs=[processed_data_state, measure_view_selector],
1059
+ outputs=[measure_view_selector, measure_image, measure_points_state],
1060
  )
1061
  next_measure_btn.click(
1062
+ fn=lambda pd, sel: navigate_measure_view(pd, sel, 1),
1063
+ inputs=[processed_data_state, measure_view_selector],
1064
+ outputs=[measure_view_selector, measure_image, measure_points_state],
1065
  )
1066
  measure_view_selector.change(
1067
+ fn=lambda pd, sel: update_measure_view(pd, int(sel.split()[1]) - 1) if sel else (None, []),
1068
+ inputs=[processed_data_state, measure_view_selector],
1069
+ outputs=[measure_image, measure_points_state],
1070
  )
1071
 
1072
+ demo.queue(max_size=20).launch(css=CUSTOM_CSS, show_error=True, share=True, ssr_mode=False)