Upload continuous_learning_session.py
Browse files
continuous_learning_session.py
CHANGED
|
@@ -1,6 +1,11 @@
|
|
| 1 |
import random
|
| 2 |
import logging
|
| 3 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
| 6 |
import torch.optim as optim
|
|
@@ -84,7 +89,8 @@ class ContinuousLearningSession:
|
|
| 84 |
# If it's a python list
|
| 85 |
adapter_params = [p for layer in self.model.flux_layers for p in layer.parameters()]
|
| 86 |
|
| 87 |
-
|
|
|
|
| 88 |
|
| 89 |
self.model.train() # Enable gradients for Controller/Adapters
|
| 90 |
|
|
@@ -270,6 +276,10 @@ class ContinuousLearningSession:
|
|
| 270 |
# 1. Add new knowledge to Buffer
|
| 271 |
self.replay_buffer.add(concept_id, user_input, correct_answer)
|
| 272 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
# 2. Training Loop (Micro-Epochs)
|
| 274 |
steps = 50 # Increase back to 50 since we have more data now!
|
| 275 |
|
|
|
|
| 1 |
import random
|
| 2 |
import logging
|
| 3 |
import os
|
| 4 |
+
import gc
|
| 5 |
+
|
| 6 |
+
# Optimize CUDA memory allocation to reduce fragmentation
|
| 7 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 8 |
+
|
| 9 |
import torch
|
| 10 |
import torch.nn as nn
|
| 11 |
import torch.optim as optim
|
|
|
|
| 89 |
# If it's a python list
|
| 90 |
adapter_params = [p for layer in self.model.flux_layers for p in layer.parameters()]
|
| 91 |
|
| 92 |
+
# Switch to SGD to save memory (Adam uses 2x states, causing OOM on T4)
|
| 93 |
+
self.optimizer = optim.SGD(controller_params + adapter_params, lr=1e-3, momentum=0.9)
|
| 94 |
|
| 95 |
self.model.train() # Enable gradients for Controller/Adapters
|
| 96 |
|
|
|
|
| 276 |
# 1. Add new knowledge to Buffer
|
| 277 |
self.replay_buffer.add(concept_id, user_input, correct_answer)
|
| 278 |
|
| 279 |
+
# Force cleanup before training to prevent OOM
|
| 280 |
+
gc.collect()
|
| 281 |
+
torch.cuda.empty_cache()
|
| 282 |
+
|
| 283 |
# 2. Training Loop (Micro-Epochs)
|
| 284 |
steps = 50 # Increase back to 50 since we have more data now!
|
| 285 |
|