flux-test-time-training / benchmark_physics.py
convaiinnovations's picture
Upload benchmark_physics.py with huggingface_hub
d6535c0 verified
import torch
import logging
import os
import glob
from config_physics import Config
from modeling_physics_rl import PhysicsModel
# Setup logging
logging.basicConfig(level=logging.ERROR)
def load_models():
"""
Loads two versions of the model:
1. Flux Model: With trained Controller & Adapters active.
2. Base Model: The exact same model but with modulation forced to ZERO.
"""
print("⏳ Loading Physics Model...")
model = PhysicsModel()
# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f" Using Device: {device}")
# Load Weights
# Define search paths
search_paths = [
".",
"/kaggle/input/worldmodels/physics_model",
"/kaggle/working/physics_model"
]
# Check for weights
final_path = None
for p in search_paths:
fpath = os.path.join(p, "final_physics_controller.pt")
if os.path.exists(fpath):
final_path = p
break
try:
if final_path:
print(f" Loading Final Weights from {final_path}...")
model.controller.load_state_dict(torch.load(os.path.join(final_path, "final_physics_controller.pt"), map_location=model.llm.device))
model.walt.load_state_dict(torch.load(os.path.join(final_path, "final_walt_head.pt"), map_location=model.llm.device))
# Load Adapters
adapter_path = os.path.join(final_path, "final_flux_adapters.pt")
if os.path.exists(adapter_path):
print(" Loading Flux Adapters...")
adapter_states = torch.load(adapter_path, map_location=model.llm.device)
for layer, state in zip(model.flux_layers, adapter_states):
layer.load_state_dict(state)
else:
print(" ⚠️ Startled: Final adapters not found! Modulation might be dead.")
else:
# Fallback to latest checkpoint
checkpoints = []
for p in search_paths:
checkpoints.extend(glob.glob(os.path.join(p, "checkpoint_epoch_*.pt")))
if checkpoints:
latest_ckpt = max(checkpoints, key=os.path.getctime)
print(f" ⚠️ 'final' weights not found. Loading latest checkpoint: {latest_ckpt}")
ckpt_data = torch.load(latest_ckpt, map_location=model.llm.device)
# Check point uses specific keys, not full model_state_dict
if 'controller_state_dict' in ckpt_data:
model.controller.load_state_dict(ckpt_data['controller_state_dict'])
model.walt.load_state_dict(ckpt_data['walt_state_dict'])
if 'adapters_state_dict' in ckpt_data:
print(" Loading Flux Adapters from Checkpoint...")
for layer, state in zip(model.flux_layers, ckpt_data['adapters_state_dict']):
layer.load_state_dict(state)
else:
# Fallback if we accidentally saved it differently in a previous run
model.load_state_dict(ckpt_data['model_state_dict'], strict=False)
else:
raise FileNotFoundError("No 'final_physics_controller.pt' or 'checkpoint_epoch_*.pt' found.")
print("✅ Weights Loaded.")
except Exception as e:
print(f"⚠️ Warning: Could not load weights: {e}")
model.eval()
return model
def run_benchmark():
model = load_models()
# Health Check
try:
if hasattr(model.flux_layers[0], 'lora_B'):
lb_norm = model.flux_layers[0].lora_B.norm().item()
print(f"\n🔍 Health Check - First Adapter LoRA_B Norm: {lb_norm:.6f}")
if lb_norm == 0:
print(" ❌ WARNING: LoRA weights are ZERO. Training failed to update weights.")
else:
print(" ✅ Weights are LEARNED (Non-Zero).")
except: pass
test_cases = [
# --- TYPE A: QUALITATIVE (Concept Checks) ---
"I release a heavy steel marble from a height of one meter in a zero-gravity environment.",
"I drop a plastic camping plate onto a marble floor from waist height.",
"I shine a red laser beam through a glass prism.",
# --- TYPE B: QUANTITATIVE (Math & Engineering) ---
"A 2kg block slides down a frictionless ramp of height 5m. Calculate its velocity at the bottom. (g=9.8 m/s^2)",
"A car accelerates from 0 to 20 m/s in 4 seconds. What is the average acceleration?",
"A one-meter-long flexible cable lies at rest on a frictionless table, with 5 cm hanging over the edge. At what time will the cable completely slide off the table?",
"If I mix 100g of ice at 0°C with 100g of water at 80°C, what is the final temperature? (Specific heat of water = 4.18 J/g°C)",
]
results = []
print("\n" + "="*50)
print(" 🧪 Physics Benchmark: Base vs Flux")
print("="*50)
for prompt in test_cases:
full_prompt = f"User: {prompt}\nModel:"
inputs = model.tokenizer(full_prompt, return_tensors="pt").to(model.llm.device)
# --- Run 1: Base Model (No Modulation) ---
model.clear_modulation() # Ensure no modulation
# We can simulate "Base" by simply NOT calling set_active_modulation
# Or by setting modulation to all zeros.
# Let's set to zeros to be explicit.
zero_mod = torch.zeros(1, Config.MODULATION_DIM).to(model.llm.device).to(Config.DTYPE)
model.set_active_modulation(zero_mod)
out_base = model.llm.generate(**inputs, max_new_tokens=100, max_length=Config.MAX_LENGTH, do_sample=False) # Greedy for base
text_base = model.tokenizer.decode(out_base[0], skip_special_tokens=True).replace(full_prompt, "").strip()
# --- Run 2: Flux Model (With RL Modulation) ---
model.clear_modulation()
# Thinking Step
with torch.no_grad():
h_init = model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
modulation = model.controller(h_init)
# Analyze Modulation strength
mod_mag = modulation.norm().item()
model.set_active_modulation(modulation)
# --- Debug Trace (First 3 tokens) ---
try:
print("\n 🔍 Generation Trace (First 3 Steps):")
trace_input = inputs.input_ids.clone()
for i in range(3):
# Base (No Mod)
model.clear_modulation()
out_base = model.llm.model(trace_input)
base_norm = out_base.last_hidden_state[:,-1,:].norm().item()
# Flux (Modulated)
model.set_active_modulation(modulation)
out_liq = model.llm.model(trace_input)
liq_norm = out_liq.last_hidden_state[:,-1,:].norm().item()
# Difference
diff = out_liq.last_hidden_state[:,-1,:] - out_base.last_hidden_state[:,-1,:]
diff_norm = diff.norm().item()
ratio = (diff_norm / base_norm) * 100
print(f" Step {i}: Base={base_norm:.2f} | Flux={liq_norm:.2f} | Diff={diff_norm:.4f} ({ratio:.2f}%)")
# Advance one step (Greedy)
# Use internal lm_head to get logits
logits = model.llm.lm_head(out_liq.last_hidden_state[:,-1,:].unsqueeze(0))
# Check dim
if logits.dim() == 3: logits = logits[:, -1, :]
next_token = torch.argmax(logits, dim=-1).unsqueeze(0)
token_str = model.tokenizer.decode(next_token[0])
print(f" Selected Token: '{token_str}'")
if next_token.dim() == 1: next_token = next_token.unsqueeze(0)
trace_input = torch.cat([trace_input, next_token], dim=1)
except Exception as e:
print(f" ⚠️ Debug Trace Failed: {e}")
# Reset for actual generation
model.clear_modulation()
model.set_active_modulation(modulation)
out_liquid = model.llm.generate(**inputs, max_new_tokens=100, max_length=Config.MAX_LENGTH, do_sample=True, temperature=0.01)
text_liquid = model.tokenizer.decode(out_liquid[0], skip_special_tokens=True).replace(full_prompt, "").strip()
# Store Result
res = {
"Prompt": prompt,
"Base": text_base,
"Flux": text_liquid,
"Modulation_Norm": mod_mag
}
results.append(res)
print(f"\n📝 {prompt}")
print(f" 🧊 Base: {text_base[:100]}...")
print(f" 💧 Flux: {text_liquid[:100]}... (Mod Norm: {mod_mag:.2f})")
# Save detailed report
with open("benchmark_results.txt", "w") as f:
for r in results:
f.write(f"Prompt: {r['Prompt']}\n")
f.write(f"Base Model: {r['Base']}\n")
f.write(f"Flux Model: {r['Flux']}\n")
f.write(f"Modulation Strength: {r['Modulation_Norm']:.4f}\n")
f.write("-" * 30 + "\n")
print("\n✅ Benchmark Complete. Saved to benchmark_results.txt")
if __name__ == "__main__":
run_benchmark()