import datetime import os import torch import wandb from datasets import load_from_disk from sklearn.metrics import mean_absolute_error, mean_squared_error from transformers import ( EarlyStoppingCallback, PretrainedConfig, Trainer, TrainingArguments, ) from methformer import ( Methformer, MethformerCollator, ) run_name = f"mf_{datetime.datetime.now().strftime('%Y-%m-%d_%H%M')}" print(f"Run name: {run_name}") out_dir = "/home/ubuntu/project/MethFormer/output/methformer_pretrained/" os.makedirs(out_dir, exist_ok=True) device = ( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) dataset = load_from_disk("/home/ubuntu/project/MethFormer/data/methformer_pretrain_binned") train_dataset = dataset["train"].shuffle(seed=42) eval_dataset = dataset["validation"] data_collator = MethformerCollator() config = PretrainedConfig( input_dim=2, hidden_dim=128, num_hidden_layers=12, num_attention_heads=8, hidden_dropout_prob=0.1, ) model = Methformer(config) model.to(device) training_args = TrainingArguments( run_name=run_name, output_dir=os.path.join(out_dir, "checkpoints"), eval_on_start=True, per_device_train_batch_size=128, per_device_eval_batch_size=256, gradient_accumulation_steps=1, max_grad_norm=1.0, learning_rate=1e-5, warmup_ratio=0.05, lr_scheduler_type="cosine", num_train_epochs=20, logging_dir=os.path.join(out_dir, "logs"), save_strategy="steps", save_total_limit=1, eval_strategy="steps", logging_steps=1000, eval_steps=1000, save_steps=5000, metric_for_best_model="masked_mse", greater_is_better=False, report_to="wandb", disable_tqdm=False, dataloader_num_workers=8, remove_unused_columns=False, fp16=not torch.backends.mps.is_available(), load_best_model_at_end=True, seed=42, ) def compute_metrics(eval_preds): logits, labels = eval_preds logits = torch.tensor(logits) labels = torch.tensor(labels) mask = labels != -1.0 masked_logits = logits[mask].cpu.numpy() masked_labels = labels[mask].cpu.numpy() mse = mean_squared_error(masked_labels, masked_logits) mae = mean_absolute_error(masked_labels, masked_logits) return { "masked_mse": mse, "masked_mae": mae, } trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, compute_metrics=compute_metrics, data_collator=data_collator, callbacks=[EarlyStoppingCallback(early_stopping_patience=3)], ) print("Starting training...") wandb.init( group="methformer_pretrain", job_type="pretrain_full", name=run_name, dir=out_dir, reinit="finish_previous", config=config.to_dict(), ) trainer.train() print("Training complete. Saving model...") save_path = f"{out_dir}/model" os.makedirs(save_path, exist_ok=True) trainer.save_model(save_path) model.config.save_pretrained(save_path) print(f"Model saved to {save_path}") wandb.finish()