File size: 789 Bytes
38adcf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
from .model import EsmModel
from .utils import load_hub_workaround
import torch
MODEL_URL = "https://zenodo.org/record/7042286/files/model.pth"
def predict(peptide_list, device='cpu'):
with torch.no_grad():
neuroPred_model = EsmModel()
neuroPred_model.eval()
state_dict = load_hub_workaround(MODEL_URL)
# state_dict = torch.load("/mnt/d/protein-net/Neuropep-ESM/model.pth", map_location="cpu")
neuroPred_model.load_state_dict(state_dict)
neuroPred_model = neuroPred_model.to(device)
prob, att = neuroPred_model(peptide_list, device)
pred = torch.argmax(prob, dim=-1).cpu().tolist()
att = att.cpu().numpy()
out = {i[0]:[j,m[:, :len(i[1])]] for i, j, m in zip(peptide_list, pred, att)}
return out
|