Yeqing0814's picture
Upload folder using huggingface_hub
a6dd040 verified
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic, Literal, TypeVar
from jaxtyping import Float
from torch import Tensor, nn
from ...dataset import DatasetCfg
from ..types import Gaussians
DepthRenderingMode = Literal[
"depth",
"log",
"disparity",
"relative_disparity",
]
@dataclass
class DecoderOutput:
color: Float[Tensor, "batch view 3 height width"]
depth: Float[Tensor, "batch view height width"] | None
T = TypeVar("T")
class Decoder(nn.Module, ABC, Generic[T]):
cfg: T
dataset_cfg: DatasetCfg
def __init__(self, cfg: T, dataset_cfg: DatasetCfg) -> None:
super().__init__()
self.cfg = cfg
self.dataset_cfg = dataset_cfg
@abstractmethod
def forward(
self,
gaussians: Gaussians,
extrinsics: Float[Tensor, "batch view 4 4"],
intrinsics: Float[Tensor, "batch view 3 3"],
near: Float[Tensor, "batch view"],
far: Float[Tensor, "batch view"],
image_shape: tuple[int, int],
depth_mode: DepthRenderingMode | None = None,
) -> DecoderOutput:
pass