File size: 3,777 Bytes
9659b2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
import torch.nn as nn
import torch.nn.functional as F
from otitans_core import OLoRALinear

class OTitansMemoryGate(nn.Module):
    """
    Phase 2: The OTITANS Memory Core.
    A recurrent memory state shielded by orthogonal LoRA projections.
    """
    def __init__(self, hidden_size: int, rank: int = 8, memory_momentum: float = 0.9):
        super().__init__()
        self.hidden_size = hidden_size
        self.memory_momentum = memory_momentum
        
        # 1. The Orthogonal Projections
        # We use standard nn.Linear here as placeholders, but in the actual injection script,
        # we will map these directly to Gemma's layers wrapped in our OLoRALinear class.
        self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
        self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
        self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)
        
        # 2. The Memory Gate
        # A learned parameter that decides how much to trust the recurrent memory vs the base attention.
        self.gate = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size // 4),
            nn.SiLU(),
            nn.Linear(hidden_size // 4, hidden_size),
            nn.Sigmoid()
        )
        
        # 3. The Persistent State
        # This is where Nyxxie's continuous memory lives. 
        self.register_buffer("memory_state", torch.zeros(hidden_size, hidden_size))

    def reset_memory(self):
        """Wipes the recurrent memory clean for a new session."""
        self.memory_state.zero_()

    def forward(self, hidden_states: torch.Tensor):
        batch_size, seq_len, _ = hidden_states.shape
        
        # Generate Queries, Keys, and Values through the orthogonal pathways
        q = self.q_proj(hidden_states)
        k = self.k_proj(hidden_states)
        v = self.v_proj(hidden_states)
        
        memory_outputs = []
        
        # The Recurrent Engine (Autoregressive Delta Rule Update)
        # Note: In training, we will parallelize this. For inference, it processes step-by-step.
        current_memory = self.memory_state.clone()
        
        for t in range(seq_len):
            q_t = q[:, t, :]  # Current query
            k_t = k[:, t, :]  # Current key
            v_t = v[:, t, :]  # Current value
            
            # Read from the current memory state
            # Retrieval = Q * Memory
            retrieval = torch.matmul(q_t.unsqueeze(1), current_memory).squeeze(1)
            memory_outputs.append(retrieval)
            
            # Update the memory state using the Surprise / Delta mechanism
            # How much does the new Key/Value differ from what we already know?
            memory_prediction = torch.matmul(k_t.unsqueeze(1), current_memory).squeeze(1)
            surprise = v_t - memory_prediction
            
            # Update: M_t = momentum * M_{t-1} + (Surprise ⊗ Key)
            update = torch.bmm(surprise.unsqueeze(2), k_t.unsqueeze(1))
            current_memory = (self.memory_momentum * current_memory) + update
            
        # Stack the memory retrievals back into the sequence shape
        memory_out_tensor = torch.stack(memory_outputs, dim=1)
        
        # Save the updated memory state for the next generation step
        self.memory_state.copy_(current_memory.detach())
        
        # Calculate the Gating mechanism: How much should we blend memory with standard logic?
        # We concatenate the base hidden states with the memory retrieval to decide.
        gate_input = torch.cat([hidden_states, memory_out_tensor], dim=-1)
        gate_value = self.gate(gate_input)
        
        # Return the gated memory logic
        return hidden_states + (gate_value * memory_out_tensor)