| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from abc import ABC, abstractmethod |
|
|
| import torch |
| import torch.distributed as dist |
| from torch import Tensor |
|
|
|
|
| class Metric(ABC): |
| """ Metric class with synchronization capabilities similar to TorchMetrics """ |
|
|
| def __init__(self): |
| self.states = {} |
|
|
| def add_state(self, name: str, default: Tensor): |
| assert name not in self.states |
| self.states[name] = default.clone() |
| setattr(self, name, default) |
|
|
| def synchronize(self): |
| if dist.is_initialized(): |
| for state in self.states: |
| dist.all_reduce(getattr(self, state), op=dist.ReduceOp.SUM, group=dist.group.WORLD) |
|
|
| def __call__(self, *args, **kwargs): |
| self.update(*args, **kwargs) |
|
|
| def reset(self): |
| for name, default in self.states.items(): |
| setattr(self, name, default.clone()) |
|
|
| def compute(self): |
| self.synchronize() |
| value = self._compute().item() |
| self.reset() |
| return value |
|
|
| @abstractmethod |
| def _compute(self): |
| pass |
|
|
| @abstractmethod |
| def update(self, preds: Tensor, targets: Tensor): |
| pass |
|
|
|
|
| class MeanAbsoluteError(Metric): |
| def __init__(self): |
| super().__init__() |
| self.add_state('error', torch.tensor(0, dtype=torch.float32, device='cuda')) |
| self.add_state('total', torch.tensor(0, dtype=torch.int32, device='cuda')) |
|
|
| def update(self, preds: Tensor, targets: Tensor): |
| preds = preds.detach() |
| n = preds.shape[0] |
| error = torch.abs(preds.view(n, -1) - targets.view(n, -1)).sum() |
| self.total += n |
| self.error += error |
|
|
| def _compute(self): |
| return self.error / self.total |
|
|