|
|
|
|
|
import torch |
|
|
import os |
|
|
import json |
|
|
import logging |
|
|
import torch.optim as optim |
|
|
from config_physics import Config |
|
|
from modeling_physics_rl import PhysicsModel |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(message)s") |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def run_auto_ttt(): |
|
|
print("\n" + "="*50) |
|
|
print(" π€ DATA CENTER MODE: Automated TTT (Test-Time Training)") |
|
|
print("="*50) |
|
|
|
|
|
|
|
|
print("β³ Loading Physics Model...") |
|
|
model = PhysicsModel() |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
search_paths = [".", "/kaggle/input/worldmodels/physics_model", "/kaggle/working/physics_model"] |
|
|
for p in search_paths: |
|
|
fpath = os.path.join(p, "final_flux_adapters.pt") |
|
|
if os.path.exists(fpath): |
|
|
print(f" Loading Flux Adapters from {fpath}...") |
|
|
adapter_states = torch.load(fpath, map_location=device) |
|
|
|
|
|
if isinstance(adapter_states, dict): |
|
|
|
|
|
pass |
|
|
elif isinstance(adapter_states, list): |
|
|
for layer, state in zip(model.flux_layers, adapter_states): |
|
|
layer.load_state_dict(state) |
|
|
break |
|
|
|
|
|
|
|
|
for p in search_paths: |
|
|
fpath = os.path.join(p, "final_physics_controller.pt") |
|
|
if os.path.exists(fpath): |
|
|
print(f" Loading Controller from {fpath}...") |
|
|
model.controller.load_state_dict(torch.load(fpath, map_location=device)) |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
params = list(model.controller.parameters()) |
|
|
for layer in model.flux_layers: |
|
|
params.extend(list(layer.modulation_proj.parameters())) |
|
|
|
|
|
optimizer = optim.AdamW(params, lr=1e-3) |
|
|
|
|
|
|
|
|
test_cases = [ |
|
|
{ |
|
|
"scenario": "Zero Gravity", |
|
|
"prompt": "I drop a heavy hammer inside a space station. What happens?", |
|
|
"correct_answer": "The hammer floats in place. Inside a space station in orbit, objects are in freefall and appear weightless (microgravity). It does not fall to the floor." |
|
|
}, |
|
|
{ |
|
|
"scenario": "Moon Gravity", |
|
|
"prompt": "I drop a feather and a hammer on the Moon. Which hits the ground first?", |
|
|
"correct_answer": "They hit the ground at the same time. On the Moon, there is no air resistance, so gravity accelerates all objects at the same rate regardless of mass." |
|
|
}, |
|
|
{ |
|
|
"scenario": "Underwater", |
|
|
"prompt": "I release a helium balloon underwater. Which way does it go?", |
|
|
"correct_answer": "The balloon floats UP. The buoyant force from the water is greater than the weight of the balloon." |
|
|
} |
|
|
] |
|
|
|
|
|
print(f"\nπ Starting Automation Loop ({len(test_cases)} scenarios)...") |
|
|
|
|
|
for i, case in enumerate(test_cases): |
|
|
print(f"\n--------------------------------------------------") |
|
|
print(f"π Scenario {i+1}: {case['scenario']}") |
|
|
print(f" Question: \"{case['prompt']}\"") |
|
|
|
|
|
|
|
|
inputs = model.tokenizer(f"User: {case['prompt']}\nModel:", return_tensors="pt").to(device) |
|
|
|
|
|
|
|
|
h_init = model.get_embeddings(inputs.input_ids).to(Config.DTYPE) |
|
|
modulation = model.controller(h_init) |
|
|
mod_norm = modulation.norm().item() |
|
|
|
|
|
|
|
|
model.set_active_modulation(modulation) |
|
|
out = model.llm.generate(**inputs, max_new_tokens=60, do_sample=False) |
|
|
model.clear_modulation() |
|
|
|
|
|
text_initial = model.tokenizer.decode(out[0], skip_special_tokens=True).split("Model:")[-1].strip() |
|
|
print(f" π€ Initial Answer: {text_initial}") |
|
|
print(f" π Modulation Norm: {mod_norm:.4f}") |
|
|
|
|
|
|
|
|
print(f" π‘ Teaching: \"{case['correct_answer']}\"") |
|
|
|
|
|
|
|
|
full_text_correct = f"User: {case['prompt']}\nModel: {case['correct_answer']}" |
|
|
inputs_correct = model.tokenizer(full_text_correct, return_tensors="pt").to(device) |
|
|
labels = inputs_correct.input_ids.clone() |
|
|
|
|
|
|
|
|
model.train() |
|
|
print(f" π§ Adapting Weights (30 steps)...") |
|
|
|
|
|
for step in range(30): |
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
h_prompt = model.get_embeddings(inputs.input_ids).to(Config.DTYPE) |
|
|
mod_pred = model.controller(h_prompt) |
|
|
|
|
|
|
|
|
logits = model(inputs_correct.input_ids, forced_modulation=mod_pred) |
|
|
|
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
loss = torch.nn.functional.cross_entropy( |
|
|
shift_logits.view(-1, shift_logits.size(-1)), |
|
|
shift_labels.view(-1) |
|
|
) |
|
|
|
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
if (step + 1) % 10 == 0: |
|
|
print(f" Step {step+1}: Loss = {loss.item():.4f}") |
|
|
|
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
h_new = model.get_embeddings(inputs.input_ids).to(Config.DTYPE) |
|
|
mod_new = model.controller(h_new) |
|
|
model.set_active_modulation(mod_new) |
|
|
out_new = model.llm.generate(**inputs, max_new_tokens=60, do_sample=False) |
|
|
model.clear_modulation() |
|
|
|
|
|
text_new = model.tokenizer.decode(out_new[0], skip_special_tokens=True).split("Model:")[-1].strip() |
|
|
|
|
|
print(f" π New Answer: {text_new}") |
|
|
print(f" π New Mod Norm: {mod_new.norm().item():.4f}") |
|
|
|
|
|
|
|
|
print("\nπΎ Saving Adapted Weights...") |
|
|
torch.save(model.controller.state_dict(), "ttt_physics_controller.pt") |
|
|
|
|
|
|
|
|
adapter_states = [layer.state_dict() for layer in model.flux_layers] |
|
|
torch.save(adapter_states, "ttt_flux_adapters.pt") |
|
|
print("β
Saved to 'ttt_physics_controller.pt' and 'ttt_flux_adapters.pt'") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
run_auto_ttt() |
|
|
|