|
|
import torch
|
|
|
import math
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
import sentencepiece as spm
|
|
|
|
|
|
|
|
|
class MemoryEfficientSelfAttention(nn.Module):
|
|
|
def __init__(self, d_model, dropout_rate=0.2):
|
|
|
super().__init__()
|
|
|
self.d_model = d_model
|
|
|
self.keys = nn.Linear(d_model, d_model, bias=False)
|
|
|
self.queries = nn.Linear(d_model, d_model, bias=False)
|
|
|
self.values = nn.Linear(d_model, d_model, bias=False)
|
|
|
self.dropout = nn.Dropout(dropout_rate)
|
|
|
self.register_buffer("mask", None)
|
|
|
|
|
|
def forward(self, X):
|
|
|
B, T, C = X.shape
|
|
|
|
|
|
|
|
|
K = self.keys(X)
|
|
|
Q = self.queries(X)
|
|
|
V = self.values(X)
|
|
|
|
|
|
|
|
|
|
|
|
scaled_dot = (Q @ K.transpose(-2, -1)) / math.sqrt(C)
|
|
|
|
|
|
|
|
|
if self.mask is None or self.mask.shape[1] != T:
|
|
|
self.mask = torch.tril(torch.ones(T, T, device=X.device)).unsqueeze(0)
|
|
|
|
|
|
scaled_dot = scaled_dot.masked_fill(self.mask == 0, float('-inf'))
|
|
|
|
|
|
attn = F.softmax(scaled_dot, dim=-1)
|
|
|
attn = self.dropout(attn)
|
|
|
out = attn @ V
|
|
|
|
|
|
|
|
|
del scaled_dot, attn
|
|
|
return out
|
|
|
|
|
|
|
|
|
class MemoryEfficientMultiHeadAttention(nn.Module):
|
|
|
def __init__(self, d_model, h, dropout_rate=0.2):
|
|
|
super().__init__()
|
|
|
self.h = h
|
|
|
self.d_k = d_model // h
|
|
|
self.heads = nn.ModuleList([MemoryEfficientSelfAttention(self.d_k, dropout_rate) for _ in range(h)])
|
|
|
self.proj = nn.Linear(d_model, d_model)
|
|
|
self.dropout = nn.Dropout(dropout_rate)
|
|
|
|
|
|
def forward(self, X):
|
|
|
B, T, C = X.shape
|
|
|
|
|
|
X_split = X.view(B, T, self.h, self.d_k).transpose(1, 2)
|
|
|
out_heads = []
|
|
|
for i, head in enumerate(self.heads):
|
|
|
out_heads.append(head(X_split[:, i]))
|
|
|
|
|
|
out = torch.cat(out_heads, dim=-1)
|
|
|
out = self.proj(out)
|
|
|
out = self.dropout(out)
|
|
|
|
|
|
|
|
|
del X_split, out_heads
|
|
|
return out
|
|
|
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
|
def __init__(self, d_model, dropout_rate=0.2):
|
|
|
super().__init__()
|
|
|
self.net = nn.Sequential(
|
|
|
nn.Linear(d_model, 4*d_model),
|
|
|
nn.ReLU(),
|
|
|
nn.Linear(4*d_model, d_model),
|
|
|
nn.Dropout(dropout_rate)
|
|
|
)
|
|
|
def forward(self, X):
|
|
|
return self.net(X)
|
|
|
|
|
|
|
|
|
class MemoryEfficientBlock(nn.Module):
|
|
|
def __init__(self, d_model, h, dropout_rate=0.2):
|
|
|
super().__init__()
|
|
|
self.ln1 = nn.LayerNorm(d_model)
|
|
|
self.ln2 = nn.LayerNorm(d_model)
|
|
|
self.attn = MemoryEfficientMultiHeadAttention(d_model, h, dropout_rate)
|
|
|
self.ff = FeedForward(d_model, dropout_rate)
|
|
|
|
|
|
def forward(self, X):
|
|
|
|
|
|
attn_out = self.attn(self.ln1(X))
|
|
|
X = X + attn_out
|
|
|
ff_out = self.ff(self.ln2(X))
|
|
|
X = X + ff_out
|
|
|
|
|
|
|
|
|
del attn_out, ff_out
|
|
|
return X
|
|
|
|
|
|
|
|
|
class MemoryOptimizedBigramLM(nn.Module):
|
|
|
def __init__(self, vocab_size, d_model=512, max_seq_len=1024, h=8, Nx=6, dropout_rate=0.2):
|
|
|
super().__init__()
|
|
|
self.vocab_size = vocab_size
|
|
|
self.d_model = d_model
|
|
|
|
|
|
|
|
|
self.token_embedding = nn.Embedding(vocab_size, d_model)
|
|
|
|
|
|
self.pos_embedding = nn.Embedding(max_seq_len, d_model)
|
|
|
self.dropout = nn.Dropout(dropout_rate)
|
|
|
|
|
|
|
|
|
blocks = [MemoryEfficientBlock(d_model, h, dropout_rate) for _ in range(Nx)]
|
|
|
blocks.append(nn.LayerNorm(d_model))
|
|
|
self.blocks = nn.Sequential(*blocks)
|
|
|
|
|
|
|
|
|
self.lm_head = nn.Linear(d_model, vocab_size)
|
|
|
|
|
|
def forward(self, idx, targets=None):
|
|
|
B, T = idx.shape
|
|
|
tok_emb = self.token_embedding(idx)
|
|
|
pos_idx = torch.arange(T, device=idx.device).unsqueeze(0)
|
|
|
pos_emb = self.pos_embedding(pos_idx)
|
|
|
X = self.dropout(tok_emb + pos_emb)
|
|
|
|
|
|
X = self.blocks(X)
|
|
|
logits = self.lm_head(X)
|
|
|
|
|
|
loss = None
|
|
|
if targets is not None:
|
|
|
|
|
|
loss_mask = (targets != 1).view(-1)
|
|
|
if loss_mask.sum() > 0:
|
|
|
loss = F.cross_entropy(logits.view(B*T, -1), targets.view(-1), ignore_index=1)
|
|
|
else:
|
|
|
loss = torch.tensor(0.0, device=idx.device)
|
|
|
|
|
|
|
|
|
del tok_emb, pos_emb, X
|
|
|
return logits, loss
|
|
|
|
|
|
|
|
|
def generate(self, idx, max_new_tokens, temperature=0.8, top_k=None, eos_token_id=None, repetition_penalty=1.2):
|
|
|
B = idx.size(0)
|
|
|
generated = idx.clone()
|
|
|
|
|
|
if eos_token_id is None:
|
|
|
sp = spm.SentencePieceProcessor()
|
|
|
sp.load("tokenizer.model")
|
|
|
eos_token_id = sp.piece_to_id("<END>")
|
|
|
|
|
|
finished = torch.zeros(B, dtype=torch.bool, device=generated.device)
|
|
|
|
|
|
for _ in range(max_new_tokens):
|
|
|
logits, _ = self(generated)
|
|
|
logits = logits[:, -1, :]
|
|
|
|
|
|
|
|
|
if repetition_penalty != 1.0:
|
|
|
for b in range(B):
|
|
|
for t in torch.unique(generated[b]):
|
|
|
if logits[b, t] < 0:
|
|
|
logits[b, t] *= repetition_penalty
|
|
|
else:
|
|
|
logits[b, t] /= repetition_penalty
|
|
|
|
|
|
|
|
|
logits = logits / temperature
|
|
|
|
|
|
|
|
|
if top_k is not None:
|
|
|
topv, topi = torch.topk(logits, top_k, dim=-1)
|
|
|
mask = torch.full_like(logits, float('-inf'))
|
|
|
mask.scatter_(1, topi, logits.gather(1, topi))
|
|
|
logits = mask
|
|
|
|
|
|
probs = F.softmax(logits, dim=-1)
|
|
|
idx_next = torch.multinomial(probs, num_samples=1)
|
|
|
|
|
|
finished |= (idx_next.squeeze(1) == eos_token_id)
|
|
|
idx_next[finished, :] = eos_token_id
|
|
|
generated = torch.cat((generated, idx_next), dim=1)
|
|
|
|
|
|
if finished.all():
|
|
|
break
|
|
|
|
|
|
return generated
|
|
|
|