import torch import torch.nn as nn from typing import TYPE_CHECKING from torch.nn import functional as F from .layers import layer_norm, mlp from .config import TextConfig # type checking imports if typechecking if TYPE_CHECKING: from .rope import RotaryEmbedding def text_encoder(input_ids: torch.Tensor, w: nn.Module): return F.embedding(input_ids, w.wte) def attn( x: torch.Tensor, w: nn.Module, attn_mask: torch.Tensor, n_heads: int, rope: "RotaryEmbedding", kv_cache: nn.Module, pos_ids: torch.Tensor, ): bsz, q_len, d_model = x.shape head_dim = d_model // n_heads qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads * 3)*head_dim) qkv_reshaped = qkv_out.view(bsz, q_len, 3, n_heads, head_dim) # 2. Permute to bring heads before sequence length and QKV to the front # Current: (bsz, q_len, 3, n_heads, head_dim) -> (0, 1, 2, 3, 4) # Target: (3, bsz, n_heads, q_len, head_dim) -> (2, 0, 3, 1, 4) qkv_permuted = qkv_reshaped.permute(2, 0, 3, 1, 4) # 3. Unpack/Split along the first dimension (which now separates Q, K, V) q, k, v = qkv_permuted[0], qkv_permuted[1], qkv_permuted[2] q = rope.apply(q, pos_ids) k = rope.apply(k, pos_ids) k, v = kv_cache.update(pos_ids, k, v) out = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask ) out = out.transpose(1, 2).reshape(bsz, q_len, d_model) out = w.proj(out) return out def text_decoder( x: torch.Tensor, w: nn.Module, attn_mask: torch.Tensor, config: TextConfig, rope: "RotaryEmbedding", pos_ids: torch.Tensor, ): for i, block in enumerate(w.blocks): l_in = layer_norm(x, block.ln) l_attn = attn( l_in, block.attn, attn_mask=attn_mask, n_heads=config.n_heads, rope=rope, kv_cache=block.kv_cache, pos_ids=pos_ids, ) l_mlp = mlp(l_in, block.mlp) x = x + l_attn + l_mlp return x def lm_head(hidden_BTC: torch.Tensor, w: nn.Module): hidden_BC = hidden_BTC[:, -1, :] hidden_BC = layer_norm(hidden_BC, w.post_ln) logits = w.lm_head(hidden_BC) return logits def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module: qkv_dim = int(config.dim * 3) text = nn.ModuleDict( { "blocks": nn.ModuleList( [ nn.ModuleDict( { "ln": nn.LayerNorm(config.dim, dtype=dtype), "attn": nn.ModuleDict( { "qkv": nn.Linear(config.dim, qkv_dim, dtype=dtype), "proj": nn.Linear( config.dim, config.dim, dtype=dtype ), } ), "mlp": nn.ModuleDict( { "fc1": nn.Linear( config.dim, config.ff_dim, dtype=dtype ), "fc2": nn.Linear( config.ff_dim, config.dim, dtype=dtype ), } ), } ) for _ in range(config.n_layers) ] ), "post_ln": nn.LayerNorm(config.dim, dtype=dtype), "lm_head": nn.Linear(config.dim, config.vocab_size, dtype=dtype), } ) text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype)) return text