File size: 6,888 Bytes
2243b52 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import torch
import os
import json
import logging
import torch.optim as optim
from config_physics import Config
from modeling_physics_rl import PhysicsModel
# Setup Logging
logging.basicConfig(level=logging.INFO, format="%(message)s")
logger = logging.getLogger(__name__)
def run_auto_ttt():
print("\n" + "="*50)
print(" π€ DATA CENTER MODE: Automated TTT (Test-Time Training)")
print("="*50)
# 1. Load Model
print("β³ Loading Physics Model...")
model = PhysicsModel()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Load Adapters (Generic Path Logic)
search_paths = [".", "/kaggle/input/worldmodels/physics_model", "/kaggle/working/physics_model"]
for p in search_paths:
fpath = os.path.join(p, "final_flux_adapters.pt")
if os.path.exists(fpath):
print(f" Loading Flux Adapters from {fpath}...")
adapter_states = torch.load(fpath, map_location=device)
# Handle list vs dict (safe load)
if isinstance(adapter_states, dict):
# If it's a state_dict of the whole model (rare but possible)
pass
elif isinstance(adapter_states, list):
for layer, state in zip(model.flux_layers, adapter_states):
layer.load_state_dict(state)
break
# Load Controller
for p in search_paths:
fpath = os.path.join(p, "final_physics_controller.pt")
if os.path.exists(fpath):
print(f" Loading Controller from {fpath}...")
model.controller.load_state_dict(torch.load(fpath, map_location=device))
break
# 2. Setup Meta-Optimizer (AdamW)
# We optimize the Controller AND the Adapter Projections
params = list(model.controller.parameters())
for layer in model.flux_layers:
params.extend(list(layer.modulation_proj.parameters()))
optimizer = optim.AdamW(params, lr=1e-3) # High LR for fast adaptation
# 3. Define Test Cases (Scenario, Prompt, Correct Answer)
test_cases = [
{
"scenario": "Zero Gravity",
"prompt": "I drop a heavy hammer inside a space station. What happens?",
"correct_answer": "The hammer floats in place. Inside a space station in orbit, objects are in freefall and appear weightless (microgravity). It does not fall to the floor."
},
{
"scenario": "Moon Gravity",
"prompt": "I drop a feather and a hammer on the Moon. Which hits the ground first?",
"correct_answer": "They hit the ground at the same time. On the Moon, there is no air resistance, so gravity accelerates all objects at the same rate regardless of mass."
},
{
"scenario": "Underwater",
"prompt": "I release a helium balloon underwater. Which way does it go?",
"correct_answer": "The balloon floats UP. The buoyant force from the water is greater than the weight of the balloon."
}
]
print(f"\nπ Starting Automation Loop ({len(test_cases)} scenarios)...")
for i, case in enumerate(test_cases):
print(f"\n--------------------------------------------------")
print(f"π Scenario {i+1}: {case['scenario']}")
print(f" Question: \"{case['prompt']}\"")
# --- Step A: Initial Inference ---
inputs = model.tokenizer(f"User: {case['prompt']}\nModel:", return_tensors="pt").to(device)
# Thinking (Dynamics Pass)
h_init = model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
modulation = model.controller(h_init)
mod_norm = modulation.norm().item()
# Generate Text
model.set_active_modulation(modulation)
out = model.llm.generate(**inputs, max_new_tokens=60, do_sample=False)
model.clear_modulation()
text_initial = model.tokenizer.decode(out[0], skip_special_tokens=True).split("Model:")[-1].strip()
print(f" π€ Initial Answer: {text_initial}")
print(f" π Modulation Norm: {mod_norm:.4f}")
# --- Step B: "User" Correction (Simulated) ---
print(f" π‘ Teaching: \"{case['correct_answer']}\"")
# Prepare Training Data
full_text_correct = f"User: {case['prompt']}\nModel: {case['correct_answer']}"
inputs_correct = model.tokenizer(full_text_correct, return_tensors="pt").to(device)
labels = inputs_correct.input_ids.clone()
# --- Step C: Test-Time Update (The Learning) ---
model.train()
print(f" π§ Adapting Weights (30 steps)...")
for step in range(30): # INCREASED STEPS AGAIN
optimizer.zero_grad()
# ... (Forward/Backward logic remains same) ...
# 1. Controller sees Prompt
h_prompt = model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
mod_pred = model.controller(h_prompt)
# 2. LLM sees Full Sequence (forced by mod_pred)
logits = model(inputs_correct.input_ids, forced_modulation=mod_pred)
# 3. Loss
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = torch.nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
loss.backward()
optimizer.step()
# Logging convergence
if (step + 1) % 10 == 0:
print(f" Step {step+1}: Loss = {loss.item():.4f}")
# --- Step D: Verify Adaptation ---
model.eval()
with torch.no_grad():
h_new = model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
mod_new = model.controller(h_new)
model.set_active_modulation(mod_new)
out_new = model.llm.generate(**inputs, max_new_tokens=60, do_sample=False)
model.clear_modulation()
text_new = model.tokenizer.decode(out_new[0], skip_special_tokens=True).split("Model:")[-1].strip()
print(f" π New Answer: {text_new}")
print(f" π New Mod Norm: {mod_new.norm().item():.4f}")
# 4. Save TTT Weights
print("\nπΎ Saving Adapted Weights...")
torch.save(model.controller.state_dict(), "ttt_physics_controller.pt")
# Save Adapters
adapter_states = [layer.state_dict() for layer in model.flux_layers]
torch.save(adapter_states, "ttt_flux_adapters.pt")
print("β
Saved to 'ttt_physics_controller.pt' and 'ttt_flux_adapters.pt'")
if __name__ == "__main__":
run_auto_ttt()
|