Spaces:
Runtime error
Runtime error
| from abc import ABC, abstractmethod | |
| from typing import Generic, TypeVar | |
| from torch import nn | |
| from dataclasses import dataclass | |
| from src.dataset.types import BatchedViews, DataShim | |
| from ..types import Gaussians | |
| from jaxtyping import Float | |
| from torch import Tensor, nn | |
| T = TypeVar("T") | |
| class EncoderOutput: | |
| gaussians: Gaussians | |
| pred_pose_enc_list: list[Float[Tensor, "batch view 6"]] | None | |
| pred_context_pose: dict | None | |
| depth_dict: dict | None | |
| infos: dict | None | |
| distill_infos: dict | None | |
| class Encoder(nn.Module, ABC, Generic[T]): | |
| cfg: T | |
| def __init__(self, cfg: T) -> None: | |
| super().__init__() | |
| self.cfg = cfg | |
| def forward( | |
| self, | |
| context: BatchedViews, | |
| ) -> Gaussians: | |
| pass | |
| def get_data_shim(self) -> DataShim: | |
| """The default shim doesn't modify the batch.""" | |
| return lambda x: x | |