import torch from torch import nn import torch.nn.functional as F from dataclasses import dataclass @dataclass class TransformerConfig: vocab_size: int block_size: int n_embed: int n_heads: int n_layers: int dropout: float = 0.0 bias: bool = True class MultiHeadAttention(nn.Module): """ 多头注意力模块 """ def __init__(self, config: TransformerConfig): super().__init__() assert config.n_embed % config.n_heads == 0 self.config = config self.head_size = config.n_embed // config.n_heads self.c_attn = nn.Linear(config.n_embed, config.n_embed * 3, bias = config.bias) self.c_proj = nn.Linear(config.n_embed, config.n_embed) self.attention_dropout = nn.Dropout(config.dropout) self.residue_dropout = nn.Dropout(config.dropout) # 是否支持flash attention self.flash_att = hasattr(F, 'scaled_dot_product_attention') if not self.flash_att: print('警告:未使用Flash Attention, 这可能减慢模型计算速度。') # casual mask需要使用的下三角矩阵 self.register_buffer('mask', torch.tril(torch.ones(config.block_size, config.block_size).view(1,1,config.block_size,config.block_size))) def forward(self, x): B,T,C = x.shape q,k,v = self.c_attn(x).split(self.config.n_embed, dim=2) q = q.view(B,T,self.config.n_heads,self.head_size).transpose(1,2) k = k.view(B,T,self.config.n_heads,self.head_size).transpose(1,2) v = v.view(B,T,self.config.n_heads,self.head_size).transpose(1,2) if self.flash_att: out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.config.dropout if self.training else 0.0, is_causal=True) else: scale = self.head_size**(-0.5) weight = q @ k.transpose(-2,-1) * scale weight = weight.masked_fill(self.mask[:,:,:T,:T]==0, float('-inf')) weight = F.softmax(weight, dim=-1) weight = self.attention_dropout(weight) out = weight @ v out = out.transpose(1,2).contiguous().view(B,T,C) out = self.residue_dropout(self.c_proj(out)) return out def forward_with_cache(self, x, past_key_values=None, use_cache=False): B,T,C = x.shape q,k,v = self.c_attn(x).split(self.config.n_embed, dim=2) q = q.view(B,T,self.config.n_heads,self.head_size).transpose(1,2) k = k.view(B,T,self.config.n_heads,self.head_size).transpose(1,2) v = v.view(B,T,self.config.n_heads,self.head_size).transpose(1,2) if past_key_values is not None: past_k, past_v = past_key_values k = torch.cat([past_k, k], dim=2).contiguous() v = torch.cat([past_v, v], dim=2).contiguous() scale = self.head_size**(-0.5) weight = q @ k.transpose(-2,-1) * scale if past_key_values is None: weight = weight.masked_fill(self.mask[:,:,:T,:T]==0, float('-inf')) weight = F.softmax(weight, dim=-1) weight = self.attention_dropout(weight) out = weight @ v out = out.transpose(1,2).contiguous().view(B,T,C) out = self.residue_dropout(self.c_proj(out)) if use_cache: kv_cache = (k,v) else: kv_cache = None return (out, kv_cache) class FeedForward(nn.Module): """ 一个简单的前馈网络模块,包含两层线性层和中间的激活函数 """ def __init__(self, config: TransformerConfig) -> None: super().__init__() self.config = config self.layer_1 = nn.Linear(config.n_embed, 4 * config.n_embed, bias=config.bias) self.gelu = nn.GELU() self.layer_2 = nn.Linear(4 * config.n_embed, config.n_embed, bias=config.bias) self.dropout = nn.Dropout(config.dropout) def forward(self, x): out = self.layer_1(x) out = self.gelu(out) out = self.layer_2(out) out = self.dropout(out) return out class TransformerBlock(nn.Module): """ Transformer块 """ def __init__(self, config: TransformerConfig): super().__init__() self.config = config self.mha = MultiHeadAttention(config) self.fwd = FeedForward(config) self.ln1 = nn.LayerNorm(config.n_embed) self.ln2 = nn.LayerNorm(config.n_embed) def forward(self, x): x = x + self.mha(self.ln1(x)) x = x + self.fwd(self.ln2(x)) return x def forward_with_cache(self, x, kv_cache=None, use_cache=False): y = self.ln1(x) y = self.mha.forward_with_cache(y, kv_cache, use_cache) x = x + y[0] x = x + self.fwd(self.ln2(x)) return (x, y[1]) class TransformerLanguageModel(nn.Module): """ Transformer语言模型 """ def __init__(self, config:TransformerConfig): super().__init__() self.config = config # token嵌入层 self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embed) # 位置编码层 self.position_embedding_table = nn.Embedding(config.block_size, config.n_embed) # Transformer主体,由一系列堆叠的Transformer块组成 self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)]) # 最后的LayerNorm层 self.ln_f = nn.LayerNorm(config.n_embed) # 语言模型头,用于预测下一个token self.lm_head = nn.Linear(config.n_embed, config.vocab_size) def forward(self, idx, targets=None, device='cuda:0'): B, T = idx.shape # batch size, context length token_embed = self.token_embedding_table(idx) # token嵌入向量 pos_embed = self.position_embedding_table(torch.arange(T, device=device)) # 位置嵌入向量 x = token_embed + pos_embed # 两个向量相加输入到Transformer块中 for block in self.blocks: x = block(x) logits = self.lm_head(self.ln_f(x)) if targets is None: loss = None else: # 计算交叉熵损失 logits = logits.view(B*T, self.config.vocab_size) targets = targets.view(B*T) loss = F.cross_entropy(logits, targets) return logits, loss def forward_with_cache(self, idx, kv_cache=None, use_cache=False, targets=None, device='cuda:0'): B, T = idx.shape # batch size, context length token_embed = self.token_embedding_table(idx) # token嵌入向量 if kv_cache is None: pos_embed = self.position_embedding_table(torch.arange(T, device=device)) # 位置嵌入向量 else: past_len = kv_cache[0][0].size(2) pos_embed = self.position_embedding_table(torch.arange(past_len, T+past_len, device=device)) # 位置嵌入向量 x = token_embed + pos_embed # 两个向量相加输入到Transformer块中 # 张量顺次通过各个Transformer块 if kv_cache is not None: new_cache = [] for i,block in enumerate(self.blocks): x, curr_cache = block.forward_with_cache(x,kv_cache[i],use_cache) new_cache.append(curr_cache) # x = x1[0] # kv_cache[i] = x1[1] else: if use_cache: new_cache = [0] * self.config.n_layers for i,block in enumerate(self.blocks): x1 = block.forward_with_cache(x,None,use_cache) x = x1[0] if use_cache: new_cache[i] = x1[1] x = self.ln_f(x) # 张量通过最后的LayerNorm层 logits = self.lm_head(x) # 使用语言模型头得到logits if not use_cache: new_cache = None if targets is None: loss = None else: # 计算交叉熵损失 logits = logits.view(B*T, self.config.vocab_size) targets = targets.view(B*T) loss = F.cross_entropy(logits, targets) return logits, new_cache, loss @torch.no_grad() def generate(self, idx, max_new_tokens=300, temperature=1.0, top_k=0, kv_cache = None, use_cache=False): flag = False curr_kv_cache = kv_cache for _ in range(max_new_tokens): if curr_kv_cache is None: idx_cond = idx[:, -self.config.block_size:] else: if flag: idx_cond = idx[:, -1:] curr_kv_cache = [(item[0][:,:,-self.config.block_size:,:],item[1][:,:,-self.config.block_size:,:]) for item in curr_kv_cache] else: length0 = idx.shape[1] length1 = kv_cache[0][0].shape[-2] if length0 > self.config.block_size: idx_cond = idx[:, -self.config.block_size:] kv_cache = None else: idx_cond = idx[:, :] if length0 + length1 > self.config.block_size: length2 = self.config.block_size-length0 kv_cache = [(item[0][-length2:],item[1][-length2:]) for item in kv_cache] logits, curr_kv_cache, _ = self.forward_with_cache(idx_cond, curr_kv_cache, use_cache) # print(kv_cache[0][0].shape) logits = logits[:,-1,:] / temperature if top_k > 0: logits_top, _ = torch.topk(logits, min(top_k, logits.shape[-1])) logits[logits < logits_top[:,[-1]]] = -float('Inf') probs = F.softmax(logits, dim=-1) # 计算各token出现的概率 next_idx = torch.multinomial(probs, num_samples=1) # 采样 idx = torch.cat((idx, next_idx), dim=1) # (B,T+1) flag = True return idx @torch.no_grad() def generate_normal(self, idx, max_new_tokens=300): for _ in range(max_new_tokens): idx_cond = idx[:,-self.config.block_size:] logits, _ = self(idx_cond) last_logits = logits[:,-1,:] probs = F.softmax(last_logits, dim=-1) # 计算各token出现的概率 next_idx = torch.multinomial(probs, num_samples=1) # 采样 idx = torch.cat((idx, next_idx), dim=1) # (B,T+1) return idx