Spaces:
Build error
Build error
File size: 1,483 Bytes
b140e2c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | import torch
from code.data_preprocessing import process_pdb, mutate_seq
from code.data_loader import create_datapoint, collate_batch
class EnsemblePredictor:
def __init__(self,
weights: str = "model_weights.pth", # trained model
version: str = None, # model version: '21_1'|...
):
conf_dict = torch.load(weights)
if version is None: # default model
from code.model_py import Ensemble
elif version == '21_2':
from code.model_py.v21_2 import Ensemble
elif version == '21_1':
from code.model_py.v21_1 import Ensemble
else:
raise 'Non-existing version!'
pred_model = Ensemble(5)
pred_model.load_state_dict(conf_dict)
pred_model.eval()
self._model = pred_model
def predict_change(self,
PDB_path: str,
chain: str,
aa_from: list,
locs: list,
aa_to: list
):
assert len(aa_from) == len(locs) == len(aa_to)
print("Processing structure...")
orig_seq, coords = process_pdb(PDB_path, chain)
mut_seq = mutate_seq(orig_seq, aa_from, locs, aa_to)
print("Calculating prediction...")
protein = create_datapoint(
'wt',
orig_seq,
coords
)
mutant = create_datapoint(
'mut',
mut_seq,
coords
)
data = collate_batch([protein, mutant])[:-2]
prediction = float(self._model.feed(data)[0])
prediction_cl = 'N' if round(prediction, 2) == 0.5 else\
'+' if prediction > 0.5 else '-'
return prediction_cl, prediction |