from torchmetrics import classification from dataclasses import dataclass, field from typing import Any, Dict, Mapping from utils import parse_structure from torch import Tensor, nn import lightning.pytorch as pl import torch @dataclass class BaseMetricsConfig: metrics_names:list = field(default_factory=list) metrics_short_names:list = field(default_factory=list) class BaseMetrics(pl.LightningModule): def __init__(self, cfg: Dict, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.cfg: BaseMetricsConfig = parse_structure(BaseMetricsConfig, cfg) self.metrics_names = self.cfg.metrics_names self.metrics_short_names = self.cfg.metrics_short_names self.metrics = nn.ModuleDict() for name, short_name in zip(self.metrics_names, self.metrics_short_names): obj = getattr(classification, name) metric = obj() self.metrics[short_name] = metric print(f"[INFO]: Metrics: {self.metrics}") def __call__(self, pred: Tensor, target: Tensor, prefix:str) -> Dict[str, float]: pred = torch.sigmoid(pred).round() return {f'{prefix}/{name}': metric(pred, target) for name, metric in self.metrics.items()}