| | import torch
|
| | import os
|
| | import random
|
| | import logging
|
| | import torch.optim as optim
|
| | from config_physics import Config
|
| | from modeling_physics_rl import PhysicsModel
|
| |
|
| |
|
| | logging.basicConfig(level=logging.INFO, format="%(message)s")
|
| |
|
| | class StratifiedReplayBuffer:
|
| | """
|
| | Stores memories by Concept ID to ensure we sample DIVERSE history,
|
| | not just random history.
|
| | """
|
| | def __init__(self):
|
| | self.memory = {}
|
| |
|
| | def add(self, concept_id, prompt, answer):
|
| | if concept_id not in self.memory:
|
| | self.memory[concept_id] = []
|
| | self.memory[concept_id].append({"prompt": prompt, "answer": answer})
|
| |
|
| | def sample_stratified(self, current_concept_id, n_per_concept=1):
|
| | """
|
| | Returns a batch containing examples from ALL previous concepts
|
| | EXCEPT the current one (which is handled separately).
|
| | This forces the model to face its 'confusing neighbors'.
|
| | """
|
| | batch = []
|
| | past_concepts = [cid for cid in self.memory.keys() if cid != current_concept_id]
|
| |
|
| | if not past_concepts:
|
| | return []
|
| |
|
| | for cid in past_concepts:
|
| |
|
| | samples = random.sample(self.memory[cid], min(len(self.memory[cid]), n_per_concept))
|
| | batch.extend(samples)
|
| |
|
| | return batch
|
| |
|
| | def run_cumulative_ttt():
|
| | print("\n" + "="*60)
|
| | print(" π§ CUMULATIVE TTT: Fixed LR + Stratified Replay")
|
| | print("="*60)
|
| |
|
| |
|
| | print("β³ Loading Physics Model...")
|
| | model = PhysicsModel()
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| | model.to(device)
|
| |
|
| |
|
| | search_paths = [".", "/kaggle/input/worldmodels/physics_model", "/kaggle/working/physics_model"]
|
| | loaded = False
|
| | for p in search_paths:
|
| | fpath = os.path.join(p, "final_flux_adapters.pt")
|
| | if os.path.exists(fpath):
|
| | print(f" Loading Adapters from {fpath}")
|
| | states = torch.load(fpath, map_location=device)
|
| | if isinstance(states, list):
|
| | for layer, state in zip(model.flux_layers, states):
|
| | layer.load_state_dict(state)
|
| | loaded = True
|
| |
|
| |
|
| | c_path = os.path.join(p, "final_physics_controller.pt")
|
| | if os.path.exists(c_path):
|
| | print(f" Loading Controller from {c_path}")
|
| | model.controller.load_state_dict(torch.load(c_path, map_location=device))
|
| |
|
| | w_path = os.path.join(p, "final_walt_head.pt")
|
| | if os.path.exists(w_path):
|
| | print(f" Loading WALT Head from {w_path}")
|
| | model.walt.load_state_dict(torch.load(w_path, map_location=device))
|
| |
|
| | break
|
| |
|
| | if not loaded:
|
| | print("β οΈ WARNING: No pre-trained adapters found. Starting from scratch.")
|
| |
|
| |
|
| | params = list(model.controller.parameters())
|
| | for layer in model.flux_layers:
|
| | params.extend(list(layer.modulation_proj.parameters()))
|
| |
|
| |
|
| |
|
| | optimizer = optim.AdamW(params, lr=1e-3)
|
| |
|
| | replay_buffer = StratifiedReplayBuffer()
|
| |
|
| |
|
| | curriculum = [
|
| | {
|
| | "id": "scenario_1",
|
| | "concept": "Zero Gravity Inertia",
|
| | "prompt": "I release a heavy chrome dumbbell inside the International Space Station. What happens to it?",
|
| | "correction": "The dumbbell floats in mid-air because it is weightless in the microgravity environment.",
|
| | "test_variations": [
|
| | {"q": "I gentle place a sandwich in the middle of the cabin air. Does it fall?", "a": "No, it hovers in place due to weightlessness."},
|
| | {"q": "I squeeze water out of a pouch and it forms a sphere. Does the sphere drop to the deck?", "a": "No, the water sphere floats in the air."},
|
| | {"q": "If I jump off the 'floor' of the ISS, do I come back down?", "a": "No, you continue floating until you hit the opposite wall."}
|
| | ]
|
| | },
|
| | {
|
| | "id": "scenario_2",
|
| | "concept": "Vacuum Gravity (Galileo)",
|
| | "prompt": "I drop a heavy anvil and a light goose feather in a vacuum chamber. Which lands first?",
|
| | "correction": "They land at the exact same time. Without air resistance, gravity accelerates all masses equally.",
|
| | "test_variations": [
|
| | {"q": "Two spheres, one made of solid lead and one of styrofoam, are released in a vacuum tube. Result?", "a": "They fall and hit the bottom at the exact same instant."},
|
| | {"q": "On a planet with no atmosphere, I drop a cannonball and a glass marble. Comparte their speed.", "a": "Their speed increases at the same rate; they remain side-by-side."},
|
| | {"q": "A massive steel beam drops next to a tiny screw in vacuum. Race result?", "a": "Tie. They crash down at the same moment."}
|
| | ]
|
| | },
|
| | {
|
| | "id": "scenario_4",
|
| | "concept": "Initial Velocity vs Gravity",
|
| | "prompt": "If I throw a coin DOWNWARDS and drop a feather at the moon at the same time, which hits the floor first?",
|
| | "correction": "The coin hits first. Physics: Distance = V_initial*t + 0.5*g*t^2. The coin has V_initial > 0, so it covers the distance faster than the feather (V_initial=0).",
|
| | "test_variations": [
|
| | {"q": "I slam a volleyball downwards and drop a bowling ball in vacuum. Winner?", "a": "The volleyball hits first. It starts with extra speed ($v_0 > 0$), unlike the dropped ball."},
|
| | {"q": "Race: Bullet fired down vs Bullet dropped. Who wins?", "a": "The fired bullet wins. Muzzle velocity adds to gravity."},
|
| | {"q": "If I toss a rock DOWN and drop a feather on Moon?", "a": "The rock hits first. Throwing it adds kinetic energy and speed right at the start."}
|
| | ]
|
| | },
|
| | {
|
| | "id": "scenario_3",
|
| | "concept": "Buoyancy",
|
| | "prompt": "I hold a large, air-filled beach ball at the bottom of a pool and release it. What is its motion?",
|
| | "correction": "The ball accelerates rapidly upward due to the strong buoyant force of the water.",
|
| | "test_variations": [
|
| | {"q": "A submarine blows its ballast tanks while deep underwater. What is the immediate effect?", "a": "The submarine rises towards the surface."},
|
| | {"q": "I force a block of white Styrofoam deep underwater and remove my hand. What happens?", "a": "It shoots upward to the surface immediately."},
|
| | {"q": "What happens to the air bubbles exhaled by a scuba diver?", "a": "The bubbles float upward through the water column."}
|
| | ]
|
| | }
|
| | ]
|
| |
|
| | print(f"\nπ Starting Cumulative Loop ({len(curriculum)} concepts)...")
|
| |
|
| |
|
| | initial_controller_state = {k: v.clone() for k, v in model.controller.state_dict().items()}
|
| |
|
| | for stage, task in enumerate(curriculum):
|
| | print(f"\n" + "-"*50)
|
| | print(f"π STAGE {stage+1}: {task['concept']}")
|
| |
|
| |
|
| | for v in task['test_variations']:
|
| | replay_buffer.add(task['id'], v['q'], v['a'])
|
| | replay_buffer.add(task['id'], task['prompt'], task['correction'])
|
| |
|
| | print(" π§ Robust Learning (Current + Stratified History)...")
|
| | model.train()
|
| |
|
| |
|
| | main_prompt_dict = {"q": task['prompt'], "a": task['correction']}
|
| | current_variations = task['test_variations'] + [main_prompt_dict]
|
| |
|
| |
|
| | for step in range(300):
|
| | optimizer.zero_grad()
|
| | total_loss = 0
|
| |
|
| |
|
| |
|
| | v_current = random.sample(current_variations, 2)
|
| | for v in v_current:
|
| | loss_new = calculate_loss(model, v['q'], v['a'], device)
|
| | total_loss += loss_new * 1.0
|
| |
|
| |
|
| |
|
| |
|
| | past_memories = replay_buffer.sample_stratified(task['id'], n_per_concept=2)
|
| |
|
| | if past_memories:
|
| | for mem in past_memories:
|
| | loss_replay = calculate_loss(model, mem['prompt'], mem['answer'], device)
|
| |
|
| | total_loss += loss_replay * 1.0
|
| |
|
| |
|
| |
|
| | drift_loss = 0
|
| | for name, param in model.controller.named_parameters():
|
| | drift_loss += torch.sum((param - initial_controller_state[name].to(device)) ** 2)
|
| |
|
| | total_loss += drift_loss * 0.01
|
| |
|
| | total_loss.backward()
|
| | optimizer.step()
|
| |
|
| | if (step+1) % 50 == 0:
|
| |
|
| | print(f" Step {step+1}: Total {total_loss.item():.4f} | Task {total_loss.item() - drift_loss.item()*0.01:.4f}")
|
| | print(f" > Learning: \"{v_current[0]['q'][:50]}...\"")
|
| | if past_memories:
|
| | print(f" > Replaying: \"{past_memories[0]['prompt'][:50]}...\"")
|
| |
|
| | print("\n" + "="*60)
|
| | print(" π§ͺ FINAL EXAM: Testing Generalization & Retention")
|
| | print("="*60)
|
| |
|
| | model.eval()
|
| | for task in curriculum:
|
| | print(f"\nπ Concept: {task['concept']}")
|
| | score = 0
|
| | total = len(task['test_variations'])
|
| |
|
| | for item in task['test_variations']:
|
| | q = item['q']
|
| | target = item['a']
|
| | print(f" Q: \"{q}\"")
|
| | inputs = model.tokenizer(f"User: {q}\nModel:", return_tensors="pt").to(device)
|
| | with torch.no_grad():
|
| | h_init = model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
|
| | mod = model.controller(h_init)
|
| | model.set_active_modulation(mod)
|
| | out = model.llm.generate(
|
| | **inputs,
|
| | max_new_tokens=60,
|
| | do_sample=True,
|
| | temperature=0.6,
|
| | top_p=0.9,
|
| | repetition_penalty=1.2
|
| | )
|
| | model.clear_modulation()
|
| | ans = model.tokenizer.decode(out[0], skip_special_tokens=True).split("Model:")[-1].strip()
|
| | print(f" π€ Ans: {ans}")
|
| |
|
| |
|
| | if check_answer(task['id'], ans):
|
| | print(" β
PASS")
|
| | score += 1
|
| | else:
|
| | print(" β FAIL")
|
| |
|
| | print(f" π Score: {score}/{total}")
|
| | if score == total: print(" π PERFECT")
|
| | elif score == 0: print(" π FAILED")
|
| |
|
| | print("\nπΎ Saving Final TTT Model...")
|
| | torch.save(model.controller.state_dict(), "final_physics_controller.pt")
|
| |
|
| |
|
| | if isinstance(model.flux_layers, list):
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | adapter_states = [layer.state_dict() for layer in model.flux_layers]
|
| | torch.save(adapter_states, "final_flux_adapters.pt")
|
| | else:
|
| | torch.save(model.flux_layers.state_dict(), "final_flux_adapters.pt")
|
| |
|
| | print("β
Model Saved: final_physics_controller.pt, final_flux_adapters.pt")
|
| |
|
| | def calculate_loss(model, prompt, answer, device):
|
| | full_text = f"User: {prompt}\nModel: {answer}"
|
| | inputs = model.tokenizer(full_text, return_tensors="pt").to(device)
|
| |
|
| |
|
| | h_prompt = model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
|
| | mod_pred = model.controller(h_prompt)
|
| | logits = model(inputs.input_ids, forced_modulation=mod_pred)
|
| |
|
| | shift_logits = logits[..., :-1, :].contiguous()
|
| | shift_labels = inputs.input_ids[..., 1:].contiguous()
|
| | loss = torch.nn.functional.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
| | return loss
|
| |
|
| | def check_answer(task_id, text):
|
| | text = text.lower()
|
| | if task_id == "scenario_1": return "float" in text or "hover" in text or "drift" in text
|
| | if task_id == "scenario_2": return "same time" in text or "equal" in text or "identical" in text or "same rate" in text or "neither" in text or "instant" in text or "side-by-side" in text
|
| | if task_id == "scenario_3": return "up" in text or "rise" in text or "float" in text
|
| | if task_id == "scenario_4": return "coin" in text or "thrown" in text or "initial" in text or "toss" in text or "rock" in text or "bullet" in text or "volleyball" in text
|
| | return False
|
| |
|
| | if __name__ == "__main__":
|
| | run_cumulative_ttt() |