| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from otitans_core import OLoRALinear |
| |
|
| | class OTitansMemoryGate(nn.Module): |
| | """ |
| | Phase 2: The OTITANS Memory Core. |
| | A recurrent memory state shielded by orthogonal LoRA projections. |
| | """ |
| | def __init__(self, hidden_size: int, rank: int = 8, memory_momentum: float = 0.9): |
| | super().__init__() |
| | self.hidden_size = hidden_size |
| | self.memory_momentum = memory_momentum |
| | |
| | |
| | |
| | |
| | self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False) |
| | self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False) |
| | self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False) |
| | |
| | |
| | |
| | self.gate = nn.Sequential( |
| | nn.Linear(hidden_size * 2, hidden_size // 4), |
| | nn.SiLU(), |
| | nn.Linear(hidden_size // 4, hidden_size), |
| | nn.Sigmoid() |
| | ) |
| | |
| | |
| | |
| | self.register_buffer("memory_state", torch.zeros(hidden_size, hidden_size)) |
| |
|
| | def reset_memory(self): |
| | """Wipes the recurrent memory clean for a new session.""" |
| | self.memory_state.zero_() |
| |
|
| | def forward(self, hidden_states: torch.Tensor): |
| | batch_size, seq_len, _ = hidden_states.shape |
| | |
| | |
| | q = self.q_proj(hidden_states) |
| | k = self.k_proj(hidden_states) |
| | v = self.v_proj(hidden_states) |
| | |
| | memory_outputs = [] |
| | |
| | |
| | |
| | current_memory = self.memory_state.clone() |
| | |
| | for t in range(seq_len): |
| | q_t = q[:, t, :] |
| | k_t = k[:, t, :] |
| | v_t = v[:, t, :] |
| | |
| | |
| | |
| | retrieval = torch.matmul(q_t.unsqueeze(1), current_memory).squeeze(1) |
| | memory_outputs.append(retrieval) |
| | |
| | |
| | |
| | memory_prediction = torch.matmul(k_t.unsqueeze(1), current_memory).squeeze(1) |
| | surprise = v_t - memory_prediction |
| | |
| | |
| | update = torch.bmm(surprise.unsqueeze(2), k_t.unsqueeze(1)) |
| | current_memory = (self.memory_momentum * current_memory) + update |
| | |
| | |
| | memory_out_tensor = torch.stack(memory_outputs, dim=1) |
| | |
| | |
| | self.memory_state.copy_(current_memory.detach()) |
| | |
| | |
| | |
| | gate_input = torch.cat([hidden_states, memory_out_tensor], dim=-1) |
| | gate_value = self.gate(gate_input) |
| | |
| | |
| | return hidden_states + (gate_value * memory_out_tensor) |
| |
|