| | from abc import ABC, abstractmethod |
| | from typing import Generic, TypeVar |
| |
|
| | from torch import nn |
| | from dataclasses import dataclass |
| | from src.dataset.types import BatchedViews, DataShim |
| | from ..types import Gaussians |
| | from jaxtyping import Float |
| | from torch import Tensor, nn |
| |
|
| | T = TypeVar("T") |
| |
|
| | @dataclass |
| | class EncoderOutput: |
| | gaussians: Gaussians |
| | pred_pose_enc_list: list[Float[Tensor, "batch view 6"]] | None |
| | pred_context_pose: dict | None |
| | depth_dict: dict | None |
| | infos: dict | None |
| | distill_infos: dict | None |
| |
|
| | class Encoder(nn.Module, ABC, Generic[T]): |
| | cfg: T |
| |
|
| | def __init__(self, cfg: T) -> None: |
| | super().__init__() |
| | self.cfg = cfg |
| |
|
| | @abstractmethod |
| | def forward( |
| | self, |
| | context: BatchedViews, |
| | ) -> Gaussians: |
| | pass |
| |
|
| | def get_data_shim(self) -> DataShim: |
| | """The default shim doesn't modify the batch.""" |
| | return lambda x: x |
| |
|