prithivMLmods commited on
Commit
fee1c33
·
verified ·
1 Parent(s): eb7073f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1122 -416
app.py CHANGED
@@ -1,10 +1,3 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- # conda activate hf3.10
7
-
8
  import gc
9
  import os
10
  import shutil
@@ -27,19 +20,16 @@ register_heif_opener()
27
  sys.path.append("mapanything/")
28
 
29
  from mapanything.utils.geometry import depthmap_to_world_frame, points_to_normals
 
 
 
 
 
 
30
  from mapanything.utils.hf_utils.hf_helpers import initialize_mapanything_model
31
  from mapanything.utils.hf_utils.viz import predictions_to_glb
32
  from mapanything.utils.image import load_images, rgb
33
 
34
- import rerun as rr
35
-
36
- # Attempt to import blueprint for advanced view configuration
37
- try:
38
- import rerun.blueprint as rrb
39
- except ImportError:
40
- rrb = None
41
-
42
- from gradio_rerun import Rerun
43
 
44
  # MapAnything Configuration
45
  high_level_config = {
@@ -61,177 +51,10 @@ high_level_config = {
61
  "resolution": 518,
62
  }
63
 
64
- # Initialize model
65
  model = None
66
 
67
 
68
- # -------------------------------------------------------------------------
69
- # Custom Modern CSS
70
- # -------------------------------------------------------------------------
71
- modern_css = r"""
72
- @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700;800&family=JetBrains+Mono:wght@400;500;600&display=swap');
73
-
74
- *{box-sizing:border-box;margin:0;padding:0}
75
-
76
- body, .gradio-container {
77
- background:#0f0f13!important;
78
- font-family:'Inter',system-ui,-apple-system,sans-serif!important;
79
- font-size:14px!important;
80
- color:#e4e4e7!important;
81
- min-height:100vh;
82
- }
83
- .dark body, .dark .gradio-container {
84
- background:#0f0f13!important;
85
- color:#e4e4e7!important;
86
- }
87
- footer{display:none!important}
88
-
89
- .app-shell {
90
- background:#18181b;
91
- border:1px solid #27272a;
92
- border-radius:16px;
93
- margin:12px auto;
94
- max-width:1400px;
95
- overflow:hidden;
96
- box-shadow:0 25px 50px -12px rgba(0,0,0,.6), 0 0 0 1px rgba(255,255,255,.03);
97
- display: flex;
98
- flex-direction: column;
99
- }
100
-
101
- .app-header {
102
- background:linear-gradient(135deg,#18181b 0%,#1e1e24 100%);
103
- border-bottom:1px solid #27272a;
104
- padding:14px 24px;
105
- display:flex;
106
- align-items:center;
107
- justify-content:space-between;
108
- }
109
-
110
- .app-header-left {
111
- display:flex;
112
- align-items:center;
113
- gap:12px;
114
- }
115
-
116
- .app-logo {
117
- width:36px;height:36px;
118
- background:linear-gradient(135deg,#6366f1,#8b5cf6,#a78bfa);
119
- border-radius:10px;
120
- display:flex;align-items:center;justify-content:center;
121
- font-size:18px;font-weight:800;color:#fff;
122
- box-shadow:0 4px 12px rgba(99,102,241,.35);
123
- }
124
-
125
- .app-title {
126
- font-size:18px;font-weight:700;
127
- background:linear-gradient(135deg,#e4e4e7,#a1a1aa);
128
- -webkit-background-clip:text;
129
- -webkit-text-fill-color:transparent;
130
- letter-spacing:-.3px;
131
- }
132
-
133
- .app-badge {
134
- font-size:11px;
135
- font-weight:600;
136
- padding:3px 10px;
137
- border-radius:20px;
138
- background:rgba(99,102,241,.15);
139
- color:#818cf8;
140
- border:1px solid rgba(99,102,241,.25);
141
- letter-spacing:.3px;
142
- }
143
-
144
- .app-main-row {
145
- display:flex;
146
- gap:0;
147
- flex:1;
148
- overflow:hidden;
149
- flex-wrap: wrap;
150
- }
151
-
152
- .app-main-left {
153
- flex: 1;
154
- min-width: 350px;
155
- border-right:1px solid #27272a;
156
- padding: 20px;
157
- display: flex;
158
- flex-direction: column;
159
- gap: 16px;
160
- }
161
-
162
- .app-main-right {
163
- flex: 2;
164
- min-width: 500px;
165
- display:flex;
166
- flex-direction:column;
167
- background:#18181b;
168
- }
169
-
170
- .modern-btn {
171
- display:flex;align-items:center;justify-content:center;gap:8px;
172
- width:100%;
173
- background:linear-gradient(135deg,#6366f1,#7c3aed);
174
- border:none;
175
- border-radius:10px;
176
- padding:12px 24px;
177
- cursor:pointer;
178
- font-size:15px;
179
- font-weight:600;
180
- color:#fff!important;
181
- transition:all .2s ease;
182
- box-shadow:0 4px 16px rgba(99,102,241,.3), inset 0 1px 0 rgba(255,255,255,.1);
183
- }
184
- .modern-btn:hover {
185
- background:linear-gradient(135deg,#7c7cf5,#8b5cf6);
186
- transform:translateY(-1px);
187
- }
188
- .modern-btn.secondary {
189
- background:#27272a;
190
- box-shadow:none;
191
- border: 1px solid #3f3f46;
192
- }
193
- .modern-btn.secondary:hover {
194
- background:#3f3f46;
195
- }
196
-
197
- .settings-group {
198
- border:1px solid #27272a;
199
- border-radius:10px;
200
- overflow:hidden;
201
- background:#18181b;
202
- margin-bottom: 12px;
203
- }
204
- .settings-group-title {
205
- font-size:12px;
206
- font-weight:600;
207
- color:#71717a;
208
- text-transform:uppercase;
209
- letter-spacing:.8px;
210
- padding:10px 16px;
211
- border-bottom:1px solid #27272a;
212
- background:rgba(24,24,27,.5);
213
- }
214
- .settings-group-body {
215
- padding:14px 16px;
216
- display:flex;
217
- flex-direction:column;
218
- gap:12px;
219
- }
220
-
221
- /* Gradio Overrides */
222
- .tabs { background: transparent !important; border: none !important; }
223
- .tab-nav { border-bottom: 1px solid #27272a !important; margin-bottom: 10px; }
224
- .tab-nav button { color: #a1a1aa !important; border-radius: 8px 8px 0 0 !important; font-weight: 600 !important; }
225
- .tab-nav button.selected { color: #818cf8 !important; border-bottom: 2px solid #818cf8 !important; background: rgba(99,102,241,.08) !important; }
226
- .gradio-dropdown { background: #09090b !important; border: 1px solid #27272a !important; color: #fff !important; }
227
- .gallery-item { border-radius: 8px !important; overflow: hidden !important; border: 1px solid #27272a !important; }
228
-
229
- /* Custom log styling */
230
- .custom-log { padding: 12px; border-radius: 8px; background: rgba(99, 102, 241, 0.1); border: 1px solid rgba(99, 102, 241, 0.2); color: #c7d2fe; font-family: 'JetBrains Mono', monospace; font-size: 12px; }
231
- """
232
-
233
- DOWNLOAD_SVG = '<svg viewBox="0 0 24 24" width="16" height="16" fill="currentColor" xmlns="http://www.w3.org/2000/svg"><path d="M12 16l-5-5h3V4h4v7h3l-5 5z"/><path d="M20 18H4v2h16v-2z"/></svg>'
234
-
235
  # -------------------------------------------------------------------------
236
  # 1) Core model inference
237
  # -------------------------------------------------------------------------
@@ -243,92 +66,126 @@ def run_model(
243
  filter_black_bg=False,
244
  filter_white_bg=False,
245
  ):
 
 
 
246
  global model
247
- import torch
 
 
248
 
 
249
  device = "cuda" if torch.cuda.is_available() else "cpu"
250
  device = torch.device(device)
251
 
 
252
  if model is None:
253
  model = initialize_mapanything_model(high_level_config, device)
 
254
  else:
255
  model = model.to(device)
256
 
257
  model.eval()
258
 
 
 
259
  image_folder_path = os.path.join(target_dir, "images")
260
  views = load_images(image_folder_path)
261
 
 
262
  if len(views) == 0:
263
  raise ValueError("No images found. Check your upload.")
264
 
 
 
 
 
 
265
  outputs = model.infer(
266
  views, apply_mask=apply_mask, mask_edges=True, memory_efficient_inference=False
267
  )
268
 
 
269
  predictions = {}
270
- extrinsic_list, intrinsic_list, world_points_list = [], [], []
271
- depth_maps_list, images_list, final_mask_list = [], [], []
272
 
 
 
 
 
 
 
 
 
 
273
  for pred in outputs:
274
- depthmap_torch = pred["depth_z"][0].squeeze(-1)
275
- intrinsics_torch = pred["intrinsics"][0]
276
- camera_pose_torch = pred["camera_poses"][0]
 
277
 
 
278
  pts3d_computed, valid_mask = depthmap_to_world_frame(
279
  depthmap_torch, intrinsics_torch, camera_pose_torch
280
  )
281
 
 
 
282
  if "mask" in pred:
283
  mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool)
284
  else:
 
285
  mask = np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool)
286
 
 
287
  mask = mask & valid_mask.cpu().numpy()
 
288
  image = pred["img_no_norm"][0].cpu().numpy()
289
 
 
290
  extrinsic_list.append(camera_pose_torch.cpu().numpy())
291
  intrinsic_list.append(intrinsics_torch.cpu().numpy())
292
  world_points_list.append(pts3d_computed.cpu().numpy())
293
  depth_maps_list.append(depthmap_torch.cpu().numpy())
294
- images_list.append(image)
295
- final_mask_list.append(mask)
296
 
 
 
297
  predictions["extrinsic"] = np.stack(extrinsic_list, axis=0)
 
 
298
  predictions["intrinsic"] = np.stack(intrinsic_list, axis=0)
 
 
299
  predictions["world_points"] = np.stack(world_points_list, axis=0)
300
 
 
301
  depth_maps = np.stack(depth_maps_list, axis=0)
 
302
  if len(depth_maps.shape) == 3:
303
  depth_maps = depth_maps[..., np.newaxis]
 
304
  predictions["depth"] = depth_maps
 
 
305
  predictions["images"] = np.stack(images_list, axis=0)
 
 
306
  predictions["final_mask"] = np.stack(final_mask_list, axis=0)
307
 
 
308
  processed_data = process_predictions_for_visualization(
309
  predictions, views, high_level_config, filter_black_bg, filter_white_bg
310
  )
311
 
 
312
  torch.cuda.empty_cache()
 
313
  return predictions, processed_data
314
 
315
- def generate_rrd(glb_file, target_dir):
316
- """Generates an RRD file from the exported GLB to be visualized via Rerun."""
317
- rrd_path = os.path.join(target_dir, "reconstruction.rrd")
318
- rr.init("MapAnything", spawn=False)
319
- rr.log("scene", rr.Asset3D(path=glb_file))
320
-
321
- if rrb is not None:
322
- blueprint = rrb.Blueprint(
323
- rrb.Spatial3DView(origin="scene"),
324
- rrb.TimePanel(state="collapsed"),
325
- )
326
- rr.save(rrd_path, default_blueprint=blueprint)
327
- else:
328
- rr.save(rrd_path)
329
- return rrd_path
330
 
331
  def update_view_selectors(processed_data):
 
332
  if processed_data is None or len(processed_data) == 0:
333
  choices = ["View 1"]
334
  else:
@@ -336,116 +193,173 @@ def update_view_selectors(processed_data):
336
  choices = [f"View {i + 1}" for i in range(num_views)]
337
 
338
  return (
339
- gr.Dropdown(choices=choices, value=choices[0]),
340
- gr.Dropdown(choices=choices, value=choices[0]),
341
- gr.Dropdown(choices=choices, value=choices[0]),
342
  )
343
 
 
344
  def get_view_data_by_index(processed_data, view_index):
 
345
  if processed_data is None or len(processed_data) == 0:
346
  return None
 
347
  view_keys = list(processed_data.keys())
348
  if view_index < 0 or view_index >= len(view_keys):
349
  view_index = 0
 
350
  return processed_data[view_keys[view_index]]
351
 
 
352
  def update_depth_view(processed_data, view_index):
 
353
  view_data = get_view_data_by_index(processed_data, view_index)
354
  if view_data is None or view_data["depth"] is None:
355
  return None
 
356
  return colorize_depth(view_data["depth"], mask=view_data.get("mask"))
357
 
 
358
  def update_normal_view(processed_data, view_index):
 
359
  view_data = get_view_data_by_index(processed_data, view_index)
360
  if view_data is None or view_data["normal"] is None:
361
  return None
 
362
  return colorize_normal(view_data["normal"], mask=view_data.get("mask"))
363
 
 
364
  def update_measure_view(processed_data, view_index):
 
365
  view_data = get_view_data_by_index(processed_data, view_index)
366
  if view_data is None:
367
- return None, []
368
-
 
369
  image = view_data["image"].copy()
 
 
370
  if image.dtype != np.uint8:
371
  if image.max() <= 1.0:
372
  image = (image * 255).astype(np.uint8)
373
  else:
374
  image = image.astype(np.uint8)
375
 
 
376
  if view_data["mask"] is not None:
377
  mask = view_data["mask"]
378
- invalid_mask = ~mask
 
 
 
 
379
  if invalid_mask.any():
 
380
  overlay_color = np.array([255, 220, 220], dtype=np.uint8)
381
- alpha = 0.5
382
- for c in range(3):
 
 
383
  image[:, :, c] = np.where(
384
  invalid_mask,
385
  (1 - alpha) * image[:, :, c] + alpha * overlay_color[c],
386
  image[:, :, c],
387
  ).astype(np.uint8)
 
388
  return image, []
389
 
 
390
  def navigate_depth_view(processed_data, current_selector_value, direction):
 
391
  if processed_data is None or len(processed_data) == 0:
392
  return "View 1", None
 
 
393
  try:
394
  current_view = int(current_selector_value.split()[1]) - 1
395
  except:
396
  current_view = 0
 
397
  num_views = len(processed_data)
398
  new_view = (current_view + direction) % num_views
 
399
  new_selector_value = f"View {new_view + 1}"
400
  depth_vis = update_depth_view(processed_data, new_view)
 
401
  return new_selector_value, depth_vis
402
 
 
403
  def navigate_normal_view(processed_data, current_selector_value, direction):
 
404
  if processed_data is None or len(processed_data) == 0:
405
  return "View 1", None
 
 
406
  try:
407
  current_view = int(current_selector_value.split()[1]) - 1
408
  except:
409
  current_view = 0
 
410
  num_views = len(processed_data)
411
  new_view = (current_view + direction) % num_views
 
412
  new_selector_value = f"View {new_view + 1}"
413
  normal_vis = update_normal_view(processed_data, new_view)
 
414
  return new_selector_value, normal_vis
415
 
 
416
  def navigate_measure_view(processed_data, current_selector_value, direction):
 
417
  if processed_data is None or len(processed_data) == 0:
418
  return "View 1", None, []
 
 
419
  try:
420
  current_view = int(current_selector_value.split()[1]) - 1
421
  except:
422
  current_view = 0
 
423
  num_views = len(processed_data)
424
  new_view = (current_view + direction) % num_views
 
425
  new_selector_value = f"View {new_view + 1}"
426
  measure_image, measure_points = update_measure_view(processed_data, new_view)
 
427
  return new_selector_value, measure_image, measure_points
428
 
 
429
  def populate_visualization_tabs(processed_data):
 
430
  if processed_data is None or len(processed_data) == 0:
431
  return None, None, None, []
 
 
432
  depth_vis = update_depth_view(processed_data, 0)
433
  normal_vis = update_normal_view(processed_data, 0)
434
  measure_img, _ = update_measure_view(processed_data, 0)
 
435
  return depth_vis, normal_vis, measure_img, []
436
 
 
437
  # -------------------------------------------------------------------------
438
- # 2) Handle uploaded video/images
439
  # -------------------------------------------------------------------------
440
  def handle_uploads(unified_upload, s_time_interval=1.0):
 
 
 
 
441
  start_time = time.time()
442
  gc.collect()
443
  torch.cuda.empty_cache()
444
 
 
445
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
446
  target_dir = f"input_images_{timestamp}"
447
  target_dir_images = os.path.join(target_dir, "images")
448
 
 
449
  if os.path.exists(target_dir):
450
  shutil.rmtree(target_dir)
451
  os.makedirs(target_dir)
@@ -453,6 +367,7 @@ def handle_uploads(unified_upload, s_time_interval=1.0):
453
 
454
  image_paths = []
455
 
 
456
  if unified_upload is not None:
457
  for file_data in unified_upload:
458
  if isinstance(file_data, dict) and "name" in file_data:
@@ -461,12 +376,24 @@ def handle_uploads(unified_upload, s_time_interval=1.0):
461
  file_path = str(file_data)
462
 
463
  file_ext = os.path.splitext(file_path)[1].lower()
464
- video_extensions = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"]
465
-
 
 
 
 
 
 
 
 
 
 
 
466
  if file_ext in video_extensions:
 
467
  vs = cv2.VideoCapture(file_path)
468
  fps = vs.get(cv2.CAP_PROP_FPS)
469
- frame_interval = int(fps * s_time_interval)
470
 
471
  count = 0
472
  video_frame_num = 0
@@ -476,42 +403,90 @@ def handle_uploads(unified_upload, s_time_interval=1.0):
476
  break
477
  count += 1
478
  if count % frame_interval == 0:
 
479
  base_name = os.path.splitext(os.path.basename(file_path))[0]
480
- image_path = os.path.join(target_dir_images, f"{base_name}_{video_frame_num:06}.png")
 
 
481
  cv2.imwrite(image_path, frame)
482
  image_paths.append(image_path)
483
  video_frame_num += 1
484
  vs.release()
 
 
 
 
485
  else:
 
 
486
  if file_ext in [".heic", ".heif"]:
 
487
  try:
488
  with Image.open(file_path) as img:
 
489
  if img.mode not in ("RGB", "L"):
490
  img = img.convert("RGB")
 
 
491
  base_name = os.path.splitext(os.path.basename(file_path))[0]
492
- dst_path = os.path.join(target_dir_images, f"{base_name}.jpg")
 
 
 
 
493
  img.save(dst_path, "JPEG", quality=95)
494
  image_paths.append(dst_path)
 
 
 
495
  except Exception as e:
496
- dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
 
 
 
 
497
  shutil.copy(file_path, dst_path)
498
  image_paths.append(dst_path)
499
  else:
500
- dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
 
 
 
501
  shutil.copy(file_path, dst_path)
502
  image_paths.append(dst_path)
503
 
 
504
  image_paths = sorted(image_paths)
 
 
 
 
 
505
  return target_dir, image_paths
506
 
 
 
 
 
507
  def update_gallery_on_upload(input_video, input_images, s_time_interval=1.0):
 
 
 
 
 
508
  if not input_video and not input_images:
509
  return None, None, None, None
510
  target_dir, image_paths = handle_uploads(input_video, input_images, s_time_interval)
511
- return None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin."
 
 
 
 
 
 
512
 
513
  # -------------------------------------------------------------------------
514
- # 4) Reconstruction
515
  # -------------------------------------------------------------------------
516
  @spaces.GPU(duration=120)
517
  def gradio_demo(
@@ -523,133 +498,233 @@ def gradio_demo(
523
  apply_mask=True,
524
  show_mesh=True,
525
  ):
 
 
 
526
  if not os.path.isdir(target_dir) or target_dir == "None":
527
  return None, "No valid target directory found. Please upload first.", None, None
528
 
 
529
  gc.collect()
530
  torch.cuda.empty_cache()
531
 
 
532
  target_dir_images = os.path.join(target_dir, "images")
533
- all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
 
 
 
 
534
  all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
535
  frame_filter_choices = ["All"] + all_files
536
 
 
537
  with torch.no_grad():
538
  predictions, processed_data = run_model(target_dir, apply_mask)
539
 
 
540
  prediction_save_path = os.path.join(target_dir, "predictions.npz")
541
  np.savez(prediction_save_path, **predictions)
542
 
 
543
  if frame_filter is None:
544
  frame_filter = "All"
545
 
 
546
  glbfile = os.path.join(
547
  target_dir,
548
  f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb",
549
  )
550
 
 
551
  glbscene = predictions_to_glb(
552
  predictions,
553
  filter_by_frames=frame_filter,
554
  show_cam=show_cam,
555
  mask_black_bg=filter_black_bg,
556
  mask_white_bg=filter_white_bg,
557
- as_mesh=show_mesh,
558
  )
559
  glbscene.export(file_obj=glbfile)
560
-
561
- # Generate RRD
562
- rrd_path = generate_rrd(glbfile, target_dir)
563
 
 
564
  del predictions
565
  gc.collect()
566
  torch.cuda.empty_cache()
567
 
568
- log_msg = f"Reconstruction Success ({len(all_files)} frames). Visualization active."
 
 
 
 
 
 
 
 
 
569
 
570
- depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs(processed_data)
571
- depth_selector, normal_selector, measure_selector = update_view_selectors(processed_data)
 
 
572
 
573
  return (
574
- rrd_path,
575
  log_msg,
576
  gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
577
  processed_data,
578
  depth_vis,
579
  normal_vis,
580
  measure_img,
581
- "",
582
  depth_selector,
583
  normal_selector,
584
  measure_selector,
585
  )
586
 
 
 
 
 
587
  def colorize_depth(depth_map, mask=None):
588
- if depth_map is None: return None
 
 
 
 
589
  depth_normalized = depth_map.copy()
590
  valid_mask = depth_normalized > 0
591
- if mask is not None: valid_mask = valid_mask & mask
 
 
 
592
 
593
  if valid_mask.sum() > 0:
594
  valid_depths = depth_normalized[valid_mask]
595
  p5 = np.percentile(valid_depths, 5)
596
  p95 = np.percentile(valid_depths, 95)
 
597
  depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5)
598
 
 
599
  import matplotlib.pyplot as plt
 
600
  colormap = plt.cm.turbo_r
601
  colored = colormap(depth_normalized)
602
  colored = (colored[:, :, :3] * 255).astype(np.uint8)
 
 
603
  colored[~valid_mask] = [255, 255, 255]
 
604
  return colored
605
 
 
606
  def colorize_normal(normal_map, mask=None):
607
- if normal_map is None: return None
 
 
 
 
608
  normal_vis = normal_map.copy()
 
 
609
  if mask is not None:
610
  invalid_mask = ~mask
611
- normal_vis[invalid_mask] = [0, 0, 0]
 
 
612
  normal_vis = (normal_vis + 1.0) / 2.0
613
  normal_vis = (normal_vis * 255).astype(np.uint8)
 
614
  return normal_vis
615
 
 
616
  def process_predictions_for_visualization(
617
  predictions, views, high_level_config, filter_black_bg=False, filter_white_bg=False
618
  ):
 
619
  processed_data = {}
 
 
620
  for view_idx, view in enumerate(views):
 
621
  image = rgb(view["img"], norm_type=high_level_config["data_norm_type"])
 
 
622
  pred_pts3d = predictions["world_points"][view_idx]
623
- view_data = {"image": image[0], "points3d": pred_pts3d, "depth": None, "normal": None, "mask": None}
 
 
 
 
 
 
 
 
 
 
624
  mask = predictions["final_mask"][view_idx].copy()
625
 
 
626
  if filter_black_bg:
 
627
  view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
 
628
  black_bg_mask = view_colors.sum(axis=2) >= 16
629
  mask = mask & black_bg_mask
630
 
 
631
  if filter_white_bg:
 
632
  view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
633
- white_bg_mask = ~((view_colors[:, :, 0] > 240) & (view_colors[:, :, 1] > 240) & (view_colors[:, :, 2] > 240))
 
 
 
 
 
634
  mask = mask & white_bg_mask
635
 
636
  view_data["mask"] = mask
637
  view_data["depth"] = predictions["depth"][view_idx].squeeze()
 
638
  normals, _ = points_to_normals(pred_pts3d, mask=view_data["mask"])
639
  view_data["normal"] = normals
 
640
  processed_data[view_idx] = view_data
641
 
642
  return processed_data
643
 
644
- def measure(processed_data, measure_points, current_view_selector, event: gr.SelectData):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
645
  try:
 
 
646
  if processed_data is None or len(processed_data) == 0:
647
  return None, [], "No data available"
 
 
648
  try:
649
  current_view_index = int(current_view_selector.split()[1]) - 1
650
  except:
651
  current_view_index = 0
652
 
 
 
 
653
  if current_view_index < 0 or current_view_index >= len(processed_data):
654
  current_view_index = 0
655
 
@@ -660,60 +735,146 @@ def measure(processed_data, measure_points, current_view_selector, event: gr.Sel
660
  return None, [], "No view data available"
661
 
662
  point2d = event.index[0], event.index[1]
663
-
664
- if (current_view["mask"] is not None and 0 <= point2d[1] < current_view["mask"].shape[0] and 0 <= point2d[0] < current_view["mask"].shape[1]):
 
 
 
 
 
 
 
665
  if not current_view["mask"][point2d[1], point2d[0]]:
666
- masked_image, _ = update_measure_view(processed_data, current_view_index)
667
- return masked_image, measure_points, 'Cannot measure on masked areas (shown in grey)'
 
 
 
 
 
 
 
 
668
 
669
  measure_points.append(point2d)
 
 
670
  image, _ = update_measure_view(processed_data, current_view_index)
 
 
 
671
  image = image.copy()
672
  points3d = current_view["points3d"]
673
 
674
- if image.dtype != np.uint8:
675
- image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
676
 
677
- for p in measure_points:
678
- if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]:
679
- image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2)
 
 
 
 
 
 
 
680
 
681
  depth_text = ""
682
- for i, p in enumerate(measure_points):
683
- if current_view["depth"] is not None and 0 <= p[1] < current_view["depth"].shape[0] and 0 <= p[0] < current_view["depth"].shape[1]:
684
- d = current_view["depth"][p[1], p[0]]
685
- depth_text += f"P{i + 1} depth: {d:.2f}m.\n"
686
- else:
687
- if points3d is not None and 0 <= p[1] < points3d.shape[0] and 0 <= p[0] < points3d.shape[1]:
688
- z = points3d[p[1], p[0], 2]
689
- depth_text += f"P{i + 1} Z-coord: {z:.2f}m.\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
690
 
691
  if len(measure_points) == 2:
692
- point1, point2 = measure_points
693
- if (0 <= point1[0] < image.shape[1] and 0 <= point1[1] < image.shape[0] and 0 <= point2[0] < image.shape[1] and 0 <= point2[1] < image.shape[0]):
694
- image = cv2.line(image, point1, point2, color=(255, 0, 0), thickness=2)
695
-
696
- distance_text = "Distance: Unable to compute"
697
- if (points3d is not None and 0 <= point1[1] < points3d.shape[0] and 0 <= point1[0] < points3d.shape[1] and 0 <= point2[1] < points3d.shape[0] and 0 <= point2[0] < points3d.shape[1]):
698
- p1_3d = points3d[point1[1], point1[0]]
699
- p2_3d = points3d[point2[1], point2[0]]
700
- distance = np.linalg.norm(p1_3d - p2_3d)
701
- distance_text = f"Distance: {distance:.2f}m"
702
-
703
- measure_points = []
704
- return [image, measure_points, depth_text + "\n" + distance_text]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
705
  else:
 
706
  return [image, measure_points, depth_text]
707
 
708
  except Exception as e:
 
709
  return None, [], f"Measure function error: {e}"
710
 
 
711
  def clear_fields():
 
 
 
712
  return None
713
 
 
714
  def update_log():
 
 
 
715
  return "Loading and Reconstructing..."
716
 
 
717
  def update_visualization(
718
  target_dir,
719
  frame_filter,
@@ -723,15 +884,30 @@ def update_visualization(
723
  filter_white_bg=False,
724
  show_mesh=True,
725
  ):
 
 
 
 
 
 
726
  if is_example == "True":
727
- return gr.update(), "Please reconstruct first."
 
 
 
728
 
729
  if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
730
- return gr.update(), "Please reconstruct first."
 
 
 
731
 
732
  predictions_path = os.path.join(target_dir, "predictions.npz")
733
  if not os.path.exists(predictions_path):
734
- return gr.update(), "No predictions found."
 
 
 
735
 
736
  loaded = np.load(predictions_path, allow_pickle=True)
737
  predictions = {key: loaded[key] for key in loaded.keys()}
@@ -743,16 +919,35 @@ def update_visualization(
743
 
744
  if not os.path.exists(glbfile):
745
  glbscene = predictions_to_glb(
746
- predictions, filter_by_frames=frame_filter, show_cam=show_cam, mask_black_bg=filter_black_bg, mask_white_bg=filter_white_bg, as_mesh=show_mesh,
 
 
 
 
 
747
  )
748
  glbscene.export(file_obj=glbfile)
749
 
750
- rrd_path = generate_rrd(glbfile, target_dir)
751
- return rrd_path, "Visualization updated."
 
 
 
752
 
753
  def update_all_views_on_filter_change(
754
- target_dir, filter_black_bg, filter_white_bg, processed_data, depth_view_selector, normal_view_selector, measure_view_selector,
 
 
 
 
 
 
755
  ):
 
 
 
 
 
756
  if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
757
  return processed_data, None, None, None, []
758
 
@@ -761,211 +956,722 @@ def update_all_views_on_filter_change(
761
  return processed_data, None, None, None, []
762
 
763
  try:
 
764
  loaded = np.load(predictions_path, allow_pickle=True)
765
  predictions = {key: loaded[key] for key in loaded.keys()}
766
- views = load_images(os.path.join(target_dir, "images"))
767
 
 
 
 
 
 
768
  new_processed_data = process_predictions_for_visualization(
769
  predictions, views, high_level_config, filter_black_bg, filter_white_bg
770
  )
771
 
772
- depth_view_idx = int(depth_view_selector.split()[1]) - 1 if depth_view_selector else 0
773
- normal_view_idx = int(normal_view_selector.split()[1]) - 1 if normal_view_selector else 0
774
- measure_view_idx = int(measure_view_selector.split()[1]) - 1 if measure_view_selector else 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
775
 
 
776
  depth_vis = update_depth_view(new_processed_data, depth_view_idx)
777
  normal_vis = update_normal_view(new_processed_data, normal_view_idx)
778
  measure_img, _ = update_measure_view(new_processed_data, measure_view_idx)
779
 
780
  return new_processed_data, depth_vis, normal_vis, measure_img, []
 
781
  except Exception as e:
 
782
  return processed_data, None, None, None, []
783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
784
  # -------------------------------------------------------------------------
785
- # Build Gradio UI
786
  # -------------------------------------------------------------------------
 
787
 
788
- with gr.Blocks() as demo:
789
- is_example = gr.Textbox(visible=False, value="None")
790
- num_images = gr.Textbox(visible=False, value="None")
 
791
  processed_data_state = gr.State(value=None)
792
  measure_points_state = gr.State(value=[])
793
- current_view_index = gr.State(value=0)
794
- target_dir_output = gr.Textbox(visible=False, value="None")
795
-
796
- with gr.Column(elem_classes="app-shell"):
797
- # Header
798
- gr.HTML(f"""
799
- <div class="app-header">
800
- <div class="app-header-left">
801
- <div class="app-logo">M3D</div>
802
- <span class="app-title">MapAnything 3D Reconstruction</span>
803
- <span class="app-badge">Rerun View</span>
804
- </div>
805
- </div>
806
- """)
807
-
808
- with gr.Row(elem_classes="app-main-row"):
809
- # Left Panel
810
- with gr.Column(elem_classes="app-main-left"):
811
-
812
- gr.Markdown("### Input Media")
813
- unified_upload = gr.File(
814
- file_count="multiple",
815
- label="Upload Video or Images",
816
  interactive=True,
817
- file_types=["image", "video"],
 
818
  )
819
-
820
- with gr.Row():
821
- s_time_interval = gr.Slider(
822
- minimum=0.1, maximum=5.0, value=1.0, step=0.1,
823
- label="Video Sample Interval (seconds)",
824
- interactive=True,
825
- scale=3,
826
- )
827
- resample_btn = gr.Button("Resample", visible=False, elem_classes="modern-btn secondary", scale=1)
828
-
829
- image_gallery = gr.Gallery(
830
- label="Preview",
831
- columns=4,
832
- height="200px",
833
- show_download_button=True,
834
- object_fit="contain",
835
- preview=True,
836
  )
837
-
838
- with gr.Row():
839
- clear_uploads_btn = gr.ClearButton([unified_upload, image_gallery], value="Clear Media", elem_classes="modern-btn secondary")
840
- submit_btn = gr.Button("Reconstruct", elem_classes="modern-btn")
841
-
842
- log_output = gr.Markdown("Ready to reconstruct.", elem_classes="custom-log")
843
-
844
- with gr.Column(elem_classes="settings-group"):
845
- gr.HTML("<div class='settings-group-title'>Reconstruction Options</div>")
846
- with gr.Column(elem_classes="settings-group-body"):
847
- frame_filter = gr.Dropdown(choices=["All"], value="All", label="Show Points from Frame")
848
- show_cam = gr.Checkbox(label="Show Camera", value=True)
849
- show_mesh = gr.Checkbox(label="Show Mesh", value=True)
850
- filter_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
851
- filter_white_bg = gr.Checkbox(label="Filter White Background", value=False)
852
- apply_mask_checkbox = gr.Checkbox(label="Apply mask for predicted depth & edges", value=True)
853
-
854
- # Right Panel
855
- with gr.Column(elem_classes="app-main-right"):
 
 
 
 
 
 
 
 
 
856
  with gr.Tabs():
857
  with gr.Tab("3D View"):
858
- rerun_output = Rerun(label="Rerun 3D Viewer")
859
-
 
 
 
 
 
 
860
  with gr.Tab("Depth"):
861
- with gr.Row():
862
- prev_depth_btn = gr.Button("Previous", size="sm", elem_classes="modern-btn secondary")
863
- depth_view_selector = gr.Dropdown(choices=["View 1"], value="View 1", label="Select View", interactive=True)
864
- next_depth_btn = gr.Button("Next", size="sm", elem_classes="modern-btn secondary")
865
- depth_map = gr.Image(type="numpy", label="Colorized Depth Map", interactive=False)
866
-
 
 
 
 
 
 
 
 
 
 
 
867
  with gr.Tab("Normal"):
868
- with gr.Row():
869
- prev_normal_btn = gr.Button("Previous", size="sm", elem_classes="modern-btn secondary")
870
- normal_view_selector = gr.Dropdown(choices=["View 1"], value="View 1", label="Select View", interactive=True)
871
- next_normal_btn = gr.Button("Next", size="sm", elem_classes="modern-btn secondary")
872
- normal_map = gr.Image(type="numpy", label="Normal Map", interactive=False)
873
-
 
 
 
 
 
 
 
 
 
 
 
 
 
874
  with gr.Tab("Measure"):
875
- with gr.Row():
876
- prev_measure_btn = gr.Button("Previous", size="sm", elem_classes="modern-btn secondary")
877
- measure_view_selector = gr.Dropdown(choices=["View 1"], value="View 1", label="Select View", interactive=True)
878
- next_measure_btn = gr.Button("Next", size="sm", elem_classes="modern-btn secondary")
879
- measure_image = gr.Image(type="numpy", show_label=False, interactive=False)
880
- measure_text = gr.Markdown("Select points to measure depth and distance.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
881
 
882
  # -------------------------------------------------------------------------
883
- # Event Listeners
 
 
 
 
884
  # -------------------------------------------------------------------------
885
- submit_btn.click(
886
- fn=clear_fields, inputs=[], outputs=[rerun_output]
887
- ).then(
888
  fn=update_log, inputs=[], outputs=[log_output]
889
  ).then(
890
  fn=gradio_demo,
891
  inputs=[
892
- target_dir_output, frame_filter, show_cam, filter_black_bg, filter_white_bg, apply_mask_checkbox, show_mesh,
 
 
 
 
 
 
893
  ],
894
  outputs=[
895
- rerun_output, log_output, frame_filter, processed_data_state,
896
- depth_map, normal_map, measure_image, measure_text,
897
- depth_view_selector, normal_view_selector, measure_view_selector,
 
 
 
 
 
 
 
 
898
  ],
899
- ).then(fn=lambda: "False", inputs=[], outputs=[is_example])
 
 
 
 
900
 
 
 
 
901
  frame_filter.change(
902
- update_visualization, [target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh], [rerun_output, log_output]
 
 
 
 
 
 
 
 
 
 
903
  )
904
  show_cam.change(
905
- update_visualization, [target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh], [rerun_output, log_output]
906
- )
907
- show_mesh.change(
908
- update_visualization, [target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh], [rerun_output, log_output]
 
 
 
 
909
  )
910
-
911
  filter_black_bg.change(
912
- update_visualization, [target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh], [rerun_output, log_output]
 
 
 
 
 
 
 
 
 
913
  ).then(
914
- update_all_views_on_filter_change,
915
- [target_dir_output, filter_black_bg, filter_white_bg, processed_data_state, depth_view_selector, normal_view_selector, measure_view_selector],
916
- [processed_data_state, depth_map, normal_map, measure_image, measure_points_state]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
917
  )
918
-
919
  filter_white_bg.change(
920
- update_visualization, [target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh], [rerun_output, log_output]
 
 
 
 
 
 
 
 
 
 
921
  ).then(
922
- update_all_views_on_filter_change,
923
- [target_dir_output, filter_black_bg, filter_white_bg, processed_data_state, depth_view_selector, normal_view_selector, measure_view_selector],
924
- [processed_data_state, depth_map, normal_map, measure_image, measure_points_state]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
925
  )
926
 
 
 
 
927
  def update_gallery_on_unified_upload(files, interval):
928
- if not files: return None, None, None
 
929
  target_dir, image_paths = handle_uploads(files, interval)
930
- return target_dir, image_paths, "Upload complete. Click 'Reconstruct'."
 
 
 
 
931
 
932
  def show_resample_button(files):
933
- if not files: return gr.update(visible=False)
934
- video_extensions = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
935
  has_video = False
 
936
  for file_data in files:
937
- file_path = file_data["name"] if isinstance(file_data, dict) else str(file_data)
938
- if os.path.splitext(file_path)[1].lower() in video_extensions:
 
 
 
 
 
939
  has_video = True
940
  break
 
941
  return gr.update(visible=has_video)
942
 
 
 
 
 
943
  def resample_video_with_new_interval(files, new_interval, current_target_dir):
944
- if not files: return current_target_dir, None, "No files.", gr.update(visible=False)
945
- if current_target_dir and os.path.exists(current_target_dir): shutil.rmtree(current_target_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
946
  target_dir, image_paths = handle_uploads(files, new_interval)
947
- return target_dir, image_paths, "Video resampled.", gr.update(visible=False)
 
 
 
 
 
 
948
 
949
  unified_upload.change(
950
- update_gallery_on_unified_upload, [unified_upload, s_time_interval], [target_dir_output, image_gallery, log_output]
951
- ).then(show_resample_button, [unified_upload], [resample_btn])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
952
 
953
- s_time_interval.change(show_resample_button, [unified_upload], [resample_btn])
954
- resample_btn.click(resample_video_with_new_interval, [unified_upload, s_time_interval, target_dir_output], [target_dir_output, image_gallery, log_output, resample_btn])
 
 
 
 
 
 
955
 
956
- measure_image.select(measure, [processed_data_state, measure_points_state, measure_view_selector], [measure_image, measure_points_state, measure_text])
 
 
 
 
 
 
957
 
958
- prev_depth_btn.click(lambda pd, sel: navigate_depth_view(pd, sel, -1), [processed_data_state, depth_view_selector], [depth_view_selector, depth_map])
959
- next_depth_btn.click(lambda pd, sel: navigate_depth_view(pd, sel, 1), [processed_data_state, depth_view_selector], [depth_view_selector, depth_map])
960
- depth_view_selector.change(lambda pd, sel: update_depth_view(pd, int(sel.split()[1])-1) if sel else None, [processed_data_state, depth_view_selector], [depth_map])
 
 
 
 
 
 
 
 
 
961
 
962
- prev_normal_btn.click(lambda pd, sel: navigate_normal_view(pd, sel, -1), [processed_data_state, normal_view_selector], [normal_view_selector, normal_map])
963
- next_normal_btn.click(lambda pd, sel: navigate_normal_view(pd, sel, 1), [processed_data_state, normal_view_selector], [normal_view_selector, normal_map])
964
- normal_view_selector.change(lambda pd, sel: update_normal_view(pd, int(sel.split()[1])-1) if sel else None, [processed_data_state, normal_view_selector], [normal_map])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
965
 
966
- prev_measure_btn.click(lambda pd, sel: navigate_measure_view(pd, sel, -1), [processed_data_state, measure_view_selector], [measure_view_selector, measure_image, measure_points_state])
967
- next_measure_btn.click(lambda pd, sel: navigate_measure_view(pd, sel, 1), [processed_data_state, measure_view_selector], [measure_view_selector, measure_image, measure_points_state])
968
- measure_view_selector.change(lambda pd, sel: update_measure_view(pd, int(sel.split()[1])-1) if sel else (None, []), [processed_data_state, measure_view_selector], [measure_image, measure_points_state])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
969
 
970
- if __name__ == "__main__":
971
- demo.queue(max_size=20).launch(css=modern_css, show_error=True, share=True, ssr_mode=False)
 
 
 
 
 
 
 
 
1
  import gc
2
  import os
3
  import shutil
 
20
  sys.path.append("mapanything/")
21
 
22
  from mapanything.utils.geometry import depthmap_to_world_frame, points_to_normals
23
+ from mapanything.utils.hf_utils.css_and_html import (
24
+ GRADIO_CSS,
25
+ MEASURE_INSTRUCTIONS_HTML,
26
+ get_acknowledgements_html,
27
+ get_gradio_theme,
28
+ )
29
  from mapanything.utils.hf_utils.hf_helpers import initialize_mapanything_model
30
  from mapanything.utils.hf_utils.viz import predictions_to_glb
31
  from mapanything.utils.image import load_images, rgb
32
 
 
 
 
 
 
 
 
 
 
33
 
34
  # MapAnything Configuration
35
  high_level_config = {
 
51
  "resolution": 518,
52
  }
53
 
54
+ # Initialize model - this will be done on GPU when needed
55
  model = None
56
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  # -------------------------------------------------------------------------
59
  # 1) Core model inference
60
  # -------------------------------------------------------------------------
 
66
  filter_black_bg=False,
67
  filter_white_bg=False,
68
  ):
69
+ """
70
+ Run the MapAnything model on images in the 'target_dir/images' folder and return predictions.
71
+ """
72
  global model
73
+ import torch # Ensure torch is available in function scope
74
+
75
+ print(f"Processing images from {target_dir}")
76
 
77
+ # Device check
78
  device = "cuda" if torch.cuda.is_available() else "cpu"
79
  device = torch.device(device)
80
 
81
+ # Initialize model if not already done
82
  if model is None:
83
  model = initialize_mapanything_model(high_level_config, device)
84
+
85
  else:
86
  model = model.to(device)
87
 
88
  model.eval()
89
 
90
+ # Load images using MapAnything's load_images function
91
+ print("Loading images...")
92
  image_folder_path = os.path.join(target_dir, "images")
93
  views = load_images(image_folder_path)
94
 
95
+ print(f"Loaded {len(views)} images")
96
  if len(views) == 0:
97
  raise ValueError("No images found. Check your upload.")
98
 
99
+ # Run model inference
100
+ print("Running inference...")
101
+ # apply_mask: Whether to apply the non-ambiguous mask to the output. Defaults to True.
102
+ # mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True.
103
+ # Use checkbox values - mask_edges is set to True by default since there's no UI control for it
104
  outputs = model.infer(
105
  views, apply_mask=apply_mask, mask_edges=True, memory_efficient_inference=False
106
  )
107
 
108
+ # Convert predictions to format expected by visualization
109
  predictions = {}
 
 
110
 
111
+ # Initialize lists for the required keys
112
+ extrinsic_list = []
113
+ intrinsic_list = []
114
+ world_points_list = []
115
+ depth_maps_list = []
116
+ images_list = []
117
+ final_mask_list = []
118
+
119
+ # Loop through the outputs
120
  for pred in outputs:
121
+ # Extract data from predictions
122
+ depthmap_torch = pred["depth_z"][0].squeeze(-1) # (H, W)
123
+ intrinsics_torch = pred["intrinsics"][0] # (3, 3)
124
+ camera_pose_torch = pred["camera_poses"][0] # (4, 4)
125
 
126
+ # Compute new pts3d using depth, intrinsics, and camera pose
127
  pts3d_computed, valid_mask = depthmap_to_world_frame(
128
  depthmap_torch, intrinsics_torch, camera_pose_torch
129
  )
130
 
131
+ # Convert to numpy arrays for visualization
132
+ # Check if mask key exists in pred, if not, fill with boolean trues in the size of depthmap_torch
133
  if "mask" in pred:
134
  mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool)
135
  else:
136
+ # Fill with boolean trues in the size of depthmap_torch
137
  mask = np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool)
138
 
139
+ # Combine with valid depth mask
140
  mask = mask & valid_mask.cpu().numpy()
141
+
142
  image = pred["img_no_norm"][0].cpu().numpy()
143
 
144
+ # Append to lists
145
  extrinsic_list.append(camera_pose_torch.cpu().numpy())
146
  intrinsic_list.append(intrinsics_torch.cpu().numpy())
147
  world_points_list.append(pts3d_computed.cpu().numpy())
148
  depth_maps_list.append(depthmap_torch.cpu().numpy())
149
+ images_list.append(image) # Add image to list
150
+ final_mask_list.append(mask) # Add final_mask to list
151
 
152
+ # Convert lists to numpy arrays with required shapes
153
+ # extrinsic: (S, 3, 4) - batch of camera extrinsic matrices
154
  predictions["extrinsic"] = np.stack(extrinsic_list, axis=0)
155
+
156
+ # intrinsic: (S, 3, 3) - batch of camera intrinsic matrices
157
  predictions["intrinsic"] = np.stack(intrinsic_list, axis=0)
158
+
159
+ # world_points: (S, H, W, 3) - batch of 3D world points
160
  predictions["world_points"] = np.stack(world_points_list, axis=0)
161
 
162
+ # depth: (S, H, W, 1) or (S, H, W) - batch of depth maps
163
  depth_maps = np.stack(depth_maps_list, axis=0)
164
+ # Add channel dimension if needed to match (S, H, W, 1) format
165
  if len(depth_maps.shape) == 3:
166
  depth_maps = depth_maps[..., np.newaxis]
167
+
168
  predictions["depth"] = depth_maps
169
+
170
+ # images: (S, H, W, 3) - batch of input images
171
  predictions["images"] = np.stack(images_list, axis=0)
172
+
173
+ # final_mask: (S, H, W) - batch of final masks for filtering
174
  predictions["final_mask"] = np.stack(final_mask_list, axis=0)
175
 
176
+ # Process data for visualization tabs (depth, normal, measure)
177
  processed_data = process_predictions_for_visualization(
178
  predictions, views, high_level_config, filter_black_bg, filter_white_bg
179
  )
180
 
181
+ # Clean up
182
  torch.cuda.empty_cache()
183
+
184
  return predictions, processed_data
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  def update_view_selectors(processed_data):
188
+ """Update view selector dropdowns based on available views"""
189
  if processed_data is None or len(processed_data) == 0:
190
  choices = ["View 1"]
191
  else:
 
193
  choices = [f"View {i + 1}" for i in range(num_views)]
194
 
195
  return (
196
+ gr.Dropdown(choices=choices, value=choices[0]), # depth_view_selector
197
+ gr.Dropdown(choices=choices, value=choices[0]), # normal_view_selector
198
+ gr.Dropdown(choices=choices, value=choices[0]), # measure_view_selector
199
  )
200
 
201
+
202
  def get_view_data_by_index(processed_data, view_index):
203
+ """Get view data by index, handling bounds"""
204
  if processed_data is None or len(processed_data) == 0:
205
  return None
206
+
207
  view_keys = list(processed_data.keys())
208
  if view_index < 0 or view_index >= len(view_keys):
209
  view_index = 0
210
+
211
  return processed_data[view_keys[view_index]]
212
 
213
+
214
  def update_depth_view(processed_data, view_index):
215
+ """Update depth view for a specific view index"""
216
  view_data = get_view_data_by_index(processed_data, view_index)
217
  if view_data is None or view_data["depth"] is None:
218
  return None
219
+
220
  return colorize_depth(view_data["depth"], mask=view_data.get("mask"))
221
 
222
+
223
  def update_normal_view(processed_data, view_index):
224
+ """Update normal view for a specific view index"""
225
  view_data = get_view_data_by_index(processed_data, view_index)
226
  if view_data is None or view_data["normal"] is None:
227
  return None
228
+
229
  return colorize_normal(view_data["normal"], mask=view_data.get("mask"))
230
 
231
+
232
  def update_measure_view(processed_data, view_index):
233
+ """Update measure view for a specific view index with mask overlay"""
234
  view_data = get_view_data_by_index(processed_data, view_index)
235
  if view_data is None:
236
+ return None, [] # image, measure_points
237
+
238
+ # Get the base image
239
  image = view_data["image"].copy()
240
+
241
+ # Ensure image is in uint8 format
242
  if image.dtype != np.uint8:
243
  if image.max() <= 1.0:
244
  image = (image * 255).astype(np.uint8)
245
  else:
246
  image = image.astype(np.uint8)
247
 
248
+ # Apply mask overlay if mask is available
249
  if view_data["mask"] is not None:
250
  mask = view_data["mask"]
251
+
252
+ # Create light grey overlay for masked areas
253
+ # Masked areas (False values) will be overlaid with light grey
254
+ invalid_mask = ~mask # Areas where mask is False
255
+
256
  if invalid_mask.any():
257
+ # Create a light grey overlay (RGB: 192, 192, 192)
258
  overlay_color = np.array([255, 220, 220], dtype=np.uint8)
259
+
260
+ # Apply overlay with some transparency
261
+ alpha = 0.5 # Transparency level
262
+ for c in range(3): # RGB channels
263
  image[:, :, c] = np.where(
264
  invalid_mask,
265
  (1 - alpha) * image[:, :, c] + alpha * overlay_color[c],
266
  image[:, :, c],
267
  ).astype(np.uint8)
268
+
269
  return image, []
270
 
271
+
272
  def navigate_depth_view(processed_data, current_selector_value, direction):
273
+ """Navigate depth view (direction: -1 for previous, +1 for next)"""
274
  if processed_data is None or len(processed_data) == 0:
275
  return "View 1", None
276
+
277
+ # Parse current view number
278
  try:
279
  current_view = int(current_selector_value.split()[1]) - 1
280
  except:
281
  current_view = 0
282
+
283
  num_views = len(processed_data)
284
  new_view = (current_view + direction) % num_views
285
+
286
  new_selector_value = f"View {new_view + 1}"
287
  depth_vis = update_depth_view(processed_data, new_view)
288
+
289
  return new_selector_value, depth_vis
290
 
291
+
292
  def navigate_normal_view(processed_data, current_selector_value, direction):
293
+ """Navigate normal view (direction: -1 for previous, +1 for next)"""
294
  if processed_data is None or len(processed_data) == 0:
295
  return "View 1", None
296
+
297
+ # Parse current view number
298
  try:
299
  current_view = int(current_selector_value.split()[1]) - 1
300
  except:
301
  current_view = 0
302
+
303
  num_views = len(processed_data)
304
  new_view = (current_view + direction) % num_views
305
+
306
  new_selector_value = f"View {new_view + 1}"
307
  normal_vis = update_normal_view(processed_data, new_view)
308
+
309
  return new_selector_value, normal_vis
310
 
311
+
312
  def navigate_measure_view(processed_data, current_selector_value, direction):
313
+ """Navigate measure view (direction: -1 for previous, +1 for next)"""
314
  if processed_data is None or len(processed_data) == 0:
315
  return "View 1", None, []
316
+
317
+ # Parse current view number
318
  try:
319
  current_view = int(current_selector_value.split()[1]) - 1
320
  except:
321
  current_view = 0
322
+
323
  num_views = len(processed_data)
324
  new_view = (current_view + direction) % num_views
325
+
326
  new_selector_value = f"View {new_view + 1}"
327
  measure_image, measure_points = update_measure_view(processed_data, new_view)
328
+
329
  return new_selector_value, measure_image, measure_points
330
 
331
+
332
  def populate_visualization_tabs(processed_data):
333
+ """Populate the depth, normal, and measure tabs with processed data"""
334
  if processed_data is None or len(processed_data) == 0:
335
  return None, None, None, []
336
+
337
+ # Use update functions to ensure confidence filtering is applied from the start
338
  depth_vis = update_depth_view(processed_data, 0)
339
  normal_vis = update_normal_view(processed_data, 0)
340
  measure_img, _ = update_measure_view(processed_data, 0)
341
+
342
  return depth_vis, normal_vis, measure_img, []
343
 
344
+
345
  # -------------------------------------------------------------------------
346
+ # 2) Handle uploaded video/images --> produce target_dir + images
347
  # -------------------------------------------------------------------------
348
  def handle_uploads(unified_upload, s_time_interval=1.0):
349
+ """
350
+ Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
351
+ images or extracted frames from video into it. Return (target_dir, image_paths).
352
+ """
353
  start_time = time.time()
354
  gc.collect()
355
  torch.cuda.empty_cache()
356
 
357
+ # Create a unique folder name
358
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
359
  target_dir = f"input_images_{timestamp}"
360
  target_dir_images = os.path.join(target_dir, "images")
361
 
362
+ # Clean up if somehow that folder already exists
363
  if os.path.exists(target_dir):
364
  shutil.rmtree(target_dir)
365
  os.makedirs(target_dir)
 
367
 
368
  image_paths = []
369
 
370
+ # --- Handle uploaded files (both images and videos) ---
371
  if unified_upload is not None:
372
  for file_data in unified_upload:
373
  if isinstance(file_data, dict) and "name" in file_data:
 
376
  file_path = str(file_data)
377
 
378
  file_ext = os.path.splitext(file_path)[1].lower()
379
+
380
+ # Check if it's a video file
381
+ video_extensions = [
382
+ ".mp4",
383
+ ".avi",
384
+ ".mov",
385
+ ".mkv",
386
+ ".wmv",
387
+ ".flv",
388
+ ".webm",
389
+ ".m4v",
390
+ ".3gp",
391
+ ]
392
  if file_ext in video_extensions:
393
+ # Handle as video
394
  vs = cv2.VideoCapture(file_path)
395
  fps = vs.get(cv2.CAP_PROP_FPS)
396
+ frame_interval = int(fps * s_time_interval) # frames per interval
397
 
398
  count = 0
399
  video_frame_num = 0
 
403
  break
404
  count += 1
405
  if count % frame_interval == 0:
406
+ # Use original filename as prefix for frames
407
  base_name = os.path.splitext(os.path.basename(file_path))[0]
408
+ image_path = os.path.join(
409
+ target_dir_images, f"{base_name}_{video_frame_num:06}.png"
410
+ )
411
  cv2.imwrite(image_path, frame)
412
  image_paths.append(image_path)
413
  video_frame_num += 1
414
  vs.release()
415
+ print(
416
+ f"Extracted {video_frame_num} frames from video: {os.path.basename(file_path)}"
417
+ )
418
+
419
  else:
420
+ # Handle as image
421
+ # Check if the file is a HEIC image
422
  if file_ext in [".heic", ".heif"]:
423
+ # Convert HEIC to JPEG for better gallery compatibility
424
  try:
425
  with Image.open(file_path) as img:
426
+ # Convert to RGB if necessary (HEIC can have different color modes)
427
  if img.mode not in ("RGB", "L"):
428
  img = img.convert("RGB")
429
+
430
+ # Create JPEG filename
431
  base_name = os.path.splitext(os.path.basename(file_path))[0]
432
+ dst_path = os.path.join(
433
+ target_dir_images, f"{base_name}.jpg"
434
+ )
435
+
436
+ # Save as JPEG with high quality
437
  img.save(dst_path, "JPEG", quality=95)
438
  image_paths.append(dst_path)
439
+ print(
440
+ f"Converted HEIC to JPEG: {os.path.basename(file_path)} -> {os.path.basename(dst_path)}"
441
+ )
442
  except Exception as e:
443
+ print(f"Error converting HEIC file {file_path}: {e}")
444
+ # Fall back to copying as is
445
+ dst_path = os.path.join(
446
+ target_dir_images, os.path.basename(file_path)
447
+ )
448
  shutil.copy(file_path, dst_path)
449
  image_paths.append(dst_path)
450
  else:
451
+ # Regular image files - copy as is
452
+ dst_path = os.path.join(
453
+ target_dir_images, os.path.basename(file_path)
454
+ )
455
  shutil.copy(file_path, dst_path)
456
  image_paths.append(dst_path)
457
 
458
+ # Sort final images for gallery
459
  image_paths = sorted(image_paths)
460
+
461
+ end_time = time.time()
462
+ print(
463
+ f"Files processed to {target_dir_images}; took {end_time - start_time:.3f} seconds"
464
+ )
465
  return target_dir, image_paths
466
 
467
+
468
+ # -------------------------------------------------------------------------
469
+ # 3) Update gallery on upload
470
+ # -------------------------------------------------------------------------
471
  def update_gallery_on_upload(input_video, input_images, s_time_interval=1.0):
472
+ """
473
+ Whenever user uploads or changes files, immediately handle them
474
+ and show in the gallery. Return (target_dir, image_paths).
475
+ If nothing is uploaded, returns "None" and empty list.
476
+ """
477
  if not input_video and not input_images:
478
  return None, None, None, None
479
  target_dir, image_paths = handle_uploads(input_video, input_images, s_time_interval)
480
+ return (
481
+ None,
482
+ target_dir,
483
+ image_paths,
484
+ "Upload complete. Click 'Reconstruct' to begin 3D processing.",
485
+ )
486
+
487
 
488
  # -------------------------------------------------------------------------
489
+ # 4) Reconstruction: uses the target_dir plus any viz parameters
490
  # -------------------------------------------------------------------------
491
  @spaces.GPU(duration=120)
492
  def gradio_demo(
 
498
  apply_mask=True,
499
  show_mesh=True,
500
  ):
501
+ """
502
+ Perform reconstruction using the already-created target_dir/images.
503
+ """
504
  if not os.path.isdir(target_dir) or target_dir == "None":
505
  return None, "No valid target directory found. Please upload first.", None, None
506
 
507
+ start_time = time.time()
508
  gc.collect()
509
  torch.cuda.empty_cache()
510
 
511
+ # Prepare frame_filter dropdown
512
  target_dir_images = os.path.join(target_dir, "images")
513
+ all_files = (
514
+ sorted(os.listdir(target_dir_images))
515
+ if os.path.isdir(target_dir_images)
516
+ else []
517
+ )
518
  all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
519
  frame_filter_choices = ["All"] + all_files
520
 
521
+ print("Running MapAnything model...")
522
  with torch.no_grad():
523
  predictions, processed_data = run_model(target_dir, apply_mask)
524
 
525
+ # Save predictions
526
  prediction_save_path = os.path.join(target_dir, "predictions.npz")
527
  np.savez(prediction_save_path, **predictions)
528
 
529
+ # Handle None frame_filter
530
  if frame_filter is None:
531
  frame_filter = "All"
532
 
533
+ # Build a GLB file name
534
  glbfile = os.path.join(
535
  target_dir,
536
  f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb",
537
  )
538
 
539
+ # Convert predictions to GLB
540
  glbscene = predictions_to_glb(
541
  predictions,
542
  filter_by_frames=frame_filter,
543
  show_cam=show_cam,
544
  mask_black_bg=filter_black_bg,
545
  mask_white_bg=filter_white_bg,
546
+ as_mesh=show_mesh, # Use the show_mesh parameter
547
  )
548
  glbscene.export(file_obj=glbfile)
 
 
 
549
 
550
+ # Cleanup
551
  del predictions
552
  gc.collect()
553
  torch.cuda.empty_cache()
554
 
555
+ end_time = time.time()
556
+ print(f"Total time: {end_time - start_time:.2f} seconds")
557
+ log_msg = (
558
+ f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
559
+ )
560
+
561
+ # Populate visualization tabs with processed data
562
+ depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs(
563
+ processed_data
564
+ )
565
 
566
+ # Update view selectors based on available views
567
+ depth_selector, normal_selector, measure_selector = update_view_selectors(
568
+ processed_data
569
+ )
570
 
571
  return (
572
+ glbfile,
573
  log_msg,
574
  gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
575
  processed_data,
576
  depth_vis,
577
  normal_vis,
578
  measure_img,
579
+ "", # measure_text (empty initially)
580
  depth_selector,
581
  normal_selector,
582
  measure_selector,
583
  )
584
 
585
+
586
+ # -------------------------------------------------------------------------
587
+ # 5) Helper functions for UI resets + re-visualization
588
+ # -------------------------------------------------------------------------
589
  def colorize_depth(depth_map, mask=None):
590
+ """Convert depth map to colorized visualization with optional mask"""
591
+ if depth_map is None:
592
+ return None
593
+
594
+ # Normalize depth to 0-1 range
595
  depth_normalized = depth_map.copy()
596
  valid_mask = depth_normalized > 0
597
+
598
+ # Apply additional mask if provided (for background filtering)
599
+ if mask is not None:
600
+ valid_mask = valid_mask & mask
601
 
602
  if valid_mask.sum() > 0:
603
  valid_depths = depth_normalized[valid_mask]
604
  p5 = np.percentile(valid_depths, 5)
605
  p95 = np.percentile(valid_depths, 95)
606
+
607
  depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5)
608
 
609
+ # Apply colormap
610
  import matplotlib.pyplot as plt
611
+
612
  colormap = plt.cm.turbo_r
613
  colored = colormap(depth_normalized)
614
  colored = (colored[:, :, :3] * 255).astype(np.uint8)
615
+
616
+ # Set invalid pixels to white
617
  colored[~valid_mask] = [255, 255, 255]
618
+
619
  return colored
620
 
621
+
622
  def colorize_normal(normal_map, mask=None):
623
+ """Convert normal map to colorized visualization with optional mask"""
624
+ if normal_map is None:
625
+ return None
626
+
627
+ # Create a copy for modification
628
  normal_vis = normal_map.copy()
629
+
630
+ # Apply mask if provided (set masked areas to [0, 0, 0] which becomes grey after normalization)
631
  if mask is not None:
632
  invalid_mask = ~mask
633
+ normal_vis[invalid_mask] = [0, 0, 0] # Set invalid areas to zero
634
+
635
+ # Normalize normals to [0, 1] range for visualization
636
  normal_vis = (normal_vis + 1.0) / 2.0
637
  normal_vis = (normal_vis * 255).astype(np.uint8)
638
+
639
  return normal_vis
640
 
641
+
642
  def process_predictions_for_visualization(
643
  predictions, views, high_level_config, filter_black_bg=False, filter_white_bg=False
644
  ):
645
+ """Extract depth, normal, and 3D points from predictions for visualization"""
646
  processed_data = {}
647
+
648
+ # Process each view
649
  for view_idx, view in enumerate(views):
650
+ # Get image
651
  image = rgb(view["img"], norm_type=high_level_config["data_norm_type"])
652
+
653
+ # Get predicted points
654
  pred_pts3d = predictions["world_points"][view_idx]
655
+
656
+ # Initialize data for this view
657
+ view_data = {
658
+ "image": image[0],
659
+ "points3d": pred_pts3d,
660
+ "depth": None,
661
+ "normal": None,
662
+ "mask": None,
663
+ }
664
+
665
+ # Start with the final mask from predictions
666
  mask = predictions["final_mask"][view_idx].copy()
667
 
668
+ # Apply black background filtering if enabled
669
  if filter_black_bg:
670
+ # Get the image colors (ensure they're in 0-255 range)
671
  view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
672
+ # Filter out black background pixels (sum of RGB < 16)
673
  black_bg_mask = view_colors.sum(axis=2) >= 16
674
  mask = mask & black_bg_mask
675
 
676
+ # Apply white background filtering if enabled
677
  if filter_white_bg:
678
+ # Get the image colors (ensure they're in 0-255 range)
679
  view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
680
+ # Filter out white background pixels (all RGB > 240)
681
+ white_bg_mask = ~(
682
+ (view_colors[:, :, 0] > 240)
683
+ & (view_colors[:, :, 1] > 240)
684
+ & (view_colors[:, :, 2] > 240)
685
+ )
686
  mask = mask & white_bg_mask
687
 
688
  view_data["mask"] = mask
689
  view_data["depth"] = predictions["depth"][view_idx].squeeze()
690
+
691
  normals, _ = points_to_normals(pred_pts3d, mask=view_data["mask"])
692
  view_data["normal"] = normals
693
+
694
  processed_data[view_idx] = view_data
695
 
696
  return processed_data
697
 
698
+
699
+ def reset_measure(processed_data):
700
+ """Reset measure points"""
701
+ if processed_data is None or len(processed_data) == 0:
702
+ return None, [], ""
703
+
704
+ # Return the first view image
705
+ first_view = list(processed_data.values())[0]
706
+ return first_view["image"], [], ""
707
+
708
+
709
+ def measure(
710
+ processed_data, measure_points, current_view_selector, event: gr.SelectData
711
+ ):
712
+ """Handle measurement on images"""
713
  try:
714
+ print(f"Measure function called with selector: {current_view_selector}")
715
+
716
  if processed_data is None or len(processed_data) == 0:
717
  return None, [], "No data available"
718
+
719
+ # Use the currently selected view instead of always using the first view
720
  try:
721
  current_view_index = int(current_view_selector.split()[1]) - 1
722
  except:
723
  current_view_index = 0
724
 
725
+ print(f"Using view index: {current_view_index}")
726
+
727
+ # Get view data safely
728
  if current_view_index < 0 or current_view_index >= len(processed_data):
729
  current_view_index = 0
730
 
 
735
  return None, [], "No view data available"
736
 
737
  point2d = event.index[0], event.index[1]
738
+ print(f"Clicked point: {point2d}")
739
+
740
+ # Check if the clicked point is in a masked area (prevent interaction)
741
+ if (
742
+ current_view["mask"] is not None
743
+ and 0 <= point2d[1] < current_view["mask"].shape[0]
744
+ and 0 <= point2d[0] < current_view["mask"].shape[1]
745
+ ):
746
+ # Check if the point is in a masked (invalid) area
747
  if not current_view["mask"][point2d[1], point2d[0]]:
748
+ print(f"Clicked point {point2d} is in masked area, ignoring click")
749
+ # Always return image with mask overlay
750
+ masked_image, _ = update_measure_view(
751
+ processed_data, current_view_index
752
+ )
753
+ return (
754
+ masked_image,
755
+ measure_points,
756
+ '<span style="color: red; font-weight: bold;">Cannot measure on masked areas (shown in grey)</span>',
757
+ )
758
 
759
  measure_points.append(point2d)
760
+
761
+ # Get image with mask overlay and ensure it's valid
762
  image, _ = update_measure_view(processed_data, current_view_index)
763
+ if image is None:
764
+ return None, [], "No image available"
765
+
766
  image = image.copy()
767
  points3d = current_view["points3d"]
768
 
769
+ # Ensure image is in uint8 format for proper cv2 operations
770
+ try:
771
+ if image.dtype != np.uint8:
772
+ if image.max() <= 1.0:
773
+ # Image is in [0, 1] range, convert to [0, 255]
774
+ image = (image * 255).astype(np.uint8)
775
+ else:
776
+ # Image is already in [0, 255] range
777
+ image = image.astype(np.uint8)
778
+ except Exception as e:
779
+ print(f"Image conversion error: {e}")
780
+ return None, [], f"Image conversion error: {e}"
781
 
782
+ # Draw circles for points
783
+ try:
784
+ for p in measure_points:
785
+ if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]:
786
+ image = cv2.circle(
787
+ image, p, radius=5, color=(255, 0, 0), thickness=2
788
+ )
789
+ except Exception as e:
790
+ print(f"Drawing error: {e}")
791
+ return None, [], f"Drawing error: {e}"
792
 
793
  depth_text = ""
794
+ try:
795
+ for i, p in enumerate(measure_points):
796
+ if (
797
+ current_view["depth"] is not None
798
+ and 0 <= p[1] < current_view["depth"].shape[0]
799
+ and 0 <= p[0] < current_view["depth"].shape[1]
800
+ ):
801
+ d = current_view["depth"][p[1], p[0]]
802
+ depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n"
803
+ else:
804
+ # Use Z coordinate of 3D points if depth not available
805
+ if (
806
+ points3d is not None
807
+ and 0 <= p[1] < points3d.shape[0]
808
+ and 0 <= p[0] < points3d.shape[1]
809
+ ):
810
+ z = points3d[p[1], p[0], 2]
811
+ depth_text += f"- **P{i + 1} Z-coord: {z:.2f}m.**\n"
812
+ except Exception as e:
813
+ print(f"Depth text error: {e}")
814
+ depth_text = f"Error computing depth: {e}\n"
815
 
816
  if len(measure_points) == 2:
817
+ try:
818
+ point1, point2 = measure_points
819
+ # Draw line
820
+ if (
821
+ 0 <= point1[0] < image.shape[1]
822
+ and 0 <= point1[1] < image.shape[0]
823
+ and 0 <= point2[0] < image.shape[1]
824
+ and 0 <= point2[1] < image.shape[0]
825
+ ):
826
+ image = cv2.line(
827
+ image, point1, point2, color=(255, 0, 0), thickness=2
828
+ )
829
+
830
+ # Compute 3D distance
831
+ distance_text = "- **Distance: Unable to compute**"
832
+ if (
833
+ points3d is not None
834
+ and 0 <= point1[1] < points3d.shape[0]
835
+ and 0 <= point1[0] < points3d.shape[1]
836
+ and 0 <= point2[1] < points3d.shape[0]
837
+ and 0 <= point2[0] < points3d.shape[1]
838
+ ):
839
+ try:
840
+ p1_3d = points3d[point1[1], point1[0]]
841
+ p2_3d = points3d[point2[1], point2[0]]
842
+ distance = np.linalg.norm(p1_3d - p2_3d)
843
+ distance_text = f"- **Distance: {distance:.2f}m**"
844
+ except Exception as e:
845
+ print(f"Distance computation error: {e}")
846
+ distance_text = f"- **Distance computation error: {e}**"
847
+
848
+ measure_points = []
849
+ text = depth_text + distance_text
850
+ print(f"Measurement complete: {text}")
851
+ return [image, measure_points, text]
852
+ except Exception as e:
853
+ print(f"Final measurement error: {e}")
854
+ return None, [], f"Measurement error: {e}"
855
  else:
856
+ print(f"Single point measurement: {depth_text}")
857
  return [image, measure_points, depth_text]
858
 
859
  except Exception as e:
860
+ print(f"Overall measure function error: {e}")
861
  return None, [], f"Measure function error: {e}"
862
 
863
+
864
  def clear_fields():
865
+ """
866
+ Clears the 3D viewer, the stored target_dir, and empties the gallery.
867
+ """
868
  return None
869
 
870
+
871
  def update_log():
872
+ """
873
+ Display a quick log message while waiting.
874
+ """
875
  return "Loading and Reconstructing..."
876
 
877
+
878
  def update_visualization(
879
  target_dir,
880
  frame_filter,
 
884
  filter_white_bg=False,
885
  show_mesh=True,
886
  ):
887
+ """
888
+ Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
889
+ and return it for the 3D viewer. If is_example == "True", skip.
890
+ """
891
+
892
+ # If it's an example click, skip as requested
893
  if is_example == "True":
894
+ return (
895
+ gr.update(),
896
+ "No reconstruction available. Please click the Reconstruct button first.",
897
+ )
898
 
899
  if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
900
+ return (
901
+ gr.update(),
902
+ "No reconstruction available. Please click the Reconstruct button first.",
903
+ )
904
 
905
  predictions_path = os.path.join(target_dir, "predictions.npz")
906
  if not os.path.exists(predictions_path):
907
+ return (
908
+ gr.update(),
909
+ f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.",
910
+ )
911
 
912
  loaded = np.load(predictions_path, allow_pickle=True)
913
  predictions = {key: loaded[key] for key in loaded.keys()}
 
919
 
920
  if not os.path.exists(glbfile):
921
  glbscene = predictions_to_glb(
922
+ predictions,
923
+ filter_by_frames=frame_filter,
924
+ show_cam=show_cam,
925
+ mask_black_bg=filter_black_bg,
926
+ mask_white_bg=filter_white_bg,
927
+ as_mesh=show_mesh,
928
  )
929
  glbscene.export(file_obj=glbfile)
930
 
931
+ return (
932
+ glbfile,
933
+ "Visualization updated.",
934
+ )
935
+
936
 
937
  def update_all_views_on_filter_change(
938
+ target_dir,
939
+ filter_black_bg,
940
+ filter_white_bg,
941
+ processed_data,
942
+ depth_view_selector,
943
+ normal_view_selector,
944
+ measure_view_selector,
945
  ):
946
+ """
947
+ Update all individual view tabs when background filtering checkboxes change.
948
+ This regenerates the processed data with new filtering and updates all views.
949
+ """
950
+ # Check if we have a valid target directory and predictions
951
  if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
952
  return processed_data, None, None, None, []
953
 
 
956
  return processed_data, None, None, None, []
957
 
958
  try:
959
+ # Load the original predictions and views
960
  loaded = np.load(predictions_path, allow_pickle=True)
961
  predictions = {key: loaded[key] for key in loaded.keys()}
 
962
 
963
+ # Load images using MapAnything's load_images function
964
+ image_folder_path = os.path.join(target_dir, "images")
965
+ views = load_images(image_folder_path)
966
+
967
+ # Regenerate processed data with new filtering settings
968
  new_processed_data = process_predictions_for_visualization(
969
  predictions, views, high_level_config, filter_black_bg, filter_white_bg
970
  )
971
 
972
+ # Get current view indices
973
+ try:
974
+ depth_view_idx = (
975
+ int(depth_view_selector.split()[1]) - 1 if depth_view_selector else 0
976
+ )
977
+ except:
978
+ depth_view_idx = 0
979
+
980
+ try:
981
+ normal_view_idx = (
982
+ int(normal_view_selector.split()[1]) - 1 if normal_view_selector else 0
983
+ )
984
+ except:
985
+ normal_view_idx = 0
986
+
987
+ try:
988
+ measure_view_idx = (
989
+ int(measure_view_selector.split()[1]) - 1
990
+ if measure_view_selector
991
+ else 0
992
+ )
993
+ except:
994
+ measure_view_idx = 0
995
 
996
+ # Update all views with new filtered data
997
  depth_vis = update_depth_view(new_processed_data, depth_view_idx)
998
  normal_vis = update_normal_view(new_processed_data, normal_view_idx)
999
  measure_img, _ = update_measure_view(new_processed_data, measure_view_idx)
1000
 
1001
  return new_processed_data, depth_vis, normal_vis, measure_img, []
1002
+
1003
  except Exception as e:
1004
+ print(f"Error updating views on filter change: {e}")
1005
  return processed_data, None, None, None, []
1006
 
1007
+
1008
+ # -------------------------------------------------------------------------
1009
+ # Example scene functions
1010
+ # -------------------------------------------------------------------------
1011
+ def get_scene_info(examples_dir):
1012
+ """Get information about scenes in the examples directory"""
1013
+ import glob
1014
+
1015
+ scenes = []
1016
+ if not os.path.exists(examples_dir):
1017
+ return scenes
1018
+
1019
+ for scene_folder in sorted(os.listdir(examples_dir)):
1020
+ scene_path = os.path.join(examples_dir, scene_folder)
1021
+ if os.path.isdir(scene_path):
1022
+ # Find all image files in the scene folder
1023
+ image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"]
1024
+ image_files = []
1025
+ for ext in image_extensions:
1026
+ image_files.extend(glob.glob(os.path.join(scene_path, ext)))
1027
+ image_files.extend(glob.glob(os.path.join(scene_path, ext.upper())))
1028
+
1029
+ if image_files:
1030
+ # Sort images and get the first one for thumbnail
1031
+ image_files = sorted(image_files)
1032
+ first_image = image_files[0]
1033
+ num_images = len(image_files)
1034
+
1035
+ scenes.append(
1036
+ {
1037
+ "name": scene_folder,
1038
+ "path": scene_path,
1039
+ "thumbnail": first_image,
1040
+ "num_images": num_images,
1041
+ "image_files": image_files,
1042
+ }
1043
+ )
1044
+
1045
+ return scenes
1046
+
1047
+
1048
+ def load_example_scene(scene_name, examples_dir="examples"):
1049
+ """Load a scene from examples directory"""
1050
+ scenes = get_scene_info(examples_dir)
1051
+
1052
+ # Find the selected scene
1053
+ selected_scene = None
1054
+ for scene in scenes:
1055
+ if scene["name"] == scene_name:
1056
+ selected_scene = scene
1057
+ break
1058
+
1059
+ if selected_scene is None:
1060
+ return None, None, None, "Scene not found"
1061
+
1062
+ # Create file-like objects for the unified upload system
1063
+ # Convert image file paths to the format expected by unified_upload
1064
+ file_objects = []
1065
+ for image_path in selected_scene["image_files"]:
1066
+ file_objects.append(image_path)
1067
+
1068
+ # Create target directory and copy images using the unified upload system
1069
+ target_dir, image_paths = handle_uploads(file_objects, 1.0)
1070
+
1071
+ return (
1072
+ None, # Clear reconstruction output
1073
+ target_dir, # Set target directory
1074
+ image_paths, # Set gallery
1075
+ f"Loaded scene '{scene_name}' with {selected_scene['num_images']} images. Click 'Reconstruct' to begin 3D processing.",
1076
+ )
1077
+
1078
+
1079
  # -------------------------------------------------------------------------
1080
+ # 6) Build Gradio UI
1081
  # -------------------------------------------------------------------------
1082
+ theme = get_gradio_theme()
1083
 
1084
+ with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo:
1085
+ # State variables for the tabbed interface
1086
+ is_example = gr.Textbox(label="is_example", visible=False, value="None")
1087
+ num_images = gr.Textbox(label="num_images", visible=False, value="None")
1088
  processed_data_state = gr.State(value=None)
1089
  measure_points_state = gr.State(value=[])
1090
+ current_view_index = gr.State(value=0) # Track current view index for navigation
1091
+
1092
+ target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
1093
+
1094
+ with gr.Row():
1095
+ with gr.Column(scale=2):
1096
+ # Unified upload component for both videos and images
1097
+ unified_upload = gr.File(
1098
+ file_count="multiple",
1099
+ label="Upload Video or Images",
1100
+ interactive=True,
1101
+ file_types=["image", "video"],
1102
+ )
1103
+ with gr.Row():
1104
+ s_time_interval = gr.Slider(
1105
+ minimum=0.1,
1106
+ maximum=5.0,
1107
+ value=1.0,
1108
+ step=0.1,
1109
+ label="Video sample time interval (take a sample every x sec.)",
 
 
 
1110
  interactive=True,
1111
+ visible=True,
1112
+ scale=3,
1113
  )
1114
+ resample_btn = gr.Button(
1115
+ "Resample Video",
1116
+ visible=False,
1117
+ variant="secondary",
1118
+ scale=1,
 
 
 
 
 
 
 
 
 
 
 
 
1119
  )
1120
+
1121
+ image_gallery = gr.Gallery(
1122
+ label="Preview",
1123
+ columns=4,
1124
+ height="300px",
1125
+ show_download_button=True,
1126
+ object_fit="contain",
1127
+ preview=True,
1128
+ )
1129
+
1130
+ clear_uploads_btn = gr.ClearButton(
1131
+ [unified_upload, image_gallery],
1132
+ value="Clear Uploads",
1133
+ variant="secondary",
1134
+ size="sm",
1135
+ )
1136
+
1137
+ with gr.Column(scale=4):
1138
+ with gr.Column():
1139
+ gr.Markdown(
1140
+ "**Metric 3D Reconstruction (Point Cloud and Camera Poses)**"
1141
+ )
1142
+ log_output = gr.Markdown(
1143
+ "Please upload a video or images, then click Reconstruct.",
1144
+ elem_classes=["custom-log"],
1145
+ )
1146
+
1147
+ # Add tabbed interface similar to MoGe
1148
  with gr.Tabs():
1149
  with gr.Tab("3D View"):
1150
+ reconstruction_output = gr.Model3D(
1151
+ height=520,
1152
+ zoom_speed=0.5,
1153
+ pan_speed=0.5,
1154
+ clear_color=[0.0, 0.0, 0.0, 0.0],
1155
+ key="persistent_3d_viewer",
1156
+ elem_id="reconstruction_3d_viewer",
1157
+ )
1158
  with gr.Tab("Depth"):
1159
+ with gr.Row(elem_classes=["navigation-row"]):
1160
+ prev_depth_btn = gr.Button("Previous", size="sm", scale=1)
1161
+ depth_view_selector = gr.Dropdown(
1162
+ choices=["View 1"],
1163
+ value="View 1",
1164
+ label="Select View",
1165
+ scale=2,
1166
+ interactive=True,
1167
+ allow_custom_value=True,
1168
+ )
1169
+ next_depth_btn = gr.Button("Next ▶", size="sm", scale=1)
1170
+ depth_map = gr.Image(
1171
+ type="numpy",
1172
+ label="Colorized Depth Map",
1173
+ format="png",
1174
+ interactive=False,
1175
+ )
1176
  with gr.Tab("Normal"):
1177
+ with gr.Row(elem_classes=["navigation-row"]):
1178
+ prev_normal_btn = gr.Button(
1179
+ " Previous", size="sm", scale=1
1180
+ )
1181
+ normal_view_selector = gr.Dropdown(
1182
+ choices=["View 1"],
1183
+ value="View 1",
1184
+ label="Select View",
1185
+ scale=2,
1186
+ interactive=True,
1187
+ allow_custom_value=True,
1188
+ )
1189
+ next_normal_btn = gr.Button("Next ▶", size="sm", scale=1)
1190
+ normal_map = gr.Image(
1191
+ type="numpy",
1192
+ label="Normal Map",
1193
+ format="png",
1194
+ interactive=False,
1195
+ )
1196
  with gr.Tab("Measure"):
1197
+ gr.Markdown(MEASURE_INSTRUCTIONS_HTML)
1198
+ with gr.Row(elem_classes=["navigation-row"]):
1199
+ prev_measure_btn = gr.Button(
1200
+ "◀ Previous", size="sm", scale=1
1201
+ )
1202
+ measure_view_selector = gr.Dropdown(
1203
+ choices=["View 1"],
1204
+ value="View 1",
1205
+ label="Select View",
1206
+ scale=2,
1207
+ interactive=True,
1208
+ allow_custom_value=True,
1209
+ )
1210
+ next_measure_btn = gr.Button("Next ▶", size="sm", scale=1)
1211
+ measure_image = gr.Image(
1212
+ type="numpy",
1213
+ show_label=False,
1214
+ format="webp",
1215
+ interactive=False,
1216
+ sources=[],
1217
+ )
1218
+ gr.Markdown(
1219
+ "**Note:** Light-grey areas indicate regions with no depth information where measurements cannot be taken."
1220
+ )
1221
+ measure_text = gr.Markdown("")
1222
+
1223
+ with gr.Row():
1224
+ submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
1225
+ clear_btn = gr.ClearButton(
1226
+ [
1227
+ unified_upload,
1228
+ reconstruction_output,
1229
+ log_output,
1230
+ target_dir_output,
1231
+ image_gallery,
1232
+ ],
1233
+ scale=1,
1234
+ )
1235
+
1236
+ with gr.Row():
1237
+ frame_filter = gr.Dropdown(
1238
+ choices=["All"], value="All", label="Show Points from Frame"
1239
+ )
1240
+ with gr.Column():
1241
+ gr.Markdown("### Pointcloud Options: (live updates)")
1242
+ show_cam = gr.Checkbox(label="Show Camera", value=True)
1243
+ show_mesh = gr.Checkbox(label="Show Mesh", value=True)
1244
+ filter_black_bg = gr.Checkbox(
1245
+ label="Filter Black Background", value=False
1246
+ )
1247
+ filter_white_bg = gr.Checkbox(
1248
+ label="Filter White Background", value=False
1249
+ )
1250
+ gr.Markdown("### Reconstruction Options: (updated on next run)")
1251
+ apply_mask_checkbox = gr.Checkbox(
1252
+ label="Apply mask for predicted ambiguous depth classes & edges",
1253
+ value=True,
1254
+ )
1255
+ # ---------------------- Example Scenes Section ----------------------
1256
+ gr.Markdown("## Example Scenes (lists all scenes in the examples folder)")
1257
+ gr.Markdown("Click any thumbnail to load the scene for reconstruction.")
1258
+
1259
+ # Get scene information
1260
+ scenes = get_scene_info("examples")
1261
+
1262
+ # Create thumbnail grid (4 columns, N rows)
1263
+ if scenes:
1264
+ for i in range(0, len(scenes), 4): # Process 4 scenes per row
1265
+ with gr.Row():
1266
+ for j in range(4):
1267
+ scene_idx = i + j
1268
+ if scene_idx < len(scenes):
1269
+ scene = scenes[scene_idx]
1270
+ with gr.Column(scale=1, elem_classes=["clickable-thumbnail"]):
1271
+ # Clickable thumbnail
1272
+ scene_img = gr.Image(
1273
+ value=scene["thumbnail"],
1274
+ height=150,
1275
+ interactive=False,
1276
+ show_label=False,
1277
+ elem_id=f"scene_thumb_{scene['name']}",
1278
+ sources=[],
1279
+ )
1280
+
1281
+ # Scene name and image count as text below thumbnail
1282
+ gr.Markdown(
1283
+ f"**{scene['name']}** \n {scene['num_images']} images",
1284
+ elem_classes=["scene-info"],
1285
+ )
1286
+
1287
+ # Connect thumbnail click to load scene
1288
+ scene_img.select(
1289
+ fn=lambda name=scene["name"]: load_example_scene(name),
1290
+ outputs=[
1291
+ reconstruction_output,
1292
+ target_dir_output,
1293
+ image_gallery,
1294
+ log_output,
1295
+ ],
1296
+ )
1297
+ else:
1298
+ # Empty column to maintain grid structure
1299
+ with gr.Column(scale=1):
1300
+ pass
1301
 
1302
  # -------------------------------------------------------------------------
1303
+ # "Reconstruct" button logic:
1304
+ # - Clear fields
1305
+ # - Update log
1306
+ # - gradio_demo(...) with the existing target_dir
1307
+ # - Then set is_example = "False"
1308
  # -------------------------------------------------------------------------
1309
+ submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then(
 
 
1310
  fn=update_log, inputs=[], outputs=[log_output]
1311
  ).then(
1312
  fn=gradio_demo,
1313
  inputs=[
1314
+ target_dir_output,
1315
+ frame_filter,
1316
+ show_cam,
1317
+ filter_black_bg,
1318
+ filter_white_bg,
1319
+ apply_mask_checkbox,
1320
+ show_mesh,
1321
  ],
1322
  outputs=[
1323
+ reconstruction_output,
1324
+ log_output,
1325
+ frame_filter,
1326
+ processed_data_state,
1327
+ depth_map,
1328
+ normal_map,
1329
+ measure_image,
1330
+ measure_text,
1331
+ depth_view_selector,
1332
+ normal_view_selector,
1333
+ measure_view_selector,
1334
  ],
1335
+ ).then(
1336
+ fn=lambda: "False",
1337
+ inputs=[],
1338
+ outputs=[is_example], # set is_example to "False"
1339
+ )
1340
 
1341
+ # -------------------------------------------------------------------------
1342
+ # Real-time Visualization Updates
1343
+ # -------------------------------------------------------------------------
1344
  frame_filter.change(
1345
+ update_visualization,
1346
+ [
1347
+ target_dir_output,
1348
+ frame_filter,
1349
+ show_cam,
1350
+ is_example,
1351
+ filter_black_bg,
1352
+ filter_white_bg,
1353
+ show_mesh,
1354
+ ],
1355
+ [reconstruction_output, log_output],
1356
  )
1357
  show_cam.change(
1358
+ update_visualization,
1359
+ [
1360
+ target_dir_output,
1361
+ frame_filter,
1362
+ show_cam,
1363
+ is_example,
1364
+ ],
1365
+ [reconstruction_output, log_output],
1366
  )
 
1367
  filter_black_bg.change(
1368
+ update_visualization,
1369
+ [
1370
+ target_dir_output,
1371
+ frame_filter,
1372
+ show_cam,
1373
+ is_example,
1374
+ filter_black_bg,
1375
+ filter_white_bg,
1376
+ ],
1377
+ [reconstruction_output, log_output],
1378
  ).then(
1379
+ fn=update_all_views_on_filter_change,
1380
+ inputs=[
1381
+ target_dir_output,
1382
+ filter_black_bg,
1383
+ filter_white_bg,
1384
+ processed_data_state,
1385
+ depth_view_selector,
1386
+ normal_view_selector,
1387
+ measure_view_selector,
1388
+ ],
1389
+ outputs=[
1390
+ processed_data_state,
1391
+ depth_map,
1392
+ normal_map,
1393
+ measure_image,
1394
+ measure_points_state,
1395
+ ],
1396
  )
 
1397
  filter_white_bg.change(
1398
+ update_visualization,
1399
+ [
1400
+ target_dir_output,
1401
+ frame_filter,
1402
+ show_cam,
1403
+ is_example,
1404
+ filter_black_bg,
1405
+ filter_white_bg,
1406
+ show_mesh,
1407
+ ],
1408
+ [reconstruction_output, log_output],
1409
  ).then(
1410
+ fn=update_all_views_on_filter_change,
1411
+ inputs=[
1412
+ target_dir_output,
1413
+ filter_black_bg,
1414
+ filter_white_bg,
1415
+ processed_data_state,
1416
+ depth_view_selector,
1417
+ normal_view_selector,
1418
+ measure_view_selector,
1419
+ ],
1420
+ outputs=[
1421
+ processed_data_state,
1422
+ depth_map,
1423
+ normal_map,
1424
+ measure_image,
1425
+ measure_points_state,
1426
+ ],
1427
+ )
1428
+
1429
+ show_mesh.change(
1430
+ update_visualization,
1431
+ [
1432
+ target_dir_output,
1433
+ frame_filter,
1434
+ show_cam,
1435
+ is_example,
1436
+ filter_black_bg,
1437
+ filter_white_bg,
1438
+ show_mesh,
1439
+ ],
1440
+ [reconstruction_output, log_output],
1441
  )
1442
 
1443
+ # -------------------------------------------------------------------------
1444
+ # Auto-update gallery whenever user uploads or changes their files
1445
+ # -------------------------------------------------------------------------
1446
  def update_gallery_on_unified_upload(files, interval):
1447
+ if not files:
1448
+ return None, None, None
1449
  target_dir, image_paths = handle_uploads(files, interval)
1450
+ return (
1451
+ target_dir,
1452
+ image_paths,
1453
+ "Upload complete. Click 'Reconstruct' to begin 3D processing.",
1454
+ )
1455
 
1456
  def show_resample_button(files):
1457
+ """Show the resample button only if there are uploaded files containing videos"""
1458
+ if not files:
1459
+ return gr.update(visible=False)
1460
+
1461
+ # Check if any uploaded files are videos
1462
+ video_extensions = [
1463
+ ".mp4",
1464
+ ".avi",
1465
+ ".mov",
1466
+ ".mkv",
1467
+ ".wmv",
1468
+ ".flv",
1469
+ ".webm",
1470
+ ".m4v",
1471
+ ".3gp",
1472
+ ]
1473
  has_video = False
1474
+
1475
  for file_data in files:
1476
+ if isinstance(file_data, dict) and "name" in file_data:
1477
+ file_path = file_data["name"]
1478
+ else:
1479
+ file_path = str(file_data)
1480
+
1481
+ file_ext = os.path.splitext(file_path)[1].lower()
1482
+ if file_ext in video_extensions:
1483
  has_video = True
1484
  break
1485
+
1486
  return gr.update(visible=has_video)
1487
 
1488
+ def hide_resample_button():
1489
+ """Hide the resample button after use"""
1490
+ return gr.update(visible=False)
1491
+
1492
  def resample_video_with_new_interval(files, new_interval, current_target_dir):
1493
+ """Resample video with new slider value"""
1494
+ if not files:
1495
+ return (
1496
+ current_target_dir,
1497
+ None,
1498
+ "No files to resample.",
1499
+ gr.update(visible=False),
1500
+ )
1501
+
1502
+ # Check if we have videos to resample
1503
+ video_extensions = [
1504
+ ".mp4",
1505
+ ".avi",
1506
+ ".mov",
1507
+ ".mkv",
1508
+ ".wmv",
1509
+ ".flv",
1510
+ ".webm",
1511
+ ".m4v",
1512
+ ".3gp",
1513
+ ]
1514
+ has_video = any(
1515
+ os.path.splitext(
1516
+ str(file_data["name"] if isinstance(file_data, dict) else file_data)
1517
+ )[1].lower()
1518
+ in video_extensions
1519
+ for file_data in files
1520
+ )
1521
+
1522
+ if not has_video:
1523
+ return (
1524
+ current_target_dir,
1525
+ None,
1526
+ "No videos found to resample.",
1527
+ gr.update(visible=False),
1528
+ )
1529
+
1530
+ # Clean up old target directory if it exists
1531
+ if (
1532
+ current_target_dir
1533
+ and current_target_dir != "None"
1534
+ and os.path.exists(current_target_dir)
1535
+ ):
1536
+ shutil.rmtree(current_target_dir)
1537
+
1538
+ # Process files with new interval
1539
  target_dir, image_paths = handle_uploads(files, new_interval)
1540
+
1541
+ return (
1542
+ target_dir,
1543
+ image_paths,
1544
+ f"Video resampled with {new_interval}s interval. Click 'Reconstruct' to begin 3D processing.",
1545
+ gr.update(visible=False),
1546
+ )
1547
 
1548
  unified_upload.change(
1549
+ fn=update_gallery_on_unified_upload,
1550
+ inputs=[unified_upload, s_time_interval],
1551
+ outputs=[target_dir_output, image_gallery, log_output],
1552
+ ).then(
1553
+ fn=show_resample_button,
1554
+ inputs=[unified_upload],
1555
+ outputs=[resample_btn],
1556
+ )
1557
+
1558
+ # Show resample button when slider changes (only if files are uploaded)
1559
+ s_time_interval.change(
1560
+ fn=show_resample_button,
1561
+ inputs=[unified_upload],
1562
+ outputs=[resample_btn],
1563
+ )
1564
+
1565
+ # Handle resample button click
1566
+ resample_btn.click(
1567
+ fn=resample_video_with_new_interval,
1568
+ inputs=[unified_upload, s_time_interval, target_dir_output],
1569
+ outputs=[target_dir_output, image_gallery, log_output, resample_btn],
1570
+ )
1571
+
1572
+ # -------------------------------------------------------------------------
1573
+ # Measure tab functionality
1574
+ # -------------------------------------------------------------------------
1575
+ measure_image.select(
1576
+ fn=measure,
1577
+ inputs=[processed_data_state, measure_points_state, measure_view_selector],
1578
+ outputs=[measure_image, measure_points_state, measure_text],
1579
+ )
1580
+
1581
+ # -------------------------------------------------------------------------
1582
+ # Navigation functionality for Depth, Normal, and Measure tabs
1583
+ # -------------------------------------------------------------------------
1584
 
1585
+ # Depth tab navigation
1586
+ prev_depth_btn.click(
1587
+ fn=lambda processed_data, current_selector: navigate_depth_view(
1588
+ processed_data, current_selector, -1
1589
+ ),
1590
+ inputs=[processed_data_state, depth_view_selector],
1591
+ outputs=[depth_view_selector, depth_map],
1592
+ )
1593
 
1594
+ next_depth_btn.click(
1595
+ fn=lambda processed_data, current_selector: navigate_depth_view(
1596
+ processed_data, current_selector, 1
1597
+ ),
1598
+ inputs=[processed_data_state, depth_view_selector],
1599
+ outputs=[depth_view_selector, depth_map],
1600
+ )
1601
 
1602
+ depth_view_selector.change(
1603
+ fn=lambda processed_data, selector_value: (
1604
+ update_depth_view(
1605
+ processed_data,
1606
+ int(selector_value.split()[1]) - 1,
1607
+ )
1608
+ if selector_value
1609
+ else None
1610
+ ),
1611
+ inputs=[processed_data_state, depth_view_selector],
1612
+ outputs=[depth_map],
1613
+ )
1614
 
1615
+ # Normal tab navigation
1616
+ prev_normal_btn.click(
1617
+ fn=lambda processed_data, current_selector: navigate_normal_view(
1618
+ processed_data, current_selector, -1
1619
+ ),
1620
+ inputs=[processed_data_state, normal_view_selector],
1621
+ outputs=[normal_view_selector, normal_map],
1622
+ )
1623
+
1624
+ next_normal_btn.click(
1625
+ fn=lambda processed_data, current_selector: navigate_normal_view(
1626
+ processed_data, current_selector, 1
1627
+ ),
1628
+ inputs=[processed_data_state, normal_view_selector],
1629
+ outputs=[normal_view_selector, normal_map],
1630
+ )
1631
+
1632
+ normal_view_selector.change(
1633
+ fn=lambda processed_data, selector_value: (
1634
+ update_normal_view(
1635
+ processed_data,
1636
+ int(selector_value.split()[1]) - 1,
1637
+ )
1638
+ if selector_value
1639
+ else None
1640
+ ),
1641
+ inputs=[processed_data_state, normal_view_selector],
1642
+ outputs=[normal_map],
1643
+ )
1644
 
1645
+ # Measure tab navigation
1646
+ prev_measure_btn.click(
1647
+ fn=lambda processed_data, current_selector: navigate_measure_view(
1648
+ processed_data, current_selector, -1
1649
+ ),
1650
+ inputs=[processed_data_state, measure_view_selector],
1651
+ outputs=[measure_view_selector, measure_image, measure_points_state],
1652
+ )
1653
+
1654
+ next_measure_btn.click(
1655
+ fn=lambda processed_data, current_selector: navigate_measure_view(
1656
+ processed_data, current_selector, 1
1657
+ ),
1658
+ inputs=[processed_data_state, measure_view_selector],
1659
+ outputs=[measure_view_selector, measure_image, measure_points_state],
1660
+ )
1661
+
1662
+ measure_view_selector.change(
1663
+ fn=lambda processed_data, selector_value: (
1664
+ update_measure_view(processed_data, int(selector_value.split()[1]) - 1)
1665
+ if selector_value
1666
+ else (None, [])
1667
+ ),
1668
+ inputs=[processed_data_state, measure_view_selector],
1669
+ outputs=[measure_image, measure_points_state],
1670
+ )
1671
+
1672
+ # -------------------------------------------------------------------------
1673
+ # Acknowledgement section
1674
+ # -------------------------------------------------------------------------
1675
+ gr.HTML(get_acknowledgements_html())
1676
 
1677
+ demo.queue(max_size=20).launch(show_error=True, share=True, ssr_mode=False)