Luna-150M / model.py
JMSykala's picture
Upload 9 files
9c737ff verified
# Copyright 2026 Jakub Sykała
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
# Feature indices
F_SYLLABLE = 0
F_ONSET = 1
F_NUCLEUS = 2
F_CODA = 3
F_POSITION = 4
F_CAPITALIZED = 5
F_TOKEN_TYPE = 6
F_SPACE_AFTER = 7
F_WORD_END = 8
N_FEATURES = 9
@dataclass
class LunaConfig:
"""Configuration for Luna."""
# Vocabulary sizes
syllable_vocab: int = 32768
onset_vocab: int = 2048
nucleus_vocab: int = 512
coda_vocab: int = 2048
# Fixed vocab sizes
position_vocab: int = 4
capitalized_vocab: int = 2
token_type_vocab: int = 4
space_vocab: int = 2
word_end_vocab: int = 2
# Embedding dimensions
syllable_dim: int = 256
onset_dim: int = 64
nucleus_dim: int = 64
coda_dim: int = 64
position_dim: int = 32
cap_dim: int = 16
type_dim: int = 16
space_dim: int = 32
word_end_dim: int = 16
# Transformer
n_layer: int = 12
n_head: int = 12
n_embd: int = 768
dropout: float = 0.1
max_seq_len: int = 1024
# Optimization flags
fuse_output_heads: bool = True
use_flash_attention: bool = True
#-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=--=
#
# Components
class RMSNorm(nn.Module):
__constants__ = ['eps']
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Fused computation
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
class RotaryEmbedding(nn.Module):
"""RoPE with pre-computed cache."""
def __init__(self, dim: int, max_seq_len: int = 2048):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._build_cache(max_seq_len)
def _build_cache(self, seq_len: int):
device = self.inv_freq.device
t = torch.arange(seq_len, device=device)
freqs = torch.outer(t, self.inv_freq)
self.register_buffer("cos_cached", freqs.cos(), persistent=False)
self.register_buffer("sin_cached", freqs.sin(), persistent=False)
def forward(self, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
if seq_len > self.cos_cached.shape[0]:
self._build_cache(seq_len)
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
@torch.jit.script
def apply_rotary_emb_fused(q: torch.Tensor, k: torch.Tensor,
cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""JIT-compiled rotary embedding application."""
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
q_even, q_odd = q[..., 0::2], q[..., 1::2]
k_even, k_odd = k[..., 0::2], k[..., 1::2]
q_rot = torch.cat([q_even * cos - q_odd * sin, q_even * sin + q_odd * cos], dim=-1)
k_rot = torch.cat([k_even * cos - k_odd * sin, k_even * sin + k_odd * cos], dim=-1)
return q_rot, k_rot
class Attention(nn.Module):
def __init__(self, config: LunaConfig):
super().__init__()
self.n_head = config.n_head
self.head_dim = config.n_embd // config.n_head
self.dropout_p = config.dropout
# Fused QKV projection (single matmul instead of 3)
self.wqkv = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
self.wo = nn.Linear(config.n_embd, config.n_embd, bias=False)
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
B, T, C = x.shape
# Fused QKV: single matmul
qkv = self.wqkv(x)
q, k, v = qkv.split(C, dim=-1)
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
q, k = apply_rotary_emb_fused(q, k, cos, sin)
# Flash Attention
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=self.dropout_p if self.training else 0.0,
is_causal=True
)
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.wo(out)
class FeedForward(nn.Module):
"""SwiGLU with fused gate computation."""
def __init__(self, config: LunaConfig):
super().__init__()
hidden = int(4 * config.n_embd)
# Fuse w1 and w3 into single matmul
self.w13 = nn.Linear(config.n_embd, 2 * hidden, bias=False)
self.w2 = nn.Linear(hidden, config.n_embd, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Single matmul for both gate and value
x13 = self.w13(x)
x1, x3 = x13.chunk(2, dim=-1)
return self.dropout(self.w2(F.silu(x1) * x3))
class TransformerBlock(nn.Module):
"""Pre-norm transformer block."""
def __init__(self, config: LunaConfig):
super().__init__()
self.norm1 = RMSNorm(config.n_embd)
self.attn = Attention(config)
self.norm2 = RMSNorm(config.n_embd)
self.ffn = FeedForward(config)
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.norm1(x), cos, sin)
x = x + self.ffn(self.norm2(x))
return x
#-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
#
# Dual Stream Fusion
class OptimizedDualStreamFusion(nn.Module):
def __init__(self, config: LunaConfig):
super().__init__()
self.config = config
# Semantic stream
self.syllable_embed = nn.Embedding(config.syllable_vocab, config.syllable_dim)
# Phonetic stream - combined embedding then project
self.onset_embed = nn.Embedding(config.onset_vocab, config.onset_dim)
self.nucleus_embed = nn.Embedding(config.nucleus_vocab, config.nucleus_dim)
self.coda_embed = nn.Embedding(config.coda_vocab, config.coda_dim)
phonetic_dim = config.onset_dim + config.nucleus_dim + config.coda_dim
self.phonetic_proj = nn.Linear(phonetic_dim, config.syllable_dim, bias=False)
self.gate = nn.Sequential(
nn.Linear(config.syllable_dim * 2, config.syllable_dim // 2, bias=False),
nn.SiLU(), # SiLU is fused in CUDA
nn.Linear(config.syllable_dim // 2, 1, bias=False),
nn.Sigmoid()
)
# Auxiliary embeddings (avoid reserved names like 'type')
self.aux_embeddings = nn.ModuleDict({
'position': nn.Embedding(config.position_vocab, config.position_dim),
'cap': nn.Embedding(config.capitalized_vocab, config.cap_dim),
'tok_type': nn.Embedding(config.token_type_vocab, config.type_dim), # renamed from 'type'
'space': nn.Embedding(config.space_vocab, config.space_dim),
'word_end': nn.Embedding(config.word_end_vocab, config.word_end_dim),
})
self.aux_dim = config.position_dim + config.cap_dim + config.type_dim + config.space_dim + config.word_end_dim
# Final projection
total_dim = config.syllable_dim + self.aux_dim
self.output_proj = nn.Linear(total_dim, config.n_embd, bias=False)
self.output_norm = RMSNorm(config.n_embd)
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""
Args:
features: [B, T, 9] stacked feature tensor
Returns:
[B, T, n_embd] embedded representation
"""
# Extract features (compile-friendly static indexing)
syl_ids = features[:, :, F_SYLLABLE]
onset_ids = features[:, :, F_ONSET]
nucleus_ids = features[:, :, F_NUCLEUS]
coda_ids = features[:, :, F_CODA]
pos_ids = features[:, :, F_POSITION]
cap_ids = features[:, :, F_CAPITALIZED]
type_ids = features[:, :, F_TOKEN_TYPE]
space_ids = features[:, :, F_SPACE_AFTER]
word_end_ids = features[:, :, F_WORD_END]
# Semantic stream
semantic = self.syllable_embed(syl_ids)
# Phonetic stream - batch the lookups
onset = self.onset_embed(onset_ids)
nucleus = self.nucleus_embed(nucleus_ids)
coda = self.coda_embed(coda_ids)
phonetic = self.phonetic_proj(torch.cat([onset, nucleus, coda], dim=-1))
# Gated fusion
gate_in = torch.cat([semantic, phonetic], dim=-1)
alpha = self.gate(gate_in)
fused = alpha * semantic + (1 - alpha) * phonetic
# Auxiliary features - batch all lookups
aux = torch.cat([
self.aux_embeddings['position'](pos_ids),
self.aux_embeddings['cap'](cap_ids),
self.aux_embeddings['tok_type'](type_ids), # renamed from 'type'
self.aux_embeddings['space'](space_ids),
self.aux_embeddings['word_end'](word_end_ids),
], dim=-1)
# Final output
combined = torch.cat([fused, aux], dim=-1)
return self.output_norm(self.output_proj(combined))
#-=-=-=-=-=-=-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
#
# Output Heads
class FusedOutputHeads(nn.Module):
def __init__(self, config: LunaConfig):
super().__init__()
# Head output sizes
self.head_sizes = {
'syllable': config.syllable_vocab,
'onset': config.onset_vocab,
'nucleus': config.nucleus_vocab,
'coda': config.coda_vocab,
'position': config.position_vocab,
'is_capitalized': config.capitalized_vocab,
'token_type': config.token_type_vocab,
'has_space_after': config.space_vocab,
}
self.head_names = list(self.head_sizes.keys())
self.total_output = sum(self.head_sizes.values())
# Single fused projection
self.fused_head = nn.Linear(config.n_embd, self.total_output, bias=False)
# Pre-compute split sizes
self.split_sizes = [self.head_sizes[name] for name in self.head_names]
# Register as buffer for fast access
self.register_buffer('_split_sizes_tensor', torch.tensor(self.split_sizes))
def forward(self, h: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Args:
h: [B, T, n_embd]
Returns:
Dict of logits for each head
"""
# Single matmul
all_logits = self.fused_head(h)
# Split into heads
splits = all_logits.split(self.split_sizes, dim=-1)
return {name: logit for name, logit in zip(self.head_names, splits)}
#-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
#
# Loss computation
class OptimizedMultiTaskLoss(nn.Module):
"""Vectorized multi-task loss computation. """
def __init__(self, config: LunaConfig):
super().__init__()
# Loss weights as buffer
self.register_buffer('loss_weights', torch.tensor([
1.0, # syllable
0.2, # onset
0.2, # nucleus
0.2, # coda
0.3, # position
0.1, # is_capitalized
0.15, # token_type
0.4, # has_space_after
]))
self.weight_sum = self.loss_weights.sum().item()
# Position and type weights for syllable loss
self.register_buffer('position_weights', torch.tensor([0.8, 1.0, 1.5, 1.2]))
self.register_buffer('type_weights', torch.tensor([1.0, 1.2, 2.5, 1.0]))
# Feature indices for targets
self.target_indices = [F_SYLLABLE, F_ONSET, F_NUCLEUS, F_CODA,
F_POSITION, F_CAPITALIZED, F_TOKEN_TYPE, F_SPACE_AFTER]
def forward(self, logits: Dict[str, torch.Tensor], targets: torch.Tensor) -> torch.Tensor:
"""
Args:
logits: Dict of [B, T, V] tensors
targets: [B, T, 9] target tensor
"""
head_names = ['syllable', 'onset', 'nucleus', 'coda',
'position', 'is_capitalized', 'token_type', 'has_space_after']
total_loss = 0.0
# Get position/type targets for syllable weighting
pos_targets = targets[:, :, F_POSITION]
type_targets = targets[:, :, F_TOKEN_TYPE]
for i, name in enumerate(head_names):
logit = logits[name]
target = targets[:, :, self.target_indices[i]]
weight = self.loss_weights[i]
if name == 'syllable':
# Weighted syllable loss
B, T, V = logit.shape
per_token = F.cross_entropy(
logit.view(-1, V), target.view(-1), reduction='none'
).view(B, T)
pos_w = self.position_weights[pos_targets]
type_w = self.type_weights[type_targets]
head_loss = (per_token * pos_w * type_w).mean()
else:
head_loss = F.cross_entropy(
logit.view(-1, logit.size(-1)), target.view(-1)
)
total_loss = total_loss + weight * head_loss
return total_loss / self.weight_sum
#-=-=-=-=-=--=-=-=-==-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
#
# Main Model
class Luna(nn.Module):
def __init__(self, config: LunaConfig):
super().__init__()
self.config = config
# Embedding
self.embedding = OptimizedDualStreamFusion(config)
# Transformer
self.rotary = RotaryEmbedding(config.n_embd // config.n_head, config.max_seq_len)
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)])
self.norm = RMSNorm(config.n_embd)
# Output (fused or separate based on config)
if config.fuse_output_heads:
self.heads = FusedOutputHeads(config)
else:
self.heads = nn.ModuleDict({
'syllable': nn.Linear(config.n_embd, config.syllable_vocab, bias=False),
'onset': nn.Linear(config.n_embd, config.onset_vocab, bias=False),
'nucleus': nn.Linear(config.n_embd, config.nucleus_vocab, bias=False),
'coda': nn.Linear(config.n_embd, config.coda_vocab, bias=False),
'position': nn.Linear(config.n_embd, config.position_vocab, bias=False),
'is_capitalized': nn.Linear(config.n_embd, config.capitalized_vocab, bias=False),
'token_type': nn.Linear(config.n_embd, config.token_type_vocab, bias=False),
'has_space_after': nn.Linear(config.n_embd, config.space_vocab, bias=False),
})
self.dropout = nn.Dropout(config.dropout)
self.loss_fn = OptimizedMultiTaskLoss(config)
self.apply(self._init_weights)
self._print_info()
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def _print_info(self):
n_params = sum(p.numel() for p in self.parameters())
embed_params = sum(p.numel() for p in self.embedding.parameters())
if isinstance(self.heads, FusedOutputHeads):
head_params = self.heads.fused_head.weight.numel()
else:
head_params = sum(p.numel() for p in self.heads.parameters())
print(f"\n{'='*60}")
print("Luna Summary")
print(f"{'='*60}")
print(f"Total parameters: {n_params:,}")
print(f"Embedding parameters: {embed_params:,}")
print(f"Output head parameters: {head_params:,}")
print(f"Transformer backbone: {n_params - embed_params - head_params:,}")
print(f"\nOptimizations enabled:")
print(f" - Fused QKV projection")
print(f" - Fused FFN gate")
print(f" - Fused output heads: {self.config.fuse_output_heads}")
print(f" - JIT rotary embeddings")
print(f" - RMSNorm everywhere")
print(f" - Vectorized loss")
print(f"{'='*60}\n")
def forward(
self,
features: torch.Tensor,
targets: Optional[torch.Tensor] = None
) -> Tuple[Dict[str, torch.Tensor], Optional[torch.Tensor]]:
"""
Args:
features: [B, T, 9] input features
targets: [B, T, 9] targets (optional)
Returns:
logits: Dict of output logits
loss: Combined loss (if targets provided)
"""
B, T, _ = features.shape
# Embedding
h = self.embedding(features)
h = self.dropout(h)
# Transformer
cos, sin = self.rotary(T)
for layer in self.layers:
h = layer(h, cos, sin)
h = self.norm(h)
# Output heads
if isinstance(self.heads, FusedOutputHeads):
logits = self.heads(h)
else:
logits = {name: head(h) for name, head in self.heads.items()}
# Loss
loss = None
if targets is not None:
loss = self.loss_fn(logits, targets)
return logits, loss
#-=-=-=-=-=-=-=-=-=-=-=--=-=-=---=-=-=-=-=-=-=-=-=-=-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-
#
# Helper for Migration
def dict_to_tensor(features_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Convert dict features to stacked tensor."""
return torch.stack([
features_dict['syllable_id'],
features_dict['onset_id'],
features_dict['nucleus_id'],
features_dict['coda_id'],
features_dict['position'],
features_dict['is_capitalized'],
features_dict['token_type'],
features_dict['has_space_after'],
features_dict['is_word_end'],
], dim=-1)
#-=-=-=-=-=-=-=-=-=-=-=--=-=-=---=-=-=-=-=-=-=-=-=-=-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-
#
# Lil Test
if __name__ == "__main__":
print("Luna - Speed Test")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = LunaConfig(
syllable_vocab=32768,
onset_vocab=2048,
nucleus_vocab=512,
coda_vocab=2048,
max_seq_len=1024,
fuse_output_heads=True,
)
model = Luna(config).to(device)
# Test forward pass
B, T = 8, 1024
features = torch.stack([
torch.randint(0, 1000, (B, T)),
torch.randint(0, 100, (B, T)),
torch.randint(0, 50, (B, T)),
torch.randint(0, 100, (B, T)),
torch.randint(0, 4, (B, T)),
torch.randint(0, 2, (B, T)),
torch.randint(0, 4, (B, T)),
torch.randint(0, 2, (B, T)),
torch.randint(0, 2, (B, T)),
], dim=-1).to(device)
targets = features.clone()
# Warmup
for _ in range(3):
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
logits, loss = model(features, targets)
torch.cuda.synchronize()
# Benchmark
import time
n_iters = 50
start = time.time()
for _ in range(n_iters):
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
logits, loss = model(features, targets)
loss.backward()
torch.cuda.synchronize()
elapsed = time.time() - start
tokens_per_iter = B * T
tok_per_sec = (n_iters * tokens_per_iter) / elapsed
print(f"\nBenchmark Results:")
print(f" Batch: {B} x {T} = {B*T:,} tokens")
print(f" Iterations: {n_iters}")
print(f" Time: {elapsed:.2f}s")
print(f" Throughput: {tok_per_sec:,.0f} tok/s")
print(f" Loss: {loss.item():.4f}")
# Test torch.compile
print("\nTesting torch.compile()...")
compiled_model = torch.compile(model, mode="reduce-overhead")
# Warmup compiled
for _ in range(5):
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
logits, loss = compiled_model(features, targets)
torch.cuda.synchronize()
# Benchmark compiled
start = time.time()
for _ in range(n_iters):
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
logits, loss = compiled_model(features, targets)
loss.backward()
torch.cuda.synchronize()
elapsed_compiled = time.time() - start
tok_per_sec_compiled = (n_iters * tokens_per_iter) / elapsed_compiled
print(f"\nCompiled Results:")
print(f" Throughput: {tok_per_sec_compiled:,.0f} tok/s")
print(f" Speedup: {tok_per_sec_compiled/tok_per_sec:.2f}x")
print(f"\n✓ All tests passed!")