File size: 6,563 Bytes
8b57151 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | import torch
import math
import torch.nn as nn
import torch.nn.functional as F
import sentencepiece as spm
#------------------------------- 单头自注意力 -------------------------------------------------#
class SelfAttention(nn.Module):
def __init__(self, d_model, dropout_rate=0.2):
super().__init__()
self.d_model = d_model
self.keys = nn.Linear(d_model, d_model, bias=False)
self.queries = nn.Linear(d_model, d_model, bias=False)
self.values = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, X):
B, T, C = X.shape
K = self.keys(X) # (B, T, C)
Q = self.queries(X) # (B, T, C)
V = self.values(X) # (B, T, C)
# 缩放点积
scaled_dot = (Q @ K.transpose(-2, -1)) / math.sqrt(C) # (B, T, T)
# 动态 mask:只保留下三角
mask = torch.tril(torch.ones(T, T, device=X.device)).unsqueeze(0) # (1, T, T)
scaled_dot = scaled_dot.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scaled_dot, dim=-1)
attn = self.dropout(attn)
out = attn @ V
return out
#------------------------------- 多头注意力 -------------------------------------------------#
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, h, dropout_rate=0.2):
super().__init__()
self.h = h
self.d_k = d_model // h
self.heads = nn.ModuleList([SelfAttention(self.d_k, dropout_rate) for _ in range(h)])
self.proj = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, X):
B, T, C = X.shape
# 分头
X_split = X.view(B, T, self.h, self.d_k).transpose(1, 2) # (B, h, T, d_k)
out_heads = []
for i, head in enumerate(self.heads):
out_heads.append(head(X_split[:, i])) # (B, T, d_k)
# 拼接回原始维度
out = torch.cat(out_heads, dim=-1) # (B, T, C)
out = self.proj(out)
out = self.dropout(out)
return out
#------------------------------- 前馈网络 -------------------------------------------------#
class FeedForward(nn.Module):
def __init__(self, d_model, dropout_rate=0.2):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, 4*d_model),
nn.ReLU(),
nn.Linear(4*d_model, d_model),
nn.Dropout(dropout_rate)
)
def forward(self, X):
return self.net(X)
#------------------------------- Transformer 块 -------------------------------------------------#
class Block(nn.Module):
def __init__(self, d_model, h, dropout_rate=0.2):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.attn = MultiHeadAttention(d_model, h, dropout_rate)
self.ff = FeedForward(d_model, dropout_rate)
def forward(self, X):
X = X + self.attn(self.ln1(X))
X = X + self.ff(self.ln2(X))
return X
#------------------------------- BigramLM -------------------------------------------------#
class BigramLM(nn.Module):
def __init__(self, vocab_size, d_model=512, max_seq_len=1024, h=8, Nx=6, dropout_rate=0.2):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
# token embedding
self.token_embedding = nn.Embedding(vocab_size, d_model)
# learnable positional embedding
self.pos_embedding = nn.Embedding(max_seq_len, d_model)
self.dropout = nn.Dropout(dropout_rate)
# Transformer blocks
blocks = [Block(d_model, h, dropout_rate) for _ in range(Nx)]
blocks.append(nn.LayerNorm(d_model))
self.blocks = nn.Sequential(*blocks)
# output layer
self.lm_head = nn.Linear(d_model, vocab_size)
def forward(self, idx, targets=None):
B, T = idx.shape
tok_emb = self.token_embedding(idx) # (B, T, C)
pos_idx = torch.arange(T, device=idx.device).unsqueeze(0) # (1, T)
pos_emb = self.pos_embedding(pos_idx)
X = self.dropout(tok_emb + pos_emb)
X = self.blocks(X)
logits = self.lm_head(X)
loss = None
if targets is not None:
# Ignore padding tokens (ID 1) in loss calculation
loss_mask = (targets != 1).view(-1)
if loss_mask.sum() > 0: # Only compute loss if there are non-padding tokens
loss = F.cross_entropy(logits.view(B*T, -1), targets.view(-1), ignore_index=1)
else:
loss = torch.tensor(0.0, device=idx.device)
return logits, loss
# ---------------- 生成函数 ----------------#
def generate(self, idx, max_new_tokens, temperature=0.8, top_k=None, eos_token_id=None, repetition_penalty=1.2):
B = idx.size(0)
generated = idx.clone()
if eos_token_id is None:
sp = spm.SentencePieceProcessor()
sp.load("tokenizer.model")
eos_token_id = sp.piece_to_id("<END>")
finished = torch.zeros(B, dtype=torch.bool, device=generated.device)
for _ in range(max_new_tokens):
logits, _ = self(generated)
logits = logits[:, -1, :]
# 重复惩罚
if repetition_penalty != 1.0:
for b in range(B):
for t in torch.unique(generated[b]):
if logits[b, t] < 0:
logits[b, t] *= repetition_penalty
else:
logits[b, t] /= repetition_penalty
# temperature
logits = logits / temperature
# top-k
if top_k is not None:
topv, topi = torch.topk(logits, top_k, dim=-1)
mask = torch.full_like(logits, float('-inf'))
mask.scatter_(1, topi, logits.gather(1, topi))
logits = mask
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
finished |= (idx_next.squeeze(1) == eos_token_id)
idx_next[finished, :] = eos_token_id
generated = torch.cat((generated, idx_next), dim=1)
if finished.all():
break
return generated
|