File size: 1,315 Bytes
7968cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import esm
import torch.nn as nn
import torch
from esm.inverse_folding.util import CoordBatchConverter

class PretrainESMIF_Model(nn.Module):
    def __init__(self):
        super(PretrainESMIF_Model, self).__init__()
        #  /root/.cache/torch/hub/checkpoints
        model_data = torch.load("./model_zoo/esmif/esm_if1_gvp4_t16_142M_UR50.pt")
        self.model, self.alphabet = esm.pretrained.load_model_and_alphabet_core("esm_if1_gvp4_t16_142M_UR50", model_data, None)
    
    def forward(self, coords_list):
        self.model.eval()
        batch_converter = CoordBatchConverter(self.model.decoder.dictionary)
        batch_coords, confidence, _, _, padding_mask = (
            batch_converter([(coord, None, None) for coord in coords_list], device=coords_list[0].device)
        )
        with torch.no_grad():
            encoder_out = self.model.encoder(batch_coords, padding_mask, confidence)
            
        feat = encoder_out['encoder_out'][0].permute(1,0,2)[:,1:-1] # 2,1046-2,512
        attention_mask = encoder_out['encoder_padding_mask'][0][:,1:-1]==False # 2,1046-2
        
        return {"feat":feat}

if __name__ == '__main__':
    model = PretrainESMIF_Model(0.1)
    coords1 = torch.rand(1044,3,3)#N, CA, C
    coords2 = torch.rand(500,3,3)
    model([coords1, coords2])
    print()