File size: 1,310 Bytes
37163a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch


def make_coord_grid(shape, range=(0, 1), device='cpu', batch_size=None):
    """
        Args:
            shape: (s_1, ..., s_k), grid shape
            range: range for each axis, list or tuple, [minv, maxv] or [[minv_1, maxv_1], ..., [minv_k, maxv_k]]
        Returns:
            (s_1, ..., s_k, k), coordinate grid
    """
    p_lst = []
    for i, n in enumerate(shape):
        p = (torch.arange(n, device=device) + 0.5) / n
        if isinstance(range[0], list) or isinstance(range[0], tuple):
            minv, maxv = range[i]
        else:
            minv, maxv = range
        p = minv + (maxv - minv) * p
        p_lst.append(p)
    coord = torch.stack(torch.meshgrid(*p_lst, indexing='ij'), dim=-1)
    
    if batch_size is not None:
        coord = coord.unsqueeze(0).expand(batch_size, *([-1] * coord.dim()))
    return coord


def make_coord_scale_grid(shape, range=(0, 1), device='cpu', batch_size=None):
    coord = make_coord_grid(shape, range=range, device=device, batch_size=batch_size)
    scale = torch.ones_like(coord)
    for i, n in enumerate(shape):
        if isinstance(range[0], list) or isinstance(range[0], tuple):
            minv, maxv = range[i]
        else:
            minv, maxv = range
        scale[..., i] *= (maxv - minv) / n
    return coord, scale