A2D2 / a2d2_mol /README.md
Sophia
initial commit
8019be0
|
Raw
History Blame Contribute Delete
8.13 kB

A2D2 for Molecule Generation 🧪

This part of the code fine-tunes an any-length masked diffusion model (MDM) over molecules with A2D2 (Fine-Tuning Any-Length Discrete Diffusion for Adaptive Decoding) to optimize drug-likeness rewards (QED, and optionally synthetic accessibility / SA).

A2D2 jointly fine-tunes the insertion and unmasking policies together with insertion and unmasking quality predictors, generating molecules via Adaptive Joint Decoding (AJD) that remasks low-quality tokens and drops low-quality insertions to sample from the reward-tilted distribution while preserving generation quality.

Molecules are represented as SAFE strings and tokenized with the datamol-io/safe-gpt tokenizer.

The codebase is partially built upon FlexMDM (Kim et.al, 2025) and TR2-D2 (Tang et.al, 2025).

Environment Installation

# from the repository root
conda env create -f environment.yml

conda activate a2d2

The molecule scripts share the a2d2 environment with the peptide and language experiments. See the root environment.yml for the flash-attn install step.

Model Pretrained Weights

A2D2 fine-tunes a pretrained any-length insertion MDM trained on drug-like SAFE molecules. Download the base checkpoint and place it at:

A2D2/pretrained/anylength_mol.ckpt
# from the repository root
pip install gdown
mkdir -p pretrained
gdown 1I5EGiV1I5XZZpB9JAKABFLKVqfCyenxq -O pretrained/anylength_mol.ckpt

(Or download manually from https://drive.google.com/file/d/1I5EGiV1I5XZZpB9JAKABFLKVqfCyenxq/view?usp=drive_link — a plain wget/curl of the link saves Google's HTML warning page, not the checkpoint.) This is the default --checkpoint_path (for fine-tuning) and --pretrained_ckpt (for evaluation) used throughout.

Pretraining the Any-Length Model

If you only want to fine-tune with A2D2, download the released anylength_mol.ckpt above and skip this section. Follow these steps to reproduce the base checkpoint by pretraining the any-length insertion MDM from scratch.

1. The pretraining dataset

The model is pretrained on drug-like SAFE molecules from the datamol-io/safe-gpt dataset (~1.1B molecules) on the Hugging Face Hub. No manual download is required — the dataset is loaded in streaming mode (load_dataset(..., streaming=True)) and tokenized on the fly with the datamol-io/safe-gpt tokenizer, both fetched automatically on first run.

The dataset is configured in config_mol.yaml:

hf_dataset:
  name: "datamol-io/safe-gpt"
  smiles_column: "smiles"

To pretrain on a different Hugging Face SMILES/SAFE dataset, change hf_dataset.name (and smiles_column to match its column).

2. Configure

Pretraining is driven by config_mol.yaml. Key fields:

Field Default Notes
hf_dataset.name datamol-io/safe-gpt Streaming HF dataset (auto-downloaded).
training.devices 2 GPUs per node (DDP).
training.batch_size 2048 Global batch; gradient accumulation is derived automatically from per_gpu_batch_size.
training.max_steps 500000 Total optimizer steps.
training.learning_rate 3e-4 AdamW LR with warmup_steps: 2000.
training.save_every_n_steps 1000 Step-based checkpointing (used for streaming datasets).
training.checkpoint_dir checkpoints/pretrain_mol A timestamped subdirectory is created per run.
interpolant.max_length 256 Max token length.

3. Pre-training Any-Length Molecule Model

Log in to Weights & Biases once (wandb login), or set export WANDB_MODE=disabled to skip logging. Then submit the SLURM job:

# from a2d2_mol/
sbatch train_mol.sh

train_mol.sh is a SLURM batch script that requests one dgx-b200 node with 2 full B200 GPUs and launches DDP via srun (one task per GPU), running the equivalent of:

python train.py --task mol

It activates the conda env (CONDA_ENV, defaults to the peptune env) from CONDA_ROOT (defaults to the shared miniconda install) — the batch shell does not source ~/.bashrc, so override these env vars if your install or env path differs. The GPU count is auto-detected from the SLURM allocation and passed to hydra as training.devices/training.nodes, so to scale just change --gpus-per-node and --ntasks-per-node together at the top of the script (they must match). --task mol makes train.py load config_mol.yaml.

Checkpoints are written to checkpoints/pretrain_mol/<timestamp>/ (use last.ckpt / the best train_loss checkpoint as the --checkpoint_path / --pretrained_ckpt for fine-tuning and evaluation); the run log goes to logs/<date>_a2d2-mol_<jobid>.log and SLURM's catch-file to logs/slurm/. To resume, add a training.resume_path: /path/to/last.ckpt entry to the config.

Fine-Tune with A2D2

The canonical run directory is the parent a2d2/ package (finetune_mol.py, inference_quality_mol.py, sampling.py, and mol_scoring/ here are the molecule-specific modules used from there). Before running:

  1. Set --base_path to the location of a2d2. Results plots are written to <base_path>/flexible/results/<run_name>/.
  2. Create the output directories: a2d2/checkpoints/finetune_mol, a2d2/results, and a2d2/logs.

Single run

scripts/run_mol_finetune.slurm runs a single finetune_mol.py experiment on one MIG GPU, then evaluates the resulting checkpoint. It bundles the full hyperparameter set used in the paper (replicates R = 16, pool size 1000, buffer size 100, sampling steps N_steps = 90, warmup N_warmup = 20, alternation frequency N_alt = 5, reward scaling α = 0.01, quality threshold μ_min = 0.3, --qed_only), so you don't have to pass them by hand.

The script resolves the repo root automatically — $A2D2_ROOT if set, else the sbatch submit directory, else the script's own location — so either submit from the repo root or export your clone path. Set CONDA_ROOT (your miniconda install) and, if needed, CONDA_ENV (defaults to peptune):

export A2D2_ROOT=/path/to/your/A2D2     # absolute path to your clone
export CONDA_ROOT=/path/to/miniconda3   # or just have `conda` on PATH
sbatch scripts/run_mol_finetune.slurm

Select which variant to run with MODE_ID (default 0): 0 = A2D2 (full planner), 1 = --disable_planner, 2 = --disable_insertion_planner, 3 = --disable_unmasking_planner. Override at submit time:

sbatch --export=ALL,MODE_ID=2 scripts/run_mol_finetune.slurm

The pretrained base checkpoint is read from $A2D2_ROOT/pretrained/anylength_mol.ckpt. Outputs land in checkpoints/finetune_mol/<job>_mol_<mode>/ and results/mol_ablation/<mode>/.

Ablation flags

Flag Variant
(none) A2D2 w/ insertion + unmasking quality (alternation)
--disable_planner A2D2 w/o quality (policy only, no remasking)
--disable_insertion_planner A2D2 w/o insertion quality
--disable_unmasking_planner A2D2 w/o unmasking/remasking quality
--joint_training train policy + quality heads jointly (no alternation)

Evaluation

Evaluation runs automatically at the end of the SLURM job. To evaluate a checkpoint manually:

python evaluate_mol_table.py \
    --checkpoint_path /path/to/a2d2/checkpoints/finetune_mol/my_run/last.ckpt \
    --pretrained_ckpt /path/to/A2D2/pretrained/anylength_mol.ckpt \
    --output_dir /path/to/results \
    --num_samples 1000 --batch_size 50 \
    --max_length 256 --total_num_steps 256 \
    --num_remasking 2 --quality_threshold 0.3 --seed 42 --device cuda:0

This reports QED, SA, validity, uniqueness, diversity, and mean unmasking/insertion quality over the generated molecules and writes eval_metrics_<mode>.csv.