File size: 8,132 Bytes
8019be0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# A2D2 for Molecule Generation 🧪

This part of the code fine-tunes an **any-length masked diffusion model (MDM)** over molecules with **A2D2** (Fine-Tuning Any-Length Discrete Diffusion for Adaptive Decoding) to optimize drug-likeness rewards (QED, and optionally synthetic accessibility / SA).

A2D2 jointly fine-tunes the insertion and unmasking policies together with **insertion and unmasking quality predictors**, generating molecules via **Adaptive Joint Decoding (AJD)** that remasks low-quality tokens and drops low-quality insertions to sample from the reward-tilted distribution while preserving generation quality.

Molecules are represented as [SAFE](https://github.com/datamol-io/safe) strings and tokenized with the `datamol-io/safe-gpt` tokenizer.

The codebase is partially built upon [FlexMDM (Kim et.al, 2025)](https://github.com/brianlck/FlexMDM/tree/main) and [TR2-D2 (Tang et.al, 2025)](https://github.com/sophtang/TR2-D2/tree/main).

## Environment Installation
```
# from the repository root
conda env create -f environment.yml

conda activate a2d2
```
The molecule scripts share the `a2d2` environment with the peptide and language experiments. See the root [`environment.yml`](../environment.yml) for the `flash-attn` install step.

## Model Pretrained Weights

A2D2 fine-tunes a pretrained any-length insertion MDM trained on drug-like SAFE molecules. Download the base checkpoint and place it at:
```
A2D2/pretrained/anylength_mol.ckpt
```
```bash
# from the repository root
pip install gdown
mkdir -p pretrained
gdown 1I5EGiV1I5XZZpB9JAKABFLKVqfCyenxq -O pretrained/anylength_mol.ckpt
```
(Or download manually from https://drive.google.com/file/d/1I5EGiV1I5XZZpB9JAKABFLKVqfCyenxq/view?usp=drive_link — a plain `wget`/`curl` of the link saves Google's HTML warning page, not the checkpoint.)
This is the default `--checkpoint_path` (for fine-tuning) and `--pretrained_ckpt` (for evaluation) used throughout.

## Pretraining the Any-Length Model

If you only want to fine-tune with A2D2, download the released `anylength_mol.ckpt` above and skip this section. Follow these steps to reproduce the base checkpoint by pretraining the any-length insertion MDM from scratch.

### 1. The pretraining dataset

The model is pretrained on drug-like [SAFE](https://github.com/datamol-io/safe) molecules from the [`datamol-io/safe-gpt`](https://huggingface.co/datasets/datamol-io/safe-gpt) dataset (~1.1B molecules) on the Hugging Face Hub. **No manual download is required** — the dataset is loaded in streaming mode (`load_dataset(..., streaming=True)`) and tokenized on the fly with the `datamol-io/safe-gpt` tokenizer, both fetched automatically on first run.

The dataset is configured in [`config_mol.yaml`](config_mol.yaml):

```yaml
hf_dataset:
  name: "datamol-io/safe-gpt"
  smiles_column: "smiles"
```

To pretrain on a different Hugging Face SMILES/SAFE dataset, change `hf_dataset.name` (and `smiles_column` to match its column).

### 2. Configure

Pretraining is driven by [`config_mol.yaml`](config_mol.yaml). Key fields:

| Field | Default | Notes |
|-------|---------|-------|
| `hf_dataset.name` | `datamol-io/safe-gpt` | Streaming HF dataset (auto-downloaded). |
| `training.devices` | `2` | GPUs per node (DDP). |
| `training.batch_size` | `2048` | Global batch; gradient accumulation is derived automatically from `per_gpu_batch_size`. |
| `training.max_steps` | `500000` | Total optimizer steps. |
| `training.learning_rate` | `3e-4` | AdamW LR with `warmup_steps: 2000`. |
| `training.save_every_n_steps` | `1000` | Step-based checkpointing (used for streaming datasets). |
| `training.checkpoint_dir` | `checkpoints/pretrain_mol` | A timestamped subdirectory is created per run. |
| `interpolant.max_length` | `256` | Max token length. |

### 3. Pre-training Any-Length Molecule Model

Log in to Weights & Biases once (`wandb login`), or set `export WANDB_MODE=disabled` to skip logging. Then submit the SLURM job:

```bash
# from a2d2_mol/
sbatch train_mol.sh
```

`train_mol.sh` is a SLURM batch script that requests one `dgx-b200` node with 2 full B200 GPUs and launches DDP via `srun` (one task per GPU), running the equivalent of:

```bash
python train.py --task mol
```

It activates the conda env (`CONDA_ENV`, defaults to the `peptune` env) from `CONDA_ROOT` (defaults to the shared miniconda install) — the batch shell does not source `~/.bashrc`, so override these env vars if your install or env path differs. The GPU count is auto-detected from the SLURM allocation and passed to hydra as `training.devices`/`training.nodes`, so to scale just change `--gpus-per-node` and `--ntasks-per-node` together at the top of the script (they must match). `--task mol` makes `train.py` load `config_mol.yaml`.

Checkpoints are written to `checkpoints/pretrain_mol/<timestamp>/` (use `last.ckpt` / the best `train_loss` checkpoint as the `--checkpoint_path` / `--pretrained_ckpt` for fine-tuning and evaluation); the run log goes to `logs/<date>_a2d2-mol_<jobid>.log` and SLURM's catch-file to `logs/slurm/`. To resume, add a `training.resume_path: /path/to/last.ckpt` entry to the config.

## Fine-Tune with A2D2

The canonical run directory is the parent `a2d2/` package (`finetune_mol.py`, `inference_quality_mol.py`, `sampling.py`, and `mol_scoring/` here are the molecule-specific modules used from there). Before running:

1. Set `--base_path` to the location of `a2d2`. Results plots are written to `<base_path>/flexible/results/<run_name>/`.
2. Create the output directories: `a2d2/checkpoints/finetune_mol`, `a2d2/results`, and `a2d2/logs`.

### Single run

[`scripts/run_mol_finetune.slurm`](scripts/run_mol_finetune.slurm) runs a single `finetune_mol.py` experiment on one MIG GPU, then evaluates the resulting checkpoint. It bundles the full hyperparameter set used in the paper (replicates `R = 16`, pool size `1000`, buffer size `100`, sampling steps `N_steps = 90`, warmup `N_warmup = 20`, alternation frequency `N_alt = 5`, reward scaling `α = 0.01`, quality threshold `μ_min = 0.3`, `--qed_only`), so you don't have to pass them by hand.

The script resolves the repo root automatically — `$A2D2_ROOT` if set, else the `sbatch` submit directory, else the script's own location — so either submit from the repo root or export your clone path. Set `CONDA_ROOT` (your miniconda install) and, if needed, `CONDA_ENV` (defaults to `peptune`):
```bash
export A2D2_ROOT=/path/to/your/A2D2     # absolute path to your clone
export CONDA_ROOT=/path/to/miniconda3   # or just have `conda` on PATH
sbatch scripts/run_mol_finetune.slurm
```

Select which variant to run with `MODE_ID` (default `0`): `0` = A2D2 (full planner), `1` = `--disable_planner`, `2` = `--disable_insertion_planner`, `3` = `--disable_unmasking_planner`. Override at submit time:
```bash
sbatch --export=ALL,MODE_ID=2 scripts/run_mol_finetune.slurm
```
The pretrained base checkpoint is read from `$A2D2_ROOT/pretrained/anylength_mol.ckpt`. Outputs land in `checkpoints/finetune_mol/<job>_mol_<mode>/` and `results/mol_ablation/<mode>/`.

### Ablation flags
| Flag | Variant |
|------|---------|
| *(none)* | A2D2 w/ insertion + unmasking quality (alternation) |
| `--disable_planner` | A2D2 w/o quality (policy only, no remasking) |
| `--disable_insertion_planner` | A2D2 w/o insertion quality |
| `--disable_unmasking_planner` | A2D2 w/o unmasking/remasking quality |
| `--joint_training` | train policy + quality heads jointly (no alternation) |

## Evaluation

Evaluation runs automatically at the end of the SLURM job. To evaluate a checkpoint manually:
```
python evaluate_mol_table.py \
    --checkpoint_path /path/to/a2d2/checkpoints/finetune_mol/my_run/last.ckpt \
    --pretrained_ckpt /path/to/A2D2/pretrained/anylength_mol.ckpt \
    --output_dir /path/to/results \
    --num_samples 1000 --batch_size 50 \
    --max_length 256 --total_num_steps 256 \
    --num_remasking 2 --quality_threshold 0.3 --seed 42 --device cuda:0
```
This reports QED, SA, validity, uniqueness, diversity, and mean unmasking/insertion quality over the generated molecules and writes `eval_metrics_<mode>.csv`.