Spaces:
Sleeping
Sleeping
| """ | |
| Alle transforms sind grundsätzlich auf batches bezogen! | |
| Vae transforms sind invertierbar | |
| """ | |
| import pickle | |
| from dataclasses import dataclass | |
| from functools import partial, reduce, wraps | |
| import numpy as np | |
| import torch | |
| # Allgemeine Funktionen ------------------------------------------------------------- | |
| # Transformations in Pytorch sind am einfachsten. | |
| def load(p): | |
| with open(p, "rb") as stream: | |
| return pickle.load(stream) | |
| def save(obj, p): | |
| with open(p, "wb") as stream: | |
| pickle.dump(obj, stream) | |
| def sequential_function(*functions): | |
| return lambda x: reduce(lambda res, func: func(res), functions, x) | |
| def np_sample(func): | |
| rtn = sequential_function( | |
| lambda x: torch.from_numpy(x).float(), | |
| lambda x: torch.unsqueeze(x, 0), | |
| func, | |
| lambda x: x[0].numpy(), | |
| ) | |
| return rtn | |
| # Inverseabvle | |
| class SequentialInversable(torch.nn.Sequential): | |
| def __init__(self, *functions): | |
| super().__init__(*functions) | |
| self.inv_funcs = [f.inv for f in functions] | |
| self.inv_funcs.reverse() | |
| # def forward(self, x): | |
| # return sequential_function(*self.functions)(x) | |
| def inv(self, x): | |
| return sequential_function(*self.inv_funcs)(x) | |
| class LatentSelector(torch.nn.Module): | |
| """Verarbeitet Tensoren und numpy arrays""" | |
| def __init__(self, ldim: int, selectdim: int): | |
| super().__init__() | |
| self.ldim = ldim | |
| self.selectdim = selectdim | |
| def forward(self, x: torch.Tensor): | |
| return x[:, : self.selectdim] | |
| def inv(self, x: torch.Tensor): | |
| rtn = torch.cat( | |
| [x, torch.zeros((x.shape[0], self.ldim - x.shape[1]), device=x.device)], | |
| dim=1, | |
| ) | |
| return rtn | |
| class MinMaxScaler(torch.nn.Module): | |
| #! Bei mehreren Signalen vorsicht mit dem Broadcasting. | |
| def __init__( | |
| self, | |
| _min: torch.Tensor, | |
| _max: torch.Tensor, | |
| min_norm: float = 0.0, | |
| max_norm: float = 1.0, | |
| ): | |
| super().__init__() | |
| self._min = _min | |
| self._max = _max | |
| self.min_norm = min_norm | |
| self.max_norm = max_norm | |
| def forward(self, ts): | |
| """None, no_signals""" | |
| std = (ts - self._min) / (self._max - self._min) | |
| rtn = std * (self.max_norm - self.min_norm) + self.min_norm | |
| return rtn | |
| def inv(self, ts): | |
| std = (ts - self.min_norm) / (self.max_norm - self.min_norm) | |
| rtn = std * (self._max - self._min) + self._min | |
| return rtn | |
| def from_array(cls, arr: torch.Tensor): | |
| _min = torch.min(arr, axis=0).values | |
| _max = torch.max(arr, axis=0).values | |
| return cls(_min, _max) | |
| class LatentSorter(torch.nn.Module): | |
| def __init__(self, kl_dict: dict): | |
| super().__init__() | |
| self.kl_dict = kl_dict | |
| def forward(self, latent): | |
| """ | |
| unsorted -> sorted | |
| latent: (None, latent_dim) | |
| """ | |
| return latent[:, list(self.kl_dict.keys())] | |
| def inv(self, latent): | |
| keys = np.array(list(self.kl_dict.keys())) | |
| return latent[:, torch.from_numpy(keys.argsort())] | |
| def names(self): | |
| rtn = ["{} KL{:.2f}".format(k, v) for k, v in self.kl_dict.items()] | |
| return rtn | |
| def apply_along_axis(function, x, axis: int = 0): | |
| return torch.stack([function(x_i) for x_i in torch.unbind(x, dim=axis)], dim=axis) | |
| # Eingangsshapes bleiben wie sie sind! | |
| class SumField(torch.nn.Module): | |
| """ | |
| time series: [idx, time_step, signal] | |
| image: [idx, signal, time_step, time_step] | |
| """ | |
| def forward(self, ts: torch.Tensor): | |
| """ts2img""" | |
| samples = ts.shape[0] | |
| time = ts.shape[1] | |
| channels = ts.shape[2] | |
| ts = torch.swapaxes(ts, 1, 2) # Zeitachse ans Ende | |
| ts = torch.reshape( | |
| ts, (samples * channels, time) | |
| ) # Zusammenfassen von Channel + idx | |
| #! TODO: Schleife besser lösen | |
| rtn = apply_along_axis(self._mtf_forward, ts, 0) | |
| rtn = torch.reshape(rtn, (samples, channels, time, time)) | |
| return rtn | |
| def inv(self, img: torch.Tensor): | |
| """img2ts""" | |
| rtn = torch.diagonal(img, dim1=2, dim2=3) | |
| rtn = torch.swapaxes(rtn, 1, 2) # Channel und Zeitachse tauschen | |
| return rtn | |
| def _mtf_forward(ts): | |
| """For one dimensional time series ts""" | |
| return torch.add(*torch.meshgrid(ts, ts, indexing="ij")) / 2 | |