lidavidsh commited on
Commit
5e49721
·
1 Parent(s): c788f41

update app.py to use vggt service on amd gpus

Browse files
Files changed (1) hide show
  1. app.py +920 -609
app.py CHANGED
@@ -1,609 +1,920 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import os
8
- import cv2
9
- import torch
10
- import numpy as np
11
- import gradio as gr
12
- import sys
13
- import shutil
14
- from datetime import datetime
15
- import glob
16
- import gc
17
- import time
18
- import spaces
19
-
20
-
21
- sys.path.append("vggt/")
22
-
23
- from visual_util import predictions_to_glb
24
- from vggt.models.vggt import VGGT
25
- from vggt.utils.load_fn import load_and_preprocess_images
26
- from vggt.utils.pose_enc import pose_encoding_to_extri_intri
27
- from vggt.utils.geometry import unproject_depth_map_to_point_map
28
-
29
- # device = "cuda" if torch.cuda.is_available() else "cpu"
30
-
31
- print("Initializing and loading VGGT model...")
32
- # model = VGGT.from_pretrained("facebook/VGGT-1B") # another way to load the model
33
-
34
- model = VGGT()
35
- _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
36
- model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
37
-
38
-
39
- model.eval()
40
- # model = model.to(device)
41
-
42
-
43
- # -------------------------------------------------------------------------
44
- # 1) Core model inference
45
- # -------------------------------------------------------------------------
46
- @spaces.GPU(duration=120)
47
- def run_model(target_dir, model) -> dict:
48
- """
49
- Run the VGGT model on images in the 'target_dir/images' folder and return predictions.
50
- """
51
- print(f"Processing images from {target_dir}")
52
-
53
- # Device check
54
- device = "cuda" if torch.cuda.is_available() else "cpu"
55
- if not torch.cuda.is_available():
56
- raise ValueError("CUDA is not available. Check your environment.")
57
-
58
- # Move model to device
59
- model = model.to(device)
60
- model.eval()
61
-
62
- # Load and preprocess images
63
- image_names = glob.glob(os.path.join(target_dir, "images", "*"))
64
- image_names = sorted(image_names)
65
- print(f"Found {len(image_names)} images")
66
- if len(image_names) == 0:
67
- raise ValueError("No images found. Check your upload.")
68
-
69
- images = load_and_preprocess_images(image_names).to(device)
70
- print(f"Preprocessed images shape: {images.shape}")
71
-
72
- # Run inference
73
- print("Running inference...")
74
- with torch.no_grad():
75
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
76
- predictions = model(images)
77
-
78
- # Convert pose encoding to extrinsic and intrinsic matrices
79
- print("Converting pose encoding to extrinsic and intrinsic matrices...")
80
- extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
81
- predictions["extrinsic"] = extrinsic
82
- predictions["intrinsic"] = intrinsic
83
-
84
- # Convert tensors to numpy
85
- for key in predictions.keys():
86
- if isinstance(predictions[key], torch.Tensor):
87
- predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension
88
-
89
- # Generate world points from depth map
90
- print("Computing world points from depth map...")
91
- depth_map = predictions["depth"] # (S, H, W, 1)
92
- world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"])
93
- predictions["world_points_from_depth"] = world_points
94
-
95
- # Clean up
96
- torch.cuda.empty_cache()
97
- return predictions
98
-
99
-
100
- # -------------------------------------------------------------------------
101
- # 2) Handle uploaded video/images --> produce target_dir + images
102
- # -------------------------------------------------------------------------
103
- def handle_uploads(input_video, input_images):
104
- """
105
- Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
106
- images or extracted frames from video into it. Return (target_dir, image_paths).
107
- """
108
- start_time = time.time()
109
- gc.collect()
110
- torch.cuda.empty_cache()
111
-
112
- # Create a unique folder name
113
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
114
- target_dir = f"input_images_{timestamp}"
115
- target_dir_images = os.path.join(target_dir, "images")
116
-
117
- # Clean up if somehow that folder already exists
118
- if os.path.exists(target_dir):
119
- shutil.rmtree(target_dir)
120
- os.makedirs(target_dir)
121
- os.makedirs(target_dir_images)
122
-
123
- image_paths = []
124
-
125
- # --- Handle images ---
126
- if input_images is not None:
127
- for file_data in input_images:
128
- if isinstance(file_data, dict) and "name" in file_data:
129
- file_path = file_data["name"]
130
- else:
131
- file_path = file_data
132
- dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
133
- shutil.copy(file_path, dst_path)
134
- image_paths.append(dst_path)
135
-
136
- # --- Handle video ---
137
- if input_video is not None:
138
- if isinstance(input_video, dict) and "name" in input_video:
139
- video_path = input_video["name"]
140
- else:
141
- video_path = input_video
142
-
143
- vs = cv2.VideoCapture(video_path)
144
- fps = vs.get(cv2.CAP_PROP_FPS)
145
- frame_interval = int(fps * 1) # 1 frame/sec
146
-
147
- count = 0
148
- video_frame_num = 0
149
- while True:
150
- gotit, frame = vs.read()
151
- if not gotit:
152
- break
153
- count += 1
154
- if count % frame_interval == 0:
155
- image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png")
156
- cv2.imwrite(image_path, frame)
157
- image_paths.append(image_path)
158
- video_frame_num += 1
159
-
160
- # Sort final images for gallery
161
- image_paths = sorted(image_paths)
162
-
163
- end_time = time.time()
164
- print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds")
165
- return target_dir, image_paths
166
-
167
-
168
- # -------------------------------------------------------------------------
169
- # 3) Update gallery on upload
170
- # -------------------------------------------------------------------------
171
- def update_gallery_on_upload(input_video, input_images):
172
- """
173
- Whenever user uploads or changes files, immediately handle them
174
- and show in the gallery. Return (target_dir, image_paths).
175
- If nothing is uploaded, returns "None" and empty list.
176
- """
177
- if not input_video and not input_images:
178
- return None, None, None, None
179
- target_dir, image_paths = handle_uploads(input_video, input_images)
180
- return None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing."
181
-
182
-
183
- # -------------------------------------------------------------------------
184
- # 4) Reconstruction: uses the target_dir plus any viz parameters
185
- # -------------------------------------------------------------------------
186
- @spaces.GPU(duration=120)
187
- def gradio_demo(
188
- target_dir,
189
- conf_thres=3.0,
190
- frame_filter="All",
191
- mask_black_bg=False,
192
- mask_white_bg=False,
193
- show_cam=True,
194
- mask_sky=False,
195
- prediction_mode="Pointmap Regression",
196
- ):
197
- """
198
- Perform reconstruction using the already-created target_dir/images.
199
- """
200
- if not os.path.isdir(target_dir) or target_dir == "None":
201
- return None, "No valid target directory found. Please upload first.", None, None
202
-
203
- start_time = time.time()
204
- gc.collect()
205
- torch.cuda.empty_cache()
206
-
207
- # Prepare frame_filter dropdown
208
- target_dir_images = os.path.join(target_dir, "images")
209
- all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
210
- all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
211
- frame_filter_choices = ["All"] + all_files
212
-
213
- print("Running run_model...")
214
- with torch.no_grad():
215
- predictions = run_model(target_dir, model)
216
-
217
- # Save predictions
218
- prediction_save_path = os.path.join(target_dir, "predictions.npz")
219
- np.savez(prediction_save_path, **predictions)
220
-
221
- # Handle None frame_filter
222
- if frame_filter is None:
223
- frame_filter = "All"
224
-
225
- # Build a GLB file name
226
- glbfile = os.path.join(
227
- target_dir,
228
- f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb",
229
- )
230
-
231
- # Convert predictions to GLB
232
- glbscene = predictions_to_glb(
233
- predictions,
234
- conf_thres=conf_thres,
235
- filter_by_frames=frame_filter,
236
- mask_black_bg=mask_black_bg,
237
- mask_white_bg=mask_white_bg,
238
- show_cam=show_cam,
239
- mask_sky=mask_sky,
240
- target_dir=target_dir,
241
- prediction_mode=prediction_mode,
242
- )
243
- glbscene.export(file_obj=glbfile)
244
-
245
- # Cleanup
246
- del predictions
247
- gc.collect()
248
- torch.cuda.empty_cache()
249
-
250
- end_time = time.time()
251
- print(f"Total time: {end_time - start_time:.2f} seconds")
252
- log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
253
-
254
- return glbfile, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True)
255
-
256
-
257
- # -------------------------------------------------------------------------
258
- # 5) Helper functions for UI resets + re-visualization
259
- # -------------------------------------------------------------------------
260
- def clear_fields():
261
- """
262
- Clears the 3D viewer, the stored target_dir, and empties the gallery.
263
- """
264
- return None
265
-
266
-
267
- def update_log():
268
- """
269
- Display a quick log message while waiting.
270
- """
271
- return "Loading and Reconstructing..."
272
-
273
-
274
- def update_visualization(
275
- target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example
276
- ):
277
- """
278
- Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
279
- and return it for the 3D viewer. If is_example == "True", skip.
280
- """
281
-
282
- # If it's an example click, skip as requested
283
- if is_example == "True":
284
- return None, "No reconstruction available. Please click the Reconstruct button first."
285
-
286
- if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
287
- return None, "No reconstruction available. Please click the Reconstruct button first."
288
-
289
- predictions_path = os.path.join(target_dir, "predictions.npz")
290
- if not os.path.exists(predictions_path):
291
- return None, f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first."
292
-
293
- loaded = np.load(predictions_path, allow_pickle=True)
294
- predictions = {key: loaded[key] for key in loaded.keys()}
295
-
296
- glbfile = os.path.join(
297
- target_dir,
298
- f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb",
299
- )
300
-
301
- if not os.path.exists(glbfile):
302
- glbscene = predictions_to_glb(
303
- predictions,
304
- conf_thres=conf_thres,
305
- filter_by_frames=frame_filter,
306
- mask_black_bg=mask_black_bg,
307
- mask_white_bg=mask_white_bg,
308
- show_cam=show_cam,
309
- mask_sky=mask_sky,
310
- target_dir=target_dir,
311
- prediction_mode=prediction_mode,
312
- )
313
- glbscene.export(file_obj=glbfile)
314
-
315
- return glbfile, "Updating Visualization"
316
-
317
-
318
- # -------------------------------------------------------------------------
319
- # Example images
320
- # -------------------------------------------------------------------------
321
-
322
- # canyon_video = "examples/videos/Studlagil_Canyon_East_Iceland.mp4"
323
- great_wall_video = "examples/videos/great_wall.mp4"
324
- colosseum_video = "examples/videos/Colosseum.mp4"
325
- room_video = "examples/videos/room.mp4"
326
- kitchen_video = "examples/videos/kitchen.mp4"
327
- fern_video = "examples/videos/fern.mp4"
328
- single_cartoon_video = "examples/videos/single_cartoon.mp4"
329
- single_oil_painting_video = "examples/videos/single_oil_painting.mp4"
330
- pyramid_video = "examples/videos/pyramid.mp4"
331
-
332
-
333
- # -------------------------------------------------------------------------
334
- # 6) Build Gradio UI
335
- # -------------------------------------------------------------------------
336
- theme = gr.themes.Ocean()
337
- theme.set(
338
- checkbox_label_background_fill_selected="*button_primary_background_fill",
339
- checkbox_label_text_color_selected="*button_primary_text_color",
340
- )
341
-
342
- with gr.Blocks(
343
- theme=theme,
344
- css="""
345
- .custom-log * {
346
- font-style: italic;
347
- font-size: 22px !important;
348
- background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
349
- -webkit-background-clip: text;
350
- background-clip: text;
351
- font-weight: bold !important;
352
- color: transparent !important;
353
- text-align: center !important;
354
- }
355
-
356
- .example-log * {
357
- font-style: italic;
358
- font-size: 16px !important;
359
- background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
360
- -webkit-background-clip: text;
361
- background-clip: text;
362
- color: transparent !important;
363
- }
364
-
365
- #my_radio .wrap {
366
- display: flex;
367
- flex-wrap: nowrap;
368
- justify-content: center;
369
- align-items: center;
370
- }
371
-
372
- #my_radio .wrap label {
373
- display: flex;
374
- width: 50%;
375
- justify-content: center;
376
- align-items: center;
377
- margin: 0;
378
- padding: 10px 0;
379
- box-sizing: border-box;
380
- }
381
- """,
382
- ) as demo:
383
-
384
- # Instead of gr.State, we use a hidden Textbox:
385
- is_example = gr.Textbox(label="is_example", visible=False, value="None")
386
- num_images = gr.Textbox(label="num_images", visible=False, value="None")
387
-
388
- gr.HTML(
389
- """
390
- <h1>🏛️ VGGT: Visual Geometry Grounded Transformer</h1>
391
- <p>
392
- <a href="https://github.com/facebookresearch/vggt">🌟 GitHub Repository</a> |
393
- <a href="https://vgg-t.github.io/">🚀 Project Page</a>
394
- </p>
395
-
396
- <div style="font-size: 16px; line-height: 1.5;">
397
- <p>Upload a video or a set of images to create a 3D reconstruction of a scene or object. VGGT takes these images and generates all key 3D attributes, including extrinsic and intrinsic camera parameters, point maps, depth maps, and 3D point tracks.</p>
398
-
399
- <h3>Getting Started:</h3>
400
- <ol>
401
- <li><strong>Upload Your Data:</strong> Use the "Upload Video" or "Upload Images" buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).</li>
402
- <li><strong>Preview:</strong> Your uploaded images will appear in the gallery on the left.</li>
403
- <li><strong>Reconstruct:</strong> Click the "Reconstruct" button to start the 3D reconstruction process.</li>
404
- <li><strong>Visualize:</strong> The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note the visualization of 3D points may be slow for a large number of input images.</li>
405
- <li>
406
- <strong>Adjust Visualization (Optional):</strong>
407
- After reconstruction, you can fine-tune the visualization using the options below
408
- <details style="display:inline;">
409
- <summary style="display:inline;">(<strong>click to expand</strong>):</summary>
410
- <ul>
411
- <li><em>Confidence Threshold:</em> Adjust the filtering of points based on confidence.</li>
412
- <li><em>Show Points from Frame:</em> Select specific frames to display in the point cloud.</li>
413
- <li><em>Show Camera:</em> Toggle the display of estimated camera positions.</li>
414
- <li><em>Filter Sky / Filter Black Background:</em> Remove sky or black-background points.</li>
415
- <li><em>Select a Prediction Mode:</em> Choose between "Depthmap and Camera Branch" or "Pointmap Branch."</li>
416
- </ul>
417
- </details>
418
- </li>
419
- </ol>
420
- <p><strong style="color: #0ea5e9;">Please note:</strong> <span style="color: #0ea5e9; font-weight: bold;">Our model itself usually only needs less than 1 second to reconstruct a scene. However, visualizing 3D points may take tens of seconds due to third-party rendering, which are independent of VGGT's processing time. Please be patient or, for faster visualization, use a local machine to run our demo from our <a href="https://github.com/facebookresearch/vggt">GitHub repository</a>. </span></p>
421
- </div>
422
- """
423
- )
424
-
425
- target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
426
-
427
- with gr.Row():
428
- with gr.Column(scale=2):
429
- input_video = gr.Video(label="Upload Video", interactive=True)
430
- input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
431
-
432
- image_gallery = gr.Gallery(
433
- label="Preview",
434
- columns=4,
435
- height="300px",
436
- object_fit="contain",
437
- preview=True,
438
- )
439
-
440
- with gr.Column(scale=4):
441
- with gr.Column():
442
- gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses)**")
443
- log_output = gr.Markdown(
444
- "Please upload a video or images, then click Reconstruct.", elem_classes=["custom-log"]
445
- )
446
- reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5)
447
-
448
- with gr.Row():
449
- submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
450
- clear_btn = gr.ClearButton(
451
- [input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery],
452
- scale=1,
453
- )
454
-
455
- with gr.Row():
456
- prediction_mode = gr.Radio(
457
- ["Depthmap and Camera Branch", "Pointmap Branch"],
458
- label="Select a Prediction Mode",
459
- value="Depthmap and Camera Branch",
460
- scale=1,
461
- elem_id="my_radio",
462
- )
463
-
464
- with gr.Row():
465
- conf_thres = gr.Slider(minimum=0, maximum=100, value=50, step=0.1, label="Confidence Threshold (%)")
466
- frame_filter = gr.Dropdown(choices=["All"], value="All", label="Show Points from Frame")
467
- with gr.Column():
468
- show_cam = gr.Checkbox(label="Show Camera", value=True)
469
- mask_sky = gr.Checkbox(label="Filter Sky", value=False)
470
- mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
471
- mask_white_bg = gr.Checkbox(label="Filter White Background", value=False)
472
-
473
- # ---------------------- Examples section ----------------------
474
- examples = [
475
- [colosseum_video, "22", None, 20.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
476
- [pyramid_video, "30", None, 35.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
477
- [single_cartoon_video, "1", None, 15.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
478
- [single_oil_painting_video, "1", None, 20.0, False, False, True, True, "Depthmap and Camera Branch", "True"],
479
- # [canyon_video, "14", None, 40.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
480
- [room_video, "8", None, 5.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
481
- [kitchen_video, "25", None, 50.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
482
- [fern_video, "20", None, 45.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
483
- ]
484
-
485
- def example_pipeline(
486
- input_video,
487
- num_images_str,
488
- input_images,
489
- conf_thres,
490
- mask_black_bg,
491
- mask_white_bg,
492
- show_cam,
493
- mask_sky,
494
- prediction_mode,
495
- is_example_str,
496
- ):
497
- """
498
- 1) Copy example images to new target_dir
499
- 2) Reconstruct
500
- 3) Return model3D + logs + new_dir + updated dropdown + gallery
501
- We do NOT return is_example. It's just an input.
502
- """
503
- target_dir, image_paths = handle_uploads(input_video, input_images)
504
- # Always use "All" for frame_filter in examples
505
- frame_filter = "All"
506
- glbfile, log_msg, dropdown = gradio_demo(
507
- target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode
508
- )
509
- return glbfile, log_msg, target_dir, dropdown, image_paths
510
-
511
- gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
512
-
513
- gr.Examples(
514
- examples=examples,
515
- inputs=[
516
- input_video,
517
- num_images,
518
- input_images,
519
- conf_thres,
520
- mask_black_bg,
521
- mask_white_bg,
522
- show_cam,
523
- mask_sky,
524
- prediction_mode,
525
- is_example,
526
- ],
527
- outputs=[
528
- reconstruction_output,
529
- log_output,
530
- target_dir_output,
531
- frame_filter,
532
- image_gallery,
533
- ],
534
- fn=example_pipeline,
535
- cache_examples=False,
536
- examples_per_page=50,
537
- )
538
-
539
- # -------------------------------------------------------------------------
540
- # "Reconstruct" button logic:
541
- # - Clear fields
542
- # - Update log
543
- # - gradio_demo(...) with the existing target_dir
544
- # - Then set is_example = "False"
545
- # -------------------------------------------------------------------------
546
- submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then(
547
- fn=update_log, inputs=[], outputs=[log_output]
548
- ).then(
549
- fn=gradio_demo,
550
- inputs=[target_dir_output, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode],
551
- outputs=[reconstruction_output, log_output, frame_filter],
552
- ).then(
553
- fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
554
- )
555
-
556
- # -------------------------------------------------------------------------
557
- # Real-time Visualization Updates
558
- # -------------------------------------------------------------------------
559
- conf_thres.change(
560
- update_visualization,
561
- [target_dir_output, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example],
562
- [reconstruction_output, log_output],
563
- )
564
- frame_filter.change(
565
- update_visualization,
566
- [target_dir_output, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example],
567
- [reconstruction_output, log_output],
568
- )
569
- mask_black_bg.change(
570
- update_visualization,
571
- [target_dir_output, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example],
572
- [reconstruction_output, log_output],
573
- )
574
- mask_white_bg.change(
575
- update_visualization,
576
- [target_dir_output, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example],
577
- [reconstruction_output, log_output],
578
- )
579
- show_cam.change(
580
- update_visualization,
581
- [target_dir_output, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example],
582
- [reconstruction_output, log_output],
583
- )
584
- mask_sky.change(
585
- update_visualization,
586
- [target_dir_output, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example],
587
- [reconstruction_output, log_output],
588
- )
589
- prediction_mode.change(
590
- update_visualization,
591
- [target_dir_output, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example],
592
- [reconstruction_output, log_output],
593
- )
594
-
595
- # -------------------------------------------------------------------------
596
- # Auto-update gallery whenever user uploads or changes their files
597
- # -------------------------------------------------------------------------
598
- input_video.change(
599
- fn=update_gallery_on_upload,
600
- inputs=[input_video, input_images],
601
- outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
602
- )
603
- input_images.change(
604
- fn=update_gallery_on_upload,
605
- inputs=[input_video, input_images],
606
- outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
607
- )
608
-
609
- demo.queue(max_size=20).launch(show_error=True, share=True, ssr_mode=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import cv2
9
+ import torch
10
+ import numpy as np
11
+ import gradio as gr
12
+ import sys
13
+ import shutil
14
+ from datetime import datetime
15
+ import glob
16
+ import gc
17
+ import time
18
+ import spaces
19
+ import requests
20
+ import websocket
21
+ import json
22
+ import uuid
23
+ import base64
24
+ import io
25
+
26
+
27
+ sys.path.append("vggt/")
28
+
29
+ from visual_util import predictions_to_glb
30
+ from vggt.utils.pose_enc import pose_encoding_to_extri_intri
31
+ from vggt.utils.geometry import unproject_depth_map_to_point_map
32
+
33
+ # Remote VGGT service host
34
+ VGGT_HOST = os.getenv("VGGT_HOST", "134.199.132.159")
35
+
36
+ # No longer loading model locally
37
+ # model = None
38
+
39
+
40
+ # -------------------------------------------------------------------------
41
+ # Remote service communication functions
42
+ # -------------------------------------------------------------------------
43
+ def _open_ws(client_id: str, token: str):
44
+ """Open WebSocket connection to remote VGGT service"""
45
+ ws = websocket.WebSocket()
46
+ ws.connect(f"ws://{VGGT_HOST}/ws?clientId={client_id}&token={token}", timeout=1800)
47
+ return ws
48
+
49
+
50
+ def _submit_inference(target_dir: str, client_id: str, token: str) -> str:
51
+ """Submit inference job to remote VGGT service"""
52
+ # Prepare image files for upload
53
+ image_names = glob.glob(os.path.join(target_dir, "images", "*"))
54
+ image_names = sorted(image_names)
55
+
56
+ if len(image_names) == 0:
57
+ raise ValueError("No images found. Check your upload.")
58
+
59
+ # Encode images as base64
60
+ images_data = []
61
+ for img_path in image_names:
62
+ with open(img_path, "rb") as f:
63
+ img_bytes = f.read()
64
+ img_b64 = base64.b64encode(img_bytes).decode("utf-8")
65
+ images_data.append(
66
+ {"filename": os.path.basename(img_path), "data": img_b64}
67
+ )
68
+
69
+ payload = {"images": images_data, "client_id": client_id}
70
+
71
+ resp = requests.post(
72
+ f"http://{VGGT_HOST}/inference?token={token}", json=payload, timeout=1800
73
+ )
74
+
75
+ if resp.status_code != 200:
76
+ raise RuntimeError(f"VGGT service /inference err: {resp.text}")
77
+
78
+ data = resp.json()
79
+ if "job_id" not in data:
80
+ raise RuntimeError(f"/inference no job_id: {data}")
81
+
82
+ return data["job_id"]
83
+
84
+
85
+ def _get_result(job_id: str, token: str) -> dict:
86
+ """Get inference result from remote VGGT service"""
87
+ resp = requests.get(
88
+ f"http://{VGGT_HOST}/result/{job_id}?token={token}", timeout=1800
89
+ )
90
+ resp.raise_for_status()
91
+ result = resp.json()
92
+ return result.get(job_id, {})
93
+
94
+
95
+ # -------------------------------------------------------------------------
96
+ # 1) Core model inference (now forwards to remote service)
97
+ # -------------------------------------------------------------------------
98
+ @spaces.GPU(duration=120)
99
+ def run_model(target_dir, model=None) -> dict:
100
+ """
101
+ Run the VGGT model on images in the 'target_dir/images' folder and return predictions.
102
+ Now forwards to remote VGGT service instead of running locally.
103
+ """
104
+ print(f"Processing images from {target_dir}")
105
+
106
+ # Load image names for validation
107
+ image_names = glob.glob(os.path.join(target_dir, "images", "*"))
108
+ image_names = sorted(image_names)
109
+ print(f"Found {len(image_names)} images")
110
+ if len(image_names) == 0:
111
+ raise ValueError("No images found. Check your upload.")
112
+
113
+ # Generate client ID and token
114
+ client_id = str(uuid.uuid4())
115
+ token = str(uuid.uuid4())
116
+
117
+ # Open WebSocket for progress updates
118
+ print("Connecting to remote VGGT service...")
119
+ ws = _open_ws(client_id, token)
120
+
121
+ # Submit inference job
122
+ print("Submitting inference job...")
123
+ job_id = _submit_inference(target_dir, client_id, token)
124
+
125
+ # Monitor progress via WebSocket
126
+ print("Monitoring inference progress...")
127
+ ws.settimeout(180)
128
+ while True:
129
+ try:
130
+ out = ws.recv()
131
+ if isinstance(out, (bytes, bytearray)):
132
+ continue
133
+
134
+ msg = json.loads(out)
135
+ if msg.get("type") == "executing":
136
+ data = msg.get("data", {})
137
+ if data.get("job_id") != job_id:
138
+ continue
139
+ node = data.get("node")
140
+ if node is None:
141
+ # Job complete
142
+ break
143
+ print(f"Processing node: {node}")
144
+ except Exception as e:
145
+ print(f"WebSocket error: {e}")
146
+ break
147
+
148
+ ws.close()
149
+
150
+ # Get final result
151
+ print("Retrieving inference results...")
152
+ result = _get_result(job_id, token)
153
+
154
+ if "predictions" not in result:
155
+ raise RuntimeError(f"No predictions in result: {result}")
156
+
157
+ # Deserialize predictions from base64-encoded numpy arrays
158
+ predictions = {}
159
+ for key, value in result["predictions"].items():
160
+ if isinstance(value, str):
161
+ # Decode base64 numpy array
162
+ arr_bytes = base64.b64decode(value)
163
+ predictions[key] = np.load(io.BytesIO(arr_bytes), allow_pickle=True)
164
+ else:
165
+ predictions[key] = np.array(value)
166
+
167
+ # Post-process predictions (same as before)
168
+ print("Post-processing predictions...")
169
+
170
+ # Generate world points from depth map if not already present
171
+ if "world_points_from_depth" not in predictions and "depth" in predictions:
172
+ print("Computing world points from depth map...")
173
+ depth_map = predictions["depth"]
174
+ world_points = unproject_depth_map_to_point_map(
175
+ depth_map, predictions["extrinsic"], predictions["intrinsic"]
176
+ )
177
+ predictions["world_points_from_depth"] = world_points
178
+
179
+ return predictions
180
+
181
+
182
+ # -------------------------------------------------------------------------
183
+ # 2) Handle uploaded video/images --> produce target_dir + images
184
+ # -------------------------------------------------------------------------
185
+ def handle_uploads(input_video, input_images):
186
+ """
187
+ Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
188
+ images or extracted frames from video into it. Return (target_dir, image_paths).
189
+ """
190
+ start_time = time.time()
191
+ gc.collect()
192
+ torch.cuda.empty_cache()
193
+
194
+ # Create a unique folder name
195
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
196
+ target_dir = f"input_images_{timestamp}"
197
+ target_dir_images = os.path.join(target_dir, "images")
198
+
199
+ # Clean up if somehow that folder already exists
200
+ if os.path.exists(target_dir):
201
+ shutil.rmtree(target_dir)
202
+ os.makedirs(target_dir)
203
+ os.makedirs(target_dir_images)
204
+
205
+ image_paths = []
206
+
207
+ # --- Handle images ---
208
+ if input_images is not None:
209
+ for file_data in input_images:
210
+ if isinstance(file_data, dict) and "name" in file_data:
211
+ file_path = file_data["name"]
212
+ else:
213
+ file_path = file_data
214
+ dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
215
+ shutil.copy(file_path, dst_path)
216
+ image_paths.append(dst_path)
217
+
218
+ # --- Handle video ---
219
+ if input_video is not None:
220
+ if isinstance(input_video, dict) and "name" in input_video:
221
+ video_path = input_video["name"]
222
+ else:
223
+ video_path = input_video
224
+
225
+ vs = cv2.VideoCapture(video_path)
226
+ fps = vs.get(cv2.CAP_PROP_FPS)
227
+ frame_interval = int(fps * 1) # 1 frame/sec
228
+
229
+ count = 0
230
+ video_frame_num = 0
231
+ while True:
232
+ gotit, frame = vs.read()
233
+ if not gotit:
234
+ break
235
+ count += 1
236
+ if count % frame_interval == 0:
237
+ image_path = os.path.join(
238
+ target_dir_images, f"{video_frame_num:06}.png"
239
+ )
240
+ cv2.imwrite(image_path, frame)
241
+ image_paths.append(image_path)
242
+ video_frame_num += 1
243
+
244
+ # Sort final images for gallery
245
+ image_paths = sorted(image_paths)
246
+
247
+ end_time = time.time()
248
+ print(
249
+ f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds"
250
+ )
251
+ return target_dir, image_paths
252
+
253
+
254
+ # -------------------------------------------------------------------------
255
+ # 3) Update gallery on upload
256
+ # -------------------------------------------------------------------------
257
+ def update_gallery_on_upload(input_video, input_images):
258
+ """
259
+ Whenever user uploads or changes files, immediately handle them
260
+ and show in the gallery. Return (target_dir, image_paths).
261
+ If nothing is uploaded, returns "None" and empty list.
262
+ """
263
+ if not input_video and not input_images:
264
+ return None, None, None, None
265
+ target_dir, image_paths = handle_uploads(input_video, input_images)
266
+ return (
267
+ None,
268
+ target_dir,
269
+ image_paths,
270
+ "Upload complete. Click 'Reconstruct' to begin 3D processing.",
271
+ )
272
+
273
+
274
+ # -------------------------------------------------------------------------
275
+ # 4) Reconstruction: uses the target_dir plus any viz parameters
276
+ # -------------------------------------------------------------------------
277
+ @spaces.GPU(duration=120)
278
+ def gradio_demo(
279
+ target_dir,
280
+ conf_thres=3.0,
281
+ frame_filter="All",
282
+ mask_black_bg=False,
283
+ mask_white_bg=False,
284
+ show_cam=True,
285
+ mask_sky=False,
286
+ prediction_mode="Pointmap Regression",
287
+ ):
288
+ """
289
+ Perform reconstruction using the already-created target_dir/images.
290
+ """
291
+ if not os.path.isdir(target_dir) or target_dir == "None":
292
+ return None, "No valid target directory found. Please upload first.", None, None
293
+
294
+ start_time = time.time()
295
+ gc.collect()
296
+ torch.cuda.empty_cache()
297
+
298
+ # Prepare frame_filter dropdown
299
+ target_dir_images = os.path.join(target_dir, "images")
300
+ all_files = (
301
+ sorted(os.listdir(target_dir_images))
302
+ if os.path.isdir(target_dir_images)
303
+ else []
304
+ )
305
+ all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
306
+ frame_filter_choices = ["All"] + all_files
307
+
308
+ print("Running run_model...")
309
+ with torch.no_grad():
310
+ predictions = run_model(target_dir)
311
+
312
+ # Save predictions
313
+ prediction_save_path = os.path.join(target_dir, "predictions.npz")
314
+ np.savez(prediction_save_path, **predictions)
315
+
316
+ # Handle None frame_filter
317
+ if frame_filter is None:
318
+ frame_filter = "All"
319
+
320
+ # Build a GLB file name
321
+ glbfile = os.path.join(
322
+ target_dir,
323
+ f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb",
324
+ )
325
+
326
+ # Convert predictions to GLB
327
+ glbscene = predictions_to_glb(
328
+ predictions,
329
+ conf_thres=conf_thres,
330
+ filter_by_frames=frame_filter,
331
+ mask_black_bg=mask_black_bg,
332
+ mask_white_bg=mask_white_bg,
333
+ show_cam=show_cam,
334
+ mask_sky=mask_sky,
335
+ target_dir=target_dir,
336
+ prediction_mode=prediction_mode,
337
+ )
338
+ glbscene.export(file_obj=glbfile)
339
+
340
+ # Cleanup
341
+ del predictions
342
+ gc.collect()
343
+ torch.cuda.empty_cache()
344
+
345
+ end_time = time.time()
346
+ print(f"Total time: {end_time - start_time:.2f} seconds")
347
+ log_msg = (
348
+ f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
349
+ )
350
+
351
+ return (
352
+ glbfile,
353
+ log_msg,
354
+ gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
355
+ )
356
+
357
+
358
+ # -------------------------------------------------------------------------
359
+ # 5) Helper functions for UI resets + re-visualization
360
+ # -------------------------------------------------------------------------
361
+ def clear_fields():
362
+ """
363
+ Clears the 3D viewer, the stored target_dir, and empties the gallery.
364
+ """
365
+ return None
366
+
367
+
368
+ def update_log():
369
+ """
370
+ Display a quick log message while waiting.
371
+ """
372
+ return "Loading and Reconstructing..."
373
+
374
+
375
+ def update_visualization(
376
+ target_dir,
377
+ conf_thres,
378
+ frame_filter,
379
+ mask_black_bg,
380
+ mask_white_bg,
381
+ show_cam,
382
+ mask_sky,
383
+ prediction_mode,
384
+ is_example,
385
+ ):
386
+ """
387
+ Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
388
+ and return it for the 3D viewer. If is_example == "True", skip.
389
+ """
390
+
391
+ # If it's an example click, skip as requested
392
+ if is_example == "True":
393
+ return (
394
+ None,
395
+ "No reconstruction available. Please click the Reconstruct button first.",
396
+ )
397
+
398
+ if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
399
+ return (
400
+ None,
401
+ "No reconstruction available. Please click the Reconstruct button first.",
402
+ )
403
+
404
+ predictions_path = os.path.join(target_dir, "predictions.npz")
405
+ if not os.path.exists(predictions_path):
406
+ return (
407
+ None,
408
+ f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.",
409
+ )
410
+
411
+ loaded = np.load(predictions_path, allow_pickle=True)
412
+ predictions = {key: loaded[key] for key in loaded.keys()}
413
+
414
+ glbfile = os.path.join(
415
+ target_dir,
416
+ f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb",
417
+ )
418
+
419
+ if not os.path.exists(glbfile):
420
+ glbscene = predictions_to_glb(
421
+ predictions,
422
+ conf_thres=conf_thres,
423
+ filter_by_frames=frame_filter,
424
+ mask_black_bg=mask_black_bg,
425
+ mask_white_bg=mask_white_bg,
426
+ show_cam=show_cam,
427
+ mask_sky=mask_sky,
428
+ target_dir=target_dir,
429
+ prediction_mode=prediction_mode,
430
+ )
431
+ glbscene.export(file_obj=glbfile)
432
+
433
+ return glbfile, "Updating Visualization"
434
+
435
+
436
+ # -------------------------------------------------------------------------
437
+ # Example images
438
+ # -------------------------------------------------------------------------
439
+
440
+ # Get the absolute directory of app.py to ensure correct resource paths
441
+ APP_DIR = os.path.dirname(os.path.abspath(__file__))
442
+
443
+ # Use absolute paths for all example videos to ensure they load correctly in containerized environments
444
+ # canyon_video = os.path.join(APP_DIR, "examples/videos/Studlagil_Canyon_East_Iceland.mp4")
445
+ great_wall_video = os.path.join(APP_DIR, "examples/videos/great_wall.mp4")
446
+ colosseum_video = os.path.join(APP_DIR, "examples/videos/Colosseum.mp4")
447
+ room_video = os.path.join(APP_DIR, "examples/videos/room.mp4")
448
+ kitchen_video = os.path.join(APP_DIR, "examples/videos/kitchen.mp4")
449
+ fern_video = os.path.join(APP_DIR, "examples/videos/fern.mp4")
450
+ single_cartoon_video = os.path.join(APP_DIR, "examples/videos/single_cartoon.mp4")
451
+ single_oil_painting_video = os.path.join(
452
+ APP_DIR, "examples/videos/single_oil_painting.mp4"
453
+ )
454
+ pyramid_video = os.path.join(APP_DIR, "examples/videos/pyramid.mp4")
455
+
456
+
457
+ # -------------------------------------------------------------------------
458
+ # 6) Build Gradio UI
459
+ # -------------------------------------------------------------------------
460
+ theme = gr.themes.Ocean()
461
+ theme.set(
462
+ checkbox_label_background_fill_selected="*button_primary_background_fill",
463
+ checkbox_label_text_color_selected="*button_primary_text_color",
464
+ )
465
+
466
+ with gr.Blocks(
467
+ theme=theme,
468
+ css="""
469
+ .custom-log * {
470
+ font-style: italic;
471
+ font-size: 22px !important;
472
+ background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
473
+ -webkit-background-clip: text;
474
+ background-clip: text;
475
+ font-weight: bold !important;
476
+ color: transparent !important;
477
+ text-align: center !important;
478
+ }
479
+
480
+ .example-log * {
481
+ font-style: italic;
482
+ font-size: 16px !important;
483
+ background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
484
+ -webkit-background-clip: text;
485
+ background-clip: text;
486
+ color: transparent !important;
487
+ }
488
+
489
+ #my_radio .wrap {
490
+ display: flex;
491
+ flex-wrap: nowrap;
492
+ justify-content: center;
493
+ align-items: center;
494
+ }
495
+
496
+ #my_radio .wrap label {
497
+ display: flex;
498
+ width: 50%;
499
+ justify-content: center;
500
+ align-items: center;
501
+ margin: 0;
502
+ padding: 10px 0;
503
+ box-sizing: border-box;
504
+ }
505
+ """,
506
+ ) as demo:
507
+
508
+ # Instead of gr.State, we use a hidden Textbox:
509
+ is_example = gr.Textbox(label="is_example", visible=False, value="None")
510
+ num_images = gr.Textbox(label="num_images", visible=False, value="None")
511
+
512
+ gr.HTML(
513
+ """
514
+ <h1>🏛️ VGGT: Visual Geometry Grounded Transformer</h1>
515
+ <p>
516
+ <a href="https://github.com/facebookresearch/vggt">🌟 GitHub Repository</a> |
517
+ <a href="https://vgg-t.github.io/">🚀 Project Page</a>
518
+ </p>
519
+
520
+ <div style="font-size: 16px; line-height: 1.5;">
521
+ <p>Upload a video or a set of images to create a 3D reconstruction of a scene or object. VGGT takes these images and generates all key 3D attributes, including extrinsic and intrinsic camera parameters, point maps, depth maps, and 3D point tracks.</p>
522
+
523
+ <h3>Getting Started:</h3>
524
+ <ol>
525
+ <li><strong>Upload Your Data:</strong> Use the "Upload Video" or "Upload Images" buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).</li>
526
+ <li><strong>Preview:</strong> Your uploaded images will appear in the gallery on the left.</li>
527
+ <li><strong>Reconstruct:</strong> Click the "Reconstruct" button to start the 3D reconstruction process.</li>
528
+ <li><strong>Visualize:</strong> The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note the visualization of 3D points may be slow for a large number of input images.</li>
529
+ <li>
530
+ <strong>Adjust Visualization (Optional):</strong>
531
+ After reconstruction, you can fine-tune the visualization using the options below
532
+ <details style="display:inline;">
533
+ <summary style="display:inline;">(<strong>click to expand</strong>):</summary>
534
+ <ul>
535
+ <li><em>Confidence Threshold:</em> Adjust the filtering of points based on confidence.</li>
536
+ <li><em>Show Points from Frame:</em> Select specific frames to display in the point cloud.</li>
537
+ <li><em>Show Camera:</em> Toggle the display of estimated camera positions.</li>
538
+ <li><em>Filter Sky / Filter Black Background:</em> Remove sky or black-background points.</li>
539
+ <li><em>Select a Prediction Mode:</em> Choose between "Depthmap and Camera Branch" or "Pointmap Branch."</li>
540
+ </ul>
541
+ </details>
542
+ </li>
543
+ </ol>
544
+ <p><strong style="color: #0ea5e9;">Please note:</strong> <span style="color: #0ea5e9; font-weight: bold;">Our model itself usually only needs less than 1 second to reconstruct a scene. However, visualizing 3D points may take tens of seconds due to third-party rendering, which are independent of VGGT's processing time. Please be patient or, for faster visualization, use a local machine to run our demo from our <a href="https://github.com/facebookresearch/vggt">GitHub repository</a>. </span></p>
545
+ </div>
546
+ """
547
+ )
548
+
549
+ target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
550
+
551
+ with gr.Row():
552
+ with gr.Column(scale=2):
553
+ input_video = gr.Video(label="Upload Video", interactive=True)
554
+ input_images = gr.File(
555
+ file_count="multiple", label="Upload Images", interactive=True
556
+ )
557
+
558
+ image_gallery = gr.Gallery(
559
+ label="Preview",
560
+ columns=4,
561
+ height="300px",
562
+ object_fit="contain",
563
+ preview=True,
564
+ )
565
+
566
+ with gr.Column(scale=4):
567
+ with gr.Column():
568
+ gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses)**")
569
+ log_output = gr.Markdown(
570
+ "Please upload a video or images, then click Reconstruct.",
571
+ elem_classes=["custom-log"],
572
+ )
573
+ reconstruction_output = gr.Model3D(
574
+ height=520, zoom_speed=0.5, pan_speed=0.5
575
+ )
576
+
577
+ with gr.Row():
578
+ submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
579
+ clear_btn = gr.ClearButton(
580
+ [
581
+ input_video,
582
+ input_images,
583
+ reconstruction_output,
584
+ log_output,
585
+ target_dir_output,
586
+ image_gallery,
587
+ ],
588
+ scale=1,
589
+ )
590
+
591
+ with gr.Row():
592
+ prediction_mode = gr.Radio(
593
+ ["Depthmap and Camera Branch", "Pointmap Branch"],
594
+ label="Select a Prediction Mode",
595
+ value="Depthmap and Camera Branch",
596
+ scale=1,
597
+ elem_id="my_radio",
598
+ )
599
+
600
+ with gr.Row():
601
+ conf_thres = gr.Slider(
602
+ minimum=0,
603
+ maximum=100,
604
+ value=50,
605
+ step=0.1,
606
+ label="Confidence Threshold (%)",
607
+ )
608
+ frame_filter = gr.Dropdown(
609
+ choices=["All"], value="All", label="Show Points from Frame"
610
+ )
611
+ with gr.Column():
612
+ show_cam = gr.Checkbox(label="Show Camera", value=True)
613
+ mask_sky = gr.Checkbox(label="Filter Sky", value=False)
614
+ mask_black_bg = gr.Checkbox(
615
+ label="Filter Black Background", value=False
616
+ )
617
+ mask_white_bg = gr.Checkbox(
618
+ label="Filter White Background", value=False
619
+ )
620
+
621
+ # ---------------------- Examples section ----------------------
622
+ examples = [
623
+ [
624
+ colosseum_video,
625
+ "22",
626
+ None,
627
+ 20.0,
628
+ False,
629
+ False,
630
+ True,
631
+ False,
632
+ "Depthmap and Camera Branch",
633
+ "True",
634
+ ],
635
+ [
636
+ pyramid_video,
637
+ "30",
638
+ None,
639
+ 35.0,
640
+ False,
641
+ False,
642
+ True,
643
+ False,
644
+ "Depthmap and Camera Branch",
645
+ "True",
646
+ ],
647
+ [
648
+ single_cartoon_video,
649
+ "1",
650
+ None,
651
+ 15.0,
652
+ False,
653
+ False,
654
+ True,
655
+ False,
656
+ "Depthmap and Camera Branch",
657
+ "True",
658
+ ],
659
+ [
660
+ single_oil_painting_video,
661
+ "1",
662
+ None,
663
+ 20.0,
664
+ False,
665
+ False,
666
+ True,
667
+ True,
668
+ "Depthmap and Camera Branch",
669
+ "True",
670
+ ],
671
+ # [canyon_video, "14", None, 40.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
672
+ [
673
+ room_video,
674
+ "8",
675
+ None,
676
+ 5.0,
677
+ False,
678
+ False,
679
+ True,
680
+ False,
681
+ "Depthmap and Camera Branch",
682
+ "True",
683
+ ],
684
+ [
685
+ kitchen_video,
686
+ "25",
687
+ None,
688
+ 50.0,
689
+ False,
690
+ False,
691
+ True,
692
+ False,
693
+ "Depthmap and Camera Branch",
694
+ "True",
695
+ ],
696
+ [
697
+ fern_video,
698
+ "20",
699
+ None,
700
+ 45.0,
701
+ False,
702
+ False,
703
+ True,
704
+ False,
705
+ "Depthmap and Camera Branch",
706
+ "True",
707
+ ],
708
+ ]
709
+
710
+ def example_pipeline(
711
+ input_video,
712
+ num_images_str,
713
+ input_images,
714
+ conf_thres,
715
+ mask_black_bg,
716
+ mask_white_bg,
717
+ show_cam,
718
+ mask_sky,
719
+ prediction_mode,
720
+ is_example_str,
721
+ ):
722
+ """
723
+ 1) Copy example images to new target_dir
724
+ 2) Reconstruct
725
+ 3) Return model3D + logs + new_dir + updated dropdown + gallery
726
+ We do NOT return is_example. It's just an input.
727
+ """
728
+ target_dir, image_paths = handle_uploads(input_video, input_images)
729
+ # Always use "All" for frame_filter in examples
730
+ frame_filter = "All"
731
+ glbfile, log_msg, dropdown = gradio_demo(
732
+ target_dir,
733
+ conf_thres,
734
+ frame_filter,
735
+ mask_black_bg,
736
+ mask_white_bg,
737
+ show_cam,
738
+ mask_sky,
739
+ prediction_mode,
740
+ )
741
+ return glbfile, log_msg, target_dir, dropdown, image_paths
742
+
743
+ gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
744
+
745
+ gr.Examples(
746
+ examples=examples,
747
+ inputs=[
748
+ input_video,
749
+ num_images,
750
+ input_images,
751
+ conf_thres,
752
+ mask_black_bg,
753
+ mask_white_bg,
754
+ show_cam,
755
+ mask_sky,
756
+ prediction_mode,
757
+ is_example,
758
+ ],
759
+ outputs=[
760
+ reconstruction_output,
761
+ log_output,
762
+ target_dir_output,
763
+ frame_filter,
764
+ image_gallery,
765
+ ],
766
+ fn=example_pipeline,
767
+ cache_examples=False,
768
+ examples_per_page=50,
769
+ )
770
+
771
+ # -------------------------------------------------------------------------
772
+ # "Reconstruct" button logic:
773
+ # - Clear fields
774
+ # - Update log
775
+ # - gradio_demo(...) with the existing target_dir
776
+ # - Then set is_example = "False"
777
+ # -------------------------------------------------------------------------
778
+ submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then(
779
+ fn=update_log, inputs=[], outputs=[log_output]
780
+ ).then(
781
+ fn=gradio_demo,
782
+ inputs=[
783
+ target_dir_output,
784
+ conf_thres,
785
+ frame_filter,
786
+ mask_black_bg,
787
+ mask_white_bg,
788
+ show_cam,
789
+ mask_sky,
790
+ prediction_mode,
791
+ ],
792
+ outputs=[reconstruction_output, log_output, frame_filter],
793
+ ).then(
794
+ fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
795
+ )
796
+
797
+ # -------------------------------------------------------------------------
798
+ # Real-time Visualization Updates
799
+ # -------------------------------------------------------------------------
800
+ conf_thres.change(
801
+ update_visualization,
802
+ [
803
+ target_dir_output,
804
+ conf_thres,
805
+ frame_filter,
806
+ mask_black_bg,
807
+ mask_white_bg,
808
+ show_cam,
809
+ mask_sky,
810
+ prediction_mode,
811
+ is_example,
812
+ ],
813
+ [reconstruction_output, log_output],
814
+ )
815
+ frame_filter.change(
816
+ update_visualization,
817
+ [
818
+ target_dir_output,
819
+ conf_thres,
820
+ frame_filter,
821
+ mask_black_bg,
822
+ mask_white_bg,
823
+ show_cam,
824
+ mask_sky,
825
+ prediction_mode,
826
+ is_example,
827
+ ],
828
+ [reconstruction_output, log_output],
829
+ )
830
+ mask_black_bg.change(
831
+ update_visualization,
832
+ [
833
+ target_dir_output,
834
+ conf_thres,
835
+ frame_filter,
836
+ mask_black_bg,
837
+ mask_white_bg,
838
+ show_cam,
839
+ mask_sky,
840
+ prediction_mode,
841
+ is_example,
842
+ ],
843
+ [reconstruction_output, log_output],
844
+ )
845
+ mask_white_bg.change(
846
+ update_visualization,
847
+ [
848
+ target_dir_output,
849
+ conf_thres,
850
+ frame_filter,
851
+ mask_black_bg,
852
+ mask_white_bg,
853
+ show_cam,
854
+ mask_sky,
855
+ prediction_mode,
856
+ is_example,
857
+ ],
858
+ [reconstruction_output, log_output],
859
+ )
860
+ show_cam.change(
861
+ update_visualization,
862
+ [
863
+ target_dir_output,
864
+ conf_thres,
865
+ frame_filter,
866
+ mask_black_bg,
867
+ mask_white_bg,
868
+ show_cam,
869
+ mask_sky,
870
+ prediction_mode,
871
+ is_example,
872
+ ],
873
+ [reconstruction_output, log_output],
874
+ )
875
+ mask_sky.change(
876
+ update_visualization,
877
+ [
878
+ target_dir_output,
879
+ conf_thres,
880
+ frame_filter,
881
+ mask_black_bg,
882
+ mask_white_bg,
883
+ show_cam,
884
+ mask_sky,
885
+ prediction_mode,
886
+ is_example,
887
+ ],
888
+ [reconstruction_output, log_output],
889
+ )
890
+ prediction_mode.change(
891
+ update_visualization,
892
+ [
893
+ target_dir_output,
894
+ conf_thres,
895
+ frame_filter,
896
+ mask_black_bg,
897
+ mask_white_bg,
898
+ show_cam,
899
+ mask_sky,
900
+ prediction_mode,
901
+ is_example,
902
+ ],
903
+ [reconstruction_output, log_output],
904
+ )
905
+
906
+ # -------------------------------------------------------------------------
907
+ # Auto-update gallery whenever user uploads or changes their files
908
+ # -------------------------------------------------------------------------
909
+ input_video.change(
910
+ fn=update_gallery_on_upload,
911
+ inputs=[input_video, input_images],
912
+ outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
913
+ )
914
+ input_images.change(
915
+ fn=update_gallery_on_upload,
916
+ inputs=[input_video, input_images],
917
+ outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
918
+ )
919
+
920
+ demo.queue(max_size=20).launch(show_error=True, share=True, ssr_mode=False)