A2D2 / a2d2_mol /finetune_mol.py
Sophia
initial commit
8019be0
Raw
History Blame Contribute Delete
35.1 kB
import argparse
from datetime import datetime
import numpy as np
import torch
import pytorch_lightning as pl
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import wandb
import os
import sys
from tqdm import tqdm
import pandas as pd
# add repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# imports
from inference_quality_mol import sample_mol_buffer, sample_mol_eval
from mol_utils.utils import str2bool, set_seed
from mol_scoring.scoring_functions import MolScoringFunctions
from lightning_modules.any_length_remask import AnyOrderInsertionFlowModuleFT
from lightning_modules import AnyOrderInsertionFlowModule
from safe.tokenizer import SAFETokenizer
from tdc import Evaluator
# Repository root (two levels up from this file: A2D2/a2d2_mol/finetune_mol.py)
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def get_tokenizer():
"""Get SAFE tokenizer with added special tokens."""
tk = SAFETokenizer.from_pretrained('datamol-io/safe-gpt').get_pretrained()
tk.add_tokens(['<', '>']) # for bracket_safe
return tk
class MolFinetuner(pl.LightningModule):
"""Lightning module for distributed molecule finetuning."""
def __init__(
self,
args,
policy_model,
reward_model,
tokenizer,
pretrained=None,
mcts=None,
filename=None,
eps=1e-5
):
super().__init__()
self.args = args
self.policy_model = policy_model
self.reward_model = reward_model
self.tokenizer = tokenizer
self.pretrained = pretrained
self.mcts = mcts
self.filename = filename
self.eps = eps
self.evaluator = Evaluator("diversity")
# Save hyperparameters
self.save_hyperparameters(ignore=['policy_model', 'reward_model', 'tokenizer', 'pretrained', 'mcts'])
# Buffer for sequences
self.x_saved = None
self.log_rnd_saved = None
self.final_rewards_saved = None
# initialize logs
self.valid_fraction_log = []
self.diversity_log = []
self.qed_log = []
self.sa_log = []
self.quality_log = []
self.uniqueness_log = []
# Alternating training between policy and planner
self.train_policy = True # Start by training policy
self.alternation_frequency = getattr(args, 'alternation_frequency', 1) # Alternate every N epochs
def freeze_policy_model(self):
"""Freeze policy model parameters (but not planner)."""
for name, param in self.policy_model.named_parameters():
if not name.startswith('planner.'):
param.requires_grad = False
def unfreeze_policy_model(self):
"""Unfreeze policy model parameters (but not planner)."""
for name, param in self.policy_model.named_parameters():
if not name.startswith('planner.'):
param.requires_grad = True
def freeze_planner_model(self):
"""Freeze planner parameters."""
if hasattr(self.policy_model, 'planner'):
for param in self.policy_model.planner.parameters():
param.requires_grad = False
def unfreeze_planner_model(self):
"""Unfreeze planner parameters."""
if hasattr(self.policy_model, 'planner'):
for param in self.policy_model.planner.parameters():
param.requires_grad = True
def configure_optimizers(self):
# Separate parameter groups for policy backbone vs planner heads
planner_lr = getattr(self.args, 'planner_learning_rate', self.args.learning_rate)
planner_params = []
policy_params = []
for name, param in self.policy_model.named_parameters():
if name.startswith('planner.'):
planner_params.append(param)
else:
policy_params.append(param)
param_groups = [
{'params': policy_params, 'lr': self.args.learning_rate},
{'params': planner_params, 'lr': planner_lr},
]
optimizer = torch.optim.AdamW(param_groups)
return optimizer
def _get_quality_mode(self):
"""Map ablation flags + warmup state to quality_mode string."""
if self.args.disable_planner:
return "none"
if self.current_epoch < self.args.schedule_warmup_epochs:
return "none"
di = getattr(self.args, 'disable_insertion_planner', False)
du = getattr(self.args, 'disable_unmasking_planner', False)
if di and du:
return "none"
if di:
return "unmasking_only"
if du:
return "insertion_only"
return "both"
def on_train_epoch_start(self):
"""Called at the start of each training epoch."""
# If disable_planner mode, only train policy (no alternation)
if self.args.disable_planner:
self.train_policy = True
self.unfreeze_policy_model()
self.freeze_planner_model()
if self.global_rank == 0 and self.current_epoch == 0:
print(f"[FINETUNE_QUALITY] Training ONLY policy model (planner frozen, no remasking)")
elif getattr(self.args, 'joint_training', False):
# Joint mode: train policy + planner together every step (no alternation)
self.train_policy = True # marker; training_step adds planner loss when joint_training is set
self.unfreeze_policy_model()
self.unfreeze_planner_model()
if self.global_rank == 0 and self.current_epoch == 0:
print(f"[FINETUNE_QUALITY] JOINT TRAINING: policy + planner trained together (no alternation)")
else:
# Alternate between training policy and planner from epoch 0
# Determine which model to train this epoch
cycle_position = (self.current_epoch // self.alternation_frequency) % 2
self.train_policy = (cycle_position == 0)
if self.train_policy:
# Train policy, freeze planner
self.unfreeze_policy_model()
self.freeze_planner_model()
if self.global_rank == 0:
print(f"[ALTERNATION] Epoch {self.current_epoch}: Training POLICY model (planner frozen)")
else:
# Train planner, freeze policy
self.freeze_policy_model()
self.unfreeze_planner_model()
if self.global_rank == 0:
print(f"[ALTERNATION] Epoch {self.current_epoch}: Training PLANNER model (policy frozen)")
# Resample buffer if needed
if self.x_saved is None or self.current_epoch % self.args.resample_every_n_step == 0:
if self.global_rank == 0:
print(f"[BUFFER] Starting buffer generation for epoch {self.current_epoch}")
self._generate_buffer()
# Synchronize all ranks after buffer generation
if self.trainer and self.trainer.world_size > 1:
if self.global_rank == 0:
print(f"[BUFFER] All ranks completed buffer generation, synchronizing...")
torch.distributed.barrier()
if self.global_rank == 0:
print(f"[BUFFER] Synchronization complete!")
def _generate_buffer(self):
"""Generate buffer of sequences for training.
When pool_size > 0, maintains a persistent pool and refreshes a fraction
each time instead of regenerating the entire buffer from scratch.
"""
rank = self.global_rank if self.trainer else 0
world_size = self.trainer.world_size if self.trainer else 1
pool_size = getattr(self.args, 'pool_size', 0)
is_pool = pool_size > 0
is_init = self.x_saved is None
# Determine how many molecules to sample this call
if is_pool:
refresh_frac = getattr(self.args, 'pool_refresh_fraction', 0.2)
if is_init:
samples_per_gpu = pool_size
else:
samples_per_gpu = max(1, int(pool_size * refresh_frac))
if rank == 0:
if is_init:
print(f"\n[POOL] Initializing pool with {pool_size} molecules at epoch {self.current_epoch}")
else:
print(f"\n[POOL] Refreshing {samples_per_gpu}/{pool_size} molecules ({refresh_frac*100:.0f}%) at epoch {self.current_epoch}")
else:
samples_per_gpu = self.args.buffer_size // world_size
if rank == 0:
samples_per_gpu += self.args.buffer_size % world_size
if rank == 0:
print(f"\n[BUFFER] Starting buffer generation at epoch {self.current_epoch}")
accumulated_x = []
accumulated_log_rnd = []
accumulated_rewards = []
total_accumulated = 0
max_attempts = 100 # Prevent infinite loop
attempts = 0
import time
while total_accumulated < samples_per_gpu and attempts < max_attempts:
attempts += 1
if rank == 0:
print(f"[BUFFER] rank={rank} starting sampling attempt {attempts} at {time.strftime('%H:%M:%S')}")
start_time = time.time()
x_final, log_rnd, final_rewards, trace = \
sample_mol_buffer(
self.policy_model,
self.pretrained,
self.reward_model,
self.tokenizer,
steps=self.args.total_num_steps,
mask=self.policy_model.interpolant.mask_token,
pad=self.policy_model.interpolant.pad_token,
batch_size=self.args.batch_size,
max_length=self.args.max_length,
quality_mode=self._get_quality_mode(),
alpha=self.args.alpha,
num_remasking=self.args.num_remasking,
quality_threshold=self.args.quality_threshold,
use_quality_filter=self.args.use_quality_filter,
)
if self.args.elbo_rnd:
# Override trajectory log_rnd with forward ELBO estimate
if x_final.shape[0] > 0:
with torch.no_grad():
noised = self.policy_model.prepare_noised_sample(
x_final, num_samples=self.args.elbo_rnd_num_samples)
policy_loss = self.policy_model.compute_loss_from_noised(noised)
pretrained_loss = self.pretrained.compute_loss_from_noised(noised)
log_rnd = (pretrained_loss - policy_loss) + (final_rewards / self.args.alpha)
elapsed = time.time() - start_time
if rank == 0:
print(f"[BUFFER] rank={rank} sampling took {elapsed:.1f}s")
n_valid = x_final.shape[0]
if n_valid > 0:
accumulated_x.append(x_final)
accumulated_log_rnd.append(log_rnd)
accumulated_rewards.append(final_rewards)
total_accumulated += n_valid
if rank == 0:
qm = self._get_quality_mode()
print(f"[BUFFER] rank={rank} epoch={self.current_epoch} quality_mode={qm} accumulated={total_accumulated} / {samples_per_gpu} (batch yielded {n_valid} valid) attempt={attempts}")
if total_accumulated == 0:
raise RuntimeError(f"[BUFFER ERROR] Rank {rank}: No valid sequences generated after {attempts} attempts. Check sampling function and reward model.")
if total_accumulated < samples_per_gpu:
print(f"[BUFFER WARNING] Rank {rank}: Only generated {total_accumulated}/{samples_per_gpu} sequences after {attempts} attempts")
new_x = torch.cat(accumulated_x, dim=0)[:samples_per_gpu]
new_log_rnd = torch.cat(accumulated_log_rnd, dim=0)[:samples_per_gpu]
new_rewards = torch.cat(accumulated_rewards, dim=0)[:samples_per_gpu]
del accumulated_x, accumulated_log_rnd, accumulated_rewards
torch.cuda.empty_cache()
# add to buffer: pool mode replaces a random subset, classic mode overwrites
if is_pool and not is_init:
actual_new = min(new_x.shape[0], self.x_saved.shape[0])
indices = torch.randperm(self.x_saved.shape[0], device=self.x_saved.device)[:actual_new]
self.x_saved[indices] = new_x[:actual_new]
self.log_rnd_saved[indices] = new_log_rnd[:actual_new]
self.final_rewards_saved[indices] = new_rewards[:actual_new]
if rank == 0:
print(f"[POOL] Replaced {actual_new}/{self.x_saved.shape[0]} molecules, reward mean={self.final_rewards_saved.mean():.4f}")
else:
self.x_saved = new_x
self.log_rnd_saved = new_log_rnd
self.final_rewards_saved = new_rewards
if rank == 0:
print(f"[BUFFER] After cleanup - GPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
def training_step(self, batch, batch_idx):
"""Training step - batch is ignored, we use saved buffer."""
# Process buffer in mini-batches to avoid OOM
mini_batch_size = getattr(self.args, 'training_mini_batch_size', 8)
buffer_size = self.x_saved.shape[0]
# Randomly sample a mini-batch from buffer
indices = torch.randperm(buffer_size, device=self.x_saved.device)[:mini_batch_size]
x_final = self.x_saved[indices]
# get log_rnd values
log_rnd = self.log_rnd_saved[indices]
sm_temp = getattr(self.args, 'softmax_temperature', 1.0)
joint = getattr(self.args, 'joint_training', False)
policy_loss = None
planner_loss = None
if self.train_policy:
# Train policy with WDCE loss
policy_loss = self.policy_model.loss_wdce_flexible(
log_rnd,
x_final,
num_replicates=self.args.wdce_num_replicates,
centering=self.args.centering,
centering_strength=self.args.centering_strength,
softmax_temperature=sm_temp,
)
self.log('train/policy_loss', policy_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
if (not self.train_policy) or joint:
# Train planner with appropriate loss based on ablation flags
if self.args.disable_insertion_planner:
# Ablation: only train unmasking planner (no insertion head)
planner_loss = self.policy_model.loss_planner_flexible(
log_rnd,
x_final,
num_replicates=self.args.wdce_num_replicates,
centering=self.args.centering,
centering_strength=self.args.centering_strength,
softmax_temperature=sm_temp,
)
self.log('train/planner_unmask_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
self.log('train/planner_insert_loss', 0.0, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
elif self.args.disable_unmasking_planner:
# only train insertion planner (no remasking head)
unmask_loss, insert_loss, _ = self.policy_model.loss_insert_planner_flexible(
log_rnd,
x_final,
num_replicates=self.args.wdce_num_replicates,
centering=self.args.centering,
centering_strength=self.args.centering_strength,
softmax_temperature=sm_temp,
)
# Zero out the unmasking component - only backprop insertion loss
planner_loss = insert_loss
self.log('train/planner_unmask_loss', 0.0, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
self.log('train/planner_insert_loss', insert_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
else:
# Full planner: train both remasking + insertion
unmask_loss, insert_loss, planner_loss = self.policy_model.loss_insert_planner_flexible(
log_rnd,
x_final,
num_replicates=self.args.wdce_num_replicates,
centering=self.args.centering,
centering_strength=self.args.centering_strength,
softmax_temperature=sm_temp,
)
self.log('train/planner_unmask_loss', unmask_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
self.log('train/planner_insert_loss', insert_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
self.log('train/planner_loss', planner_loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
# Combine losses depending on mode
if joint:
loss = policy_loss + planner_loss
mode_value = 0.5
elif self.train_policy:
loss = policy_loss
mode_value = 0.0
else:
loss = planner_loss
mode_value = 1.0
# Log overall loss and mode
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
self.log('train/mode', mode_value, prog_bar=True, sync_dist=True)
return loss
def on_train_epoch_end(self):
"""Called at the end of each training epoch - only rank 0 evaluates."""
# Only evaluate every N epochs to save time
eval_frequency = getattr(self.args, 'eval_every_n_epochs', 5)
is_last_epoch = (self.trainer and self.current_epoch == self.trainer.max_epochs - 1)
if self.global_rank == 0 and (self.current_epoch % eval_frequency == 0 or is_last_epoch):
# Sample eval batch with updated policy
x_eval, qed, sa, uniqueness, diversity, quality, valid_fraction = \
sample_mol_eval(
self.policy_model, self.reward_model,
self.tokenizer,
steps=self.args.total_num_steps,
mask=self.policy_model.interpolant.mask_token,
pad=self.policy_model.interpolant.pad_token,
batch_size=50,
max_length=self.args.max_length,
quality_mode=self._get_quality_mode(),
num_remasking=self.args.num_remasking,
evaluator=self.evaluator,
)
# Append to logs
self.valid_fraction_log.append(valid_fraction)
self.uniqueness_log.append(uniqueness)
self.diversity_log.append(diversity)
self.qed_log.append(qed)
self.sa_log.append(sa)
self.quality_log.append(quality)
# Compute reward stats
mean_reward = self.final_rewards_saved.mean().item()
min_reward = self.final_rewards_saved.min().item()
max_reward = self.final_rewards_saved.max().item()
median_reward = self.final_rewards_saved.median().item()
# Log metrics
self.log_dict({
"eval/valid_fraction": valid_fraction,
"eval/uniqueness": np.mean(uniqueness),
"eval/diversity": np.mean(diversity),
"eval/qed": np.mean(qed),
"eval/sa": np.mean(sa),
"eval/quality": np.mean(quality),
"eval/mean_reward_search": mean_reward,
"eval/min_reward_search": min_reward,
"eval/max_reward_search": max_reward,
"eval/median_reward_search": median_reward
})
print(f"epoch {self.current_epoch} | validity {valid_fraction:.4f} | uniqueness {np.mean(uniqueness):.4f} | diversity {np.mean(diversity):.4f} | "
f"QED {np.mean(qed):.4f} | SA {np.mean(sa):.4f} | quality {np.mean(quality):.4f} | ")
def on_fit_end(self):
"""Called at the end of training - save results."""
if self.global_rank == 0:
# Save logs and plot
base_path = self.args.base_path
plot_path = f'{base_path}/results/{self.args.run_name}'
os.makedirs(plot_path, exist_ok=True)
output_log_path = f'{plot_path}/log_{self.filename}.csv'
save_logs_to_file(self.valid_fraction_log, self.uniqueness_log,
self.diversity_log, self.qed_log, self.sa_log,
self.quality_log, output_log_path)
# Final generation
x_eval, qed, sa, valid_fraction, uniqueness, diversity, quality, df = \
sample_mol_eval(
self.policy_model, self.reward_model,
self.tokenizer,
steps=self.args.total_num_steps,
mask=self.policy_model.interpolant.mask_token,
pad=self.policy_model.interpolant.pad_token,
batch_size=50,
max_length=self.args.max_length,
quality_mode=self._get_quality_mode(),
num_remasking=self.args.num_remasking,
evaluator=self.evaluator,
dataframe=True,
)
df.to_csv(f'{plot_path}/mol_generation_results.csv', index=False)
def save_logs_to_file(valid_fraction_log, uniqueness_log,
diversity_log, qed_log, sa_log,
quality_log, output_path):
"""
Saves the logs to a CSV file.
"""
os.makedirs(os.path.dirname(output_path), exist_ok=True)
log_data = {
"Iteration": list(range(1, len(valid_fraction_log) + 1)),
"Valid Fraction": valid_fraction_log,
"Uniqueness": uniqueness_log,
"Diversity": diversity_log,
"QED": qed_log,
"Synthetic Accessibility": sa_log,
"Quality": quality_log
}
df = pd.DataFrame(log_data)
df.to_csv(output_path, index=False)
class DummyDataset(torch.utils.data.Dataset):
"""Dummy dataset for Lightning trainer (we use buffer instead)."""
def __init__(self, size=100):
self.size = size
def __len__(self):
return self.size
def __getitem__(self, idx):
return torch.zeros(1) # Dummy data
def main():
"""Main entry point for distributed training."""
# Disable DDP optimizer for higher-order ops like flex_attention
import torch._dynamo
torch._dynamo.config.optimize_ddp = False
argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
argparser.add_argument('--base_path', type=str, default=REPO_ROOT)
argparser.add_argument('--learning_rate', type=float, default=1e-4)
argparser.add_argument('--num_epochs', type=int, default=100)
argparser.add_argument('--num_accum_steps', type=int, default=4)
argparser.add_argument('--truncate_steps', type=int, default=50)
argparser.add_argument("--truncate_kl", type=str2bool, default=False)
argparser.add_argument('--gumbel_temp', type=float, default=1.0)
argparser.add_argument('--gradnorm_clip', type=float, default=1.0)
argparser.add_argument('--batch_size', type=int, default=50)
argparser.add_argument('--name', type=str, default='debug')
argparser.add_argument('--total_num_steps', type=int, default=128)
argparser.add_argument('--copy_flag_temp', type=float, default=None)
argparser.add_argument('--save_every_n_epochs', type=int, default=10)
argparser.add_argument('--eval_every_n_epochs', type=int, default=5, help='Evaluate only every N epochs to save time')
argparser.add_argument('--alpha_schedule_warmup', type=int, default=0)
argparser.add_argument("--seed", type=int, default=0)
# new
argparser.add_argument('--run_name', type=str, default='mol')
argparser.add_argument("--save_path_dir", default="", type=str)
# mcts
argparser.add_argument('--num_sequences', type=int, default=10)
argparser.add_argument('--max_length', type=int, default=1024)
argparser.add_argument('--num_children', type=int, default=50)
argparser.add_argument('--num_iter', type=int, default=30) # iterations of mcts
argparser.add_argument('--seq_length', type=int, default=1024)
argparser.add_argument('--time_conditioning', action='store_true', default=False)
argparser.add_argument('--mcts_sampling', type=int, default=0) # for batched categorical sampling: '0' means gumbel noise
argparser.add_argument('--buffer_size', type=int, default=100)
argparser.add_argument('--wdce_num_replicates', type=int, default=16)
argparser.add_argument('--noise_removal', action='store_true', default=False)
argparser.add_argument('--grad_clip', action='store_true', default=False)
argparser.add_argument('--resample_every_n_step', type=int, default=3)
argparser.add_argument('--exploration', type=float, default=0.1)
argparser.add_argument('--reset_every_n_step', type=int, default=100)
argparser.add_argument('--alpha', type=float, default=0.01)
argparser.add_argument('--scalarization', type=str, default='sum')
argparser.add_argument('--no_mcts', action='store_true', default=False)
argparser.add_argument("--centering", action='store_true', default=False)
argparser.add_argument("--centering_strength", type=float, default=1.0)
# adaptive schedule parameters
argparser.add_argument('--use_adaptive_schedule', action='store_true', default=True)
argparser.add_argument('--schedule_hidden_dim', type=int, default=256)
argparser.add_argument('--schedule_num_layers', type=int, default=2)
argparser.add_argument('--schedule_loss_weight', type=float, default=0.1)
argparser.add_argument('--adaptive_threshold', type=float, default=0.5)
argparser.add_argument('--freeze_base_model', action='store_true', default=False)
argparser.add_argument('--schedule_warmup_epochs', type=int, default=20, help='Number of initial epochs to train WITHOUT remasking in buffer generation')
argparser.add_argument('--alternation_frequency', type=int, default=5, help='Number of epochs to train each model before alternating (1=alternate every epoch)')
argparser.add_argument('--planner_learning_rate', type=float, default=None, help='Separate learning rate for planner heads (defaults to --learning_rate if not set)')
# objectives
argparser.add_argument('--num_obj', type=int, default=2)
argparser.add_argument('--devices', type=int, default=-1)
argparser.add_argument('--checkpoint_path', type=str, default=None)
# ELBO-based log_rnd estimation
argparser.add_argument('--elbo_rnd', action='store_true', default=False,
help='If set, compute log_rnd via forward ELBO instead of trajectory rollout')
argparser.add_argument('--elbo_rnd_num_samples', type=int, default=4,
help='Number of noisy time samples per sequence for ELBO-based log_rnd estimation')
# remasking
argparser.add_argument('--num_remasking', type=int, default=5)
argparser.add_argument('--quality_threshold', type=float, default=1)
argparser.add_argument('--use_quality_filter', action='store_true', help='If set, filter buffer to only include molecules with QED>=0.6 and SA<=4')
argparser.add_argument('--training_mini_batch_size', type=int, default=8, help='Mini-batch size for training step to avoid OOM')
argparser.add_argument('--disable_planner', action='store_true', help='If set, disable remasking completely and only train policy (not planner) for quality optimization')
argparser.add_argument('--disable_insertion_planner', action='store_true', help='Ablation: disable insertion quality filtering but keep unmasking/remasking planner')
argparser.add_argument('--disable_unmasking_planner', action='store_true', help='Ablation: disable unmasking/remasking planner but keep insertion quality filtering')
argparser.add_argument('--joint_training', action='store_true', help='Ablation: train policy and planner jointly each step (no alternation, both unfrozen, summed loss). Incompatible with --disable_planner.')
argparser.add_argument('--qed_only', action='store_true', help='If set, optimize only for QED score (no SA)')
argparser.add_argument('--softmax_temperature', type=float, default=1.0,
help='Temperature for softmax on importance weights (>1 smooths, prevents concentration)')
argparser.add_argument('--pool_size', type=int, default=0,
help='If >0, maintain a persistent pool of this size and refresh a fraction each resample step (0=disabled, classic buffer)')
argparser.add_argument('--pool_refresh_fraction', type=float, default=0.2,
help='Fraction of pool to replace each resample step (only used when pool_size>0)')
argparser.add_argument('--num_training_steps_per_epoch', type=int, default=10,
help='Number of gradient updates per epoch (1=original, 10=recommended)')
args = argparser.parse_args()
# Default planner LR to policy LR if not specified
if args.planner_learning_rate is None:
args.planner_learning_rate = args.learning_rate
# Set seed
pl.seed_everything(args.seed)
# Load models
checkpoint_path = args.checkpoint_path if args.checkpoint_path else \
os.path.join(REPO_ROOT, 'pretrained', 'anylength_mol.ckpt')
curr_time = datetime.now().strftime("%Y%m%d_%H%M%S")
if args.no_mcts:
args.run_name = f'mol_al_resample{args.resample_every_n_step}_no-mcts_{curr_time}'
else:
args.run_name = f'mol_al_resample{args.resample_every_n_step}_buffer{args.buffer_size}_numiter{args.num_iter}_children{args.num_children}_{curr_time}'
# append ablation tags to run name for easy identification
if args.disable_planner:
args.run_name += '_no_planner'
if args.disable_insertion_planner:
args.run_name += '_no_insertion_planner'
if args.disable_unmasking_planner:
args.run_name += '_no_unmasking_planner'
if args.joint_training:
if args.disable_planner:
raise ValueError("--joint_training is incompatible with --disable_planner (no planner to train)")
args.run_name += '_joint_training'
args.save_path = os.path.join(args.save_path_dir, args.run_name)
os.makedirs(args.save_path, exist_ok=True)
set_seed(args.seed, use_cuda=False) # Don't init CUDA before Lightning spawns DDP workers
# Initialize the model
print("Loading models..")
# Load pretrained model for reference (frozen)
pretrained = AnyOrderInsertionFlowModule.load_from_checkpoint(checkpoint_path,
map_location='cpu',
weights_only=False)
pretrained.eval()
for param in pretrained.parameters():
param.requires_grad = False
# Load checkpoint to extract config
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
if 'hyper_parameters' in checkpoint:
config = checkpoint['hyper_parameters']['config']
elif 'config' in checkpoint:
config = checkpoint['config']
else:
raise ValueError("Cannot find config in checkpoint")
# Update config for adaptive schedule
from omegaconf import OmegaConf
if not OmegaConf.is_config(config):
from omegaconf import DictConfig
config = DictConfig(config)
OmegaConf.set_struct(config, False)
config.training.use_adaptive_schedule = args.use_adaptive_schedule
config.training.schedule_hidden_dim = args.schedule_hidden_dim
config.training.schedule_num_layers = args.schedule_num_layers
config.training.schedule_loss_weight = args.schedule_loss_weight
config.training.freeze_base_model = args.freeze_base_model
config.training.schedule_warmup_epochs = args.schedule_warmup_epochs
config.training.use_bracket_safe = True
OmegaConf.set_struct(config, True)
# initialize policy model with adaptive schedule
policy_model = AnyOrderInsertionFlowModuleFT(
config=config,
args=args,
pretrained_checkpoint=checkpoint_path,
insertion_planner=True,
)
# define mcts
if args.qed_only:
score_func_names = ['qed']
else:
score_func_names = ['qed', 'sa']
tokenizer = get_tokenizer()
filename = args.run_name
# Device will be set by Lightning automatically in DDP
reward_model = MolScoringFunctions(score_func_names, device='cpu')
model = MolFinetuner(
args=args,
policy_model=policy_model,
reward_model=reward_model,
tokenizer=tokenizer,
pretrained=pretrained,
mcts=None,
filename=filename,
)
checkpoint_callback = ModelCheckpoint(
dirpath=args.save_path,
filename='model-{epoch:02d}-{train_loss:.4f}',
every_n_epochs=args.save_every_n_epochs,
save_top_k=-1, # Save all checkpoints
save_last=True, # Also save last.ckpt
auto_insert_metric_name=False
)
# Defaults to your default wandb entity; override with the WANDB_ENTITY env var.
wandb_logger = WandbLogger(entity=os.environ.get('WANDB_ENTITY'), project='a2d2-mol', name=args.run_name)
# create dummy dataloader
dataset = DummyDataset(size=args.num_training_steps_per_epoch)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)
# setup trainer with DDP
trainer = pl.Trainer(
max_epochs=args.num_epochs,
accelerator='gpu',
devices=args.devices,
strategy=DDPStrategy(find_unused_parameters=True) if args.devices != 1 else 'auto',
gradient_clip_val=args.gradnorm_clip if args.grad_clip else None,
logger=wandb_logger,
callbacks=[checkpoint_callback],
enable_progress_bar=True,
log_every_n_steps=1
)
# Train
trainer.fit(model, dataloader)
if __name__ == '__main__':
main()