temp / README.md
bi0s's picture
Add STAMP hybrid AO-GPT 30M d=512/l=8 ep7 (GenMol 79.00%)
ad8888f verified
|
Raw
History Blame Contribute Delete
7.55 kB
---
license: apache-2.0
tags:
- chemistry
- molecular-generation
- smiles
- stamp
- drug-discovery
---
# STAMP Hybrid AO-GPT (31M, d=512/l=8, epoch 7)
Pretrained AO-GPT (any-order GPT) over **STAMP** molecular token sequences
with a **hybrid motif + character vocabulary**. Trained on 30M unique
filtered molecules. Achieves **79.00% GenMol quality** — matching the
112M-parameter AR baseline (79.64%) with 28% of the parameters and 20% of
the vocabulary size.
## Highlights
- **Small vocab (2481)**: 2387 high-frequency atomic motifs (freq ≥ 5000)
+ 49 SMILES character tokens + ~45 STAMP structural tokens. Covers ~91%
of motif occurrences as atomic tokens; rare motifs expand to chars.
- **Training-time char fallback** with log-interpolated probability
(~2% at the most frequent motif, ~15% at the cutoff). The model sees
each atomic motif in both atomic and char form, closing the
train/inference gap for OOV motifs.
- **STAMP structural tokens** (`[J_*]`, `[B_*]`, `[S_*]`, `[END]`) act as
natural motif boundaries — no extra `[MS]`/`[ME]` markers needed.
- **Drug-like outputs**: 100% validity, 100% uniqueness (at N=1024),
79.00% pass the GenMol filter (QED ≥ 0.6 AND SA ≤ 4.0).
## Files
| file | what it is |
|---|---|
| `model.pt` | torch checkpoint: `{model_state, cfg, epoch, representation, model_type}` |
| `hybrid_vocab.json` | full vocab with atomic motif map, frequencies, and char expansions |
| `motif_vocab.txt` | source motif-freq file (v3_cm_union format: `smiles\tn_heavy\tfreq`) |
| `hybrid_vocab.py` | self-contained `HybridVocab` class for decoding |
| `config.json` | architecture summary + default sampling + eval numbers |
## Evaluation (N=1024 at T=0.95, top_p=0.85)
| metric | value |
|---|---:|
| validity | 100.00% |
| uniqueness (raw SMILES) | 100.00% |
| quality over valid (QED ≥ 0.6 ∧ SA ≤ 4) | 79.16% |
| **GenMol score** | **79.00%** |
| QED mean | 0.727 |
| SA mean | 2.92 |
| diversity (1 − pairwise Tanimoto, 1024-bit Morgan r=2) | 0.860 |
**Reference (AR baseline, old 12573-token vocab, d=768/l=12, 112M params): 79.64%.**
The hybrid model matches within noise at 28% of the parameter count.
## Usage
### 1. Load vocab
```python
from hybrid_vocab import HybridVocab
vocab = HybridVocab.load("hybrid_vocab.json")
# vocab.itos -> list of 2481 token strings
# vocab.atomic_motifs -> {smiles: id} for the 2387 motifs
# vocab.motif_freq -> {smiles: freq}
# vocab.motif_expansion -> {smiles: [char_id, ...]}
```
### 2. Load model
```python
import torch
from dataclasses import dataclass, field
from typing import Optional
# Option A: clone https://github.com/... (STAMP repo) to get `stamp.benchmark.lm`
from stamp.benchmark.lm import LMConfig, TinyDecoderLM
ckpt = torch.load("model.pt", map_location="cpu", weights_only=False)
cfg = LMConfig(**ckpt["cfg"])
cfg.use_adaln = True # AO-GPT arch
model = TinyDecoderLM(vocab_size=len(vocab.itos), cfg=cfg, bidirectional=False)
state = ckpt["model_state"]
# Strip torch.compile prefix if present.
if any(k.startswith("_orig_mod.") for k in state):
state = {k.replace("_orig_mod.", "", 1): v for k, v in state.items()}
model.load_state_dict(state)
model.eval().cuda()
```
### 3. Sample (AR, top-p)
```python
import torch
from hybrid_vocab import is_stamp_structural
BOS, EOS = vocab.bos_id, vocab.eos_id
PAD, UNK, MASK = vocab.pad_id, vocab.unk_id, vocab.mask_id
struct_ids = {vocab.stoi[t] for t in vocab.itos if is_stamp_structural(t)}
suppress = {PAD, BOS, MASK, UNK}
T, P = 0.95, 0.85
n, max_new = 64, 64
@torch.no_grad()
def sample(n_samples=64):
x = torch.full((n_samples, 1), BOS, dtype=torch.long, device="cuda")
finished = torch.zeros(n_samples, dtype=torch.bool, device="cuda")
for step in range(max_new):
orders = torch.arange(x.size(1), device="cuda").unsqueeze(0).expand(x.size(0), -1)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
logits = model(x[:, -cfg.max_seq_len:], orders=orders)[:, -1, :].float()
for sid in suppress:
logits[:, sid] = float("-inf")
# top-p
sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1)
sorted_probs = torch.softmax(sorted_logits / T, dim=-1)
cum = torch.cumsum(sorted_probs, dim=-1)
remove = cum > P
remove[..., 1:] = remove[..., :-1].clone()
remove[..., 0] = False
sorted_logits = sorted_logits.masked_fill(remove, float("-inf"))
logits = torch.zeros_like(logits).scatter_(-1, sorted_idx, sorted_logits)
probs = torch.softmax(logits / T, dim=-1)
nxt = torch.multinomial(probs, 1).squeeze(-1)
nxt = torch.where(finished, torch.full_like(nxt, EOS), nxt)
x = torch.cat([x, nxt.unsqueeze(1)], dim=1)
finished = finished | (nxt == EOS)
if finished.all():
break
return x
```
### 4. Decode token stream → SMILES
```python
def decode_to_stamp_tokens(ids):
"""Flush character runs to motif SMILES at structural token boundaries."""
special = {PAD, BOS, EOS, MASK, UNK}
out, buf = [], []
for i in ids:
if i in special: continue
tok = vocab.itos[i]
if i in struct_ids:
if buf: out.append("".join(buf)); buf = []
out.append(tok)
else:
buf.append(tok)
if buf: out.append("".join(buf))
return out
# Then run through the STAMP codec in the stamp repo:
# from stamp.benchmark.representations import build_representation
# rep = build_representation("stamp")
# text = rep.detokenize(stamp_tokens)
# mol = rep.codec.decode_stamp_to_mol(text)
```
## Sample outputs
Ten representative draws from this checkpoint (all drug-like, QED ≥ 0.6 ∧ SA ≤ 4):
```
Cn1nc(CNCc2cc(Cl)ccc2Cl)n(C)c1=O QED=0.935 SA=2.46 MW=300
CCn1ncc(NC[C@@H]2CCCC[C@@H]2C)c(Br)c1=O QED=0.923 SA=3.31 MW=327
Cc1cccc(Cl)c1NC(=O)CN1CCO[C@@H](C(F)F)CC1 QED=0.921 SA=2.86 MW=332
CCN1CCN(CC(=O)Nc2cc(C(F)(F)F)ccc2Cl)CC1 QED=0.908 SA=1.94 MW=349
COc1ccc(F)c(CNC(=O)C2=CCCCC2)c1 QED=0.907 SA=2.16 MW=263
CN1CC[C@@H]2[C@@H](CCCN2C(=O)NCc2ccc(OC(F)F)cc2)C1 QED=0.905 SA=3.03 MW=353
CC[C@H](C(=O)NCc1c(F)cc(F)cc1F)N1CCCC1=O QED=0.905 SA=2.93 MW=314
Cc1ccc(C2CCN(C(=O)NCc3cccc(F)c3F)CC2)c(=O)n1C QED=0.897 SA=2.45 MW=375
CN(CC(=O)NCc1ccccc1)C(=O)C12CC3CC(CC(C3)C1)C2 QED=0.896 SA=3.37 MW=340
NCC1CCN(c2cc3c(cc2F)c(=O)c(C(=O)O)cn3C2CC2)C1 QED=0.884 SA=2.96 MW=345
```
## Architecture notes
- **AO-GPT**: decoder-only transformer with causal attention over a
shuffled token order per batch (random permutation of middle tokens,
BOS/EOS pinned at ends). Target position is conditioned via AdaLN so
the model learns "any-order" decoding.
- **Hybrid vocab**: structural tokens + SMILES char tokens + atomic motif
tokens share a single id space. At training time, atomic motif tokens
may be expanded to their SMILES char form with a
log-frequency-weighted probability (`HybridVocab.fallback_prob`) so
the model is not brittle at char-level decoding.
- **Decoder**: the STAMP structural tokens delimit motifs; consecutive
character tokens between structural tokens concatenate to a single
motif SMILES, which the STAMP codec parses to a molecule via a stack
machine with safety fallbacks.
## License
Apache-2.0.
## Citation
Cite the STAMP representation paper and this repository. (Placeholder —
fill in with your actual citation info.)