# memory.py from collections import deque import torch import torch.nn.functional as F import config class ActiveMemory: """ An active memory module that stores and retrieves examples to enhance reasoning. Supports both logging for analysis and retrieval for improved predictions. """ def __init__(self, max_size=config.MEMORY_MAX_SIZE, retrieval_k=config.MEMORY_RETRIEVAL_K): self.max_size = max_size self.retrieval_k = retrieval_k self.memory = deque(maxlen=max_size) self.device = config.DEVICE print(f"Initialized ActiveMemory with max size {self.max_size}, retrieval_k={self.retrieval_k}") def add(self, input_data, hidden_states, output, reasoning_trace, final_hidden_states=None, final_output=None): """ Adds a new entry to the memory. Args: input_data: The input to the model (tokenized IDs, attention masks, etc.) hidden_states (H0): Initial hidden states from the base model output (y0): Initial prediction from the model reasoning_trace (T): Reasoning trace (all hidden states) final_hidden_states (H1, optional): Final hidden states after retroactive update final_output (y1, optional): Final prediction after retroactive update """ # Create a memory entry with detached tensors moved to CPU entry = { 'input_ids': input_data.get('input_ids', None).cpu().detach() if input_data.get('input_ids', None) is not None else None, 'attention_mask': input_data.get('attention_mask', None).cpu().detach() if input_data.get('attention_mask', None) is not None else None, 'token_type_ids': input_data.get('token_type_ids', None).cpu().detach() if input_data.get('token_type_ids', None) is not None else None, 'hidden_states': hidden_states.cpu().detach(), 'output': {k: v.cpu().detach() for k, v in output.items()} if isinstance(output, dict) else output.cpu().detach(), 'reasoning_trace': tuple(h.cpu().detach() for h in reasoning_trace) if isinstance(reasoning_trace, tuple) else reasoning_trace.cpu().detach(), } # Add final states if provided if final_hidden_states is not None: entry['final_hidden_states'] = final_hidden_states.cpu().detach() if final_output is not None: entry['final_output'] = {k: v.cpu().detach() for k, v in final_output.items()} if isinstance(final_output, dict) else final_output.cpu().detach() # Compute and store a summary vector for efficient retrieval # Use mean pooling of hidden states as the summary vector if entry['hidden_states'] is not None and entry['attention_mask'] is not None: # Mean pooling with attention mask mask = entry['attention_mask'].unsqueeze(-1).float() masked_embeddings = entry['hidden_states'] * mask sum_embeddings = torch.sum(masked_embeddings, dim=1) sum_mask = torch.sum(mask, dim=1) sum_mask = torch.clamp(sum_mask, min=1e-9) entry['summary_vector'] = (sum_embeddings / sum_mask).squeeze(0) else: # Fallback to simple mean if attention mask is not available entry['summary_vector'] = entry['hidden_states'].mean(dim=1).squeeze(0) self.memory.append(entry) def retrieve(self, query_hidden_states, query_attention_mask=None, k=None): """ Retrieves the k most similar examples from memory based on hidden state similarity. Args: query_hidden_states: Hidden states to compare against memory query_attention_mask: Attention mask for the query k: Number of examples to retrieve (defaults to self.retrieval_k) Returns: List of retrieved memory entries, ordered by similarity (most similar first) """ if len(self.memory) == 0: return [] if k is None: k = self.retrieval_k k = min(k, len(self.memory)) # Compute query summary vector (mean pooling with attention mask) if query_attention_mask is not None: mask = query_attention_mask.unsqueeze(-1).float() masked_embeddings = query_hidden_states * mask sum_embeddings = torch.sum(masked_embeddings, dim=1) sum_mask = torch.sum(mask, dim=1) sum_mask = torch.clamp(sum_mask, min=1e-9) query_vector = (sum_embeddings / sum_mask).squeeze(0) else: query_vector = query_hidden_states.mean(dim=1).squeeze(0) # Move query vector to CPU for comparison with memory query_vector = query_vector.cpu().detach() # Compute similarities with all memory entries similarities = [] for i, entry in enumerate(self.memory): memory_vector = entry['summary_vector'] # Compute cosine similarity similarity = F.cosine_similarity(query_vector, memory_vector, dim=0) similarities.append((i, similarity.item())) # Sort by similarity (descending) and get top k similarities.sort(key=lambda x: x[1], reverse=True) top_k_indices = [idx for idx, _ in similarities[:k]] # Retrieve the top k entries retrieved_entries = [self.memory[idx] for idx in top_k_indices] # Move retrieved entries to the same device as the query device = query_hidden_states.device for entry in retrieved_entries: # Only move the tensors we'll actually use (hidden_states and final_hidden_states) if 'hidden_states' in entry: entry['hidden_states'] = entry['hidden_states'].to(device) if 'final_hidden_states' in entry: entry['final_hidden_states'] = entry['final_hidden_states'].to(device) return retrieved_entries def get_memory_context(self, query_hidden_states, query_attention_mask=None): """ Retrieves and processes memory entries to create a context tensor for the model. Args: query_hidden_states: Hidden states to compare against memory query_attention_mask: Attention mask for the query Returns: memory_context: Tensor of shape (batch_size, seq_len, hidden_dim) containing processed memory information, or None if memory is empty """ # Retrieve similar examples from memory retrieved = self.retrieve(query_hidden_states, query_attention_mask) if not retrieved: return None # Use the device of the query device = query_hidden_states.device batch_size, seq_len, hidden_dim = query_hidden_states.shape # Process retrieved examples to create memory context # Strategy: Average the final hidden states of retrieved examples memory_tensors = [] for entry in retrieved: # Prefer final hidden states if available, otherwise use initial hidden states if 'final_hidden_states' in entry and entry['final_hidden_states'] is not None: memory_tensors.append(entry['final_hidden_states']) elif 'hidden_states' in entry: memory_tensors.append(entry['hidden_states']) if not memory_tensors: return None # Average the memory tensors # First ensure all tensors have the same sequence length by padding or truncating padded_tensors = [] for tensor in memory_tensors: if tensor.size(1) < seq_len: # Pad padding = torch.zeros(1, seq_len - tensor.size(1), hidden_dim, device=device) padded_tensor = torch.cat([tensor, padding], dim=1) padded_tensors.append(padded_tensor) elif tensor.size(1) > seq_len: # Truncate padded_tensors.append(tensor[:, :seq_len, :]) else: padded_tensors.append(tensor) # Stack and average memory_context = torch.stack(padded_tensors).mean(dim=0) # Expand to match batch size if needed if memory_context.size(0) == 1 and batch_size > 1: memory_context = memory_context.expand(batch_size, -1, -1) return memory_context def clear(self): """Clears all entries from memory.""" self.memory.clear() def __len__(self): return len(self.memory)