Spaces:
Runtime error
Runtime error
| import abc | |
| import torch | |
| import torch.nn as nn | |
| from jaxtyping import Float | |
| class AbstractLoss(nn.Module, abc.ABC): | |
| def forward( | |
| self, | |
| pred: Float[torch.Tensor, "B C H W"], | |
| gt: Float[torch.Tensor, "B C H W"], | |
| step: int, | |
| **kwargs, | |
| ) -> Float[torch.Tensor, ""]: | |
| pass | |
| def reset(self): | |
| pass | |