|
|
from torchmetrics import classification |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Any, Dict, Mapping |
|
|
from utils import parse_structure |
|
|
from torch import Tensor, nn |
|
|
|
|
|
import lightning.pytorch as pl |
|
|
import torch |
|
|
|
|
|
@dataclass |
|
|
class BaseMetricsConfig: |
|
|
metrics_names:list = field(default_factory=list) |
|
|
metrics_short_names:list = field(default_factory=list) |
|
|
|
|
|
class BaseMetrics(pl.LightningModule): |
|
|
def __init__(self, cfg: Dict, *args: Any, **kwargs: Any) -> None: |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
self.cfg: BaseMetricsConfig = parse_structure(BaseMetricsConfig, cfg) |
|
|
self.metrics_names = self.cfg.metrics_names |
|
|
self.metrics_short_names = self.cfg.metrics_short_names |
|
|
|
|
|
self.metrics = nn.ModuleDict() |
|
|
for name, short_name in zip(self.metrics_names, self.metrics_short_names): |
|
|
obj = getattr(classification, name) |
|
|
metric = obj() |
|
|
self.metrics[short_name] = metric |
|
|
|
|
|
print(f"[INFO]: Metrics: {self.metrics}") |
|
|
|
|
|
def __call__(self, pred: Tensor, target: Tensor, prefix:str) -> Dict[str, float]: |
|
|
pred = torch.sigmoid(pred).round() |
|
|
return {f'{prefix}/{name}': metric(pred, target) for name, metric in self.metrics.items()} |