agiformer / src /models /layers.py
tefoteknik's picture
Phase 7: Curriculum Learning (20K steps, BPC 1.78)
98c160d verified
## Developer: inkbytefo
## Modified: 2025-11-23
import torch
import torch.nn as nn
import torch.nn.functional as F
from .memory import HebbianMemory
class SlidingWindowAttention(nn.Module):
"""
Local Attention mechanism restricted to a sliding window.
"""
def __init__(self, d_model: int, num_heads: int, window_size: int):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.window_size = window_size
self.head_dim = d_model // num_heads
self.qkv = nn.Linear(d_model, 3 * d_model)
self.proj = nn.Linear(d_model, d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, L, D = x.shape
H = self.num_heads
E = self.head_dim
scale = 1.0 / (E ** 0.5)
qkv = self.qkv(x).reshape(B, L, 3, H, E).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
# Manual Attention for Stability
scores = torch.matmul(q, k.transpose(-2, -1)) * scale
# Construct Mask
ones = torch.ones(L, L, device=x.device, dtype=torch.bool)
causal_mask = ones.triu(1)
window_mask = ones.tril(-self.window_size)
mask = causal_mask | window_mask
scores = scores.masked_fill(mask, -1e4)
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).reshape(B, L, D)
return self.proj(out)
class HybridBlock(nn.Module):
"""
Combines Sliding Window Attention (Local) and Hebbian Memory (Global).
"""
def __init__(self, d_model, num_heads, window_size, dropout):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
# Local Precision
self.attn = SlidingWindowAttention(d_model, num_heads, window_size)
# Global Context (Hebbian Memory)
# Replaces the static LinearAttention with dynamic Fast Weights
self.memory = HebbianMemory(d_model, num_heads, dropout)
self.out_proj = nn.Linear(d_model, d_model)
self.mlp = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.GELU(),
nn.Linear(4 * d_model, d_model),
nn.Dropout(dropout)
)
self.norm_mlp = nn.LayerNorm(d_model)
def forward(self, x):
residual = x
x_norm = self.norm1(x)
# Parallel Branches: Local Attention + Global Hebbian Memory
attn_out = self.attn(x_norm)
memory_out = self.memory(x_norm)
# Fusion
x = residual + self.out_proj(attn_out + memory_out)
# MLP
x = x + self.mlp(self.norm_mlp(x))
return x