|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from dataclasses import dataclass |
|
|
|
|
|
@dataclass |
|
|
class ModelConfig: |
|
|
embedding_dim: int |
|
|
hidden_dim: int |
|
|
num_attention_heads: int |
|
|
layer_count: int |
|
|
max_sequence_length: int |
|
|
rope_theta: float |
|
|
vocab_size: int |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
class TokenEmbedding(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.token_embedding_table = nn.Embedding(config.vocab_size, config.embedding_dim) |
|
|
|
|
|
|
|
|
def forward(self, input_indices): |
|
|
return self.token_embedding_table(input_indices) |
|
|
|
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
|
def __init__(self, dim, max_seq_len=2048, rope_theta=1e6): |
|
|
super().__init__() |
|
|
|
|
|
inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2) / dim)) |
|
|
position_index = torch.arange(max_seq_len) |
|
|
frequency_matrix = torch.einsum('i,j->ij', position_index, inv_freq) |
|
|
|
|
|
cosine = torch.cos(frequency_matrix)[None, None, :, :] |
|
|
sine = torch.sin(frequency_matrix)[None, None, :, :] |
|
|
|
|
|
self.register_buffer("cos_cached", cosine, persistent=False) |
|
|
self.register_buffer("sin_cached", sine, persistent=False) |
|
|
|
|
|
def apply_rotary_emb(self, x, position_offset=0): |
|
|
sequence_length = x.size(2) |
|
|
|
|
|
cosine = self.cos_cached[:, :, position_offset:position_offset + sequence_length, :] |
|
|
sine = self.sin_cached[:, :, position_offset:position_offset + sequence_length, :] |
|
|
|
|
|
x_even = x[..., 0::2] |
|
|
x_odd = x[..., 1::2] |
|
|
|
|
|
rotated_even = x_even * cosine - x_odd * sine |
|
|
rotated_odd = x_odd * cosine + x_even * sine |
|
|
|
|
|
rotated = torch.empty_like(x) |
|
|
rotated[..., 0::2] = rotated_even |
|
|
rotated[..., 1::2] = rotated_odd |
|
|
|
|
|
return rotated |
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.num_heads = config.num_attention_heads |
|
|
self.embed_dim = config.embedding_dim |
|
|
self.head_dim = self.embed_dim // self.num_heads |
|
|
|
|
|
|
|
|
self.query_fc = nn.Linear(self.embed_dim, self.embed_dim, bias=False) |
|
|
self.key_fc = nn.Linear(self.embed_dim, self.embed_dim, bias=False) |
|
|
self.value_fc = nn.Linear(self.embed_dim, self.embed_dim, bias=False) |
|
|
|
|
|
|
|
|
self.rotary_emb = RotaryEmbedding( |
|
|
dim=self.head_dim, |
|
|
max_seq_len=config.max_sequence_length, |
|
|
rope_theta=config.rope_theta |
|
|
) |
|
|
|
|
|
self.output_projection = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
|
|
|
self.register_buffer( |
|
|
"causal_mask", |
|
|
torch.tril(torch.ones( |
|
|
config.max_sequence_length, |
|
|
config.max_sequence_length, |
|
|
dtype=torch.bool |
|
|
)), |
|
|
persistent=False |
|
|
) |
|
|
|
|
|
|
|
|
self.register_buffer("cache_k", None, persistent=False) |
|
|
self.register_buffer("cache_v", None, persistent=False) |
|
|
self.current_pos = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, use_cache=False): |
|
|
input_len = x.size(1) |
|
|
if use_cache is False: |
|
|
return self.forward_no_cache(x) |
|
|
elif use_cache is True and input_len > 1: |
|
|
return self.forward_prefill(x) |
|
|
elif use_cache is True and input_len == 1: |
|
|
return self.forward_cached_decoding(x) |
|
|
else: |
|
|
raise RuntimeError("Unexpected condition in MultiHeadAttention forward") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_no_cache(self, x): |
|
|
B, T, C = x.shape |
|
|
|
|
|
Q = self.query_fc(x) |
|
|
K = self.key_fc(x) |
|
|
V = self.value_fc(x) |
|
|
|
|
|
Q = Q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
K = K.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
V = V.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
Q = self.rotary_emb.apply_rotary_emb(Q, position_offset=0) |
|
|
K = self.rotary_emb.apply_rotary_emb(K, position_offset=0) |
|
|
|
|
|
out = F.scaled_dot_product_attention( |
|
|
Q, K, V, |
|
|
attn_mask=None, |
|
|
is_causal=True |
|
|
) |
|
|
|
|
|
out = out.transpose(1, 2).contiguous().view(B, T, C) |
|
|
out = self.output_projection(out) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_prefill(self, x): |
|
|
B, T, C = x.shape |
|
|
|
|
|
Q = self.query_fc(x) |
|
|
K = self.key_fc(x) |
|
|
V = self.value_fc(x) |
|
|
|
|
|
Q = Q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
K = K.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
V = V.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
if self.cache_k is None: |
|
|
self.cache_k = torch.zeros( |
|
|
B, self.num_heads, self.config.max_sequence_length, self.head_dim, |
|
|
device=x.device, dtype=K.dtype |
|
|
) |
|
|
self.cache_v = torch.zeros( |
|
|
B, self.num_heads, self.config.max_sequence_length, self.head_dim, |
|
|
device=x.device, dtype=V.dtype |
|
|
) |
|
|
self.current_pos = 0 |
|
|
|
|
|
|
|
|
Q = self.rotary_emb.apply_rotary_emb(Q, position_offset=self.current_pos) |
|
|
K = self.rotary_emb.apply_rotary_emb(K, position_offset=self.current_pos) |
|
|
|
|
|
|
|
|
if self.current_pos + T > self.config.max_sequence_length: |
|
|
raise RuntimeError("KV cache exceeded max_sequence_length") |
|
|
|
|
|
self.cache_k[:, :, self.current_pos:self.current_pos + T, :] = K |
|
|
self.cache_v[:, :, self.current_pos:self.current_pos + T, :] = V |
|
|
|
|
|
K = self.cache_k[:, :, :self.current_pos + T, :] |
|
|
V = self.cache_v[:, :, :self.current_pos + T, :] |
|
|
|
|
|
attn_mask = self.causal_mask[ |
|
|
self.current_pos : self.current_pos + T, |
|
|
: self.current_pos + T |
|
|
] |
|
|
|
|
|
out = F.scaled_dot_product_attention( |
|
|
Q, K, V, |
|
|
attn_mask=attn_mask, |
|
|
is_causal=False |
|
|
) |
|
|
|
|
|
self.current_pos += T |
|
|
|
|
|
out = out.transpose(1, 2).contiguous().view(B, T, C) |
|
|
out = self.output_projection(out) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_cached_decoding(self, x): |
|
|
B, T, C = x.shape |
|
|
assert T == 1, "cached decoding expects T==1" |
|
|
|
|
|
Q = self.query_fc(x) |
|
|
K = self.key_fc(x) |
|
|
V = self.value_fc(x) |
|
|
|
|
|
Q = Q.view(B, 1, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
K = K.view(B, 1, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
V = V.view(B, 1, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
|
|
|
if self.cache_k is None: |
|
|
self.cache_k = torch.zeros( |
|
|
B, self.num_heads, self.config.max_sequence_length, self.head_dim, |
|
|
device=x.device, dtype=K.dtype |
|
|
) |
|
|
self.cache_v = torch.zeros( |
|
|
B, self.num_heads, self.config.max_sequence_length, self.head_dim, |
|
|
device=x.device, dtype=V.dtype |
|
|
) |
|
|
self.current_pos = 0 |
|
|
|
|
|
if self.current_pos + 1 >= self.config.max_sequence_length: |
|
|
raise RuntimeError("KV cache exceeded max_sequence_length") |
|
|
|
|
|
|
|
|
Q = self.rotary_emb.apply_rotary_emb(Q, position_offset=self.current_pos) |
|
|
K = self.rotary_emb.apply_rotary_emb(K, position_offset=self.current_pos) |
|
|
|
|
|
self.cache_k[:, :, self.current_pos:self.current_pos + 1, :] = K |
|
|
self.cache_v[:, :, self.current_pos:self.current_pos + 1, :] = V |
|
|
|
|
|
K = self.cache_k[:, :, :self.current_pos + 1, :] |
|
|
V = self.cache_v[:, :, :self.current_pos + 1, :] |
|
|
|
|
|
out = F.scaled_dot_product_attention( |
|
|
Q, K, V, |
|
|
attn_mask=None, |
|
|
is_causal=False |
|
|
) |
|
|
|
|
|
self.current_pos += 1 |
|
|
|
|
|
out = out.transpose(1, 2).contiguous().view(B, T, C) |
|
|
out = self.output_projection(out) |
|
|
return out |
|
|
|
|
|
def reset_cache(self): |
|
|
self.cache_k = None |
|
|
self.cache_v = None |
|
|
self.current_pos = 0 |
|
|
|
|
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.net = nn.Sequential( |
|
|
nn.Linear(config.embedding_dim, config.hidden_dim, bias=False), |
|
|
nn.ReLU(), |
|
|
nn.Linear(config.hidden_dim, config.embedding_dim, bias=False), |
|
|
) |
|
|
|
|
|
def forward(self, input_tensor): |
|
|
return self.net(input_tensor) |
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.layer_norm1 = nn.LayerNorm(config.embedding_dim) |
|
|
self.layer_norm2 = nn.LayerNorm(config.embedding_dim) |
|
|
self.multihead_attention = MultiHeadAttention(config=config) |
|
|
self.feed_forward = FeedForward(config=config) |
|
|
|
|
|
|
|
|
def forward(self, input_tensor, use_cache=False): |
|
|
normed_input = self.layer_norm1(input_tensor) |
|
|
attention_output = self.multihead_attention(normed_input, use_cache=use_cache) |
|
|
residual_attention = attention_output + input_tensor |
|
|
normed_attention = self.layer_norm2(residual_attention) |
|
|
feedforward_output = self.feed_forward(normed_attention) |
|
|
final_output = feedforward_output + residual_attention |
|
|
return final_output |
|
|
|
|
|
|
|
|
class VocabularyLogits(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.output_norm = nn.LayerNorm(config.embedding_dim) |
|
|
self.vocab_projection = nn.Linear(config.embedding_dim, config.vocab_size, bias=False) |
|
|
|
|
|
def forward(self, transformer_block_output): |
|
|
x = transformer_block_output |
|
|
normalized_output = self.output_norm(x) |
|
|
vocab_logits = self.vocab_projection(normalized_output) |
|
|
return vocab_logits |
|
|
|
|
|
|
|
|
class GPT(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.token_embedding_layer = TokenEmbedding(config=config) |
|
|
self.blocks = nn.ModuleList([TransformerBlock(config=config) for _ in range(config.layer_count)]) |
|
|
self.vocab_projection = VocabularyLogits(config=config) |
|
|
self.criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
|
|
|
def forward(self, input_indices, target_indices, use_cache=False): |
|
|
token_embeddings = self.token_embedding_layer.forward(input_indices) |
|
|
|
|
|
x = token_embeddings |
|
|
for block in self.blocks: |
|
|
x = block(x, use_cache=use_cache) |
|
|
logits = self.vocab_projection(x) |
|
|
|
|
|
if target_indices is None: |
|
|
return logits, None |
|
|
|
|
|
batch_size, token_len, vocab_size = logits.shape |
|
|
logits_flat = logits.view(batch_size * token_len, vocab_size) |
|
|
targets_flat = target_indices.view(batch_size * token_len) |
|
|
loss = self.criterion(logits_flat, targets_flat) |
|
|
return logits, loss |
|
|
|
|
|
|
|
|
def generate(self, |
|
|
input_indices, |
|
|
max_new_tokens, |
|
|
temperature=1.0, |
|
|
use_cache=True, |
|
|
reset_cache=True, |
|
|
top_k=None, |
|
|
top_p=None, |
|
|
): |
|
|
self.eval() |
|
|
|
|
|
if reset_cache: |
|
|
for block in self.blocks: |
|
|
block.multihead_attention.reset_cache() |
|
|
|
|
|
next_token = None |
|
|
|
|
|
for i in range(max_new_tokens): |
|
|
if use_cache: |
|
|
if i == 0: |
|
|
logits, _ = self.forward(input_indices, None, use_cache=True) |
|
|
else: |
|
|
logits, _ = self.forward(next_token, None, use_cache=True) |
|
|
else: |
|
|
logits, _ = self.forward(input_indices, None, use_cache=False) |
|
|
|
|
|
""" DELETE |
|
|
last_logits = logits[:, -1, :] / temperature |
|
|
probs = F.softmax(last_logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
""" |
|
|
|
|
|
|
|
|
last_logits = logits[:, -1, :] / temperature |
|
|
|
|
|
if top_k is not None: |
|
|
top_k = min(top_k, last_logits.size(-1)) |
|
|
values, _ = torch.topk(last_logits, top_k) |
|
|
min_value = values[:, -1].unsqueeze(-1) |
|
|
last_logits = torch.where( |
|
|
last_logits < min_value, |
|
|
torch.full_like(last_logits, float("-inf")), |
|
|
last_logits, |
|
|
) |
|
|
|
|
|
if top_p is not None: |
|
|
sorted_logits, sorted_indices = torch.sort(last_logits, descending=True) |
|
|
sorted_probs = F.softmax(sorted_logits, dim=-1) |
|
|
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
|
|
|
|
|
sorted_mask = cumulative_probs > top_p |
|
|
sorted_mask[..., 1:] = sorted_mask[..., :-1].clone() |
|
|
sorted_mask[..., 0] = False |
|
|
|
|
|
sorted_logits = torch.where( |
|
|
sorted_mask, |
|
|
torch.full_like(sorted_logits, float("-inf")), |
|
|
sorted_logits, |
|
|
) |
|
|
|
|
|
last_logits = torch.zeros_like(last_logits).scatter( |
|
|
-1, sorted_indices, sorted_logits |
|
|
) |
|
|
|
|
|
probs = F.softmax(last_logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
|
|
|
yield int(next_token.item()) |
|
|
input_indices = torch.cat((input_indices, next_token), dim=1) |
|
|
|