primepake
add training flowvae
4f877a2
raw
history blame
1.31 kB
import torch
def make_coord_grid(shape, range=(0, 1), device='cpu', batch_size=None):
"""
Args:
shape: (s_1, ..., s_k), grid shape
range: range for each axis, list or tuple, [minv, maxv] or [[minv_1, maxv_1], ..., [minv_k, maxv_k]]
Returns:
(s_1, ..., s_k, k), coordinate grid
"""
p_lst = []
for i, n in enumerate(shape):
p = (torch.arange(n, device=device) + 0.5) / n
if isinstance(range[0], list) or isinstance(range[0], tuple):
minv, maxv = range[i]
else:
minv, maxv = range
p = minv + (maxv - minv) * p
p_lst.append(p)
coord = torch.stack(torch.meshgrid(*p_lst, indexing='ij'), dim=-1)
if batch_size is not None:
coord = coord.unsqueeze(0).expand(batch_size, *([-1] * coord.dim()))
return coord
def make_coord_scale_grid(shape, range=(0, 1), device='cpu', batch_size=None):
coord = make_coord_grid(shape, range=range, device=device, batch_size=batch_size)
scale = torch.ones_like(coord)
for i, n in enumerate(shape):
if isinstance(range[0], list) or isinstance(range[0], tuple):
minv, maxv = range[i]
else:
minv, maxv = range
scale[..., i] *= (maxv - minv) / n
return coord, scale