maraxen commited on
Commit
6f18f43
·
verified ·
1 Parent(s): 89f25d0

Update README with .eqx model usage instructions

Browse files
Files changed (1) hide show
  1. README.md +183 -3
README.md CHANGED
@@ -1,3 +1,183 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - protein-design
5
+ - protein-mpnn
6
+ - jax
7
+ - equinox
8
+ - biology
9
+ - structure-based-design
10
+ library_name: equinox
11
+ ---
12
+
13
+ # PrxteinMPNN
14
+
15
+ A JAX/Equinox implementation of ProteinMPNN for inverse protein folding and sequence design.
16
+
17
+ ## Model Description
18
+
19
+ PrxteinMPNN is a message-passing neural network that generates amino acid sequences given a protein backbone structure. This implementation uses JAX and Equinox for efficient computation and functional programming patterns.
20
+
21
+ **Key Features:**
22
+ - Fully modular Equinox implementation
23
+ - JAX-based for GPU acceleration and automatic differentiation
24
+ - Multiple pre-trained model variants (original and soluble)
25
+ - Multiple training epochs (002, 010, 020, 030)
26
+
27
+ ## Available Models
28
+
29
+ All models use the same architecture with different training:
30
+
31
+ ### Original Models
32
+ - `original_v_48_002` - Trained for 2 epochs
33
+ - `original_v_48_010` - Trained for 10 epochs
34
+ - `original_v_48_020` - Trained for 20 epochs (recommended)
35
+ - `original_v_48_030` - Trained for 30 epochs
36
+
37
+ ### Soluble Models
38
+ - `soluble_v_48_002` - Trained for 2 epochs on soluble proteins
39
+ - `soluble_v_48_010` - Trained for 10 epochs on soluble proteins
40
+ - `soluble_v_48_020` - Trained for 20 epochs on soluble proteins (recommended)
41
+ - `soluble_v_48_030` - Trained for 30 epochs on soluble proteins
42
+
43
+ ## Installation
44
+
45
+ ```bash
46
+ pip install jax equinox huggingface_hub
47
+ ```
48
+
49
+ ## Usage
50
+
51
+ ### Basic Usage
52
+
53
+ ```python
54
+ import jax
55
+ import jax.numpy as jnp
56
+ import equinox as eqx
57
+ from huggingface_hub import hf_hub_download
58
+
59
+ # Download model from HuggingFace
60
+ model_path = hf_hub_download(
61
+ repo_id="maraxen/prxteinmpnn",
62
+ filename="eqx/original_v_48_020.eqx",
63
+ repo_type="model",
64
+ )
65
+
66
+ # Create model structure (must match saved architecture)
67
+ from prxteinmpnn.eqx_new import PrxteinMPNN
68
+
69
+ key = jax.random.PRNGKey(0)
70
+ model = PrxteinMPNN(
71
+ node_features=128,
72
+ edge_features=128,
73
+ hidden_features=512,
74
+ num_encoder_layers=3,
75
+ num_decoder_layers=3,
76
+ vocab_size=21,
77
+ k_neighbors=48,
78
+ key=key,
79
+ )
80
+
81
+ # Load weights
82
+ model = eqx.tree_deserialise_leaves(model_path, model)
83
+
84
+ # Use model for inference
85
+ # ... (see full documentation for inference examples)
86
+ ```
87
+
88
+ ### Using the High-Level API
89
+
90
+ ```python
91
+ from prxteinmpnn.io.weights import load_model
92
+
93
+ # Automatically downloads and loads the model
94
+ model = load_model(
95
+ model_version="v_48_020",
96
+ model_weights="original"
97
+ )
98
+ ```
99
+
100
+ ## Model Architecture
101
+
102
+ **Hyperparameters:**
103
+ - Node features: 128
104
+ - Edge features: 128
105
+ - Hidden features: 512
106
+ - Encoder layers: 3
107
+ - Decoder layers: 3
108
+ - K-nearest neighbors: 48
109
+ - Vocabulary size: 21 (20 amino acids + 1 unknown)
110
+
111
+ **Architecture:**
112
+ - Message-passing encoder for structural features
113
+ - Autoregressive decoder for sequence generation
114
+ - Attention-based edge updates
115
+ - LayerNorm and residual connections
116
+
117
+ ## Training Data
118
+
119
+ The models were trained on protein structures from the Protein Data Bank (PDB):
120
+ - **Original models:** Standard PDB training set
121
+ - **Soluble models:** Filtered for soluble, well-expressed proteins
122
+
123
+ ## Performance
124
+
125
+ These models achieve state-of-the-art performance on:
126
+ - Native sequence recovery
127
+ - Structural compatibility (predicted structure vs. designed sequence)
128
+ - Expressibility and stability (for soluble models)
129
+
130
+ ## Citation
131
+
132
+ If you use PrxteinMPNN in your research, please cite the original ProteinMPNN paper:
133
+
134
+ ```bibtex
135
+ @article{dauparas2022robust,
136
+ title={Robust deep learning--based protein sequence design using ProteinMPNN},
137
+ author={Dauparas, Justas and Anishchenko, Ivan and Bennett, Nathaniel and Bai, Hua and Ragotte, Robert J and Milles, Lukas F and Wicky, Basile IM and Courbet, Alexis and de Haas, Rob J and Bethel, Neville and others},
138
+ journal={Science},
139
+ volume={378},
140
+ number={6615},
141
+ pages={49--56},
142
+ year={2022},
143
+ publisher={American Association for the Advancement of Science}
144
+ }
145
+ ```
146
+
147
+ ## License
148
+
149
+ MIT License - See LICENSE file for details.
150
+
151
+ ## Links
152
+
153
+ - **GitHub Repository:** [maraxen/PrxteinMPNN](https://github.com/maraxen/PrxteinMPNN)
154
+ - **Original ProteinMPNN:** [dauparas/ProteinMPNN](https://github.com/dauparas/ProteinMPNN)
155
+ - **Documentation:** [Full documentation](https://github.com/maraxen/PrxteinMPNN/tree/main/docs)
156
+
157
+ ## Technical Details
158
+
159
+ ### File Format
160
+
161
+ Models are saved using Equinox's `tree_serialise_leaves` format (`.eqx` files), which:
162
+ - Preserves PyTree structure
163
+ - Ensures bit-perfect reproducibility
164
+ - Is compatible with JAX's functional programming paradigm
165
+ - Supports efficient serialization/deserialization
166
+
167
+ ### Computational Requirements
168
+
169
+ - **Memory:** ~30 MB per model
170
+ - **Inference:** CPU-compatible, GPU-accelerated
171
+ - **Batch processing:** Supported via `jax.vmap`
172
+
173
+ ## Updates
174
+
175
+ **Latest (v2.0):**
176
+ - Migrated to unified Equinox architecture
177
+ - All models now in `.eqx` format
178
+ - Improved modularity and type safety
179
+ - Full JAX compatibility with JIT, vmap, and grad
180
+
181
+ ---
182
+
183
+ For more information, examples, and tutorials, visit the [GitHub repository](https://github.com/maraxen/PrxteinMPNN).