Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| """ | |
| Author: Philipp Seidl | |
| ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning | |
| Johannes Kepler University Linz | |
| Contact: seidl@ml.jku.at | |
| Model related functionality | |
| """ | |
| from .utils import top_k_accuracy | |
| from .plotutils import plot_loss, plot_topk, plot_nte | |
| from .molutils import convert_smiles_to_fp | |
| import os | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from collections import defaultdict | |
| from scipy import sparse | |
| import logging | |
| from tqdm import tqdm | |
| import wandb | |
| log = logging.getLogger(__name__) | |
| class ChemRXNDataset(torch.utils.data.Dataset): | |
| "Torch Dataset for ChemRXN containing Xs: the input as np array, target: the target molecules (or nothing), and ys: the label" | |
| def __init__(self, Xs, target, ys, is_smiles=False, fp_size=2048, fingerprint_type='morgan'): | |
| self.is_smiles=is_smiles | |
| if is_smiles: | |
| self.Xs = Xs | |
| self.target = target | |
| self.fp_size = fp_size | |
| self.fingerprint_type = fingerprint_type | |
| else: | |
| self.Xs = Xs.astype(np.float32) | |
| self.target = target.astype(np.float32) | |
| self.ys = ys | |
| self.ys_is_sparse = isinstance(self.ys, sparse.csr.csr_matrix) | |
| def __getitem__(self, k): | |
| mol_fp = self.Xs[k] | |
| if self.is_smiles: | |
| mol_fp = convert_smiles_to_fp(mol_fp, fp_size=self.fp_size, which=self.fingerprint_type).astype(np.float32) | |
| target = None if self.target is None else self.target[k] | |
| if self.is_smiles and self.target: | |
| target = convert_smiles_to_fp(target, fp_size=self.fp_size, which=self.fingerprint_type).astype(np.float32) | |
| label = self.ys[k] | |
| if isinstance(self.ys, sparse.csr.csr_matrix): | |
| label = label.toarray()[0] | |
| return (mol_fp, target, label) | |
| def __len__(self): | |
| return len(self.Xs) | |
| class ModelConfig(object): | |
| def __init__(self, **kwargs): | |
| self.fingerprint_type = kwargs.pop("fingerprint_type", 'morgan') | |
| self.template_fp_type = kwargs.pop("template_fp_type", 'rdk') | |
| self.num_templates = kwargs.pop("num_templates", 401) | |
| self.fp_size = kwargs.pop("fp_size", 2048) | |
| self.fp_radius = kwargs.pop("fp_radius", 4) | |
| self.device = kwargs.pop("device", 'cuda' if torch.cuda.is_available() else 'cpu') | |
| self.batch_size = kwargs.pop("batch_size", 32) | |
| self.pooling_operation_state_embedding = kwargs.pop('pooling_operation_state_embedding', 'mean') | |
| self.pooling_operation_head = kwargs.pop('pooling_operation_head', 'max') | |
| self.dropout = kwargs.pop('dropout', 0.0) | |
| self.lr = kwargs.pop('lr', 1e-4) | |
| self.optimizer = kwargs.pop("optimizer", "Adam") | |
| self.activation_function = kwargs.pop('activation_function', 'ReLU') | |
| self.verbose = kwargs.pop("verbose", False) # debugging or printing additional warnings / information set tot True | |
| self.hopf_input_size = kwargs.pop('hopf_input_size', 2048) | |
| self.hopf_output_size = kwargs.pop("hopf_output_size", 768) | |
| self.hopf_num_heads = kwargs.pop("hopf_num_heads", 1) | |
| self.hopf_asso_dim = kwargs.pop("hopf_asso_dim", 768) | |
| self.hopf_association_activation = kwargs.pop("hopf_association_activation", None) | |
| self.hopf_beta = kwargs.pop("hopf_beta",0.125) # 1/(self.hopf_asso_dim**(1/2) sqrt(d_k) | |
| self.norm_input = kwargs.pop("norm_input",False) | |
| self.norm_asso = kwargs.pop("norm_asso", False) | |
| # additional experimental hyperparams | |
| if 'hopf_n_layers' in kwargs.keys(): | |
| self.hopf_n_layers = kwargs.pop('hopf_n_layers', 0) | |
| if 'mol_encoder_layers' in kwargs.keys(): | |
| self.mol_encoder_layers = kwargs.pop('mol_encoder_layers', 1) | |
| if 'temp_encoder_layers' in kwargs.keys(): | |
| self.temp_encoder_layers = kwargs.pop('temp_encoder_layers', 1) | |
| if 'encoder_af' in kwargs.keys(): | |
| self.encoder_af = kwargs.pop('encoder_af', 'ReLU') | |
| # additional kwargs | |
| for key, value in kwargs.items(): | |
| try: | |
| setattr(self, key, value) | |
| except AttributeError as err: | |
| log.error(f"Can't set {key} with value {value} for {self}") | |
| raise err | |
| class Encoder(nn.Module): | |
| """Simple FFNN""" | |
| def __init__(self, input_size: int = 2048, output_size: int = 1024, | |
| num_layers: int = 1, dropout: float = 0.3, af_name: str ='None', | |
| norm_in: bool = False, norm_out: bool = False): | |
| super().__init__() | |
| self.ws = [] | |
| self.setup_af(af_name) | |
| self.norm_in = (lambda k: k) if not norm_in else torch.nn.LayerNorm(input_size, elementwise_affine=False) | |
| self.norm_out = (lambda k: k) if not norm_out else torch.nn.LayerNorm(output_size, elementwise_affine=False) | |
| self.setup_ff(input_size, output_size, num_layers) | |
| self.dropout = nn.Dropout(p=dropout) | |
| def forward(self, x: torch.Tensor): | |
| x = self.norm_in(x) | |
| for i, w in enumerate(self.ws): | |
| if i==(len(self.ws)-1): | |
| x = self.dropout(w(x)) # all except last haf ff_af | |
| else: | |
| x = self.dropout(self.af(w(x))) | |
| x = self.norm_out(x) | |
| return x | |
| def setup_ff(self, input_size:int, output_size:int, num_layers=1): | |
| """setup feed-forward NN with n-layers""" | |
| for n in range(0, num_layers): | |
| w = nn.Linear(input_size if n==0 else output_size, output_size) | |
| torch.nn.init.kaiming_normal_(w.weight, mode='fan_in', nonlinearity='linear') # eqiv to LeCun init | |
| setattr(self, f'W_{n}', w) # consider doing a step-wise reduction | |
| self.ws.append(getattr(self, f'W_{n}')) | |
| def setup_af(self, af_name : str): | |
| """set activation function""" | |
| if af_name is None or (af_name == 'None'): | |
| self.af = lambda k: k | |
| else: | |
| try: | |
| self.af = getattr(nn, af_name)() | |
| except AttributeError as err: | |
| log.error(f"Can't find activation-function {af_name} in torch.nn") | |
| raise err | |
| class MoleculeEncoder(Encoder): | |
| """ | |
| Class for Molecule encoder: can be any class mapping Smiles to a Vector (preferable differentiable ;) | |
| """ | |
| def __init__(self, config): | |
| self.config = config | |
| class FPMolEncoder(Encoder): | |
| """ | |
| Fingerprint Based Molecular encoder | |
| """ | |
| def __init__(self, config): | |
| super().__init__(input_size = config.hopf_input_size*config.hopf_num_heads, | |
| output_size = config.hopf_asso_dim*config.hopf_num_heads, | |
| num_layers = config.mol_encoder_layers, | |
| dropout = config.dropout, | |
| af_name = config.encoder_af, | |
| norm_in = config.norm_input, | |
| norm_out = config.norm_asso, | |
| ) | |
| # number of layers = self.config.mol_encoder_layers | |
| # layer-dimension = self.config.hopf_asso_dim | |
| # activation-function = self.config.af | |
| self.config = config | |
| def forward_smiles(self, list_of_smiles: list): | |
| fp_tensor = self.convert_smiles_to_tensor(list_of_smiles) | |
| return self.forward(fp_tensor) | |
| def convert_smiles_to_tensor(self, list_of_smiles): | |
| fps = convert_smiles_to_fp(list_of_smiles, fp_size=self.config.fp_size, | |
| which=self.config.fingerprint_type, radius=self.config.fp_radius) | |
| fps_tensor = torch.from_numpy(fps.astype(np.float)).to(dtype=torch.float).to(self.config.device) | |
| return fps_tensor | |
| class TemplateEncoder(Encoder): | |
| """ | |
| Class for Template encoder: can be any class mapping a Smarts-Reaction to a Vector (preferable differentiable ;) | |
| """ | |
| def __init__(self, config): | |
| super().__init__(input_size = config.hopf_input_size*config.hopf_num_heads, | |
| output_size = config.hopf_asso_dim*config.hopf_num_heads, | |
| num_layers = config.temp_encoder_layers, | |
| dropout = config.dropout, | |
| af_name = config.encoder_af, | |
| norm_in = config.norm_input, | |
| norm_out = config.norm_asso, | |
| ) | |
| self.config = config | |
| #number of layers | |
| #template fingerprint type | |
| #random template threshold | |
| #reactant pooling | |
| if config.temp_encoder_layers==0: | |
| print('No Key-Projection = Static Key/Templates') | |
| assert self.config.hopf_asso_dim==self.config.fp_size | |
| self.wks = [] | |
| class MHN(nn.Module): | |
| """ | |
| MHN - modern Hopfield Network -- for Template relevance prediction | |
| """ | |
| def __init__(self, config=None, layer2weight=0.05, use_template_encoder=True): | |
| super().__init__() | |
| if config: | |
| self.config = config | |
| else: | |
| self.config = ModelConfig() | |
| self.beta = self.config.hopf_beta | |
| # hopf_num_heads | |
| self.mol_encoder = FPMolEncoder(self.config) | |
| if use_template_encoder: | |
| self.template_encoder = TemplateEncoder(self.config) | |
| self.W_v = None | |
| self.layer2weight = layer2weight | |
| # more MHN layers -- added recursively | |
| if hasattr(self.config, 'hopf_n_layers'): | |
| di = self.config.__dict__ | |
| di['hopf_n_layers'] -= 1 | |
| if di['hopf_n_layers']>0: | |
| conf_wo_hopf_nlayers = ModelConfig(**di) | |
| self.layer = MHN(conf_wo_hopf_nlayers) | |
| if di['hopf_n_layers']!=0: | |
| self.W_v = nn.Linear(self.config.hopf_asso_dim, self.config.hopf_input_size) | |
| torch.nn.init.kaiming_normal_(self.W_v.weight, mode='fan_in', nonlinearity='linear') # eqiv to LeCun init | |
| self.softmax = torch.nn.Softmax(dim=1) | |
| self.lossfunction = nn.CrossEntropyLoss(reduction='none')#, weight=class_weights) | |
| self.pretrain_lossfunction = nn.BCEWithLogitsLoss(reduction='none')#, weight=class_weights) | |
| self.lr = self.config.lr | |
| if self.config.hopf_association_activation is None or (self.config.hopf_association_activation.lower()=='none'): | |
| self.af = lambda k: k | |
| else: | |
| self.af = getattr(nn, self.config.hopf_association_activation)() | |
| self.pooling_operation_head = getattr(torch, self.config.pooling_operation_head) | |
| self.X = None # templates projected to Hopfield Layer | |
| self.optimizer = getattr(torch.optim, self.config.optimizer)(self.parameters(), lr=self.lr) | |
| self.steps = 0 | |
| self.hist = defaultdict(list) | |
| self.to(self.config.device) | |
| def set_templates(self, template_list, which='rdk', fp_size=None, radius=2, learnable=False, njobs=1, only_templates_in_batch=False): | |
| self.template_list = template_list.copy() | |
| if fp_size is None: | |
| fp_size = self.config.fp_size | |
| if len(template_list)>=100000: | |
| import math | |
| print('batch-wise template_calculation') | |
| bs = 30000 | |
| final_temp_emb = torch.zeros((len(template_list), fp_size)).float().to(self.config.device) | |
| for b in range(math.ceil(len(template_list)//bs)+1): | |
| self.template_list = template_list[bs*b:min(bs*(b+1), len(template_list))] | |
| templ_emb = self.update_template_embedding(which=which, fp_size=fp_size, radius=radius, learnable=learnable, njobs=njobs, only_templates_in_batch=only_templates_in_batch) | |
| final_temp_emb[bs*b:min(bs*(b+1), len(template_list))] = torch.from_numpy(templ_emb) | |
| self.templates = final_temp_emb | |
| else: | |
| self.update_template_embedding(which=which, fp_size=fp_size, radius=radius, learnable=learnable, njobs=njobs, only_templates_in_batch=only_templates_in_batch) | |
| self.set_templates_recursively() | |
| def set_templates_recursively(self): | |
| if 'hopf_n_layers' in self.config.__dict__.keys(): | |
| if self.config.hopf_n_layers >0: | |
| self.layer.templates = self.templates | |
| self.layer.set_templates_recursively() | |
| def update_template_embedding(self,fp_size=2048, radius=4, which='rdk', learnable=False, njobs=1, only_templates_in_batch=False): | |
| print('updating template-embedding; (just computing the template-fingerprint and using that)') | |
| bs = self.config.batch_size | |
| split_template_list = [str(t).split('>')[0].split('.') for t in self.template_list] | |
| templates_np = convert_smiles_to_fp(split_template_list, is_smarts=True, fp_size=fp_size, radius=radius, which=which, njobs=njobs) | |
| split_template_list = [str(t).split('>')[-1].split('.') for t in self.template_list] | |
| reactants_np = convert_smiles_to_fp(split_template_list, is_smarts=True, fp_size=fp_size, radius=radius, which=which, njobs=njobs) | |
| template_representation = templates_np-(reactants_np*0.5) | |
| if learnable: | |
| self.templates = torch.nn.Parameter(torch.from_numpy(template_representation).float(), requires_grad=True).to(self.config.device) | |
| self.register_parameter(name='templates', param=self.templates) | |
| else: | |
| if only_templates_in_batch: | |
| self.templates_np = template_representation | |
| else: | |
| self.templates = torch.from_numpy(template_representation).float().to(self.config.device) | |
| return template_representation | |
| def np_fp_to_tensor(self, np_fp): | |
| return torch.from_numpy(np_fp.astype(np.float64)).to(self.config.device).float() | |
| def masked_loss_fun(self, loss_fun, h_out, ys_batch): | |
| if loss_fun == self.BCEWithLogitsLoss: | |
| mask = (ys_batch != -1).float() | |
| ys_batch = ys_batch.float() | |
| else: | |
| mask = (ys_batch.long() != -1).long() | |
| mask_sum = int(mask.sum().cpu().numpy()) | |
| if mask_sum == 0: | |
| return 0 | |
| ys_batch = ys_batch * mask | |
| loss = (loss_fun(h_out, ys_batch * mask) * mask.float()).sum() / mask_sum # only mean from non -1 | |
| return loss | |
| def compute_losses(self, out, ys_batch, head_loss_weight=None): | |
| if len(ys_batch.shape)==2: | |
| if ys_batch.shape[1]==self.config.num_templates: # it is in pretraining_mode | |
| loss = self.pretrain_lossfunction(out, ys_batch.float()).mean() | |
| else: | |
| # legacy from policyNN | |
| loss = self.lossfunction(out, ys_batch[:, 2]).mean() # WARNING: HEAD4 Reaction Template is ys[:,2] | |
| else: | |
| loss = self.lossfunction(out, ys_batch).mean() | |
| return loss | |
| def forward_smiles(self, list_of_smiles, templates=None): | |
| state_tensor = self.mol_encoder.convert_smiles_to_tensor(list_of_smiles) | |
| return self.forward(state_tensor, templates=templates) | |
| def forward(self, m, templates=None): | |
| """ | |
| m: molecule in the form batch x fingerprint | |
| templates: None or newly given templates if not instanciated | |
| returns logits ranking the templates for each molecule | |
| """ | |
| #states_emb = self.fcfe(state_fp) | |
| bs = m.shape[0] #batch_size | |
| #templates = self.temp_emb(torch.arange(0,2000).long()) | |
| if (templates is None) and (self.X is None) and (self.templates is None): | |
| raise Exception('Either pass in templates, or init templates by runnting clf.set_templates') | |
| n_temp = len(templates) if templates is not None else len(self.templates) | |
| if self.training or (templates is None) or (self.X is not None): | |
| templates = templates if templates is not None else self.templates | |
| X = self.template_encoder(templates) | |
| else: | |
| X = self.X # precomputed from last forward run | |
| Xi = self.mol_encoder(m) | |
| Xi = Xi.view(bs, self.config.hopf_num_heads, self.config.hopf_asso_dim) # [bs, H, A] | |
| X = X.view(1, n_temp, self.config.hopf_asso_dim, self.config.hopf_num_heads) #[1, T, A, H] | |
| XXi = torch.tensordot(Xi, X, dims=[(2,1), (2,0)]) # AxA -> [bs, T, H] | |
| # pooling over heads | |
| if self.config.hopf_num_heads<=1: | |
| #QKt_pooled = QKt | |
| XXi = XXi[:,:,0] #torch.squeeze(QKt, dim=2) | |
| else: | |
| XXi = self.pooling_operation_head(XXi, dim=2) # default is max pooling over H [bs, T] | |
| if (self.config.pooling_operation_head =='max') or (self.config.pooling_operation_head =='min'): | |
| XXi = XXi[0] #max and min also return the indices =S | |
| out = self.beta*XXi # [bs, T, H] # softmax over dim=1 #pooling_operation_head | |
| self.xinew = self.softmax(out)@X.view(n_temp, self.config.hopf_asso_dim) # [bs,T]@[T,emb] -> [bs,emb] | |
| if self.W_v: | |
| # call layers recursive | |
| hopfout = self.W_v(self.xinew) # [bs,emb]@[emb,hopf_inp] --> [bs, hopf_inp] | |
| # TODO check if using x_pooled or if not going through mol_encoder again | |
| hopfout = hopfout + m # skip-connection | |
| # give it to the next layer | |
| out2 = self.layer.forward(hopfout) #templates=self.W_v(self.K) | |
| out = out*(1-self.layer2weight)+out2*self.layer2weight | |
| return out | |
| def train_from_np(self, Xs, targets, ys, is_smiles=False, epochs=2, lr=0.001, bs=32, | |
| permute_batches=False, shuffle=True, optimizer=None, | |
| use_dataloader=True, verbose=False, | |
| wandb=None, scheduler=None, only_templates_in_batch=False): | |
| """ | |
| Xs in the form sample x states | |
| targets | |
| ys in the form sample x [y_h1, y_h2, y_h3, y_h4] | |
| """ | |
| self.train() | |
| if optimizer is None: | |
| try: | |
| self.optimizer = getattr(torch.optim, self.config.optimizer)(self.parameters(), lr=self.lr if lr is None else lr) | |
| except AttributeError as err: | |
| log.error(f"Can't find optimizer {config.optimizer} in torch.optim") | |
| raise err | |
| optimizer = self.optimizer | |
| dataset = ChemRXNDataset(Xs, targets, ys, is_smiles=is_smiles, | |
| fp_size=self.config.fp_size, fingerprint_type=self.config.fingerprint_type) | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=shuffle, sampler=None, | |
| batch_sampler=None, num_workers=0, collate_fn=None, | |
| pin_memory=False, drop_last=False, timeout=0, | |
| worker_init_fn=None) | |
| for epoch in range(epochs): # loop over the dataset multiple times | |
| running_loss = 0.0 | |
| running_loss_dict = defaultdict(int) | |
| batch_order = range(0, len(Xs), bs) | |
| if permute_batches: | |
| batch_order = np.random.permutation(batch_order) | |
| for step, s in tqdm(enumerate(dataloader),mininterval=2): | |
| batch = [b.to(self.config.device, non_blocking=True) for b in s] | |
| Xs_batch, target_batch, ys_batch = batch | |
| # zero the parameter gradients | |
| optimizer.zero_grad() | |
| # forward + backward + optimize | |
| out = self.forward(Xs_batch) | |
| total_loss = self.compute_losses(out, ys_batch) | |
| loss_dict = {'CE_loss': total_loss} | |
| total_loss.backward() | |
| optimizer.step() | |
| if scheduler: | |
| scheduler.step() | |
| self.steps += 1 | |
| # print statistics | |
| for k in loss_dict: | |
| running_loss_dict[k] += loss_dict[k].item() | |
| try: | |
| running_loss += total_loss.item() | |
| except: | |
| running_loss += 0 | |
| rs = min(100,len(Xs)//bs) # reporting/logging steps | |
| if step % rs == (rs-1): # print every 2000 mini-batches | |
| if verbose: print('[%d, %5d] loss: %.3f' % | |
| (epoch + 1, step + 1, running_loss / rs)) | |
| self.hist['step'].append(self.steps) | |
| self.hist['loss'].append(running_loss/rs) | |
| self.hist['trianing_running_loss'].append(running_loss/rs) | |
| [self.hist[k].append(running_loss_dict[k]/rs) for k in running_loss_dict] | |
| if wandb: | |
| wandb.log({'trianing_running_loss': running_loss / rs}) | |
| running_loss = 0.0 | |
| running_loss_dict = defaultdict(int) | |
| if verbose: print('Finished Training') | |
| return optimizer | |
| def evaluate(self, Xs, targets, ys, split='test', is_smiles=False, bs = 32, shuffle=False, wandb=None, only_loss=False): | |
| self.eval() | |
| y_preds = np.zeros( (ys.shape[0], self.config.num_templates), dtype=np.float16) | |
| loss_metrics = defaultdict(int) | |
| new_hist = defaultdict(float) | |
| with torch.no_grad(): | |
| dataset = ChemRXNDataset(Xs, targets, ys, is_smiles=is_smiles, | |
| fp_size=self.config.fp_size, fingerprint_type=self.config.fingerprint_type) | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=shuffle, sampler=None, | |
| batch_sampler=None, num_workers=0, collate_fn=None, | |
| pin_memory=False, drop_last=False, timeout=0, | |
| worker_init_fn=None) | |
| #for step, s in eoutputs = self.forward(batch[0], batchnumerate(range(0, len(Xs), bs)): | |
| for step, batch in enumerate(dataloader):# | |
| batch = [b.to(self.config.device, non_blocking=True) for b in batch] | |
| ys_batch = batch[2] | |
| if hasattr(self, 'templates_np'): | |
| outputs = [] | |
| for ii in range(10): | |
| tlen = len(self.templates_np) | |
| i_tlen = tlen//10 | |
| templates = torch.from_numpy(self.templates_np[(i_tlen*ii):min(i_tlen*(ii+1), tlen)]).float().to(self.config.device) | |
| outputs.append( self.forward(batch[0], templates = templates ) ) | |
| outputs = torch.cat(outputs, dim=0) | |
| else: | |
| outputs = self.forward(batch[0]) | |
| loss = self.compute_losses(outputs, ys_batch, None) | |
| # not quite right because in every batch there might be different number of valid samples | |
| weight = 1/len(batch[0])#len(Xs[s:min(s + bs, len(Xs))]) / len(Xs) | |
| loss_metrics['loss'] += (loss.item()) | |
| if len(ys.shape)>1: | |
| outputs = self.softmax(outputs) if not (ys.shape[1]==self.config.num_templates) else torch.sigmoid(outputs) | |
| else: | |
| outputs = self.softmax(outputs) | |
| outputs_np = [None if o is None else o.to('cpu').numpy().astype(np.float16) for o in outputs] | |
| if not only_loss: | |
| ks = [1, 2, 3, 4, 5, 10, 20, 30, 40, 50, 100] | |
| topkacc, mrocc = top_k_accuracy(ys_batch, outputs, k=ks, ret_arocc=True, ret_mrocc=False) | |
| # mrocc -- median rank of correct choice | |
| for k, tkacc in zip(ks, topkacc): | |
| #iterative average update | |
| new_hist[f't{k}_acc_{split}'] += (tkacc-new_hist[f't{k}_acc_{split}']) / (step+1) | |
| # todo weight by batch-size | |
| new_hist[f'meanrank_{split}'] = mrocc | |
| y_preds[step*bs : min((step+1)*bs,len(y_preds))] = outputs_np | |
| new_hist[f'steps_{split}'] = (self.steps) | |
| new_hist[f'loss_{split}'] = (loss_metrics['loss'] / (step+1)) | |
| for k in new_hist: | |
| self.hist[k].append(new_hist[k]) | |
| if wandb: | |
| wandb.log(new_hist) | |
| self.hist[f'loss_{split}'].append(loss_metrics[f'loss'] / (step+1)) | |
| return y_preds | |
| def save_hist(self, prefix='', postfix=''): | |
| HIST_PATH = 'data/hist/' | |
| if not os.path.exists(HIST_PATH): | |
| os.mkdir(HIST_PATH) | |
| fn_hist = HIST_PATH+prefix+postfix+'.csv' | |
| with open(fn_hist, 'w') as fh: | |
| print(dict(self.hist), file=fh) | |
| return fn_hist | |
| def save_model(self, prefix='', postfix='', name_as_conf=False): | |
| MODEL_PATH = 'data/model/' | |
| if not os.path.exists(MODEL_PATH): | |
| os.mkdir(MODEL_PATH) | |
| if name_as_conf: | |
| confi_str = str(self.config.__dict__.values()).replace("'","").replace(': ','_').replace(', ',';') | |
| else: | |
| confi_str = '' | |
| model_name = prefix+confi_str+postfix+'.pt' | |
| torch.save(self.state_dict(), MODEL_PATH+model_name) | |
| return MODEL_PATH+model_name | |
| def plot_loss(self): | |
| plot_loss(self.hist) | |
| def plot_topk(self, sets=['train', 'valid', 'test'], with_last = 2): | |
| plot_topk(self.hist, sets=sets, with_last = with_last) | |
| def plot_nte(self, last_cpt=1, dataset='Sm', include_bar=True): | |
| plot_nte(self.hist, dataset=dataset, last_cpt=last_cpt, include_bar=include_bar) | |
| class SeglerBaseline(MHN): | |
| """FFNN - only the Molecule Encoder + an output projection""" | |
| def __init__(self, config=None): | |
| config.template_fp_type = 'none' | |
| config.temp_encoder_layers = 0 | |
| super().__init__(config, use_template_encoder=False) | |
| self.W_out = torch.nn.Linear(config.hopf_asso_dim, config.num_templates) | |
| self.optimizer = getattr(torch.optim, self.config.optimizer)(self.parameters(), lr=self.lr) | |
| self.steps = 0 | |
| self.hist = defaultdict(list) | |
| self.to(self.config.device) | |
| def forward(self, m, templates=None): | |
| """ | |
| m: molecule in the form batch x fingerprint | |
| templates: won't be used in this case | |
| returns logits ranking the templates for each molecule | |
| """ | |
| bs = m.shape[0] #batch_size | |
| Xi = self.mol_encoder(m) | |
| Xi = self.mol_encoder.af(Xi) # is not applied in encoder for last layer | |
| out = self.W_out(Xi) # [bs, T] # softmax over dim=1 | |
| return out | |
| class StaticQK(MHN): | |
| """ Static QK baseline - beware to have the same fingerprint for mol_encoder as for the template_encoder (fp2048 r4 rdk by default)""" | |
| def __init__(self, config=None): | |
| if config: | |
| self.config = config | |
| else: | |
| self.config = ModelConfig() | |
| super().__init__(config) | |
| self.fp_size = 2048 | |
| self.fingerprint_type = 'rdk' | |
| self.beta = 1 | |
| def update_template_embedding(self, which='rdk', fp_size=2048, radius=4, learnable=False): | |
| bs = self.config.batch_size | |
| split_template_list = [t.split('>>')[0].split('.') for t in self.template_list] | |
| self.templates = torch.from_numpy(convert_smiles_to_fp(split_template_list, | |
| is_smarts=True, fp_size=fp_size, | |
| radius=radius, which=which).max(1)).float().to(self.config.device) | |
| def forward(self, m, templates=None): | |
| """ | |
| """ | |
| #states_emb = self.fcfe(state_fp) | |
| bs = m.shape[0] #batch_size | |
| Xi = m #[bs, emb] | |
| X = self.templates #[T, emb]) | |
| XXi = Xi@X.T # [bs, T] | |
| # normalize | |
| t_sum = templates.sum(1) #[T] | |
| t_sum = t_sum.view(1,-1).expand(bs, -1) #[bs, T] | |
| XXi = XXi / t_sum | |
| # not neccecaire because it is not trained | |
| out = self.beta*XXi # [bs, T] # softmax over dim=1 | |
| return out | |
| class Retrosim(StaticQK): | |
| """ Retrosim-like baseline only for template relevance prediction """ | |
| def fit_with_train(self, X_fp_train, y_train): | |
| self.templates = torch.from_numpy(X_fp_train).float().to(self.config.device) | |
| # train_samples, num_templates | |
| self.sample2acttemplate = torch.nn.functional.one_hot(torch.from_numpy(y_train), self.config.num_templates).float() | |
| tmpnorm = self.sample2acttemplate.sum(0) | |
| tmpnorm[tmpnorm==0] = 1 | |
| self.sample2acttemplate = (self.sample2acttemplate / tmpnorm).to(self.config.device) # results in an average after dot product | |
| def forward(self, m, templates=None): | |
| """ | |
| """ | |
| out = super().forward(m, templates=templates) | |
| # bs, train_samples | |
| # map out to actual templates | |
| out = out @ self.sample2acttemplate | |
| return out |