Spaces:
Sleeping
Sleeping
| from contextlib import ContextDecorator | |
| import torch | |
| import torch.nn.functional as F | |
| all_times = [] | |
| class Timer(ContextDecorator): | |
| def __init__(self, name, enabled=True): | |
| self.name = name | |
| self.enabled = enabled | |
| if self.enabled: | |
| self.start = torch.cuda.Event(enable_timing=True) | |
| self.end = torch.cuda.Event(enable_timing=True) | |
| def __enter__(self): | |
| if self.enabled: | |
| self.start.record() | |
| def __exit__(self, type, value, traceback): | |
| global all_times | |
| if self.enabled: | |
| self.end.record() | |
| torch.cuda.synchronize() | |
| elapsed = self.start.elapsed_time(self.end) | |
| all_times.append(elapsed) | |
| print(f"{self.name} {elapsed:.03f}") | |
| def coords_grid(b, n, h, w, **kwargs): | |
| """ coordinate grid """ | |
| x = torch.arange(0, w, dtype=torch.float, **kwargs) | |
| y = torch.arange(0, h, dtype=torch.float, **kwargs) | |
| coords = torch.stack(torch.meshgrid(y, x, indexing="ij")) | |
| return coords[[1,0]].view(1, 1, 2, h, w).repeat(b, n, 1, 1, 1) | |
| def coords_grid_with_index(d, **kwargs): | |
| """ coordinate grid with frame index""" | |
| b, n, h, w = d.shape | |
| i = torch.ones_like(d) | |
| x = torch.arange(0, w, dtype=torch.float, **kwargs) | |
| y = torch.arange(0, h, dtype=torch.float, **kwargs) | |
| y, x = torch.stack(torch.meshgrid(y, x, indexing="ij")) | |
| y = y.view(1, 1, h, w).repeat(b, n, 1, 1) | |
| x = x.view(1, 1, h, w).repeat(b, n, 1, 1) | |
| coords = torch.stack([x, y, d], dim=2) | |
| index = torch.arange(0, n, dtype=torch.float, **kwargs) | |
| index = index.view(1, n, 1, 1, 1).repeat(b, 1, 1, h, w) | |
| return coords, index | |
| def patchify(x, patch_size=3): | |
| """ extract patches from video """ | |
| b, n, c, h, w = x.shape | |
| x = x.view(b*n, c, h, w) | |
| y = F.unfold(x, patch_size) | |
| y = y.transpose(1,2) | |
| return y.reshape(b, -1, c, patch_size, patch_size) | |
| def pyramidify(fmap, lvls=[1]): | |
| """ turn fmap into a pyramid """ | |
| b, n, c, h, w = fmap.shape | |
| pyramid = [] | |
| for lvl in lvls: | |
| gmap = F.avg_pool2d(fmap.view(b*n, c, h, w), lvl, stride=lvl) | |
| pyramid += [ gmap.view(b, n, c, h//lvl, w//lvl) ] | |
| return pyramid | |
| def all_pairs_exclusive(n, **kwargs): | |
| ii, jj = torch.meshgrid(torch.arange(n, **kwargs), torch.arange(n, **kwargs)) | |
| k = ii != jj | |
| return ii[k].reshape(-1), jj[k].reshape(-1) | |
| def set_depth(patches, depth): | |
| patches[...,2,:,:] = depth[...,None,None] | |
| return patches | |
| def flatmeshgrid(*args, **kwargs): | |
| grid = torch.meshgrid(*args, **kwargs) | |
| return (x.reshape(-1) for x in grid) | |