--- license: mit tags: - protein-design - protein-mpnn - jax - equinox - biology - structure-based-design library_name: equinox --- # PrxteinMPNN A JAX/Equinox implementation of ProteinMPNN for inverse protein folding and sequence design. ## Model Description PrxteinMPNN is a message-passing neural network that generates amino acid sequences given a protein backbone structure. This implementation uses JAX and Equinox for efficient computation and functional programming patterns. **Key Features:** - Fully modular Equinox implementation - JAX-based for GPU acceleration and automatic differentiation - Multiple pre-trained model variants (original and soluble) - Multiple training epochs (002, 010, 020, 030) ## Available Models All models use the same architecture with different training: ### Original Models - `original_v_48_002` - Trained for 2 epochs - `original_v_48_010` - Trained for 10 epochs - `original_v_48_020` - Trained for 20 epochs (recommended) - `original_v_48_030` - Trained for 30 epochs ### Soluble Models - `soluble_v_48_002` - Trained for 2 epochs on soluble proteins - `soluble_v_48_010` - Trained for 10 epochs on soluble proteins - `soluble_v_48_020` - Trained for 20 epochs on soluble proteins (recommended) - `soluble_v_48_030` - Trained for 30 epochs on soluble proteins ## Installation ```bash pip install jax equinox huggingface_hub ``` ## Usage ### Basic Usage ```python import jax import jax.numpy as jnp import equinox as eqx from huggingface_hub import hf_hub_download # Download model from HuggingFace model_path = hf_hub_download( repo_id="maraxen/prxteinmpnn", filename="eqx/original_v_48_020.eqx", repo_type="model", ) # Create model structure (must match saved architecture) from prxteinmpnn.eqx_new import PrxteinMPNN key = jax.random.PRNGKey(0) model = PrxteinMPNN( node_features=128, edge_features=128, hidden_features=512, num_encoder_layers=3, num_decoder_layers=3, vocab_size=21, k_neighbors=48, key=key, ) # Load weights model = eqx.tree_deserialise_leaves(model_path, model) # Use model for inference # ... (see full documentation for inference examples) ``` ### Using the High-Level API ```python from prxteinmpnn.io.weights import load_model # Automatically downloads and loads the model model = load_model( model_version="v_48_020", model_weights="original" ) ``` ## Model Architecture **Hyperparameters:** - Node features: 128 - Edge features: 128 - Hidden features: 512 - Encoder layers: 3 - Decoder layers: 3 - K-nearest neighbors: 48 - Vocabulary size: 21 (20 amino acids + 1 unknown) **Architecture:** - Message-passing encoder for structural features - Autoregressive decoder for sequence generation - Attention-based edge updates - LayerNorm and residual connections ## Training Data The models were trained on protein structures from the Protein Data Bank (PDB): - **Original models:** Standard PDB training set - **Soluble models:** Filtered for soluble, well-expressed proteins ## Performance These models achieve state-of-the-art performance on: - Native sequence recovery - Structural compatibility (predicted structure vs. designed sequence) - Expressibility and stability (for soluble models) ## Citation If you use PrxteinMPNN in your research, please cite the original ProteinMPNN paper: ```bibtex @article{dauparas2022robust, title={Robust deep learning--based protein sequence design using ProteinMPNN}, author={Dauparas, Justas and Anishchenko, Ivan and Bennett, Nathaniel and Bai, Hua and Ragotte, Robert J and Milles, Lukas F and Wicky, Basile IM and Courbet, Alexis and de Haas, Rob J and Bethel, Neville and others}, journal={Science}, volume={378}, number={6615}, pages={49--56}, year={2022}, publisher={American Association for the Advancement of Science} } ``` ## License MIT License - See LICENSE file for details. ## Links - **GitHub Repository:** [maraxen/PrxteinMPNN](https://github.com/maraxen/PrxteinMPNN) - **Original ProteinMPNN:** [dauparas/ProteinMPNN](https://github.com/dauparas/ProteinMPNN) - **Documentation:** [Full documentation](https://github.com/maraxen/PrxteinMPNN/tree/main/docs) ## Technical Details ### File Format Models are saved using Equinox's `tree_serialise_leaves` format (`.eqx` files), which: - Preserves PyTree structure - Ensures bit-perfect reproducibility - Is compatible with JAX's functional programming paradigm - Supports efficient serialization/deserialization ### Computational Requirements - **Memory:** ~30 MB per model - **Inference:** CPU-compatible, GPU-accelerated - **Batch processing:** Supported via `jax.vmap` ## Updates **Latest (v2.0):** - Migrated to unified Equinox architecture - All models now in `.eqx` format - Improved modularity and type safety - Full JAX compatibility with JIT, vmap, and grad --- For more information, examples, and tutorials, visit the [GitHub repository](https://github.com/maraxen/PrxteinMPNN).