""" Forward synthesis model training script. Trains T5 model to predict products from reactants. """ import sacrebleu import numpy as np import torch from transformers import ( AutoTokenizer, T5ForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq, ) from config import ( TOKENIZER_NAME, FORWARD_MODEL_NAME, BATCH_SIZE, GRADIENT_ACCUMULATION_STEPS, LEARNING_RATE, NUM_EPOCHS, EVAL_STEPS, SAVE_STEPS, LOGGING_STEPS, BASE_MODEL ) from data_utils import load_tokenized, get_tokenizer import os def main(): """Main training pipeline for forward synthesis.""" print("=" * 60) print("Forward Synthesis Model Training") print("=" * 60) # Load datasets and tokenizer print("\nLoading datasets...") dataset = load_tokenized("forward") tokenizer = get_tokenizer() print(f"Train samples: {len(dataset['train'])}") print(f"Validation samples: {len(dataset['validation'])}") if "test" in dataset: print(f"Test samples: {len(dataset['test'])}") # Load model print(f"\nLoading base model: {BASE_MODEL}") model = T5ForConditionalGeneration.from_pretrained(BASE_MODEL) model.resize_token_embeddings(len(tokenizer)) use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8 # Setup training arguments print("\nSetting up training arguments...") args = Seq2SeqTrainingArguments( output_dir="./forward_model", eval_strategy="steps", save_strategy="steps", logging_strategy="steps", learning_rate=LEARNING_RATE, per_device_train_batch_size=BATCH_SIZE, per_device_eval_batch_size=BATCH_SIZE, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, weight_decay=0.01, save_total_limit=2, num_train_epochs=NUM_EPOCHS, predict_with_generate=True, logging_steps=LOGGING_STEPS, eval_steps=EVAL_STEPS, save_steps=SAVE_STEPS, report_to=[], bf16=use_bf16, fp16=not use_bf16, dataloader_num_workers=4, dataloader_pin_memory=True, push_to_hub=True, hub_model_id=FORWARD_MODEL_NAME, hub_strategy="every_save", hub_token=os.environ.get("HF_TOKEN"), ) # Data collator collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True) # Metrics def compute_metrics(eval_pred): preds, labels = eval_pred preds = np.where(preds != -100, preds, tokenizer.pad_token_id) labels = np.where(labels != -100, labels, tokenizer.pad_token_id) decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) decoded_preds = [p.strip() for p in decoded_preds] decoded_labels = [l.strip() for l in decoded_labels] bleu = sacrebleu.corpus_bleu(decoded_preds, [decoded_labels]) exact = np.mean([p == l for p, l in zip(decoded_preds, decoded_labels)]) return {"bleu": bleu.score, "exact_match": exact} # Trainer print("\nInitializing trainer...") trainer = Seq2SeqTrainer( model=model, args=args, train_dataset=dataset["train"], eval_dataset=dataset["validation"], tokenizer=tokenizer, data_collator=collator, compute_metrics=compute_metrics, ) # Train print("\nStarting training...") trainer.train() # Evaluate on test set if "test" in dataset: print("\nEvaluating on test set...") test_results = trainer.evaluate(dataset["test"]) print(f"Test Results: {test_results}") # Push to hub print(f"\nPushing model to {FORWARD_MODEL_NAME}...") trainer.push_to_hub() print("\nForward model training complete!") print(f"Model available at: https://huggingface.co/{FORWARD_MODEL_NAME}") if __name__ == "__main__": main()