File size: 8,912 Bytes
3451ca0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
# memory.py
from collections import deque
import torch
import torch.nn.functional as F
import config
class ActiveMemory:
"""
An active memory module that stores and retrieves examples to enhance reasoning.
Supports both logging for analysis and retrieval for improved predictions.
"""
def __init__(self, max_size=config.MEMORY_MAX_SIZE, retrieval_k=config.MEMORY_RETRIEVAL_K):
self.max_size = max_size
self.retrieval_k = retrieval_k
self.memory = deque(maxlen=max_size)
self.device = config.DEVICE
print(f"Initialized ActiveMemory with max size {self.max_size}, retrieval_k={self.retrieval_k}")
def add(self, input_data, hidden_states, output, reasoning_trace, final_hidden_states=None, final_output=None):
"""
Adds a new entry to the memory.
Args:
input_data: The input to the model (tokenized IDs, attention masks, etc.)
hidden_states (H0): Initial hidden states from the base model
output (y0): Initial prediction from the model
reasoning_trace (T): Reasoning trace (all hidden states)
final_hidden_states (H1, optional): Final hidden states after retroactive update
final_output (y1, optional): Final prediction after retroactive update
"""
# Create a memory entry with detached tensors moved to CPU
entry = {
'input_ids': input_data.get('input_ids', None).cpu().detach() if input_data.get('input_ids', None) is not None else None,
'attention_mask': input_data.get('attention_mask', None).cpu().detach() if input_data.get('attention_mask', None) is not None else None,
'token_type_ids': input_data.get('token_type_ids', None).cpu().detach() if input_data.get('token_type_ids', None) is not None else None,
'hidden_states': hidden_states.cpu().detach(),
'output': {k: v.cpu().detach() for k, v in output.items()} if isinstance(output, dict) else output.cpu().detach(),
'reasoning_trace': tuple(h.cpu().detach() for h in reasoning_trace) if isinstance(reasoning_trace, tuple) else reasoning_trace.cpu().detach(),
}
# Add final states if provided
if final_hidden_states is not None:
entry['final_hidden_states'] = final_hidden_states.cpu().detach()
if final_output is not None:
entry['final_output'] = {k: v.cpu().detach() for k, v in final_output.items()} if isinstance(final_output, dict) else final_output.cpu().detach()
# Compute and store a summary vector for efficient retrieval
# Use mean pooling of hidden states as the summary vector
if entry['hidden_states'] is not None and entry['attention_mask'] is not None:
# Mean pooling with attention mask
mask = entry['attention_mask'].unsqueeze(-1).float()
masked_embeddings = entry['hidden_states'] * mask
sum_embeddings = torch.sum(masked_embeddings, dim=1)
sum_mask = torch.sum(mask, dim=1)
sum_mask = torch.clamp(sum_mask, min=1e-9)
entry['summary_vector'] = (sum_embeddings / sum_mask).squeeze(0)
else:
# Fallback to simple mean if attention mask is not available
entry['summary_vector'] = entry['hidden_states'].mean(dim=1).squeeze(0)
self.memory.append(entry)
def retrieve(self, query_hidden_states, query_attention_mask=None, k=None):
"""
Retrieves the k most similar examples from memory based on hidden state similarity.
Args:
query_hidden_states: Hidden states to compare against memory
query_attention_mask: Attention mask for the query
k: Number of examples to retrieve (defaults to self.retrieval_k)
Returns:
List of retrieved memory entries, ordered by similarity (most similar first)
"""
if len(self.memory) == 0:
return []
if k is None:
k = self.retrieval_k
k = min(k, len(self.memory))
# Compute query summary vector (mean pooling with attention mask)
if query_attention_mask is not None:
mask = query_attention_mask.unsqueeze(-1).float()
masked_embeddings = query_hidden_states * mask
sum_embeddings = torch.sum(masked_embeddings, dim=1)
sum_mask = torch.sum(mask, dim=1)
sum_mask = torch.clamp(sum_mask, min=1e-9)
query_vector = (sum_embeddings / sum_mask).squeeze(0)
else:
query_vector = query_hidden_states.mean(dim=1).squeeze(0)
# Move query vector to CPU for comparison with memory
query_vector = query_vector.cpu().detach()
# Compute similarities with all memory entries
similarities = []
for i, entry in enumerate(self.memory):
memory_vector = entry['summary_vector']
# Compute cosine similarity
similarity = F.cosine_similarity(query_vector, memory_vector, dim=0)
similarities.append((i, similarity.item()))
# Sort by similarity (descending) and get top k
similarities.sort(key=lambda x: x[1], reverse=True)
top_k_indices = [idx for idx, _ in similarities[:k]]
# Retrieve the top k entries
retrieved_entries = [self.memory[idx] for idx in top_k_indices]
# Move retrieved entries to the same device as the query
device = query_hidden_states.device
for entry in retrieved_entries:
# Only move the tensors we'll actually use (hidden_states and final_hidden_states)
if 'hidden_states' in entry:
entry['hidden_states'] = entry['hidden_states'].to(device)
if 'final_hidden_states' in entry:
entry['final_hidden_states'] = entry['final_hidden_states'].to(device)
return retrieved_entries
def get_memory_context(self, query_hidden_states, query_attention_mask=None):
"""
Retrieves and processes memory entries to create a context tensor for the model.
Args:
query_hidden_states: Hidden states to compare against memory
query_attention_mask: Attention mask for the query
Returns:
memory_context: Tensor of shape (batch_size, seq_len, hidden_dim) containing
processed memory information, or None if memory is empty
"""
# Retrieve similar examples from memory
retrieved = self.retrieve(query_hidden_states, query_attention_mask)
if not retrieved:
return None
# Use the device of the query
device = query_hidden_states.device
batch_size, seq_len, hidden_dim = query_hidden_states.shape
# Process retrieved examples to create memory context
# Strategy: Average the final hidden states of retrieved examples
memory_tensors = []
for entry in retrieved:
# Prefer final hidden states if available, otherwise use initial hidden states
if 'final_hidden_states' in entry and entry['final_hidden_states'] is not None:
memory_tensors.append(entry['final_hidden_states'])
elif 'hidden_states' in entry:
memory_tensors.append(entry['hidden_states'])
if not memory_tensors:
return None
# Average the memory tensors
# First ensure all tensors have the same sequence length by padding or truncating
padded_tensors = []
for tensor in memory_tensors:
if tensor.size(1) < seq_len:
# Pad
padding = torch.zeros(1, seq_len - tensor.size(1), hidden_dim, device=device)
padded_tensor = torch.cat([tensor, padding], dim=1)
padded_tensors.append(padded_tensor)
elif tensor.size(1) > seq_len:
# Truncate
padded_tensors.append(tensor[:, :seq_len, :])
else:
padded_tensors.append(tensor)
# Stack and average
memory_context = torch.stack(padded_tensors).mean(dim=0)
# Expand to match batch size if needed
if memory_context.size(0) == 1 and batch_size > 1:
memory_context = memory_context.expand(batch_size, -1, -1)
return memory_context
def clear(self):
"""Clears all entries from memory."""
self.memory.clear()
def __len__(self):
return len(self.memory)
|