|
|
|
|
|
from collections import deque
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
import config
|
|
|
|
|
|
class ActiveMemory:
|
|
|
"""
|
|
|
An active memory module that stores and retrieves examples to enhance reasoning.
|
|
|
Supports both logging for analysis and retrieval for improved predictions.
|
|
|
"""
|
|
|
def __init__(self, max_size=config.MEMORY_MAX_SIZE, retrieval_k=config.MEMORY_RETRIEVAL_K):
|
|
|
self.max_size = max_size
|
|
|
self.retrieval_k = retrieval_k
|
|
|
self.memory = deque(maxlen=max_size)
|
|
|
self.device = config.DEVICE
|
|
|
print(f"Initialized ActiveMemory with max size {self.max_size}, retrieval_k={self.retrieval_k}")
|
|
|
|
|
|
def add(self, input_data, hidden_states, output, reasoning_trace, final_hidden_states=None, final_output=None):
|
|
|
"""
|
|
|
Adds a new entry to the memory.
|
|
|
|
|
|
Args:
|
|
|
input_data: The input to the model (tokenized IDs, attention masks, etc.)
|
|
|
hidden_states (H0): Initial hidden states from the base model
|
|
|
output (y0): Initial prediction from the model
|
|
|
reasoning_trace (T): Reasoning trace (all hidden states)
|
|
|
final_hidden_states (H1, optional): Final hidden states after retroactive update
|
|
|
final_output (y1, optional): Final prediction after retroactive update
|
|
|
"""
|
|
|
|
|
|
entry = {
|
|
|
'input_ids': input_data.get('input_ids', None).cpu().detach() if input_data.get('input_ids', None) is not None else None,
|
|
|
'attention_mask': input_data.get('attention_mask', None).cpu().detach() if input_data.get('attention_mask', None) is not None else None,
|
|
|
'token_type_ids': input_data.get('token_type_ids', None).cpu().detach() if input_data.get('token_type_ids', None) is not None else None,
|
|
|
'hidden_states': hidden_states.cpu().detach(),
|
|
|
'output': {k: v.cpu().detach() for k, v in output.items()} if isinstance(output, dict) else output.cpu().detach(),
|
|
|
'reasoning_trace': tuple(h.cpu().detach() for h in reasoning_trace) if isinstance(reasoning_trace, tuple) else reasoning_trace.cpu().detach(),
|
|
|
}
|
|
|
|
|
|
|
|
|
if final_hidden_states is not None:
|
|
|
entry['final_hidden_states'] = final_hidden_states.cpu().detach()
|
|
|
if final_output is not None:
|
|
|
entry['final_output'] = {k: v.cpu().detach() for k, v in final_output.items()} if isinstance(final_output, dict) else final_output.cpu().detach()
|
|
|
|
|
|
|
|
|
|
|
|
if entry['hidden_states'] is not None and entry['attention_mask'] is not None:
|
|
|
|
|
|
mask = entry['attention_mask'].unsqueeze(-1).float()
|
|
|
masked_embeddings = entry['hidden_states'] * mask
|
|
|
sum_embeddings = torch.sum(masked_embeddings, dim=1)
|
|
|
sum_mask = torch.sum(mask, dim=1)
|
|
|
sum_mask = torch.clamp(sum_mask, min=1e-9)
|
|
|
entry['summary_vector'] = (sum_embeddings / sum_mask).squeeze(0)
|
|
|
else:
|
|
|
|
|
|
entry['summary_vector'] = entry['hidden_states'].mean(dim=1).squeeze(0)
|
|
|
|
|
|
self.memory.append(entry)
|
|
|
|
|
|
def retrieve(self, query_hidden_states, query_attention_mask=None, k=None):
|
|
|
"""
|
|
|
Retrieves the k most similar examples from memory based on hidden state similarity.
|
|
|
|
|
|
Args:
|
|
|
query_hidden_states: Hidden states to compare against memory
|
|
|
query_attention_mask: Attention mask for the query
|
|
|
k: Number of examples to retrieve (defaults to self.retrieval_k)
|
|
|
|
|
|
Returns:
|
|
|
List of retrieved memory entries, ordered by similarity (most similar first)
|
|
|
"""
|
|
|
if len(self.memory) == 0:
|
|
|
return []
|
|
|
|
|
|
if k is None:
|
|
|
k = self.retrieval_k
|
|
|
|
|
|
k = min(k, len(self.memory))
|
|
|
|
|
|
|
|
|
if query_attention_mask is not None:
|
|
|
mask = query_attention_mask.unsqueeze(-1).float()
|
|
|
masked_embeddings = query_hidden_states * mask
|
|
|
sum_embeddings = torch.sum(masked_embeddings, dim=1)
|
|
|
sum_mask = torch.sum(mask, dim=1)
|
|
|
sum_mask = torch.clamp(sum_mask, min=1e-9)
|
|
|
query_vector = (sum_embeddings / sum_mask).squeeze(0)
|
|
|
else:
|
|
|
query_vector = query_hidden_states.mean(dim=1).squeeze(0)
|
|
|
|
|
|
|
|
|
query_vector = query_vector.cpu().detach()
|
|
|
|
|
|
|
|
|
similarities = []
|
|
|
for i, entry in enumerate(self.memory):
|
|
|
memory_vector = entry['summary_vector']
|
|
|
|
|
|
similarity = F.cosine_similarity(query_vector, memory_vector, dim=0)
|
|
|
similarities.append((i, similarity.item()))
|
|
|
|
|
|
|
|
|
similarities.sort(key=lambda x: x[1], reverse=True)
|
|
|
top_k_indices = [idx for idx, _ in similarities[:k]]
|
|
|
|
|
|
|
|
|
retrieved_entries = [self.memory[idx] for idx in top_k_indices]
|
|
|
|
|
|
|
|
|
device = query_hidden_states.device
|
|
|
for entry in retrieved_entries:
|
|
|
|
|
|
if 'hidden_states' in entry:
|
|
|
entry['hidden_states'] = entry['hidden_states'].to(device)
|
|
|
if 'final_hidden_states' in entry:
|
|
|
entry['final_hidden_states'] = entry['final_hidden_states'].to(device)
|
|
|
|
|
|
return retrieved_entries
|
|
|
|
|
|
def get_memory_context(self, query_hidden_states, query_attention_mask=None):
|
|
|
"""
|
|
|
Retrieves and processes memory entries to create a context tensor for the model.
|
|
|
|
|
|
Args:
|
|
|
query_hidden_states: Hidden states to compare against memory
|
|
|
query_attention_mask: Attention mask for the query
|
|
|
|
|
|
Returns:
|
|
|
memory_context: Tensor of shape (batch_size, seq_len, hidden_dim) containing
|
|
|
processed memory information, or None if memory is empty
|
|
|
"""
|
|
|
|
|
|
retrieved = self.retrieve(query_hidden_states, query_attention_mask)
|
|
|
|
|
|
if not retrieved:
|
|
|
return None
|
|
|
|
|
|
|
|
|
device = query_hidden_states.device
|
|
|
batch_size, seq_len, hidden_dim = query_hidden_states.shape
|
|
|
|
|
|
|
|
|
|
|
|
memory_tensors = []
|
|
|
for entry in retrieved:
|
|
|
|
|
|
if 'final_hidden_states' in entry and entry['final_hidden_states'] is not None:
|
|
|
memory_tensors.append(entry['final_hidden_states'])
|
|
|
elif 'hidden_states' in entry:
|
|
|
memory_tensors.append(entry['hidden_states'])
|
|
|
|
|
|
if not memory_tensors:
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
padded_tensors = []
|
|
|
for tensor in memory_tensors:
|
|
|
if tensor.size(1) < seq_len:
|
|
|
|
|
|
padding = torch.zeros(1, seq_len - tensor.size(1), hidden_dim, device=device)
|
|
|
padded_tensor = torch.cat([tensor, padding], dim=1)
|
|
|
padded_tensors.append(padded_tensor)
|
|
|
elif tensor.size(1) > seq_len:
|
|
|
|
|
|
padded_tensors.append(tensor[:, :seq_len, :])
|
|
|
else:
|
|
|
padded_tensors.append(tensor)
|
|
|
|
|
|
|
|
|
memory_context = torch.stack(padded_tensors).mean(dim=0)
|
|
|
|
|
|
|
|
|
if memory_context.size(0) == 1 and batch_size > 1:
|
|
|
memory_context = memory_context.expand(batch_size, -1, -1)
|
|
|
|
|
|
return memory_context
|
|
|
|
|
|
def clear(self):
|
|
|
"""Clears all entries from memory."""
|
|
|
self.memory.clear()
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.memory)
|
|
|
|