botverse commited on
Commit
8f20917
·
1 Parent(s): 770285c

test is working

Browse files
Files changed (1) hide show
  1. models/inference_memory_wrapper.py +219 -147
models/inference_memory_wrapper.py CHANGED
@@ -4,43 +4,53 @@ import torch.nn.functional as F
4
  import math
5
  from transformers import LlamaForCausalLM, LlamaConfig, PreTrainedModel
6
  from transformers.modeling_outputs import CausalLMOutputWithPast
 
7
  from typing import Optional, List, Tuple, Union
8
  import os
9
  from pathlib import Path
10
 
11
- class InferenceMemoryWrapper(PreTrainedModel):
12
- # Note: Inheriting PreTrainedModel helps with saving/loading config,
13
- # but the core logic wraps an existing LlamaForCausalLM instance.
14
- config_class = LlamaConfig # Use LlamaConfig
15
 
16
- def __init__(self, llama_model: LlamaForCausalLM, memory_size: int = 512, num_retrieved: int = 1, update_alpha: float = 0.1, surprise_momentum: float = 0.9, surprise_lr: float = 0.01):
17
- super().__init__(llama_model.config) # Initialize with the base model's config
18
- self.llama = llama_model
19
- self.memory_size = memory_size
20
- self.num_retrieved = 1 # Using attention retrieval, effectively K=1 weighted sum
21
- self.update_alpha = update_alpha # For EMA update (can be used as alternative)
22
- self.dim = self.llama.config.hidden_size
23
 
24
- # --- MODIFICATION: Memory buffer is a Parameter ---
25
- self.memory_buffer = nn.Parameter(torch.zeros(memory_size, self.dim)) # (memory_size, C)
26
- nn.init.normal_(self.memory_buffer, mean=0.0, std=1 / math.sqrt(self.dim))
 
27
 
28
- # --- Surprise Update Parameters & State ---
 
 
 
29
  self.surprise_momentum_eta = surprise_momentum
30
  self.surprise_lr_theta = surprise_lr
31
- self.register_buffer("surprise_state", torch.zeros_like(self.memory_buffer.data)) # (memory_size, C)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- # --- Attention Retrieval Projection ---
34
- # self.mac_query = nn.Linear(self.dim, self.dim, bias=False)
35
- # Optional: Key/Value projections for memory buffer in attention
36
- # self.mac_key = nn.Linear(self.dim, self.dim, bias=False)
37
- # self.mac_value = nn.Linear(self.dim, self.dim, bias=False)
38
 
39
  # --- Freeze the underlying Llama model ---
40
  for param in self.llama.parameters():
41
  param.requires_grad = False
42
- self.llama.eval() # Keep llama in eval mode permanently
43
 
 
44
  def get_input_embeddings(self):
45
  return self.llama.get_input_embeddings()
46
 
@@ -62,16 +72,19 @@ class InferenceMemoryWrapper(PreTrainedModel):
62
  Returns:
63
  torch.Tensor: Retrieved memory embedding (weighted sum). Shape (B, 1, C)
64
  """
65
-
66
- # q = self.mac_query(query_input) # (B, C)
67
- q = query_input # Use the input directly as the query (B, C)
68
 
69
  # Use memory_buffer directly as keys and values
 
70
  mem_keys = self.memory_buffer # (memory_size, C)
71
  mem_values = self.memory_buffer # (memory_size, C)
72
 
 
73
  attn_scores = torch.matmul(q, mem_keys.T) / math.sqrt(self.dim) # (B, memory_size)
74
  attn_weights = torch.softmax(attn_scores, dim=-1) # (B, memory_size)
 
 
75
  retrieved_mem = torch.matmul(attn_weights, mem_values) # (B, C)
76
 
77
  return retrieved_mem.unsqueeze(1) # (B, 1, C)
@@ -81,22 +94,16 @@ class InferenceMemoryWrapper(PreTrainedModel):
81
  def apply_surprise_update(self):
82
  """ Applies the TITANS-style surprise update rule using self.memory_buffer.grad """
83
  if self.memory_buffer.grad is None:
84
- # This might happen in the first step or if loss was zero
85
- # print("Warning: apply_surprise_update called but memory_buffer has no gradient.")
86
  return
87
 
88
- # Ensure surprise_state is on the same device
89
- self.surprise_state = self.surprise_state.to(self.memory_buffer.device)
90
 
91
- # S_t = η * S_{t-1} - θ * ∇_M L_assoc
92
- # Note the minus sign for gradient descent direction w.r.t the loss
93
  surprise_update_val = -self.surprise_lr_theta * self.memory_buffer.grad.data
94
  self.surprise_state.mul_(self.surprise_momentum_eta).add_(surprise_update_val)
95
 
96
- # M_t = M_{t-1} + S_t
97
  self.memory_buffer.data.add_(self.surprise_state)
98
-
99
- # Zero the gradient *after* using it
100
  self.memory_buffer.grad.zero_()
101
 
102
 
@@ -104,158 +111,226 @@ class InferenceMemoryWrapper(PreTrainedModel):
104
  @torch.no_grad()
105
  def update_memory_ema(self, new_context_embedding: torch.Tensor):
106
  """ Updates the memory buffer using EMA. """
107
- if new_context_embedding.shape[0] > 1:
108
- update_vec = new_context_embedding.mean(dim=0, keepdim=True) # (1, C)
109
- else:
110
- update_vec = new_context_embedding # (1, C)
111
 
 
112
  self.memory_buffer.data = self.memory_buffer.data.to(update_vec.device)
113
- # Simple EMA on the whole buffer - might be better to replace slots
114
  self.memory_buffer.data.mul_(1 - self.update_alpha).add_(update_vec * self.update_alpha)
115
 
116
 
117
- # --- Forward Pass (Pass-through) ---
118
- def forward(self, *args, **kwargs):
119
- # If used directly, ensure gradients can flow to memory_buffer if needed.
120
- # For generate, we handle it explicitly.
121
- return self.llama(*args, **kwargs)
122
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  # --- MODIFIED Generate Method with Inline Backward Pass ---
 
125
  def generate(
126
  self,
127
  input_ids: torch.LongTensor,
128
  max_new_tokens: int = 20,
129
  num_beams: int = 1,
130
  use_memory: bool = True,
131
- update_rule: str = 'ema', # Default to EMA for endpoints
132
  temperature: float = 0.7,
133
  top_p: float = 0.95,
134
  do_sample: bool = True,
135
  repetition_penalty: float = 1.0,
136
  eos_token_id: Optional[int] = None,
137
  pad_token_id: Optional[int] = None,
 
138
  **kwargs,
139
  ) -> torch.LongTensor:
140
- """
141
- Custom generate method incorporating memory retrieval and potential INFERENCE-TIME update.
142
- If update_rule='surprise', performs backward pass and memory update in each step.
143
- WARNING: Computationally expensive and experimental. KV Caching is disabled for simplicity.
144
- """
145
  if num_beams != 1:
146
  raise NotImplementedError("Beam search not implemented.")
147
  if update_rule == 'surprise' and not use_memory:
148
  print("Warning: update_rule='surprise' requires use_memory=True.")
149
  update_rule = 'none'
150
 
151
- # No torch.no_grad() context here.
152
-
153
- # self.train() if update_rule == 'surprise' else self.eval() # Llama is always eval, memory_buffer always requires grad
154
- # Only need train() if other components (like potential future query layers) needed it.
155
- # Since only memory_buffer needs grads, we can potentially remove this line
156
- # or just call self.train() to be explicit that *something* might need grads.
157
  if update_rule == 'surprise':
158
- # Ensure memory_buffer is treated as needing grads if other parts are frozen
159
- # This doesn't strictly change requires_grad, but good practice.
160
  self.memory_buffer.requires_grad_(True)
 
 
 
 
161
 
162
  bsz, seq_len_start = input_ids.shape
163
  device = input_ids.device
164
  generated_ids = input_ids.clone()
165
  current_seq_len = seq_len_start
 
 
166
 
167
  if eos_token_id is None: eos_token_id = self.config.eos_token_id
168
  if pad_token_id is None: pad_token_id = self.config.pad_token_id
169
 
170
- for step in range(max_new_tokens):
171
- # --- Prepare Inputs ---
172
- current_input_ids = generated_ids
173
 
174
- # 1. Embeddings
175
- inputs_embeds = self.llama.model.embed_tokens(current_input_ids) # (B, T_cur, C)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
- # --- Memory Query Input ---
178
- # Use the hidden state of the last token as the query basis
179
- # To get this, we might need a preliminary forward pass or use embeddings directly
180
- # Let's use the embedding of the last token for simplicity first
181
- query_basis = inputs_embeds[:, -1, :] # (B, C)
182
 
183
- # 2. Retrieve Memory (Differentiable if surprise update)
184
  retrieved_mem = None
185
  if use_memory:
 
186
  retrieved_mem = self.retrieve_memory(query_basis) # (B, 1, C)
187
 
188
- # 3. Prepend Memory
 
 
 
189
  if retrieved_mem is not None:
190
- combined_embeds = torch.cat([retrieved_mem, inputs_embeds], dim=1) # (B, 1 + T_cur, C)
191
- mem_len = retrieved_mem.shape[1] # Should be 1
192
- else:
193
- combined_embeds = inputs_embeds # (B, T_cur, C)
194
- mem_len = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
- # 4. Position IDs and Attention Mask
197
- combined_seq_len = combined_embeds.shape[1]
198
- position_ids = torch.arange(combined_seq_len, device=device).unsqueeze(0).expand(bsz, -1)
199
- attention_mask = torch.ones_like(position_ids) # Let Llama handle causal mask
200
-
201
- # --- Llama Forward Pass (With Gradients Enabled if surprise) ---
202
  outputs = self.llama(
203
- inputs_embeds=combined_embeds,
204
- attention_mask=attention_mask, # Or None
205
- position_ids=position_ids,
206
- past_key_values=None, # Disabled for simplicity
207
- use_cache=False, # Disabled for simplicity
 
208
  output_hidden_states=True, # Needed for query/target/update
209
  return_dict=True,
210
  )
211
 
212
  # --- Associative Loss Calculation (if surprise update) ---
213
  if update_rule == 'surprise' and use_memory and retrieved_mem is not None:
214
- # Target: Final hidden state for the last *input* token position
215
- target_repr = outputs.hidden_states[-1][:, mem_len + current_seq_len - 1, :] # (B, C)
216
- # Prediction: The memory retrieved based on the query_basis
 
 
 
217
  pred_repr = retrieved_mem.squeeze(1) # (B, C)
218
 
219
- # Calculate MSE Loss
220
- assoc_loss = F.mse_loss(pred_repr, target_repr.detach()) # Detach target!
221
 
222
- # --- Backward Pass & Update ---
223
- # Zero previous gradient for memory buffer
224
  if self.memory_buffer.grad is not None:
225
  self.memory_buffer.grad.zero_()
 
 
226
 
227
- # Compute gradient of assoc_loss w.r.t memory_buffer
228
- # Need to retain graph if other losses depend on memory? No, this is self-contained.
229
- assoc_loss.backward()
230
 
231
- # Apply the surprise update rule
232
- self.apply_surprise_update() # Uses .grad and zeros it
 
 
233
 
234
- # --- Standard Generation Logic ---
235
- # 5. Get Logits for the *original* sequence part's next token
236
- next_token_logits = outputs.logits[:, mem_len + current_seq_len - 1, :] # (B, V)
237
 
238
- # 6. Sampling (Apply penalties, temperature, top-p, sample)
239
- # Apply repetition penalty
240
  if repetition_penalty != 1.0:
241
- # Create penalty mask efficiently
242
- penalties = torch.ones_like(next_token_logits)
243
- prev_output_tokens = generated_ids # (B, T_cur)
244
- # Expand generated_ids to match logits shape for scatter_
245
- expanded_prev_tokens = prev_output_tokens.unsqueeze(-1).expand(-1, -1, next_token_logits.size(-1)) # (B, T_cur, V)
246
- # Use scatter_ to apply penalty only to previously generated tokens
247
- penalties.scatter_add_(1, prev_output_tokens, torch.full_like(prev_output_tokens, repetition_penalty - 1, dtype=penalties.dtype))
248
- next_token_logits /= penalties[:, -1, :] # Apply penalty based on last step's view? Needs care.
249
- # Simpler loop version (slower):
250
- # for i in range(bsz):
251
- # for token_id in generated_ids[i]:
252
- # next_token_logits[i, token_id] /= repetition_penalty
253
-
254
- # Apply temperature
255
  if temperature > 0 and temperature != 1.0:
256
  next_token_logits = next_token_logits / temperature
257
- # Apply top-p
258
  if do_sample and top_p < 1.0:
 
259
  sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
260
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
261
  sorted_indices_to_remove = cumulative_probs > top_p
@@ -263,34 +338,37 @@ class InferenceMemoryWrapper(PreTrainedModel):
263
  sorted_indices_to_remove[..., 0] = 0
264
  indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
265
  next_token_logits = next_token_logits.masked_fill(indices_to_remove, float('-inf'))
266
- # Sample next token
267
  if do_sample:
268
  probs = F.softmax(next_token_logits, dim=-1)
269
- next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
270
  else:
271
- next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) # (B, 1)
272
-
273
 
274
  # --- Update State ---
275
  generated_ids = torch.cat([generated_ids, next_token], dim=1)
276
  current_seq_len += 1
 
 
 
277
 
278
- # --- EMA Memory Update (if selected) ---
279
  if update_rule == 'ema' and use_memory and outputs.hidden_states is not None:
280
- # Use hidden state corresponding to the newly generated token
281
- new_context_state = outputs.hidden_states[-1][:, mem_len + current_seq_len - 1, :] # (B, C)
282
- self.update_memory_ema(new_context_state.detach()) # Detach as EMA doesn't need grads
 
283
 
284
- # Check stopping conditions
285
  if eos_token_id is not None and (next_token == eos_token_id).all():
286
  break
287
 
288
- # Restore eval mode if necessary
289
- self.eval()
290
 
291
  return generated_ids
292
 
293
- # --- Add Save/Load for Wrapper State ---
 
 
294
  def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
295
  """ Saves the wrapper's specific state (memory buffer, surprise state). """
296
  save_directory = Path(save_directory)
@@ -299,23 +377,17 @@ class InferenceMemoryWrapper(PreTrainedModel):
299
  # Save the base model's config (important for PreTrainedModel compatibility)
300
  self.config.save_pretrained(save_directory)
301
 
302
- # Save the memory buffer and surprise state
303
- torch.save(self.memory_buffer.state_dict(), save_directory / "memory_buffer.pt")
304
- torch.save(self.surprise_state, save_directory / "surprise_state.pt") # Save buffer directly
 
 
 
305
 
306
  print(f"InferenceMemoryWrapper state saved to {save_directory}")
307
  # Note: Base Llama model weights are assumed to be saved separately or loaded from source.
308
 
309
- @classmethod
310
- def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs):
311
- """ Loads the base model and the wrapper state. """
312
- # TODO: This needs careful implementation.
313
- # 1. Load the base Llama model (e.g., AutoModelForCausalLM.from_pretrained(...))
314
- # 2. Load the config for the wrapper
315
- # 3. Initialize the wrapper with the base model
316
- # 4. Load the memory_buffer.pt and surprise_state.pt into the wrapper instance
317
- raise NotImplementedError("Custom from_pretrained needs implementation for Inference Endpoints.")
318
- # For handler.py, we will load manually instead of relying on this classmethod.
319
-
320
- # Need to implement save/load methods if inheriting PreTrainedModel
321
- # or provide a way to save/load the wrapper + base model + memory buffer.
 
4
  import math
5
  from transformers import LlamaForCausalLM, LlamaConfig, PreTrainedModel
6
  from transformers.modeling_outputs import CausalLMOutputWithPast
7
+ from transformers.cache_utils import Cache
8
  from typing import Optional, List, Tuple, Union
9
  import os
10
  from pathlib import Path
11
 
12
+ # Use the actual LlamaForCausalLM from the packaged 'models' dir if needed,
13
+ # but relying on the globally installed transformers version is usually fine.
14
+ # from .hf_llama.modeling_llama import LlamaForCausalLM, LlamaConfig
 
15
 
16
+ class InferenceMemoryWrapper(PreTrainedModel):
17
+ # config_class = LlamaConfig # Keep if needed for saving config
 
 
 
 
 
18
 
19
+ # --- REVERTED __init__ signature ---
20
+ def __init__(self, llama_model: LlamaForCausalLM, memory_size: int = 4096, num_retrieved: int = 1, update_alpha: float = 0.1, surprise_momentum: float = 0.9, surprise_lr: float = 0.01):
21
+ super().__init__(llama_model.config) # Use config from the passed model
22
+ self.llama = llama_model # Store the pre-loaded model
23
 
24
+ # --- Use passed parameters ---
25
+ self.memory_size = memory_size
26
+ self.num_retrieved = num_retrieved
27
+ self.update_alpha = update_alpha
28
  self.surprise_momentum_eta = surprise_momentum
29
  self.surprise_lr_theta = surprise_lr
30
+ self.dim = llama_model.config.hidden_size
31
+ self._target_dtype = llama_model.dtype # Get dtype from the base model (should be float16)
32
+
33
+ # --- Memory buffer is a Parameter ---
34
+ # Create tensor directly with correct dtype on CPU initially
35
+ init_buffer_data = torch.zeros(self.memory_size, self.dim, dtype=self._target_dtype)
36
+ # Initialize in place
37
+ nn.init.normal_(init_buffer_data, mean=0.0, std=1 / math.sqrt(self.dim))
38
+ # Wrap in Parameter (Parameter itself doesn't change dtype)
39
+ self.memory_buffer = nn.Parameter(init_buffer_data)
40
+
41
+
42
+ # --- Surprise Update State ---
43
+ # Create tensor directly with correct dtype on CPU initially
44
+ init_surprise_state = torch.zeros_like(self.memory_buffer.data, dtype=self._target_dtype) # Use buffer's shape/dtype
45
+ self.register_buffer("surprise_state", init_surprise_state)
46
 
 
 
 
 
 
47
 
48
  # --- Freeze the underlying Llama model ---
49
  for param in self.llama.parameters():
50
  param.requires_grad = False
51
+ self.llama.eval() # Keep llama in eval mode
52
 
53
+ # --- Keep existing methods (get_input_embeddings, set_input_embeddings, etc.) ---
54
  def get_input_embeddings(self):
55
  return self.llama.get_input_embeddings()
56
 
 
72
  Returns:
73
  torch.Tensor: Retrieved memory embedding (weighted sum). Shape (B, 1, C)
74
  """
75
+ # Ensure query is the correct dtype (should match memory buffer)
76
+ q = query_input.to(self.memory_buffer.dtype) # Still check against buffer's actual dtype
 
77
 
78
  # Use memory_buffer directly as keys and values
79
+ # self.memory_buffer should now consistently be self._target_dtype (float16)
80
  mem_keys = self.memory_buffer # (memory_size, C)
81
  mem_values = self.memory_buffer # (memory_size, C)
82
 
83
+ # Matmul should now work as dtypes match
84
  attn_scores = torch.matmul(q, mem_keys.T) / math.sqrt(self.dim) # (B, memory_size)
85
  attn_weights = torch.softmax(attn_scores, dim=-1) # (B, memory_size)
86
+
87
+ # Ensure retrieved mem is also the correct dtype before returning
88
  retrieved_mem = torch.matmul(attn_weights, mem_values) # (B, C)
89
 
90
  return retrieved_mem.unsqueeze(1) # (B, 1, C)
 
94
  def apply_surprise_update(self):
95
  """ Applies the TITANS-style surprise update rule using self.memory_buffer.grad """
96
  if self.memory_buffer.grad is None:
 
 
97
  return
98
 
99
+ # Ensure surprise_state is on the same device and dtype
100
+ self.surprise_state = self.surprise_state.to(device=self.memory_buffer.device, dtype=self.memory_buffer.dtype)
101
 
102
+ # Grad should have the same dtype as the parameter
 
103
  surprise_update_val = -self.surprise_lr_theta * self.memory_buffer.grad.data
104
  self.surprise_state.mul_(self.surprise_momentum_eta).add_(surprise_update_val)
105
 
 
106
  self.memory_buffer.data.add_(self.surprise_state)
 
 
107
  self.memory_buffer.grad.zero_()
108
 
109
 
 
111
  @torch.no_grad()
112
  def update_memory_ema(self, new_context_embedding: torch.Tensor):
113
  """ Updates the memory buffer using EMA. """
114
+ # Ensure update vector is the correct dtype
115
+ update_vec_float = new_context_embedding.mean(dim=0, keepdim=True) if new_context_embedding.shape[0] > 1 else new_context_embedding # (1, C)
116
+ update_vec = update_vec_float.to(self.memory_buffer.dtype)
 
117
 
118
+ # Ensure buffer is on the correct device before update
119
  self.memory_buffer.data = self.memory_buffer.data.to(update_vec.device)
 
120
  self.memory_buffer.data.mul_(1 - self.update_alpha).add_(update_vec * self.update_alpha)
121
 
122
 
123
+ # --- Forward Pass (Pass-through to Llama) ---
124
+ # Overriding forward is needed if we want AutoModelForCausalLM(wrapper) to work directly
125
+ # This now needs to call self.llama.forward
126
+ def forward(
127
+ self,
128
+ input_ids: Optional[torch.LongTensor] = None,
129
+ attention_mask: Optional[torch.Tensor] = None,
130
+ position_ids: Optional[torch.LongTensor] = None,
131
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
132
+ inputs_embeds: Optional[torch.FloatTensor] = None,
133
+ labels: Optional[torch.LongTensor] = None,
134
+ use_cache: Optional[bool] = None,
135
+ output_attentions: Optional[bool] = None,
136
+ output_hidden_states: Optional[bool] = None,
137
+ return_dict: Optional[bool] = None,
138
+ cache_position: Optional[torch.LongTensor] = None,
139
+ **kwargs, # Pass any extra kwargs
140
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
141
+ # Directly call the wrapped llama model's forward pass
142
+ # Note: This basic forward doesn't include the memory prepending logic.
143
+ # That logic is currently only in the custom generate method.
144
+ # If you wanted to use model(input_ids) directly *with* memory,
145
+ # you'd need to replicate the generate logic here.
146
+ return self.llama(
147
+ input_ids=input_ids,
148
+ attention_mask=attention_mask,
149
+ position_ids=position_ids,
150
+ past_key_values=past_key_values,
151
+ inputs_embeds=inputs_embeds,
152
+ labels=labels,
153
+ use_cache=use_cache,
154
+ output_attentions=output_attentions,
155
+ output_hidden_states=output_hidden_states,
156
+ return_dict=return_dict,
157
+ cache_position=cache_position,
158
+ **kwargs,
159
+ )
160
 
161
  # --- MODIFIED Generate Method with Inline Backward Pass ---
162
+ # (Generate method remains largely the same as before, but ensure it uses self.llama correctly)
163
  def generate(
164
  self,
165
  input_ids: torch.LongTensor,
166
  max_new_tokens: int = 20,
167
  num_beams: int = 1,
168
  use_memory: bool = True,
169
+ update_rule: str = 'ema',
170
  temperature: float = 0.7,
171
  top_p: float = 0.95,
172
  do_sample: bool = True,
173
  repetition_penalty: float = 1.0,
174
  eos_token_id: Optional[int] = None,
175
  pad_token_id: Optional[int] = None,
176
+ attention_mask: Optional[torch.Tensor] = None, # Added attention_mask parameter
177
  **kwargs,
178
  ) -> torch.LongTensor:
 
 
 
 
 
179
  if num_beams != 1:
180
  raise NotImplementedError("Beam search not implemented.")
181
  if update_rule == 'surprise' and not use_memory:
182
  print("Warning: update_rule='surprise' requires use_memory=True.")
183
  update_rule = 'none'
184
 
 
 
 
 
 
 
185
  if update_rule == 'surprise':
 
 
186
  self.memory_buffer.requires_grad_(True)
187
+ else:
188
+ # Ensure no grads are computed if not needed
189
+ # Note: Llama part is already frozen and in eval mode
190
+ pass # No specific action needed if not surprise
191
 
192
  bsz, seq_len_start = input_ids.shape
193
  device = input_ids.device
194
  generated_ids = input_ids.clone()
195
  current_seq_len = seq_len_start
196
+ # Determine the expected dtype from the buffer
197
+ expected_dtype = self.memory_buffer.dtype # Use actual buffer dtype
198
 
199
  if eos_token_id is None: eos_token_id = self.config.eos_token_id
200
  if pad_token_id is None: pad_token_id = self.config.pad_token_id
201
 
202
+ past_key_values = None # Initialize KV cache
 
 
203
 
204
+ # Prepare initial attention mask if provided
205
+ if attention_mask is None:
206
+ attention_mask = torch.ones_like(input_ids)
207
+
208
+ for step in range(max_new_tokens):
209
+ # --- Prepare Inputs for this step ---
210
+ # Use only the last token for generation if KV cache is active
211
+ if past_key_values is not None:
212
+ current_input_ids = generated_ids[:, -1:]
213
+ # We need the hidden state/embedding of the *previous* token to query memory
214
+ # Let's get the full embeddings first, then select the query basis
215
+ # Use the full sequence length processed so far for embeddings
216
+ full_embeds = self.llama.model.embed_tokens(generated_ids) # (B, T_cur, C)
217
+ # Ensure query_basis has the expected dtype
218
+ query_basis = full_embeds[:, -1, :].to(expected_dtype) # Query based on the last token generated *before* this step
219
+ else:
220
+ current_input_ids = generated_ids
221
+ inputs_embeds_full = self.llama.model.embed_tokens(current_input_ids) # (B, T_cur, C)
222
+ # Ensure query_basis has the expected dtype
223
+ query_basis = inputs_embeds_full[:, -1, :].to(expected_dtype) # Query based on last token of the input prompt
224
 
 
 
 
 
 
225
 
226
+ # --- Memory Retrieval ---
227
  retrieved_mem = None
228
  if use_memory:
229
+ # query_basis should now match memory_buffer dtype
230
  retrieved_mem = self.retrieve_memory(query_basis) # (B, 1, C)
231
 
232
+ # --- Combine Embeddings and Prepare Model Inputs ---
233
+ # Manage attention mask and position IDs carefully
234
+ current_mask = None
235
+ mem_len = 0
236
  if retrieved_mem is not None:
237
+ retrieved_mem_casted = retrieved_mem.to(self.llama.dtype) # (B, 1, C_llama)
238
+ mem_len = retrieved_mem_casted.shape[1] # Should be 1
239
+
240
+ if past_key_values is None: # First step
241
+ inputs_embeds_full_casted = inputs_embeds_full.to(self.llama.dtype) # (B, T_cur, C_llama)
242
+ if retrieved_mem is not None:
243
+ model_inputs_embeds = torch.cat([retrieved_mem_casted, inputs_embeds_full_casted], dim=1) # (B, 1 + T_cur, C)
244
+ # Create mask for memory + original input mask
245
+ mem_mask = torch.ones((bsz, mem_len), dtype=attention_mask.dtype, device=device)
246
+ current_mask = torch.cat([mem_mask, attention_mask], dim=1) # (B, 1 + T_cur)
247
+ else:
248
+ model_inputs_embeds = inputs_embeds_full_casted # (B, T_cur, C)
249
+ current_mask = attention_mask # Use original mask
250
+
251
+ effective_seq_len = model_inputs_embeds.shape[1]
252
+ position_ids = torch.arange(effective_seq_len, device=device).unsqueeze(0) # (1, P+K+T)
253
+ cur_input_ids_for_llama = None # Using embeds
254
+ else: # Subsequent steps with KV cache
255
+ current_input_embeds = self.llama.model.embed_tokens(current_input_ids).to(self.llama.dtype) # (B, 1, C_llama)
256
+ if retrieved_mem is not None:
257
+ model_inputs_embeds = torch.cat([retrieved_mem_casted, current_input_embeds], dim=1) # (B, 1 + 1, C)
258
+ # Mask for memory + current token
259
+ current_mask = torch.ones((bsz, mem_len + 1), dtype=attention_mask.dtype, device=device) # (B, 1 + 1)
260
+ else:
261
+ model_inputs_embeds = current_input_embeds # (B, 1, C)
262
+ # Mask for current token only
263
+ current_mask = torch.ones((bsz, 1), dtype=attention_mask.dtype, device=device) # (B, 1)
264
+
265
+ # Position ID for the new token(s) relative to KV cache length + memory length
266
+ # LlamaModel._update_causal_mask and cache handling expect position_ids to reflect the absolute position
267
+ # cache_position (passed internally by generate if use_cache) handles this. We construct it manually here.
268
+ # The position id for the *new token* is the current sequence length (including memory if prepended this step)
269
+ past_len = past_key_values.get_seq_length() # Length stored in cache
270
+ # The position_id should reflect where this new token/memory would be in the *full* sequence if no cache was used
271
+ # Let's use current_seq_len derived from generated_ids, which doesn't include memory
272
+ position_ids = torch.tensor([[current_seq_len -1 + i + mem_len for i in range(model_inputs_embeds.shape[1])]], device=device) # (1, M+1) or (1, 1)
273
+
274
+ cur_input_ids_for_llama = None # Using embeds
275
+
276
+ # --- Llama Forward Pass ---
277
+ # Use KV caching if possible (update_rule != 'surprise')
278
+ # We need past_key_values AND not be doing surprise update AND base model supports caching
279
+ use_kv_cache_this_step = past_key_values is not None and update_rule != 'surprise' and self.llama.config.use_cache
280
 
 
 
 
 
 
 
281
  outputs = self.llama(
282
+ input_ids=cur_input_ids_for_llama, # None if using embeds
283
+ inputs_embeds=model_inputs_embeds,
284
+ attention_mask=current_mask, # Pass the correctly shaped mask for this step
285
+ position_ids=position_ids, # Pass adjusted position IDs
286
+ past_key_values=past_key_values,
287
+ use_cache=use_kv_cache_this_step,
288
  output_hidden_states=True, # Needed for query/target/update
289
  return_dict=True,
290
  )
291
 
292
  # --- Associative Loss Calculation (if surprise update) ---
293
  if update_rule == 'surprise' and use_memory and retrieved_mem is not None:
294
+ # Target: Final hidden state corresponding to the *last input token* before generation
295
+ # The index needs to account for the prepended memory.
296
+ # If mem_len=1, the target state corresponds to index -1 in the output sequence
297
+ target_repr = outputs.hidden_states[-1][:, -1, :].to(self.memory_buffer.dtype) # (B, C)
298
+
299
+ # pred_repr comes from retrieve_memory, should already match buffer dtype
300
  pred_repr = retrieved_mem.squeeze(1) # (B, C)
301
 
302
+ assoc_loss = F.mse_loss(pred_repr, target_repr.detach())
 
303
 
 
 
304
  if self.memory_buffer.grad is not None:
305
  self.memory_buffer.grad.zero_()
306
+ assoc_loss.backward() # Compute grads for memory_buffer
307
+ self.apply_surprise_update() # Apply update and zero grad
308
 
309
+ # --- Standard Generation Logic ---
310
+ # Get logits for the very last position in the output sequence (corresponds to the token we just fed in)
311
+ next_token_logits = outputs.logits[:, -1, :] # (B, V)
312
 
313
+ # Update KV cache for next step
314
+ if use_kv_cache_this_step:
315
+ # The past_key_values returned by Llama should account for the memory prepended in this step
316
+ past_key_values = outputs.past_key_values
317
 
 
 
 
318
 
319
+ # Sampling (same as before)
 
320
  if repetition_penalty != 1.0:
321
+ # Simple loop for now:
322
+ for i in range(bsz):
323
+ # Penalize tokens in the *generated* sequence (excluding prompt if needed)
324
+ # Use generated_ids which tracks the full sequence
325
+ for token_id in generated_ids[i]:
326
+ # Avoid penalizing pad token if present
327
+ if token_id != pad_token_id:
328
+ next_token_logits[i, token_id] /= repetition_penalty
329
+
 
 
 
 
 
330
  if temperature > 0 and temperature != 1.0:
331
  next_token_logits = next_token_logits / temperature
 
332
  if do_sample and top_p < 1.0:
333
+ # Use Hugging Face's top_p implementation detail
334
  sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
335
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
336
  sorted_indices_to_remove = cumulative_probs > top_p
 
338
  sorted_indices_to_remove[..., 0] = 0
339
  indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
340
  next_token_logits = next_token_logits.masked_fill(indices_to_remove, float('-inf'))
341
+
342
  if do_sample:
343
  probs = F.softmax(next_token_logits, dim=-1)
344
+ next_token = torch.multinomial(probs, num_samples=1)
345
  else:
346
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
 
347
 
348
  # --- Update State ---
349
  generated_ids = torch.cat([generated_ids, next_token], dim=1)
350
  current_seq_len += 1
351
+ # Update attention mask for the next iteration by appending 1
352
+ attention_mask = torch.cat([attention_mask, torch.ones((bsz, 1), dtype=attention_mask.dtype, device=device)], dim=1)
353
+
354
 
355
+ # --- EMA Memory Update ---
356
  if update_rule == 'ema' and use_memory and outputs.hidden_states is not None:
357
+ # Use hidden state corresponding to the newly generated token position (index -1)
358
+ # Cast state to buffer dtype before update
359
+ new_context_state = outputs.hidden_states[-1][:, -1, :].to(self.memory_buffer.dtype) # (B, C)
360
+ self.update_memory_ema(new_context_state.detach())
361
 
 
362
  if eos_token_id is not None and (next_token == eos_token_id).all():
363
  break
364
 
365
+ # self.eval() # Already in eval mode if llama is frozen
 
366
 
367
  return generated_ids
368
 
369
+
370
+ # --- Save/Load ---
371
+ # Keep the save_pretrained as is, it saves wrapper specific state.
372
  def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
373
  """ Saves the wrapper's specific state (memory buffer, surprise state). """
374
  save_directory = Path(save_directory)
 
377
  # Save the base model's config (important for PreTrainedModel compatibility)
378
  self.config.save_pretrained(save_directory)
379
 
380
+ # Save the memory buffer parameter directly
381
+ # Ensure saving in float32 for broader compatibility, can be cast back on load
382
+ # Note: Saving the Parameter itself, not just its .data
383
+ torch.save(self.memory_buffer.float(), save_directory / "memory_buffer.pt")
384
+ # Save the surprise state buffer directly
385
+ torch.save(self.surprise_state.float(), save_directory / "surprise_state.pt")
386
 
387
  print(f"InferenceMemoryWrapper state saved to {save_directory}")
388
  # Note: Base Llama model weights are assumed to be saved separately or loaded from source.
389
 
390
+ # from_pretrained is complex with wrappers. For local testing/handler, load manually.
391
+ # @classmethod
392
+ # def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs):
393
+ # raise NotImplementedError(...)