prithivMLmods commited on
Commit
60c5d65
·
verified ·
1 Parent(s): e6c43aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +809 -84
app.py CHANGED
@@ -7,7 +7,9 @@ from datetime import datetime
7
 
8
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
9
 
 
10
  import gradio as gr
 
11
  import numpy as np
12
  import spaces
13
  import torch
@@ -23,14 +25,45 @@ from mapanything.utils.hf_utils.hf_helpers import initialize_mapanything_model
23
  from mapanything.utils.hf_utils.viz import predictions_to_glb
24
  from mapanything.utils.image import load_images
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  register_heif_opener()
27
  sys.path.append("mapanything/")
28
 
 
29
  # ============================================================================
30
  # Global Configuration
31
  # ============================================================================
32
 
33
- # MapAnything Configuration
34
  high_level_config = {
35
  "path": "configs/train.yaml",
36
  "hf_model_name": "facebook/map-anything",
@@ -50,11 +83,12 @@ high_level_config = {
50
  "resolution": 518,
51
  }
52
 
53
- # Global model variables
54
  model = None
55
 
 
56
  # ============================================================================
57
- # Core Model Inference
58
  # ============================================================================
59
 
60
  @spaces.GPU(duration=120)
@@ -108,13 +142,13 @@ def run_model(
108
  images_list = []
109
  final_mask_list = []
110
  confidences = []
111
-
112
  for pred in outputs:
113
  depthmap_torch = pred["depth_z"][0].squeeze(-1)
114
  intrinsics_torch = pred["intrinsics"][0]
115
  camera_pose_torch = pred["camera_poses"][0]
116
  conf = pred["conf"][0].squeeze(-1)
117
-
118
  pts3d_computed, valid_mask = depthmap_to_world_frame(
119
  depthmap_torch, intrinsics_torch, camera_pose_torch
120
  )
@@ -139,12 +173,12 @@ def run_model(
139
  predictions["intrinsic"] = np.stack(intrinsic_list, axis=0)
140
  predictions["world_points"] = np.stack(world_points_list, axis=0)
141
  predictions["conf"] = np.stack(confidences, axis=0)
142
-
143
  depth_maps = np.stack(depth_maps_list, axis=0)
144
  if len(depth_maps.shape) == 3:
145
  depth_maps = depth_maps[..., np.newaxis]
146
  predictions["depth"] = depth_maps
147
-
148
  predictions["images"] = np.stack(images_list, axis=0)
149
  predictions["final_mask"] = np.stack(final_mask_list, axis=0)
150
 
@@ -155,11 +189,362 @@ def run_model(
155
 
156
 
157
  # ============================================================================
158
- # Helper Functions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  # ============================================================================
160
 
161
  def handle_uploads(input_images):
162
- """Handle uploaded images"""
163
  start_time = time.time()
164
  gc.collect()
165
  torch.cuda.empty_cache()
@@ -175,7 +560,6 @@ def handle_uploads(input_images):
175
 
176
  image_paths = []
177
 
178
- # Handle images
179
  if input_images is not None:
180
  for file_data in input_images:
181
  if isinstance(file_data, dict) and "name" in file_data:
@@ -211,7 +595,7 @@ def handle_uploads(input_images):
211
 
212
 
213
  def update_gallery_on_upload(input_images):
214
- """Update gallery on upload"""
215
  if not input_images:
216
  return None, None, None, None
217
  target_dir, image_paths = handle_uploads(input_images)
@@ -223,6 +607,10 @@ def update_gallery_on_upload(input_images):
223
  )
224
 
225
 
 
 
 
 
226
  @spaces.GPU(duration=120)
227
  def gradio_demo(
228
  target_dir,
@@ -234,9 +622,15 @@ def gradio_demo(
234
  apply_mask=True,
235
  show_mesh=True,
236
  ):
237
- """Perform reconstruction"""
238
  if not os.path.isdir(target_dir) or target_dir == "None":
239
- return None, "Please upload files first", None
 
 
 
 
 
 
240
 
241
  start_time = time.time()
242
  gc.collect()
@@ -247,18 +641,19 @@ def gradio_demo(
247
  all_files_display = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
248
  frame_filter_choices = ["All"] + all_files_display
249
 
 
250
  print("Running MapAnything model...")
251
  with torch.no_grad():
252
  predictions = run_model(target_dir, apply_mask)
253
 
254
- # Save prediction results
255
  prediction_save_path = os.path.join(target_dir, "predictions.npz")
256
  np.savez(prediction_save_path, **predictions)
257
 
258
  if frame_filter is None:
259
  frame_filter = "All"
260
 
261
- # Generate raw GLB
262
  glbfile = os.path.join(
263
  target_dir,
264
  f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}.glb",
@@ -275,6 +670,13 @@ def gradio_demo(
275
  )
276
  glbscene.export(file_obj=glbfile)
277
 
 
 
 
 
 
 
 
278
  # Cleanup
279
  del predictions
280
  gc.collect()
@@ -285,19 +687,32 @@ def gradio_demo(
285
  log_msg = f"✅ Reconstruction successful ({len(all_files)} frames)"
286
 
287
  return (
288
- glbfile,
289
- log_msg,
290
- gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
 
 
 
 
 
 
 
 
 
291
  )
292
 
293
 
 
 
 
 
294
  def clear_fields():
295
- """Clear 3D viewer"""
296
  return None
297
 
298
 
299
  def update_log():
300
- """Display log message"""
301
  return "Loading and reconstructing..."
302
 
303
 
@@ -311,7 +726,10 @@ def update_visualization(
311
  filter_white_bg=False,
312
  show_mesh=True,
313
  ):
314
- """Update visualization"""
 
 
 
315
  if is_example == "True":
316
  return gr.update(), "No reconstruction available. Please click the reconstruct button first."
317
 
@@ -344,12 +762,72 @@ def update_visualization(
344
  return glbfile, "Visualization updated."
345
 
346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  # ============================================================================
348
- # Example Scenes
349
  # ============================================================================
350
 
351
  def get_scene_info(examples_dir):
352
- """Get information about scenes in the examples directory"""
353
  import glob
354
 
355
  scenes = []
@@ -384,7 +862,7 @@ def get_scene_info(examples_dir):
384
 
385
 
386
  def load_example_scene(scene_name, examples_dir="examples"):
387
- """Load a scene from examples directory"""
388
  scenes = get_scene_info(examples_dir)
389
 
390
  selected_scene = None
@@ -407,12 +885,11 @@ def load_example_scene(scene_name, examples_dir="examples"):
407
 
408
 
409
  # ============================================================================
410
- # Gradio UI
411
  # ============================================================================
412
 
413
  theme = get_gradio_theme()
414
 
415
- # Custom CSS to prevent UI jitter
416
  APP_CSS = GRADIO_CSS + """
417
  /* Prevent components from expanding the layout */
418
  .gradio-container {
@@ -440,57 +917,150 @@ APP_CSS = GRADIO_CSS + """
440
  .tab-content {
441
  min-height: 550px !important;
442
  }
 
 
 
 
 
 
 
443
  """
444
 
445
  with gr.Blocks() as demo:
 
446
  is_example = gr.Textbox(label="is_example", visible=False, value="None")
447
-
448
  target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
 
 
449
 
450
  with gr.Row(equal_height=False):
451
- # Left Side: Input Area
452
  with gr.Column(scale=1, min_width=300):
453
  gr.Markdown("### 📤 Input")
454
-
455
  input_images = gr.File(
456
- file_count="multiple",
457
- label="Upload multiple images (3-10 recommended)",
458
  interactive=True,
459
- height=200
460
  )
461
-
462
  image_gallery = gr.Gallery(
463
- label="Image Preview", columns=3, height=350,
464
- object_fit="contain", preview=True
 
 
 
465
  )
466
-
467
  with gr.Row():
468
- submit_btn = gr.Button("🚀 Start Reconstruction", variant="primary", scale=2)
 
 
469
  clear_btn = gr.ClearButton(
470
  [input_images, target_dir_output, image_gallery],
471
- value="🗑️ Clear", scale=1
 
472
  )
473
 
474
- # Right Side: Output Area
475
  with gr.Column(scale=2, min_width=600):
476
  gr.Markdown("### 🎯 Output")
477
 
478
  with gr.Tabs():
 
479
  with gr.Tab("🏗️ Raw 3D"):
480
  reconstruction_output = gr.Model3D(
481
- height=550, zoom_speed=0.5, pan_speed=0.5,
482
- clear_color=[0.0, 0.0, 0.0, 0.0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
  )
484
-
 
 
 
 
485
  log_output = gr.Textbox(
486
  value="📌 Please upload images, then click 'Start Reconstruction'",
487
  label="Status Information",
488
  interactive=False,
489
  lines=1,
490
- max_lines=1
491
  )
492
 
493
- # Advanced Options (Collapsible)
494
  with gr.Accordion("⚙️ Advanced Options", open=False):
495
  with gr.Row(equal_height=False):
496
  with gr.Column(scale=1, min_width=300):
@@ -499,21 +1069,28 @@ with gr.Blocks() as demo:
499
  choices=["All"], value="All", label="Display Frame"
500
  )
501
  conf_thres = gr.Slider(
502
- minimum=0, maximum=100, value=0, step=0.1,
503
- label="Confidence Threshold (Percentile)"
 
 
 
504
  )
505
  show_cam = gr.Checkbox(label="Show Camera", value=True)
506
  show_mesh = gr.Checkbox(label="Show Mesh", value=True)
507
- filter_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
508
- filter_white_bg = gr.Checkbox(label="Filter White Background", value=False)
509
-
 
 
 
 
510
  with gr.Column(scale=1, min_width=300):
511
  gr.Markdown("#### Reconstruction Parameters")
512
  apply_mask_checkbox = gr.Checkbox(
513
  label="Apply Depth Mask", value=True
514
  )
515
 
516
- # Example Scenes (Collapsible)
517
  with gr.Accordion("🖼️ Example Scenes", open=False):
518
  scenes = get_scene_info("examples")
519
  if scenes:
@@ -525,68 +1102,216 @@ with gr.Blocks() as demo:
525
  scene = scenes[scene_idx]
526
  with gr.Column(scale=1, min_width=150):
527
  scene_img = gr.Image(
528
- value=scene["thumbnail"],
529
  height=150,
530
- interactive=False,
531
- show_label=False,
532
  sources=[],
533
- container=False
534
  )
535
  gr.Markdown(
536
  f"**{scene['name']}** ({scene['num_images']} images)",
537
- elem_classes=["text-center"]
538
  )
539
  scene_img.select(
540
- fn=lambda name=scene["name"]: load_example_scene(name),
 
 
541
  outputs=[
542
  reconstruction_output,
543
- target_dir_output, image_gallery, log_output
544
- ]
 
 
545
  )
546
 
547
- # === Event Binding ===
548
-
549
- # Auto update on file upload
 
 
550
  input_images.change(
551
  fn=update_gallery_on_upload,
552
  inputs=[input_images],
553
- outputs=[reconstruction_output, target_dir_output, image_gallery, log_output]
 
 
 
 
 
 
 
 
554
  )
555
-
556
- # Reconstruction button
557
  submit_btn.click(
558
- fn=clear_fields,
559
- outputs=[reconstruction_output]
560
  ).then(
561
  fn=update_log,
562
- outputs=[log_output]
563
  ).then(
564
  fn=gradio_demo,
565
  inputs=[
566
- target_dir_output, frame_filter, show_cam,
567
- filter_black_bg, filter_white_bg, conf_thres,
568
- apply_mask_checkbox, show_mesh
 
 
 
 
 
569
  ],
570
  outputs=[
571
- reconstruction_output, log_output, frame_filter
572
- ]
 
 
 
 
 
 
 
 
 
 
 
573
  ).then(
574
  fn=lambda: "False",
575
- outputs=[is_example]
576
  )
577
-
578
- # Clear button
579
- clear_btn.add([reconstruction_output, log_output])
580
-
581
- # Visualization parameters real-time update
582
- for component in [frame_filter, show_cam, conf_thres, show_mesh, filter_black_bg, filter_white_bg]:
583
  component.change(
584
- fn=update_visualization,
585
  inputs=[
586
- target_dir_output, frame_filter, show_cam, is_example,
587
- conf_thres, filter_black_bg, filter_white_bg, show_mesh
 
 
 
 
 
 
 
 
 
 
 
588
  ],
589
- outputs=[reconstruction_output, log_output]
590
  )
591
 
592
- demo.queue(max_size=20).launch(theme=theme, css=APP_CSS, show_error=True, share=True, ssr_mode=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
9
 
10
+ import cv2
11
  import gradio as gr
12
+ import matplotlib.pyplot as plt
13
  import numpy as np
14
  import spaces
15
  import torch
 
25
  from mapanything.utils.hf_utils.viz import predictions_to_glb
26
  from mapanything.utils.image import load_images
27
 
28
+ # Optional imports with fallbacks
29
+ try:
30
+ from mapanything.utils.geometry import points_to_normals
31
+ except ImportError:
32
+ def points_to_normals(points3d, mask=None):
33
+ """Fallback: compute surface normals from 3D point cloud via cross products"""
34
+ H, W, _ = points3d.shape
35
+ dpdx = np.zeros_like(points3d)
36
+ dpdy = np.zeros_like(points3d)
37
+ dpdx[:, :-1] = points3d[:, 1:] - points3d[:, :-1]
38
+ dpdy[:-1, :] = points3d[1:, :] - points3d[:-1, :]
39
+ normals = np.cross(dpdx, dpdy)
40
+ norms = np.linalg.norm(normals, axis=-1, keepdims=True)
41
+ norms = np.maximum(norms, 1e-8)
42
+ normals = normals / norms
43
+ valid = norms.squeeze(-1) > 1e-6
44
+ if mask is not None:
45
+ valid = valid & mask
46
+ return normals, valid
47
+
48
+ try:
49
+ from mapanything.utils.hf_utils.css_and_html import MEASURE_INSTRUCTIONS_HTML
50
+ except ImportError:
51
+ MEASURE_INSTRUCTIONS_HTML = """
52
+ **📏 Measurement Tool:**
53
+ 1. Click on the **first point** in the image to mark it
54
+ 2. Click on the **second point** to measure the 3D distance between them
55
+ 3. The depth of each point and the computed 3D distance will be displayed below
56
+ 4. After each measurement, click two new points for a new measurement
57
+ """
58
+
59
  register_heif_opener()
60
  sys.path.append("mapanything/")
61
 
62
+
63
  # ============================================================================
64
  # Global Configuration
65
  # ============================================================================
66
 
 
67
  high_level_config = {
68
  "path": "configs/train.yaml",
69
  "hf_model_name": "facebook/map-anything",
 
83
  "resolution": 518,
84
  }
85
 
86
+ # Global model variable
87
  model = None
88
 
89
+
90
  # ============================================================================
91
+ # Core Model Inference (KEPT AS-IS)
92
  # ============================================================================
93
 
94
  @spaces.GPU(duration=120)
 
142
  images_list = []
143
  final_mask_list = []
144
  confidences = []
145
+
146
  for pred in outputs:
147
  depthmap_torch = pred["depth_z"][0].squeeze(-1)
148
  intrinsics_torch = pred["intrinsics"][0]
149
  camera_pose_torch = pred["camera_poses"][0]
150
  conf = pred["conf"][0].squeeze(-1)
151
+
152
  pts3d_computed, valid_mask = depthmap_to_world_frame(
153
  depthmap_torch, intrinsics_torch, camera_pose_torch
154
  )
 
173
  predictions["intrinsic"] = np.stack(intrinsic_list, axis=0)
174
  predictions["world_points"] = np.stack(world_points_list, axis=0)
175
  predictions["conf"] = np.stack(confidences, axis=0)
176
+
177
  depth_maps = np.stack(depth_maps_list, axis=0)
178
  if len(depth_maps.shape) == 3:
179
  depth_maps = depth_maps[..., np.newaxis]
180
  predictions["depth"] = depth_maps
181
+
182
  predictions["images"] = np.stack(images_list, axis=0)
183
  predictions["final_mask"] = np.stack(final_mask_list, axis=0)
184
 
 
189
 
190
 
191
  # ============================================================================
192
+ # Visualization Processing Functions (NEW - for Depth, Normal, Measure tabs)
193
+ # ============================================================================
194
+
195
+ def process_predictions_for_visualization(
196
+ predictions, filter_black_bg=False, filter_white_bg=False
197
+ ):
198
+ """Extract depth, normal, and 3D points from predictions for per-view visualization tabs."""
199
+ processed_data = {}
200
+ num_views = predictions["images"].shape[0]
201
+
202
+ for view_idx in range(num_views):
203
+ image = predictions["images"][view_idx] # (H, W, 3)
204
+ pred_pts3d = predictions["world_points"][view_idx] # (H, W, 3)
205
+ depth = predictions["depth"][view_idx].squeeze() # (H, W)
206
+ mask = predictions["final_mask"][view_idx].copy() # (H, W)
207
+
208
+ # Apply black background filtering
209
+ if filter_black_bg:
210
+ view_colors = image * 255 if image.max() <= 1.0 else image.copy()
211
+ black_bg_mask = view_colors.sum(axis=2) >= 16
212
+ mask = mask & black_bg_mask
213
+
214
+ # Apply white background filtering
215
+ if filter_white_bg:
216
+ view_colors = image * 255 if image.max() <= 1.0 else image.copy()
217
+ white_bg_mask = ~(
218
+ (view_colors[:, :, 0] > 240)
219
+ & (view_colors[:, :, 1] > 240)
220
+ & (view_colors[:, :, 2] > 240)
221
+ )
222
+ mask = mask & white_bg_mask
223
+
224
+ # Compute surface normals from 3D points
225
+ normals, _ = points_to_normals(pred_pts3d, mask=mask)
226
+
227
+ processed_data[view_idx] = {
228
+ "image": image,
229
+ "points3d": pred_pts3d,
230
+ "depth": depth,
231
+ "normal": normals,
232
+ "mask": mask,
233
+ }
234
+
235
+ return processed_data
236
+
237
+
238
+ def colorize_depth(depth_map, mask=None):
239
+ """Convert depth map to colorized visualization using turbo_r colormap."""
240
+ if depth_map is None:
241
+ return None
242
+
243
+ depth_normalized = depth_map.copy()
244
+ valid_mask = depth_normalized > 0
245
+
246
+ if mask is not None:
247
+ valid_mask = valid_mask & mask
248
+
249
+ if valid_mask.sum() > 0:
250
+ valid_depths = depth_normalized[valid_mask]
251
+ p5 = np.percentile(valid_depths, 5)
252
+ p95 = np.percentile(valid_depths, 95)
253
+ if p95 > p5:
254
+ depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5)
255
+ else:
256
+ depth_normalized[valid_mask] = 0.5
257
+
258
+ colormap = plt.cm.turbo_r
259
+ colored = colormap(np.clip(depth_normalized, 0, 1))
260
+ colored = (colored[:, :, :3] * 255).astype(np.uint8)
261
+
262
+ # Set invalid pixels to white
263
+ colored[~valid_mask] = [255, 255, 255]
264
+
265
+ return colored
266
+
267
+
268
+ def colorize_normal(normal_map, mask=None):
269
+ """Convert normal map to colorized visualization."""
270
+ if normal_map is None:
271
+ return None
272
+
273
+ normal_vis = normal_map.copy()
274
+
275
+ if mask is not None:
276
+ normal_vis[~mask] = [0, 0, 0]
277
+
278
+ # Map normals from [-1, 1] to [0, 1] then to [0, 255]
279
+ normal_vis = (normal_vis + 1.0) / 2.0
280
+ normal_vis = np.clip(normal_vis, 0, 1)
281
+ normal_vis = (normal_vis * 255).astype(np.uint8)
282
+
283
+ return normal_vis
284
+
285
+
286
+ def update_view_selectors(processed_data):
287
+ """Update view selector dropdowns based on available views."""
288
+ if processed_data is None or len(processed_data) == 0:
289
+ choices = ["View 1"]
290
+ else:
291
+ num_views = len(processed_data)
292
+ choices = [f"View {i + 1}" for i in range(num_views)]
293
+
294
+ return (
295
+ gr.Dropdown(choices=choices, value=choices[0]), # depth_view_selector
296
+ gr.Dropdown(choices=choices, value=choices[0]), # normal_view_selector
297
+ gr.Dropdown(choices=choices, value=choices[0]), # measure_view_selector
298
+ )
299
+
300
+
301
+ def get_view_data_by_index(processed_data, view_index):
302
+ """Get view data by index, handling bounds."""
303
+ if processed_data is None or len(processed_data) == 0:
304
+ return None
305
+
306
+ view_keys = list(processed_data.keys())
307
+ if view_index < 0 or view_index >= len(view_keys):
308
+ view_index = 0
309
+
310
+ return processed_data[view_keys[view_index]]
311
+
312
+
313
+ def update_depth_view(processed_data, view_index):
314
+ """Update depth view for a specific view index."""
315
+ view_data = get_view_data_by_index(processed_data, view_index)
316
+ if view_data is None or view_data["depth"] is None:
317
+ return None
318
+ return colorize_depth(view_data["depth"], mask=view_data.get("mask"))
319
+
320
+
321
+ def update_normal_view(processed_data, view_index):
322
+ """Update normal view for a specific view index."""
323
+ view_data = get_view_data_by_index(processed_data, view_index)
324
+ if view_data is None or view_data["normal"] is None:
325
+ return None
326
+ return colorize_normal(view_data["normal"], mask=view_data.get("mask"))
327
+
328
+
329
+ def update_measure_view(processed_data, view_index):
330
+ """Update measure view for a specific view index with mask overlay."""
331
+ view_data = get_view_data_by_index(processed_data, view_index)
332
+ if view_data is None:
333
+ return None, []
334
+
335
+ image = view_data["image"].copy()
336
+
337
+ # Ensure image is uint8
338
+ if image.dtype != np.uint8:
339
+ if image.max() <= 1.0:
340
+ image = (image * 255).astype(np.uint8)
341
+ else:
342
+ image = image.astype(np.uint8)
343
+
344
+ # Apply mask overlay — light pink tint on invalid regions
345
+ if view_data["mask"] is not None:
346
+ invalid_mask = ~view_data["mask"]
347
+ if invalid_mask.any():
348
+ overlay_color = np.array([255, 220, 220], dtype=np.uint8)
349
+ alpha = 0.5
350
+ for c in range(3):
351
+ image[:, :, c] = np.where(
352
+ invalid_mask,
353
+ (1 - alpha) * image[:, :, c] + alpha * overlay_color[c],
354
+ image[:, :, c],
355
+ ).astype(np.uint8)
356
+
357
+ return image, []
358
+
359
+
360
+ def navigate_depth_view(processed_data, current_selector_value, direction):
361
+ """Navigate depth view (direction: -1 for previous, +1 for next)."""
362
+ if processed_data is None or len(processed_data) == 0:
363
+ return "View 1", None
364
+ try:
365
+ current_view = int(current_selector_value.split()[1]) - 1
366
+ except Exception:
367
+ current_view = 0
368
+ num_views = len(processed_data)
369
+ new_view = (current_view + direction) % num_views
370
+ new_selector_value = f"View {new_view + 1}"
371
+ depth_vis = update_depth_view(processed_data, new_view)
372
+ return new_selector_value, depth_vis
373
+
374
+
375
+ def navigate_normal_view(processed_data, current_selector_value, direction):
376
+ """Navigate normal view (direction: -1 for previous, +1 for next)."""
377
+ if processed_data is None or len(processed_data) == 0:
378
+ return "View 1", None
379
+ try:
380
+ current_view = int(current_selector_value.split()[1]) - 1
381
+ except Exception:
382
+ current_view = 0
383
+ num_views = len(processed_data)
384
+ new_view = (current_view + direction) % num_views
385
+ new_selector_value = f"View {new_view + 1}"
386
+ normal_vis = update_normal_view(processed_data, new_view)
387
+ return new_selector_value, normal_vis
388
+
389
+
390
+ def navigate_measure_view(processed_data, current_selector_value, direction):
391
+ """Navigate measure view (direction: -1 for previous, +1 for next)."""
392
+ if processed_data is None or len(processed_data) == 0:
393
+ return "View 1", None, []
394
+ try:
395
+ current_view = int(current_selector_value.split()[1]) - 1
396
+ except Exception:
397
+ current_view = 0
398
+ num_views = len(processed_data)
399
+ new_view = (current_view + direction) % num_views
400
+ new_selector_value = f"View {new_view + 1}"
401
+ measure_image, measure_points = update_measure_view(processed_data, new_view)
402
+ return new_selector_value, measure_image, measure_points
403
+
404
+
405
+ def populate_visualization_tabs(processed_data):
406
+ """Populate the depth, normal, and measure tabs with initial data (view 0)."""
407
+ if processed_data is None or len(processed_data) == 0:
408
+ return None, None, None, []
409
+ depth_vis = update_depth_view(processed_data, 0)
410
+ normal_vis = update_normal_view(processed_data, 0)
411
+ measure_img, _ = update_measure_view(processed_data, 0)
412
+ return depth_vis, normal_vis, measure_img, []
413
+
414
+
415
+ def measure(processed_data, measure_points, current_view_selector, event: gr.SelectData):
416
+ """Handle click-to-measure on images: two clicks → 3D distance."""
417
+ try:
418
+ if processed_data is None or len(processed_data) == 0:
419
+ return None, [], "No data available"
420
+
421
+ # Determine which view is currently active
422
+ try:
423
+ current_view_index = int(current_view_selector.split()[1]) - 1
424
+ except Exception:
425
+ current_view_index = 0
426
+
427
+ if current_view_index < 0 or current_view_index >= len(processed_data):
428
+ current_view_index = 0
429
+
430
+ view_keys = list(processed_data.keys())
431
+ current_view = processed_data[view_keys[current_view_index]]
432
+
433
+ if current_view is None:
434
+ return None, [], "No view data available"
435
+
436
+ point2d = event.index[0], event.index[1]
437
+
438
+ # Reject clicks on masked (invalid) areas
439
+ if (
440
+ current_view["mask"] is not None
441
+ and 0 <= point2d[1] < current_view["mask"].shape[0]
442
+ and 0 <= point2d[0] < current_view["mask"].shape[1]
443
+ ):
444
+ if not current_view["mask"][point2d[1], point2d[0]]:
445
+ masked_image, _ = update_measure_view(processed_data, current_view_index)
446
+ return (
447
+ masked_image,
448
+ measure_points,
449
+ '<span style="color: red; font-weight: bold;">Cannot measure on masked areas (shown in grey)</span>',
450
+ )
451
+
452
+ measure_points.append(point2d)
453
+
454
+ # Get base image with mask overlay
455
+ image, _ = update_measure_view(processed_data, current_view_index)
456
+ if image is None:
457
+ return None, [], "No image available"
458
+
459
+ image = image.copy()
460
+ points3d = current_view["points3d"]
461
+
462
+ # Ensure uint8
463
+ if image.dtype != np.uint8:
464
+ if image.max() <= 1.0:
465
+ image = (image * 255).astype(np.uint8)
466
+ else:
467
+ image = image.astype(np.uint8)
468
+
469
+ # Draw circles on marked points
470
+ for p in measure_points:
471
+ if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]:
472
+ image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2)
473
+
474
+ # Build depth info text
475
+ depth_text = ""
476
+ for i, p in enumerate(measure_points):
477
+ if (
478
+ current_view["depth"] is not None
479
+ and 0 <= p[1] < current_view["depth"].shape[0]
480
+ and 0 <= p[0] < current_view["depth"].shape[1]
481
+ ):
482
+ d = current_view["depth"][p[1], p[0]]
483
+ depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n"
484
+ elif (
485
+ points3d is not None
486
+ and 0 <= p[1] < points3d.shape[0]
487
+ and 0 <= p[0] < points3d.shape[1]
488
+ ):
489
+ z = points3d[p[1], p[0], 2]
490
+ depth_text += f"- **P{i + 1} Z-coord: {z:.2f}m.**\n"
491
+
492
+ # If two points are marked, compute distance
493
+ if len(measure_points) == 2:
494
+ point1, point2 = measure_points
495
+
496
+ # Draw line between the two points
497
+ if (
498
+ 0 <= point1[0] < image.shape[1]
499
+ and 0 <= point1[1] < image.shape[0]
500
+ and 0 <= point2[0] < image.shape[1]
501
+ and 0 <= point2[1] < image.shape[0]
502
+ ):
503
+ image = cv2.line(image, point1, point2, color=(255, 0, 0), thickness=2)
504
+
505
+ # Compute 3D Euclidean distance
506
+ distance_text = "- **Distance: Unable to compute**"
507
+ if (
508
+ points3d is not None
509
+ and 0 <= point1[1] < points3d.shape[0]
510
+ and 0 <= point1[0] < points3d.shape[1]
511
+ and 0 <= point2[1] < points3d.shape[0]
512
+ and 0 <= point2[0] < points3d.shape[1]
513
+ ):
514
+ try:
515
+ p1_3d = points3d[point1[1], point1[0]]
516
+ p2_3d = points3d[point2[1], point2[0]]
517
+ distance = np.linalg.norm(p1_3d - p2_3d)
518
+ distance_text = f"- **Distance: {distance:.2f}m**"
519
+ except Exception as e:
520
+ distance_text = f"- **Distance computation error: {e}**"
521
+
522
+ # Reset points after measurement
523
+ measure_points = []
524
+ text = depth_text + distance_text
525
+ return [image, measure_points, text]
526
+ else:
527
+ return [image, measure_points, depth_text]
528
+
529
+ except Exception as e:
530
+ print(f"Measure error: {e}")
531
+ return None, [], f"Measure error: {e}"
532
+
533
+
534
+ def reset_measure(processed_data):
535
+ """Reset measure points and return clean image."""
536
+ if processed_data is None or len(processed_data) == 0:
537
+ return None, [], ""
538
+ first_view = list(processed_data.values())[0]
539
+ return first_view["image"], [], ""
540
+
541
+
542
+ # ============================================================================
543
+ # Helper Functions (KEPT AS-IS)
544
  # ============================================================================
545
 
546
  def handle_uploads(input_images):
547
+ """Handle uploaded images."""
548
  start_time = time.time()
549
  gc.collect()
550
  torch.cuda.empty_cache()
 
560
 
561
  image_paths = []
562
 
 
563
  if input_images is not None:
564
  for file_data in input_images:
565
  if isinstance(file_data, dict) and "name" in file_data:
 
595
 
596
 
597
  def update_gallery_on_upload(input_images):
598
+ """Update gallery on upload."""
599
  if not input_images:
600
  return None, None, None, None
601
  target_dir, image_paths = handle_uploads(input_images)
 
607
  )
608
 
609
 
610
+ # ============================================================================
611
+ # Main Reconstruction Function (Extended for new tabs)
612
+ # ============================================================================
613
+
614
  @spaces.GPU(duration=120)
615
  def gradio_demo(
616
  target_dir,
 
622
  apply_mask=True,
623
  show_mesh=True,
624
  ):
625
+ """Perform reconstruction and populate all tabs."""
626
  if not os.path.isdir(target_dir) or target_dir == "None":
627
+ return (
628
+ None, None,
629
+ "Please upload files first",
630
+ None, None,
631
+ None, None, None, "",
632
+ None, None, None,
633
+ )
634
 
635
  start_time = time.time()
636
  gc.collect()
 
641
  all_files_display = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
642
  frame_filter_choices = ["All"] + all_files_display
643
 
644
+ # ---- Run model (KEPT AS-IS) ----
645
  print("Running MapAnything model...")
646
  with torch.no_grad():
647
  predictions = run_model(target_dir, apply_mask)
648
 
649
+ # ---- Save predictions (KEPT AS-IS) ----
650
  prediction_save_path = os.path.join(target_dir, "predictions.npz")
651
  np.savez(prediction_save_path, **predictions)
652
 
653
  if frame_filter is None:
654
  frame_filter = "All"
655
 
656
+ # ---- Generate GLB (KEPT AS-IS) ----
657
  glbfile = os.path.join(
658
  target_dir,
659
  f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}.glb",
 
670
  )
671
  glbscene.export(file_obj=glbfile)
672
 
673
+ # ---- NEW: Process data for Depth / Normal / Measure tabs ----
674
+ processed_data = process_predictions_for_visualization(
675
+ predictions, filter_black_bg, filter_white_bg
676
+ )
677
+ depth_vis, normal_vis, measure_img, _ = populate_visualization_tabs(processed_data)
678
+ depth_selector, normal_selector, measure_selector = update_view_selectors(processed_data)
679
+
680
  # Cleanup
681
  del predictions
682
  gc.collect()
 
687
  log_msg = f"✅ Reconstruction successful ({len(all_files)} frames)"
688
 
689
  return (
690
+ glbfile, # reconstruction_output (Raw 3D)
691
+ glbfile, # reconstruction_output_3d (3D View)
692
+ log_msg, # log_output
693
+ gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True), # frame_filter
694
+ processed_data, # processed_data_state
695
+ depth_vis, # depth_map
696
+ normal_vis, # normal_map
697
+ measure_img, # measure_image
698
+ "", # measure_text
699
+ depth_selector, # depth_view_selector
700
+ normal_selector, # normal_view_selector
701
+ measure_selector, # measure_view_selector
702
  )
703
 
704
 
705
+ # ============================================================================
706
+ # UI Helper Functions
707
+ # ============================================================================
708
+
709
  def clear_fields():
710
+ """Clear 3D viewer."""
711
  return None
712
 
713
 
714
  def update_log():
715
+ """Display log message while processing."""
716
  return "Loading and reconstructing..."
717
 
718
 
 
726
  filter_white_bg=False,
727
  show_mesh=True,
728
  ):
729
+ """
730
+ Reload saved predictions from npz, create (or reuse) the GLB for new parameters.
731
+ KEPT AS-IS from original code.
732
+ """
733
  if is_example == "True":
734
  return gr.update(), "No reconstruction available. Please click the reconstruct button first."
735
 
 
762
  return glbfile, "Visualization updated."
763
 
764
 
765
+ def update_all_3d_views(
766
+ target_dir, frame_filter, show_cam, is_example,
767
+ conf_thres, filter_black_bg, filter_white_bg, show_mesh,
768
+ ):
769
+ """Wrapper: update both Raw 3D and 3D View tabs simultaneously."""
770
+ glb_result, log_msg = update_visualization(
771
+ target_dir, frame_filter, show_cam, is_example,
772
+ conf_thres, filter_black_bg, filter_white_bg, show_mesh,
773
+ )
774
+ return glb_result, glb_result, log_msg
775
+
776
+
777
+ def update_all_views_on_filter_change(
778
+ target_dir, filter_black_bg, filter_white_bg, processed_data,
779
+ depth_view_selector, normal_view_selector, measure_view_selector,
780
+ ):
781
+ """
782
+ Re-process per-view visualization (depth / normal / measure) when
783
+ background filter checkboxes change.
784
+ """
785
+ if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
786
+ return processed_data, None, None, None, []
787
+
788
+ predictions_path = os.path.join(target_dir, "predictions.npz")
789
+ if not os.path.exists(predictions_path):
790
+ return processed_data, None, None, None, []
791
+
792
+ try:
793
+ loaded = np.load(predictions_path, allow_pickle=True)
794
+ predictions = {key: loaded[key] for key in loaded.keys()}
795
+
796
+ new_processed_data = process_predictions_for_visualization(
797
+ predictions, filter_black_bg, filter_white_bg
798
+ )
799
+
800
+ # Determine current view indices
801
+ try:
802
+ depth_idx = int(depth_view_selector.split()[1]) - 1 if depth_view_selector else 0
803
+ except Exception:
804
+ depth_idx = 0
805
+ try:
806
+ normal_idx = int(normal_view_selector.split()[1]) - 1 if normal_view_selector else 0
807
+ except Exception:
808
+ normal_idx = 0
809
+ try:
810
+ measure_idx = int(measure_view_selector.split()[1]) - 1 if measure_view_selector else 0
811
+ except Exception:
812
+ measure_idx = 0
813
+
814
+ depth_vis = update_depth_view(new_processed_data, depth_idx)
815
+ normal_vis = update_normal_view(new_processed_data, normal_idx)
816
+ measure_img, _ = update_measure_view(new_processed_data, measure_idx)
817
+
818
+ return new_processed_data, depth_vis, normal_vis, measure_img, []
819
+
820
+ except Exception as e:
821
+ print(f"Error updating views on filter change: {e}")
822
+ return processed_data, None, None, None, []
823
+
824
+
825
  # ============================================================================
826
+ # Example Scenes (KEPT AS-IS)
827
  # ============================================================================
828
 
829
  def get_scene_info(examples_dir):
830
+ """Get information about scenes in the examples directory."""
831
  import glob
832
 
833
  scenes = []
 
862
 
863
 
864
  def load_example_scene(scene_name, examples_dir="examples"):
865
+ """Load a scene from examples directory."""
866
  scenes = get_scene_info(examples_dir)
867
 
868
  selected_scene = None
 
885
 
886
 
887
  # ============================================================================
888
+ # Gradio UI — 5 Tabs: Raw 3D · 3D View · Depth · Normal · Measure
889
  # ============================================================================
890
 
891
  theme = get_gradio_theme()
892
 
 
893
  APP_CSS = GRADIO_CSS + """
894
  /* Prevent components from expanding the layout */
895
  .gradio-container {
 
917
  .tab-content {
918
  min-height: 550px !important;
919
  }
920
+
921
+ /* Navigation row styling */
922
+ .navigation-row {
923
+ display: flex;
924
+ align-items: center;
925
+ gap: 8px;
926
+ }
927
  """
928
 
929
  with gr.Blocks() as demo:
930
+ # Hidden state variables
931
  is_example = gr.Textbox(label="is_example", visible=False, value="None")
 
932
  target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
933
+ processed_data_state = gr.State(value=None)
934
+ measure_points_state = gr.State(value=[])
935
 
936
  with gr.Row(equal_height=False):
937
+ # ==================== Left Side: Input Area ====================
938
  with gr.Column(scale=1, min_width=300):
939
  gr.Markdown("### 📤 Input")
940
+
941
  input_images = gr.File(
942
+ file_count="multiple",
943
+ label="Upload multiple images (3-10 recommended)",
944
  interactive=True,
945
+ height=200,
946
  )
947
+
948
  image_gallery = gr.Gallery(
949
+ label="Image Preview",
950
+ columns=3,
951
+ height=350,
952
+ object_fit="contain",
953
+ preview=True,
954
  )
955
+
956
  with gr.Row():
957
+ submit_btn = gr.Button(
958
+ "🚀 Start Reconstruction", variant="primary", scale=2
959
+ )
960
  clear_btn = gr.ClearButton(
961
  [input_images, target_dir_output, image_gallery],
962
+ value="🗑️ Clear",
963
+ scale=1,
964
  )
965
 
966
+ # ==================== Right Side: Output Area ====================
967
  with gr.Column(scale=2, min_width=600):
968
  gr.Markdown("### 🎯 Output")
969
 
970
  with gr.Tabs():
971
+ # ---------- Tab 1: Raw 3D (KEPT AS-IS) ----------
972
  with gr.Tab("🏗️ Raw 3D"):
973
  reconstruction_output = gr.Model3D(
974
+ height=550,
975
+ zoom_speed=0.5,
976
+ pan_speed=0.5,
977
+ clear_color=[0.0, 0.0, 0.0, 0.0],
978
+ )
979
+
980
+ # ---------- Tab 2: 3D View (NEW) ----------
981
+ with gr.Tab("🌐 3D View"):
982
+ reconstruction_output_3d = gr.Model3D(
983
+ height=550,
984
+ zoom_speed=0.5,
985
+ pan_speed=0.5,
986
+ clear_color=[0.05, 0.05, 0.05, 1.0],
987
+ )
988
+
989
+ # ---------- Tab 3: Depth (NEW) ----------
990
+ with gr.Tab("🔵 Depth"):
991
+ with gr.Row(elem_classes=["navigation-row"]):
992
+ prev_depth_btn = gr.Button("◀ Prev", size="sm", scale=1)
993
+ depth_view_selector = gr.Dropdown(
994
+ choices=["View 1"],
995
+ value="View 1",
996
+ label="Select View",
997
+ scale=2,
998
+ interactive=True,
999
+ allow_custom_value=True,
1000
+ )
1001
+ next_depth_btn = gr.Button("Next ▶", size="sm", scale=1)
1002
+ depth_map = gr.Image(
1003
+ type="numpy",
1004
+ label="Colorized Depth Map",
1005
+ format="png",
1006
+ interactive=False,
1007
+ )
1008
+
1009
+ # ---------- Tab 4: Normal (NEW) ----------
1010
+ with gr.Tab("🟢 Normal"):
1011
+ with gr.Row(elem_classes=["navigation-row"]):
1012
+ prev_normal_btn = gr.Button("◀ Prev", size="sm", scale=1)
1013
+ normal_view_selector = gr.Dropdown(
1014
+ choices=["View 1"],
1015
+ value="View 1",
1016
+ label="Select View",
1017
+ scale=2,
1018
+ interactive=True,
1019
+ allow_custom_value=True,
1020
+ )
1021
+ next_normal_btn = gr.Button("Next ▶", size="sm", scale=1)
1022
+ normal_map = gr.Image(
1023
+ type="numpy",
1024
+ label="Normal Map",
1025
+ format="png",
1026
+ interactive=False,
1027
+ )
1028
+
1029
+ # ---------- Tab 5: Measure (NEW) ----------
1030
+ with gr.Tab("📏 Measure"):
1031
+ gr.Markdown(MEASURE_INSTRUCTIONS_HTML)
1032
+ with gr.Row(elem_classes=["navigation-row"]):
1033
+ prev_measure_btn = gr.Button("◀ Prev", size="sm", scale=1)
1034
+ measure_view_selector = gr.Dropdown(
1035
+ choices=["View 1"],
1036
+ value="View 1",
1037
+ label="Select View",
1038
+ scale=2,
1039
+ interactive=True,
1040
+ allow_custom_value=True,
1041
+ )
1042
+ next_measure_btn = gr.Button("Next ▶", size="sm", scale=1)
1043
+ measure_image = gr.Image(
1044
+ type="numpy",
1045
+ show_label=False,
1046
+ format="webp",
1047
+ interactive=False,
1048
+ sources=[],
1049
  )
1050
+ gr.Markdown(
1051
+ "**Note:** Light-grey areas indicate regions with no depth information where measurements cannot be taken."
1052
+ )
1053
+ measure_text = gr.Markdown("")
1054
+
1055
  log_output = gr.Textbox(
1056
  value="📌 Please upload images, then click 'Start Reconstruction'",
1057
  label="Status Information",
1058
  interactive=False,
1059
  lines=1,
1060
+ max_lines=1,
1061
  )
1062
 
1063
+ # ==================== Advanced Options (Collapsible) ====================
1064
  with gr.Accordion("⚙️ Advanced Options", open=False):
1065
  with gr.Row(equal_height=False):
1066
  with gr.Column(scale=1, min_width=300):
 
1069
  choices=["All"], value="All", label="Display Frame"
1070
  )
1071
  conf_thres = gr.Slider(
1072
+ minimum=0,
1073
+ maximum=100,
1074
+ value=0,
1075
+ step=0.1,
1076
+ label="Confidence Threshold (Percentile)",
1077
  )
1078
  show_cam = gr.Checkbox(label="Show Camera", value=True)
1079
  show_mesh = gr.Checkbox(label="Show Mesh", value=True)
1080
+ filter_black_bg = gr.Checkbox(
1081
+ label="Filter Black Background", value=False
1082
+ )
1083
+ filter_white_bg = gr.Checkbox(
1084
+ label="Filter White Background", value=False
1085
+ )
1086
+
1087
  with gr.Column(scale=1, min_width=300):
1088
  gr.Markdown("#### Reconstruction Parameters")
1089
  apply_mask_checkbox = gr.Checkbox(
1090
  label="Apply Depth Mask", value=True
1091
  )
1092
 
1093
+ # ==================== Example Scenes (Collapsible) ====================
1094
  with gr.Accordion("🖼️ Example Scenes", open=False):
1095
  scenes = get_scene_info("examples")
1096
  if scenes:
 
1102
  scene = scenes[scene_idx]
1103
  with gr.Column(scale=1, min_width=150):
1104
  scene_img = gr.Image(
1105
+ value=scene["thumbnail"],
1106
  height=150,
1107
+ interactive=False,
1108
+ show_label=False,
1109
  sources=[],
1110
+ container=False,
1111
  )
1112
  gr.Markdown(
1113
  f"**{scene['name']}** ({scene['num_images']} images)",
1114
+ elem_classes=["text-center"],
1115
  )
1116
  scene_img.select(
1117
+ fn=lambda name=scene["name"]: load_example_scene(
1118
+ name
1119
+ ),
1120
  outputs=[
1121
  reconstruction_output,
1122
+ target_dir_output,
1123
+ image_gallery,
1124
+ log_output,
1125
+ ],
1126
  )
1127
 
1128
+ # ====================================================================
1129
+ # Event Binding
1130
+ # ====================================================================
1131
+
1132
+ # ---- Auto-update gallery on file upload ----
1133
  input_images.change(
1134
  fn=update_gallery_on_upload,
1135
  inputs=[input_images],
1136
+ outputs=[
1137
+ reconstruction_output,
1138
+ target_dir_output,
1139
+ image_gallery,
1140
+ log_output,
1141
+ ],
1142
+ ).then(
1143
+ fn=lambda: None,
1144
+ outputs=[reconstruction_output_3d],
1145
  )
1146
+
1147
+ # ---- Reconstruction button ----
1148
  submit_btn.click(
1149
+ fn=lambda: (None, None),
1150
+ outputs=[reconstruction_output, reconstruction_output_3d],
1151
  ).then(
1152
  fn=update_log,
1153
+ outputs=[log_output],
1154
  ).then(
1155
  fn=gradio_demo,
1156
  inputs=[
1157
+ target_dir_output,
1158
+ frame_filter,
1159
+ show_cam,
1160
+ filter_black_bg,
1161
+ filter_white_bg,
1162
+ conf_thres,
1163
+ apply_mask_checkbox,
1164
+ show_mesh,
1165
  ],
1166
  outputs=[
1167
+ reconstruction_output, # Raw 3D
1168
+ reconstruction_output_3d, # 3D View
1169
+ log_output,
1170
+ frame_filter,
1171
+ processed_data_state,
1172
+ depth_map,
1173
+ normal_map,
1174
+ measure_image,
1175
+ measure_text,
1176
+ depth_view_selector,
1177
+ normal_view_selector,
1178
+ measure_view_selector,
1179
+ ],
1180
  ).then(
1181
  fn=lambda: "False",
1182
+ outputs=[is_example],
1183
  )
1184
+
1185
+ # ---- Clear button: also clear new tabs ----
1186
+ clear_btn.add([reconstruction_output, reconstruction_output_3d, log_output])
1187
+
1188
+ # ---- 3D visualization param changes (frame_filter, show_cam, conf, mesh) ----
1189
+ for component in [frame_filter, show_cam, conf_thres, show_mesh]:
1190
  component.change(
1191
+ fn=update_all_3d_views,
1192
  inputs=[
1193
+ target_dir_output,
1194
+ frame_filter,
1195
+ show_cam,
1196
+ is_example,
1197
+ conf_thres,
1198
+ filter_black_bg,
1199
+ filter_white_bg,
1200
+ show_mesh,
1201
+ ],
1202
+ outputs=[
1203
+ reconstruction_output,
1204
+ reconstruction_output_3d,
1205
+ log_output,
1206
  ],
 
1207
  )
1208
 
1209
+ # ---- Background filter changes: update 3D viewers AND per-view tabs ----
1210
+ for filter_component in [filter_black_bg, filter_white_bg]:
1211
+ filter_component.change(
1212
+ fn=update_all_3d_views,
1213
+ inputs=[
1214
+ target_dir_output,
1215
+ frame_filter,
1216
+ show_cam,
1217
+ is_example,
1218
+ conf_thres,
1219
+ filter_black_bg,
1220
+ filter_white_bg,
1221
+ show_mesh,
1222
+ ],
1223
+ outputs=[
1224
+ reconstruction_output,
1225
+ reconstruction_output_3d,
1226
+ log_output,
1227
+ ],
1228
+ ).then(
1229
+ fn=update_all_views_on_filter_change,
1230
+ inputs=[
1231
+ target_dir_output,
1232
+ filter_black_bg,
1233
+ filter_white_bg,
1234
+ processed_data_state,
1235
+ depth_view_selector,
1236
+ normal_view_selector,
1237
+ measure_view_selector,
1238
+ ],
1239
+ outputs=[
1240
+ processed_data_state,
1241
+ depth_map,
1242
+ normal_map,
1243
+ measure_image,
1244
+ measure_points_state,
1245
+ ],
1246
+ )
1247
+
1248
+ # ---- Depth tab navigation ----
1249
+ prev_depth_btn.click(
1250
+ fn=lambda pd, cs: navigate_depth_view(pd, cs, -1),
1251
+ inputs=[processed_data_state, depth_view_selector],
1252
+ outputs=[depth_view_selector, depth_map],
1253
+ )
1254
+ next_depth_btn.click(
1255
+ fn=lambda pd, cs: navigate_depth_view(pd, cs, 1),
1256
+ inputs=[processed_data_state, depth_view_selector],
1257
+ outputs=[depth_view_selector, depth_map],
1258
+ )
1259
+ depth_view_selector.change(
1260
+ fn=lambda pd, sv: (
1261
+ update_depth_view(pd, int(sv.split()[1]) - 1) if sv else None
1262
+ ),
1263
+ inputs=[processed_data_state, depth_view_selector],
1264
+ outputs=[depth_map],
1265
+ )
1266
+
1267
+ # ---- Normal tab navigation ----
1268
+ prev_normal_btn.click(
1269
+ fn=lambda pd, cs: navigate_normal_view(pd, cs, -1),
1270
+ inputs=[processed_data_state, normal_view_selector],
1271
+ outputs=[normal_view_selector, normal_map],
1272
+ )
1273
+ next_normal_btn.click(
1274
+ fn=lambda pd, cs: navigate_normal_view(pd, cs, 1),
1275
+ inputs=[processed_data_state, normal_view_selector],
1276
+ outputs=[normal_view_selector, normal_map],
1277
+ )
1278
+ normal_view_selector.change(
1279
+ fn=lambda pd, sv: (
1280
+ update_normal_view(pd, int(sv.split()[1]) - 1) if sv else None
1281
+ ),
1282
+ inputs=[processed_data_state, normal_view_selector],
1283
+ outputs=[normal_map],
1284
+ )
1285
+
1286
+ # ---- Measure tab navigation ----
1287
+ prev_measure_btn.click(
1288
+ fn=lambda pd, cs: navigate_measure_view(pd, cs, -1),
1289
+ inputs=[processed_data_state, measure_view_selector],
1290
+ outputs=[measure_view_selector, measure_image, measure_points_state],
1291
+ )
1292
+ next_measure_btn.click(
1293
+ fn=lambda pd, cs: navigate_measure_view(pd, cs, 1),
1294
+ inputs=[processed_data_state, measure_view_selector],
1295
+ outputs=[measure_view_selector, measure_image, measure_points_state],
1296
+ )
1297
+ measure_view_selector.change(
1298
+ fn=lambda pd, sv: (
1299
+ update_measure_view(pd, int(sv.split()[1]) - 1)
1300
+ if sv
1301
+ else (None, [])
1302
+ ),
1303
+ inputs=[processed_data_state, measure_view_selector],
1304
+ outputs=[measure_image, measure_points_state],
1305
+ )
1306
+
1307
+ # ---- Measure click handler ----
1308
+ measure_image.select(
1309
+ fn=measure,
1310
+ inputs=[processed_data_state, measure_points_state, measure_view_selector],
1311
+ outputs=[measure_image, measure_points_state, measure_text],
1312
+ )
1313
+
1314
+
1315
+ demo.queue(max_size=20).launch(
1316
+ theme=theme, css=APP_CSS, show_error=True, share=True, ssr_mode=False
1317
+ )