zh_0.18B_LLM / DiffMultiHeadAttention.py
mdokl's picture
Upload 11 files
b8199f9
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from Affine import Affine
#获取相对位置矩阵
def get_relative_mat(height,width,k=0):
posi_i = np.arange(k,height+k) #列的范围
posi_j = np.arange(0,width) #行的范围
posi_grid = np.meshgrid(posi_i, posi_j, indexing='ij')
return abs(posi_grid[0]-posi_grid[1])
#用于添加绝对位置信息的掩码
def get_relative_dist(i,j,block_size,i_end,j_end,is_cross_attention):
if block_size == 0:
assert i==0 and j==0 ,"i!=0 or j!=0"
return get_relative_mat(i_end,j_end,k=0)
if is_cross_attention:
return get_relative_mat(min(block_size,i_end-i),j_end,k=i)
#i,j:当前分块的起始位置
#block_size:分块大小
#i_end,j_end:序列的长度
height = block_size #高度,也就是第一个序列中截取的长度,与分块大小相等
width = block_size * 3 #宽度,也就是第二个序列中截取的长度,为了更长的上下文,还需要考虑上一个分块和下一个分块
#创建用来遮挡未来信息的标准掩码
#i越大,可见的部分越多,j相反,+block_size是因为上一个分块可见。
rela_dist = get_relative_mat(height,width,k=block_size+i-j)
#边界超出处理
#下超出
down_out = max(0,i+height-i_end)
#左超出
left_out = max(0,block_size-j)
#右超出
right_out = max(0,j+block_size*2-j_end)
#边界内截取
rela_dist = rela_dist[:height-down_out,left_out:width-right_out]
return rela_dist.astype(np.float32)
#用于添加绝对位置信息的掩码
def get_absolute_mask(i,j,block_size,i_end,j_end,is_cross_attention):
if block_size == 0:
assert i==0 and j==0 ,"i!=0 or j!=0"
return np.triu(np.ones((i_end,j_end),dtype='bool'), k=0)
if is_cross_attention:
return np.triu(np.ones((min(block_size,i_end-i),j_end),dtype='bool'), k=i)
#i,j:当前分块的起始位置
#block_size:分块大小
#i_end,j_end:序列的长度
height = block_size #高度,也就是第一个序列中截取的长度,与分块大小相等
width = block_size * 3 #宽度,也就是第二个序列中截取的长度,为了更长的上下文,还需要考虑上一个分块和下一个分块
#创建用来遮挡未来信息的标准掩码
#i越大,可见的部分越多,j相反,+block_size是因为上一个分块可见。
abs_mask = np.triu(np.ones((height,width),dtype='bool'), k=block_size+i-j)
#边界超出处理
#下超出
down_out = max(0,i+height-i_end)
#左超出
left_out = max(0,block_size-j)
#右超出
right_out = max(0,j+block_size*2-j_end)
#边界内截取
abs_mask = abs_mask[:height-down_out,left_out:width-right_out]
return abs_mask
#用于遮挡未来信息的标准掩码
def get_std_mask(i,j,block_size,i_end,j_end,is_cross_attention):
if block_size == 0:
assert i==0 and j==0 ,"i!=0 or j!=0"
return np.triu(np.ones((i_end,j_end),dtype='bool'), k=1) == False
if is_cross_attention:
return np.triu(np.ones((min(block_size,i_end-i),j_end),dtype='bool'), k=1+i) == False
#i,j:当前分块的起始位置
#block_size:分块大小
#i_end,j_end:序列的长度
height = block_size #高度,也就是第一个序列中截取的长度,与分块大小相等
width = block_size * 3 #宽度,也就是第二个序列中截取的长度,为了更长的上下文,还需要考虑上一个分块和下一个分块
#创建用来遮挡未来信息的标准掩码
#i越大,可见的部分越多,j相反,+block_size是因为上一个分块可见。
std_mask = np.triu(np.ones((height,width),dtype='bool'), k=1+block_size+i-j)
#边界超出处理
#下超出
down_out = max(0,i+height-i_end)
#左超出
left_out = max(0,block_size-j)
#右超出
right_out = max(0,j+block_size*2-j_end)
#边界内截取
std_mask = std_mask[:height-down_out,left_out:width-right_out]
return std_mask == False
#标记一个需要多次使用的tensor
def ident(p_list):
i,j,block_size,i_end,j_end,is_cross_attention = p_list[1:]
ret = [p_list[0]]
if p_list[0]=='r' or p_list[0]=='a':
if block_size == 0:
ret += [i_end,j_end,0]
elif is_cross_attention:
ret += [min(block_size,i_end-i),j_end,i]
else:
height = block_size
width = block_size * 3
ret += [height,width,block_size+i-j]
down_out = max(0,i+height-i_end)
left_out = max(0,block_size-j)
right_out = max(0,j+block_size*2-j_end)
ret += [height-down_out,left_out,width-right_out]
else:
if block_size == 0:
ret += [i_end,j_end,1]
elif is_cross_attention:
ret += [min(block_size,i_end-i),j_end,1+i]
else:
height = block_size
width = block_size * 3
ret += [height,width,1+block_size+i-j]
down_out = max(0,i+height-i_end)
left_out = max(0,block_size-j)
right_out = max(0,j+block_size*2-j_end)
ret += [height-down_out,left_out,width-right_out]
return str(ret)
#缓存字典与定时器
reg_dict = dict()
reg_timer = dict()
#查看是否未注册
def un_reg(p):
return not p in reg_dict
#注册需要重复使用的tensor
def reg(p,v):
#找缓冲中用的最少的
keys = [k for k in reg_dict]
time_min = 0
if len(keys) != 0:
key_min = keys[0]
time_min = reg_timer[key_min]
for k in keys:
if reg_timer[k]<time_min:
key_min = k
time_min = reg_timer[key_min]
#计数
if not p in reg_timer:
reg_timer[p] = 1
else:
reg_timer[p] += 1
#缓冲满了就删掉最少用的
if len(keys) > 12:
del reg_dict[key_min]
#比最小的值大就保留
if reg_timer[p] > time_min or len(keys) < 12:
reg_dict[p] = v
#从缓冲区中获取可重复使用的张量
def get_reg(p):
reg_timer[p] += 1
return reg_dict[p]
#注意力运算
def diff_attention(query, q_mask, key, value, \
absolute_affine, relative_affine, diff_affine, \
talking_before_softmax, talking_after_softmax, \
self_attention_block_size, cross_attention_block_size, \
mask_future, is_cross_attention):
#cross_attention有特殊的分块方式
if is_cross_attention:
block_size = cross_attention_block_size
else:
block_size = self_attention_block_size
#获取词嵌入大小,用于缩放点积运算
query_dim = query.size(-1)//2
#提前调整q_mask的形状,方便广播
#query:[batch,head,query_len,emb_dim]
#q_mask:[batch,query_len]
#q_mask:[batch,query_len]->[batch,1,query_len]
#q_mask:[batch,1,query_len]->[batch,head,query_len]
q_mask = q_mask.unsqueeze(1).expand(*(query.size()[:-1]))
query_p = query[...,query_dim:]
query_n = query[...,:query_dim]
key_p = key[...,query_dim:]
key_n = key[...,:query_dim]
#判断是否需要分块运算
if block_size == 0:
#不进行分块
#计算scores
scores_p = torch.matmul(query_p,key_p.transpose(-1,-2))/math.sqrt(query_dim)
scores_n = torch.matmul(query_n,key_n.transpose(-1,-2))/math.sqrt(query_dim)
#尝试添加相对位置信息
if relative_affine is not None:
p = ident(['r',0,0,0,query.size(-2),key.size(-2),is_cross_attention])
if un_reg(p):
rela_dist = get_relative_dist(0,0,0,query.size(-2),key.size(-2),is_cross_attention)
#直接广播更高效
rela_dist = torch.from_numpy(rela_dist).detach().to(query.device)
reg(p,rela_dist)
else:
rela_dist = get_reg(p)
dist_decay= rela_dist.mul(relative_affine(1.0)).add(1.0).reciprocal()
scores_p = scores_p.mul(dist_decay)
scores_n = scores_n.mul(dist_decay)
#尝试添加绝对位置信息
if absolute_affine is not None:
p = ident(['a',0,0,0,query.size(-2),key.size(-2),is_cross_attention])
if un_reg(p):
abs_mask = get_absolute_mask(0,0,0,query.size(-2),key.size(-2),is_cross_attention)
#mask:[query_len,key_len]->[batch,head,query_len,key_len]
abs_mask = torch.from_numpy(abs_mask).unsqueeze_(0).unsqueeze_(0).detach().to(query.device)
reg(p,abs_mask)
else:
abs_mask = get_reg(p)
abs_mask = abs_mask.expand(*(scores_p.size()))
value_to_sub = absolute_affine(1.0)
scores_p = torch.where(abs_mask == 0, scores_p - value_to_sub, scores_p)
scores_n = torch.where(abs_mask == 0, scores_n - value_to_sub, scores_n)
#遮挡信息之前先talk,这样数值稳定
if talking_before_softmax is not None:
scores_p = talking_before_softmax(scores_p.transpose(-1,-3)).transpose(-1,-3)
scores_n = talking_before_softmax(scores_n.transpose(-1,-3)).transpose(-1,-3)
#是否需要遮挡未来信息
if mask_future == True:
p = ident(['f',0,0,0,query.size(-2),key.size(-2),is_cross_attention])
if un_reg(p):
#创建遮挡未来信息的掩码
#mask:[query_len,key_len]->[batch,head,query_len,key_len]
std_mask = get_std_mask(0,0,0,query.size(-2),key.size(-2),is_cross_attention)
std_mask = torch.from_numpy(std_mask).unsqueeze_(0).unsqueeze_(0).detach().to(query.device)
reg(p,std_mask)
else:
std_mask = get_reg(p)
std_mask = std_mask.expand(*(scores_p.size()))
#q_mask:[batch,head,query_len]->[batch,head,query_len,key_len]
std_mask = q_mask.unsqueeze_(-1).expand(*(std_mask.size())) & std_mask
scores_p.masked_fill_(std_mask == 0.0,-1e3)
scores_n.masked_fill_(std_mask == 0.0,-1e3)
#计算概率权重
p_attn = F.softmax(scores_p, dim = -1) - diff_affine(F.softmax(scores_n, dim = -1))
#权重talk
if talking_after_softmax is not None:
p_attn = talking_after_softmax(p_attn.transpose(-1,-3)).transpose(-1,-3)
#计算加权求和的结果
ret = torch.matmul(p_attn, value)
else:
#分块时需要一个空间存放最终计算结果
ret = torch.zeros_like(query_p)
#分块操作
for i in range(0,query.size(-2),block_size):
#进行分块
query_block_p = query_p[...,i:i+block_size,:]
query_block_n = query_n[...,i:i+block_size,:]
q_mask_block = q_mask[...,i:i+block_size]
if is_cross_attention:
key_block_p = key_p
key_block_n = key_n
value_block = value
else:
key_block_p = key_p[...,max(0,i-block_size):i+block_size*2,:]
key_block_n = key_n[...,max(0,i-block_size):i+block_size*2,:]
value_block = value[...,max(0,i-block_size):i+block_size*2,:]
#计算scores
scores_p = torch.matmul(query_block_p,key_block_p.transpose(-1,-2))/math.sqrt(query_dim)
scores_n = torch.matmul(query_block_n,key_block_n.transpose(-1,-2))/math.sqrt(query_dim)
#尝试添加相对位置信息
if relative_affine is not None:
p = ident(['r',i,i,block_size,query.size(-2),key.size(-2),is_cross_attention])
if un_reg(p):
rela_dist = get_relative_dist(i,i,block_size,query.size(-2),key.size(-2),is_cross_attention)
rela_dist = torch.from_numpy(rela_dist).detach().to(query.device)
reg(p,rela_dist)
else:
rela_dist = get_reg(p)
# dist_decay= 1.0 / (1 + rela_dist*relative_affine(1.0))
dist_decay= rela_dist.mul(relative_affine(1.0)).add(1.0).reciprocal()
scores_p = scores_p.mul(dist_decay)
scores_n = scores_n.mul(dist_decay)
#尝试添加绝对位置信息
if absolute_affine is not None:
p = ident(['a',i,i,block_size,query.size(-2),key.size(-2),is_cross_attention])
if un_reg(p):
abs_mask = get_absolute_mask(i,i,block_size,query.size(-2),key.size(-2),is_cross_attention)
abs_mask = torch.from_numpy(abs_mask).unsqueeze_(0).unsqueeze_(0).detach().to(query.device)
reg(p,abs_mask)
else:
abs_mask = get_reg(p)
abs_mask = abs_mask.expand(*(scores_p.size()))
value_to_sub = absolute_affine(1.0)
scores_p = torch.where(abs_mask == 0, scores_p - value_to_sub, scores_p)
scores_n = torch.where(abs_mask == 0, scores_n - value_to_sub, scores_n)
#遮挡信息之前先talk,这样数值稳定
if talking_before_softmax is not None:
scores_p = talking_before_softmax(scores_p.transpose(-1,-3)).transpose(-1,-3)
scores_n = talking_before_softmax(scores_n.transpose(-1,-3)).transpose(-1,-3)
#是否需要遮挡未来信息
if mask_future == True:
p = ident(['f',i,i,block_size,query.size(-2),key.size(-2),is_cross_attention])
if un_reg(p):
#创建遮挡未来信息的掩码,因为是批次操作,需要进行升维
std_mask = get_std_mask(i,i,block_size,query.size(-2),key.size(-2),is_cross_attention)
std_mask = torch.from_numpy(std_mask).unsqueeze_(0).unsqueeze_(0).detach().to(query.device)
reg(p,std_mask)
else:
std_mask = get_reg(p)
std_mask = std_mask.expand(*(scores_p.size()))
std_mask = q_mask_block.unsqueeze(-1).expand(*(std_mask.size())) & std_mask
scores_p.masked_fill_(std_mask == 0.0,-1e3)
scores_n.masked_fill_(std_mask == 0.0,-1e3)
#计算概率权重
p_attn = F.softmax(scores_p, dim = -1) - diff_affine(F.softmax(scores_n, dim = -1))
#权重talk
if talking_after_softmax is not None:
p_attn = talking_after_softmax(p_attn.transpose(-1,-3)).transpose(-1,-3)
#计算加权求和的结果
ret[...,i:i+block_size,:] = torch.matmul(p_attn, value_block)
return ret
#多头注意力
class DiffMultiHeadAttention(nn.Module):
def __init__(self,embedding_dim,key_dim,head_number,position_information_type,enable_affine,enable_talking_head, \
self_attention_block_size,cross_attention_block_size,dropout_rate):
super(DiffMultiHeadAttention, self).__init__()
self.embedding_dim = embedding_dim
self.key_dim = key_dim
self.head_number = head_number
self.position_information_type = position_information_type
self.enable_talking_head = enable_talking_head
self.self_attention_block_size = self_attention_block_size
self.cross_attention_block_size = cross_attention_block_size
self.dropout_layer = nn.Dropout(p=dropout_rate)
self.enable_affine = enable_affine
self.query_w = nn.Linear(embedding_dim,key_dim*head_number*2,bias=False)
self.key_w = nn.Linear(embedding_dim,key_dim*head_number*2,bias=False)
self.value_w = nn.Linear(embedding_dim,key_dim*head_number,bias=False)
self.out_w = nn.Linear(key_dim*head_number,embedding_dim,bias=False)
if enable_affine == True:
self.query_a = Affine(1.0)
self.key_a = Affine(1.0)
self.value_a = Affine(1.0)
self.out_a = Affine(1.0)
self.diff_affine = Affine(1.0)
else:
self.diff_affine = None
if enable_talking_head == True:
self.talking_before_softmax = nn.Linear(head_number,head_number,bias=False)
self.talking_after_softmax = nn.Linear(head_number,head_number,bias=False)
else:
self.talking_before_softmax = None
self.talking_after_softmax = None
if position_information_type == "mask":
self.absolute_affine = Affine(1.0,grad_factor=1.0)
self.relative_affine = Affine(0.1,grad_factor=1.0)
else:
self.absolute_affine = None
self.relative_affine = None
def forward(self, query, q_mask, key_value, mask_future, is_cross_attention):
#经过线性变换得到真正的QKV
query = self.query_w(query)
key = self.key_w(key_value)
value = self.value_w(key_value)
#进行仿射变换,加快训练速度
if self.enable_affine == True:
query = self.query_a(query)
key = self.key_a(key)
value = self.value_a(value)
#对词嵌入进行划分,准备进行多头注意力计算
batch_size = query.size(0)
query = query.view(batch_size, -1, self.head_number, self.key_dim*2).transpose(1,2)
key = key.view(batch_size, -1, self.head_number, self.key_dim*2).transpose(1,2)
value = value.view(batch_size, -1, self.head_number, self.key_dim).transpose(1,2)
#计算分块多头注意力
out = diff_attention(query, q_mask, key, value, \
self.absolute_affine, self.relative_affine, self.diff_affine, \
self.talking_before_softmax, self.talking_after_softmax, \
self_attention_block_size = self.self_attention_block_size, \
cross_attention_block_size = self.cross_attention_block_size, \
mask_future = mask_future, is_cross_attention = is_cross_attention)
#将计算完注意力的结果拼接回去
out = out.transpose(1,2).contiguous().view(batch_size, -1, self.head_number * self.key_dim)
if self.enable_affine:
return self.dropout_layer(self.out_a(self.out_w(out)))
else:
return self.dropout_layer(self.out_w(out))