Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,037 Bytes
a6e928c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from torch import nn
from dataclasses import dataclass
from src.dataset.types import BatchedViews, DataShim
from ..types import Gaussians
from jaxtyping import Float
from torch import Tensor, nn
T = TypeVar("T")
@dataclass
class EncoderOutput:
gaussians: Gaussians
pred_pose_enc_list: list[Float[Tensor, "batch view 6"]] | None
pred_context_pose: dict | None
depth_dict: dict | None
infos: dict | None
distill_infos: dict | None
pts_all: Float[Tensor, "batch view height width 3"] | None
conf: Float[Tensor, "batch view height width"] | None
class Encoder(nn.Module, ABC, Generic[T]):
cfg: T
def __init__(self, cfg: T) -> None:
super().__init__()
self.cfg = cfg
@abstractmethod
def forward(
self,
context: BatchedViews,
) -> Gaussians:
pass
def get_data_shim(self) -> DataShim:
"""The default shim doesn't modify the batch."""
return lambda x: x
|