Spaces:
Paused
Paused
| """ | |
| AETHER-Net Attention Layers | |
| 5 types: GDN, Full, Mamba2, Sliding Window, Cross Attention | |
| Each layer follows the same interface: | |
| forward(hidden_states, attention_mask=None, position_ids=None, **kwargs) -> hidden_states | |
| """ | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional, Tuple | |
| class RMSNorm(nn.Module): | |
| def __init__(self, hidden_size: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(hidden_size)) | |
| self.eps = eps | |
| def forward(self, x): | |
| variance = x.float().pow(2).mean(-1, keepdim=True) | |
| x = x * torch.rsqrt(variance + self.eps) | |
| return (self.weight * x).to(x.dtype) | |
| def rotate_half(x): | |
| x1, x2 = x.chunk(2, dim=-1) | |
| return torch.cat((-x2, x1), dim=-1) | |
| def apply_rotary_pos_emb(q, k, cos, sin): | |
| q_embed = (q * cos) + (rotate_half(q) * sin) | |
| k_embed = (k * cos) + (rotate_half(k) * sin) | |
| return q_embed, k_embed | |
| class RotaryEmbedding(nn.Module): | |
| def __init__(self, dim: int, max_seq_len: int = 262144, theta: float = 10000000.0): | |
| super().__init__() | |
| inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) | |
| self.register_buffer("inv_freq", inv_freq, persistent=False) | |
| self.max_seq_len = max_seq_len | |
| def forward(self, x, position_ids): | |
| # position_ids: [B, L] β take first batch (all same for standard positions) | |
| pos = position_ids[0] if position_ids.dim() == 2 else position_ids | |
| freqs = torch.outer(pos.float(), self.inv_freq.to(pos.device)) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| return emb.cos().unsqueeze(0), emb.sin().unsqueeze(0) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 1. FULL ATTENTION (Softmax, GQA, RoPE) β O(nΒ²) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class FullAttention(nn.Module): | |
| """Standard grouped-query attention with RoPE. | |
| Kept for 5 layers β provides precise token-to-token reasoning. | |
| These layers maintain KV cache.""" | |
| def __init__(self, config): | |
| super().__init__() | |
| self.num_heads = config.num_attention_heads | |
| self.num_kv_heads = config.num_kv_heads | |
| self.head_dim = config.head_dim | |
| self.num_kv_groups = self.num_heads // self.num_kv_heads | |
| self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) | |
| self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) | |
| self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) | |
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) | |
| # Output gate (Qwen3.5 style gated attention) | |
| self.gate = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) | |
| self.rotary_emb = RotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta) | |
| def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs): | |
| B, L, _ = hidden_states.shape | |
| q = self.q_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) | |
| k = self.k_proj(hidden_states).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2) | |
| v = self.v_proj(hidden_states).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2) | |
| # RoPE | |
| cos, sin = self.rotary_emb(hidden_states, position_ids) | |
| cos = cos.unsqueeze(1) | |
| sin = sin.unsqueeze(1) | |
| q, k = apply_rotary_pos_emb(q, k, cos, sin) | |
| # GQA: expand KV heads | |
| if self.num_kv_groups > 1: | |
| k = k.repeat_interleave(self.num_kv_groups, dim=1) | |
| v = v.repeat_interleave(self.num_kv_groups, dim=1) | |
| # Scaled dot-product attention | |
| attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) | |
| # Causal mask | |
| causal = torch.triu(torch.full((L, L), float('-inf'), device=attn.device), diagonal=1) | |
| attn = attn + causal.unsqueeze(0).unsqueeze(0) | |
| if attention_mask is not None: | |
| attn = attn + attention_mask | |
| attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype) | |
| out = torch.matmul(attn, v) | |
| out = out.transpose(1, 2).contiguous().view(B, L, -1) | |
| # Output gating | |
| gate = torch.sigmoid(self.gate(hidden_states)) | |
| out = out * gate | |
| return self.o_proj(out) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 2. GATED DELTANET (GDN) β O(n) linear time | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class GatedDeltaNet(nn.Module): | |
| """Gated DeltaNet: Mamba-style gating + DeltaNet fast-weight update. | |
| Core linear attention mechanism β 10 layers (40% of model). | |
| Implements: M_t = Ξ±_t * M_{t-1} * (I - k_t * q_t^T) + k_t * v_t^T | |
| with SiLU output gating for gradient flow stability. | |
| Weight transplant: Q,K,V projections map directly from Qwen3.5 GDN layers. | |
| """ | |
| def __init__(self, config): | |
| super().__init__() | |
| self.hidden_size = config.hidden_size | |
| self.num_heads = config.num_attention_heads | |
| self.head_dim = config.head_dim | |
| self.state_size = config.gdn_state_size | |
| # Input projections (transplantable from Qwen3.5 GDN) | |
| self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) | |
| self.k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) | |
| self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) | |
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) | |
| # Decay gate (Ξ±): controls memory decay speed | |
| self.decay_proj = nn.Linear(config.hidden_size, self.num_heads, bias=True) | |
| # Update gate (Ξ²): controls state update strength | |
| self.beta_proj = nn.Linear(config.hidden_size, self.num_heads, bias=True) | |
| # Output gate (SiLU activation for gradient stability) | |
| self.gate = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) | |
| # Short convolution for local context (replaces positional encoding) | |
| self.conv1d = nn.Conv1d( | |
| in_channels=config.hidden_size, | |
| out_channels=config.hidden_size, | |
| kernel_size=4, padding=3, groups=config.hidden_size, bias=True | |
| ) | |
| def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs): | |
| B, L, D = hidden_states.shape | |
| # Local context mixing via causal conv1d | |
| conv_out = self.conv1d(hidden_states.transpose(1, 2))[..., :L].transpose(1, 2) | |
| q = self.q_proj(conv_out).view(B, L, self.num_heads, self.head_dim) | |
| k = self.k_proj(conv_out).view(B, L, self.num_heads, self.head_dim) | |
| v = self.v_proj(hidden_states).view(B, L, self.num_heads, self.head_dim) | |
| # L2 normalize Q, K (replaces softmax normalization) | |
| q = F.normalize(q, p=2, dim=-1) | |
| k = F.normalize(k, p=2, dim=-1) | |
| # Decay and update gates | |
| alpha = torch.sigmoid(self.decay_proj(hidden_states)).unsqueeze(-1) # [B, L, H, 1] | |
| beta = torch.sigmoid(self.beta_proj(hidden_states)).unsqueeze(-1) | |
| # Recurrent scan with delta rule | |
| # M_t = Ξ± * M_{t-1} * (I - Ξ² * k * q^T) + Ξ² * k * v^T | |
| # For efficiency, compute as: o_t = q^T @ M_t | |
| outputs = [] | |
| state = torch.zeros(B, self.num_heads, self.head_dim, self.head_dim, | |
| device=hidden_states.device, dtype=hidden_states.dtype) | |
| for t in range(L): | |
| q_t = q[:, t] # [B, H, D] | |
| k_t = k[:, t] | |
| v_t = v[:, t] | |
| a_t = alpha[:, t] # [B, H, 1] | |
| b_t = beta[:, t] | |
| # Delta rule update | |
| # Erase: state = Ξ± * state * (I - Ξ² * k * q^T) | |
| # Write: state += Ξ² * k * v^T | |
| erase = torch.einsum('bhd,bhe->bhde', k_t * b_t, q_t) | |
| write = torch.einsum('bhd,bhe->bhde', k_t * b_t, v_t) | |
| state = a_t.unsqueeze(-1) * (state - state * erase) + write | |
| # Read: o_t = q^T @ state | |
| o_t = torch.einsum('bhd,bhde->bhe', q_t, state) | |
| outputs.append(o_t) | |
| out = torch.stack(outputs, dim=1) # [B, L, H, D] | |
| out = out.reshape(B, L, -1) | |
| # Output gating with SiLU | |
| gate = F.silu(self.gate(hidden_states)) | |
| out = out * gate | |
| return self.o_proj(out) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 3. MAMBA2 β O(n) with SSM state-space duality | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class Mamba2Block(nn.Module): | |
| """Mamba-2 block with Structured State Space Duality. | |
| 5 layers β provides state compression for memory efficiency. | |
| Weight transplant: Via MOHAWK SSD duality from Llama-3.1 Q,K,V β C,B,X. | |
| """ | |
| def __init__(self, config): | |
| super().__init__() | |
| self.hidden_size = config.hidden_size | |
| expand = config.mamba2_expand | |
| self.inner_size = config.hidden_size * expand | |
| self.state_size = config.mamba2_state_size | |
| self.conv_size = config.mamba2_conv_size | |
| self.num_heads = config.num_attention_heads | |
| # Input projection: x β (z, x_ssm) split | |
| self.in_proj = nn.Linear(config.hidden_size, self.inner_size * 2, bias=False) | |
| # Causal conv1d | |
| self.conv1d = nn.Conv1d( | |
| self.inner_size, self.inner_size, | |
| kernel_size=self.conv_size, padding=self.conv_size - 1, | |
| groups=self.inner_size, bias=True | |
| ) | |
| # SSM parameters | |
| self.dt_proj = nn.Linear(self.inner_size, self.num_heads, bias=True) | |
| self.A_log = nn.Parameter(torch.log(torch.arange(1, self.num_heads + 1, dtype=torch.float32))) | |
| self.D = nn.Parameter(torch.ones(self.num_heads)) | |
| # B, C projections (state-space) | |
| head_dim_ssm = self.inner_size // self.num_heads | |
| self.B_proj = nn.Linear(self.inner_size, self.state_size * self.num_heads, bias=False) | |
| self.C_proj = nn.Linear(self.inner_size, self.state_size * self.num_heads, bias=False) | |
| # Output | |
| self.out_proj = nn.Linear(self.inner_size, config.hidden_size, bias=False) | |
| self.norm = RMSNorm(self.inner_size) | |
| def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs): | |
| B, L, _ = hidden_states.shape | |
| # Input split | |
| zx = self.in_proj(hidden_states) | |
| z, x = zx.chunk(2, dim=-1) | |
| # Causal conv | |
| x = self.conv1d(x.transpose(1, 2))[..., :L].transpose(1, 2) | |
| x = F.silu(x) | |
| # SSM parameters | |
| A = -torch.exp(self.A_log) # [H] | |
| dt = F.softplus(self.dt_proj(x)) # [B, L, H] | |
| B_state = self.B_proj(x).view(B, L, self.num_heads, self.state_size) | |
| C_state = self.C_proj(x).view(B, L, self.num_heads, self.state_size) | |
| # Discretize: A_bar = exp(dt * A), B_bar = dt * B | |
| dt_A = dt.unsqueeze(-1) * A.view(1, 1, -1, 1) # [B, L, H, 1] | |
| A_bar = torch.exp(dt_A) | |
| B_bar = dt.unsqueeze(-1) * B_state # [B, L, H, N] | |
| # Selective scan (sequential for correctness; replace with FLA parallel kernel) | |
| head_dim = self.inner_size // self.num_heads | |
| x_heads = x.view(B, L, self.num_heads, head_dim) | |
| outputs = [] | |
| state = torch.zeros(B, self.num_heads, self.state_size, device=x.device, dtype=x.dtype) | |
| for t in range(L): | |
| state = A_bar[:, t] * state + B_bar[:, t] * x_heads[:, t, :, :1].expand_as(B_bar[:, t]) | |
| y_t = torch.sum(state * C_state[:, t], dim=-1) # [B, H] | |
| outputs.append(y_t) | |
| y = torch.stack(outputs, dim=1) # [B, L, H] | |
| # Skip connection with D | |
| y = y + self.D.view(1, 1, -1) * x.view(B, L, self.num_heads, head_dim).mean(-1) | |
| # Expand back and gate with z | |
| y = y.unsqueeze(-1).expand(-1, -1, -1, head_dim).reshape(B, L, self.inner_size) | |
| y = self.norm(y) | |
| y = y * F.silu(z) | |
| return self.out_proj(y) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 4. SLIDING WINDOW ATTENTION β O(n * w) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class SlidingWindowAttention(nn.Module): | |
| """Sliding window attention for local pattern capture. | |
| 5 layers β complements GDN's global view with fine-grained local context.""" | |
| def __init__(self, config): | |
| super().__init__() | |
| self.num_heads = config.num_attention_heads | |
| self.num_kv_heads = config.num_kv_heads | |
| self.head_dim = config.head_dim | |
| self.window_size = config.sliding_window_size | |
| self.num_kv_groups = self.num_heads // self.num_kv_heads | |
| self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) | |
| self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) | |
| self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) | |
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) | |
| self.gate = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) | |
| self.rotary_emb = RotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta) | |
| def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs): | |
| B, L, _ = hidden_states.shape | |
| q = self.q_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) | |
| k = self.k_proj(hidden_states).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2) | |
| v = self.v_proj(hidden_states).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2) | |
| cos, sin = self.rotary_emb(hidden_states, position_ids) | |
| cos = cos.unsqueeze(1) | |
| sin = sin.unsqueeze(1) | |
| q, k = apply_rotary_pos_emb(q, k, cos, sin) | |
| if self.num_kv_groups > 1: | |
| k = k.repeat_interleave(self.num_kv_groups, dim=1) | |
| v = v.repeat_interleave(self.num_kv_groups, dim=1) | |
| attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) | |
| # Sliding window + causal mask | |
| mask = torch.ones(L, L, device=attn.device, dtype=torch.bool) | |
| mask = torch.triu(mask, diagonal=1) # causal | |
| mask = mask | torch.tril(torch.ones_like(mask), diagonal=-self.window_size) # window | |
| attn = attn.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf')) | |
| attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype) | |
| out = torch.matmul(attn, v) | |
| out = out.transpose(1, 2).contiguous().view(B, L, -1) | |
| gate = torch.sigmoid(self.gate(hidden_states)) | |
| out = out * gate | |
| return self.o_proj(out) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 5. CROSS ATTENTION β for multimodal / tool bridging | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class CrossAttention(nn.Module): | |
| """Cross attention for PROMETHEUS (world model) and HEPHAESTUS (embodiment) connection. | |
| 5 layers β bridges AETHER-Net to external modalities. | |
| When no external context: falls back to self-attention with gating.""" | |
| def __init__(self, config): | |
| super().__init__() | |
| self.num_heads = config.num_attention_heads | |
| self.head_dim = config.head_dim | |
| # Self-attention path (default when no external context) | |
| self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) | |
| self.k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) | |
| self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) | |
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) | |
| # Cross-attention path (when external context available) | |
| self.cross_k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) | |
| self.cross_v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) | |
| # Modality gate: lerp between self and cross | |
| self.modality_gate = nn.Linear(config.hidden_size, 1, bias=True) | |
| nn.init.constant_(self.modality_gate.bias, -2.0) # default: mostly self-attention | |
| self.gate = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) | |
| def forward(self, hidden_states, attention_mask=None, position_ids=None, | |
| encoder_hidden_states=None, **kwargs): | |
| B, L, _ = hidden_states.shape | |
| q = self.q_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) | |
| if encoder_hidden_states is not None: | |
| # Cross-attention mode | |
| k_cross = self.cross_k_proj(encoder_hidden_states).view( | |
| B, -1, self.num_heads, self.head_dim).transpose(1, 2) | |
| v_cross = self.cross_v_proj(encoder_hidden_states).view( | |
| B, -1, self.num_heads, self.head_dim).transpose(1, 2) | |
| attn_cross = torch.matmul(q, k_cross.transpose(-2, -1)) / math.sqrt(self.head_dim) | |
| attn_cross = F.softmax(attn_cross, dim=-1, dtype=torch.float32).to(q.dtype) | |
| out_cross = torch.matmul(attn_cross, v_cross) | |
| out_cross = out_cross.transpose(1, 2).contiguous().view(B, L, -1) | |
| # Self-attention path (always runs) | |
| k_self = self.k_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) | |
| v_self = self.v_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) | |
| attn_self = torch.matmul(q, k_self.transpose(-2, -1)) / math.sqrt(self.head_dim) | |
| causal = torch.triu(torch.full((L, L), float('-inf'), device=attn_self.device), diagonal=1) | |
| attn_self = attn_self + causal.unsqueeze(0).unsqueeze(0) | |
| attn_self = F.softmax(attn_self, dim=-1, dtype=torch.float32).to(q.dtype) | |
| out_self = torch.matmul(attn_self, v_self).transpose(1, 2).contiguous().view(B, L, -1) | |
| # Blend via modality gate | |
| mg = torch.sigmoid(self.modality_gate(hidden_states)) | |
| out = mg * out_cross + (1 - mg) * out_self | |
| else: | |
| # Pure self-attention fallback | |
| k = self.k_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) | |
| v = self.v_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) | |
| attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) | |
| causal = torch.triu(torch.full((L, L), float('-inf'), device=attn.device), diagonal=1) | |
| attn = attn + causal.unsqueeze(0).unsqueeze(0) | |
| attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype) | |
| out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, L, -1) | |
| gate = torch.sigmoid(self.gate(hidden_states)) | |
| out = out * gate | |
| return self.o_proj(out) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Factory | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ATTENTION_CLASSES = { | |
| "gdn": GatedDeltaNet, | |
| "full": FullAttention, | |
| "mamba2": Mamba2Block, | |
| "slide": SlidingWindowAttention, | |
| "cross": CrossAttention, | |
| } | |
| def build_attention(layer_type: str, config): | |
| cls = ATTENTION_CLASSES.get(layer_type) | |
| if cls is None: | |
| raise ValueError(f"Unknown attention type: {layer_type}. Choose from {list(ATTENTION_CLASSES.keys())}") | |
| return cls(config) | |