Learn2Splat / optgs /model /decoder /decoder.py
SteEsp's picture
Add Docker-based Learn2Splat demo (viser GUI)
78d2329 verified
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic, Literal, TypeVar
import torch
from jaxtyping import Float, Int32, Bool, UInt8
from torch import Tensor, nn
from ..types import Gaussians
from ...dataset import DatasetCfg
from ...dataset.data_types import BatchedViews, BatchedViewsDict, BatchedExample
from ...scene_trainer.gaussian_module import GaussiansModule
DepthRenderingMode = Literal[
"depth",
"log",
"disparity",
"relative_disparity",
]
@dataclass
class DecoderOutput:
color: Float[Tensor, "batch view 3 height width"] | UInt8[Tensor, "batch view 3 height width"]
depth: Float[Tensor, "batch view height width"] | None
normal: Float[Tensor, "batch view 3 height width"] | None = None
distortion_map: Float[Tensor, "batch view height width"] | None = None
accumulated_alpha: Float[Tensor, "batch view height width"] | None = None
radii: Int32[Tensor, "batch view n 2"] | None = None
means2d: Float[Tensor, "batch view n 2"] | None = None
visibility_filter: Bool[Tensor, "batch view n"] | None = None
T = TypeVar("T")
class Decoder(nn.Module, ABC, Generic[T]):
cfg: T
dataset_cfg: DatasetCfg
def __init__(self, cfg: T, dataset_cfg: DatasetCfg) -> None:
super().__init__()
self.cfg = cfg
self.dataset_cfg = dataset_cfg
@abstractmethod
def forward(
self,
gaussians: Gaussians | GaussiansModule,
extrinsics: Float[Tensor, "batch view 4 4"],
intrinsics: Float[Tensor, "batch view 3 3"],
near: Float[Tensor, "batch view"],
far: Float[Tensor, "batch view"],
image_shape: tuple[int, int],
depth_mode: DepthRenderingMode | None = None,
to_cpu: bool = False,
) -> DecoderOutput:
pass
def forward_batch(
self,
gaussians: Gaussians | GaussiansModule,
batch: BatchedExample,
image_shape: tuple[int, int] | None = None,
input_str: Literal["context", "target"] | None = None,
eval_context_views: bool | None = None,
depth_mode: DepthRenderingMode | None = None,
start=None, end=None,
camera_poses=None, # In case of manipulating camera poses (e.g. for stabilization)
to_cpu: bool = False, # move outputs to cpu as they are rendered
iter_batch_size: int = -1, # -1 to render all views at once
) -> DecoderOutput:
assert input_str is not None or eval_context_views is not None
if input_str is None:
input_str = "context" if eval_context_views else "target"
input = batch[input_str]
if image_shape is None:
image_shape = input["image_shape"].shape[-2:]
if camera_poses is None:
camera_poses = input["extrinsics"]
return self.forward(
gaussians,
camera_poses[:, start:end],
input["intrinsics"][:, start:end],
input["near"][:, start:end],
input["far"][:, start:end],
image_shape,
depth_mode=depth_mode,
to_cpu=to_cpu,
iter_batch_size=iter_batch_size,
)
def forward_batch_subset(self, gaussians: Gaussians | GaussiansModule,
batch_subset: BatchedViewsDict | BatchedViews,
image_shape: tuple[int, int] | None = None,
start: int | None = None,
end: int | None = None,
indices: torch.Tensor | list | None = None,
**kwargs) -> DecoderOutput:
assert not ((start is not None and end is not None) and (
indices is not None)), "Either start and end or indices must be provided."
if start is not None:
indices = list(range(start, end))
if indices is None:
indices = list(range(batch_subset["extrinsics"].shape[1]))
if isinstance(indices, list):
# Convert list to tensor for one flow handling
indices = torch.tensor(indices, device=batch_subset["extrinsics"].device)
indices = indices.unsqueeze(0).expand(batch_subset["extrinsics"].shape[0], -1) # (batch, num_indices)
if image_shape is None:
image_shape = batch_subset["image"].shape[-2:]
assert indices.dim() == 2, "Indices tensor must be 2D (scene_batch, num_indices)."
scene_batch = indices.size(0)
scene_batch_idx = torch.arange(scene_batch, device=indices.device)[:, None] # (batch, 1)
return self.forward(gaussians,
batch_subset["extrinsics"][scene_batch_idx, indices],
batch_subset["intrinsics"][scene_batch_idx, indices],
batch_subset["near"][scene_batch_idx, indices],
batch_subset["far"][scene_batch_idx, indices],
image_shape,
**kwargs)
def forward_context(
self,
gaussians: Gaussians | GaussiansModule,
batch: BatchedExample,
image_shape: tuple[int, int] | None = None,
depth_mode: DepthRenderingMode | None = None,
**kwargs,
) -> DecoderOutput:
return self.forward_batch(
gaussians,
batch,
image_shape,
"context",
depth_mode=depth_mode,
**kwargs,
)
def forward_target(
self,
gaussians: Gaussians | GaussiansModule,
batch: BatchedExample,
image_shape: tuple[int, int] | None = None,
depth_mode: DepthRenderingMode | None = None,
**kwargs,
) -> DecoderOutput:
return self.forward_batch(
gaussians,
batch,
image_shape,
"target",
depth_mode=depth_mode,
**kwargs,
)