memory-augmented-generation / miras_memory.py
Pavantej's picture
Upload folder using huggingface_hub
6e79137 verified
"""
MIRAS-inspired Associative Memory Module
Implements an associative memory that learns key-value mappings
through attentional bias objective during test time.
"""
import torch
import torch.nn as nn
class MIRASMemory(nn.Module):
"""
Associative memory module inspired by MIRAS framework.
The memory learns to map keys to values using a simple linear projection
and updates itself during test time via gradient descent.
Args:
memory_dim: Dimensionality of memory keys/values
init_scale: Scale for random weight initialization
"""
def __init__(self, memory_dim=256, init_scale=0.01):
super().__init__()
self.memory_dim = memory_dim
# Memory matrix: maps keys to values
# W: (memory_dim, memory_dim)
self.W = nn.Parameter(
torch.randn(memory_dim, memory_dim) * init_scale
)
# Track number of updates for retention gate
self.register_buffer('update_count', torch.tensor(0))
self.register_buffer('total_loss', torch.tensor(0.0))
def forward(self, key):
"""
Query memory with a key.
Args:
key: (batch_size, memory_dim) tensor
Returns:
predicted_value: (batch_size, memory_dim) tensor
"""
# Simple linear mapping: pred_v = k @ W
predicted_value = key @ self.W
return predicted_value
def query(self, key):
"""
Query memory without computing gradients (for generation).
Args:
key: (batch_size, memory_dim) tensor
Returns:
memory_output: (batch_size, memory_dim) tensor
"""
with torch.no_grad():
return self.forward(key)
def compute_loss(self, key, value):
"""
Compute attentional bias loss between predicted and true value.
Args:
key: (batch_size, memory_dim)
value: (batch_size, memory_dim)
Returns:
loss: scalar tensor
"""
pred = self.forward(key)
loss = ((pred - value) ** 2).mean()
return loss
def retention_gate(self, loss):
"""
Retention gate: higher loss relative to average = more surprising = more memorable.
Returns a scaling factor for the learning rate based on surprise.
If current loss is higher than average, learn more aggressively.
Args:
loss: scalar tensor
Returns:
retention_factor: scalar in range [0.5, 2.0]
"""
# Get running average loss (or use current if first update)
avg_loss = (self.total_loss / max(self.update_count, 1)).item()
if avg_loss < 0.001: # First few updates
avg_loss = loss.item()
# Compute surprise ratio: how much higher is current loss vs average?
surprise_ratio = loss.item() / max(avg_loss, 0.001)
# Map surprise ratio to retention factor
# ratio = 1.0 (average) -> retention = 1.0
# ratio = 2.0 (2x surprise) -> retention = 2.0
# ratio = 0.5 (familiar) -> retention = 0.5
retention_factor = torch.clamp(torch.tensor(surprise_ratio), 0.5, 2.0)
return retention_factor.item()
def update_stats(self, loss):
"""Track memory statistics."""
self.update_count += 1
self.total_loss += loss.item()
def get_stats(self):
"""Get memory statistics."""
avg_loss = self.total_loss / max(self.update_count, 1)
return {
'updates': self.update_count.item(),
'avg_loss': avg_loss.item(),
'memory_size': self.W.numel()
}