--- license: apache-2.0 tags: - chemistry - molecular-generation - smiles - stamp - drug-discovery --- # STAMP Hybrid AO-GPT (31M, d=512/l=8, epoch 7) Pretrained AO-GPT (any-order GPT) over **STAMP** molecular token sequences with a **hybrid motif + character vocabulary**. Trained on 30M unique filtered molecules. Achieves **79.00% GenMol quality** — matching the 112M-parameter AR baseline (79.64%) with 28% of the parameters and 20% of the vocabulary size. ## Highlights - **Small vocab (2481)**: 2387 high-frequency atomic motifs (freq ≥ 5000) + 49 SMILES character tokens + ~45 STAMP structural tokens. Covers ~91% of motif occurrences as atomic tokens; rare motifs expand to chars. - **Training-time char fallback** with log-interpolated probability (~2% at the most frequent motif, ~15% at the cutoff). The model sees each atomic motif in both atomic and char form, closing the train/inference gap for OOV motifs. - **STAMP structural tokens** (`[J_*]`, `[B_*]`, `[S_*]`, `[END]`) act as natural motif boundaries — no extra `[MS]`/`[ME]` markers needed. - **Drug-like outputs**: 100% validity, 100% uniqueness (at N=1024), 79.00% pass the GenMol filter (QED ≥ 0.6 AND SA ≤ 4.0). ## Files | file | what it is | |---|---| | `model.pt` | torch checkpoint: `{model_state, cfg, epoch, representation, model_type}` | | `hybrid_vocab.json` | full vocab with atomic motif map, frequencies, and char expansions | | `motif_vocab.txt` | source motif-freq file (v3_cm_union format: `smiles\tn_heavy\tfreq`) | | `hybrid_vocab.py` | self-contained `HybridVocab` class for decoding | | `config.json` | architecture summary + default sampling + eval numbers | ## Evaluation (N=1024 at T=0.95, top_p=0.85) | metric | value | |---|---:| | validity | 100.00% | | uniqueness (raw SMILES) | 100.00% | | quality over valid (QED ≥ 0.6 ∧ SA ≤ 4) | 79.16% | | **GenMol score** | **79.00%** | | QED mean | 0.727 | | SA mean | 2.92 | | diversity (1 − pairwise Tanimoto, 1024-bit Morgan r=2) | 0.860 | **Reference (AR baseline, old 12573-token vocab, d=768/l=12, 112M params): 79.64%.** The hybrid model matches within noise at 28% of the parameter count. ## Usage ### 1. Load vocab ```python from hybrid_vocab import HybridVocab vocab = HybridVocab.load("hybrid_vocab.json") # vocab.itos -> list of 2481 token strings # vocab.atomic_motifs -> {smiles: id} for the 2387 motifs # vocab.motif_freq -> {smiles: freq} # vocab.motif_expansion -> {smiles: [char_id, ...]} ``` ### 2. Load model ```python import torch from dataclasses import dataclass, field from typing import Optional # Option A: clone https://github.com/... (STAMP repo) to get `stamp.benchmark.lm` from stamp.benchmark.lm import LMConfig, TinyDecoderLM ckpt = torch.load("model.pt", map_location="cpu", weights_only=False) cfg = LMConfig(**ckpt["cfg"]) cfg.use_adaln = True # AO-GPT arch model = TinyDecoderLM(vocab_size=len(vocab.itos), cfg=cfg, bidirectional=False) state = ckpt["model_state"] # Strip torch.compile prefix if present. if any(k.startswith("_orig_mod.") for k in state): state = {k.replace("_orig_mod.", "", 1): v for k, v in state.items()} model.load_state_dict(state) model.eval().cuda() ``` ### 3. Sample (AR, top-p) ```python import torch from hybrid_vocab import is_stamp_structural BOS, EOS = vocab.bos_id, vocab.eos_id PAD, UNK, MASK = vocab.pad_id, vocab.unk_id, vocab.mask_id struct_ids = {vocab.stoi[t] for t in vocab.itos if is_stamp_structural(t)} suppress = {PAD, BOS, MASK, UNK} T, P = 0.95, 0.85 n, max_new = 64, 64 @torch.no_grad() def sample(n_samples=64): x = torch.full((n_samples, 1), BOS, dtype=torch.long, device="cuda") finished = torch.zeros(n_samples, dtype=torch.bool, device="cuda") for step in range(max_new): orders = torch.arange(x.size(1), device="cuda").unsqueeze(0).expand(x.size(0), -1) with torch.autocast(device_type="cuda", dtype=torch.bfloat16): logits = model(x[:, -cfg.max_seq_len:], orders=orders)[:, -1, :].float() for sid in suppress: logits[:, sid] = float("-inf") # top-p sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1) sorted_probs = torch.softmax(sorted_logits / T, dim=-1) cum = torch.cumsum(sorted_probs, dim=-1) remove = cum > P remove[..., 1:] = remove[..., :-1].clone() remove[..., 0] = False sorted_logits = sorted_logits.masked_fill(remove, float("-inf")) logits = torch.zeros_like(logits).scatter_(-1, sorted_idx, sorted_logits) probs = torch.softmax(logits / T, dim=-1) nxt = torch.multinomial(probs, 1).squeeze(-1) nxt = torch.where(finished, torch.full_like(nxt, EOS), nxt) x = torch.cat([x, nxt.unsqueeze(1)], dim=1) finished = finished | (nxt == EOS) if finished.all(): break return x ``` ### 4. Decode token stream → SMILES ```python def decode_to_stamp_tokens(ids): """Flush character runs to motif SMILES at structural token boundaries.""" special = {PAD, BOS, EOS, MASK, UNK} out, buf = [], [] for i in ids: if i in special: continue tok = vocab.itos[i] if i in struct_ids: if buf: out.append("".join(buf)); buf = [] out.append(tok) else: buf.append(tok) if buf: out.append("".join(buf)) return out # Then run through the STAMP codec in the stamp repo: # from stamp.benchmark.representations import build_representation # rep = build_representation("stamp") # text = rep.detokenize(stamp_tokens) # mol = rep.codec.decode_stamp_to_mol(text) ``` ## Sample outputs Ten representative draws from this checkpoint (all drug-like, QED ≥ 0.6 ∧ SA ≤ 4): ``` Cn1nc(CNCc2cc(Cl)ccc2Cl)n(C)c1=O QED=0.935 SA=2.46 MW=300 CCn1ncc(NC[C@@H]2CCCC[C@@H]2C)c(Br)c1=O QED=0.923 SA=3.31 MW=327 Cc1cccc(Cl)c1NC(=O)CN1CCO[C@@H](C(F)F)CC1 QED=0.921 SA=2.86 MW=332 CCN1CCN(CC(=O)Nc2cc(C(F)(F)F)ccc2Cl)CC1 QED=0.908 SA=1.94 MW=349 COc1ccc(F)c(CNC(=O)C2=CCCCC2)c1 QED=0.907 SA=2.16 MW=263 CN1CC[C@@H]2[C@@H](CCCN2C(=O)NCc2ccc(OC(F)F)cc2)C1 QED=0.905 SA=3.03 MW=353 CC[C@H](C(=O)NCc1c(F)cc(F)cc1F)N1CCCC1=O QED=0.905 SA=2.93 MW=314 Cc1ccc(C2CCN(C(=O)NCc3cccc(F)c3F)CC2)c(=O)n1C QED=0.897 SA=2.45 MW=375 CN(CC(=O)NCc1ccccc1)C(=O)C12CC3CC(CC(C3)C1)C2 QED=0.896 SA=3.37 MW=340 NCC1CCN(c2cc3c(cc2F)c(=O)c(C(=O)O)cn3C2CC2)C1 QED=0.884 SA=2.96 MW=345 ``` ## Architecture notes - **AO-GPT**: decoder-only transformer with causal attention over a shuffled token order per batch (random permutation of middle tokens, BOS/EOS pinned at ends). Target position is conditioned via AdaLN so the model learns "any-order" decoding. - **Hybrid vocab**: structural tokens + SMILES char tokens + atomic motif tokens share a single id space. At training time, atomic motif tokens may be expanded to their SMILES char form with a log-frequency-weighted probability (`HybridVocab.fallback_prob`) so the model is not brittle at char-level decoding. - **Decoder**: the STAMP structural tokens delimit motifs; consecutive character tokens between structural tokens concatenate to a single motif SMILES, which the STAMP codec parses to a molecule via a stack machine with safety fallbacks. ## License Apache-2.0. ## Citation Cite the STAMP representation paper and this repository. (Placeholder — fill in with your actual citation info.)