import random import logging import os import gc # Optimize CUDA memory allocation to reduce fragmentation os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import torch import torch.nn as nn import torch.optim as optim from modeling_physics_rl import PhysicsModel, Config class StratifiedReplayBuffer: """ Stores memories by Concept ID (or just generic 'user_taught') to ensure we sample DIVERSE history. """ def __init__(self): self.memory = {} # { "concept_id": [ {prompt, answer}, ... ] } # Pre-fill with Anchor Memories to prevent Cold-Start Catastrophic Forgetting self._add_anchor_memories() def _add_anchor_memories(self): anchors = [ ("What is gravity?", "Gravity is a fundamental interaction which causes mutual attraction between all things with mass or energy."), ("Hello", "Hello! How can I help you today?"), ("What is AI?", "Artificial Intelligence (AI) refers to the simulation of human intelligence in machines."), ("Define thermodynamics.", "Thermodynamics is a branch of physics that deals with heat, work, and temperature, and their relation to energy, entropy, and the physical properties of matter."), ("Who are you?", "I am a large language model, trained by Google.") ] self.memory["anchor"] = [{"prompt": q, "answer": a} for q, a in anchors] print(f" โš“ Added {len(anchors)} General Knowledge Anchors to Replay Buffer.") def add(self, concept_id, prompt, answer): if concept_id not in self.memory: self.memory[concept_id] = [] self.memory[concept_id].append({"prompt": prompt, "answer": answer}) def sample_stratified(self, current_concept_id, n_per_concept=1): batch = [] past_concepts = [cid for cid in self.memory.keys() if cid != current_concept_id] if not past_concepts: return [] for cid in past_concepts: samples = random.sample(self.memory[cid], min(len(self.memory[cid]), n_per_concept)) batch.extend(samples) return batch class ContinuousLearningSession: def __init__(self): print("๐Ÿง  Initializing Continuous Learning Session...") # 1. Load Model self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f" ๐Ÿš€ Using Device: {self.device}") self.model = PhysicsModel() self.model.to(self.device) # Force move to GPU # 2. Load Pre-trained Weights self._load_weights() # 3. Setup Online Optimizer # Update BOTH Controller AND Flux Adapters for true adaptation trainable_params = [ {'params': self.model.controller.parameters(), 'lr': 1e-3}, # Fast adaptation ] # Also update the Flux Adapters' modulation projection for layer in self.model.flux_layers: trainable_params.append({'params': layer.modulation_proj.parameters(), 'lr': 5e-4}) self.optimizer = optim.AdamW(trainable_params) # 4. Session Memory (Context Window) # This stores the "learned context" so the model remembers the session self.session_context = [] # List of (input, modulation) pairs self.context_modulation = None # Accumulated modulation bias # 5. Ensure backbone is frozen, but Controller & Adapters are TRAINABLE for p in self.model.llm.parameters(): p.requires_grad = False print(" ๐Ÿ”ง Unfreezing Controller & Flux Adapters...") for p in self.model.controller.parameters(): p.requires_grad = True if isinstance(self.model.flux_layers, list): for layer in self.model.flux_layers: for p in layer.parameters(): p.requires_grad = True else: for p in self.model.flux_layers.parameters(): p.requires_grad = True # 3. Setup Online Optimizer # Update BOTH Controller AND Flux Adapters for true adaptation controller_params = list(self.model.controller.parameters()) if isinstance(self.model.flux_layers, torch.nn.ModuleList) or isinstance(self.model.flux_layers, torch.nn.Sequential): adapter_params = list(self.model.flux_layers.parameters()) else: # If it's a python list adapter_params = [p for layer in self.model.flux_layers for p in layer.parameters()] # Switch back to Adam (Better convergence, relying on GC/Env for memory safety) self.optimizer = optim.Adam(controller_params + adapter_params, lr=1e-4) self.model.train() # Enable gradients for Controller/Adapters # 6. Initialize Replay Buffer & Drift Anchor self.replay_buffer = StratifiedReplayBuffer() self.initial_controller_state = {k: v.clone() for k, v in self.model.controller.state_dict().items()} print(" โœ… Ready for Interactive Continuous Learning (Powered by Replay Buffer)!") def _load_weights(self): """Load pre-trained weights from various possible locations.""" search_paths = [ ".", "/kaggle/input/worldmodels/physics_model", "/kaggle/working/physics_model" ] for path in search_paths: controller_path = os.path.join(path, "final_physics_controller.pt") if os.path.exists(controller_path): print(f" Loading weights from {path}...") self.model.controller.load_state_dict( torch.load(controller_path, map_location=self.device) ) # Load WALT walt_path = os.path.join(path, "final_walt_head.pt") if os.path.exists(walt_path): self.model.walt.load_state_dict( torch.load(walt_path, map_location=self.device) ) # Load Adapters adapter_path = os.path.join(path, "final_liquid_adapters.pt") if os.path.exists(adapter_path): adapter_states = torch.load(adapter_path, map_location=self.device) for layer, state in zip(self.model.flux_layers, adapter_states): layer.load_state_dict(state) print(" โœ… Loaded Flux Adapters.") return print(" โš ๏ธ No pre-trained weights found. Using random initialization.") # def _get_context_modulation(self): # """ # Compute a modulation bias from session history. # This allows the model to "remember" previous physics context. # """ # if not self.session_context: # return None # # Average the modulations from recent context (last 3 interactions) # recent = self.session_context[-3:] # mods = [m for _, m in recent if m is not None] # if not mods: # return None # # Stack and average # stacked = torch.stack(mods) # return stacked.mean(dim=0) def predict(self, user_input: str): """ Generate a response using the current Controller & Flux Adapters. Pure Inference: No context history, just the current weights. """ self.model.eval() full_prompt = f"User: {user_input}\nModel:" inputs = self.model.tokenizer(full_prompt, return_tensors="pt").to(self.device) # 1. Generate Modulation (Based strictly on CURRENT input) with torch.no_grad(): h_init = self.model.get_embeddings(inputs.input_ids).to(Config.DTYPE) modulation = self.model.controller(h_init) # 2. No Context Bias (Disabled per request) # We rely solely on the weight updates from 'learn()' # context_mod = self._get_context_modulation() # if context_mod is not None: # # Blend: 70% new, 30% context # modulation = 0.7 * modulation + 0.3 * context_mod.to(modulation.device) # 3. Apply modulation and generate self.model.set_active_modulation(modulation) out_ids = self.model.llm.generate( **inputs, max_new_tokens=100, # Increased for chat # max_length=Config.MAX_LENGTH, # Removed as per diff do_sample=True, temperature=0.7, # Changed from 0.6 to 0.7 repetition_penalty=1.0, # Reset to default (was 1.2) to fix silence pad_token_id=self.model.tokenizer.eos_token_id ) response = self.model.tokenizer.decode(out_ids[0], skip_special_tokens=True) response_clean = response.split("Model:")[-1].strip() self.model.clear_modulation() return response_clean, modulation.detach() def _generate_synthetic_data(self, question, answer, num_variations=3): """ Uses the frozen Base LLM to generate diverse variations of the training example. This turns One-Shot Learning into Few-Shot Learning (Synthetic Data Augmentation). """ print(" โœจ Generating synthetic training data (Self-Distillation)...") # 1. Disable adapters/modulation to get clean English capability self.model.clear_modulation() self.model.eval() prompt = ( f"Original Question: {question}\n" f"Original Answer: {answer}\n\n" f"Task: Rewrite the above Question and Answer pair in {num_variations} different styles (e.g. simple, formal, detailed). " f"Keep the facts exactly the same.\n" f"Output format:\n" f"Q1: ...\n" f"A1: ...\n" f"Q2: ...\n" f"A2: ...\n" f"Start now:" ) inputs = self.model.tokenizer(prompt, return_tensors="pt").to(self.device) with torch.no_grad(): out_ids = self.model.llm.generate( **inputs, max_new_tokens=256, do_sample=True, temperature=0.7 ) raw_text = self.model.tokenizer.decode(out_ids[0], skip_special_tokens=True) # Parse the output (Simple heuristic parsing) variations = [{"q": question, "a": answer}] # Always include original current_q = None for line in raw_text.split('\n'): line = line.strip() if line.startswith("Q") and ":" in line: current_q = line.split(":", 1)[1].strip() elif line.startswith("A") and ":" in line and current_q: current_a = line.split(":", 1)[1].strip() # Validation: Ensure neither Q nor A is empty or garbage if current_q and current_a and "..." not in current_q and "..." not in current_a: variations.append({"q": current_q, "a": current_a}) current_q = None # Cleanup Memory del inputs, out_ids torch.cuda.empty_cache() # Fallback: If synthetic generation failed, duplicate original if len(variations) == 1: print(" โš ๏ธ Synthetic generation failed to produce valid format. Duplicating original.") variations.append({"q": question, "a": answer}) print(f" โœจ Generated {len(variations)-1} synthetic variations.") for i, v in enumerate(variations): print(f" [{i}] Q: {v['q'][:30]}... A: {v['a'][:30]}...") return variations def learn(self, user_input: str, correct_answer: str, concept_id: str = "general"): """ Robust Learning: Updates weights using the new example + Replay Buffer. Runs specific number of steps (plasticity) while anchoring to past (stability). """ print("\n ๐Ÿง  Starting Robust Adaptation Loop...") # 0. Augment Data (Synthetic Variations) training_batch = self._generate_synthetic_data(user_input, correct_answer) # 1. Add new knowledge to Buffer self.replay_buffer.add(concept_id, user_input, correct_answer) # Force cleanup before training to prevent OOM gc.collect() torch.cuda.empty_cache() # 2. Training Loop (Micro-Epochs) # 2. Training Loop (Micro-Epochs) steps = 20 # Reduced to 20 (Safe limit for strong replay) for step in range(steps): self.optimizer.zero_grad() total_loss = 0 # --- A. Current Task (Random Sample from Synthetic Batch) --- # Pick a random variation to train on this step example = random.choice(training_batch) # Append EOS so model knows when to STOP talking full_text = f"User: {example['q']}\nModel: {example['a']}{self.model.tokenizer.eos_token}" inputs_train = self.model.tokenizer(full_text, return_tensors="pt", max_length=Config.MAX_LENGTH, truncation=True).to(self.device) h_train = self.model.get_embeddings(inputs_train.input_ids).to(Config.DTYPE) mod_pred = self.model.controller(h_train) logits = self.model(inputs_train.input_ids, forced_modulation=mod_pred) shift_logits = logits[..., :-1, :].contiguous() shift_labels = inputs_train.input_ids[..., 1:].contiguous() task_loss = torch.nn.functional.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) total_loss += task_loss * 1.0 # --- B. Replay (Stability) --- past_memories = self.replay_buffer.sample_stratified(concept_id, n_per_concept=2) if past_memories: for mem in past_memories: full_replay = f"User: {mem['prompt']}\nModel: {mem['answer']}" inputs_replay = self.model.tokenizer(full_replay, return_tensors="pt", max_length=Config.MAX_LENGTH, truncation=True).to(self.device) h_rep = self.model.get_embeddings(inputs_replay.input_ids).to(Config.DTYPE) mod_rep = self.model.controller(h_rep) logits_rep = self.model(inputs_replay.input_ids, forced_modulation=mod_rep) s_log = logits_rep[..., :-1, :].contiguous() s_lab = inputs_replay.input_ids[..., 1:].contiguous() loss_rep = torch.nn.functional.cross_entropy(s_log.view(-1, s_log.size(-1)), s_lab.view(-1)) # Weight Replay EQUAL (1.0) to task to enforce stability total_loss += loss_rep * 1.0 # --- C. Anti-Drift (Crucial for TTT) --- # Penalize deviation from original weights to prevent "Model Collapse" drift_loss = 0 for name, param in self.model.controller.named_parameters(): drift_loss += torch.sum((param - self.initial_controller_state[name].to(self.device)) ** 2) total_loss += drift_loss * 10.0 # Very Strong anchor (was 1.0) total_loss.backward() # Debug Gradients total_norm = 0.0 for p in self.model.controller.parameters(): if p.grad is not None: total_norm += p.grad.data.norm(2).item() ** 2 total_norm = total_norm ** 0.5 self.optimizer.step() if (step+1) % 10 == 0: print(f" Step {step+1}: Loss {total_loss.item():.4f} | Grad Norm: {total_norm:.4f}") # Early Stopping (Prevent Overfitting) if total_loss.item() < 0.005: print(f" โœ… Converged early at step {step+1} (Loss < 0.005)") break # 3. Store context (DISABLED) # self.session_context.append((user_input, mod_pred.detach())) self.model.clear_modulation() print(" โœ… Adaptation Complete. Weights Updated.") return total_loss.item() def save_weights(self, suffix="session"): """Save the updated weights after a learning session.""" print(" ๐Ÿ’พ Saving updated weights...") torch.save(self.model.controller.state_dict(), f"controller_{suffix}.pt") adapter_states = [l.state_dict() for l in self.model.flux_layers] torch.save(adapter_states, f"adapters_{suffix}.pt") print(f" โœ… Saved to controller_{suffix}.pt and adapters_{suffix}.pt") def run(self): """Main interactive loop.""" print("\n" + "="*60) print(" ๐Ÿงช CONTINUOUS LEARNING LAB") print(" Commands:") print(" - Ask any physics question") print(" - Type 'wrong' if the answer is incorrect") print(" - Type 'save' to save updated weights") print(" - Type 'exit' to quit") print("="*60) while True: try: user_input = input("\nUSER: ").strip() except (EOFError, KeyboardInterrupt): break if not user_input: continue if user_input.lower() in ['exit', 'quit']: break if user_input.lower() == 'save': self.save_weights() continue # Generate prediction response, modulation = self.predict(user_input) mod_norm = modulation.norm().item() print(f"MODEL: {response}") print(f" [Modulation Norm: {mod_norm:.2f}]") # Feedback loop try: feedback = input(" (Enter=correct, 'wrong'=teach): ").strip().lower() except (EOFError, KeyboardInterrupt): break if feedback == "wrong": try: truth = input(" CORRECT ANSWER: ").strip() # topic = input(" TOPIC ID (e.g. 'gravity', 'thermo'): ").strip() topic = "general" # Defaulting as requested except (EOFError, KeyboardInterrupt): break if truth: # Pass the topic to learn so it can index it correctly self.learn(user_input, truth, topic) # Store correct modulation in context (DISABLED) # self.session_context.append((user_input, modulation)) else: # Correct answer - store in context for future reference (DISABLED) # self.session_context.append((user_input, modulation)) print(" ๐Ÿ‘ Perfect! (No update needed)") print("\n๐Ÿ‘‹ Session ended.") # Offer to save try: save = input(" Save updated weights? (y/n): ").strip().lower() if save == 'y': self.save_weights() except: pass if __name__ == "__main__": session = ContinuousLearningSession() session.run()