rrn-qa / code /memory.py
will4381's picture
Upload folder using huggingface_hub
3451ca0 verified
# memory.py
from collections import deque
import torch
import torch.nn.functional as F
import config
class ActiveMemory:
"""
An active memory module that stores and retrieves examples to enhance reasoning.
Supports both logging for analysis and retrieval for improved predictions.
"""
def __init__(self, max_size=config.MEMORY_MAX_SIZE, retrieval_k=config.MEMORY_RETRIEVAL_K):
self.max_size = max_size
self.retrieval_k = retrieval_k
self.memory = deque(maxlen=max_size)
self.device = config.DEVICE
print(f"Initialized ActiveMemory with max size {self.max_size}, retrieval_k={self.retrieval_k}")
def add(self, input_data, hidden_states, output, reasoning_trace, final_hidden_states=None, final_output=None):
"""
Adds a new entry to the memory.
Args:
input_data: The input to the model (tokenized IDs, attention masks, etc.)
hidden_states (H0): Initial hidden states from the base model
output (y0): Initial prediction from the model
reasoning_trace (T): Reasoning trace (all hidden states)
final_hidden_states (H1, optional): Final hidden states after retroactive update
final_output (y1, optional): Final prediction after retroactive update
"""
# Create a memory entry with detached tensors moved to CPU
entry = {
'input_ids': input_data.get('input_ids', None).cpu().detach() if input_data.get('input_ids', None) is not None else None,
'attention_mask': input_data.get('attention_mask', None).cpu().detach() if input_data.get('attention_mask', None) is not None else None,
'token_type_ids': input_data.get('token_type_ids', None).cpu().detach() if input_data.get('token_type_ids', None) is not None else None,
'hidden_states': hidden_states.cpu().detach(),
'output': {k: v.cpu().detach() for k, v in output.items()} if isinstance(output, dict) else output.cpu().detach(),
'reasoning_trace': tuple(h.cpu().detach() for h in reasoning_trace) if isinstance(reasoning_trace, tuple) else reasoning_trace.cpu().detach(),
}
# Add final states if provided
if final_hidden_states is not None:
entry['final_hidden_states'] = final_hidden_states.cpu().detach()
if final_output is not None:
entry['final_output'] = {k: v.cpu().detach() for k, v in final_output.items()} if isinstance(final_output, dict) else final_output.cpu().detach()
# Compute and store a summary vector for efficient retrieval
# Use mean pooling of hidden states as the summary vector
if entry['hidden_states'] is not None and entry['attention_mask'] is not None:
# Mean pooling with attention mask
mask = entry['attention_mask'].unsqueeze(-1).float()
masked_embeddings = entry['hidden_states'] * mask
sum_embeddings = torch.sum(masked_embeddings, dim=1)
sum_mask = torch.sum(mask, dim=1)
sum_mask = torch.clamp(sum_mask, min=1e-9)
entry['summary_vector'] = (sum_embeddings / sum_mask).squeeze(0)
else:
# Fallback to simple mean if attention mask is not available
entry['summary_vector'] = entry['hidden_states'].mean(dim=1).squeeze(0)
self.memory.append(entry)
def retrieve(self, query_hidden_states, query_attention_mask=None, k=None):
"""
Retrieves the k most similar examples from memory based on hidden state similarity.
Args:
query_hidden_states: Hidden states to compare against memory
query_attention_mask: Attention mask for the query
k: Number of examples to retrieve (defaults to self.retrieval_k)
Returns:
List of retrieved memory entries, ordered by similarity (most similar first)
"""
if len(self.memory) == 0:
return []
if k is None:
k = self.retrieval_k
k = min(k, len(self.memory))
# Compute query summary vector (mean pooling with attention mask)
if query_attention_mask is not None:
mask = query_attention_mask.unsqueeze(-1).float()
masked_embeddings = query_hidden_states * mask
sum_embeddings = torch.sum(masked_embeddings, dim=1)
sum_mask = torch.sum(mask, dim=1)
sum_mask = torch.clamp(sum_mask, min=1e-9)
query_vector = (sum_embeddings / sum_mask).squeeze(0)
else:
query_vector = query_hidden_states.mean(dim=1).squeeze(0)
# Move query vector to CPU for comparison with memory
query_vector = query_vector.cpu().detach()
# Compute similarities with all memory entries
similarities = []
for i, entry in enumerate(self.memory):
memory_vector = entry['summary_vector']
# Compute cosine similarity
similarity = F.cosine_similarity(query_vector, memory_vector, dim=0)
similarities.append((i, similarity.item()))
# Sort by similarity (descending) and get top k
similarities.sort(key=lambda x: x[1], reverse=True)
top_k_indices = [idx for idx, _ in similarities[:k]]
# Retrieve the top k entries
retrieved_entries = [self.memory[idx] for idx in top_k_indices]
# Move retrieved entries to the same device as the query
device = query_hidden_states.device
for entry in retrieved_entries:
# Only move the tensors we'll actually use (hidden_states and final_hidden_states)
if 'hidden_states' in entry:
entry['hidden_states'] = entry['hidden_states'].to(device)
if 'final_hidden_states' in entry:
entry['final_hidden_states'] = entry['final_hidden_states'].to(device)
return retrieved_entries
def get_memory_context(self, query_hidden_states, query_attention_mask=None):
"""
Retrieves and processes memory entries to create a context tensor for the model.
Args:
query_hidden_states: Hidden states to compare against memory
query_attention_mask: Attention mask for the query
Returns:
memory_context: Tensor of shape (batch_size, seq_len, hidden_dim) containing
processed memory information, or None if memory is empty
"""
# Retrieve similar examples from memory
retrieved = self.retrieve(query_hidden_states, query_attention_mask)
if not retrieved:
return None
# Use the device of the query
device = query_hidden_states.device
batch_size, seq_len, hidden_dim = query_hidden_states.shape
# Process retrieved examples to create memory context
# Strategy: Average the final hidden states of retrieved examples
memory_tensors = []
for entry in retrieved:
# Prefer final hidden states if available, otherwise use initial hidden states
if 'final_hidden_states' in entry and entry['final_hidden_states'] is not None:
memory_tensors.append(entry['final_hidden_states'])
elif 'hidden_states' in entry:
memory_tensors.append(entry['hidden_states'])
if not memory_tensors:
return None
# Average the memory tensors
# First ensure all tensors have the same sequence length by padding or truncating
padded_tensors = []
for tensor in memory_tensors:
if tensor.size(1) < seq_len:
# Pad
padding = torch.zeros(1, seq_len - tensor.size(1), hidden_dim, device=device)
padded_tensor = torch.cat([tensor, padding], dim=1)
padded_tensors.append(padded_tensor)
elif tensor.size(1) > seq_len:
# Truncate
padded_tensors.append(tensor[:, :seq_len, :])
else:
padded_tensors.append(tensor)
# Stack and average
memory_context = torch.stack(padded_tensors).mean(dim=0)
# Expand to match batch size if needed
if memory_context.size(0) == 1 and batch_size > 1:
memory_context = memory_context.expand(batch_size, -1, -1)
return memory_context
def clear(self):
"""Clears all entries from memory."""
self.memory.clear()
def __len__(self):
return len(self.memory)