|
|
""" |
|
|
Julian Model Layers. |
|
|
Core building blocks: RMSNorm, RoPE, Attention, FFN. |
|
|
""" |
|
|
|
|
|
import math |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
import jax |
|
|
import jax.numpy as jnp |
|
|
import flax.linen as nn |
|
|
from flax.linen import initializers |
|
|
|
|
|
from .config import JulianConfig |
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
"""Root Mean Square Layer Normalization.""" |
|
|
|
|
|
dim: int |
|
|
eps: float = 1e-6 |
|
|
|
|
|
@nn.compact |
|
|
def __call__(self, x: jnp.ndarray) -> jnp.ndarray: |
|
|
weight = self.param( |
|
|
"weight", |
|
|
initializers.ones, |
|
|
(self.dim,) |
|
|
) |
|
|
|
|
|
|
|
|
variance = jnp.mean(x ** 2, axis=-1, keepdims=True) |
|
|
x = x * jax.lax.rsqrt(variance + self.eps) |
|
|
|
|
|
return x * weight |
|
|
|
|
|
|
|
|
def precompute_rope_frequencies( |
|
|
dim: int, |
|
|
max_seq_len: int, |
|
|
theta: float = 10000.0, |
|
|
) -> Tuple[jnp.ndarray, jnp.ndarray]: |
|
|
"""Precompute RoPE sin/cos frequencies.""" |
|
|
|
|
|
|
|
|
freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2).astype(jnp.float32) / dim)) |
|
|
|
|
|
|
|
|
positions = jnp.arange(max_seq_len) |
|
|
|
|
|
|
|
|
angles = jnp.outer(positions, freqs) |
|
|
|
|
|
|
|
|
sin = jnp.sin(angles) |
|
|
cos = jnp.cos(angles) |
|
|
|
|
|
return sin, cos |
|
|
|
|
|
|
|
|
def apply_rope( |
|
|
x: jnp.ndarray, |
|
|
sin: jnp.ndarray, |
|
|
cos: jnp.ndarray, |
|
|
) -> jnp.ndarray: |
|
|
"""Apply rotary position embeddings.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seq_len = x.shape[1] |
|
|
sin = sin[:seq_len] |
|
|
cos = cos[:seq_len] |
|
|
|
|
|
|
|
|
x1 = x[..., ::2] |
|
|
x2 = x[..., 1::2] |
|
|
|
|
|
|
|
|
|
|
|
sin = sin[None, :, None, :] |
|
|
cos = cos[None, :, None, :] |
|
|
|
|
|
rotated_x1 = x1 * cos - x2 * sin |
|
|
rotated_x2 = x1 * sin + x2 * cos |
|
|
|
|
|
|
|
|
rotated = jnp.stack([rotated_x1, rotated_x2], axis=-1) |
|
|
rotated = rotated.reshape(x.shape) |
|
|
|
|
|
return rotated |
|
|
|
|
|
|
|
|
class Attention(nn.Module): |
|
|
"""Multi-head self-attention with RoPE.""" |
|
|
|
|
|
config: JulianConfig |
|
|
|
|
|
@nn.compact |
|
|
def __call__( |
|
|
self, |
|
|
x: jnp.ndarray, |
|
|
sin: jnp.ndarray, |
|
|
cos: jnp.ndarray, |
|
|
mask: Optional[jnp.ndarray] = None, |
|
|
deterministic: bool = True, |
|
|
) -> jnp.ndarray: |
|
|
|
|
|
batch_size, seq_len, _ = x.shape |
|
|
config = self.config |
|
|
|
|
|
|
|
|
q = nn.Dense( |
|
|
config.d_model, |
|
|
use_bias=config.use_bias, |
|
|
kernel_init=initializers.normal(config.initializer_range), |
|
|
name="q_proj", |
|
|
)(x) |
|
|
|
|
|
k = nn.Dense( |
|
|
config.d_model, |
|
|
use_bias=config.use_bias, |
|
|
kernel_init=initializers.normal(config.initializer_range), |
|
|
name="k_proj", |
|
|
)(x) |
|
|
|
|
|
v = nn.Dense( |
|
|
config.d_model, |
|
|
use_bias=config.use_bias, |
|
|
kernel_init=initializers.normal(config.initializer_range), |
|
|
name="v_proj", |
|
|
)(x) |
|
|
|
|
|
|
|
|
|
|
|
q = q.reshape(batch_size, seq_len, config.n_heads, config.head_dim) |
|
|
k = k.reshape(batch_size, seq_len, config.n_heads, config.head_dim) |
|
|
v = v.reshape(batch_size, seq_len, config.n_heads, config.head_dim) |
|
|
|
|
|
|
|
|
q = apply_rope(q, sin, cos) |
|
|
k = apply_rope(k, sin, cos) |
|
|
|
|
|
|
|
|
q = jnp.transpose(q, (0, 2, 1, 3)) |
|
|
k = jnp.transpose(k, (0, 2, 1, 3)) |
|
|
v = jnp.transpose(v, (0, 2, 1, 3)) |
|
|
|
|
|
|
|
|
scale = 1.0 / math.sqrt(config.head_dim) |
|
|
|
|
|
|
|
|
q = q.astype(jnp.bfloat16) |
|
|
k = k.astype(jnp.bfloat16) |
|
|
v = v.astype(jnp.bfloat16) |
|
|
|
|
|
attn_weights = jnp.einsum("bhqd,bhkd->bhqk", q, k) * scale |
|
|
|
|
|
|
|
|
if mask is not None: |
|
|
attn_weights = jnp.where(mask, attn_weights, jnp.finfo(jnp.bfloat16).min) |
|
|
|
|
|
attn_weights = jax.nn.softmax(attn_weights.astype(jnp.float32), axis=-1) |
|
|
attn_weights = attn_weights.astype(jnp.bfloat16) |
|
|
|
|
|
|
|
|
if not deterministic: |
|
|
attn_weights = nn.Dropout( |
|
|
rate=config.attention_dropout, |
|
|
deterministic=deterministic, |
|
|
)(attn_weights) |
|
|
|
|
|
|
|
|
attn_output = jnp.einsum("bhqk,bhkd->bhqd", attn_weights, v) |
|
|
|
|
|
|
|
|
attn_output = jnp.transpose(attn_output, (0, 2, 1, 3)) |
|
|
attn_output = attn_output.reshape(batch_size, seq_len, config.d_model) |
|
|
|
|
|
|
|
|
output = nn.Dense( |
|
|
config.d_model, |
|
|
use_bias=config.use_bias, |
|
|
kernel_init=initializers.normal(config.initializer_range), |
|
|
name="o_proj", |
|
|
)(attn_output) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
"""SwiGLU Feed-Forward Network.""" |
|
|
|
|
|
config: JulianConfig |
|
|
|
|
|
@nn.compact |
|
|
def __call__( |
|
|
self, |
|
|
x: jnp.ndarray, |
|
|
deterministic: bool = True, |
|
|
) -> jnp.ndarray: |
|
|
|
|
|
config = self.config |
|
|
|
|
|
|
|
|
gate = nn.Dense( |
|
|
config.d_ff, |
|
|
use_bias=config.use_bias, |
|
|
kernel_init=initializers.normal(config.initializer_range), |
|
|
name="gate_proj", |
|
|
)(x) |
|
|
|
|
|
up = nn.Dense( |
|
|
config.d_ff, |
|
|
use_bias=config.use_bias, |
|
|
kernel_init=initializers.normal(config.initializer_range), |
|
|
name="up_proj", |
|
|
)(x) |
|
|
|
|
|
|
|
|
hidden = jax.nn.silu(gate) * up |
|
|
|
|
|
|
|
|
if not deterministic: |
|
|
hidden = nn.Dropout( |
|
|
rate=config.dropout, |
|
|
deterministic=deterministic, |
|
|
)(hidden) |
|
|
|
|
|
|
|
|
output = nn.Dense( |
|
|
config.d_model, |
|
|
use_bias=config.use_bias, |
|
|
kernel_init=initializers.normal(config.initializer_range), |
|
|
name="down_proj", |
|
|
)(hidden) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
|
"""Single transformer decoder block.""" |
|
|
|
|
|
config: JulianConfig |
|
|
|
|
|
@nn.compact |
|
|
def __call__( |
|
|
self, |
|
|
x: jnp.ndarray, |
|
|
sin: jnp.ndarray, |
|
|
cos: jnp.ndarray, |
|
|
mask: Optional[jnp.ndarray] = None, |
|
|
deterministic: bool = True, |
|
|
) -> jnp.ndarray: |
|
|
|
|
|
config = self.config |
|
|
|
|
|
|
|
|
residual = x |
|
|
x = RMSNorm(config.d_model, config.rms_norm_eps, name="input_layernorm")(x) |
|
|
x = Attention(config, name="self_attn")(x, sin, cos, mask, deterministic) |
|
|
|
|
|
if not deterministic: |
|
|
x = nn.Dropout(rate=config.dropout, deterministic=deterministic)(x) |
|
|
|
|
|
x = residual + x |
|
|
|
|
|
|
|
|
residual = x |
|
|
x = RMSNorm(config.d_model, config.rms_norm_eps, name="post_attention_layernorm")(x) |
|
|
x = FeedForward(config, name="mlp")(x, deterministic) |
|
|
|
|
|
if not deterministic: |
|
|
x = nn.Dropout(rate=config.dropout, deterministic=deterministic)(x) |
|
|
|
|
|
x = residual + x |
|
|
|
|
|
return x |
|
|
|