File size: 2,721 Bytes
ab0f6ec | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | import torch
import pdb
from functools import lru_cache
import numpy as np
import math
__all__=['build_relative_position', 'make_log_bucket_position']
@lru_cache(maxsize=128)
def make_log_bucket_dict(bucket_size, max_position, device=None):
relative_pos = torch.arange(-max_position, max_position, device=device)
sign = torch.sign(relative_pos)
mid = bucket_size//2
abs_pos = torch.where((relative_pos<mid) & (relative_pos > -mid), torch.tensor(mid-1).to(relative_pos), torch.abs(relative_pos))
log_pos = torch.ceil(torch.log(abs_pos/mid)/math.log((max_position-1)/mid) * (mid-1)) + mid
bucket_pos = torch.where(abs_pos<=mid, relative_pos, (log_pos*sign).to(relative_pos)).to(torch.long)
return bucket_pos
# Faster version
def make_log_bucket_position(relative_pos, bucket_size, max_position):
relative_pos = torch.clamp(relative_pos,-max_position+1, max_position-1) + max_position
bucket_dict = make_log_bucket_dict(bucket_size, max_position, relative_pos.device)
for d in range(relative_pos.dim()-1):
bucket_dict = bucket_dict.unsqueeze(0)
bucket_pos = torch.gather(bucket_dict.expand(list(relative_pos.size())[:-1] + [bucket_dict.size(-1)]), index=relative_pos.long(), dim=-1)
return bucket_pos
@lru_cache(maxsize=128)
def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1, device=None):
q_ids = torch.arange(0, query_size)
k_ids = torch.arange(0, key_size)
if device is not None:
q_ids = q_ids.to(device)
k_ids = k_ids.to(device)
rel_pos_ids = q_ids.view(-1,1) - k_ids.view(1,-1)
#q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0],1))
if bucket_size>0 and max_position > 0:
rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
#rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
rel_pos_ids = rel_pos_ids[:query_size, :]
rel_pos_ids = rel_pos_ids.unsqueeze(0)
return rel_pos_ids
def build_relative_position_from_abs(query_pos, key_pos, bucket_size=-1, max_position=-1, device=None):
if isinstance(query_pos, tuple):
q_ids = torch.tensor(query_pos)
else:
q_ids = query_pos
if isinstance(key_pos, tuple):
k_ids = torch.tensor(key_pos)
else:
k_ids = key_pos
if device is not None:
q_ids = q_ids.to(device)
k_ids = k_ids.to(device)
rel_pos_ids = q_ids.unsqueeze(-1) - k_ids.unsqueeze(-2)
#q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0],1))
if bucket_size>0 and max_position > 0:
rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
#rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
return rel_pos_ids
def test_log_bucket():
x=np.arange(-511,511)
y=make_log_bucket_position(x, 128, 512)
pdb.set_trace()
|