|
|
""" |
|
|
Forward synthesis model training script. |
|
|
Trains T5 model to predict products from reactants. |
|
|
""" |
|
|
import sacrebleu |
|
|
import numpy as np |
|
|
import torch |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
T5ForConditionalGeneration, |
|
|
Seq2SeqTrainingArguments, |
|
|
Seq2SeqTrainer, |
|
|
DataCollatorForSeq2Seq, |
|
|
) |
|
|
from config import ( |
|
|
TOKENIZER_NAME, FORWARD_MODEL_NAME, BATCH_SIZE, |
|
|
GRADIENT_ACCUMULATION_STEPS, LEARNING_RATE, NUM_EPOCHS, |
|
|
EVAL_STEPS, SAVE_STEPS, LOGGING_STEPS, BASE_MODEL |
|
|
) |
|
|
from data_utils import load_tokenized, get_tokenizer |
|
|
import os |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main training pipeline for forward synthesis.""" |
|
|
print("=" * 60) |
|
|
print("Forward Synthesis Model Training") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
print("\nLoading datasets...") |
|
|
dataset = load_tokenized("forward") |
|
|
tokenizer = get_tokenizer() |
|
|
|
|
|
print(f"Train samples: {len(dataset['train'])}") |
|
|
print(f"Validation samples: {len(dataset['validation'])}") |
|
|
if "test" in dataset: |
|
|
print(f"Test samples: {len(dataset['test'])}") |
|
|
|
|
|
|
|
|
print(f"\nLoading base model: {BASE_MODEL}") |
|
|
model = T5ForConditionalGeneration.from_pretrained(BASE_MODEL) |
|
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
|
|
use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8 |
|
|
|
|
|
|
|
|
print("\nSetting up training arguments...") |
|
|
args = Seq2SeqTrainingArguments( |
|
|
output_dir="./forward_model", |
|
|
eval_strategy="steps", |
|
|
save_strategy="steps", |
|
|
logging_strategy="steps", |
|
|
learning_rate=LEARNING_RATE, |
|
|
per_device_train_batch_size=BATCH_SIZE, |
|
|
per_device_eval_batch_size=BATCH_SIZE, |
|
|
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, |
|
|
weight_decay=0.01, |
|
|
save_total_limit=2, |
|
|
num_train_epochs=NUM_EPOCHS, |
|
|
predict_with_generate=True, |
|
|
logging_steps=LOGGING_STEPS, |
|
|
eval_steps=EVAL_STEPS, |
|
|
save_steps=SAVE_STEPS, |
|
|
report_to=[], |
|
|
bf16=use_bf16, |
|
|
fp16=not use_bf16, |
|
|
dataloader_num_workers=4, |
|
|
dataloader_pin_memory=True, |
|
|
push_to_hub=True, |
|
|
hub_model_id=FORWARD_MODEL_NAME, |
|
|
hub_strategy="every_save", |
|
|
hub_token=os.environ.get("HF_TOKEN"), |
|
|
) |
|
|
|
|
|
|
|
|
collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True) |
|
|
|
|
|
|
|
|
def compute_metrics(eval_pred): |
|
|
preds, labels = eval_pred |
|
|
preds = np.where(preds != -100, preds, tokenizer.pad_token_id) |
|
|
labels = np.where(labels != -100, labels, tokenizer.pad_token_id) |
|
|
|
|
|
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) |
|
|
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) |
|
|
|
|
|
decoded_preds = [p.strip() for p in decoded_preds] |
|
|
decoded_labels = [l.strip() for l in decoded_labels] |
|
|
|
|
|
bleu = sacrebleu.corpus_bleu(decoded_preds, [decoded_labels]) |
|
|
exact = np.mean([p == l for p, l in zip(decoded_preds, decoded_labels)]) |
|
|
|
|
|
return {"bleu": bleu.score, "exact_match": exact} |
|
|
|
|
|
|
|
|
print("\nInitializing trainer...") |
|
|
trainer = Seq2SeqTrainer( |
|
|
model=model, |
|
|
args=args, |
|
|
train_dataset=dataset["train"], |
|
|
eval_dataset=dataset["validation"], |
|
|
tokenizer=tokenizer, |
|
|
data_collator=collator, |
|
|
compute_metrics=compute_metrics, |
|
|
) |
|
|
|
|
|
|
|
|
print("\nStarting training...") |
|
|
trainer.train() |
|
|
|
|
|
|
|
|
if "test" in dataset: |
|
|
print("\nEvaluating on test set...") |
|
|
test_results = trainer.evaluate(dataset["test"]) |
|
|
print(f"Test Results: {test_results}") |
|
|
|
|
|
|
|
|
print(f"\nPushing model to {FORWARD_MODEL_NAME}...") |
|
|
trainer.push_to_hub() |
|
|
|
|
|
print("\nForward model training complete!") |
|
|
print(f"Model available at: https://huggingface.co/{FORWARD_MODEL_NAME}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|