""" Julian Model Layers. Core building blocks: RMSNorm, RoPE, Attention, FFN. """ import math from typing import Optional, Tuple import jax import jax.numpy as jnp import flax.linen as nn from flax.linen import initializers from .config import JulianConfig class RMSNorm(nn.Module): """Root Mean Square Layer Normalization.""" dim: int eps: float = 1e-6 @nn.compact def __call__(self, x: jnp.ndarray) -> jnp.ndarray: weight = self.param( "weight", initializers.ones, (self.dim,) ) # RMS norm variance = jnp.mean(x ** 2, axis=-1, keepdims=True) x = x * jax.lax.rsqrt(variance + self.eps) return x * weight def precompute_rope_frequencies( dim: int, max_seq_len: int, theta: float = 10000.0, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Precompute RoPE sin/cos frequencies.""" # Frequency for each dimension pair freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2).astype(jnp.float32) / dim)) # Position indices positions = jnp.arange(max_seq_len) # Outer product: [seq_len, dim/2] angles = jnp.outer(positions, freqs) # Sin and cos sin = jnp.sin(angles) cos = jnp.cos(angles) return sin, cos def apply_rope( x: jnp.ndarray, sin: jnp.ndarray, cos: jnp.ndarray, ) -> jnp.ndarray: """Apply rotary position embeddings.""" # x shape: [batch, seq_len, n_heads, head_dim] # sin/cos shape: [seq_len, head_dim/2] seq_len = x.shape[1] sin = sin[:seq_len] cos = cos[:seq_len] # Split x into pairs x1 = x[..., ::2] # Even indices x2 = x[..., 1::2] # Odd indices # Rotate # [batch, seq, heads, dim/2] sin = sin[None, :, None, :] # Add batch and head dims cos = cos[None, :, None, :] rotated_x1 = x1 * cos - x2 * sin rotated_x2 = x1 * sin + x2 * cos # Interleave back rotated = jnp.stack([rotated_x1, rotated_x2], axis=-1) rotated = rotated.reshape(x.shape) return rotated class Attention(nn.Module): """Multi-head self-attention with RoPE.""" config: JulianConfig @nn.compact def __call__( self, x: jnp.ndarray, sin: jnp.ndarray, cos: jnp.ndarray, mask: Optional[jnp.ndarray] = None, deterministic: bool = True, ) -> jnp.ndarray: batch_size, seq_len, _ = x.shape config = self.config # QKV projections q = nn.Dense( config.d_model, use_bias=config.use_bias, kernel_init=initializers.normal(config.initializer_range), name="q_proj", )(x) k = nn.Dense( config.d_model, use_bias=config.use_bias, kernel_init=initializers.normal(config.initializer_range), name="k_proj", )(x) v = nn.Dense( config.d_model, use_bias=config.use_bias, kernel_init=initializers.normal(config.initializer_range), name="v_proj", )(x) # Reshape for multi-head attention # [batch, seq, d_model] -> [batch, seq, n_heads, head_dim] q = q.reshape(batch_size, seq_len, config.n_heads, config.head_dim) k = k.reshape(batch_size, seq_len, config.n_heads, config.head_dim) v = v.reshape(batch_size, seq_len, config.n_heads, config.head_dim) # Apply RoPE to Q and K q = apply_rope(q, sin, cos) k = apply_rope(k, sin, cos) # Transpose for attention: [batch, n_heads, seq, head_dim] q = jnp.transpose(q, (0, 2, 1, 3)) k = jnp.transpose(k, (0, 2, 1, 3)) v = jnp.transpose(v, (0, 2, 1, 3)) # Scaled dot-product attention in bfloat16 for memory efficiency scale = 1.0 / math.sqrt(config.head_dim) # Force bfloat16 for attention computation (major memory savings) q = q.astype(jnp.bfloat16) k = k.astype(jnp.bfloat16) v = v.astype(jnp.bfloat16) attn_weights = jnp.einsum("bhqd,bhkd->bhqk", q, k) * scale # Apply causal mask if mask is not None: attn_weights = jnp.where(mask, attn_weights, jnp.finfo(jnp.bfloat16).min) attn_weights = jax.nn.softmax(attn_weights.astype(jnp.float32), axis=-1) attn_weights = attn_weights.astype(jnp.bfloat16) # Dropout if not deterministic: attn_weights = nn.Dropout( rate=config.attention_dropout, deterministic=deterministic, )(attn_weights) # Apply attention to values attn_output = jnp.einsum("bhqk,bhkd->bhqd", attn_weights, v) # Reshape back: [batch, n_heads, seq, head_dim] -> [batch, seq, d_model] attn_output = jnp.transpose(attn_output, (0, 2, 1, 3)) attn_output = attn_output.reshape(batch_size, seq_len, config.d_model) # Output projection output = nn.Dense( config.d_model, use_bias=config.use_bias, kernel_init=initializers.normal(config.initializer_range), name="o_proj", )(attn_output) return output class FeedForward(nn.Module): """SwiGLU Feed-Forward Network.""" config: JulianConfig @nn.compact def __call__( self, x: jnp.ndarray, deterministic: bool = True, ) -> jnp.ndarray: config = self.config # Gate and up projections gate = nn.Dense( config.d_ff, use_bias=config.use_bias, kernel_init=initializers.normal(config.initializer_range), name="gate_proj", )(x) up = nn.Dense( config.d_ff, use_bias=config.use_bias, kernel_init=initializers.normal(config.initializer_range), name="up_proj", )(x) # SwiGLU activation hidden = jax.nn.silu(gate) * up # Dropout if not deterministic: hidden = nn.Dropout( rate=config.dropout, deterministic=deterministic, )(hidden) # Down projection output = nn.Dense( config.d_model, use_bias=config.use_bias, kernel_init=initializers.normal(config.initializer_range), name="down_proj", )(hidden) return output class TransformerBlock(nn.Module): """Single transformer decoder block.""" config: JulianConfig @nn.compact def __call__( self, x: jnp.ndarray, sin: jnp.ndarray, cos: jnp.ndarray, mask: Optional[jnp.ndarray] = None, deterministic: bool = True, ) -> jnp.ndarray: config = self.config # Pre-norm attention residual = x x = RMSNorm(config.d_model, config.rms_norm_eps, name="input_layernorm")(x) x = Attention(config, name="self_attn")(x, sin, cos, mask, deterministic) if not deterministic: x = nn.Dropout(rate=config.dropout, deterministic=deterministic)(x) x = residual + x # Pre-norm FFN residual = x x = RMSNorm(config.d_model, config.rms_norm_eps, name="post_attention_layernorm")(x) x = FeedForward(config, name="mlp")(x, deterministic) if not deterministic: x = nn.Dropout(rate=config.dropout, deterministic=deterministic)(x) x = residual + x return x