Methformer / scripts /pretrain_methformer.py
CChahrour's picture
Upload folder using huggingface_hub
9a0f27c verified
import datetime
import os
import torch
import wandb
from datasets import load_from_disk
from sklearn.metrics import mean_absolute_error, mean_squared_error
from transformers import (
EarlyStoppingCallback,
PretrainedConfig,
Trainer,
TrainingArguments,
)
from methformer import (
Methformer,
MethformerCollator,
)
run_name = f"mf_{datetime.datetime.now().strftime('%Y-%m-%d_%H%M')}"
print(f"Run name: {run_name}")
out_dir = "/home/ubuntu/project/MethFormer/output/methformer_pretrained/"
os.makedirs(out_dir, exist_ok=True)
device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
dataset = load_from_disk("/home/ubuntu/project/MethFormer/data/methformer_pretrain_binned")
train_dataset = dataset["train"].shuffle(seed=42)
eval_dataset = dataset["validation"]
data_collator = MethformerCollator()
config = PretrainedConfig(
input_dim=2,
hidden_dim=128,
num_hidden_layers=12,
num_attention_heads=8,
hidden_dropout_prob=0.1,
)
model = Methformer(config)
model.to(device)
training_args = TrainingArguments(
run_name=run_name,
output_dir=os.path.join(out_dir, "checkpoints"),
eval_on_start=True,
per_device_train_batch_size=128,
per_device_eval_batch_size=256,
gradient_accumulation_steps=1,
max_grad_norm=1.0,
learning_rate=1e-5,
warmup_ratio=0.05,
lr_scheduler_type="cosine",
num_train_epochs=20,
logging_dir=os.path.join(out_dir, "logs"),
save_strategy="steps",
save_total_limit=1,
eval_strategy="steps",
logging_steps=1000,
eval_steps=1000,
save_steps=5000,
metric_for_best_model="masked_mse",
greater_is_better=False,
report_to="wandb",
disable_tqdm=False,
dataloader_num_workers=8,
remove_unused_columns=False,
fp16=not torch.backends.mps.is_available(),
load_best_model_at_end=True,
seed=42,
)
def compute_metrics(eval_preds):
logits, labels = eval_preds
logits = torch.tensor(logits)
labels = torch.tensor(labels)
mask = labels != -1.0
masked_logits = logits[mask].cpu.numpy()
masked_labels = labels[mask].cpu.numpy()
mse = mean_squared_error(masked_labels, masked_logits)
mae = mean_absolute_error(masked_labels, masked_logits)
return {
"masked_mse": mse,
"masked_mae": mae,
}
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics,
data_collator=data_collator,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)
print("Starting training...")
wandb.init(
group="methformer_pretrain",
job_type="pretrain_full",
name=run_name,
dir=out_dir,
reinit="finish_previous",
config=config.to_dict(),
)
trainer.train()
print("Training complete. Saving model...")
save_path = f"{out_dir}/model"
os.makedirs(save_path, exist_ok=True)
trainer.save_model(save_path)
model.config.save_pretrained(save_path)
print(f"Model saved to {save_path}")
wandb.finish()