|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
from transformers import LlamaForCausalLM, LlamaConfig, PreTrainedModel |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
from transformers.cache_utils import Cache |
|
|
from typing import Optional, List, Tuple, Union |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InferenceMemoryWrapper(PreTrainedModel): |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, llama_model: LlamaForCausalLM, memory_size: int = 4096, num_retrieved: int = 1, update_alpha: float = 0.1, surprise_momentum: float = 0.9, surprise_lr: float = 0.01): |
|
|
super().__init__(llama_model.config) |
|
|
self.llama = llama_model |
|
|
|
|
|
|
|
|
self.memory_size = memory_size |
|
|
self.num_retrieved = num_retrieved |
|
|
self.update_alpha = update_alpha |
|
|
self.surprise_momentum_eta = surprise_momentum |
|
|
self.surprise_lr_theta = surprise_lr |
|
|
self.dim = llama_model.config.hidden_size |
|
|
self._target_dtype = llama_model.dtype |
|
|
|
|
|
|
|
|
|
|
|
init_buffer_data = torch.zeros(self.memory_size, self.dim, dtype=self._target_dtype) |
|
|
|
|
|
nn.init.normal_(init_buffer_data, mean=0.0, std=1 / math.sqrt(self.dim)) |
|
|
|
|
|
self.memory_buffer = nn.Parameter(init_buffer_data) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
init_surprise_state = torch.zeros_like(self.memory_buffer.data, dtype=self._target_dtype) |
|
|
self.register_buffer("surprise_state", init_surprise_state) |
|
|
|
|
|
|
|
|
|
|
|
for param in self.llama.parameters(): |
|
|
param.requires_grad = False |
|
|
self.llama.eval() |
|
|
|
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.llama.get_input_embeddings() |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.llama.set_input_embeddings(value) |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.llama.get_output_embeddings() |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
self.llama.set_output_embeddings(new_embeddings) |
|
|
|
|
|
|
|
|
def retrieve_memory(self, query_input: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Retrieves memory using differentiable attention based on query_input. |
|
|
Args: |
|
|
query_input (torch.Tensor): Query tensor. Shape (B, C). |
|
|
Returns: |
|
|
torch.Tensor: Retrieved memory embedding (weighted sum). Shape (B, 1, C) |
|
|
""" |
|
|
|
|
|
q = query_input.to(self.memory_buffer.dtype) |
|
|
|
|
|
|
|
|
|
|
|
mem_keys = self.memory_buffer |
|
|
mem_values = self.memory_buffer |
|
|
|
|
|
|
|
|
attn_scores = torch.matmul(q, mem_keys.T) / math.sqrt(self.dim) |
|
|
attn_weights = torch.softmax(attn_scores, dim=-1) |
|
|
|
|
|
|
|
|
retrieved_mem = torch.matmul(attn_weights, mem_values) |
|
|
|
|
|
return retrieved_mem.unsqueeze(1) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def apply_surprise_update(self): |
|
|
""" Applies the TITANS-style surprise update rule using self.memory_buffer.grad """ |
|
|
if self.memory_buffer.grad is None: |
|
|
print("DEBUG: apply_surprise_update called but memory_buffer.grad is None.") |
|
|
return |
|
|
|
|
|
|
|
|
self.surprise_state = self.surprise_state.to(device=self.memory_buffer.device, dtype=self.memory_buffer.dtype) |
|
|
|
|
|
|
|
|
surprise_update_val = -self.surprise_lr_theta * self.memory_buffer.grad.data |
|
|
self.surprise_state.mul_(self.surprise_momentum_eta).add_(surprise_update_val) |
|
|
|
|
|
self.memory_buffer.data.add_(self.surprise_state) |
|
|
self.memory_buffer.grad.zero_() |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def update_memory_ema(self, new_context_embedding: torch.Tensor): |
|
|
""" Updates the memory buffer using EMA. """ |
|
|
|
|
|
update_vec_float = new_context_embedding.mean(dim=0, keepdim=True) if new_context_embedding.shape[0] > 1 else new_context_embedding |
|
|
update_vec = update_vec_float.to(self.memory_buffer.dtype) |
|
|
|
|
|
|
|
|
self.memory_buffer.data = self.memory_buffer.data.to(update_vec.device) |
|
|
self.memory_buffer.data.mul_(1 - self.update_alpha).add_(update_vec * self.update_alpha) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
**kwargs, |
|
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return self.llama( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
labels=labels, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
cache_position=cache_position, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def generate( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
max_new_tokens: int = 20, |
|
|
num_beams: int = 1, |
|
|
use_memory: bool = True, |
|
|
update_rule: str = 'ema', |
|
|
temperature: float = 0.7, |
|
|
top_p: float = 0.95, |
|
|
do_sample: bool = True, |
|
|
repetition_penalty: float = 1.0, |
|
|
eos_token_id: Optional[int] = None, |
|
|
pad_token_id: Optional[int] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
**kwargs, |
|
|
) -> torch.LongTensor: |
|
|
if num_beams != 1: |
|
|
raise NotImplementedError("Beam search not implemented.") |
|
|
if update_rule == 'surprise' and not use_memory: |
|
|
print("Warning: update_rule='surprise' requires use_memory=True.") |
|
|
update_rule = 'none' |
|
|
|
|
|
|
|
|
original_requires_grad = self.memory_buffer.requires_grad |
|
|
if update_rule == 'surprise': |
|
|
self.memory_buffer.requires_grad_(True) |
|
|
print(f"DEBUG: Set memory_buffer.requires_grad = {self.memory_buffer.requires_grad}") |
|
|
else: |
|
|
self.memory_buffer.requires_grad_(False) |
|
|
|
|
|
|
|
|
bsz, seq_len_start = input_ids.shape |
|
|
device = input_ids.device |
|
|
generated_ids = input_ids.clone() |
|
|
current_seq_len = seq_len_start |
|
|
|
|
|
expected_dtype = self.memory_buffer.dtype |
|
|
|
|
|
if eos_token_id is None: eos_token_id = self.config.eos_token_id |
|
|
if pad_token_id is None: pad_token_id = self.config.pad_token_id |
|
|
|
|
|
past_key_values = None |
|
|
|
|
|
|
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones_like(input_ids) |
|
|
|
|
|
for step in range(max_new_tokens): |
|
|
|
|
|
|
|
|
if past_key_values is not None: |
|
|
current_input_ids = generated_ids[:, -1:] |
|
|
|
|
|
|
|
|
|
|
|
full_embeds = self.llama.model.embed_tokens(generated_ids) |
|
|
|
|
|
query_basis = full_embeds[:, -1, :].to(expected_dtype) |
|
|
else: |
|
|
current_input_ids = generated_ids |
|
|
inputs_embeds_full = self.llama.model.embed_tokens(current_input_ids) |
|
|
|
|
|
query_basis = inputs_embeds_full[:, -1, :].to(expected_dtype) |
|
|
|
|
|
|
|
|
|
|
|
retrieved_mem = None |
|
|
if use_memory: |
|
|
|
|
|
retrieved_mem = self.retrieve_memory(query_basis) |
|
|
|
|
|
|
|
|
|
|
|
current_mask = None |
|
|
mem_len = 0 |
|
|
if retrieved_mem is not None: |
|
|
retrieved_mem_casted = retrieved_mem.to(self.llama.dtype) |
|
|
mem_len = retrieved_mem_casted.shape[1] |
|
|
|
|
|
if past_key_values is None: |
|
|
inputs_embeds_full_casted = inputs_embeds_full.to(self.llama.dtype) |
|
|
if retrieved_mem is not None: |
|
|
model_inputs_embeds = torch.cat([retrieved_mem_casted, inputs_embeds_full_casted], dim=1) |
|
|
|
|
|
mem_mask = torch.ones((bsz, mem_len), dtype=attention_mask.dtype, device=device) |
|
|
current_mask = torch.cat([mem_mask, attention_mask], dim=1) |
|
|
else: |
|
|
model_inputs_embeds = inputs_embeds_full_casted |
|
|
current_mask = attention_mask |
|
|
|
|
|
effective_seq_len = model_inputs_embeds.shape[1] |
|
|
position_ids = torch.arange(effective_seq_len, device=device).unsqueeze(0) |
|
|
cur_input_ids_for_llama = None |
|
|
else: |
|
|
current_input_embeds = self.llama.model.embed_tokens(current_input_ids).to(self.llama.dtype) |
|
|
if retrieved_mem is not None: |
|
|
model_inputs_embeds = torch.cat([retrieved_mem_casted, current_input_embeds], dim=1) |
|
|
|
|
|
current_mask = torch.ones((bsz, mem_len + 1), dtype=attention_mask.dtype, device=device) |
|
|
else: |
|
|
model_inputs_embeds = current_input_embeds |
|
|
|
|
|
current_mask = torch.ones((bsz, 1), dtype=attention_mask.dtype, device=device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
past_len = past_key_values.get_seq_length() |
|
|
|
|
|
|
|
|
position_ids = torch.tensor([[current_seq_len -1 + i + mem_len for i in range(model_inputs_embeds.shape[1])]], device=device) |
|
|
|
|
|
cur_input_ids_for_llama = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
use_kv_cache_this_step = past_key_values is not None and update_rule != 'surprise' and self.llama.config.use_cache |
|
|
|
|
|
|
|
|
context = torch.enable_grad() if update_rule == 'surprise' else torch.no_grad() |
|
|
with context: |
|
|
outputs = self.llama( |
|
|
input_ids=cur_input_ids_for_llama, |
|
|
inputs_embeds=model_inputs_embeds, |
|
|
attention_mask=current_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
use_cache=use_kv_cache_this_step, |
|
|
output_hidden_states=True, |
|
|
return_dict=True, |
|
|
) |
|
|
|
|
|
|
|
|
if update_rule == 'surprise' and use_memory and retrieved_mem is not None: |
|
|
|
|
|
|
|
|
|
|
|
target_repr = outputs.hidden_states[-1][:, -1, :].to(self.memory_buffer.dtype) |
|
|
|
|
|
|
|
|
pred_repr = retrieved_mem.squeeze(1) |
|
|
|
|
|
|
|
|
print(f"\n--- Surprise Update Debug (Step {step}) ---") |
|
|
print(f" memory_buffer requires_grad: {self.memory_buffer.requires_grad}") |
|
|
print(f" retrieved_mem requires_grad: {retrieved_mem.requires_grad if retrieved_mem is not None else 'N/A'}") |
|
|
print(f" pred_repr requires_grad: {pred_repr.requires_grad if pred_repr is not None else 'N/A'}") |
|
|
print(f" target_repr requires_grad: {target_repr.requires_grad}") |
|
|
|
|
|
|
|
|
assoc_loss = F.mse_loss(pred_repr, target_repr.detach()) |
|
|
print(f" assoc_loss: {assoc_loss.item():.4f}, requires_grad: {assoc_loss.requires_grad}") |
|
|
|
|
|
|
|
|
if self.memory_buffer.grad is not None: |
|
|
print(" Zeroing existing memory_buffer gradient.") |
|
|
self.memory_buffer.grad.zero_() |
|
|
|
|
|
if assoc_loss.requires_grad: |
|
|
print(" Calling assoc_loss.backward()") |
|
|
assoc_loss.backward() |
|
|
print(f" memory_buffer.grad is None after backward: {self.memory_buffer.grad is None}") |
|
|
if self.memory_buffer.grad is not None: |
|
|
print(f" memory_buffer.grad norm: {torch.norm(self.memory_buffer.grad).item():.4f}") |
|
|
self.apply_surprise_update() |
|
|
else: |
|
|
print(" ERROR: assoc_loss does not require grad! Skipping backward and update.") |
|
|
print("--- End Surprise Update Debug ---") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
|
|
|
|
|
if use_kv_cache_this_step: |
|
|
|
|
|
past_key_values = outputs.past_key_values |
|
|
|
|
|
|
|
|
|
|
|
if repetition_penalty != 1.0: |
|
|
|
|
|
for i in range(bsz): |
|
|
|
|
|
|
|
|
for token_id in generated_ids[i]: |
|
|
|
|
|
if token_id != pad_token_id: |
|
|
next_token_logits[i, token_id] /= repetition_penalty |
|
|
|
|
|
if temperature > 0 and temperature != 1.0: |
|
|
next_token_logits = next_token_logits / temperature |
|
|
if do_sample and top_p < 1.0: |
|
|
|
|
|
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) |
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
|
next_token_logits = next_token_logits.masked_fill(indices_to_remove, float('-inf')) |
|
|
|
|
|
if do_sample: |
|
|
probs = F.softmax(next_token_logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
else: |
|
|
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
generated_ids = torch.cat([generated_ids, next_token], dim=1) |
|
|
current_seq_len += 1 |
|
|
|
|
|
attention_mask = torch.cat([attention_mask, torch.ones((bsz, 1), dtype=attention_mask.dtype, device=device)], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
if update_rule == 'ema' and use_memory and outputs.hidden_states is not None: |
|
|
|
|
|
|
|
|
new_context_state = outputs.hidden_states[-1][:, -1, :].to(self.memory_buffer.dtype) |
|
|
self.update_memory_ema(new_context_state.detach()) |
|
|
|
|
|
if eos_token_id is not None and (next_token == eos_token_id).all(): |
|
|
break |
|
|
|
|
|
|
|
|
self.memory_buffer.requires_grad_(original_requires_grad) |
|
|
|
|
|
return generated_ids |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): |
|
|
""" Saves the wrapper's specific state (memory buffer, surprise state). """ |
|
|
save_directory = Path(save_directory) |
|
|
save_directory.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
self.config.save_pretrained(save_directory) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.save(self.memory_buffer.float(), save_directory / "memory_buffer.pt") |
|
|
|
|
|
torch.save(self.surprise_state.float(), save_directory / "surprise_state.pt") |
|
|
|
|
|
print(f"InferenceMemoryWrapper state saved to {save_directory}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|