File size: 7,323 Bytes
8b57151 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 | import torch
import math
import torch.nn as nn
import torch.nn.functional as F
import sentencepiece as spm
#------------------------------- 内存优化的单头自注意力 -------------------------------------------------#
class MemoryEfficientSelfAttention(nn.Module):
def __init__(self, d_model, dropout_rate=0.2):
super().__init__()
self.d_model = d_model
self.keys = nn.Linear(d_model, d_model, bias=False)
self.queries = nn.Linear(d_model, d_model, bias=False)
self.values = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout_rate)
self.register_buffer("mask", None)
def forward(self, X):
B, T, C = X.shape
# 使用更小的数据类型来节省内存
K = self.keys(X) # (B, T, C)
Q = self.queries(X) # (B, T, C)
V = self.values(X) # (B, T, C)
# 分块计算注意力,避免大矩阵操作
# 缩放点积
scaled_dot = (Q @ K.transpose(-2, -1)) / math.sqrt(C) # (B, T, T)
# 动态 mask:只保留下三角
if self.mask is None or self.mask.shape[1] != T:
self.mask = torch.tril(torch.ones(T, T, device=X.device)).unsqueeze(0) # (1, T, T)
scaled_dot = scaled_dot.masked_fill(self.mask == 0, float('-inf'))
attn = F.softmax(scaled_dot, dim=-1)
attn = self.dropout(attn)
out = attn @ V
# 清理中间变量
del scaled_dot, attn
return out
#------------------------------- 内存优化的多头注意力 -------------------------------------------------#
class MemoryEfficientMultiHeadAttention(nn.Module):
def __init__(self, d_model, h, dropout_rate=0.2):
super().__init__()
self.h = h
self.d_k = d_model // h
self.heads = nn.ModuleList([MemoryEfficientSelfAttention(self.d_k, dropout_rate) for _ in range(h)])
self.proj = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, X):
B, T, C = X.shape
# 分头
X_split = X.view(B, T, self.h, self.d_k).transpose(1, 2) # (B, h, T, d_k)
out_heads = []
for i, head in enumerate(self.heads):
out_heads.append(head(X_split[:, i])) # (B, T, d_k)
# 拼接回原始维度
out = torch.cat(out_heads, dim=-1) # (B, T, C)
out = self.proj(out)
out = self.dropout(out)
# 清理中间变量
del X_split, out_heads
return out
#------------------------------- 前馈网络 -------------------------------------------------#
class FeedForward(nn.Module):
def __init__(self, d_model, dropout_rate=0.2):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, 4*d_model),
nn.ReLU(),
nn.Linear(4*d_model, d_model),
nn.Dropout(dropout_rate)
)
def forward(self, X):
return self.net(X)
#------------------------------- Transformer 块 -------------------------------------------------#
class MemoryEfficientBlock(nn.Module):
def __init__(self, d_model, h, dropout_rate=0.2):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.attn = MemoryEfficientMultiHeadAttention(d_model, h, dropout_rate)
self.ff = FeedForward(d_model, dropout_rate)
def forward(self, X):
# 残差连接
attn_out = self.attn(self.ln1(X))
X = X + attn_out
ff_out = self.ff(self.ln2(X))
X = X + ff_out
# 清理中间变量
del attn_out, ff_out
return X
#------------------------------- 内存优化的BigramLM -------------------------------------------------#
class MemoryOptimizedBigramLM(nn.Module):
def __init__(self, vocab_size, d_model=512, max_seq_len=1024, h=8, Nx=6, dropout_rate=0.2):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
# token embedding
self.token_embedding = nn.Embedding(vocab_size, d_model)
# learnable positional embedding
self.pos_embedding = nn.Embedding(max_seq_len, d_model)
self.dropout = nn.Dropout(dropout_rate)
# Transformer blocks
blocks = [MemoryEfficientBlock(d_model, h, dropout_rate) for _ in range(Nx)]
blocks.append(nn.LayerNorm(d_model))
self.blocks = nn.Sequential(*blocks)
# output layer
self.lm_head = nn.Linear(d_model, vocab_size)
def forward(self, idx, targets=None):
B, T = idx.shape
tok_emb = self.token_embedding(idx) # (B, T, C)
pos_idx = torch.arange(T, device=idx.device).unsqueeze(0) # (1, T)
pos_emb = self.pos_embedding(pos_idx)
X = self.dropout(tok_emb + pos_emb)
X = self.blocks(X)
logits = self.lm_head(X)
loss = None
if targets is not None:
# Ignore padding tokens (ID 1) in loss calculation
loss_mask = (targets != 1).view(-1)
if loss_mask.sum() > 0: # Only compute loss if there are non-padding tokens
loss = F.cross_entropy(logits.view(B*T, -1), targets.view(-1), ignore_index=1)
else:
loss = torch.tensor(0.0, device=idx.device)
# 清理中间变量
del tok_emb, pos_emb, X
return logits, loss
# ---------------- 生成函数 ----------------#
def generate(self, idx, max_new_tokens, temperature=0.8, top_k=None, eos_token_id=None, repetition_penalty=1.2):
B = idx.size(0)
generated = idx.clone()
if eos_token_id is None:
sp = spm.SentencePieceProcessor()
sp.load("tokenizer.model")
eos_token_id = sp.piece_to_id("<END>")
finished = torch.zeros(B, dtype=torch.bool, device=generated.device)
for _ in range(max_new_tokens):
logits, _ = self(generated)
logits = logits[:, -1, :]
# 重复惩罚
if repetition_penalty != 1.0:
for b in range(B):
for t in torch.unique(generated[b]):
if logits[b, t] < 0:
logits[b, t] *= repetition_penalty
else:
logits[b, t] /= repetition_penalty
# temperature
logits = logits / temperature
# top-k
if top_k is not None:
topv, topi = torch.topk(logits, top_k, dim=-1)
mask = torch.full_like(logits, float('-inf'))
mask.scatter_(1, topi, logits.gather(1, topi))
logits = mask
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
finished |= (idx_next.squeeze(1) == eos_token_id)
idx_next[finished, :] = eos_token_id
generated = torch.cat((generated, idx_next), dim=1)
if finished.all():
break
return generated
|