| | import torch
|
| | import math
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | import sentencepiece as spm
|
| |
|
| |
|
| | class SelfAttention(nn.Module):
|
| | def __init__(self, d_model, dropout_rate=0.2):
|
| | super().__init__()
|
| | self.d_model = d_model
|
| | self.keys = nn.Linear(d_model, d_model, bias=False)
|
| | self.queries = nn.Linear(d_model, d_model, bias=False)
|
| | self.values = nn.Linear(d_model, d_model, bias=False)
|
| | self.dropout = nn.Dropout(dropout_rate)
|
| |
|
| | def forward(self, X):
|
| | B, T, C = X.shape
|
| | K = self.keys(X)
|
| | Q = self.queries(X)
|
| | V = self.values(X)
|
| |
|
| |
|
| | scaled_dot = (Q @ K.transpose(-2, -1)) / math.sqrt(C)
|
| |
|
| |
|
| | mask = torch.tril(torch.ones(T, T, device=X.device)).unsqueeze(0)
|
| | scaled_dot = scaled_dot.masked_fill(mask == 0, float('-inf'))
|
| |
|
| | attn = F.softmax(scaled_dot, dim=-1)
|
| | attn = self.dropout(attn)
|
| | out = attn @ V
|
| | return out
|
| |
|
| |
|
| | class MultiHeadAttention(nn.Module):
|
| | def __init__(self, d_model, h, dropout_rate=0.2):
|
| | super().__init__()
|
| | self.h = h
|
| | self.d_k = d_model // h
|
| | self.heads = nn.ModuleList([SelfAttention(self.d_k, dropout_rate) for _ in range(h)])
|
| | self.proj = nn.Linear(d_model, d_model)
|
| | self.dropout = nn.Dropout(dropout_rate)
|
| |
|
| | def forward(self, X):
|
| | B, T, C = X.shape
|
| |
|
| | X_split = X.view(B, T, self.h, self.d_k).transpose(1, 2)
|
| | out_heads = []
|
| | for i, head in enumerate(self.heads):
|
| | out_heads.append(head(X_split[:, i]))
|
| |
|
| | out = torch.cat(out_heads, dim=-1)
|
| | out = self.proj(out)
|
| | out = self.dropout(out)
|
| | return out
|
| |
|
| |
|
| | class FeedForward(nn.Module):
|
| | def __init__(self, d_model, dropout_rate=0.2):
|
| | super().__init__()
|
| | self.net = nn.Sequential(
|
| | nn.Linear(d_model, 4*d_model),
|
| | nn.ReLU(),
|
| | nn.Linear(4*d_model, d_model),
|
| | nn.Dropout(dropout_rate)
|
| | )
|
| | def forward(self, X):
|
| | return self.net(X)
|
| |
|
| |
|
| | class Block(nn.Module):
|
| | def __init__(self, d_model, h, dropout_rate=0.2):
|
| | super().__init__()
|
| | self.ln1 = nn.LayerNorm(d_model)
|
| | self.ln2 = nn.LayerNorm(d_model)
|
| | self.attn = MultiHeadAttention(d_model, h, dropout_rate)
|
| | self.ff = FeedForward(d_model, dropout_rate)
|
| |
|
| | def forward(self, X):
|
| | X = X + self.attn(self.ln1(X))
|
| | X = X + self.ff(self.ln2(X))
|
| | return X
|
| |
|
| |
|
| | class BigramLM(nn.Module):
|
| | def __init__(self, vocab_size, d_model=512, max_seq_len=1024, h=8, Nx=6, dropout_rate=0.2):
|
| | super().__init__()
|
| | self.vocab_size = vocab_size
|
| | self.d_model = d_model
|
| |
|
| |
|
| | self.token_embedding = nn.Embedding(vocab_size, d_model)
|
| |
|
| | self.pos_embedding = nn.Embedding(max_seq_len, d_model)
|
| | self.dropout = nn.Dropout(dropout_rate)
|
| |
|
| |
|
| | blocks = [Block(d_model, h, dropout_rate) for _ in range(Nx)]
|
| | blocks.append(nn.LayerNorm(d_model))
|
| | self.blocks = nn.Sequential(*blocks)
|
| |
|
| |
|
| | self.lm_head = nn.Linear(d_model, vocab_size)
|
| |
|
| | def forward(self, idx, targets=None):
|
| | B, T = idx.shape
|
| | tok_emb = self.token_embedding(idx)
|
| | pos_idx = torch.arange(T, device=idx.device).unsqueeze(0)
|
| | pos_emb = self.pos_embedding(pos_idx)
|
| | X = self.dropout(tok_emb + pos_emb)
|
| |
|
| | X = self.blocks(X)
|
| | logits = self.lm_head(X)
|
| |
|
| | loss = None
|
| | if targets is not None:
|
| |
|
| | loss_mask = (targets != 1).view(-1)
|
| | if loss_mask.sum() > 0:
|
| | loss = F.cross_entropy(logits.view(B*T, -1), targets.view(-1), ignore_index=1)
|
| | else:
|
| | loss = torch.tensor(0.0, device=idx.device)
|
| | return logits, loss
|
| |
|
| |
|
| | def generate(self, idx, max_new_tokens, temperature=0.8, top_k=None, eos_token_id=None, repetition_penalty=1.2):
|
| | B = idx.size(0)
|
| | generated = idx.clone()
|
| |
|
| | if eos_token_id is None:
|
| | sp = spm.SentencePieceProcessor()
|
| | sp.load("tokenizer.model")
|
| | eos_token_id = sp.piece_to_id("<END>")
|
| |
|
| | finished = torch.zeros(B, dtype=torch.bool, device=generated.device)
|
| |
|
| | for _ in range(max_new_tokens):
|
| | logits, _ = self(generated)
|
| | logits = logits[:, -1, :]
|
| |
|
| |
|
| | if repetition_penalty != 1.0:
|
| | for b in range(B):
|
| | for t in torch.unique(generated[b]):
|
| | if logits[b, t] < 0:
|
| | logits[b, t] *= repetition_penalty
|
| | else:
|
| | logits[b, t] /= repetition_penalty
|
| |
|
| |
|
| | logits = logits / temperature
|
| |
|
| |
|
| | if top_k is not None:
|
| | topv, topi = torch.topk(logits, top_k, dim=-1)
|
| | mask = torch.full_like(logits, float('-inf'))
|
| | mask.scatter_(1, topi, logits.gather(1, topi))
|
| | logits = mask
|
| |
|
| | probs = F.softmax(logits, dim=-1)
|
| | idx_next = torch.multinomial(probs, num_samples=1)
|
| |
|
| | finished |= (idx_next.squeeze(1) == eos_token_id)
|
| | idx_next[finished, :] = eos_token_id
|
| | generated = torch.cat((generated, idx_next), dim=1)
|
| |
|
| | if finished.all():
|
| | break
|
| |
|
| | return generated
|
| |
|