test is working
Browse files- models/inference_memory_wrapper.py +219 -147
models/inference_memory_wrapper.py
CHANGED
|
@@ -4,43 +4,53 @@ import torch.nn.functional as F
|
|
| 4 |
import math
|
| 5 |
from transformers import LlamaForCausalLM, LlamaConfig, PreTrainedModel
|
| 6 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
|
|
| 7 |
from typing import Optional, List, Tuple, Union
|
| 8 |
import os
|
| 9 |
from pathlib import Path
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
config_class = LlamaConfig # Use LlamaConfig
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
self.llama = llama_model
|
| 19 |
-
self.memory_size = memory_size
|
| 20 |
-
self.num_retrieved = 1 # Using attention retrieval, effectively K=1 weighted sum
|
| 21 |
-
self.update_alpha = update_alpha # For EMA update (can be used as alternative)
|
| 22 |
-
self.dim = self.llama.config.hidden_size
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
| 27 |
|
| 28 |
-
# ---
|
|
|
|
|
|
|
|
|
|
| 29 |
self.surprise_momentum_eta = surprise_momentum
|
| 30 |
self.surprise_lr_theta = surprise_lr
|
| 31 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
-
# --- Attention Retrieval Projection ---
|
| 34 |
-
# self.mac_query = nn.Linear(self.dim, self.dim, bias=False)
|
| 35 |
-
# Optional: Key/Value projections for memory buffer in attention
|
| 36 |
-
# self.mac_key = nn.Linear(self.dim, self.dim, bias=False)
|
| 37 |
-
# self.mac_value = nn.Linear(self.dim, self.dim, bias=False)
|
| 38 |
|
| 39 |
# --- Freeze the underlying Llama model ---
|
| 40 |
for param in self.llama.parameters():
|
| 41 |
param.requires_grad = False
|
| 42 |
-
self.llama.eval() # Keep llama in eval mode
|
| 43 |
|
|
|
|
| 44 |
def get_input_embeddings(self):
|
| 45 |
return self.llama.get_input_embeddings()
|
| 46 |
|
|
@@ -62,16 +72,19 @@ class InferenceMemoryWrapper(PreTrainedModel):
|
|
| 62 |
Returns:
|
| 63 |
torch.Tensor: Retrieved memory embedding (weighted sum). Shape (B, 1, C)
|
| 64 |
"""
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
q = query_input # Use the input directly as the query (B, C)
|
| 68 |
|
| 69 |
# Use memory_buffer directly as keys and values
|
|
|
|
| 70 |
mem_keys = self.memory_buffer # (memory_size, C)
|
| 71 |
mem_values = self.memory_buffer # (memory_size, C)
|
| 72 |
|
|
|
|
| 73 |
attn_scores = torch.matmul(q, mem_keys.T) / math.sqrt(self.dim) # (B, memory_size)
|
| 74 |
attn_weights = torch.softmax(attn_scores, dim=-1) # (B, memory_size)
|
|
|
|
|
|
|
| 75 |
retrieved_mem = torch.matmul(attn_weights, mem_values) # (B, C)
|
| 76 |
|
| 77 |
return retrieved_mem.unsqueeze(1) # (B, 1, C)
|
|
@@ -81,22 +94,16 @@ class InferenceMemoryWrapper(PreTrainedModel):
|
|
| 81 |
def apply_surprise_update(self):
|
| 82 |
""" Applies the TITANS-style surprise update rule using self.memory_buffer.grad """
|
| 83 |
if self.memory_buffer.grad is None:
|
| 84 |
-
# This might happen in the first step or if loss was zero
|
| 85 |
-
# print("Warning: apply_surprise_update called but memory_buffer has no gradient.")
|
| 86 |
return
|
| 87 |
|
| 88 |
-
# Ensure surprise_state is on the same device
|
| 89 |
-
self.surprise_state = self.surprise_state.to(self.memory_buffer.device)
|
| 90 |
|
| 91 |
-
#
|
| 92 |
-
# Note the minus sign for gradient descent direction w.r.t the loss
|
| 93 |
surprise_update_val = -self.surprise_lr_theta * self.memory_buffer.grad.data
|
| 94 |
self.surprise_state.mul_(self.surprise_momentum_eta).add_(surprise_update_val)
|
| 95 |
|
| 96 |
-
# M_t = M_{t-1} + S_t
|
| 97 |
self.memory_buffer.data.add_(self.surprise_state)
|
| 98 |
-
|
| 99 |
-
# Zero the gradient *after* using it
|
| 100 |
self.memory_buffer.grad.zero_()
|
| 101 |
|
| 102 |
|
|
@@ -104,158 +111,226 @@ class InferenceMemoryWrapper(PreTrainedModel):
|
|
| 104 |
@torch.no_grad()
|
| 105 |
def update_memory_ema(self, new_context_embedding: torch.Tensor):
|
| 106 |
""" Updates the memory buffer using EMA. """
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
update_vec = new_context_embedding # (1, C)
|
| 111 |
|
|
|
|
| 112 |
self.memory_buffer.data = self.memory_buffer.data.to(update_vec.device)
|
| 113 |
-
# Simple EMA on the whole buffer - might be better to replace slots
|
| 114 |
self.memory_buffer.data.mul_(1 - self.update_alpha).add_(update_vec * self.update_alpha)
|
| 115 |
|
| 116 |
|
| 117 |
-
# --- Forward Pass (Pass-through) ---
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
# --- MODIFIED Generate Method with Inline Backward Pass ---
|
|
|
|
| 125 |
def generate(
|
| 126 |
self,
|
| 127 |
input_ids: torch.LongTensor,
|
| 128 |
max_new_tokens: int = 20,
|
| 129 |
num_beams: int = 1,
|
| 130 |
use_memory: bool = True,
|
| 131 |
-
update_rule: str = 'ema',
|
| 132 |
temperature: float = 0.7,
|
| 133 |
top_p: float = 0.95,
|
| 134 |
do_sample: bool = True,
|
| 135 |
repetition_penalty: float = 1.0,
|
| 136 |
eos_token_id: Optional[int] = None,
|
| 137 |
pad_token_id: Optional[int] = None,
|
|
|
|
| 138 |
**kwargs,
|
| 139 |
) -> torch.LongTensor:
|
| 140 |
-
"""
|
| 141 |
-
Custom generate method incorporating memory retrieval and potential INFERENCE-TIME update.
|
| 142 |
-
If update_rule='surprise', performs backward pass and memory update in each step.
|
| 143 |
-
WARNING: Computationally expensive and experimental. KV Caching is disabled for simplicity.
|
| 144 |
-
"""
|
| 145 |
if num_beams != 1:
|
| 146 |
raise NotImplementedError("Beam search not implemented.")
|
| 147 |
if update_rule == 'surprise' and not use_memory:
|
| 148 |
print("Warning: update_rule='surprise' requires use_memory=True.")
|
| 149 |
update_rule = 'none'
|
| 150 |
|
| 151 |
-
# No torch.no_grad() context here.
|
| 152 |
-
|
| 153 |
-
# self.train() if update_rule == 'surprise' else self.eval() # Llama is always eval, memory_buffer always requires grad
|
| 154 |
-
# Only need train() if other components (like potential future query layers) needed it.
|
| 155 |
-
# Since only memory_buffer needs grads, we can potentially remove this line
|
| 156 |
-
# or just call self.train() to be explicit that *something* might need grads.
|
| 157 |
if update_rule == 'surprise':
|
| 158 |
-
# Ensure memory_buffer is treated as needing grads if other parts are frozen
|
| 159 |
-
# This doesn't strictly change requires_grad, but good practice.
|
| 160 |
self.memory_buffer.requires_grad_(True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
bsz, seq_len_start = input_ids.shape
|
| 163 |
device = input_ids.device
|
| 164 |
generated_ids = input_ids.clone()
|
| 165 |
current_seq_len = seq_len_start
|
|
|
|
|
|
|
| 166 |
|
| 167 |
if eos_token_id is None: eos_token_id = self.config.eos_token_id
|
| 168 |
if pad_token_id is None: pad_token_id = self.config.pad_token_id
|
| 169 |
|
| 170 |
-
|
| 171 |
-
# --- Prepare Inputs ---
|
| 172 |
-
current_input_ids = generated_ids
|
| 173 |
|
| 174 |
-
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
-
# --- Memory Query Input ---
|
| 178 |
-
# Use the hidden state of the last token as the query basis
|
| 179 |
-
# To get this, we might need a preliminary forward pass or use embeddings directly
|
| 180 |
-
# Let's use the embedding of the last token for simplicity first
|
| 181 |
-
query_basis = inputs_embeds[:, -1, :] # (B, C)
|
| 182 |
|
| 183 |
-
#
|
| 184 |
retrieved_mem = None
|
| 185 |
if use_memory:
|
|
|
|
| 186 |
retrieved_mem = self.retrieve_memory(query_basis) # (B, 1, C)
|
| 187 |
|
| 188 |
-
#
|
|
|
|
|
|
|
|
|
|
| 189 |
if retrieved_mem is not None:
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
-
# 4. Position IDs and Attention Mask
|
| 197 |
-
combined_seq_len = combined_embeds.shape[1]
|
| 198 |
-
position_ids = torch.arange(combined_seq_len, device=device).unsqueeze(0).expand(bsz, -1)
|
| 199 |
-
attention_mask = torch.ones_like(position_ids) # Let Llama handle causal mask
|
| 200 |
-
|
| 201 |
-
# --- Llama Forward Pass (With Gradients Enabled if surprise) ---
|
| 202 |
outputs = self.llama(
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
|
|
|
| 208 |
output_hidden_states=True, # Needed for query/target/update
|
| 209 |
return_dict=True,
|
| 210 |
)
|
| 211 |
|
| 212 |
# --- Associative Loss Calculation (if surprise update) ---
|
| 213 |
if update_rule == 'surprise' and use_memory and retrieved_mem is not None:
|
| 214 |
-
# Target: Final hidden state
|
| 215 |
-
|
| 216 |
-
#
|
|
|
|
|
|
|
|
|
|
| 217 |
pred_repr = retrieved_mem.squeeze(1) # (B, C)
|
| 218 |
|
| 219 |
-
|
| 220 |
-
assoc_loss = F.mse_loss(pred_repr, target_repr.detach()) # Detach target!
|
| 221 |
|
| 222 |
-
# --- Backward Pass & Update ---
|
| 223 |
-
# Zero previous gradient for memory buffer
|
| 224 |
if self.memory_buffer.grad is not None:
|
| 225 |
self.memory_buffer.grad.zero_()
|
|
|
|
|
|
|
| 226 |
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
|
| 231 |
-
|
| 232 |
-
|
|
|
|
|
|
|
| 233 |
|
| 234 |
-
# --- Standard Generation Logic ---
|
| 235 |
-
# 5. Get Logits for the *original* sequence part's next token
|
| 236 |
-
next_token_logits = outputs.logits[:, mem_len + current_seq_len - 1, :] # (B, V)
|
| 237 |
|
| 238 |
-
#
|
| 239 |
-
# Apply repetition penalty
|
| 240 |
if repetition_penalty != 1.0:
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
# for i in range(bsz):
|
| 251 |
-
# for token_id in generated_ids[i]:
|
| 252 |
-
# next_token_logits[i, token_id] /= repetition_penalty
|
| 253 |
-
|
| 254 |
-
# Apply temperature
|
| 255 |
if temperature > 0 and temperature != 1.0:
|
| 256 |
next_token_logits = next_token_logits / temperature
|
| 257 |
-
# Apply top-p
|
| 258 |
if do_sample and top_p < 1.0:
|
|
|
|
| 259 |
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
| 260 |
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 261 |
sorted_indices_to_remove = cumulative_probs > top_p
|
|
@@ -263,34 +338,37 @@ class InferenceMemoryWrapper(PreTrainedModel):
|
|
| 263 |
sorted_indices_to_remove[..., 0] = 0
|
| 264 |
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 265 |
next_token_logits = next_token_logits.masked_fill(indices_to_remove, float('-inf'))
|
| 266 |
-
|
| 267 |
if do_sample:
|
| 268 |
probs = F.softmax(next_token_logits, dim=-1)
|
| 269 |
-
next_token = torch.multinomial(probs, num_samples=1)
|
| 270 |
else:
|
| 271 |
-
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
| 272 |
-
|
| 273 |
|
| 274 |
# --- Update State ---
|
| 275 |
generated_ids = torch.cat([generated_ids, next_token], dim=1)
|
| 276 |
current_seq_len += 1
|
|
|
|
|
|
|
|
|
|
| 277 |
|
| 278 |
-
# --- EMA Memory Update
|
| 279 |
if update_rule == 'ema' and use_memory and outputs.hidden_states is not None:
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
|
|
|
| 283 |
|
| 284 |
-
# Check stopping conditions
|
| 285 |
if eos_token_id is not None and (next_token == eos_token_id).all():
|
| 286 |
break
|
| 287 |
|
| 288 |
-
#
|
| 289 |
-
self.eval()
|
| 290 |
|
| 291 |
return generated_ids
|
| 292 |
|
| 293 |
-
|
|
|
|
|
|
|
| 294 |
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
|
| 295 |
""" Saves the wrapper's specific state (memory buffer, surprise state). """
|
| 296 |
save_directory = Path(save_directory)
|
|
@@ -299,23 +377,17 @@ class InferenceMemoryWrapper(PreTrainedModel):
|
|
| 299 |
# Save the base model's config (important for PreTrainedModel compatibility)
|
| 300 |
self.config.save_pretrained(save_directory)
|
| 301 |
|
| 302 |
-
# Save the memory buffer
|
| 303 |
-
|
| 304 |
-
|
|
|
|
|
|
|
|
|
|
| 305 |
|
| 306 |
print(f"InferenceMemoryWrapper state saved to {save_directory}")
|
| 307 |
# Note: Base Llama model weights are assumed to be saved separately or loaded from source.
|
| 308 |
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
# 1. Load the base Llama model (e.g., AutoModelForCausalLM.from_pretrained(...))
|
| 314 |
-
# 2. Load the config for the wrapper
|
| 315 |
-
# 3. Initialize the wrapper with the base model
|
| 316 |
-
# 4. Load the memory_buffer.pt and surprise_state.pt into the wrapper instance
|
| 317 |
-
raise NotImplementedError("Custom from_pretrained needs implementation for Inference Endpoints.")
|
| 318 |
-
# For handler.py, we will load manually instead of relying on this classmethod.
|
| 319 |
-
|
| 320 |
-
# Need to implement save/load methods if inheriting PreTrainedModel
|
| 321 |
-
# or provide a way to save/load the wrapper + base model + memory buffer.
|
|
|
|
| 4 |
import math
|
| 5 |
from transformers import LlamaForCausalLM, LlamaConfig, PreTrainedModel
|
| 6 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 7 |
+
from transformers.cache_utils import Cache
|
| 8 |
from typing import Optional, List, Tuple, Union
|
| 9 |
import os
|
| 10 |
from pathlib import Path
|
| 11 |
|
| 12 |
+
# Use the actual LlamaForCausalLM from the packaged 'models' dir if needed,
|
| 13 |
+
# but relying on the globally installed transformers version is usually fine.
|
| 14 |
+
# from .hf_llama.modeling_llama import LlamaForCausalLM, LlamaConfig
|
|
|
|
| 15 |
|
| 16 |
+
class InferenceMemoryWrapper(PreTrainedModel):
|
| 17 |
+
# config_class = LlamaConfig # Keep if needed for saving config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
+
# --- REVERTED __init__ signature ---
|
| 20 |
+
def __init__(self, llama_model: LlamaForCausalLM, memory_size: int = 4096, num_retrieved: int = 1, update_alpha: float = 0.1, surprise_momentum: float = 0.9, surprise_lr: float = 0.01):
|
| 21 |
+
super().__init__(llama_model.config) # Use config from the passed model
|
| 22 |
+
self.llama = llama_model # Store the pre-loaded model
|
| 23 |
|
| 24 |
+
# --- Use passed parameters ---
|
| 25 |
+
self.memory_size = memory_size
|
| 26 |
+
self.num_retrieved = num_retrieved
|
| 27 |
+
self.update_alpha = update_alpha
|
| 28 |
self.surprise_momentum_eta = surprise_momentum
|
| 29 |
self.surprise_lr_theta = surprise_lr
|
| 30 |
+
self.dim = llama_model.config.hidden_size
|
| 31 |
+
self._target_dtype = llama_model.dtype # Get dtype from the base model (should be float16)
|
| 32 |
+
|
| 33 |
+
# --- Memory buffer is a Parameter ---
|
| 34 |
+
# Create tensor directly with correct dtype on CPU initially
|
| 35 |
+
init_buffer_data = torch.zeros(self.memory_size, self.dim, dtype=self._target_dtype)
|
| 36 |
+
# Initialize in place
|
| 37 |
+
nn.init.normal_(init_buffer_data, mean=0.0, std=1 / math.sqrt(self.dim))
|
| 38 |
+
# Wrap in Parameter (Parameter itself doesn't change dtype)
|
| 39 |
+
self.memory_buffer = nn.Parameter(init_buffer_data)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# --- Surprise Update State ---
|
| 43 |
+
# Create tensor directly with correct dtype on CPU initially
|
| 44 |
+
init_surprise_state = torch.zeros_like(self.memory_buffer.data, dtype=self._target_dtype) # Use buffer's shape/dtype
|
| 45 |
+
self.register_buffer("surprise_state", init_surprise_state)
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
# --- Freeze the underlying Llama model ---
|
| 49 |
for param in self.llama.parameters():
|
| 50 |
param.requires_grad = False
|
| 51 |
+
self.llama.eval() # Keep llama in eval mode
|
| 52 |
|
| 53 |
+
# --- Keep existing methods (get_input_embeddings, set_input_embeddings, etc.) ---
|
| 54 |
def get_input_embeddings(self):
|
| 55 |
return self.llama.get_input_embeddings()
|
| 56 |
|
|
|
|
| 72 |
Returns:
|
| 73 |
torch.Tensor: Retrieved memory embedding (weighted sum). Shape (B, 1, C)
|
| 74 |
"""
|
| 75 |
+
# Ensure query is the correct dtype (should match memory buffer)
|
| 76 |
+
q = query_input.to(self.memory_buffer.dtype) # Still check against buffer's actual dtype
|
|
|
|
| 77 |
|
| 78 |
# Use memory_buffer directly as keys and values
|
| 79 |
+
# self.memory_buffer should now consistently be self._target_dtype (float16)
|
| 80 |
mem_keys = self.memory_buffer # (memory_size, C)
|
| 81 |
mem_values = self.memory_buffer # (memory_size, C)
|
| 82 |
|
| 83 |
+
# Matmul should now work as dtypes match
|
| 84 |
attn_scores = torch.matmul(q, mem_keys.T) / math.sqrt(self.dim) # (B, memory_size)
|
| 85 |
attn_weights = torch.softmax(attn_scores, dim=-1) # (B, memory_size)
|
| 86 |
+
|
| 87 |
+
# Ensure retrieved mem is also the correct dtype before returning
|
| 88 |
retrieved_mem = torch.matmul(attn_weights, mem_values) # (B, C)
|
| 89 |
|
| 90 |
return retrieved_mem.unsqueeze(1) # (B, 1, C)
|
|
|
|
| 94 |
def apply_surprise_update(self):
|
| 95 |
""" Applies the TITANS-style surprise update rule using self.memory_buffer.grad """
|
| 96 |
if self.memory_buffer.grad is None:
|
|
|
|
|
|
|
| 97 |
return
|
| 98 |
|
| 99 |
+
# Ensure surprise_state is on the same device and dtype
|
| 100 |
+
self.surprise_state = self.surprise_state.to(device=self.memory_buffer.device, dtype=self.memory_buffer.dtype)
|
| 101 |
|
| 102 |
+
# Grad should have the same dtype as the parameter
|
|
|
|
| 103 |
surprise_update_val = -self.surprise_lr_theta * self.memory_buffer.grad.data
|
| 104 |
self.surprise_state.mul_(self.surprise_momentum_eta).add_(surprise_update_val)
|
| 105 |
|
|
|
|
| 106 |
self.memory_buffer.data.add_(self.surprise_state)
|
|
|
|
|
|
|
| 107 |
self.memory_buffer.grad.zero_()
|
| 108 |
|
| 109 |
|
|
|
|
| 111 |
@torch.no_grad()
|
| 112 |
def update_memory_ema(self, new_context_embedding: torch.Tensor):
|
| 113 |
""" Updates the memory buffer using EMA. """
|
| 114 |
+
# Ensure update vector is the correct dtype
|
| 115 |
+
update_vec_float = new_context_embedding.mean(dim=0, keepdim=True) if new_context_embedding.shape[0] > 1 else new_context_embedding # (1, C)
|
| 116 |
+
update_vec = update_vec_float.to(self.memory_buffer.dtype)
|
|
|
|
| 117 |
|
| 118 |
+
# Ensure buffer is on the correct device before update
|
| 119 |
self.memory_buffer.data = self.memory_buffer.data.to(update_vec.device)
|
|
|
|
| 120 |
self.memory_buffer.data.mul_(1 - self.update_alpha).add_(update_vec * self.update_alpha)
|
| 121 |
|
| 122 |
|
| 123 |
+
# --- Forward Pass (Pass-through to Llama) ---
|
| 124 |
+
# Overriding forward is needed if we want AutoModelForCausalLM(wrapper) to work directly
|
| 125 |
+
# This now needs to call self.llama.forward
|
| 126 |
+
def forward(
|
| 127 |
+
self,
|
| 128 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 129 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 130 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 131 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
| 132 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 133 |
+
labels: Optional[torch.LongTensor] = None,
|
| 134 |
+
use_cache: Optional[bool] = None,
|
| 135 |
+
output_attentions: Optional[bool] = None,
|
| 136 |
+
output_hidden_states: Optional[bool] = None,
|
| 137 |
+
return_dict: Optional[bool] = None,
|
| 138 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 139 |
+
**kwargs, # Pass any extra kwargs
|
| 140 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 141 |
+
# Directly call the wrapped llama model's forward pass
|
| 142 |
+
# Note: This basic forward doesn't include the memory prepending logic.
|
| 143 |
+
# That logic is currently only in the custom generate method.
|
| 144 |
+
# If you wanted to use model(input_ids) directly *with* memory,
|
| 145 |
+
# you'd need to replicate the generate logic here.
|
| 146 |
+
return self.llama(
|
| 147 |
+
input_ids=input_ids,
|
| 148 |
+
attention_mask=attention_mask,
|
| 149 |
+
position_ids=position_ids,
|
| 150 |
+
past_key_values=past_key_values,
|
| 151 |
+
inputs_embeds=inputs_embeds,
|
| 152 |
+
labels=labels,
|
| 153 |
+
use_cache=use_cache,
|
| 154 |
+
output_attentions=output_attentions,
|
| 155 |
+
output_hidden_states=output_hidden_states,
|
| 156 |
+
return_dict=return_dict,
|
| 157 |
+
cache_position=cache_position,
|
| 158 |
+
**kwargs,
|
| 159 |
+
)
|
| 160 |
|
| 161 |
# --- MODIFIED Generate Method with Inline Backward Pass ---
|
| 162 |
+
# (Generate method remains largely the same as before, but ensure it uses self.llama correctly)
|
| 163 |
def generate(
|
| 164 |
self,
|
| 165 |
input_ids: torch.LongTensor,
|
| 166 |
max_new_tokens: int = 20,
|
| 167 |
num_beams: int = 1,
|
| 168 |
use_memory: bool = True,
|
| 169 |
+
update_rule: str = 'ema',
|
| 170 |
temperature: float = 0.7,
|
| 171 |
top_p: float = 0.95,
|
| 172 |
do_sample: bool = True,
|
| 173 |
repetition_penalty: float = 1.0,
|
| 174 |
eos_token_id: Optional[int] = None,
|
| 175 |
pad_token_id: Optional[int] = None,
|
| 176 |
+
attention_mask: Optional[torch.Tensor] = None, # Added attention_mask parameter
|
| 177 |
**kwargs,
|
| 178 |
) -> torch.LongTensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
if num_beams != 1:
|
| 180 |
raise NotImplementedError("Beam search not implemented.")
|
| 181 |
if update_rule == 'surprise' and not use_memory:
|
| 182 |
print("Warning: update_rule='surprise' requires use_memory=True.")
|
| 183 |
update_rule = 'none'
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
if update_rule == 'surprise':
|
|
|
|
|
|
|
| 186 |
self.memory_buffer.requires_grad_(True)
|
| 187 |
+
else:
|
| 188 |
+
# Ensure no grads are computed if not needed
|
| 189 |
+
# Note: Llama part is already frozen and in eval mode
|
| 190 |
+
pass # No specific action needed if not surprise
|
| 191 |
|
| 192 |
bsz, seq_len_start = input_ids.shape
|
| 193 |
device = input_ids.device
|
| 194 |
generated_ids = input_ids.clone()
|
| 195 |
current_seq_len = seq_len_start
|
| 196 |
+
# Determine the expected dtype from the buffer
|
| 197 |
+
expected_dtype = self.memory_buffer.dtype # Use actual buffer dtype
|
| 198 |
|
| 199 |
if eos_token_id is None: eos_token_id = self.config.eos_token_id
|
| 200 |
if pad_token_id is None: pad_token_id = self.config.pad_token_id
|
| 201 |
|
| 202 |
+
past_key_values = None # Initialize KV cache
|
|
|
|
|
|
|
| 203 |
|
| 204 |
+
# Prepare initial attention mask if provided
|
| 205 |
+
if attention_mask is None:
|
| 206 |
+
attention_mask = torch.ones_like(input_ids)
|
| 207 |
+
|
| 208 |
+
for step in range(max_new_tokens):
|
| 209 |
+
# --- Prepare Inputs for this step ---
|
| 210 |
+
# Use only the last token for generation if KV cache is active
|
| 211 |
+
if past_key_values is not None:
|
| 212 |
+
current_input_ids = generated_ids[:, -1:]
|
| 213 |
+
# We need the hidden state/embedding of the *previous* token to query memory
|
| 214 |
+
# Let's get the full embeddings first, then select the query basis
|
| 215 |
+
# Use the full sequence length processed so far for embeddings
|
| 216 |
+
full_embeds = self.llama.model.embed_tokens(generated_ids) # (B, T_cur, C)
|
| 217 |
+
# Ensure query_basis has the expected dtype
|
| 218 |
+
query_basis = full_embeds[:, -1, :].to(expected_dtype) # Query based on the last token generated *before* this step
|
| 219 |
+
else:
|
| 220 |
+
current_input_ids = generated_ids
|
| 221 |
+
inputs_embeds_full = self.llama.model.embed_tokens(current_input_ids) # (B, T_cur, C)
|
| 222 |
+
# Ensure query_basis has the expected dtype
|
| 223 |
+
query_basis = inputs_embeds_full[:, -1, :].to(expected_dtype) # Query based on last token of the input prompt
|
| 224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
+
# --- Memory Retrieval ---
|
| 227 |
retrieved_mem = None
|
| 228 |
if use_memory:
|
| 229 |
+
# query_basis should now match memory_buffer dtype
|
| 230 |
retrieved_mem = self.retrieve_memory(query_basis) # (B, 1, C)
|
| 231 |
|
| 232 |
+
# --- Combine Embeddings and Prepare Model Inputs ---
|
| 233 |
+
# Manage attention mask and position IDs carefully
|
| 234 |
+
current_mask = None
|
| 235 |
+
mem_len = 0
|
| 236 |
if retrieved_mem is not None:
|
| 237 |
+
retrieved_mem_casted = retrieved_mem.to(self.llama.dtype) # (B, 1, C_llama)
|
| 238 |
+
mem_len = retrieved_mem_casted.shape[1] # Should be 1
|
| 239 |
+
|
| 240 |
+
if past_key_values is None: # First step
|
| 241 |
+
inputs_embeds_full_casted = inputs_embeds_full.to(self.llama.dtype) # (B, T_cur, C_llama)
|
| 242 |
+
if retrieved_mem is not None:
|
| 243 |
+
model_inputs_embeds = torch.cat([retrieved_mem_casted, inputs_embeds_full_casted], dim=1) # (B, 1 + T_cur, C)
|
| 244 |
+
# Create mask for memory + original input mask
|
| 245 |
+
mem_mask = torch.ones((bsz, mem_len), dtype=attention_mask.dtype, device=device)
|
| 246 |
+
current_mask = torch.cat([mem_mask, attention_mask], dim=1) # (B, 1 + T_cur)
|
| 247 |
+
else:
|
| 248 |
+
model_inputs_embeds = inputs_embeds_full_casted # (B, T_cur, C)
|
| 249 |
+
current_mask = attention_mask # Use original mask
|
| 250 |
+
|
| 251 |
+
effective_seq_len = model_inputs_embeds.shape[1]
|
| 252 |
+
position_ids = torch.arange(effective_seq_len, device=device).unsqueeze(0) # (1, P+K+T)
|
| 253 |
+
cur_input_ids_for_llama = None # Using embeds
|
| 254 |
+
else: # Subsequent steps with KV cache
|
| 255 |
+
current_input_embeds = self.llama.model.embed_tokens(current_input_ids).to(self.llama.dtype) # (B, 1, C_llama)
|
| 256 |
+
if retrieved_mem is not None:
|
| 257 |
+
model_inputs_embeds = torch.cat([retrieved_mem_casted, current_input_embeds], dim=1) # (B, 1 + 1, C)
|
| 258 |
+
# Mask for memory + current token
|
| 259 |
+
current_mask = torch.ones((bsz, mem_len + 1), dtype=attention_mask.dtype, device=device) # (B, 1 + 1)
|
| 260 |
+
else:
|
| 261 |
+
model_inputs_embeds = current_input_embeds # (B, 1, C)
|
| 262 |
+
# Mask for current token only
|
| 263 |
+
current_mask = torch.ones((bsz, 1), dtype=attention_mask.dtype, device=device) # (B, 1)
|
| 264 |
+
|
| 265 |
+
# Position ID for the new token(s) relative to KV cache length + memory length
|
| 266 |
+
# LlamaModel._update_causal_mask and cache handling expect position_ids to reflect the absolute position
|
| 267 |
+
# cache_position (passed internally by generate if use_cache) handles this. We construct it manually here.
|
| 268 |
+
# The position id for the *new token* is the current sequence length (including memory if prepended this step)
|
| 269 |
+
past_len = past_key_values.get_seq_length() # Length stored in cache
|
| 270 |
+
# The position_id should reflect where this new token/memory would be in the *full* sequence if no cache was used
|
| 271 |
+
# Let's use current_seq_len derived from generated_ids, which doesn't include memory
|
| 272 |
+
position_ids = torch.tensor([[current_seq_len -1 + i + mem_len for i in range(model_inputs_embeds.shape[1])]], device=device) # (1, M+1) or (1, 1)
|
| 273 |
+
|
| 274 |
+
cur_input_ids_for_llama = None # Using embeds
|
| 275 |
+
|
| 276 |
+
# --- Llama Forward Pass ---
|
| 277 |
+
# Use KV caching if possible (update_rule != 'surprise')
|
| 278 |
+
# We need past_key_values AND not be doing surprise update AND base model supports caching
|
| 279 |
+
use_kv_cache_this_step = past_key_values is not None and update_rule != 'surprise' and self.llama.config.use_cache
|
| 280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
outputs = self.llama(
|
| 282 |
+
input_ids=cur_input_ids_for_llama, # None if using embeds
|
| 283 |
+
inputs_embeds=model_inputs_embeds,
|
| 284 |
+
attention_mask=current_mask, # Pass the correctly shaped mask for this step
|
| 285 |
+
position_ids=position_ids, # Pass adjusted position IDs
|
| 286 |
+
past_key_values=past_key_values,
|
| 287 |
+
use_cache=use_kv_cache_this_step,
|
| 288 |
output_hidden_states=True, # Needed for query/target/update
|
| 289 |
return_dict=True,
|
| 290 |
)
|
| 291 |
|
| 292 |
# --- Associative Loss Calculation (if surprise update) ---
|
| 293 |
if update_rule == 'surprise' and use_memory and retrieved_mem is not None:
|
| 294 |
+
# Target: Final hidden state corresponding to the *last input token* before generation
|
| 295 |
+
# The index needs to account for the prepended memory.
|
| 296 |
+
# If mem_len=1, the target state corresponds to index -1 in the output sequence
|
| 297 |
+
target_repr = outputs.hidden_states[-1][:, -1, :].to(self.memory_buffer.dtype) # (B, C)
|
| 298 |
+
|
| 299 |
+
# pred_repr comes from retrieve_memory, should already match buffer dtype
|
| 300 |
pred_repr = retrieved_mem.squeeze(1) # (B, C)
|
| 301 |
|
| 302 |
+
assoc_loss = F.mse_loss(pred_repr, target_repr.detach())
|
|
|
|
| 303 |
|
|
|
|
|
|
|
| 304 |
if self.memory_buffer.grad is not None:
|
| 305 |
self.memory_buffer.grad.zero_()
|
| 306 |
+
assoc_loss.backward() # Compute grads for memory_buffer
|
| 307 |
+
self.apply_surprise_update() # Apply update and zero grad
|
| 308 |
|
| 309 |
+
# --- Standard Generation Logic ---
|
| 310 |
+
# Get logits for the very last position in the output sequence (corresponds to the token we just fed in)
|
| 311 |
+
next_token_logits = outputs.logits[:, -1, :] # (B, V)
|
| 312 |
|
| 313 |
+
# Update KV cache for next step
|
| 314 |
+
if use_kv_cache_this_step:
|
| 315 |
+
# The past_key_values returned by Llama should account for the memory prepended in this step
|
| 316 |
+
past_key_values = outputs.past_key_values
|
| 317 |
|
|
|
|
|
|
|
|
|
|
| 318 |
|
| 319 |
+
# Sampling (same as before)
|
|
|
|
| 320 |
if repetition_penalty != 1.0:
|
| 321 |
+
# Simple loop for now:
|
| 322 |
+
for i in range(bsz):
|
| 323 |
+
# Penalize tokens in the *generated* sequence (excluding prompt if needed)
|
| 324 |
+
# Use generated_ids which tracks the full sequence
|
| 325 |
+
for token_id in generated_ids[i]:
|
| 326 |
+
# Avoid penalizing pad token if present
|
| 327 |
+
if token_id != pad_token_id:
|
| 328 |
+
next_token_logits[i, token_id] /= repetition_penalty
|
| 329 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
if temperature > 0 and temperature != 1.0:
|
| 331 |
next_token_logits = next_token_logits / temperature
|
|
|
|
| 332 |
if do_sample and top_p < 1.0:
|
| 333 |
+
# Use Hugging Face's top_p implementation detail
|
| 334 |
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
| 335 |
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 336 |
sorted_indices_to_remove = cumulative_probs > top_p
|
|
|
|
| 338 |
sorted_indices_to_remove[..., 0] = 0
|
| 339 |
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 340 |
next_token_logits = next_token_logits.masked_fill(indices_to_remove, float('-inf'))
|
| 341 |
+
|
| 342 |
if do_sample:
|
| 343 |
probs = F.softmax(next_token_logits, dim=-1)
|
| 344 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 345 |
else:
|
| 346 |
+
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
|
|
|
| 347 |
|
| 348 |
# --- Update State ---
|
| 349 |
generated_ids = torch.cat([generated_ids, next_token], dim=1)
|
| 350 |
current_seq_len += 1
|
| 351 |
+
# Update attention mask for the next iteration by appending 1
|
| 352 |
+
attention_mask = torch.cat([attention_mask, torch.ones((bsz, 1), dtype=attention_mask.dtype, device=device)], dim=1)
|
| 353 |
+
|
| 354 |
|
| 355 |
+
# --- EMA Memory Update ---
|
| 356 |
if update_rule == 'ema' and use_memory and outputs.hidden_states is not None:
|
| 357 |
+
# Use hidden state corresponding to the newly generated token position (index -1)
|
| 358 |
+
# Cast state to buffer dtype before update
|
| 359 |
+
new_context_state = outputs.hidden_states[-1][:, -1, :].to(self.memory_buffer.dtype) # (B, C)
|
| 360 |
+
self.update_memory_ema(new_context_state.detach())
|
| 361 |
|
|
|
|
| 362 |
if eos_token_id is not None and (next_token == eos_token_id).all():
|
| 363 |
break
|
| 364 |
|
| 365 |
+
# self.eval() # Already in eval mode if llama is frozen
|
|
|
|
| 366 |
|
| 367 |
return generated_ids
|
| 368 |
|
| 369 |
+
|
| 370 |
+
# --- Save/Load ---
|
| 371 |
+
# Keep the save_pretrained as is, it saves wrapper specific state.
|
| 372 |
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
|
| 373 |
""" Saves the wrapper's specific state (memory buffer, surprise state). """
|
| 374 |
save_directory = Path(save_directory)
|
|
|
|
| 377 |
# Save the base model's config (important for PreTrainedModel compatibility)
|
| 378 |
self.config.save_pretrained(save_directory)
|
| 379 |
|
| 380 |
+
# Save the memory buffer parameter directly
|
| 381 |
+
# Ensure saving in float32 for broader compatibility, can be cast back on load
|
| 382 |
+
# Note: Saving the Parameter itself, not just its .data
|
| 383 |
+
torch.save(self.memory_buffer.float(), save_directory / "memory_buffer.pt")
|
| 384 |
+
# Save the surprise state buffer directly
|
| 385 |
+
torch.save(self.surprise_state.float(), save_directory / "surprise_state.pt")
|
| 386 |
|
| 387 |
print(f"InferenceMemoryWrapper state saved to {save_directory}")
|
| 388 |
# Note: Base Llama model weights are assumed to be saved separately or loaded from source.
|
| 389 |
|
| 390 |
+
# from_pretrained is complex with wrappers. For local testing/handler, load manually.
|
| 391 |
+
# @classmethod
|
| 392 |
+
# def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs):
|
| 393 |
+
# raise NotImplementedError(...)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|