| | import torch |
| |
|
| |
|
| | def generate_grid(n_vox, interval): |
| | """ |
| | generate grid |
| | if 3D volume, grid[:,:,x,y,z] = (x,y,z) |
| | :param n_vox: |
| | :param interval: |
| | :return: |
| | """ |
| | with torch.no_grad(): |
| | |
| | grid_range = [torch.arange(0, n_vox[axis], interval) for axis in range(3)] |
| | grid = torch.stack(torch.meshgrid(grid_range[0], grid_range[1], grid_range[2], indexing="ij")) |
| | |
| | grid = grid.unsqueeze(0).type(torch.float32) |
| |
|
| | return grid |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import torch.nn.functional as F |
| | grid = generate_grid([5, 6, 8], 1) |
| |
|
| | pts = 2 * torch.tensor([1, 2, 3]) / (torch.tensor([5, 6, 8]) - 1) - 1 |
| | pts = pts.view(1, 1, 1, 1, 3) |
| |
|
| | pts = torch.flip(pts, dims=[-1]) |
| |
|
| | sampled = F.grid_sample(grid, pts, mode='nearest') |
| |
|
| | print(sampled) |
| |
|