convaiinnovations commited on
Commit
f3fff34
·
verified ·
1 Parent(s): fef69a4

Upload continuous_learning_session.py

Browse files
Files changed (1) hide show
  1. continuous_learning_session.py +11 -1
continuous_learning_session.py CHANGED
@@ -1,6 +1,11 @@
1
  import random
2
  import logging
3
  import os
 
 
 
 
 
4
  import torch
5
  import torch.nn as nn
6
  import torch.optim as optim
@@ -84,7 +89,8 @@ class ContinuousLearningSession:
84
  # If it's a python list
85
  adapter_params = [p for layer in self.model.flux_layers for p in layer.parameters()]
86
 
87
- self.optimizer = optim.Adam(controller_params + adapter_params, lr=1e-4) # Reduced from 1e-3 for stability
 
88
 
89
  self.model.train() # Enable gradients for Controller/Adapters
90
 
@@ -270,6 +276,10 @@ class ContinuousLearningSession:
270
  # 1. Add new knowledge to Buffer
271
  self.replay_buffer.add(concept_id, user_input, correct_answer)
272
 
 
 
 
 
273
  # 2. Training Loop (Micro-Epochs)
274
  steps = 50 # Increase back to 50 since we have more data now!
275
 
 
1
  import random
2
  import logging
3
  import os
4
+ import gc
5
+
6
+ # Optimize CUDA memory allocation to reduce fragmentation
7
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
8
+
9
  import torch
10
  import torch.nn as nn
11
  import torch.optim as optim
 
89
  # If it's a python list
90
  adapter_params = [p for layer in self.model.flux_layers for p in layer.parameters()]
91
 
92
+ # Switch to SGD to save memory (Adam uses 2x states, causing OOM on T4)
93
+ self.optimizer = optim.SGD(controller_params + adapter_params, lr=1e-3, momentum=0.9)
94
 
95
  self.model.train() # Enable gradients for Controller/Adapters
96
 
 
276
  # 1. Add new knowledge to Buffer
277
  self.replay_buffer.add(concept_id, user_input, correct_answer)
278
 
279
+ # Force cleanup before training to prevent OOM
280
+ gc.collect()
281
+ torch.cuda.empty_cache()
282
+
283
  # 2. Training Loop (Micro-Epochs)
284
  steps = 50 # Increase back to 50 since we have more data now!
285