File size: 10,660 Bytes
9469618
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from config_physics import Config

# ============================================================================
# 1. THE CONTROLLER (The "Brain's Brain")
# ============================================================================
class PhysicsController(nn.Module):
    """
    RL Policy Network.
    Observes the input state (hidden state from LLM) and outputs
    a 'Modulation Vector' that adjusts the Flux Layers.
    """
    def __init__(self, input_dim, hidden_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Tanh() # Actions are shifts in [-1, 1] or scales
        )
        
    def forward(self, x):
        # x: [Batch, Seq, Dim] -> Pool to [Batch, Dim] for global context
        # OR per-token modulation? Let's do Global Context for stability first.
        x_pooled = x.mean(dim=1) 
        action = self.net(x_pooled)
        return action

# ============================================================================
# 2. DYNAMIC FLUX LAYER (Modulated FFN)
# ============================================================================
class FluxAdapter(nn.Module):
    """
    Injects into standard Linear layers.
    Weight = W_base + (Modulation * W_lora)
    This is effectively 'Dynamic LoRA'.
    """
    def __init__(self, original_layer, modulation_dim):
        super().__init__()
        self.base_layer = original_layer
        self.in_features = original_layer.in_features
        self.out_features = original_layer.out_features
        
        # LoRA-style adapters
        self.lora_A = nn.Parameter(torch.randn(self.in_features, modulation_dim) * 0.02) # Boosted form 0.01
        self.lora_B = nn.Parameter(torch.zeros(modulation_dim, self.out_features))
        
        self.scaling = 4.0 # Standard LoRA scaling factor , previously it was 20
        
        # The modulation input comes from the Controller
        # self.modulation_proj = nn.Linear(Config.MODULATION_DIM, modulation_dim)
        # nn.init.zeros_(self.modulation_proj.weight)
        # nn.init.constant_(self.modulation_proj.bias, 1.0) # FIX: Start at 1.0 to enable gradient flow
        
        self.modulation_proj = nn.Linear(Config.MODULATION_DIM, modulation_dim)
        nn.init.zeros_(self.modulation_proj.weight)
        nn.init.constant_(self.modulation_proj.bias, 1.0) # Enable Flow!
        print(f"✅ FluxAdapter Init: Bias Norm = {self.modulation_proj.bias.norm().item()} (Flow Enabled)")

        # self.debug = False # Debug Flag

    def forward(self, x, modulation_vector=None):
        # 1. Base Pass
        out_base = self.base_layer(x)
        
        # Ensure x matches adapter dtype (Float32)
        x = x.to(self.lora_A.dtype)
        
        # Check instance state if arg is missing
        if modulation_vector is None:
            if hasattr(self, 'active_modulation'):
                modulation_vector = self.active_modulation
            else:
                return out_base
            
        # 2. Dynamic Adapter Pass
        # self.modulation_proj: [Global_Dim -> Local_Dim]
        layer_scale = self.modulation_proj(modulation_vector) # [Batch, Rank]
        
        # x: [Batch, Seq, In]
        # A: [In, Rank]
        low_rank = x @ self.lora_A # [Batch, Seq, Rank]
        
        # Apply modulation
        # [Batch, Seq, Rank] * [Batch, 1, Rank]
        # Broadcasing: layer_scale is [Batch, Rank]. We need to unsqueeze to match Seq.
        # If modulation_vector is [Batch, Dim], layer_scale is [Batch, Dim].
        if layer_scale.dim() == 2:
             layer_scale = layer_scale.unsqueeze(1)
             
        modulated_low_rank = low_rank * layer_scale
        
        # Apply Scaling Factor (Key for learning signal!)
        out_lora = (modulated_low_rank @ self.lora_B) * self.scaling
        
        return out_base + out_lora

# ============================================================================
# 3. WALT DYNAMICS (The World Model Head)
# ============================================================================
class WALTDynamics(nn.Module):
    def __init__(self, hidden_dim, latent_dim):
        super().__init__()
        self.norm = nn.LayerNorm(hidden_dim) # Stabilize input from LLM
        self.projector = nn.Linear(hidden_dim, latent_dim)
        self.predictor = nn.GRUCell(latent_dim, latent_dim) # Simple dynamics
        
    def forward(self, h):
        # h: [Batch, Seq, Dim] -> [Batch, Dim] (Last token)
        h = h.to(self.projector.weight.dtype)
        h = self.norm(h)
        z = self.projector(h[:, -1, :])
        z_next = self.predictor(z, z) # Auto-regressive step
        return z, z_next

# ============================================================================
# 4. FULL RL-PHYSICS MODEL
# ============================================================================
class PhysicsModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        # 1. Base LLM (Load in FP16/BF16 - No Quantization for DataParallel stability)
        # Gemma 1B is small enough (2GB) to fit fully on T4 without quantization.
        print(f"Loading {Config.MODEL_ID}...")
        self.llm = AutoModelForCausalLM.from_pretrained(
            Config.MODEL_ID,
            torch_dtype=Config.DTYPE,
            # quantization_config=bnb_config, # Disabled for Multi-GPU Stability
            # device_map="auto" # Disabled for DataParallel
        )
        self.tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_ID)
        
        # Freeze LLM
        for p in self.llm.parameters():
            p.requires_grad = False
            
        # Ensure lm_head matches our FP32 stream for inference
        if hasattr(self.llm, 'lm_head'):
             self.llm.lm_head.to(Config.DTYPE)
            
        # 2. Controller
        hidden_size = self.llm.config.hidden_size
        print(f"Detected Hidden Size: {hidden_size}")
        
        self.controller = PhysicsController(
            hidden_size, 
            Config.CONTROLLER_HIDDEN, 
            Config.MODULATION_DIM
        ).to(Config.DTYPE).to(self.llm.device)
        
        # 3. WALT Head
        self.walt = WALTDynamics(
            hidden_size, 
            Config.LATENT_DIM
        ).to(Config.DTYPE).to(self.llm.device)
        
        # 4. Inject Flux Adapters
        self.flux_layers = []
        self._inject_flux_layers()
        
    def _inject_flux_layers(self):
        print("Injecting Flux Adapters into MLP layers...")
        # Recursive replacement
        for name, module in self.llm.named_modules():
            # Target gate_proj, up_proj, down_proj in MLP
            if name.endswith("gate_proj") or name.endswith("up_proj") or name.endswith("down_proj"):
                parent_name = ".".join(name.split(".")[:-1])
                child_name = name.split(".")[-1]
                parent = self.llm.get_submodule(parent_name)
                
                # Wrap
                original_layer = getattr(parent, child_name)
                flux_adapter = FluxAdapter(original_layer, Config.MODULATION_DIM).to(device=original_layer.weight.device, dtype=Config.DTYPE)
                
                setattr(parent, child_name, flux_adapter)
                self.flux_layers.append(flux_adapter)
                
        print(f"Injected {len(self.flux_layers)} Flux Adapters.")

    def set_active_modulation(self, modulation_vector):
        """
        Broadcasts the modulation vector to all Flux Layers.
        vector: [Batch, Mod_Dim]
        """
        # We can store it in the module, but FluxAdapter doesn't know about 'self'.
        # We need a shared state or pass it. 
        # Hack: Set it on each adapter instance before forward.
        for layer in self.flux_layers:
            layer.active_modulation = modulation_vector
            
            # Monkey-patch the forward to use this specific vector? 
            # Better: The FluxAdapter.forward checks `self.active_modulation`
            
        # Update FluxAdapter class to look for this
    
    def clear_modulation(self):
        for layer in self.flux_layers:
            if hasattr(layer, 'active_modulation'):
                del layer.active_modulation

    def get_embeddings(self, input_ids):
        # We need to run the full model to get hidden states
        # The FluxLayers will kick in if set_active_modulation was called.
        # Use .model to bypass lm_head (which might be FP16 and crash with FP32 stream)
        out = self.llm.model(input_ids, output_hidden_states=True)
        return out.last_hidden_state # or out.hidden_states[-1]

    def forward(self, input_ids, forced_modulation=None):
        # 1. Get Initial Context (Unmodulated "Perception")
        self.clear_modulation()
        
        # --- PATH A: FORCED MODULATION (Language Training) ---
        if forced_modulation is not None:
             self.set_active_modulation(forced_modulation)
             h_modulated = self.get_embeddings(input_ids)
             h_modulated = h_modulated.to(Config.DTYPE)
             logits = self.llm.lm_head(h_modulated)
             self.clear_modulation()
             return logits

        # --- PATH B: STANDARD CONTROLLER (Physics & Dynamics Training) ---
        with torch.no_grad():
            h_init = self.get_embeddings(input_ids)
            
        # 2. Controller decision
        # h_init: [Batch, Seq, Dim]
        # Ensure gradients flow from here back to Controller
        # Cast to Float32 if needed
        modulation = self.controller(h_init.to(Config.DTYPE)) # [Batch, Mod_Dim]
        
        # 3. Apply Modulation ("Flux State")
        self.set_active_modulation(modulation)
        
        # 4. Get New Context (Modulated "Understanding")
        # We run the LLM again. This is expensive but necessary for "Flux" change.
        h_modulated = self.get_embeddings(input_ids)
        h_modulated = h_modulated.to(Config.DTYPE)
        
        # 5. Simulate World Model (Latent) from Modulated State
        z, z_next_pred = self.walt(h_modulated)
        
        # 6. Get LM Logits for KL Divergence (Language Preservation)
        # We need to project h_modulated back to vocabulary
        logits = self.llm.lm_head(h_modulated)
        
        return z, z_next_pred, modulation, logits