Spaces:
Runtime error
Runtime error
| from abc import ABC, abstractmethod | |
| from dataclasses import fields | |
| from typing import Generic, TypeVar | |
| from jaxtyping import Float | |
| from torch import Tensor, nn | |
| from src.dataset.types import BatchedExample | |
| from src.model.decoder.decoder import DecoderOutput | |
| from src.model.types import Gaussians | |
| T_cfg = TypeVar("T_cfg") | |
| T_wrapper = TypeVar("T_wrapper") | |
| class Loss(nn.Module, ABC, Generic[T_cfg, T_wrapper]): | |
| cfg: T_cfg | |
| name: str | |
| def __init__(self, cfg: T_wrapper) -> None: | |
| super().__init__() | |
| # Extract the configuration from the wrapper. | |
| (field,) = fields(type(cfg)) | |
| self.cfg = getattr(cfg, field.name) | |
| self.name = field.name | |
| def forward( | |
| self, | |
| prediction: DecoderOutput, | |
| batch: BatchedExample, | |
| gaussians: Gaussians, | |
| depth_dict: dict, | |
| global_step: int, | |
| ) -> Float[Tensor, ""]: | |
| pass | |