| import torch | |
| class TorchHijackForUnet: | |
| """ | |
| This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match; | |
| this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64 | |
| """ | |
| def __getattr__(self, item): | |
| if item == 'cat': | |
| return self.cat | |
| if hasattr(torch, item): | |
| return getattr(torch, item) | |
| raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) | |
| def cat(self, tensors, *args, **kwargs): | |
| if len(tensors) == 2: | |
| a, b = tensors | |
| if a.shape[-2:] != b.shape[-2:]: | |
| a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest") | |
| tensors = (a, b) | |
| return torch.cat(tensors, *args, **kwargs) | |
| th = TorchHijackForUnet() | |