test-old / layers.py
SeaWolf-AI's picture
Upload 6 files
ca19627 verified
"""
AETHER-Net Attention Layers
5 types: GDN, Full, Mamba2, Sliding Window, Cross Attention
Each layer follows the same interface:
forward(hidden_states, attention_mask=None, position_ids=None, **kwargs) -> hidden_states
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
class RMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x):
variance = x.float().pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
return (self.weight * x).to(x.dtype)
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_seq_len: int = 262144, theta: float = 10000000.0):
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.max_seq_len = max_seq_len
def forward(self, x, position_ids):
# position_ids: [B, L] β†’ take first batch (all same for standard positions)
pos = position_ids[0] if position_ids.dim() == 2 else position_ids
freqs = torch.outer(pos.float(), self.inv_freq.to(pos.device))
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().unsqueeze(0), emb.sin().unsqueeze(0)
# ═══════════════════════════════════════════════════════════
# 1. FULL ATTENTION (Softmax, GQA, RoPE) β€” O(nΒ²)
# ═══════════════════════════════════════════════════════════
class FullAttention(nn.Module):
"""Standard grouped-query attention with RoPE.
Kept for 5 layers β€” provides precise token-to-token reasoning.
These layers maintain KV cache."""
def __init__(self, config):
super().__init__()
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_kv_heads
self.head_dim = config.head_dim
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
# Output gate (Qwen3.5 style gated attention)
self.gate = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
self.rotary_emb = RotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta)
def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs):
B, L, _ = hidden_states.shape
q = self.q_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(hidden_states).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(hidden_states).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
# RoPE
cos, sin = self.rotary_emb(hidden_states, position_ids)
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# GQA: expand KV heads
if self.num_kv_groups > 1:
k = k.repeat_interleave(self.num_kv_groups, dim=1)
v = v.repeat_interleave(self.num_kv_groups, dim=1)
# Scaled dot-product attention
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
# Causal mask
causal = torch.triu(torch.full((L, L), float('-inf'), device=attn.device), diagonal=1)
attn = attn + causal.unsqueeze(0).unsqueeze(0)
if attention_mask is not None:
attn = attn + attention_mask
attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, L, -1)
# Output gating
gate = torch.sigmoid(self.gate(hidden_states))
out = out * gate
return self.o_proj(out)
# ═══════════════════════════════════════════════════════════
# 2. GATED DELTANET (GDN) β€” O(n) linear time
# ═══════════════════════════════════════════════════════════
class GatedDeltaNet(nn.Module):
"""Gated DeltaNet: Mamba-style gating + DeltaNet fast-weight update.
Core linear attention mechanism β€” 10 layers (40% of model).
Implements: M_t = Ξ±_t * M_{t-1} * (I - k_t * q_t^T) + k_t * v_t^T
with SiLU output gating for gradient flow stability.
Weight transplant: Q,K,V projections map directly from Qwen3.5 GDN layers.
"""
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.state_size = config.gdn_state_size
# Input projections (transplantable from Qwen3.5 GDN)
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
# Decay gate (Ξ±): controls memory decay speed
self.decay_proj = nn.Linear(config.hidden_size, self.num_heads, bias=True)
# Update gate (Ξ²): controls state update strength
self.beta_proj = nn.Linear(config.hidden_size, self.num_heads, bias=True)
# Output gate (SiLU activation for gradient stability)
self.gate = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
# Short convolution for local context (replaces positional encoding)
self.conv1d = nn.Conv1d(
in_channels=config.hidden_size,
out_channels=config.hidden_size,
kernel_size=4, padding=3, groups=config.hidden_size, bias=True
)
def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs):
B, L, D = hidden_states.shape
# Local context mixing via causal conv1d
conv_out = self.conv1d(hidden_states.transpose(1, 2))[..., :L].transpose(1, 2)
q = self.q_proj(conv_out).view(B, L, self.num_heads, self.head_dim)
k = self.k_proj(conv_out).view(B, L, self.num_heads, self.head_dim)
v = self.v_proj(hidden_states).view(B, L, self.num_heads, self.head_dim)
# L2 normalize Q, K (replaces softmax normalization)
q = F.normalize(q, p=2, dim=-1)
k = F.normalize(k, p=2, dim=-1)
# Decay and update gates
alpha = torch.sigmoid(self.decay_proj(hidden_states)).unsqueeze(-1) # [B, L, H, 1]
beta = torch.sigmoid(self.beta_proj(hidden_states)).unsqueeze(-1)
# Recurrent scan with delta rule
# M_t = Ξ± * M_{t-1} * (I - Ξ² * k * q^T) + Ξ² * k * v^T
# For efficiency, compute as: o_t = q^T @ M_t
outputs = []
state = torch.zeros(B, self.num_heads, self.head_dim, self.head_dim,
device=hidden_states.device, dtype=hidden_states.dtype)
for t in range(L):
q_t = q[:, t] # [B, H, D]
k_t = k[:, t]
v_t = v[:, t]
a_t = alpha[:, t] # [B, H, 1]
b_t = beta[:, t]
# Delta rule update
# Erase: state = Ξ± * state * (I - Ξ² * k * q^T)
# Write: state += Ξ² * k * v^T
erase = torch.einsum('bhd,bhe->bhde', k_t * b_t, q_t)
write = torch.einsum('bhd,bhe->bhde', k_t * b_t, v_t)
state = a_t.unsqueeze(-1) * (state - state * erase) + write
# Read: o_t = q^T @ state
o_t = torch.einsum('bhd,bhde->bhe', q_t, state)
outputs.append(o_t)
out = torch.stack(outputs, dim=1) # [B, L, H, D]
out = out.reshape(B, L, -1)
# Output gating with SiLU
gate = F.silu(self.gate(hidden_states))
out = out * gate
return self.o_proj(out)
# ═══════════════════════════════════════════════════════════
# 3. MAMBA2 β€” O(n) with SSM state-space duality
# ═══════════════════════════════════════════════════════════
class Mamba2Block(nn.Module):
"""Mamba-2 block with Structured State Space Duality.
5 layers β€” provides state compression for memory efficiency.
Weight transplant: Via MOHAWK SSD duality from Llama-3.1 Q,K,V β†’ C,B,X.
"""
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
expand = config.mamba2_expand
self.inner_size = config.hidden_size * expand
self.state_size = config.mamba2_state_size
self.conv_size = config.mamba2_conv_size
self.num_heads = config.num_attention_heads
# Input projection: x β†’ (z, x_ssm) split
self.in_proj = nn.Linear(config.hidden_size, self.inner_size * 2, bias=False)
# Causal conv1d
self.conv1d = nn.Conv1d(
self.inner_size, self.inner_size,
kernel_size=self.conv_size, padding=self.conv_size - 1,
groups=self.inner_size, bias=True
)
# SSM parameters
self.dt_proj = nn.Linear(self.inner_size, self.num_heads, bias=True)
self.A_log = nn.Parameter(torch.log(torch.arange(1, self.num_heads + 1, dtype=torch.float32)))
self.D = nn.Parameter(torch.ones(self.num_heads))
# B, C projections (state-space)
head_dim_ssm = self.inner_size // self.num_heads
self.B_proj = nn.Linear(self.inner_size, self.state_size * self.num_heads, bias=False)
self.C_proj = nn.Linear(self.inner_size, self.state_size * self.num_heads, bias=False)
# Output
self.out_proj = nn.Linear(self.inner_size, config.hidden_size, bias=False)
self.norm = RMSNorm(self.inner_size)
def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs):
B, L, _ = hidden_states.shape
# Input split
zx = self.in_proj(hidden_states)
z, x = zx.chunk(2, dim=-1)
# Causal conv
x = self.conv1d(x.transpose(1, 2))[..., :L].transpose(1, 2)
x = F.silu(x)
# SSM parameters
A = -torch.exp(self.A_log) # [H]
dt = F.softplus(self.dt_proj(x)) # [B, L, H]
B_state = self.B_proj(x).view(B, L, self.num_heads, self.state_size)
C_state = self.C_proj(x).view(B, L, self.num_heads, self.state_size)
# Discretize: A_bar = exp(dt * A), B_bar = dt * B
dt_A = dt.unsqueeze(-1) * A.view(1, 1, -1, 1) # [B, L, H, 1]
A_bar = torch.exp(dt_A)
B_bar = dt.unsqueeze(-1) * B_state # [B, L, H, N]
# Selective scan (sequential for correctness; replace with FLA parallel kernel)
head_dim = self.inner_size // self.num_heads
x_heads = x.view(B, L, self.num_heads, head_dim)
outputs = []
state = torch.zeros(B, self.num_heads, self.state_size, device=x.device, dtype=x.dtype)
for t in range(L):
state = A_bar[:, t] * state + B_bar[:, t] * x_heads[:, t, :, :1].expand_as(B_bar[:, t])
y_t = torch.sum(state * C_state[:, t], dim=-1) # [B, H]
outputs.append(y_t)
y = torch.stack(outputs, dim=1) # [B, L, H]
# Skip connection with D
y = y + self.D.view(1, 1, -1) * x.view(B, L, self.num_heads, head_dim).mean(-1)
# Expand back and gate with z
y = y.unsqueeze(-1).expand(-1, -1, -1, head_dim).reshape(B, L, self.inner_size)
y = self.norm(y)
y = y * F.silu(z)
return self.out_proj(y)
# ═══════════════════════════════════════════════════════════
# 4. SLIDING WINDOW ATTENTION β€” O(n * w)
# ═══════════════════════════════════════════════════════════
class SlidingWindowAttention(nn.Module):
"""Sliding window attention for local pattern capture.
5 layers β€” complements GDN's global view with fine-grained local context."""
def __init__(self, config):
super().__init__()
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_kv_heads
self.head_dim = config.head_dim
self.window_size = config.sliding_window_size
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
self.gate = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
self.rotary_emb = RotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta)
def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs):
B, L, _ = hidden_states.shape
q = self.q_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(hidden_states).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(hidden_states).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(hidden_states, position_ids)
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
if self.num_kv_groups > 1:
k = k.repeat_interleave(self.num_kv_groups, dim=1)
v = v.repeat_interleave(self.num_kv_groups, dim=1)
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
# Sliding window + causal mask
mask = torch.ones(L, L, device=attn.device, dtype=torch.bool)
mask = torch.triu(mask, diagonal=1) # causal
mask = mask | torch.tril(torch.ones_like(mask), diagonal=-self.window_size) # window
attn = attn.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, L, -1)
gate = torch.sigmoid(self.gate(hidden_states))
out = out * gate
return self.o_proj(out)
# ═══════════════════════════════════════════════════════════
# 5. CROSS ATTENTION β€” for multimodal / tool bridging
# ═══════════════════════════════════════════════════════════
class CrossAttention(nn.Module):
"""Cross attention for PROMETHEUS (world model) and HEPHAESTUS (embodiment) connection.
5 layers β€” bridges AETHER-Net to external modalities.
When no external context: falls back to self-attention with gating."""
def __init__(self, config):
super().__init__()
self.num_heads = config.num_attention_heads
self.head_dim = config.head_dim
# Self-attention path (default when no external context)
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
# Cross-attention path (when external context available)
self.cross_k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
self.cross_v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
# Modality gate: lerp between self and cross
self.modality_gate = nn.Linear(config.hidden_size, 1, bias=True)
nn.init.constant_(self.modality_gate.bias, -2.0) # default: mostly self-attention
self.gate = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
def forward(self, hidden_states, attention_mask=None, position_ids=None,
encoder_hidden_states=None, **kwargs):
B, L, _ = hidden_states.shape
q = self.q_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
if encoder_hidden_states is not None:
# Cross-attention mode
k_cross = self.cross_k_proj(encoder_hidden_states).view(
B, -1, self.num_heads, self.head_dim).transpose(1, 2)
v_cross = self.cross_v_proj(encoder_hidden_states).view(
B, -1, self.num_heads, self.head_dim).transpose(1, 2)
attn_cross = torch.matmul(q, k_cross.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn_cross = F.softmax(attn_cross, dim=-1, dtype=torch.float32).to(q.dtype)
out_cross = torch.matmul(attn_cross, v_cross)
out_cross = out_cross.transpose(1, 2).contiguous().view(B, L, -1)
# Self-attention path (always runs)
k_self = self.k_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
v_self = self.v_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
attn_self = torch.matmul(q, k_self.transpose(-2, -1)) / math.sqrt(self.head_dim)
causal = torch.triu(torch.full((L, L), float('-inf'), device=attn_self.device), diagonal=1)
attn_self = attn_self + causal.unsqueeze(0).unsqueeze(0)
attn_self = F.softmax(attn_self, dim=-1, dtype=torch.float32).to(q.dtype)
out_self = torch.matmul(attn_self, v_self).transpose(1, 2).contiguous().view(B, L, -1)
# Blend via modality gate
mg = torch.sigmoid(self.modality_gate(hidden_states))
out = mg * out_cross + (1 - mg) * out_self
else:
# Pure self-attention fallback
k = self.k_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
causal = torch.triu(torch.full((L, L), float('-inf'), device=attn.device), diagonal=1)
attn = attn + causal.unsqueeze(0).unsqueeze(0)
attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)
out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, L, -1)
gate = torch.sigmoid(self.gate(hidden_states))
out = out * gate
return self.o_proj(out)
# ═══════════════════════════════════════════════════════════
# Factory
# ═══════════════════════════════════════════════════════════
ATTENTION_CLASSES = {
"gdn": GatedDeltaNet,
"full": FullAttention,
"mamba2": Mamba2Block,
"slide": SlidingWindowAttention,
"cross": CrossAttention,
}
def build_attention(layer_type: str, config):
cls = ATTENTION_CLASSES.get(layer_type)
if cls is None:
raise ValueError(f"Unknown attention type: {layer_type}. Choose from {list(ATTENTION_CLASSES.keys())}")
return cls(config)