from DeepMFPP.model import DeepMFPP import torch import torch.nn as nn import numpy as np from DeepMFPP.config import ArgsConfig args = ArgsConfig() args.embedding_size = 480 args.aa_dict = 'esm' args.loss_fn_name = 'MLFDL' args.weight_decay = 0 args.batch_size = 192 args.dropout = 0.62 args.scale_factor = 100 args.fldl_pos_weight = 0.4 sigmoid = nn.Sigmoid() def predict(seqs:torch.Tensor,data:list,model_path:str, top_k:int=0,threshold:float=0.5, device=args.device): torch.manual_seed(args.random_seed) with torch.no_grad(): model = DeepMFPP(vocab_size=21,embedding_size=args.embedding_size, encoder_layer_num=1, fan_layer_num=1, num_heads=8,output_size=args.num_classes, esm_path=args.ems_path,layer_idx=args.esm_layer_idx,dropout=args.dropout,Contrastive_Learning=args.ctl).to(args.device) model.eval() state_dict = torch.load(model_path, map_location=device) model.load_state_dict(state_dict,strict=False) model.to(args.device) # print(device) seqs.to(args.device) _, logits = model(seqs) prob = sigmoid(logits) # logits = np.round(logits.cpu().numpy(),3) # prob = np.round(prob.cpu().numpy(),3) # logits = logits.cpu().numpy() prob = prob.cpu().numpy() # print(logits) # print(prob) categories = ['AAP', 'ABP', 'ACP', 'ACVP','ADP', 'AEP', 'AFP', 'AHIVP', 'AHP', 'AIP', 'AMRSAP', 'APP', 'ATP', 'AVP', 'BBP', 'BIP', 'CPP', 'DPPIP', 'QSP', 'SBP', 'THP'] final_out = [] for i, j, k in zip(data, logits, prob): temp = [i[0], i[1]] # , f"logits:{j}", f"probability:{k}" # 过滤概率值大于阈值的预测结果 result_dict = {} for label, p in zip(categories, k): # print(p) if p > threshold: result_dict[label] = round(float(p), 4) # 返回概率值大于阈值的字典对 # 示例: {'AVP': 0.567, 'ATP': 0.678, ...} if result_dict: sorted_result = {k: v for k, v in sorted(result_dict.items(), key=lambda item: item[1], reverse=True)} else: sorted_result = {} # print(sorted_result) if top_k: sorted_items_list = list(sorted_result.items()) top_k_result = dict(sorted_items_list[:top_k]) top_k_result_str = ", ".join(f"{key}: {value}" for key, value in top_k_result.items()) temp.extend([top_k_result_str]) else: sorted_result_str = ", ".join(f"{key}: {value}" for key, value in sorted_result.items()) temp.extend([sorted_result_str]) final_out.append(temp) return final_out