File size: 2,749 Bytes
2d48951
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from DeepMFPP.model import DeepMFPP
import torch
import torch.nn as nn
import numpy as np
from DeepMFPP.config import ArgsConfig

args = ArgsConfig()
args.embedding_size = 480
args.aa_dict = 'esm'
args.loss_fn_name = 'MLFDL'
args.weight_decay = 0
args.batch_size = 192
args.dropout = 0.62
args.scale_factor = 100
args.fldl_pos_weight = 0.4

sigmoid = nn.Sigmoid()
def predict(seqs:torch.Tensor,data:list,model_path:str, top_k:int=0,threshold:float=0.5, device=args.device):
    torch.manual_seed(args.random_seed)
    with torch.no_grad():
        model = DeepMFPP(vocab_size=21,embedding_size=args.embedding_size, encoder_layer_num=1, fan_layer_num=1, num_heads=8,output_size=args.num_classes,
                  esm_path=args.ems_path,layer_idx=args.esm_layer_idx,dropout=args.dropout,Contrastive_Learning=args.ctl).to(args.device)
        model.eval()
        state_dict = torch.load(model_path, map_location=device)
        model.load_state_dict(state_dict,strict=False)
        model.to(args.device)
        # print(device)
        seqs.to(args.device)
        _, logits = model(seqs)
        prob = sigmoid(logits)
        # logits = np.round(logits.cpu().numpy(),3)
        # prob = np.round(prob.cpu().numpy(),3)
        # logits = logits.cpu().numpy()
        prob = prob.cpu().numpy()
        # print(logits)
        # print(prob)
    categories = ['AAP', 'ABP', 'ACP', 'ACVP','ADP', 'AEP', 'AFP', 'AHIVP', 'AHP', 'AIP', 'AMRSAP', 
                    'APP', 'ATP', 'AVP', 'BBP', 'BIP', 'CPP', 'DPPIP', 'QSP', 'SBP', 'THP']
    final_out = []
    for i, j, k in zip(data, logits, prob):
        temp = [i[0], i[1]]  # , f"logits:{j}", f"probability:{k}"

        # 过滤概率值大于阈值的预测结果
        result_dict = {}
        for label, p in zip(categories, k):
            # print(p)
            if p > threshold:
                result_dict[label] = round(float(p), 4)
        
        # 返回概率值大于阈值的字典对
        # 示例: {'AVP': 0.567, 'ATP': 0.678, ...}
        if result_dict:
            sorted_result = {k: v for k, v in sorted(result_dict.items(), key=lambda item: item[1], reverse=True)}
        else:
            sorted_result = {}
        # print(sorted_result)

        if top_k: 
            sorted_items_list = list(sorted_result.items())
            top_k_result = dict(sorted_items_list[:top_k])
            top_k_result_str = ", ".join(f"{key}: {value}" for key, value in top_k_result.items())
            temp.extend([top_k_result_str])
            
        else:
            sorted_result_str = ", ".join(f"{key}: {value}" for key, value in sorted_result.items())
            temp.extend([sorted_result_str])
        
        final_out.append(temp)
            
    return final_out