File size: 5,184 Bytes
344be51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
## Developer: inkbytefo
## Modified: 2025-11-23
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class HebbianMemory(nn.Module):
"""
Hebbian Memory Module (Fast Weights).
Implements the update rule:
M_t = lambda * M_{t-1} + K_t * V_t^T
O_t = Q_t * M_t
CRITICAL CHANGE:
To prevent numerical overflow in parallel computation (cumsum),
the decay rate (lambda) is constrained to the range [0.99, 1.0].
This ensures lambda^(-L) does not explode for L=1024.
"""
def __init__(self, d_model, num_heads=8, dropout=0.1):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.qkv = nn.Linear(d_model, 3 * d_model)
self.out_proj = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
# Feature map: ELU + 1 ensures positivity for valid probability kernel
self.feature_map = nn.ELU()
# Learnable Decay Parameter
# Initialized to generate sigmoid output ~0.5, mapped to range later
self.decay_logits = nn.Parameter(torch.zeros(num_heads))
self.norm = nn.LayerNorm(d_model)
# Plasticity Factor (Alpha) - Controlled externally
self.plasticity = 1.0
def set_plasticity(self, alpha):
"""
Updates the plasticity coefficient (alpha).
alpha: float in [0, 1].
0.1 -> Childhood (Fast forgetting)
0.99 -> Adulthood (Stable memory)
"""
self.plasticity = alpha
@torch.amp.autocast('cuda', enabled=False)
def forward(self, x):
# CRITICAL: Bypass AMP for this entire module to prevent NaN
# With plasticity=0.1, decay factors become exp(±50) and the cumsum
# operations accumulate massive intermediate values that overflow in float16
# We must use float32 for all computations including linear layers
x = x.float() # Ensure input is float32
input_dtype = x.dtype
B, L, D = x.shape
H = self.num_heads
E = self.head_dim
# 1. Projections
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
# Reshape (B, L, H, E)
q = q.view(B, L, H, E)
k = k.view(B, L, H, E)
v = v.view(B, L, H, E)
# 2. Feature Map (Kernel Trick)
q = self.feature_map(q) + 1.0
k = self.feature_map(k) + 1.0
# Scale Q to prevent magnitude explosion
q = q / math.sqrt(E)
# 3. Decay Factor (Lambda) - STABILIZED
# Map sigmoid (0,1) to (0.990, 1.0)
# This prevents overflow. 0.99^-1024 ~= 29468 (Safe for FP32)
raw_sigmoid = torch.sigmoid(self.decay_logits).view(1, 1, H, 1)
lambdas = 0.99 + (0.01 * raw_sigmoid)
# Apply Plasticity Schedule
# Effective Lambda = Lambda * Alpha
# If Alpha is low (childhood), decay is very fast.
lambdas = lambdas * self.plasticity
# 4. Parallel Hebbian Update
# Formula: O_i = (Q_i * sum_{j=1}^i lambda^{i-j} K_j^T V_j)
# Implementation: Q_i * lambda^i * cumsum(lambda^-j * K_j * V_j)
indices = torch.arange(L, device=x.device, dtype=torch.float32).view(1, L, 1, 1)
# Use log-space arithmetic to prevent overflow/underflow
log_lambdas = torch.log(lambdas.clamp(min=1e-10))
# Clamp the exponent BEFORE exp() to prevent overflow
# We use ±50 as a safe range that works for float32
exp_k = (-indices * log_lambdas).clamp(min=-50, max=50)
exp_q = (indices * log_lambdas).clamp(min=-50, max=50)
# Compute decay factors
decay_k = torch.exp(exp_k) # lambda^-indices
decay_q = torch.exp(exp_q) # lambda^indices
k_decayed = k * decay_k
# Memory State Accumulation (KV)
# (B, L, H, E) * (B, L, H, E) -> (B, L, H, E, E)
# Einsum: b l h e, b l h f -> b l h e f
kv = torch.einsum('blhe,blhf->blhef', k_decayed, v)
# Cumsum (The "Write" Operation)
memory_state = torch.cumsum(kv, dim=1) # (B, L, H, E, E)
# Denominator Accumulation (Z) for normalization
k_sum_decayed = torch.cumsum(k_decayed, dim=1) # (B, L, H, E)
# Read Operation (Query * Memory)
q_decayed = q * decay_q
# Num: (B, L, H, E) * (B, L, H, E, E) -> (B, L, H, E)
num = torch.einsum('blhe,blhef->blhf', q_decayed, memory_state)
# Den: (B, L, H, E) * (B, L, H, E) -> (B, L, H)
den = torch.einsum('blhe,blhe->blh', q_decayed, k_sum_decayed)
den = den.unsqueeze(-1) + 1e-6 # Stability epsilon
out = num / den
# Final Projection
out = out.reshape(B, L, D)
out = self.out_proj(out)
# Convert back to input dtype before applying norm and dropout
out = self.dropout(self.norm(out))
return out.to(input_dtype) # Convert back to original dtype
|