| |
| import concurrent.futures |
|
|
| import numpy as np |
| import torch |
|
|
| executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) |
|
|
| __all__ = ['get_mesh_id', 'save_async', 'data_seq_to_patch'] |
|
|
|
|
| def data_seq_to_patch( |
| patch_size, |
| data_seq, |
| latent_num_frames, |
| latent_height, |
| latent_width, |
| batch_size=1, |
| ): |
| p_t, p_h, p_w = patch_size |
| post_patch_num_frames = latent_num_frames // p_t |
| post_patch_height = latent_height // p_h |
| post_patch_width = latent_width // p_w |
|
|
| data_patch = data_seq.reshape(batch_size, post_patch_num_frames, |
| post_patch_height, post_patch_width, p_t, |
| p_h, p_w, -1) |
| data_patch = data_patch.permute(0, 7, 1, 4, 2, 5, 3, 6) |
| data_patch = data_patch.flatten(6, 7).flatten(4, 5).flatten(2, 3) |
| return data_patch |
|
|
|
|
| def get_mesh_id(f, h, w, t, f_w=1, f_shift=0, action=False): |
| f_idx = torch.arange(f_shift, f + f_shift) * f_w |
| h_idx = torch.arange(h) |
| w_idx = torch.arange(w) |
| ff, hh, ww = torch.meshgrid(f_idx, h_idx, w_idx, indexing='ij') |
| if action: |
| ff_offset = (torch.ones([h]).cumsum(0) / (h + 1)).view(1, -1, 1) |
| ff = ff + ff_offset |
| hh = torch.ones_like(hh) * -1 |
| ww = torch.ones_like(ww) * -1 |
|
|
| grid_id = torch.cat( |
| [ |
| ff.unsqueeze(0), |
| hh.unsqueeze(0), |
| ww.unsqueeze(0), |
| ], |
| dim=0, |
| ).flatten(1) |
| grid_id = torch.cat([grid_id, torch.full_like(grid_id[:1], t)], dim=0) |
| return grid_id |
|
|
|
|
| def save_async(obj, file_path): |
| """ |
| todo |
| """ |
| if torch.is_tensor(obj) or (isinstance(obj, dict) and any( |
| torch.is_tensor(v) for v in obj.values())): |
| if torch.is_tensor(obj): |
| if obj.is_cuda: |
| obj = obj.cpu() |
| elif isinstance(obj, dict): |
| obj = { |
| k: v.cpu() if torch.is_tensor(v) else v |
| for k, v in obj.items() |
| } |
| executor.submit(torch.save, obj, file_path) |
| elif isinstance(obj, np.ndarray): |
| obj_copy = obj.copy() |
| executor.submit(np.save, file_path, obj_copy) |
| else: |
| executor.submit(torch.save, obj, file_path) |
|
|