hugfaceguy0001's picture
upload model and train/infer codes
e10f35b verified
Raw
History Blame Contribute Delete
10.7 kB
import torch
from torch import nn
import torch.nn.functional as F
from dataclasses import dataclass
@dataclass
class TransformerConfig:
vocab_size: int
block_size: int
n_embed: int
n_heads: int
n_layers: int
dropout: float = 0.0
bias: bool = True
class MultiHeadAttention(nn.Module):
"""
多头注意力模块
"""
def __init__(self, config: TransformerConfig):
super().__init__()
assert config.n_embed % config.n_heads == 0
self.config = config
self.head_size = config.n_embed // config.n_heads
self.c_attn = nn.Linear(config.n_embed, config.n_embed * 3, bias = config.bias)
self.c_proj = nn.Linear(config.n_embed, config.n_embed)
self.attention_dropout = nn.Dropout(config.dropout)
self.residue_dropout = nn.Dropout(config.dropout)
# 是否支持flash attention
self.flash_att = hasattr(F, 'scaled_dot_product_attention')
if not self.flash_att:
print('警告:未使用Flash Attention, 这可能减慢模型计算速度。')
# casual mask需要使用的下三角矩阵
self.register_buffer('mask', torch.tril(torch.ones(config.block_size, config.block_size).view(1,1,config.block_size,config.block_size)))
def forward(self, x):
B,T,C = x.shape
q,k,v = self.c_attn(x).split(self.config.n_embed, dim=2)
q = q.view(B,T,self.config.n_heads,self.head_size).transpose(1,2)
k = k.view(B,T,self.config.n_heads,self.head_size).transpose(1,2)
v = v.view(B,T,self.config.n_heads,self.head_size).transpose(1,2)
if self.flash_att:
out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.config.dropout if self.training else 0.0, is_causal=True)
else:
scale = self.head_size**(-0.5)
weight = q @ k.transpose(-2,-1) * scale
weight = weight.masked_fill(self.mask[:,:,:T,:T]==0, float('-inf'))
weight = F.softmax(weight, dim=-1)
weight = self.attention_dropout(weight)
out = weight @ v
out = out.transpose(1,2).contiguous().view(B,T,C)
out = self.residue_dropout(self.c_proj(out))
return out
def forward_with_cache(self, x, past_key_values=None, use_cache=False):
B,T,C = x.shape
q,k,v = self.c_attn(x).split(self.config.n_embed, dim=2)
q = q.view(B,T,self.config.n_heads,self.head_size).transpose(1,2)
k = k.view(B,T,self.config.n_heads,self.head_size).transpose(1,2)
v = v.view(B,T,self.config.n_heads,self.head_size).transpose(1,2)
if past_key_values is not None:
past_k, past_v = past_key_values
k = torch.cat([past_k, k], dim=2).contiguous()
v = torch.cat([past_v, v], dim=2).contiguous()
scale = self.head_size**(-0.5)
weight = q @ k.transpose(-2,-1) * scale
if past_key_values is None:
weight = weight.masked_fill(self.mask[:,:,:T,:T]==0, float('-inf'))
weight = F.softmax(weight, dim=-1)
weight = self.attention_dropout(weight)
out = weight @ v
out = out.transpose(1,2).contiguous().view(B,T,C)
out = self.residue_dropout(self.c_proj(out))
if use_cache:
kv_cache = (k,v)
else:
kv_cache = None
return (out, kv_cache)
class FeedForward(nn.Module):
"""
一个简单的前馈网络模块,包含两层线性层和中间的激活函数
"""
def __init__(self, config: TransformerConfig) -> None:
super().__init__()
self.config = config
self.layer_1 = nn.Linear(config.n_embed, 4 * config.n_embed, bias=config.bias)
self.gelu = nn.GELU()
self.layer_2 = nn.Linear(4 * config.n_embed, config.n_embed, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
out = self.layer_1(x)
out = self.gelu(out)
out = self.layer_2(out)
out = self.dropout(out)
return out
class TransformerBlock(nn.Module):
"""
Transformer块
"""
def __init__(self, config: TransformerConfig):
super().__init__()
self.config = config
self.mha = MultiHeadAttention(config)
self.fwd = FeedForward(config)
self.ln1 = nn.LayerNorm(config.n_embed)
self.ln2 = nn.LayerNorm(config.n_embed)
def forward(self, x):
x = x + self.mha(self.ln1(x))
x = x + self.fwd(self.ln2(x))
return x
def forward_with_cache(self, x, kv_cache=None, use_cache=False):
y = self.ln1(x)
y = self.mha.forward_with_cache(y, kv_cache, use_cache)
x = x + y[0]
x = x + self.fwd(self.ln2(x))
return (x, y[1])
class TransformerLanguageModel(nn.Module):
"""
Transformer语言模型
"""
def __init__(self, config:TransformerConfig):
super().__init__()
self.config = config
# token嵌入层
self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embed)
# 位置编码层
self.position_embedding_table = nn.Embedding(config.block_size, config.n_embed)
# Transformer主体,由一系列堆叠的Transformer块组成
self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
# 最后的LayerNorm层
self.ln_f = nn.LayerNorm(config.n_embed)
# 语言模型头,用于预测下一个token
self.lm_head = nn.Linear(config.n_embed, config.vocab_size)
def forward(self, idx, targets=None, device='cuda:0'):
B, T = idx.shape # batch size, context length
token_embed = self.token_embedding_table(idx) # token嵌入向量
pos_embed = self.position_embedding_table(torch.arange(T, device=device)) # 位置嵌入向量
x = token_embed + pos_embed # 两个向量相加输入到Transformer块中
for block in self.blocks:
x = block(x)
logits = self.lm_head(self.ln_f(x))
if targets is None:
loss = None
else:
# 计算交叉熵损失
logits = logits.view(B*T, self.config.vocab_size)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return logits, loss
def forward_with_cache(self, idx, kv_cache=None, use_cache=False, targets=None, device='cuda:0'):
B, T = idx.shape # batch size, context length
token_embed = self.token_embedding_table(idx) # token嵌入向量
if kv_cache is None:
pos_embed = self.position_embedding_table(torch.arange(T, device=device)) # 位置嵌入向量
else:
past_len = kv_cache[0][0].size(2)
pos_embed = self.position_embedding_table(torch.arange(past_len, T+past_len, device=device)) # 位置嵌入向量
x = token_embed + pos_embed # 两个向量相加输入到Transformer块中
# 张量顺次通过各个Transformer块
if kv_cache is not None:
new_cache = []
for i,block in enumerate(self.blocks):
x, curr_cache = block.forward_with_cache(x,kv_cache[i],use_cache)
new_cache.append(curr_cache)
# x = x1[0]
# kv_cache[i] = x1[1]
else:
if use_cache:
new_cache = [0] * self.config.n_layers
for i,block in enumerate(self.blocks):
x1 = block.forward_with_cache(x,None,use_cache)
x = x1[0]
if use_cache:
new_cache[i] = x1[1]
x = self.ln_f(x) # 张量通过最后的LayerNorm层
logits = self.lm_head(x) # 使用语言模型头得到logits
if not use_cache:
new_cache = None
if targets is None:
loss = None
else:
# 计算交叉熵损失
logits = logits.view(B*T, self.config.vocab_size)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return logits, new_cache, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens=300, temperature=1.0, top_k=0, kv_cache = None, use_cache=False):
flag = False
curr_kv_cache = kv_cache
for _ in range(max_new_tokens):
if curr_kv_cache is None:
idx_cond = idx[:, -self.config.block_size:]
else:
if flag:
idx_cond = idx[:, -1:]
curr_kv_cache = [(item[0][:,:,-self.config.block_size:,:],item[1][:,:,-self.config.block_size:,:]) for item in curr_kv_cache]
else:
length0 = idx.shape[1]
length1 = kv_cache[0][0].shape[-2]
if length0 > self.config.block_size:
idx_cond = idx[:, -self.config.block_size:]
kv_cache = None
else:
idx_cond = idx[:, :]
if length0 + length1 > self.config.block_size:
length2 = self.config.block_size-length0
kv_cache = [(item[0][-length2:],item[1][-length2:]) for item in kv_cache]
logits, curr_kv_cache, _ = self.forward_with_cache(idx_cond, curr_kv_cache, use_cache)
# print(kv_cache[0][0].shape)
logits = logits[:,-1,:] / temperature
if top_k > 0:
logits_top, _ = torch.topk(logits, min(top_k, logits.shape[-1]))
logits[logits < logits_top[:,[-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1) # 计算各token出现的概率
next_idx = torch.multinomial(probs, num_samples=1) # 采样
idx = torch.cat((idx, next_idx), dim=1) # (B,T+1)
flag = True
return idx
@torch.no_grad()
def generate_normal(self, idx, max_new_tokens=300):
for _ in range(max_new_tokens):
idx_cond = idx[:,-self.config.block_size:]
logits, _ = self(idx_cond)
last_logits = logits[:,-1,:]
probs = F.softmax(last_logits, dim=-1) # 计算各token出现的概率
next_idx = torch.multinomial(probs, num_samples=1) # 采样
idx = torch.cat((idx, next_idx), dim=1) # (B,T+1)
return idx