ord-training-simple / src /train_forward.py
Vaishnav14220
Tune training hyperparameters for L4 GPU
bd72f86
"""
Forward synthesis model training script.
Trains T5 model to predict products from reactants.
"""
import sacrebleu
import numpy as np
import torch
from transformers import (
AutoTokenizer,
T5ForConditionalGeneration,
Seq2SeqTrainingArguments,
Seq2SeqTrainer,
DataCollatorForSeq2Seq,
)
from config import (
TOKENIZER_NAME, FORWARD_MODEL_NAME, BATCH_SIZE,
GRADIENT_ACCUMULATION_STEPS, LEARNING_RATE, NUM_EPOCHS,
EVAL_STEPS, SAVE_STEPS, LOGGING_STEPS, BASE_MODEL
)
from data_utils import load_tokenized, get_tokenizer
import os
def main():
"""Main training pipeline for forward synthesis."""
print("=" * 60)
print("Forward Synthesis Model Training")
print("=" * 60)
# Load datasets and tokenizer
print("\nLoading datasets...")
dataset = load_tokenized("forward")
tokenizer = get_tokenizer()
print(f"Train samples: {len(dataset['train'])}")
print(f"Validation samples: {len(dataset['validation'])}")
if "test" in dataset:
print(f"Test samples: {len(dataset['test'])}")
# Load model
print(f"\nLoading base model: {BASE_MODEL}")
model = T5ForConditionalGeneration.from_pretrained(BASE_MODEL)
model.resize_token_embeddings(len(tokenizer))
use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8
# Setup training arguments
print("\nSetting up training arguments...")
args = Seq2SeqTrainingArguments(
output_dir="./forward_model",
eval_strategy="steps",
save_strategy="steps",
logging_strategy="steps",
learning_rate=LEARNING_RATE,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
weight_decay=0.01,
save_total_limit=2,
num_train_epochs=NUM_EPOCHS,
predict_with_generate=True,
logging_steps=LOGGING_STEPS,
eval_steps=EVAL_STEPS,
save_steps=SAVE_STEPS,
report_to=[],
bf16=use_bf16,
fp16=not use_bf16,
dataloader_num_workers=4,
dataloader_pin_memory=True,
push_to_hub=True,
hub_model_id=FORWARD_MODEL_NAME,
hub_strategy="every_save",
hub_token=os.environ.get("HF_TOKEN"),
)
# Data collator
collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)
# Metrics
def compute_metrics(eval_pred):
preds, labels = eval_pred
preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
decoded_preds = [p.strip() for p in decoded_preds]
decoded_labels = [l.strip() for l in decoded_labels]
bleu = sacrebleu.corpus_bleu(decoded_preds, [decoded_labels])
exact = np.mean([p == l for p, l in zip(decoded_preds, decoded_labels)])
return {"bleu": bleu.score, "exact_match": exact}
# Trainer
print("\nInitializing trainer...")
trainer = Seq2SeqTrainer(
model=model,
args=args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
tokenizer=tokenizer,
data_collator=collator,
compute_metrics=compute_metrics,
)
# Train
print("\nStarting training...")
trainer.train()
# Evaluate on test set
if "test" in dataset:
print("\nEvaluating on test set...")
test_results = trainer.evaluate(dataset["test"])
print(f"Test Results: {test_results}")
# Push to hub
print(f"\nPushing model to {FORWARD_MODEL_NAME}...")
trainer.push_to_hub()
print("\nForward model training complete!")
print(f"Model available at: https://huggingface.co/{FORWARD_MODEL_NAME}")
if __name__ == "__main__":
main()