|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Data transformation utilities for GSLRM model. |
|
|
|
|
|
This module contains classes and utilities for transforming input and target data |
|
|
for training and inference in the GSLRM (Gaussian Splatting Latent Radiance Model). |
|
|
""" |
|
|
|
|
|
import itertools |
|
|
import random |
|
|
from typing import Dict, Optional, Tuple, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from easydict import EasyDict as edict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_camera_rays( |
|
|
fxfycxcy: torch.Tensor, |
|
|
c2w: torch.Tensor, |
|
|
h: int, |
|
|
w: int, |
|
|
device: torch.device |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Compute camera rays for given intrinsics and extrinsics. |
|
|
|
|
|
Args: |
|
|
fxfycxcy: Camera intrinsics [b*v, 4] |
|
|
c2w: Camera-to-world matrices [b*v, 4, 4] |
|
|
h: Image height |
|
|
w: Image width |
|
|
device: Target device |
|
|
|
|
|
Returns: |
|
|
Tuple of (ray_origins, ray_directions, ray_directions_camera) |
|
|
""" |
|
|
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") |
|
|
y, x = y.to(device), x.to(device) |
|
|
|
|
|
b_v = fxfycxcy.size(0) |
|
|
x = x[None, :, :].expand(b_v, -1, -1).reshape(b_v, -1) |
|
|
y = y[None, :, :].expand(b_v, -1, -1).reshape(b_v, -1) |
|
|
|
|
|
|
|
|
x = (x + 0.5 - fxfycxcy[:, 2:3]) / fxfycxcy[:, 0:1] |
|
|
y = (y + 0.5 - fxfycxcy[:, 3:4]) / fxfycxcy[:, 1:2] |
|
|
z = torch.ones_like(x) |
|
|
|
|
|
ray_d_cam = torch.stack([x, y, z], dim=2) |
|
|
ray_d_cam = ray_d_cam / torch.norm(ray_d_cam, dim=2, keepdim=True) |
|
|
|
|
|
|
|
|
ray_d = torch.bmm(ray_d_cam, c2w[:, :3, :3].transpose(1, 2)) |
|
|
ray_d = ray_d / torch.norm(ray_d, dim=2, keepdim=True) |
|
|
ray_o = c2w[:, :3, 3][:, None, :].expand_as(ray_d) |
|
|
|
|
|
return ray_o, ray_d, ray_d_cam |
|
|
|
|
|
|
|
|
def sample_patch_rays( |
|
|
image: torch.Tensor, |
|
|
fxfycxcy: torch.Tensor, |
|
|
c2w: torch.Tensor, |
|
|
patch_size: int, |
|
|
h: int, |
|
|
w: int |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Sample rays at patch centers for efficient processing. |
|
|
|
|
|
Args: |
|
|
image: Input images [b*v, c, h, w] |
|
|
fxfycxcy: Camera intrinsics [b*v, 4] |
|
|
c2w: Camera-to-world matrices [b*v, 4, 4] |
|
|
patch_size: Size of patches |
|
|
h: Image height |
|
|
w: Image width |
|
|
|
|
|
Returns: |
|
|
Tuple of (colors, ray_origins, ray_directions, xy_norm, projection_matrices) |
|
|
""" |
|
|
b_v, c = image.shape[:2] |
|
|
device = image.device |
|
|
|
|
|
start_patch_center = patch_size / 2.0 |
|
|
y, x = torch.meshgrid( |
|
|
torch.arange(h // patch_size) * patch_size + start_patch_center, |
|
|
torch.arange(w // patch_size) * patch_size + start_patch_center, |
|
|
indexing="ij", |
|
|
) |
|
|
y, x = y.to(device), x.to(device) |
|
|
|
|
|
x_flat = x[None, :, :].expand(b_v, -1, -1).reshape(b_v, -1) |
|
|
y_flat = y[None, :, :].expand(b_v, -1, -1).reshape(b_v, -1) |
|
|
|
|
|
|
|
|
ray_color = F.grid_sample( |
|
|
image, |
|
|
torch.stack([x_flat / w * 2.0 - 1.0, y_flat / h * 2.0 - 1.0], dim=2).reshape( |
|
|
b_v, -1, 1, 2 |
|
|
), |
|
|
align_corners=False, |
|
|
).squeeze(-1).permute(0, 2, 1).contiguous() |
|
|
|
|
|
|
|
|
ray_xy_norm = torch.stack([x_flat / w, y_flat / h], dim=2) |
|
|
|
|
|
|
|
|
K_norm = torch.eye(3, device=device).unsqueeze(0).repeat(b_v, 1, 1) |
|
|
K_norm[:, 0, 0] = fxfycxcy[:, 0] / w |
|
|
K_norm[:, 1, 1] = fxfycxcy[:, 1] / h |
|
|
K_norm[:, 0, 2] = fxfycxcy[:, 2] / w |
|
|
K_norm[:, 1, 2] = fxfycxcy[:, 3] / h |
|
|
|
|
|
w2c = torch.inverse(c2w) |
|
|
proj_mat = torch.bmm(K_norm, w2c[:, :3, :4]) |
|
|
proj_mat = proj_mat.reshape(b_v, 12) |
|
|
proj_mat = proj_mat / (proj_mat.norm(dim=1, keepdim=True) + 1e-6) |
|
|
proj_mat = proj_mat.reshape(b_v, 3, 4) |
|
|
proj_mat = proj_mat * proj_mat[:, 0:1, 0:1].sign() |
|
|
|
|
|
|
|
|
x_norm = (x_flat - fxfycxcy[:, 2:3]) / fxfycxcy[:, 0:1] |
|
|
y_norm = (y_flat - fxfycxcy[:, 3:4]) / fxfycxcy[:, 1:2] |
|
|
z_norm = torch.ones_like(x_norm) |
|
|
|
|
|
ray_d = torch.stack([x_norm, y_norm, z_norm], dim=2) |
|
|
ray_d = torch.bmm(ray_d, c2w[:, :3, :3].transpose(1, 2)) |
|
|
ray_d = ray_d / torch.norm(ray_d, dim=2, keepdim=True) |
|
|
ray_o = c2w[:, :3, 3][:, None, :].expand_as(ray_d) |
|
|
|
|
|
return ray_color, ray_o, ray_d, ray_xy_norm, proj_mat |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SplitData(nn.Module): |
|
|
""" |
|
|
Split data batch into input and target views for training. |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
@torch.no_grad() |
|
|
def forward(self, data_batch: Dict[str, torch.Tensor], target_has_input: bool = True) -> Tuple[edict, edict]: |
|
|
""" |
|
|
Split data into input and target views. |
|
|
|
|
|
Args: |
|
|
data_batch: Dictionary containing batch data |
|
|
target_has_input: Whether target views can overlap with input views |
|
|
|
|
|
Returns: |
|
|
Tuple of (input_data, target_data) |
|
|
""" |
|
|
input_data, target_data = {}, {} |
|
|
index = None |
|
|
|
|
|
for key, value in data_batch.items(): |
|
|
|
|
|
input_data[key] = value[:, :self.config.training.dataset.num_input_views, ...] |
|
|
|
|
|
|
|
|
num_target_views = self.config.training.dataset.num_views |
|
|
|
|
|
if num_target_views >= value.size(1): |
|
|
target_data[key] = value |
|
|
else: |
|
|
if index is None: |
|
|
index = self._generate_target_indices( |
|
|
value, target_has_input |
|
|
) |
|
|
|
|
|
target_data[key] = self._gather_target_data(value, index) |
|
|
|
|
|
return edict(input_data), edict(target_data) |
|
|
|
|
|
def _generate_target_indices(self, value: torch.Tensor, target_has_input: bool) -> torch.Tensor: |
|
|
"""Generate indices for target view selection.""" |
|
|
b, v = value.shape[:2] |
|
|
|
|
|
|
|
|
num_input_views = self.config.training.dataset.num_input_views |
|
|
num_views = self.config.training.dataset.num_views |
|
|
num_target_views = num_views |
|
|
|
|
|
if target_has_input: |
|
|
|
|
|
index = np.array([ |
|
|
random.sample(range(v), num_target_views) |
|
|
for _ in range(b) |
|
|
]) |
|
|
else: |
|
|
|
|
|
assert ( |
|
|
num_input_views + num_target_views <= num_views |
|
|
), "num_input_views + num_target_views must <= num_views to avoid duplicate views" |
|
|
|
|
|
index = np.array([ |
|
|
[num_views - 1 - j for j in range(num_target_views)] |
|
|
for _ in range(b) |
|
|
]) |
|
|
|
|
|
return torch.from_numpy(index).long().to(value.device) |
|
|
|
|
|
def _gather_target_data(self, value: torch.Tensor, index: torch.Tensor) -> torch.Tensor: |
|
|
"""Gather target data using provided indices.""" |
|
|
value_index = index |
|
|
if value.dim() > 2: |
|
|
dummy_dims = [1] * (value.dim() - 2) |
|
|
value_index = index.reshape(index.size(0), index.size(1), *dummy_dims) |
|
|
|
|
|
try: |
|
|
return torch.gather( |
|
|
value, |
|
|
dim=1, |
|
|
index=value_index.expand(-1, -1, *value.size()[2:]), |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error gathering data for key with value shape: {value.size()}") |
|
|
print(f"Index shape: {value_index.size()}") |
|
|
raise e |
|
|
|
|
|
|
|
|
class TransformInput(nn.Module): |
|
|
""" |
|
|
Transform input data for feeding into the transformer network. |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
@torch.no_grad() |
|
|
def forward(self, data_batch: edict, patch_size: Optional[int] = None) -> edict: |
|
|
""" |
|
|
Transform input images to rays and other representations. |
|
|
|
|
|
Args: |
|
|
data_batch: Input data batch |
|
|
patch_size: Optional patch size for patch-based processing |
|
|
|
|
|
Returns: |
|
|
Transformed input data |
|
|
""" |
|
|
self._validate_input(data_batch) |
|
|
|
|
|
image, fxfycxcy, c2w, index = ( |
|
|
data_batch.image, data_batch.fxfycxcy, |
|
|
data_batch.c2w, data_batch.index |
|
|
) |
|
|
|
|
|
b, v, c, h, w = image.size() |
|
|
|
|
|
|
|
|
image_flat = image.reshape(b * v, c, h * w) |
|
|
fxfycxcy_flat = fxfycxcy.reshape(b * v, 4) |
|
|
c2w_flat = c2w.reshape(b * v, 4, 4) |
|
|
|
|
|
|
|
|
xy_norm = self._compute_normalized_coordinates(b, v, h, w, image.device) |
|
|
|
|
|
|
|
|
ray_o, ray_d, ray_d_cam = compute_camera_rays( |
|
|
fxfycxcy_flat, c2w_flat, h, w, image.device |
|
|
) |
|
|
|
|
|
|
|
|
patch_data = self._process_patches( |
|
|
image_flat, fxfycxcy_flat, c2w_flat, patch_size, h, w, b, v, c |
|
|
) if patch_size is not None else (None, None, None, None, None) |
|
|
|
|
|
|
|
|
ray_o = ray_o.reshape(b, v, h, w, 3).permute(0, 1, 4, 2, 3) |
|
|
ray_d = ray_d.reshape(b, v, h, w, 3).permute(0, 1, 4, 2, 3) |
|
|
ray_d_cam = ray_d_cam.reshape(b, v, h, w, 3).permute(0, 1, 4, 2, 3) |
|
|
|
|
|
return edict( |
|
|
image=image, |
|
|
ray_o=ray_o, |
|
|
ray_d=ray_d, |
|
|
ray_d_cam=ray_d_cam, |
|
|
fxfycxcy=fxfycxcy, |
|
|
c2w=c2w, |
|
|
index=index, |
|
|
xy_norm=xy_norm, |
|
|
ray_color_patch=patch_data[0], |
|
|
ray_o_patch=patch_data[1], |
|
|
ray_d_patch=patch_data[2], |
|
|
ray_xy_norm_patch=patch_data[3], |
|
|
proj_mat=patch_data[4], |
|
|
) |
|
|
|
|
|
def _validate_input(self, data_batch: edict) -> None: |
|
|
"""Validate input data dimensions.""" |
|
|
assert data_batch.image.dim() == 5, f"image dim should be 5, got {data_batch.image.dim()}" |
|
|
assert data_batch.fxfycxcy.dim() == 3, f"fxfycxcy dim should be 3, got {data_batch.fxfycxcy.dim()}" |
|
|
assert data_batch.c2w.dim() == 4, f"c2w dim should be 4, got {data_batch.c2w.dim()}" |
|
|
|
|
|
def _compute_normalized_coordinates(self, b: int, v: int, h: int, w: int, device: torch.device) -> torch.Tensor: |
|
|
"""Compute normalized coordinates for the full image.""" |
|
|
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") |
|
|
y, x = y.to(device), x.to(device) |
|
|
|
|
|
y_norm = (y + 0.5) / h * 2 - 1 |
|
|
x_norm = (x + 0.5) / w * 2 - 1 |
|
|
|
|
|
return torch.stack([x_norm, y_norm], dim=0)[None, None, :, :, :].expand(b, v, -1, -1, -1) |
|
|
|
|
|
def _process_patches( |
|
|
self, |
|
|
image: torch.Tensor, |
|
|
fxfycxcy: torch.Tensor, |
|
|
c2w: torch.Tensor, |
|
|
patch_size: int, |
|
|
h: int, |
|
|
w: int, |
|
|
b: int, |
|
|
v: int, |
|
|
c: int |
|
|
) -> Tuple[Optional[torch.Tensor], ...]: |
|
|
"""Process patch-based data if patch_size is provided.""" |
|
|
ray_color, ray_o, ray_d, ray_xy_norm, proj_mat = sample_patch_rays( |
|
|
image.reshape(b * v, c, h, w), fxfycxcy, c2w, patch_size, h, w |
|
|
) |
|
|
|
|
|
n_patch = ray_color.size(1) |
|
|
|
|
|
return ( |
|
|
ray_color.reshape(b, v, n_patch, c), |
|
|
ray_o.reshape(b, v, n_patch, 3), |
|
|
ray_d.reshape(b, v, n_patch, 3), |
|
|
ray_xy_norm.reshape(b, v, n_patch, 2), |
|
|
proj_mat.reshape(b, v, 3, 4), |
|
|
) |
|
|
|
|
|
|
|
|
class TransformTarget(nn.Module): |
|
|
""" |
|
|
Handles target image transformations during training. |
|
|
|
|
|
Currently implements random cropping for data augmentation. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: edict): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
@torch.no_grad() |
|
|
def forward(self, data_batch: edict) -> edict: |
|
|
""" |
|
|
Apply transformations to target data. |
|
|
|
|
|
Args: |
|
|
data_batch: Dictionary containing 'image' and 'fxfycxcy' |
|
|
|
|
|
Returns: |
|
|
Transformed data batch |
|
|
""" |
|
|
image = data_batch["image"] |
|
|
fxfycxcy = data_batch["fxfycxcy"] |
|
|
|
|
|
b, v, c, h, w = image.size() |
|
|
crop_size = getattr(self.config.training, 'crop_size', min(h, w)) |
|
|
|
|
|
|
|
|
if h > crop_size or w > crop_size: |
|
|
crop_image = torch.zeros( |
|
|
(b, v, c, crop_size, crop_size), |
|
|
dtype=image.dtype, |
|
|
device=image.device |
|
|
) |
|
|
crop_fxfycxcy = fxfycxcy.clone() |
|
|
|
|
|
for i in range(b): |
|
|
for j in range(v): |
|
|
|
|
|
idx_x = torch.randint(low=0, high=w - crop_size, size=(1,)).item() |
|
|
idx_y = torch.randint(low=0, high=h - crop_size, size=(1,)).item() |
|
|
|
|
|
|
|
|
crop_image[i, j] = image[ |
|
|
i, j, :, idx_y:idx_y + crop_size, idx_x:idx_x + crop_size |
|
|
] |
|
|
|
|
|
|
|
|
crop_fxfycxcy[i, j, 2] -= idx_x |
|
|
crop_fxfycxcy[i, j, 3] -= idx_y |
|
|
|
|
|
data_batch["image"] = crop_image |
|
|
data_batch["fxfycxcy"] = crop_fxfycxcy |
|
|
|
|
|
return data_batch |
|
|
|