|
|
--- |
|
|
license: mit |
|
|
tags: |
|
|
- protein-design |
|
|
- protein-mpnn |
|
|
- jax |
|
|
- equinox |
|
|
- biology |
|
|
- structure-based-design |
|
|
library_name: equinox |
|
|
--- |
|
|
|
|
|
# PrxteinMPNN |
|
|
|
|
|
A JAX/Equinox implementation of ProteinMPNN for inverse protein folding and sequence design. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
PrxteinMPNN is a message-passing neural network that generates amino acid sequences given a protein backbone structure. This implementation uses JAX and Equinox for efficient computation and functional programming patterns. |
|
|
|
|
|
**Key Features:** |
|
|
- Fully modular Equinox implementation |
|
|
- JAX-based for GPU acceleration and automatic differentiation |
|
|
- Multiple pre-trained model variants (original and soluble) |
|
|
- Multiple training epochs (002, 010, 020, 030) |
|
|
|
|
|
## Available Models |
|
|
|
|
|
All models use the same architecture with different training: |
|
|
|
|
|
### Original Models |
|
|
- `original_v_48_002` - Trained for 2 epochs |
|
|
- `original_v_48_010` - Trained for 10 epochs |
|
|
- `original_v_48_020` - Trained for 20 epochs (recommended) |
|
|
- `original_v_48_030` - Trained for 30 epochs |
|
|
|
|
|
### Soluble Models |
|
|
- `soluble_v_48_002` - Trained for 2 epochs on soluble proteins |
|
|
- `soluble_v_48_010` - Trained for 10 epochs on soluble proteins |
|
|
- `soluble_v_48_020` - Trained for 20 epochs on soluble proteins (recommended) |
|
|
- `soluble_v_48_030` - Trained for 30 epochs on soluble proteins |
|
|
|
|
|
## Installation |
|
|
|
|
|
```bash |
|
|
pip install jax equinox huggingface_hub |
|
|
``` |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Basic Usage |
|
|
|
|
|
```python |
|
|
import jax |
|
|
import jax.numpy as jnp |
|
|
import equinox as eqx |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
# Download model from HuggingFace |
|
|
model_path = hf_hub_download( |
|
|
repo_id="maraxen/prxteinmpnn", |
|
|
filename="eqx/original_v_48_020.eqx", |
|
|
repo_type="model", |
|
|
) |
|
|
|
|
|
# Create model structure (must match saved architecture) |
|
|
from prxteinmpnn.eqx_new import PrxteinMPNN |
|
|
|
|
|
key = jax.random.PRNGKey(0) |
|
|
model = PrxteinMPNN( |
|
|
node_features=128, |
|
|
edge_features=128, |
|
|
hidden_features=512, |
|
|
num_encoder_layers=3, |
|
|
num_decoder_layers=3, |
|
|
vocab_size=21, |
|
|
k_neighbors=48, |
|
|
key=key, |
|
|
) |
|
|
|
|
|
# Load weights |
|
|
model = eqx.tree_deserialise_leaves(model_path, model) |
|
|
|
|
|
# Use model for inference |
|
|
# ... (see full documentation for inference examples) |
|
|
``` |
|
|
|
|
|
### Using the High-Level API |
|
|
|
|
|
```python |
|
|
from prxteinmpnn.io.weights import load_model |
|
|
|
|
|
# Automatically downloads and loads the model |
|
|
model = load_model( |
|
|
model_version="v_48_020", |
|
|
model_weights="original" |
|
|
) |
|
|
``` |
|
|
|
|
|
## Model Architecture |
|
|
|
|
|
**Hyperparameters:** |
|
|
- Node features: 128 |
|
|
- Edge features: 128 |
|
|
- Hidden features: 512 |
|
|
- Encoder layers: 3 |
|
|
- Decoder layers: 3 |
|
|
- K-nearest neighbors: 48 |
|
|
- Vocabulary size: 21 (20 amino acids + 1 unknown) |
|
|
|
|
|
**Architecture:** |
|
|
- Message-passing encoder for structural features |
|
|
- Autoregressive decoder for sequence generation |
|
|
- Attention-based edge updates |
|
|
- LayerNorm and residual connections |
|
|
|
|
|
## Training Data |
|
|
|
|
|
The models were trained on protein structures from the Protein Data Bank (PDB): |
|
|
- **Original models:** Standard PDB training set |
|
|
- **Soluble models:** Filtered for soluble, well-expressed proteins |
|
|
|
|
|
## Performance |
|
|
|
|
|
These models achieve state-of-the-art performance on: |
|
|
- Native sequence recovery |
|
|
- Structural compatibility (predicted structure vs. designed sequence) |
|
|
- Expressibility and stability (for soluble models) |
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use PrxteinMPNN in your research, please cite the original ProteinMPNN paper: |
|
|
|
|
|
```bibtex |
|
|
@article{dauparas2022robust, |
|
|
title={Robust deep learning--based protein sequence design using ProteinMPNN}, |
|
|
author={Dauparas, Justas and Anishchenko, Ivan and Bennett, Nathaniel and Bai, Hua and Ragotte, Robert J and Milles, Lukas F and Wicky, Basile IM and Courbet, Alexis and de Haas, Rob J and Bethel, Neville and others}, |
|
|
journal={Science}, |
|
|
volume={378}, |
|
|
number={6615}, |
|
|
pages={49--56}, |
|
|
year={2022}, |
|
|
publisher={American Association for the Advancement of Science} |
|
|
} |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
MIT License - See LICENSE file for details. |
|
|
|
|
|
## Links |
|
|
|
|
|
- **GitHub Repository:** [maraxen/PrxteinMPNN](https://github.com/maraxen/PrxteinMPNN) |
|
|
- **Original ProteinMPNN:** [dauparas/ProteinMPNN](https://github.com/dauparas/ProteinMPNN) |
|
|
- **Documentation:** [Full documentation](https://github.com/maraxen/PrxteinMPNN/tree/main/docs) |
|
|
|
|
|
## Technical Details |
|
|
|
|
|
### File Format |
|
|
|
|
|
Models are saved using Equinox's `tree_serialise_leaves` format (`.eqx` files), which: |
|
|
- Preserves PyTree structure |
|
|
- Ensures bit-perfect reproducibility |
|
|
- Is compatible with JAX's functional programming paradigm |
|
|
- Supports efficient serialization/deserialization |
|
|
|
|
|
### Computational Requirements |
|
|
|
|
|
- **Memory:** ~30 MB per model |
|
|
- **Inference:** CPU-compatible, GPU-accelerated |
|
|
- **Batch processing:** Supported via `jax.vmap` |
|
|
|
|
|
## Updates |
|
|
|
|
|
**Latest (v2.0):** |
|
|
- Migrated to unified Equinox architecture |
|
|
- All models now in `.eqx` format |
|
|
- Improved modularity and type safety |
|
|
- Full JAX compatibility with JIT, vmap, and grad |
|
|
|
|
|
--- |
|
|
|
|
|
For more information, examples, and tutorials, visit the [GitHub repository](https://github.com/maraxen/PrxteinMPNN). |
|
|
|