File size: 4,975 Bytes
6f18f43 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
---
license: mit
tags:
- protein-design
- protein-mpnn
- jax
- equinox
- biology
- structure-based-design
library_name: equinox
---
# PrxteinMPNN
A JAX/Equinox implementation of ProteinMPNN for inverse protein folding and sequence design.
## Model Description
PrxteinMPNN is a message-passing neural network that generates amino acid sequences given a protein backbone structure. This implementation uses JAX and Equinox for efficient computation and functional programming patterns.
**Key Features:**
- Fully modular Equinox implementation
- JAX-based for GPU acceleration and automatic differentiation
- Multiple pre-trained model variants (original and soluble)
- Multiple training epochs (002, 010, 020, 030)
## Available Models
All models use the same architecture with different training:
### Original Models
- `original_v_48_002` - Trained for 2 epochs
- `original_v_48_010` - Trained for 10 epochs
- `original_v_48_020` - Trained for 20 epochs (recommended)
- `original_v_48_030` - Trained for 30 epochs
### Soluble Models
- `soluble_v_48_002` - Trained for 2 epochs on soluble proteins
- `soluble_v_48_010` - Trained for 10 epochs on soluble proteins
- `soluble_v_48_020` - Trained for 20 epochs on soluble proteins (recommended)
- `soluble_v_48_030` - Trained for 30 epochs on soluble proteins
## Installation
```bash
pip install jax equinox huggingface_hub
```
## Usage
### Basic Usage
```python
import jax
import jax.numpy as jnp
import equinox as eqx
from huggingface_hub import hf_hub_download
# Download model from HuggingFace
model_path = hf_hub_download(
repo_id="maraxen/prxteinmpnn",
filename="eqx/original_v_48_020.eqx",
repo_type="model",
)
# Create model structure (must match saved architecture)
from prxteinmpnn.eqx_new import PrxteinMPNN
key = jax.random.PRNGKey(0)
model = PrxteinMPNN(
node_features=128,
edge_features=128,
hidden_features=512,
num_encoder_layers=3,
num_decoder_layers=3,
vocab_size=21,
k_neighbors=48,
key=key,
)
# Load weights
model = eqx.tree_deserialise_leaves(model_path, model)
# Use model for inference
# ... (see full documentation for inference examples)
```
### Using the High-Level API
```python
from prxteinmpnn.io.weights import load_model
# Automatically downloads and loads the model
model = load_model(
model_version="v_48_020",
model_weights="original"
)
```
## Model Architecture
**Hyperparameters:**
- Node features: 128
- Edge features: 128
- Hidden features: 512
- Encoder layers: 3
- Decoder layers: 3
- K-nearest neighbors: 48
- Vocabulary size: 21 (20 amino acids + 1 unknown)
**Architecture:**
- Message-passing encoder for structural features
- Autoregressive decoder for sequence generation
- Attention-based edge updates
- LayerNorm and residual connections
## Training Data
The models were trained on protein structures from the Protein Data Bank (PDB):
- **Original models:** Standard PDB training set
- **Soluble models:** Filtered for soluble, well-expressed proteins
## Performance
These models achieve state-of-the-art performance on:
- Native sequence recovery
- Structural compatibility (predicted structure vs. designed sequence)
- Expressibility and stability (for soluble models)
## Citation
If you use PrxteinMPNN in your research, please cite the original ProteinMPNN paper:
```bibtex
@article{dauparas2022robust,
title={Robust deep learning--based protein sequence design using ProteinMPNN},
author={Dauparas, Justas and Anishchenko, Ivan and Bennett, Nathaniel and Bai, Hua and Ragotte, Robert J and Milles, Lukas F and Wicky, Basile IM and Courbet, Alexis and de Haas, Rob J and Bethel, Neville and others},
journal={Science},
volume={378},
number={6615},
pages={49--56},
year={2022},
publisher={American Association for the Advancement of Science}
}
```
## License
MIT License - See LICENSE file for details.
## Links
- **GitHub Repository:** [maraxen/PrxteinMPNN](https://github.com/maraxen/PrxteinMPNN)
- **Original ProteinMPNN:** [dauparas/ProteinMPNN](https://github.com/dauparas/ProteinMPNN)
- **Documentation:** [Full documentation](https://github.com/maraxen/PrxteinMPNN/tree/main/docs)
## Technical Details
### File Format
Models are saved using Equinox's `tree_serialise_leaves` format (`.eqx` files), which:
- Preserves PyTree structure
- Ensures bit-perfect reproducibility
- Is compatible with JAX's functional programming paradigm
- Supports efficient serialization/deserialization
### Computational Requirements
- **Memory:** ~30 MB per model
- **Inference:** CPU-compatible, GPU-accelerated
- **Batch processing:** Supported via `jax.vmap`
## Updates
**Latest (v2.0):**
- Migrated to unified Equinox architecture
- All models now in `.eqx` format
- Improved modularity and type safety
- Full JAX compatibility with JIT, vmap, and grad
---
For more information, examples, and tutorials, visit the [GitHub repository](https://github.com/maraxen/PrxteinMPNN).
|