File size: 7,202 Bytes
8019be0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 | import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import os
import sys
import argparse
import hydra
from omegaconf import OmegaConf
from datetime import datetime
# Directory containing this file and the config_*.yaml files (used by Hydra below).
CONFIG_DIR = os.path.dirname(os.path.abspath(__file__))
# Add the repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve.
sys.path.insert(0, os.path.dirname(CONFIG_DIR))
import wandb
from lightning_modules import AnyOrderInsertionFlowModule
torch.set_printoptions(threshold=10_000)
torch.set_float32_matmul_precision("high")
# Disable DDP optimizer due to incompatibility with flex_attention higher-order ops
torch._dynamo.config.optimize_ddp = False
def train(config):
wandb_logger = None
# set the random seed
pl.seed_everything(42)
torch.manual_seed(42)
# Only initialize wandb on rank 0 to avoid multiple runs
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
wandb.init(
project=config.wandb.project,
name=config.wandb.name,
config=OmegaConf.to_container(config, resolve=True), # Convert to dict
dir=config.wandb.path
)
wandb_logger = WandbLogger(
project=wandb.run.project,
name=wandb.run.name,
log_model=False, # Disable checkpoint uploading to save disk space
)
# Modify config to add timestamp to checkpoint directory
OmegaConf.set_struct(config, False)
time_string = datetime.now().strftime("%Y%m%d-%H%M%S")
config.training.checkpoint_dir = os.path.join(
config.training.checkpoint_dir, time_string
)
OmegaConf.set_struct(config, True)
# Create checkpoint directory
os.makedirs(config.training.checkpoint_dir, exist_ok=True)
# Setup data module - check if using HuggingFace dataset
if hasattr(config, 'hf_dataset'):
# Imported lazily: the HF/SAFE path is only used by the molecule configs,
# which keep mol_dataset.py (and its `safe` dependency) in a2d2_mol/.
from mol_dataset import setup_hf_data_and_update_config
print(f"Using HuggingFace dataset: {config.hf_dataset.name}")
data_module = setup_hf_data_and_update_config(
config,
dataset_name=config.hf_dataset.name,
smiles_column=config.hf_dataset.get('smiles_column', 'smiles')
)
else:
# Imported lazily: the local (arrow) path is used by the peptide config,
# which keeps dataloading_for_dynamic_batching.py in a2d2_pep/.
from dataloading_for_dynamic_batching import setup_data_and_update_config
print("Using local dataset")
data_module = setup_data_and_update_config(config)
module = AnyOrderInsertionFlowModule(config)
# Initialize trainer
# Configure trainer arguments
# Map torch_dtype to Lightning precision
dtype_str = config.model.get('torch_dtype', 'bfloat16')
precision_map = {
'float32': '32-true',
'float16': '16-mixed',
'bfloat16': 'bf16-mixed'
}
precision = precision_map.get(dtype_str, 'bf16-mixed')
trainer_kwargs = dict(
num_nodes=config.training.nodes,
accelerator="gpu",
devices=config.training.devices,
strategy="ddp",
precision=precision,
accumulate_grad_batches=(
config.training.batch_size
// (
config.training.per_gpu_batch_size
* config.training.nodes
* config.training.devices
)
),
log_every_n_steps=10,
enable_checkpointing=True,
default_root_dir=config.training.checkpoint_dir,
gradient_clip_val=1.0,
)
# Only one of max_steps or max_epochs will be used
if config.training.max_steps is not None:
trainer_kwargs["max_steps"] = config.training.max_steps
elif config.training.num_epochs is not None:
trainer_kwargs["max_epochs"] = config.training.num_epochs
config.training.max_steps = config.training.max_steps
else:
raise ValueError(
"Either max_steps or num_epochs must be specified in the config"
)
if config.training.warmup_steps is None:
config.training.warmup_steps = int(config.training.max_steps * 0.01)
# Add ModelCheckpoint callback to save the checkpoint when validation loss is at a new low
checkpoint_callback = ModelCheckpoint(
monitor="train/total_loss",
mode="min",
save_top_k=config.training.save_top_k,
save_last=True,
filename="epoch-{epoch:02d}-train_loss-{train/total_loss:.4f}",
dirpath=config.training.checkpoint_dir,
# Don't use val_loss in filename for periodic saves - causes failures when val doesn't run
auto_insert_metric_name=False
)
# Add separate callback for periodic saves (no val_loss dependency). Use
# step-based saves for streaming datasets (save_every_n_steps) and epoch-based
# saves otherwise (save_every_n_epochs); whichever the config provides.
save_every_n_steps = config.training.get('save_every_n_steps', None)
save_every_n_epochs = config.training.get('save_every_n_epochs', None)
if save_every_n_steps is not None:
periodic_checkpoint_callback = ModelCheckpoint(
save_top_k=-1, # Save all periodic checkpoints
filename="step-{step:08d}",
dirpath=config.training.checkpoint_dir,
every_n_train_steps=save_every_n_steps,
auto_insert_metric_name=False
)
elif save_every_n_epochs is not None:
periodic_checkpoint_callback = ModelCheckpoint(
save_top_k=-1, # Save all periodic checkpoints
filename="epoch-{epoch:02d}",
dirpath=config.training.checkpoint_dir,
every_n_epochs=save_every_n_epochs,
auto_insert_metric_name=False
)
else:
raise ValueError(
"Either save_every_n_steps or save_every_n_epochs must be specified in the config"
)
trainer_kwargs["callbacks"] = [checkpoint_callback, periodic_checkpoint_callback]
if wandb_logger is not None:
trainer_kwargs["logger"] = wandb_logger
trainer = pl.Trainer(**trainer_kwargs)
# Train the model
ckpt_path = None
if "resume_path" in config.training:
ckpt_path = config.training.resume_path
trainer.fit(module,
datamodule=data_module,
ckpt_path=ckpt_path)
# Only finish wandb on rank 0
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
wandb.finish()
if __name__ == '__main__':
# Parse arguments to get config name
parser = argparse.ArgumentParser()
parser.add_argument('--config_name', type=str, default='config',
help='Name of the config file to use')
parser.add_argument('--task', type=str, default=None,
help='Task name (uses config_{task}.yaml)')
# Parse known args (hydra will handle the rest)
args, unknown = parser.parse_known_args()
# Determine config name from task or config_name
if args.task:
config_name = f'config_{args.task}'
else:
config_name = args.config_name
print(f"Using config: {config_name}.yaml")
# Add config name to Hydra overrides (this persists across DDP subprocesses)
if '--config-name' not in unknown and f'--config-name={config_name}' not in unknown:
unknown.insert(0, f'--config-name={config_name}')
# Reconstruct sys.argv for hydra
sys.argv = [sys.argv[0]] + unknown
# Define main function with default config (will be overridden by command line)
@hydra.main(version_base=None,
config_path=CONFIG_DIR,
config_name='config')
def main(config):
"""Main entry point for training"""
train(config)
main() |