Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import pytest | |
| from pathlib import Path | |
| import torch | |
| import esm | |
| # Directly from hubconf.py | |
| model_names = """ | |
| esm1_t6_43M_UR50S, | |
| esm1_t12_85M_UR50S, | |
| esm1_t34_670M_UR50S, | |
| esm1_t34_670M_UR50D, | |
| esm1_t34_670M_UR100, | |
| esm1b_t33_650M_UR50S, | |
| esm_msa1_t12_100M_UR50S, | |
| esm_msa1b_t12_100M_UR50S, | |
| esm1v_t33_650M_UR90S, | |
| esm1v_t33_650M_UR90S_1, | |
| esm1v_t33_650M_UR90S_2, | |
| esm1v_t33_650M_UR90S_3, | |
| esm1v_t33_650M_UR90S_4, | |
| esm1v_t33_650M_UR90S_5, | |
| esm_if1_gvp4_t16_142M_UR50, | |
| esm2_t6_8M_UR50D, | |
| esm2_t12_35M_UR50D, | |
| esm2_t30_150M_UR50D, | |
| esm2_t33_650M_UR50D, | |
| esm2_t36_3B_UR50D, | |
| esm2_t48_15B_UR50D | |
| """ | |
| model_names = [mn.strip() for mn in model_names.strip(" ,\n").split(",")] | |
| def test_load_hub_fwd_model(model_name: str) -> None: | |
| model, alphabet = getattr(esm.pretrained, model_name)() | |
| # batch_size = 2, seq_len = 3, tokens within vocab | |
| dummy_inp = torch.tensor([[0, 1, 2], [3, 4, 5]]) | |
| if "esm_msa" in model_name: | |
| dummy_inp = dummy_inp.unsqueeze(0) | |
| output = model(dummy_inp) # dict | |
| logits = output["logits"].squeeze(0) | |
| assert logits.shape == (2, 3, len(alphabet)) | |
| def test_load_local(model_name: str) -> None: | |
| # Assumes everything has already been loaded & cached. | |
| local_path = Path.home() / ".cache/torch/hub/checkpoints" / (model_name + ".pt") | |
| if model_name.endswith("esm1v_t33_650M_UR90S"): | |
| return # skip; needs to get rerouted to specific instance | |
| model, alphabet = esm.pretrained.load_model_and_alphabet_local(local_path) | |