Configuration Parsing Warning:In UNKNOWN_FILENAME: "auto_map.AutoTokenizer" must be a string

RetroGPT

RetroGPT is a transformer-based model for single-step retrosynthetic prediction. Given a target product molecule in SMILES format, the model predicts the required reactant molecules through sequence-to-sequence generation.

Model Details

Model Description

RetroGPT is a lightweight transformer architecture specifically designed for chemical reaction prediction. The model learns to "reverse" chemical reactions by predicting precursor reactants from product molecules.

  • Developed by: kssrikar4
  • Model type: Transformer-based Sequence-to-Sequence
  • Language: SMILES (Simplified Molecular Input Line Entry System)
  • License: MIT

Model Sources

  • Architecture: Custom transformer with RMSNorm, SwiGLU activation, and multi-head attention
  • Code: Training code available upon request

Uses

Direct Use

The model is intended for retrosynthetic analysis in drug discovery and organic chemistry:

  • Predicting reactant molecules for a given target product
  • Assisting chemists in planning synthetic routes
  • Educational purposes for teaching retrosynthetic thinking
  • High-throughput virtual synthesis planning

Out-of-Scope Use

  • Multi-step synthesis planning (requires integration with search algorithms)
  • Predicting reaction conditions or yields
  • Stereochemistry-specific predictions without additional fine-tuning

How to Get Started with the Model

import torch, sys
from transformers import AutoModelForCausalLM
from rdkit.Chem import AllChem, Draw
from IPython.display import display

def get_reaction(product, model_id="kssrikar4/RetroGPT"):
    model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).eval()
    tok = getattr(sys.modules[model.__class__.__module__], "RetroGPTTokenizer").from_pretrained(model_id)
    ids = torch.tensor([tok.convert_tokens_to_ids(tok.tokenize(f"<s>{product}<sep>"))])
    
    out = model.generate(
        input_ids=ids, 
        attention_mask=torch.ones_like(ids),
        max_length=256, 
        num_beams=5, 
        num_return_sequences=1
    )

    reac = tok.decode(out[0].tolist(), skip_special_tokens=True).split("<sep>")[-1].replace(" ", "")
    rxn = AllChem.ReactionFromSmarts(f"{reac}>>{product}", useSmiles=True)
    
    if rxn:
        AllChem.Compute2DCoordsForReaction(rxn)
        display(Draw.ReactionToImage(rxn, subImgSize=(350, 350)))

get_reaction("your smiles")

Training

Training Data

The model was trained on the USPTO dataset (uspto.csv), which contains patent-derived chemical reactions extracted from US patents. The dataset includes:

  • Single-product, multi-reactant reactions
  • Canonicalized SMILES representations
  • Reactions spanning multiple decades (1976-present)

Dataset Statistics:

  • Training split: ~99% of data
  • Validation split: ~1% of data
  • Maximum sequence length: 256 tokens

Training Procedure

Architecture:

  • Transformer layers: 6
  • Hidden dimension (d_model): 512
  • Attention heads: 8
  • Vocabulary size: ~150 chemical tokens
  • Positional encoding: Learned embeddings

Hyperparameters:

  • Batch size: 64 (effective batch: 128 with gradient accumulation)
  • Learning rate: 3e-4 with AdamW optimizer
  • Weight decay: 0.01
  • Dropout: 0.1
  • Training epochs: 80
  • Gradient accumulation steps: 2
  • Learning rate scheduler: Cosine annealing

Optimization:

  • Mixed precision training (AMP)
  • Gradient clipping
  • Distributed training support (DDP)

Training Hyperparameters

Hyperparameter Value
Transformer Layers 6
Hidden Size 512
Attention Heads 8
Max Sequence Length 256
Batch Size 64
Learning Rate 3e-4
Weight Decay 0.01
Dropout 0.1
Epochs 80
Optimizer AdamW

Evaluation

Evaluation Metrics

The model is evaluated using Top-k Accuracy based on exact canonical SMILES matching:

  • Top-1 Accuracy: The exact correct reactant mixture is the model's first prediction
  • Top-3 Accuracy: The correct answer appears in the top 3 beam search candidates
  • Top-5 Accuracy: The correct answer appears in the top 5 candidates

Test Results

Top-k Accuracy

Performance on Validation Set:

  • Top-1 Accuracy: 59.2%
  • Top-3 Accuracy: 73.0%
  • Top-5 Accuracy: 76.4%

Temporal Performance

Accuracy by Patent Year

The model maintains consistent performance across different patent years, demonstrating robust generalization to reactions from different time periods.

Token-Level Analysis

Token Confusion Matrix

The confusion matrix reveals:

  • Strong diagonal dominance: The model correctly predicts chemical tokens with high accuracy
  • Aromatic vs. Aliphatic: Minor confusion between aromatic (c) and aliphatic (C) carbons
  • Ring closures: Some challenges with numeric ring closure markers (1, 2, etc.)
  • Branching syntax: Occasional misplacement of parentheses for molecular branches

Qualitative Examples

Reaction Examples

Disclaimer: This model is intended for research and educational purposes. Always verify predictions with chemical expertise and experimental validation before laboratory use.

Downloads last month
7
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support