File size: 283 Bytes
c8b42eb
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
from functools import lru_cache

import torch


@lru_cache(maxsize=10)
def get_meshgrid_torch(W, H, device):
    u, v = torch.meshgrid(torch.arange(W, device=device).float(), torch.arange(H, device=device).float(), indexing="xy")

    uv = torch.stack((u, v), dim=-1)

    return uv