Update AGIFORMER with Turkish benchmark
Browse files- src/models/reasoning.py +59 -0
src/models/reasoning.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Developer: inkbytefo
|
| 2 |
+
## Modified: 2025-11-22
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
class RecurrentReasoningBlock(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
System 2 Thinking Module.
|
| 10 |
+
Refines the latent representation through N steps of recurrence.
|
| 11 |
+
Formula: z_{t+1} = z_t + MLP(LayerNorm(z_t))
|
| 12 |
+
"""
|
| 13 |
+
def __init__(self, d_model, thinking_steps=3, dropout=0.1):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.d_model = d_model
|
| 16 |
+
self.thinking_steps = thinking_steps
|
| 17 |
+
|
| 18 |
+
# The "Thinking" Core
|
| 19 |
+
# A dense MLP that transforms the latent space
|
| 20 |
+
self.think_mlp = nn.Sequential(
|
| 21 |
+
nn.Linear(d_model, 4 * d_model),
|
| 22 |
+
nn.GELU(),
|
| 23 |
+
nn.Linear(4 * d_model, d_model),
|
| 24 |
+
nn.Dropout(dropout)
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
self.norm = nn.LayerNorm(d_model)
|
| 28 |
+
|
| 29 |
+
# Gate to control how much "thought" updates the state
|
| 30 |
+
# (Similar to LSTM update gate, helps stability)
|
| 31 |
+
self.gate = nn.Linear(d_model, d_model)
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
"""
|
| 35 |
+
Args:
|
| 36 |
+
x: (Batch, Seq_Len, d_model) - Initial Latent (System 1 output)
|
| 37 |
+
Returns:
|
| 38 |
+
x: Refined Latent (System 2 output)
|
| 39 |
+
"""
|
| 40 |
+
# Iterative Refinement
|
| 41 |
+
# We unroll the loop for 'thinking_steps'
|
| 42 |
+
|
| 43 |
+
current_thought = x
|
| 44 |
+
|
| 45 |
+
for _ in range(self.thinking_steps):
|
| 46 |
+
# Pre-norm
|
| 47 |
+
normed = self.norm(current_thought)
|
| 48 |
+
|
| 49 |
+
# Compute update candidate
|
| 50 |
+
update = self.think_mlp(normed)
|
| 51 |
+
|
| 52 |
+
# Compute Gate (0 to 1)
|
| 53 |
+
# Decides how much of the new thought to accept
|
| 54 |
+
g = torch.sigmoid(self.gate(normed))
|
| 55 |
+
|
| 56 |
+
# Residual Update: z_{t+1} = z_t + gate * update
|
| 57 |
+
current_thought = current_thought + (g * update)
|
| 58 |
+
|
| 59 |
+
return current_thought
|