import torch import logging import os import glob from config_physics import Config from modeling_physics_rl import PhysicsModel # Setup logging logging.basicConfig(level=logging.ERROR) def load_models(): """ Loads two versions of the model: 1. Flux Model: With trained Controller & Adapters active. 2. Base Model: The exact same model but with modulation forced to ZERO. """ print("⏳ Loading Physics Model...") model = PhysicsModel() # Move to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) print(f" Using Device: {device}") # Load Weights # Define search paths search_paths = [ ".", "/kaggle/input/worldmodels/physics_model", "/kaggle/working/physics_model" ] # Check for weights final_path = None for p in search_paths: fpath = os.path.join(p, "final_physics_controller.pt") if os.path.exists(fpath): final_path = p break try: if final_path: print(f" Loading Final Weights from {final_path}...") model.controller.load_state_dict(torch.load(os.path.join(final_path, "final_physics_controller.pt"), map_location=model.llm.device)) model.walt.load_state_dict(torch.load(os.path.join(final_path, "final_walt_head.pt"), map_location=model.llm.device)) # Load Adapters adapter_path = os.path.join(final_path, "final_flux_adapters.pt") if os.path.exists(adapter_path): print(" Loading Flux Adapters...") adapter_states = torch.load(adapter_path, map_location=model.llm.device) for layer, state in zip(model.flux_layers, adapter_states): layer.load_state_dict(state) else: print(" ⚠️ Startled: Final adapters not found! Modulation might be dead.") else: # Fallback to latest checkpoint checkpoints = [] for p in search_paths: checkpoints.extend(glob.glob(os.path.join(p, "checkpoint_epoch_*.pt"))) if checkpoints: latest_ckpt = max(checkpoints, key=os.path.getctime) print(f" ⚠️ 'final' weights not found. Loading latest checkpoint: {latest_ckpt}") ckpt_data = torch.load(latest_ckpt, map_location=model.llm.device) # Check point uses specific keys, not full model_state_dict if 'controller_state_dict' in ckpt_data: model.controller.load_state_dict(ckpt_data['controller_state_dict']) model.walt.load_state_dict(ckpt_data['walt_state_dict']) if 'adapters_state_dict' in ckpt_data: print(" Loading Flux Adapters from Checkpoint...") for layer, state in zip(model.flux_layers, ckpt_data['adapters_state_dict']): layer.load_state_dict(state) else: # Fallback if we accidentally saved it differently in a previous run model.load_state_dict(ckpt_data['model_state_dict'], strict=False) else: raise FileNotFoundError("No 'final_physics_controller.pt' or 'checkpoint_epoch_*.pt' found.") print("✅ Weights Loaded.") except Exception as e: print(f"⚠️ Warning: Could not load weights: {e}") model.eval() return model def run_benchmark(): model = load_models() # Health Check try: if hasattr(model.flux_layers[0], 'lora_B'): lb_norm = model.flux_layers[0].lora_B.norm().item() print(f"\n🔍 Health Check - First Adapter LoRA_B Norm: {lb_norm:.6f}") if lb_norm == 0: print(" ❌ WARNING: LoRA weights are ZERO. Training failed to update weights.") else: print(" ✅ Weights are LEARNED (Non-Zero).") except: pass test_cases = [ # --- TYPE A: QUALITATIVE (Concept Checks) --- "I release a heavy steel marble from a height of one meter in a zero-gravity environment.", "I drop a plastic camping plate onto a marble floor from waist height.", "I shine a red laser beam through a glass prism.", # --- TYPE B: QUANTITATIVE (Math & Engineering) --- "A 2kg block slides down a frictionless ramp of height 5m. Calculate its velocity at the bottom. (g=9.8 m/s^2)", "A car accelerates from 0 to 20 m/s in 4 seconds. What is the average acceleration?", "A one-meter-long flexible cable lies at rest on a frictionless table, with 5 cm hanging over the edge. At what time will the cable completely slide off the table?", "If I mix 100g of ice at 0°C with 100g of water at 80°C, what is the final temperature? (Specific heat of water = 4.18 J/g°C)", ] results = [] print("\n" + "="*50) print(" 🧪 Physics Benchmark: Base vs Flux") print("="*50) for prompt in test_cases: full_prompt = f"User: {prompt}\nModel:" inputs = model.tokenizer(full_prompt, return_tensors="pt").to(model.llm.device) # --- Run 1: Base Model (No Modulation) --- model.clear_modulation() # Ensure no modulation # We can simulate "Base" by simply NOT calling set_active_modulation # Or by setting modulation to all zeros. # Let's set to zeros to be explicit. zero_mod = torch.zeros(1, Config.MODULATION_DIM).to(model.llm.device).to(Config.DTYPE) model.set_active_modulation(zero_mod) out_base = model.llm.generate(**inputs, max_new_tokens=100, max_length=Config.MAX_LENGTH, do_sample=False) # Greedy for base text_base = model.tokenizer.decode(out_base[0], skip_special_tokens=True).replace(full_prompt, "").strip() # --- Run 2: Flux Model (With RL Modulation) --- model.clear_modulation() # Thinking Step with torch.no_grad(): h_init = model.get_embeddings(inputs.input_ids).to(Config.DTYPE) modulation = model.controller(h_init) # Analyze Modulation strength mod_mag = modulation.norm().item() model.set_active_modulation(modulation) # --- Debug Trace (First 3 tokens) --- try: print("\n 🔍 Generation Trace (First 3 Steps):") trace_input = inputs.input_ids.clone() for i in range(3): # Base (No Mod) model.clear_modulation() out_base = model.llm.model(trace_input) base_norm = out_base.last_hidden_state[:,-1,:].norm().item() # Flux (Modulated) model.set_active_modulation(modulation) out_liq = model.llm.model(trace_input) liq_norm = out_liq.last_hidden_state[:,-1,:].norm().item() # Difference diff = out_liq.last_hidden_state[:,-1,:] - out_base.last_hidden_state[:,-1,:] diff_norm = diff.norm().item() ratio = (diff_norm / base_norm) * 100 print(f" Step {i}: Base={base_norm:.2f} | Flux={liq_norm:.2f} | Diff={diff_norm:.4f} ({ratio:.2f}%)") # Advance one step (Greedy) # Use internal lm_head to get logits logits = model.llm.lm_head(out_liq.last_hidden_state[:,-1,:].unsqueeze(0)) # Check dim if logits.dim() == 3: logits = logits[:, -1, :] next_token = torch.argmax(logits, dim=-1).unsqueeze(0) token_str = model.tokenizer.decode(next_token[0]) print(f" Selected Token: '{token_str}'") if next_token.dim() == 1: next_token = next_token.unsqueeze(0) trace_input = torch.cat([trace_input, next_token], dim=1) except Exception as e: print(f" ⚠️ Debug Trace Failed: {e}") # Reset for actual generation model.clear_modulation() model.set_active_modulation(modulation) out_liquid = model.llm.generate(**inputs, max_new_tokens=100, max_length=Config.MAX_LENGTH, do_sample=True, temperature=0.01) text_liquid = model.tokenizer.decode(out_liquid[0], skip_special_tokens=True).replace(full_prompt, "").strip() # Store Result res = { "Prompt": prompt, "Base": text_base, "Flux": text_liquid, "Modulation_Norm": mod_mag } results.append(res) print(f"\n📝 {prompt}") print(f" 🧊 Base: {text_base[:100]}...") print(f" 💧 Flux: {text_liquid[:100]}... (Mod Norm: {mod_mag:.2f})") # Save detailed report with open("benchmark_results.txt", "w") as f: for r in results: f.write(f"Prompt: {r['Prompt']}\n") f.write(f"Base Model: {r['Base']}\n") f.write(f"Flux Model: {r['Flux']}\n") f.write(f"Modulation Strength: {r['Modulation_Norm']:.4f}\n") f.write("-" * 30 + "\n") print("\n✅ Benchmark Complete. Saved to benchmark_results.txt") if __name__ == "__main__": run_benchmark()