import torch import os import json import logging import torch.optim as optim from config_physics import Config from modeling_physics_rl import PhysicsModel # Setup Logging logging.basicConfig(level=logging.INFO, format="%(message)s") logger = logging.getLogger(__name__) def run_auto_ttt(): print("\n" + "="*50) print(" šŸ¤– DATA CENTER MODE: Automated TTT (Test-Time Training)") print("="*50) # 1. Load Model print("ā³ Loading Physics Model...") model = PhysicsModel() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Load Adapters (Generic Path Logic) search_paths = [".", "/kaggle/input/worldmodels/physics_model", "/kaggle/working/physics_model"] for p in search_paths: fpath = os.path.join(p, "final_flux_adapters.pt") if os.path.exists(fpath): print(f" Loading Flux Adapters from {fpath}...") adapter_states = torch.load(fpath, map_location=device) # Handle list vs dict (safe load) if isinstance(adapter_states, dict): # If it's a state_dict of the whole model (rare but possible) pass elif isinstance(adapter_states, list): for layer, state in zip(model.flux_layers, adapter_states): layer.load_state_dict(state) break # Load Controller for p in search_paths: fpath = os.path.join(p, "final_physics_controller.pt") if os.path.exists(fpath): print(f" Loading Controller from {fpath}...") model.controller.load_state_dict(torch.load(fpath, map_location=device)) break # 2. Setup Meta-Optimizer (AdamW) # We optimize the Controller AND the Adapter Projections params = list(model.controller.parameters()) for layer in model.flux_layers: params.extend(list(layer.modulation_proj.parameters())) optimizer = optim.AdamW(params, lr=1e-3) # High LR for fast adaptation # 3. Define Test Cases (Scenario, Prompt, Correct Answer) test_cases = [ { "scenario": "Zero Gravity", "prompt": "I drop a heavy hammer inside a space station. What happens?", "correct_answer": "The hammer floats in place. Inside a space station in orbit, objects are in freefall and appear weightless (microgravity). It does not fall to the floor." }, { "scenario": "Moon Gravity", "prompt": "I drop a feather and a hammer on the Moon. Which hits the ground first?", "correct_answer": "They hit the ground at the same time. On the Moon, there is no air resistance, so gravity accelerates all objects at the same rate regardless of mass." }, { "scenario": "Underwater", "prompt": "I release a helium balloon underwater. Which way does it go?", "correct_answer": "The balloon floats UP. The buoyant force from the water is greater than the weight of the balloon." } ] print(f"\nšŸš€ Starting Automation Loop ({len(test_cases)} scenarios)...") for i, case in enumerate(test_cases): print(f"\n--------------------------------------------------") print(f"šŸ“ Scenario {i+1}: {case['scenario']}") print(f" Question: \"{case['prompt']}\"") # --- Step A: Initial Inference --- inputs = model.tokenizer(f"User: {case['prompt']}\nModel:", return_tensors="pt").to(device) # Thinking (Dynamics Pass) h_init = model.get_embeddings(inputs.input_ids).to(Config.DTYPE) modulation = model.controller(h_init) mod_norm = modulation.norm().item() # Generate Text model.set_active_modulation(modulation) out = model.llm.generate(**inputs, max_new_tokens=60, do_sample=False) model.clear_modulation() text_initial = model.tokenizer.decode(out[0], skip_special_tokens=True).split("Model:")[-1].strip() print(f" šŸ¤– Initial Answer: {text_initial}") print(f" šŸ“Š Modulation Norm: {mod_norm:.4f}") # --- Step B: "User" Correction (Simulated) --- print(f" šŸ’” Teaching: \"{case['correct_answer']}\"") # Prepare Training Data full_text_correct = f"User: {case['prompt']}\nModel: {case['correct_answer']}" inputs_correct = model.tokenizer(full_text_correct, return_tensors="pt").to(device) labels = inputs_correct.input_ids.clone() # --- Step C: Test-Time Update (The Learning) --- model.train() print(f" 🧠 Adapting Weights (30 steps)...") for step in range(30): # INCREASED STEPS AGAIN optimizer.zero_grad() # ... (Forward/Backward logic remains same) ... # 1. Controller sees Prompt h_prompt = model.get_embeddings(inputs.input_ids).to(Config.DTYPE) mod_pred = model.controller(h_prompt) # 2. LLM sees Full Sequence (forced by mod_pred) logits = model(inputs_correct.input_ids, forced_modulation=mod_pred) # 3. Loss shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = torch.nn.functional.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) loss.backward() optimizer.step() # Logging convergence if (step + 1) % 10 == 0: print(f" Step {step+1}: Loss = {loss.item():.4f}") # --- Step D: Verify Adaptation --- model.eval() with torch.no_grad(): h_new = model.get_embeddings(inputs.input_ids).to(Config.DTYPE) mod_new = model.controller(h_new) model.set_active_modulation(mod_new) out_new = model.llm.generate(**inputs, max_new_tokens=60, do_sample=False) model.clear_modulation() text_new = model.tokenizer.decode(out_new[0], skip_special_tokens=True).split("Model:")[-1].strip() print(f" šŸŽ“ New Answer: {text_new}") print(f" šŸ“ˆ New Mod Norm: {mod_new.norm().item():.4f}") # 4. Save TTT Weights print("\nšŸ’¾ Saving Adapted Weights...") torch.save(model.controller.state_dict(), "ttt_physics_controller.pt") # Save Adapters adapter_states = [layer.state_dict() for layer in model.flux_layers] torch.save(adapter_states, "ttt_flux_adapters.pt") print("āœ… Saved to 'ttt_physics_controller.pt' and 'ttt_flux_adapters.pt'") if __name__ == "__main__": run_auto_ttt()