JulianKrgd's picture
Upload src/model/layers.py with huggingface_hub
3d1cafb verified
"""
Julian Model Layers.
Core building blocks: RMSNorm, RoPE, Attention, FFN.
"""
import math
from typing import Optional, Tuple
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen import initializers
from .config import JulianConfig
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization."""
dim: int
eps: float = 1e-6
@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
weight = self.param(
"weight",
initializers.ones,
(self.dim,)
)
# RMS norm
variance = jnp.mean(x ** 2, axis=-1, keepdims=True)
x = x * jax.lax.rsqrt(variance + self.eps)
return x * weight
def precompute_rope_frequencies(
dim: int,
max_seq_len: int,
theta: float = 10000.0,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Precompute RoPE sin/cos frequencies."""
# Frequency for each dimension pair
freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2).astype(jnp.float32) / dim))
# Position indices
positions = jnp.arange(max_seq_len)
# Outer product: [seq_len, dim/2]
angles = jnp.outer(positions, freqs)
# Sin and cos
sin = jnp.sin(angles)
cos = jnp.cos(angles)
return sin, cos
def apply_rope(
x: jnp.ndarray,
sin: jnp.ndarray,
cos: jnp.ndarray,
) -> jnp.ndarray:
"""Apply rotary position embeddings."""
# x shape: [batch, seq_len, n_heads, head_dim]
# sin/cos shape: [seq_len, head_dim/2]
seq_len = x.shape[1]
sin = sin[:seq_len]
cos = cos[:seq_len]
# Split x into pairs
x1 = x[..., ::2] # Even indices
x2 = x[..., 1::2] # Odd indices
# Rotate
# [batch, seq, heads, dim/2]
sin = sin[None, :, None, :] # Add batch and head dims
cos = cos[None, :, None, :]
rotated_x1 = x1 * cos - x2 * sin
rotated_x2 = x1 * sin + x2 * cos
# Interleave back
rotated = jnp.stack([rotated_x1, rotated_x2], axis=-1)
rotated = rotated.reshape(x.shape)
return rotated
class Attention(nn.Module):
"""Multi-head self-attention with RoPE."""
config: JulianConfig
@nn.compact
def __call__(
self,
x: jnp.ndarray,
sin: jnp.ndarray,
cos: jnp.ndarray,
mask: Optional[jnp.ndarray] = None,
deterministic: bool = True,
) -> jnp.ndarray:
batch_size, seq_len, _ = x.shape
config = self.config
# QKV projections
q = nn.Dense(
config.d_model,
use_bias=config.use_bias,
kernel_init=initializers.normal(config.initializer_range),
name="q_proj",
)(x)
k = nn.Dense(
config.d_model,
use_bias=config.use_bias,
kernel_init=initializers.normal(config.initializer_range),
name="k_proj",
)(x)
v = nn.Dense(
config.d_model,
use_bias=config.use_bias,
kernel_init=initializers.normal(config.initializer_range),
name="v_proj",
)(x)
# Reshape for multi-head attention
# [batch, seq, d_model] -> [batch, seq, n_heads, head_dim]
q = q.reshape(batch_size, seq_len, config.n_heads, config.head_dim)
k = k.reshape(batch_size, seq_len, config.n_heads, config.head_dim)
v = v.reshape(batch_size, seq_len, config.n_heads, config.head_dim)
# Apply RoPE to Q and K
q = apply_rope(q, sin, cos)
k = apply_rope(k, sin, cos)
# Transpose for attention: [batch, n_heads, seq, head_dim]
q = jnp.transpose(q, (0, 2, 1, 3))
k = jnp.transpose(k, (0, 2, 1, 3))
v = jnp.transpose(v, (0, 2, 1, 3))
# Scaled dot-product attention in bfloat16 for memory efficiency
scale = 1.0 / math.sqrt(config.head_dim)
# Force bfloat16 for attention computation (major memory savings)
q = q.astype(jnp.bfloat16)
k = k.astype(jnp.bfloat16)
v = v.astype(jnp.bfloat16)
attn_weights = jnp.einsum("bhqd,bhkd->bhqk", q, k) * scale
# Apply causal mask
if mask is not None:
attn_weights = jnp.where(mask, attn_weights, jnp.finfo(jnp.bfloat16).min)
attn_weights = jax.nn.softmax(attn_weights.astype(jnp.float32), axis=-1)
attn_weights = attn_weights.astype(jnp.bfloat16)
# Dropout
if not deterministic:
attn_weights = nn.Dropout(
rate=config.attention_dropout,
deterministic=deterministic,
)(attn_weights)
# Apply attention to values
attn_output = jnp.einsum("bhqk,bhkd->bhqd", attn_weights, v)
# Reshape back: [batch, n_heads, seq, head_dim] -> [batch, seq, d_model]
attn_output = jnp.transpose(attn_output, (0, 2, 1, 3))
attn_output = attn_output.reshape(batch_size, seq_len, config.d_model)
# Output projection
output = nn.Dense(
config.d_model,
use_bias=config.use_bias,
kernel_init=initializers.normal(config.initializer_range),
name="o_proj",
)(attn_output)
return output
class FeedForward(nn.Module):
"""SwiGLU Feed-Forward Network."""
config: JulianConfig
@nn.compact
def __call__(
self,
x: jnp.ndarray,
deterministic: bool = True,
) -> jnp.ndarray:
config = self.config
# Gate and up projections
gate = nn.Dense(
config.d_ff,
use_bias=config.use_bias,
kernel_init=initializers.normal(config.initializer_range),
name="gate_proj",
)(x)
up = nn.Dense(
config.d_ff,
use_bias=config.use_bias,
kernel_init=initializers.normal(config.initializer_range),
name="up_proj",
)(x)
# SwiGLU activation
hidden = jax.nn.silu(gate) * up
# Dropout
if not deterministic:
hidden = nn.Dropout(
rate=config.dropout,
deterministic=deterministic,
)(hidden)
# Down projection
output = nn.Dense(
config.d_model,
use_bias=config.use_bias,
kernel_init=initializers.normal(config.initializer_range),
name="down_proj",
)(hidden)
return output
class TransformerBlock(nn.Module):
"""Single transformer decoder block."""
config: JulianConfig
@nn.compact
def __call__(
self,
x: jnp.ndarray,
sin: jnp.ndarray,
cos: jnp.ndarray,
mask: Optional[jnp.ndarray] = None,
deterministic: bool = True,
) -> jnp.ndarray:
config = self.config
# Pre-norm attention
residual = x
x = RMSNorm(config.d_model, config.rms_norm_eps, name="input_layernorm")(x)
x = Attention(config, name="self_attn")(x, sin, cos, mask, deterministic)
if not deterministic:
x = nn.Dropout(rate=config.dropout, deterministic=deterministic)(x)
x = residual + x
# Pre-norm FFN
residual = x
x = RMSNorm(config.d_model, config.rms_norm_eps, name="post_attention_layernorm")(x)
x = FeedForward(config, name="mlp")(x, deterministic)
if not deterministic:
x = nn.Dropout(rate=config.dropout, deterministic=deterministic)(x)
x = residual + x
return x