| | import torch |
| | import pdb |
| | from functools import lru_cache |
| | import numpy as np |
| | import math |
| |
|
| | __all__=['build_relative_position', 'make_log_bucket_position'] |
| |
|
| | @lru_cache(maxsize=128) |
| | def make_log_bucket_dict(bucket_size, max_position, device=None): |
| | relative_pos = torch.arange(-max_position, max_position, device=device) |
| | sign = torch.sign(relative_pos) |
| | mid = bucket_size//2 |
| | abs_pos = torch.where((relative_pos<mid) & (relative_pos > -mid), torch.tensor(mid-1).to(relative_pos), torch.abs(relative_pos)) |
| | log_pos = torch.ceil(torch.log(abs_pos/mid)/math.log((max_position-1)/mid) * (mid-1)) + mid |
| | bucket_pos = torch.where(abs_pos<=mid, relative_pos, (log_pos*sign).to(relative_pos)).to(torch.long) |
| | return bucket_pos |
| |
|
| | |
| | def make_log_bucket_position(relative_pos, bucket_size, max_position): |
| | relative_pos = torch.clamp(relative_pos,-max_position+1, max_position-1) + max_position |
| | bucket_dict = make_log_bucket_dict(bucket_size, max_position, relative_pos.device) |
| | for d in range(relative_pos.dim()-1): |
| | bucket_dict = bucket_dict.unsqueeze(0) |
| | bucket_pos = torch.gather(bucket_dict.expand(list(relative_pos.size())[:-1] + [bucket_dict.size(-1)]), index=relative_pos.long(), dim=-1) |
| | return bucket_pos |
| |
|
| | @lru_cache(maxsize=128) |
| | def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1, device=None): |
| | q_ids = torch.arange(0, query_size) |
| | k_ids = torch.arange(0, key_size) |
| | if device is not None: |
| | q_ids = q_ids.to(device) |
| | k_ids = k_ids.to(device) |
| | rel_pos_ids = q_ids.view(-1,1) - k_ids.view(1,-1) |
| | |
| | if bucket_size>0 and max_position > 0: |
| | rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) |
| | |
| | rel_pos_ids = rel_pos_ids[:query_size, :] |
| | rel_pos_ids = rel_pos_ids.unsqueeze(0) |
| | return rel_pos_ids |
| |
|
| | def build_relative_position_from_abs(query_pos, key_pos, bucket_size=-1, max_position=-1, device=None): |
| | if isinstance(query_pos, tuple): |
| | q_ids = torch.tensor(query_pos) |
| | else: |
| | q_ids = query_pos |
| | if isinstance(key_pos, tuple): |
| | k_ids = torch.tensor(key_pos) |
| | else: |
| | k_ids = key_pos |
| |
|
| | if device is not None: |
| | q_ids = q_ids.to(device) |
| | k_ids = k_ids.to(device) |
| | rel_pos_ids = q_ids.unsqueeze(-1) - k_ids.unsqueeze(-2) |
| | |
| | if bucket_size>0 and max_position > 0: |
| | rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) |
| | |
| | return rel_pos_ids |
| |
|
| | def test_log_bucket(): |
| | x=np.arange(-511,511) |
| | y=make_log_bucket_position(x, 128, 512) |
| | pdb.set_trace() |
| |
|
| |
|