File size: 4,975 Bytes
6f18f43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
---
license: mit
tags:
- protein-design
- protein-mpnn
- jax
- equinox
- biology
- structure-based-design
library_name: equinox
---

# PrxteinMPNN

A JAX/Equinox implementation of ProteinMPNN for inverse protein folding and sequence design.

## Model Description

PrxteinMPNN is a message-passing neural network that generates amino acid sequences given a protein backbone structure. This implementation uses JAX and Equinox for efficient computation and functional programming patterns.

**Key Features:**
- Fully modular Equinox implementation
- JAX-based for GPU acceleration and automatic differentiation
- Multiple pre-trained model variants (original and soluble)
- Multiple training epochs (002, 010, 020, 030)

## Available Models

All models use the same architecture with different training:

### Original Models
- `original_v_48_002` - Trained for 2 epochs
- `original_v_48_010` - Trained for 10 epochs  
- `original_v_48_020` - Trained for 20 epochs (recommended)
- `original_v_48_030` - Trained for 30 epochs

### Soluble Models
- `soluble_v_48_002` - Trained for 2 epochs on soluble proteins
- `soluble_v_48_010` - Trained for 10 epochs on soluble proteins
- `soluble_v_48_020` - Trained for 20 epochs on soluble proteins (recommended)
- `soluble_v_48_030` - Trained for 30 epochs on soluble proteins

## Installation

```bash
pip install jax equinox huggingface_hub
```

## Usage

### Basic Usage

```python
import jax
import jax.numpy as jnp
import equinox as eqx
from huggingface_hub import hf_hub_download

# Download model from HuggingFace
model_path = hf_hub_download(
    repo_id="maraxen/prxteinmpnn",
    filename="eqx/original_v_48_020.eqx",
    repo_type="model",
)

# Create model structure (must match saved architecture)
from prxteinmpnn.eqx_new import PrxteinMPNN

key = jax.random.PRNGKey(0)
model = PrxteinMPNN(
    node_features=128,
    edge_features=128,
    hidden_features=512,
    num_encoder_layers=3,
    num_decoder_layers=3,
    vocab_size=21,
    k_neighbors=48,
    key=key,
)

# Load weights
model = eqx.tree_deserialise_leaves(model_path, model)

# Use model for inference
# ... (see full documentation for inference examples)
```

### Using the High-Level API

```python
from prxteinmpnn.io.weights import load_model

# Automatically downloads and loads the model
model = load_model(
    model_version="v_48_020",
    model_weights="original"
)
```

## Model Architecture

**Hyperparameters:**
- Node features: 128
- Edge features: 128
- Hidden features: 512
- Encoder layers: 3
- Decoder layers: 3
- K-nearest neighbors: 48
- Vocabulary size: 21 (20 amino acids + 1 unknown)

**Architecture:**
- Message-passing encoder for structural features
- Autoregressive decoder for sequence generation
- Attention-based edge updates
- LayerNorm and residual connections

## Training Data

The models were trained on protein structures from the Protein Data Bank (PDB):
- **Original models:** Standard PDB training set
- **Soluble models:** Filtered for soluble, well-expressed proteins

## Performance

These models achieve state-of-the-art performance on:
- Native sequence recovery
- Structural compatibility (predicted structure vs. designed sequence)
- Expressibility and stability (for soluble models)

## Citation

If you use PrxteinMPNN in your research, please cite the original ProteinMPNN paper:

```bibtex
@article{dauparas2022robust,
  title={Robust deep learning--based protein sequence design using ProteinMPNN},
  author={Dauparas, Justas and Anishchenko, Ivan and Bennett, Nathaniel and Bai, Hua and Ragotte, Robert J and Milles, Lukas F and Wicky, Basile IM and Courbet, Alexis and de Haas, Rob J and Bethel, Neville and others},
  journal={Science},
  volume={378},
  number={6615},
  pages={49--56},
  year={2022},
  publisher={American Association for the Advancement of Science}
}
```

## License

MIT License - See LICENSE file for details.

## Links

- **GitHub Repository:** [maraxen/PrxteinMPNN](https://github.com/maraxen/PrxteinMPNN)
- **Original ProteinMPNN:** [dauparas/ProteinMPNN](https://github.com/dauparas/ProteinMPNN)
- **Documentation:** [Full documentation](https://github.com/maraxen/PrxteinMPNN/tree/main/docs)

## Technical Details

### File Format

Models are saved using Equinox's `tree_serialise_leaves` format (`.eqx` files), which:
- Preserves PyTree structure
- Ensures bit-perfect reproducibility
- Is compatible with JAX's functional programming paradigm
- Supports efficient serialization/deserialization

### Computational Requirements

- **Memory:** ~30 MB per model
- **Inference:** CPU-compatible, GPU-accelerated
- **Batch processing:** Supported via `jax.vmap`

## Updates

**Latest (v2.0):**
- Migrated to unified Equinox architecture
- All models now in `.eqx` format
- Improved modularity and type safety
- Full JAX compatibility with JIT, vmap, and grad

---

For more information, examples, and tutorials, visit the [GitHub repository](https://github.com/maraxen/PrxteinMPNN).