prxteinmpnn / README.md
maraxen's picture
Update README with .eqx model usage instructions
6f18f43 verified
---
license: mit
tags:
- protein-design
- protein-mpnn
- jax
- equinox
- biology
- structure-based-design
library_name: equinox
---
# PrxteinMPNN
A JAX/Equinox implementation of ProteinMPNN for inverse protein folding and sequence design.
## Model Description
PrxteinMPNN is a message-passing neural network that generates amino acid sequences given a protein backbone structure. This implementation uses JAX and Equinox for efficient computation and functional programming patterns.
**Key Features:**
- Fully modular Equinox implementation
- JAX-based for GPU acceleration and automatic differentiation
- Multiple pre-trained model variants (original and soluble)
- Multiple training epochs (002, 010, 020, 030)
## Available Models
All models use the same architecture with different training:
### Original Models
- `original_v_48_002` - Trained for 2 epochs
- `original_v_48_010` - Trained for 10 epochs
- `original_v_48_020` - Trained for 20 epochs (recommended)
- `original_v_48_030` - Trained for 30 epochs
### Soluble Models
- `soluble_v_48_002` - Trained for 2 epochs on soluble proteins
- `soluble_v_48_010` - Trained for 10 epochs on soluble proteins
- `soluble_v_48_020` - Trained for 20 epochs on soluble proteins (recommended)
- `soluble_v_48_030` - Trained for 30 epochs on soluble proteins
## Installation
```bash
pip install jax equinox huggingface_hub
```
## Usage
### Basic Usage
```python
import jax
import jax.numpy as jnp
import equinox as eqx
from huggingface_hub import hf_hub_download
# Download model from HuggingFace
model_path = hf_hub_download(
repo_id="maraxen/prxteinmpnn",
filename="eqx/original_v_48_020.eqx",
repo_type="model",
)
# Create model structure (must match saved architecture)
from prxteinmpnn.eqx_new import PrxteinMPNN
key = jax.random.PRNGKey(0)
model = PrxteinMPNN(
node_features=128,
edge_features=128,
hidden_features=512,
num_encoder_layers=3,
num_decoder_layers=3,
vocab_size=21,
k_neighbors=48,
key=key,
)
# Load weights
model = eqx.tree_deserialise_leaves(model_path, model)
# Use model for inference
# ... (see full documentation for inference examples)
```
### Using the High-Level API
```python
from prxteinmpnn.io.weights import load_model
# Automatically downloads and loads the model
model = load_model(
model_version="v_48_020",
model_weights="original"
)
```
## Model Architecture
**Hyperparameters:**
- Node features: 128
- Edge features: 128
- Hidden features: 512
- Encoder layers: 3
- Decoder layers: 3
- K-nearest neighbors: 48
- Vocabulary size: 21 (20 amino acids + 1 unknown)
**Architecture:**
- Message-passing encoder for structural features
- Autoregressive decoder for sequence generation
- Attention-based edge updates
- LayerNorm and residual connections
## Training Data
The models were trained on protein structures from the Protein Data Bank (PDB):
- **Original models:** Standard PDB training set
- **Soluble models:** Filtered for soluble, well-expressed proteins
## Performance
These models achieve state-of-the-art performance on:
- Native sequence recovery
- Structural compatibility (predicted structure vs. designed sequence)
- Expressibility and stability (for soluble models)
## Citation
If you use PrxteinMPNN in your research, please cite the original ProteinMPNN paper:
```bibtex
@article{dauparas2022robust,
title={Robust deep learning--based protein sequence design using ProteinMPNN},
author={Dauparas, Justas and Anishchenko, Ivan and Bennett, Nathaniel and Bai, Hua and Ragotte, Robert J and Milles, Lukas F and Wicky, Basile IM and Courbet, Alexis and de Haas, Rob J and Bethel, Neville and others},
journal={Science},
volume={378},
number={6615},
pages={49--56},
year={2022},
publisher={American Association for the Advancement of Science}
}
```
## License
MIT License - See LICENSE file for details.
## Links
- **GitHub Repository:** [maraxen/PrxteinMPNN](https://github.com/maraxen/PrxteinMPNN)
- **Original ProteinMPNN:** [dauparas/ProteinMPNN](https://github.com/dauparas/ProteinMPNN)
- **Documentation:** [Full documentation](https://github.com/maraxen/PrxteinMPNN/tree/main/docs)
## Technical Details
### File Format
Models are saved using Equinox's `tree_serialise_leaves` format (`.eqx` files), which:
- Preserves PyTree structure
- Ensures bit-perfect reproducibility
- Is compatible with JAX's functional programming paradigm
- Supports efficient serialization/deserialization
### Computational Requirements
- **Memory:** ~30 MB per model
- **Inference:** CPU-compatible, GPU-accelerated
- **Batch processing:** Supported via `jax.vmap`
## Updates
**Latest (v2.0):**
- Migrated to unified Equinox architecture
- All models now in `.eqx` format
- Improved modularity and type safety
- Full JAX compatibility with JIT, vmap, and grad
---
For more information, examples, and tutorials, visit the [GitHub repository](https://github.com/maraxen/PrxteinMPNN).