File size: 20,080 Bytes
17ba477
 
 
f3fff34
 
 
 
 
17ba477
 
 
 
 
 
 
 
 
 
 
8a52161
 
 
 
 
 
 
 
 
 
 
 
 
17ba477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1c4be0
 
17ba477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b2e77d
17ba477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fef69a4
 
8b2e77d
17ba477
 
fef69a4
 
 
 
8b2e77d
 
 
 
 
17ba477
8b2e77d
 
 
17ba477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3fff34
 
 
 
17ba477
349f5a8
 
17ba477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349f5a8
 
17ba477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
import random
import logging
import os
import gc

# Optimize CUDA memory allocation to reduce fragmentation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import torch
import torch.nn as nn
import torch.optim as optim
from modeling_physics_rl import PhysicsModel, Config

class StratifiedReplayBuffer:
    """

    Stores memories by Concept ID (or just generic 'user_taught') to ensure we sample DIVERSE history.

    """
    def __init__(self):
        self.memory = {} # { "concept_id": [ {prompt, answer}, ... ] }
        # Pre-fill with Anchor Memories to prevent Cold-Start Catastrophic Forgetting
        self._add_anchor_memories()
        
    def _add_anchor_memories(self):
        anchors = [
            ("What is gravity?", "Gravity is a fundamental interaction which causes mutual attraction between all things with mass or energy."),
            ("Hello", "Hello! How can I help you today?"),
            ("What is AI?", "Artificial Intelligence (AI) refers to the simulation of human intelligence in machines."),
            ("Define thermodynamics.", "Thermodynamics is a branch of physics that deals with heat, work, and temperature, and their relation to energy, entropy, and the physical properties of matter."),
            ("Who are you?", "I am a large language model, trained by Google.")
        ]
        self.memory["anchor"] = [{"prompt": q, "answer": a} for q, a in anchors]
        print(f"   βš“ Added {len(anchors)} General Knowledge Anchors to Replay Buffer.")
        
    def add(self, concept_id, prompt, answer):
        if concept_id not in self.memory:
            self.memory[concept_id] = []
        self.memory[concept_id].append({"prompt": prompt, "answer": answer})
        
    def sample_stratified(self, current_concept_id, n_per_concept=1):
        batch = []
        past_concepts = [cid for cid in self.memory.keys() if cid != current_concept_id]
        if not past_concepts: return []
        for cid in past_concepts:
            samples = random.sample(self.memory[cid], min(len(self.memory[cid]), n_per_concept))
            batch.extend(samples)
        return batch


class ContinuousLearningSession:
    def __init__(self):
        print("🧠 Initializing Continuous Learning Session...")
        
        # 1. Load Model
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"   πŸš€ Using Device: {self.device}")
        
        self.model = PhysicsModel()
        self.model.to(self.device) # Force move to GPU
        
        # 2. Load Pre-trained Weights
        self._load_weights()
        
        # 3. Setup Online Optimizer
        # Update BOTH Controller AND Flux Adapters for true adaptation
        trainable_params = [
            {'params': self.model.controller.parameters(), 'lr': 1e-3},  # Fast adaptation
        ]
        
        # Also update the Flux Adapters' modulation projection
        for layer in self.model.flux_layers:
            trainable_params.append({'params': layer.modulation_proj.parameters(), 'lr': 5e-4})
        
        self.optimizer = optim.AdamW(trainable_params)
        
        # 4. Session Memory (Context Window)
        # This stores the "learned context" so the model remembers the session
        self.session_context = []  # List of (input, modulation) pairs
        self.context_modulation = None  # Accumulated modulation bias
        
        # 5. Ensure backbone is frozen, but Controller & Adapters are TRAINABLE
        for p in self.model.llm.parameters():
            p.requires_grad = False
            
        print("   πŸ”§ Unfreezing Controller & Flux Adapters...")
        for p in self.model.controller.parameters():
            p.requires_grad = True
        
        if isinstance(self.model.flux_layers, list):
             for layer in self.model.flux_layers:
                 for p in layer.parameters():
                     p.requires_grad = True
        else:
             for p in self.model.flux_layers.parameters():
                 p.requires_grad = True
            
        # 3. Setup Online Optimizer
        # Update BOTH Controller AND Flux Adapters for true adaptation
        controller_params = list(self.model.controller.parameters())
        if isinstance(self.model.flux_layers, torch.nn.ModuleList) or isinstance(self.model.flux_layers, torch.nn.Sequential):
             adapter_params = list(self.model.flux_layers.parameters())
        else:
             # If it's a python list
             adapter_params = [p for layer in self.model.flux_layers for p in layer.parameters()]
             
        # Switch back to Adam (Better convergence, relying on GC/Env for memory safety)
        self.optimizer = optim.Adam(controller_params + adapter_params, lr=1e-4)
        
        self.model.train()  # Enable gradients for Controller/Adapters
        
        # 6. Initialize Replay Buffer & Drift Anchor
        self.replay_buffer = StratifiedReplayBuffer()
        self.initial_controller_state = {k: v.clone() for k, v in self.model.controller.state_dict().items()}
        
        print("   βœ… Ready for Interactive Continuous Learning (Powered by Replay Buffer)!")
        
    def _load_weights(self):
        """Load pre-trained weights from various possible locations."""
        search_paths = [
            ".",
            "/kaggle/input/worldmodels/physics_model",
            "/kaggle/working/physics_model"
        ]
        
        for path in search_paths:
            controller_path = os.path.join(path, "final_physics_controller.pt")
            if os.path.exists(controller_path):
                print(f"   Loading weights from {path}...")
                self.model.controller.load_state_dict(
                    torch.load(controller_path, map_location=self.device)
                )
                
                # Load WALT
                walt_path = os.path.join(path, "final_walt_head.pt")
                if os.path.exists(walt_path):
                    self.model.walt.load_state_dict(
                        torch.load(walt_path, map_location=self.device)
                    )
                
                # Load Adapters
                adapter_path = os.path.join(path, "final_liquid_adapters.pt")
                if os.path.exists(adapter_path):
                    adapter_states = torch.load(adapter_path, map_location=self.device)
                    for layer, state in zip(self.model.flux_layers, adapter_states):
                        layer.load_state_dict(state)
                    print("   βœ… Loaded Flux Adapters.")
                
                return
        
        print("   ⚠️ No pre-trained weights found. Using random initialization.")
    

    # def _get_context_modulation(self):
    #     """
    #     Compute a modulation bias from session history.
    #     This allows the model to "remember" previous physics context.
    #     """
    #     if not self.session_context:
    #         return None
        
    #     # Average the modulations from recent context (last 3 interactions)
    #     recent = self.session_context[-3:]
    #     mods = [m for _, m in recent if m is not None]
        
    #     if not mods:
    #         return None
            
    #     # Stack and average
    #     stacked = torch.stack(mods)
    #     return stacked.mean(dim=0)
    
    def predict(self, user_input: str):
        """

        Generate a response using the current Controller & Flux Adapters.

        Pure Inference: No context history, just the current weights.

        """
        self.model.eval()
        
        full_prompt = f"User: {user_input}\nModel:"
        inputs = self.model.tokenizer(full_prompt, return_tensors="pt").to(self.device)
        
        # 1. Generate Modulation (Based strictly on CURRENT input)
        with torch.no_grad():
            h_init = self.model.get_embeddings(inputs.input_ids).to(Config.DTYPE)
        
        modulation = self.model.controller(h_init)
        
        # 2. No Context Bias (Disabled per request)
        # We rely solely on the weight updates from 'learn()'
        # context_mod = self._get_context_modulation()
        # if context_mod is not None:
        #     # Blend: 70% new, 30% context
        #     modulation = 0.7 * modulation + 0.3 * context_mod.to(modulation.device)
        
        # 3. Apply modulation and generate
        self.model.set_active_modulation(modulation)
        
        out_ids = self.model.llm.generate(
            **inputs,
            max_new_tokens=100, # Increased for chat
            # max_length=Config.MAX_LENGTH, # Removed as per diff
            do_sample=True,
            temperature=0.7, # Changed from 0.6 to 0.7
            repetition_penalty=1.0, # Reset to default (was 1.2) to fix silence
            pad_token_id=self.model.tokenizer.eos_token_id
        )
        
        response = self.model.tokenizer.decode(out_ids[0], skip_special_tokens=True)
        response_clean = response.split("Model:")[-1].strip()
        
        self.model.clear_modulation()
        
        return response_clean, modulation.detach()
    
    def _generate_synthetic_data(self, question, answer, num_variations=3):
        """

        Uses the frozen Base LLM to generate diverse variations of the training example.

        This turns One-Shot Learning into Few-Shot Learning (Synthetic Data Augmentation).

        """
        print("   ✨ Generating synthetic training data (Self-Distillation)...")
        
        # 1. Disable adapters/modulation to get clean English capability
        self.model.clear_modulation()
        self.model.eval()
        
        prompt = (
            f"Original Question: {question}\n"
            f"Original Answer: {answer}\n\n"
            f"Task: Rewrite the above Question and Answer pair in {num_variations} different styles (e.g. simple, formal, detailed). "
            f"Keep the facts exactly the same.\n"
            f"Output format:\n"
            f"Q1: ...\n"
            f"A1: ...\n"
            f"Q2: ...\n"
            f"A2: ...\n"
            f"Start now:"
        )
        
        inputs = self.model.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        with torch.no_grad():
            out_ids = self.model.llm.generate(
                **inputs,
                max_new_tokens=256,
                do_sample=True,
                temperature=0.7
            )
            
        raw_text = self.model.tokenizer.decode(out_ids[0], skip_special_tokens=True)
        # Parse the output (Simple heuristic parsing)
        variations = [{"q": question, "a": answer}] # Always include original
        
        current_q = None
        for line in raw_text.split('\n'):
            line = line.strip()
            if line.startswith("Q") and ":" in line:
                current_q = line.split(":", 1)[1].strip()
            elif line.startswith("A") and ":" in line and current_q:
                current_a = line.split(":", 1)[1].strip()
                # Validation: Ensure neither Q nor A is empty or garbage
                if current_q and current_a and "..." not in current_q and "..." not in current_a:
                    variations.append({"q": current_q, "a": current_a})
                current_q = None
        
        # Cleanup Memory
        del inputs, out_ids
        torch.cuda.empty_cache()

        # Fallback: If synthetic generation failed, duplicate original
        if len(variations) == 1:
             print("   ⚠️ Synthetic generation failed to produce valid format. Duplicating original.")
             variations.append({"q": question, "a": answer})
        
        print(f"   ✨ Generated {len(variations)-1} synthetic variations.")
        for i, v in enumerate(variations):
            print(f"      [{i}] Q: {v['q'][:30]}... A: {v['a'][:30]}...")
            
        return variations

    def learn(self, user_input: str, correct_answer: str, concept_id: str = "general"):
        """

        Robust Learning: Updates weights using the new example + Replay Buffer.

        Runs specific number of steps (plasticity) while anchoring to past (stability).

        """
        print("\n   🧠 Starting Robust Adaptation Loop...")
        
        # 0. Augment Data (Synthetic Variations)
        training_batch = self._generate_synthetic_data(user_input, correct_answer)
        
        # 1. Add new knowledge to Buffer
        self.replay_buffer.add(concept_id, user_input, correct_answer)
        
        # Force cleanup before training to prevent OOM
        gc.collect()
        torch.cuda.empty_cache()

        # 2. Training Loop (Micro-Epochs)
        # 2. Training Loop (Micro-Epochs)
        steps = 20 # Reduced to 20 (Safe limit for strong replay)
        
        for step in range(steps):
            self.optimizer.zero_grad()
            total_loss = 0
            
            # --- A. Current Task (Random Sample from Synthetic Batch) ---
            # Pick a random variation to train on this step
            example = random.choice(training_batch)
            
            # Append EOS so model knows when to STOP talking
            full_text = f"User: {example['q']}\nModel: {example['a']}{self.model.tokenizer.eos_token}"
            inputs_train = self.model.tokenizer(full_text, return_tensors="pt", max_length=Config.MAX_LENGTH, truncation=True).to(self.device)
            
            h_train = self.model.get_embeddings(inputs_train.input_ids).to(Config.DTYPE)
            mod_pred = self.model.controller(h_train)
            logits = self.model(inputs_train.input_ids, forced_modulation=mod_pred)
            
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = inputs_train.input_ids[..., 1:].contiguous()
            task_loss = torch.nn.functional.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            
            total_loss += task_loss * 1.0
            
            # --- B. Replay (Stability) ---
            past_memories = self.replay_buffer.sample_stratified(concept_id, n_per_concept=2)
            if past_memories:
                 for mem in past_memories:
                     full_replay = f"User: {mem['prompt']}\nModel: {mem['answer']}"
                     inputs_replay = self.model.tokenizer(full_replay, return_tensors="pt", max_length=Config.MAX_LENGTH, truncation=True).to(self.device)
                     
                     h_rep = self.model.get_embeddings(inputs_replay.input_ids).to(Config.DTYPE)
                     mod_rep = self.model.controller(h_rep)
                     logits_rep = self.model(inputs_replay.input_ids, forced_modulation=mod_rep)
                     
                     s_log = logits_rep[..., :-1, :].contiguous()
                     s_lab = inputs_replay.input_ids[..., 1:].contiguous()
                     loss_rep = torch.nn.functional.cross_entropy(s_log.view(-1, s_log.size(-1)), s_lab.view(-1))
                     
                     # Weight Replay EQUAL (1.0) to task to enforce stability
                     total_loss += loss_rep * 1.0

            # --- C. Anti-Drift (Crucial for TTT) ---
            # Penalize deviation from original weights to prevent "Model Collapse"
            drift_loss = 0
            for name, param in self.model.controller.named_parameters():
                drift_loss += torch.sum((param - self.initial_controller_state[name].to(self.device)) ** 2)
            total_loss += drift_loss * 10.0 # Very Strong anchor (was 1.0)

            total_loss.backward()
            
            # Debug Gradients
            total_norm = 0.0
            for p in self.model.controller.parameters():
                if p.grad is not None:
                    total_norm += p.grad.data.norm(2).item() ** 2
            total_norm = total_norm ** 0.5
            
            self.optimizer.step()
            
            if (step+1) % 10 == 0:
                print(f"      Step {step+1}: Loss {total_loss.item():.4f} | Grad Norm: {total_norm:.4f}")
            
            # Early Stopping (Prevent Overfitting)
            if total_loss.item() < 0.005:
                print(f"      βœ… Converged early at step {step+1} (Loss < 0.005)")
                break
        
        # 3. Store context (DISABLED)
        # self.session_context.append((user_input, mod_pred.detach()))
        self.model.clear_modulation()
        
        print("   βœ… Adaptation Complete. Weights Updated.")
        return total_loss.item()
    
    def save_weights(self, suffix="session"):
        """Save the updated weights after a learning session."""
        print("   πŸ’Ύ Saving updated weights...")
        torch.save(self.model.controller.state_dict(), f"controller_{suffix}.pt")
        
        adapter_states = [l.state_dict() for l in self.model.flux_layers]
        torch.save(adapter_states, f"adapters_{suffix}.pt")
        
        print(f"   βœ… Saved to controller_{suffix}.pt and adapters_{suffix}.pt")
    
    def run(self):
        """Main interactive loop."""
        print("\n" + "="*60)
        print(" πŸ§ͺ CONTINUOUS LEARNING LAB")
        print(" Commands:")
        print("   - Ask any physics question")
        print("   - Type 'wrong' if the answer is incorrect")
        print("   - Type 'save' to save updated weights")
        print("   - Type 'exit' to quit")
        print("="*60)
        
        while True:
            try:
                user_input = input("\nUSER: ").strip()
            except (EOFError, KeyboardInterrupt):
                break
                
            if not user_input:
                continue
            if user_input.lower() in ['exit', 'quit']:
                break
            if user_input.lower() == 'save':
                self.save_weights()
                continue
            
            # Generate prediction
            response, modulation = self.predict(user_input)
            mod_norm = modulation.norm().item()
            
            print(f"MODEL: {response}")
            print(f"   [Modulation Norm: {mod_norm:.2f}]")
            
            # Feedback loop
            try:
                feedback = input("   (Enter=correct, 'wrong'=teach): ").strip().lower()
            except (EOFError, KeyboardInterrupt):
                break
            
            if feedback == "wrong":
                try:
                    truth = input("   CORRECT ANSWER: ").strip()
                    # topic = input("   TOPIC ID (e.g. 'gravity', 'thermo'): ").strip()
                    topic = "general" # Defaulting as requested
                except (EOFError, KeyboardInterrupt):
                    break
                    
                if truth:
                    # Pass the topic to learn so it can index it correctly
                    self.learn(user_input, truth, topic)
                    # Store correct modulation in context (DISABLED)
                    # self.session_context.append((user_input, modulation))
            else:
                # Correct answer - store in context for future reference (DISABLED)
                # self.session_context.append((user_input, modulation))
                print("   πŸ‘ Perfect! (No update needed)")
        
        print("\nπŸ‘‹ Session ended.")
        
        # Offer to save
        try:
            save = input("   Save updated weights? (y/n): ").strip().lower()
            if save == 'y':
                self.save_weights()
        except:
            pass


if __name__ == "__main__":
    session = ContinuousLearningSession()
    session.run()