|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from einops import rearrange |
|
|
|
|
|
from cosmos_predict1.diffusion.inference.forward_warp_utils_pytorch import ( |
|
|
forward_warp, |
|
|
reliable_depth_mask_range_batch, |
|
|
unproject_points, |
|
|
) |
|
|
from cosmos_predict1.diffusion.inference.camera_utils import align_depth |
|
|
|
|
|
class Cache3D_Base: |
|
|
def __init__( |
|
|
self, |
|
|
input_image, |
|
|
input_depth, |
|
|
input_w2c, |
|
|
input_intrinsics, |
|
|
input_mask=None, |
|
|
input_format=None, |
|
|
input_points=None, |
|
|
weight_dtype=torch.float32, |
|
|
is_depth=True, |
|
|
device="cuda", |
|
|
filter_points_threshold=1.0, |
|
|
foreground_masking=False, |
|
|
): |
|
|
""" |
|
|
input_image: Tensor with varying dimensions. |
|
|
input_format: List of dimension labels corresponding to input_image's dimensions. |
|
|
E.g., ['B', 'C', 'H', 'W'], ['B', 'F', 'C', 'H', 'W'], etc. |
|
|
""" |
|
|
self.weight_dtype = weight_dtype |
|
|
self.is_depth = is_depth |
|
|
self.device = device |
|
|
self.filter_points_threshold = filter_points_threshold |
|
|
self.foreground_masking = foreground_masking |
|
|
if input_format is None: |
|
|
assert input_image.dim() == 4 |
|
|
input_format = ["B", "C", "H", "W"] |
|
|
|
|
|
|
|
|
format_to_indices = {dim: idx for idx, dim in enumerate(input_format)} |
|
|
input_shape = input_image.shape |
|
|
if input_mask is not None: |
|
|
input_image = torch.cat([input_image, input_mask], dim=format_to_indices.get("C")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
B = input_shape[format_to_indices.get("B", 0)] if "B" in format_to_indices else 1 |
|
|
F = input_shape[format_to_indices.get("F", 0)] if "F" in format_to_indices else 1 |
|
|
N = input_shape[format_to_indices.get("N", 0)] if "N" in format_to_indices else 1 |
|
|
V = input_shape[format_to_indices.get("V", 0)] if "V" in format_to_indices else 1 |
|
|
H = input_shape[format_to_indices.get("H", 0)] if "H" in format_to_indices else None |
|
|
W = input_shape[format_to_indices.get("W", 0)] if "W" in format_to_indices else None |
|
|
|
|
|
|
|
|
desired_dims = ["B", "F", "N", "V", "C", "H", "W"] |
|
|
|
|
|
|
|
|
permute_order = [] |
|
|
for dim in desired_dims: |
|
|
idx = format_to_indices.get(dim) |
|
|
if idx is not None: |
|
|
permute_order.append(idx) |
|
|
else: |
|
|
|
|
|
permute_order.append(None) |
|
|
|
|
|
|
|
|
permute_indices = [idx for idx in permute_order if idx is not None] |
|
|
input_image = input_image.permute(*permute_indices) |
|
|
|
|
|
|
|
|
for i, idx in enumerate(permute_order): |
|
|
if idx is None: |
|
|
input_image = input_image.unsqueeze(i) |
|
|
|
|
|
|
|
|
if input_mask is not None: |
|
|
self.input_image, self.input_mask = input_image[:, :, :, :, :3], input_image[:, :, :, :, 3:] |
|
|
self.input_mask = self.input_mask.to("cpu") |
|
|
else: |
|
|
self.input_mask = None |
|
|
self.input_image = input_image |
|
|
self.input_image = self.input_image.to(weight_dtype).to("cpu") |
|
|
|
|
|
if input_points is not None: |
|
|
self.input_points = input_points.reshape(B, F, N, V, H, W, 3).to("cpu") |
|
|
self.input_depth = None |
|
|
else: |
|
|
input_depth = torch.nan_to_num(input_depth, nan=100) |
|
|
input_depth = torch.clamp(input_depth, min=0, max=100) |
|
|
if weight_dtype == torch.float16: |
|
|
input_depth = torch.clamp(input_depth, max=70) |
|
|
self.input_points = ( |
|
|
self._compute_input_points( |
|
|
input_depth.reshape(-1, 1, H, W), |
|
|
input_w2c.reshape(-1, 4, 4), |
|
|
input_intrinsics.reshape(-1, 3, 3), |
|
|
) |
|
|
.to(weight_dtype) |
|
|
.reshape(B, F, N, V, H, W, 3) |
|
|
.to("cpu") |
|
|
) |
|
|
self.input_depth = input_depth |
|
|
|
|
|
if self.filter_points_threshold < 1.0 and input_depth is not None: |
|
|
input_depth = input_depth.reshape(-1, 1, H, W) |
|
|
depth_mask = reliable_depth_mask_range_batch(input_depth, ratio_thresh=self.filter_points_threshold).reshape(B, F, N, V, 1, H, W) |
|
|
if self.input_mask is None: |
|
|
self.input_mask = depth_mask.to("cpu") |
|
|
else: |
|
|
self.input_mask = self.input_mask * depth_mask.to(self.input_mask.device) |
|
|
self.boundary_mask = None |
|
|
if foreground_masking: |
|
|
input_depth = input_depth.reshape(-1, 1, H, W) |
|
|
depth_mask = reliable_depth_mask_range_batch(input_depth) |
|
|
self.boundary_mask = (~depth_mask).reshape(B, F, N, V, 1, H, W).to("cpu") |
|
|
|
|
|
def _compute_input_points(self, input_depth, input_w2c, input_intrinsics): |
|
|
input_points = unproject_points( |
|
|
input_depth, |
|
|
input_w2c, |
|
|
input_intrinsics, |
|
|
is_depth=self.is_depth, |
|
|
) |
|
|
return input_points |
|
|
|
|
|
def update_cache(self): |
|
|
raise NotImplementedError |
|
|
|
|
|
def input_frame_count(self) -> int: |
|
|
return self.input_image.shape[1] |
|
|
|
|
|
def render_cache(self, target_w2cs, target_intrinsics, render_depth=False, start_frame_idx=0): |
|
|
bs, F_target, _, _ = target_w2cs.shape |
|
|
|
|
|
B, F, N, V, C, H, W = self.input_image.shape |
|
|
assert bs == B |
|
|
|
|
|
target_w2cs = target_w2cs.reshape(B, F_target, 1, 4, 4).expand(B, F_target, N, 4, 4).reshape(-1, 4, 4) |
|
|
target_intrinsics = ( |
|
|
target_intrinsics.reshape(B, F_target, 1, 3, 3).expand(B, F_target, N, 3, 3).reshape(-1, 3, 3) |
|
|
) |
|
|
|
|
|
|
|
|
first_images = rearrange(self.input_image[:, start_frame_idx:start_frame_idx+F_target].expand(B, F_target, N, V, C, H, W), "B F N V C H W-> (B F N) V C H W") |
|
|
first_points = rearrange( |
|
|
self.input_points[:, start_frame_idx:start_frame_idx+F_target].expand(B, F_target, N, V, H, W, 3), "B F N V H W C-> (B F N) V H W C" |
|
|
) |
|
|
first_masks = rearrange( |
|
|
self.input_mask[:, start_frame_idx:start_frame_idx+F_target].expand(B, F_target, N, V, 1, H, W), "B F N V C H W-> (B F N) V C H W" |
|
|
) if self.input_mask is not None else None |
|
|
boundary_masks = rearrange( |
|
|
self.boundary_mask.expand(B, F_target, N, V, 1, H, W), "B F N V C H W-> (B F N) V C H W" |
|
|
) if self.boundary_mask is not None else None |
|
|
|
|
|
if first_images.shape[1] == 1: |
|
|
warp_chunk_size = 2 |
|
|
rendered_warp_images = [] |
|
|
rendered_warp_masks = [] |
|
|
rendered_warp_depth = [] |
|
|
|
|
|
first_images = first_images.squeeze(1) |
|
|
first_points = first_points.squeeze(1) |
|
|
first_masks = first_masks.squeeze(1) if first_masks is not None else None |
|
|
for i in range(0, first_images.shape[0], warp_chunk_size): |
|
|
with torch.no_grad(): |
|
|
imgs_chunk = first_images[i : i + warp_chunk_size].to(self.device, non_blocking=True) |
|
|
pts_chunk = first_points[i : i + warp_chunk_size].to(self.device, non_blocking=True) |
|
|
masks_chunk = ( |
|
|
first_masks[i : i + warp_chunk_size].to(self.device, non_blocking=True) |
|
|
if first_masks is not None |
|
|
else None |
|
|
) |
|
|
bmask_chunk = ( |
|
|
boundary_masks[i : i + warp_chunk_size, 0, 0].to(self.device, non_blocking=True) |
|
|
if boundary_masks is not None |
|
|
else None |
|
|
) |
|
|
( |
|
|
rendered_warp_images_chunk, |
|
|
rendered_warp_masks_chunk, |
|
|
rendered_warp_depth_chunk, |
|
|
_, |
|
|
) = forward_warp( |
|
|
imgs_chunk, |
|
|
mask1=masks_chunk, |
|
|
depth1=None, |
|
|
transformation1=None, |
|
|
transformation2=target_w2cs[i : i + warp_chunk_size], |
|
|
intrinsic1=target_intrinsics[i : i + warp_chunk_size], |
|
|
intrinsic2=target_intrinsics[i : i + warp_chunk_size], |
|
|
render_depth=render_depth, |
|
|
world_points1=pts_chunk, |
|
|
foreground_masking=self.foreground_masking, |
|
|
boundary_mask=bmask_chunk, |
|
|
) |
|
|
rendered_warp_images.append(rendered_warp_images_chunk.to("cpu")) |
|
|
rendered_warp_masks.append(rendered_warp_masks_chunk.to("cpu")) |
|
|
if render_depth: |
|
|
rendered_warp_depth.append(rendered_warp_depth_chunk.to("cpu")) |
|
|
del imgs_chunk, pts_chunk, masks_chunk, bmask_chunk |
|
|
del rendered_warp_images_chunk, rendered_warp_masks_chunk |
|
|
if render_depth: |
|
|
del rendered_warp_depth_chunk |
|
|
torch.cuda.empty_cache() |
|
|
rendered_warp_images = torch.cat(rendered_warp_images, dim=0) |
|
|
rendered_warp_masks = torch.cat(rendered_warp_masks, dim=0) |
|
|
if render_depth: |
|
|
rendered_warp_depth = torch.cat(rendered_warp_depth, dim=0) |
|
|
|
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
pixels = rearrange(rendered_warp_images, "(b f n) c h w -> b f n c h w", b=bs, f=F_target, n=N) |
|
|
masks = rearrange(rendered_warp_masks, "(b f n) c h w -> b f n c h w", b=bs, f=F_target, n=N) |
|
|
if render_depth: |
|
|
pixels = rearrange(rendered_warp_depth, "(b f n) h w -> b f n h w", b=bs, f=F_target, n=N) |
|
|
return pixels.to(self.device), masks.to(self.device) |
|
|
|
|
|
|
|
|
class Cache3D_Buffer(Cache3D_Base): |
|
|
def __init__(self, frame_buffer_max=0, noise_aug_strength=0, generator=None, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.frame_buffer_max = frame_buffer_max |
|
|
self.noise_aug_strength = noise_aug_strength |
|
|
self.generator = generator |
|
|
|
|
|
def update_cache(self, new_image, new_depth, new_w2c, new_mask=None, new_intrinsics=None, depth_alignment=True, alignment_method="non_rigid"): |
|
|
new_image = new_image.to(self.weight_dtype).to(self.device) |
|
|
new_depth = new_depth.to(self.weight_dtype).to(self.device) |
|
|
new_w2c = new_w2c.to(self.weight_dtype).to(self.device) |
|
|
if new_intrinsics is not None: |
|
|
new_intrinsics = new_intrinsics.to(self.weight_dtype).to(self.device) |
|
|
|
|
|
new_depth = torch.nan_to_num(new_depth, nan=1e4) |
|
|
new_depth = torch.clamp(new_depth, min=0, max=1e4) |
|
|
|
|
|
if depth_alignment: |
|
|
target_depth, target_mask = self.render_cache( |
|
|
new_w2c.unsqueeze(1), new_intrinsics.unsqueeze(1), render_depth=True |
|
|
) |
|
|
target_depth, target_mask = target_depth[:, :, 0], target_mask[:, :, 0] |
|
|
if alignment_method == "rigid": |
|
|
new_depth = ( |
|
|
align_depth( |
|
|
new_depth.squeeze(), |
|
|
target_depth.squeeze(), |
|
|
target_mask.bool().squeeze(), |
|
|
) |
|
|
.reshape_as(new_depth) |
|
|
.detach() |
|
|
) |
|
|
elif alignment_method == "non_rigid": |
|
|
with torch.enable_grad(): |
|
|
new_depth = ( |
|
|
align_depth( |
|
|
new_depth.squeeze(), |
|
|
target_depth.squeeze(), |
|
|
target_mask.bool().squeeze(), |
|
|
k=new_intrinsics.squeeze(), |
|
|
c2w=torch.inverse(new_w2c.squeeze()), |
|
|
alignment_method="non_rigid", |
|
|
num_iters=100, |
|
|
lambda_arap=0.1, |
|
|
smoothing_kernel_size=3, |
|
|
) |
|
|
.reshape_as(new_depth) |
|
|
.detach() |
|
|
) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
new_points = unproject_points(new_depth, new_w2c, new_intrinsics, is_depth=self.is_depth).cpu() |
|
|
new_image = new_image.cpu() |
|
|
|
|
|
if self.filter_points_threshold < 1.0: |
|
|
B, F, N, V, C, H, W = self.input_image.shape |
|
|
new_depth = new_depth.reshape(-1, 1, H, W) |
|
|
depth_mask = reliable_depth_mask_range_batch(new_depth, ratio_thresh=self.filter_points_threshold).reshape(B, 1, H, W) |
|
|
if new_mask is None: |
|
|
new_mask = depth_mask.to("cpu") |
|
|
else: |
|
|
new_mask = new_mask * depth_mask.to(new_mask.device) |
|
|
if new_mask is not None: |
|
|
new_mask = new_mask.cpu() |
|
|
if self.frame_buffer_max > 1: |
|
|
if self.input_image.shape[2] < self.frame_buffer_max: |
|
|
self.input_image = torch.cat([new_image[:, None, None, None], self.input_image], 2) |
|
|
self.input_points = torch.cat([new_points[:, None, None, None], self.input_points], 2) |
|
|
if self.input_mask is not None: |
|
|
self.input_mask = torch.cat([new_mask[:, None, None, None], self.input_mask], 2) |
|
|
else: |
|
|
self.input_image[:, :, 0] = new_image[:, None, None] |
|
|
self.input_points[:, :, 0] = new_points[:, None, None] |
|
|
if self.input_mask is not None: |
|
|
self.input_mask[:, :, 0] = new_mask[:, None, None] |
|
|
else: |
|
|
self.input_image = new_image[:, None, None, None] |
|
|
self.input_points = new_points[:, None, None, None] |
|
|
|
|
|
|
|
|
def render_cache( |
|
|
self, |
|
|
target_w2cs, |
|
|
target_intrinsics, |
|
|
render_depth: bool = False, |
|
|
start_frame_idx: int = 0, |
|
|
): |
|
|
assert start_frame_idx == 0, "start_frame_idx must be 0 for Cache3D_Buffer" |
|
|
|
|
|
output_device = target_w2cs.device |
|
|
target_w2cs = target_w2cs.to(self.weight_dtype).to(self.device) |
|
|
target_intrinsics = target_intrinsics.to(self.weight_dtype).to(self.device) |
|
|
pixels, masks = super().render_cache( |
|
|
target_w2cs, target_intrinsics, render_depth |
|
|
) |
|
|
pixels = pixels.to(output_device) |
|
|
masks = masks.to(output_device) |
|
|
if not render_depth: |
|
|
noise = torch.randn(pixels.shape, generator=self.generator, device=pixels.device, dtype=pixels.dtype) |
|
|
per_buffer_noise = ( |
|
|
torch.arange(start=pixels.shape[2] - 1, end=-1, step=-1, device=pixels.device) |
|
|
* self.noise_aug_strength |
|
|
) |
|
|
pixels = pixels + noise * per_buffer_noise.reshape(1, 1, -1, 1, 1, 1) |
|
|
return pixels, masks |
|
|
|
|
|
|
|
|
class Cache3D_BufferSelector(Cache3D_Base): |
|
|
def __init__(self, frame_buffer_max=1, mask_for_max_buffer_model: bool = True, mask_full_threshold: float = 0.9, **kwargs): |
|
|
"""A buffer that holds many initialization frames and selects top-K by overlap per target. |
|
|
|
|
|
This class does not support update_cache. It assumes multiple source frames are provided |
|
|
at initialization time via the 'N' (buffer) dimension. |
|
|
""" |
|
|
super().__init__(**kwargs) |
|
|
self.frame_buffer_max = max(int(frame_buffer_max), 1) |
|
|
self.mask_for_max_buffer_model = bool(mask_for_max_buffer_model) |
|
|
self.mask_full_threshold = float(mask_full_threshold) |
|
|
|
|
|
def update_cache(self, *args, **kwargs): |
|
|
raise NotImplementedError("Cache3D_BufferSelector does not support update_cache") |
|
|
|
|
|
def render_cache( |
|
|
self, |
|
|
target_w2cs, |
|
|
target_intrinsics, |
|
|
render_depth: bool = False, |
|
|
start_frame_idx: int = 0, |
|
|
): |
|
|
|
|
|
output_device = target_w2cs.device |
|
|
target_w2cs = target_w2cs.to(self.weight_dtype).to(self.device) |
|
|
target_intrinsics = target_intrinsics.to(self.weight_dtype).to(self.device) |
|
|
|
|
|
pixels_all, masks_all = super().render_cache( |
|
|
target_w2cs, target_intrinsics, render_depth, start_frame_idx |
|
|
) |
|
|
|
|
|
B, F, N = pixels_all.shape[0], pixels_all.shape[1], pixels_all.shape[2] |
|
|
if N <= self.frame_buffer_max: |
|
|
pixels_sel, masks_sel = pixels_all, masks_all |
|
|
else: |
|
|
|
|
|
|
|
|
overlap_scores = masks_all.sum(dim=(1, 3, 4, 5)) |
|
|
|
|
|
|
|
|
k = min(self.frame_buffer_max, N) |
|
|
topk_indices = overlap_scores.topk(k=k, dim=1, largest=True, sorted=True).indices |
|
|
|
|
|
|
|
|
selected_pixels_list = [] |
|
|
selected_masks_list = [] |
|
|
for b in range(B): |
|
|
idx_b = topk_indices[b] |
|
|
selected_pixels_list.append(pixels_all[b : b + 1, :, idx_b]) |
|
|
selected_masks_list.append(masks_all[b : b + 1, :, idx_b]) |
|
|
|
|
|
pixels_sel = torch.cat(selected_pixels_list, dim=0) |
|
|
masks_sel = torch.cat(selected_masks_list, dim=0) |
|
|
if self.mask_for_max_buffer_model and not render_depth: |
|
|
|
|
|
_masks = masks_sel.mean(dim=[3, 4, 5]) |
|
|
Bm, Fm, Nm = _masks.shape |
|
|
_masks_flat = rearrange(_masks, "b t n -> (b t) n") |
|
|
result_mask = torch.zeros_like(_masks_flat) |
|
|
near_full = _masks_flat >= self.mask_full_threshold |
|
|
has_near_full = near_full.any(dim=1) |
|
|
|
|
|
indices = near_full.float().argmax(dim=1) |
|
|
valid_rows = torch.arange(near_full.size(0), device=_masks_flat.device)[has_near_full] |
|
|
valid_indices = indices[has_near_full] |
|
|
result_mask[valid_rows, valid_indices] = 1 |
|
|
|
|
|
invalid_rows = torch.arange(near_full.size(0), device=_masks_flat.device)[~has_near_full] |
|
|
if invalid_rows.numel() > 0: |
|
|
result_mask[invalid_rows] = 1 |
|
|
result_mask = rearrange(result_mask, "(b t) n -> b t n", b=Bm, t=Fm) |
|
|
|
|
|
result_mask_expanded = result_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
|
pixels_sel = (pixels_sel + 1) * result_mask_expanded - 1 |
|
|
masks_sel = masks_sel * result_mask_expanded |
|
|
return pixels_sel.to(output_device), masks_sel.to(output_device) |
|
|
|
|
|
|
|
|
class Cache4D(Cache3D_Base): |
|
|
def __init__(self, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
def update_cache(self, **kwargs): |
|
|
raise NotImplementedError |
|
|
|
|
|
def render_cache(self, target_w2cs, target_intrinsics, render_depth=False, start_frame_idx=0): |
|
|
rendered_warp_images, rendered_warp_masks = super().render_cache(target_w2cs, target_intrinsics, render_depth, start_frame_idx) |
|
|
return rendered_warp_images, rendered_warp_masks |