SoluProtMutDemo / code /predictor.py
vvelda's picture
Initial commit
b140e2c verified
import torch
from code.data_preprocessing import process_pdb, mutate_seq
from code.data_loader import create_datapoint, collate_batch
class EnsemblePredictor:
def __init__(self,
weights: str = "model_weights.pth", # trained model
version: str = None, # model version: '21_1'|...
):
conf_dict = torch.load(weights)
if version is None: # default model
from code.model_py import Ensemble
elif version == '21_2':
from code.model_py.v21_2 import Ensemble
elif version == '21_1':
from code.model_py.v21_1 import Ensemble
else:
raise 'Non-existing version!'
pred_model = Ensemble(5)
pred_model.load_state_dict(conf_dict)
pred_model.eval()
self._model = pred_model
def predict_change(self,
PDB_path: str,
chain: str,
aa_from: list,
locs: list,
aa_to: list
):
assert len(aa_from) == len(locs) == len(aa_to)
print("Processing structure...")
orig_seq, coords = process_pdb(PDB_path, chain)
mut_seq = mutate_seq(orig_seq, aa_from, locs, aa_to)
print("Calculating prediction...")
protein = create_datapoint(
'wt',
orig_seq,
coords
)
mutant = create_datapoint(
'mut',
mut_seq,
coords
)
data = collate_batch([protein, mutant])[:-2]
prediction = float(self._model.feed(data)[0])
prediction_cl = 'N' if round(prediction, 2) == 0.5 else\
'+' if prediction > 0.5 else '-'
return prediction_cl, prediction