# A2D2 for Molecule Generation 🧪 This part of the code fine-tunes an **any-length masked diffusion model (MDM)** over molecules with **A2D2** (Fine-Tuning Any-Length Discrete Diffusion for Adaptive Decoding) to optimize drug-likeness rewards (QED, and optionally synthetic accessibility / SA). A2D2 jointly fine-tunes the insertion and unmasking policies together with **insertion and unmasking quality predictors**, generating molecules via **Adaptive Joint Decoding (AJD)** that remasks low-quality tokens and drops low-quality insertions to sample from the reward-tilted distribution while preserving generation quality. Molecules are represented as [SAFE](https://github.com/datamol-io/safe) strings and tokenized with the `datamol-io/safe-gpt` tokenizer. The codebase is partially built upon [FlexMDM (Kim et.al, 2025)](https://github.com/brianlck/FlexMDM/tree/main) and [TR2-D2 (Tang et.al, 2025)](https://github.com/sophtang/TR2-D2/tree/main). ## Environment Installation ``` # from the repository root conda env create -f environment.yml conda activate a2d2 ``` The molecule scripts share the `a2d2` environment with the peptide and language experiments. See the root [`environment.yml`](../environment.yml) for the `flash-attn` install step. ## Model Pretrained Weights A2D2 fine-tunes a pretrained any-length insertion MDM trained on drug-like SAFE molecules. Download the base checkpoint and place it at: ``` A2D2/pretrained/anylength_mol.ckpt ``` ```bash # from the repository root pip install gdown mkdir -p pretrained gdown 1I5EGiV1I5XZZpB9JAKABFLKVqfCyenxq -O pretrained/anylength_mol.ckpt ``` (Or download manually from https://drive.google.com/file/d/1I5EGiV1I5XZZpB9JAKABFLKVqfCyenxq/view?usp=drive_link — a plain `wget`/`curl` of the link saves Google's HTML warning page, not the checkpoint.) This is the default `--checkpoint_path` (for fine-tuning) and `--pretrained_ckpt` (for evaluation) used throughout. ## Pretraining the Any-Length Model If you only want to fine-tune with A2D2, download the released `anylength_mol.ckpt` above and skip this section. Follow these steps to reproduce the base checkpoint by pretraining the any-length insertion MDM from scratch. ### 1. The pretraining dataset The model is pretrained on drug-like [SAFE](https://github.com/datamol-io/safe) molecules from the [`datamol-io/safe-gpt`](https://huggingface.co/datasets/datamol-io/safe-gpt) dataset (~1.1B molecules) on the Hugging Face Hub. **No manual download is required** — the dataset is loaded in streaming mode (`load_dataset(..., streaming=True)`) and tokenized on the fly with the `datamol-io/safe-gpt` tokenizer, both fetched automatically on first run. The dataset is configured in [`config_mol.yaml`](config_mol.yaml): ```yaml hf_dataset: name: "datamol-io/safe-gpt" smiles_column: "smiles" ``` To pretrain on a different Hugging Face SMILES/SAFE dataset, change `hf_dataset.name` (and `smiles_column` to match its column). ### 2. Configure Pretraining is driven by [`config_mol.yaml`](config_mol.yaml). Key fields: | Field | Default | Notes | |-------|---------|-------| | `hf_dataset.name` | `datamol-io/safe-gpt` | Streaming HF dataset (auto-downloaded). | | `training.devices` | `2` | GPUs per node (DDP). | | `training.batch_size` | `2048` | Global batch; gradient accumulation is derived automatically from `per_gpu_batch_size`. | | `training.max_steps` | `500000` | Total optimizer steps. | | `training.learning_rate` | `3e-4` | AdamW LR with `warmup_steps: 2000`. | | `training.save_every_n_steps` | `1000` | Step-based checkpointing (used for streaming datasets). | | `training.checkpoint_dir` | `checkpoints/pretrain_mol` | A timestamped subdirectory is created per run. | | `interpolant.max_length` | `256` | Max token length. | ### 3. Pre-training Any-Length Molecule Model Log in to Weights & Biases once (`wandb login`), or set `export WANDB_MODE=disabled` to skip logging. Then submit the SLURM job: ```bash # from a2d2_mol/ sbatch train_mol.sh ``` `train_mol.sh` is a SLURM batch script that requests one `dgx-b200` node with 2 full B200 GPUs and launches DDP via `srun` (one task per GPU), running the equivalent of: ```bash python train.py --task mol ``` It activates the conda env (`CONDA_ENV`, defaults to the `peptune` env) from `CONDA_ROOT` (defaults to the shared miniconda install) — the batch shell does not source `~/.bashrc`, so override these env vars if your install or env path differs. The GPU count is auto-detected from the SLURM allocation and passed to hydra as `training.devices`/`training.nodes`, so to scale just change `--gpus-per-node` and `--ntasks-per-node` together at the top of the script (they must match). `--task mol` makes `train.py` load `config_mol.yaml`. Checkpoints are written to `checkpoints/pretrain_mol//` (use `last.ckpt` / the best `train_loss` checkpoint as the `--checkpoint_path` / `--pretrained_ckpt` for fine-tuning and evaluation); the run log goes to `logs/_a2d2-mol_.log` and SLURM's catch-file to `logs/slurm/`. To resume, add a `training.resume_path: /path/to/last.ckpt` entry to the config. ## Fine-Tune with A2D2 The canonical run directory is the parent `a2d2/` package (`finetune_mol.py`, `inference_quality_mol.py`, `sampling.py`, and `mol_scoring/` here are the molecule-specific modules used from there). Before running: 1. Set `--base_path` to the location of `a2d2`. Results plots are written to `/flexible/results//`. 2. Create the output directories: `a2d2/checkpoints/finetune_mol`, `a2d2/results`, and `a2d2/logs`. ### Single run [`scripts/run_mol_finetune.slurm`](scripts/run_mol_finetune.slurm) runs a single `finetune_mol.py` experiment on one MIG GPU, then evaluates the resulting checkpoint. It bundles the full hyperparameter set used in the paper (replicates `R = 16`, pool size `1000`, buffer size `100`, sampling steps `N_steps = 90`, warmup `N_warmup = 20`, alternation frequency `N_alt = 5`, reward scaling `α = 0.01`, quality threshold `μ_min = 0.3`, `--qed_only`), so you don't have to pass them by hand. The script resolves the repo root automatically — `$A2D2_ROOT` if set, else the `sbatch` submit directory, else the script's own location — so either submit from the repo root or export your clone path. Set `CONDA_ROOT` (your miniconda install) and, if needed, `CONDA_ENV` (defaults to `peptune`): ```bash export A2D2_ROOT=/path/to/your/A2D2 # absolute path to your clone export CONDA_ROOT=/path/to/miniconda3 # or just have `conda` on PATH sbatch scripts/run_mol_finetune.slurm ``` Select which variant to run with `MODE_ID` (default `0`): `0` = A2D2 (full planner), `1` = `--disable_planner`, `2` = `--disable_insertion_planner`, `3` = `--disable_unmasking_planner`. Override at submit time: ```bash sbatch --export=ALL,MODE_ID=2 scripts/run_mol_finetune.slurm ``` The pretrained base checkpoint is read from `$A2D2_ROOT/pretrained/anylength_mol.ckpt`. Outputs land in `checkpoints/finetune_mol/_mol_/` and `results/mol_ablation//`. ### Ablation flags | Flag | Variant | |------|---------| | *(none)* | A2D2 w/ insertion + unmasking quality (alternation) | | `--disable_planner` | A2D2 w/o quality (policy only, no remasking) | | `--disable_insertion_planner` | A2D2 w/o insertion quality | | `--disable_unmasking_planner` | A2D2 w/o unmasking/remasking quality | | `--joint_training` | train policy + quality heads jointly (no alternation) | ## Evaluation Evaluation runs automatically at the end of the SLURM job. To evaluate a checkpoint manually: ``` python evaluate_mol_table.py \ --checkpoint_path /path/to/a2d2/checkpoints/finetune_mol/my_run/last.ckpt \ --pretrained_ckpt /path/to/A2D2/pretrained/anylength_mol.ckpt \ --output_dir /path/to/results \ --num_samples 1000 --batch_size 50 \ --max_length 256 --total_num_steps 256 \ --num_remasking 2 --quality_threshold 0.3 --seed 42 --device cuda:0 ``` This reports QED, SA, validity, uniqueness, diversity, and mean unmasking/insertion quality over the generated molecules and writes `eval_metrics_.csv`.