from time import time import torch _current_milli_time = lambda: time() * 1000.0 # feeds batch to the predictor def feed(model, cur_batch): start_time = _current_milli_time() chain_name, chain, chain_pos, chain_seq_pos, chain_axes, batch_ids = cur_batch dev = next(model.parameters()).device # move the batch to the same device as model is in chain = chain.to(dev, torch.int64) chain_pos = chain_pos.to(dev, torch.float32) chain_seq_pos = chain_seq_pos.to(dev, torch.int64) chain_axes = chain_axes.to(dev, torch.float32) batch_ids = batch_ids.to(dev) time_data = _current_milli_time() - start_time logits = model( # [batch_size] of residues each: chain, # AA type (index) chain_pos, # coordinates in 3D (of a-Carbon?) chain_seq_pos, # seqential index chain_axes, # ??? axes (orthonormal) between consecutive residues batch_ids # chain ID mask ) return logits, time_data class Feeder: # enables feeding models with batches in objective way: model.feed(batch) instead of imperative feed(model, batch) # simplifies imports... def feed(self, *args): return feed(self, *args) class _Ensemble(torch.nn.Module, Feeder): def __init__(self, paths_or_n, base_nn: torch.nn.Module, mut_nn: torch.nn.Module, ): super().__init__() models = [] if type(paths_or_n) is int: for i in range(paths_or_n): models.append(mut_nn(base_nn('avg', True, False))) # todo: this should not be hardcoded else: for path in paths_or_n: conf_dict = torch.load(path) model = mut_nn( base_nn(conf_dict['gl_pooling'], bool(conf_dict['embed_aa']), conf_dict['embed_aa'] == 'learn')) model.load_state_dict(conf_dict["state_dict_prot_enc"]) model.eval() models.append(model) # all modules has to be properly registered as modules to be visible for PyTorch self.models = torch.nn.ModuleList( models) # jit.trace would not work without this with a crpytic message: RuntimeError: "Cannot insert a Tensor that requires grad as a constant" def forward(self, *args): preds = [m(*args) for m in self.models] preds = torch.cat(preds) return torch.mean(preds)