"""Load official ESM2 model from HuggingFace transformers for comparison.""" import torch import torch.nn as nn from transformers import EsmForMaskedLM, EsmTokenizer class _OfficialESM2ForwardWrapper(nn.Module): def __init__(self, model: EsmForMaskedLM): super().__init__() self.model = model self.tokenizer = EsmTokenizer.from_pretrained(model.config._name_or_path) def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs): outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, ) return outputs def load_official_model( reference_repo_id: str, device: torch.device, dtype: torch.dtype = torch.float32, ) -> tuple[nn.Module, EsmTokenizer]: """Load the official HuggingFace ESM2 model. Returns (wrapped_model, tokenizer). The wrapped model's forward returns standard HF outputs with hidden_states. """ model = EsmForMaskedLM.from_pretrained( reference_repo_id, device_map=device, dtype=dtype, attn_implementation="sdpa", position_embedding_type="rotary", ).eval() tokenizer = EsmTokenizer.from_pretrained(reference_repo_id) wrapped = _OfficialESM2ForwardWrapper(model) return wrapped, tokenizer if __name__ == "__main__": model, tokenizer = load_official_model(reference_repo_id="facebook/esm2_t6_8M_UR50D", device=torch.device("cuda"), dtype=torch.float32) print(model) print(tokenizer)