File size: 381 Bytes
7349148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import abc

import torch
import torch.nn as nn
from jaxtyping import Float


class AbstractLoss(nn.Module, abc.ABC):
    @abc.abstractmethod
    def forward(
        self,
        pred: Float[torch.Tensor, "B C H W"],
        gt: Float[torch.Tensor, "B C H W"],
        step: int,
        **kwargs,
    ) -> Float[torch.Tensor, ""]:
        pass

    def reset(self):
        pass