File size: 5,184 Bytes
344be51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
## Developer: inkbytefo
## Modified: 2025-11-23

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class HebbianMemory(nn.Module):
    """
    Hebbian Memory Module (Fast Weights).
    
    Implements the update rule:
    M_t = lambda * M_{t-1} + K_t * V_t^T
    O_t = Q_t * M_t
    
    CRITICAL CHANGE:
    To prevent numerical overflow in parallel computation (cumsum),
    the decay rate (lambda) is constrained to the range [0.99, 1.0].
    This ensures lambda^(-L) does not explode for L=1024.
    """
    def __init__(self, d_model, num_heads=8, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        
        # Feature map: ELU + 1 ensures positivity for valid probability kernel
        self.feature_map = nn.ELU()
        
        # Learnable Decay Parameter
        # Initialized to generate sigmoid output ~0.5, mapped to range later
        self.decay_logits = nn.Parameter(torch.zeros(num_heads)) 
        
        self.norm = nn.LayerNorm(d_model)
        
        # Plasticity Factor (Alpha) - Controlled externally
        self.plasticity = 1.0

    def set_plasticity(self, alpha):
        """
        Updates the plasticity coefficient (alpha).
        alpha: float in [0, 1]. 
               0.1 -> Childhood (Fast forgetting)
               0.99 -> Adulthood (Stable memory)
        """
        self.plasticity = alpha

    @torch.amp.autocast('cuda', enabled=False)
    def forward(self, x):
        # CRITICAL: Bypass AMP for this entire module to prevent NaN
        # With plasticity=0.1, decay factors become exp(±50) and the cumsum
        # operations accumulate massive intermediate values that overflow in float16
        # We must use float32 for all computations including linear layers
        x = x.float()  # Ensure input is float32
        input_dtype = x.dtype
        
        B, L, D = x.shape
        H = self.num_heads
        E = self.head_dim
        
        # 1. Projections
        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)
        
        # Reshape (B, L, H, E)
        q = q.view(B, L, H, E)
        k = k.view(B, L, H, E)
        v = v.view(B, L, H, E)
        
        # 2. Feature Map (Kernel Trick)
        q = self.feature_map(q) + 1.0
        k = self.feature_map(k) + 1.0
        
        # Scale Q to prevent magnitude explosion
        q = q / math.sqrt(E)
        
        # 3. Decay Factor (Lambda) - STABILIZED
        # Map sigmoid (0,1) to (0.990, 1.0)
        # This prevents overflow. 0.99^-1024 ~= 29468 (Safe for FP32)
        raw_sigmoid = torch.sigmoid(self.decay_logits).view(1, 1, H, 1)
        lambdas = 0.99 + (0.01 * raw_sigmoid)
        
        # Apply Plasticity Schedule
        # Effective Lambda = Lambda * Alpha
        # If Alpha is low (childhood), decay is very fast.
        lambdas = lambdas * self.plasticity
        
        # 4. Parallel Hebbian Update
        # Formula: O_i = (Q_i * sum_{j=1}^i lambda^{i-j} K_j^T V_j)
        # Implementation: Q_i * lambda^i * cumsum(lambda^-j * K_j * V_j)
        
        indices = torch.arange(L, device=x.device, dtype=torch.float32).view(1, L, 1, 1)
        
        # Use log-space arithmetic to prevent overflow/underflow
        log_lambdas = torch.log(lambdas.clamp(min=1e-10))
        
        # Clamp the exponent BEFORE exp() to prevent overflow
        # We use ±50 as a safe range that works for float32
        exp_k = (-indices * log_lambdas).clamp(min=-50, max=50)
        exp_q = (indices * log_lambdas).clamp(min=-50, max=50)
        
        # Compute decay factors
        decay_k = torch.exp(exp_k)  # lambda^-indices
        decay_q = torch.exp(exp_q)  # lambda^indices
        
        k_decayed = k * decay_k
        
        # Memory State Accumulation (KV)
        # (B, L, H, E) * (B, L, H, E) -> (B, L, H, E, E)
        # Einsum: b l h e, b l h f -> b l h e f
        kv = torch.einsum('blhe,blhf->blhef', k_decayed, v)
        
        # Cumsum (The "Write" Operation)
        memory_state = torch.cumsum(kv, dim=1) # (B, L, H, E, E)
        
        # Denominator Accumulation (Z) for normalization
        k_sum_decayed = torch.cumsum(k_decayed, dim=1) # (B, L, H, E)
        
        # Read Operation (Query * Memory)
        q_decayed = q * decay_q
        
        # Num: (B, L, H, E) * (B, L, H, E, E) -> (B, L, H, E)
        num = torch.einsum('blhe,blhef->blhf', q_decayed, memory_state)
        
        # Den: (B, L, H, E) * (B, L, H, E) -> (B, L, H)
        den = torch.einsum('blhe,blhe->blh', q_decayed, k_sum_decayed)
        den = den.unsqueeze(-1) + 1e-6 # Stability epsilon
        
        out = num / den
        
        # Final Projection
        out = out.reshape(B, L, D)
        out = self.out_proj(out)
        
        # Convert back to input dtype before applying norm and dropout
        out = self.dropout(self.norm(out))
        return out.to(input_dtype)  # Convert back to original dtype