File size: 2,765 Bytes
f6484bc
98c160d
f6484bc
 
 
 
98c160d
f6484bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a878980
 
 
f6484bc
a878980
f6484bc
 
a878980
 
 
 
 
 
 
 
f6484bc
a878980
 
 
 
f6484bc
 
 
 
9472f1c
 
 
f6484bc
 
 
 
9472f1c
f6484bc
9472f1c
98c160d
 
9472f1c
f6484bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98c160d
f6484bc
98c160d
f6484bc
9472f1c
 
f6484bc
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
## Developer: inkbytefo
## Modified: 2025-11-23

import torch
import torch.nn as nn
import torch.nn.functional as F
from .memory import HebbianMemory

class SlidingWindowAttention(nn.Module):
    """
    Local Attention mechanism restricted to a sliding window.
    """
    def __init__(self, d_model: int, num_heads: int, window_size: int):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.window_size = window_size
        self.head_dim = d_model // num_heads
        
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.proj = nn.Linear(d_model, d_model)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, L, D = x.shape
        H = self.num_heads
        E = self.head_dim
        scale = 1.0 / (E ** 0.5)
        
        qkv = self.qkv(x).reshape(B, L, 3, H, E).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # Manual Attention for Stability
        scores = torch.matmul(q, k.transpose(-2, -1)) * scale
        
        # Construct Mask
        ones = torch.ones(L, L, device=x.device, dtype=torch.bool)
        causal_mask = ones.triu(1)
        window_mask = ones.tril(-self.window_size)
        mask = causal_mask | window_mask
        
        scores = scores.masked_fill(mask, -1e4)
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)
        
        out = out.transpose(1, 2).reshape(B, L, D)
        return self.proj(out)

class HybridBlock(nn.Module):
    """
    Combines Sliding Window Attention (Local) and Hebbian Memory (Global).
    """
    def __init__(self, d_model, num_heads, window_size, dropout):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        
        # Local Precision
        self.attn = SlidingWindowAttention(d_model, num_heads, window_size)
        
        # Global Context (Hebbian Memory)
        # Replaces the static LinearAttention with dynamic Fast Weights
        self.memory = HebbianMemory(d_model, num_heads, dropout)
        
        self.out_proj = nn.Linear(d_model, d_model)
        
        self.mlp = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
            nn.Dropout(dropout)
        )
        self.norm_mlp = nn.LayerNorm(d_model)

    def forward(self, x):
        residual = x
        x_norm = self.norm1(x)
        
        # Parallel Branches: Local Attention + Global Hebbian Memory
        attn_out = self.attn(x_norm)
        memory_out = self.memory(x_norm)
        
        # Fusion
        x = residual + self.out_proj(attn_out + memory_out)
        
        # MLP
        x = x + self.mlp(self.norm_mlp(x))
        
        return x