"""Utility functions""" import logging import numpy as np import torch logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) class PredAccumulator: """A class for accumulating y-predictions using grad accumulation and small batch size. Attributes: _y_hats (list[torch.Tensor]): List of prediction tensors """ def __init__(self): """Prediction accumulator""" self._y_hats = [] def __bool__(self): return len(self._y_hats) > 0 def append(self, y_hat: torch.Tensor): """Append a sub-batch of predictions""" self._y_hats.append(y_hat) def flush(self) -> torch.Tensor: """Return all appended predictions as single tensor and remove from accumulated store.""" y_hat = torch.cat(self._y_hats, dim=0) self._y_hats = [] return y_hat class DictListAccumulator: """Abstract class for accumulating dictionaries of lists""" @staticmethod def _dict_list_append(d1, d2): for k, v in d2.items(): d1[k].append(v) @staticmethod def _dict_init_list(d): return {k: [v] for k, v in d.items()} class MetricAccumulator(DictListAccumulator): """Dictionary of metrics accumulator. A class for accumulating, and finding the mean of logging metrics when using grad accumulation and the batch size is small. Attributes: _metrics (Dict[str, list[float]]): Dictionary containing lists of metrics. """ def __init__(self): """Dictionary of metrics accumulator.""" self._metrics = {} def __bool__(self): return self._metrics != {} def append(self, loss_dict: dict[str, float]): """Append lictionary of metrics to self""" if not self: self._metrics = self._dict_init_list(loss_dict) else: self._dict_list_append(self._metrics, loss_dict) def flush(self) -> dict[str, float]: """Calculate mean of all accumulated metrics and clear""" mean_metrics = {k: np.mean(v) for k, v in self._metrics.items()} self._metrics = {} return mean_metrics class BatchAccumulator(DictListAccumulator): """A class for accumulating batches when using grad accumulation and the batch size is small. Attributes: _batches (Dict[str, list[torch.Tensor]]): Dictionary containing lists of metrics. """ def __init__(self, key_to_keep: str = "gsp"): """Batch accumulator""" self._batches = {} self.key_to_keep = key_to_keep def __bool__(self): return self._batches != {} # @staticmethod def _filter_batch_dict(self, d): keep_keys = [ self.key_to_keep, f"{self.key_to_keep}_id", f"{self.key_to_keep}_t0_idx", f"{self.key_to_keep}_time_utc", ] return {k: v for k, v in d.items() if k in keep_keys} def append(self, batch: dict[str, list[torch.Tensor]]): """Append batch to self""" if not self: self._batches = self._dict_init_list(self._filter_batch_dict(batch)) else: self._dict_list_append(self._batches, self._filter_batch_dict(batch)) def flush(self) -> dict[str, list[torch.Tensor]]: """Concatenate all accumulated batches, return, and clear self""" batch = {} for k, v in self._batches.items(): if k == f"{self.key_to_keep}_t0_idx": batch[k] = v[0] else: batch[k] = torch.cat(v, dim=0) self._batches = {} return batch