| import torch
|
| from torch import nn
|
| import torch.nn.functional as F
|
| from dataclasses import dataclass
|
|
|
| @dataclass
|
| class TransformerConfig:
|
| vocab_size: int
|
| block_size: int
|
| n_embed: int
|
| n_heads: int
|
| n_layers: int
|
| dropout: float = 0.0
|
| bias: bool = True
|
|
|
| class MultiHeadAttention(nn.Module):
|
| """
|
| 多头注意力模块
|
| """
|
| def __init__(self, config: TransformerConfig):
|
| super().__init__()
|
| assert config.n_embed % config.n_heads == 0
|
| self.config = config
|
| self.head_size = config.n_embed // config.n_heads
|
| self.c_attn = nn.Linear(config.n_embed, config.n_embed * 3, bias = config.bias)
|
| self.c_proj = nn.Linear(config.n_embed, config.n_embed)
|
| self.attention_dropout = nn.Dropout(config.dropout)
|
| self.residue_dropout = nn.Dropout(config.dropout)
|
|
|
| self.flash_att = hasattr(F, 'scaled_dot_product_attention')
|
| if not self.flash_att:
|
| print('警告:未使用Flash Attention, 这可能减慢模型计算速度。')
|
|
|
| self.register_buffer('mask', torch.tril(torch.ones(config.block_size, config.block_size).view(1,1,config.block_size,config.block_size)))
|
|
|
| def forward(self, x):
|
| B,T,C = x.shape
|
| q,k,v = self.c_attn(x).split(self.config.n_embed, dim=2)
|
| q = q.view(B,T,self.config.n_heads,self.head_size).transpose(1,2)
|
| k = k.view(B,T,self.config.n_heads,self.head_size).transpose(1,2)
|
| v = v.view(B,T,self.config.n_heads,self.head_size).transpose(1,2)
|
| if self.flash_att:
|
| out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.config.dropout if self.training else 0.0, is_causal=True)
|
| else:
|
| scale = self.head_size**(-0.5)
|
| weight = q @ k.transpose(-2,-1) * scale
|
| weight = weight.masked_fill(self.mask[:,:,:T,:T]==0, float('-inf'))
|
| weight = F.softmax(weight, dim=-1)
|
| weight = self.attention_dropout(weight)
|
| out = weight @ v
|
| out = out.transpose(1,2).contiguous().view(B,T,C)
|
| out = self.residue_dropout(self.c_proj(out))
|
| return out
|
|
|
| def forward_with_cache(self, x, past_key_values=None, use_cache=False):
|
| B,T,C = x.shape
|
| q,k,v = self.c_attn(x).split(self.config.n_embed, dim=2)
|
| q = q.view(B,T,self.config.n_heads,self.head_size).transpose(1,2)
|
| k = k.view(B,T,self.config.n_heads,self.head_size).transpose(1,2)
|
| v = v.view(B,T,self.config.n_heads,self.head_size).transpose(1,2)
|
| if past_key_values is not None:
|
| past_k, past_v = past_key_values
|
| k = torch.cat([past_k, k], dim=2).contiguous()
|
| v = torch.cat([past_v, v], dim=2).contiguous()
|
| scale = self.head_size**(-0.5)
|
| weight = q @ k.transpose(-2,-1) * scale
|
| if past_key_values is None:
|
| weight = weight.masked_fill(self.mask[:,:,:T,:T]==0, float('-inf'))
|
| weight = F.softmax(weight, dim=-1)
|
| weight = self.attention_dropout(weight)
|
| out = weight @ v
|
| out = out.transpose(1,2).contiguous().view(B,T,C)
|
| out = self.residue_dropout(self.c_proj(out))
|
| if use_cache:
|
| kv_cache = (k,v)
|
| else:
|
| kv_cache = None
|
| return (out, kv_cache)
|
|
|
| class FeedForward(nn.Module):
|
| """
|
| 一个简单的前馈网络模块,包含两层线性层和中间的激活函数
|
| """
|
| def __init__(self, config: TransformerConfig) -> None:
|
| super().__init__()
|
| self.config = config
|
| self.layer_1 = nn.Linear(config.n_embed, 4 * config.n_embed, bias=config.bias)
|
| self.gelu = nn.GELU()
|
| self.layer_2 = nn.Linear(4 * config.n_embed, config.n_embed, bias=config.bias)
|
| self.dropout = nn.Dropout(config.dropout)
|
|
|
| def forward(self, x):
|
| out = self.layer_1(x)
|
| out = self.gelu(out)
|
| out = self.layer_2(out)
|
| out = self.dropout(out)
|
| return out
|
|
|
| class TransformerBlock(nn.Module):
|
| """
|
| Transformer块
|
| """
|
| def __init__(self, config: TransformerConfig):
|
| super().__init__()
|
| self.config = config
|
| self.mha = MultiHeadAttention(config)
|
| self.fwd = FeedForward(config)
|
| self.ln1 = nn.LayerNorm(config.n_embed)
|
| self.ln2 = nn.LayerNorm(config.n_embed)
|
|
|
| def forward(self, x):
|
| x = x + self.mha(self.ln1(x))
|
| x = x + self.fwd(self.ln2(x))
|
| return x
|
|
|
| def forward_with_cache(self, x, kv_cache=None, use_cache=False):
|
| y = self.ln1(x)
|
| y = self.mha.forward_with_cache(y, kv_cache, use_cache)
|
| x = x + y[0]
|
| x = x + self.fwd(self.ln2(x))
|
| return (x, y[1])
|
|
|
| class TransformerLanguageModel(nn.Module):
|
| """
|
| Transformer语言模型
|
| """
|
| def __init__(self, config:TransformerConfig):
|
| super().__init__()
|
| self.config = config
|
|
|
| self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embed)
|
|
|
| self.position_embedding_table = nn.Embedding(config.block_size, config.n_embed)
|
|
|
| self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
|
|
|
| self.ln_f = nn.LayerNorm(config.n_embed)
|
|
|
| self.lm_head = nn.Linear(config.n_embed, config.vocab_size)
|
|
|
| def forward(self, idx, targets=None, device='cuda:0'):
|
| B, T = idx.shape
|
| token_embed = self.token_embedding_table(idx)
|
| pos_embed = self.position_embedding_table(torch.arange(T, device=device))
|
| x = token_embed + pos_embed
|
| for block in self.blocks:
|
| x = block(x)
|
| logits = self.lm_head(self.ln_f(x))
|
| if targets is None:
|
| loss = None
|
| else:
|
|
|
| logits = logits.view(B*T, self.config.vocab_size)
|
| targets = targets.view(B*T)
|
| loss = F.cross_entropy(logits, targets)
|
| return logits, loss
|
|
|
| def forward_with_cache(self, idx, kv_cache=None, use_cache=False, targets=None, device='cuda:0'):
|
| B, T = idx.shape
|
| token_embed = self.token_embedding_table(idx)
|
| if kv_cache is None:
|
| pos_embed = self.position_embedding_table(torch.arange(T, device=device))
|
| else:
|
| past_len = kv_cache[0][0].size(2)
|
| pos_embed = self.position_embedding_table(torch.arange(past_len, T+past_len, device=device))
|
| x = token_embed + pos_embed
|
|
|
|
|
| if kv_cache is not None:
|
| new_cache = []
|
| for i,block in enumerate(self.blocks):
|
| x, curr_cache = block.forward_with_cache(x,kv_cache[i],use_cache)
|
| new_cache.append(curr_cache)
|
|
|
|
|
| else:
|
| if use_cache:
|
| new_cache = [0] * self.config.n_layers
|
| for i,block in enumerate(self.blocks):
|
| x1 = block.forward_with_cache(x,None,use_cache)
|
| x = x1[0]
|
| if use_cache:
|
| new_cache[i] = x1[1]
|
|
|
| x = self.ln_f(x)
|
| logits = self.lm_head(x)
|
|
|
| if not use_cache:
|
| new_cache = None
|
| if targets is None:
|
| loss = None
|
| else:
|
|
|
| logits = logits.view(B*T, self.config.vocab_size)
|
| targets = targets.view(B*T)
|
| loss = F.cross_entropy(logits, targets)
|
|
|
| return logits, new_cache, loss
|
|
|
| @torch.no_grad()
|
| def generate(self, idx, max_new_tokens=300, temperature=1.0, top_k=0, kv_cache = None, use_cache=False):
|
| flag = False
|
| curr_kv_cache = kv_cache
|
| for _ in range(max_new_tokens):
|
| if curr_kv_cache is None:
|
| idx_cond = idx[:, -self.config.block_size:]
|
| else:
|
| if flag:
|
| idx_cond = idx[:, -1:]
|
| curr_kv_cache = [(item[0][:,:,-self.config.block_size:,:],item[1][:,:,-self.config.block_size:,:]) for item in curr_kv_cache]
|
| else:
|
| length0 = idx.shape[1]
|
| length1 = kv_cache[0][0].shape[-2]
|
| if length0 > self.config.block_size:
|
| idx_cond = idx[:, -self.config.block_size:]
|
| kv_cache = None
|
| else:
|
| idx_cond = idx[:, :]
|
| if length0 + length1 > self.config.block_size:
|
| length2 = self.config.block_size-length0
|
| kv_cache = [(item[0][-length2:],item[1][-length2:]) for item in kv_cache]
|
| logits, curr_kv_cache, _ = self.forward_with_cache(idx_cond, curr_kv_cache, use_cache)
|
|
|
| logits = logits[:,-1,:] / temperature
|
| if top_k > 0:
|
| logits_top, _ = torch.topk(logits, min(top_k, logits.shape[-1]))
|
| logits[logits < logits_top[:,[-1]]] = -float('Inf')
|
| probs = F.softmax(logits, dim=-1)
|
| next_idx = torch.multinomial(probs, num_samples=1)
|
| idx = torch.cat((idx, next_idx), dim=1)
|
| flag = True
|
| return idx
|
|
|
| @torch.no_grad()
|
| def generate_normal(self, idx, max_new_tokens=300):
|
|
|
| for _ in range(max_new_tokens):
|
| idx_cond = idx[:,-self.config.block_size:]
|
| logits, _ = self(idx_cond)
|
| last_logits = logits[:,-1,:]
|
| probs = F.softmax(last_logits, dim=-1)
|
| next_idx = torch.multinomial(probs, num_samples=1)
|
| idx = torch.cat((idx, next_idx), dim=1)
|
|
|
| return idx |