hispath / metrics /base.py
kohido's picture
init
8bf25c8
from torchmetrics import classification
from dataclasses import dataclass, field
from typing import Any, Dict, Mapping
from utils import parse_structure
from torch import Tensor, nn
import lightning.pytorch as pl
import torch
@dataclass
class BaseMetricsConfig:
metrics_names:list = field(default_factory=list)
metrics_short_names:list = field(default_factory=list)
class BaseMetrics(pl.LightningModule):
def __init__(self, cfg: Dict, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.cfg: BaseMetricsConfig = parse_structure(BaseMetricsConfig, cfg)
self.metrics_names = self.cfg.metrics_names
self.metrics_short_names = self.cfg.metrics_short_names
self.metrics = nn.ModuleDict()
for name, short_name in zip(self.metrics_names, self.metrics_short_names):
obj = getattr(classification, name)
metric = obj()
self.metrics[short_name] = metric
print(f"[INFO]: Metrics: {self.metrics}")
def __call__(self, pred: Tensor, target: Tensor, prefix:str) -> Dict[str, float]:
pred = torch.sigmoid(pred).round()
return {f'{prefix}/{name}': metric(pred, target) for name, metric in self.metrics.items()}