flux-test-time-training / modeling_physics_rl.py
convaiinnovations's picture
Upload modeling_physics_rl.py with huggingface_hub
9469618 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from config_physics import Config
# ============================================================================
# 1. THE CONTROLLER (The "Brain's Brain")
# ============================================================================
class PhysicsController(nn.Module):
"""
RL Policy Network.
Observes the input state (hidden state from LLM) and outputs
a 'Modulation Vector' that adjusts the Flux Layers.
"""
def __init__(self, input_dim, hidden_dim, action_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Tanh() # Actions are shifts in [-1, 1] or scales
)
def forward(self, x):
# x: [Batch, Seq, Dim] -> Pool to [Batch, Dim] for global context
# OR per-token modulation? Let's do Global Context for stability first.
x_pooled = x.mean(dim=1)
action = self.net(x_pooled)
return action
# ============================================================================
# 2. DYNAMIC FLUX LAYER (Modulated FFN)
# ============================================================================
class FluxAdapter(nn.Module):
"""
Injects into standard Linear layers.
Weight = W_base + (Modulation * W_lora)
This is effectively 'Dynamic LoRA'.
"""
def __init__(self, original_layer, modulation_dim):
super().__init__()
self.base_layer = original_layer
self.in_features = original_layer.in_features
self.out_features = original_layer.out_features
# LoRA-style adapters
self.lora_A = nn.Parameter(torch.randn(self.in_features, modulation_dim) * 0.02) # Boosted form 0.01
self.lora_B = nn.Parameter(torch.zeros(modulation_dim, self.out_features))
self.scaling = 4.0 # Standard LoRA scaling factor , previously it was 20
# The modulation input comes from the Controller
# self.modulation_proj = nn.Linear(Config.MODULATION_DIM, modulation_dim)
# nn.init.zeros_(self.modulation_proj.weight)
# nn.init.constant_(self.modulation_proj.bias, 1.0) # FIX: Start at 1.0 to enable gradient flow
self.modulation_proj = nn.Linear(Config.MODULATION_DIM, modulation_dim)
nn.init.zeros_(self.modulation_proj.weight)
nn.init.constant_(self.modulation_proj.bias, 1.0) # Enable Flow!
print(f"✅ FluxAdapter Init: Bias Norm = {self.modulation_proj.bias.norm().item()} (Flow Enabled)")
# self.debug = False # Debug Flag
def forward(self, x, modulation_vector=None):
# 1. Base Pass
out_base = self.base_layer(x)
# Ensure x matches adapter dtype (Float32)
x = x.to(self.lora_A.dtype)
# Check instance state if arg is missing
if modulation_vector is None:
if hasattr(self, 'active_modulation'):
modulation_vector = self.active_modulation
else:
return out_base
# 2. Dynamic Adapter Pass
# self.modulation_proj: [Global_Dim -> Local_Dim]
layer_scale = self.modulation_proj(modulation_vector) # [Batch, Rank]
# x: [Batch, Seq, In]
# A: [In, Rank]
low_rank = x @ self.lora_A # [Batch, Seq, Rank]
# Apply modulation
# [Batch, Seq, Rank] * [Batch, 1, Rank]
# Broadcasing: layer_scale is [Batch, Rank]. We need to unsqueeze to match Seq.
# If modulation_vector is [Batch, Dim], layer_scale is [Batch, Dim].
if layer_scale.dim() == 2:
layer_scale = layer_scale.unsqueeze(1)
modulated_low_rank = low_rank * layer_scale
# Apply Scaling Factor (Key for learning signal!)
out_lora = (modulated_low_rank @ self.lora_B) * self.scaling
return out_base + out_lora
# ============================================================================
# 3. WALT DYNAMICS (The World Model Head)
# ============================================================================
class WALTDynamics(nn.Module):
def __init__(self, hidden_dim, latent_dim):
super().__init__()
self.norm = nn.LayerNorm(hidden_dim) # Stabilize input from LLM
self.projector = nn.Linear(hidden_dim, latent_dim)
self.predictor = nn.GRUCell(latent_dim, latent_dim) # Simple dynamics
def forward(self, h):
# h: [Batch, Seq, Dim] -> [Batch, Dim] (Last token)
h = h.to(self.projector.weight.dtype)
h = self.norm(h)
z = self.projector(h[:, -1, :])
z_next = self.predictor(z, z) # Auto-regressive step
return z, z_next
# ============================================================================
# 4. FULL RL-PHYSICS MODEL
# ============================================================================
class PhysicsModel(nn.Module):
def __init__(self):
super().__init__()
# 1. Base LLM (Load in FP16/BF16 - No Quantization for DataParallel stability)
# Gemma 1B is small enough (2GB) to fit fully on T4 without quantization.
print(f"Loading {Config.MODEL_ID}...")
self.llm = AutoModelForCausalLM.from_pretrained(
Config.MODEL_ID,
torch_dtype=Config.DTYPE,
# quantization_config=bnb_config, # Disabled for Multi-GPU Stability
# device_map="auto" # Disabled for DataParallel
)
self.tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_ID)
# Freeze LLM
for p in self.llm.parameters():
p.requires_grad = False
# Ensure lm_head matches our FP32 stream for inference
if hasattr(self.llm, 'lm_head'):
self.llm.lm_head.to(Config.DTYPE)
# 2. Controller
hidden_size = self.llm.config.hidden_size
print(f"Detected Hidden Size: {hidden_size}")
self.controller = PhysicsController(
hidden_size,
Config.CONTROLLER_HIDDEN,
Config.MODULATION_DIM
).to(Config.DTYPE).to(self.llm.device)
# 3. WALT Head
self.walt = WALTDynamics(
hidden_size,
Config.LATENT_DIM
).to(Config.DTYPE).to(self.llm.device)
# 4. Inject Flux Adapters
self.flux_layers = []
self._inject_flux_layers()
def _inject_flux_layers(self):
print("Injecting Flux Adapters into MLP layers...")
# Recursive replacement
for name, module in self.llm.named_modules():
# Target gate_proj, up_proj, down_proj in MLP
if name.endswith("gate_proj") or name.endswith("up_proj") or name.endswith("down_proj"):
parent_name = ".".join(name.split(".")[:-1])
child_name = name.split(".")[-1]
parent = self.llm.get_submodule(parent_name)
# Wrap
original_layer = getattr(parent, child_name)
flux_adapter = FluxAdapter(original_layer, Config.MODULATION_DIM).to(device=original_layer.weight.device, dtype=Config.DTYPE)
setattr(parent, child_name, flux_adapter)
self.flux_layers.append(flux_adapter)
print(f"Injected {len(self.flux_layers)} Flux Adapters.")
def set_active_modulation(self, modulation_vector):
"""
Broadcasts the modulation vector to all Flux Layers.
vector: [Batch, Mod_Dim]
"""
# We can store it in the module, but FluxAdapter doesn't know about 'self'.
# We need a shared state or pass it.
# Hack: Set it on each adapter instance before forward.
for layer in self.flux_layers:
layer.active_modulation = modulation_vector
# Monkey-patch the forward to use this specific vector?
# Better: The FluxAdapter.forward checks `self.active_modulation`
# Update FluxAdapter class to look for this
def clear_modulation(self):
for layer in self.flux_layers:
if hasattr(layer, 'active_modulation'):
del layer.active_modulation
def get_embeddings(self, input_ids):
# We need to run the full model to get hidden states
# The FluxLayers will kick in if set_active_modulation was called.
# Use .model to bypass lm_head (which might be FP16 and crash with FP32 stream)
out = self.llm.model(input_ids, output_hidden_states=True)
return out.last_hidden_state # or out.hidden_states[-1]
def forward(self, input_ids, forced_modulation=None):
# 1. Get Initial Context (Unmodulated "Perception")
self.clear_modulation()
# --- PATH A: FORCED MODULATION (Language Training) ---
if forced_modulation is not None:
self.set_active_modulation(forced_modulation)
h_modulated = self.get_embeddings(input_ids)
h_modulated = h_modulated.to(Config.DTYPE)
logits = self.llm.lm_head(h_modulated)
self.clear_modulation()
return logits
# --- PATH B: STANDARD CONTROLLER (Physics & Dynamics Training) ---
with torch.no_grad():
h_init = self.get_embeddings(input_ids)
# 2. Controller decision
# h_init: [Batch, Seq, Dim]
# Ensure gradients flow from here back to Controller
# Cast to Float32 if needed
modulation = self.controller(h_init.to(Config.DTYPE)) # [Batch, Mod_Dim]
# 3. Apply Modulation ("Flux State")
self.set_active_modulation(modulation)
# 4. Get New Context (Modulated "Understanding")
# We run the LLM again. This is expensive but necessary for "Flux" change.
h_modulated = self.get_embeddings(input_ids)
h_modulated = h_modulated.to(Config.DTYPE)
# 5. Simulate World Model (Latent) from Modulated State
z, z_next_pred = self.walt(h_modulated)
# 6. Get LM Logits for KL Divergence (Language Preservation)
# We need to project h_modulated back to vocabulary
logits = self.llm.lm_head(h_modulated)
return z, z_next_pred, modulation, logits