win10 commited on
Commit
21c94ed
·
verified ·
1 Parent(s): f16aa26

Upload architecture.py

Browse files
Files changed (1) hide show
  1. architecture.py +245 -238
architecture.py CHANGED
@@ -1,238 +1,245 @@
1
- # --- START OF FILE architectureV3.py ---
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- from transformers import Phi3Config, Phi3ForCausalLM
7
- from transformers.modeling_outputs import CausalLMOutputWithPast
8
- from typing import Optional, Dict, Tuple
9
- from dataclasses import dataclass
10
-
11
- @dataclass
12
- class CausalLMOutputWithLTM(CausalLMOutputWithPast):
13
- loss: Optional[torch.FloatTensor] = None
14
- logits: torch.FloatTensor = None
15
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
16
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
17
- attentions: Optional[Tuple[torch.FloatTensor]] = None
18
- ltm_state: Optional[torch.Tensor] = None # The returned LTM state
19
-
20
- # --- BUILDING BLOCK 1: Hierarchical VectorMemoryHead (Stateless) ---
21
- class VectorMemoryHead(nn.Module):
22
- def __init__(self, hidden_dim: int, num_memory_slots: int, num_heads: int, ff_dim: int,
23
- num_long_term_memory_slots: int = 0,
24
- device=None, dtype=None):
25
- super().__init__()
26
- self.hidden_dim = hidden_dim
27
- self.num_memory_slots = num_memory_slots
28
- self.num_long_term_memory_slots = num_long_term_memory_slots
29
- self.use_long_term_memory = self.num_long_term_memory_slots > 0
30
-
31
- encoder_layer = nn.TransformerEncoderLayer(
32
- d_model=hidden_dim, nhead=num_heads, dim_feedforward=ff_dim, dropout=0.1, batch_first=True,
33
- device=device, dtype=dtype)
34
- self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
35
- self.memory_queries = nn.Parameter(torch.randn(1, num_memory_slots, hidden_dim, device=device, dtype=dtype))
36
- self.memory_attention = nn.MultiheadAttention(
37
- embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True, device=device, dtype=dtype)
38
- self.memory_layernorm = nn.LayerNorm(hidden_dim, device=device, dtype=dtype)
39
- self.decoder_attention = nn.MultiheadAttention(
40
- embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True, device=device, dtype=dtype)
41
- self.decoder_layernorm = nn.LayerNorm(hidden_dim, device=device, dtype=dtype)
42
- self.decoder_ffn = nn.Sequential(
43
- nn.Linear(hidden_dim, ff_dim, device=device, dtype=dtype), nn.ReLU(),
44
- nn.Linear(ff_dim, hidden_dim, device=device, dtype=dtype))
45
-
46
- if self.use_long_term_memory:
47
- self.memory_update_gate = nn.Sequential(
48
- nn.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype), nn.Sigmoid())
49
- self.ltm_retrieval_attention = nn.MultiheadAttention(
50
- embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True, device=device, dtype=dtype)
51
-
52
- def forward(self, memory_input_sequence: torch.Tensor,
53
- long_term_memory: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
54
- batch_size = memory_input_sequence.shape[0]
55
- new_ltm_state = long_term_memory
56
- queries = self.memory_queries.expand(batch_size, -1, -1)
57
- encoded_vectors = self.encoder(memory_input_sequence)
58
- compressed_memory, _ = self.memory_attention(query=queries, key=encoded_vectors, value=encoded_vectors)
59
- compressed_memory = self.memory_layernorm(compressed_memory + queries)
60
- final_memory_context = compressed_memory
61
-
62
- if self.use_long_term_memory and long_term_memory is not None:
63
- retrieved_ltm, _ = self.ltm_retrieval_attention(
64
- query=compressed_memory, key=long_term_memory, value=long_term_memory)
65
- l1_summary = compressed_memory.mean(dim=1, keepdim=True)
66
- update_gate = self.memory_update_gate(l1_summary)
67
- new_ltm_state = (update_gate * l1_summary) + ((1 - update_gate) * long_term_memory)
68
- final_memory_context = final_memory_context + retrieved_ltm
69
-
70
- reconstructed, _ = self.decoder_attention(query=encoded_vectors, key=final_memory_context, value=final_memory_context)
71
- reconstructed_vectors = self.decoder_layernorm(reconstructed + encoded_vectors)
72
- reconstructed_vectors = self.decoder_ffn(reconstructed_vectors)
73
- return compressed_memory, reconstructed_vectors, new_ltm_state
74
-
75
- # --- BUILDING BLOCK 2: ReflectiveMemoryLayer ---
76
- class ReflectiveMemoryLayer(nn.Module):
77
- def __init__(self, original_layer: nn.Linear, global_input_dim: int,
78
- memory_dim: int, num_memory_slots: int, memory_num_heads: int,
79
- global_state_storage: Dict):
80
- super().__init__()
81
- self.input_dim, self.output_dim = original_layer.in_features, original_layer.out_features
82
- self.memory_dim, self.global_state_storage = memory_dim, global_state_storage
83
- self.linear = original_layer # Keep the original linear layer frozen
84
- self.refinement_passes: int = 2
85
- device, dtype = self.linear.weight.device, self.linear.weight.dtype
86
-
87
- self.local_state_proj = nn.Linear(self.input_dim, memory_dim, device=device, dtype=dtype)
88
- self.global_state_proj = nn.Linear(global_input_dim, memory_dim, device=device, dtype=dtype)
89
- self.memory_head = VectorMemoryHead(
90
- hidden_dim=memory_dim, num_memory_slots=num_memory_slots, num_heads=memory_num_heads,
91
- ff_dim=memory_dim * 2, num_long_term_memory_slots=32, device=device, dtype=dtype)
92
- self.thought_critique_attention = nn.MultiheadAttention(
93
- embed_dim=memory_dim, num_heads=memory_num_heads, dropout=0.1, batch_first=True, device=device, dtype=dtype)
94
- self.thought_layernorm = nn.LayerNorm(memory_dim, device=device, dtype=dtype)
95
- self.correction_head = nn.Linear(memory_dim, 2 * self.output_dim, device=device, dtype=dtype)
96
-
97
- self.last_corrected_activation, self.last_additive_correction = None, None
98
- self.last_memory_input, self.last_reconstructed_from_memory = None, None
99
-
100
- def forward(self, x: torch.Tensor):
101
- base_output = self.linear(x)
102
- if 'embeds' not in self.global_state_storage:
103
- return base_output
104
-
105
- global_embeds = self.global_state_storage['embeds']
106
- if global_embeds.shape[1] != x.shape[1]:
107
- global_embeds = global_embeds[:, -x.shape[1]:, :]
108
- B, S, _ = x.shape
109
-
110
- # CRITICAL FIX: Always detach LTM state to prevent backward through previous graphs
111
- ltm_state = self.global_state_storage.get('ltm', None)
112
- if ltm_state is not None:
113
- ltm_state = ltm_state.detach()
114
-
115
- proj_local = self.local_state_proj(x)
116
- proj_global = self.global_state_proj(global_embeds)
117
- memory_input = torch.stack([proj_global, proj_local], dim=2)
118
- memory_input_flat = memory_input.view(B * S, 2, self.memory_dim)
119
-
120
- # *** FIX: Expand LTM state to match the flattened token dimension (B*S) ***
121
- ltm_state_expanded = None
122
- if ltm_state is not None:
123
- ltm_state_expanded = ltm_state.repeat_interleave(S, dim=0)
124
-
125
- compressed_mem_flat, recon_flat, new_ltm_state_expanded = self.memory_head(memory_input_flat, ltm_state_expanded)
126
-
127
- # *** FIX: Condense updated LTM state back to batch dimension B ***
128
- if new_ltm_state_expanded is not None:
129
- num_ltm_slots = new_ltm_state_expanded.shape[1]
130
- new_ltm_condensed = new_ltm_state_expanded.view(B, S, num_ltm_slots, self.memory_dim).mean(dim=1)
131
- # CRITICAL FIX: Always detach when storing in global state
132
- self.global_state_storage['ltm'] = new_ltm_condensed.detach()
133
-
134
- initial_thought = compressed_mem_flat.mean(dim=1).view(B, S, self.memory_dim)
135
- current_thought = initial_thought
136
- if not self.training and self.refinement_passes > 0:
137
- with torch.no_grad():
138
- for _ in range(self.refinement_passes):
139
- current_thought_flat = current_thought.view(B * S, 1, self.memory_dim)
140
- internal_ref, _ = self.memory_head.decoder_attention(
141
- query=current_thought_flat, key=compressed_mem_flat, value=compressed_mem_flat)
142
- external_crit, _ = self.thought_critique_attention(
143
- query=current_thought_flat, key=memory_input_flat, value=memory_input_flat)
144
- refined_thought = current_thought + internal_ref.view(B,S,-1) + external_crit.view(B,S,-1)
145
- current_thought = self.thought_layernorm(refined_thought)
146
-
147
- thought_for_correction = current_thought if not self.training else initial_thought
148
- raw_correction = self.correction_head(thought_for_correction)
149
- gate, value = torch.chunk(raw_correction, 2, dim=-1)
150
- final_activation = base_output * torch.sigmoid(gate.to(x.dtype)) + value.to(x.dtype)
151
-
152
- if self.training:
153
- # CRITICAL FIX: Detach tensors stored for debugging/analysis
154
- self.last_corrected_activation = final_activation.detach()
155
- self.last_additive_correction = value.detach()
156
- self.last_memory_input = memory_input.detach()
157
- self.last_reconstructed_from_memory = recon_flat.view(B, S, 2, self.memory_dim).detach()
158
- return final_activation
159
-
160
- # --- BUILDING BLOCK 3: The Full Custom Model with State Management ---
161
- class Phi3WithReflectiveMemoryForCausalLM(Phi3ForCausalLM):
162
- def __init__(self, config):
163
- super().__init__(config)
164
- self.global_state_storage = {}
165
- self.target_layer_path = "model.layers.15.mlp.gate_up_proj"
166
- self.memory_dim, self.num_long_term_memory_slots = 256, 32
167
-
168
- # CRITICAL FIX: Ensure embeddings are detached when stored
169
- def embedding_hook(module, input, output):
170
- self.global_state_storage['embeds'] = output.detach()
171
-
172
- self.model.embed_tokens.register_forward_hook(embedding_hook)
173
-
174
- try:
175
- original_layer = self.get_submodule(self.target_layer_path)
176
- custom_layer = ReflectiveMemoryLayer(
177
- original_layer=original_layer, global_input_dim=config.hidden_size,
178
- memory_dim=self.memory_dim, num_memory_slots=32, memory_num_heads=16,
179
- global_state_storage=self.global_state_storage)
180
- parent_path = ".".join(self.target_layer_path.split('.')[:-1])
181
- setattr(self.get_submodule(parent_path), self.target_layer_path.split('.')[-1], custom_layer)
182
- print(f"Successfully replaced '{self.target_layer_path}' with ReflectiveMemoryLayer.")
183
- except AttributeError:
184
- print(f"Could not find target layer '{self.target_layer_path}'. Model remains unmodified.")
185
-
186
- def _init_ltm_state(self, batch_size, device, dtype):
187
- # *** FIX: Initialize LTM state per item in the batch (no hardcoded hack) ***
188
- return torch.zeros(
189
- batch_size, self.num_long_term_memory_slots, self.memory_dim, device=device, dtype=dtype)
190
-
191
- def forward(self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None,
192
- position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[list[torch.FloatTensor]] = None,
193
- inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None,
194
- use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None,
195
- output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None,
196
- ltm_state: Optional[torch.Tensor] = None):
197
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
198
-
199
- # CRITICAL FIX: Don't clear global state storage completely, just reset embeds
200
- # This prevents losing LTM state continuity
201
- if 'embeds' in self.global_state_storage:
202
- del self.global_state_storage['embeds']
203
-
204
- # *** FIX: Initialize LTM state if not provided, for both training and first step of inference ***
205
- if ltm_state is None:
206
- batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
207
- ltm_state = self._init_ltm_state(batch_size, self.device, self.dtype)
208
-
209
- # CRITICAL FIX: Ensure LTM state is detached when stored
210
- self.global_state_storage['ltm'] = ltm_state.detach() if ltm_state is not None else None
211
-
212
- outputs = self.model(
213
- input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids,
214
- past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache,
215
- output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
216
-
217
- hidden_states = outputs[0]
218
- logits = self.lm_head(hidden_states).float()
219
-
220
- loss = None
221
- if labels is not None:
222
- loss_fct = nn.CrossEntropyLoss()
223
- loss = loss_fct(logits[..., :-1, :].contiguous().view(-1, self.config.vocab_size),
224
- labels[..., 1:].contiguous().view(-1))
225
- # Note: Auxiliary losses from main.py are calculated outside the model forward pass.
226
-
227
- # CRITICAL FIX: Ensure returned LTM state is detached
228
- new_ltm_state = self.global_state_storage.get('ltm', None)
229
- if new_ltm_state is not None:
230
- new_ltm_state = new_ltm_state.detach()
231
-
232
- if not return_dict:
233
- output = (logits,) + outputs[1:] + (new_ltm_state,)
234
- return (loss,) + output if loss is not None else output
235
-
236
- return CausalLMOutputWithLTM(
237
- loss=loss, logits=logits, past_key_values=outputs.past_key_values,
238
- hidden_states=outputs.hidden_states, attentions=outputs.attentions, ltm_state=new_ltm_state)
 
 
 
 
 
 
 
 
1
+ # --- START OF FILE architectureV3.py ---
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers import Phi3Config, Phi3ForCausalLM
7
+ from transformers.modeling_outputs import CausalLMOutputWithPast
8
+ from typing import Optional, Dict, Tuple
9
+ from dataclasses import dataclass
10
+
11
+ @dataclass
12
+ class CausalLMOutputWithLTM(CausalLMOutputWithPast):
13
+ loss: Optional[torch.FloatTensor] = None
14
+ logits: torch.FloatTensor = None
15
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
16
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
17
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
18
+ ltm_state: Optional[torch.Tensor] = None # The returned LTM state
19
+
20
+ # --- BUILDING BLOCK 1: Hierarchical VectorMemoryHead (Stateless) ---
21
+ class VectorMemoryHead(nn.Module):
22
+ def __init__(self, hidden_dim: int, num_memory_slots: int, num_heads: int, ff_dim: int,
23
+ num_long_term_memory_slots: int = 0,
24
+ device=None, dtype=None):
25
+ super().__init__()
26
+ self.hidden_dim = hidden_dim
27
+ self.num_memory_slots = num_memory_slots
28
+ self.num_long_term_memory_slots = num_long_term_memory_slots
29
+ self.use_long_term_memory = self.num_long_term_memory_slots > 0
30
+
31
+ encoder_layer = nn.TransformerEncoderLayer(
32
+ d_model=hidden_dim, nhead=num_heads, dim_feedforward=ff_dim, dropout=0.1, batch_first=True,
33
+ device=device, dtype=dtype)
34
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
35
+ self.memory_queries = nn.Parameter(torch.randn(1, num_memory_slots, hidden_dim, device=device, dtype=dtype))
36
+ self.memory_attention = nn.MultiheadAttention(
37
+ embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True, device=device, dtype=dtype)
38
+ self.memory_layernorm = nn.LayerNorm(hidden_dim, device=device, dtype=dtype)
39
+ self.decoder_attention = nn.MultiheadAttention(
40
+ embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True, device=device, dtype=dtype)
41
+ self.decoder_layernorm = nn.LayerNorm(hidden_dim, device=device, dtype=dtype)
42
+ self.decoder_ffn = nn.Sequential(
43
+ nn.Linear(hidden_dim, ff_dim, device=device, dtype=dtype), nn.ReLU(),
44
+ nn.Linear(ff_dim, hidden_dim, device=device, dtype=dtype))
45
+
46
+ if self.use_long_term_memory:
47
+ self.memory_update_gate = nn.Sequential(
48
+ nn.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype), nn.Sigmoid())
49
+ self.ltm_retrieval_attention = nn.MultiheadAttention(
50
+ embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True, device=device, dtype=dtype)
51
+
52
+ def forward(self, memory_input_sequence: torch.Tensor,
53
+ long_term_memory: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
54
+ batch_size = memory_input_sequence.shape[0]
55
+ new_ltm_state = long_term_memory
56
+ queries = self.memory_queries.expand(batch_size, -1, -1)
57
+ encoded_vectors = self.encoder(memory_input_sequence)
58
+ compressed_memory, _ = self.memory_attention(query=queries, key=encoded_vectors, value=encoded_vectors)
59
+ compressed_memory = self.memory_layernorm(compressed_memory + queries)
60
+ final_memory_context = compressed_memory
61
+
62
+ if self.use_long_term_memory and long_term_memory is not None:
63
+ retrieved_ltm, _ = self.ltm_retrieval_attention(
64
+ query=compressed_memory, key=long_term_memory, value=long_term_memory)
65
+ l1_summary = compressed_memory.mean(dim=1, keepdim=True)
66
+ update_gate = self.memory_update_gate(l1_summary)
67
+ new_ltm_state = (update_gate * l1_summary) + ((1 - update_gate) * long_term_memory)
68
+ final_memory_context = final_memory_context + retrieved_ltm
69
+
70
+ reconstructed, _ = self.decoder_attention(query=encoded_vectors, key=final_memory_context, value=final_memory_context)
71
+ reconstructed_vectors = self.decoder_layernorm(reconstructed + encoded_vectors)
72
+ reconstructed_vectors = self.decoder_ffn(reconstructed_vectors)
73
+ return compressed_memory, reconstructed_vectors, new_ltm_state
74
+
75
+ # --- BUILDING BLOCK 2: ReflectiveMemoryLayer ---
76
+ class ReflectiveMemoryLayer(nn.Module):
77
+ def __init__(self, original_layer: nn.Linear, global_input_dim: int,
78
+ memory_dim: int, num_memory_slots: int, memory_num_heads: int,
79
+ global_state_storage: Dict):
80
+ super().__init__()
81
+ self.input_dim, self.output_dim = original_layer.in_features, original_layer.out_features
82
+ self.memory_dim, self.global_state_storage = memory_dim, global_state_storage
83
+ self.linear = original_layer # Keep the original linear layer frozen
84
+ self.refinement_passes: int = 2
85
+ device, dtype = self.linear.weight.device, self.linear.weight.dtype
86
+
87
+ self.local_state_proj = nn.Linear(self.input_dim, memory_dim, device=device, dtype=dtype)
88
+ self.global_state_proj = nn.Linear(global_input_dim, memory_dim, device=device, dtype=dtype)
89
+ self.memory_head = VectorMemoryHead(
90
+ hidden_dim=memory_dim, num_memory_slots=num_memory_slots, num_heads=memory_num_heads,
91
+ ff_dim=memory_dim * 2, num_long_term_memory_slots=32, device=device, dtype=dtype)
92
+ self.thought_critique_attention = nn.MultiheadAttention(
93
+ embed_dim=memory_dim, num_heads=memory_num_heads, dropout=0.1, batch_first=True, device=device, dtype=dtype)
94
+ self.thought_layernorm = nn.LayerNorm(memory_dim, device=device, dtype=dtype)
95
+ self.correction_head = nn.Linear(memory_dim, 2 * self.output_dim, device=device, dtype=dtype)
96
+
97
+ self.last_corrected_activation, self.last_additive_correction = None, None
98
+ self.last_memory_input, self.last_reconstructed_from_memory = None, None
99
+
100
+ def forward(self, x: torch.Tensor):
101
+ base_output = self.linear(x)
102
+ if 'embeds' not in self.global_state_storage:
103
+ return base_output
104
+
105
+ global_embeds = self.global_state_storage['embeds']
106
+ if global_embeds.shape[1] != x.shape[1]:
107
+ global_embeds = global_embeds[:, -x.shape[1]:, :]
108
+ B, S, _ = x.shape
109
+
110
+ # CRITICAL FIX: Always detach LTM state to prevent backward through previous graphs
111
+ ltm_state = self.global_state_storage.get('ltm', None)
112
+ if ltm_state is not None:
113
+ ltm_state = ltm_state.detach()
114
+
115
+ proj_local = self.local_state_proj(x)
116
+ proj_global = self.global_state_proj(global_embeds)
117
+ memory_input = torch.stack([proj_global, proj_local], dim=2)
118
+ memory_input_flat = memory_input.view(B * S, 2, self.memory_dim)
119
+
120
+ # *** FIX: Expand LTM state to match the flattened token dimension (B*S) ***
121
+ ltm_state_expanded = None
122
+ if ltm_state is not None:
123
+ ltm_state_expanded = ltm_state.repeat_interleave(S, dim=0)
124
+
125
+ compressed_mem_flat, recon_flat, new_ltm_state_expanded = self.memory_head(memory_input_flat, ltm_state_expanded)
126
+
127
+ # *** FIX: Condense updated LTM state back to batch dimension B ***
128
+ if new_ltm_state_expanded is not None:
129
+ num_ltm_slots = new_ltm_state_expanded.shape[1]
130
+ new_ltm_condensed = new_ltm_state_expanded.view(B, S, num_ltm_slots, self.memory_dim).mean(dim=1)
131
+ # CRITICAL FIX: Always detach when storing in global state
132
+ self.global_state_storage['ltm'] = new_ltm_condensed.detach()
133
+
134
+ initial_thought = compressed_mem_flat.mean(dim=1).view(B, S, self.memory_dim)
135
+ current_thought = initial_thought
136
+ if not self.training and self.refinement_passes > 0:
137
+ with torch.no_grad():
138
+ for _ in range(self.refinement_passes):
139
+ current_thought_flat = current_thought.view(B * S, 1, self.memory_dim)
140
+ internal_ref, _ = self.memory_head.decoder_attention(
141
+ query=current_thought_flat, key=compressed_mem_flat, value=compressed_mem_flat)
142
+ external_crit, _ = self.thought_critique_attention(
143
+ query=current_thought_flat, key=memory_input_flat, value=memory_input_flat)
144
+ refined_thought = current_thought + internal_ref.view(B,S,-1) + external_crit.view(B,S,-1)
145
+ current_thought = self.thought_layernorm(refined_thought)
146
+
147
+ thought_for_correction = current_thought if not self.training else initial_thought
148
+ raw_correction = self.correction_head(thought_for_correction)
149
+ gate, value = torch.chunk(raw_correction, 2, dim=-1)
150
+ final_activation = base_output * torch.sigmoid(gate.to(x.dtype)) + value.to(x.dtype)
151
+
152
+ if self.training:
153
+ # CRITICAL FIX: Detach tensors stored for debugging/analysis
154
+ self.last_corrected_activation = final_activation.detach()
155
+ self.last_additive_correction = value.detach()
156
+ self.last_memory_input = memory_input.detach()
157
+ self.last_reconstructed_from_memory = recon_flat.view(B, S, 2, self.memory_dim).detach()
158
+ return final_activation
159
+
160
+ # --- BUILDING BLOCK 3: The Full Custom Model with State Management ---
161
+ class Phi3WithReflectiveMemoryForCausalLM(Phi3ForCausalLM):
162
+ def __init__(self, config):
163
+ super().__init__(config)
164
+ self.global_state_storage = {}
165
+ self.target_layer_path = "model.layers.15.mlp.gate_up_proj"
166
+ self.memory_dim, self.num_long_term_memory_slots = 256, 32
167
+
168
+ # CRITICAL FIX: Ensure embeddings are detached when stored
169
+ def embedding_hook(module, input, output):
170
+ self.global_state_storage['embeds'] = output.detach()
171
+
172
+ self.model.embed_tokens.register_forward_hook(embedding_hook)
173
+
174
+ try:
175
+ original_layer = self.get_submodule(self.target_layer_path)
176
+ custom_layer = ReflectiveMemoryLayer(
177
+ original_layer=original_layer, global_input_dim=config.hidden_size,
178
+ memory_dim=self.memory_dim, num_memory_slots=32, memory_num_heads=16,
179
+ global_state_storage=self.global_state_storage)
180
+ parent_path = ".".join(self.target_layer_path.split('.')[:-1])
181
+ setattr(self.get_submodule(parent_path), self.target_layer_path.split('.')[-1], custom_layer)
182
+ print(f"Successfully replaced '{self.target_layer_path}' with ReflectiveMemoryLayer.")
183
+ except AttributeError:
184
+ print(f"Could not find target layer '{self.target_layer_path}'. Model remains unmodified.")
185
+
186
+ def _init_ltm_state(self, batch_size, device, dtype):
187
+ # *** FIX: Initialize LTM state per item in the batch (no hardcoded hack) ***
188
+ return torch.zeros(
189
+ batch_size, self.num_long_term_memory_slots, self.memory_dim, device=device, dtype=dtype)
190
+
191
+ def forward(self, input_ids: torch.LongTensor = None,
192
+ attention_mask: Optional[torch.Tensor] = None,
193
+ position_ids: Optional[torch.LongTensor] = None,
194
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
195
+ inputs_embeds: Optional[torch.FloatTensor] = None,
196
+ labels: Optional[torch.LongTensor] = None,
197
+ use_cache: Optional[bool] = None,
198
+ output_attentions: Optional[bool] = None,
199
+ output_hidden_states: Optional[bool] = None,
200
+ return_dict: Optional[bool] = None,
201
+ cache_position: Optional[torch.LongTensor] = None,
202
+ logits_to_keep: Optional[torch.LongTensor] = None,
203
+ ltm_state: Optional[torch.Tensor] = None):
204
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
205
+
206
+ # CRITICAL FIX: Don't clear global state storage completely, just reset embeds
207
+ # This prevents losing LTM state continuity
208
+ if 'embeds' in self.global_state_storage:
209
+ del self.global_state_storage['embeds']
210
+
211
+ # *** FIX: Initialize LTM state if not provided, for both training and first step of inference ***
212
+ if ltm_state is None:
213
+ batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
214
+ ltm_state = self._init_ltm_state(batch_size, self.device, self.dtype)
215
+
216
+ # CRITICAL FIX: Ensure LTM state is detached when stored
217
+ self.global_state_storage['ltm'] = ltm_state.detach() if ltm_state is not None else None
218
+
219
+ outputs = self.model(
220
+ input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids,
221
+ past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache,
222
+ output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, logits_to_keep=logits_to_keep, return_dict=return_dict)
223
+
224
+ hidden_states = outputs[0]
225
+ logits = self.lm_head(hidden_states).float()
226
+
227
+ loss = None
228
+ if labels is not None:
229
+ loss_fct = nn.CrossEntropyLoss()
230
+ loss = loss_fct(logits[..., :-1, :].contiguous().view(-1, self.config.vocab_size),
231
+ labels[..., 1:].contiguous().view(-1))
232
+ # Note: Auxiliary losses from main.py are calculated outside the model forward pass.
233
+
234
+ # CRITICAL FIX: Ensure returned LTM state is detached
235
+ new_ltm_state = self.global_state_storage.get('ltm', None)
236
+ if new_ltm_state is not None:
237
+ new_ltm_state = new_ltm_state.detach()
238
+
239
+ if not return_dict:
240
+ output = (logits,) + outputs[1:] + (new_ltm_state,)
241
+ return (loss,) + output if loss is not None else output
242
+
243
+ return CausalLMOutputWithLTM(
244
+ loss=loss, logits=logits, past_key_values=outputs.past_key_values,
245
+ hidden_states=outputs.hidden_states, attentions=outputs.attentions, ltm_state=new_ltm_state)