HayatoHongoEveryonesAI's picture
initial commit
81845c2
# model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
@dataclass
class ModelConfig:
embedding_dim: int
hidden_dim: int
num_attention_heads: int
layer_count: int
max_sequence_length: int
rope_theta: float
vocab_size: int
# ---- 以下 TokenEmbedding / RotaryEmbedding / MHA / FFN / Block / GPT ----
# (あなたが提示したコードをそのまま貼る)
# added top-p and top-k filtering in generate function
# set vocab_size in config.py
# MHA with KV cache + RoPE + PyTorch SDPA.
# This traditional implementation is easier to understand, and still efficient in practice.
# GQA and MLA is a great way for long-text inference with reduced KV cache size,
# but both comes with slight loss increase and no efficiency merits during training phase.
# KV cache does not help training speed. Codebase will be simpler without it.
# KV cache supports multi-turn continuation by RoPE with position offset.
# No Dropout. Dataset is large enough and regularization is not necessary.
import torch
import torch.nn as nn
import torch.nn.functional as F
class TokenEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
self.token_embedding_table = nn.Embedding(config.vocab_size, config.embedding_dim)
# keep embedding in default dtype (autocast will handle bf16 when enabled)
def forward(self, input_indices):
return self.token_embedding_table(input_indices)
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq_len=2048, rope_theta=1e6):
super().__init__()
inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2) / dim))
position_index = torch.arange(max_seq_len)
frequency_matrix = torch.einsum('i,j->ij', position_index, inv_freq)
cosine = torch.cos(frequency_matrix)[None, None, :, :]
sine = torch.sin(frequency_matrix)[None, None, :, :]
self.register_buffer("cos_cached", cosine, persistent=False)
self.register_buffer("sin_cached", sine, persistent=False)
def apply_rotary_emb(self, x, position_offset=0):
sequence_length = x.size(2)
cosine = self.cos_cached[:, :, position_offset:position_offset + sequence_length, :]
sine = self.sin_cached[:, :, position_offset:position_offset + sequence_length, :]
x_even = x[..., 0::2]
x_odd = x[..., 1::2]
rotated_even = x_even * cosine - x_odd * sine
rotated_odd = x_odd * cosine + x_even * sine
rotated = torch.empty_like(x)
rotated[..., 0::2] = rotated_even
rotated[..., 1::2] = rotated_odd
return rotated
class MultiHeadAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.num_heads = config.num_attention_heads
self.embed_dim = config.embedding_dim
self.head_dim = self.embed_dim // self.num_heads
# QKV projection
self.query_fc = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.key_fc = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.value_fc = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
# Rotary Positional Embedding (RoPE)
self.rotary_emb = RotaryEmbedding(
dim=self.head_dim,
max_seq_len=config.max_sequence_length,
rope_theta=config.rope_theta
)
self.output_projection = nn.Linear(self.embed_dim, self.embed_dim)
self.register_buffer(
"causal_mask",
torch.tril(torch.ones(
config.max_sequence_length,
config.max_sequence_length,
dtype=torch.bool
)),
persistent=False
)
# KV cache
self.register_buffer("cache_k", None, persistent=False)
self.register_buffer("cache_v", None, persistent=False)
self.current_pos = 0
# --------------------------------------------------
# router
# --------------------------------------------------
def forward(self, x, use_cache=False):
input_len = x.size(1)
if use_cache is False:
return self.forward_no_cache(x)
elif use_cache is True and input_len > 1:
return self.forward_prefill(x)
elif use_cache is True and input_len == 1: # Hi scenario also starts with T==1
return self.forward_cached_decoding(x)
else:
raise RuntimeError("Unexpected condition in MultiHeadAttention forward")
# --------------------------------------------------
# (1) no cache : training
# --------------------------------------------------
def forward_no_cache(self, x):
B, T, C = x.shape
Q = self.query_fc(x)
K = self.key_fc(x)
V = self.value_fc(x)
Q = Q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
# RoPE : offset = 0
Q = self.rotary_emb.apply_rotary_emb(Q, position_offset=0)
K = self.rotary_emb.apply_rotary_emb(K, position_offset=0)
out = F.scaled_dot_product_attention(
Q, K, V,
attn_mask=None,
is_causal=True
)
out = out.transpose(1, 2).contiguous().view(B, T, C)
out = self.output_projection(out)
return out
# --------------------------------------------------
# (2) prefill : initialize KV cache
# --------------------------------------------------
def forward_prefill(self, x):
B, T, C = x.shape
Q = self.query_fc(x)
K = self.key_fc(x)
V = self.value_fc(x)
Q = Q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
# init cache
if self.cache_k is None:
self.cache_k = torch.zeros(
B, self.num_heads, self.config.max_sequence_length, self.head_dim,
device=x.device, dtype=K.dtype
)
self.cache_v = torch.zeros(
B, self.num_heads, self.config.max_sequence_length, self.head_dim,
device=x.device, dtype=V.dtype
)
self.current_pos = 0
# RoPE : offset = current_pos (supports multi-turn continuation)
Q = self.rotary_emb.apply_rotary_emb(Q, position_offset=self.current_pos)
K = self.rotary_emb.apply_rotary_emb(K, position_offset=self.current_pos)
# prevent overflow
if self.current_pos + T > self.config.max_sequence_length:
raise RuntimeError("KV cache exceeded max_sequence_length")
self.cache_k[:, :, self.current_pos:self.current_pos + T, :] = K
self.cache_v[:, :, self.current_pos:self.current_pos + T, :] = V
K = self.cache_k[:, :, :self.current_pos + T, :]
V = self.cache_v[:, :, :self.current_pos + T, :]
attn_mask = self.causal_mask[
self.current_pos : self.current_pos + T,
: self.current_pos + T
]
out = F.scaled_dot_product_attention(
Q, K, V,
attn_mask=attn_mask,
is_causal=False
)
self.current_pos += T
out = out.transpose(1, 2).contiguous().view(B, T, C)
out = self.output_projection(out)
return out
# --------------------------------------------------
# (3) decode : cached decoding (1 token)
# --------------------------------------------------
def forward_cached_decoding(self, x):
B, T, C = x.shape
assert T == 1, "cached decoding expects T==1"
Q = self.query_fc(x)
K = self.key_fc(x)
V = self.value_fc(x)
Q = Q.view(B, 1, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(B, 1, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(B, 1, self.num_heads, self.head_dim).transpose(1, 2)
# This is not usually needed since prefill should have initialized the cache.
# Just in case for "Hi" scenario, which starts with single token input.
if self.cache_k is None:
self.cache_k = torch.zeros(
B, self.num_heads, self.config.max_sequence_length, self.head_dim,
device=x.device, dtype=K.dtype
)
self.cache_v = torch.zeros(
B, self.num_heads, self.config.max_sequence_length, self.head_dim,
device=x.device, dtype=V.dtype
)
self.current_pos = 0
if self.current_pos + 1 >= self.config.max_sequence_length:
raise RuntimeError("KV cache exceeded max_sequence_length")
# RoPE : offset = current_pos
Q = self.rotary_emb.apply_rotary_emb(Q, position_offset=self.current_pos)
K = self.rotary_emb.apply_rotary_emb(K, position_offset=self.current_pos)
self.cache_k[:, :, self.current_pos:self.current_pos + 1, :] = K
self.cache_v[:, :, self.current_pos:self.current_pos + 1, :] = V
K = self.cache_k[:, :, :self.current_pos + 1, :]
V = self.cache_v[:, :, :self.current_pos + 1, :]
out = F.scaled_dot_product_attention(
Q, K, V,
attn_mask=None,
is_causal=False
)
self.current_pos += 1
out = out.transpose(1, 2).contiguous().view(B, T, C)
out = self.output_projection(out)
return out
def reset_cache(self):
self.cache_k = None
self.cache_v = None
self.current_pos = 0
class FeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.net = nn.Sequential(
nn.Linear(config.embedding_dim, config.hidden_dim, bias=False),
nn.ReLU(),
nn.Linear(config.hidden_dim, config.embedding_dim, bias=False),
)
def forward(self, input_tensor):
return self.net(input_tensor)
class TransformerBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.layer_norm1 = nn.LayerNorm(config.embedding_dim)
self.layer_norm2 = nn.LayerNorm(config.embedding_dim)
self.multihead_attention = MultiHeadAttention(config=config)
self.feed_forward = FeedForward(config=config)
def forward(self, input_tensor, use_cache=False):
normed_input = self.layer_norm1(input_tensor)
attention_output = self.multihead_attention(normed_input, use_cache=use_cache)
residual_attention = attention_output + input_tensor
normed_attention = self.layer_norm2(residual_attention)
feedforward_output = self.feed_forward(normed_attention)
final_output = feedforward_output + residual_attention
return final_output
class VocabularyLogits(nn.Module):
def __init__(self, config):
super().__init__()
self.output_norm = nn.LayerNorm(config.embedding_dim)
self.vocab_projection = nn.Linear(config.embedding_dim, config.vocab_size, bias=False)
def forward(self, transformer_block_output):
x = transformer_block_output
normalized_output = self.output_norm(x)
vocab_logits = self.vocab_projection(normalized_output)
return vocab_logits
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.token_embedding_layer = TokenEmbedding(config=config)
self.blocks = nn.ModuleList([TransformerBlock(config=config) for _ in range(config.layer_count)])
self.vocab_projection = VocabularyLogits(config=config)
self.criterion = nn.CrossEntropyLoss()
def forward(self, input_indices, target_indices, use_cache=False):
token_embeddings = self.token_embedding_layer.forward(input_indices)
x = token_embeddings
for block in self.blocks:
x = block(x, use_cache=use_cache)
logits = self.vocab_projection(x)
if target_indices is None:
return logits, None
batch_size, token_len, vocab_size = logits.shape
logits_flat = logits.view(batch_size * token_len, vocab_size)
targets_flat = target_indices.view(batch_size * token_len)
loss = self.criterion(logits_flat, targets_flat)
return logits, loss
def generate(self,
input_indices,
max_new_tokens,
temperature=1.0,
use_cache=True,
reset_cache=True,
top_k=None, # ### NEW ###
top_p=None, # ### NEW ###
):
self.eval()
if reset_cache:
for block in self.blocks:
block.multihead_attention.reset_cache()
next_token = None
for i in range(max_new_tokens):
if use_cache:
if i == 0:
logits, _ = self.forward(input_indices, None, use_cache=True)
else:
logits, _ = self.forward(next_token, None, use_cache=True)
else:
logits, _ = self.forward(input_indices, None, use_cache=False)
""" DELETE
last_logits = logits[:, -1, :] / temperature
probs = F.softmax(last_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
"""
### NEW ###
last_logits = logits[:, -1, :] / temperature
if top_k is not None:
top_k = min(top_k, last_logits.size(-1))
values, _ = torch.topk(last_logits, top_k)
min_value = values[:, -1].unsqueeze(-1)
last_logits = torch.where(
last_logits < min_value,
torch.full_like(last_logits, float("-inf")),
last_logits,
)
if top_p is not None:
sorted_logits, sorted_indices = torch.sort(last_logits, descending=True)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_mask = cumulative_probs > top_p
sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
sorted_mask[..., 0] = False
sorted_logits = torch.where(
sorted_mask,
torch.full_like(sorted_logits, float("-inf")),
sorted_logits,
)
last_logits = torch.zeros_like(last_logits).scatter(
-1, sorted_indices, sorted_logits
)
probs = F.softmax(last_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
### NEW ###
yield int(next_token.item())
input_indices = torch.cat((input_indices, next_token), dim=1)