|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
|
|
|
class HebbianMemory(nn.Module): |
|
|
""" |
|
|
Hebbian Memory Module (Fast Weights). |
|
|
|
|
|
Implements the update rule: |
|
|
M_t = lambda * M_{t-1} + K_t * V_t^T |
|
|
O_t = Q_t * M_t |
|
|
|
|
|
CRITICAL CHANGE: |
|
|
To prevent numerical overflow in parallel computation (cumsum), |
|
|
the decay rate (lambda) is constrained to the range [0.99, 1.0]. |
|
|
This ensures lambda^(-L) does not explode for L=1024. |
|
|
""" |
|
|
def __init__(self, d_model, num_heads=8, dropout=0.1): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = d_model // num_heads |
|
|
|
|
|
self.qkv = nn.Linear(d_model, 3 * d_model) |
|
|
self.out_proj = nn.Linear(d_model, d_model) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
self.feature_map = nn.ELU() |
|
|
|
|
|
|
|
|
|
|
|
self.decay_logits = nn.Parameter(torch.zeros(num_heads)) |
|
|
|
|
|
self.norm = nn.LayerNorm(d_model) |
|
|
|
|
|
|
|
|
self.plasticity = 1.0 |
|
|
|
|
|
def set_plasticity(self, alpha): |
|
|
""" |
|
|
Updates the plasticity coefficient (alpha). |
|
|
alpha: float in [0, 1]. |
|
|
0.1 -> Childhood (Fast forgetting) |
|
|
0.99 -> Adulthood (Stable memory) |
|
|
""" |
|
|
self.plasticity = alpha |
|
|
|
|
|
@torch.amp.autocast('cuda', enabled=False) |
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = x.float() |
|
|
input_dtype = x.dtype |
|
|
|
|
|
B, L, D = x.shape |
|
|
H = self.num_heads |
|
|
E = self.head_dim |
|
|
|
|
|
|
|
|
qkv = self.qkv(x) |
|
|
q, k, v = qkv.chunk(3, dim=-1) |
|
|
|
|
|
|
|
|
q = q.view(B, L, H, E) |
|
|
k = k.view(B, L, H, E) |
|
|
v = v.view(B, L, H, E) |
|
|
|
|
|
|
|
|
q = self.feature_map(q) + 1.0 |
|
|
k = self.feature_map(k) + 1.0 |
|
|
|
|
|
|
|
|
q = q / math.sqrt(E) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raw_sigmoid = torch.sigmoid(self.decay_logits).view(1, 1, H, 1) |
|
|
lambdas = 0.99 + (0.01 * raw_sigmoid) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lambdas = lambdas * self.plasticity |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
indices = torch.arange(L, device=x.device, dtype=torch.float32).view(1, L, 1, 1) |
|
|
|
|
|
|
|
|
log_lambdas = torch.log(lambdas.clamp(min=1e-10)) |
|
|
|
|
|
|
|
|
|
|
|
exp_k = (-indices * log_lambdas).clamp(min=-50, max=50) |
|
|
exp_q = (indices * log_lambdas).clamp(min=-50, max=50) |
|
|
|
|
|
|
|
|
decay_k = torch.exp(exp_k) |
|
|
decay_q = torch.exp(exp_q) |
|
|
|
|
|
k_decayed = k * decay_k |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kv = torch.einsum('blhe,blhf->blhef', k_decayed, v) |
|
|
|
|
|
|
|
|
memory_state = torch.cumsum(kv, dim=1) |
|
|
|
|
|
|
|
|
k_sum_decayed = torch.cumsum(k_decayed, dim=1) |
|
|
|
|
|
|
|
|
q_decayed = q * decay_q |
|
|
|
|
|
|
|
|
num = torch.einsum('blhe,blhef->blhf', q_decayed, memory_state) |
|
|
|
|
|
|
|
|
den = torch.einsum('blhe,blhe->blh', q_decayed, k_sum_decayed) |
|
|
den = den.unsqueeze(-1) + 1e-6 |
|
|
|
|
|
out = num / den |
|
|
|
|
|
|
|
|
out = out.reshape(B, L, D) |
|
|
out = self.out_proj(out) |
|
|
|
|
|
|
|
|
out = self.dropout(self.norm(out)) |
|
|
return out.to(input_dtype) |
|
|
|
|
|
|