wnagleiofficial
First model version
38adcf4
raw
history blame
1.75 kB
import torch
import esm
from argparse import Namespace
import pathlib
import urllib
def length_to_mask(length, max_len=None, dtype=None):
"""length: B.
return B x max_len.
If max_len is None, then max of length will be used.
"""
assert len(length.shape) == 1, 'Length shape should be 1 dimensional.'
max_len = max_len or length.max().item()
mask = torch.arange(max_len, device=length.device,
dtype=length.dtype).expand(len(length), max_len) < length.unsqueeze(1)
if dtype is not None:
mask = torch.as_tensor(mask, dtype=dtype, device=length.device)
return mask
def load_model_and_alphabet_core(args_dict, regression_data=None):
args_dict = torch.load(args_dict)
alphabet = esm.Alphabet.from_architecture(args_dict["args"].arch)
# upgrade state dict
pra = lambda s: "".join(s.split("decoder_")[1:] if "decoder" in s else s)
prs = lambda s: "".join(s.split("decoder.")[1:] if "decoder" in s else s)
model_args = {pra(arg[0]): arg[1] for arg in vars(args_dict["args"]).items()}
model_type = esm.ProteinBertModel
model = model_type(
Namespace(**model_args),
alphabet,
)
return model, alphabet
def load_hub_workaround(url):
try:
data = torch.hub.load_state_dict_from_url(url, progress=False, map_location="cpu")
except RuntimeError:
# Pytorch version issue - see https://github.com/pytorch/pytorch/issues/43106
fn = pathlib.Path(url).name
data = torch.load(
f"{torch.hub.get_dir()}/checkpoints/{fn}",
map_location="cpu",
)
except urllib.error.HTTPError as e:
raise Exception(f"Could not load {url}, check your network!")
return data