Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import Tensor | |
| from torchmetrics import Metric, Accuracy | |
| class AccuracyMine(Accuracy): | |
| """Wrap torchmetrics.Accuracy to take argmax of y in case of Mixup. | |
| """ | |
| def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore | |
| super().update(preds, target.argmax(dim=-1) if target.is_floating_point() else target) | |