File size: 8,912 Bytes
3451ca0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
# memory.py
from collections import deque
import torch
import torch.nn.functional as F
import config

class ActiveMemory:
    """

    An active memory module that stores and retrieves examples to enhance reasoning.

    Supports both logging for analysis and retrieval for improved predictions.

    """
    def __init__(self, max_size=config.MEMORY_MAX_SIZE, retrieval_k=config.MEMORY_RETRIEVAL_K):
        self.max_size = max_size
        self.retrieval_k = retrieval_k
        self.memory = deque(maxlen=max_size)
        self.device = config.DEVICE
        print(f"Initialized ActiveMemory with max size {self.max_size}, retrieval_k={self.retrieval_k}")

    def add(self, input_data, hidden_states, output, reasoning_trace, final_hidden_states=None, final_output=None):
        """

        Adds a new entry to the memory.



        Args:

            input_data: The input to the model (tokenized IDs, attention masks, etc.)

            hidden_states (H0): Initial hidden states from the base model

            output (y0): Initial prediction from the model

            reasoning_trace (T): Reasoning trace (all hidden states)

            final_hidden_states (H1, optional): Final hidden states after retroactive update

            final_output (y1, optional): Final prediction after retroactive update

        """
        # Create a memory entry with detached tensors moved to CPU
        entry = {
            'input_ids': input_data.get('input_ids', None).cpu().detach() if input_data.get('input_ids', None) is not None else None,
            'attention_mask': input_data.get('attention_mask', None).cpu().detach() if input_data.get('attention_mask', None) is not None else None,
            'token_type_ids': input_data.get('token_type_ids', None).cpu().detach() if input_data.get('token_type_ids', None) is not None else None,
            'hidden_states': hidden_states.cpu().detach(),
            'output': {k: v.cpu().detach() for k, v in output.items()} if isinstance(output, dict) else output.cpu().detach(),
            'reasoning_trace': tuple(h.cpu().detach() for h in reasoning_trace) if isinstance(reasoning_trace, tuple) else reasoning_trace.cpu().detach(),
        }
        
        # Add final states if provided
        if final_hidden_states is not None:
            entry['final_hidden_states'] = final_hidden_states.cpu().detach()
        if final_output is not None:
            entry['final_output'] = {k: v.cpu().detach() for k, v in final_output.items()} if isinstance(final_output, dict) else final_output.cpu().detach()
        
        # Compute and store a summary vector for efficient retrieval
        # Use mean pooling of hidden states as the summary vector
        if entry['hidden_states'] is not None and entry['attention_mask'] is not None:
            # Mean pooling with attention mask
            mask = entry['attention_mask'].unsqueeze(-1).float()
            masked_embeddings = entry['hidden_states'] * mask
            sum_embeddings = torch.sum(masked_embeddings, dim=1)
            sum_mask = torch.sum(mask, dim=1)
            sum_mask = torch.clamp(sum_mask, min=1e-9)
            entry['summary_vector'] = (sum_embeddings / sum_mask).squeeze(0)
        else:
            # Fallback to simple mean if attention mask is not available
            entry['summary_vector'] = entry['hidden_states'].mean(dim=1).squeeze(0)
        
        self.memory.append(entry)

    def retrieve(self, query_hidden_states, query_attention_mask=None, k=None):
        """

        Retrieves the k most similar examples from memory based on hidden state similarity.

        

        Args:

            query_hidden_states: Hidden states to compare against memory

            query_attention_mask: Attention mask for the query

            k: Number of examples to retrieve (defaults to self.retrieval_k)

            

        Returns:

            List of retrieved memory entries, ordered by similarity (most similar first)

        """
        if len(self.memory) == 0:
            return []
        
        if k is None:
            k = self.retrieval_k
        
        k = min(k, len(self.memory))
        
        # Compute query summary vector (mean pooling with attention mask)
        if query_attention_mask is not None:
            mask = query_attention_mask.unsqueeze(-1).float()
            masked_embeddings = query_hidden_states * mask
            sum_embeddings = torch.sum(masked_embeddings, dim=1)
            sum_mask = torch.sum(mask, dim=1)
            sum_mask = torch.clamp(sum_mask, min=1e-9)
            query_vector = (sum_embeddings / sum_mask).squeeze(0)
        else:
            query_vector = query_hidden_states.mean(dim=1).squeeze(0)
        
        # Move query vector to CPU for comparison with memory
        query_vector = query_vector.cpu().detach()
        
        # Compute similarities with all memory entries
        similarities = []
        for i, entry in enumerate(self.memory):
            memory_vector = entry['summary_vector']
            # Compute cosine similarity
            similarity = F.cosine_similarity(query_vector, memory_vector, dim=0)
            similarities.append((i, similarity.item()))
        
        # Sort by similarity (descending) and get top k
        similarities.sort(key=lambda x: x[1], reverse=True)
        top_k_indices = [idx for idx, _ in similarities[:k]]
        
        # Retrieve the top k entries
        retrieved_entries = [self.memory[idx] for idx in top_k_indices]
        
        # Move retrieved entries to the same device as the query
        device = query_hidden_states.device
        for entry in retrieved_entries:
            # Only move the tensors we'll actually use (hidden_states and final_hidden_states)
            if 'hidden_states' in entry:
                entry['hidden_states'] = entry['hidden_states'].to(device)
            if 'final_hidden_states' in entry:
                entry['final_hidden_states'] = entry['final_hidden_states'].to(device)
        
        return retrieved_entries
    
    def get_memory_context(self, query_hidden_states, query_attention_mask=None):
        """

        Retrieves and processes memory entries to create a context tensor for the model.

        

        Args:

            query_hidden_states: Hidden states to compare against memory

            query_attention_mask: Attention mask for the query

            

        Returns:

            memory_context: Tensor of shape (batch_size, seq_len, hidden_dim) containing

                           processed memory information, or None if memory is empty

        """
        # Retrieve similar examples from memory
        retrieved = self.retrieve(query_hidden_states, query_attention_mask)
        
        if not retrieved:
            return None
        
        # Use the device of the query
        device = query_hidden_states.device
        batch_size, seq_len, hidden_dim = query_hidden_states.shape
        
        # Process retrieved examples to create memory context
        # Strategy: Average the final hidden states of retrieved examples
        memory_tensors = []
        for entry in retrieved:
            # Prefer final hidden states if available, otherwise use initial hidden states
            if 'final_hidden_states' in entry and entry['final_hidden_states'] is not None:
                memory_tensors.append(entry['final_hidden_states'])
            elif 'hidden_states' in entry:
                memory_tensors.append(entry['hidden_states'])
        
        if not memory_tensors:
            return None
        
        # Average the memory tensors
        # First ensure all tensors have the same sequence length by padding or truncating
        padded_tensors = []
        for tensor in memory_tensors:
            if tensor.size(1) < seq_len:
                # Pad
                padding = torch.zeros(1, seq_len - tensor.size(1), hidden_dim, device=device)
                padded_tensor = torch.cat([tensor, padding], dim=1)
                padded_tensors.append(padded_tensor)
            elif tensor.size(1) > seq_len:
                # Truncate
                padded_tensors.append(tensor[:, :seq_len, :])
            else:
                padded_tensors.append(tensor)
        
        # Stack and average
        memory_context = torch.stack(padded_tensors).mean(dim=0)
        
        # Expand to match batch size if needed
        if memory_context.size(0) == 1 and batch_size > 1:
            memory_context = memory_context.expand(batch_size, -1, -1)
        
        return memory_context
    
    def clear(self):
        """Clears all entries from memory."""
        self.memory.clear()
    
    def __len__(self):
        return len(self.memory)