4dgs-dpm / app.py
dxm21's picture
Upload folder using huggingface_hub
f61284c verified
"""
DPM-Splat: End-to-end pipeline for Video β†’ 4D Gaussian Splats
Combines VDPM inference with Dynamic 4DGS training in a single Gradio interface.
"""
import os
import sys
import shutil
import zipfile
import gc
import json
import glob
import time
from pathlib import Path
from datetime import datetime
import cv2
import numpy as np
import gradio as gr
import torch
import imageio
# Set memory optimization
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# Add paths
sys.path.insert(0, str(Path(__file__).parent / "vdpm"))
sys.path.insert(0, str(Path(__file__).parent / "gs"))
# Import depth utilities
from vdpm.util.depth import write_depth_to_png
# Check GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
gpu_name = torch.cuda.get_device_name(0)
gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)
print(f"βœ“ GPU: {gpu_name} ({gpu_mem:.1f} GB)")
else:
print("⚠ No GPU detected - running on CPU (will be slow)")
# Configuration
VIDEO_SAMPLE_HZ = 1.0
# Set MAX_FRAMES based on VRAM
if device == "cuda":
if gpu_mem >= 20.0:
MAX_FRAMES = 32
elif gpu_mem >= 12.0:
MAX_FRAMES = 16
else:
MAX_FRAMES = 8
else:
MAX_FRAMES = 4
print(f"βœ“ Configured limit: {MAX_FRAMES} frames based on {gpu_mem:.1f} GB VRAM")
# Global model cache
_vdpm_model = None
def get_vdpm_model():
"""Load and cache the VDPM model"""
global _vdpm_model
if _vdpm_model is not None:
print("βœ“ Using cached VDPM model")
return _vdpm_model
print("Loading VDPM model...")
sys.stdout.flush()
from hydra import compose, initialize
from hydra.core.global_hydra import GlobalHydra
from dpm.model import VDPM
if GlobalHydra.instance().is_initialized():
GlobalHydra.instance().clear()
with initialize(config_path="vdpm/configs"):
cfg = compose(config_name="visualise")
model = VDPM(cfg).to(device)
# Load weights
cache_dir = os.path.expanduser("~/.cache/vdpm")
os.makedirs(cache_dir, exist_ok=True)
model_path = os.path.join(cache_dir, "vdpm_model.pt")
_URL = "https://huggingface.co/edgarsucar/vdpm/resolve/main/model.pt"
if not os.path.exists(model_path):
print(f"Downloading VDPM model...")
sd = torch.hub.load_state_dict_from_url(_URL, file_name="vdpm_model.pt", progress=True, map_location=device)
torch.save(sd, model_path)
else:
print(f"βœ“ Loading cached model from {model_path}")
sd = torch.load(model_path, map_location=device)
model.load_state_dict(sd, strict=True)
model.eval()
# Use half precision
if device == "cuda":
if torch.cuda.get_device_capability()[0] >= 8:
model = model.to(torch.bfloat16)
print("βœ“ Using BF16 precision")
else:
model = model.half()
print("βœ“ Using FP16 precision")
_vdpm_model = model
return model
def process_videos(video_files, target_dir):
"""Extract and interleave frames from uploaded videos"""
images_dir = target_dir / "images"
images_dir.mkdir(parents=True, exist_ok=True)
num_views = len(video_files)
captures = []
intervals = []
for vid_obj in video_files:
video_path = vid_obj.name if hasattr(vid_obj, 'name') else str(vid_obj)
vs = cv2.VideoCapture(video_path)
fps = float(vs.get(cv2.CAP_PROP_FPS) or 30.0)
interval = max(int(fps / max(VIDEO_SAMPLE_HZ, 1e-6)), 1)
captures.append(vs)
intervals.append(interval)
# Interleave frames: [cam0_t0, cam1_t0, cam0_t1, cam1_t1, ...]
frame_num = 0
step_count = 0
active = True
image_paths = []
while active:
active = False
for i, vs in enumerate(captures):
if not vs.isOpened():
continue
ret, frame = vs.read()
if ret:
active = True
if step_count % intervals[i] == 0:
out_path = images_dir / f"{frame_num:06d}.png"
cv2.imwrite(str(out_path), frame)
image_paths.append(str(out_path))
frame_num += 1
else:
vs.release()
step_count += 1
for vs in captures:
if vs.isOpened():
vs.release()
# Save metadata
meta = {"num_views": num_views}
with open(target_dir / "meta.json", "w") as f:
json.dump(meta, f)
return image_paths, num_views
def decode_poses(pose_enc: np.ndarray, image_hw: tuple) -> tuple:
"""Decode VGGT pose encodings to camera matrices."""
try:
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
pose_enc_t = torch.from_numpy(pose_enc).float()
extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc_t, image_hw)
extrinsic = extrinsic[0].numpy() # (N, 3, 4)
intrinsic = intrinsic[0].numpy() # (N, 3, 3)
N = extrinsic.shape[0]
bottom = np.array([0, 0, 0, 1], dtype=np.float32).reshape(1, 1, 4)
bottom = np.tile(bottom, (N, 1, 1))
extrinsics_4x4 = np.concatenate([extrinsic, bottom], axis=1)
return extrinsics_4x4, intrinsic
except ImportError:
print("Warning: vggt not available. Using identity poses.")
N = pose_enc.shape[1]
extrinsics = np.tile(np.eye(4, dtype=np.float32), (N, 1, 1))
H, W = image_hw
fx = fy = max(H, W)
cx, cy = W / 2, H / 2
intrinsic = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
intrinsics = np.tile(intrinsic, (N, 1, 1))
return extrinsics, intrinsics
def compute_depths(world_points: np.ndarray, extrinsics: np.ndarray, num_views: int) -> np.ndarray:
"""
Compute depth maps from world points and camera extrinsics.
Args:
world_points: (T, V, H, W, 3) world-space 3D points
extrinsics: (N, 4, 4) camera extrinsics (world-to-camera)
num_views: Number of camera views
Returns:
depths: (T, V, H, W) depth maps (Z in camera coordinates)
"""
T, V, H, W, _ = world_points.shape
depths = np.zeros((T, V, H, W), dtype=np.float32)
for t in range(T):
for v in range(V):
# Get camera extrinsic for this view at this timestep
img_idx = t * num_views + v
if img_idx >= len(extrinsics):
img_idx = v # Fallback to first timestep's cameras
w2c = extrinsics[img_idx] # (4, 4)
R = w2c[:3, :3] # (3, 3)
t_vec = w2c[:3, 3] # (3,)
# Transform world points to camera coordinates
pts_world = world_points[t, v].reshape(-1, 3) # (H*W, 3)
pts_cam = (R @ pts_world.T).T + t_vec # (H*W, 3)
# Depth is Z in camera coordinates
depth = pts_cam[:, 2].reshape(H, W)
depths[t, v] = depth
return depths
def run_vdpm_inference(target_dir, progress):
"""Run VDPM inference and save outputs in output_4d.npz format"""
from vggt.utils.load_fn import load_and_preprocess_images
model = get_vdpm_model()
image_names = sorted(glob.glob(os.path.join(target_dir, "images", "*")))
if not image_names:
raise ValueError("No images found")
# Load metadata
meta_path = target_dir / "meta.json"
num_views = 1
if meta_path.exists():
with open(meta_path) as f:
num_views = json.load(f).get("num_views", 1)
# Limit frames
if len(image_names) > MAX_FRAMES:
limit = (MAX_FRAMES // num_views) * num_views
if limit == 0:
limit = num_views
print(f"⚠ Limiting to {limit} frames")
image_names = image_names[:limit]
progress(0.15, desc=f"Loading {len(image_names)} images...")
images = load_and_preprocess_images(image_names).to(device)
# Store original images for visualization
images_np = images.cpu().numpy() # (S, 3, H, W)
# Construct views
views = []
for i in range(len(image_names)):
t_idx = i // num_views
cam_idx = i % num_views
views.append({
"img": images[i].unsqueeze(0),
"view_idxs": torch.tensor([[cam_idx, t_idx]], device=device, dtype=torch.long)
})
progress(0.2, desc="Running VDPM forward pass...")
print(f"Running inference on {len(image_names)} images...")
sys.stdout.flush()
with torch.no_grad():
with torch.amp.autocast('cuda'):
predictions = model.inference(views=views)
# Extract results
pts_list = [pm["pts3d"].detach().cpu().numpy() for pm in predictions["pointmaps"]]
conf_list = [pm["conf"].detach().cpu().numpy() for pm in predictions["pointmaps"]]
pose_enc = None
if "pose_enc" in predictions:
pose_enc = predictions["pose_enc"].detach().cpu().numpy()
del predictions
torch.cuda.empty_cache()
world_points_raw = np.concatenate(pts_list, axis=0) # (T, S, H, W, 3)
world_points_conf_raw = np.concatenate(conf_list, axis=0) # (T, S, H, W)
T = world_points_raw.shape[0]
S = world_points_raw.shape[1]
H, W = world_points_raw.shape[2:4]
num_timesteps = S // num_views
print(f"VDPM output shape: T={T}, S={S}, num_views={num_views}")
progress(0.3, desc="Processing VDPM outputs...")
# ========================================================================
# Extract diagonal entries for 4DGS (each image at its natural timestep)
# Format: (num_timesteps, num_views*H*W, 3) flattened for train_dynamic.py
# ========================================================================
world_points_4d = []
world_points_conf_4d = []
images_4d = []
for t in range(num_timesteps):
# Collect all views for this timestep
pts_t = []
conf_t = []
imgs_t = []
for v in range(num_views):
img_idx = t * num_views + v
if img_idx >= S:
break
# Use the pointmap at timestep query = img_idx (diagonal)
# VDPM outputs: world_points_raw[query_t, input_img_idx, H, W, 3]
# We want the point where query_t == input_img_idx for single-view consistency
query_t = min(img_idx, T - 1)
pts_v = world_points_raw[query_t, img_idx] # (H, W, 3)
conf_v = world_points_conf_raw[query_t, img_idx] # (H, W)
img_v = images_np[img_idx] # (3, H, W)
pts_t.append(pts_v.reshape(-1, 3)) # (H*W, 3)
conf_t.append(conf_v.reshape(-1)) # (H*W,)
imgs_t.append(img_v)
if pts_t:
# Concatenate all views: (V*H*W, 3)
world_points_4d.append(np.concatenate(pts_t, axis=0))
world_points_conf_4d.append(np.concatenate(conf_t, axis=0))
# Stack images: (V, 3, H, W) -> average to (3, H, W) for visualization
# Or just use first view
images_4d.append(imgs_t[0])
world_points_4d = np.stack(world_points_4d, axis=0) # (T, N, 3) where N = V*H*W
world_points_conf_4d = np.stack(world_points_conf_4d, axis=0) # (T, N)
images_4d = np.stack(images_4d, axis=0) # (T, 3, H, W)
print(f"4DGS format: world_points={world_points_4d.shape}, images={images_4d.shape}")
progress(0.35, desc="Saving outputs...")
# Save in output_4d.npz format (compatible with train_dynamic.py)
np.savez_compressed(
target_dir / "output_4d.npz",
world_points=world_points_4d,
world_points_conf=world_points_conf_4d,
images=images_4d,
num_views=num_views,
num_timesteps=num_timesteps
)
if pose_enc is not None:
np.savez_compressed(target_dir / "poses.npz", pose_enc=pose_enc)
# ========================================================================
# COMPUTE AND SAVE DEPTHS
# ========================================================================
if pose_enc is not None:
print("Computing depth maps...")
# Reshape for depth computation: (T, V, H, W, 3)
world_points_for_depth = world_points_4d.reshape(num_timesteps, num_views, H, W, 3)
extrinsics, intrinsics = decode_poses(pose_enc, (H, W))
depths = compute_depths(world_points_for_depth, extrinsics, num_views)
# Save depths
np.savez_compressed(
target_dir / "depths.npz",
depths=depths,
num_views=num_views,
num_timesteps=num_timesteps
)
# Save depth images
depths_dir = target_dir / "depths"
depths_dir.mkdir(exist_ok=True)
for t in range(depths.shape[0]):
for v in range(depths.shape[1]):
png_path = depths_dir / f"depth_t{t:04d}_v{v:02d}.png"
write_depth_to_png(str(png_path), depths[t, v])
print(f"βœ“ Saved {depths.shape[0] * depths.shape[1]} depth images")
print(f"βœ“ VDPM complete: {num_timesteps} timesteps, {num_views} views")
sys.stdout.flush()
return num_timesteps, num_views
def run_4dgs_training(target_dir, output_dir, initial_iterations, subsequent_iterations, conf_threshold, progress):
"""Run Dynamic 4D Gaussian Splatting training"""
import warp as wp
from gs.train_dynamic import load_dynamic_data, DynamicGaussianTrainer
wp.init()
print(f"\n{'='*50}")
print("[DYNAMIC 4D GAUSSIANS TRAINING]")
print(f"Frame 0: {initial_iterations} iterations")
print(f"Frames 1+: {subsequent_iterations} iterations each")
print(f"{'='*50}")
sys.stdout.flush()
data = load_dynamic_data(str(target_dir))
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Create and run trainer
trainer = DynamicGaussianTrainer(
data=data,
output_path=str(output_path),
conf_threshold=conf_threshold,
initial_iterations=initial_iterations,
subsequent_iterations=subsequent_iterations,
simultaneous_mode=False,
)
def progress_callback(frac, desc):
progress(0.4 + 0.5 * frac, desc=desc)
trainer.train_sequential(progress_callback=progress_callback)
# Return paths to outputs
npz_path = output_path / "dynamic_gaussians.npz"
mp4_path = output_path / "animation.mp4"
gif_path = output_path / "animation.gif"
print(f"βœ“ 4DGS training complete: {trainer.num_timesteps} frames, {trainer.num_points} Gaussians")
sys.stdout.flush()
return {
'npz_path': str(npz_path) if npz_path.exists() else None,
'mp4_path': str(mp4_path) if mp4_path.exists() else None,
'gif_path': str(gif_path) if gif_path.exists() else None,
'num_frames': trainer.num_timesteps,
'num_points': trainer.num_points,
}
def run_pipeline(video_files, initial_iterations, subsequent_iterations, conf_threshold, progress=gr.Progress()):
"""Run the full VDPM β†’ 4DGS pipeline"""
if not video_files:
return None, None, None, "❌ Please upload video file(s)"
gc.collect()
if device == "cuda":
torch.cuda.empty_cache()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = Path(f"output/pipeline/run_{timestamp}")
run_dir.mkdir(parents=True, exist_ok=True)
try:
# Step 1: Process videos
progress(0.05, desc="Processing uploaded videos...")
print("=" * 50)
print("Processing Videos")
print("=" * 50)
sys.stdout.flush()
image_paths, num_views = process_videos(video_files, run_dir)
print(f"βœ“ Extracted {len(image_paths)} frames from {num_views} videos")
sys.stdout.flush()
# Step 2: VDPM inference
progress(0.1, desc="Running VDPM inference...")
print("=" * 50)
print("Running VDPM Inference")
print("=" * 50)
sys.stdout.flush()
num_timesteps, num_views = run_vdpm_inference(run_dir, progress)
# Clear VRAM before 4DGS training
global _vdpm_model
_vdpm_model = None
gc.collect()
if device == "cuda":
torch.cuda.empty_cache()
print(f"βœ“ Cleared VRAM: {torch.cuda.memory_allocated()/1024**3:.2f} GB in use")
sys.stdout.flush()
# Step 3: 4DGS training
progress(0.4, desc="Training 4D Gaussian Splats...")
print("=" * 50)
print("Training 4D Gaussian Splats")
print("=" * 50)
sys.stdout.flush()
splat_dir = run_dir / "splats"
results = run_4dgs_training(
run_dir, splat_dir,
int(initial_iterations),
int(subsequent_iterations),
float(conf_threshold),
progress
)
# Step 4: Package results
progress(0.95, desc="Packaging results...")
zip_path = run_dir / "results.zip"
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
# Add main outputs
if results['npz_path'] and Path(results['npz_path']).exists():
zf.write(results['npz_path'], "dynamic_gaussians.npz")
if results['mp4_path'] and Path(results['mp4_path']).exists():
zf.write(results['mp4_path'], "animation.mp4")
if results['gif_path'] and Path(results['gif_path']).exists():
zf.write(results['gif_path'], "animation.gif")
# Add frame renders
for frame_dir in splat_dir.glob("frame_*"):
for subdir in ["renders", "training_renders"]:
render_dir = frame_dir / subdir
if render_dir.exists():
for img in render_dir.glob("*.png"):
rel_path = img.relative_to(splat_dir)
zf.write(img, f"renders/{rel_path}")
# Add VDPM data
for f in ["output_4d.npz", "poses.npz", "depths.npz", "meta.json"]:
fp = run_dir / f
if fp.exists():
zf.write(fp, f)
# Add input images
images_dir = run_dir / "images"
if images_dir.exists():
for img in sorted(images_dir.glob("*"))[:20]: # Limit to first 20
zf.write(img, f"images/{img.name}")
# Add depth images
depths_dir = run_dir / "depths"
if depths_dir.exists():
for img in sorted(depths_dir.glob("*.png"))[:20]:
zf.write(img, f"depths/{img.name}")
progress(1.0, desc="Complete!")
status = f"""βœ… Pipeline Complete!
πŸ“Š Results:
β€’ {results['num_frames']} timesteps Γ— {num_views} views
β€’ {results['num_points']:,} Gaussians
β€’ Animation: {'βœ“' if results['mp4_path'] else 'βœ—'}
πŸ“ Output: {run_dir}
πŸ“¦ Download the ZIP for all files"""
# Return video for preview
video_path = results.get('mp4_path')
return str(zip_path), video_path, status
except Exception as e:
import traceback
traceback.print_exc()
return None, None, f"❌ Error: {str(e)}"
# ===== Gradio Interface =====
with gr.Blocks(title="DPM-Splat: 4D Gaussian Splatting", theme=gr.themes.Soft()) as app:
gr.Markdown("""
# 🎬 DPM-Splat: Video β†’ 4D Gaussian Splats
End-to-end pipeline combining **V-DPM** (Video Dynamic Point Maps) with **4D Gaussian Splatting**.
Upload synchronized videos to generate temporally consistent 4D reconstructions with time-varying Gaussians.
""")
with gr.Row():
with gr.Column(scale=1):
video_input = gr.File(
label="πŸ“Ή Upload Videos",
file_count="multiple",
file_types=[".mp4", ".mov", ".avi", ".webm"]
)
gr.Markdown("*Upload 1-4 synchronized video files for multi-view reconstruction*")
with gr.Accordion("βš™οΈ Training Settings", open=True):
initial_iterations = gr.Slider(
minimum=100, maximum=10000, value=3000, step=100,
label="Frame 0 Iterations",
info="Training iterations for canonical frame (more = better base quality)"
)
subsequent_iterations = gr.Slider(
minimum=100, maximum=5000, value=500, step=100,
label="Subsequent Frame Iterations",
info="Training iterations for frames 1+ (positions only)"
)
conf_threshold = gr.Slider(
minimum=0, maximum=100, value=0, step=5,
label="Confidence Threshold (%)",
info="0% keeps all points, higher = filter low confidence"
)
run_btn = gr.Button("πŸš€ Run Pipeline", variant="primary", size="lg")
status_text = gr.Textbox(
label="Status",
interactive=False,
lines=8,
value="Upload videos and click 'Run Pipeline' to begin."
)
with gr.Column(scale=2):
video_viewer = gr.Video(
label="🎞️ 4D Gaussian Animation",
height=500,
autoplay=True,
loop=True
)
download_btn = gr.File(label="πŸ“¦ Download Results (ZIP)")
gr.Markdown("""
---
### πŸ“‹ Output Contents
The downloaded ZIP contains:
- `dynamic_gaussians.npz` - All Gaussian parameters (positions per frame, shared scales/rotations/opacities/SHs)
- `animation.mp4` - Rendered video with smooth camera interpolation
- `renders/` - Per-frame training renders showing RGB and depth
- `output_4d.npz` - VDPM point tracks
- `poses.npz` - Camera poses
- `depths/` - Computed depth maps
- `images/` - Input frames
### 🎯 How It Works
1. **VDPM**: Extracts temporally consistent 3D point maps from video
2. **4DGS Training**:
- Train canonical frame (t=0) with all Gaussian parameters
- Train subsequent frames with position-only updates (shared appearance)
3. **Animation**: Smooth camera path through training viewpoints
**Local runs**: Results saved to `output/pipeline/run_TIMESTAMP/`
""")
run_btn.click(
fn=run_pipeline,
inputs=[video_input, initial_iterations, subsequent_iterations, conf_threshold],
outputs=[download_btn, video_viewer, status_text]
)
if __name__ == "__main__":
# Download model on startup
if device == "cuda":
print("Pre-loading VDPM model...")
try:
get_vdpm_model()
_vdpm_model = None # Free VRAM but keep file cached
gc.collect()
torch.cuda.empty_cache()
print("βœ“ Model pre-loaded and cached")
except Exception as e:
print(f"⚠ Failed to pre-load model: {e}")
app.queue().launch(share=True, show_error=True)