File size: 3,605 Bytes
cbe6208 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
"""Utility functions"""
import logging
import numpy as np
import torch
logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
class PredAccumulator:
"""A class for accumulating y-predictions using grad accumulation and small batch size.
Attributes:
_y_hats (list[torch.Tensor]): List of prediction tensors
"""
def __init__(self):
"""Prediction accumulator"""
self._y_hats = []
def __bool__(self):
return len(self._y_hats) > 0
def append(self, y_hat: torch.Tensor):
"""Append a sub-batch of predictions"""
self._y_hats.append(y_hat)
def flush(self) -> torch.Tensor:
"""Return all appended predictions as single tensor and remove from accumulated store."""
y_hat = torch.cat(self._y_hats, dim=0)
self._y_hats = []
return y_hat
class DictListAccumulator:
"""Abstract class for accumulating dictionaries of lists"""
@staticmethod
def _dict_list_append(d1, d2):
for k, v in d2.items():
d1[k].append(v)
@staticmethod
def _dict_init_list(d):
return {k: [v] for k, v in d.items()}
class MetricAccumulator(DictListAccumulator):
"""Dictionary of metrics accumulator.
A class for accumulating, and finding the mean of logging metrics when using grad
accumulation and the batch size is small.
Attributes:
_metrics (Dict[str, list[float]]): Dictionary containing lists of metrics.
"""
def __init__(self):
"""Dictionary of metrics accumulator."""
self._metrics = {}
def __bool__(self):
return self._metrics != {}
def append(self, loss_dict: dict[str, float]):
"""Append lictionary of metrics to self"""
if not self:
self._metrics = self._dict_init_list(loss_dict)
else:
self._dict_list_append(self._metrics, loss_dict)
def flush(self) -> dict[str, float]:
"""Calculate mean of all accumulated metrics and clear"""
mean_metrics = {k: np.mean(v) for k, v in self._metrics.items()}
self._metrics = {}
return mean_metrics
class BatchAccumulator(DictListAccumulator):
"""A class for accumulating batches when using grad accumulation and the batch size is small.
Attributes:
_batches (Dict[str, list[torch.Tensor]]): Dictionary containing lists of metrics.
"""
def __init__(self, key_to_keep: str = "gsp"):
"""Batch accumulator"""
self._batches = {}
self.key_to_keep = key_to_keep
def __bool__(self):
return self._batches != {}
# @staticmethod
def _filter_batch_dict(self, d):
keep_keys = [
self.key_to_keep,
f"{self.key_to_keep}_id",
f"{self.key_to_keep}_t0_idx",
f"{self.key_to_keep}_time_utc",
]
return {k: v for k, v in d.items() if k in keep_keys}
def append(self, batch: dict[str, list[torch.Tensor]]):
"""Append batch to self"""
if not self:
self._batches = self._dict_init_list(self._filter_batch_dict(batch))
else:
self._dict_list_append(self._batches, self._filter_batch_dict(batch))
def flush(self) -> dict[str, list[torch.Tensor]]:
"""Concatenate all accumulated batches, return, and clear self"""
batch = {}
for k, v in self._batches.items():
if k == f"{self.key_to_keep}_t0_idx":
batch[k] = v[0]
else:
batch[k] = torch.cat(v, dim=0)
self._batches = {}
return batch
|