File size: 3,095 Bytes
9a0f27c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import datetime
import os
import torch
import wandb
from datasets import load_from_disk
from sklearn.metrics import mean_absolute_error, mean_squared_error
from transformers import (
EarlyStoppingCallback,
PretrainedConfig,
Trainer,
TrainingArguments,
)
from methformer import (
Methformer,
MethformerCollator,
)
run_name = f"mf_{datetime.datetime.now().strftime('%Y-%m-%d_%H%M')}"
print(f"Run name: {run_name}")
out_dir = "/home/ubuntu/project/MethFormer/output/methformer_pretrained/"
os.makedirs(out_dir, exist_ok=True)
device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
dataset = load_from_disk("/home/ubuntu/project/MethFormer/data/methformer_pretrain_binned")
train_dataset = dataset["train"].shuffle(seed=42)
eval_dataset = dataset["validation"]
data_collator = MethformerCollator()
config = PretrainedConfig(
input_dim=2,
hidden_dim=128,
num_hidden_layers=12,
num_attention_heads=8,
hidden_dropout_prob=0.1,
)
model = Methformer(config)
model.to(device)
training_args = TrainingArguments(
run_name=run_name,
output_dir=os.path.join(out_dir, "checkpoints"),
eval_on_start=True,
per_device_train_batch_size=128,
per_device_eval_batch_size=256,
gradient_accumulation_steps=1,
max_grad_norm=1.0,
learning_rate=1e-5,
warmup_ratio=0.05,
lr_scheduler_type="cosine",
num_train_epochs=20,
logging_dir=os.path.join(out_dir, "logs"),
save_strategy="steps",
save_total_limit=1,
eval_strategy="steps",
logging_steps=1000,
eval_steps=1000,
save_steps=5000,
metric_for_best_model="masked_mse",
greater_is_better=False,
report_to="wandb",
disable_tqdm=False,
dataloader_num_workers=8,
remove_unused_columns=False,
fp16=not torch.backends.mps.is_available(),
load_best_model_at_end=True,
seed=42,
)
def compute_metrics(eval_preds):
logits, labels = eval_preds
logits = torch.tensor(logits)
labels = torch.tensor(labels)
mask = labels != -1.0
masked_logits = logits[mask].cpu.numpy()
masked_labels = labels[mask].cpu.numpy()
mse = mean_squared_error(masked_labels, masked_logits)
mae = mean_absolute_error(masked_labels, masked_logits)
return {
"masked_mse": mse,
"masked_mae": mae,
}
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics,
data_collator=data_collator,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)
print("Starting training...")
wandb.init(
group="methformer_pretrain",
job_type="pretrain_full",
name=run_name,
dir=out_dir,
reinit="finish_previous",
config=config.to_dict(),
)
trainer.train()
print("Training complete. Saving model...")
save_path = f"{out_dir}/model"
os.makedirs(save_path, exist_ok=True)
trainer.save_model(save_path)
model.config.save_pretrained(save_path)
print(f"Model saved to {save_path}")
wandb.finish()
|