File size: 2,612 Bytes
7968cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import os.path as osp
from src.models.pifold_model import PiFold_Model
import torch.nn.functional as F


class PretrainPiFold_Model(PiFold_Model):
    def __init__(self, args, **kwargs):
        """ Graph labeling network """
        PiFold_Model.__init__(self, args)
        if args.augment_eps>0:
            pretrain_pifold_path = osp.join(self.args.res_dir, self.args.dataset, f"PiFold_{args.augment_eps}", "checkpoint.pth")
        else:
            # pretrain_pifold_path = osp.join(self.args.res_dir, self.args.dataset, "PiFold", "checkpoint.pth")
            pretrain_pifold_path = osp.join('model_zoo', self.args.dataset, "PiFold", "checkpoint.pth")
        self.load_state_dict(torch.load(pretrain_pifold_path))
    
    @torch.no_grad()
    def forward(self, batch):
        h_V, h_P, P_idx, batch_id = batch['h_V'], batch['h_E'], batch['E_idx'], batch['batch_id']
        device = h_V.device
        h_V = self.W_v(self.norm_nodes(self.node_embedding(h_V)))
        h_P = self.W_e(self.norm_edges(self.edge_embedding(h_P)))
        
        h_V, h_P = self.encoder(h_V, h_P, P_idx, batch_id)
        log_probs, logits = self.decoder(h_V, batch_id)
        probs = F.softmax(logits, dim=-1)
        conf, pred_id = probs.max(dim=-1)
        
        maxL = 0
        for b in batch_id.unique():
            mask = batch_id==b
            L = mask.sum()
            if L>maxL:
                maxL=L
        
        confs = []
        seqs = []
        embeds = []
        probs2 = []
        for b in batch_id.unique():
            mask = batch_id==b
            # elements = [alphabet[int(id)] for id in pred_id[mask]]
            elements = self.tokenizer.decode(pred_id[mask]).split(" ")
            seqs.append(elements)
            confs.append(conf[mask])
            embeds.append(h_V[mask])
            probs2.append(probs[mask])
        
        seqs = self.tokenizer(["".join(one) for one in seqs], padding=True, truncation=True, return_tensors='pt', add_special_tokens=False)
        confs = torch.stack([F.pad(one, (0, maxL-len(one))) for one in confs])
        embeds = torch.stack([F.pad(one, (0,0, 0, maxL-len(one))) for one in embeds])
        probs2 = torch.stack([F.pad(one, (0,0, 0, maxL-len(one)), value=1/33) for one in probs2])
        
        ret = {"pred_ids":seqs['input_ids'].to(device),
               "confs":confs,
               "embeds":embeds,
               "probs":probs2,
               "attention_mask":seqs['attention_mask'].to(device),
               "E_idx":P_idx,
               "batch_id":batch_id,
               "h_E":h_P}
        return ret