""" DPM-Splat: End-to-end pipeline for Video → 4D Gaussian Splats Combines VDPM inference with Dynamic 4DGS training in a single Gradio interface. """ import os import sys import shutil import zipfile import gc import json import glob import time from pathlib import Path from datetime import datetime import cv2 import numpy as np import gradio as gr import torch import imageio # Set memory optimization os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # Add paths sys.path.insert(0, str(Path(__file__).parent / "vdpm")) sys.path.insert(0, str(Path(__file__).parent / "gs")) # Import depth utilities from vdpm.util.depth import write_depth_to_png # Check GPU availability device = "cuda" if torch.cuda.is_available() else "cpu" if device == "cuda": torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True gpu_name = torch.cuda.get_device_name(0) gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3) print(f"✓ GPU: {gpu_name} ({gpu_mem:.1f} GB)") else: print("⚠ No GPU detected - running on CPU (will be slow)") # Configuration VIDEO_SAMPLE_HZ = 1.0 # Set MAX_FRAMES based on VRAM if device == "cuda": if gpu_mem >= 20.0: MAX_FRAMES = 32 elif gpu_mem >= 12.0: MAX_FRAMES = 16 else: MAX_FRAMES = 8 else: MAX_FRAMES = 4 print(f"✓ Configured limit: {MAX_FRAMES} frames based on {gpu_mem:.1f} GB VRAM") # Global model cache _vdpm_model = None def get_vdpm_model(): """Load and cache the VDPM model""" global _vdpm_model if _vdpm_model is not None: print("✓ Using cached VDPM model") return _vdpm_model print("Loading VDPM model...") sys.stdout.flush() from hydra import compose, initialize from hydra.core.global_hydra import GlobalHydra from dpm.model import VDPM if GlobalHydra.instance().is_initialized(): GlobalHydra.instance().clear() with initialize(config_path="vdpm/configs"): cfg = compose(config_name="visualise") model = VDPM(cfg).to(device) # Load weights cache_dir = os.path.expanduser("~/.cache/vdpm") os.makedirs(cache_dir, exist_ok=True) model_path = os.path.join(cache_dir, "vdpm_model.pt") _URL = "https://huggingface.co/edgarsucar/vdpm/resolve/main/model.pt" if not os.path.exists(model_path): print(f"Downloading VDPM model...") sd = torch.hub.load_state_dict_from_url(_URL, file_name="vdpm_model.pt", progress=True, map_location=device) torch.save(sd, model_path) else: print(f"✓ Loading cached model from {model_path}") sd = torch.load(model_path, map_location=device) model.load_state_dict(sd, strict=True) model.eval() # Use half precision if device == "cuda": if torch.cuda.get_device_capability()[0] >= 8: model = model.to(torch.bfloat16) print("✓ Using BF16 precision") else: model = model.half() print("✓ Using FP16 precision") _vdpm_model = model return model def process_videos(video_files, target_dir): """Extract and interleave frames from uploaded videos""" images_dir = target_dir / "images" images_dir.mkdir(parents=True, exist_ok=True) num_views = len(video_files) captures = [] intervals = [] for vid_obj in video_files: video_path = vid_obj.name if hasattr(vid_obj, 'name') else str(vid_obj) vs = cv2.VideoCapture(video_path) fps = float(vs.get(cv2.CAP_PROP_FPS) or 30.0) interval = max(int(fps / max(VIDEO_SAMPLE_HZ, 1e-6)), 1) captures.append(vs) intervals.append(interval) # Interleave frames: [cam0_t0, cam1_t0, cam0_t1, cam1_t1, ...] frame_num = 0 step_count = 0 active = True image_paths = [] while active: active = False for i, vs in enumerate(captures): if not vs.isOpened(): continue ret, frame = vs.read() if ret: active = True if step_count % intervals[i] == 0: out_path = images_dir / f"{frame_num:06d}.png" cv2.imwrite(str(out_path), frame) image_paths.append(str(out_path)) frame_num += 1 else: vs.release() step_count += 1 for vs in captures: if vs.isOpened(): vs.release() # Save metadata meta = {"num_views": num_views} with open(target_dir / "meta.json", "w") as f: json.dump(meta, f) return image_paths, num_views def decode_poses(pose_enc: np.ndarray, image_hw: tuple) -> tuple: """Decode VGGT pose encodings to camera matrices.""" try: from vggt.utils.pose_enc import pose_encoding_to_extri_intri pose_enc_t = torch.from_numpy(pose_enc).float() extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc_t, image_hw) extrinsic = extrinsic[0].numpy() # (N, 3, 4) intrinsic = intrinsic[0].numpy() # (N, 3, 3) N = extrinsic.shape[0] bottom = np.array([0, 0, 0, 1], dtype=np.float32).reshape(1, 1, 4) bottom = np.tile(bottom, (N, 1, 1)) extrinsics_4x4 = np.concatenate([extrinsic, bottom], axis=1) return extrinsics_4x4, intrinsic except ImportError: print("Warning: vggt not available. Using identity poses.") N = pose_enc.shape[1] extrinsics = np.tile(np.eye(4, dtype=np.float32), (N, 1, 1)) H, W = image_hw fx = fy = max(H, W) cx, cy = W / 2, H / 2 intrinsic = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) intrinsics = np.tile(intrinsic, (N, 1, 1)) return extrinsics, intrinsics def compute_depths(world_points: np.ndarray, extrinsics: np.ndarray, num_views: int) -> np.ndarray: """ Compute depth maps from world points and camera extrinsics. Args: world_points: (T, V, H, W, 3) world-space 3D points extrinsics: (N, 4, 4) camera extrinsics (world-to-camera) num_views: Number of camera views Returns: depths: (T, V, H, W) depth maps (Z in camera coordinates) """ T, V, H, W, _ = world_points.shape depths = np.zeros((T, V, H, W), dtype=np.float32) for t in range(T): for v in range(V): # Get camera extrinsic for this view at this timestep img_idx = t * num_views + v if img_idx >= len(extrinsics): img_idx = v # Fallback to first timestep's cameras w2c = extrinsics[img_idx] # (4, 4) R = w2c[:3, :3] # (3, 3) t_vec = w2c[:3, 3] # (3,) # Transform world points to camera coordinates pts_world = world_points[t, v].reshape(-1, 3) # (H*W, 3) pts_cam = (R @ pts_world.T).T + t_vec # (H*W, 3) # Depth is Z in camera coordinates depth = pts_cam[:, 2].reshape(H, W) depths[t, v] = depth return depths def run_vdpm_inference(target_dir, progress): """Run VDPM inference and save outputs in output_4d.npz format""" from vggt.utils.load_fn import load_and_preprocess_images model = get_vdpm_model() image_names = sorted(glob.glob(os.path.join(target_dir, "images", "*"))) if not image_names: raise ValueError("No images found") # Load metadata meta_path = target_dir / "meta.json" num_views = 1 if meta_path.exists(): with open(meta_path) as f: num_views = json.load(f).get("num_views", 1) # Limit frames if len(image_names) > MAX_FRAMES: limit = (MAX_FRAMES // num_views) * num_views if limit == 0: limit = num_views print(f"⚠ Limiting to {limit} frames") image_names = image_names[:limit] progress(0.15, desc=f"Loading {len(image_names)} images...") images = load_and_preprocess_images(image_names).to(device) # Store original images for visualization images_np = images.cpu().numpy() # (S, 3, H, W) # Construct views views = [] for i in range(len(image_names)): t_idx = i // num_views cam_idx = i % num_views views.append({ "img": images[i].unsqueeze(0), "view_idxs": torch.tensor([[cam_idx, t_idx]], device=device, dtype=torch.long) }) progress(0.2, desc="Running VDPM forward pass...") print(f"Running inference on {len(image_names)} images...") sys.stdout.flush() with torch.no_grad(): with torch.amp.autocast('cuda'): predictions = model.inference(views=views) # Extract results pts_list = [pm["pts3d"].detach().cpu().numpy() for pm in predictions["pointmaps"]] conf_list = [pm["conf"].detach().cpu().numpy() for pm in predictions["pointmaps"]] pose_enc = None if "pose_enc" in predictions: pose_enc = predictions["pose_enc"].detach().cpu().numpy() del predictions torch.cuda.empty_cache() world_points_raw = np.concatenate(pts_list, axis=0) # (T, S, H, W, 3) world_points_conf_raw = np.concatenate(conf_list, axis=0) # (T, S, H, W) T = world_points_raw.shape[0] S = world_points_raw.shape[1] H, W = world_points_raw.shape[2:4] num_timesteps = S // num_views print(f"VDPM output shape: T={T}, S={S}, num_views={num_views}") progress(0.3, desc="Processing VDPM outputs...") # ======================================================================== # Extract diagonal entries for 4DGS (each image at its natural timestep) # Format: (num_timesteps, num_views*H*W, 3) flattened for train_dynamic.py # ======================================================================== world_points_4d = [] world_points_conf_4d = [] images_4d = [] for t in range(num_timesteps): # Collect all views for this timestep pts_t = [] conf_t = [] imgs_t = [] for v in range(num_views): img_idx = t * num_views + v if img_idx >= S: break # Use the pointmap at timestep query = img_idx (diagonal) # VDPM outputs: world_points_raw[query_t, input_img_idx, H, W, 3] # We want the point where query_t == input_img_idx for single-view consistency query_t = min(img_idx, T - 1) pts_v = world_points_raw[query_t, img_idx] # (H, W, 3) conf_v = world_points_conf_raw[query_t, img_idx] # (H, W) img_v = images_np[img_idx] # (3, H, W) pts_t.append(pts_v.reshape(-1, 3)) # (H*W, 3) conf_t.append(conf_v.reshape(-1)) # (H*W,) imgs_t.append(img_v) if pts_t: # Concatenate all views: (V*H*W, 3) world_points_4d.append(np.concatenate(pts_t, axis=0)) world_points_conf_4d.append(np.concatenate(conf_t, axis=0)) # Stack images: (V, 3, H, W) -> average to (3, H, W) for visualization # Or just use first view images_4d.append(imgs_t[0]) world_points_4d = np.stack(world_points_4d, axis=0) # (T, N, 3) where N = V*H*W world_points_conf_4d = np.stack(world_points_conf_4d, axis=0) # (T, N) images_4d = np.stack(images_4d, axis=0) # (T, 3, H, W) print(f"4DGS format: world_points={world_points_4d.shape}, images={images_4d.shape}") progress(0.35, desc="Saving outputs...") # Save in output_4d.npz format (compatible with train_dynamic.py) np.savez_compressed( target_dir / "output_4d.npz", world_points=world_points_4d, world_points_conf=world_points_conf_4d, images=images_4d, num_views=num_views, num_timesteps=num_timesteps ) if pose_enc is not None: np.savez_compressed(target_dir / "poses.npz", pose_enc=pose_enc) # ======================================================================== # COMPUTE AND SAVE DEPTHS # ======================================================================== if pose_enc is not None: print("Computing depth maps...") # Reshape for depth computation: (T, V, H, W, 3) world_points_for_depth = world_points_4d.reshape(num_timesteps, num_views, H, W, 3) extrinsics, intrinsics = decode_poses(pose_enc, (H, W)) depths = compute_depths(world_points_for_depth, extrinsics, num_views) # Save depths np.savez_compressed( target_dir / "depths.npz", depths=depths, num_views=num_views, num_timesteps=num_timesteps ) # Save depth images depths_dir = target_dir / "depths" depths_dir.mkdir(exist_ok=True) for t in range(depths.shape[0]): for v in range(depths.shape[1]): png_path = depths_dir / f"depth_t{t:04d}_v{v:02d}.png" write_depth_to_png(str(png_path), depths[t, v]) print(f"✓ Saved {depths.shape[0] * depths.shape[1]} depth images") print(f"✓ VDPM complete: {num_timesteps} timesteps, {num_views} views") sys.stdout.flush() return num_timesteps, num_views def run_4dgs_training(target_dir, output_dir, initial_iterations, subsequent_iterations, conf_threshold, progress): """Run Dynamic 4D Gaussian Splatting training""" import warp as wp from gs.train_dynamic import load_dynamic_data, DynamicGaussianTrainer wp.init() print(f"\n{'='*50}") print("[DYNAMIC 4D GAUSSIANS TRAINING]") print(f"Frame 0: {initial_iterations} iterations") print(f"Frames 1+: {subsequent_iterations} iterations each") print(f"{'='*50}") sys.stdout.flush() data = load_dynamic_data(str(target_dir)) output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) # Create and run trainer trainer = DynamicGaussianTrainer( data=data, output_path=str(output_path), conf_threshold=conf_threshold, initial_iterations=initial_iterations, subsequent_iterations=subsequent_iterations, simultaneous_mode=False, ) def progress_callback(frac, desc): progress(0.4 + 0.5 * frac, desc=desc) trainer.train_sequential(progress_callback=progress_callback) # Return paths to outputs npz_path = output_path / "dynamic_gaussians.npz" mp4_path = output_path / "animation.mp4" gif_path = output_path / "animation.gif" print(f"✓ 4DGS training complete: {trainer.num_timesteps} frames, {trainer.num_points} Gaussians") sys.stdout.flush() return { 'npz_path': str(npz_path) if npz_path.exists() else None, 'mp4_path': str(mp4_path) if mp4_path.exists() else None, 'gif_path': str(gif_path) if gif_path.exists() else None, 'num_frames': trainer.num_timesteps, 'num_points': trainer.num_points, } def run_pipeline(video_files, initial_iterations, subsequent_iterations, conf_threshold, progress=gr.Progress()): """Run the full VDPM → 4DGS pipeline""" if not video_files: return None, None, None, "❌ Please upload video file(s)" gc.collect() if device == "cuda": torch.cuda.empty_cache() timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") run_dir = Path(f"output/pipeline/run_{timestamp}") run_dir.mkdir(parents=True, exist_ok=True) try: # Step 1: Process videos progress(0.05, desc="Processing uploaded videos...") print("=" * 50) print("Processing Videos") print("=" * 50) sys.stdout.flush() image_paths, num_views = process_videos(video_files, run_dir) print(f"✓ Extracted {len(image_paths)} frames from {num_views} videos") sys.stdout.flush() # Step 2: VDPM inference progress(0.1, desc="Running VDPM inference...") print("=" * 50) print("Running VDPM Inference") print("=" * 50) sys.stdout.flush() num_timesteps, num_views = run_vdpm_inference(run_dir, progress) # Clear VRAM before 4DGS training global _vdpm_model _vdpm_model = None gc.collect() if device == "cuda": torch.cuda.empty_cache() print(f"✓ Cleared VRAM: {torch.cuda.memory_allocated()/1024**3:.2f} GB in use") sys.stdout.flush() # Step 3: 4DGS training progress(0.4, desc="Training 4D Gaussian Splats...") print("=" * 50) print("Training 4D Gaussian Splats") print("=" * 50) sys.stdout.flush() splat_dir = run_dir / "splats" results = run_4dgs_training( run_dir, splat_dir, int(initial_iterations), int(subsequent_iterations), float(conf_threshold), progress ) # Step 4: Package results progress(0.95, desc="Packaging results...") zip_path = run_dir / "results.zip" with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf: # Add main outputs if results['npz_path'] and Path(results['npz_path']).exists(): zf.write(results['npz_path'], "dynamic_gaussians.npz") if results['mp4_path'] and Path(results['mp4_path']).exists(): zf.write(results['mp4_path'], "animation.mp4") if results['gif_path'] and Path(results['gif_path']).exists(): zf.write(results['gif_path'], "animation.gif") # Add frame renders for frame_dir in splat_dir.glob("frame_*"): for subdir in ["renders", "training_renders"]: render_dir = frame_dir / subdir if render_dir.exists(): for img in render_dir.glob("*.png"): rel_path = img.relative_to(splat_dir) zf.write(img, f"renders/{rel_path}") # Add VDPM data for f in ["output_4d.npz", "poses.npz", "depths.npz", "meta.json"]: fp = run_dir / f if fp.exists(): zf.write(fp, f) # Add input images images_dir = run_dir / "images" if images_dir.exists(): for img in sorted(images_dir.glob("*"))[:20]: # Limit to first 20 zf.write(img, f"images/{img.name}") # Add depth images depths_dir = run_dir / "depths" if depths_dir.exists(): for img in sorted(depths_dir.glob("*.png"))[:20]: zf.write(img, f"depths/{img.name}") progress(1.0, desc="Complete!") status = f"""✅ Pipeline Complete! 📊 Results: • {results['num_frames']} timesteps × {num_views} views • {results['num_points']:,} Gaussians • Animation: {'✓' if results['mp4_path'] else '✗'} 📁 Output: {run_dir} 📦 Download the ZIP for all files""" # Return video for preview video_path = results.get('mp4_path') return str(zip_path), video_path, status except Exception as e: import traceback traceback.print_exc() return None, None, f"❌ Error: {str(e)}" # ===== Gradio Interface ===== with gr.Blocks(title="DPM-Splat: 4D Gaussian Splatting", theme=gr.themes.Soft()) as app: gr.Markdown(""" # 🎬 DPM-Splat: Video → 4D Gaussian Splats End-to-end pipeline combining **V-DPM** (Video Dynamic Point Maps) with **4D Gaussian Splatting**. Upload synchronized videos to generate temporally consistent 4D reconstructions with time-varying Gaussians. """) with gr.Row(): with gr.Column(scale=1): video_input = gr.File( label="📹 Upload Videos", file_count="multiple", file_types=[".mp4", ".mov", ".avi", ".webm"] ) gr.Markdown("*Upload 1-4 synchronized video files for multi-view reconstruction*") with gr.Accordion("⚙️ Training Settings", open=True): initial_iterations = gr.Slider( minimum=100, maximum=10000, value=3000, step=100, label="Frame 0 Iterations", info="Training iterations for canonical frame (more = better base quality)" ) subsequent_iterations = gr.Slider( minimum=100, maximum=5000, value=500, step=100, label="Subsequent Frame Iterations", info="Training iterations for frames 1+ (positions only)" ) conf_threshold = gr.Slider( minimum=0, maximum=100, value=0, step=5, label="Confidence Threshold (%)", info="0% keeps all points, higher = filter low confidence" ) run_btn = gr.Button("🚀 Run Pipeline", variant="primary", size="lg") status_text = gr.Textbox( label="Status", interactive=False, lines=8, value="Upload videos and click 'Run Pipeline' to begin." ) with gr.Column(scale=2): video_viewer = gr.Video( label="🎞️ 4D Gaussian Animation", height=500, autoplay=True, loop=True ) download_btn = gr.File(label="📦 Download Results (ZIP)") gr.Markdown(""" --- ### 📋 Output Contents The downloaded ZIP contains: - `dynamic_gaussians.npz` - All Gaussian parameters (positions per frame, shared scales/rotations/opacities/SHs) - `animation.mp4` - Rendered video with smooth camera interpolation - `renders/` - Per-frame training renders showing RGB and depth - `output_4d.npz` - VDPM point tracks - `poses.npz` - Camera poses - `depths/` - Computed depth maps - `images/` - Input frames ### 🎯 How It Works 1. **VDPM**: Extracts temporally consistent 3D point maps from video 2. **4DGS Training**: - Train canonical frame (t=0) with all Gaussian parameters - Train subsequent frames with position-only updates (shared appearance) 3. **Animation**: Smooth camera path through training viewpoints **Local runs**: Results saved to `output/pipeline/run_TIMESTAMP/` """) run_btn.click( fn=run_pipeline, inputs=[video_input, initial_iterations, subsequent_iterations, conf_threshold], outputs=[download_btn, video_viewer, status_text] ) if __name__ == "__main__": # Download model on startup if device == "cuda": print("Pre-loading VDPM model...") try: get_vdpm_model() _vdpm_model = None # Free VRAM but keep file cached gc.collect() torch.cuda.empty_cache() print("✓ Model pre-loaded and cached") except Exception as e: print(f"⚠ Failed to pre-load model: {e}") app.queue().launch(share=True, show_error=True)