Update README.md
Browse files
README.md
CHANGED
|
@@ -5,9 +5,9 @@ datasets:
|
|
| 5 |
tags:
|
| 6 |
- esmc
|
| 7 |
---
|
| 8 |
-
# ProtHash
|
| 9 |
|
| 10 |
-
A protein language model that outputs amino acid sequence embeddings for use in clustering, classification, locality-sensitive hashing, and more. Distilled from the [ESMC](https://www.evolutionaryscale.ai/blog/esm-cambrian) family of models with deep comprehension of protein structure, ProtHash produces contextual embeddings that align in vector space according to the sequences' atomic structure. Trained on the [SwissProt](https://huggingface.co/datasets/andrewdalpino/SwissProt-Gene-Ontology) dataset to mimic the activations of its ESMC teacher model, ProtHash embeddings have near
|
| 11 |
|
| 12 |
## Key Features
|
| 13 |
|
|
@@ -15,13 +15,13 @@ A protein language model that outputs amino acid sequence embeddings for use in
|
|
| 15 |
|
| 16 |
- **Structurally-relevant**: Structurally similar proteins will show up nearby in the embedding space enabling downstream tasks such as clustering, classification, and locality-sensitive hashing based on atomic structure.
|
| 17 |
|
| 18 |
-
- **Compatible with ESMC**: ProtHash can output embeddings in its native or ESMC teacher's dimensionality - allowing it to serve as
|
| 19 |
|
| 20 |
-
- **Quantization-ready**: With quantization-aware post-training, ProtHash allows you to quantize the weights of the model
|
| 21 |
|
| 22 |
## Pretrained Models
|
| 23 |
|
| 24 |
-
| Name | Context Length | Embedding
|
| 25 |
|---|---|---|---|---|---|---|---|
|
| 26 |
| [andrewdalpino/ProtHash-384-Tiny](https://huggingface.co/andrewdalpino/ProtHash-384-Tiny) | 2048 | 384 | 16/4 | 4 | 5M | esmc_300m | 960 |
|
| 27 |
| [andrewdalpino/ProtHash-384](https://huggingface.co/andrewdalpino/ProtHash-384) | 2048 | 384 | 16/4 | 10 | 11M | esmc_300m | 960 |
|
|
@@ -51,6 +51,9 @@ model_name = "andrewdalpino/ProtHash-512-Tiny"
|
|
| 51 |
|
| 52 |
model = ProtHash.from_pretrained(model_name)
|
| 53 |
|
|
|
|
|
|
|
|
|
|
| 54 |
sequence = input("Enter a sequence: ")
|
| 55 |
|
| 56 |
out = tokenizer(sequence, max_length=2048)
|
|
@@ -63,14 +66,87 @@ x = torch.tensor(tokens, dtype=torch.int64).unsqueeze(0)
|
|
| 63 |
# Output the sequence embedding in native dimensionality.
|
| 64 |
y_embed_native = model.embed_native(x).squeeze(0)
|
| 65 |
|
| 66 |
-
print(y_embed_native.shape)
|
| 67 |
-
|
| 68 |
# Output a drop-in replacement for the teacher's embeddings.
|
| 69 |
y_embed_teacher = model.embed_teacher(x).squeeze(0)
|
| 70 |
|
|
|
|
| 71 |
print(y_embed_teacher.shape)
|
| 72 |
```
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
## References
|
| 75 |
|
| 76 |
>- The UniProt Consortium, UniProt: the Universal Protein Knowledgebase in 2025, Nucleic Acids Research, 2025, 53, D609–D617.
|
|
|
|
| 5 |
tags:
|
| 6 |
- esmc
|
| 7 |
---
|
| 8 |
+
# ESM ProtHash
|
| 9 |
|
| 10 |
+
A protein language model that outputs amino acid sequence embeddings for use in clustering, classification, locality-sensitive hashing, and more. Distilled from the [ESMC](https://www.evolutionaryscale.ai/blog/esm-cambrian) family of models with deep comprehension of protein structure, ProtHash produces contextual embeddings that align in vector space according to the sequences' atomic structure. Trained on the [SwissProt](https://huggingface.co/datasets/andrewdalpino/SwissProt-Gene-Ontology) dataset to mimic the activations of its ESMC teacher model, ProtHash embeddings have near-perfect similarity to ESMC embeddings but at a greatly reduced computational cost.
|
| 11 |
|
| 12 |
## Key Features
|
| 13 |
|
|
|
|
| 15 |
|
| 16 |
- **Structurally-relevant**: Structurally similar proteins will show up nearby in the embedding space enabling downstream tasks such as clustering, classification, and locality-sensitive hashing based on atomic structure.
|
| 17 |
|
| 18 |
+
- **Compatible with ESMC**: ProtHash can output embeddings in its native or ESMC teacher's dimensionality - allowing it to serve as either a faster drop-in approximation to ESMC embeddings or a more efficient compressed representation.
|
| 19 |
|
| 20 |
+
- **Quantization-ready**: With quantization-aware post-training, ProtHash allows you to quantize the weights of the model while maintaining similarity to the teacher's embedding space.
|
| 21 |
|
| 22 |
## Pretrained Models
|
| 23 |
|
| 24 |
+
| Name | Context Length | Embedding Dimensions | Attention Heads (Q/KV) | Encoder Layers | Total Params | Teacher Model | Teacher Dimensions |
|
| 25 |
|---|---|---|---|---|---|---|---|
|
| 26 |
| [andrewdalpino/ProtHash-384-Tiny](https://huggingface.co/andrewdalpino/ProtHash-384-Tiny) | 2048 | 384 | 16/4 | 4 | 5M | esmc_300m | 960 |
|
| 27 |
| [andrewdalpino/ProtHash-384](https://huggingface.co/andrewdalpino/ProtHash-384) | 2048 | 384 | 16/4 | 10 | 11M | esmc_300m | 960 |
|
|
|
|
| 51 |
|
| 52 |
model = ProtHash.from_pretrained(model_name)
|
| 53 |
|
| 54 |
+
# Optionally quantize the weights.
|
| 55 |
+
model.quantize_weights()
|
| 56 |
+
|
| 57 |
sequence = input("Enter a sequence: ")
|
| 58 |
|
| 59 |
out = tokenizer(sequence, max_length=2048)
|
|
|
|
| 66 |
# Output the sequence embedding in native dimensionality.
|
| 67 |
y_embed_native = model.embed_native(x).squeeze(0)
|
| 68 |
|
|
|
|
|
|
|
| 69 |
# Output a drop-in replacement for the teacher's embeddings.
|
| 70 |
y_embed_teacher = model.embed_teacher(x).squeeze(0)
|
| 71 |
|
| 72 |
+
print(y_embed_native.shape)
|
| 73 |
print(y_embed_teacher.shape)
|
| 74 |
```
|
| 75 |
|
| 76 |
+
## Training
|
| 77 |
+
|
| 78 |
+
### Clone the project repo
|
| 79 |
+
|
| 80 |
+
We'll need the code from the project repository to train and/or fine-tune the model.
|
| 81 |
+
|
| 82 |
+
```sh
|
| 83 |
+
git clone https://github.com/andrewdalpino/ProtHash
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
### Install Project Dependencies
|
| 87 |
+
|
| 88 |
+
Project dependencies are specified in the requirements.txt file. You can install them with pip using the following command from the project root. We recommend using a virtual environment such as `venv` to keep package dependencies on your system tidy.
|
| 89 |
+
|
| 90 |
+
python -m venv ./.venv
|
| 91 |
+
|
| 92 |
+
source ./.venv/bin/activate
|
| 93 |
+
|
| 94 |
+
pip install -r requirements.txt
|
| 95 |
+
|
| 96 |
+
### Distilling
|
| 97 |
+
|
| 98 |
+
ProtHash is trained to mimic the activations of its ESMC teacher model. To begin distillation with the default arguments check the example below.
|
| 99 |
+
|
| 100 |
+
```sh
|
| 101 |
+
python train.py
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
You can change the default arguments like in the example below.
|
| 105 |
+
|
| 106 |
+
```sh
|
| 107 |
+
python train --teacher_name="esmc_300m" --max_steps=3500 --embedding_dimensions=256
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
#### Training Dashboard
|
| 111 |
+
|
| 112 |
+
We use [TensorBoard](https://www.tensorflow.org/tensorboard) to capture and display training events such as loss and gradient norm updates. To launch the dashboard server run the following command from the terminal.
|
| 113 |
+
|
| 114 |
+
```sh
|
| 115 |
+
tensorboard --logdir=./runs
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
Then navigate to the dashboard using your favorite web browser.
|
| 119 |
+
|
| 120 |
+
#### Training Arguments
|
| 121 |
+
|
| 122 |
+
| Argument | Default | Type | Description |
|
| 123 |
+
|---|---|---|---|
|
| 124 |
+
| --teacher_name | 'esmc_600m' | str | The teacher model name. |
|
| 125 |
+
| --num_dataset_processes | 1 | int | The number of CPU processes to use to preprocess the dataset. |
|
| 126 |
+
| --min_sequence_length | 1 | int | The minimum length of the input sequences. |
|
| 127 |
+
| --max_sequence_length | 2048 | int | The maximum length of the input sequences. |
|
| 128 |
+
| --quantization_aware_training | False | bool | Should we add fake quantized tensors to simulate quantized training? |
|
| 129 |
+
| --batch_size | 4 | int | The number of training images to pass through the network at a time. |
|
| 130 |
+
| --gradient_accumulation_steps | 32 | int | The number of batches to pass through the network before updating the model weights. |
|
| 131 |
+
| --max_steps | 4200 | int | The number of steps to train for. |
|
| 132 |
+
| --learning_rate | 1e-4 | float | The learning rate of the AdamW optimizer. |
|
| 133 |
+
| --max_gradient_norm | 100.0 | float | Clip gradients above this threshold norm before stepping. |
|
| 134 |
+
| --temperature | 8.0 | float | The smoothing parameter of the activations - higher temperature results in smoother activations. |
|
| 135 |
+
| --embeddings_dimensions | 512 | int | The dimensionality of the native embeddings. |
|
| 136 |
+
| --q_heads | 16 | int | The number of query heads used in the self-attention layers. |
|
| 137 |
+
| --kv_heads | 4 | int | The number of key and value heads used in the self-attention layers. |
|
| 138 |
+
| --hidden_ratio | 2 | (1, 2, 4) | The ratio of hidden neurons to embedding dimensions in the feed-forward layers of the network.|
|
| 139 |
+
| --num_encoder_layers | 4 | int | The number of layers within the body of the encoder. |
|
| 140 |
+
| --dropout | 0.0 | float | The proportion of activations to send to zero during training as regularization. |
|
| 141 |
+
| --activation_checkpointing | False | bool | Should we use activation checkpointing? This will drastically reduce memory utilization during training at the cost of recomputing the forward pass. |
|
| 142 |
+
| --eval_interval | 100 | int | Evaluate the model after this many epochs on the testing set. |
|
| 143 |
+
| --checkpoint_interval | 100 | int | Save the model checkpoint to disk every this many epochs. |
|
| 144 |
+
| --checkpoint_path | "./checkpoints/checkpoint.pt" | str | The path to the base checkpoint file on disk. |
|
| 145 |
+
| --resume | False | bool | Should we resume training from the last checkpoint? |
|
| 146 |
+
| --run_dir_path | "./runs" | str | The path to the TensorBoard run directory for this training session. |
|
| 147 |
+
| --device | "cpu" | str | The device to run the computation on. |
|
| 148 |
+
| --seed | None | int | The seed for the random number generator. |
|
| 149 |
+
|
| 150 |
## References
|
| 151 |
|
| 152 |
>- The UniProt Consortium, UniProt: the Universal Protein Knowledgebase in 2025, Nucleic Acids Research, 2025, 53, D609–D617.
|