Spaces:
Sleeping
Sleeping
| from DeepMFPP.model import DeepMFPP | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from DeepMFPP.config import ArgsConfig | |
| args = ArgsConfig() | |
| args.embedding_size = 480 | |
| args.aa_dict = 'esm' | |
| args.loss_fn_name = 'MLFDL' | |
| args.weight_decay = 0 | |
| args.batch_size = 192 | |
| args.dropout = 0.62 | |
| args.scale_factor = 100 | |
| args.fldl_pos_weight = 0.4 | |
| sigmoid = nn.Sigmoid() | |
| def predict(seqs:torch.Tensor,data:list,model_path:str, top_k:int=0,threshold:float=0.5, device=args.device): | |
| torch.manual_seed(args.random_seed) | |
| with torch.no_grad(): | |
| model = DeepMFPP(vocab_size=21,embedding_size=args.embedding_size, encoder_layer_num=1, fan_layer_num=1, num_heads=8,output_size=args.num_classes, | |
| esm_path=args.ems_path,layer_idx=args.esm_layer_idx,dropout=args.dropout,Contrastive_Learning=args.ctl).to(args.device) | |
| model.eval() | |
| state_dict = torch.load(model_path, map_location=device) | |
| model.load_state_dict(state_dict,strict=False) | |
| model.to(args.device) | |
| # print(device) | |
| seqs.to(args.device) | |
| _, logits = model(seqs) | |
| prob = sigmoid(logits) | |
| # logits = np.round(logits.cpu().numpy(),3) | |
| # prob = np.round(prob.cpu().numpy(),3) | |
| # logits = logits.cpu().numpy() | |
| prob = prob.cpu().numpy() | |
| # print(logits) | |
| # print(prob) | |
| categories = ['AAP', 'ABP', 'ACP', 'ACVP','ADP', 'AEP', 'AFP', 'AHIVP', 'AHP', 'AIP', 'AMRSAP', | |
| 'APP', 'ATP', 'AVP', 'BBP', 'BIP', 'CPP', 'DPPIP', 'QSP', 'SBP', 'THP'] | |
| final_out = [] | |
| for i, j, k in zip(data, logits, prob): | |
| temp = [i[0], i[1]] # , f"logits:{j}", f"probability:{k}" | |
| # 过滤概率值大于阈值的预测结果 | |
| result_dict = {} | |
| for label, p in zip(categories, k): | |
| # print(p) | |
| if p > threshold: | |
| result_dict[label] = round(float(p), 4) | |
| # 返回概率值大于阈值的字典对 | |
| # 示例: {'AVP': 0.567, 'ATP': 0.678, ...} | |
| if result_dict: | |
| sorted_result = {k: v for k, v in sorted(result_dict.items(), key=lambda item: item[1], reverse=True)} | |
| else: | |
| sorted_result = {} | |
| # print(sorted_result) | |
| if top_k: | |
| sorted_items_list = list(sorted_result.items()) | |
| top_k_result = dict(sorted_items_list[:top_k]) | |
| top_k_result_str = ", ".join(f"{key}: {value}" for key, value in top_k_result.items()) | |
| temp.extend([top_k_result_str]) | |
| else: | |
| sorted_result_str = ", ".join(f"{key}: {value}" for key, value in sorted_result.items()) | |
| temp.extend([sorted_result_str]) | |
| final_out.append(temp) | |
| return final_out |