File size: 10,660 Bytes
9469618 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from config_physics import Config
# ============================================================================
# 1. THE CONTROLLER (The "Brain's Brain")
# ============================================================================
class PhysicsController(nn.Module):
"""
RL Policy Network.
Observes the input state (hidden state from LLM) and outputs
a 'Modulation Vector' that adjusts the Flux Layers.
"""
def __init__(self, input_dim, hidden_dim, action_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Tanh() # Actions are shifts in [-1, 1] or scales
)
def forward(self, x):
# x: [Batch, Seq, Dim] -> Pool to [Batch, Dim] for global context
# OR per-token modulation? Let's do Global Context for stability first.
x_pooled = x.mean(dim=1)
action = self.net(x_pooled)
return action
# ============================================================================
# 2. DYNAMIC FLUX LAYER (Modulated FFN)
# ============================================================================
class FluxAdapter(nn.Module):
"""
Injects into standard Linear layers.
Weight = W_base + (Modulation * W_lora)
This is effectively 'Dynamic LoRA'.
"""
def __init__(self, original_layer, modulation_dim):
super().__init__()
self.base_layer = original_layer
self.in_features = original_layer.in_features
self.out_features = original_layer.out_features
# LoRA-style adapters
self.lora_A = nn.Parameter(torch.randn(self.in_features, modulation_dim) * 0.02) # Boosted form 0.01
self.lora_B = nn.Parameter(torch.zeros(modulation_dim, self.out_features))
self.scaling = 4.0 # Standard LoRA scaling factor , previously it was 20
# The modulation input comes from the Controller
# self.modulation_proj = nn.Linear(Config.MODULATION_DIM, modulation_dim)
# nn.init.zeros_(self.modulation_proj.weight)
# nn.init.constant_(self.modulation_proj.bias, 1.0) # FIX: Start at 1.0 to enable gradient flow
self.modulation_proj = nn.Linear(Config.MODULATION_DIM, modulation_dim)
nn.init.zeros_(self.modulation_proj.weight)
nn.init.constant_(self.modulation_proj.bias, 1.0) # Enable Flow!
print(f"✅ FluxAdapter Init: Bias Norm = {self.modulation_proj.bias.norm().item()} (Flow Enabled)")
# self.debug = False # Debug Flag
def forward(self, x, modulation_vector=None):
# 1. Base Pass
out_base = self.base_layer(x)
# Ensure x matches adapter dtype (Float32)
x = x.to(self.lora_A.dtype)
# Check instance state if arg is missing
if modulation_vector is None:
if hasattr(self, 'active_modulation'):
modulation_vector = self.active_modulation
else:
return out_base
# 2. Dynamic Adapter Pass
# self.modulation_proj: [Global_Dim -> Local_Dim]
layer_scale = self.modulation_proj(modulation_vector) # [Batch, Rank]
# x: [Batch, Seq, In]
# A: [In, Rank]
low_rank = x @ self.lora_A # [Batch, Seq, Rank]
# Apply modulation
# [Batch, Seq, Rank] * [Batch, 1, Rank]
# Broadcasing: layer_scale is [Batch, Rank]. We need to unsqueeze to match Seq.
# If modulation_vector is [Batch, Dim], layer_scale is [Batch, Dim].
if layer_scale.dim() == 2:
layer_scale = layer_scale.unsqueeze(1)
modulated_low_rank = low_rank * layer_scale
# Apply Scaling Factor (Key for learning signal!)
out_lora = (modulated_low_rank @ self.lora_B) * self.scaling
return out_base + out_lora
# ============================================================================
# 3. WALT DYNAMICS (The World Model Head)
# ============================================================================
class WALTDynamics(nn.Module):
def __init__(self, hidden_dim, latent_dim):
super().__init__()
self.norm = nn.LayerNorm(hidden_dim) # Stabilize input from LLM
self.projector = nn.Linear(hidden_dim, latent_dim)
self.predictor = nn.GRUCell(latent_dim, latent_dim) # Simple dynamics
def forward(self, h):
# h: [Batch, Seq, Dim] -> [Batch, Dim] (Last token)
h = h.to(self.projector.weight.dtype)
h = self.norm(h)
z = self.projector(h[:, -1, :])
z_next = self.predictor(z, z) # Auto-regressive step
return z, z_next
# ============================================================================
# 4. FULL RL-PHYSICS MODEL
# ============================================================================
class PhysicsModel(nn.Module):
def __init__(self):
super().__init__()
# 1. Base LLM (Load in FP16/BF16 - No Quantization for DataParallel stability)
# Gemma 1B is small enough (2GB) to fit fully on T4 without quantization.
print(f"Loading {Config.MODEL_ID}...")
self.llm = AutoModelForCausalLM.from_pretrained(
Config.MODEL_ID,
torch_dtype=Config.DTYPE,
# quantization_config=bnb_config, # Disabled for Multi-GPU Stability
# device_map="auto" # Disabled for DataParallel
)
self.tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_ID)
# Freeze LLM
for p in self.llm.parameters():
p.requires_grad = False
# Ensure lm_head matches our FP32 stream for inference
if hasattr(self.llm, 'lm_head'):
self.llm.lm_head.to(Config.DTYPE)
# 2. Controller
hidden_size = self.llm.config.hidden_size
print(f"Detected Hidden Size: {hidden_size}")
self.controller = PhysicsController(
hidden_size,
Config.CONTROLLER_HIDDEN,
Config.MODULATION_DIM
).to(Config.DTYPE).to(self.llm.device)
# 3. WALT Head
self.walt = WALTDynamics(
hidden_size,
Config.LATENT_DIM
).to(Config.DTYPE).to(self.llm.device)
# 4. Inject Flux Adapters
self.flux_layers = []
self._inject_flux_layers()
def _inject_flux_layers(self):
print("Injecting Flux Adapters into MLP layers...")
# Recursive replacement
for name, module in self.llm.named_modules():
# Target gate_proj, up_proj, down_proj in MLP
if name.endswith("gate_proj") or name.endswith("up_proj") or name.endswith("down_proj"):
parent_name = ".".join(name.split(".")[:-1])
child_name = name.split(".")[-1]
parent = self.llm.get_submodule(parent_name)
# Wrap
original_layer = getattr(parent, child_name)
flux_adapter = FluxAdapter(original_layer, Config.MODULATION_DIM).to(device=original_layer.weight.device, dtype=Config.DTYPE)
setattr(parent, child_name, flux_adapter)
self.flux_layers.append(flux_adapter)
print(f"Injected {len(self.flux_layers)} Flux Adapters.")
def set_active_modulation(self, modulation_vector):
"""
Broadcasts the modulation vector to all Flux Layers.
vector: [Batch, Mod_Dim]
"""
# We can store it in the module, but FluxAdapter doesn't know about 'self'.
# We need a shared state or pass it.
# Hack: Set it on each adapter instance before forward.
for layer in self.flux_layers:
layer.active_modulation = modulation_vector
# Monkey-patch the forward to use this specific vector?
# Better: The FluxAdapter.forward checks `self.active_modulation`
# Update FluxAdapter class to look for this
def clear_modulation(self):
for layer in self.flux_layers:
if hasattr(layer, 'active_modulation'):
del layer.active_modulation
def get_embeddings(self, input_ids):
# We need to run the full model to get hidden states
# The FluxLayers will kick in if set_active_modulation was called.
# Use .model to bypass lm_head (which might be FP16 and crash with FP32 stream)
out = self.llm.model(input_ids, output_hidden_states=True)
return out.last_hidden_state # or out.hidden_states[-1]
def forward(self, input_ids, forced_modulation=None):
# 1. Get Initial Context (Unmodulated "Perception")
self.clear_modulation()
# --- PATH A: FORCED MODULATION (Language Training) ---
if forced_modulation is not None:
self.set_active_modulation(forced_modulation)
h_modulated = self.get_embeddings(input_ids)
h_modulated = h_modulated.to(Config.DTYPE)
logits = self.llm.lm_head(h_modulated)
self.clear_modulation()
return logits
# --- PATH B: STANDARD CONTROLLER (Physics & Dynamics Training) ---
with torch.no_grad():
h_init = self.get_embeddings(input_ids)
# 2. Controller decision
# h_init: [Batch, Seq, Dim]
# Ensure gradients flow from here back to Controller
# Cast to Float32 if needed
modulation = self.controller(h_init.to(Config.DTYPE)) # [Batch, Mod_Dim]
# 3. Apply Modulation ("Flux State")
self.set_active_modulation(modulation)
# 4. Get New Context (Modulated "Understanding")
# We run the LLM again. This is expensive but necessary for "Flux" change.
h_modulated = self.get_embeddings(input_ids)
h_modulated = h_modulated.to(Config.DTYPE)
# 5. Simulate World Model (Latent) from Modulated State
z, z_next_pred = self.walt(h_modulated)
# 6. Get LM Logits for KL Divergence (Language Preservation)
# We need to project h_modulated back to vocabulary
logits = self.llm.lm_head(h_modulated)
return z, z_next_pred, modulation, logits
|