File size: 10,680 Bytes
e10f35b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 | import torch
from torch import nn
import torch.nn.functional as F
from dataclasses import dataclass
@dataclass
class TransformerConfig:
vocab_size: int
block_size: int
n_embed: int
n_heads: int
n_layers: int
dropout: float = 0.0
bias: bool = True
class MultiHeadAttention(nn.Module):
"""
多头注意力模块
"""
def __init__(self, config: TransformerConfig):
super().__init__()
assert config.n_embed % config.n_heads == 0
self.config = config
self.head_size = config.n_embed // config.n_heads
self.c_attn = nn.Linear(config.n_embed, config.n_embed * 3, bias = config.bias)
self.c_proj = nn.Linear(config.n_embed, config.n_embed)
self.attention_dropout = nn.Dropout(config.dropout)
self.residue_dropout = nn.Dropout(config.dropout)
# 是否支持flash attention
self.flash_att = hasattr(F, 'scaled_dot_product_attention')
if not self.flash_att:
print('警告:未使用Flash Attention, 这可能减慢模型计算速度。')
# casual mask需要使用的下三角矩阵
self.register_buffer('mask', torch.tril(torch.ones(config.block_size, config.block_size).view(1,1,config.block_size,config.block_size)))
def forward(self, x):
B,T,C = x.shape
q,k,v = self.c_attn(x).split(self.config.n_embed, dim=2)
q = q.view(B,T,self.config.n_heads,self.head_size).transpose(1,2)
k = k.view(B,T,self.config.n_heads,self.head_size).transpose(1,2)
v = v.view(B,T,self.config.n_heads,self.head_size).transpose(1,2)
if self.flash_att:
out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.config.dropout if self.training else 0.0, is_causal=True)
else:
scale = self.head_size**(-0.5)
weight = q @ k.transpose(-2,-1) * scale
weight = weight.masked_fill(self.mask[:,:,:T,:T]==0, float('-inf'))
weight = F.softmax(weight, dim=-1)
weight = self.attention_dropout(weight)
out = weight @ v
out = out.transpose(1,2).contiguous().view(B,T,C)
out = self.residue_dropout(self.c_proj(out))
return out
def forward_with_cache(self, x, past_key_values=None, use_cache=False):
B,T,C = x.shape
q,k,v = self.c_attn(x).split(self.config.n_embed, dim=2)
q = q.view(B,T,self.config.n_heads,self.head_size).transpose(1,2)
k = k.view(B,T,self.config.n_heads,self.head_size).transpose(1,2)
v = v.view(B,T,self.config.n_heads,self.head_size).transpose(1,2)
if past_key_values is not None:
past_k, past_v = past_key_values
k = torch.cat([past_k, k], dim=2).contiguous()
v = torch.cat([past_v, v], dim=2).contiguous()
scale = self.head_size**(-0.5)
weight = q @ k.transpose(-2,-1) * scale
if past_key_values is None:
weight = weight.masked_fill(self.mask[:,:,:T,:T]==0, float('-inf'))
weight = F.softmax(weight, dim=-1)
weight = self.attention_dropout(weight)
out = weight @ v
out = out.transpose(1,2).contiguous().view(B,T,C)
out = self.residue_dropout(self.c_proj(out))
if use_cache:
kv_cache = (k,v)
else:
kv_cache = None
return (out, kv_cache)
class FeedForward(nn.Module):
"""
一个简单的前馈网络模块,包含两层线性层和中间的激活函数
"""
def __init__(self, config: TransformerConfig) -> None:
super().__init__()
self.config = config
self.layer_1 = nn.Linear(config.n_embed, 4 * config.n_embed, bias=config.bias)
self.gelu = nn.GELU()
self.layer_2 = nn.Linear(4 * config.n_embed, config.n_embed, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
out = self.layer_1(x)
out = self.gelu(out)
out = self.layer_2(out)
out = self.dropout(out)
return out
class TransformerBlock(nn.Module):
"""
Transformer块
"""
def __init__(self, config: TransformerConfig):
super().__init__()
self.config = config
self.mha = MultiHeadAttention(config)
self.fwd = FeedForward(config)
self.ln1 = nn.LayerNorm(config.n_embed)
self.ln2 = nn.LayerNorm(config.n_embed)
def forward(self, x):
x = x + self.mha(self.ln1(x))
x = x + self.fwd(self.ln2(x))
return x
def forward_with_cache(self, x, kv_cache=None, use_cache=False):
y = self.ln1(x)
y = self.mha.forward_with_cache(y, kv_cache, use_cache)
x = x + y[0]
x = x + self.fwd(self.ln2(x))
return (x, y[1])
class TransformerLanguageModel(nn.Module):
"""
Transformer语言模型
"""
def __init__(self, config:TransformerConfig):
super().__init__()
self.config = config
# token嵌入层
self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embed)
# 位置编码层
self.position_embedding_table = nn.Embedding(config.block_size, config.n_embed)
# Transformer主体,由一系列堆叠的Transformer块组成
self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
# 最后的LayerNorm层
self.ln_f = nn.LayerNorm(config.n_embed)
# 语言模型头,用于预测下一个token
self.lm_head = nn.Linear(config.n_embed, config.vocab_size)
def forward(self, idx, targets=None, device='cuda:0'):
B, T = idx.shape # batch size, context length
token_embed = self.token_embedding_table(idx) # token嵌入向量
pos_embed = self.position_embedding_table(torch.arange(T, device=device)) # 位置嵌入向量
x = token_embed + pos_embed # 两个向量相加输入到Transformer块中
for block in self.blocks:
x = block(x)
logits = self.lm_head(self.ln_f(x))
if targets is None:
loss = None
else:
# 计算交叉熵损失
logits = logits.view(B*T, self.config.vocab_size)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return logits, loss
def forward_with_cache(self, idx, kv_cache=None, use_cache=False, targets=None, device='cuda:0'):
B, T = idx.shape # batch size, context length
token_embed = self.token_embedding_table(idx) # token嵌入向量
if kv_cache is None:
pos_embed = self.position_embedding_table(torch.arange(T, device=device)) # 位置嵌入向量
else:
past_len = kv_cache[0][0].size(2)
pos_embed = self.position_embedding_table(torch.arange(past_len, T+past_len, device=device)) # 位置嵌入向量
x = token_embed + pos_embed # 两个向量相加输入到Transformer块中
# 张量顺次通过各个Transformer块
if kv_cache is not None:
new_cache = []
for i,block in enumerate(self.blocks):
x, curr_cache = block.forward_with_cache(x,kv_cache[i],use_cache)
new_cache.append(curr_cache)
# x = x1[0]
# kv_cache[i] = x1[1]
else:
if use_cache:
new_cache = [0] * self.config.n_layers
for i,block in enumerate(self.blocks):
x1 = block.forward_with_cache(x,None,use_cache)
x = x1[0]
if use_cache:
new_cache[i] = x1[1]
x = self.ln_f(x) # 张量通过最后的LayerNorm层
logits = self.lm_head(x) # 使用语言模型头得到logits
if not use_cache:
new_cache = None
if targets is None:
loss = None
else:
# 计算交叉熵损失
logits = logits.view(B*T, self.config.vocab_size)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return logits, new_cache, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens=300, temperature=1.0, top_k=0, kv_cache = None, use_cache=False):
flag = False
curr_kv_cache = kv_cache
for _ in range(max_new_tokens):
if curr_kv_cache is None:
idx_cond = idx[:, -self.config.block_size:]
else:
if flag:
idx_cond = idx[:, -1:]
curr_kv_cache = [(item[0][:,:,-self.config.block_size:,:],item[1][:,:,-self.config.block_size:,:]) for item in curr_kv_cache]
else:
length0 = idx.shape[1]
length1 = kv_cache[0][0].shape[-2]
if length0 > self.config.block_size:
idx_cond = idx[:, -self.config.block_size:]
kv_cache = None
else:
idx_cond = idx[:, :]
if length0 + length1 > self.config.block_size:
length2 = self.config.block_size-length0
kv_cache = [(item[0][-length2:],item[1][-length2:]) for item in kv_cache]
logits, curr_kv_cache, _ = self.forward_with_cache(idx_cond, curr_kv_cache, use_cache)
# print(kv_cache[0][0].shape)
logits = logits[:,-1,:] / temperature
if top_k > 0:
logits_top, _ = torch.topk(logits, min(top_k, logits.shape[-1]))
logits[logits < logits_top[:,[-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1) # 计算各token出现的概率
next_idx = torch.multinomial(probs, num_samples=1) # 采样
idx = torch.cat((idx, next_idx), dim=1) # (B,T+1)
flag = True
return idx
@torch.no_grad()
def generate_normal(self, idx, max_new_tokens=300):
for _ in range(max_new_tokens):
idx_cond = idx[:,-self.config.block_size:]
logits, _ = self(idx_cond)
last_logits = logits[:,-1,:]
probs = F.softmax(last_logits, dim=-1) # 计算各token出现的概率
next_idx = torch.multinomial(probs, num_samples=1) # 采样
idx = torch.cat((idx, next_idx), dim=1) # (B,T+1)
return idx |