| --- |
| license: apache-2.0 |
| tags: |
| - chemistry |
| - molecular-generation |
| - smiles |
| - stamp |
| - drug-discovery |
| --- |
| |
| # STAMP Hybrid AO-GPT (31M, d=512/l=8, epoch 7) |
|
|
| Pretrained AO-GPT (any-order GPT) over **STAMP** molecular token sequences |
| with a **hybrid motif + character vocabulary**. Trained on 30M unique |
| filtered molecules. Achieves **79.00% GenMol quality** — matching the |
| 112M-parameter AR baseline (79.64%) with 28% of the parameters and 20% of |
| the vocabulary size. |
|
|
| ## Highlights |
|
|
| - **Small vocab (2481)**: 2387 high-frequency atomic motifs (freq ≥ 5000) |
| + 49 SMILES character tokens + ~45 STAMP structural tokens. Covers ~91% |
| of motif occurrences as atomic tokens; rare motifs expand to chars. |
| - **Training-time char fallback** with log-interpolated probability |
| (~2% at the most frequent motif, ~15% at the cutoff). The model sees |
| each atomic motif in both atomic and char form, closing the |
| train/inference gap for OOV motifs. |
| - **STAMP structural tokens** (`[J_*]`, `[B_*]`, `[S_*]`, `[END]`) act as |
| natural motif boundaries — no extra `[MS]`/`[ME]` markers needed. |
| - **Drug-like outputs**: 100% validity, 100% uniqueness (at N=1024), |
| 79.00% pass the GenMol filter (QED ≥ 0.6 AND SA ≤ 4.0). |
|
|
| ## Files |
|
|
| | file | what it is | |
| |---|---| |
| | `model.pt` | torch checkpoint: `{model_state, cfg, epoch, representation, model_type}` | |
| | `hybrid_vocab.json` | full vocab with atomic motif map, frequencies, and char expansions | |
| | `motif_vocab.txt` | source motif-freq file (v3_cm_union format: `smiles\tn_heavy\tfreq`) | |
| | `hybrid_vocab.py` | self-contained `HybridVocab` class for decoding | |
| | `config.json` | architecture summary + default sampling + eval numbers | |
|
|
| ## Evaluation (N=1024 at T=0.95, top_p=0.85) |
| |
| | metric | value | |
| |---|---:| |
| | validity | 100.00% | |
| | uniqueness (raw SMILES) | 100.00% | |
| | quality over valid (QED ≥ 0.6 ∧ SA ≤ 4) | 79.16% | |
| | **GenMol score** | **79.00%** | |
| | QED mean | 0.727 | |
| | SA mean | 2.92 | |
| | diversity (1 − pairwise Tanimoto, 1024-bit Morgan r=2) | 0.860 | |
| |
| **Reference (AR baseline, old 12573-token vocab, d=768/l=12, 112M params): 79.64%.** |
| The hybrid model matches within noise at 28% of the parameter count. |
| |
| ## Usage |
| |
| ### 1. Load vocab |
| |
| ```python |
| from hybrid_vocab import HybridVocab |
| vocab = HybridVocab.load("hybrid_vocab.json") |
| # vocab.itos -> list of 2481 token strings |
| # vocab.atomic_motifs -> {smiles: id} for the 2387 motifs |
| # vocab.motif_freq -> {smiles: freq} |
| # vocab.motif_expansion -> {smiles: [char_id, ...]} |
| ``` |
| |
| ### 2. Load model |
| |
| ```python |
| import torch |
| from dataclasses import dataclass, field |
| from typing import Optional |
| |
| # Option A: clone https://github.com/... (STAMP repo) to get `stamp.benchmark.lm` |
| from stamp.benchmark.lm import LMConfig, TinyDecoderLM |
| |
| ckpt = torch.load("model.pt", map_location="cpu", weights_only=False) |
| cfg = LMConfig(**ckpt["cfg"]) |
| cfg.use_adaln = True # AO-GPT arch |
| model = TinyDecoderLM(vocab_size=len(vocab.itos), cfg=cfg, bidirectional=False) |
| |
| state = ckpt["model_state"] |
| # Strip torch.compile prefix if present. |
| if any(k.startswith("_orig_mod.") for k in state): |
| state = {k.replace("_orig_mod.", "", 1): v for k, v in state.items()} |
| model.load_state_dict(state) |
| model.eval().cuda() |
| ``` |
| |
| ### 3. Sample (AR, top-p) |
| |
| ```python |
| import torch |
| from hybrid_vocab import is_stamp_structural |
| |
| BOS, EOS = vocab.bos_id, vocab.eos_id |
| PAD, UNK, MASK = vocab.pad_id, vocab.unk_id, vocab.mask_id |
| struct_ids = {vocab.stoi[t] for t in vocab.itos if is_stamp_structural(t)} |
| suppress = {PAD, BOS, MASK, UNK} |
| T, P = 0.95, 0.85 |
| n, max_new = 64, 64 |
| |
| @torch.no_grad() |
| def sample(n_samples=64): |
| x = torch.full((n_samples, 1), BOS, dtype=torch.long, device="cuda") |
| finished = torch.zeros(n_samples, dtype=torch.bool, device="cuda") |
| for step in range(max_new): |
| orders = torch.arange(x.size(1), device="cuda").unsqueeze(0).expand(x.size(0), -1) |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): |
| logits = model(x[:, -cfg.max_seq_len:], orders=orders)[:, -1, :].float() |
| for sid in suppress: |
| logits[:, sid] = float("-inf") |
| # top-p |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1) |
| sorted_probs = torch.softmax(sorted_logits / T, dim=-1) |
| cum = torch.cumsum(sorted_probs, dim=-1) |
| remove = cum > P |
| remove[..., 1:] = remove[..., :-1].clone() |
| remove[..., 0] = False |
| sorted_logits = sorted_logits.masked_fill(remove, float("-inf")) |
| logits = torch.zeros_like(logits).scatter_(-1, sorted_idx, sorted_logits) |
| probs = torch.softmax(logits / T, dim=-1) |
| nxt = torch.multinomial(probs, 1).squeeze(-1) |
| nxt = torch.where(finished, torch.full_like(nxt, EOS), nxt) |
| x = torch.cat([x, nxt.unsqueeze(1)], dim=1) |
| finished = finished | (nxt == EOS) |
| if finished.all(): |
| break |
| return x |
| ``` |
| |
| ### 4. Decode token stream → SMILES |
| |
| ```python |
| def decode_to_stamp_tokens(ids): |
| """Flush character runs to motif SMILES at structural token boundaries.""" |
| special = {PAD, BOS, EOS, MASK, UNK} |
| out, buf = [], [] |
| for i in ids: |
| if i in special: continue |
| tok = vocab.itos[i] |
| if i in struct_ids: |
| if buf: out.append("".join(buf)); buf = [] |
| out.append(tok) |
| else: |
| buf.append(tok) |
| if buf: out.append("".join(buf)) |
| return out |
| |
| # Then run through the STAMP codec in the stamp repo: |
| # from stamp.benchmark.representations import build_representation |
| # rep = build_representation("stamp") |
| # text = rep.detokenize(stamp_tokens) |
| # mol = rep.codec.decode_stamp_to_mol(text) |
| ``` |
| |
| ## Sample outputs |
| |
| Ten representative draws from this checkpoint (all drug-like, QED ≥ 0.6 ∧ SA ≤ 4): |
| |
| ``` |
| Cn1nc(CNCc2cc(Cl)ccc2Cl)n(C)c1=O QED=0.935 SA=2.46 MW=300 |
| CCn1ncc(NC[C@@H]2CCCC[C@@H]2C)c(Br)c1=O QED=0.923 SA=3.31 MW=327 |
| Cc1cccc(Cl)c1NC(=O)CN1CCO[C@@H](C(F)F)CC1 QED=0.921 SA=2.86 MW=332 |
| CCN1CCN(CC(=O)Nc2cc(C(F)(F)F)ccc2Cl)CC1 QED=0.908 SA=1.94 MW=349 |
| COc1ccc(F)c(CNC(=O)C2=CCCCC2)c1 QED=0.907 SA=2.16 MW=263 |
| CN1CC[C@@H]2[C@@H](CCCN2C(=O)NCc2ccc(OC(F)F)cc2)C1 QED=0.905 SA=3.03 MW=353 |
| CC[C@H](C(=O)NCc1c(F)cc(F)cc1F)N1CCCC1=O QED=0.905 SA=2.93 MW=314 |
| Cc1ccc(C2CCN(C(=O)NCc3cccc(F)c3F)CC2)c(=O)n1C QED=0.897 SA=2.45 MW=375 |
| CN(CC(=O)NCc1ccccc1)C(=O)C12CC3CC(CC(C3)C1)C2 QED=0.896 SA=3.37 MW=340 |
| NCC1CCN(c2cc3c(cc2F)c(=O)c(C(=O)O)cn3C2CC2)C1 QED=0.884 SA=2.96 MW=345 |
| ``` |
| |
| ## Architecture notes |
| |
| - **AO-GPT**: decoder-only transformer with causal attention over a |
| shuffled token order per batch (random permutation of middle tokens, |
| BOS/EOS pinned at ends). Target position is conditioned via AdaLN so |
| the model learns "any-order" decoding. |
| - **Hybrid vocab**: structural tokens + SMILES char tokens + atomic motif |
| tokens share a single id space. At training time, atomic motif tokens |
| may be expanded to their SMILES char form with a |
| log-frequency-weighted probability (`HybridVocab.fallback_prob`) so |
| the model is not brittle at char-level decoding. |
| - **Decoder**: the STAMP structural tokens delimit motifs; consecutive |
| character tokens between structural tokens concatenate to a single |
| motif SMILES, which the STAMP codec parses to a molecule via a stack |
| machine with safety fallbacks. |
| |
| ## License |
| |
| Apache-2.0. |
| |
| ## Citation |
| |
| Cite the STAMP representation paper and this repository. (Placeholder — |
| fill in with your actual citation info.) |
| |