| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
| | from peft import LoraConfig, get_peft_model |
| | from config_physics import Config |
| |
|
| | |
| | |
| | |
| | class PhysicsController(nn.Module): |
| | """ |
| | RL Policy Network. |
| | Observes the input state (hidden state from LLM) and outputs |
| | a 'Modulation Vector' that adjusts the Flux Layers. |
| | """ |
| | def __init__(self, input_dim, hidden_dim, action_dim): |
| | super().__init__() |
| | self.net = nn.Sequential( |
| | nn.Linear(input_dim, hidden_dim), |
| | nn.LayerNorm(hidden_dim), |
| | nn.ReLU(), |
| | nn.Linear(hidden_dim, action_dim), |
| | nn.Tanh() |
| | ) |
| | |
| | def forward(self, x): |
| | |
| | |
| | x_pooled = x.mean(dim=1) |
| | action = self.net(x_pooled) |
| | return action |
| |
|
| | |
| | |
| | |
| | class FluxAdapter(nn.Module): |
| | """ |
| | Injects into standard Linear layers. |
| | Weight = W_base + (Modulation * W_lora) |
| | This is effectively 'Dynamic LoRA'. |
| | """ |
| | def __init__(self, original_layer, modulation_dim): |
| | super().__init__() |
| | self.base_layer = original_layer |
| | self.in_features = original_layer.in_features |
| | self.out_features = original_layer.out_features |
| | |
| | |
| | self.lora_A = nn.Parameter(torch.randn(self.in_features, modulation_dim) * 0.02) |
| | self.lora_B = nn.Parameter(torch.zeros(modulation_dim, self.out_features)) |
| | |
| | self.scaling = 4.0 |
| | |
| | |
| | |
| | |
| | |
| | |
| | self.modulation_proj = nn.Linear(Config.MODULATION_DIM, modulation_dim) |
| | nn.init.zeros_(self.modulation_proj.weight) |
| | nn.init.constant_(self.modulation_proj.bias, 1.0) |
| | print(f"✅ FluxAdapter Init: Bias Norm = {self.modulation_proj.bias.norm().item()} (Flow Enabled)") |
| |
|
| | |
| |
|
| | def forward(self, x, modulation_vector=None): |
| | |
| | out_base = self.base_layer(x) |
| | |
| | |
| | x = x.to(self.lora_A.dtype) |
| | |
| | |
| | if modulation_vector is None: |
| | if hasattr(self, 'active_modulation'): |
| | modulation_vector = self.active_modulation |
| | else: |
| | return out_base |
| | |
| | |
| | |
| | layer_scale = self.modulation_proj(modulation_vector) |
| | |
| | |
| | |
| | low_rank = x @ self.lora_A |
| | |
| | |
| | |
| | |
| | |
| | if layer_scale.dim() == 2: |
| | layer_scale = layer_scale.unsqueeze(1) |
| | |
| | modulated_low_rank = low_rank * layer_scale |
| | |
| | |
| | out_lora = (modulated_low_rank @ self.lora_B) * self.scaling |
| | |
| | return out_base + out_lora |
| |
|
| | |
| | |
| | |
| | class WALTDynamics(nn.Module): |
| | def __init__(self, hidden_dim, latent_dim): |
| | super().__init__() |
| | self.norm = nn.LayerNorm(hidden_dim) |
| | self.projector = nn.Linear(hidden_dim, latent_dim) |
| | self.predictor = nn.GRUCell(latent_dim, latent_dim) |
| | |
| | def forward(self, h): |
| | |
| | h = h.to(self.projector.weight.dtype) |
| | h = self.norm(h) |
| | z = self.projector(h[:, -1, :]) |
| | z_next = self.predictor(z, z) |
| | return z, z_next |
| |
|
| | |
| | |
| | |
| | class PhysicsModel(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | |
| | |
| | |
| | print(f"Loading {Config.MODEL_ID}...") |
| | self.llm = AutoModelForCausalLM.from_pretrained( |
| | Config.MODEL_ID, |
| | torch_dtype=Config.DTYPE, |
| | |
| | |
| | ) |
| | self.tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_ID) |
| | |
| | |
| | for p in self.llm.parameters(): |
| | p.requires_grad = False |
| | |
| | |
| | if hasattr(self.llm, 'lm_head'): |
| | self.llm.lm_head.to(Config.DTYPE) |
| | |
| | |
| | hidden_size = self.llm.config.hidden_size |
| | print(f"Detected Hidden Size: {hidden_size}") |
| | |
| | self.controller = PhysicsController( |
| | hidden_size, |
| | Config.CONTROLLER_HIDDEN, |
| | Config.MODULATION_DIM |
| | ).to(Config.DTYPE).to(self.llm.device) |
| | |
| | |
| | self.walt = WALTDynamics( |
| | hidden_size, |
| | Config.LATENT_DIM |
| | ).to(Config.DTYPE).to(self.llm.device) |
| | |
| | |
| | self.flux_layers = [] |
| | self._inject_flux_layers() |
| | |
| | def _inject_flux_layers(self): |
| | print("Injecting Flux Adapters into MLP layers...") |
| | |
| | for name, module in self.llm.named_modules(): |
| | |
| | if name.endswith("gate_proj") or name.endswith("up_proj") or name.endswith("down_proj"): |
| | parent_name = ".".join(name.split(".")[:-1]) |
| | child_name = name.split(".")[-1] |
| | parent = self.llm.get_submodule(parent_name) |
| | |
| | |
| | original_layer = getattr(parent, child_name) |
| | flux_adapter = FluxAdapter(original_layer, Config.MODULATION_DIM).to(device=original_layer.weight.device, dtype=Config.DTYPE) |
| | |
| | setattr(parent, child_name, flux_adapter) |
| | self.flux_layers.append(flux_adapter) |
| | |
| | print(f"Injected {len(self.flux_layers)} Flux Adapters.") |
| |
|
| | def set_active_modulation(self, modulation_vector): |
| | """ |
| | Broadcasts the modulation vector to all Flux Layers. |
| | vector: [Batch, Mod_Dim] |
| | """ |
| | |
| | |
| | |
| | for layer in self.flux_layers: |
| | layer.active_modulation = modulation_vector |
| | |
| | |
| | |
| | |
| | |
| | |
| | def clear_modulation(self): |
| | for layer in self.flux_layers: |
| | if hasattr(layer, 'active_modulation'): |
| | del layer.active_modulation |
| |
|
| | def get_embeddings(self, input_ids): |
| | |
| | |
| | |
| | out = self.llm.model(input_ids, output_hidden_states=True) |
| | return out.last_hidden_state |
| |
|
| | def forward(self, input_ids, forced_modulation=None): |
| | |
| | self.clear_modulation() |
| | |
| | |
| | if forced_modulation is not None: |
| | self.set_active_modulation(forced_modulation) |
| | h_modulated = self.get_embeddings(input_ids) |
| | h_modulated = h_modulated.to(Config.DTYPE) |
| | logits = self.llm.lm_head(h_modulated) |
| | self.clear_modulation() |
| | return logits |
| |
|
| | |
| | with torch.no_grad(): |
| | h_init = self.get_embeddings(input_ids) |
| | |
| | |
| | |
| | |
| | |
| | modulation = self.controller(h_init.to(Config.DTYPE)) |
| | |
| | |
| | self.set_active_modulation(modulation) |
| | |
| | |
| | |
| | h_modulated = self.get_embeddings(input_ids) |
| | h_modulated = h_modulated.to(Config.DTYPE) |
| | |
| | |
| | z, z_next_pred = self.walt(h_modulated) |
| | |
| | |
| | |
| | logits = self.llm.lm_head(h_modulated) |
| | |
| | return z, z_next_pred, modulation, logits |
| |
|