flux-test-time-training / continuous_learning_session.py
convaiinnovations's picture
Upload continuous_learning_session.py
349f5a8 verified
import random
import logging
import os
import gc
# Optimize CUDA memory allocation to reduce fragmentation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
import torch.nn as nn
import torch.optim as optim
from modeling_physics_rl import PhysicsModel, Config
class StratifiedReplayBuffer:
"""
Stores memories by Concept ID (or just generic 'user_taught') to ensure we sample DIVERSE history.
"""
def __init__(self):
self.memory = {} # { "concept_id": [ {prompt, answer}, ... ] }
# Pre-fill with Anchor Memories to prevent Cold-Start Catastrophic Forgetting
self._add_anchor_memories()
def _add_anchor_memories(self):
anchors = [
("What is gravity?", "Gravity is a fundamental interaction which causes mutual attraction between all things with mass or energy."),
("Hello", "Hello! How can I help you today?"),
("What is AI?", "Artificial Intelligence (AI) refers to the simulation of human intelligence in machines."),
("Define thermodynamics.", "Thermodynamics is a branch of physics that deals with heat, work, and temperature, and their relation to energy, entropy, and the physical properties of matter."),
("Who are you?", "I am a large language model, trained by Google.")
]
self.memory["anchor"] = [{"prompt": q, "answer": a} for q, a in anchors]
print(f" βš“ Added {len(anchors)} General Knowledge Anchors to Replay Buffer.")
def add(self, concept_id, prompt, answer):
if concept_id not in self.memory:
self.memory[concept_id] = []
self.memory[concept_id].append({"prompt": prompt, "answer": answer})
def sample_stratified(self, current_concept_id, n_per_concept=1):
batch = []
past_concepts = [cid for cid in self.memory.keys() if cid != current_concept_id]
if not past_concepts: return []
for cid in past_concepts:
samples = random.sample(self.memory[cid], min(len(self.memory[cid]), n_per_concept))
batch.extend(samples)
return batch
class ContinuousLearningSession:
def __init__(self):
print("🧠 Initializing Continuous Learning Session...")
# 1. Load Model
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f" πŸš€ Using Device: {self.device}")
self.model = PhysicsModel()
self.model.to(self.device) # Force move to GPU
# 2. Load Pre-trained Weights
self._load_weights()
# 3. Setup Online Optimizer
# Update BOTH Controller AND Flux Adapters for true adaptation
trainable_params = [
{'params': self.model.controller.parameters(), 'lr': 1e-3}, # Fast adaptation
]
# Also update the Flux Adapters' modulation projection
for layer in self.model.flux_layers:
trainable_params.append({'params': layer.modulation_proj.parameters(), 'lr': 5e-4})
self.optimizer = optim.AdamW(trainable_params)
# 4. Session Memory (Context Window)
# This stores the "learned context" so the model remembers the session
self.session_context = [] # List of (input, modulation) pairs
self.context_modulation = None # Accumulated modulation bias
# 5. Ensure backbone is frozen, but Controller & Adapters are TRAINABLE
for p in self.model.llm.parameters():
p.requires_grad = False
print(" πŸ”§ Unfreezing Controller & Flux Adapters...")
for p in self.model.controller.parameters():
p.requires_grad = True
if isinstance(self.model.flux_layers, list):
for layer in self.model.flux_layers:
for p in layer.parameters():
p.requires_grad = True
else:
for p in self.model.flux_layers.parameters():
p.requires_grad = True
# 3. Setup Online Optimizer
# Update BOTH Controller AND Flux Adapters for true adaptation
controller_params = list(self.model.controller.parameters())
if isinstance(self.model.flux_layers, torch.nn.ModuleList) or isinstance(self.model.flux_layers, torch.nn.Sequential):
adapter_params = list(self.model.flux_layers.parameters())
else:
# If it's a python list
adapter_params = [p for layer in self.model.flux_layers for p in layer.parameters()]
# Switch back to Adam (Better convergence, relying on GC/Env for memory safety)
self.optimizer = optim.Adam(controller_params + adapter_params, lr=1e-4)
self.model.train() # Enable gradients for Controller/Adapters
# 6. Initialize Replay Buffer & Drift Anchor
self.replay_buffer = StratifiedReplayBuffer()
self.initial_controller_state = {k: v.clone() for k, v in self.model.controller.state_dict().items()}
print(" βœ… Ready for Interactive Continuous Learning (Powered by Replay Buffer)!")
def _load_weights(self):
"""Load pre-trained weights from various possible locations."""
search_paths = [
".",
"/kaggle/input/worldmodels/physics_model",
"/kaggle/working/physics_model"
]
for path in search_paths:
controller_path = os.path.join(path, "final_physics_controller.pt")
if os.path.exists(controller_path):
print(f" Loading weights from {path}...")
self.model.controller.load_state_dict(
torch.load(controller_path, map_location=self.device)
)
# Load WALT
walt_path = os.path.join(path, "final_walt_head.pt")
if os.path.exists(walt_path):
self.model.walt.load_state_dict(
torch.load(walt_path, map_location=self.device)
)
# Load Adapters
adapter_path = os.path.join(path, "final_liquid_adapters.pt")
if os.path.exists(adapter_path):
adapter_states = torch.load(adapter_path, map_location=self.device)
for layer, state in zip(self.model.flux_layers, adapter_states):
layer.load_state_dict(state)
print(" βœ… Loaded Flux Adapters.")
return
print(" ⚠️ No pre-trained weights found. Using random initialization.")
# def _get_context_modulation(self):
# """
# Compute a modulation bias from session history.
# This allows the model to "remember" previous physics context.
# """
# if not self.session_context:
# return None
# # Average the modulations from recent context (last 3 interactions)
# recent = self.session_context[-3:]
# mods = [m for _, m in recent if m is not None]
# if not mods:
# return None
# # Stack and average
# stacked = torch.stack(mods)
# return stacked.mean(dim=0)
def predict(self, user_input: str):
"""
Generate a response using the current Controller & Flux Adapters.
Pure Inference: No context history, just the current weights.
"""
self.model.eval()
full_prompt = f"User: {user_input}\nModel:"
inputs = self.model.tokenizer(full_prompt, return_tensors="pt").to(self.device)
# 1. Generate Modulation (Based strictly on CURRENT input)
with torch.no_grad():
h_init = self.model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
modulation = self.model.controller(h_init)
# 2. No Context Bias (Disabled per request)
# We rely solely on the weight updates from 'learn()'
# context_mod = self._get_context_modulation()
# if context_mod is not None:
# # Blend: 70% new, 30% context
# modulation = 0.7 * modulation + 0.3 * context_mod.to(modulation.device)
# 3. Apply modulation and generate
self.model.set_active_modulation(modulation)
out_ids = self.model.llm.generate(
**inputs,
max_new_tokens=100, # Increased for chat
# max_length=Config.MAX_LENGTH, # Removed as per diff
do_sample=True,
temperature=0.7, # Changed from 0.6 to 0.7
repetition_penalty=1.0, # Reset to default (was 1.2) to fix silence
pad_token_id=self.model.tokenizer.eos_token_id
)
response = self.model.tokenizer.decode(out_ids[0], skip_special_tokens=True)
response_clean = response.split("Model:")[-1].strip()
self.model.clear_modulation()
return response_clean, modulation.detach()
def _generate_synthetic_data(self, question, answer, num_variations=3):
"""
Uses the frozen Base LLM to generate diverse variations of the training example.
This turns One-Shot Learning into Few-Shot Learning (Synthetic Data Augmentation).
"""
print(" ✨ Generating synthetic training data (Self-Distillation)...")
# 1. Disable adapters/modulation to get clean English capability
self.model.clear_modulation()
self.model.eval()
prompt = (
f"Original Question: {question}\n"
f"Original Answer: {answer}\n\n"
f"Task: Rewrite the above Question and Answer pair in {num_variations} different styles (e.g. simple, formal, detailed). "
f"Keep the facts exactly the same.\n"
f"Output format:\n"
f"Q1: ...\n"
f"A1: ...\n"
f"Q2: ...\n"
f"A2: ...\n"
f"Start now:"
)
inputs = self.model.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
out_ids = self.model.llm.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
temperature=0.7
)
raw_text = self.model.tokenizer.decode(out_ids[0], skip_special_tokens=True)
# Parse the output (Simple heuristic parsing)
variations = [{"q": question, "a": answer}] # Always include original
current_q = None
for line in raw_text.split('\n'):
line = line.strip()
if line.startswith("Q") and ":" in line:
current_q = line.split(":", 1)[1].strip()
elif line.startswith("A") and ":" in line and current_q:
current_a = line.split(":", 1)[1].strip()
# Validation: Ensure neither Q nor A is empty or garbage
if current_q and current_a and "..." not in current_q and "..." not in current_a:
variations.append({"q": current_q, "a": current_a})
current_q = None
# Cleanup Memory
del inputs, out_ids
torch.cuda.empty_cache()
# Fallback: If synthetic generation failed, duplicate original
if len(variations) == 1:
print(" ⚠️ Synthetic generation failed to produce valid format. Duplicating original.")
variations.append({"q": question, "a": answer})
print(f" ✨ Generated {len(variations)-1} synthetic variations.")
for i, v in enumerate(variations):
print(f" [{i}] Q: {v['q'][:30]}... A: {v['a'][:30]}...")
return variations
def learn(self, user_input: str, correct_answer: str, concept_id: str = "general"):
"""
Robust Learning: Updates weights using the new example + Replay Buffer.
Runs specific number of steps (plasticity) while anchoring to past (stability).
"""
print("\n 🧠 Starting Robust Adaptation Loop...")
# 0. Augment Data (Synthetic Variations)
training_batch = self._generate_synthetic_data(user_input, correct_answer)
# 1. Add new knowledge to Buffer
self.replay_buffer.add(concept_id, user_input, correct_answer)
# Force cleanup before training to prevent OOM
gc.collect()
torch.cuda.empty_cache()
# 2. Training Loop (Micro-Epochs)
# 2. Training Loop (Micro-Epochs)
steps = 20 # Reduced to 20 (Safe limit for strong replay)
for step in range(steps):
self.optimizer.zero_grad()
total_loss = 0
# --- A. Current Task (Random Sample from Synthetic Batch) ---
# Pick a random variation to train on this step
example = random.choice(training_batch)
# Append EOS so model knows when to STOP talking
full_text = f"User: {example['q']}\nModel: {example['a']}{self.model.tokenizer.eos_token}"
inputs_train = self.model.tokenizer(full_text, return_tensors="pt", max_length=Config.MAX_LENGTH, truncation=True).to(self.device)
h_train = self.model.get_embeddings(inputs_train.input_ids).to(Config.DTYPE)
mod_pred = self.model.controller(h_train)
logits = self.model(inputs_train.input_ids, forced_modulation=mod_pred)
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = inputs_train.input_ids[..., 1:].contiguous()
task_loss = torch.nn.functional.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
total_loss += task_loss * 1.0
# --- B. Replay (Stability) ---
past_memories = self.replay_buffer.sample_stratified(concept_id, n_per_concept=2)
if past_memories:
for mem in past_memories:
full_replay = f"User: {mem['prompt']}\nModel: {mem['answer']}"
inputs_replay = self.model.tokenizer(full_replay, return_tensors="pt", max_length=Config.MAX_LENGTH, truncation=True).to(self.device)
h_rep = self.model.get_embeddings(inputs_replay.input_ids).to(Config.DTYPE)
mod_rep = self.model.controller(h_rep)
logits_rep = self.model(inputs_replay.input_ids, forced_modulation=mod_rep)
s_log = logits_rep[..., :-1, :].contiguous()
s_lab = inputs_replay.input_ids[..., 1:].contiguous()
loss_rep = torch.nn.functional.cross_entropy(s_log.view(-1, s_log.size(-1)), s_lab.view(-1))
# Weight Replay EQUAL (1.0) to task to enforce stability
total_loss += loss_rep * 1.0
# --- C. Anti-Drift (Crucial for TTT) ---
# Penalize deviation from original weights to prevent "Model Collapse"
drift_loss = 0
for name, param in self.model.controller.named_parameters():
drift_loss += torch.sum((param - self.initial_controller_state[name].to(self.device)) ** 2)
total_loss += drift_loss * 10.0 # Very Strong anchor (was 1.0)
total_loss.backward()
# Debug Gradients
total_norm = 0.0
for p in self.model.controller.parameters():
if p.grad is not None:
total_norm += p.grad.data.norm(2).item() ** 2
total_norm = total_norm ** 0.5
self.optimizer.step()
if (step+1) % 10 == 0:
print(f" Step {step+1}: Loss {total_loss.item():.4f} | Grad Norm: {total_norm:.4f}")
# Early Stopping (Prevent Overfitting)
if total_loss.item() < 0.005:
print(f" βœ… Converged early at step {step+1} (Loss < 0.005)")
break
# 3. Store context (DISABLED)
# self.session_context.append((user_input, mod_pred.detach()))
self.model.clear_modulation()
print(" βœ… Adaptation Complete. Weights Updated.")
return total_loss.item()
def save_weights(self, suffix="session"):
"""Save the updated weights after a learning session."""
print(" πŸ’Ύ Saving updated weights...")
torch.save(self.model.controller.state_dict(), f"controller_{suffix}.pt")
adapter_states = [l.state_dict() for l in self.model.flux_layers]
torch.save(adapter_states, f"adapters_{suffix}.pt")
print(f" βœ… Saved to controller_{suffix}.pt and adapters_{suffix}.pt")
def run(self):
"""Main interactive loop."""
print("\n" + "="*60)
print(" πŸ§ͺ CONTINUOUS LEARNING LAB")
print(" Commands:")
print(" - Ask any physics question")
print(" - Type 'wrong' if the answer is incorrect")
print(" - Type 'save' to save updated weights")
print(" - Type 'exit' to quit")
print("="*60)
while True:
try:
user_input = input("\nUSER: ").strip()
except (EOFError, KeyboardInterrupt):
break
if not user_input:
continue
if user_input.lower() in ['exit', 'quit']:
break
if user_input.lower() == 'save':
self.save_weights()
continue
# Generate prediction
response, modulation = self.predict(user_input)
mod_norm = modulation.norm().item()
print(f"MODEL: {response}")
print(f" [Modulation Norm: {mod_norm:.2f}]")
# Feedback loop
try:
feedback = input(" (Enter=correct, 'wrong'=teach): ").strip().lower()
except (EOFError, KeyboardInterrupt):
break
if feedback == "wrong":
try:
truth = input(" CORRECT ANSWER: ").strip()
# topic = input(" TOPIC ID (e.g. 'gravity', 'thermo'): ").strip()
topic = "general" # Defaulting as requested
except (EOFError, KeyboardInterrupt):
break
if truth:
# Pass the topic to learn so it can index it correctly
self.learn(user_input, truth, topic)
# Store correct modulation in context (DISABLED)
# self.session_context.append((user_input, modulation))
else:
# Correct answer - store in context for future reference (DISABLED)
# self.session_context.append((user_input, modulation))
print(" πŸ‘ Perfect! (No update needed)")
print("\nπŸ‘‹ Session ended.")
# Offer to save
try:
save = input(" Save updated weights? (y/n): ").strip().lower()
if save == 'y':
self.save_weights()
except:
pass
if __name__ == "__main__":
session = ContinuousLearningSession()
session.run()