# Copyright 2026 Jakub Sykała # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import torch import torch.nn as nn import torch.nn.functional as F from dataclasses import dataclass from typing import Dict, Optional, Tuple # Feature indices F_SYLLABLE = 0 F_ONSET = 1 F_NUCLEUS = 2 F_CODA = 3 F_POSITION = 4 F_CAPITALIZED = 5 F_TOKEN_TYPE = 6 F_SPACE_AFTER = 7 F_WORD_END = 8 N_FEATURES = 9 @dataclass class LunaConfig: """Configuration for Luna.""" # Vocabulary sizes syllable_vocab: int = 32768 onset_vocab: int = 2048 nucleus_vocab: int = 512 coda_vocab: int = 2048 # Fixed vocab sizes position_vocab: int = 4 capitalized_vocab: int = 2 token_type_vocab: int = 4 space_vocab: int = 2 word_end_vocab: int = 2 # Embedding dimensions syllable_dim: int = 256 onset_dim: int = 64 nucleus_dim: int = 64 coda_dim: int = 64 position_dim: int = 32 cap_dim: int = 16 type_dim: int = 16 space_dim: int = 32 word_end_dim: int = 16 # Transformer n_layer: int = 12 n_head: int = 12 n_embd: int = 768 dropout: float = 0.1 max_seq_len: int = 1024 # Optimization flags fuse_output_heads: bool = True use_flash_attention: bool = True #-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=--= # # Components class RMSNorm(nn.Module): __constants__ = ['eps'] def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: # Fused computation return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight class RotaryEmbedding(nn.Module): """RoPE with pre-computed cache.""" def __init__(self, dim: int, max_seq_len: int = 2048): super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._build_cache(max_seq_len) def _build_cache(self, seq_len: int): device = self.inv_freq.device t = torch.arange(seq_len, device=device) freqs = torch.outer(t, self.inv_freq) self.register_buffer("cos_cached", freqs.cos(), persistent=False) self.register_buffer("sin_cached", freqs.sin(), persistent=False) def forward(self, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]: if seq_len > self.cos_cached.shape[0]: self._build_cache(seq_len) return self.cos_cached[:seq_len], self.sin_cached[:seq_len] @torch.jit.script def apply_rotary_emb_fused(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """JIT-compiled rotary embedding application.""" cos = cos.unsqueeze(0).unsqueeze(0) sin = sin.unsqueeze(0).unsqueeze(0) q_even, q_odd = q[..., 0::2], q[..., 1::2] k_even, k_odd = k[..., 0::2], k[..., 1::2] q_rot = torch.cat([q_even * cos - q_odd * sin, q_even * sin + q_odd * cos], dim=-1) k_rot = torch.cat([k_even * cos - k_odd * sin, k_even * sin + k_odd * cos], dim=-1) return q_rot, k_rot class Attention(nn.Module): def __init__(self, config: LunaConfig): super().__init__() self.n_head = config.n_head self.head_dim = config.n_embd // config.n_head self.dropout_p = config.dropout # Fused QKV projection (single matmul instead of 3) self.wqkv = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False) self.wo = nn.Linear(config.n_embd, config.n_embd, bias=False) def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: B, T, C = x.shape # Fused QKV: single matmul qkv = self.wqkv(x) q, k, v = qkv.split(C, dim=-1) q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) q, k = apply_rotary_emb_fused(q, k, cos, sin) # Flash Attention out = F.scaled_dot_product_attention( q, k, v, attn_mask=None, dropout_p=self.dropout_p if self.training else 0.0, is_causal=True ) out = out.transpose(1, 2).contiguous().view(B, T, C) return self.wo(out) class FeedForward(nn.Module): """SwiGLU with fused gate computation.""" def __init__(self, config: LunaConfig): super().__init__() hidden = int(4 * config.n_embd) # Fuse w1 and w3 into single matmul self.w13 = nn.Linear(config.n_embd, 2 * hidden, bias=False) self.w2 = nn.Linear(hidden, config.n_embd, bias=False) self.dropout = nn.Dropout(config.dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: # Single matmul for both gate and value x13 = self.w13(x) x1, x3 = x13.chunk(2, dim=-1) return self.dropout(self.w2(F.silu(x1) * x3)) class TransformerBlock(nn.Module): """Pre-norm transformer block.""" def __init__(self, config: LunaConfig): super().__init__() self.norm1 = RMSNorm(config.n_embd) self.attn = Attention(config) self.norm2 = RMSNorm(config.n_embd) self.ffn = FeedForward(config) def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.norm1(x), cos, sin) x = x + self.ffn(self.norm2(x)) return x #-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= # # Dual Stream Fusion class OptimizedDualStreamFusion(nn.Module): def __init__(self, config: LunaConfig): super().__init__() self.config = config # Semantic stream self.syllable_embed = nn.Embedding(config.syllable_vocab, config.syllable_dim) # Phonetic stream - combined embedding then project self.onset_embed = nn.Embedding(config.onset_vocab, config.onset_dim) self.nucleus_embed = nn.Embedding(config.nucleus_vocab, config.nucleus_dim) self.coda_embed = nn.Embedding(config.coda_vocab, config.coda_dim) phonetic_dim = config.onset_dim + config.nucleus_dim + config.coda_dim self.phonetic_proj = nn.Linear(phonetic_dim, config.syllable_dim, bias=False) self.gate = nn.Sequential( nn.Linear(config.syllable_dim * 2, config.syllable_dim // 2, bias=False), nn.SiLU(), # SiLU is fused in CUDA nn.Linear(config.syllable_dim // 2, 1, bias=False), nn.Sigmoid() ) # Auxiliary embeddings (avoid reserved names like 'type') self.aux_embeddings = nn.ModuleDict({ 'position': nn.Embedding(config.position_vocab, config.position_dim), 'cap': nn.Embedding(config.capitalized_vocab, config.cap_dim), 'tok_type': nn.Embedding(config.token_type_vocab, config.type_dim), # renamed from 'type' 'space': nn.Embedding(config.space_vocab, config.space_dim), 'word_end': nn.Embedding(config.word_end_vocab, config.word_end_dim), }) self.aux_dim = config.position_dim + config.cap_dim + config.type_dim + config.space_dim + config.word_end_dim # Final projection total_dim = config.syllable_dim + self.aux_dim self.output_proj = nn.Linear(total_dim, config.n_embd, bias=False) self.output_norm = RMSNorm(config.n_embd) def forward(self, features: torch.Tensor) -> torch.Tensor: """ Args: features: [B, T, 9] stacked feature tensor Returns: [B, T, n_embd] embedded representation """ # Extract features (compile-friendly static indexing) syl_ids = features[:, :, F_SYLLABLE] onset_ids = features[:, :, F_ONSET] nucleus_ids = features[:, :, F_NUCLEUS] coda_ids = features[:, :, F_CODA] pos_ids = features[:, :, F_POSITION] cap_ids = features[:, :, F_CAPITALIZED] type_ids = features[:, :, F_TOKEN_TYPE] space_ids = features[:, :, F_SPACE_AFTER] word_end_ids = features[:, :, F_WORD_END] # Semantic stream semantic = self.syllable_embed(syl_ids) # Phonetic stream - batch the lookups onset = self.onset_embed(onset_ids) nucleus = self.nucleus_embed(nucleus_ids) coda = self.coda_embed(coda_ids) phonetic = self.phonetic_proj(torch.cat([onset, nucleus, coda], dim=-1)) # Gated fusion gate_in = torch.cat([semantic, phonetic], dim=-1) alpha = self.gate(gate_in) fused = alpha * semantic + (1 - alpha) * phonetic # Auxiliary features - batch all lookups aux = torch.cat([ self.aux_embeddings['position'](pos_ids), self.aux_embeddings['cap'](cap_ids), self.aux_embeddings['tok_type'](type_ids), # renamed from 'type' self.aux_embeddings['space'](space_ids), self.aux_embeddings['word_end'](word_end_ids), ], dim=-1) # Final output combined = torch.cat([fused, aux], dim=-1) return self.output_norm(self.output_proj(combined)) #-=-=-=-=-=-=-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= # # Output Heads class FusedOutputHeads(nn.Module): def __init__(self, config: LunaConfig): super().__init__() # Head output sizes self.head_sizes = { 'syllable': config.syllable_vocab, 'onset': config.onset_vocab, 'nucleus': config.nucleus_vocab, 'coda': config.coda_vocab, 'position': config.position_vocab, 'is_capitalized': config.capitalized_vocab, 'token_type': config.token_type_vocab, 'has_space_after': config.space_vocab, } self.head_names = list(self.head_sizes.keys()) self.total_output = sum(self.head_sizes.values()) # Single fused projection self.fused_head = nn.Linear(config.n_embd, self.total_output, bias=False) # Pre-compute split sizes self.split_sizes = [self.head_sizes[name] for name in self.head_names] # Register as buffer for fast access self.register_buffer('_split_sizes_tensor', torch.tensor(self.split_sizes)) def forward(self, h: torch.Tensor) -> Dict[str, torch.Tensor]: """ Args: h: [B, T, n_embd] Returns: Dict of logits for each head """ # Single matmul all_logits = self.fused_head(h) # Split into heads splits = all_logits.split(self.split_sizes, dim=-1) return {name: logit for name, logit in zip(self.head_names, splits)} #-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= # # Loss computation class OptimizedMultiTaskLoss(nn.Module): """Vectorized multi-task loss computation. """ def __init__(self, config: LunaConfig): super().__init__() # Loss weights as buffer self.register_buffer('loss_weights', torch.tensor([ 1.0, # syllable 0.2, # onset 0.2, # nucleus 0.2, # coda 0.3, # position 0.1, # is_capitalized 0.15, # token_type 0.4, # has_space_after ])) self.weight_sum = self.loss_weights.sum().item() # Position and type weights for syllable loss self.register_buffer('position_weights', torch.tensor([0.8, 1.0, 1.5, 1.2])) self.register_buffer('type_weights', torch.tensor([1.0, 1.2, 2.5, 1.0])) # Feature indices for targets self.target_indices = [F_SYLLABLE, F_ONSET, F_NUCLEUS, F_CODA, F_POSITION, F_CAPITALIZED, F_TOKEN_TYPE, F_SPACE_AFTER] def forward(self, logits: Dict[str, torch.Tensor], targets: torch.Tensor) -> torch.Tensor: """ Args: logits: Dict of [B, T, V] tensors targets: [B, T, 9] target tensor """ head_names = ['syllable', 'onset', 'nucleus', 'coda', 'position', 'is_capitalized', 'token_type', 'has_space_after'] total_loss = 0.0 # Get position/type targets for syllable weighting pos_targets = targets[:, :, F_POSITION] type_targets = targets[:, :, F_TOKEN_TYPE] for i, name in enumerate(head_names): logit = logits[name] target = targets[:, :, self.target_indices[i]] weight = self.loss_weights[i] if name == 'syllable': # Weighted syllable loss B, T, V = logit.shape per_token = F.cross_entropy( logit.view(-1, V), target.view(-1), reduction='none' ).view(B, T) pos_w = self.position_weights[pos_targets] type_w = self.type_weights[type_targets] head_loss = (per_token * pos_w * type_w).mean() else: head_loss = F.cross_entropy( logit.view(-1, logit.size(-1)), target.view(-1) ) total_loss = total_loss + weight * head_loss return total_loss / self.weight_sum #-=-=-=-=-=--=-=-=-==-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Main Model class Luna(nn.Module): def __init__(self, config: LunaConfig): super().__init__() self.config = config # Embedding self.embedding = OptimizedDualStreamFusion(config) # Transformer self.rotary = RotaryEmbedding(config.n_embd // config.n_head, config.max_seq_len) self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)]) self.norm = RMSNorm(config.n_embd) # Output (fused or separate based on config) if config.fuse_output_heads: self.heads = FusedOutputHeads(config) else: self.heads = nn.ModuleDict({ 'syllable': nn.Linear(config.n_embd, config.syllable_vocab, bias=False), 'onset': nn.Linear(config.n_embd, config.onset_vocab, bias=False), 'nucleus': nn.Linear(config.n_embd, config.nucleus_vocab, bias=False), 'coda': nn.Linear(config.n_embd, config.coda_vocab, bias=False), 'position': nn.Linear(config.n_embd, config.position_vocab, bias=False), 'is_capitalized': nn.Linear(config.n_embd, config.capitalized_vocab, bias=False), 'token_type': nn.Linear(config.n_embd, config.token_type_vocab, bias=False), 'has_space_after': nn.Linear(config.n_embd, config.space_vocab, bias=False), }) self.dropout = nn.Dropout(config.dropout) self.loss_fn = OptimizedMultiTaskLoss(config) self.apply(self._init_weights) self._print_info() def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def _print_info(self): n_params = sum(p.numel() for p in self.parameters()) embed_params = sum(p.numel() for p in self.embedding.parameters()) if isinstance(self.heads, FusedOutputHeads): head_params = self.heads.fused_head.weight.numel() else: head_params = sum(p.numel() for p in self.heads.parameters()) print(f"\n{'='*60}") print("Luna Summary") print(f"{'='*60}") print(f"Total parameters: {n_params:,}") print(f"Embedding parameters: {embed_params:,}") print(f"Output head parameters: {head_params:,}") print(f"Transformer backbone: {n_params - embed_params - head_params:,}") print(f"\nOptimizations enabled:") print(f" - Fused QKV projection") print(f" - Fused FFN gate") print(f" - Fused output heads: {self.config.fuse_output_heads}") print(f" - JIT rotary embeddings") print(f" - RMSNorm everywhere") print(f" - Vectorized loss") print(f"{'='*60}\n") def forward( self, features: torch.Tensor, targets: Optional[torch.Tensor] = None ) -> Tuple[Dict[str, torch.Tensor], Optional[torch.Tensor]]: """ Args: features: [B, T, 9] input features targets: [B, T, 9] targets (optional) Returns: logits: Dict of output logits loss: Combined loss (if targets provided) """ B, T, _ = features.shape # Embedding h = self.embedding(features) h = self.dropout(h) # Transformer cos, sin = self.rotary(T) for layer in self.layers: h = layer(h, cos, sin) h = self.norm(h) # Output heads if isinstance(self.heads, FusedOutputHeads): logits = self.heads(h) else: logits = {name: head(h) for name, head in self.heads.items()} # Loss loss = None if targets is not None: loss = self.loss_fn(logits, targets) return logits, loss #-=-=-=-=-=-=-=-=-=-=-=--=-=-=---=-=-=-=-=-=-=-=-=-=-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Helper for Migration def dict_to_tensor(features_dict: Dict[str, torch.Tensor]) -> torch.Tensor: """Convert dict features to stacked tensor.""" return torch.stack([ features_dict['syllable_id'], features_dict['onset_id'], features_dict['nucleus_id'], features_dict['coda_id'], features_dict['position'], features_dict['is_capitalized'], features_dict['token_type'], features_dict['has_space_after'], features_dict['is_word_end'], ], dim=-1) #-=-=-=-=-=-=-=-=-=-=-=--=-=-=---=-=-=-=-=-=-=-=-=-=-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Lil Test if __name__ == "__main__": print("Luna - Speed Test") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") config = LunaConfig( syllable_vocab=32768, onset_vocab=2048, nucleus_vocab=512, coda_vocab=2048, max_seq_len=1024, fuse_output_heads=True, ) model = Luna(config).to(device) # Test forward pass B, T = 8, 1024 features = torch.stack([ torch.randint(0, 1000, (B, T)), torch.randint(0, 100, (B, T)), torch.randint(0, 50, (B, T)), torch.randint(0, 100, (B, T)), torch.randint(0, 4, (B, T)), torch.randint(0, 2, (B, T)), torch.randint(0, 4, (B, T)), torch.randint(0, 2, (B, T)), torch.randint(0, 2, (B, T)), ], dim=-1).to(device) targets = features.clone() # Warmup for _ in range(3): with torch.cuda.amp.autocast(dtype=torch.bfloat16): logits, loss = model(features, targets) torch.cuda.synchronize() # Benchmark import time n_iters = 50 start = time.time() for _ in range(n_iters): with torch.cuda.amp.autocast(dtype=torch.bfloat16): logits, loss = model(features, targets) loss.backward() torch.cuda.synchronize() elapsed = time.time() - start tokens_per_iter = B * T tok_per_sec = (n_iters * tokens_per_iter) / elapsed print(f"\nBenchmark Results:") print(f" Batch: {B} x {T} = {B*T:,} tokens") print(f" Iterations: {n_iters}") print(f" Time: {elapsed:.2f}s") print(f" Throughput: {tok_per_sec:,.0f} tok/s") print(f" Loss: {loss.item():.4f}") # Test torch.compile print("\nTesting torch.compile()...") compiled_model = torch.compile(model, mode="reduce-overhead") # Warmup compiled for _ in range(5): with torch.cuda.amp.autocast(dtype=torch.bfloat16): logits, loss = compiled_model(features, targets) torch.cuda.synchronize() # Benchmark compiled start = time.time() for _ in range(n_iters): with torch.cuda.amp.autocast(dtype=torch.bfloat16): logits, loss = compiled_model(features, targets) loss.backward() torch.cuda.synchronize() elapsed_compiled = time.time() - start tok_per_sec_compiled = (n_iters * tokens_per_iter) / elapsed_compiled print(f"\nCompiled Results:") print(f" Throughput: {tok_per_sec_compiled:,.0f} tok/s") print(f" Speedup: {tok_per_sec_compiled/tok_per_sec:.2f}x") print(f"\n✓ All tests passed!")