BonanDing's picture
Reproduce Training & Fix distributed eval
681f346
import os
import random
import math
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from torchvision.transforms import InterpolationMode
from PIL import Image
from packaging import version as pver
from einops import rearrange
from tqdm import tqdm
from omegaconf import DictConfig
from lightning.pytorch.utilities.types import STEP_OUTPUT
from algorithms.common.metrics import (
LearnedPerceptualImagePatchSimilarity,
)
from utils.logging_utils import log_video, get_validation_metrics_for_videos
from .df_base import DiffusionForcingBase
from .models.vae import VAE_models
from .models.diffusion import Diffusion
from .models.pose_prediction import PosePredictionNet
import glob
import wandb
# Utility Functions
def euler_to_rotation_matrix(pitch, yaw):
"""
Convert pitch and yaw angles (in radians) to a 3x3 rotation matrix.
Supports batch input.
Args:
pitch (torch.Tensor): Pitch angles in radians.
yaw (torch.Tensor): Yaw angles in radians.
Returns:
torch.Tensor: Rotation matrix of shape (batch_size, 3, 3).
"""
cos_pitch, sin_pitch = torch.cos(pitch), torch.sin(pitch)
cos_yaw, sin_yaw = torch.cos(yaw), torch.sin(yaw)
R_pitch = torch.stack([
torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch),
torch.zeros_like(pitch), cos_pitch, -sin_pitch,
torch.zeros_like(pitch), sin_pitch, cos_pitch
], dim=-1).reshape(-1, 3, 3)
R_yaw = torch.stack([
cos_yaw, torch.zeros_like(yaw), sin_yaw,
torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw),
-sin_yaw, torch.zeros_like(yaw), cos_yaw
], dim=-1).reshape(-1, 3, 3)
return torch.matmul(R_yaw, R_pitch)
def euler_to_camera_to_world_matrix(pose):
"""
Convert (x, y, z, pitch, yaw) to a 4x4 camera-to-world transformation matrix using torch.
Supports both (5,) and (f, b, 5) shaped inputs.
Args:
pose (torch.Tensor): Pose tensor of shape (5,) or (f, b, 5).
Returns:
torch.Tensor: Camera-to-world transformation matrix of shape (4, 4).
"""
origin_dim = pose.ndim
if origin_dim == 1:
pose = pose.unsqueeze(0).unsqueeze(0) # Convert (5,) -> (1, 1, 5)
elif origin_dim == 2:
pose = pose.unsqueeze(0)
x, y, z, pitch, yaw = pose[..., 0], pose[..., 1], pose[..., 2], pose[..., 3], pose[..., 4]
pitch, yaw = torch.deg2rad(pitch), torch.deg2rad(yaw)
# Compute rotation matrix (batch mode)
R = euler_to_rotation_matrix(pitch, yaw) # Shape (f*b, 3, 3)
# Create the 4x4 transformation matrix
eye = torch.eye(4, dtype=torch.float32, device=pose.device)
camera_to_world = eye.repeat(R.shape[0], 1, 1) # Shape (f*b, 4, 4)
# Assign rotation
camera_to_world[:, :3, :3] = R
# Assign translation
camera_to_world[:, :3, 3] = torch.stack([x.reshape(-1), y.reshape(-1), z.reshape(-1)], dim=-1)
# Reshape back to (f, b, 4, 4) if needed
if origin_dim == 3:
return camera_to_world.view(pose.shape[0], pose.shape[1], 4, 4)
elif origin_dim == 2:
return camera_to_world.view(pose.shape[0], 4, 4)
else:
return camera_to_world.squeeze(0).squeeze(0) # Convert (1,1,4,4) -> (4,4)
def is_inside_fov_3d_hv(points, center, center_pitch, center_yaw, fov_half_h, fov_half_v):
"""
Check whether points are within a given 3D field of view (FOV)
with separately defined horizontal and vertical ranges.
The center view direction is specified by pitch and yaw (in degrees).
:param points: (N, B, 3) Sample point coordinates
:param center: (3,) Center coordinates of the FOV
:param center_pitch: Pitch angle of the center view (in degrees)
:param center_yaw: Yaw angle of the center view (in degrees)
:param fov_half_h: Horizontal half-FOV angle (in degrees)
:param fov_half_v: Vertical half-FOV angle (in degrees)
:return: Boolean tensor (N, B), indicating whether each point is inside the FOV
"""
# Compute vectors relative to the center
vectors = points - center # shape (N, B, 3)
x = vectors[..., 0]
y = vectors[..., 1]
z = vectors[..., 2]
# Compute horizontal angle (yaw): measured with respect to the z-axis as the forward direction,
# and the x-axis as left-right, resulting in a range of -180 to 180 degrees.
azimuth = torch.atan2(x, z) * (180 / math.pi)
# Compute vertical angle (pitch): measured with respect to the horizontal plane,
# resulting in a range of -90 to 90 degrees.
elevation = torch.atan2(y, torch.sqrt(x**2 + z**2)) * (180 / math.pi)
# Compute the angular difference from the center view (handling circular angle wrap-around)
diff_azimuth = (azimuth - center_yaw).abs() % 360
diff_elevation = (elevation - center_pitch).abs() % 360
# Adjust values greater than 180 degrees to the shorter angular difference
diff_azimuth = torch.where(diff_azimuth > 180, 360 - diff_azimuth, diff_azimuth)
diff_elevation = torch.where(diff_elevation > 180, 360 - diff_elevation, diff_elevation)
# Check if both horizontal and vertical angles are within their respective FOV limits
return (diff_azimuth < fov_half_h) & (diff_elevation < fov_half_v)
def generate_points_in_sphere(n_points, radius):
# Sample three independent uniform distributions
samples_r = torch.rand(n_points) # For radius distribution
samples_phi = torch.rand(n_points) # For azimuthal angle phi
samples_u = torch.rand(n_points) # For polar angle theta
# Apply cube root to ensure uniform volumetric distribution
r = radius * torch.pow(samples_r, 1/3)
# Azimuthal angle phi uniformly distributed in [0, 2π]
phi = 2 * math.pi * samples_phi
# Convert u to theta to ensure cos(theta) is uniformly distributed
theta = torch.acos(1 - 2 * samples_u)
# Convert spherical coordinates to Cartesian coordinates
x = r * torch.sin(theta) * torch.cos(phi)
y = r * torch.sin(theta) * torch.sin(phi)
z = r * torch.cos(theta)
points = torch.stack((x, y, z), dim=1)
return points
def tensor_max_with_number(tensor, number):
number_tensor = torch.tensor(number, dtype=tensor.dtype, device=tensor.device)
result = torch.max(tensor, number_tensor)
return result
def custom_meshgrid(*args):
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
if pver.parse(torch.__version__) < pver.parse('1.10'):
return torch.meshgrid(*args)
else:
return torch.meshgrid(*args, indexing='ij')
def camera_to_world_to_world_to_camera(camera_to_world: torch.Tensor) -> torch.Tensor:
"""
Convert Camera-to-World matrices to World-to-Camera matrices for a tensor with shape (f, b, 4, 4).
Args:
camera_to_world (torch.Tensor): A tensor of shape (f, b, 4, 4), where:
f = number of frames,
b = batch size.
Returns:
torch.Tensor: A tensor of shape (f, b, 4, 4) representing the World-to-Camera matrices.
"""
# Ensure input is a 4D tensor
assert camera_to_world.ndim == 4 and camera_to_world.shape[2:] == (4, 4), \
"Input must be of shape (f, b, 4, 4)"
# Extract the rotation (R) and translation (T) parts
R = camera_to_world[:, :, :3, :3] # Shape: (f, b, 3, 3)
T = camera_to_world[:, :, :3, 3] # Shape: (f, b, 3)
# Initialize an identity matrix for the output
world_to_camera = torch.eye(4, device=camera_to_world.device).unsqueeze(0).unsqueeze(0)
world_to_camera = world_to_camera.repeat(camera_to_world.size(0), camera_to_world.size(1), 1, 1) # Shape: (f, b, 4, 4)
# Compute the rotation (transpose of R)
world_to_camera[:, :, :3, :3] = R.transpose(2, 3)
# Compute the translation (-R^T * T)
world_to_camera[:, :, :3, 3] = -torch.matmul(R.transpose(2, 3), T.unsqueeze(-1)).squeeze(-1)
return world_to_camera.to(camera_to_world.dtype)
def convert_to_plucker(poses, curr_frame, focal_length, image_width, image_height):
intrinsic = np.asarray([focal_length * image_width,
focal_length * image_height,
0.5 * image_width,
0.5 * image_height], dtype=np.float32)
c2ws = get_relative_pose(poses, zero_first_frame_scale=curr_frame)
c2ws = rearrange(c2ws, "t b m n -> b t m n")
K = torch.as_tensor(intrinsic, device=poses.device, dtype=poses.dtype).repeat(c2ws.shape[0],c2ws.shape[1],1) # [B, F, 4]
plucker_embedding = ray_condition(K, c2ws, image_height, image_width, device=c2ws.device)
plucker_embedding = rearrange(plucker_embedding, "b t h w d -> t b h w d").contiguous()
return plucker_embedding
def get_relative_pose(abs_c2ws, zero_first_frame_scale):
abs_w2cs = camera_to_world_to_world_to_camera(abs_c2ws)
target_cam_c2w = torch.tensor([
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]
]).to(abs_c2ws.device).to(abs_c2ws.dtype)
abs2rel = target_cam_c2w @ abs_w2cs[zero_first_frame_scale]
ret_poses = [abs2rel @ abs_c2w for abs_c2w in abs_c2ws]
ret_poses = torch.stack(ret_poses)
return ret_poses
def ray_condition(K, c2w, H, W, device):
# c2w: B, V, 4, 4
# K: B, V, 4
B = K.shape[0]
j, i = custom_meshgrid(
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
)
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
zs = torch.ones_like(i, device=device, dtype=c2w.dtype) # [B, HxW]
xs = -(i - cx) / fx * zs
ys = -(j - cy) / fy * zs
zs = zs.expand_as(ys)
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
rays_o = c2w[..., :3, 3] # B, V, 3
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
# c2w @ dirctions
rays_dxo = torch.linalg.cross(rays_o, rays_d)
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
return plucker
def random_transform(tensor):
"""
Apply the same random translation, rotation, and scaling to all frames in the batch.
Args:
tensor (torch.Tensor): Input tensor of shape (F, B, 3, H, W).
Returns:
torch.Tensor: Transformed tensor of shape (F, B, 3, H, W).
"""
if tensor.ndim != 5:
raise ValueError("Input tensor must have shape (F, B, 3, H, W)")
F, B, C, H, W = tensor.shape
# Generate random transformation parameters
max_translate = 0.2 # Translate up to 20% of width/height
max_rotate = 30 # Rotate up to 30 degrees
max_scale = 0.2 # Scale change by up to +/- 20%
translate_x = random.uniform(-max_translate, max_translate) * W
translate_y = random.uniform(-max_translate, max_translate) * H
rotate_angle = random.uniform(-max_rotate, max_rotate)
scale_factor = 1 + random.uniform(-max_scale, max_scale)
# Apply the same transformation to all frames and batches
tensor = tensor.reshape(F*B, C, H, W)
transformed_tensor = TF.affine(
tensor,
angle=rotate_angle,
translate=(translate_x, translate_y),
scale=scale_factor,
shear=(0, 0),
interpolation=InterpolationMode.BILINEAR,
fill=0
)
transformed_tensor = transformed_tensor.reshape(F, B, C, H, W)
return transformed_tensor
def save_tensor_as_png(tensor, file_path):
"""
Save a 3*H*W tensor as a PNG image.
Args:
tensor (torch.Tensor): Input tensor of shape (3, H, W).
file_path (str): Path to save the PNG file.
"""
if tensor.ndim != 3 or tensor.shape[0] != 3:
raise ValueError("Input tensor must have shape (3, H, W)")
# Convert tensor to PIL Image
image = TF.to_pil_image(tensor)
# Save image
image.save(file_path)
class WorldMemMinecraft(DiffusionForcingBase):
"""
Video generation for MineCraft with memory.
"""
def __init__(self, cfg: DictConfig):
"""
Initialize the WorldMemMinecraft class with the given configuration.
Args:
cfg (DictConfig): Configuration object.
"""
self.n_tokens = cfg.n_frames // cfg.frame_stack # number of max tokens for the model
self.n_frames = cfg.n_frames
if hasattr(cfg, "n_tokens"):
self.n_tokens = cfg.n_tokens // cfg.frame_stack
self.memory_condition_length = cfg.memory_condition_length
self.pose_cond_dim = getattr(cfg, "pose_cond_dim", 5)
self.use_plucker = getattr(cfg, "use_plucker", True)
self.relative_embedding = getattr(cfg, "relative_embedding", True)
self.state_embed_only_on_qk = getattr(cfg, "state_embed_only_on_qk", True)
self.use_memory_attention = getattr(cfg, "use_memory_attention", True)
self.add_timestamp_embedding = getattr(cfg, "add_timestamp_embedding", True)
self.ref_mode = getattr(cfg, "ref_mode", 'sequential')
self.log_curve = getattr(cfg, "log_curve", False)
self.focal_length = getattr(cfg, "focal_length", 0.35)
self.log_video = cfg.log_video
self.save_local = getattr(cfg, "save_local", True)
self.local_save_dir = getattr(cfg, "local_save_dir", None)
self.lpips_batch_size = getattr(cfg, "lpips_batch_size", 16)
self.next_frame_length = getattr(cfg, "next_frame_length", 1)
self.require_pose_prediction = getattr(cfg, "require_pose_prediction", False)
super().__init__(cfg)
def _build_model(self):
self.diffusion_model = Diffusion(
reference_length=self.memory_condition_length,
x_shape=self.x_stacked_shape,
action_cond_dim=self.action_cond_dim,
pose_cond_dim=self.pose_cond_dim,
is_causal=self.causal,
cfg=self.cfg.diffusion,
is_dit=True,
use_plucker=self.use_plucker,
relative_embedding=self.relative_embedding,
state_embed_only_on_qk=self.state_embed_only_on_qk,
use_memory_attention=self.use_memory_attention,
add_timestamp_embedding=self.add_timestamp_embedding,
ref_mode=self.ref_mode
)
# Avoid distributed sync inside torchmetrics; reduce metrics manually across ranks.
self.validation_lpips_model = LearnedPerceptualImagePatchSimilarity(sync_on_compute=False)
vae = VAE_models["vit-l-20-shallow-encoder"]()
self.vae = vae.eval()
if self.require_pose_prediction:
self.pose_prediction_model = PosePredictionNet()
def _generate_noise_levels(self, xs: torch.Tensor, masks = None) -> torch.Tensor:
"""
Generate noise levels for training.
"""
num_frames, batch_size, *_ = xs.shape
match self.cfg.noise_level:
case "random_all": # entirely random noise levels
noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device)
case "same":
noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device)
noise_levels[1:] = noise_levels[0]
if masks is not None:
# for frames that are not available, treat as full noise
discard = torch.all(~rearrange(masks.bool(), "(t fs) b -> t b fs", fs=self.frame_stack), -1)
noise_levels = torch.where(discard, torch.full_like(noise_levels, self.timesteps - 1), noise_levels)
return noise_levels
def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
"""
Perform a single training step.
This function processes the input batch,
encodes the input frames, generates noise levels, and computes the loss using the diffusion model.
Args:
batch: Input batch of data containing frames, conditions, poses, etc.
batch_idx: Index of the current batch.
Returns:
dict: A dictionary containing the training loss.
"""
xs, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch)
if self.use_plucker:
if self.relative_embedding:
input_pose_condition = []
frame_idx_list = []
for i in range(self.n_frames):
input_pose_condition.append(
convert_to_plucker(
torch.cat([c2w_mat[i:i + 1], c2w_mat[-self.memory_condition_length:]]).clone(),
0,
focal_length=self.focal_length,
image_height=xs.shape[-2],image_width=xs.shape[-1]
).to(xs.dtype)
) # [V(1 + memory_condition_length),B ,H, W, 6]
frame_idx_list.append(
torch.cat([
frame_idx[i:i + 1] - frame_idx[i:i + 1],
frame_idx[-self.memory_condition_length:] - frame_idx[i:i + 1]
]).clone()
) # [V(1 + memory_condition_length),B] (0 for current frame, others for memory frames with relative index to current frame)
input_pose_condition = torch.cat(input_pose_condition)
frame_idx_list = torch.cat(frame_idx_list)
else:
input_pose_condition = convert_to_plucker(
c2w_mat, 0, focal_length=self.focal_length
).to(xs.dtype)
frame_idx_list = frame_idx
else:
input_pose_condition = pose_conditions.to(xs.dtype)
frame_idx_list = None
xs = self.encode(xs)
noise_levels = self._generate_noise_levels(xs)
if self.memory_condition_length:
noise_levels[-self.memory_condition_length:] = self.diffusion_model.stabilization_level
conditions[-self.memory_condition_length:] *= 0
_, loss = self.diffusion_model(
xs,
conditions,
input_pose_condition,
noise_levels=noise_levels,
reference_length=self.memory_condition_length,
frame_idx=frame_idx_list
)
if self.memory_condition_length:
loss = loss[:-self.memory_condition_length]
loss = self.reweight_loss(loss, None)
if batch_idx % 20 == 0:
self.log("training/loss", loss.cpu())
return {"loss": loss}
def on_validation_epoch_end(self, namespace="validation") -> None:
if not hasattr(self, "_metric_device"):
return
if dist.is_available() and dist.is_initialized():
for tensor in (
self._mse_sum,
self._mse_count,
self._psnr_sum,
self._psnr_count,
self._lpips_sum,
self._lpips_count,
):
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
mse = self._mse_sum / self._mse_count.clamp_min(1.0)
psnr = self._psnr_sum / self._psnr_count.clamp_min(1.0)
lpips = self._lpips_sum / self._lpips_count.clamp_min(1.0)
if self.trainer is None or self.trainer.is_global_zero:
if self._mse_count.item() > 0:
self.log_dict(
{"mse": mse, "psnr": psnr, "lpips": lpips},
sync_dist=False,
)
self.validation_step_outputs.clear()
def on_validation_epoch_start(self) -> None:
self._reset_metric_accumulators()
def on_test_epoch_start(self) -> None:
self._reset_metric_accumulators()
def _reset_metric_accumulators(self) -> None:
self._metric_device = next(self.validation_lpips_model.parameters()).device
self._mse_sum = torch.tensor(0.0, device=self._metric_device)
self._mse_count = torch.tensor(0.0, device=self._metric_device)
self._psnr_sum = torch.tensor(0.0, device=self._metric_device)
self._psnr_count = torch.tensor(0.0, device=self._metric_device)
self._lpips_sum = torch.tensor(0.0, device=self._metric_device)
self._lpips_count = torch.tensor(0.0, device=self._metric_device)
def _update_metric_accumulators(self, xs_pred: torch.Tensor, xs_gt: torch.Tensor) -> None:
xs_pred_device = xs_pred.to(self._metric_device)
xs_device = xs_gt.to(self._metric_device)
metric_dict = get_validation_metrics_for_videos(
xs_pred_device,
xs_device,
lpips_model=self.validation_lpips_model,
lpips_batch_size=self.lpips_batch_size,
)
mse_val = metric_dict["mse"].detach()
psnr_val = metric_dict["psnr"].detach()
lpips_val = torch.tensor(metric_dict["lpips"], device=self._metric_device)
mse_count_batch = torch.tensor(float(xs_pred_device.numel()), device=self._metric_device)
psnr_count_batch = torch.tensor(float(xs_pred_device.shape[1]), device=self._metric_device)
lpips_count_batch = torch.tensor(
float(xs_pred_device.shape[0] * xs_pred_device.shape[1]), device=self._metric_device
)
self._mse_sum += mse_val * mse_count_batch
self._psnr_sum += psnr_val * psnr_count_batch
self._lpips_sum += lpips_val * lpips_count_batch
self._mse_count += mse_count_batch
self._psnr_count += psnr_count_batch
self._lpips_count += lpips_count_batch
del xs_pred_device, xs_device
def _preprocess_batch(self, batch):
xs, conditions, pose_conditions, frame_index = batch
if self.action_cond_dim:
conditions = torch.cat([torch.zeros_like(conditions[:, :1]), conditions[:, 1:]], 1)
conditions = rearrange(conditions, "b t d -> t b d").contiguous()
else:
raise NotImplementedError("Only support external cond.")
pose_conditions = rearrange(pose_conditions, "b t d -> t b d").contiguous()
c2w_mat = euler_to_camera_to_world_matrix(pose_conditions)
xs = rearrange(xs, "b t c ... -> t b c ...").contiguous()
frame_index = rearrange(frame_index, "b t -> t b").contiguous()
return xs, conditions, pose_conditions, c2w_mat, frame_index
def encode(self, x):
# vae encoding x with shape (t b c h w)
T = x.shape[0]
H, W = x.shape[-2:]
scaling_factor = 0.07843137255
x = rearrange(x, "t b c h w -> (t b) c h w")
with torch.no_grad():
x = self.vae.encode(x * 2 - 1).mean * scaling_factor
x = rearrange(x, "(t b) (h w) c -> t b c h w", t=T, h=H // self.vae.patch_size, w=W // self.vae.patch_size)
return x
def decode(self, x):
total_frames = x.shape[0]
scaling_factor = 0.07843137255
x = rearrange(x, "t b c h w -> (t b) (h w) c")
with torch.no_grad():
x = (self.vae.decode(x / scaling_factor) + 1) / 2
x = rearrange(x, "(t b) c h w-> t b c h w", t=total_frames)
return x
def _generate_condition_indices(self, curr_frame, memory_condition_length, xs_pred, pose_conditions, frame_idx, horizon):
"""
Generate indices for condition similarity based on the current frame and pose conditions.
"""
if curr_frame < memory_condition_length:
random_idx = [i for i in range(curr_frame)] + [0] * (memory_condition_length - curr_frame)
random_idx = np.repeat(np.array(random_idx)[:, None], xs_pred.shape[1], -1)
else:
# Generate points in a sphere and filter based on field of view
num_samples = 10000
radius = 30
points = generate_points_in_sphere(num_samples, radius).to(pose_conditions.device)
points = points[:, None].repeat(1, pose_conditions.shape[1], 1)
points += pose_conditions[curr_frame, :, :3][None]
fov_half_h = torch.tensor(105 / 2, device=pose_conditions.device)
fov_half_v = torch.tensor(75 / 2, device=pose_conditions.device)
# in_fov1 = is_inside_fov_3d_hv(
# points, pose_conditions[curr_frame, :, :3],
# pose_conditions[curr_frame, :, -2], pose_conditions[curr_frame, :, -1],
# fov_half_h, fov_half_v
# )
in_fov1 = torch.stack([
is_inside_fov_3d_hv(points, pc[:, :3], pc[:, -2], pc[:, -1], fov_half_h, fov_half_v)
for pc in pose_conditions[curr_frame:curr_frame+horizon]
])
in_fov1 = torch.sum(in_fov1, 0) > 0
# Compute overlap ratios and select indices
in_fov_list = torch.stack([
is_inside_fov_3d_hv(points, pc[:, :3], pc[:, -2], pc[:, -1], fov_half_h, fov_half_v)
for pc in pose_conditions[:curr_frame]
])
random_idx = []
for _ in range(memory_condition_length):
overlap_ratio = ((in_fov1.bool() & in_fov_list).sum(1)) / in_fov1.sum()
confidence = overlap_ratio + (curr_frame - frame_idx[:curr_frame]) / curr_frame * (-0.2)
if len(random_idx) > 0:
confidence[torch.cat(random_idx)] = -1e10
_, r_idx = torch.topk(confidence, k=1, dim=0)
random_idx.append(r_idx[0])
# choice 1: directly remove overlapping region
occupied_mask = in_fov_list[r_idx[0, range(in_fov1.shape[-1])], :, range(in_fov1.shape[-1])].permute(1,0)
in_fov1 = in_fov1 & ~occupied_mask
# choice 2: apply similarity filter
# cos_sim = F.cosine_similarity(xs_pred.to(r_idx.device)[r_idx[:, range(in_fov1.shape[1])],
# range(in_fov1.shape[1])], xs_pred.to(r_idx.device)[:curr_frame], dim=2)
# cos_sim = cos_sim.mean((-2,-1))
# mask_sim = cos_sim>0.9
# in_fov_list = in_fov_list & ~mask_sim[:,None].to(in_fov_list.device)
random_idx = torch.stack(random_idx).cpu()
return random_idx
def _prepare_conditions(self,
start_frame, curr_frame, horizon, conditions,
pose_conditions, c2w_mat, frame_idx, random_idx,
image_width, image_height):
"""
Prepare input conditions and pose conditions for sampling.
"""
padding = torch.zeros((len(random_idx),) + conditions.shape[1:], device=conditions.device, dtype=conditions.dtype)
input_condition = torch.cat([conditions[start_frame:curr_frame + horizon], padding], dim=0)
batch_size = conditions.shape[1]
if self.use_plucker:
if self.relative_embedding:
frame_idx_list = []
input_pose_condition = []
for i in range(start_frame, curr_frame + horizon):
input_pose_condition.append(convert_to_plucker(torch.cat([c2w_mat[i:i+1],c2w_mat[random_idx[:,range(batch_size)], range(batch_size)]]).clone(), 0, focal_length=self.focal_length,
image_width=image_width, image_height=image_height).to(conditions.dtype))
frame_idx_list.append(torch.cat([frame_idx[i:i+1]-frame_idx[i:i+1], frame_idx[random_idx[:,range(batch_size)], range(batch_size)]-frame_idx[i:i+1]]))
input_pose_condition = torch.cat(input_pose_condition)
frame_idx_list = torch.cat(frame_idx_list)
else:
input_pose_condition = torch.cat([c2w_mat[start_frame : curr_frame + horizon], c2w_mat[random_idx[:,range(batch_size)], range(batch_size)]], dim=0).clone()
input_pose_condition = convert_to_plucker(input_pose_condition, 0, focal_length=self.focal_length)
frame_idx_list = None
else:
input_pose_condition = torch.cat([pose_conditions[start_frame : curr_frame + horizon], pose_conditions[random_idx[:,range(batch_size)], range(batch_size)]], dim=0).clone()
frame_idx_list = None
return input_condition, input_pose_condition, frame_idx_list
def _prepare_noise_levels(self, scheduling_matrix, m, curr_frame, batch_size, memory_condition_length):
"""
Prepare noise levels for the current sampling step.
"""
from_noise_levels = np.concatenate((np.zeros((curr_frame,), dtype=np.int64), scheduling_matrix[m]))[:, None].repeat(batch_size, axis=1)
to_noise_levels = np.concatenate((np.zeros((curr_frame,), dtype=np.int64), scheduling_matrix[m + 1]))[:, None].repeat(batch_size, axis=1)
if memory_condition_length:
from_noise_levels = np.concatenate([from_noise_levels, np.zeros((memory_condition_length, from_noise_levels.shape[-1]), dtype=np.int32)], axis=0)
to_noise_levels = np.concatenate([to_noise_levels, np.zeros((memory_condition_length, from_noise_levels.shape[-1]), dtype=np.int32)], axis=0)
from_noise_levels = torch.from_numpy(from_noise_levels).to(self.device)
to_noise_levels = torch.from_numpy(to_noise_levels).to(self.device)
return from_noise_levels, to_noise_levels
def validation_step(self, batch, batch_idx, namespace="validation") -> STEP_OUTPUT:
"""
Perform a single validation step.
This function processes the input batch, encodes frames, generates predictions using a sliding window approach,
and handles condition similarity logic for sampling. The results are decoded and stored for evaluation.
Args:
batch: Input batch of data containing frames, conditions, poses, etc.
batch_idx: Index of the current batch.
namespace: Namespace for logging (default: "validation").
Returns:
None: Appends the predicted and ground truth frames to `self.validation_step_outputs`.
"""
# Preprocess the input batch
memory_condition_length = self.memory_condition_length
xs_raw, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch)
# Encode frames in chunks if necessary
total_frame = xs_raw.shape[0]
if total_frame > 10:
xs = torch.cat([
self.encode(xs_raw[int(total_frame * i / 10):int(total_frame * (i + 1) / 10)]).cpu()
for i in range(10)
])
else:
xs = self.encode(xs_raw).cpu()
n_frames, batch_size, *_ = xs.shape
curr_frame = 0
# Initialize context frames
n_context_frames = self.context_frames // self.frame_stack
xs_pred = xs[:n_context_frames].clone()
curr_frame += n_context_frames
# Progress bar for sampling
pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
while curr_frame < n_frames:
# Determine the horizon for the current chunk
horizon = min(n_frames - curr_frame, self.chunk_size) if self.chunk_size > 0 else n_frames - curr_frame
assert horizon <= self.n_tokens, "Horizon exceeds the number of tokens."
# Generate scheduling matrix and initialize noise
scheduling_matrix = self._generate_scheduling_matrix(horizon)
chunk = torch.randn((horizon, batch_size, *xs_pred.shape[2:]))
chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise).to(xs_pred.device)
xs_pred = torch.cat([xs_pred, chunk], 0)
# Sliding window: only input the last `n_tokens` frames
start_frame = max(0, curr_frame + horizon - self.n_tokens)
pbar.set_postfix({"start": start_frame, "end": curr_frame + horizon})
# Handle condition similarity logic
if memory_condition_length:
random_idx = self._generate_condition_indices(
curr_frame, memory_condition_length, xs_pred, pose_conditions, frame_idx, horizon
)
xs_pred = torch.cat([xs_pred, xs_pred[random_idx[:, range(xs_pred.shape[1])], range(xs_pred.shape[1])].clone()], 0)
# Prepare input conditions and pose conditions
input_condition, input_pose_condition, frame_idx_list = self._prepare_conditions(
start_frame, curr_frame, horizon, conditions, pose_conditions, c2w_mat, frame_idx, random_idx,
image_width=xs_raw.shape[-1], image_height=xs_raw.shape[-2]
)
# Perform sampling for each step in the scheduling matrix
for m in range(scheduling_matrix.shape[0] - 1):
from_noise_levels, to_noise_levels = self._prepare_noise_levels(
scheduling_matrix, m, curr_frame, batch_size, memory_condition_length
)
xs_pred[start_frame:] = self.diffusion_model.sample_step(
xs_pred[start_frame:].to(input_condition.device),
input_condition,
input_pose_condition,
from_noise_levels[start_frame:],
to_noise_levels[start_frame:],
current_frame=curr_frame,
mode="validation",
reference_length=memory_condition_length,
frame_idx=frame_idx_list
).cpu()
# Remove condition similarity frames if applicable
if memory_condition_length:
xs_pred = xs_pred[:-memory_condition_length]
curr_frame += horizon
pbar.update(horizon)
# Decode predictions and ground truth
xs_pred = self.decode(xs_pred[n_context_frames:].to(conditions.device))
xs_decode = self.decode(xs[n_context_frames:].to(conditions.device))
# Save videos for every batch (rank is encoded in filenames).
if self.logger and self.log_video:
log_video(
xs_pred,
xs_decode,
step=batch_idx,
namespace=namespace + "_vis",
context_frames=self.context_frames,
logger=self.logger.experiment,
save_local=self.save_local,
local_save_dir=self.local_save_dir,
)
# Stream metrics to avoid holding all outputs in memory.
self._update_metric_accumulators(xs_pred, xs_decode)
return
@torch.no_grad()
def interactive(self, first_frame, new_actions, first_pose, device,
memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx):
memory_condition_length = self.memory_condition_length
if memory_latent_frames is None:
first_frame = torch.from_numpy(first_frame)
new_actions = torch.from_numpy(new_actions)
first_pose = torch.from_numpy(first_pose)
first_frame_encode = self.encode(first_frame[None, None].to(device))
memory_latent_frames = first_frame_encode.cpu()
memory_actions = new_actions[None, None].to(device)
memory_poses = first_pose[None, None].to(device)
new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
memory_c2w = new_c2w_mat[None, None].to(device)
memory_frame_idx = torch.tensor([[0]]).to(device)
return first_frame.cpu().numpy(), memory_latent_frames.cpu().numpy(), memory_actions.cpu().numpy(), memory_poses.cpu().numpy(), memory_c2w.cpu().numpy(), memory_frame_idx.cpu().numpy()
else:
memory_latent_frames = torch.from_numpy(memory_latent_frames)
memory_actions = torch.from_numpy(memory_actions).to(device)
memory_poses = torch.from_numpy(memory_poses).to(device)
memory_c2w = torch.from_numpy(memory_c2w).to(device)
memory_frame_idx = torch.from_numpy(memory_frame_idx).to(device)
new_actions = new_actions.to(device)
curr_frame = 0
batch_size = 1
horizon = self.next_frame_length
n_frames = curr_frame + horizon
# context
n_context_frames = len(memory_latent_frames)
xs_pred = memory_latent_frames[:n_context_frames].clone()
curr_frame += n_context_frames
pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
new_pose_condition_list = []
last_frame = xs_pred[-1].clone()
last_pose_condition = memory_poses[-1].clone()
curr_actions = new_actions.clone()
for hi in range(len(new_actions)):
last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None, hi], last_pose_condition)
new_pose_condition_offset[:,3:] = torch.round(new_pose_condition_offset[:,3:])
new_pose_condition = last_pose_condition + new_pose_condition_offset
new_pose_condition[:,3:] = new_pose_condition[:,3:] * 15
new_pose_condition[:,3:] %= 360
last_pose_condition = new_pose_condition.clone()
new_pose_condition_list.append(new_pose_condition[None])
new_pose_condition_list = torch.cat(new_pose_condition_list, 0)
ai = 0
while ai < len(new_actions):
next_horizon = min(horizon, len(new_actions) - ai)
last_frame = xs_pred[-1].clone()
curr_actions = new_actions[ai:ai+next_horizon].clone()
new_pose_condition = new_pose_condition_list[ai:ai+next_horizon].clone()
new_c2w_mat = euler_to_camera_to_world_matrix(new_pose_condition)
memory_poses = torch.cat([memory_poses, new_pose_condition])
memory_actions = torch.cat([memory_actions, curr_actions[:, None]])
memory_c2w = torch.cat([memory_c2w, new_c2w_mat])
new_indices = memory_frame_idx[-1,0] + torch.arange(next_horizon, device=memory_frame_idx.device) + 1
memory_frame_idx = torch.cat([memory_frame_idx, new_indices[:, None]])
conditions = memory_actions.clone()
pose_conditions = memory_poses.clone()
c2w_mat = memory_c2w .clone()
frame_idx = memory_frame_idx.clone()
# generation on frame
scheduling_matrix = self._generate_scheduling_matrix(next_horizon)
chunk = torch.randn((next_horizon, batch_size, *xs_pred.shape[2:])).to(xs_pred.device)
chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise)
xs_pred = torch.cat([xs_pred, chunk], 0)
# sliding window: only input the last n_tokens frames
start_frame = max(0, curr_frame - self.n_tokens)
pbar.set_postfix(
{
"start": start_frame,
"end": curr_frame + next_horizon,
}
)
# Handle condition similarity logic
if memory_condition_length:
random_idx = self._generate_condition_indices(
curr_frame, memory_condition_length, xs_pred, pose_conditions, frame_idx, next_horizon
)
# random_idx = np.unique(random_idx)[:, None]
# memory_condition_length = len(random_idx)
xs_pred = torch.cat([xs_pred, xs_pred[random_idx[:, range(xs_pred.shape[1])], range(xs_pred.shape[1])].clone()], 0)
# Prepare input conditions and pose conditions
input_condition, input_pose_condition, frame_idx_list = self._prepare_conditions(
start_frame, curr_frame, next_horizon, conditions, pose_conditions, c2w_mat, frame_idx, random_idx,
image_width=first_frame.shape[-1], image_height=first_frame.shape[-2]
)
# Perform sampling for each step in the scheduling matrix
for m in range(scheduling_matrix.shape[0] - 1):
from_noise_levels, to_noise_levels = self._prepare_noise_levels(
scheduling_matrix, m, curr_frame, batch_size, memory_condition_length
)
xs_pred[start_frame:] = self.diffusion_model.sample_step(
xs_pred[start_frame:].to(input_condition.device),
input_condition,
input_pose_condition,
from_noise_levels[start_frame:],
to_noise_levels[start_frame:],
current_frame=curr_frame,
mode="validation",
reference_length=memory_condition_length,
frame_idx=frame_idx_list
).cpu()
if memory_condition_length:
xs_pred = xs_pred[:-memory_condition_length]
curr_frame += next_horizon
pbar.update(next_horizon)
ai += next_horizon
memory_latent_frames = torch.cat([memory_latent_frames, xs_pred[n_context_frames:]])
xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
return xs_pred.cpu().numpy(), memory_latent_frames.cpu().numpy(), memory_actions.cpu().numpy(), \
memory_poses.cpu().numpy(), memory_c2w.cpu().numpy(), memory_frame_idx.cpu().numpy()