Configuration Parsing Warning:In UNKNOWN_FILENAME: "auto_map.AutoTokenizer" must be a string

RetroGPT-lite

RetroGPT-lite is a transformer-based model for single-step retrosynthetic prediction. Given a target product molecule in SMILES format, the model predicts the required reactant molecules through sequence-to-sequence generation.

Model Details

Model Description

RetroGPT-lite is a lightweight transformer architecture specifically designed for chemical reaction prediction. The model learns to "reverse" chemical reactions by predicting precursor reactants from product molecules.

  • Developed by: kssrikar4
  • Model type: Transformer-based Sequence-to-Sequence
  • Language: SMILES (Simplified Molecular Input Line Entry System)
  • License: MIT

Model Sources

  • Architecture: Custom transformer with RMSNorm, SwiGLU activation, and multi-head attention
  • Code: Training code available upon request

Uses

Direct Use

The model is intended for retrosynthetic analysis in drug discovery and organic chemistry:

  • Predicting reactant molecules for a given target product
  • Assisting chemists in planning synthetic routes
  • Educational purposes for teaching retrosynthetic thinking
  • High-throughput virtual synthesis planning

Out-of-Scope Use

  • Multi-step synthesis planning (requires integration with search algorithms)
  • Predicting reaction conditions or yields
  • Stereochemistry-specific predictions without additional fine-tuning

How to Get Started with the Model

import torch, sys
from transformers import AutoModelForCausalLM
from rdkit.Chem import AllChem, Draw
from IPython.display import display

def get_reaction(product, model_id="kssrikar4/RetroGPT-lite"):
    model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).eval()
    tok = getattr(sys.modules[model.__class__.__module__], "RetroGPTTokenizer").from_pretrained(model_id)
    ids = torch.tensor([tok.convert_tokens_to_ids(tok.tokenize(f"<s>{product}<sep>"))])
    
    out = model.generate(
        input_ids=ids, 
        attention_mask=torch.ones_like(ids),
        max_length=256, 
        num_beams=5, 
        num_return_sequences=1
    )

    reac = tok.decode(out[0].tolist(), skip_special_tokens=True).split("<sep>")[-1].replace(" ", "")
    rxn = AllChem.ReactionFromSmarts(f"{reac}>>{product}", useSmiles=True)
    
    if rxn:
        AllChem.Compute2DCoordsForReaction(rxn)
        display(Draw.ReactionToImage(rxn, subImgSize=(350, 350)))

get_reaction("your smiles")

Training

Training Data

The model was trained on the USPTO dataset (uspto.csv), which contains patent-derived chemical reactions extracted from US patents. The dataset includes:

  • Single-product, multi-reactant reactions
  • Canonicalized SMILES representations
  • Reactions spanning multiple decades (1976-present)

Dataset Statistics:

  • Training split: ~99% of data
  • Validation split: ~1% of data
  • Maximum sequence length: 256 tokens

Training Procedure

Architecture:

  • Transformer layers: 6
  • Hidden dimension (d_model): 512
  • Attention heads: 8
  • Vocabulary size: ~150 chemical tokens
  • Positional encoding: Learned embeddings

Hyperparameters:

  • Batch size: 64 (effective batch: 128 with gradient accumulation)
  • Learning rate: 3e-4 with AdamW optimizer
  • Weight decay: 0.01
  • Dropout: 0.1
  • Training epochs: 80
  • Gradient accumulation steps: 2
  • Learning rate scheduler: Cosine annealing

Optimization:

  • Mixed precision training (AMP)
  • Gradient clipping
  • Distributed training support (DDP)

Training Hyperparameters

Hyperparameter Value
Transformer Layers 6
Hidden Size 512
Attention Heads 8
Max Sequence Length 256
Batch Size 64
Learning Rate 3e-4
Weight Decay 0.01
Dropout 0.1
Epochs 80
Optimizer AdamW

Evaluation

Evaluation Metrics

The model is evaluated using Top-k Accuracy based on exact canonical SMILES matching:

  • Top-1 Accuracy: The exact correct reactant mixture is the model's first prediction
  • Top-3 Accuracy: The correct answer appears in the top 3 beam search candidates
  • Top-5 Accuracy: The correct answer appears in the top 5 candidates

Test Results

Top-k Accuracy

Performance on Validation Set:

  • Top-1 Accuracy: 59.2%
  • Top-3 Accuracy: 73.0%
  • Top-5 Accuracy: 76.4%

Temporal Performance

Accuracy by Patent Year

The model maintains consistent performance across different patent years, demonstrating robust generalization to reactions from different time periods.

Token-Level Analysis

Token Confusion Matrix

The confusion matrix reveals:

  • Strong diagonal dominance: The model correctly predicts chemical tokens with high accuracy
  • Aromatic vs. Aliphatic: Minor confusion between aromatic (c) and aliphatic (C) carbons
  • Ring closures: Some challenges with numeric ring closure markers (1, 2, etc.)
  • Branching syntax: Occasional misplacement of parentheses for molecular branches

Qualitative Examples

Reaction Examples

Disclaimer: This model is intended for research and educational purposes. Always verify predictions with chemical expertise and experimental validation before laboratory use.

Downloads last month
159
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support