| import math |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from Affine import Affine |
|
|
| |
| def get_relative_mat(height,width,k=0): |
| posi_i = np.arange(k,height+k) |
| posi_j = np.arange(0,width) |
| posi_grid = np.meshgrid(posi_i, posi_j, indexing='ij') |
| return abs(posi_grid[0]-posi_grid[1]) |
| |
| |
| def get_relative_dist(i,j,block_size,i_end,j_end,is_cross_attention): |
| if block_size == 0: |
| assert i==0 and j==0 ,"i!=0 or j!=0" |
| return get_relative_mat(i_end,j_end,k=0) |
| if is_cross_attention: |
| return get_relative_mat(min(block_size,i_end-i),j_end,k=i) |
| |
| |
| |
| height = block_size |
| width = block_size * 3 |
| |
| |
| rela_dist = get_relative_mat(height,width,k=block_size+i-j) |
| |
| |
| down_out = max(0,i+height-i_end) |
| |
| left_out = max(0,block_size-j) |
| |
| right_out = max(0,j+block_size*2-j_end) |
| |
| rela_dist = rela_dist[:height-down_out,left_out:width-right_out] |
| return rela_dist.astype(np.float32) |
|
|
| |
| def get_absolute_mask(i,j,block_size,i_end,j_end,is_cross_attention): |
| if block_size == 0: |
| assert i==0 and j==0 ,"i!=0 or j!=0" |
| return np.triu(np.ones((i_end,j_end),dtype='bool'), k=0) |
| if is_cross_attention: |
| return np.triu(np.ones((min(block_size,i_end-i),j_end),dtype='bool'), k=i) |
| |
| |
| |
| height = block_size |
| width = block_size * 3 |
| |
| |
| abs_mask = np.triu(np.ones((height,width),dtype='bool'), k=block_size+i-j) |
| |
| |
| down_out = max(0,i+height-i_end) |
| |
| left_out = max(0,block_size-j) |
| |
| right_out = max(0,j+block_size*2-j_end) |
| |
| abs_mask = abs_mask[:height-down_out,left_out:width-right_out] |
| return abs_mask |
|
|
| |
| def get_std_mask(i,j,block_size,i_end,j_end,is_cross_attention): |
| if block_size == 0: |
| assert i==0 and j==0 ,"i!=0 or j!=0" |
| return np.triu(np.ones((i_end,j_end),dtype='bool'), k=1) == False |
| if is_cross_attention: |
| return np.triu(np.ones((min(block_size,i_end-i),j_end),dtype='bool'), k=1+i) == False |
| |
| |
| |
| height = block_size |
| width = block_size * 3 |
| |
| |
| std_mask = np.triu(np.ones((height,width),dtype='bool'), k=1+block_size+i-j) |
| |
| |
| down_out = max(0,i+height-i_end) |
| |
| left_out = max(0,block_size-j) |
| |
| right_out = max(0,j+block_size*2-j_end) |
| |
| std_mask = std_mask[:height-down_out,left_out:width-right_out] |
| return std_mask == False |
|
|
| |
| def ident(p_list): |
| i,j,block_size,i_end,j_end,is_cross_attention = p_list[1:] |
| ret = [p_list[0]] |
| if p_list[0]=='r' or p_list[0]=='a': |
| if block_size == 0: |
| ret += [i_end,j_end,0] |
| elif is_cross_attention: |
| ret += [min(block_size,i_end-i),j_end,i] |
| else: |
| height = block_size |
| width = block_size * 3 |
| ret += [height,width,block_size+i-j] |
| down_out = max(0,i+height-i_end) |
| left_out = max(0,block_size-j) |
| right_out = max(0,j+block_size*2-j_end) |
| ret += [height-down_out,left_out,width-right_out] |
| else: |
| if block_size == 0: |
| ret += [i_end,j_end,1] |
| elif is_cross_attention: |
| ret += [min(block_size,i_end-i),j_end,1+i] |
| else: |
| height = block_size |
| width = block_size * 3 |
| ret += [height,width,1+block_size+i-j] |
| down_out = max(0,i+height-i_end) |
| left_out = max(0,block_size-j) |
| right_out = max(0,j+block_size*2-j_end) |
| ret += [height-down_out,left_out,width-right_out] |
| return str(ret) |
|
|
| |
| reg_dict = dict() |
| reg_timer = dict() |
|
|
| |
| def un_reg(p): |
| return not p in reg_dict |
|
|
| |
| def reg(p,v): |
| |
| keys = [k for k in reg_dict] |
| time_min = 0 |
| if len(keys) != 0: |
| key_min = keys[0] |
| time_min = reg_timer[key_min] |
| for k in keys: |
| if reg_timer[k]<time_min: |
| key_min = k |
| time_min = reg_timer[key_min] |
| |
| if not p in reg_timer: |
| reg_timer[p] = 1 |
| else: |
| reg_timer[p] += 1 |
| |
| if len(keys) > 12: |
| del reg_dict[key_min] |
| |
| if reg_timer[p] > time_min or len(keys) < 12: |
| reg_dict[p] = v |
|
|
| |
| def get_reg(p): |
| reg_timer[p] += 1 |
| return reg_dict[p] |
|
|
| |
| def diff_attention(query, q_mask, key, value, \ |
| absolute_affine, relative_affine, diff_affine, \ |
| talking_before_softmax, talking_after_softmax, \ |
| self_attention_block_size, cross_attention_block_size, \ |
| mask_future, is_cross_attention): |
| |
| if is_cross_attention: |
| block_size = cross_attention_block_size |
| else: |
| block_size = self_attention_block_size |
| |
| |
| query_dim = query.size(-1)//2 |
| |
| |
| |
| |
| |
| q_mask = q_mask.unsqueeze(1).expand(*(query.size()[:-1])) |
| query_p = query[...,query_dim:] |
| query_n = query[...,:query_dim] |
| key_p = key[...,query_dim:] |
| key_n = key[...,:query_dim] |
| |
| if block_size == 0: |
| |
| |
| scores_p = torch.matmul(query_p,key_p.transpose(-1,-2))/math.sqrt(query_dim) |
| scores_n = torch.matmul(query_n,key_n.transpose(-1,-2))/math.sqrt(query_dim) |
| |
| |
| if relative_affine is not None: |
| p = ident(['r',0,0,0,query.size(-2),key.size(-2),is_cross_attention]) |
| if un_reg(p): |
| rela_dist = get_relative_dist(0,0,0,query.size(-2),key.size(-2),is_cross_attention) |
| |
| rela_dist = torch.from_numpy(rela_dist).detach().to(query.device) |
| reg(p,rela_dist) |
| else: |
| rela_dist = get_reg(p) |
| dist_decay= rela_dist.mul(relative_affine(1.0)).add(1.0).reciprocal() |
| scores_p = scores_p.mul(dist_decay) |
| scores_n = scores_n.mul(dist_decay) |
| |
| |
| if absolute_affine is not None: |
| p = ident(['a',0,0,0,query.size(-2),key.size(-2),is_cross_attention]) |
| if un_reg(p): |
| abs_mask = get_absolute_mask(0,0,0,query.size(-2),key.size(-2),is_cross_attention) |
| |
| abs_mask = torch.from_numpy(abs_mask).unsqueeze_(0).unsqueeze_(0).detach().to(query.device) |
| reg(p,abs_mask) |
| else: |
| abs_mask = get_reg(p) |
| abs_mask = abs_mask.expand(*(scores_p.size())) |
| value_to_sub = absolute_affine(1.0) |
| scores_p = torch.where(abs_mask == 0, scores_p - value_to_sub, scores_p) |
| scores_n = torch.where(abs_mask == 0, scores_n - value_to_sub, scores_n) |
|
|
| |
| if talking_before_softmax is not None: |
| scores_p = talking_before_softmax(scores_p.transpose(-1,-3)).transpose(-1,-3) |
| scores_n = talking_before_softmax(scores_n.transpose(-1,-3)).transpose(-1,-3) |
| |
| |
| if mask_future == True: |
| p = ident(['f',0,0,0,query.size(-2),key.size(-2),is_cross_attention]) |
| if un_reg(p): |
| |
| |
| std_mask = get_std_mask(0,0,0,query.size(-2),key.size(-2),is_cross_attention) |
| std_mask = torch.from_numpy(std_mask).unsqueeze_(0).unsqueeze_(0).detach().to(query.device) |
| reg(p,std_mask) |
| else: |
| std_mask = get_reg(p) |
| std_mask = std_mask.expand(*(scores_p.size())) |
| |
| std_mask = q_mask.unsqueeze_(-1).expand(*(std_mask.size())) & std_mask |
| scores_p.masked_fill_(std_mask == 0.0,-1e3) |
| scores_n.masked_fill_(std_mask == 0.0,-1e3) |
|
|
| |
| p_attn = F.softmax(scores_p, dim = -1) - diff_affine(F.softmax(scores_n, dim = -1)) |
|
|
| |
| if talking_after_softmax is not None: |
| p_attn = talking_after_softmax(p_attn.transpose(-1,-3)).transpose(-1,-3) |
|
|
| |
| ret = torch.matmul(p_attn, value) |
| else: |
| |
| ret = torch.zeros_like(query_p) |
| |
| for i in range(0,query.size(-2),block_size): |
| |
| query_block_p = query_p[...,i:i+block_size,:] |
| query_block_n = query_n[...,i:i+block_size,:] |
| q_mask_block = q_mask[...,i:i+block_size] |
| if is_cross_attention: |
| key_block_p = key_p |
| key_block_n = key_n |
| value_block = value |
| else: |
| key_block_p = key_p[...,max(0,i-block_size):i+block_size*2,:] |
| key_block_n = key_n[...,max(0,i-block_size):i+block_size*2,:] |
| value_block = value[...,max(0,i-block_size):i+block_size*2,:] |
| |
| scores_p = torch.matmul(query_block_p,key_block_p.transpose(-1,-2))/math.sqrt(query_dim) |
| scores_n = torch.matmul(query_block_n,key_block_n.transpose(-1,-2))/math.sqrt(query_dim) |
| |
| |
| if relative_affine is not None: |
| p = ident(['r',i,i,block_size,query.size(-2),key.size(-2),is_cross_attention]) |
| if un_reg(p): |
| rela_dist = get_relative_dist(i,i,block_size,query.size(-2),key.size(-2),is_cross_attention) |
| rela_dist = torch.from_numpy(rela_dist).detach().to(query.device) |
| reg(p,rela_dist) |
| else: |
| rela_dist = get_reg(p) |
| |
| dist_decay= rela_dist.mul(relative_affine(1.0)).add(1.0).reciprocal() |
| scores_p = scores_p.mul(dist_decay) |
| scores_n = scores_n.mul(dist_decay) |
| |
| |
| if absolute_affine is not None: |
| p = ident(['a',i,i,block_size,query.size(-2),key.size(-2),is_cross_attention]) |
| if un_reg(p): |
| abs_mask = get_absolute_mask(i,i,block_size,query.size(-2),key.size(-2),is_cross_attention) |
| abs_mask = torch.from_numpy(abs_mask).unsqueeze_(0).unsqueeze_(0).detach().to(query.device) |
| reg(p,abs_mask) |
| else: |
| abs_mask = get_reg(p) |
| abs_mask = abs_mask.expand(*(scores_p.size())) |
| value_to_sub = absolute_affine(1.0) |
| scores_p = torch.where(abs_mask == 0, scores_p - value_to_sub, scores_p) |
| scores_n = torch.where(abs_mask == 0, scores_n - value_to_sub, scores_n) |
| |
| |
| if talking_before_softmax is not None: |
| scores_p = talking_before_softmax(scores_p.transpose(-1,-3)).transpose(-1,-3) |
| scores_n = talking_before_softmax(scores_n.transpose(-1,-3)).transpose(-1,-3) |
| |
| |
| if mask_future == True: |
| p = ident(['f',i,i,block_size,query.size(-2),key.size(-2),is_cross_attention]) |
| if un_reg(p): |
| |
| std_mask = get_std_mask(i,i,block_size,query.size(-2),key.size(-2),is_cross_attention) |
| std_mask = torch.from_numpy(std_mask).unsqueeze_(0).unsqueeze_(0).detach().to(query.device) |
| reg(p,std_mask) |
| else: |
| std_mask = get_reg(p) |
| std_mask = std_mask.expand(*(scores_p.size())) |
| std_mask = q_mask_block.unsqueeze(-1).expand(*(std_mask.size())) & std_mask |
| scores_p.masked_fill_(std_mask == 0.0,-1e3) |
| scores_n.masked_fill_(std_mask == 0.0,-1e3) |
| |
| |
| p_attn = F.softmax(scores_p, dim = -1) - diff_affine(F.softmax(scores_n, dim = -1)) |
| |
| |
| if talking_after_softmax is not None: |
| p_attn = talking_after_softmax(p_attn.transpose(-1,-3)).transpose(-1,-3) |
| |
| |
| ret[...,i:i+block_size,:] = torch.matmul(p_attn, value_block) |
| return ret |
|
|
| |
| class DiffMultiHeadAttention(nn.Module): |
| def __init__(self,embedding_dim,key_dim,head_number,position_information_type,enable_affine,enable_talking_head, \ |
| self_attention_block_size,cross_attention_block_size,dropout_rate): |
| super(DiffMultiHeadAttention, self).__init__() |
| self.embedding_dim = embedding_dim |
| self.key_dim = key_dim |
| self.head_number = head_number |
| self.position_information_type = position_information_type |
| self.enable_talking_head = enable_talking_head |
| self.self_attention_block_size = self_attention_block_size |
| self.cross_attention_block_size = cross_attention_block_size |
| self.dropout_layer = nn.Dropout(p=dropout_rate) |
| self.enable_affine = enable_affine |
|
|
| self.query_w = nn.Linear(embedding_dim,key_dim*head_number*2,bias=False) |
| self.key_w = nn.Linear(embedding_dim,key_dim*head_number*2,bias=False) |
| self.value_w = nn.Linear(embedding_dim,key_dim*head_number,bias=False) |
| self.out_w = nn.Linear(key_dim*head_number,embedding_dim,bias=False) |
|
|
| if enable_affine == True: |
| self.query_a = Affine(1.0) |
| self.key_a = Affine(1.0) |
| self.value_a = Affine(1.0) |
| self.out_a = Affine(1.0) |
| self.diff_affine = Affine(1.0) |
| else: |
| self.diff_affine = None |
| |
| if enable_talking_head == True: |
| self.talking_before_softmax = nn.Linear(head_number,head_number,bias=False) |
| self.talking_after_softmax = nn.Linear(head_number,head_number,bias=False) |
| else: |
| self.talking_before_softmax = None |
| self.talking_after_softmax = None |
|
|
| if position_information_type == "mask": |
| self.absolute_affine = Affine(1.0,grad_factor=1.0) |
| self.relative_affine = Affine(0.1,grad_factor=1.0) |
| else: |
| self.absolute_affine = None |
| self.relative_affine = None |
| |
| def forward(self, query, q_mask, key_value, mask_future, is_cross_attention): |
| |
| query = self.query_w(query) |
| key = self.key_w(key_value) |
| value = self.value_w(key_value) |
| |
| |
| if self.enable_affine == True: |
| query = self.query_a(query) |
| key = self.key_a(key) |
| value = self.value_a(value) |
|
|
| |
| batch_size = query.size(0) |
| query = query.view(batch_size, -1, self.head_number, self.key_dim*2).transpose(1,2) |
| key = key.view(batch_size, -1, self.head_number, self.key_dim*2).transpose(1,2) |
| value = value.view(batch_size, -1, self.head_number, self.key_dim).transpose(1,2) |
| |
| out = diff_attention(query, q_mask, key, value, \ |
| self.absolute_affine, self.relative_affine, self.diff_affine, \ |
| self.talking_before_softmax, self.talking_after_softmax, \ |
| self_attention_block_size = self.self_attention_block_size, \ |
| cross_attention_block_size = self.cross_attention_block_size, \ |
| mask_future = mask_future, is_cross_attention = is_cross_attention) |
| |
| |
| out = out.transpose(1,2).contiguous().view(batch_size, -1, self.head_number * self.key_dim) |
| if self.enable_affine: |
| return self.dropout_layer(self.out_a(self.out_w(out))) |
| else: |
| return self.dropout_layer(self.out_w(out)) |