Spaces:
Running
on
Zero
Running
on
Zero
File size: 713 Bytes
c8b42eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
"""
Helper function for positional encoding in UniCeption
"""
import torch
class PositionGetter(object):
"Helper class to return positions of patches."
def __init__(self):
"Initialize the position getter."
self.cache_positions = {}
def __call__(self, b, h, w, device):
"Get the positions for a given batch size, height, and width. Uses caching."
if not (h, w) in self.cache_positions:
x = torch.arange(w, device=device)
y = torch.arange(h, device=device)
self.cache_positions[h, w] = torch.cartesian_prod(y, x) # (h, w, 2)
pos = self.cache_positions[h, w].view(1, h * w, 2).expand(b, -1, 2).clone()
return pos
|