# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os import cv2 import numpy as np import gradio as gr import sys import shutil from datetime import datetime import glob import gc import time import spaces import requests import websocket import json import uuid import base64 import io sys.path.append("vggt/") from visual_util import predictions_to_glb # Remote VGGT service host VGGT_HOST = os.getenv("VGGT_HOST", "134.199.133.34") # No longer loading model locally # model = None # ------------------------------------------------------------------------- # Remote service communication functions # ------------------------------------------------------------------------- def _open_ws(client_id: str, token: str): """Open WebSocket connection to remote VGGT service""" ws = websocket.WebSocket() ws.connect(f"ws://{VGGT_HOST}/ws?clientId={client_id}&token={token}", timeout=1800) return ws def _submit_inference(target_dir: str, client_id: str, token: str) -> str: """Submit inference job to remote VGGT service""" # Prepare image files for upload image_names = glob.glob(os.path.join(target_dir, "images", "*")) image_names = sorted(image_names) if len(image_names) == 0: raise ValueError("No images found. Check your upload.") # Encode images as base64 images_data = [] for img_path in image_names: with open(img_path, "rb") as f: img_bytes = f.read() img_b64 = base64.b64encode(img_bytes).decode("utf-8") images_data.append( {"filename": os.path.basename(img_path), "data": img_b64} ) payload = {"images": images_data, "client_id": client_id} resp = requests.post( f"http://{VGGT_HOST}/inference?token={token}", json=payload, timeout=1800 ) if resp.status_code != 200: raise RuntimeError(f"VGGT service /inference err: {resp.text}") data = resp.json() if "job_id" not in data: raise RuntimeError(f"/inference no job_id: {data}") return data["job_id"] def _get_result(job_id: str, token: str) -> dict: """Get inference result from remote VGGT service""" resp = requests.get( f"http://{VGGT_HOST}/result/{job_id}?token={token}", timeout=1800 ) resp.raise_for_status() result = resp.json() return result.get(job_id, {}) # ------------------------------------------------------------------------- # 1) Core model inference (now forwards to remote service) # ------------------------------------------------------------------------- def run_model(target_dir, model=None) -> dict: """ Run the VGGT model on images in the 'target_dir/images' folder and return predictions. Now forwards to remote VGGT service instead of running locally. """ print(f"Processing images from {target_dir}") # Load image names for validation image_names = glob.glob(os.path.join(target_dir, "images", "*")) image_names = sorted(image_names) print(f"Found {len(image_names)} images") if len(image_names) == 0: raise ValueError("No images found. Check your upload.") # Generate client ID and token client_id = str(uuid.uuid4()) token = str(uuid.uuid4()) # Open WebSocket for progress updates print("Connecting to remote VGGT service...") ws = _open_ws(client_id, token) # Submit inference job print("Submitting inference job...") job_id = _submit_inference(target_dir, client_id, token) # Monitor progress via WebSocket print("Monitoring inference progress...") ws.settimeout(180) while True: try: out = ws.recv() if isinstance(out, (bytes, bytearray)): continue msg = json.loads(out) if msg.get("type") == "executing": data = msg.get("data", {}) if data.get("job_id") != job_id: continue node = data.get("node") if node is None: # Job complete break print(f"Processing node: {node}") except Exception as e: print(f"WebSocket error: {e}") break ws.close() # Get final result print("Retrieving inference results...") result = _get_result(job_id, token) if "predictions" not in result: raise RuntimeError(f"No predictions in result: {result}") # Deserialize predictions from base64-encoded numpy arrays predictions = {} for key, value in result["predictions"].items(): if isinstance(value, str): # Decode base64 numpy array arr_bytes = base64.b64decode(value) predictions[key] = np.load(io.BytesIO(arr_bytes), allow_pickle=True) else: predictions[key] = np.array(value) # Backend has already computed world_points_from_depth print("Received predictions from backend") return predictions # ------------------------------------------------------------------------- # 2) Handle uploaded video/images --> produce target_dir + images # ------------------------------------------------------------------------- def handle_uploads(input_video, input_images): """ Create a new 'target_dir' + 'images' subfolder, and place user-uploaded images or extracted frames from video into it. Return (target_dir, image_paths). """ start_time = time.time() gc.collect() # Create a unique folder name timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") target_dir = f"input_images_{timestamp}" target_dir_images = os.path.join(target_dir, "images") # Clean up if somehow that folder already exists if os.path.exists(target_dir): shutil.rmtree(target_dir) os.makedirs(target_dir) os.makedirs(target_dir_images) image_paths = [] # --- Handle images --- if input_images is not None: for file_data in input_images: if isinstance(file_data, dict) and "name" in file_data: file_path = file_data["name"] else: file_path = file_data dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) shutil.copy(file_path, dst_path) image_paths.append(dst_path) # --- Handle video --- if input_video is not None: if isinstance(input_video, dict) and "name" in input_video: video_path = input_video["name"] else: video_path = input_video vs = cv2.VideoCapture(video_path) fps = vs.get(cv2.CAP_PROP_FPS) frame_interval = int(fps * 1) # 1 frame/sec count = 0 video_frame_num = 0 while True: gotit, frame = vs.read() if not gotit: break count += 1 if count % frame_interval == 0: image_path = os.path.join( target_dir_images, f"{video_frame_num:06}.png" ) cv2.imwrite(image_path, frame) image_paths.append(image_path) video_frame_num += 1 # Sort final images for gallery image_paths = sorted(image_paths) end_time = time.time() print( f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds" ) return target_dir, image_paths # ------------------------------------------------------------------------- # 3) Update gallery on upload # ------------------------------------------------------------------------- def update_gallery_on_upload(input_video, input_images): """ Whenever user uploads or changes files, immediately handle them and show in the gallery. Return (target_dir, image_paths). If nothing is uploaded, returns "None" and empty list. """ if not input_video and not input_images: return None, None, None, None target_dir, image_paths = handle_uploads(input_video, input_images) return ( None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing.", ) # ------------------------------------------------------------------------- # 4) Reconstruction: uses the target_dir plus any viz parameters # ------------------------------------------------------------------------- def gradio_demo( target_dir, conf_thres=3.0, frame_filter="All", mask_black_bg=False, mask_white_bg=False, show_cam=True, mask_sky=False, prediction_mode="Pointmap Regression", ): """ Perform reconstruction using the already-created target_dir/images. """ if not os.path.isdir(target_dir) or target_dir == "None": return None, "No valid target directory found. Please upload first.", None, None start_time = time.time() gc.collect() # Prepare frame_filter dropdown target_dir_images = os.path.join(target_dir, "images") all_files = ( sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else [] ) all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)] frame_filter_choices = ["All"] + all_files print("Running run_model...") predictions = run_model(target_dir) # Save predictions prediction_save_path = os.path.join(target_dir, "predictions.npz") np.savez(prediction_save_path, **predictions) # Handle None frame_filter if frame_filter is None: frame_filter = "All" # Build a GLB file name glbfile = os.path.join( target_dir, f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb", ) # Convert predictions to GLB glbscene = predictions_to_glb( predictions, conf_thres=conf_thres, filter_by_frames=frame_filter, mask_black_bg=mask_black_bg, mask_white_bg=mask_white_bg, show_cam=show_cam, mask_sky=mask_sky, target_dir=target_dir, prediction_mode=prediction_mode, ) glbscene.export(file_obj=glbfile) # Cleanup del predictions gc.collect() end_time = time.time() print(f"Total time: {end_time - start_time:.2f} seconds") log_msg = ( f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization." ) return ( glbfile, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True), ) # ------------------------------------------------------------------------- # 5) Helper functions for UI resets + re-visualization # ------------------------------------------------------------------------- def clear_fields(): """ Clears the 3D viewer, the stored target_dir, and empties the gallery. """ return None def update_log(): """ Display a quick log message while waiting. """ return "Loading and Reconstructing..." def update_visualization( target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example, ): """ Reload saved predictions from npz, create (or reuse) the GLB for new parameters, and return it for the 3D viewer. If is_example == "True", skip. """ # If it's an example click, skip as requested if is_example == "True": return ( None, "No reconstruction available. Please click the Reconstruct button first.", ) if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): return ( None, "No reconstruction available. Please click the Reconstruct button first.", ) predictions_path = os.path.join(target_dir, "predictions.npz") if not os.path.exists(predictions_path): return ( None, f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.", ) loaded = np.load(predictions_path, allow_pickle=True) predictions = {key: loaded[key] for key in loaded.keys()} glbfile = os.path.join( target_dir, f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb", ) if not os.path.exists(glbfile): glbscene = predictions_to_glb( predictions, conf_thres=conf_thres, filter_by_frames=frame_filter, mask_black_bg=mask_black_bg, mask_white_bg=mask_white_bg, show_cam=show_cam, mask_sky=mask_sky, target_dir=target_dir, prediction_mode=prediction_mode, ) glbscene.export(file_obj=glbfile) return glbfile, "Updating Visualization" # ------------------------------------------------------------------------- # Example images # ------------------------------------------------------------------------- # Get the absolute directory of app.py to ensure correct resource paths APP_DIR = os.path.dirname(os.path.abspath(__file__)) # Use absolute paths for all example videos to ensure they load correctly in containerized environments # canyon_video = os.path.join(APP_DIR, "examples/videos/Studlagil_Canyon_East_Iceland.mp4") great_wall_video = os.path.join(APP_DIR, "examples/videos/great_wall.mp4") colosseum_video = os.path.join(APP_DIR, "examples/videos/Colosseum.mp4") room_video = os.path.join(APP_DIR, "examples/videos/room.mp4") kitchen_video = os.path.join(APP_DIR, "examples/videos/kitchen.mp4") fern_video = os.path.join(APP_DIR, "examples/videos/fern.mp4") single_cartoon_video = os.path.join(APP_DIR, "examples/videos/single_cartoon.mp4") single_oil_painting_video = os.path.join( APP_DIR, "examples/videos/single_oil_painting.mp4" ) pyramid_video = os.path.join(APP_DIR, "examples/videos/pyramid.mp4") # ------------------------------------------------------------------------- # 6) Build Gradio UI # ------------------------------------------------------------------------- theme = gr.themes.Ocean() theme.set( checkbox_label_background_fill_selected="*button_primary_background_fill", checkbox_label_text_color_selected="*button_primary_text_color", ) with gr.Blocks( theme=theme, css=""" .custom-log * { font-style: italic; font-size: 22px !important; background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%); -webkit-background-clip: text; background-clip: text; font-weight: bold !important; color: transparent !important; text-align: center !important; } .example-log * { font-style: italic; font-size: 16px !important; background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%); -webkit-background-clip: text; background-clip: text; color: transparent !important; } #my_radio .wrap { display: flex; flex-wrap: nowrap; justify-content: center; align-items: center; } #my_radio .wrap label { display: flex; width: 50%; justify-content: center; align-items: center; margin: 0; padding: 10px 0; box-sizing: border-box; } """, ) as demo: # Instead of gr.State, we use a hidden Textbox: is_example = gr.Textbox(label="is_example", visible=False, value="None") num_images = gr.Textbox(label="num_images", visible=False, value="None") gr.HTML( """
🌟 GitHub Repository | 🚀 Project Page
Upload a video or a set of images to create a 3D reconstruction of a scene or object. VGGT takes these images and generates all key 3D attributes, including extrinsic and intrinsic camera parameters, point maps, depth maps, and 3D point tracks.
Please note: Our model itself usually only needs less than 1 second to reconstruct a scene. However, visualizing 3D points may take tens of seconds due to third-party rendering, which are independent of VGGT's processing time. Please be patient or, for faster visualization, use a local machine to run our demo from our GitHub repository.