""" AETHER-Net Attention Layers 5 types: GDN, Full, Mamba2, Sliding Window, Cross Attention Each layer follows the same interface: forward(hidden_states, attention_mask=None, position_ids=None, **kwargs) -> hidden_states """ import math import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple class RMSNorm(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = eps def forward(self, x): variance = x.float().pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + self.eps) return (self.weight * x).to(x.dtype) def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin): q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class RotaryEmbedding(nn.Module): def __init__(self, dim: int, max_seq_len: int = 262144, theta: float = 10000000.0): super().__init__() inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self.max_seq_len = max_seq_len def forward(self, x, position_ids): # position_ids: [B, L] → take first batch (all same for standard positions) pos = position_ids[0] if position_ids.dim() == 2 else position_ids freqs = torch.outer(pos.float(), self.inv_freq.to(pos.device)) emb = torch.cat((freqs, freqs), dim=-1) return emb.cos().unsqueeze(0), emb.sin().unsqueeze(0) # ═══════════════════════════════════════════════════════════ # 1. FULL ATTENTION (Softmax, GQA, RoPE) — O(n²) # ═══════════════════════════════════════════════════════════ class FullAttention(nn.Module): """Standard grouped-query attention with RoPE. Kept for 5 layers — provides precise token-to-token reasoning. These layers maintain KV cache.""" def __init__(self, config): super().__init__() self.num_heads = config.num_attention_heads self.num_kv_heads = config.num_kv_heads self.head_dim = config.head_dim self.num_kv_groups = self.num_heads // self.num_kv_heads self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) # Output gate (Qwen3.5 style gated attention) self.gate = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) self.rotary_emb = RotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta) def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs): B, L, _ = hidden_states.shape q = self.q_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(hidden_states).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(hidden_states).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2) # RoPE cos, sin = self.rotary_emb(hidden_states, position_ids) cos = cos.unsqueeze(1) sin = sin.unsqueeze(1) q, k = apply_rotary_pos_emb(q, k, cos, sin) # GQA: expand KV heads if self.num_kv_groups > 1: k = k.repeat_interleave(self.num_kv_groups, dim=1) v = v.repeat_interleave(self.num_kv_groups, dim=1) # Scaled dot-product attention attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) # Causal mask causal = torch.triu(torch.full((L, L), float('-inf'), device=attn.device), diagonal=1) attn = attn + causal.unsqueeze(0).unsqueeze(0) if attention_mask is not None: attn = attn + attention_mask attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype) out = torch.matmul(attn, v) out = out.transpose(1, 2).contiguous().view(B, L, -1) # Output gating gate = torch.sigmoid(self.gate(hidden_states)) out = out * gate return self.o_proj(out) # ═══════════════════════════════════════════════════════════ # 2. GATED DELTANET (GDN) — O(n) linear time # ═══════════════════════════════════════════════════════════ class GatedDeltaNet(nn.Module): """Gated DeltaNet: Mamba-style gating + DeltaNet fast-weight update. Core linear attention mechanism — 10 layers (40% of model). Implements: M_t = α_t * M_{t-1} * (I - k_t * q_t^T) + k_t * v_t^T with SiLU output gating for gradient flow stability. Weight transplant: Q,K,V projections map directly from Qwen3.5 GDN layers. """ def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = config.head_dim self.state_size = config.gdn_state_size # Input projections (transplantable from Qwen3.5 GDN) self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) # Decay gate (α): controls memory decay speed self.decay_proj = nn.Linear(config.hidden_size, self.num_heads, bias=True) # Update gate (β): controls state update strength self.beta_proj = nn.Linear(config.hidden_size, self.num_heads, bias=True) # Output gate (SiLU activation for gradient stability) self.gate = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) # Short convolution for local context (replaces positional encoding) self.conv1d = nn.Conv1d( in_channels=config.hidden_size, out_channels=config.hidden_size, kernel_size=4, padding=3, groups=config.hidden_size, bias=True ) def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs): B, L, D = hidden_states.shape # Local context mixing via causal conv1d conv_out = self.conv1d(hidden_states.transpose(1, 2))[..., :L].transpose(1, 2) q = self.q_proj(conv_out).view(B, L, self.num_heads, self.head_dim) k = self.k_proj(conv_out).view(B, L, self.num_heads, self.head_dim) v = self.v_proj(hidden_states).view(B, L, self.num_heads, self.head_dim) # L2 normalize Q, K (replaces softmax normalization) q = F.normalize(q, p=2, dim=-1) k = F.normalize(k, p=2, dim=-1) # Decay and update gates alpha = torch.sigmoid(self.decay_proj(hidden_states)).unsqueeze(-1) # [B, L, H, 1] beta = torch.sigmoid(self.beta_proj(hidden_states)).unsqueeze(-1) # Recurrent scan with delta rule # M_t = α * M_{t-1} * (I - β * k * q^T) + β * k * v^T # For efficiency, compute as: o_t = q^T @ M_t outputs = [] state = torch.zeros(B, self.num_heads, self.head_dim, self.head_dim, device=hidden_states.device, dtype=hidden_states.dtype) for t in range(L): q_t = q[:, t] # [B, H, D] k_t = k[:, t] v_t = v[:, t] a_t = alpha[:, t] # [B, H, 1] b_t = beta[:, t] # Delta rule update # Erase: state = α * state * (I - β * k * q^T) # Write: state += β * k * v^T erase = torch.einsum('bhd,bhe->bhde', k_t * b_t, q_t) write = torch.einsum('bhd,bhe->bhde', k_t * b_t, v_t) state = a_t.unsqueeze(-1) * (state - state * erase) + write # Read: o_t = q^T @ state o_t = torch.einsum('bhd,bhde->bhe', q_t, state) outputs.append(o_t) out = torch.stack(outputs, dim=1) # [B, L, H, D] out = out.reshape(B, L, -1) # Output gating with SiLU gate = F.silu(self.gate(hidden_states)) out = out * gate return self.o_proj(out) # ═══════════════════════════════════════════════════════════ # 3. MAMBA2 — O(n) with SSM state-space duality # ═══════════════════════════════════════════════════════════ class Mamba2Block(nn.Module): """Mamba-2 block with Structured State Space Duality. 5 layers — provides state compression for memory efficiency. Weight transplant: Via MOHAWK SSD duality from Llama-3.1 Q,K,V → C,B,X. """ def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size expand = config.mamba2_expand self.inner_size = config.hidden_size * expand self.state_size = config.mamba2_state_size self.conv_size = config.mamba2_conv_size self.num_heads = config.num_attention_heads # Input projection: x → (z, x_ssm) split self.in_proj = nn.Linear(config.hidden_size, self.inner_size * 2, bias=False) # Causal conv1d self.conv1d = nn.Conv1d( self.inner_size, self.inner_size, kernel_size=self.conv_size, padding=self.conv_size - 1, groups=self.inner_size, bias=True ) # SSM parameters self.dt_proj = nn.Linear(self.inner_size, self.num_heads, bias=True) self.A_log = nn.Parameter(torch.log(torch.arange(1, self.num_heads + 1, dtype=torch.float32))) self.D = nn.Parameter(torch.ones(self.num_heads)) # B, C projections (state-space) head_dim_ssm = self.inner_size // self.num_heads self.B_proj = nn.Linear(self.inner_size, self.state_size * self.num_heads, bias=False) self.C_proj = nn.Linear(self.inner_size, self.state_size * self.num_heads, bias=False) # Output self.out_proj = nn.Linear(self.inner_size, config.hidden_size, bias=False) self.norm = RMSNorm(self.inner_size) def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs): B, L, _ = hidden_states.shape # Input split zx = self.in_proj(hidden_states) z, x = zx.chunk(2, dim=-1) # Causal conv x = self.conv1d(x.transpose(1, 2))[..., :L].transpose(1, 2) x = F.silu(x) # SSM parameters A = -torch.exp(self.A_log) # [H] dt = F.softplus(self.dt_proj(x)) # [B, L, H] B_state = self.B_proj(x).view(B, L, self.num_heads, self.state_size) C_state = self.C_proj(x).view(B, L, self.num_heads, self.state_size) # Discretize: A_bar = exp(dt * A), B_bar = dt * B dt_A = dt.unsqueeze(-1) * A.view(1, 1, -1, 1) # [B, L, H, 1] A_bar = torch.exp(dt_A) B_bar = dt.unsqueeze(-1) * B_state # [B, L, H, N] # Selective scan (sequential for correctness; replace with FLA parallel kernel) head_dim = self.inner_size // self.num_heads x_heads = x.view(B, L, self.num_heads, head_dim) outputs = [] state = torch.zeros(B, self.num_heads, self.state_size, device=x.device, dtype=x.dtype) for t in range(L): state = A_bar[:, t] * state + B_bar[:, t] * x_heads[:, t, :, :1].expand_as(B_bar[:, t]) y_t = torch.sum(state * C_state[:, t], dim=-1) # [B, H] outputs.append(y_t) y = torch.stack(outputs, dim=1) # [B, L, H] # Skip connection with D y = y + self.D.view(1, 1, -1) * x.view(B, L, self.num_heads, head_dim).mean(-1) # Expand back and gate with z y = y.unsqueeze(-1).expand(-1, -1, -1, head_dim).reshape(B, L, self.inner_size) y = self.norm(y) y = y * F.silu(z) return self.out_proj(y) # ═══════════════════════════════════════════════════════════ # 4. SLIDING WINDOW ATTENTION — O(n * w) # ═══════════════════════════════════════════════════════════ class SlidingWindowAttention(nn.Module): """Sliding window attention for local pattern capture. 5 layers — complements GDN's global view with fine-grained local context.""" def __init__(self, config): super().__init__() self.num_heads = config.num_attention_heads self.num_kv_heads = config.num_kv_heads self.head_dim = config.head_dim self.window_size = config.sliding_window_size self.num_kv_groups = self.num_heads // self.num_kv_heads self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) self.gate = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) self.rotary_emb = RotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta) def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs): B, L, _ = hidden_states.shape q = self.q_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(hidden_states).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(hidden_states).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2) cos, sin = self.rotary_emb(hidden_states, position_ids) cos = cos.unsqueeze(1) sin = sin.unsqueeze(1) q, k = apply_rotary_pos_emb(q, k, cos, sin) if self.num_kv_groups > 1: k = k.repeat_interleave(self.num_kv_groups, dim=1) v = v.repeat_interleave(self.num_kv_groups, dim=1) attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) # Sliding window + causal mask mask = torch.ones(L, L, device=attn.device, dtype=torch.bool) mask = torch.triu(mask, diagonal=1) # causal mask = mask | torch.tril(torch.ones_like(mask), diagonal=-self.window_size) # window attn = attn.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf')) attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype) out = torch.matmul(attn, v) out = out.transpose(1, 2).contiguous().view(B, L, -1) gate = torch.sigmoid(self.gate(hidden_states)) out = out * gate return self.o_proj(out) # ═══════════════════════════════════════════════════════════ # 5. CROSS ATTENTION — for multimodal / tool bridging # ═══════════════════════════════════════════════════════════ class CrossAttention(nn.Module): """Cross attention for PROMETHEUS (world model) and HEPHAESTUS (embodiment) connection. 5 layers — bridges AETHER-Net to external modalities. When no external context: falls back to self-attention with gating.""" def __init__(self, config): super().__init__() self.num_heads = config.num_attention_heads self.head_dim = config.head_dim # Self-attention path (default when no external context) self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) # Cross-attention path (when external context available) self.cross_k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) self.cross_v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) # Modality gate: lerp between self and cross self.modality_gate = nn.Linear(config.hidden_size, 1, bias=True) nn.init.constant_(self.modality_gate.bias, -2.0) # default: mostly self-attention self.gate = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) def forward(self, hidden_states, attention_mask=None, position_ids=None, encoder_hidden_states=None, **kwargs): B, L, _ = hidden_states.shape q = self.q_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) if encoder_hidden_states is not None: # Cross-attention mode k_cross = self.cross_k_proj(encoder_hidden_states).view( B, -1, self.num_heads, self.head_dim).transpose(1, 2) v_cross = self.cross_v_proj(encoder_hidden_states).view( B, -1, self.num_heads, self.head_dim).transpose(1, 2) attn_cross = torch.matmul(q, k_cross.transpose(-2, -1)) / math.sqrt(self.head_dim) attn_cross = F.softmax(attn_cross, dim=-1, dtype=torch.float32).to(q.dtype) out_cross = torch.matmul(attn_cross, v_cross) out_cross = out_cross.transpose(1, 2).contiguous().view(B, L, -1) # Self-attention path (always runs) k_self = self.k_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) v_self = self.v_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) attn_self = torch.matmul(q, k_self.transpose(-2, -1)) / math.sqrt(self.head_dim) causal = torch.triu(torch.full((L, L), float('-inf'), device=attn_self.device), diagonal=1) attn_self = attn_self + causal.unsqueeze(0).unsqueeze(0) attn_self = F.softmax(attn_self, dim=-1, dtype=torch.float32).to(q.dtype) out_self = torch.matmul(attn_self, v_self).transpose(1, 2).contiguous().view(B, L, -1) # Blend via modality gate mg = torch.sigmoid(self.modality_gate(hidden_states)) out = mg * out_cross + (1 - mg) * out_self else: # Pure self-attention fallback k = self.k_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) v = self.v_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) causal = torch.triu(torch.full((L, L), float('-inf'), device=attn.device), diagonal=1) attn = attn + causal.unsqueeze(0).unsqueeze(0) attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype) out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, L, -1) gate = torch.sigmoid(self.gate(hidden_states)) out = out * gate return self.o_proj(out) # ═══════════════════════════════════════════════════════════ # Factory # ═══════════════════════════════════════════════════════════ ATTENTION_CLASSES = { "gdn": GatedDeltaNet, "full": FullAttention, "mamba2": Mamba2Block, "slide": SlidingWindowAttention, "cross": CrossAttention, } def build_attention(layer_type: str, config): cls = ATTENTION_CLASSES.get(layer_type) if cls is None: raise ValueError(f"Unknown attention type: {layer_type}. Choose from {list(ATTENTION_CLASSES.keys())}") return cls(config)