Spaces:
Sleeping
Sleeping
File size: 6,073 Bytes
78d2329 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic, Literal, TypeVar
import torch
from jaxtyping import Float, Int32, Bool, UInt8
from torch import Tensor, nn
from ..types import Gaussians
from ...dataset import DatasetCfg
from ...dataset.data_types import BatchedViews, BatchedViewsDict, BatchedExample
from ...scene_trainer.gaussian_module import GaussiansModule
DepthRenderingMode = Literal[
"depth",
"log",
"disparity",
"relative_disparity",
]
@dataclass
class DecoderOutput:
color: Float[Tensor, "batch view 3 height width"] | UInt8[Tensor, "batch view 3 height width"]
depth: Float[Tensor, "batch view height width"] | None
normal: Float[Tensor, "batch view 3 height width"] | None = None
distortion_map: Float[Tensor, "batch view height width"] | None = None
accumulated_alpha: Float[Tensor, "batch view height width"] | None = None
radii: Int32[Tensor, "batch view n 2"] | None = None
means2d: Float[Tensor, "batch view n 2"] | None = None
visibility_filter: Bool[Tensor, "batch view n"] | None = None
T = TypeVar("T")
class Decoder(nn.Module, ABC, Generic[T]):
cfg: T
dataset_cfg: DatasetCfg
def __init__(self, cfg: T, dataset_cfg: DatasetCfg) -> None:
super().__init__()
self.cfg = cfg
self.dataset_cfg = dataset_cfg
@abstractmethod
def forward(
self,
gaussians: Gaussians | GaussiansModule,
extrinsics: Float[Tensor, "batch view 4 4"],
intrinsics: Float[Tensor, "batch view 3 3"],
near: Float[Tensor, "batch view"],
far: Float[Tensor, "batch view"],
image_shape: tuple[int, int],
depth_mode: DepthRenderingMode | None = None,
to_cpu: bool = False,
) -> DecoderOutput:
pass
def forward_batch(
self,
gaussians: Gaussians | GaussiansModule,
batch: BatchedExample,
image_shape: tuple[int, int] | None = None,
input_str: Literal["context", "target"] | None = None,
eval_context_views: bool | None = None,
depth_mode: DepthRenderingMode | None = None,
start=None, end=None,
camera_poses=None, # In case of manipulating camera poses (e.g. for stabilization)
to_cpu: bool = False, # move outputs to cpu as they are rendered
iter_batch_size: int = -1, # -1 to render all views at once
) -> DecoderOutput:
assert input_str is not None or eval_context_views is not None
if input_str is None:
input_str = "context" if eval_context_views else "target"
input = batch[input_str]
if image_shape is None:
image_shape = input["image_shape"].shape[-2:]
if camera_poses is None:
camera_poses = input["extrinsics"]
return self.forward(
gaussians,
camera_poses[:, start:end],
input["intrinsics"][:, start:end],
input["near"][:, start:end],
input["far"][:, start:end],
image_shape,
depth_mode=depth_mode,
to_cpu=to_cpu,
iter_batch_size=iter_batch_size,
)
def forward_batch_subset(self, gaussians: Gaussians | GaussiansModule,
batch_subset: BatchedViewsDict | BatchedViews,
image_shape: tuple[int, int] | None = None,
start: int | None = None,
end: int | None = None,
indices: torch.Tensor | list | None = None,
**kwargs) -> DecoderOutput:
assert not ((start is not None and end is not None) and (
indices is not None)), "Either start and end or indices must be provided."
if start is not None:
indices = list(range(start, end))
if indices is None:
indices = list(range(batch_subset["extrinsics"].shape[1]))
if isinstance(indices, list):
# Convert list to tensor for one flow handling
indices = torch.tensor(indices, device=batch_subset["extrinsics"].device)
indices = indices.unsqueeze(0).expand(batch_subset["extrinsics"].shape[0], -1) # (batch, num_indices)
if image_shape is None:
image_shape = batch_subset["image"].shape[-2:]
assert indices.dim() == 2, "Indices tensor must be 2D (scene_batch, num_indices)."
scene_batch = indices.size(0)
scene_batch_idx = torch.arange(scene_batch, device=indices.device)[:, None] # (batch, 1)
return self.forward(gaussians,
batch_subset["extrinsics"][scene_batch_idx, indices],
batch_subset["intrinsics"][scene_batch_idx, indices],
batch_subset["near"][scene_batch_idx, indices],
batch_subset["far"][scene_batch_idx, indices],
image_shape,
**kwargs)
def forward_context(
self,
gaussians: Gaussians | GaussiansModule,
batch: BatchedExample,
image_shape: tuple[int, int] | None = None,
depth_mode: DepthRenderingMode | None = None,
**kwargs,
) -> DecoderOutput:
return self.forward_batch(
gaussians,
batch,
image_shape,
"context",
depth_mode=depth_mode,
**kwargs,
)
def forward_target(
self,
gaussians: Gaussians | GaussiansModule,
batch: BatchedExample,
image_shape: tuple[int, int] | None = None,
depth_mode: DepthRenderingMode | None = None,
**kwargs,
) -> DecoderOutput:
return self.forward_batch(
gaussians,
batch,
image_shape,
"target",
depth_mode=depth_mode,
**kwargs,
)
|