File size: 2,243 Bytes
b140e2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from time import time

import torch

_current_milli_time = lambda: time() * 1000.0

# feeds batch to the predictor
def feed(model, cur_batch):
	start_time = _current_milli_time()
	chain_name, chain, chain_pos, chain_seq_pos, chain_axes, batch_ids = cur_batch

	dev = next(model.parameters()).device
	# move the batch to the same device as model is in
	chain         =         chain.to(dev, torch.int64)
	chain_pos     =     chain_pos.to(dev, torch.float32)
	chain_seq_pos = chain_seq_pos.to(dev, torch.int64)
	chain_axes    =    chain_axes.to(dev, torch.float32)
	batch_ids     =     batch_ids.to(dev)

	time_data = _current_milli_time() - start_time

	logits = model( # [batch_size] of residues each:
		chain,         # AA type (index)
		chain_pos,     # coordinates in 3D (of a-Carbon?)
		chain_seq_pos, # seqential index
		chain_axes,    # ??? axes (orthonormal) between consecutive residues
		batch_ids      # chain ID mask
	)
	
	return logits, time_data

class Feeder:
	# enables feeding models with batches in objective way: model.feed(batch) instead of imperative feed(model, batch)
	# simplifies imports...
	def feed(self, *args):
		return feed(self, *args)


class _Ensemble(torch.nn.Module, Feeder):

	def __init__(self,

		paths_or_n,

		base_nn: torch.nn.Module,

		mut_nn:  torch.nn.Module,

	):
		super().__init__()
		models = []

		if type(paths_or_n) is int:
			for i in range(paths_or_n):
				models.append(mut_nn(base_nn('avg', True, False))) # todo: this should not be hardcoded
		else:
			for path in paths_or_n:
				conf_dict = torch.load(path)
				model = mut_nn(
					base_nn(conf_dict['gl_pooling'], bool(conf_dict['embed_aa']), conf_dict['embed_aa'] == 'learn'))
				model.load_state_dict(conf_dict["state_dict_prot_enc"])
				model.eval()
				models.append(model)
		# all modules has to be properly registered as modules to be visible for PyTorch
		self.models = torch.nn.ModuleList(
			models)  # jit.trace would not work without this with a crpytic message: RuntimeError: "Cannot insert a Tensor that requires grad as a constant"

	def forward(self, *args):
		preds = [m(*args) for m in self.models]
		preds = torch.cat(preds)
		return torch.mean(preds)