| |
|
| | import torch |
| | import logging |
| | import os |
| | import glob |
| | from config_physics import Config |
| | from modeling_physics_rl import PhysicsModel |
| |
|
| | |
| | logging.basicConfig(level=logging.ERROR) |
| |
|
| | def load_models(): |
| | """ |
| | Loads two versions of the model: |
| | 1. Flux Model: With trained Controller & Adapters active. |
| | 2. Base Model: The exact same model but with modulation forced to ZERO. |
| | """ |
| | print("⏳ Loading Physics Model...") |
| | model = PhysicsModel() |
| | |
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | model.to(device) |
| | print(f" Using Device: {device}") |
| | |
| | |
| | |
| | search_paths = [ |
| | ".", |
| | "/kaggle/input/worldmodels/physics_model", |
| | "/kaggle/working/physics_model" |
| | ] |
| | |
| | |
| | final_path = None |
| | for p in search_paths: |
| | fpath = os.path.join(p, "final_physics_controller.pt") |
| | if os.path.exists(fpath): |
| | final_path = p |
| | break |
| | |
| | try: |
| | if final_path: |
| | print(f" Loading Final Weights from {final_path}...") |
| | model.controller.load_state_dict(torch.load(os.path.join(final_path, "final_physics_controller.pt"), map_location=model.llm.device)) |
| | model.walt.load_state_dict(torch.load(os.path.join(final_path, "final_walt_head.pt"), map_location=model.llm.device)) |
| | |
| | |
| | adapter_path = os.path.join(final_path, "final_flux_adapters.pt") |
| | if os.path.exists(adapter_path): |
| | print(" Loading Flux Adapters...") |
| | adapter_states = torch.load(adapter_path, map_location=model.llm.device) |
| | for layer, state in zip(model.flux_layers, adapter_states): |
| | layer.load_state_dict(state) |
| | else: |
| | print(" ⚠️ Startled: Final adapters not found! Modulation might be dead.") |
| |
|
| | else: |
| | |
| | checkpoints = [] |
| | for p in search_paths: |
| | checkpoints.extend(glob.glob(os.path.join(p, "checkpoint_epoch_*.pt"))) |
| | |
| | if checkpoints: |
| | latest_ckpt = max(checkpoints, key=os.path.getctime) |
| | print(f" ⚠️ 'final' weights not found. Loading latest checkpoint: {latest_ckpt}") |
| | ckpt_data = torch.load(latest_ckpt, map_location=model.llm.device) |
| | |
| | |
| | if 'controller_state_dict' in ckpt_data: |
| | model.controller.load_state_dict(ckpt_data['controller_state_dict']) |
| | model.walt.load_state_dict(ckpt_data['walt_state_dict']) |
| | |
| | if 'adapters_state_dict' in ckpt_data: |
| | print(" Loading Flux Adapters from Checkpoint...") |
| | for layer, state in zip(model.flux_layers, ckpt_data['adapters_state_dict']): |
| | layer.load_state_dict(state) |
| | else: |
| | |
| | model.load_state_dict(ckpt_data['model_state_dict'], strict=False) |
| | else: |
| | raise FileNotFoundError("No 'final_physics_controller.pt' or 'checkpoint_epoch_*.pt' found.") |
| | |
| | print("✅ Weights Loaded.") |
| | except Exception as e: |
| | print(f"⚠️ Warning: Could not load weights: {e}") |
| |
|
| | model.eval() |
| | return model |
| |
|
| | def run_benchmark(): |
| | model = load_models() |
| | |
| | |
| | try: |
| | if hasattr(model.flux_layers[0], 'lora_B'): |
| | lb_norm = model.flux_layers[0].lora_B.norm().item() |
| | print(f"\n🔍 Health Check - First Adapter LoRA_B Norm: {lb_norm:.6f}") |
| | if lb_norm == 0: |
| | print(" ❌ WARNING: LoRA weights are ZERO. Training failed to update weights.") |
| | else: |
| | print(" ✅ Weights are LEARNED (Non-Zero).") |
| | except: pass |
| |
|
| | test_cases = [ |
| | |
| | "I release a heavy steel marble from a height of one meter in a zero-gravity environment.", |
| | "I drop a plastic camping plate onto a marble floor from waist height.", |
| | "I shine a red laser beam through a glass prism.", |
| | |
| | |
| | "A 2kg block slides down a frictionless ramp of height 5m. Calculate its velocity at the bottom. (g=9.8 m/s^2)", |
| | "A car accelerates from 0 to 20 m/s in 4 seconds. What is the average acceleration?", |
| | "A one-meter-long flexible cable lies at rest on a frictionless table, with 5 cm hanging over the edge. At what time will the cable completely slide off the table?", |
| | "If I mix 100g of ice at 0°C with 100g of water at 80°C, what is the final temperature? (Specific heat of water = 4.18 J/g°C)", |
| | ] |
| | |
| | results = [] |
| | |
| | print("\n" + "="*50) |
| | print(" 🧪 Physics Benchmark: Base vs Flux") |
| | print("="*50) |
| | |
| | for prompt in test_cases: |
| | full_prompt = f"User: {prompt}\nModel:" |
| | inputs = model.tokenizer(full_prompt, return_tensors="pt").to(model.llm.device) |
| | |
| | |
| | model.clear_modulation() |
| | |
| | |
| | |
| | zero_mod = torch.zeros(1, Config.MODULATION_DIM).to(model.llm.device).to(Config.DTYPE) |
| | model.set_active_modulation(zero_mod) |
| | |
| | out_base = model.llm.generate(**inputs, max_new_tokens=100, max_length=Config.MAX_LENGTH, do_sample=False) |
| | text_base = model.tokenizer.decode(out_base[0], skip_special_tokens=True).replace(full_prompt, "").strip() |
| | |
| | |
| | model.clear_modulation() |
| | |
| | |
| | with torch.no_grad(): |
| | h_init = model.get_embeddings(inputs.input_ids).to(Config.DTYPE) |
| | modulation = model.controller(h_init) |
| | |
| | |
| | mod_mag = modulation.norm().item() |
| | |
| | model.set_active_modulation(modulation) |
| | |
| | |
| | try: |
| | print("\n 🔍 Generation Trace (First 3 Steps):") |
| | trace_input = inputs.input_ids.clone() |
| | for i in range(3): |
| | |
| | model.clear_modulation() |
| | out_base = model.llm.model(trace_input) |
| | base_norm = out_base.last_hidden_state[:,-1,:].norm().item() |
| | |
| | |
| | model.set_active_modulation(modulation) |
| | out_liq = model.llm.model(trace_input) |
| | liq_norm = out_liq.last_hidden_state[:,-1,:].norm().item() |
| | |
| | |
| | diff = out_liq.last_hidden_state[:,-1,:] - out_base.last_hidden_state[:,-1,:] |
| | diff_norm = diff.norm().item() |
| | ratio = (diff_norm / base_norm) * 100 |
| | |
| | print(f" Step {i}: Base={base_norm:.2f} | Flux={liq_norm:.2f} | Diff={diff_norm:.4f} ({ratio:.2f}%)") |
| | |
| | |
| | |
| | logits = model.llm.lm_head(out_liq.last_hidden_state[:,-1,:].unsqueeze(0)) |
| | |
| | if logits.dim() == 3: logits = logits[:, -1, :] |
| | |
| | next_token = torch.argmax(logits, dim=-1).unsqueeze(0) |
| | token_str = model.tokenizer.decode(next_token[0]) |
| | print(f" Selected Token: '{token_str}'") |
| |
|
| | if next_token.dim() == 1: next_token = next_token.unsqueeze(0) |
| | |
| | trace_input = torch.cat([trace_input, next_token], dim=1) |
| | except Exception as e: |
| | print(f" ⚠️ Debug Trace Failed: {e}") |
| | |
| | |
| | model.clear_modulation() |
| | model.set_active_modulation(modulation) |
| | |
| | out_liquid = model.llm.generate(**inputs, max_new_tokens=100, max_length=Config.MAX_LENGTH, do_sample=True, temperature=0.01) |
| | text_liquid = model.tokenizer.decode(out_liquid[0], skip_special_tokens=True).replace(full_prompt, "").strip() |
| | |
| | |
| | res = { |
| | "Prompt": prompt, |
| | "Base": text_base, |
| | "Flux": text_liquid, |
| | "Modulation_Norm": mod_mag |
| | } |
| | results.append(res) |
| | |
| | print(f"\n📝 {prompt}") |
| | print(f" 🧊 Base: {text_base[:100]}...") |
| | print(f" 💧 Flux: {text_liquid[:100]}... (Mod Norm: {mod_mag:.2f})") |
| | |
| | |
| | with open("benchmark_results.txt", "w") as f: |
| | for r in results: |
| | f.write(f"Prompt: {r['Prompt']}\n") |
| | f.write(f"Base Model: {r['Base']}\n") |
| | f.write(f"Flux Model: {r['Flux']}\n") |
| | f.write(f"Modulation Strength: {r['Modulation_Norm']:.4f}\n") |
| | f.write("-" * 30 + "\n") |
| | |
| | print("\n✅ Benchmark Complete. Saved to benchmark_results.txt") |
| |
|
| | if __name__ == "__main__": |
| | run_benchmark() |
| |
|