import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from peft import LoraConfig, get_peft_model from config_physics import Config # ============================================================================ # 1. THE CONTROLLER (The "Brain's Brain") # ============================================================================ class PhysicsController(nn.Module): """ RL Policy Network. Observes the input state (hidden state from LLM) and outputs a 'Modulation Vector' that adjusts the Flux Layers. """ def __init__(self, input_dim, hidden_dim, action_dim): super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, action_dim), nn.Tanh() # Actions are shifts in [-1, 1] or scales ) def forward(self, x): # x: [Batch, Seq, Dim] -> Pool to [Batch, Dim] for global context # OR per-token modulation? Let's do Global Context for stability first. x_pooled = x.mean(dim=1) action = self.net(x_pooled) return action # ============================================================================ # 2. DYNAMIC FLUX LAYER (Modulated FFN) # ============================================================================ class FluxAdapter(nn.Module): """ Injects into standard Linear layers. Weight = W_base + (Modulation * W_lora) This is effectively 'Dynamic LoRA'. """ def __init__(self, original_layer, modulation_dim): super().__init__() self.base_layer = original_layer self.in_features = original_layer.in_features self.out_features = original_layer.out_features # LoRA-style adapters self.lora_A = nn.Parameter(torch.randn(self.in_features, modulation_dim) * 0.02) # Boosted form 0.01 self.lora_B = nn.Parameter(torch.zeros(modulation_dim, self.out_features)) self.scaling = 4.0 # Standard LoRA scaling factor , previously it was 20 # The modulation input comes from the Controller # self.modulation_proj = nn.Linear(Config.MODULATION_DIM, modulation_dim) # nn.init.zeros_(self.modulation_proj.weight) # nn.init.constant_(self.modulation_proj.bias, 1.0) # FIX: Start at 1.0 to enable gradient flow self.modulation_proj = nn.Linear(Config.MODULATION_DIM, modulation_dim) nn.init.zeros_(self.modulation_proj.weight) nn.init.constant_(self.modulation_proj.bias, 1.0) # Enable Flow! print(f"✅ FluxAdapter Init: Bias Norm = {self.modulation_proj.bias.norm().item()} (Flow Enabled)") # self.debug = False # Debug Flag def forward(self, x, modulation_vector=None): # 1. Base Pass out_base = self.base_layer(x) # Ensure x matches adapter dtype (Float32) x = x.to(self.lora_A.dtype) # Check instance state if arg is missing if modulation_vector is None: if hasattr(self, 'active_modulation'): modulation_vector = self.active_modulation else: return out_base # 2. Dynamic Adapter Pass # self.modulation_proj: [Global_Dim -> Local_Dim] layer_scale = self.modulation_proj(modulation_vector) # [Batch, Rank] # x: [Batch, Seq, In] # A: [In, Rank] low_rank = x @ self.lora_A # [Batch, Seq, Rank] # Apply modulation # [Batch, Seq, Rank] * [Batch, 1, Rank] # Broadcasing: layer_scale is [Batch, Rank]. We need to unsqueeze to match Seq. # If modulation_vector is [Batch, Dim], layer_scale is [Batch, Dim]. if layer_scale.dim() == 2: layer_scale = layer_scale.unsqueeze(1) modulated_low_rank = low_rank * layer_scale # Apply Scaling Factor (Key for learning signal!) out_lora = (modulated_low_rank @ self.lora_B) * self.scaling return out_base + out_lora # ============================================================================ # 3. WALT DYNAMICS (The World Model Head) # ============================================================================ class WALTDynamics(nn.Module): def __init__(self, hidden_dim, latent_dim): super().__init__() self.norm = nn.LayerNorm(hidden_dim) # Stabilize input from LLM self.projector = nn.Linear(hidden_dim, latent_dim) self.predictor = nn.GRUCell(latent_dim, latent_dim) # Simple dynamics def forward(self, h): # h: [Batch, Seq, Dim] -> [Batch, Dim] (Last token) h = h.to(self.projector.weight.dtype) h = self.norm(h) z = self.projector(h[:, -1, :]) z_next = self.predictor(z, z) # Auto-regressive step return z, z_next # ============================================================================ # 4. FULL RL-PHYSICS MODEL # ============================================================================ class PhysicsModel(nn.Module): def __init__(self): super().__init__() # 1. Base LLM (Load in FP16/BF16 - No Quantization for DataParallel stability) # Gemma 1B is small enough (2GB) to fit fully on T4 without quantization. print(f"Loading {Config.MODEL_ID}...") self.llm = AutoModelForCausalLM.from_pretrained( Config.MODEL_ID, torch_dtype=Config.DTYPE, # quantization_config=bnb_config, # Disabled for Multi-GPU Stability # device_map="auto" # Disabled for DataParallel ) self.tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_ID) # Freeze LLM for p in self.llm.parameters(): p.requires_grad = False # Ensure lm_head matches our FP32 stream for inference if hasattr(self.llm, 'lm_head'): self.llm.lm_head.to(Config.DTYPE) # 2. Controller hidden_size = self.llm.config.hidden_size print(f"Detected Hidden Size: {hidden_size}") self.controller = PhysicsController( hidden_size, Config.CONTROLLER_HIDDEN, Config.MODULATION_DIM ).to(Config.DTYPE).to(self.llm.device) # 3. WALT Head self.walt = WALTDynamics( hidden_size, Config.LATENT_DIM ).to(Config.DTYPE).to(self.llm.device) # 4. Inject Flux Adapters self.flux_layers = [] self._inject_flux_layers() def _inject_flux_layers(self): print("Injecting Flux Adapters into MLP layers...") # Recursive replacement for name, module in self.llm.named_modules(): # Target gate_proj, up_proj, down_proj in MLP if name.endswith("gate_proj") or name.endswith("up_proj") or name.endswith("down_proj"): parent_name = ".".join(name.split(".")[:-1]) child_name = name.split(".")[-1] parent = self.llm.get_submodule(parent_name) # Wrap original_layer = getattr(parent, child_name) flux_adapter = FluxAdapter(original_layer, Config.MODULATION_DIM).to(device=original_layer.weight.device, dtype=Config.DTYPE) setattr(parent, child_name, flux_adapter) self.flux_layers.append(flux_adapter) print(f"Injected {len(self.flux_layers)} Flux Adapters.") def set_active_modulation(self, modulation_vector): """ Broadcasts the modulation vector to all Flux Layers. vector: [Batch, Mod_Dim] """ # We can store it in the module, but FluxAdapter doesn't know about 'self'. # We need a shared state or pass it. # Hack: Set it on each adapter instance before forward. for layer in self.flux_layers: layer.active_modulation = modulation_vector # Monkey-patch the forward to use this specific vector? # Better: The FluxAdapter.forward checks `self.active_modulation` # Update FluxAdapter class to look for this def clear_modulation(self): for layer in self.flux_layers: if hasattr(layer, 'active_modulation'): del layer.active_modulation def get_embeddings(self, input_ids): # We need to run the full model to get hidden states # The FluxLayers will kick in if set_active_modulation was called. # Use .model to bypass lm_head (which might be FP16 and crash with FP32 stream) out = self.llm.model(input_ids, output_hidden_states=True) return out.last_hidden_state # or out.hidden_states[-1] def forward(self, input_ids, forced_modulation=None): # 1. Get Initial Context (Unmodulated "Perception") self.clear_modulation() # --- PATH A: FORCED MODULATION (Language Training) --- if forced_modulation is not None: self.set_active_modulation(forced_modulation) h_modulated = self.get_embeddings(input_ids) h_modulated = h_modulated.to(Config.DTYPE) logits = self.llm.lm_head(h_modulated) self.clear_modulation() return logits # --- PATH B: STANDARD CONTROLLER (Physics & Dynamics Training) --- with torch.no_grad(): h_init = self.get_embeddings(input_ids) # 2. Controller decision # h_init: [Batch, Seq, Dim] # Ensure gradients flow from here back to Controller # Cast to Float32 if needed modulation = self.controller(h_init.to(Config.DTYPE)) # [Batch, Mod_Dim] # 3. Apply Modulation ("Flux State") self.set_active_modulation(modulation) # 4. Get New Context (Modulated "Understanding") # We run the LLM again. This is expensive but necessary for "Flux" change. h_modulated = self.get_embeddings(input_ids) h_modulated = h_modulated.to(Config.DTYPE) # 5. Simulate World Model (Latent) from Modulated State z, z_next_pred = self.walt(h_modulated) # 6. Get LM Logits for KL Divergence (Language Preservation) # We need to project h_modulated back to vocabulary logits = self.llm.lm_head(h_modulated) return z, z_next_pred, modulation, logits