|
|
import random
|
|
|
import logging
|
|
|
import os
|
|
|
import gc
|
|
|
|
|
|
|
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.optim as optim
|
|
|
from modeling_physics_rl import PhysicsModel, Config
|
|
|
|
|
|
class StratifiedReplayBuffer:
|
|
|
"""
|
|
|
Stores memories by Concept ID (or just generic 'user_taught') to ensure we sample DIVERSE history.
|
|
|
"""
|
|
|
def __init__(self):
|
|
|
self.memory = {}
|
|
|
|
|
|
self._add_anchor_memories()
|
|
|
|
|
|
def _add_anchor_memories(self):
|
|
|
anchors = [
|
|
|
("What is gravity?", "Gravity is a fundamental interaction which causes mutual attraction between all things with mass or energy."),
|
|
|
("Hello", "Hello! How can I help you today?"),
|
|
|
("What is AI?", "Artificial Intelligence (AI) refers to the simulation of human intelligence in machines."),
|
|
|
("Define thermodynamics.", "Thermodynamics is a branch of physics that deals with heat, work, and temperature, and their relation to energy, entropy, and the physical properties of matter."),
|
|
|
("Who are you?", "I am a large language model, trained by Google.")
|
|
|
]
|
|
|
self.memory["anchor"] = [{"prompt": q, "answer": a} for q, a in anchors]
|
|
|
print(f" β Added {len(anchors)} General Knowledge Anchors to Replay Buffer.")
|
|
|
|
|
|
def add(self, concept_id, prompt, answer):
|
|
|
if concept_id not in self.memory:
|
|
|
self.memory[concept_id] = []
|
|
|
self.memory[concept_id].append({"prompt": prompt, "answer": answer})
|
|
|
|
|
|
def sample_stratified(self, current_concept_id, n_per_concept=1):
|
|
|
batch = []
|
|
|
past_concepts = [cid for cid in self.memory.keys() if cid != current_concept_id]
|
|
|
if not past_concepts: return []
|
|
|
for cid in past_concepts:
|
|
|
samples = random.sample(self.memory[cid], min(len(self.memory[cid]), n_per_concept))
|
|
|
batch.extend(samples)
|
|
|
return batch
|
|
|
|
|
|
|
|
|
class ContinuousLearningSession:
|
|
|
def __init__(self):
|
|
|
print("π§ Initializing Continuous Learning Session...")
|
|
|
|
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
print(f" π Using Device: {self.device}")
|
|
|
|
|
|
self.model = PhysicsModel()
|
|
|
self.model.to(self.device)
|
|
|
|
|
|
|
|
|
self._load_weights()
|
|
|
|
|
|
|
|
|
|
|
|
trainable_params = [
|
|
|
{'params': self.model.controller.parameters(), 'lr': 1e-3},
|
|
|
]
|
|
|
|
|
|
|
|
|
for layer in self.model.flux_layers:
|
|
|
trainable_params.append({'params': layer.modulation_proj.parameters(), 'lr': 5e-4})
|
|
|
|
|
|
self.optimizer = optim.AdamW(trainable_params)
|
|
|
|
|
|
|
|
|
|
|
|
self.session_context = []
|
|
|
self.context_modulation = None
|
|
|
|
|
|
|
|
|
for p in self.model.llm.parameters():
|
|
|
p.requires_grad = False
|
|
|
|
|
|
print(" π§ Unfreezing Controller & Flux Adapters...")
|
|
|
for p in self.model.controller.parameters():
|
|
|
p.requires_grad = True
|
|
|
|
|
|
if isinstance(self.model.flux_layers, list):
|
|
|
for layer in self.model.flux_layers:
|
|
|
for p in layer.parameters():
|
|
|
p.requires_grad = True
|
|
|
else:
|
|
|
for p in self.model.flux_layers.parameters():
|
|
|
p.requires_grad = True
|
|
|
|
|
|
|
|
|
|
|
|
controller_params = list(self.model.controller.parameters())
|
|
|
if isinstance(self.model.flux_layers, torch.nn.ModuleList) or isinstance(self.model.flux_layers, torch.nn.Sequential):
|
|
|
adapter_params = list(self.model.flux_layers.parameters())
|
|
|
else:
|
|
|
|
|
|
adapter_params = [p for layer in self.model.flux_layers for p in layer.parameters()]
|
|
|
|
|
|
|
|
|
self.optimizer = optim.Adam(controller_params + adapter_params, lr=1e-4)
|
|
|
|
|
|
self.model.train()
|
|
|
|
|
|
|
|
|
self.replay_buffer = StratifiedReplayBuffer()
|
|
|
self.initial_controller_state = {k: v.clone() for k, v in self.model.controller.state_dict().items()}
|
|
|
|
|
|
print(" β
Ready for Interactive Continuous Learning (Powered by Replay Buffer)!")
|
|
|
|
|
|
def _load_weights(self):
|
|
|
"""Load pre-trained weights from various possible locations."""
|
|
|
search_paths = [
|
|
|
".",
|
|
|
"/kaggle/input/worldmodels/physics_model",
|
|
|
"/kaggle/working/physics_model"
|
|
|
]
|
|
|
|
|
|
for path in search_paths:
|
|
|
controller_path = os.path.join(path, "final_physics_controller.pt")
|
|
|
if os.path.exists(controller_path):
|
|
|
print(f" Loading weights from {path}...")
|
|
|
self.model.controller.load_state_dict(
|
|
|
torch.load(controller_path, map_location=self.device)
|
|
|
)
|
|
|
|
|
|
|
|
|
walt_path = os.path.join(path, "final_walt_head.pt")
|
|
|
if os.path.exists(walt_path):
|
|
|
self.model.walt.load_state_dict(
|
|
|
torch.load(walt_path, map_location=self.device)
|
|
|
)
|
|
|
|
|
|
|
|
|
adapter_path = os.path.join(path, "final_liquid_adapters.pt")
|
|
|
if os.path.exists(adapter_path):
|
|
|
adapter_states = torch.load(adapter_path, map_location=self.device)
|
|
|
for layer, state in zip(self.model.flux_layers, adapter_states):
|
|
|
layer.load_state_dict(state)
|
|
|
print(" β
Loaded Flux Adapters.")
|
|
|
|
|
|
return
|
|
|
|
|
|
print(" β οΈ No pre-trained weights found. Using random initialization.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(self, user_input: str):
|
|
|
"""
|
|
|
Generate a response using the current Controller & Flux Adapters.
|
|
|
Pure Inference: No context history, just the current weights.
|
|
|
"""
|
|
|
self.model.eval()
|
|
|
|
|
|
full_prompt = f"User: {user_input}\nModel:"
|
|
|
inputs = self.model.tokenizer(full_prompt, return_tensors="pt").to(self.device)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
h_init = self.model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
|
|
|
|
|
|
modulation = self.model.controller(h_init)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.model.set_active_modulation(modulation)
|
|
|
|
|
|
out_ids = self.model.llm.generate(
|
|
|
**inputs,
|
|
|
max_new_tokens=100,
|
|
|
|
|
|
do_sample=True,
|
|
|
temperature=0.7,
|
|
|
repetition_penalty=1.0,
|
|
|
pad_token_id=self.model.tokenizer.eos_token_id
|
|
|
)
|
|
|
|
|
|
response = self.model.tokenizer.decode(out_ids[0], skip_special_tokens=True)
|
|
|
response_clean = response.split("Model:")[-1].strip()
|
|
|
|
|
|
self.model.clear_modulation()
|
|
|
|
|
|
return response_clean, modulation.detach()
|
|
|
|
|
|
def _generate_synthetic_data(self, question, answer, num_variations=3):
|
|
|
"""
|
|
|
Uses the frozen Base LLM to generate diverse variations of the training example.
|
|
|
This turns One-Shot Learning into Few-Shot Learning (Synthetic Data Augmentation).
|
|
|
"""
|
|
|
print(" β¨ Generating synthetic training data (Self-Distillation)...")
|
|
|
|
|
|
|
|
|
self.model.clear_modulation()
|
|
|
self.model.eval()
|
|
|
|
|
|
prompt = (
|
|
|
f"Original Question: {question}\n"
|
|
|
f"Original Answer: {answer}\n\n"
|
|
|
f"Task: Rewrite the above Question and Answer pair in {num_variations} different styles (e.g. simple, formal, detailed). "
|
|
|
f"Keep the facts exactly the same.\n"
|
|
|
f"Output format:\n"
|
|
|
f"Q1: ...\n"
|
|
|
f"A1: ...\n"
|
|
|
f"Q2: ...\n"
|
|
|
f"A2: ...\n"
|
|
|
f"Start now:"
|
|
|
)
|
|
|
|
|
|
inputs = self.model.tokenizer(prompt, return_tensors="pt").to(self.device)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
out_ids = self.model.llm.generate(
|
|
|
**inputs,
|
|
|
max_new_tokens=256,
|
|
|
do_sample=True,
|
|
|
temperature=0.7
|
|
|
)
|
|
|
|
|
|
raw_text = self.model.tokenizer.decode(out_ids[0], skip_special_tokens=True)
|
|
|
|
|
|
variations = [{"q": question, "a": answer}]
|
|
|
|
|
|
current_q = None
|
|
|
for line in raw_text.split('\n'):
|
|
|
line = line.strip()
|
|
|
if line.startswith("Q") and ":" in line:
|
|
|
current_q = line.split(":", 1)[1].strip()
|
|
|
elif line.startswith("A") and ":" in line and current_q:
|
|
|
current_a = line.split(":", 1)[1].strip()
|
|
|
|
|
|
if current_q and current_a and "..." not in current_q and "..." not in current_a:
|
|
|
variations.append({"q": current_q, "a": current_a})
|
|
|
current_q = None
|
|
|
|
|
|
|
|
|
del inputs, out_ids
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
if len(variations) == 1:
|
|
|
print(" β οΈ Synthetic generation failed to produce valid format. Duplicating original.")
|
|
|
variations.append({"q": question, "a": answer})
|
|
|
|
|
|
print(f" β¨ Generated {len(variations)-1} synthetic variations.")
|
|
|
for i, v in enumerate(variations):
|
|
|
print(f" [{i}] Q: {v['q'][:30]}... A: {v['a'][:30]}...")
|
|
|
|
|
|
return variations
|
|
|
|
|
|
def learn(self, user_input: str, correct_answer: str, concept_id: str = "general"):
|
|
|
"""
|
|
|
Robust Learning: Updates weights using the new example + Replay Buffer.
|
|
|
Runs specific number of steps (plasticity) while anchoring to past (stability).
|
|
|
"""
|
|
|
print("\n π§ Starting Robust Adaptation Loop...")
|
|
|
|
|
|
|
|
|
training_batch = self._generate_synthetic_data(user_input, correct_answer)
|
|
|
|
|
|
|
|
|
self.replay_buffer.add(concept_id, user_input, correct_answer)
|
|
|
|
|
|
|
|
|
gc.collect()
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
steps = 20
|
|
|
|
|
|
for step in range(steps):
|
|
|
self.optimizer.zero_grad()
|
|
|
total_loss = 0
|
|
|
|
|
|
|
|
|
|
|
|
example = random.choice(training_batch)
|
|
|
|
|
|
|
|
|
full_text = f"User: {example['q']}\nModel: {example['a']}{self.model.tokenizer.eos_token}"
|
|
|
inputs_train = self.model.tokenizer(full_text, return_tensors="pt", max_length=Config.MAX_LENGTH, truncation=True).to(self.device)
|
|
|
|
|
|
h_train = self.model.get_embeddings(inputs_train.input_ids).to(Config.DTYPE)
|
|
|
mod_pred = self.model.controller(h_train)
|
|
|
logits = self.model(inputs_train.input_ids, forced_modulation=mod_pred)
|
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
|
shift_labels = inputs_train.input_ids[..., 1:].contiguous()
|
|
|
task_loss = torch.nn.functional.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
|
|
|
|
|
total_loss += task_loss * 1.0
|
|
|
|
|
|
|
|
|
past_memories = self.replay_buffer.sample_stratified(concept_id, n_per_concept=2)
|
|
|
if past_memories:
|
|
|
for mem in past_memories:
|
|
|
full_replay = f"User: {mem['prompt']}\nModel: {mem['answer']}"
|
|
|
inputs_replay = self.model.tokenizer(full_replay, return_tensors="pt", max_length=Config.MAX_LENGTH, truncation=True).to(self.device)
|
|
|
|
|
|
h_rep = self.model.get_embeddings(inputs_replay.input_ids).to(Config.DTYPE)
|
|
|
mod_rep = self.model.controller(h_rep)
|
|
|
logits_rep = self.model(inputs_replay.input_ids, forced_modulation=mod_rep)
|
|
|
|
|
|
s_log = logits_rep[..., :-1, :].contiguous()
|
|
|
s_lab = inputs_replay.input_ids[..., 1:].contiguous()
|
|
|
loss_rep = torch.nn.functional.cross_entropy(s_log.view(-1, s_log.size(-1)), s_lab.view(-1))
|
|
|
|
|
|
|
|
|
total_loss += loss_rep * 1.0
|
|
|
|
|
|
|
|
|
|
|
|
drift_loss = 0
|
|
|
for name, param in self.model.controller.named_parameters():
|
|
|
drift_loss += torch.sum((param - self.initial_controller_state[name].to(self.device)) ** 2)
|
|
|
total_loss += drift_loss * 10.0
|
|
|
|
|
|
total_loss.backward()
|
|
|
|
|
|
|
|
|
total_norm = 0.0
|
|
|
for p in self.model.controller.parameters():
|
|
|
if p.grad is not None:
|
|
|
total_norm += p.grad.data.norm(2).item() ** 2
|
|
|
total_norm = total_norm ** 0.5
|
|
|
|
|
|
self.optimizer.step()
|
|
|
|
|
|
if (step+1) % 10 == 0:
|
|
|
print(f" Step {step+1}: Loss {total_loss.item():.4f} | Grad Norm: {total_norm:.4f}")
|
|
|
|
|
|
|
|
|
if total_loss.item() < 0.005:
|
|
|
print(f" β
Converged early at step {step+1} (Loss < 0.005)")
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
self.model.clear_modulation()
|
|
|
|
|
|
print(" β
Adaptation Complete. Weights Updated.")
|
|
|
return total_loss.item()
|
|
|
|
|
|
def save_weights(self, suffix="session"):
|
|
|
"""Save the updated weights after a learning session."""
|
|
|
print(" πΎ Saving updated weights...")
|
|
|
torch.save(self.model.controller.state_dict(), f"controller_{suffix}.pt")
|
|
|
|
|
|
adapter_states = [l.state_dict() for l in self.model.flux_layers]
|
|
|
torch.save(adapter_states, f"adapters_{suffix}.pt")
|
|
|
|
|
|
print(f" β
Saved to controller_{suffix}.pt and adapters_{suffix}.pt")
|
|
|
|
|
|
def run(self):
|
|
|
"""Main interactive loop."""
|
|
|
print("\n" + "="*60)
|
|
|
print(" π§ͺ CONTINUOUS LEARNING LAB")
|
|
|
print(" Commands:")
|
|
|
print(" - Ask any physics question")
|
|
|
print(" - Type 'wrong' if the answer is incorrect")
|
|
|
print(" - Type 'save' to save updated weights")
|
|
|
print(" - Type 'exit' to quit")
|
|
|
print("="*60)
|
|
|
|
|
|
while True:
|
|
|
try:
|
|
|
user_input = input("\nUSER: ").strip()
|
|
|
except (EOFError, KeyboardInterrupt):
|
|
|
break
|
|
|
|
|
|
if not user_input:
|
|
|
continue
|
|
|
if user_input.lower() in ['exit', 'quit']:
|
|
|
break
|
|
|
if user_input.lower() == 'save':
|
|
|
self.save_weights()
|
|
|
continue
|
|
|
|
|
|
|
|
|
response, modulation = self.predict(user_input)
|
|
|
mod_norm = modulation.norm().item()
|
|
|
|
|
|
print(f"MODEL: {response}")
|
|
|
print(f" [Modulation Norm: {mod_norm:.2f}]")
|
|
|
|
|
|
|
|
|
try:
|
|
|
feedback = input(" (Enter=correct, 'wrong'=teach): ").strip().lower()
|
|
|
except (EOFError, KeyboardInterrupt):
|
|
|
break
|
|
|
|
|
|
if feedback == "wrong":
|
|
|
try:
|
|
|
truth = input(" CORRECT ANSWER: ").strip()
|
|
|
|
|
|
topic = "general"
|
|
|
except (EOFError, KeyboardInterrupt):
|
|
|
break
|
|
|
|
|
|
if truth:
|
|
|
|
|
|
self.learn(user_input, truth, topic)
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
|
print(" π Perfect! (No update needed)")
|
|
|
|
|
|
print("\nπ Session ended.")
|
|
|
|
|
|
|
|
|
try:
|
|
|
save = input(" Save updated weights? (y/n): ").strip().lower()
|
|
|
if save == 'y':
|
|
|
self.save_weights()
|
|
|
except:
|
|
|
pass
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
session = ContinuousLearningSession()
|
|
|
session.run()
|
|
|
|