b2u commited on
Commit
5b6ee0c
·
1 Parent(s): 6ee281b

rolling back the settings

Browse files
Files changed (2) hide show
  1. docker-compose.yml +3 -3
  2. model.py +6 -6
docker-compose.yml CHANGED
@@ -15,12 +15,12 @@ services:
15
  - GENERATION_MAX_LENGTH=128
16
  - NUM_RETURN_SEQUENCES=1
17
  # LoRA settings
18
- - LORA_R=4
19
- - LORA_ALPHA=8
20
  - LORA_DROPOUT=0.1
21
  - LORA_TARGET_MODULES=q,v
22
  # Training settings
23
- - EPOCHS=6
24
  - LEARNING_RATE=1e-5
25
  - BATCH_SIZE=1
26
  - MAX_STEPS=100
 
15
  - GENERATION_MAX_LENGTH=128
16
  - NUM_RETURN_SEQUENCES=1
17
  # LoRA settings
18
+ - LORA_R=16
19
+ - LORA_ALPHA=16
20
  - LORA_DROPOUT=0.1
21
  - LORA_TARGET_MODULES=q,v
22
  # Training settings
23
+ - NUM_EPOCHS=16
24
  - LEARNING_RATE=1e-5
25
  - BATCH_SIZE=1
26
  - MAX_STEPS=100
model.py CHANGED
@@ -207,10 +207,10 @@ class T5Model(LabelStudioMLBase):
207
 
208
  # Configure LoRA
209
  lora_config = LoraConfig(
210
- r=int(os.getenv('LORA_R')),
211
- lora_alpha=int(os.getenv('LORA_ALPHA')),
212
- target_modules=os.getenv('LORA_TARGET_MODULES').split(','),
213
- lora_dropout=float(os.getenv('LORA_DROPOUT')),
214
  bias="none",
215
  task_type="SEQ_2_SEQ_LM"
216
  )
@@ -225,9 +225,9 @@ class T5Model(LabelStudioMLBase):
225
 
226
  # Training loop
227
  logger.info("Starting training loop...")
228
- optimizer = torch.optim.AdamW(model.parameters(), lr=float(os.getenv('LEARNING_RATE')))
229
 
230
- num_epochs = int(os.getenv('NUM_EPOCHS'))
231
 
232
 
233
  # Add LoRA settings logging here
 
207
 
208
  # Configure LoRA
209
  lora_config = LoraConfig(
210
+ r=int(os.getenv('LORA_R', '4')),
211
+ lora_alpha=int(os.getenv('LORA_ALPHA', '8')),
212
+ target_modules=os.getenv('LORA_TARGET_MODULES', 'q,v').split(','),
213
+ lora_dropout=float(os.getenv('LORA_DROPOUT', '0.1')),
214
  bias="none",
215
  task_type="SEQ_2_SEQ_LM"
216
  )
 
225
 
226
  # Training loop
227
  logger.info("Starting training loop...")
228
+ optimizer = torch.optim.AdamW(model.parameters(), lr=float(os.getenv('LEARNING_RATE', '1e-5')))
229
 
230
+ num_epochs = int(os.getenv('NUM_EPOCHS', '6'))
231
 
232
 
233
  # Add LoRA settings logging here