bounty
final working rotary fix and removed image cache waste
505474b
import torch
import torch.nn as nn
from typing import TYPE_CHECKING
from torch.nn import functional as F
from .layers import layer_norm, mlp
from .config import TextConfig
# type checking imports if typechecking
if TYPE_CHECKING:
from .rope import RotaryEmbedding
def text_encoder(input_ids: torch.Tensor, w: nn.Module):
return F.embedding(input_ids, w.wte)
def attn(
x: torch.Tensor,
w: nn.Module,
attn_mask: torch.Tensor,
n_heads: int,
rope: "RotaryEmbedding",
kv_cache: nn.Module,
pos_ids: torch.Tensor,
):
bsz, q_len, d_model = x.shape
head_dim = d_model // n_heads
qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads * 3)*head_dim)
qkv_reshaped = qkv_out.view(bsz, q_len, 3, n_heads, head_dim)
# 2. Permute to bring heads before sequence length and QKV to the front
# Current: (bsz, q_len, 3, n_heads, head_dim) -> (0, 1, 2, 3, 4)
# Target: (3, bsz, n_heads, q_len, head_dim) -> (2, 0, 3, 1, 4)
qkv_permuted = qkv_reshaped.permute(2, 0, 3, 1, 4)
# 3. Unpack/Split along the first dimension (which now separates Q, K, V)
q, k, v = qkv_permuted[0], qkv_permuted[1], qkv_permuted[2]
q = rope.apply(q, pos_ids)
k = rope.apply(k, pos_ids)
k, v = kv_cache.update(pos_ids, k, v)
out = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask
)
out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
out = w.proj(out)
return out
def text_decoder(
x: torch.Tensor,
w: nn.Module,
attn_mask: torch.Tensor,
config: TextConfig,
rope: "RotaryEmbedding",
pos_ids: torch.Tensor,
):
for i, block in enumerate(w.blocks):
l_in = layer_norm(x, block.ln)
l_attn = attn(
l_in,
block.attn,
attn_mask=attn_mask,
n_heads=config.n_heads,
rope=rope,
kv_cache=block.kv_cache,
pos_ids=pos_ids,
)
l_mlp = mlp(l_in, block.mlp)
x = x + l_attn + l_mlp
return x
def lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
hidden_BC = hidden_BTC[:, -1, :]
hidden_BC = layer_norm(hidden_BC, w.post_ln)
logits = w.lm_head(hidden_BC)
return logits
def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
qkv_dim = int(config.dim * 3)
text = nn.ModuleDict(
{
"blocks": nn.ModuleList(
[
nn.ModuleDict(
{
"ln": nn.LayerNorm(config.dim, dtype=dtype),
"attn": nn.ModuleDict(
{
"qkv": nn.Linear(config.dim, qkv_dim, dtype=dtype),
"proj": nn.Linear(
config.dim, config.dim, dtype=dtype
),
}
),
"mlp": nn.ModuleDict(
{
"fc1": nn.Linear(
config.dim, config.ff_dim, dtype=dtype
),
"fc2": nn.Linear(
config.ff_dim, config.dim, dtype=dtype
),
}
),
}
)
for _ in range(config.n_layers)
]
),
"post_ln": nn.LayerNorm(config.dim, dtype=dtype),
"lm_head": nn.Linear(config.dim, config.vocab_size, dtype=dtype),
}
)
text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype))
return text