|
|
import os |
|
|
import tempfile |
|
|
import uuid |
|
|
import traceback |
|
|
import spaces |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import cv2 |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from glob import glob |
|
|
from natsort import natsorted |
|
|
|
|
|
import rerun as rr |
|
|
import rerun.blueprint as rrb |
|
|
from gradio_rerun import Rerun |
|
|
|
|
|
import hydra |
|
|
from tqdm import tqdm |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib as mpl |
|
|
import matplotlib.cm as cm |
|
|
from matplotlib.colors import hsv_to_rgb |
|
|
from PIL import Image |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
from mapanything.utils.image import load_images, rgb |
|
|
from mapanything.utils.misc import seed_everything |
|
|
from mapanything.utils.moge_inference import load_moge_model |
|
|
from mapanything.utils.inference import loss_of_one_batch_multi_view |
|
|
from mapanything.models import init_model |
|
|
from mapanything.utils.geometry import ( |
|
|
quaternion_to_rotation_matrix, |
|
|
recover_pinhole_intrinsics_from_ray_directions, |
|
|
normals_edge, |
|
|
depth_edge, |
|
|
points_to_normals, |
|
|
) |
|
|
|
|
|
def log_data_to_rerun(image, depthmap, pose, intrinsics, pts3d, mask, base_name, pts_name, viz_mask=None): |
|
|
|
|
|
height, width = image.shape[0], image.shape[1] |
|
|
rr.log( |
|
|
base_name, |
|
|
rr.Transform3D( |
|
|
translation=pose[:3, 3], |
|
|
mat3x3=pose[:3, :3], |
|
|
from_parent=False, |
|
|
), |
|
|
) |
|
|
rr.log( |
|
|
f"{base_name}/pinhole", |
|
|
rr.Pinhole( |
|
|
image_from_camera=intrinsics, |
|
|
height=height, |
|
|
width=width, |
|
|
camera_xyz=rr.ViewCoordinates.RDF, |
|
|
), |
|
|
) |
|
|
rr.log( |
|
|
f"{base_name}/pinhole/rgb", |
|
|
rr.Image(image), |
|
|
) |
|
|
|
|
|
|
|
|
def log_points_to_rerun(image, pts3d, pts_name, mask=None): |
|
|
|
|
|
if mask is None: |
|
|
filtered_pts = pts3d |
|
|
filtered_pts_col = image |
|
|
else: |
|
|
filtered_pts = pts3d[mask] |
|
|
filtered_pts_col = image[mask] |
|
|
rr.log( |
|
|
pts_name, |
|
|
rr.Points3D( |
|
|
positions=filtered_pts.reshape(-1, 3), |
|
|
colors=filtered_pts_col.reshape(-1, 3), |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
def init_hydra_config(config_path, overrides=None): |
|
|
"Initialize Hydra config" |
|
|
config_dir = os.path.dirname(config_path) |
|
|
config_name = os.path.basename(config_path).split(".")[0] |
|
|
relative_path = os.path.relpath(config_dir, os.path.dirname(__file__)) |
|
|
hydra.core.global_hydra.GlobalHydra.instance().clear() |
|
|
hydra.initialize(version_base=None, config_path=relative_path) |
|
|
if overrides is not None: |
|
|
cfg = hydra.compose(config_name=config_name, overrides=overrides) |
|
|
else: |
|
|
cfg = hydra.compose(config_name=config_name) |
|
|
|
|
|
return cfg |
|
|
|
|
|
|
|
|
def log_point_tracks(point_tracks): |
|
|
""" |
|
|
Log the point tracks for visualization. |
|
|
""" |
|
|
|
|
|
initial_positions = point_tracks[0] |
|
|
x_coords = initial_positions[:, 0] |
|
|
x_min = x_coords.min() |
|
|
x_max = x_coords.max() |
|
|
|
|
|
norm = mpl.colors.Normalize(vmin=x_min, vmax=x_max) |
|
|
normalized_x = norm(x_coords) |
|
|
colormap = cm.get_cmap('rainbow') |
|
|
track_colors = colormap(normalized_x)[:, :3] |
|
|
|
|
|
for track_idx in range(point_tracks.shape[1]): |
|
|
track = point_tracks[:, track_idx, :] |
|
|
rr.log( |
|
|
f"pred/point_tracks/track_{track_idx}", |
|
|
rr.LineStrips3D( |
|
|
strips=[track], |
|
|
colors=[track_colors[track_idx]], |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
def init_inference_model(config, ckpt_path, device): |
|
|
"Initialize the model for inference" |
|
|
|
|
|
if isinstance(config, dict): |
|
|
config_path = config["path"] |
|
|
overrrides = config["config_overrides"] |
|
|
model_args = init_hydra_config(config_path, overrides=overrrides) |
|
|
model = init_model(model_args.model.model_str, model_args.model.model_config) |
|
|
else: |
|
|
config_path = config |
|
|
model_args = init_hydra_config(config_path) |
|
|
model = init_model(model_args.model_str, model_args.model_config) |
|
|
model.to(device) |
|
|
if ckpt_path is not None: |
|
|
print("Loading model from: ", ckpt_path) |
|
|
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
|
|
print(model.load_state_dict(ckpt["model"], strict=False)) |
|
|
model.to(device) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def sample_inference(model, views, device, use_amp): |
|
|
|
|
|
result = loss_of_one_batch_multi_view( |
|
|
views, |
|
|
model, |
|
|
None, |
|
|
device, |
|
|
use_amp=use_amp, |
|
|
) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
class Any4DProcessor: |
|
|
"""Handles Any4D processing and Rerun visualization""" |
|
|
|
|
|
def __init__(self, config_dir, checkpoint_path, machine="local"): |
|
|
self.config_dir = config_dir |
|
|
self.machine = machine |
|
|
self.checkpoint_path = checkpoint_path |
|
|
|
|
|
|
|
|
if checkpoint_path is None or not os.path.exists(checkpoint_path): |
|
|
print("📥 Downloading checkpoint from Hugging Face Model Hub...") |
|
|
self.checkpoint_path = hf_hub_download( |
|
|
repo_id="airlabshare/any4d-checkpoint", |
|
|
filename="any4d_4v_combined.pth", |
|
|
cache_dir="./checkpoints" |
|
|
) |
|
|
print(f"✓ Checkpoint downloaded to: {self.checkpoint_path}") |
|
|
else: |
|
|
self.checkpoint_path = checkpoint_path |
|
|
print(f"✓ Using local checkpoint: {self.checkpoint_path}") |
|
|
|
|
|
seed_everything(0) |
|
|
|
|
|
def get_config(self): |
|
|
"""Get model configuration""" |
|
|
return { |
|
|
"path": f"{self.config_dir}/train.yaml", |
|
|
"config_overrides": [ |
|
|
f"machine={self.machine}", |
|
|
"model=any4d", |
|
|
"model.encoder.uses_torch_hub=false", |
|
|
"model/task=images_only", |
|
|
], |
|
|
"checkpoint_path": self.checkpoint_path, |
|
|
"trained_with_amp": True, |
|
|
"data_norm_type": "dinov2", |
|
|
} |
|
|
|
|
|
def extract_frames_from_video(self, video_path, max_frames=100): |
|
|
"""Extract frames from video""" |
|
|
cap = cv2.VideoCapture(video_path) |
|
|
frames = [] |
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
interval = max(1, total_frames // max_frames) |
|
|
|
|
|
frame_count = 0 |
|
|
while cap.isOpened() and len(frames) < max_frames: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
|
|
|
if frame_count % interval == 0: |
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
frames.append(frame_rgb) |
|
|
|
|
|
frame_count += 1 |
|
|
|
|
|
cap.release() |
|
|
return frames |
|
|
|
|
|
@spaces.GPU |
|
|
def process_video_streaming( |
|
|
self, |
|
|
recording_id: str, |
|
|
video_path: str, |
|
|
max_frames: int = 40, |
|
|
img_idx: int = 0, |
|
|
use_scene_flow_mask_refinement: bool = True, |
|
|
progress=gr.Progress() |
|
|
): |
|
|
""" |
|
|
Process video and stream results to embedded Rerun viewer. |
|
|
|
|
|
This function yields data incrementally to the Rerun viewer |
|
|
embedded in the Gradio interface. |
|
|
""" |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
high_level_config = self.get_config() |
|
|
|
|
|
print("Loading models inside GPU context...") |
|
|
model = init_inference_model( |
|
|
high_level_config, |
|
|
self.checkpoint_path, |
|
|
device |
|
|
) |
|
|
moge_model = load_moge_model(model_code_path="MoGe", device=device) |
|
|
print("✓ Models loaded") |
|
|
|
|
|
|
|
|
rec = rr.RecordingStream( |
|
|
application_id="any4d_visualization", |
|
|
recording_id=recording_id |
|
|
) |
|
|
stream = rec.binary_stream() |
|
|
|
|
|
try: |
|
|
progress(0, desc="Extracting frames...") |
|
|
|
|
|
|
|
|
frames = self.extract_frames_from_video(video_path, max_frames=max_frames) |
|
|
|
|
|
|
|
|
if img_idx >= len(frames): |
|
|
raise gr.Error(f"Reference frame index ({img_idx}) must be less than number of frames ({len(frames)})") |
|
|
|
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
|
image_paths = [] |
|
|
for idx, frame in enumerate(frames): |
|
|
frame_path = os.path.join(temp_dir, f"frame_{idx:05d}.jpg") |
|
|
cv2.imwrite(frame_path, cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) |
|
|
image_paths.append(frame_path) |
|
|
|
|
|
progress(0.2, desc="Loading images...") |
|
|
|
|
|
|
|
|
start_idx = 0 |
|
|
end_idx = len(image_paths) |
|
|
|
|
|
image_list = [image_paths[img_idx]] |
|
|
image_list += [image_paths[idx] for idx in range(start_idx, end_idx, 1)] |
|
|
|
|
|
|
|
|
views = load_images( |
|
|
image_list, |
|
|
|
|
|
verbose=True, |
|
|
norm_type="dinov2", |
|
|
patch_size=14, |
|
|
compute_moge_mask=True, |
|
|
moge_model=moge_model, |
|
|
binary_mask_path=None |
|
|
) |
|
|
|
|
|
progress(0.4, desc="Running inference...") |
|
|
|
|
|
|
|
|
pred_result = sample_inference( |
|
|
model, |
|
|
views, |
|
|
device, |
|
|
use_amp=True |
|
|
) |
|
|
|
|
|
progress(0.6, desc="Creating visualization...") |
|
|
|
|
|
|
|
|
rec.log("pred", rr.ViewCoordinates.RDF, static=True) |
|
|
|
|
|
|
|
|
blueprint = rrb.Blueprint( |
|
|
rrb.Spatial3DView( |
|
|
origin="pred", |
|
|
name="3D Scene", |
|
|
background=[255, 255, 255], |
|
|
line_grid=rrb.archetypes.LineGrid3D( |
|
|
visible=False, |
|
|
), |
|
|
), |
|
|
collapse_panels=True, |
|
|
) |
|
|
rec.send_blueprint(blueprint) |
|
|
|
|
|
|
|
|
yield stream.read() |
|
|
|
|
|
|
|
|
num_views = len(views) |
|
|
for idx in range(1, num_views): |
|
|
progress_val = 0.6 + (0.4 * idx / (num_views - 1)) |
|
|
progress(progress_val, desc=f"Visualizing frame {idx}/{num_views-1}") |
|
|
|
|
|
cur_views = [views[0], views[idx]] |
|
|
cur_pred_result = { |
|
|
"view1": pred_result["view1"], |
|
|
"view2": pred_result[f"view{idx+1}"], |
|
|
"pred1": pred_result["pred1"], |
|
|
"pred2": pred_result[f"pred{idx+1}"] |
|
|
} |
|
|
|
|
|
|
|
|
rec.set_time_seconds("stable_time", 0.2*idx) |
|
|
|
|
|
self.log_visualization_data( |
|
|
rec, |
|
|
cur_views, |
|
|
cur_pred_result, |
|
|
img_norm_type=high_level_config["data_norm_type"], |
|
|
use_scene_flow_type="allo_scene_flow", |
|
|
use_scene_flow_mask_refinement=use_scene_flow_mask_refinement, |
|
|
) |
|
|
|
|
|
|
|
|
yield stream.read() |
|
|
|
|
|
progress(1.0, desc="Complete!") |
|
|
|
|
|
|
|
|
yield stream.read() |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}" |
|
|
print(error_msg) |
|
|
raise gr.Error(error_msg) |
|
|
|
|
|
def log_visualization_data(self, rec, views, pred_output, img_norm_type, use_scene_flow_type, use_scene_flow_mask_refinement=True, start_view_idx=0): |
|
|
|
|
|
viz_mask_0 = None |
|
|
depth_z_0 = None |
|
|
|
|
|
for view_idx, view in enumerate(views): |
|
|
image = rgb(view["img"], norm_type=img_norm_type) |
|
|
|
|
|
|
|
|
pts_name = f"pred/pointcloud_view_{view_idx+start_view_idx}" |
|
|
pts_key = "pts3d" |
|
|
pred_pts3d = pred_output[f"pred{view_idx+1}"][pts_key][0].cpu() |
|
|
pred_pts3d_0 = pred_output[f"pred1"]["pts3d"][0].cpu() |
|
|
|
|
|
|
|
|
non_ambiguous_mask = view["non_ambiguous_mask"].cpu() |
|
|
|
|
|
|
|
|
normals, normals_mask = points_to_normals(pred_pts3d.numpy(), mask=non_ambiguous_mask.numpy()) |
|
|
normal_edges = normals_edge(normals, tol=5, mask=normals_mask) |
|
|
|
|
|
|
|
|
depth_z = pred_output[f"pred{view_idx+1}"]["pts3d_cam"][...,2:3][0].squeeze(-1).cpu().numpy() |
|
|
depth_edges = depth_edge(depth_z, rtol=0.03, mask=non_ambiguous_mask.numpy()) |
|
|
|
|
|
|
|
|
mask = ~(depth_edges & normal_edges) |
|
|
|
|
|
|
|
|
mask = non_ambiguous_mask.numpy() & mask |
|
|
|
|
|
|
|
|
close_depth_mask = depth_z < 40.0 |
|
|
mask = mask & close_depth_mask |
|
|
|
|
|
if view_idx == 0: |
|
|
viz_mask_0 = mask |
|
|
|
|
|
kernel_size = 3 |
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) |
|
|
viz_mask_0 = viz_mask_0.astype(np.uint8) * 255 |
|
|
viz_mask_0 = cv2.erode(viz_mask_0, kernel, iterations=3) |
|
|
|
|
|
depth_z_0 = depth_z.copy() |
|
|
|
|
|
if "cam_quats" not in pred_output[f"pred{view_idx+1}"].keys(): |
|
|
|
|
|
base_name = f"pred/image_view_{view_idx+start_view_idx}" |
|
|
rec.log( |
|
|
f"{base_name}/pinhole/rgb", |
|
|
rr.Image(image), |
|
|
) |
|
|
|
|
|
log_points_to_rerun(image[0], pred_pts3d, pts_name, mask=mask) |
|
|
|
|
|
if mask is not None: |
|
|
rec.log( |
|
|
f"{base_name}/pinhole/mask", |
|
|
rr.SegmentationImage(mask.numpy().astype(int)), |
|
|
) |
|
|
else: |
|
|
base_name = f"pred/image_view_{view_idx+start_view_idx}" |
|
|
cam_quats = pred_output[f"pred{view_idx+1}"]["cam_quats"][0].cpu() |
|
|
cam_trans = pred_output[f"pred{view_idx+1}"]["cam_trans"][0].cpu() |
|
|
ray_directions = pred_output[f"pred{view_idx+1}"]["ray_directions"][0].cpu() |
|
|
ray_depth = pred_output[f"pred{view_idx+1}"]["depth_along_ray"][0].cpu() |
|
|
local_pts3d = ray_directions * ray_depth |
|
|
depth_z = local_pts3d[..., 2:] |
|
|
if mask is not None: |
|
|
ambiguous_mask = ~mask |
|
|
depth_z[ambiguous_mask] = 0 |
|
|
viz_mask = mask |
|
|
else: |
|
|
viz_mask = None |
|
|
|
|
|
if "motion_mask" in pred_output[f"pred{view_idx+1}"].keys() and view_idx == 0: |
|
|
pred_motion_mask = pred_output[f"pred{view_idx+1}"]["motion_mask"][0].cpu().numpy() |
|
|
viz_mask = viz_mask * pred_motion_mask |
|
|
|
|
|
cam_rot = quaternion_to_rotation_matrix(cam_quats) |
|
|
cam_pose = torch.eye(4) |
|
|
cam_pose[:3, :3] = cam_rot |
|
|
cam_pose[:3, 3] = cam_trans |
|
|
cam_intrinsics = recover_pinhole_intrinsics_from_ray_directions(ray_directions) |
|
|
self.log_data_to_rerun( |
|
|
rec, |
|
|
image[0], |
|
|
depth_z, |
|
|
cam_pose, |
|
|
cam_intrinsics, |
|
|
pred_pts3d, |
|
|
mask, |
|
|
base_name, |
|
|
pts_name, |
|
|
viz_mask=viz_mask, |
|
|
) |
|
|
|
|
|
|
|
|
cam_quats_0 = pred_output[f"pred{1}"]["cam_quats"][0].cpu() |
|
|
cam_trans_0 = pred_output[f"pred{1}"]["cam_trans"][0].cpu() |
|
|
cam_rot_0 = quaternion_to_rotation_matrix(cam_quats_0) |
|
|
cam_pose_0 = torch.eye(4) |
|
|
cam_pose_0[:3, :3] = cam_rot_0 |
|
|
cam_pose_0[:3, 3] = cam_trans_0 |
|
|
|
|
|
if "scene_flow" in pred_output[f"pred{view_idx+1}"].keys(): |
|
|
if use_scene_flow_type == "allo_scene_flow": |
|
|
|
|
|
pred_scene_flow_vectors = pred_output[f"pred{view_idx+1}"]["scene_flow"][0].cpu() |
|
|
|
|
|
|
|
|
motion_sf_mask = pred_scene_flow_vectors.norm(dim=-1) > 1e-1 |
|
|
scene_flow_mask = viz_mask_0 & motion_sf_mask.numpy() & views[0]["binary_mask"].cpu().numpy() |
|
|
|
|
|
mask = (scene_flow_mask > 0) |
|
|
depth_values = depth_z_0[mask] |
|
|
if len(depth_values) > 0: |
|
|
depth_mean = np.median(depth_values) |
|
|
depth_std = np.std(depth_values) |
|
|
else: |
|
|
depth_mean, depth_std = 0, 0 |
|
|
depth_tolerance = 0.5 * depth_std |
|
|
depth_min = depth_mean - depth_tolerance |
|
|
depth_max = depth_mean + depth_tolerance |
|
|
depth_consistent_mask = np.logical_and( |
|
|
mask, |
|
|
np.logical_and(depth_z_0 >= depth_min, depth_z_0 <= depth_max) |
|
|
) |
|
|
scene_flow_mask_refined = depth_consistent_mask.astype(np.uint8) * 255 |
|
|
|
|
|
if view_idx == 1: |
|
|
|
|
|
mask_to_use = scene_flow_mask_refined if use_scene_flow_mask_refinement else viz_mask_0 |
|
|
self.log_scene_flow_to_rerun(rec, image[0], pred_pts3d_0, pred_scene_flow_vectors, f"pred/scene_flow_{view_idx+start_view_idx}", mask=mask_to_use) |
|
|
|
|
|
return scene_flow_mask_refined |
|
|
|
|
|
def log_data_to_rerun(self, rec, image, depthmap, pose, intrinsics, pts3d, mask, base_name, pts_name, viz_mask=None): |
|
|
|
|
|
height, width = image.shape[0], image.shape[1] |
|
|
rec.log( |
|
|
base_name, |
|
|
rr.Transform3D( |
|
|
translation=pose[:3, 3], |
|
|
mat3x3=pose[:3, :3], |
|
|
from_parent=False, |
|
|
), |
|
|
) |
|
|
rec.log( |
|
|
f"{base_name}/pinhole", |
|
|
rr.Pinhole( |
|
|
image_from_camera=intrinsics, |
|
|
height=height, |
|
|
width=width, |
|
|
camera_xyz=rr.ViewCoordinates.RDF, |
|
|
), |
|
|
) |
|
|
rec.log( |
|
|
f"{base_name}/pinhole/rgb", |
|
|
rr.Image(image), |
|
|
) |
|
|
|
|
|
filtered_pts = pts3d[mask] |
|
|
filtered_pts_col = image[mask] |
|
|
rec.log( |
|
|
pts_name, |
|
|
rr.Points3D( |
|
|
positions=filtered_pts.reshape(-1, 3), |
|
|
colors=filtered_pts_col.reshape(-1, 3), |
|
|
), |
|
|
) |
|
|
|
|
|
def log_scene_flow_to_rerun(self, rec, image, pts3d, scene_flow_vecs, base_name, mask=None, scene_bounds=None, colormap_name='rainbow'): |
|
|
""" |
|
|
Log scene flow in 3D with color visualization based on flow magnitude and direction |
|
|
""" |
|
|
|
|
|
if mask is None: |
|
|
filtered_pts = pts3d |
|
|
filtered_scene_flow_vecs = scene_flow_vecs |
|
|
filtered_pts_col = image |
|
|
else: |
|
|
filtered_pts = pts3d[mask] |
|
|
filtered_scene_flow_vecs = scene_flow_vecs[mask] |
|
|
filtered_pts_col = image[mask] |
|
|
|
|
|
|
|
|
if filtered_pts.numel() == 0: |
|
|
print(f"Warning: No valid points found for {base_name}. Skipping visualization.") |
|
|
return scene_bounds |
|
|
|
|
|
|
|
|
filtered_pts = filtered_pts.reshape(-1, 3) |
|
|
filtered_scene_flow_vecs = filtered_scene_flow_vecs.reshape(-1, 3) |
|
|
|
|
|
|
|
|
max_arrows = 10000 |
|
|
if filtered_pts.shape[0] > max_arrows: |
|
|
flow_magnitudes = torch.norm(filtered_scene_flow_vecs, dim=1) |
|
|
|
|
|
if flow_magnitudes.max() > 1e-6: |
|
|
probabilities = 0.2 + 0.8 * (flow_magnitudes / (flow_magnitudes.max() + 1e-6)) |
|
|
probabilities_np = probabilities.cpu().numpy() |
|
|
probabilities_np = probabilities_np / probabilities_np.sum() |
|
|
|
|
|
indices_np = np.random.choice( |
|
|
filtered_pts.shape[0], |
|
|
size=max_arrows, |
|
|
replace=False, |
|
|
p=probabilities_np |
|
|
) |
|
|
indices = torch.tensor(indices_np, device=filtered_pts.device) |
|
|
else: |
|
|
indices = torch.randperm(filtered_pts.shape[0], device=filtered_pts.device)[:max_arrows] |
|
|
|
|
|
sampled_pts = filtered_pts[indices] |
|
|
sampled_vectors = filtered_scene_flow_vecs[indices] |
|
|
else: |
|
|
sampled_pts = filtered_pts |
|
|
sampled_vectors = filtered_scene_flow_vecs |
|
|
|
|
|
|
|
|
sampled_pts_np = sampled_pts.cpu().numpy() |
|
|
sampled_vectors_np = sampled_vectors.cpu().numpy() |
|
|
|
|
|
|
|
|
flow_magnitudes = np.linalg.norm(sampled_vectors_np, axis=1) |
|
|
|
|
|
|
|
|
if scene_bounds is None: |
|
|
if len(flow_magnitudes) == 0: |
|
|
return (0, 1) |
|
|
mag_min = flow_magnitudes.min() |
|
|
mag_max = flow_magnitudes.max() |
|
|
if mag_min == mag_max: |
|
|
mag_max = mag_min + 1e-6 |
|
|
else: |
|
|
mag_min, mag_max = scene_bounds |
|
|
|
|
|
|
|
|
if len(flow_magnitudes) > 0 and mag_max > mag_min: |
|
|
|
|
|
normalized_flow = sampled_vectors_np / (flow_magnitudes[:, np.newaxis] + 1e-8) |
|
|
|
|
|
|
|
|
hue = np.arctan2(normalized_flow[:, 2], normalized_flow[:, 0]) |
|
|
hue = (hue + np.pi) / (2 * np.pi) |
|
|
|
|
|
|
|
|
normalized_magnitude = np.clip((flow_magnitudes - mag_min) / (mag_max - mag_min + 1e-8), 0, 1) |
|
|
saturation = 0.3 + 0.7 * normalized_magnitude |
|
|
value = 0.5 + 0.5 * normalized_magnitude |
|
|
|
|
|
|
|
|
hsv = np.stack([hue, saturation, value], axis=1) |
|
|
colors = hsv_to_rgb(hsv) |
|
|
else: |
|
|
colors = np.ones((len(sampled_vectors_np), 3)) * 0.5 |
|
|
|
|
|
|
|
|
rec.log( |
|
|
f"{base_name}/scene_flow", |
|
|
rr.Arrows3D( |
|
|
origins=sampled_pts_np, |
|
|
vectors=sampled_vectors_np, |
|
|
colors=colors, |
|
|
), |
|
|
) |
|
|
|
|
|
return (mag_min, mag_max) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_app(config_dir, checkpoint_path): |
|
|
"""Create Gradio app with embedded Rerun viewer""" |
|
|
|
|
|
|
|
|
processor = Any4DProcessor(config_dir, checkpoint_path) |
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
|
title="Any4D Scene Flow Visualization", |
|
|
css=""" |
|
|
.gradio-container { |
|
|
max-height: 100vh; |
|
|
overflow-y: auto !important; |
|
|
} |
|
|
footer { |
|
|
display: none !important; |
|
|
} |
|
|
/* Style clickable thumbnails */ |
|
|
.clickable-thumbnail { |
|
|
cursor: pointer; |
|
|
transition: transform 0.2s ease; |
|
|
} |
|
|
.clickable-thumbnail:hover { |
|
|
transform: scale(1.05); |
|
|
} |
|
|
.scene-info { |
|
|
text-align: center; |
|
|
margin-top: 8px; |
|
|
} |
|
|
""" |
|
|
) as demo: |
|
|
|
|
|
gr.Markdown(""" |
|
|
# 🎬 [Any4D](https://any-4d.github.io/) Rerun Scene Flow Visualization |
|
|
|
|
|
Upload a video or choose an example video from the thumbnails below to visualize 3D reconstruction and scene flow in real-time. |
|
|
|
|
|
Getting Started: |
|
|
1. Upload a video using the input panel on the left. |
|
|
2. Adjust the parameters as needed: |
|
|
- **Max Frames to Extract**: Number of frames to extract from the video for processing. |
|
|
- **Reference Frame Index**: Select which frame to use as the reference for scene flow computation. |
|
|
- **Use Scene Flow Mask Refinement**: Enable to reduce background smearing in scene flow visualization. |
|
|
3. Click "Process Video" to start the visualization. |
|
|
4. Explore the live 3D visualization on the right panel. |
|
|
|
|
|
Note: Inference time for Any4D for 100 frames is approximately less than 1.5 seconds. However, downloading model weights and visualizing 3D points may take tens of seconds. |
|
|
Please be patient or, for faster visualization, use a local machine to run our demo from our [GitHub repository](https://github.com/Any-4D/Any4D). |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### 📥 Input") |
|
|
|
|
|
video_input = gr.Video( |
|
|
label="Upload Video", |
|
|
sources=["upload"] |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
max_frames_slider = gr.Slider( |
|
|
minimum=10, |
|
|
maximum=200, |
|
|
value=40, |
|
|
step=1, |
|
|
label="Max Frames to Extract", |
|
|
info="Number of frames to extract from video" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
frame_preview = gr.Image( |
|
|
label="Reference Frame Preview", |
|
|
type="numpy", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
reference_frame_slider = gr.Slider( |
|
|
minimum=0, |
|
|
maximum=100, |
|
|
value=0, |
|
|
step=1, |
|
|
label="Reference Frame Index", |
|
|
info="Select which frame to use as reference for scene flow" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
use_mask_refinement = gr.Checkbox( |
|
|
value=False, |
|
|
label="Use Scene Flow Mask Refinement", |
|
|
info=" Removes background scene-flow smearing near boundaries by assuming dominant motion in scene. Do not use for scenes with multiple dynamic objects." |
|
|
) |
|
|
|
|
|
process_btn = gr.Button("🚀 Process Video", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
gr.Markdown("### 🎥 Live 3D Visualization") |
|
|
|
|
|
|
|
|
viewer = Rerun( |
|
|
streaming=True, |
|
|
height=700, |
|
|
panel_states={ |
|
|
"time": "collapsed", |
|
|
"blueprint": "collapsed", |
|
|
"selection": "collapsed", |
|
|
}, |
|
|
) |
|
|
|
|
|
status_text = gr.Textbox( |
|
|
label="Status", |
|
|
lines=2, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
|
|
|
extracted_frames = gr.State([]) |
|
|
loading_example = gr.State(False) |
|
|
preset_ref_state = gr.State(None) |
|
|
|
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("## 📂 Example Videos") |
|
|
gr.Markdown("Click any thumbnail to load the example video with preset configurations") |
|
|
|
|
|
|
|
|
example_videos = [ |
|
|
{"path": "examples/layup_fpv.mp4", "max_frames": 100, "ref_frame": 50, "mask_refinement": True, "name": "Basketball Layup", "thumbnail": "examples/layup_fpv.jpg"}, |
|
|
{"path": "examples/cat_jump.mp4", "max_frames": 40, "ref_frame": 0, "mask_refinement": True, "name": "Cat Jump", "thumbnail": "examples/cat_jump.jpg"}, |
|
|
{"path": "examples/climbing.mp4", "max_frames": 50, "ref_frame": 0, "mask_refinement": True, "name": "Climbing", "thumbnail": "examples/climbing.jpg"}, |
|
|
{"path": "examples/lady-running.mp4", "max_frames": 30, "ref_frame": 15, "mask_refinement": True, "name": "Lady Running", "thumbnail": "examples/lady-running.jpg"}, |
|
|
{"path": "examples/mobile_robot.mp4", "max_frames": 40, "ref_frame": 25, "mask_refinement": True, "name": "Mobile Robot", "thumbnail": "examples/mobile_robot.jpg"}, |
|
|
{"path": "examples/quadruped_climbing.mp4", "max_frames": 90, "ref_frame": 45, "mask_refinement": True, "name": "Quadruped Climbing", "thumbnail": "examples/quadruped_climbing.jpg"}, |
|
|
{"path": "examples/seagull.mp4", "max_frames": 50, "ref_frame": 25, "mask_refinement": False, "name": "Seagull", "thumbnail": "examples/seagull.jpg"}, |
|
|
{"path": "examples/Humanoid_Walking.mp4", "max_frames": 50, "ref_frame": 25, "mask_refinement": True, "name": "Humanoid Walking", "thumbnail": "examples/Humanoid_Walking.jpg"}, |
|
|
{"path": "examples/car.mp4", "max_frames": 90, "ref_frame": 0, "mask_refinement": True, "name": "Car", "thumbnail": "examples/car.jpg"}, |
|
|
{"path": "examples/Bigfoot.mp4", "max_frames": 60, "ref_frame": 30, "mask_refinement": True, "name": "Bigfoot", "thumbnail": "examples/Bigfoot.jpg"}, |
|
|
{"path": "examples/humanoid_football.mp4", "max_frames": 22, "ref_frame": 10, "mask_refinement": True, "name": "humanoid_football", "thumbnail": "examples/humanoid_football.png"}, |
|
|
{"path": "examples/motorbike_desert.mp4", "max_frames": 50, "ref_frame": 25, "mask_refinement": True, "name": "Motorbike Desert", "thumbnail": "examples/motorbike_desert.jpg"}, |
|
|
] |
|
|
|
|
|
|
|
|
def load_example_video(video_path, max_frames, ref_frame, mask_refinement): |
|
|
"""Load an example video with its configuration and extract frames""" |
|
|
print(f"[Example] Loading {video_path} with ref_frame={ref_frame}, max_frames={max_frames}") |
|
|
|
|
|
frame_preview_img, ref_slider_update, frames = extract_and_preview(video_path, max_frames, preset_ref_frame=ref_frame) |
|
|
print(f"[Example] Extracted {len(frames)} frames, slider update value: {ref_slider_update}") |
|
|
|
|
|
return max_frames, mask_refinement, True, video_path, ref_slider_update, frame_preview_img, frames, ref_frame |
|
|
|
|
|
|
|
|
for i in range(0, len(example_videos), 4): |
|
|
with gr.Row(): |
|
|
for j in range(4): |
|
|
video_idx = i + j |
|
|
if video_idx < len(example_videos): |
|
|
video = example_videos[video_idx] |
|
|
with gr.Column(scale=1, elem_classes=["clickable-thumbnail"]): |
|
|
|
|
|
video_thumb = gr.Image( |
|
|
value=video.get("thumbnail", video["path"]), |
|
|
label="", |
|
|
height=150, |
|
|
interactive=False, |
|
|
show_label=False, |
|
|
elem_id=f"video_thumb_{video['name'].replace(' ', '_')}", |
|
|
sources=[], |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown( |
|
|
f"**{video['name']}**", |
|
|
elem_classes=["scene-info"], |
|
|
) |
|
|
|
|
|
|
|
|
video_thumb.select( |
|
|
fn=lambda p=video["path"], m=video["max_frames"], r=video["ref_frame"], mr=video["mask_refinement"]: |
|
|
load_example_video(p, m, r, mr), |
|
|
outputs=[max_frames_slider, use_mask_refinement, loading_example, video_input, reference_frame_slider, frame_preview, extracted_frames, preset_ref_state], |
|
|
) |
|
|
else: |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
pass |
|
|
|
|
|
|
|
|
def extract_and_preview(video_file, max_frames, preset_ref_frame=None, progress=gr.Progress()): |
|
|
if video_file is None: |
|
|
return None, gr.update(maximum=0, value=0), [] |
|
|
|
|
|
try: |
|
|
progress(0, desc="Extracting frames...") |
|
|
|
|
|
|
|
|
cap = cv2.VideoCapture(video_file) |
|
|
frames = [] |
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
interval = max(1, total_frames // max_frames) |
|
|
|
|
|
frame_count = 0 |
|
|
while cap.isOpened() and len(frames) < max_frames: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
|
|
|
if frame_count % interval == 0: |
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
frames.append(frame_rgb) |
|
|
|
|
|
frame_count += 1 |
|
|
|
|
|
cap.release() |
|
|
|
|
|
num_frames = len(frames) |
|
|
|
|
|
|
|
|
if preset_ref_frame is not None and preset_ref_frame < num_frames: |
|
|
default_idx = preset_ref_frame |
|
|
else: |
|
|
default_idx = num_frames // 2 if num_frames > 0 else 0 |
|
|
|
|
|
|
|
|
return ( |
|
|
frames[default_idx] if num_frames > 0 else None, |
|
|
gr.update(maximum=max(0, num_frames - 1), value=default_idx), |
|
|
frames |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error extracting frames: {e}") |
|
|
return None, gr.update(maximum=0, value=0), [] |
|
|
|
|
|
|
|
|
def update_frame_preview(frames, frame_idx): |
|
|
if frames is None or len(frames) == 0 or frame_idx >= len(frames): |
|
|
return None |
|
|
return frames[frame_idx] |
|
|
|
|
|
|
|
|
def process_video_wrapper(video_file, max_frames, ref_frame_idx, use_mask_ref, progress=gr.Progress()): |
|
|
if video_file is None: |
|
|
raise gr.Error("Please upload a video first") |
|
|
|
|
|
|
|
|
new_rec_id = str(uuid.uuid4()) |
|
|
|
|
|
try: |
|
|
|
|
|
for stream_data in processor.process_video_streaming( |
|
|
new_rec_id, |
|
|
video_file, |
|
|
max_frames, |
|
|
ref_frame_idx, |
|
|
use_mask_ref, |
|
|
progress |
|
|
): |
|
|
|
|
|
yield stream_data, "🔄 Processing... Please wait." |
|
|
|
|
|
|
|
|
mask_status = "with refinement" if use_mask_ref else "without refinement" |
|
|
yield stream_data, f"✅ Processing complete! Reference frame: {ref_frame_idx}, Scene flow mask {mask_status}. Explore the 3D visualization above." |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
yield None, f"❌ Error: {str(e)}" |
|
|
|
|
|
|
|
|
def handle_video_input_change(video_file, max_frames, is_loading_example): |
|
|
"""Handle video input changes, skip if loading from example""" |
|
|
print(f"[VideoChange] Triggered with is_loading_example={is_loading_example}, video={video_file}") |
|
|
if is_loading_example: |
|
|
|
|
|
|
|
|
print("[VideoChange] Skipping extraction (loading from example)") |
|
|
return gr.update(), gr.update(), gr.update(), False |
|
|
|
|
|
if video_file is None: |
|
|
return None, gr.update(maximum=0, value=0), [], False |
|
|
print("[VideoChange] Extracting frames for user upload") |
|
|
frame_preview_img, ref_slider_update, frames = extract_and_preview(video_file, max_frames) |
|
|
return frame_preview_img, ref_slider_update, frames, False |
|
|
|
|
|
video_input.change( |
|
|
handle_video_input_change, |
|
|
inputs=[video_input, max_frames_slider, loading_example], |
|
|
outputs=[frame_preview, reference_frame_slider, extracted_frames, loading_example] |
|
|
) |
|
|
|
|
|
|
|
|
def handle_max_frames_change(video_file, max_frames, preset_ref): |
|
|
"""Handle max frames change, respecting preset reference frame""" |
|
|
frame_preview_img, ref_slider_update, frames = extract_and_preview(video_file, max_frames, preset_ref_frame=preset_ref) |
|
|
|
|
|
return frame_preview_img, ref_slider_update, frames, None |
|
|
|
|
|
max_frames_slider.change( |
|
|
handle_max_frames_change, |
|
|
inputs=[video_input, max_frames_slider, preset_ref_state], |
|
|
outputs=[frame_preview, reference_frame_slider, extracted_frames, preset_ref_state] |
|
|
) |
|
|
|
|
|
|
|
|
reference_frame_slider.change( |
|
|
update_frame_preview, |
|
|
inputs=[extracted_frames, reference_frame_slider], |
|
|
outputs=[frame_preview] |
|
|
) |
|
|
|
|
|
|
|
|
process_btn.click( |
|
|
process_video_wrapper, |
|
|
inputs=[video_input, max_frames_slider, reference_frame_slider, use_mask_refinement], |
|
|
outputs=[viewer, status_text] |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
CONFIG_DIR = "configs" |
|
|
CHECKPOINT_PATH = None |
|
|
|
|
|
app = create_app(CONFIG_DIR, CHECKPOINT_PATH) |
|
|
app.launch( |
|
|
share=True, |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
ssr_mode=False |
|
|
) |