File size: 1,483 Bytes
b140e2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch

from code.data_preprocessing import process_pdb, mutate_seq
from code.data_loader import create_datapoint, collate_batch


class EnsemblePredictor:
	
	def __init__(self,

		weights: str = "model_weights.pth", # trained model

		version: str = None,                # model version: '21_1'|...

	):
		conf_dict = torch.load(weights)
		
		if version is None: # default model
			from code.model_py import Ensemble
		elif version == '21_2':
			from code.model_py.v21_2 import Ensemble
		elif version == '21_1':
			from code.model_py.v21_1 import Ensemble
		else:
			raise 'Non-existing version!'
		
		pred_model = Ensemble(5)
		pred_model.load_state_dict(conf_dict)
		pred_model.eval()
		
		self._model = pred_model

	def predict_change(self,

		PDB_path: str,

		chain: str,

		aa_from: list,

		locs: list,

		aa_to: list

	):
		assert len(aa_from) == len(locs) == len(aa_to)
		
		print("Processing structure...")
		orig_seq, coords = process_pdb(PDB_path, chain)
		mut_seq = mutate_seq(orig_seq, aa_from, locs, aa_to)
		
		print("Calculating prediction...")
		protein = create_datapoint(
			'wt',
			orig_seq,
			coords
		)
		mutant = create_datapoint(
			'mut',
			mut_seq,
			coords
		)
		
		data = collate_batch([protein, mutant])[:-2]
		prediction = float(self._model.feed(data)[0])
		prediction_cl = 'N' if round(prediction, 2) == 0.5 else\
					'+' if prediction > 0.5 else '-'
		
		return prediction_cl, prediction