|
|
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from transformers import Phi3Config, Phi3ForCausalLM
|
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
|
from typing import Optional, Dict, Tuple
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
@dataclass
|
|
|
class CausalLMOutputWithLTM(CausalLMOutputWithPast):
|
|
|
loss: Optional[torch.FloatTensor] = None
|
|
|
logits: torch.FloatTensor = None
|
|
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
|
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
|
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
|
|
ltm_state: Optional[torch.Tensor] = None
|
|
|
|
|
|
|
|
|
class VectorMemoryHead(nn.Module):
|
|
|
def __init__(self, hidden_dim: int, num_memory_slots: int, num_heads: int, ff_dim: int,
|
|
|
num_long_term_memory_slots: int = 0,
|
|
|
device=None, dtype=None):
|
|
|
super().__init__()
|
|
|
self.hidden_dim = hidden_dim
|
|
|
self.num_memory_slots = num_memory_slots
|
|
|
self.num_long_term_memory_slots = num_long_term_memory_slots
|
|
|
self.use_long_term_memory = self.num_long_term_memory_slots > 0
|
|
|
|
|
|
encoder_layer = nn.TransformerEncoderLayer(
|
|
|
d_model=hidden_dim, nhead=num_heads, dim_feedforward=ff_dim, dropout=0.1, batch_first=True,
|
|
|
device=device, dtype=dtype)
|
|
|
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
|
|
|
self.memory_queries = nn.Parameter(torch.randn(1, num_memory_slots, hidden_dim, device=device, dtype=dtype))
|
|
|
self.memory_attention = nn.MultiheadAttention(
|
|
|
embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True, device=device, dtype=dtype)
|
|
|
self.memory_layernorm = nn.LayerNorm(hidden_dim, device=device, dtype=dtype)
|
|
|
self.decoder_attention = nn.MultiheadAttention(
|
|
|
embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True, device=device, dtype=dtype)
|
|
|
self.decoder_layernorm = nn.LayerNorm(hidden_dim, device=device, dtype=dtype)
|
|
|
self.decoder_ffn = nn.Sequential(
|
|
|
nn.Linear(hidden_dim, ff_dim, device=device, dtype=dtype), nn.ReLU(),
|
|
|
nn.Linear(ff_dim, hidden_dim, device=device, dtype=dtype))
|
|
|
|
|
|
if self.use_long_term_memory:
|
|
|
self.memory_update_gate = nn.Sequential(
|
|
|
nn.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype), nn.Sigmoid())
|
|
|
self.ltm_retrieval_attention = nn.MultiheadAttention(
|
|
|
embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True, device=device, dtype=dtype)
|
|
|
|
|
|
def forward(self, memory_input_sequence: torch.Tensor,
|
|
|
long_term_memory: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
|
|
batch_size = memory_input_sequence.shape[0]
|
|
|
new_ltm_state = long_term_memory
|
|
|
queries = self.memory_queries.expand(batch_size, -1, -1)
|
|
|
encoded_vectors = self.encoder(memory_input_sequence)
|
|
|
compressed_memory, _ = self.memory_attention(query=queries, key=encoded_vectors, value=encoded_vectors)
|
|
|
compressed_memory = self.memory_layernorm(compressed_memory + queries)
|
|
|
final_memory_context = compressed_memory
|
|
|
|
|
|
if self.use_long_term_memory and long_term_memory is not None:
|
|
|
retrieved_ltm, _ = self.ltm_retrieval_attention(
|
|
|
query=compressed_memory, key=long_term_memory, value=long_term_memory)
|
|
|
l1_summary = compressed_memory.mean(dim=1, keepdim=True)
|
|
|
update_gate = self.memory_update_gate(l1_summary)
|
|
|
new_ltm_state = (update_gate * l1_summary) + ((1 - update_gate) * long_term_memory)
|
|
|
final_memory_context = final_memory_context + retrieved_ltm
|
|
|
|
|
|
reconstructed, _ = self.decoder_attention(query=encoded_vectors, key=final_memory_context, value=final_memory_context)
|
|
|
reconstructed_vectors = self.decoder_layernorm(reconstructed + encoded_vectors)
|
|
|
reconstructed_vectors = self.decoder_ffn(reconstructed_vectors)
|
|
|
return compressed_memory, reconstructed_vectors, new_ltm_state
|
|
|
|
|
|
|
|
|
class ReflectiveMemoryLayer(nn.Module):
|
|
|
def __init__(self, original_layer: nn.Linear, global_input_dim: int,
|
|
|
memory_dim: int, num_memory_slots: int, memory_num_heads: int,
|
|
|
global_state_storage: Dict):
|
|
|
super().__init__()
|
|
|
self.input_dim, self.output_dim = original_layer.in_features, original_layer.out_features
|
|
|
self.memory_dim, self.global_state_storage = memory_dim, global_state_storage
|
|
|
self.linear = original_layer
|
|
|
self.refinement_passes: int = 2
|
|
|
device, dtype = self.linear.weight.device, self.linear.weight.dtype
|
|
|
|
|
|
self.local_state_proj = nn.Linear(self.input_dim, memory_dim, device=device, dtype=dtype)
|
|
|
self.global_state_proj = nn.Linear(global_input_dim, memory_dim, device=device, dtype=dtype)
|
|
|
self.memory_head = VectorMemoryHead(
|
|
|
hidden_dim=memory_dim, num_memory_slots=num_memory_slots, num_heads=memory_num_heads,
|
|
|
ff_dim=memory_dim * 2, num_long_term_memory_slots=32, device=device, dtype=dtype)
|
|
|
self.thought_critique_attention = nn.MultiheadAttention(
|
|
|
embed_dim=memory_dim, num_heads=memory_num_heads, dropout=0.1, batch_first=True, device=device, dtype=dtype)
|
|
|
self.thought_layernorm = nn.LayerNorm(memory_dim, device=device, dtype=dtype)
|
|
|
self.correction_head = nn.Linear(memory_dim, 2 * self.output_dim, device=device, dtype=dtype)
|
|
|
|
|
|
self.last_corrected_activation, self.last_additive_correction = None, None
|
|
|
self.last_memory_input, self.last_reconstructed_from_memory = None, None
|
|
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
|
base_output = self.linear(x)
|
|
|
if 'embeds' not in self.global_state_storage:
|
|
|
return base_output
|
|
|
|
|
|
global_embeds = self.global_state_storage['embeds']
|
|
|
if global_embeds.shape[1] != x.shape[1]:
|
|
|
global_embeds = global_embeds[:, -x.shape[1]:, :]
|
|
|
B, S, _ = x.shape
|
|
|
|
|
|
|
|
|
ltm_state = self.global_state_storage.get('ltm', None)
|
|
|
if ltm_state is not None:
|
|
|
ltm_state = ltm_state.detach()
|
|
|
|
|
|
proj_local = self.local_state_proj(x)
|
|
|
proj_global = self.global_state_proj(global_embeds)
|
|
|
memory_input = torch.stack([proj_global, proj_local], dim=2)
|
|
|
memory_input_flat = memory_input.view(B * S, 2, self.memory_dim)
|
|
|
|
|
|
|
|
|
ltm_state_expanded = None
|
|
|
if ltm_state is not None:
|
|
|
ltm_state_expanded = ltm_state.repeat_interleave(S, dim=0)
|
|
|
|
|
|
compressed_mem_flat, recon_flat, new_ltm_state_expanded = self.memory_head(memory_input_flat, ltm_state_expanded)
|
|
|
|
|
|
|
|
|
if new_ltm_state_expanded is not None:
|
|
|
num_ltm_slots = new_ltm_state_expanded.shape[1]
|
|
|
new_ltm_condensed = new_ltm_state_expanded.view(B, S, num_ltm_slots, self.memory_dim).mean(dim=1)
|
|
|
|
|
|
self.global_state_storage['ltm'] = new_ltm_condensed.detach()
|
|
|
|
|
|
initial_thought = compressed_mem_flat.mean(dim=1).view(B, S, self.memory_dim)
|
|
|
current_thought = initial_thought
|
|
|
if not self.training and self.refinement_passes > 0:
|
|
|
with torch.no_grad():
|
|
|
for _ in range(self.refinement_passes):
|
|
|
current_thought_flat = current_thought.view(B * S, 1, self.memory_dim)
|
|
|
internal_ref, _ = self.memory_head.decoder_attention(
|
|
|
query=current_thought_flat, key=compressed_mem_flat, value=compressed_mem_flat)
|
|
|
external_crit, _ = self.thought_critique_attention(
|
|
|
query=current_thought_flat, key=memory_input_flat, value=memory_input_flat)
|
|
|
refined_thought = current_thought + internal_ref.view(B,S,-1) + external_crit.view(B,S,-1)
|
|
|
current_thought = self.thought_layernorm(refined_thought)
|
|
|
|
|
|
thought_for_correction = current_thought if not self.training else initial_thought
|
|
|
raw_correction = self.correction_head(thought_for_correction)
|
|
|
gate, value = torch.chunk(raw_correction, 2, dim=-1)
|
|
|
final_activation = base_output * torch.sigmoid(gate.to(x.dtype)) + value.to(x.dtype)
|
|
|
|
|
|
if self.training:
|
|
|
|
|
|
self.last_corrected_activation = final_activation.detach()
|
|
|
self.last_additive_correction = value.detach()
|
|
|
self.last_memory_input = memory_input.detach()
|
|
|
self.last_reconstructed_from_memory = recon_flat.view(B, S, 2, self.memory_dim).detach()
|
|
|
return final_activation
|
|
|
|
|
|
|
|
|
class Phi3WithReflectiveMemoryForCausalLM(Phi3ForCausalLM):
|
|
|
def __init__(self, config):
|
|
|
super().__init__(config)
|
|
|
self.global_state_storage = {}
|
|
|
self.target_layer_path = "model.layers.15.mlp.gate_up_proj"
|
|
|
self.memory_dim, self.num_long_term_memory_slots = 256, 32
|
|
|
|
|
|
|
|
|
def embedding_hook(module, input, output):
|
|
|
self.global_state_storage['embeds'] = output.detach()
|
|
|
|
|
|
self.model.embed_tokens.register_forward_hook(embedding_hook)
|
|
|
|
|
|
try:
|
|
|
original_layer = self.get_submodule(self.target_layer_path)
|
|
|
custom_layer = ReflectiveMemoryLayer(
|
|
|
original_layer=original_layer, global_input_dim=config.hidden_size,
|
|
|
memory_dim=self.memory_dim, num_memory_slots=32, memory_num_heads=16,
|
|
|
global_state_storage=self.global_state_storage)
|
|
|
parent_path = ".".join(self.target_layer_path.split('.')[:-1])
|
|
|
setattr(self.get_submodule(parent_path), self.target_layer_path.split('.')[-1], custom_layer)
|
|
|
print(f"Successfully replaced '{self.target_layer_path}' with ReflectiveMemoryLayer.")
|
|
|
except AttributeError:
|
|
|
print(f"Could not find target layer '{self.target_layer_path}'. Model remains unmodified.")
|
|
|
|
|
|
def _init_ltm_state(self, batch_size, device, dtype):
|
|
|
|
|
|
return torch.zeros(
|
|
|
batch_size, self.num_long_term_memory_slots, self.memory_dim, device=device, dtype=dtype)
|
|
|
|
|
|
def forward(self, input_ids: torch.LongTensor = None,
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
|
labels: Optional[torch.LongTensor] = None,
|
|
|
use_cache: Optional[bool] = None,
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
return_dict: Optional[bool] = None,
|
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
|
logits_to_keep: Optional[torch.LongTensor] = None,
|
|
|
ltm_state: Optional[torch.Tensor] = None):
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
|
|
|
|
|
|
if 'embeds' in self.global_state_storage:
|
|
|
del self.global_state_storage['embeds']
|
|
|
|
|
|
|
|
|
if ltm_state is None:
|
|
|
batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
|
|
|
ltm_state = self._init_ltm_state(batch_size, self.device, self.dtype)
|
|
|
|
|
|
|
|
|
self.global_state_storage['ltm'] = ltm_state.detach() if ltm_state is not None else None
|
|
|
|
|
|
outputs = self.model(
|
|
|
input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids,
|
|
|
past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache,
|
|
|
output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, logits_to_keep=logits_to_keep, return_dict=return_dict)
|
|
|
|
|
|
hidden_states = outputs[0]
|
|
|
logits = self.lm_head(hidden_states).float()
|
|
|
|
|
|
loss = None
|
|
|
if labels is not None:
|
|
|
loss_fct = nn.CrossEntropyLoss()
|
|
|
loss = loss_fct(logits[..., :-1, :].contiguous().view(-1, self.config.vocab_size),
|
|
|
labels[..., 1:].contiguous().view(-1))
|
|
|
|
|
|
|
|
|
|
|
|
new_ltm_state = self.global_state_storage.get('ltm', None)
|
|
|
if new_ltm_state is not None:
|
|
|
new_ltm_state = new_ltm_state.detach()
|
|
|
|
|
|
if not return_dict:
|
|
|
output = (logits,) + outputs[1:] + (new_ltm_state,)
|
|
|
return (loss,) + output if loss is not None else output
|
|
|
|
|
|
return CausalLMOutputWithLTM(
|
|
|
loss=loss, logits=logits, past_key_values=outputs.past_key_values,
|
|
|
hidden_states=outputs.hidden_states, attentions=outputs.attentions, ltm_state=new_ltm_state)
|
|
|
|