Bailan-Alex's picture
Upload folder using huggingface_hub
4f2b2f4 verified
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import os
import hydra
from omegaconf import DictConfig, OmegaConf
from datetime import datetime
import wandb
import data
from lightning_modules import (
MaskedDiffusionModule,
AutoregressiveModule,
AnyOrderInsertionFlowModule,
)
from pytorch_lightning.utilities import rank_zero_only
torch.set_printoptions(threshold=10_000)
torch.set_float32_matmul_precision("high")
def train(config: DictConfig):
# set the random seed
pl.seed_everything(42)
torch.manual_seed(42)
if "wandb" in config and rank_zero_only.rank == 0:
init_kwargs = dict(
project="interpretable-flow",
entity=config.wandb.entity,
config=OmegaConf.to_container(config, resolve=True),
name=config.wandb.name,
)
# resume wandb run if we're resuming from a checkpoint
if "resume_path" in config.training:
init_kwargs["resume"] = "allow"
wandb.init(**init_kwargs)
wandb_logger = WandbLogger(
project=wandb.run.project,
name=wandb.run.name,
log_model=True,
)
else:
wandb_logger = None
time_string = datetime.now().strftime("%Y%m%d-%H%M%S")
config.training.checkpoint_dir = os.path.join(
config.training.checkpoint_dir, time_string
)
# Create checkpoint directory
os.makedirs(config.training.checkpoint_dir, exist_ok=True)
dataset_bundle = data.setup_data_and_update_config(config)
match config.trainer:
case "mdm":
module = MaskedDiffusionModule(config)
case "autoregressive":
module = AutoregressiveModule(config)
case "any-order-flow":
module = AnyOrderInsertionFlowModule(config)
case _:
raise NotImplementedError(f"Trainer {config.trainer} is not supported")
# Initialize trainer
# Configure trainer arguments
trainer_kwargs = dict(
num_nodes=config.training.nodes,
accelerator="gpu",
devices=config.training.devices,
strategy="ddp",
accumulate_grad_batches=(
config.training.batch_size
// (
config.training.per_gpu_batch_size
* config.training.nodes
* config.training.devices
)
),
log_every_n_steps=10,
enable_checkpointing=True,
default_root_dir=config.training.checkpoint_dir,
gradient_clip_val=1.0,
)
# Only one of max_steps or max_epochs will be used
if config.training.max_steps is not None:
trainer_kwargs["max_steps"] = config.training.max_steps
elif config.training.num_epochs is not None:
trainer_kwargs["max_epochs"] = config.training.num_epochs
config.training.max_steps = config.training.num_epochs * len(
dataset_bundle.train_loader
)
else:
raise ValueError(
"Either max_steps or num_epochs must be specified in the config"
)
if config.training.warmup_steps is None:
config.training.warmup_steps = int(config.training.max_steps * 0.01)
# Add ModelCheckpoint callback to save the checkpoint when validation loss is at a new low
checkpoint_callback = ModelCheckpoint(
monitor="val_loss",
mode="min",
save_top_k=config.training.save_top_k,
save_last=True,
filename="epoch-{epoch:02d}-val_loss-{val_loss:.4f}",
dirpath=config.training.checkpoint_dir,
every_n_train_steps=10000,
# every_n_epochs=config.training.save_every_n_epochs,
)
trainer_kwargs["callbacks"] = [checkpoint_callback]
if wandb_logger is not None:
trainer_kwargs["logger"] = wandb_logger
trainer = pl.Trainer(**trainer_kwargs)
# Train the model
ckpt_path = None
if "resume_path" in config.training:
ckpt_path = config.training.resume_path
trainer.fit(
module,
train_dataloaders=dataset_bundle.train_loader,
val_dataloaders=dataset_bundle.val_loader,
ckpt_path=ckpt_path,
)
if "wandb" in config:
wandb.finish()
@hydra.main(config_path=".", config_name="config")
def main(cfg: DictConfig):
train(cfg)
if __name__ == "__main__":
main()