| from abc import ABC, abstractmethod |
| from dataclasses import fields |
| from typing import Generic, TypeVar |
|
|
| from jaxtyping import Float |
| from torch import Tensor, nn |
|
|
| from src.dataset.types import BatchedExample |
| from src.model.decoder.decoder import DecoderOutput |
| from src.model.types import Gaussians |
|
|
| T_cfg = TypeVar("T_cfg") |
| T_wrapper = TypeVar("T_wrapper") |
|
|
|
|
| class Loss(nn.Module, ABC, Generic[T_cfg, T_wrapper]): |
| cfg: T_cfg |
| name: str |
|
|
| def __init__(self, cfg: T_wrapper) -> None: |
| super().__init__() |
| |
| |
| (field,) = fields(type(cfg)) |
| self.cfg = getattr(cfg, field.name) |
| self.name = field.name |
|
|
| @abstractmethod |
| def forward( |
| self, |
| prediction: DecoderOutput, |
| batch: BatchedExample, |
| gaussians: Gaussians, |
| depth_dict: dict, |
| global_step: int, |
| ) -> Float[Tensor, ""]: |
| pass |
|
|