own_gpt / model_optimized.py
AISkywalker's picture
Upload 21 files
8b57151 verified
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
import sentencepiece as spm
#------------------------------- 内存优化的单头自注意力 -------------------------------------------------#
class MemoryEfficientSelfAttention(nn.Module):
def __init__(self, d_model, dropout_rate=0.2):
super().__init__()
self.d_model = d_model
self.keys = nn.Linear(d_model, d_model, bias=False)
self.queries = nn.Linear(d_model, d_model, bias=False)
self.values = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout_rate)
self.register_buffer("mask", None)
def forward(self, X):
B, T, C = X.shape
# 使用更小的数据类型来节省内存
K = self.keys(X) # (B, T, C)
Q = self.queries(X) # (B, T, C)
V = self.values(X) # (B, T, C)
# 分块计算注意力,避免大矩阵操作
# 缩放点积
scaled_dot = (Q @ K.transpose(-2, -1)) / math.sqrt(C) # (B, T, T)
# 动态 mask:只保留下三角
if self.mask is None or self.mask.shape[1] != T:
self.mask = torch.tril(torch.ones(T, T, device=X.device)).unsqueeze(0) # (1, T, T)
scaled_dot = scaled_dot.masked_fill(self.mask == 0, float('-inf'))
attn = F.softmax(scaled_dot, dim=-1)
attn = self.dropout(attn)
out = attn @ V
# 清理中间变量
del scaled_dot, attn
return out
#------------------------------- 内存优化的多头注意力 -------------------------------------------------#
class MemoryEfficientMultiHeadAttention(nn.Module):
def __init__(self, d_model, h, dropout_rate=0.2):
super().__init__()
self.h = h
self.d_k = d_model // h
self.heads = nn.ModuleList([MemoryEfficientSelfAttention(self.d_k, dropout_rate) for _ in range(h)])
self.proj = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, X):
B, T, C = X.shape
# 分头
X_split = X.view(B, T, self.h, self.d_k).transpose(1, 2) # (B, h, T, d_k)
out_heads = []
for i, head in enumerate(self.heads):
out_heads.append(head(X_split[:, i])) # (B, T, d_k)
# 拼接回原始维度
out = torch.cat(out_heads, dim=-1) # (B, T, C)
out = self.proj(out)
out = self.dropout(out)
# 清理中间变量
del X_split, out_heads
return out
#------------------------------- 前馈网络 -------------------------------------------------#
class FeedForward(nn.Module):
def __init__(self, d_model, dropout_rate=0.2):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, 4*d_model),
nn.ReLU(),
nn.Linear(4*d_model, d_model),
nn.Dropout(dropout_rate)
)
def forward(self, X):
return self.net(X)
#------------------------------- Transformer 块 -------------------------------------------------#
class MemoryEfficientBlock(nn.Module):
def __init__(self, d_model, h, dropout_rate=0.2):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.attn = MemoryEfficientMultiHeadAttention(d_model, h, dropout_rate)
self.ff = FeedForward(d_model, dropout_rate)
def forward(self, X):
# 残差连接
attn_out = self.attn(self.ln1(X))
X = X + attn_out
ff_out = self.ff(self.ln2(X))
X = X + ff_out
# 清理中间变量
del attn_out, ff_out
return X
#------------------------------- 内存优化的BigramLM -------------------------------------------------#
class MemoryOptimizedBigramLM(nn.Module):
def __init__(self, vocab_size, d_model=512, max_seq_len=1024, h=8, Nx=6, dropout_rate=0.2):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
# token embedding
self.token_embedding = nn.Embedding(vocab_size, d_model)
# learnable positional embedding
self.pos_embedding = nn.Embedding(max_seq_len, d_model)
self.dropout = nn.Dropout(dropout_rate)
# Transformer blocks
blocks = [MemoryEfficientBlock(d_model, h, dropout_rate) for _ in range(Nx)]
blocks.append(nn.LayerNorm(d_model))
self.blocks = nn.Sequential(*blocks)
# output layer
self.lm_head = nn.Linear(d_model, vocab_size)
def forward(self, idx, targets=None):
B, T = idx.shape
tok_emb = self.token_embedding(idx) # (B, T, C)
pos_idx = torch.arange(T, device=idx.device).unsqueeze(0) # (1, T)
pos_emb = self.pos_embedding(pos_idx)
X = self.dropout(tok_emb + pos_emb)
X = self.blocks(X)
logits = self.lm_head(X)
loss = None
if targets is not None:
# Ignore padding tokens (ID 1) in loss calculation
loss_mask = (targets != 1).view(-1)
if loss_mask.sum() > 0: # Only compute loss if there are non-padding tokens
loss = F.cross_entropy(logits.view(B*T, -1), targets.view(-1), ignore_index=1)
else:
loss = torch.tensor(0.0, device=idx.device)
# 清理中间变量
del tok_emb, pos_emb, X
return logits, loss
# ---------------- 生成函数 ----------------#
def generate(self, idx, max_new_tokens, temperature=0.8, top_k=None, eos_token_id=None, repetition_penalty=1.2):
B = idx.size(0)
generated = idx.clone()
if eos_token_id is None:
sp = spm.SentencePieceProcessor()
sp.load("tokenizer.model")
eos_token_id = sp.piece_to_id("<END>")
finished = torch.zeros(B, dtype=torch.bool, device=generated.device)
for _ in range(max_new_tokens):
logits, _ = self(generated)
logits = logits[:, -1, :]
# 重复惩罚
if repetition_penalty != 1.0:
for b in range(B):
for t in torch.unique(generated[b]):
if logits[b, t] < 0:
logits[b, t] *= repetition_penalty
else:
logits[b, t] /= repetition_penalty
# temperature
logits = logits / temperature
# top-k
if top_k is not None:
topv, topi = torch.topk(logits, top_k, dim=-1)
mask = torch.full_like(logits, float('-inf'))
mask.scatter_(1, topi, logits.gather(1, topi))
logits = mask
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
finished |= (idx_next.squeeze(1) == eos_token_id)
idx_next[finished, :] = eos_token_id
generated = torch.cat((generated, idx_next), dim=1)
if finished.all():
break
return generated