vdpm / app.py
Edgar Sucar
Fix links and title
4e7bf8e
import gc
import glob
import os
import shutil
import sys
import time
from datetime import datetime
import cv2
import gradio as gr
import matplotlib
import numpy as np
import plotly.graph_objects as go
import torch
import spaces
from hydra import compose, initialize
from hydra.core.global_hydra import GlobalHydra
from vggt.utils.load_fn import load_and_preprocess_images
sys.path.append("dpm/")
from dpm.model import VDPM
TRAIL_LENGTH = 20
MAX_TRACKS = 150
STATIC_THRESHOLD = 0.025
VIDEO_SAMPLE_HZ = 1.0
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_cfg_from_cli() -> "omegaconf.DictConfig":
if GlobalHydra.instance().is_initialized():
GlobalHydra.instance().clear()
with initialize(config_path="configs"):
return compose(config_name="visualise")
def load_model(cfg) -> VDPM:
model = VDPM(cfg).to(device)
_URL = "https://huggingface.co/edgarsucar/vdpm/resolve/main/model.pt"
sd = torch.hub.load_state_dict_from_url(
_URL,
file_name="vdpm_model.pt",
progress=True
)
print(model.load_state_dict(sd, strict=True))
model.eval()
return model
def require_cuda():
if device != "cuda":
raise ValueError("CUDA is not available. Check your environment.")
def gradio_file_path(file_obj):
if file_obj is None:
return None
if isinstance(file_obj, dict) and "name" in file_obj:
return file_obj["name"]
return file_obj
def ensure_nhwc_images(images: np.ndarray) -> np.ndarray:
if images.ndim == 4 and images.shape[1] == 3:
return np.transpose(images, (0, 2, 3, 1))
return images
def compute_scene_bounds(world_points: np.ndarray):
all_pts = world_points.reshape(-1, 3)
raw_min = all_pts.min(axis=0)
raw_max = all_pts.max(axis=0)
center = 0.5 * (raw_min + raw_max)
half_extent = 0.5 * (raw_max - raw_min) * 1.05
if np.all(half_extent < 1e-6):
half_extent[:] = 1.0
else:
half_extent[half_extent < 1e-6] = half_extent.max()
global_min = center - half_extent
global_max = center + half_extent
max_half = half_extent.max()
aspectratio = {
"x": float(half_extent[0] / max_half),
"y": float(half_extent[1] / max_half),
"z": float(half_extent[2] / max_half),
}
return global_min, global_max, aspectratio
def compute_point_mask(
conf_score: np.ndarray | None,
cols: np.ndarray,
conf_thres: float,
mask_black_bg: bool,
mask_white_bg: bool,
) -> np.ndarray:
"""
conf_score: (N,) or None
cols: (N,3) uint8
Returns: (N,) boolean mask
"""
mask = np.ones(cols.shape[0], dtype=bool)
# confidence percentile threshold (same semantics as before)
if conf_score is not None and conf_thres > 0:
thresh = np.percentile(conf_score, conf_thres)
mask &= (conf_score >= thresh) & (conf_score > 1e-5)
# background masks (same as before)
if mask_black_bg:
mask &= (cols.sum(axis=1) >= 16)
if mask_white_bg:
mask &= ~((cols[:, 0] > 240) & (cols[:, 1] > 240) & (cols[:, 2] > 240))
return mask
def sample_frame_points(
world_points: np.ndarray,
images_nhwc: np.ndarray,
conf: np.ndarray | None,
idx: int,
conf_thres: float,
mask_black_bg: bool,
mask_white_bg: bool,
):
i = int(np.clip(idx, 0, world_points.shape[0] - 1))
pts = world_points[i].reshape(-1, 3)
cols = (images_nhwc[i].reshape(-1, 3) * 255).astype(np.uint8)
conf_score = conf[i].reshape(-1) if (conf is not None) else None
mask = compute_point_mask(
conf_score=conf_score,
cols=cols,
conf_thres=conf_thres,
mask_black_bg=mask_black_bg,
mask_white_bg=mask_white_bg,
)
pts = pts[mask]
cols = cols[mask]
if pts.size == 0:
pts = np.array([[0.0, 0.0, 0.0]])
cols = np.array([[255, 255, 255]], dtype=np.uint8)
colors_str = [f"#{r:02x}{g:02x}{b:02x}" for r, g, b in cols]
return pts, colors_str
def prepare_tracks(
world_points: np.ndarray,
images_nhwc: np.ndarray,
conf: np.ndarray | None,
conf_thres: float,
mask_black_bg: bool,
mask_white_bg: bool,
):
S, H, W, _ = world_points.shape
N = H * W
if S < 2 or N == 0:
return None, None, None
tracks_xyz = world_points.reshape(S, N, 3)
disp = np.linalg.norm(tracks_xyz - tracks_xyz[0:1], axis=-1)
dynamic_mask = disp.max(axis=0) > STATIC_THRESHOLD
# build a per-point confidence score (across time)
conf_score = None
if conf is not None:
conf_flat = conf.reshape(S, N)
conf_score = conf_flat.mean(axis=0)
# Use reference-frame colors for background masking (stable, consistent)
ref_cols = (images_nhwc[0].reshape(-1, 3) * 255).astype(np.uint8)
point_mask = compute_point_mask(
conf_score=conf_score,
cols=ref_cols,
conf_thres=conf_thres,
mask_black_bg=mask_black_bg,
mask_white_bg=mask_white_bg,
)
dynamic_mask &= point_mask
idx_tracks = np.nonzero(dynamic_mask)[0]
if idx_tracks.size == 0:
return None, None, None
if idx_tracks.size > MAX_TRACKS:
step = int(np.ceil(idx_tracks.size / MAX_TRACKS))
idx_tracks = idx_tracks[::step][:MAX_TRACKS]
tracks_xyz = tracks_xyz[:, idx_tracks, :]
order = np.argsort(tracks_xyz[0, :, 1])
tracks_xyz = tracks_xyz[:, order, :]
num_tracks = tracks_xyz.shape[1]
cmap = matplotlib.colormaps.get_cmap("hsv")
norm = matplotlib.colors.Normalize(vmin=0, vmax=max(num_tracks - 1, 1))
colorscale = []
for t in range(num_tracks):
r, g, b, _ = cmap(norm(t))
r, g, b = int(r * 255), int(g * 255), int(b * 255)
pos = t / max(num_tracks - 1, 1)
colorscale.append([pos, f"rgb({r},{g},{b})"])
track_ids = np.arange(num_tracks, dtype=float)
return tracks_xyz, colorscale, track_ids
def track_segments_for_frame(tracks_xyz: np.ndarray | None, track_ids: np.ndarray | None, f: int):
if tracks_xyz is None or track_ids is None or f <= 0:
return np.array([]), np.array([]), np.array([]), np.array([])
start_t = max(0, f - TRAIL_LENGTH)
num_tracks = tracks_xyz.shape[1]
xs, ys, zs, cs = [], [], [], []
for j in range(num_tracks):
seg = tracks_xyz[start_t : f + 1, j, :]
if seg.shape[0] < 2:
continue
xs.extend([seg[:, 0], np.array([np.nan])])
ys.extend([seg[:, 1], np.array([np.nan])])
zs.extend([seg[:, 2], np.array([np.nan])])
cs.append(np.full(seg.shape[0] + 1, track_ids[j], dtype=float))
x = np.concatenate(xs) if xs else np.array([])
y = np.concatenate(ys) if ys else np.array([])
z = np.concatenate(zs) if zs else np.array([])
c = np.concatenate(cs) if cs else np.array([])
return x, y, z, c
def build_pointcloud_figure_update(
data,
conf_thres: float,
mask_black_bg: bool,
mask_white_bg: bool,
):
if data is None:
return go.Figure()
world_points = data["world_points"]
conf = data.get("world_points_conf")
images = ensure_nhwc_images(data["images"])
S = world_points.shape[0]
global_min, global_max, aspectratio = compute_scene_bounds(world_points)
tracks_xyz, colorscale, track_ids = prepare_tracks(
world_points=world_points,
images_nhwc=images,
conf=conf,
conf_thres=conf_thres,
mask_black_bg=mask_black_bg,
mask_white_bg=mask_white_bg,
)
track_cmax = max(len(track_ids) - 1, 1) if track_ids is not None else 1
pts_xyz = [None] * S
pts_cols = [None] * S
trk_xyz = [None] * S
trk_c = [None] * S
for i in range(S):
pts_i, cols_i = sample_frame_points(
world_points=world_points,
images_nhwc=images,
conf=conf,
idx=i,
conf_thres=conf_thres,
mask_black_bg=mask_black_bg,
mask_white_bg=mask_white_bg,
)
pts_xyz[i] = pts_i
pts_cols[i] = cols_i
x, y, z, c = track_segments_for_frame(tracks_xyz, track_ids, f=i)
trk_xyz[i] = (x, y, z)
trk_c[i] = c
p0 = pts_xyz[0]
c0 = pts_cols[0]
x0, y0, z0 = trk_xyz[0]
tc0 = trk_c[0]
scene_cfg = dict(
xaxis=dict(
visible=False,
showbackground=False,
showgrid=False,
zeroline=False,
showticklabels=False,
range=[float(global_min[0]), float(global_max[0])],
),
yaxis=dict(
visible=False,
showbackground=False,
showgrid=False,
zeroline=False,
showticklabels=False,
range=[float(global_min[1]), float(global_max[1])],
),
zaxis=dict(
visible=False,
showbackground=False,
showgrid=False,
zeroline=False,
showticklabels=False,
range=[float(global_min[2]), float(global_max[2])],
),
aspectmode="manual",
aspectratio=aspectratio,
dragmode="orbit",
camera=dict(
eye=dict(x=0.0, y=0.0, z=-1.0),
center=dict(x=0.0, y=0.0, z=0.0),
up=dict(x=0.0, y=-1.0, z=0.0),
),
)
fig = go.Figure(
data=[
go.Scatter3d(
x=p0[:, 0],
y=p0[:, 1],
z=p0[:, 2],
mode="markers",
marker=dict(size=2, color=c0),
showlegend=False,
name="points",
),
go.Scatter3d(
x=x0,
y=y0,
z=z0,
mode="lines",
line=dict(
width=2,
color=tc0 if (tc0 is not None and tc0.size) else None,
colorscale=colorscale if colorscale is not None else None,
cmin=0,
cmax=track_cmax,
),
hoverinfo="skip",
showlegend=False,
name="tracks",
),
]
)
steps = []
for i in range(S):
pi = pts_xyz[i]
ci = pts_cols[i]
xi, yi, zi = trk_xyz[i]
ti = trk_c[i]
steps.append(
dict(
method="update",
label=str(i),
args=[
{
"x": [pi[:, 0], xi],
"y": [pi[:, 1], yi],
"z": [pi[:, 2], zi],
"marker.color": [ci, None],
"line.color": [None, ti if (ti is not None and len(ti)) else None],
},
{},
],
)
)
sliders = [
dict(
active=0,
currentvalue={"prefix": "Frame: ", "visible": True, "font": {"size": 14}},
pad={"t": 10},
len=0.6,
x=0.2,
font={"size": 8},
steps=steps,
)
]
fig.update_layout(
margin=dict(l=0, r=0, t=30, b=0),
scene=scene_cfg,
sliders=sliders,
showlegend=False,
title="Scrub frames with the slider below",
uirevision="keep-camera",
height=700,
)
return fig
@spaces.GPU(duration=120)
def run_model(target_dir: str, model: VDPM, frame_id_arg=0) -> dict:
require_cuda()
image_names = sorted(glob.glob(os.path.join(target_dir, "images", "*")))
if not image_names:
raise ValueError("No images found. Check your upload.")
images = load_and_preprocess_images(image_names).to(device)
with torch.no_grad():
predictions = model.inference(None, images=images.unsqueeze(0))
pts_list = [pm["pts3d"].detach().cpu().numpy() for pm in predictions["pointmaps"]]
conf_list = [pm["conf"].detach().cpu().numpy() for pm in predictions["pointmaps"]]
world_points = np.concatenate(pts_list, axis=0)
world_points_conf = np.concatenate(conf_list, axis=0)
try:
frame_id = int(frame_id_arg)
except Exception:
frame_id = 0
if frame_id >= world_points.shape[0]:
frame_id = 0
world_points_s = world_points[:, frame_id, ::2, ::2, :]
single_mask = world_points_conf[frame_id, frame_id, ::2, ::2]
world_points_conf_s = np.tile(single_mask[np.newaxis, ...], (world_points.shape[0], 1, 1))
img_np = images.detach().cpu().numpy()
img_np = img_np[frame_id : frame_id + 1, :, ::2, ::2]
img_np = np.repeat(img_np, world_points.shape[0], axis=0)
torch.cuda.empty_cache()
return {
"world_points": world_points_s,
"world_points_conf": world_points_conf_s,
"images": img_np,
}
def handle_uploads(input_video, input_images):
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
target_dir = f"input_images_{timestamp}"
target_dir_images = os.path.join(target_dir, "images")
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
os.makedirs(target_dir_images, exist_ok=True)
image_paths = []
if input_images:
for file_obj in input_images:
src = gradio_file_path(file_obj)
if not src:
continue
dst = os.path.join(target_dir_images, os.path.basename(src))
shutil.copy(src, dst)
image_paths.append(dst)
if input_video:
video_path = gradio_file_path(input_video)
vs = cv2.VideoCapture(video_path)
fps = float(vs.get(cv2.CAP_PROP_FPS) or 0.0)
frame_interval = max(int(fps / max(VIDEO_SAMPLE_HZ, 1e-6)), 1)
count = 0
frame_num = 0
try:
while True:
gotit, frame = vs.read()
if not gotit:
break
if count % frame_interval == 0:
out_path = os.path.join(target_dir_images, f"{frame_num:06}.png")
cv2.imwrite(out_path, frame)
image_paths.append(out_path)
frame_num += 1
count += 1
finally:
vs.release()
image_paths.sort()
print(f"Files copied to {target_dir_images}; took {time.time() - start_time:.3f} seconds")
return target_dir, image_paths
def update_gallery_on_upload(input_video, input_images):
if not input_video and not input_images:
return None, None, None, None
target_dir, image_paths = handle_uploads(input_video, input_images)
return None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing."
@spaces.GPU(duration=120)
def gradio_reconstruct(
target_dir,
conf_thres=50.0,
mask_black_bg=False,
mask_white_bg=False,
frame_id_val=0,
):
if not os.path.isdir(target_dir) or target_dir == "None":
return None, "No valid target directory found. Please upload first."
gc.collect()
torch.cuda.empty_cache()
target_dir_images = os.path.join(target_dir, "images")
num_frames = len(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else 0
with torch.no_grad():
predictions = run_model(target_dir, model, frame_id_val)
predictions_path = os.path.join(target_dir, "predictions.npz")
np.savez_compressed(predictions_path, **predictions)
fig = build_pointcloud_figure_update(predictions, conf_thres, mask_black_bg, mask_white_bg)
torch.cuda.empty_cache()
msg = f"Reconstruction Success ({num_frames} frames processed, showing frame {frame_id_val})."
return fig, msg
def update_plot(
target_dir,
conf_thres,
mask_black_bg,
mask_white_bg,
is_example,
):
if is_example == "True":
return None, "No reconstruction available. Please click the Reconstruct button first."
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
return None, "No reconstruction available. Please click the Reconstruct button first."
predictions_path = os.path.join(target_dir, "predictions.npz")
if not os.path.exists(predictions_path):
return None, f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first."
key_list = [
"world_points",
"world_points_conf",
"images",
]
loaded = np.load(predictions_path)
predictions = {}
for k in key_list:
if k in loaded.files:
predictions[k] = np.array(loaded[k])
if "world_points" not in predictions or "images" not in predictions:
return None, f"Saved predictions at {predictions_path} are missing required keys."
fig = build_pointcloud_figure_update(predictions, conf_thres, mask_black_bg, mask_white_bg)
return fig, "Updated visualization with new settings. Use the slider below the plot to scrub frames."
def clear_fields():
return None
def update_log():
return "Loading and Reconstructing..."
def example_pipeline(
input_video_ex,
num_images_str,
input_images_ex,
conf_thres_val,
mask_black_bg_val,
mask_white_bg_val,
is_example_str,
frame_id_val,
):
target_dir, image_paths = handle_uploads(input_video_ex, input_images_ex)
fig, log_msg = gradio_reconstruct(
target_dir,
conf_thres_val,
mask_black_bg_val,
mask_white_bg_val,
frame_id_val,
)
# We keep the same outputs as before (minus predictions_state)
return fig, log_msg, target_dir, image_paths
camel_video = "examples/videos/camel.mp4"
tennis_video = "examples/videos/tennis.mp4"
paragliding_video = "examples/videos/paragliding.mp4"
stroller_video = "examples/videos/stroller.mp4"
goldfish_video = "examples/videos/goldfish.mp4"
horse_video = "examples/videos/horse.mp4"
swing_video = "examples/videos/swing.mp4"
car_video = "examples/videos/car.mp4"
figure1_video = "examples/videos/figure1.mp4"
figure2_video = "examples/videos/figure2.mp4"
figure3_video = "examples/videos/figure3.mp4"
tesla_video = "examples/videos/tesla.mp4"
pstudio_video = "examples/videos/pstudio.mp4"
theme = gr.themes.Default(
primary_hue=gr.themes.colors.slate,
secondary_hue=gr.themes.colors.zinc,
neutral_hue=gr.themes.colors.slate,
).set(
checkbox_label_background_fill_selected="*button_primary_background_fill",
checkbox_label_text_color_selected="*button_primary_text_color",
body_background_fill="#FFFFFF",
)
css = """
.custom-log * {
font-style: italic;
font-size: 22px !important;
background-image: linear-gradient(120deg, #1f2937 0%, #4b5563 100%);
-webkit-background-clip: text;
background-clip: text;
font-weight: bold !important;
color: transparent !important;
text-align: center !important;
}
.example-log * {
font-style: italic;
font-size: 16px !important;
background-image: linear-gradient(120deg, #1f2937 0%, #4b5563 100%);
-webkit-background-clip: text;
background-clip: text;
color: transparent !important;
}
#my_radio .wrap {
display: flex;
flex-wrap: nowrap;
justify-content: center;
align-items: center;
}
#my_radio .wrap label {
display: flex;
width: 50%;
justify-content: center;
align-items: center;
margin: 0;
padding: 10px 0;
box-sizing: border-box;
}
"""
cfg = load_cfg_from_cli()
model = load_model(cfg)
with gr.Blocks(theme=theme, css=css) as demo:
is_example = gr.Textbox(label="is_example", visible=False, value="None")
num_images = gr.Textbox(label="num_images", visible=False, value="None")
frame_id_state = gr.Textbox(label="frame_id", visible=False, value="0")
gr.HTML(
"""
<h1>V-DPM: 4D Video Reconstruction with Dynamic Point Maps</h1>
<p>
<a href="https://github.com/eldar/vdpm">GitHub Repository</a> |
<a href="https://www.robots.ox.ac.uk/~vgg/research/vdpm">Project Page</a>
</p>
<div style="font-size: 16px; line-height: 1.5;">
<p>Upload a video or a set of images to create a dynamic point map reconstruction of a scene or object.</p>
</div>
"""
)
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
with gr.Row():
with gr.Column(scale=2):
input_video = gr.Video(label="Upload Video", interactive=True)
input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
image_gallery = gr.Gallery(
label="Preview",
columns=4,
height="300px",
show_download_button=True,
object_fit="contain",
preview=True,
)
with gr.Column(scale=5):
gr.Markdown("**3D Reconstruction (Point Cloud)**")
log_output = gr.Markdown(
"Please upload a video or images, then click Reconstruct.", elem_classes=["custom-log"]
)
reconstruction_output = gr.Plot(label="3D Point Cloud")
with gr.Row():
submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
gr.ClearButton(
[input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery],
scale=1,
)
with gr.Row():
conf_thres = gr.Slider(0, 100, value=50, step=1, label="Confidence Threshold (%)")
with gr.Column():
mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
mask_white_bg = gr.Checkbox(label="Filter White Background", value=False)
examples = [
[camel_video, "17", None, 15.0, False, False, "True", "8"],
[horse_video, "18", None, 50.0, False, False, "True", "2"],
[tennis_video, "11", None, 5.0, False, False, "True", "0"],
[paragliding_video, "11", None, 5.0, False, False, "True", "0"],
[stroller_video, "17", None, 10.0, False, False, "True", "8"],
[goldfish_video, "11", None, 12.0, False, False, "True", "5"],
[swing_video, "10", None, 40.0, False, False, "True", "4"],
[car_video, "13", None, 15.0, False, False, "True", "7"],
[figure1_video, "10", None, 25.0, False, False, "True", "0"],
[figure2_video, "12", None, 25.0, False, False, "True", "6"],
[figure3_video, "13", None, 30.0, False, False, "True", "0"],
[tesla_video, "18", None, 20.0, False, True, "True", "0"],
[pstudio_video, "12", None, 0.0, False, False, "True", "6"],
]
gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
gr.Examples(
examples=examples,
inputs=[
input_video,
num_images,
input_images,
conf_thres,
mask_black_bg,
mask_white_bg,
is_example,
frame_id_state,
],
outputs=[
reconstruction_output,
log_output,
target_dir_output,
image_gallery,
],
fn=example_pipeline,
cache_examples=False,
examples_per_page=50,
)
submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then(
fn=update_log, inputs=[], outputs=[log_output]
).then(
fn=gradio_reconstruct,
inputs=[
target_dir_output,
conf_thres,
mask_black_bg,
mask_white_bg,
frame_id_state,
],
outputs=[reconstruction_output, log_output],
).then(
fn=lambda: "False", inputs=[], outputs=[is_example]
)
for ctrl in (conf_thres, mask_black_bg, mask_white_bg):
ctrl.change(
fn=update_plot,
inputs=[
target_dir_output,
conf_thres,
mask_black_bg,
mask_white_bg,
is_example,
],
outputs=[reconstruction_output, log_output],
)
input_video.change(
fn=update_gallery_on_upload,
inputs=[input_video, input_images],
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
)
input_images.change(
fn=update_gallery_on_upload,
inputs=[input_video, input_images],
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
)
demo.queue(max_size=20).launch(show_error=True, share=True, ssr_mode=False)