File size: 4,025 Bytes
29a351f 55ff09b 29a351f bd72f86 29a351f bd72f86 29a351f a3db032 29a351f a3db032 bd72f86 29a351f 55ff09b 29a351f 55ff09b 29a351f 55ff09b 29a351f 55ff09b 29a351f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
"""
Forward synthesis model training script.
Trains T5 model to predict products from reactants.
"""
import sacrebleu
import numpy as np
import torch
from transformers import (
AutoTokenizer,
T5ForConditionalGeneration,
Seq2SeqTrainingArguments,
Seq2SeqTrainer,
DataCollatorForSeq2Seq,
)
from config import (
TOKENIZER_NAME, FORWARD_MODEL_NAME, BATCH_SIZE,
GRADIENT_ACCUMULATION_STEPS, LEARNING_RATE, NUM_EPOCHS,
EVAL_STEPS, SAVE_STEPS, LOGGING_STEPS, BASE_MODEL
)
from data_utils import load_tokenized, get_tokenizer
import os
def main():
"""Main training pipeline for forward synthesis."""
print("=" * 60)
print("Forward Synthesis Model Training")
print("=" * 60)
# Load datasets and tokenizer
print("\nLoading datasets...")
dataset = load_tokenized("forward")
tokenizer = get_tokenizer()
print(f"Train samples: {len(dataset['train'])}")
print(f"Validation samples: {len(dataset['validation'])}")
if "test" in dataset:
print(f"Test samples: {len(dataset['test'])}")
# Load model
print(f"\nLoading base model: {BASE_MODEL}")
model = T5ForConditionalGeneration.from_pretrained(BASE_MODEL)
model.resize_token_embeddings(len(tokenizer))
use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8
# Setup training arguments
print("\nSetting up training arguments...")
args = Seq2SeqTrainingArguments(
output_dir="./forward_model",
eval_strategy="steps",
save_strategy="steps",
logging_strategy="steps",
learning_rate=LEARNING_RATE,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
weight_decay=0.01,
save_total_limit=2,
num_train_epochs=NUM_EPOCHS,
predict_with_generate=True,
logging_steps=LOGGING_STEPS,
eval_steps=EVAL_STEPS,
save_steps=SAVE_STEPS,
report_to=[],
bf16=use_bf16,
fp16=not use_bf16,
dataloader_num_workers=4,
dataloader_pin_memory=True,
push_to_hub=True,
hub_model_id=FORWARD_MODEL_NAME,
hub_strategy="every_save",
hub_token=os.environ.get("HF_TOKEN"),
)
# Data collator
collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)
# Metrics
def compute_metrics(eval_pred):
preds, labels = eval_pred
preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
decoded_preds = [p.strip() for p in decoded_preds]
decoded_labels = [l.strip() for l in decoded_labels]
bleu = sacrebleu.corpus_bleu(decoded_preds, [decoded_labels])
exact = np.mean([p == l for p, l in zip(decoded_preds, decoded_labels)])
return {"bleu": bleu.score, "exact_match": exact}
# Trainer
print("\nInitializing trainer...")
trainer = Seq2SeqTrainer(
model=model,
args=args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
tokenizer=tokenizer,
data_collator=collator,
compute_metrics=compute_metrics,
)
# Train
print("\nStarting training...")
trainer.train()
# Evaluate on test set
if "test" in dataset:
print("\nEvaluating on test set...")
test_results = trainer.evaluate(dataset["test"])
print(f"Test Results: {test_results}")
# Push to hub
print(f"\nPushing model to {FORWARD_MODEL_NAME}...")
trainer.push_to_hub()
print("\nForward model training complete!")
print(f"Model available at: https://huggingface.co/{FORWARD_MODEL_NAME}")
if __name__ == "__main__":
main()
|