import gc import glob import os import shutil import sys import time from datetime import datetime import cv2 import gradio as gr import matplotlib import numpy as np import plotly.graph_objects as go import torch import spaces from hydra import compose, initialize from hydra.core.global_hydra import GlobalHydra from vggt.utils.load_fn import load_and_preprocess_images sys.path.append("dpm/") from dpm.model import VDPM TRAIL_LENGTH = 20 MAX_TRACKS = 150 STATIC_THRESHOLD = 0.025 VIDEO_SAMPLE_HZ = 1.0 device = "cuda" if torch.cuda.is_available() else "cpu" def load_cfg_from_cli() -> "omegaconf.DictConfig": if GlobalHydra.instance().is_initialized(): GlobalHydra.instance().clear() with initialize(config_path="configs"): return compose(config_name="visualise") def load_model(cfg) -> VDPM: model = VDPM(cfg).to(device) _URL = "https://huggingface.co/edgarsucar/vdpm/resolve/main/model.pt" sd = torch.hub.load_state_dict_from_url( _URL, file_name="vdpm_model.pt", progress=True ) print(model.load_state_dict(sd, strict=True)) model.eval() return model def require_cuda(): if device != "cuda": raise ValueError("CUDA is not available. Check your environment.") def gradio_file_path(file_obj): if file_obj is None: return None if isinstance(file_obj, dict) and "name" in file_obj: return file_obj["name"] return file_obj def ensure_nhwc_images(images: np.ndarray) -> np.ndarray: if images.ndim == 4 and images.shape[1] == 3: return np.transpose(images, (0, 2, 3, 1)) return images def compute_scene_bounds(world_points: np.ndarray): all_pts = world_points.reshape(-1, 3) raw_min = all_pts.min(axis=0) raw_max = all_pts.max(axis=0) center = 0.5 * (raw_min + raw_max) half_extent = 0.5 * (raw_max - raw_min) * 1.05 if np.all(half_extent < 1e-6): half_extent[:] = 1.0 else: half_extent[half_extent < 1e-6] = half_extent.max() global_min = center - half_extent global_max = center + half_extent max_half = half_extent.max() aspectratio = { "x": float(half_extent[0] / max_half), "y": float(half_extent[1] / max_half), "z": float(half_extent[2] / max_half), } return global_min, global_max, aspectratio def compute_point_mask( conf_score: np.ndarray | None, cols: np.ndarray, conf_thres: float, mask_black_bg: bool, mask_white_bg: bool, ) -> np.ndarray: """ conf_score: (N,) or None cols: (N,3) uint8 Returns: (N,) boolean mask """ mask = np.ones(cols.shape[0], dtype=bool) # confidence percentile threshold (same semantics as before) if conf_score is not None and conf_thres > 0: thresh = np.percentile(conf_score, conf_thres) mask &= (conf_score >= thresh) & (conf_score > 1e-5) # background masks (same as before) if mask_black_bg: mask &= (cols.sum(axis=1) >= 16) if mask_white_bg: mask &= ~((cols[:, 0] > 240) & (cols[:, 1] > 240) & (cols[:, 2] > 240)) return mask def sample_frame_points( world_points: np.ndarray, images_nhwc: np.ndarray, conf: np.ndarray | None, idx: int, conf_thres: float, mask_black_bg: bool, mask_white_bg: bool, ): i = int(np.clip(idx, 0, world_points.shape[0] - 1)) pts = world_points[i].reshape(-1, 3) cols = (images_nhwc[i].reshape(-1, 3) * 255).astype(np.uint8) conf_score = conf[i].reshape(-1) if (conf is not None) else None mask = compute_point_mask( conf_score=conf_score, cols=cols, conf_thres=conf_thres, mask_black_bg=mask_black_bg, mask_white_bg=mask_white_bg, ) pts = pts[mask] cols = cols[mask] if pts.size == 0: pts = np.array([[0.0, 0.0, 0.0]]) cols = np.array([[255, 255, 255]], dtype=np.uint8) colors_str = [f"#{r:02x}{g:02x}{b:02x}" for r, g, b in cols] return pts, colors_str def prepare_tracks( world_points: np.ndarray, images_nhwc: np.ndarray, conf: np.ndarray | None, conf_thres: float, mask_black_bg: bool, mask_white_bg: bool, ): S, H, W, _ = world_points.shape N = H * W if S < 2 or N == 0: return None, None, None tracks_xyz = world_points.reshape(S, N, 3) disp = np.linalg.norm(tracks_xyz - tracks_xyz[0:1], axis=-1) dynamic_mask = disp.max(axis=0) > STATIC_THRESHOLD # build a per-point confidence score (across time) conf_score = None if conf is not None: conf_flat = conf.reshape(S, N) conf_score = conf_flat.mean(axis=0) # Use reference-frame colors for background masking (stable, consistent) ref_cols = (images_nhwc[0].reshape(-1, 3) * 255).astype(np.uint8) point_mask = compute_point_mask( conf_score=conf_score, cols=ref_cols, conf_thres=conf_thres, mask_black_bg=mask_black_bg, mask_white_bg=mask_white_bg, ) dynamic_mask &= point_mask idx_tracks = np.nonzero(dynamic_mask)[0] if idx_tracks.size == 0: return None, None, None if idx_tracks.size > MAX_TRACKS: step = int(np.ceil(idx_tracks.size / MAX_TRACKS)) idx_tracks = idx_tracks[::step][:MAX_TRACKS] tracks_xyz = tracks_xyz[:, idx_tracks, :] order = np.argsort(tracks_xyz[0, :, 1]) tracks_xyz = tracks_xyz[:, order, :] num_tracks = tracks_xyz.shape[1] cmap = matplotlib.colormaps.get_cmap("hsv") norm = matplotlib.colors.Normalize(vmin=0, vmax=max(num_tracks - 1, 1)) colorscale = [] for t in range(num_tracks): r, g, b, _ = cmap(norm(t)) r, g, b = int(r * 255), int(g * 255), int(b * 255) pos = t / max(num_tracks - 1, 1) colorscale.append([pos, f"rgb({r},{g},{b})"]) track_ids = np.arange(num_tracks, dtype=float) return tracks_xyz, colorscale, track_ids def track_segments_for_frame(tracks_xyz: np.ndarray | None, track_ids: np.ndarray | None, f: int): if tracks_xyz is None or track_ids is None or f <= 0: return np.array([]), np.array([]), np.array([]), np.array([]) start_t = max(0, f - TRAIL_LENGTH) num_tracks = tracks_xyz.shape[1] xs, ys, zs, cs = [], [], [], [] for j in range(num_tracks): seg = tracks_xyz[start_t : f + 1, j, :] if seg.shape[0] < 2: continue xs.extend([seg[:, 0], np.array([np.nan])]) ys.extend([seg[:, 1], np.array([np.nan])]) zs.extend([seg[:, 2], np.array([np.nan])]) cs.append(np.full(seg.shape[0] + 1, track_ids[j], dtype=float)) x = np.concatenate(xs) if xs else np.array([]) y = np.concatenate(ys) if ys else np.array([]) z = np.concatenate(zs) if zs else np.array([]) c = np.concatenate(cs) if cs else np.array([]) return x, y, z, c def build_pointcloud_figure_update( data, conf_thres: float, mask_black_bg: bool, mask_white_bg: bool, ): if data is None: return go.Figure() world_points = data["world_points"] conf = data.get("world_points_conf") images = ensure_nhwc_images(data["images"]) S = world_points.shape[0] global_min, global_max, aspectratio = compute_scene_bounds(world_points) tracks_xyz, colorscale, track_ids = prepare_tracks( world_points=world_points, images_nhwc=images, conf=conf, conf_thres=conf_thres, mask_black_bg=mask_black_bg, mask_white_bg=mask_white_bg, ) track_cmax = max(len(track_ids) - 1, 1) if track_ids is not None else 1 pts_xyz = [None] * S pts_cols = [None] * S trk_xyz = [None] * S trk_c = [None] * S for i in range(S): pts_i, cols_i = sample_frame_points( world_points=world_points, images_nhwc=images, conf=conf, idx=i, conf_thres=conf_thres, mask_black_bg=mask_black_bg, mask_white_bg=mask_white_bg, ) pts_xyz[i] = pts_i pts_cols[i] = cols_i x, y, z, c = track_segments_for_frame(tracks_xyz, track_ids, f=i) trk_xyz[i] = (x, y, z) trk_c[i] = c p0 = pts_xyz[0] c0 = pts_cols[0] x0, y0, z0 = trk_xyz[0] tc0 = trk_c[0] scene_cfg = dict( xaxis=dict( visible=False, showbackground=False, showgrid=False, zeroline=False, showticklabels=False, range=[float(global_min[0]), float(global_max[0])], ), yaxis=dict( visible=False, showbackground=False, showgrid=False, zeroline=False, showticklabels=False, range=[float(global_min[1]), float(global_max[1])], ), zaxis=dict( visible=False, showbackground=False, showgrid=False, zeroline=False, showticklabels=False, range=[float(global_min[2]), float(global_max[2])], ), aspectmode="manual", aspectratio=aspectratio, dragmode="orbit", camera=dict( eye=dict(x=0.0, y=0.0, z=-1.0), center=dict(x=0.0, y=0.0, z=0.0), up=dict(x=0.0, y=-1.0, z=0.0), ), ) fig = go.Figure( data=[ go.Scatter3d( x=p0[:, 0], y=p0[:, 1], z=p0[:, 2], mode="markers", marker=dict(size=2, color=c0), showlegend=False, name="points", ), go.Scatter3d( x=x0, y=y0, z=z0, mode="lines", line=dict( width=2, color=tc0 if (tc0 is not None and tc0.size) else None, colorscale=colorscale if colorscale is not None else None, cmin=0, cmax=track_cmax, ), hoverinfo="skip", showlegend=False, name="tracks", ), ] ) steps = [] for i in range(S): pi = pts_xyz[i] ci = pts_cols[i] xi, yi, zi = trk_xyz[i] ti = trk_c[i] steps.append( dict( method="update", label=str(i), args=[ { "x": [pi[:, 0], xi], "y": [pi[:, 1], yi], "z": [pi[:, 2], zi], "marker.color": [ci, None], "line.color": [None, ti if (ti is not None and len(ti)) else None], }, {}, ], ) ) sliders = [ dict( active=0, currentvalue={"prefix": "Frame: ", "visible": True, "font": {"size": 14}}, pad={"t": 10}, len=0.6, x=0.2, font={"size": 8}, steps=steps, ) ] fig.update_layout( margin=dict(l=0, r=0, t=30, b=0), scene=scene_cfg, sliders=sliders, showlegend=False, title="Scrub frames with the slider below", uirevision="keep-camera", height=700, ) return fig @spaces.GPU(duration=120) def run_model(target_dir: str, model: VDPM, frame_id_arg=0) -> dict: require_cuda() image_names = sorted(glob.glob(os.path.join(target_dir, "images", "*"))) if not image_names: raise ValueError("No images found. Check your upload.") images = load_and_preprocess_images(image_names).to(device) with torch.no_grad(): predictions = model.inference(None, images=images.unsqueeze(0)) pts_list = [pm["pts3d"].detach().cpu().numpy() for pm in predictions["pointmaps"]] conf_list = [pm["conf"].detach().cpu().numpy() for pm in predictions["pointmaps"]] world_points = np.concatenate(pts_list, axis=0) world_points_conf = np.concatenate(conf_list, axis=0) try: frame_id = int(frame_id_arg) except Exception: frame_id = 0 if frame_id >= world_points.shape[0]: frame_id = 0 world_points_s = world_points[:, frame_id, ::2, ::2, :] single_mask = world_points_conf[frame_id, frame_id, ::2, ::2] world_points_conf_s = np.tile(single_mask[np.newaxis, ...], (world_points.shape[0], 1, 1)) img_np = images.detach().cpu().numpy() img_np = img_np[frame_id : frame_id + 1, :, ::2, ::2] img_np = np.repeat(img_np, world_points.shape[0], axis=0) torch.cuda.empty_cache() return { "world_points": world_points_s, "world_points_conf": world_points_conf_s, "images": img_np, } def handle_uploads(input_video, input_images): start_time = time.time() gc.collect() torch.cuda.empty_cache() timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") target_dir = f"input_images_{timestamp}" target_dir_images = os.path.join(target_dir, "images") if os.path.exists(target_dir): shutil.rmtree(target_dir) os.makedirs(target_dir_images, exist_ok=True) image_paths = [] if input_images: for file_obj in input_images: src = gradio_file_path(file_obj) if not src: continue dst = os.path.join(target_dir_images, os.path.basename(src)) shutil.copy(src, dst) image_paths.append(dst) if input_video: video_path = gradio_file_path(input_video) vs = cv2.VideoCapture(video_path) fps = float(vs.get(cv2.CAP_PROP_FPS) or 0.0) frame_interval = max(int(fps / max(VIDEO_SAMPLE_HZ, 1e-6)), 1) count = 0 frame_num = 0 try: while True: gotit, frame = vs.read() if not gotit: break if count % frame_interval == 0: out_path = os.path.join(target_dir_images, f"{frame_num:06}.png") cv2.imwrite(out_path, frame) image_paths.append(out_path) frame_num += 1 count += 1 finally: vs.release() image_paths.sort() print(f"Files copied to {target_dir_images}; took {time.time() - start_time:.3f} seconds") return target_dir, image_paths def update_gallery_on_upload(input_video, input_images): if not input_video and not input_images: return None, None, None, None target_dir, image_paths = handle_uploads(input_video, input_images) return None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing." @spaces.GPU(duration=120) def gradio_reconstruct( target_dir, conf_thres=50.0, mask_black_bg=False, mask_white_bg=False, frame_id_val=0, ): if not os.path.isdir(target_dir) or target_dir == "None": return None, "No valid target directory found. Please upload first." gc.collect() torch.cuda.empty_cache() target_dir_images = os.path.join(target_dir, "images") num_frames = len(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else 0 with torch.no_grad(): predictions = run_model(target_dir, model, frame_id_val) predictions_path = os.path.join(target_dir, "predictions.npz") np.savez_compressed(predictions_path, **predictions) fig = build_pointcloud_figure_update(predictions, conf_thres, mask_black_bg, mask_white_bg) torch.cuda.empty_cache() msg = f"Reconstruction Success ({num_frames} frames processed, showing frame {frame_id_val})." return fig, msg def update_plot( target_dir, conf_thres, mask_black_bg, mask_white_bg, is_example, ): if is_example == "True": return None, "No reconstruction available. Please click the Reconstruct button first." if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): return None, "No reconstruction available. Please click the Reconstruct button first." predictions_path = os.path.join(target_dir, "predictions.npz") if not os.path.exists(predictions_path): return None, f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first." key_list = [ "world_points", "world_points_conf", "images", ] loaded = np.load(predictions_path) predictions = {} for k in key_list: if k in loaded.files: predictions[k] = np.array(loaded[k]) if "world_points" not in predictions or "images" not in predictions: return None, f"Saved predictions at {predictions_path} are missing required keys." fig = build_pointcloud_figure_update(predictions, conf_thres, mask_black_bg, mask_white_bg) return fig, "Updated visualization with new settings. Use the slider below the plot to scrub frames." def clear_fields(): return None def update_log(): return "Loading and Reconstructing..." def example_pipeline( input_video_ex, num_images_str, input_images_ex, conf_thres_val, mask_black_bg_val, mask_white_bg_val, is_example_str, frame_id_val, ): target_dir, image_paths = handle_uploads(input_video_ex, input_images_ex) fig, log_msg = gradio_reconstruct( target_dir, conf_thres_val, mask_black_bg_val, mask_white_bg_val, frame_id_val, ) # We keep the same outputs as before (minus predictions_state) return fig, log_msg, target_dir, image_paths camel_video = "examples/videos/camel.mp4" tennis_video = "examples/videos/tennis.mp4" paragliding_video = "examples/videos/paragliding.mp4" stroller_video = "examples/videos/stroller.mp4" goldfish_video = "examples/videos/goldfish.mp4" horse_video = "examples/videos/horse.mp4" swing_video = "examples/videos/swing.mp4" car_video = "examples/videos/car.mp4" figure1_video = "examples/videos/figure1.mp4" figure2_video = "examples/videos/figure2.mp4" figure3_video = "examples/videos/figure3.mp4" tesla_video = "examples/videos/tesla.mp4" pstudio_video = "examples/videos/pstudio.mp4" theme = gr.themes.Default( primary_hue=gr.themes.colors.slate, secondary_hue=gr.themes.colors.zinc, neutral_hue=gr.themes.colors.slate, ).set( checkbox_label_background_fill_selected="*button_primary_background_fill", checkbox_label_text_color_selected="*button_primary_text_color", body_background_fill="#FFFFFF", ) css = """ .custom-log * { font-style: italic; font-size: 22px !important; background-image: linear-gradient(120deg, #1f2937 0%, #4b5563 100%); -webkit-background-clip: text; background-clip: text; font-weight: bold !important; color: transparent !important; text-align: center !important; } .example-log * { font-style: italic; font-size: 16px !important; background-image: linear-gradient(120deg, #1f2937 0%, #4b5563 100%); -webkit-background-clip: text; background-clip: text; color: transparent !important; } #my_radio .wrap { display: flex; flex-wrap: nowrap; justify-content: center; align-items: center; } #my_radio .wrap label { display: flex; width: 50%; justify-content: center; align-items: center; margin: 0; padding: 10px 0; box-sizing: border-box; } """ cfg = load_cfg_from_cli() model = load_model(cfg) with gr.Blocks(theme=theme, css=css) as demo: is_example = gr.Textbox(label="is_example", visible=False, value="None") num_images = gr.Textbox(label="num_images", visible=False, value="None") frame_id_state = gr.Textbox(label="frame_id", visible=False, value="0") gr.HTML( """
GitHub Repository | Project Page
Upload a video or a set of images to create a dynamic point map reconstruction of a scene or object.