DeepPD-hf / DeepPD /predictor.py
xiaoleon's picture
initial submission
46b9840
from DeepPD.model import MyModel,DeepPD
import torch
import torch.nn as nn
from DeepPD.config import ArgsConfig
args = ArgsConfig()
softmax = nn.Softmax(1)
def predict(seqs,data,model_path,threshold=0.5, device=args.device):
with torch.no_grad():
model = DeepPD(vocab_size=21,embedding_size=args.embedding_size,esm_path=args.ems_path,layer_idx=args.esm_layer_idx,seq_len=args.max_len,dropout=args.dropout,
fan_layer_num=1,num_heads=8,encoder_layer_num=1,Contrastive_Learning=False,info_bottleneck=args.info_bottleneck).to(args.device)
model.eval()
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict,strict=False)
model.to(device)
seqs = seqs.to(device)
out,_,_ = model(seqs)
prob = softmax(out)[:,1]
final_out = []
for i, j in zip(data, prob):
temp = [i[0], i[1], f"{j:.3f}", 'Peptide' if j >threshold else 'Non-Peptide']
final_out.append(temp)
return final_out