andrewdalpino commited on
Commit
a4666ee
·
verified ·
1 Parent(s): f025b65

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +83 -7
README.md CHANGED
@@ -5,9 +5,9 @@ datasets:
5
  tags:
6
  - esmc
7
  ---
8
- # ProtHash
9
 
10
- A protein language model that outputs amino acid sequence embeddings for use in clustering, classification, locality-sensitive hashing, and more. Distilled from the [ESMC](https://www.evolutionaryscale.ai/blog/esm-cambrian) family of models with deep comprehension of protein structure, ProtHash produces contextual embeddings that align in vector space according to the sequences' atomic structure. Trained on the [SwissProt](https://huggingface.co/datasets/andrewdalpino/SwissProt-Gene-Ontology) dataset to mimic the activations of its ESMC teacher model, ProtHash embeddings have near perfect similarity to ESMC embeddings but at a greatly reduced computational cost.
11
 
12
  ## Key Features
13
 
@@ -15,13 +15,13 @@ A protein language model that outputs amino acid sequence embeddings for use in
15
 
16
  - **Structurally-relevant**: Structurally similar proteins will show up nearby in the embedding space enabling downstream tasks such as clustering, classification, and locality-sensitive hashing based on atomic structure.
17
 
18
- - **Compatible with ESMC**: ProtHash can output embeddings in its native or ESMC teacher's dimensionality - allowing it to serve as both a faster drop-in replacement for ESMC embeddings and a more efficient compressed representation.
19
 
20
- - **Quantization-ready**: With quantization-aware post-training, ProtHash allows you to quantize the weights of the model without losing similarity to the teacher's embedding space.
21
 
22
  ## Pretrained Models
23
 
24
- | Name | Context Length | Embedding Dimensionality | Attention Heads (Q/KV) | Encoder Layers | Total Params | Teacher Model | Teacher Dimensionality |
25
  |---|---|---|---|---|---|---|---|
26
  | [andrewdalpino/ProtHash-384-Tiny](https://huggingface.co/andrewdalpino/ProtHash-384-Tiny) | 2048 | 384 | 16/4 | 4 | 5M | esmc_300m | 960 |
27
  | [andrewdalpino/ProtHash-384](https://huggingface.co/andrewdalpino/ProtHash-384) | 2048 | 384 | 16/4 | 10 | 11M | esmc_300m | 960 |
@@ -51,6 +51,9 @@ model_name = "andrewdalpino/ProtHash-512-Tiny"
51
 
52
  model = ProtHash.from_pretrained(model_name)
53
 
 
 
 
54
  sequence = input("Enter a sequence: ")
55
 
56
  out = tokenizer(sequence, max_length=2048)
@@ -63,14 +66,87 @@ x = torch.tensor(tokens, dtype=torch.int64).unsqueeze(0)
63
  # Output the sequence embedding in native dimensionality.
64
  y_embed_native = model.embed_native(x).squeeze(0)
65
 
66
- print(y_embed_native.shape)
67
-
68
  # Output a drop-in replacement for the teacher's embeddings.
69
  y_embed_teacher = model.embed_teacher(x).squeeze(0)
70
 
 
71
  print(y_embed_teacher.shape)
72
  ```
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  ## References
75
 
76
  >- The UniProt Consortium, UniProt: the Universal Protein Knowledgebase in 2025, Nucleic Acids Research, 2025, 53, D609–D617.
 
5
  tags:
6
  - esmc
7
  ---
8
+ # ESM ProtHash
9
 
10
+ A protein language model that outputs amino acid sequence embeddings for use in clustering, classification, locality-sensitive hashing, and more. Distilled from the [ESMC](https://www.evolutionaryscale.ai/blog/esm-cambrian) family of models with deep comprehension of protein structure, ProtHash produces contextual embeddings that align in vector space according to the sequences' atomic structure. Trained on the [SwissProt](https://huggingface.co/datasets/andrewdalpino/SwissProt-Gene-Ontology) dataset to mimic the activations of its ESMC teacher model, ProtHash embeddings have near-perfect similarity to ESMC embeddings but at a greatly reduced computational cost.
11
 
12
  ## Key Features
13
 
 
15
 
16
  - **Structurally-relevant**: Structurally similar proteins will show up nearby in the embedding space enabling downstream tasks such as clustering, classification, and locality-sensitive hashing based on atomic structure.
17
 
18
+ - **Compatible with ESMC**: ProtHash can output embeddings in its native or ESMC teacher's dimensionality - allowing it to serve as either a faster drop-in approximation to ESMC embeddings or a more efficient compressed representation.
19
 
20
+ - **Quantization-ready**: With quantization-aware post-training, ProtHash allows you to quantize the weights of the model while maintaining similarity to the teacher's embedding space.
21
 
22
  ## Pretrained Models
23
 
24
+ | Name | Context Length | Embedding Dimensions | Attention Heads (Q/KV) | Encoder Layers | Total Params | Teacher Model | Teacher Dimensions |
25
  |---|---|---|---|---|---|---|---|
26
  | [andrewdalpino/ProtHash-384-Tiny](https://huggingface.co/andrewdalpino/ProtHash-384-Tiny) | 2048 | 384 | 16/4 | 4 | 5M | esmc_300m | 960 |
27
  | [andrewdalpino/ProtHash-384](https://huggingface.co/andrewdalpino/ProtHash-384) | 2048 | 384 | 16/4 | 10 | 11M | esmc_300m | 960 |
 
51
 
52
  model = ProtHash.from_pretrained(model_name)
53
 
54
+ # Optionally quantize the weights.
55
+ model.quantize_weights()
56
+
57
  sequence = input("Enter a sequence: ")
58
 
59
  out = tokenizer(sequence, max_length=2048)
 
66
  # Output the sequence embedding in native dimensionality.
67
  y_embed_native = model.embed_native(x).squeeze(0)
68
 
 
 
69
  # Output a drop-in replacement for the teacher's embeddings.
70
  y_embed_teacher = model.embed_teacher(x).squeeze(0)
71
 
72
+ print(y_embed_native.shape)
73
  print(y_embed_teacher.shape)
74
  ```
75
 
76
+ ## Training
77
+
78
+ ### Clone the project repo
79
+
80
+ We'll need the code from the project repository to train and/or fine-tune the model.
81
+
82
+ ```sh
83
+ git clone https://github.com/andrewdalpino/ProtHash
84
+ ```
85
+
86
+ ### Install Project Dependencies
87
+
88
+ Project dependencies are specified in the requirements.txt file. You can install them with pip using the following command from the project root. We recommend using a virtual environment such as `venv` to keep package dependencies on your system tidy.
89
+
90
+ python -m venv ./.venv
91
+
92
+ source ./.venv/bin/activate
93
+
94
+ pip install -r requirements.txt
95
+
96
+ ### Distilling
97
+
98
+ ProtHash is trained to mimic the activations of its ESMC teacher model. To begin distillation with the default arguments check the example below.
99
+
100
+ ```sh
101
+ python train.py
102
+ ```
103
+
104
+ You can change the default arguments like in the example below.
105
+
106
+ ```sh
107
+ python train --teacher_name="esmc_300m" --max_steps=3500 --embedding_dimensions=256
108
+ ```
109
+
110
+ #### Training Dashboard
111
+
112
+ We use [TensorBoard](https://www.tensorflow.org/tensorboard) to capture and display training events such as loss and gradient norm updates. To launch the dashboard server run the following command from the terminal.
113
+
114
+ ```sh
115
+ tensorboard --logdir=./runs
116
+ ```
117
+
118
+ Then navigate to the dashboard using your favorite web browser.
119
+
120
+ #### Training Arguments
121
+
122
+ | Argument | Default | Type | Description |
123
+ |---|---|---|---|
124
+ | --teacher_name | 'esmc_600m' | str | The teacher model name. |
125
+ | --num_dataset_processes | 1 | int | The number of CPU processes to use to preprocess the dataset. |
126
+ | --min_sequence_length | 1 | int | The minimum length of the input sequences. |
127
+ | --max_sequence_length | 2048 | int | The maximum length of the input sequences. |
128
+ | --quantization_aware_training | False | bool | Should we add fake quantized tensors to simulate quantized training? |
129
+ | --batch_size | 4 | int | The number of training images to pass through the network at a time. |
130
+ | --gradient_accumulation_steps | 32 | int | The number of batches to pass through the network before updating the model weights. |
131
+ | --max_steps | 4200 | int | The number of steps to train for. |
132
+ | --learning_rate | 1e-4 | float | The learning rate of the AdamW optimizer. |
133
+ | --max_gradient_norm | 100.0 | float | Clip gradients above this threshold norm before stepping. |
134
+ | --temperature | 8.0 | float | The smoothing parameter of the activations - higher temperature results in smoother activations. |
135
+ | --embeddings_dimensions | 512 | int | The dimensionality of the native embeddings. |
136
+ | --q_heads | 16 | int | The number of query heads used in the self-attention layers. |
137
+ | --kv_heads | 4 | int | The number of key and value heads used in the self-attention layers. |
138
+ | --hidden_ratio | 2 | (1, 2, 4) | The ratio of hidden neurons to embedding dimensions in the feed-forward layers of the network.|
139
+ | --num_encoder_layers | 4 | int | The number of layers within the body of the encoder. |
140
+ | --dropout | 0.0 | float | The proportion of activations to send to zero during training as regularization. |
141
+ | --activation_checkpointing | False | bool | Should we use activation checkpointing? This will drastically reduce memory utilization during training at the cost of recomputing the forward pass. |
142
+ | --eval_interval | 100 | int | Evaluate the model after this many epochs on the testing set. |
143
+ | --checkpoint_interval | 100 | int | Save the model checkpoint to disk every this many epochs. |
144
+ | --checkpoint_path | "./checkpoints/checkpoint.pt" | str | The path to the base checkpoint file on disk. |
145
+ | --resume | False | bool | Should we resume training from the last checkpoint? |
146
+ | --run_dir_path | "./runs" | str | The path to the TensorBoard run directory for this training session. |
147
+ | --device | "cpu" | str | The device to run the computation on. |
148
+ | --seed | None | int | The seed for the random number generator. |
149
+
150
  ## References
151
 
152
  >- The UniProt Consortium, UniProt: the Universal Protein Knowledgebase in 2025, Nucleic Acids Research, 2025, 53, D609–D617.