Spaces:
Running
Running
File size: 1,835 Bytes
e76b79a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | # Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import pytest
from pathlib import Path
import torch
import esm
# Directly from hubconf.py
model_names = """
esm1_t6_43M_UR50S,
esm1_t12_85M_UR50S,
esm1_t34_670M_UR50S,
esm1_t34_670M_UR50D,
esm1_t34_670M_UR100,
esm1b_t33_650M_UR50S,
esm_msa1_t12_100M_UR50S,
esm_msa1b_t12_100M_UR50S,
esm1v_t33_650M_UR90S,
esm1v_t33_650M_UR90S_1,
esm1v_t33_650M_UR90S_2,
esm1v_t33_650M_UR90S_3,
esm1v_t33_650M_UR90S_4,
esm1v_t33_650M_UR90S_5,
esm_if1_gvp4_t16_142M_UR50,
esm2_t6_8M_UR50D,
esm2_t12_35M_UR50D,
esm2_t30_150M_UR50D,
esm2_t33_650M_UR50D,
esm2_t36_3B_UR50D,
esm2_t48_15B_UR50D
"""
model_names = [mn.strip() for mn in model_names.strip(" ,\n").split(",")]
@pytest.mark.parametrize("model_name", model_names)
def test_load_hub_fwd_model(model_name: str) -> None:
model, alphabet = getattr(esm.pretrained, model_name)()
# batch_size = 2, seq_len = 3, tokens within vocab
dummy_inp = torch.tensor([[0, 1, 2], [3, 4, 5]])
if "esm_msa" in model_name:
dummy_inp = dummy_inp.unsqueeze(0)
output = model(dummy_inp) # dict
logits = output["logits"].squeeze(0)
assert logits.shape == (2, 3, len(alphabet))
@pytest.mark.parametrize("model_name", model_names)
def test_load_local(model_name: str) -> None:
# Assumes everything has already been loaded & cached.
local_path = Path.home() / ".cache/torch/hub/checkpoints" / (model_name + ".pt")
if model_name.endswith("esm1v_t33_650M_UR90S"):
return # skip; needs to get rerouted to specific instance
model, alphabet = esm.pretrained.load_model_and_alphabet_local(local_path)
|