from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Generic, Literal, TypeVar import torch from jaxtyping import Float, Int32, Bool, UInt8 from torch import Tensor, nn from ..types import Gaussians from ...dataset import DatasetCfg from ...dataset.data_types import BatchedViews, BatchedViewsDict, BatchedExample from ...scene_trainer.gaussian_module import GaussiansModule DepthRenderingMode = Literal[ "depth", "log", "disparity", "relative_disparity", ] @dataclass class DecoderOutput: color: Float[Tensor, "batch view 3 height width"] | UInt8[Tensor, "batch view 3 height width"] depth: Float[Tensor, "batch view height width"] | None normal: Float[Tensor, "batch view 3 height width"] | None = None distortion_map: Float[Tensor, "batch view height width"] | None = None accumulated_alpha: Float[Tensor, "batch view height width"] | None = None radii: Int32[Tensor, "batch view n 2"] | None = None means2d: Float[Tensor, "batch view n 2"] | None = None visibility_filter: Bool[Tensor, "batch view n"] | None = None T = TypeVar("T") class Decoder(nn.Module, ABC, Generic[T]): cfg: T dataset_cfg: DatasetCfg def __init__(self, cfg: T, dataset_cfg: DatasetCfg) -> None: super().__init__() self.cfg = cfg self.dataset_cfg = dataset_cfg @abstractmethod def forward( self, gaussians: Gaussians | GaussiansModule, extrinsics: Float[Tensor, "batch view 4 4"], intrinsics: Float[Tensor, "batch view 3 3"], near: Float[Tensor, "batch view"], far: Float[Tensor, "batch view"], image_shape: tuple[int, int], depth_mode: DepthRenderingMode | None = None, to_cpu: bool = False, ) -> DecoderOutput: pass def forward_batch( self, gaussians: Gaussians | GaussiansModule, batch: BatchedExample, image_shape: tuple[int, int] | None = None, input_str: Literal["context", "target"] | None = None, eval_context_views: bool | None = None, depth_mode: DepthRenderingMode | None = None, start=None, end=None, camera_poses=None, # In case of manipulating camera poses (e.g. for stabilization) to_cpu: bool = False, # move outputs to cpu as they are rendered iter_batch_size: int = -1, # -1 to render all views at once ) -> DecoderOutput: assert input_str is not None or eval_context_views is not None if input_str is None: input_str = "context" if eval_context_views else "target" input = batch[input_str] if image_shape is None: image_shape = input["image_shape"].shape[-2:] if camera_poses is None: camera_poses = input["extrinsics"] return self.forward( gaussians, camera_poses[:, start:end], input["intrinsics"][:, start:end], input["near"][:, start:end], input["far"][:, start:end], image_shape, depth_mode=depth_mode, to_cpu=to_cpu, iter_batch_size=iter_batch_size, ) def forward_batch_subset(self, gaussians: Gaussians | GaussiansModule, batch_subset: BatchedViewsDict | BatchedViews, image_shape: tuple[int, int] | None = None, start: int | None = None, end: int | None = None, indices: torch.Tensor | list | None = None, **kwargs) -> DecoderOutput: assert not ((start is not None and end is not None) and ( indices is not None)), "Either start and end or indices must be provided." if start is not None: indices = list(range(start, end)) if indices is None: indices = list(range(batch_subset["extrinsics"].shape[1])) if isinstance(indices, list): # Convert list to tensor for one flow handling indices = torch.tensor(indices, device=batch_subset["extrinsics"].device) indices = indices.unsqueeze(0).expand(batch_subset["extrinsics"].shape[0], -1) # (batch, num_indices) if image_shape is None: image_shape = batch_subset["image"].shape[-2:] assert indices.dim() == 2, "Indices tensor must be 2D (scene_batch, num_indices)." scene_batch = indices.size(0) scene_batch_idx = torch.arange(scene_batch, device=indices.device)[:, None] # (batch, 1) return self.forward(gaussians, batch_subset["extrinsics"][scene_batch_idx, indices], batch_subset["intrinsics"][scene_batch_idx, indices], batch_subset["near"][scene_batch_idx, indices], batch_subset["far"][scene_batch_idx, indices], image_shape, **kwargs) def forward_context( self, gaussians: Gaussians | GaussiansModule, batch: BatchedExample, image_shape: tuple[int, int] | None = None, depth_mode: DepthRenderingMode | None = None, **kwargs, ) -> DecoderOutput: return self.forward_batch( gaussians, batch, image_shape, "context", depth_mode=depth_mode, **kwargs, ) def forward_target( self, gaussians: Gaussians | GaussiansModule, batch: BatchedExample, image_shape: tuple[int, int] | None = None, depth_mode: DepthRenderingMode | None = None, **kwargs, ) -> DecoderOutput: return self.forward_batch( gaussians, batch, image_shape, "target", depth_mode=depth_mode, **kwargs, )