File size: 2,275 Bytes
0a7036f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | # Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
import concurrent.futures
import numpy as np
import torch
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
__all__ = ['get_mesh_id', 'save_async', 'data_seq_to_patch']
def data_seq_to_patch(
patch_size,
data_seq,
latent_num_frames,
latent_height,
latent_width,
batch_size=1,
):
p_t, p_h, p_w = patch_size
post_patch_num_frames = latent_num_frames // p_t
post_patch_height = latent_height // p_h
post_patch_width = latent_width // p_w
data_patch = data_seq.reshape(batch_size, post_patch_num_frames,
post_patch_height, post_patch_width, p_t,
p_h, p_w, -1)
data_patch = data_patch.permute(0, 7, 1, 4, 2, 5, 3, 6)
data_patch = data_patch.flatten(6, 7).flatten(4, 5).flatten(2, 3)
return data_patch
def get_mesh_id(f, h, w, t, f_w=1, f_shift=0, action=False):
f_idx = torch.arange(f_shift, f + f_shift) * f_w
h_idx = torch.arange(h)
w_idx = torch.arange(w)
ff, hh, ww = torch.meshgrid(f_idx, h_idx, w_idx, indexing='ij')
if action:
ff_offset = (torch.ones([h]).cumsum(0) / (h + 1)).view(1, -1, 1)
ff = ff + ff_offset
hh = torch.ones_like(hh) * -1
ww = torch.ones_like(ww) * -1
grid_id = torch.cat(
[
ff.unsqueeze(0),
hh.unsqueeze(0),
ww.unsqueeze(0),
],
dim=0,
).flatten(1)
grid_id = torch.cat([grid_id, torch.full_like(grid_id[:1], t)], dim=0)
return grid_id
def save_async(obj, file_path):
"""
todo
"""
if torch.is_tensor(obj) or (isinstance(obj, dict) and any(
torch.is_tensor(v) for v in obj.values())):
if torch.is_tensor(obj):
if obj.is_cuda:
obj = obj.cpu()
elif isinstance(obj, dict):
obj = {
k: v.cpu() if torch.is_tensor(v) else v
for k, v in obj.items()
}
executor.submit(torch.save, obj, file_path)
elif isinstance(obj, np.ndarray):
obj_copy = obj.copy()
executor.submit(np.save, file_path, obj_copy)
else:
executor.submit(torch.save, obj, file_path)
|