File size: 1,436 Bytes
5d2fa0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
from torchmetrics.classification import MulticlassAccuracy, MulticlassAveragePrecision


class MetricTracker:
    def __init__(self, num_classes, device):
        self.num_classes = num_classes
        self.device = device

        self.map_metric = MulticlassAveragePrecision(num_classes=num_classes).to(device)
        self.acc_metric = MulticlassAccuracy(num_classes=num_classes).to(device)

        self.reset()

    def reset(self):
        self.map_metric.reset()
        self.acc_metric.reset()
        self.loss_sum = 0
        self.count = 0

    def update(self, preds, targets, loss=None, skip_metrics=False):
        """
        preds: logits [B, C]
        targets: [B] or soft labels [B, C]
        skip_metrics: If True, only loss is tracked. Use for MixUp/CutMix batches.
        """
        if targets.ndim > 1:
            hard_targets = targets.argmax(dim=1)
        else:
            hard_targets = targets
        if not skip_metrics:
            self.map_metric.update(preds, hard_targets)
            self.acc_metric.update(preds, hard_targets)

        if loss is not None:
            self.loss_sum += loss * preds.size(0)
            self.count += preds.size(0)

    def compute(self):
        mAP = self.map_metric.compute().item()
        acc = self.acc_metric.compute().item()
        avg_loss = self.loss_sum / max(self.count, 1)
        return {"mAP": mAP, "accuracy": acc, "loss": avg_loss}