""" MIRAS-inspired Associative Memory Module Implements an associative memory that learns key-value mappings through attentional bias objective during test time. """ import torch import torch.nn as nn class MIRASMemory(nn.Module): """ Associative memory module inspired by MIRAS framework. The memory learns to map keys to values using a simple linear projection and updates itself during test time via gradient descent. Args: memory_dim: Dimensionality of memory keys/values init_scale: Scale for random weight initialization """ def __init__(self, memory_dim=256, init_scale=0.01): super().__init__() self.memory_dim = memory_dim # Memory matrix: maps keys to values # W: (memory_dim, memory_dim) self.W = nn.Parameter( torch.randn(memory_dim, memory_dim) * init_scale ) # Track number of updates for retention gate self.register_buffer('update_count', torch.tensor(0)) self.register_buffer('total_loss', torch.tensor(0.0)) def forward(self, key): """ Query memory with a key. Args: key: (batch_size, memory_dim) tensor Returns: predicted_value: (batch_size, memory_dim) tensor """ # Simple linear mapping: pred_v = k @ W predicted_value = key @ self.W return predicted_value def query(self, key): """ Query memory without computing gradients (for generation). Args: key: (batch_size, memory_dim) tensor Returns: memory_output: (batch_size, memory_dim) tensor """ with torch.no_grad(): return self.forward(key) def compute_loss(self, key, value): """ Compute attentional bias loss between predicted and true value. Args: key: (batch_size, memory_dim) value: (batch_size, memory_dim) Returns: loss: scalar tensor """ pred = self.forward(key) loss = ((pred - value) ** 2).mean() return loss def retention_gate(self, loss): """ Retention gate: higher loss relative to average = more surprising = more memorable. Returns a scaling factor for the learning rate based on surprise. If current loss is higher than average, learn more aggressively. Args: loss: scalar tensor Returns: retention_factor: scalar in range [0.5, 2.0] """ # Get running average loss (or use current if first update) avg_loss = (self.total_loss / max(self.update_count, 1)).item() if avg_loss < 0.001: # First few updates avg_loss = loss.item() # Compute surprise ratio: how much higher is current loss vs average? surprise_ratio = loss.item() / max(avg_loss, 0.001) # Map surprise ratio to retention factor # ratio = 1.0 (average) -> retention = 1.0 # ratio = 2.0 (2x surprise) -> retention = 2.0 # ratio = 0.5 (familiar) -> retention = 0.5 retention_factor = torch.clamp(torch.tensor(surprise_ratio), 0.5, 2.0) return retention_factor.item() def update_stats(self, loss): """Track memory statistics.""" self.update_count += 1 self.total_loss += loss.item() def get_stats(self): """Get memory statistics.""" avg_loss = self.total_loss / max(self.update_count, 1) return { 'updates': self.update_count.item(), 'avg_loss': avg_loss.item(), 'memory_size': self.W.numel() }