|
|
import os |
|
|
from typing import List, Union |
|
|
from copy import deepcopy |
|
|
|
|
|
import math |
|
|
|
|
|
|
|
|
import PIL |
|
|
import numpy as np |
|
|
from einops import repeat |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torchvision.transforms as tvf |
|
|
|
|
|
|
|
|
from diffusers.utils import export_to_gif |
|
|
|
|
|
import sys |
|
|
sys.path.append("./extern/CUT3R") |
|
|
from extern.CUT3R.surfel_inference import run_inference_from_pil |
|
|
from extern.CUT3R.add_ckpt_path import add_path_to_dust3r |
|
|
from extern.CUT3R.src.dust3r.model import ARCroco3DStereo |
|
|
|
|
|
from modeling import VMemWrapper, VMemModel, VMemModelParams |
|
|
from modeling.modules.autoencoder import AutoEncoder |
|
|
from modeling.sampling import DDPMDiscretization, DiscreteDenoiser, create_samplers |
|
|
from modeling.modules.conditioner import CLIPConditioner |
|
|
from utils import (encode_vae_image, |
|
|
encode_image, |
|
|
visualize_depth, |
|
|
visualize_surfels, |
|
|
tensor_to_pil, |
|
|
Octree, |
|
|
Surfel, |
|
|
get_plucker_coordinates, |
|
|
do_sample, |
|
|
average_camera_pose) |
|
|
|
|
|
|
|
|
|
|
|
ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) |
|
|
|
|
|
|
|
|
|
|
|
class VMemPipeline: |
|
|
def __init__(self, config, device="cpu", dtype=torch.float32): |
|
|
self.config = config |
|
|
|
|
|
model_path = self.config.model.get("model_path", None) |
|
|
|
|
|
self.model = VMemModel(VMemModelParams()).to(device, dtype) |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
state_dict = torch.load(hf_hub_download(repo_id=model_path, filename="vmem_weights.pth"), map_location='cpu') |
|
|
state_dict = {k.replace("module.", "") if "module." in k else k: v for k, v in state_dict.items()} |
|
|
|
|
|
|
|
|
self.model.load_state_dict(state_dict, strict=True) |
|
|
|
|
|
|
|
|
self.model_wrapper = VMemWrapper(self.model) |
|
|
self.model_wrapper.eval() |
|
|
|
|
|
|
|
|
self.vae = AutoEncoder(chunk_size=1).to(device, dtype) |
|
|
self.vae.eval() |
|
|
self.image_encoder = CLIPConditioner().to(device, dtype) |
|
|
self.image_encoder.eval() |
|
|
|
|
|
self.discretization = DDPMDiscretization() |
|
|
self.denoiser = DiscreteDenoiser(discretization=self.discretization, num_idx=1000, device=device) |
|
|
self.sampler = create_samplers(guider_types=config.model.guider_types, |
|
|
discretization=self.discretization, |
|
|
num_frames=config.model.num_frames, |
|
|
num_steps=config.model.inference_num_steps, |
|
|
cfg_min=config.model.cfg_min, |
|
|
device=device) |
|
|
|
|
|
|
|
|
|
|
|
self.dtype = dtype |
|
|
self.device = device |
|
|
|
|
|
|
|
|
surfel_model_path = hf_hub_download(repo_id=self.config.surfel.model_path, filename="cut3r_512_dpt_4_64.pth") |
|
|
print(f"Loading model from {surfel_model_path}...") |
|
|
add_path_to_dust3r(surfel_model_path) |
|
|
self.surfel_model = ARCroco3DStereo.from_pretrained(surfel_model_path).to(device) |
|
|
self.surfel_model.eval() |
|
|
|
|
|
|
|
|
|
|
|
from extern.CUT3R.cloud_opt.dust3r_opt import global_aligner, GlobalAlignerMode |
|
|
self.GlobalAlignerMode = GlobalAlignerMode |
|
|
self.global_aligner = global_aligner |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.use_non_maximum_suppression = self.config.model.use_non_maximum_suppression |
|
|
|
|
|
self.context_num_frames = self.config.model.context_num_frames |
|
|
self.target_num_frames = self.config.model.target_num_frames |
|
|
|
|
|
self.original_height = self.config.model.original_height |
|
|
self.original_width = self.config.model.original_width |
|
|
self.height = self.config.model.height |
|
|
self.width = self.config.model.width |
|
|
|
|
|
self.w_ratio = self.width / self.original_width |
|
|
self.h_ratio = self.height / self.original_height |
|
|
|
|
|
self.camera_scale = self.config.model.camera_scale |
|
|
|
|
|
self.latents = [] |
|
|
self.encoder_embeddings = [] |
|
|
self.poses = [] |
|
|
self.Ks = [] |
|
|
self.surfel_Ks = [] |
|
|
self.surfels = [] |
|
|
|
|
|
self.surfel_depths = [] |
|
|
self.surfel_to_timestep = {} |
|
|
self.pil_frames = [] |
|
|
self.visualize_dir = self.config.model.samples_dir |
|
|
if not os.path.exists(self.visualize_dir): |
|
|
os.makedirs(self.visualize_dir) |
|
|
|
|
|
self.global_step = 0 |
|
|
|
|
|
|
|
|
def reset(self): |
|
|
self.rgb_vae_latents = [] |
|
|
self.rgb_encoder_embeddings = [] |
|
|
self.poses = [] |
|
|
self.focal_lengths = [] |
|
|
self.surfels = [] |
|
|
self.surfel_Ks = [] |
|
|
self.surfel_depths = [] |
|
|
self.Ks = [] |
|
|
self.surfel_to_timestep = {} |
|
|
self.all_pil_frames = [] |
|
|
self.global_step = 0 |
|
|
|
|
|
|
|
|
def initialize(self, image, c2w, K): |
|
|
""" |
|
|
Initialize the pipeline with a single image and camera parameters. |
|
|
This method sets up internal state without generating additional frames. |
|
|
|
|
|
Args: |
|
|
image: Tensor of input image [1, C, H, W] |
|
|
c2w: Camera-to-world matrix (4x4) |
|
|
K: Camera intrinsic matrix |
|
|
|
|
|
Returns: |
|
|
PIL image of the initial frame |
|
|
""" |
|
|
|
|
|
self.reset() |
|
|
|
|
|
|
|
|
if isinstance(image, torch.Tensor): |
|
|
image_tensor = image |
|
|
else: |
|
|
|
|
|
image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 127.5 - 1.0 |
|
|
image_tensor = image_tensor.unsqueeze(0).to(self.device, self.dtype) |
|
|
|
|
|
|
|
|
self.latents = [encode_vae_image(image_tensor, self.vae, self.device, self.dtype).detach().cpu().numpy()[0]] |
|
|
|
|
|
|
|
|
self.encoder_embeddings = [encode_image(image_tensor, self.image_encoder, self.device, self.dtype).detach().cpu().numpy()[0]] |
|
|
|
|
|
|
|
|
self.c2ws = [c2w] |
|
|
self.Ks = [K] |
|
|
|
|
|
|
|
|
pil_frame = tensor_to_pil(image_tensor) |
|
|
self.pil_frames = [pil_frame] |
|
|
|
|
|
|
|
|
|
|
|
return pil_frame |
|
|
|
|
|
def geodesic_distance(self, |
|
|
camera_pose1, |
|
|
camera_pose2, |
|
|
weight_translation=1,): |
|
|
""" |
|
|
Computes the geodesic distance between two camera poses in SE(3). |
|
|
|
|
|
Parameters: |
|
|
extrinsic1 (torch.Tensor): 4x4 extrinsic matrix of the first pose. |
|
|
extrinsic2 (torch.Tensor): 4x4 extrinsic matrix of the second pose. |
|
|
|
|
|
Returns: |
|
|
float: Geodesic distance between the two poses. |
|
|
""" |
|
|
|
|
|
R1 = camera_pose1[:3, :3] |
|
|
t1 = camera_pose1[:3, 3] |
|
|
R2 = camera_pose2[:3, :3] |
|
|
t2 = camera_pose2[:3, 3] |
|
|
|
|
|
|
|
|
translation_distance = torch.norm(t1 - t2) |
|
|
|
|
|
|
|
|
R_relative = torch.matmul(R1.T, R2) |
|
|
|
|
|
|
|
|
trace_value = torch.trace(R_relative) |
|
|
|
|
|
trace_value = torch.clamp(trace_value, -1.0, 3.0) |
|
|
angular_distance = torch.acos((trace_value - 1) / 2) |
|
|
|
|
|
|
|
|
geodesic_dist = translation_distance*weight_translation + angular_distance |
|
|
|
|
|
return geodesic_dist |
|
|
|
|
|
def render_surfels_to_image( |
|
|
self, |
|
|
surfels, |
|
|
poses, |
|
|
focal_lengths, |
|
|
principal_points, |
|
|
image_width, |
|
|
image_height, |
|
|
disk_resolution=16 |
|
|
): |
|
|
""" |
|
|
Renders oriented surfels into a 2D RGB image with a simple z-buffer. |
|
|
Each surfel is treated as a 2D disk in 3D, oriented by its normal. |
|
|
The disk is approximated by a polygon of 'disk_resolution' segments. |
|
|
|
|
|
Args: |
|
|
surfels (list): List of Surfel objects, each having: |
|
|
- position: (x, y, z) in world coords |
|
|
- normal: (nx, ny, nz) |
|
|
- radius: float, radius in world units |
|
|
poses (torch.Tensor): Tensor of poses, shape [4, 4] |
|
|
focal_lengths (torch.Tensor): Tensor of focal lengths, shape [2] |
|
|
principal_points (torch.Tensor): Tensor of principal points, shape [2] |
|
|
image_width, image_height (int): output image size |
|
|
disk_resolution (int): number of segments for approximating each disk |
|
|
|
|
|
Returns: |
|
|
Dictionary containing: |
|
|
- depth: depth map |
|
|
- surfel_index_map: map of surfel indices |
|
|
- cos_value_map: map of cosine values between view and normal directions |
|
|
""" |
|
|
if isinstance(focal_lengths, torch.Tensor): |
|
|
focal_lengths = focal_lengths.detach().cpu().numpy() |
|
|
if isinstance(principal_points, torch.Tensor): |
|
|
principal_points = principal_points.detach().cpu().numpy() |
|
|
if isinstance(poses, torch.Tensor): |
|
|
poses = poses.detach().cpu().numpy() |
|
|
|
|
|
|
|
|
surfel_index_map = np.full((image_height, image_width), -1, dtype=np.int32) |
|
|
z_buffer = np.full((image_height, image_width), np.inf, dtype=np.float32) |
|
|
cos_buffer = np.zeros((image_height, image_width), dtype=np.float32) |
|
|
|
|
|
|
|
|
fx, fy, cx, cy = focal_lengths[0], focal_lengths[1], principal_points[0], principal_points[1] |
|
|
R = poses[0:3, 0:3] |
|
|
t = poses[0:3, 3] |
|
|
|
|
|
|
|
|
|
|
|
near_z = 0.1 |
|
|
far_z = 1000.0 |
|
|
|
|
|
|
|
|
positions = np.array([s.position for s in surfels]) |
|
|
positions_h = np.concatenate([positions, np.ones((len(positions), 1))], axis=1) |
|
|
|
|
|
|
|
|
extrinsics = np.zeros((4, 4)) |
|
|
extrinsics[0:3, 0:3] = np.linalg.inv(R) |
|
|
extrinsics[0:3, 3] = -np.linalg.inv(R) @ t |
|
|
extrinsics[3, 3] = 1 |
|
|
|
|
|
|
|
|
cam_points = (extrinsics @ positions_h.T).T |
|
|
cam_points = cam_points[:, :3] / cam_points[:, 3:] |
|
|
|
|
|
|
|
|
in_front = cam_points[:, 2] > near_z |
|
|
behind_far = cam_points[:, 2] < far_z |
|
|
|
|
|
|
|
|
screen_x = fx * (cam_points[:, 0] / cam_points[:, 2]) + cx |
|
|
screen_y = fy * (cam_points[:, 1] / cam_points[:, 2]) + cy |
|
|
|
|
|
|
|
|
margin = 50 |
|
|
in_screen_x = (screen_x >= -margin) & (screen_x < image_width + margin) |
|
|
in_screen_y = (screen_y >= -margin) & (screen_y < image_height + margin) |
|
|
|
|
|
|
|
|
visible_mask = in_front & behind_far & in_screen_x & in_screen_y |
|
|
visible_indices = np.where(visible_mask)[0] |
|
|
|
|
|
def point_in_polygon_2d(px, py, polygon): |
|
|
"""Fast point-in-polygon test using ray casting""" |
|
|
inside = False |
|
|
n = len(polygon) |
|
|
j = n - 1 |
|
|
for i in range(n): |
|
|
if (((polygon[i][1] > py) != (polygon[j][1] > py)) and |
|
|
(px < (polygon[j][0] - polygon[i][0]) * (py - polygon[i][1]) / |
|
|
(polygon[j][1] - polygon[i][1] + 1e-15) + polygon[i][0])): |
|
|
inside = not inside |
|
|
j = i |
|
|
return inside |
|
|
|
|
|
|
|
|
angles = np.linspace(0, 2*math.pi, disk_resolution, endpoint=False) |
|
|
cos_angles = np.cos(angles) |
|
|
sin_angles = np.sin(angles) |
|
|
|
|
|
|
|
|
for idx in visible_indices: |
|
|
surfel = surfels[idx] |
|
|
px, py, pz = surfel.position |
|
|
nx, ny, nz = surfel.normal |
|
|
radius = surfel.radius |
|
|
|
|
|
|
|
|
normal = np.array([nx, ny, nz], dtype=float) |
|
|
norm_len = np.linalg.norm(normal) |
|
|
if norm_len < 1e-12: |
|
|
continue |
|
|
normal /= norm_len |
|
|
|
|
|
|
|
|
point_direction = (px, py, pz) - t |
|
|
point_direction = point_direction / np.linalg.norm(point_direction) |
|
|
cos_value = np.dot(point_direction, normal) |
|
|
|
|
|
|
|
|
if cos_value < 0: |
|
|
continue |
|
|
|
|
|
|
|
|
up = np.array([0, 0, 1], dtype=float) |
|
|
if abs(np.dot(normal, up)) > 0.9: |
|
|
up = np.array([0, 1, 0], dtype=float) |
|
|
xAxis = np.cross(normal, up) |
|
|
xAxis /= np.linalg.norm(xAxis) |
|
|
yAxis = np.cross(normal, xAxis) |
|
|
yAxis /= np.linalg.norm(yAxis) |
|
|
|
|
|
|
|
|
offsets = radius * (cos_angles[:, None] * xAxis + sin_angles[:, None] * yAxis) |
|
|
circle_points = positions[idx] + offsets |
|
|
|
|
|
|
|
|
circle_points_h = np.concatenate([circle_points, np.ones((len(circle_points), 1))], axis=1) |
|
|
cam_circle = (extrinsics @ circle_points_h.T).T |
|
|
depths = cam_circle[:, 2] |
|
|
valid_mask = depths > 0 |
|
|
if not np.any(valid_mask): |
|
|
continue |
|
|
|
|
|
screen_points = np.zeros((len(circle_points), 2)) |
|
|
screen_points[:, 0] = fx * (cam_circle[:, 0] / depths) + cx |
|
|
screen_points[:, 1] = fy * (cam_circle[:, 1] / depths) + cy |
|
|
|
|
|
|
|
|
valid_points = screen_points[valid_mask] |
|
|
if len(valid_points) < 3: |
|
|
continue |
|
|
|
|
|
min_x = max(0, int(np.floor(np.min(valid_points[:, 0])))) |
|
|
max_x = min(image_width - 1, int(np.ceil(np.max(valid_points[:, 0])))) |
|
|
min_y = max(0, int(np.floor(np.min(valid_points[:, 1])))) |
|
|
max_y = min(image_height - 1, int(np.ceil(np.max(valid_points[:, 1])))) |
|
|
|
|
|
|
|
|
avg_depth = float(np.mean(depths[valid_mask])) |
|
|
|
|
|
|
|
|
for py_ in range(min_y, max_y + 1): |
|
|
for px_ in range(min_x, max_x + 1): |
|
|
if point_in_polygon_2d(px_, py_, valid_points): |
|
|
if avg_depth < z_buffer[py_, px_]: |
|
|
z_buffer[py_, px_] = avg_depth |
|
|
surfel_index_map[py_, px_] = idx |
|
|
cos_buffer[py_, px_] = cos_value |
|
|
|
|
|
|
|
|
depth = z_buffer |
|
|
depth[depth == np.inf] = 0 |
|
|
|
|
|
return { |
|
|
"depth": depth, |
|
|
"surfel_index_map": surfel_index_map, |
|
|
"cos_value_map": cos_buffer |
|
|
} |
|
|
|
|
|
def get_frame_distribution(self, |
|
|
n, |
|
|
ratios): |
|
|
""" |
|
|
Given: |
|
|
- an integer n, |
|
|
- a list of k ratios whose sum is 1 (k <= n), |
|
|
return a list of k integers [x1, x2, ..., xk], |
|
|
such that each xi >= 1, sum(xi) = n, and |
|
|
the xi are as proportional to ratios as possible. |
|
|
""" |
|
|
k = len(ratios) |
|
|
if k > n: |
|
|
|
|
|
result = [0] * k |
|
|
sort_indices = np.argsort(ratios)[::-1] |
|
|
for sort_index in sort_indices[:n]: |
|
|
result[sort_index] = 1 |
|
|
return result |
|
|
|
|
|
|
|
|
result = [1] * k |
|
|
|
|
|
|
|
|
leftover = n - k |
|
|
if leftover == 0: |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
products = [r * leftover for r in ratios] |
|
|
floored = [int(p // 1) for p in products] |
|
|
|
|
|
sum_floors = sum(floored) |
|
|
leftover2 = leftover - sum_floors |
|
|
|
|
|
|
|
|
for i in range(k): |
|
|
result[i] += floored[i] |
|
|
|
|
|
|
|
|
remainders = [(p - f, i) for i, (p, f) in enumerate(zip(products, floored))] |
|
|
remainders.sort(key=lambda x: x[0], reverse=True) |
|
|
|
|
|
|
|
|
for j in range(leftover2): |
|
|
_, idx = remainders[j] |
|
|
result[idx] = 1 |
|
|
|
|
|
return result |
|
|
|
|
|
def process_retrieved_spatial_information(self, retrieved_spatial_information): |
|
|
|
|
|
timestep_count = {} |
|
|
|
|
|
surfel_index_map = retrieved_spatial_information["surfel_index_map"] |
|
|
cos_value_map = retrieved_spatial_information["cos_value_map"] |
|
|
depth_map = retrieved_spatial_information["depth"] |
|
|
filtered_cos_value = cos_value_map[surfel_index_map >= 0] |
|
|
filtered_surfel_index = surfel_index_map[surfel_index_map >= 0] |
|
|
filtered_depth = depth_map[surfel_index_map >= 0] |
|
|
assert len(filtered_cos_value) == len(filtered_surfel_index), "filtered_cos_value and filtered_surfel_index should have the same length" |
|
|
for j in range(len(filtered_surfel_index)): |
|
|
cos_value = filtered_cos_value[j] |
|
|
depth_value = filtered_depth[j] |
|
|
if cos_value < 0: |
|
|
continue |
|
|
surfel_index = filtered_surfel_index[j] |
|
|
timesteps = self.surfel_to_timestep[surfel_index] |
|
|
|
|
|
for timestep in timesteps: |
|
|
|
|
|
if timestep not in timestep_count: |
|
|
timestep_count[timestep] = cos_value/(1+depth_value) |
|
|
timestep_count[timestep] += cos_value/(1+depth_value) |
|
|
|
|
|
|
|
|
|
|
|
timestep_count_values = np.array(list(timestep_count.values())) |
|
|
timestep_count_ratios = timestep_count_values / np.sum(timestep_count_values) |
|
|
timestep_weights = {k: timestep_count_ratios[i] for i, k in enumerate(timestep_count)} |
|
|
num_retrieved_frames = min(self.config.model.context_num_frames+10, len(timestep_weights)) |
|
|
frame_count = self.get_frame_distribution(num_retrieved_frames, list(timestep_weights.values())) |
|
|
frame_count = {k: int(v) for k, v in zip(timestep_count.keys(), frame_count)} |
|
|
|
|
|
|
|
|
timestep_weights = sorted(timestep_weights.items(), key=lambda x: x[0]) |
|
|
frame_count = sorted(frame_count.items(), key=lambda x: x[0]) |
|
|
|
|
|
|
|
|
|
|
|
return timestep_weights, frame_count |
|
|
|
|
|
|
|
|
def get_context_info(self, target_c2ws, use_non_maximum_suppression=None): |
|
|
"""Get context information for novel view synthesis. |
|
|
|
|
|
Args: |
|
|
target_c2ws: Target camera-to-world matrices |
|
|
Ks: Camera intrinsic matrices |
|
|
current_timestep: Current timestep (used in temporal mode) |
|
|
|
|
|
Returns: |
|
|
Dictionary containing context information for the target view |
|
|
""" |
|
|
|
|
|
def prepare_context_data(indices): |
|
|
c2ws = [self.c2ws[i] for i in indices] |
|
|
latents = [torch.from_numpy(self.latents[i]).to(self.device, self.dtype) for i in indices] |
|
|
embeddings = [torch.from_numpy(self.encoder_embeddings[i]).to(self.device, self.dtype) for i in indices] |
|
|
intrinsics = [self.Ks[i] for i in indices] |
|
|
return c2ws, latents, embeddings, intrinsics, indices |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(self.pil_frames) == 1: |
|
|
context_time_indices = [0] |
|
|
else: |
|
|
|
|
|
average_c2w = average_camera_pose(target_c2ws[-self.config.model.context_num_frames//4:]) |
|
|
transformed_average_c2w = self.get_transformed_c2ws(average_c2w) |
|
|
target_K = np.mean(self.surfel_Ks, axis=0) |
|
|
|
|
|
retrieved_info = self.render_surfels_to_image( |
|
|
self.surfels, |
|
|
transformed_average_c2w, |
|
|
[target_K*0.65] * 2, |
|
|
principal_points=(int(self.config.surfel.width/2), int(self.config.surfel.height/2)), |
|
|
image_width=int(self.config.surfel.width), |
|
|
image_height=int(self.config.surfel.height) |
|
|
) |
|
|
_, frame_count = self.process_retrieved_spatial_information(retrieved_info) |
|
|
if self.config.inference.visualize: |
|
|
visualize_depth(retrieved_info["depth"], |
|
|
visualization_dir=self.visualize_dir, |
|
|
file_name=f"retrieved_depth_surfels.png", |
|
|
size=(self.width, self.height)) |
|
|
|
|
|
|
|
|
|
|
|
candidates = [] |
|
|
for frame, count in frame_count: |
|
|
candidates.extend([frame] * count) |
|
|
indices_to_frame = { |
|
|
i: frame for i, frame in enumerate(candidates) |
|
|
} |
|
|
|
|
|
|
|
|
distances = [self.geodesic_distance(torch.from_numpy(average_c2w).to(self.device, self.dtype), |
|
|
torch.from_numpy(self.c2ws[frame]).to(self.device, self.dtype), |
|
|
weight_translation=self.config.model.translation_distance_weight).item() |
|
|
for frame in candidates] |
|
|
|
|
|
sorted_indices = torch.argsort(torch.tensor(distances)) |
|
|
sorted_frames = [indices_to_frame[int(i.item())] for i in sorted_indices] |
|
|
max_frames = min(self.config.model.context_num_frames, len(candidates), len(self.latents)) |
|
|
|
|
|
|
|
|
is_second_step = len(self.pil_frames) == 5 |
|
|
|
|
|
|
|
|
|
|
|
if use_non_maximum_suppression is None: |
|
|
use_non_maximum_suppression = self.use_non_maximum_suppression |
|
|
|
|
|
if use_non_maximum_suppression: |
|
|
if is_second_step: |
|
|
|
|
|
pairwise_distances = [] |
|
|
for i in range(len(self.c2ws)): |
|
|
for j in range(i+1, len(self.c2ws)): |
|
|
sim = self.geodesic_distance( |
|
|
torch.from_numpy(np.array(self.c2ws[i])).to(self.device, self.dtype), |
|
|
torch.from_numpy(np.array(self.c2ws[j])).to(self.device, self.dtype), |
|
|
weight_translation=self.config.model.translation_distance_weight |
|
|
) |
|
|
pairwise_distances.append(sim.item()) |
|
|
|
|
|
if pairwise_distances: |
|
|
|
|
|
pairwise_distances.sort() |
|
|
percentile_idx = int(len(pairwise_distances) * 0.5) |
|
|
self.initial_threshold = pairwise_distances[percentile_idx] |
|
|
else: |
|
|
self.initial_threshold = 1 |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
self.initial_threshold = 1e8 |
|
|
|
|
|
selected_indices = [] |
|
|
current_threshold = self.initial_threshold |
|
|
|
|
|
|
|
|
selected_indices.append(sorted_frames[0]) |
|
|
if not use_non_maximum_suppression: |
|
|
selected_indices.append(len(self.c2ws) - 1) |
|
|
|
|
|
|
|
|
while len(selected_indices) < max_frames and current_threshold >= 1e-5 and use_non_maximum_suppression: |
|
|
|
|
|
for idx in sorted_frames[1:]: |
|
|
if len(selected_indices) >= max_frames: |
|
|
break |
|
|
|
|
|
|
|
|
is_too_similar = False |
|
|
for selected_idx in selected_indices: |
|
|
similarity = self.geodesic_distance( |
|
|
torch.from_numpy(np.array(self.c2ws[idx])).to(self.device, self.dtype), |
|
|
torch.from_numpy(np.array(self.c2ws[selected_idx])).to(self.device, self.dtype), |
|
|
weight_translation=self.config.model.translation_distance_weight |
|
|
) |
|
|
if similarity < current_threshold: |
|
|
is_too_similar = True |
|
|
break |
|
|
|
|
|
|
|
|
if not is_too_similar: |
|
|
selected_indices.append(idx) |
|
|
|
|
|
|
|
|
if len(selected_indices) < max_frames: |
|
|
current_threshold /= 1.2 |
|
|
else: |
|
|
break |
|
|
|
|
|
|
|
|
if len(selected_indices) < max_frames: |
|
|
available_indices = [] |
|
|
for idx in sorted_frames: |
|
|
if idx not in selected_indices: |
|
|
available_indices.append(idx) |
|
|
selected_indices.extend(available_indices[:max_frames-len(selected_indices)]) |
|
|
|
|
|
|
|
|
context_time_indices = torch.from_numpy(np.array(selected_indices)) |
|
|
context_data = prepare_context_data(context_time_indices) |
|
|
|
|
|
(context_c2ws, context_latents, context_encoder_embeddings, context_Ks, context_time_indices) = context_data |
|
|
|
|
|
|
|
|
return { |
|
|
"context_c2ws": torch.from_numpy(np.array(context_c2ws)).to(self.device, self.dtype), |
|
|
"context_latents": torch.stack(context_latents).to(self.device, self.dtype), |
|
|
"context_encoder_embeddings": torch.stack(context_encoder_embeddings).to(self.device, self.dtype), |
|
|
"context_Ks": torch.from_numpy(np.array(context_Ks)).to(self.device, self.dtype), |
|
|
"context_time_indices": context_time_indices, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def merge_surfels( |
|
|
self, |
|
|
new_surfels: list, |
|
|
current_timestep: str, |
|
|
existing_surfels: list, |
|
|
existing_surfel_to_timestep: dict, |
|
|
position_threshold: Union[float, None] = None, |
|
|
normal_threshold: float = 0.7, |
|
|
max_points_per_node: int = 10 |
|
|
): |
|
|
|
|
|
assert len(existing_surfels) == len(existing_surfel_to_timestep), ( |
|
|
"existing_surfels and existing_surfel_to_timestep should have the same length" |
|
|
) |
|
|
|
|
|
|
|
|
if position_threshold is None: |
|
|
|
|
|
all_radii = np.array([s.radius for s in existing_surfels + new_surfels]) |
|
|
if len(all_radii) > 0: |
|
|
|
|
|
mean_radius = np.mean(all_radii) |
|
|
std_radius = np.std(all_radii) |
|
|
|
|
|
position_threshold = mean_radius + 0.5 * std_radius |
|
|
else: |
|
|
|
|
|
position_threshold = 0.025 |
|
|
|
|
|
positions = np.array([s.position for s in existing_surfels]) |
|
|
normals = np.array([s.normal for s in existing_surfels]) |
|
|
|
|
|
if len(positions) > 0: |
|
|
octree = Octree(positions, max_points=max_points_per_node) |
|
|
else: |
|
|
octree = None |
|
|
|
|
|
|
|
|
filtered_surfels = [] |
|
|
|
|
|
merge_count = 0 |
|
|
for new_surfel in new_surfels: |
|
|
is_merged = False |
|
|
if octree is not None: |
|
|
neighbor_indices = octree.query_ball_point(new_surfel.position, position_threshold) |
|
|
else: |
|
|
neighbor_indices = [] |
|
|
|
|
|
for idx in neighbor_indices: |
|
|
if np.dot(normals[idx], new_surfel.normal) > normal_threshold: |
|
|
if current_timestep not in existing_surfel_to_timestep[idx]: |
|
|
existing_surfel_to_timestep[idx].append(current_timestep) |
|
|
is_merged = True |
|
|
merge_count += 1 |
|
|
break |
|
|
|
|
|
if not is_merged: |
|
|
filtered_surfels.append(new_surfel) |
|
|
|
|
|
print(f"merge_count: {merge_count}") |
|
|
return filtered_surfels, existing_surfel_to_timestep |
|
|
|
|
|
def pointmap_to_surfels(self, |
|
|
pointmap: torch.Tensor, |
|
|
focal_lengths: torch.Tensor, |
|
|
depths: torch.Tensor, |
|
|
confs: torch.Tensor, |
|
|
poses: torch.Tensor, |
|
|
radius_scale: float = 0.5, |
|
|
estimate_normals: bool = True): |
|
|
""" |
|
|
Vectorized version of pointmap to surfels conversion. |
|
|
All operations are performed on the specified device (self.device) until final numpy conversion. |
|
|
""" |
|
|
if isinstance(poses, np.ndarray): |
|
|
poses = torch.from_numpy(poses).to(self.device) |
|
|
if isinstance(focal_lengths, np.ndarray): |
|
|
focal_lengths = torch.from_numpy(focal_lengths).to(self.device) |
|
|
if isinstance(depths, np.ndarray): |
|
|
depths = torch.from_numpy(depths).to(self.device) |
|
|
if isinstance(confs, np.ndarray): |
|
|
confs = torch.from_numpy(confs).to(self.device) |
|
|
|
|
|
|
|
|
pointmap = pointmap.to(self.device) |
|
|
focal_lengths = focal_lengths.to(self.device) |
|
|
depths = depths.to(self.device) |
|
|
confs = confs.to(self.device) |
|
|
poses = poses.to(self.device) |
|
|
|
|
|
if len(focal_lengths) == 2: |
|
|
focal_lengths = torch.mean(focal_lengths, dim=0) |
|
|
|
|
|
|
|
|
if estimate_normals: |
|
|
normal_map = self.estimate_normal_from_pointmap(pointmap) |
|
|
else: |
|
|
normal_map = torch.zeros_like(pointmap) |
|
|
|
|
|
|
|
|
|
|
|
depth_threshold = torch.quantile(depths, 0.999) |
|
|
valid_mask = (depths <= depth_threshold) & (confs >= self.config.surfel.conf_thresh) |
|
|
|
|
|
|
|
|
positions = pointmap[valid_mask] |
|
|
normals = normal_map[valid_mask] |
|
|
valid_depths = depths[valid_mask] |
|
|
|
|
|
|
|
|
camera_pos = poses[0:3, 3] |
|
|
view_directions = positions - camera_pos.unsqueeze(0) |
|
|
view_directions = F.normalize(view_directions, dim=1) |
|
|
|
|
|
|
|
|
dot_products = torch.sum(view_directions * normals, dim=1) |
|
|
|
|
|
|
|
|
flip_mask = dot_products < 0 |
|
|
normals[flip_mask] = -normals[flip_mask] |
|
|
|
|
|
|
|
|
dot_products = torch.abs(torch.sum(view_directions * normals, dim=1)) |
|
|
|
|
|
|
|
|
adjustment_values = 0.2 + 0.8 * dot_products |
|
|
radii = (radius_scale * valid_depths / focal_lengths / adjustment_values) |
|
|
|
|
|
|
|
|
positions = positions.detach().cpu().numpy() |
|
|
normals = normals.detach().cpu().numpy() |
|
|
radii = radii.detach().cpu().numpy() |
|
|
|
|
|
|
|
|
surfels = [Surfel(pos, norm, rad) for pos, norm, rad in zip(positions, normals, radii)] |
|
|
|
|
|
|
|
|
|
|
|
return surfels |
|
|
|
|
|
def estimate_normal_from_pointmap(self,pointmap: torch.Tensor) -> torch.Tensor: |
|
|
h, w = pointmap.shape[:2] |
|
|
device = pointmap.device |
|
|
dtype = pointmap.dtype |
|
|
|
|
|
|
|
|
normal_map = torch.zeros((h, w, 3), device=device, dtype=dtype) |
|
|
|
|
|
for y in range(h): |
|
|
for x in range(w): |
|
|
|
|
|
if x+1 >= w or y+1 >= h: |
|
|
continue |
|
|
|
|
|
p_center = pointmap[y, x] |
|
|
p_right = pointmap[y, x+1] |
|
|
p_down = pointmap[y+1, x] |
|
|
|
|
|
|
|
|
v1 = p_right - p_center |
|
|
v2 = p_down - p_center |
|
|
|
|
|
v1 = v1 / torch.linalg.norm(v1) |
|
|
v2 = v2 / torch.linalg.norm(v2) |
|
|
|
|
|
|
|
|
n_c = torch.cross(v1, v2) |
|
|
|
|
|
|
|
|
|
|
|
norm_len = torch.linalg.norm(n_c) |
|
|
|
|
|
if norm_len < 1e-8: |
|
|
continue |
|
|
|
|
|
|
|
|
normal_map[y, x] = n_c / norm_len |
|
|
|
|
|
return normal_map |
|
|
|
|
|
def get_transformed_c2ws(self, c2ws=None): |
|
|
if c2ws is None: |
|
|
c2ws = self.c2ws |
|
|
c2ws_transformed = deepcopy(np.array(c2ws)) |
|
|
c2ws_transformed[..., :, [1, 2]] *= -1 |
|
|
return c2ws_transformed |
|
|
|
|
|
def construct_and_store_scene(self, |
|
|
input_images: List[PIL.Image.Image], |
|
|
time_indices, |
|
|
niter = 1000, |
|
|
lr = 0.01, |
|
|
device = 'cuda', |
|
|
): |
|
|
""" |
|
|
Constructs a scene from input images and stores the resulting surfels. |
|
|
|
|
|
Args: |
|
|
input_images: List of PIL images to process |
|
|
time_indices: The time indices for each image |
|
|
niter: Number of iterations for optimization |
|
|
lr: Learning rate for optimization |
|
|
device: Device to run inference on |
|
|
only_last_frame: Whether to only process the last frame |
|
|
""" |
|
|
|
|
|
c2ws_transformed = self.get_transformed_c2ws() |
|
|
|
|
|
|
|
|
scene = run_inference_from_pil( |
|
|
input_images, |
|
|
self.surfel_model, |
|
|
poses=c2ws_transformed, |
|
|
depths=torch.from_numpy(np.array(self.surfel_depths)) if len(self.surfel_depths) > 0 else None, |
|
|
lr = lr, |
|
|
niter = niter, |
|
|
visualize=self.config.inference.visualize_pointcloud, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
|
|
|
pointcloud = torch.cat(scene['point_clouds'], dim=0) |
|
|
confs = torch.cat(scene['confidences'], dim=0) |
|
|
depths = torch.cat(scene['depths'], dim=0) |
|
|
focal_lengths = scene['camera_info']['focal'] |
|
|
self.surfel_Ks.extend([focal_lengths[i] for i in range(len(focal_lengths))]) |
|
|
self.surfel_depths = [depths[i].detach().cpu().numpy() for i in range(len(depths))] |
|
|
|
|
|
pointcloud = pointcloud.permute(0, 3, 1, 2) |
|
|
pointcloud = F.interpolate( |
|
|
pointcloud, |
|
|
scale_factor=self.config.surfel.shrink_factor, |
|
|
mode='bilinear' |
|
|
) |
|
|
pointcloud = pointcloud.permute(0, 2, 3, 1) |
|
|
|
|
|
|
|
|
|
|
|
depths = depths.unsqueeze(1) |
|
|
depths = F.interpolate( |
|
|
depths, |
|
|
scale_factor=self.config.surfel.shrink_factor, |
|
|
mode='bilinear' |
|
|
) |
|
|
depths = depths.squeeze(1) |
|
|
|
|
|
confs = confs.unsqueeze(1) |
|
|
confs = F.interpolate( |
|
|
confs, |
|
|
scale_factor=self.config.surfel.shrink_factor, |
|
|
mode='bilinear' |
|
|
) |
|
|
confs = confs.squeeze(1) |
|
|
|
|
|
|
|
|
|
|
|
start_idx = 0 if len(self.surfels) == 0 else len(pointcloud) - self.config.model.target_num_frames |
|
|
end_idx = len(pointcloud) |
|
|
|
|
|
|
|
|
for frame_idx in range(start_idx, end_idx): |
|
|
surfels = self.pointmap_to_surfels( |
|
|
pointmap=pointcloud[frame_idx], |
|
|
focal_lengths=focal_lengths[frame_idx] * self.config.surfel.shrink_factor, |
|
|
depths=depths[frame_idx], |
|
|
confs=confs[frame_idx], |
|
|
poses=c2ws_transformed[frame_idx], |
|
|
estimate_normals=True, |
|
|
radius_scale=self.config.surfel.radius_scale, |
|
|
) |
|
|
|
|
|
if len(self.surfels) > 0: |
|
|
surfels, self.surfel_to_timestep = self.merge_surfels( |
|
|
new_surfels=surfels, |
|
|
current_timestep=frame_idx, |
|
|
existing_surfels=self.surfels, |
|
|
existing_surfel_to_timestep=self.surfel_to_timestep, |
|
|
|
|
|
normal_threshold=self.config.surfel.merge_normal_threshold |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
num_surfels = len(surfels) |
|
|
surfel_start_index = len(self.surfels) |
|
|
for surfel_index in range(num_surfels): |
|
|
self.surfel_to_timestep[surfel_start_index + surfel_index] = [frame_idx] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.surfels.extend(surfels) |
|
|
|
|
|
if self.config.inference.visualize_surfel: |
|
|
visualize_surfels(self.surfels, draw_normals=True, normal_scale=0.0003) |
|
|
|
|
|
|
|
|
|
|
|
def get_translation_scaling_factor(self, c2ws): |
|
|
|
|
|
""" |
|
|
Args: |
|
|
c2ws: camera-to-world matrices, shape: (N, 4, 4) |
|
|
|
|
|
Returns: |
|
|
translation_scaling_factor: translation scaling factor |
|
|
""" |
|
|
ref_c2ws = c2ws |
|
|
camera_dist_2med = torch.norm( |
|
|
ref_c2ws[:, :3, 3] - ref_c2ws[:, :3, 3].median(0, keepdim=True).values, |
|
|
dim=-1, |
|
|
) |
|
|
valid_mask = camera_dist_2med <= torch.clamp( |
|
|
torch.quantile(camera_dist_2med, 0.97) * 10, |
|
|
max=1e6, |
|
|
) |
|
|
c2ws[:, :3, 3] -= ref_c2ws[valid_mask, :3, 3].mean(0, keepdim=True) |
|
|
|
|
|
|
|
|
camera_dists = c2ws[:, :3, 3].clone() |
|
|
translation_scaling_factor = ( |
|
|
self.camera_scale |
|
|
if torch.isclose( |
|
|
torch.norm(camera_dists[0]), |
|
|
torch.zeros(1).to(self.device, self.dtype), |
|
|
atol=1e-5, |
|
|
).any() |
|
|
else (self.camera_scale / torch.norm(camera_dists[0]) + 0.01) |
|
|
) |
|
|
return translation_scaling_factor, c2ws |
|
|
|
|
|
|
|
|
def get_cond(self, context_latents, all_c2ws, all_Ks, translation_scaling_factor, encoder_embeddings, input_masks): |
|
|
context_encoder_embeddings = torch.mean(encoder_embeddings, dim=0) |
|
|
input_masks = input_masks.bool() |
|
|
|
|
|
|
|
|
all_c2ws[:, :, [1, 2]] *= -1 |
|
|
all_w2cs = torch.linalg.inv(all_c2ws) |
|
|
all_c2ws[:, :3, 3] *= translation_scaling_factor |
|
|
all_w2cs[:, :3, 3] *= translation_scaling_factor |
|
|
num_cameras = all_w2cs.shape[0] |
|
|
|
|
|
|
|
|
pluckers = get_plucker_coordinates( |
|
|
extrinsics_src=all_w2cs[:1], |
|
|
extrinsics=all_w2cs, |
|
|
intrinsics=all_Ks.float().clone(), |
|
|
target_size=(context_latents.shape[-2], context_latents.shape[-1]), |
|
|
) |
|
|
|
|
|
target_latents = torch.nn.functional.pad( |
|
|
torch.zeros(self.config.model.num_frames - context_latents.shape[0], *context_latents.shape[1:]), (0, 0, 0, 0, 0, 1), value=0 |
|
|
).to(self.device, self.dtype) |
|
|
context_latents = torch.nn.functional.pad( |
|
|
context_latents, (0, 0, 0, 0, 0, 1), value=1.0 |
|
|
) |
|
|
|
|
|
c_crossattn = repeat(context_encoder_embeddings, "d -> n 1 d", n=num_cameras) |
|
|
|
|
|
|
|
|
uc_crossattn = torch.zeros_like(c_crossattn) |
|
|
c_replace = torch.zeros((num_cameras, *context_latents.shape[1:])).to(self.device) |
|
|
c_replace[input_masks] = context_latents |
|
|
c_replace[~input_masks] = target_latents |
|
|
uc_replace = torch.zeros_like(c_replace) |
|
|
c_concat = torch.cat( |
|
|
[ |
|
|
repeat( |
|
|
input_masks, |
|
|
"n ->n 1 h w", |
|
|
h=pluckers.shape[-2], |
|
|
w=pluckers.shape[-1], |
|
|
), |
|
|
pluckers, |
|
|
], |
|
|
1, |
|
|
) |
|
|
uc_concat = torch.cat( |
|
|
[torch.zeros((num_cameras, 1, *pluckers.shape[-2:])).to(self.device), pluckers], 1 |
|
|
) |
|
|
c_dense_vector = pluckers |
|
|
uc_dense_vector = c_dense_vector |
|
|
c = { |
|
|
"crossattn": c_crossattn, |
|
|
"replace": c_replace, |
|
|
"concat": c_concat, |
|
|
"dense_vector": c_dense_vector, |
|
|
} |
|
|
uc = { |
|
|
"crossattn": uc_crossattn, |
|
|
"replace": uc_replace, |
|
|
"concat": uc_concat, |
|
|
"dense_vector": uc_dense_vector, |
|
|
} |
|
|
|
|
|
return {"c": c, |
|
|
"uc": uc, |
|
|
"all_c2ws": all_c2ws, |
|
|
"all_Ks": all_Ks, |
|
|
"input_masks": input_masks, |
|
|
"num_cameras": num_cameras} |
|
|
|
|
|
|
|
|
def _generate_frames_for_trajectory(self, c2ws_tensor, Ks_tensor, use_non_maximum_suppression=None): |
|
|
""" |
|
|
Internal helper method to generate frames for a trajectory. |
|
|
|
|
|
Args: |
|
|
c2ws: List of camera-to-world matrices |
|
|
Ks: List of camera intrinsic matrices |
|
|
|
|
|
|
|
|
Returns: |
|
|
List of all generated PIL frames |
|
|
""" |
|
|
|
|
|
padding_size = 0 |
|
|
|
|
|
generation_steps = (len(c2ws_tensor) + 1 - self.config.model.num_frames) // self.config.model.target_num_frames + 2 |
|
|
|
|
|
|
|
|
cur_start_idx = 0 |
|
|
for i in range(generation_steps): |
|
|
|
|
|
if i > 0: |
|
|
cur_start_idx = cur_end_idx |
|
|
if len(self.pil_frames) == 1: |
|
|
cur_end_idx = min(cur_start_idx + self.config.model.num_frames - 1, len(c2ws_tensor)) |
|
|
else: |
|
|
cur_end_idx = min(cur_start_idx + self.config.model.target_num_frames, len(c2ws_tensor)) |
|
|
|
|
|
target_length = cur_end_idx - cur_start_idx |
|
|
if target_length <= 0: |
|
|
break |
|
|
|
|
|
|
|
|
if target_length < self.config.model.target_num_frames or (len(self.pil_frames) == 1 and target_length < self.config.model.num_frames - 1): |
|
|
|
|
|
if len(self.pil_frames) == 1: |
|
|
padding_size = self.config.model.num_frames - 1 - target_length |
|
|
else: |
|
|
padding_size = self.config.model.target_num_frames - target_length |
|
|
padding = torch.tile(c2ws_tensor[cur_end_idx-1:cur_end_idx], (padding_size, 1, 1)) |
|
|
c2ws_tensor = torch.cat([c2ws_tensor, padding], dim=0) |
|
|
|
|
|
padding_K = torch.tile(Ks_tensor[cur_end_idx-1:cur_end_idx], (padding_size, 1, 1)) |
|
|
Ks_tensor = torch.cat([Ks_tensor, padding_K], dim=0) |
|
|
|
|
|
if len(self.pil_frames) == 1: |
|
|
cur_end_idx = cur_start_idx + self.config.model.num_frames - 1 |
|
|
else: |
|
|
cur_end_idx = cur_start_idx + self.config.model.target_num_frames |
|
|
|
|
|
target_c2ws = c2ws_tensor[cur_start_idx:cur_end_idx] |
|
|
target_Ks = Ks_tensor[cur_start_idx:cur_end_idx] |
|
|
|
|
|
|
|
|
context_info = self.get_context_info(target_c2ws, use_non_maximum_suppression) |
|
|
|
|
|
(context_c2ws, |
|
|
context_latents, |
|
|
context_encoder_embeddings, |
|
|
context_Ks, |
|
|
context_time_indices) \ |
|
|
= (context_info["context_c2ws"], |
|
|
context_info["context_latents"], |
|
|
context_info["context_encoder_embeddings"], |
|
|
context_info["context_Ks"], |
|
|
context_info["context_time_indices"]) |
|
|
|
|
|
|
|
|
all_c2ws = torch.cat([context_c2ws, target_c2ws], dim=0) |
|
|
all_Ks = torch.cat([context_Ks, target_Ks], dim=0) |
|
|
translation_scaling_factor, all_c2ws = self.get_translation_scaling_factor(all_c2ws) |
|
|
input_masks = torch.cat([torch.ones(len(context_c2ws)), torch.zeros(len(target_c2ws))], dim=0).bool().to(self.device) |
|
|
cond = self.get_cond(context_latents, all_c2ws, all_Ks, translation_scaling_factor, context_encoder_embeddings, input_masks) |
|
|
|
|
|
|
|
|
samples, samples_z = do_sample(self.model_wrapper, |
|
|
self.vae, |
|
|
self.denoiser, |
|
|
self.sampler[0], |
|
|
cond["c"], |
|
|
cond["uc"], |
|
|
cond["all_c2ws"], |
|
|
cond["all_Ks"], |
|
|
input_masks, |
|
|
H=576, W=576, C=4, F=8, T=8, |
|
|
cfg=self.config.model.cfg, |
|
|
verbose=True, |
|
|
global_pbar=None, |
|
|
return_latents=True, |
|
|
device=self.device) |
|
|
|
|
|
|
|
|
target_num = torch.sum(~input_masks) |
|
|
target_samples = samples[~input_masks] |
|
|
target_pil_frames = [tensor_to_pil(target_samples[j]) for j in range(target_num)] |
|
|
target_encoder_embeddings = encode_image(target_samples, self.image_encoder, self.device, self.dtype) |
|
|
target_latents = samples_z[~input_masks] |
|
|
|
|
|
for j in range(target_num - padding_size if padding_size > 0 else target_num): |
|
|
self.latents.append(target_latents[j].detach().cpu().numpy()) |
|
|
self.encoder_embeddings.append(target_encoder_embeddings[j].detach().cpu().numpy()) |
|
|
self.Ks.append(target_Ks[j].detach().cpu().numpy()) |
|
|
self.c2ws.append(target_c2ws[j].detach().cpu().numpy()) |
|
|
self.pil_frames.append(target_pil_frames[j]) |
|
|
|
|
|
if self.config.inference.visualize: |
|
|
self.pil_frames[-1].save(f"{self.config.visualization_dir}/final_{len(self.pil_frames):07d}.png") |
|
|
|
|
|
|
|
|
|
|
|
self.construct_and_store_scene(self.pil_frames, |
|
|
time_indices=context_time_indices, |
|
|
niter=self.config.surfel.niter, |
|
|
lr=self.config.surfel.lr, |
|
|
device=self.device) |
|
|
self.global_step += 1 |
|
|
|
|
|
if self.config.inference.visualize: |
|
|
export_to_gif(self.pil_frames, f"{self.config.visualization_dir}/inference_all.gif") |
|
|
|
|
|
|
|
|
return self.pil_frames[-self.config.model.target_num_frames:] if len(self.pil_frames) > self.config.model.target_num_frames + 1 else self.pil_frames |
|
|
|
|
|
def generate_trajectory_frames(self, c2ws: List[np.ndarray], Ks: List[np.ndarray], use_non_maximum_suppression=None): |
|
|
""" |
|
|
Generate frames for a new trajectory segment while maintaining the pipeline state. |
|
|
This allows for interactive navigation through a scene. |
|
|
|
|
|
Args: |
|
|
c2ws: List of camera-to-world matrices for the new trajectory segment |
|
|
Ks: List of camera intrinsic matrices for the new trajectory segment |
|
|
|
|
|
Returns: |
|
|
List of PIL images for the newly generated frames |
|
|
""" |
|
|
c2ws_tensor = torch.from_numpy(np.array(c2ws)).to(self.device, self.dtype) |
|
|
Ks_tensor = torch.from_numpy(np.array(Ks)).to(self.device, self.dtype) |
|
|
|
|
|
|
|
|
return self._generate_frames_for_trajectory(c2ws_tensor, Ks_tensor, use_non_maximum_suppression) |
|
|
|
|
|
def undo_latest_move(self): |
|
|
""" |
|
|
Undo the latest move by deleting the most recent batch of camera poses, embeddings, and pil images. |
|
|
This allows stepping back in the trajectory if navigation went in an undesired direction. |
|
|
|
|
|
The method removes the last generated batch of frames (up to target_num_frames) since the pipeline |
|
|
generates multiple frames at once during each generation step. |
|
|
|
|
|
Returns: |
|
|
bool: True if successfully removed the latest frames, False if there's nothing to remove |
|
|
(e.g., only one frame in the pipeline) |
|
|
""" |
|
|
|
|
|
if len(self.pil_frames) <= 1: |
|
|
print("Cannot undo: only one frame in the pipeline") |
|
|
return False |
|
|
|
|
|
|
|
|
frames_to_remove = min(self.config.model.target_num_frames, len(self.pil_frames) - 1) |
|
|
|
|
|
|
|
|
for _ in range(frames_to_remove): |
|
|
self.latents.pop() |
|
|
self.encoder_embeddings.pop() |
|
|
self.c2ws.pop() |
|
|
self.Ks.pop() |
|
|
self.pil_frames.pop() |
|
|
|
|
|
|
|
|
|
|
|
self.global_step -= frames_to_remove |
|
|
|
|
|
for _ in range(frames_to_remove): |
|
|
self.surfel_depths.pop() |
|
|
|
|
|
|
|
|
|
|
|
current_frame_count = len(self.pil_frames) |
|
|
removed_timesteps = list(range(current_frame_count, current_frame_count + frames_to_remove)) |
|
|
surfels_to_remove = [] |
|
|
|
|
|
|
|
|
updated_surfel_to_timestep = {} |
|
|
for i, timesteps in self.surfel_to_timestep.items(): |
|
|
|
|
|
if all(ts in removed_timesteps for ts in timesteps): |
|
|
surfels_to_remove.append(i) |
|
|
else: |
|
|
|
|
|
updated_timesteps = [ts for ts in timesteps if ts not in removed_timesteps] |
|
|
updated_surfel_to_timestep[i] = updated_timesteps |
|
|
|
|
|
|
|
|
updated_surfels = [] |
|
|
updated_final_surfel_to_timestep = {} |
|
|
new_idx = 0 |
|
|
|
|
|
for i, surfel in enumerate(self.surfels): |
|
|
if i not in surfels_to_remove: |
|
|
updated_surfels.append(surfel) |
|
|
updated_final_surfel_to_timestep[new_idx] = updated_surfel_to_timestep[i] |
|
|
new_idx += 1 |
|
|
|
|
|
|
|
|
self.surfels = updated_surfels |
|
|
self.surfel_to_timestep = updated_final_surfel_to_timestep |
|
|
|
|
|
print(f"Successfully removed the latest {frames_to_remove} frames. {len(self.pil_frames)} frames remaining.") |
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, image:torch.Tensor, c2ws: List[np.ndarray], Ks: List[np.ndarray]): |
|
|
""" |
|
|
Process an initial image and generate frames for a trajectory. |
|
|
|
|
|
Args: |
|
|
image: Initial image tensor |
|
|
c2ws: Camera-to-world matrices for the trajectory |
|
|
Ks: Camera intrinsic matrices for the trajectory |
|
|
|
|
|
Returns: |
|
|
List of PIL images for all generated frames |
|
|
""" |
|
|
|
|
|
c2ws_tensor = torch.from_numpy(np.array(c2ws)).to(self.device, self.dtype) |
|
|
|
|
|
Ks_tensor = torch.from_numpy(np.array(Ks)).to(self.device, self.dtype) |
|
|
|
|
|
|
|
|
|
|
|
self.initialize(image, c2ws_tensor[0].detach().cpu().numpy(), Ks_tensor[0].detach().cpu().numpy()) |
|
|
|
|
|
return self._generate_frames_for_trajectory(c2ws_tensor[1:], Ks_tensor[1:]) |
|
|
|
|
|
|
|
|
|