File size: 2,721 Bytes
ab0f6ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import pdb
from functools import lru_cache
import numpy as np
import math

__all__=['build_relative_position', 'make_log_bucket_position']

@lru_cache(maxsize=128)
def make_log_bucket_dict(bucket_size, max_position, device=None):
  relative_pos = torch.arange(-max_position, max_position, device=device)
  sign = torch.sign(relative_pos)
  mid = bucket_size//2
  abs_pos = torch.where((relative_pos<mid) & (relative_pos > -mid), torch.tensor(mid-1).to(relative_pos), torch.abs(relative_pos))
  log_pos = torch.ceil(torch.log(abs_pos/mid)/math.log((max_position-1)/mid) * (mid-1)) + mid
  bucket_pos = torch.where(abs_pos<=mid, relative_pos, (log_pos*sign).to(relative_pos)).to(torch.long)
  return bucket_pos

# Faster version
def make_log_bucket_position(relative_pos, bucket_size, max_position):
  relative_pos = torch.clamp(relative_pos,-max_position+1, max_position-1) + max_position
  bucket_dict = make_log_bucket_dict(bucket_size, max_position, relative_pos.device)
  for d in range(relative_pos.dim()-1):
    bucket_dict = bucket_dict.unsqueeze(0)
    bucket_pos = torch.gather(bucket_dict.expand(list(relative_pos.size())[:-1] + [bucket_dict.size(-1)]), index=relative_pos.long(), dim=-1)
  return bucket_pos

@lru_cache(maxsize=128)
def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1, device=None):
  q_ids = torch.arange(0, query_size)
  k_ids = torch.arange(0, key_size)
  if device is not None:
    q_ids = q_ids.to(device)
    k_ids = k_ids.to(device)
  rel_pos_ids = q_ids.view(-1,1) - k_ids.view(1,-1)
  #q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0],1))
  if bucket_size>0 and max_position > 0:
    rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
  #rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
  rel_pos_ids = rel_pos_ids[:query_size, :]
  rel_pos_ids = rel_pos_ids.unsqueeze(0)
  return rel_pos_ids

def build_relative_position_from_abs(query_pos, key_pos, bucket_size=-1, max_position=-1, device=None):
  if isinstance(query_pos, tuple):
    q_ids = torch.tensor(query_pos)
  else:
    q_ids = query_pos
  if isinstance(key_pos, tuple):
    k_ids = torch.tensor(key_pos)
  else:
    k_ids = key_pos

  if device is not None:
    q_ids = q_ids.to(device)
    k_ids = k_ids.to(device)
  rel_pos_ids = q_ids.unsqueeze(-1) - k_ids.unsqueeze(-2)
  #q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0],1))
  if bucket_size>0 and max_position > 0:
    rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
  #rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
  return rel_pos_ids

def test_log_bucket():
  x=np.arange(-511,511)
  y=make_log_bucket_position(x, 128, 512)
  pdb.set_trace()