Lower results?

#1
by GrimSqueaker - opened

Hi!

  1. What was changed in the recent model/fast PLMs updates? use of normal and not just flex attention? Anything else

  2. I made a bunch of hacks to run the models with HF trainer in the past (automodel got conniptions when trying to load by default for finetuning on protein classification). I don't know if the (1) changes mean I don't need the custom hacks (last hidden layer etc), but I found much lower eval results on the same tasks using the synthyra model vs the bog standard esm2 model. I'm probably missing something, but are there any changes to be aware of? (I also have oddly bad results with esmc, but I didn't run their baseline model to compare, due to the faff with their stupid package).

Thanks!

Synthyra org

Hey @GrimSqueaker ,

Thanks for your message.

We are in the process of updating the FastPLMs package this week and are working through some bugs. I've confirmed the weights are correct but something is currently off with the implementation - the hidden states do not match.

I suspect this will be fixed today and you can use equivalent ESM2, ESMC, ANKH, DPLM, DPLM2, and Boltz2 efficiently with nice to have functions very soon.

Best,
Logan

Synthyra org

Hey @GrimSqueaker ,

Turns out the official ESM2 implementation in Huggingface Transformers has been broken for a bit, see https://github.com/huggingface/transformers/issues/44162.

FastPLMs implementation of ESM2 should be correct, and the weights match perfectly. Here's a sanity check I ran in addition to our standard tests:

import torch
import random
from torch.nn.functional import mse_loss
from tqdm import tqdm
from transformers import EsmForMaskedLM, EsmTokenizer, AutoModelForMaskedLM

CANONICAL_AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
TEST_NUMBER_BATCHES = 10
BATCH_SIZE = 4
MIN_SEQUENCE_LENGTH = 16
MAX_SEQUENCE_LENGTH = 64
OFFICIAL_MODEL_PATH = "facebook/esm2_t6_8M_UR50D"
FAST_MODEL_PATH = "Synthyra/ESM2-8M"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

official_model = EsmForMaskedLM.from_pretrained(
    OFFICIAL_MODEL_PATH,
    dtype=torch.float32,
    device_map=DEVICE,
    attn_implementation="sdpa",
    position_embedding_type="rotary",
    force_download=True
).eval()

fast_model = AutoModelForMaskedLM.from_pretrained(
    FAST_MODEL_PATH,
    dtype=torch.float32,
    device_map=DEVICE,
    trust_remote_code=True,
    force_download=True
).eval()
fast_model.attn_backend = "sdpa"

for (official_name, official_param), (fast_name, fast_param) in zip(official_model.state_dict().items(), fast_model.state_dict().items()):
    if official_name == fast_name:
        diff = mse_loss(official_param, fast_param).item()
        print(f"{official_name}: {diff}")
    else:
        print(f"Name mismatch: {official_name} != {fast_name}")

tokenizer = EsmTokenizer.from_pretrained(OFFICIAL_MODEL_PATH)

def generate_random_sequence(length: int) -> str:
    return 'M' + "".join(random.choices(CANONICAL_AMINO_ACIDS, k=length))

def generate_random_batch(batch_size: int, min_length: int, max_length: int) -> list[str]:
    return [generate_random_sequence(random.randint(min_length, max_length)) for _ in range(batch_size)]

cumulative_last_hidden_state_mse = 0
cumulative_logits_mse = 0
cumulative_preds_accuracy = 0

with torch.inference_mode():
    for _ in tqdm(range(TEST_NUMBER_BATCHES)):
        batch = generate_random_batch(BATCH_SIZE, MIN_SEQUENCE_LENGTH, MAX_SEQUENCE_LENGTH)
        tokenized = tokenizer(batch, return_tensors="pt", padding=True)
        tokenized = {k: v.to(DEVICE) for k, v in tokenized.items()}
        official_output = official_model(**tokenized, output_hidden_states=True)
        official_last_hidden_state = official_output.hidden_states[-1].detach().cpu()
        official_logits = official_output.logits.detach().cpu()
        official_preds = official_logits.argmax(dim=-1)
        
        fast_output = fast_model(**tokenized, output_hidden_states=True)
        fast_last_hidden_state = fast_output.hidden_states[-1].detach().cpu()
        fast_logits = fast_output.logits.detach().cpu()
        fast_preds = fast_logits.argmax(dim=-1)
        
        #assert torch.allclose(official_last_hidden_state, fast_last_hidden_state, atol=1e-3), "Last hidden state mismatch"
        #assert torch.allclose(official_logits, fast_logits, atol=1e-3), "Logits mismatch"
        #assert torch.allclose(official_preds, fast_preds, atol=1e-3), "Preds mismatch"

        cumulative_last_hidden_state_mse += mse_loss(official_last_hidden_state, fast_last_hidden_state)
        cumulative_logits_mse += mse_loss(official_logits, fast_logits)
        cumulative_preds_accuracy += (official_preds == fast_preds).float().mean()

print(f"Average last hidden state MSE: {cumulative_last_hidden_state_mse / TEST_NUMBER_BATCHES}")
print(f"Average logits MSE: {cumulative_logits_mse / TEST_NUMBER_BATCHES}")
print(f"Average preds accuracy: {cumulative_preds_accuracy / TEST_NUMBER_BATCHES}")

With the attention_mask issue addressed:

esm.embeddings.word_embeddings.weight: 0.0
esm.encoder.layer.0.attention.self.query.weight: 0.0
esm.encoder.layer.0.attention.self.query.bias: 0.0
esm.encoder.layer.0.attention.self.key.weight: 0.0
esm.encoder.layer.0.attention.self.key.bias: 0.0
esm.encoder.layer.0.attention.self.value.weight: 0.0
esm.encoder.layer.0.attention.self.value.bias: 0.0
esm.encoder.layer.0.attention.self.rotary_embeddings.inv_freq: 0.0
esm.encoder.layer.0.attention.output.dense.weight: 0.0
esm.encoder.layer.0.attention.output.dense.bias: 0.0
etc...
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 10/10 [00:00<00:00, 29.43it/s]
Average last hidden state MSE: 0.0
Average logits MSE: 0.0
Average preds accuracy: 1.0

Compared to the current transformers version (5.2.0):

esm.embeddings.word_embeddings.weight: 0.0
esm.encoder.layer.0.attention.self.query.weight: 0.0
esm.encoder.layer.0.attention.self.query.bias: 0.0
esm.encoder.layer.0.attention.self.key.weight: 0.0
esm.encoder.layer.0.attention.self.key.bias: 0.0
esm.encoder.layer.0.attention.self.value.weight: 0.0
esm.encoder.layer.0.attention.self.value.bias: 0.0
esm.encoder.layer.0.attention.self.rotary_embeddings.inv_freq: 0.0
esm.encoder.layer.0.attention.output.dense.weight: 0.0
esm.encoder.layer.0.attention.output.dense.bias: 0.0
etc...
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 10/10 [00:00<00:00, 29.43it/s]
Average last hidden state MSE: 0.01390768587589264
Average logits MSE: 2.558210611343384
Average preds accuracy: 0.76673823595047

We will continue improving FastPLMs and correcting transformers. Please let me know if you have any other questions or run into any issues.
Best,
Logan

Thanks!
(So I shouldn't wait in changes in your versions? )

Sign up or log in to comment