flux-test-time-training / continuous_learning_auto.py
convaiinnovations's picture
Upload continuous_learning_auto.py with huggingface_hub
2243b52 verified
import torch
import os
import json
import logging
import torch.optim as optim
from config_physics import Config
from modeling_physics_rl import PhysicsModel
# Setup Logging
logging.basicConfig(level=logging.INFO, format="%(message)s")
logger = logging.getLogger(__name__)
def run_auto_ttt():
print("\n" + "="*50)
print(" πŸ€– DATA CENTER MODE: Automated TTT (Test-Time Training)")
print("="*50)
# 1. Load Model
print("⏳ Loading Physics Model...")
model = PhysicsModel()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Load Adapters (Generic Path Logic)
search_paths = [".", "/kaggle/input/worldmodels/physics_model", "/kaggle/working/physics_model"]
for p in search_paths:
fpath = os.path.join(p, "final_flux_adapters.pt")
if os.path.exists(fpath):
print(f" Loading Flux Adapters from {fpath}...")
adapter_states = torch.load(fpath, map_location=device)
# Handle list vs dict (safe load)
if isinstance(adapter_states, dict):
# If it's a state_dict of the whole model (rare but possible)
pass
elif isinstance(adapter_states, list):
for layer, state in zip(model.flux_layers, adapter_states):
layer.load_state_dict(state)
break
# Load Controller
for p in search_paths:
fpath = os.path.join(p, "final_physics_controller.pt")
if os.path.exists(fpath):
print(f" Loading Controller from {fpath}...")
model.controller.load_state_dict(torch.load(fpath, map_location=device))
break
# 2. Setup Meta-Optimizer (AdamW)
# We optimize the Controller AND the Adapter Projections
params = list(model.controller.parameters())
for layer in model.flux_layers:
params.extend(list(layer.modulation_proj.parameters()))
optimizer = optim.AdamW(params, lr=1e-3) # High LR for fast adaptation
# 3. Define Test Cases (Scenario, Prompt, Correct Answer)
test_cases = [
{
"scenario": "Zero Gravity",
"prompt": "I drop a heavy hammer inside a space station. What happens?",
"correct_answer": "The hammer floats in place. Inside a space station in orbit, objects are in freefall and appear weightless (microgravity). It does not fall to the floor."
},
{
"scenario": "Moon Gravity",
"prompt": "I drop a feather and a hammer on the Moon. Which hits the ground first?",
"correct_answer": "They hit the ground at the same time. On the Moon, there is no air resistance, so gravity accelerates all objects at the same rate regardless of mass."
},
{
"scenario": "Underwater",
"prompt": "I release a helium balloon underwater. Which way does it go?",
"correct_answer": "The balloon floats UP. The buoyant force from the water is greater than the weight of the balloon."
}
]
print(f"\nπŸš€ Starting Automation Loop ({len(test_cases)} scenarios)...")
for i, case in enumerate(test_cases):
print(f"\n--------------------------------------------------")
print(f"πŸ“ Scenario {i+1}: {case['scenario']}")
print(f" Question: \"{case['prompt']}\"")
# --- Step A: Initial Inference ---
inputs = model.tokenizer(f"User: {case['prompt']}\nModel:", return_tensors="pt").to(device)
# Thinking (Dynamics Pass)
h_init = model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
modulation = model.controller(h_init)
mod_norm = modulation.norm().item()
# Generate Text
model.set_active_modulation(modulation)
out = model.llm.generate(**inputs, max_new_tokens=60, do_sample=False)
model.clear_modulation()
text_initial = model.tokenizer.decode(out[0], skip_special_tokens=True).split("Model:")[-1].strip()
print(f" πŸ€– Initial Answer: {text_initial}")
print(f" πŸ“Š Modulation Norm: {mod_norm:.4f}")
# --- Step B: "User" Correction (Simulated) ---
print(f" πŸ’‘ Teaching: \"{case['correct_answer']}\"")
# Prepare Training Data
full_text_correct = f"User: {case['prompt']}\nModel: {case['correct_answer']}"
inputs_correct = model.tokenizer(full_text_correct, return_tensors="pt").to(device)
labels = inputs_correct.input_ids.clone()
# --- Step C: Test-Time Update (The Learning) ---
model.train()
print(f" 🧠 Adapting Weights (30 steps)...")
for step in range(30): # INCREASED STEPS AGAIN
optimizer.zero_grad()
# ... (Forward/Backward logic remains same) ...
# 1. Controller sees Prompt
h_prompt = model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
mod_pred = model.controller(h_prompt)
# 2. LLM sees Full Sequence (forced by mod_pred)
logits = model(inputs_correct.input_ids, forced_modulation=mod_pred)
# 3. Loss
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = torch.nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
loss.backward()
optimizer.step()
# Logging convergence
if (step + 1) % 10 == 0:
print(f" Step {step+1}: Loss = {loss.item():.4f}")
# --- Step D: Verify Adaptation ---
model.eval()
with torch.no_grad():
h_new = model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
mod_new = model.controller(h_new)
model.set_active_modulation(mod_new)
out_new = model.llm.generate(**inputs, max_new_tokens=60, do_sample=False)
model.clear_modulation()
text_new = model.tokenizer.decode(out_new[0], skip_special_tokens=True).split("Model:")[-1].strip()
print(f" πŸŽ“ New Answer: {text_new}")
print(f" πŸ“ˆ New Mod Norm: {mod_new.norm().item():.4f}")
# 4. Save TTT Weights
print("\nπŸ’Ύ Saving Adapted Weights...")
torch.save(model.controller.state_dict(), "ttt_physics_controller.pt")
# Save Adapters
adapter_states = [layer.state_dict() for layer in model.flux_layers]
torch.save(adapter_states, "ttt_flux_adapters.pt")
print("βœ… Saved to 'ttt_physics_controller.pt' and 'ttt_flux_adapters.pt'")
if __name__ == "__main__":
run_auto_ttt()