| """
|
| MIRAS-inspired Associative Memory Module
|
|
|
| Implements an associative memory that learns key-value mappings
|
| through attentional bias objective during test time.
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
|
|
|
|
| class MIRASMemory(nn.Module):
|
| """
|
| Associative memory module inspired by MIRAS framework.
|
|
|
| The memory learns to map keys to values using a simple linear projection
|
| and updates itself during test time via gradient descent.
|
|
|
| Args:
|
| memory_dim: Dimensionality of memory keys/values
|
| init_scale: Scale for random weight initialization
|
| """
|
|
|
| def __init__(self, memory_dim=256, init_scale=0.01):
|
| super().__init__()
|
| self.memory_dim = memory_dim
|
|
|
|
|
|
|
| self.W = nn.Parameter(
|
| torch.randn(memory_dim, memory_dim) * init_scale
|
| )
|
|
|
|
|
| self.register_buffer('update_count', torch.tensor(0))
|
| self.register_buffer('total_loss', torch.tensor(0.0))
|
|
|
| def forward(self, key):
|
| """
|
| Query memory with a key.
|
|
|
| Args:
|
| key: (batch_size, memory_dim) tensor
|
|
|
| Returns:
|
| predicted_value: (batch_size, memory_dim) tensor
|
| """
|
|
|
| predicted_value = key @ self.W
|
| return predicted_value
|
|
|
| def query(self, key):
|
| """
|
| Query memory without computing gradients (for generation).
|
|
|
| Args:
|
| key: (batch_size, memory_dim) tensor
|
|
|
| Returns:
|
| memory_output: (batch_size, memory_dim) tensor
|
| """
|
| with torch.no_grad():
|
| return self.forward(key)
|
|
|
| def compute_loss(self, key, value):
|
| """
|
| Compute attentional bias loss between predicted and true value.
|
|
|
| Args:
|
| key: (batch_size, memory_dim)
|
| value: (batch_size, memory_dim)
|
|
|
| Returns:
|
| loss: scalar tensor
|
| """
|
| pred = self.forward(key)
|
| loss = ((pred - value) ** 2).mean()
|
| return loss
|
|
|
| def retention_gate(self, loss):
|
| """
|
| Retention gate: higher loss relative to average = more surprising = more memorable.
|
|
|
| Returns a scaling factor for the learning rate based on surprise.
|
| If current loss is higher than average, learn more aggressively.
|
|
|
| Args:
|
| loss: scalar tensor
|
|
|
| Returns:
|
| retention_factor: scalar in range [0.5, 2.0]
|
| """
|
|
|
| avg_loss = (self.total_loss / max(self.update_count, 1)).item()
|
| if avg_loss < 0.001:
|
| avg_loss = loss.item()
|
|
|
|
|
| surprise_ratio = loss.item() / max(avg_loss, 0.001)
|
|
|
|
|
|
|
|
|
|
|
| retention_factor = torch.clamp(torch.tensor(surprise_ratio), 0.5, 2.0)
|
| return retention_factor.item()
|
|
|
| def update_stats(self, loss):
|
| """Track memory statistics."""
|
| self.update_count += 1
|
| self.total_loss += loss.item()
|
|
|
| def get_stats(self):
|
| """Get memory statistics."""
|
| avg_loss = self.total_loss / max(self.update_count, 1)
|
| return {
|
| 'updates': self.update_count.item(),
|
| 'avg_loss': avg_loss.item(),
|
| 'memory_size': self.W.numel()
|
| }
|
|
|