Vaishnav14220
commited on
Commit
·
bd72f86
1
Parent(s):
fdbfba8
Tune training hyperparameters for L4 GPU
Browse files- src/config.py +2 -2
- src/train_forward.py +7 -1
- src/train_retro.py +7 -1
src/config.py
CHANGED
|
@@ -36,8 +36,8 @@ MAX_INPUT = 512
|
|
| 36 |
MAX_TARGET = 256
|
| 37 |
|
| 38 |
# Training Configuration
|
| 39 |
-
BATCH_SIZE =
|
| 40 |
-
GRADIENT_ACCUMULATION_STEPS =
|
| 41 |
LEARNING_RATE = 3e-4
|
| 42 |
NUM_EPOCHS = 5
|
| 43 |
EVAL_STEPS = 2000
|
|
|
|
| 36 |
MAX_TARGET = 256
|
| 37 |
|
| 38 |
# Training Configuration
|
| 39 |
+
BATCH_SIZE = 4
|
| 40 |
+
GRADIENT_ACCUMULATION_STEPS = 4
|
| 41 |
LEARNING_RATE = 3e-4
|
| 42 |
NUM_EPOCHS = 5
|
| 43 |
EVAL_STEPS = 2000
|
src/train_forward.py
CHANGED
|
@@ -4,6 +4,7 @@ Trains T5 model to predict products from reactants.
|
|
| 4 |
"""
|
| 5 |
import sacrebleu
|
| 6 |
import numpy as np
|
|
|
|
| 7 |
from transformers import (
|
| 8 |
AutoTokenizer,
|
| 9 |
T5ForConditionalGeneration,
|
|
@@ -41,6 +42,8 @@ def main():
|
|
| 41 |
model = T5ForConditionalGeneration.from_pretrained(BASE_MODEL)
|
| 42 |
model.resize_token_embeddings(len(tokenizer))
|
| 43 |
|
|
|
|
|
|
|
| 44 |
# Setup training arguments
|
| 45 |
print("\nSetting up training arguments...")
|
| 46 |
args = Seq2SeqTrainingArguments(
|
|
@@ -60,7 +63,10 @@ def main():
|
|
| 60 |
eval_steps=EVAL_STEPS,
|
| 61 |
save_steps=SAVE_STEPS,
|
| 62 |
report_to=[],
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
| 64 |
push_to_hub=True,
|
| 65 |
hub_model_id=FORWARD_MODEL_NAME,
|
| 66 |
hub_strategy="every_save",
|
|
|
|
| 4 |
"""
|
| 5 |
import sacrebleu
|
| 6 |
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
from transformers import (
|
| 9 |
AutoTokenizer,
|
| 10 |
T5ForConditionalGeneration,
|
|
|
|
| 42 |
model = T5ForConditionalGeneration.from_pretrained(BASE_MODEL)
|
| 43 |
model.resize_token_embeddings(len(tokenizer))
|
| 44 |
|
| 45 |
+
use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8
|
| 46 |
+
|
| 47 |
# Setup training arguments
|
| 48 |
print("\nSetting up training arguments...")
|
| 49 |
args = Seq2SeqTrainingArguments(
|
|
|
|
| 63 |
eval_steps=EVAL_STEPS,
|
| 64 |
save_steps=SAVE_STEPS,
|
| 65 |
report_to=[],
|
| 66 |
+
bf16=use_bf16,
|
| 67 |
+
fp16=not use_bf16,
|
| 68 |
+
dataloader_num_workers=4,
|
| 69 |
+
dataloader_pin_memory=True,
|
| 70 |
push_to_hub=True,
|
| 71 |
hub_model_id=FORWARD_MODEL_NAME,
|
| 72 |
hub_strategy="every_save",
|
src/train_retro.py
CHANGED
|
@@ -4,6 +4,7 @@ Trains T5 model to predict reactants from products.
|
|
| 4 |
"""
|
| 5 |
import sacrebleu
|
| 6 |
import numpy as np
|
|
|
|
| 7 |
from transformers import (
|
| 8 |
AutoTokenizer,
|
| 9 |
T5ForConditionalGeneration,
|
|
@@ -41,6 +42,8 @@ def main():
|
|
| 41 |
model = T5ForConditionalGeneration.from_pretrained(BASE_MODEL)
|
| 42 |
model.resize_token_embeddings(len(tokenizer))
|
| 43 |
|
|
|
|
|
|
|
| 44 |
# Setup training arguments
|
| 45 |
print("\nSetting up training arguments...")
|
| 46 |
args = Seq2SeqTrainingArguments(
|
|
@@ -60,7 +63,10 @@ def main():
|
|
| 60 |
eval_steps=EVAL_STEPS,
|
| 61 |
save_steps=SAVE_STEPS,
|
| 62 |
report_to=[],
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
| 64 |
push_to_hub=True,
|
| 65 |
hub_model_id=RETRO_MODEL_NAME,
|
| 66 |
hub_strategy="every_save",
|
|
|
|
| 4 |
"""
|
| 5 |
import sacrebleu
|
| 6 |
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
from transformers import (
|
| 9 |
AutoTokenizer,
|
| 10 |
T5ForConditionalGeneration,
|
|
|
|
| 42 |
model = T5ForConditionalGeneration.from_pretrained(BASE_MODEL)
|
| 43 |
model.resize_token_embeddings(len(tokenizer))
|
| 44 |
|
| 45 |
+
use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8
|
| 46 |
+
|
| 47 |
# Setup training arguments
|
| 48 |
print("\nSetting up training arguments...")
|
| 49 |
args = Seq2SeqTrainingArguments(
|
|
|
|
| 63 |
eval_steps=EVAL_STEPS,
|
| 64 |
save_steps=SAVE_STEPS,
|
| 65 |
report_to=[],
|
| 66 |
+
bf16=use_bf16,
|
| 67 |
+
fp16=not use_bf16,
|
| 68 |
+
dataloader_num_workers=4,
|
| 69 |
+
dataloader_pin_memory=True,
|
| 70 |
push_to_hub=True,
|
| 71 |
hub_model_id=RETRO_MODEL_NAME,
|
| 72 |
hub_strategy="every_save",
|