b2u commited on
Commit
6ee281b
·
1 Parent(s): 334eca1

moving LoRA settings under .yml

Browse files
Files changed (2) hide show
  1. docker-compose.yml +4 -4
  2. model.py +6 -6
docker-compose.yml CHANGED
@@ -15,13 +15,13 @@ services:
15
  - GENERATION_MAX_LENGTH=128
16
  - NUM_RETURN_SEQUENCES=1
17
  # LoRA settings
18
- - LORA_R=8
19
- - LORA_ALPHA=32
20
  - LORA_DROPOUT=0.1
21
  - LORA_TARGET_MODULES=q,v
22
  # Training settings
23
- - EPOCHS=3
24
- - LEARNING_RATE=1e-4
25
  - BATCH_SIZE=1
26
  - MAX_STEPS=100
27
  - SAVE_STEPS=50
 
15
  - GENERATION_MAX_LENGTH=128
16
  - NUM_RETURN_SEQUENCES=1
17
  # LoRA settings
18
+ - LORA_R=4
19
+ - LORA_ALPHA=8
20
  - LORA_DROPOUT=0.1
21
  - LORA_TARGET_MODULES=q,v
22
  # Training settings
23
+ - EPOCHS=6
24
+ - LEARNING_RATE=1e-5
25
  - BATCH_SIZE=1
26
  - MAX_STEPS=100
27
  - SAVE_STEPS=50
model.py CHANGED
@@ -207,10 +207,10 @@ class T5Model(LabelStudioMLBase):
207
 
208
  # Configure LoRA
209
  lora_config = LoraConfig(
210
- r=int(os.getenv('LORA_R', '4')),
211
- lora_alpha=int(os.getenv('LORA_ALPHA', '8')),
212
- target_modules=os.getenv('LORA_TARGET_MODULES', 'q,v').split(','),
213
- lora_dropout=float(os.getenv('LORA_DROPOUT', '0.1')),
214
  bias="none",
215
  task_type="SEQ_2_SEQ_LM"
216
  )
@@ -225,9 +225,9 @@ class T5Model(LabelStudioMLBase):
225
 
226
  # Training loop
227
  logger.info("Starting training loop...")
228
- optimizer = torch.optim.AdamW(model.parameters(), lr=float(os.getenv('LEARNING_RATE', '1e-5')))
229
 
230
- num_epochs = int(os.getenv('NUM_EPOCHS', '6'))
231
 
232
 
233
  # Add LoRA settings logging here
 
207
 
208
  # Configure LoRA
209
  lora_config = LoraConfig(
210
+ r=int(os.getenv('LORA_R')),
211
+ lora_alpha=int(os.getenv('LORA_ALPHA')),
212
+ target_modules=os.getenv('LORA_TARGET_MODULES').split(','),
213
+ lora_dropout=float(os.getenv('LORA_DROPOUT')),
214
  bias="none",
215
  task_type="SEQ_2_SEQ_LM"
216
  )
 
225
 
226
  # Training loop
227
  logger.info("Starting training loop...")
228
+ optimizer = torch.optim.AdamW(model.parameters(), lr=float(os.getenv('LEARNING_RATE')))
229
 
230
+ num_epochs = int(os.getenv('NUM_EPOCHS'))
231
 
232
 
233
  # Add LoRA settings logging here