File size: 7,202 Bytes
8019be0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import os
import sys
import argparse
import hydra
from omegaconf import OmegaConf
from datetime import datetime
# Directory containing this file and the config_*.yaml files (used by Hydra below).
CONFIG_DIR = os.path.dirname(os.path.abspath(__file__))
# Add the repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve.
sys.path.insert(0, os.path.dirname(CONFIG_DIR))

import wandb
from lightning_modules import AnyOrderInsertionFlowModule


torch.set_printoptions(threshold=10_000)
torch.set_float32_matmul_precision("high")

# Disable DDP optimizer due to incompatibility with flex_attention higher-order ops
torch._dynamo.config.optimize_ddp = False

def train(config):
	wandb_logger = None

	# set the random seed
	pl.seed_everything(42)
	torch.manual_seed(42)
  
	# Only initialize wandb on rank 0 to avoid multiple runs
	if int(os.environ.get("LOCAL_RANK", 0)) == 0:
		wandb.init(
			project=config.wandb.project, 
			name=config.wandb.name, 
			config=OmegaConf.to_container(config, resolve=True),  # Convert to dict
			dir=config.wandb.path
		)
		wandb_logger = WandbLogger(
				project=wandb.run.project,
				name=wandb.run.name,
				log_model=False,  # Disable checkpoint uploading to save disk space
			)

	# Modify config to add timestamp to checkpoint directory
	OmegaConf.set_struct(config, False)
	time_string = datetime.now().strftime("%Y%m%d-%H%M%S")
	config.training.checkpoint_dir = os.path.join(
		config.training.checkpoint_dir, time_string
	)
	OmegaConf.set_struct(config, True)

	# Create checkpoint directory
	os.makedirs(config.training.checkpoint_dir, exist_ok=True)
	
	# Setup data module - check if using HuggingFace dataset
	if hasattr(config, 'hf_dataset'):
		# Imported lazily: the HF/SAFE path is only used by the molecule configs,
		# which keep mol_dataset.py (and its `safe` dependency) in a2d2_mol/.
		from mol_dataset import setup_hf_data_and_update_config
		print(f"Using HuggingFace dataset: {config.hf_dataset.name}")
		data_module = setup_hf_data_and_update_config(
			config,
			dataset_name=config.hf_dataset.name,
			smiles_column=config.hf_dataset.get('smiles_column', 'smiles')
		)
	else:
		# Imported lazily: the local (arrow) path is used by the peptide config,
		# which keeps dataloading_for_dynamic_batching.py in a2d2_pep/.
		from dataloading_for_dynamic_batching import setup_data_and_update_config
		print("Using local dataset")
		data_module = setup_data_and_update_config(config)
	
	module = AnyOrderInsertionFlowModule(config)
 
	# Initialize trainer

	# Configure trainer arguments
	# Map torch_dtype to Lightning precision
	dtype_str = config.model.get('torch_dtype', 'bfloat16')
	precision_map = {
		'float32': '32-true',
		'float16': '16-mixed',
		'bfloat16': 'bf16-mixed'
	}
	precision = precision_map.get(dtype_str, 'bf16-mixed')
	
	trainer_kwargs = dict(
		num_nodes=config.training.nodes,
		accelerator="gpu",
		devices=config.training.devices,
		strategy="ddp",
		precision=precision,
		accumulate_grad_batches=(
			config.training.batch_size
			// (
				config.training.per_gpu_batch_size
				* config.training.nodes
				* config.training.devices
			)
		),
		log_every_n_steps=10,
		enable_checkpointing=True,
		default_root_dir=config.training.checkpoint_dir,
		gradient_clip_val=1.0,
	)
	# Only one of max_steps or max_epochs will be used
	if config.training.max_steps is not None:
		trainer_kwargs["max_steps"] = config.training.max_steps
	elif config.training.num_epochs is not None:
		trainer_kwargs["max_epochs"] = config.training.num_epochs
		config.training.max_steps = config.training.max_steps
	else:
		raise ValueError(
			"Either max_steps or num_epochs must be specified in the config"
		)

	if config.training.warmup_steps is None:
		config.training.warmup_steps = int(config.training.max_steps * 0.01)

	# Add ModelCheckpoint callback to save the checkpoint when validation loss is at a new low
	checkpoint_callback = ModelCheckpoint(
		monitor="train/total_loss",
		mode="min",
		save_top_k=config.training.save_top_k,
		save_last=True,
		filename="epoch-{epoch:02d}-train_loss-{train/total_loss:.4f}",
		dirpath=config.training.checkpoint_dir,
		# Don't use val_loss in filename for periodic saves - causes failures when val doesn't run
		auto_insert_metric_name=False
	)
	
	# Add separate callback for periodic saves (no val_loss dependency). Use
	# step-based saves for streaming datasets (save_every_n_steps) and epoch-based
	# saves otherwise (save_every_n_epochs); whichever the config provides.
	save_every_n_steps = config.training.get('save_every_n_steps', None)
	save_every_n_epochs = config.training.get('save_every_n_epochs', None)
	if save_every_n_steps is not None:
		periodic_checkpoint_callback = ModelCheckpoint(
			save_top_k=-1,  # Save all periodic checkpoints
			filename="step-{step:08d}",
			dirpath=config.training.checkpoint_dir,
			every_n_train_steps=save_every_n_steps,
			auto_insert_metric_name=False
		)
	elif save_every_n_epochs is not None:
		periodic_checkpoint_callback = ModelCheckpoint(
			save_top_k=-1,  # Save all periodic checkpoints
			filename="epoch-{epoch:02d}",
			dirpath=config.training.checkpoint_dir,
			every_n_epochs=save_every_n_epochs,
			auto_insert_metric_name=False
		)
	else:
		raise ValueError(
			"Either save_every_n_steps or save_every_n_epochs must be specified in the config"
		)

	trainer_kwargs["callbacks"] = [checkpoint_callback, periodic_checkpoint_callback]

	if wandb_logger is not None:
		trainer_kwargs["logger"] = wandb_logger

	trainer = pl.Trainer(**trainer_kwargs)

	# Train the model
	ckpt_path = None
	if "resume_path" in config.training:
		ckpt_path = config.training.resume_path
 	
	trainer.fit(module, 
             datamodule=data_module, 
             ckpt_path=ckpt_path)
 
	# Only finish wandb on rank 0
	if int(os.environ.get("LOCAL_RANK", 0)) == 0:
		wandb.finish()


if __name__ == '__main__':
	# Parse arguments to get config name
	parser = argparse.ArgumentParser()
	parser.add_argument('--config_name', type=str, default='config',
	                   help='Name of the config file to use')
	parser.add_argument('--task', type=str, default=None,
	                   help='Task name (uses config_{task}.yaml)')
 
	# Parse known args (hydra will handle the rest)
	args, unknown = parser.parse_known_args()
	
	# Determine config name from task or config_name
	if args.task:
		config_name = f'config_{args.task}'
	else:
		config_name = args.config_name
	
	print(f"Using config: {config_name}.yaml")
	
	# Add config name to Hydra overrides (this persists across DDP subprocesses)
	if '--config-name' not in unknown and f'--config-name={config_name}' not in unknown:
		unknown.insert(0, f'--config-name={config_name}')
	
	# Reconstruct sys.argv for hydra
	sys.argv = [sys.argv[0]] + unknown
	
	# Define main function with default config (will be overridden by command line)
	@hydra.main(version_base=None,
	           config_path=CONFIG_DIR,
	           config_name='config')
	def main(config):
		"""Main entry point for training"""
		train(config)
	
	main()