convaiinnovations commited on
Commit
d6535c0
·
verified ·
1 Parent(s): 2243b52

Upload benchmark_physics.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. benchmark_physics.py +222 -0
benchmark_physics.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import logging
4
+ import os
5
+ import glob
6
+ from config_physics import Config
7
+ from modeling_physics_rl import PhysicsModel
8
+
9
+ # Setup logging
10
+ logging.basicConfig(level=logging.ERROR)
11
+
12
+ def load_models():
13
+ """
14
+ Loads two versions of the model:
15
+ 1. Flux Model: With trained Controller & Adapters active.
16
+ 2. Base Model: The exact same model but with modulation forced to ZERO.
17
+ """
18
+ print("⏳ Loading Physics Model...")
19
+ model = PhysicsModel()
20
+
21
+ # Move to GPU if available
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ model.to(device)
24
+ print(f" Using Device: {device}")
25
+
26
+ # Load Weights
27
+ # Define search paths
28
+ search_paths = [
29
+ ".",
30
+ "/kaggle/input/worldmodels/physics_model",
31
+ "/kaggle/working/physics_model"
32
+ ]
33
+
34
+ # Check for weights
35
+ final_path = None
36
+ for p in search_paths:
37
+ fpath = os.path.join(p, "final_physics_controller.pt")
38
+ if os.path.exists(fpath):
39
+ final_path = p
40
+ break
41
+
42
+ try:
43
+ if final_path:
44
+ print(f" Loading Final Weights from {final_path}...")
45
+ model.controller.load_state_dict(torch.load(os.path.join(final_path, "final_physics_controller.pt"), map_location=model.llm.device))
46
+ model.walt.load_state_dict(torch.load(os.path.join(final_path, "final_walt_head.pt"), map_location=model.llm.device))
47
+
48
+ # Load Adapters
49
+ adapter_path = os.path.join(final_path, "final_flux_adapters.pt")
50
+ if os.path.exists(adapter_path):
51
+ print(" Loading Flux Adapters...")
52
+ adapter_states = torch.load(adapter_path, map_location=model.llm.device)
53
+ for layer, state in zip(model.flux_layers, adapter_states):
54
+ layer.load_state_dict(state)
55
+ else:
56
+ print(" ⚠️ Startled: Final adapters not found! Modulation might be dead.")
57
+
58
+ else:
59
+ # Fallback to latest checkpoint
60
+ checkpoints = []
61
+ for p in search_paths:
62
+ checkpoints.extend(glob.glob(os.path.join(p, "checkpoint_epoch_*.pt")))
63
+
64
+ if checkpoints:
65
+ latest_ckpt = max(checkpoints, key=os.path.getctime)
66
+ print(f" ⚠️ 'final' weights not found. Loading latest checkpoint: {latest_ckpt}")
67
+ ckpt_data = torch.load(latest_ckpt, map_location=model.llm.device)
68
+
69
+ # Check point uses specific keys, not full model_state_dict
70
+ if 'controller_state_dict' in ckpt_data:
71
+ model.controller.load_state_dict(ckpt_data['controller_state_dict'])
72
+ model.walt.load_state_dict(ckpt_data['walt_state_dict'])
73
+
74
+ if 'adapters_state_dict' in ckpt_data:
75
+ print(" Loading Flux Adapters from Checkpoint...")
76
+ for layer, state in zip(model.flux_layers, ckpt_data['adapters_state_dict']):
77
+ layer.load_state_dict(state)
78
+ else:
79
+ # Fallback if we accidentally saved it differently in a previous run
80
+ model.load_state_dict(ckpt_data['model_state_dict'], strict=False)
81
+ else:
82
+ raise FileNotFoundError("No 'final_physics_controller.pt' or 'checkpoint_epoch_*.pt' found.")
83
+
84
+ print("✅ Weights Loaded.")
85
+ except Exception as e:
86
+ print(f"⚠️ Warning: Could not load weights: {e}")
87
+
88
+ model.eval()
89
+ return model
90
+
91
+ def run_benchmark():
92
+ model = load_models()
93
+
94
+ # Health Check
95
+ try:
96
+ if hasattr(model.flux_layers[0], 'lora_B'):
97
+ lb_norm = model.flux_layers[0].lora_B.norm().item()
98
+ print(f"\n🔍 Health Check - First Adapter LoRA_B Norm: {lb_norm:.6f}")
99
+ if lb_norm == 0:
100
+ print(" ❌ WARNING: LoRA weights are ZERO. Training failed to update weights.")
101
+ else:
102
+ print(" ✅ Weights are LEARNED (Non-Zero).")
103
+ except: pass
104
+
105
+ test_cases = [
106
+ # --- TYPE A: QUALITATIVE (Concept Checks) ---
107
+ "I release a heavy steel marble from a height of one meter in a zero-gravity environment.",
108
+ "I drop a plastic camping plate onto a marble floor from waist height.",
109
+ "I shine a red laser beam through a glass prism.",
110
+
111
+ # --- TYPE B: QUANTITATIVE (Math & Engineering) ---
112
+ "A 2kg block slides down a frictionless ramp of height 5m. Calculate its velocity at the bottom. (g=9.8 m/s^2)",
113
+ "A car accelerates from 0 to 20 m/s in 4 seconds. What is the average acceleration?",
114
+ "A one-meter-long flexible cable lies at rest on a frictionless table, with 5 cm hanging over the edge. At what time will the cable completely slide off the table?",
115
+ "If I mix 100g of ice at 0°C with 100g of water at 80°C, what is the final temperature? (Specific heat of water = 4.18 J/g°C)",
116
+ ]
117
+
118
+ results = []
119
+
120
+ print("\n" + "="*50)
121
+ print(" 🧪 Physics Benchmark: Base vs Flux")
122
+ print("="*50)
123
+
124
+ for prompt in test_cases:
125
+ full_prompt = f"User: {prompt}\nModel:"
126
+ inputs = model.tokenizer(full_prompt, return_tensors="pt").to(model.llm.device)
127
+
128
+ # --- Run 1: Base Model (No Modulation) ---
129
+ model.clear_modulation() # Ensure no modulation
130
+ # We can simulate "Base" by simply NOT calling set_active_modulation
131
+ # Or by setting modulation to all zeros.
132
+ # Let's set to zeros to be explicit.
133
+ zero_mod = torch.zeros(1, Config.MODULATION_DIM).to(model.llm.device).to(Config.DTYPE)
134
+ model.set_active_modulation(zero_mod)
135
+
136
+ out_base = model.llm.generate(**inputs, max_new_tokens=100, max_length=Config.MAX_LENGTH, do_sample=False) # Greedy for base
137
+ text_base = model.tokenizer.decode(out_base[0], skip_special_tokens=True).replace(full_prompt, "").strip()
138
+
139
+ # --- Run 2: Flux Model (With RL Modulation) ---
140
+ model.clear_modulation()
141
+
142
+ # Thinking Step
143
+ with torch.no_grad():
144
+ h_init = model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
145
+ modulation = model.controller(h_init)
146
+
147
+ # Analyze Modulation strength
148
+ mod_mag = modulation.norm().item()
149
+
150
+ model.set_active_modulation(modulation)
151
+
152
+ # --- Debug Trace (First 3 tokens) ---
153
+ try:
154
+ print("\n 🔍 Generation Trace (First 3 Steps):")
155
+ trace_input = inputs.input_ids.clone()
156
+ for i in range(3):
157
+ # Base (No Mod)
158
+ model.clear_modulation()
159
+ out_base = model.llm.model(trace_input)
160
+ base_norm = out_base.last_hidden_state[:,-1,:].norm().item()
161
+
162
+ # Flux (Modulated)
163
+ model.set_active_modulation(modulation)
164
+ out_liq = model.llm.model(trace_input)
165
+ liq_norm = out_liq.last_hidden_state[:,-1,:].norm().item()
166
+
167
+ # Difference
168
+ diff = out_liq.last_hidden_state[:,-1,:] - out_base.last_hidden_state[:,-1,:]
169
+ diff_norm = diff.norm().item()
170
+ ratio = (diff_norm / base_norm) * 100
171
+
172
+ print(f" Step {i}: Base={base_norm:.2f} | Flux={liq_norm:.2f} | Diff={diff_norm:.4f} ({ratio:.2f}%)")
173
+
174
+ # Advance one step (Greedy)
175
+ # Use internal lm_head to get logits
176
+ logits = model.llm.lm_head(out_liq.last_hidden_state[:,-1,:].unsqueeze(0))
177
+ # Check dim
178
+ if logits.dim() == 3: logits = logits[:, -1, :]
179
+
180
+ next_token = torch.argmax(logits, dim=-1).unsqueeze(0)
181
+ token_str = model.tokenizer.decode(next_token[0])
182
+ print(f" Selected Token: '{token_str}'")
183
+
184
+ if next_token.dim() == 1: next_token = next_token.unsqueeze(0)
185
+
186
+ trace_input = torch.cat([trace_input, next_token], dim=1)
187
+ except Exception as e:
188
+ print(f" ⚠️ Debug Trace Failed: {e}")
189
+
190
+ # Reset for actual generation
191
+ model.clear_modulation()
192
+ model.set_active_modulation(modulation)
193
+
194
+ out_liquid = model.llm.generate(**inputs, max_new_tokens=100, max_length=Config.MAX_LENGTH, do_sample=True, temperature=0.01)
195
+ text_liquid = model.tokenizer.decode(out_liquid[0], skip_special_tokens=True).replace(full_prompt, "").strip()
196
+
197
+ # Store Result
198
+ res = {
199
+ "Prompt": prompt,
200
+ "Base": text_base,
201
+ "Flux": text_liquid,
202
+ "Modulation_Norm": mod_mag
203
+ }
204
+ results.append(res)
205
+
206
+ print(f"\n📝 {prompt}")
207
+ print(f" 🧊 Base: {text_base[:100]}...")
208
+ print(f" 💧 Flux: {text_liquid[:100]}... (Mod Norm: {mod_mag:.2f})")
209
+
210
+ # Save detailed report
211
+ with open("benchmark_results.txt", "w") as f:
212
+ for r in results:
213
+ f.write(f"Prompt: {r['Prompt']}\n")
214
+ f.write(f"Base Model: {r['Base']}\n")
215
+ f.write(f"Flux Model: {r['Flux']}\n")
216
+ f.write(f"Modulation Strength: {r['Modulation_Norm']:.4f}\n")
217
+ f.write("-" * 30 + "\n")
218
+
219
+ print("\n✅ Benchmark Complete. Saved to benchmark_results.txt")
220
+
221
+ if __name__ == "__main__":
222
+ run_benchmark()