File size: 1,436 Bytes
5d2fa0b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | import torch
from torchmetrics.classification import MulticlassAccuracy, MulticlassAveragePrecision
class MetricTracker:
def __init__(self, num_classes, device):
self.num_classes = num_classes
self.device = device
self.map_metric = MulticlassAveragePrecision(num_classes=num_classes).to(device)
self.acc_metric = MulticlassAccuracy(num_classes=num_classes).to(device)
self.reset()
def reset(self):
self.map_metric.reset()
self.acc_metric.reset()
self.loss_sum = 0
self.count = 0
def update(self, preds, targets, loss=None, skip_metrics=False):
"""
preds: logits [B, C]
targets: [B] or soft labels [B, C]
skip_metrics: If True, only loss is tracked. Use for MixUp/CutMix batches.
"""
if targets.ndim > 1:
hard_targets = targets.argmax(dim=1)
else:
hard_targets = targets
if not skip_metrics:
self.map_metric.update(preds, hard_targets)
self.acc_metric.update(preds, hard_targets)
if loss is not None:
self.loss_sum += loss * preds.size(0)
self.count += preds.size(0)
def compute(self):
mAP = self.map_metric.compute().item()
acc = self.acc_metric.compute().item()
avg_loss = self.loss_sum / max(self.count, 1)
return {"mAP": mAP, "accuracy": acc, "loss": avg_loss}
|