| | """
|
| | Batch VDPM Inference (CLI version of Gradio Demo)
|
| |
|
| | This script replicates the exact logic of gradio_demo.py but for command-line usage.
|
| | It supports processing a folder of video files (treated as synchronized multi-view input)
|
| | or a single video file.
|
| |
|
| | Usage:
|
| | python vdpm/infer.py --input path/to/videos_folder --output output/
|
| | python vdpm/infer.py --input path/to/video.mp4 --output output/
|
| | """
|
| |
|
| | import os
|
| | import sys
|
| | import glob
|
| | import json
|
| | import argparse
|
| | import time
|
| | import shutil
|
| | import gc
|
| | from pathlib import Path
|
| | from datetime import datetime
|
| |
|
| | import cv2
|
| | import numpy as np
|
| | import torch
|
| | from hydra import compose, initialize
|
| | from hydra.core.global_hydra import GlobalHydra
|
| |
|
| |
|
| | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| |
|
| |
|
| | sys.path.insert(0, str(Path(__file__).parent))
|
| |
|
| | from dpm.model import VDPM
|
| | from vggt.utils.load_fn import load_and_preprocess_images
|
| | from util.depth import write_depth_to_png
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | VIDEO_SAMPLE_HZ = 1.0
|
| | USE_HALF_PRECISION = True
|
| | USE_QUANTIZATION = False
|
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu"
|
| |
|
| | MAX_FRAMES = 5
|
| | if device == "cuda":
|
| | torch.backends.cuda.matmul.allow_tf32 = True
|
| | torch.backends.cudnn.allow_tf32 = True
|
| |
|
| |
|
| | vram_bytes = torch.cuda.get_device_properties(0).total_memory
|
| | vram_gb = vram_bytes / (1024**3)
|
| |
|
| | print(f"✓ GPU Detected: {torch.cuda.get_device_name(0)} ({vram_gb:.1f} GB VRAM)")
|
| |
|
| | if vram_gb >= 22:
|
| | MAX_FRAMES = 80
|
| | print(f" -> High VRAM detected! Increased MAX_FRAMES to {MAX_FRAMES}")
|
| | elif vram_gb >= 14:
|
| | MAX_FRAMES = 16
|
| | print(f" -> Medium VRAM detected! Increased MAX_FRAMES to {MAX_FRAMES}")
|
| | elif vram_gb >= 7.5:
|
| | MAX_FRAMES = 8
|
| | print(f" -> 8GB VRAM detected. Set MAX_FRAMES to {MAX_FRAMES}")
|
| | else:
|
| | MAX_FRAMES = 5
|
| | print(f" -> Low VRAM (<8GB). Keeping MAX_FRAMES at {MAX_FRAMES} to prevent OOM")
|
| |
|
| | def require_cuda():
|
| | if device != "cuda":
|
| | raise ValueError("CUDA is not available. Check your environment.")
|
| |
|
| |
|
| | def decode_poses(pose_enc: np.ndarray, image_hw: tuple) -> tuple:
|
| | """Decode VGGT pose encodings to camera matrices."""
|
| | try:
|
| | from vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
| |
|
| | pose_enc_t = torch.from_numpy(pose_enc).float()
|
| | extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc_t, image_hw)
|
| |
|
| | extrinsic = extrinsic[0].numpy()
|
| | intrinsic = intrinsic[0].numpy()
|
| |
|
| | N = extrinsic.shape[0]
|
| | bottom = np.array([0, 0, 0, 1], dtype=np.float32).reshape(1, 1, 4)
|
| | bottom = np.tile(bottom, (N, 1, 1))
|
| | extrinsics_4x4 = np.concatenate([extrinsic, bottom], axis=1)
|
| |
|
| | return extrinsics_4x4, intrinsic
|
| |
|
| | except ImportError:
|
| | print("Warning: vggt not available. Using identity poses.")
|
| | N = pose_enc.shape[1]
|
| | extrinsics = np.tile(np.eye(4, dtype=np.float32), (N, 1, 1))
|
| |
|
| | H, W = image_hw
|
| | fx = fy = max(H, W)
|
| | cx, cy = W / 2, H / 2
|
| | intrinsic = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
|
| | intrinsics = np.tile(intrinsic, (N, 1, 1))
|
| |
|
| | return extrinsics, intrinsics
|
| |
|
| |
|
| | def compute_depths(world_points: np.ndarray, extrinsics: np.ndarray, num_views: int) -> np.ndarray:
|
| | """
|
| | Compute depth maps from world points and camera extrinsics.
|
| |
|
| | Args:
|
| | world_points: (T, V, H, W, 3) world-space 3D points
|
| | extrinsics: (N, 4, 4) camera extrinsics (world-to-camera)
|
| | num_views: Number of camera views
|
| |
|
| | Returns:
|
| | depths: (T, V, H, W) depth maps (Z in camera coordinates)
|
| | """
|
| | T, V, H, W, _ = world_points.shape
|
| | depths = np.zeros((T, V, H, W), dtype=np.float32)
|
| |
|
| | for t in range(T):
|
| | for v in range(V):
|
| |
|
| |
|
| | img_idx = t * num_views + v
|
| | if img_idx >= len(extrinsics):
|
| | img_idx = v
|
| |
|
| | w2c = extrinsics[img_idx]
|
| | R = w2c[:3, :3]
|
| | t_vec = w2c[:3, 3]
|
| |
|
| |
|
| | pts_world = world_points[t, v].reshape(-1, 3)
|
| | pts_cam = (R @ pts_world.T).T + t_vec
|
| |
|
| |
|
| | depth = pts_cam[:, 2].reshape(H, W)
|
| | depths[t, v] = depth
|
| |
|
| | return depths
|
| |
|
| | def load_cfg_from_cli() -> "omegaconf.DictConfig":
|
| | if GlobalHydra.instance().is_initialized():
|
| | GlobalHydra.instance().clear()
|
| |
|
| | with initialize(config_path="configs"):
|
| | return compose(config_name="visualise")
|
| |
|
| | def load_model(cfg) -> VDPM:
|
| | model = VDPM(cfg).to(device)
|
| |
|
| |
|
| | cache_dir = os.path.expanduser("~/.cache/vdpm")
|
| | os.makedirs(cache_dir, exist_ok=True)
|
| | model_path = os.path.join(cache_dir, "vdpm_model.pt")
|
| |
|
| | _URL = "https://huggingface.co/edgarsucar/vdpm/resolve/main/model.pt"
|
| |
|
| |
|
| | if not os.path.exists(model_path):
|
| | print(f"Downloading model to {model_path}...")
|
| | sd = torch.hub.load_state_dict_from_url(
|
| | _URL,
|
| | file_name="vdpm_model.pt",
|
| | progress=True,
|
| | map_location=device
|
| | )
|
| | torch.save(sd, model_path)
|
| | print(f"✓ Model cached at {model_path}")
|
| | else:
|
| | print(f"✓ Loading cached model from {model_path}")
|
| | sd = torch.load(model_path, map_location=device)
|
| |
|
| | print(model.load_state_dict(sd, strict=True))
|
| |
|
| | model.eval()
|
| |
|
| | if USE_HALF_PRECISION and not USE_QUANTIZATION:
|
| | if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
|
| | print("Converting model to BF16 precision...")
|
| | model = model.to(torch.bfloat16)
|
| | else:
|
| | print("Converting model to FP16 precision...")
|
| | model = model.half()
|
| |
|
| | if USE_QUANTIZATION:
|
| | try:
|
| | print("Applying INT8 dynamic quantization...")
|
| | model = model.cpu()
|
| | model = torch.quantization.quantize_dynamic(
|
| | model,
|
| | {torch.nn.Linear, torch.nn.Conv2d},
|
| | dtype=torch.qint8
|
| | )
|
| | model = model.to(device)
|
| | except Exception as e:
|
| | print(f"⚠️ Quantization failed: {e}")
|
| | model = model.to(device)
|
| |
|
| | if not USE_QUANTIZATION:
|
| | try:
|
| | print("Compiling model with torch.compile...")
|
| | model = torch.compile(model, mode="reduce-overhead")
|
| | except Exception as e:
|
| | print(f"Warning: torch.compile failed: {e}")
|
| |
|
| | return model
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def process_videos_interleaved(input_video_list, target_dir_images):
|
| | """
|
| | Extract frames from multiple videos in a synchronized/interleaved manner.
|
| | Matches handle_uploads logic from gradio_demo.py.
|
| | """
|
| | frame_num = 0
|
| | image_paths = []
|
| |
|
| |
|
| | captures = []
|
| | capture_meta = []
|
| |
|
| | for idx, video_path in enumerate(input_video_list):
|
| | print(f"Preparing video {idx+1}/{len(input_video_list)}: {video_path}")
|
| |
|
| | vs = cv2.VideoCapture(video_path)
|
| | fps = float(vs.get(cv2.CAP_PROP_FPS) or 0.0)
|
| | if fps <= 0: fps = 30.0
|
| |
|
| | frame_interval = max(int(fps / max(VIDEO_SAMPLE_HZ, 1e-6)), 1)
|
| | captures.append(vs)
|
| | capture_meta.append({"interval": frame_interval, "name": video_path})
|
| |
|
| |
|
| | print("Processing videos in interleaved mode...")
|
| | step_count = 0
|
| | active_videos = True
|
| |
|
| | while active_videos:
|
| | active_videos = False
|
| | for i, vs in enumerate(captures):
|
| | if not vs.isOpened():
|
| | continue
|
| |
|
| | gotit, frame = vs.read()
|
| | if gotit:
|
| | active_videos = True
|
| |
|
| | if step_count % capture_meta[i]["interval"] == 0:
|
| | out_path = os.path.join(target_dir_images, f"{frame_num:06}.png")
|
| | cv2.imwrite(out_path, frame)
|
| | image_paths.append(out_path)
|
| | frame_num += 1
|
| | else:
|
| | vs.release()
|
| |
|
| | step_count += 1
|
| |
|
| | return image_paths
|
| |
|
| |
|
| | def run_model(target_dir: str, model: VDPM, frame_id_arg=0) -> dict:
|
| | require_cuda()
|
| |
|
| | image_names = sorted(glob.glob(os.path.join(target_dir, "images", "*")))
|
| | if not image_names:
|
| | raise ValueError("No images found in target_dir.")
|
| |
|
| |
|
| | meta_path = os.path.join(target_dir, "meta.json")
|
| | num_views = 1
|
| | if os.path.exists(meta_path):
|
| | try:
|
| | with open(meta_path, 'r') as f:
|
| | num_views = json.load(f).get("num_views", 1)
|
| | except:
|
| | pass
|
| |
|
| |
|
| | if len(image_names) > MAX_FRAMES:
|
| | limit = (MAX_FRAMES // num_views) * num_views
|
| | if limit == 0:
|
| | limit = num_views
|
| | print(f"⚠️ Warning: MAX_FRAMES={MAX_FRAMES} is smaller than num_views={num_views}. Processing 1 full timestep anyway.")
|
| |
|
| | print(f"⚠️ Limiting to {limit} frames ({limit // num_views} timesteps * {num_views} views) to fit in GPU memory")
|
| | image_names = image_names[:limit]
|
| |
|
| | print(f"Loading {len(image_names)} images...")
|
| | images = load_and_preprocess_images(image_names).to(device)
|
| |
|
| | if device == "cuda":
|
| | print(f"GPU memory before inference: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
|
| |
|
| | print(f"Running inference on {len(image_names)} images ({num_views} synchronized views)...")
|
| |
|
| |
|
| | views = []
|
| | for i in range(len(image_names)):
|
| | t_idx = i // num_views
|
| | cam_idx = i % num_views
|
| | views.append({
|
| | "img": images[i].unsqueeze(0),
|
| | "view_idxs": torch.tensor([[cam_idx, t_idx]], device=device, dtype=torch.long)
|
| | })
|
| |
|
| | inference_start = time.time()
|
| |
|
| | with torch.no_grad():
|
| | with torch.amp.autocast('cuda'):
|
| | predictions = model.inference(views=views)
|
| |
|
| | inference_time = time.time() - inference_start
|
| | print(f"✓ Inference completed in {inference_time:.2f}s ({inference_time/len(image_names):.2f}s per frame)")
|
| |
|
| | pts_list = [pm["pts3d"].detach().cpu().numpy() for pm in predictions["pointmaps"]]
|
| | conf_list = [pm["conf"].detach().cpu().numpy() for pm in predictions["pointmaps"]]
|
| |
|
| |
|
| | pose_enc = None
|
| | if "pose_enc" in predictions:
|
| | pose_enc = predictions["pose_enc"].detach().cpu().numpy()
|
| |
|
| | del predictions
|
| | if device == "cuda":
|
| | torch.cuda.empty_cache()
|
| |
|
| | world_points_raw = np.concatenate(pts_list, axis=0)
|
| | world_points_conf_raw = np.concatenate(conf_list, axis=0)
|
| |
|
| | T = world_points_raw.shape[0]
|
| | S = world_points_raw.shape[1]
|
| | num_timesteps = T
|
| |
|
| | if num_views > 1 and S == num_views * T:
|
| |
|
| | print(f"DEBUG: Multi-view mode - extracting ALL {num_views} views")
|
| | world_points_list = []
|
| | world_points_conf_list = []
|
| | for t in range(T):
|
| | start_idx = t * num_views
|
| | end_idx = start_idx + num_views
|
| | world_points_list.append(world_points_raw[t, start_idx:end_idx])
|
| | world_points_conf_list.append(world_points_conf_raw[t, start_idx:end_idx])
|
| |
|
| | world_points_mv = np.stack(world_points_list, axis=0)
|
| | world_points_conf_mv = np.stack(world_points_conf_list, axis=0)
|
| |
|
| | world_points_full = world_points_mv
|
| | world_points_conf_full = world_points_conf_mv
|
| | else:
|
| |
|
| | if world_points_raw.ndim == 5 and world_points_raw.shape[0] == 1:
|
| | world_points = world_points_raw[0]
|
| | world_points_conf = world_points_conf_raw[0]
|
| | elif world_points_raw.ndim == 5:
|
| | world_points_list = []
|
| | world_points_conf_list = []
|
| | for t in range(min(T, S)):
|
| | world_points_list.append(world_points_raw[t, t])
|
| | world_points_conf_list.append(world_points_conf_raw[t, t])
|
| | world_points = np.stack(world_points_list, axis=0)
|
| | world_points_conf = np.stack(world_points_conf_list, axis=0)
|
| | else:
|
| | world_points = world_points_raw
|
| | world_points_conf = world_points_conf_raw
|
| |
|
| | world_points_full = world_points
|
| | world_points_conf_full = world_points_conf
|
| |
|
| |
|
| | tracks_path = os.path.join(target_dir, "tracks.npz")
|
| | print(f"Saving tracks (clean) to {tracks_path}")
|
| | np.savez_compressed(
|
| | tracks_path,
|
| | world_points=world_points_full,
|
| | world_points_conf=world_points_conf_full,
|
| | num_views=num_views,
|
| | num_timesteps=num_timesteps
|
| | )
|
| |
|
| | if pose_enc is not None:
|
| | poses_path = os.path.join(target_dir, "poses.npz")
|
| | print(f"Saving poses to {poses_path}")
|
| | np.savez_compressed(poses_path, pose_enc=pose_enc)
|
| |
|
| |
|
| |
|
| |
|
| | depths = None
|
| | if pose_enc is not None:
|
| | print("Computing depth maps from world points and camera poses...")
|
| |
|
| |
|
| | if world_points_full.ndim == 5:
|
| | _, _, H, W, _ = world_points_full.shape
|
| | elif world_points_full.ndim == 4:
|
| |
|
| | _, H, W, _ = world_points_full.shape
|
| | world_points_full = world_points_full[:, np.newaxis, :, :, :]
|
| | else:
|
| | H, W = 518, 518
|
| | print(f"Warning: Unexpected world_points shape {world_points_full.shape}")
|
| |
|
| | extrinsics, intrinsics = decode_poses(pose_enc, (H, W))
|
| | depths = compute_depths(world_points_full, extrinsics, num_views)
|
| |
|
| |
|
| | depths_path = os.path.join(target_dir, "depths.npz")
|
| | print(f"Saving depths to {depths_path}")
|
| | np.savez_compressed(
|
| | depths_path,
|
| | depths=depths,
|
| | num_views=num_views,
|
| | num_timesteps=num_timesteps
|
| | )
|
| |
|
| |
|
| | depths_dir = os.path.join(target_dir, "depths")
|
| | os.makedirs(depths_dir, exist_ok=True)
|
| | print(f"Saving depth images to {depths_dir}/")
|
| |
|
| | T_depth = depths.shape[0]
|
| | V_depth = depths.shape[1]
|
| | for t in range(T_depth):
|
| | for v in range(V_depth):
|
| | depth_map = depths[t, v]
|
| | png_path = os.path.join(depths_dir, f"depth_t{t:04d}_v{v:02d}.png")
|
| | write_depth_to_png(png_path, depth_map)
|
| |
|
| | print(f"✓ Saved {T_depth * V_depth} depth images")
|
| | else:
|
| | print("⚠ No pose encodings available - skipping depth computation")
|
| |
|
| |
|
| | output_path = os.path.join(target_dir, "output_4d.npz")
|
| | save_dict = {
|
| | "world_points": world_points_full,
|
| | "world_points_conf": world_points_conf_full,
|
| | "timestamps": np.arange(num_timesteps),
|
| | "num_views": num_views,
|
| | "num_timesteps": num_timesteps
|
| | }
|
| | if depths is not None:
|
| | save_dict["depths"] = depths
|
| | np.savez_compressed(output_path, **save_dict)
|
| |
|
| | return {
|
| | "tracks_path": tracks_path,
|
| | "output_path": output_path,
|
| | "depths_path": os.path.join(target_dir, "depths.npz") if depths is not None else None
|
| | }
|
| |
|
| | def main():
|
| | parser = argparse.ArgumentParser(description="Run VDPM Inference (CLI)")
|
| | parser.add_argument("--input", required=True, help="Input video file or folder containing videos")
|
| | parser.add_argument("--output", required=True, help="Output directory")
|
| | parser.add_argument("--name", help="Optional name for the reconstruction folder")
|
| |
|
| | args = parser.parse_args()
|
| |
|
| | input_path = Path(args.input)
|
| | output_root = Path(args.output)
|
| |
|
| |
|
| | videos = []
|
| | if input_path.is_file():
|
| | videos = [str(input_path)]
|
| | elif input_path.is_dir():
|
| |
|
| | found_videos = set()
|
| | for ext in ['*.mp4', '*.mov', '*.avi', '*.mkv']:
|
| |
|
| | matches = glob.glob(str(input_path / ext)) + glob.glob(str(input_path / ext.upper()))
|
| | for m in matches:
|
| | found_videos.add(os.path.abspath(m))
|
| |
|
| |
|
| | videos = sorted(list(found_videos))
|
| |
|
| | if not videos:
|
| | print(f"No videos found in {input_path}")
|
| | return
|
| |
|
| | print(f"Found {len(videos)} videos in {input_path}")
|
| | else:
|
| | print(f"Input {input_path} not found")
|
| | return
|
| |
|
| |
|
| | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| | folder_name = args.name if args.name else f"reconstruction_{timestamp}"
|
| | target_dir = output_root / folder_name
|
| | target_dir_images = target_dir / "images"
|
| |
|
| | if target_dir.exists():
|
| | print(f"Cleaning existing output dir: {target_dir}")
|
| | shutil.rmtree(target_dir)
|
| | target_dir_images.mkdir(parents=True, exist_ok=True)
|
| |
|
| |
|
| | process_videos_interleaved(videos, str(target_dir_images))
|
| |
|
| |
|
| | num_views = len(videos)
|
| | with open(target_dir / "meta.json", "w") as f:
|
| | json.dump({"num_views": num_views}, f)
|
| |
|
| | print(f"Metadata saved: {num_views} view(s)")
|
| |
|
| |
|
| | print("Loading model...")
|
| | cfg = load_cfg_from_cli()
|
| | model = load_model(cfg)
|
| |
|
| | print("Running inference...")
|
| | run_model(str(target_dir), model)
|
| |
|
| | print(f"\n{'='*60}")
|
| | print(f"Success! Output saved to:\n{target_dir}")
|
| | print(f"Next step: Train Gaussian Splats using:")
|
| | print(f"python gs/train_vdpm.py --input {target_dir} --output output/splats")
|
| | print(f"{'='*60}")
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|