DeepMFPP-hf / DeepMFPP /predictor.py
xiaoleon's picture
initial submission
2d48951
from DeepMFPP.model import DeepMFPP
import torch
import torch.nn as nn
import numpy as np
from DeepMFPP.config import ArgsConfig
args = ArgsConfig()
args.embedding_size = 480
args.aa_dict = 'esm'
args.loss_fn_name = 'MLFDL'
args.weight_decay = 0
args.batch_size = 192
args.dropout = 0.62
args.scale_factor = 100
args.fldl_pos_weight = 0.4
sigmoid = nn.Sigmoid()
def predict(seqs:torch.Tensor,data:list,model_path:str, top_k:int=0,threshold:float=0.5, device=args.device):
torch.manual_seed(args.random_seed)
with torch.no_grad():
model = DeepMFPP(vocab_size=21,embedding_size=args.embedding_size, encoder_layer_num=1, fan_layer_num=1, num_heads=8,output_size=args.num_classes,
esm_path=args.ems_path,layer_idx=args.esm_layer_idx,dropout=args.dropout,Contrastive_Learning=args.ctl).to(args.device)
model.eval()
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict,strict=False)
model.to(args.device)
# print(device)
seqs.to(args.device)
_, logits = model(seqs)
prob = sigmoid(logits)
# logits = np.round(logits.cpu().numpy(),3)
# prob = np.round(prob.cpu().numpy(),3)
# logits = logits.cpu().numpy()
prob = prob.cpu().numpy()
# print(logits)
# print(prob)
categories = ['AAP', 'ABP', 'ACP', 'ACVP','ADP', 'AEP', 'AFP', 'AHIVP', 'AHP', 'AIP', 'AMRSAP',
'APP', 'ATP', 'AVP', 'BBP', 'BIP', 'CPP', 'DPPIP', 'QSP', 'SBP', 'THP']
final_out = []
for i, j, k in zip(data, logits, prob):
temp = [i[0], i[1]] # , f"logits:{j}", f"probability:{k}"
# 过滤概率值大于阈值的预测结果
result_dict = {}
for label, p in zip(categories, k):
# print(p)
if p > threshold:
result_dict[label] = round(float(p), 4)
# 返回概率值大于阈值的字典对
# 示例: {'AVP': 0.567, 'ATP': 0.678, ...}
if result_dict:
sorted_result = {k: v for k, v in sorted(result_dict.items(), key=lambda item: item[1], reverse=True)}
else:
sorted_result = {}
# print(sorted_result)
if top_k:
sorted_items_list = list(sorted_result.items())
top_k_result = dict(sorted_items_list[:top_k])
top_k_result_str = ", ".join(f"{key}: {value}" for key, value in top_k_result.items())
temp.extend([top_k_result_str])
else:
sorted_result_str = ", ".join(f"{key}: {value}" for key, value in sorted_result.items())
temp.extend([sorted_result_str])
final_out.append(temp)
return final_out