| |
| |
| |
| |
| |
|
|
| from collections import defaultdict |
| from contextlib import contextmanager |
| import math |
| import os |
| import tempfile |
| import typing as tp |
|
|
| import torch |
| from torch.nn import functional as F |
| from torch.utils.data import Subset |
|
|
|
|
| def unfold(a, kernel_size, stride): |
| """Given input of size [*OT, T], output Tensor of size [*OT, F, K] |
| with K the kernel size, by extracting frames with the given stride. |
| |
| This will pad the input so that `F = ceil(T / K)`. |
| |
| see https://github.com/pytorch/pytorch/issues/60466 |
| """ |
| *shape, length = a.shape |
| n_frames = math.ceil(length / stride) |
| tgt_length = (n_frames - 1) * stride + kernel_size |
| a = F.pad(a, (0, tgt_length - length)) |
| strides = list(a.stride()) |
| assert strides[-1] == 1, 'data should be contiguous' |
| strides = strides[:-1] + [stride, 1] |
| return a.as_strided([*shape, n_frames, kernel_size], strides) |
|
|
|
|
| def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]): |
| """ |
| Center trim `tensor` with respect to `reference`, along the last dimension. |
| `reference` can also be a number, representing the length to trim to. |
| If the size difference != 0 mod 2, the extra sample is removed on the right side. |
| """ |
| ref_size: int |
| if isinstance(reference, torch.Tensor): |
| ref_size = reference.size(-1) |
| else: |
| ref_size = reference |
| delta = tensor.size(-1) - ref_size |
| if delta < 0: |
| raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.") |
| if delta: |
| tensor = tensor[..., delta // 2:-(delta - delta // 2)] |
| return tensor |
|
|
|
|
| def pull_metric(history: tp.List[dict], name: str): |
| out = [] |
| for metrics in history: |
| metric = metrics |
| for part in name.split("."): |
| metric = metric[part] |
| out.append(metric) |
| return out |
|
|
|
|
| def EMA(beta: float = 1): |
| """ |
| Exponential Moving Average callback. |
| Returns a single function that can be called to repeatidly update the EMA |
| with a dict of metrics. The callback will return |
| the new averaged dict of metrics. |
| |
| Note that for `beta=1`, this is just plain averaging. |
| """ |
| fix: tp.Dict[str, float] = defaultdict(float) |
| total: tp.Dict[str, float] = defaultdict(float) |
|
|
| def _update(metrics: dict, weight: float = 1) -> dict: |
| nonlocal total, fix |
| for key, value in metrics.items(): |
| total[key] = total[key] * beta + weight * float(value) |
| fix[key] = fix[key] * beta + weight |
| return {key: tot / fix[key] for key, tot in total.items()} |
| return _update |
|
|
|
|
| def sizeof_fmt(num: float, suffix: str = 'B'): |
| """ |
| Given `num` bytes, return human readable size. |
| Taken from https://stackoverflow.com/a/1094933 |
| """ |
| for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: |
| if abs(num) < 1024.0: |
| return "%3.1f%s%s" % (num, unit, suffix) |
| num /= 1024.0 |
| return "%.1f%s%s" % (num, 'Yi', suffix) |
|
|
|
|
| @contextmanager |
| def temp_filenames(count: int, delete=True): |
| names = [] |
| try: |
| for _ in range(count): |
| names.append(tempfile.NamedTemporaryFile(delete=False).name) |
| yield names |
| finally: |
| if delete: |
| for name in names: |
| os.unlink(name) |
|
|
|
|
| def random_subset(dataset, max_samples: int, seed: int = 42): |
| if max_samples >= len(dataset): |
| return dataset |
|
|
| generator = torch.Generator().manual_seed(seed) |
| perm = torch.randperm(len(dataset), generator=generator) |
| return Subset(dataset, perm[:max_samples].tolist()) |
|
|
|
|
| class DummyPoolExecutor: |
| class DummyResult: |
| def __init__(self, func, *args, **kwargs): |
| self.func = func |
| self.args = args |
| self.kwargs = kwargs |
|
|
| def result(self): |
| return self.func(*self.args, **self.kwargs) |
|
|
| def __init__(self, workers=0): |
| pass |
|
|
| def submit(self, func, *args, **kwargs): |
| return DummyPoolExecutor.DummyResult(func, *args, **kwargs) |
|
|
| def __enter__(self): |
| return self |
|
|
| def __exit__(self, exc_type, exc_value, exc_tb): |
| return |
|
|