andrewdalpino commited on
Commit
9572b6c
·
verified ·
1 Parent(s): 6b18a0e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +166 -6
README.md CHANGED
@@ -1,10 +1,170 @@
1
  ---
 
 
 
2
  tags:
3
- - model_hub_mixin
4
- - pytorch_model_hub_mixin
5
  ---
 
6
 
7
- This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
- - Code: [More Information Needed]
9
- - Paper: [More Information Needed]
10
- - Docs: [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - andrewdalpino/SwissProt-Gene-Ontology
5
  tags:
6
+ - esmc
 
7
  ---
8
+ # ESMC ProtHash
9
 
10
+ A protein language model that outputs amino acid sequence embeddings for use in clustering, classification, locality-sensitive hashing, and more. Distilled from the [ESMC](https://www.evolutionaryscale.ai/blog/esm-cambrian) family of models, ProtHash produces contextual embeddings that align in vector space according to the sequences' underlying biological properties such as structure and function. Trained on the [SwissProt](https://huggingface.co/datasets/andrewdalpino/SwissProt-Gene-Ontology) dataset to mimic the activations of its ESMC teacher model, ProtHash embeddings have near-perfect similarity to ESMC embeddings but at a greatly reduced computational cost.
11
+
12
+ ## Key Features
13
+
14
+ - **Blazing fast and efficient**: ProtHash uses less than 1.5% of its ESMC teacher's total parameters to achieve near-perfect cosine similarity between the two embedding spaces.
15
+
16
+ - **Biologically-relevant**: Biologically similar proteins will show up nearby in the embedding space enabling downstream tasks such as clustering, classification, and locality-sensitive hashing.
17
+
18
+ - **Compatible with ESMC**: ProtHash can output embeddings in its native or ESMC teacher's dimensionality - allowing it to serve as either a faster drop-in approximation to ESMC embeddings or a more efficient compressed representation.
19
+
20
+ - **Quantization-ready**: With quantization-aware post-training, ProtHash allows you to quantize the weights of the model while maintaining its near-perfect similarity to the teacher's embedding space.
21
+
22
+ ## Pretrained Models
23
+
24
+ | Name | Context Length | Position Embeddings | Embedding Dimensions | Attention Heads (Q/KV) | Encoder Layers | Total Params | Teacher Model | Teacher Dimensions | Library Version |
25
+ | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
26
+ | [andrewdalpino/ProtHash-V2-384-Tiny](https://huggingface.co/andrewdalpino/ProtHash-V2-384-Tiny) | 2048 | Relative | 384 | 16/4 | 4 | 4.2M | esmc_300m | 960 | 0.2.x |
27
+ | [andrewdalpino/ProtHash-V2-384](https://huggingface.co/andrewdalpino/ProtHash-V2-384) | 2048 | Relative | 384 | 16/4 | 10 | 10M | esmc_300m | 960 | 0.2.x |
28
+ | [andrewdalpino/ProtHash-V2-512-Tiny](https://huggingface.co/andrewdalpino/ProtHash-V2-512-Tiny) | 2048 | Relative | 512 | 16/4 | 4 | 7.4M | esmc_600m | 1152 | 0.2.x |
29
+ | [andrewdalpino/ProtHash-V2-512](https://huggingface.co/andrewdalpino/ProtHash-V2-512) | 2048 | Relative | 512 | 16/4 | 10 | 18M | esmc_600m | 1152 | 0.2.x |
30
+ | [andrewdalpino/ProtHash-384-Tiny](https://huggingface.co/andrewdalpino/ProtHash-384-Tiny) | 2048 | Absolute | 384 | 16/4 | 4 | 5M | esmc_300m | 960 | 0.1.x |
31
+ | [andrewdalpino/ProtHash-384](https://huggingface.co/andrewdalpino/ProtHash-384) | 2048 | Absolute | 384 | 16/4 | 10 | 11M | esmc_300m | 960 | 0.1.x |
32
+ | [andrewdalpino/ProtHash-512-Tiny](https://huggingface.co/andrewdalpino/ProtHash-512-Tiny) | 2048 | Absolute | 512 | 16/4 | 4 | 8.5M | esmc_600m | 1152 | 0.1.x |
33
+ | [andrewdalpino/ProtHash-512](https://huggingface.co/andrewdalpino/ProtHash-512) | 2048 | Absolute | 512 | 16/4 | 10 | 19M | esmc_600m | 1152 | 0.1.x |
34
+
35
+ ## Example
36
+
37
+ First, you'll need the `prothash` and `esm` packages installed into your environment. For ProtHash version 1 use library version `0.1.x` and for version 2 install library version `0.2.x`. We recommend using a virtual environment such as Python's `venv` module to prevent version conflicts with other packages.
38
+
39
+ ### Version 1
40
+
41
+ ```sh
42
+ pip install prothash~=0.1.0 esm
43
+ ```
44
+
45
+ ### Version 2
46
+
47
+ ```sh
48
+ pip install prothash~=0.2.0 esm
49
+ ```
50
+
51
+ Then, load the weights from HuggingFace Hub, tokenize a protein sequence, and pass it to the model. ProtHash adopts the ESM tokenizer as it's amino acids tokenization scheme which consists of a vocabulary of 33 amino acid and special tokens. The output will be an embedding vector that can be used in downstream tasks such as comparing to other protein sequence embeddings, clustering, and near-duplicate detection.
52
+
53
+ ```python
54
+ import torch
55
+
56
+ from esm.tokenization import EsmSequenceTokenizer
57
+
58
+ from prothash.model import ProtHash
59
+
60
+ tokenizer = EsmSequenceTokenizer()
61
+
62
+ model_name = "andrewdalpino/ProtHash-V2-512-Tiny"
63
+
64
+ model = ProtHash.from_pretrained(model_name)
65
+
66
+ # Optionally quantize the weights to Int8.
67
+ model.quantize_weights()
68
+
69
+ sequence = input("Enter a sequence: ")
70
+
71
+ out = tokenizer(sequence, max_length=2048)
72
+
73
+ tokens = out["input_ids"]
74
+
75
+ # Input is a [1, T] tensor of token indices.
76
+ x = torch.tensor(tokens, dtype=torch.int64).unsqueeze(0)
77
+
78
+ # Output the sequence embedding in native dimensionality.
79
+ y_embed_native = model.embed_native(x).squeeze(0)
80
+
81
+ # Output a drop-in replacement for the teacher's embeddings.
82
+ y_embed_teacher = model.embed_teacher(x).squeeze(0)
83
+
84
+ print(y_embed_native.shape)
85
+ print(y_embed_teacher.shape)
86
+ ```
87
+
88
+ ## Training
89
+
90
+ If you want to train your own custom ProtHash model then follow the instructions below.
91
+
92
+ ### Clone the project repo
93
+
94
+ We'll need the code from the project repository to train and/or fine-tune the model.
95
+
96
+ ```sh
97
+ git clone https://github.com/andrewdalpino/ProtHash
98
+ ```
99
+
100
+ ### Install Project Dependencies
101
+
102
+ Project dependencies are specified in the requirements.txt file. You can install them with pip using the following command from the project root. We recommend using a virtual environment such as `venv` to keep package dependencies on your system tidy.
103
+
104
+ python -m venv ./.venv
105
+
106
+ source ./.venv/bin/activate
107
+
108
+ pip install -r requirements.txt
109
+
110
+ ### Distilling
111
+
112
+ ProtHash is trained to mimic the activations of its ESMC teacher model. To begin distillation with the default arguments check the example below.
113
+
114
+ ```sh
115
+ python train.py
116
+ ```
117
+
118
+ You can change the default arguments like in the example below.
119
+
120
+ ```sh
121
+ python train --teacher_name="esmc_300m" --max_steps=4200 --embedding_dimensions=768 --temperature=4.0
122
+ ```
123
+
124
+ #### Training Dashboard
125
+
126
+ We use [TensorBoard](https://www.tensorflow.org/tensorboard) to capture and display training events such as loss and gradient norm updates. To launch the dashboard server run the following command from the terminal.
127
+
128
+ ```sh
129
+ tensorboard --logdir=./runs
130
+ ```
131
+
132
+ Then navigate to the dashboard using your favorite web browser.
133
+
134
+ #### Training Arguments
135
+
136
+ | Argument | Default | Type | Description |
137
+ |---|---|---|---|
138
+ | --teacher_name | 'esmc_600m' | str | The teacher model name. |
139
+ | --num_dataset_processes | 1 | int | The number of CPU processes to use to preprocess the dataset. |
140
+ | --min_sequence_length | 1 | int | The minimum length of the input sequences. |
141
+ | --max_sequence_length | 2048 | int | The maximum length of the input sequences. |
142
+ | --quantization_aware_training | False | bool | Should we add fake quantized tensors to simulate quantized training? |
143
+ | --batch_size | 4 | int | The number of training samples to pass through the network at a time. |
144
+ | --gradient_accumulation_steps | 32 | int | The number of batches to pass through the network before updating the model weights. |
145
+ | --max_steps | 4000 | int | The number of steps to train for. |
146
+ | --learning_rate | 1e-4 | float | The learning rate of the AdamW optimizer. |
147
+ | --max_gradient_norm | 100.0 | float | Clip gradients above this threshold norm before stepping. |
148
+ | --temperature | 8.0 | float | The smoothing parameter of the activations - higher temperature results in smoother activations. |
149
+ | --embeddings_dimensions | 512 | int | The dimensionality of the native embeddings. |
150
+ | --q_heads | 16 | int | The number of query heads used in the self-attention layers. |
151
+ | --kv_heads | 4 | int | The number of key and value heads used in the self-attention layers. |
152
+ | --hidden_ratio | 2 | (1, 2, 4) | The ratio of hidden neurons to embedding dimensions in the feed-forward layers of the network.|
153
+ | --num_encoder_layers | 4 | int | The number of layers within the body of the encoder. |
154
+ | --dropout | 0.0 | float | The proportion of activations to send to zero during training as regularization. |
155
+ | --activation_checkpointing | False | bool | Should we use activation checkpointing? This will drastically reduce memory utilization during training at the cost of recomputing the forward pass. |
156
+ | --eval_interval | 100 | int | Evaluate the model after this many epochs on the testing set. |
157
+ | --checkpoint_interval | 100 | int | Save the model checkpoint to disk every this many epochs. |
158
+ | --checkpoint_path | "./checkpoints/checkpoint.pt" | str | The path to the base checkpoint file on disk. |
159
+ | --resume | False | bool | Should we resume training from the last checkpoint? |
160
+ | --run_dir_path | "./runs" | str | The path to the TensorBoard run directory for this training session. |
161
+ | --device | "cpu" | str | The device to run the computation on. |
162
+ | --seed | None | int | The seed for the random number generator. |
163
+
164
+ ## References
165
+
166
+ >- The UniProt Consortium, UniProt: the Universal Protein Knowledgebase in 2025, Nucleic Acids Research, 2025, 53, D609–D617.
167
+ >- T. Hayes, et al. Simulating 500 million years of evolution with a language model, 2024.
168
+ >- B. Zhang, et al. Root Mean Square Layer Normalization. 33rd Conference on Neural Information Processing Systems, NeurIPS 2019.
169
+ >- J. Ainslie, et al. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints, Google Research, 2023.
170
+ >- T. Kim, et al. Comparing Kullback-Leibler Divergence and Mean Squared Error Loss in Knowledge Distillation, 2021.