Confusion about ESM-650M's Output Layer: Independent Decoder vs. Weight Tying

#1
by acow - opened

Hello,

I have a question regarding the architecture of ESM2-650M when fine-tuning for logits output.

When I load the model using model_path = 'Synthyra/ESM2-650M', I can see a dedicated layer, lm_head.decoder.weight with the shape torch.Size([33, 1280]), which projects the 1280-dimension hidden state to 33-dimension logits.

However, when I load the official model_path = 'facebook/esm2_t33_650M_UR50D', there is no such dedicated layer for this projection. My research indicates that for the official Facebook version, the weights for the final logits output are tied to the input vocabulary embeddings (esm.embeddings.word_embeddings.weight).

This seems contradictory. Could you clarify this? For the final logits projection, is there a separate, independent decoder, or does it reuse the weights from the initial vocabulary embedding layer?

Synthyra org

Hello @acow ,

You are correct ESM2-650, from the original or our version, has tied weights. So the embedding layer and final projection are shared during training. This is called weight tying, and is very common in training BERT models. There are various advantages and disadvantages of doing this, there is some literature on the topic. The original release of ESM2 was probably using the native PyTorch weight saving, which would save them as separate tensors. Nowadays, Huggingface uses safetensors which tries to not keep duplicates weights.

So, for clarity, there is a decoder that maps the last hidden state to the logits. That decoder happens to share weights with the token embedding layer.

Hope this helps.

  • Logan

Sign up or log in to comment