Spaces:
Running
on
Zero
Running
on
Zero
| import esm | |
| import torch.nn as nn | |
| import torch | |
| from esm.inverse_folding.util import CoordBatchConverter | |
| class PretrainESMIF_Model(nn.Module): | |
| def __init__(self): | |
| super(PretrainESMIF_Model, self).__init__() | |
| # /root/.cache/torch/hub/checkpoints | |
| model_data = torch.load("./model_zoo/esmif/esm_if1_gvp4_t16_142M_UR50.pt") | |
| self.model, self.alphabet = esm.pretrained.load_model_and_alphabet_core("esm_if1_gvp4_t16_142M_UR50", model_data, None) | |
| def forward(self, coords_list): | |
| self.model.eval() | |
| batch_converter = CoordBatchConverter(self.model.decoder.dictionary) | |
| batch_coords, confidence, _, _, padding_mask = ( | |
| batch_converter([(coord, None, None) for coord in coords_list], device=coords_list[0].device) | |
| ) | |
| with torch.no_grad(): | |
| encoder_out = self.model.encoder(batch_coords, padding_mask, confidence) | |
| feat = encoder_out['encoder_out'][0].permute(1,0,2)[:,1:-1] # 2,1046-2,512 | |
| attention_mask = encoder_out['encoder_padding_mask'][0][:,1:-1]==False # 2,1046-2 | |
| return {"feat":feat} | |
| if __name__ == '__main__': | |
| model = PretrainESMIF_Model(0.1) | |
| coords1 = torch.rand(1044,3,3)#N, CA, C | |
| coords2 = torch.rand(500,3,3) | |
| model([coords1, coords2]) | |
| print() |