import torch import os import random import logging import torch.optim as optim from config_physics import Config from modeling_physics_rl import PhysicsModel # Setup Logging logging.basicConfig(level=logging.INFO, format="%(message)s") class StratifiedReplayBuffer: """ Stores memories by Concept ID to ensure we sample DIVERSE history, not just random history. """ def __init__(self): self.memory = {} # { "concept_id": [ {prompt, answer}, ... ] } def add(self, concept_id, prompt, answer): if concept_id not in self.memory: self.memory[concept_id] = [] self.memory[concept_id].append({"prompt": prompt, "answer": answer}) def sample_stratified(self, current_concept_id, n_per_concept=1): """ Returns a batch containing examples from ALL previous concepts EXCEPT the current one (which is handled separately). This forces the model to face its 'confusing neighbors'. """ batch = [] past_concepts = [cid for cid in self.memory.keys() if cid != current_concept_id] if not past_concepts: return [] for cid in past_concepts: # Grab 'n' random examples from this specific concept samples = random.sample(self.memory[cid], min(len(self.memory[cid]), n_per_concept)) batch.extend(samples) return batch def run_cumulative_ttt(): print("\n" + "="*60) print(" ๐Ÿง  CUMULATIVE TTT: Fixed LR + Stratified Replay") print("="*60) # 1. Load Model print("โณ Loading Physics Model...") model = PhysicsModel() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Load Pre-trained Weights search_paths = [".", "/kaggle/input/worldmodels/physics_model", "/kaggle/working/physics_model"] loaded = False for p in search_paths: fpath = os.path.join(p, "final_flux_adapters.pt") if os.path.exists(fpath): print(f" Loading Adapters from {fpath}") states = torch.load(fpath, map_location=device) if isinstance(states, list): for layer, state in zip(model.flux_layers, states): layer.load_state_dict(state) loaded = True # Also try to load Controller and WALT Head if they exist in the same dir c_path = os.path.join(p, "final_physics_controller.pt") if os.path.exists(c_path): print(f" Loading Controller from {c_path}") model.controller.load_state_dict(torch.load(c_path, map_location=device)) w_path = os.path.join(p, "final_walt_head.pt") if os.path.exists(w_path): print(f" Loading WALT Head from {w_path}") model.walt.load_state_dict(torch.load(w_path, map_location=device)) break if not loaded: print("โš ๏ธ WARNING: No pre-trained adapters found. Starting from scratch.") # 2. Optimization Setup (FIXED) params = list(model.controller.parameters()) for layer in model.flux_layers: params.extend(list(layer.modulation_proj.parameters())) # FIX 1: Restore High Learning Rate for TTT (Plasticity) # 5e-5 was too slow. We need aggressive updates. optimizer = optim.AdamW(params, lr=1e-3) replay_buffer = StratifiedReplayBuffer() # 3. Curriculum curriculum = [ { "id": "scenario_1", "concept": "Zero Gravity Inertia", "prompt": "I release a heavy chrome dumbbell inside the International Space Station. What happens to it?", "correction": "The dumbbell floats in mid-air because it is weightless in the microgravity environment.", "test_variations": [ {"q": "I gentle place a sandwich in the middle of the cabin air. Does it fall?", "a": "No, it hovers in place due to weightlessness."}, {"q": "I squeeze water out of a pouch and it forms a sphere. Does the sphere drop to the deck?", "a": "No, the water sphere floats in the air."}, {"q": "If I jump off the 'floor' of the ISS, do I come back down?", "a": "No, you continue floating until you hit the opposite wall."} ] }, { "id": "scenario_2", "concept": "Vacuum Gravity (Galileo)", "prompt": "I drop a heavy anvil and a light goose feather in a vacuum chamber. Which lands first?", "correction": "They land at the exact same time. Without air resistance, gravity accelerates all masses equally.", "test_variations": [ {"q": "Two spheres, one made of solid lead and one of styrofoam, are released in a vacuum tube. Result?", "a": "They fall and hit the bottom at the exact same instant."}, {"q": "On a planet with no atmosphere, I drop a cannonball and a glass marble. Comparte their speed.", "a": "Their speed increases at the same rate; they remain side-by-side."}, {"q": "A massive steel beam drops next to a tiny screw in vacuum. Race result?", "a": "Tie. They crash down at the same moment."} ] }, { "id": "scenario_4", "concept": "Initial Velocity vs Gravity", "prompt": "If I throw a coin DOWNWARDS and drop a feather at the moon at the same time, which hits the floor first?", "correction": "The coin hits first. Physics: Distance = V_initial*t + 0.5*g*t^2. The coin has V_initial > 0, so it covers the distance faster than the feather (V_initial=0).", "test_variations": [ {"q": "I slam a volleyball downwards and drop a bowling ball in vacuum. Winner?", "a": "The volleyball hits first. It starts with extra speed ($v_0 > 0$), unlike the dropped ball."}, {"q": "Race: Bullet fired down vs Bullet dropped. Who wins?", "a": "The fired bullet wins. Muzzle velocity adds to gravity."}, {"q": "If I toss a rock DOWN and drop a feather on Moon?", "a": "The rock hits first. Throwing it adds kinetic energy and speed right at the start."} ] }, { "id": "scenario_3", "concept": "Buoyancy", "prompt": "I hold a large, air-filled beach ball at the bottom of a pool and release it. What is its motion?", "correction": "The ball accelerates rapidly upward due to the strong buoyant force of the water.", "test_variations": [ {"q": "A submarine blows its ballast tanks while deep underwater. What is the immediate effect?", "a": "The submarine rises towards the surface."}, {"q": "I force a block of white Styrofoam deep underwater and remove my hand. What happens?", "a": "It shoots upward to the surface immediately."}, {"q": "What happens to the air bubbles exhaled by a scuba diver?", "a": "The bubbles float upward through the water column."} ] } ] print(f"\n๐Ÿš€ Starting Cumulative Loop ({len(curriculum)} concepts)...") # Keep copy of initial controller for drift regularization initial_controller_state = {k: v.clone() for k, v in model.controller.state_dict().items()} for stage, task in enumerate(curriculum): print(f"\n" + "-"*50) print(f"๐Ÿ“ STAGE {stage+1}: {task['concept']}") # Add to Replay Buffer Categorically for v in task['test_variations']: replay_buffer.add(task['id'], v['q'], v['a']) replay_buffer.add(task['id'], task['prompt'], task['correction']) print(" ๐Ÿง  Robust Learning (Current + Stratified History)...") model.train() # Helper to format main prompt as dict for unification main_prompt_dict = {"q": task['prompt'], "a": task['correction']} current_variations = task['test_variations'] + [main_prompt_dict] # Increased steps slightly to allow low LR to work for step in range(300): optimizer.zero_grad() total_loss = 0 # --- A. Current Task (Plasticity) --- # Pick 2 random variations of the current concept v_current = random.sample(current_variations, 2) for v in v_current: loss_new = calculate_loss(model, v['q'], v['a'], device) total_loss += loss_new * 1.0 # Focus on learning new thing # --- B. Stratified Replay (Stability) --- # FIX 2: Force sample from EVERY past concept # This ensures "Vacuum" batch ALWAYS includes a "Zero-G" example past_memories = replay_buffer.sample_stratified(task['id'], n_per_concept=2) if past_memories: for mem in past_memories: loss_replay = calculate_loss(model, mem['prompt'], mem['answer'], device) # High weight on replay to fight forgetting total_loss += loss_replay * 1.0 # --- C. Anti-Drift Regularization --- # Keep controller weights close to initial sanity if possible drift_loss = 0 for name, param in model.controller.named_parameters(): drift_loss += torch.sum((param - initial_controller_state[name].to(device)) ** 2) # Reduced from 0.1 to 0.01 total_loss += drift_loss * 0.01 total_loss.backward() optimizer.step() if (step+1) % 50 == 0: # Detailed Logging print(f" Step {step+1}: Total {total_loss.item():.4f} | Task {total_loss.item() - drift_loss.item()*0.01:.4f}") print(f" > Learning: \"{v_current[0]['q'][:50]}...\"") if past_memories: print(f" > Replaying: \"{past_memories[0]['prompt'][:50]}...\"") print("\n" + "="*60) print(" ๐Ÿงช FINAL EXAM: Testing Generalization & Retention") print("="*60) model.eval() for task in curriculum: print(f"\n๐Ÿ” Concept: {task['concept']}") score = 0 total = len(task['test_variations']) for item in task['test_variations']: q = item['q'] target = item['a'] print(f" Q: \"{q}\"") inputs = model.tokenizer(f"User: {q}\nModel:", return_tensors="pt").to(device) with torch.no_grad(): h_init = model.get_embeddings(inputs.input_ids).to(Config.DTYPE) mod = model.controller(h_init) model.set_active_modulation(mod) out = model.llm.generate( **inputs, max_new_tokens=60, do_sample=True, # Enable sampling to break loops temperature=0.6, # Low temp for precision top_p=0.9, repetition_penalty=1.2 # CRITICAL: Penalizes "The balloon floats... The balloon floats..." ) model.clear_modulation() ans = model.tokenizer.decode(out[0], skip_special_tokens=True).split("Model:")[-1].strip() print(f" ๐Ÿค– Ans: {ans}") # Simple heuristic check if check_answer(task['id'], ans): print(" โœ… PASS") score += 1 else: print(" โŒ FAIL") print(f" ๐Ÿ‘‰ Score: {score}/{total}") if score == total: print(" ๐ŸŒŸ PERFECT") elif score == 0: print(" ๐Ÿ’€ FAILED") print("\n๐Ÿ’พ Saving Final TTT Model...") torch.save(model.controller.state_dict(), "final_physics_controller.pt") # Handle Flux Layers (List or ModuleList) if isinstance(model.flux_layers, list): # If list, save state_dict of each item - wait, lists don't have state_dict as a group # Better to assume they are Modules and save the first one's state_dict structure or loop. # Actually simplest is just to save the whole model state dict or iterate. # For consistency with other scripts, let's wrap them or save individually. # Let's try to save just the state dicts in a list container. adapter_states = [layer.state_dict() for layer in model.flux_layers] torch.save(adapter_states, "final_flux_adapters.pt") else: torch.save(model.flux_layers.state_dict(), "final_flux_adapters.pt") print("โœ… Model Saved: final_physics_controller.pt, final_flux_adapters.pt") def calculate_loss(model, prompt, answer, device): full_text = f"User: {prompt}\nModel: {answer}" inputs = model.tokenizer(full_text, return_tensors="pt").to(device) # Forward Pass h_prompt = model.get_embeddings(inputs.input_ids).to(Config.DTYPE) mod_pred = model.controller(h_prompt) logits = model(inputs.input_ids, forced_modulation=mod_pred) shift_logits = logits[..., :-1, :].contiguous() shift_labels = inputs.input_ids[..., 1:].contiguous() loss = torch.nn.functional.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return loss def check_answer(task_id, text): text = text.lower() if task_id == "scenario_1": return "float" in text or "hover" in text or "drift" in text if task_id == "scenario_2": return "same time" in text or "equal" in text or "identical" in text or "same rate" in text or "neither" in text or "instant" in text or "side-by-side" in text if task_id == "scenario_3": return "up" in text or "rise" in text or "float" in text if task_id == "scenario_4": return "coin" in text or "thrown" in text or "initial" in text or "toss" in text or "rock" in text or "bullet" in text or "volleyball" in text return False if __name__ == "__main__": run_cumulative_ttt()