Upload architectureV3.py
#3
by
win10
- opened
- architectureV3.py +33 -13
architectureV3.py
CHANGED
|
@@ -99,13 +99,19 @@ class ReflectiveMemoryLayer(nn.Module):
|
|
| 99 |
|
| 100 |
def forward(self, x: torch.Tensor):
|
| 101 |
base_output = self.linear(x)
|
| 102 |
-
if 'embeds' not in self.global_state_storage:
|
|
|
|
| 103 |
|
| 104 |
global_embeds = self.global_state_storage['embeds']
|
| 105 |
-
if global_embeds.shape[1] != x.shape[1]:
|
|
|
|
| 106 |
B, S, _ = x.shape
|
| 107 |
|
|
|
|
| 108 |
ltm_state = self.global_state_storage.get('ltm', None)
|
|
|
|
|
|
|
|
|
|
| 109 |
proj_local = self.local_state_proj(x)
|
| 110 |
proj_global = self.global_state_proj(global_embeds)
|
| 111 |
memory_input = torch.stack([proj_global, proj_local], dim=2)
|
|
@@ -122,6 +128,7 @@ class ReflectiveMemoryLayer(nn.Module):
|
|
| 122 |
if new_ltm_state_expanded is not None:
|
| 123 |
num_ltm_slots = new_ltm_state_expanded.shape[1]
|
| 124 |
new_ltm_condensed = new_ltm_state_expanded.view(B, S, num_ltm_slots, self.memory_dim).mean(dim=1)
|
|
|
|
| 125 |
self.global_state_storage['ltm'] = new_ltm_condensed.detach()
|
| 126 |
|
| 127 |
initial_thought = compressed_mem_flat.mean(dim=1).view(B, S, self.memory_dim)
|
|
@@ -143,10 +150,11 @@ class ReflectiveMemoryLayer(nn.Module):
|
|
| 143 |
final_activation = base_output * torch.sigmoid(gate.to(x.dtype)) + value.to(x.dtype)
|
| 144 |
|
| 145 |
if self.training:
|
| 146 |
-
|
| 147 |
-
self.
|
| 148 |
-
self.
|
| 149 |
-
self.
|
|
|
|
| 150 |
return final_activation
|
| 151 |
|
| 152 |
# --- BUILDING BLOCK 3: The Full Custom Model with State Management ---
|
|
@@ -155,16 +163,19 @@ class Phi3WithReflectiveMemoryForCausalLM(Phi3ForCausalLM):
|
|
| 155 |
super().__init__(config)
|
| 156 |
self.global_state_storage = {}
|
| 157 |
self.target_layer_path = "model.layers.15.mlp.gate_up_proj"
|
| 158 |
-
self.memory_dim, self.num_long_term_memory_slots =
|
| 159 |
|
| 160 |
-
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
try:
|
| 164 |
original_layer = self.get_submodule(self.target_layer_path)
|
| 165 |
custom_layer = ReflectiveMemoryLayer(
|
| 166 |
original_layer=original_layer, global_input_dim=config.hidden_size,
|
| 167 |
-
memory_dim=self.memory_dim, num_memory_slots=
|
| 168 |
global_state_storage=self.global_state_storage)
|
| 169 |
parent_path = ".".join(self.target_layer_path.split('.')[:-1])
|
| 170 |
setattr(self.get_submodule(parent_path), self.target_layer_path.split('.')[-1], custom_layer)
|
|
@@ -184,13 +195,19 @@ class Phi3WithReflectiveMemoryForCausalLM(Phi3ForCausalLM):
|
|
| 184 |
output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None,
|
| 185 |
ltm_state: Optional[torch.Tensor] = None):
|
| 186 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
# *** FIX: Initialize LTM state if not provided, for both training and first step of inference ***
|
| 190 |
if ltm_state is None:
|
| 191 |
batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
|
| 192 |
ltm_state = self._init_ltm_state(batch_size, self.device, self.dtype)
|
| 193 |
-
|
|
|
|
|
|
|
| 194 |
|
| 195 |
outputs = self.model(
|
| 196 |
input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids,
|
|
@@ -207,7 +224,10 @@ class Phi3WithReflectiveMemoryForCausalLM(Phi3ForCausalLM):
|
|
| 207 |
labels[..., 1:].contiguous().view(-1))
|
| 208 |
# Note: Auxiliary losses from main.py are calculated outside the model forward pass.
|
| 209 |
|
|
|
|
| 210 |
new_ltm_state = self.global_state_storage.get('ltm', None)
|
|
|
|
|
|
|
| 211 |
|
| 212 |
if not return_dict:
|
| 213 |
output = (logits,) + outputs[1:] + (new_ltm_state,)
|
|
@@ -215,4 +235,4 @@ class Phi3WithReflectiveMemoryForCausalLM(Phi3ForCausalLM):
|
|
| 215 |
|
| 216 |
return CausalLMOutputWithLTM(
|
| 217 |
loss=loss, logits=logits, past_key_values=outputs.past_key_values,
|
| 218 |
-
hidden_states=outputs.hidden_states, attentions=outputs.attentions, ltm_state=new_ltm_state)
|
|
|
|
| 99 |
|
| 100 |
def forward(self, x: torch.Tensor):
|
| 101 |
base_output = self.linear(x)
|
| 102 |
+
if 'embeds' not in self.global_state_storage:
|
| 103 |
+
return base_output
|
| 104 |
|
| 105 |
global_embeds = self.global_state_storage['embeds']
|
| 106 |
+
if global_embeds.shape[1] != x.shape[1]:
|
| 107 |
+
global_embeds = global_embeds[:, -x.shape[1]:, :]
|
| 108 |
B, S, _ = x.shape
|
| 109 |
|
| 110 |
+
# CRITICAL FIX: Always detach LTM state to prevent backward through previous graphs
|
| 111 |
ltm_state = self.global_state_storage.get('ltm', None)
|
| 112 |
+
if ltm_state is not None:
|
| 113 |
+
ltm_state = ltm_state.detach()
|
| 114 |
+
|
| 115 |
proj_local = self.local_state_proj(x)
|
| 116 |
proj_global = self.global_state_proj(global_embeds)
|
| 117 |
memory_input = torch.stack([proj_global, proj_local], dim=2)
|
|
|
|
| 128 |
if new_ltm_state_expanded is not None:
|
| 129 |
num_ltm_slots = new_ltm_state_expanded.shape[1]
|
| 130 |
new_ltm_condensed = new_ltm_state_expanded.view(B, S, num_ltm_slots, self.memory_dim).mean(dim=1)
|
| 131 |
+
# CRITICAL FIX: Always detach when storing in global state
|
| 132 |
self.global_state_storage['ltm'] = new_ltm_condensed.detach()
|
| 133 |
|
| 134 |
initial_thought = compressed_mem_flat.mean(dim=1).view(B, S, self.memory_dim)
|
|
|
|
| 150 |
final_activation = base_output * torch.sigmoid(gate.to(x.dtype)) + value.to(x.dtype)
|
| 151 |
|
| 152 |
if self.training:
|
| 153 |
+
# CRITICAL FIX: Detach tensors stored for debugging/analysis
|
| 154 |
+
self.last_corrected_activation = final_activation.detach()
|
| 155 |
+
self.last_additive_correction = value.detach()
|
| 156 |
+
self.last_memory_input = memory_input.detach()
|
| 157 |
+
self.last_reconstructed_from_memory = recon_flat.view(B, S, 2, self.memory_dim).detach()
|
| 158 |
return final_activation
|
| 159 |
|
| 160 |
# --- BUILDING BLOCK 3: The Full Custom Model with State Management ---
|
|
|
|
| 163 |
super().__init__(config)
|
| 164 |
self.global_state_storage = {}
|
| 165 |
self.target_layer_path = "model.layers.15.mlp.gate_up_proj"
|
| 166 |
+
self.memory_dim, self.num_long_term_memory_slots = 128, 32
|
| 167 |
|
| 168 |
+
# CRITICAL FIX: Ensure embeddings are detached when stored
|
| 169 |
+
def embedding_hook(module, input, output):
|
| 170 |
+
self.global_state_storage['embeds'] = output.detach()
|
| 171 |
+
|
| 172 |
+
self.model.embed_tokens.register_forward_hook(embedding_hook)
|
| 173 |
|
| 174 |
try:
|
| 175 |
original_layer = self.get_submodule(self.target_layer_path)
|
| 176 |
custom_layer = ReflectiveMemoryLayer(
|
| 177 |
original_layer=original_layer, global_input_dim=config.hidden_size,
|
| 178 |
+
memory_dim=self.memory_dim, num_memory_slots=16, memory_num_heads=4,
|
| 179 |
global_state_storage=self.global_state_storage)
|
| 180 |
parent_path = ".".join(self.target_layer_path.split('.')[:-1])
|
| 181 |
setattr(self.get_submodule(parent_path), self.target_layer_path.split('.')[-1], custom_layer)
|
|
|
|
| 195 |
output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None,
|
| 196 |
ltm_state: Optional[torch.Tensor] = None):
|
| 197 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 198 |
+
|
| 199 |
+
# CRITICAL FIX: Don't clear global state storage completely, just reset embeds
|
| 200 |
+
# This prevents losing LTM state continuity
|
| 201 |
+
if 'embeds' in self.global_state_storage:
|
| 202 |
+
del self.global_state_storage['embeds']
|
| 203 |
|
| 204 |
# *** FIX: Initialize LTM state if not provided, for both training and first step of inference ***
|
| 205 |
if ltm_state is None:
|
| 206 |
batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
|
| 207 |
ltm_state = self._init_ltm_state(batch_size, self.device, self.dtype)
|
| 208 |
+
|
| 209 |
+
# CRITICAL FIX: Ensure LTM state is detached when stored
|
| 210 |
+
self.global_state_storage['ltm'] = ltm_state.detach() if ltm_state is not None else None
|
| 211 |
|
| 212 |
outputs = self.model(
|
| 213 |
input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids,
|
|
|
|
| 224 |
labels[..., 1:].contiguous().view(-1))
|
| 225 |
# Note: Auxiliary losses from main.py are calculated outside the model forward pass.
|
| 226 |
|
| 227 |
+
# CRITICAL FIX: Ensure returned LTM state is detached
|
| 228 |
new_ltm_state = self.global_state_storage.get('ltm', None)
|
| 229 |
+
if new_ltm_state is not None:
|
| 230 |
+
new_ltm_state = new_ltm_state.detach()
|
| 231 |
|
| 232 |
if not return_dict:
|
| 233 |
output = (logits,) + outputs[1:] + (new_ltm_state,)
|
|
|
|
| 235 |
|
| 236 |
return CausalLMOutputWithLTM(
|
| 237 |
loss=loss, logits=logits, past_key_values=outputs.past_key_values,
|
| 238 |
+
hidden_states=outputs.hidden_states, attentions=outputs.attentions, ltm_state=new_ltm_state)
|