import numpy as np import pandas as pd import torch from tqdm import tqdm #from torch.utils.data import DataLoader from torch.utils.data.sampler import SubsetRandomSampler from torch.utils.data.distributed import DistributedSampler from torchvision import datasets from .dataset import ClrDataset,re_train_dataset,re_eval_dataset from functools import partial from rdkit import RDConfig from rdkit import Chem from rdkit.Chem import AllChem,rdMolDescriptors,Descriptors from matchms.Fragments import Fragments import matchms.filtering as msfilters from matchms.importing import load_from_mgf import warnings from torch_geometric.data import Data, DataLoader,Batch #from torch_geometric.data import Data warnings.filterwarnings('ignore') import json import random import ast from rdkit.Chem.rdchem import BondType as BT from rdkit import RDLogger from toolz.sandbox import unzip RDLogger.DisableLog('rdApp.*') ATOM_LIST = list(range(1,119)) CHIRALITY_LIST = [ Chem.rdchem.ChiralType.CHI_UNSPECIFIED, Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, Chem.rdchem.ChiralType.CHI_OTHER ] HYBRID_TYPE = [Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP2D, Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2, Chem.rdchem.HybridizationType.UNSPECIFIED, Chem.rdchem.HybridizationType.S] VALENCE_LIST = list(range(0,8)) DRGREE_LIST = list(range(0,5)) BOND_LIST = [BT.SINGLE, BT.DOUBLE, BT.TRIPLE, BT.AROMATIC] BONDDIR_LIST = [ Chem.rdchem.BondDir.NONE, Chem.rdchem.BondDir.ENDUPRIGHT, Chem.rdchem.BondDir.ENDDOWNRIGHT, ] def collate_func(input_list): x,mzs,intens,num_peaks = map(list, unzip(input_list)) num_peaks = torch.LongTensor(num_peaks) mzs = [torch.from_numpy(spec_mz).float() for spec_mz in mzs] intens = [torch.from_numpy(spec_intens).float() for spec_intens in intens] mzs_tensors = torch.nn.utils.rnn.pad_sequence( mzs, batch_first=True, padding_value=0 ) intens_tensors = torch.nn.utils.rnn.pad_sequence( intens, batch_first=True, padding_value=0 ) x = Batch.from_data_list(x) return x,mzs_tensors,intens_tensors,num_peaks '''def valid_collate_func(x): ms, formula = zip(*x) return ms, formula ''' def valid_collate_func(x): ms = zip(*x) return ms def MolToGraph(smiles): mol = Chem.MolFromSmiles(smiles) mol = Chem.AddHs(mol) N = mol.GetNumAtoms() M = mol.GetNumBonds() type_idx = [] chirality_idx = [] atomic_number = [] hybrid_type_idx = [] valence_idx=[] degree_idx=[] for atom in mol.GetAtoms(): atom_index = atom.GetIdx() type_idx.append(ATOM_LIST.index(atom.GetAtomicNum())) atom_charity = atom.GetChiralTag() if atom_charity in CHIRALITY_LIST: chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag())) else: chirality_idx.append(CHIRALITY_LIST.index(Chem.rdchem.ChiralType.CHI_OTHER)) atomic_number.append(atom.GetAtomicNum()) hybrid_type_idx.append(HYBRID_TYPE.index(atom.GetHybridization())) valence_idx.append(VALENCE_LIST.index(min(atom.GetTotalValence(),7))) degree_idx.append(DRGREE_LIST.index(min(atom.GetDegree(),4))) x1 = torch.tensor(type_idx, dtype=torch.long).view(-1,1) x2 = torch.tensor(chirality_idx, dtype=torch.long).view(-1,1) x3 = torch.tensor(hybrid_type_idx, dtype=torch.long).view(-1,1) x4 = torch.tensor(valence_idx, dtype=torch.long).view(-1,1) x5 = torch.tensor(degree_idx, dtype=torch.long).view(-1,1) x = torch.cat([x1, x2, x3, x4, x5], dim=-1) row, col, edge_feat = [], [], [] for bond in mol.GetBonds(): start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() row += [start, end] col += [end, start] edge_feat.append([ BOND_LIST.index(bond.GetBondType()), BONDDIR_LIST.index(bond.GetBondDir()) ]) edge_feat.append([ BOND_LIST.index(bond.GetBondType()), BONDDIR_LIST.index(bond.GetBondDir()) ]) edge_index = torch.tensor([row, col], dtype=torch.long) edge_attr = torch.tensor(np.array(edge_feat), dtype=torch.long) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) return data def remove_peaks(mz,peak_intensities, threshold, percentage): low_intensity_peaks_indices = [i for i,intensitie in enumerate(peak_intensities) if intensitie < threshold] num_peaks_to_remove = int(len(low_intensity_peaks_indices) * percentage) peaks_to_remove = random.sample(low_intensity_peaks_indices, num_peaks_to_remove) for i in peaks_to_remove: peak_intensities[i] = 0 return mz,peak_intensities def enhance_peak_intensities(mz,peak_intensities, jitter_range): enhanced_intensities = [] for intensity in peak_intensities: jitter = random.uniform(-jitter_range, jitter_range) enhanced_intensity = intensity + (intensity * jitter) enhanced_intensities.append(enhanced_intensity) return mz,enhanced_intensities def peak_addition(mz,peak_intensities,noise_max): n_noise_peaks = np.random.randint(0, noise_max) max_mz=int(max(mz)*100) min_mz=int(min(mz)*100) idx_no_peaks = np.setdiff1d([i/100 for i in range(min_mz, max_mz)], mz) idx_noise_peaks = np.random.choice(idx_no_peaks, n_noise_peaks) mz = np.concatenate((mz, idx_noise_peaks)) new_values = 0.01 * np.random.random(len(idx_noise_peaks)) peak_intensities = np.concatenate((peak_intensities, new_values)) return mz,peak_intensities def data_augmentation(spectrum): mz_initial=spectrum.mz intens_initial=spectrum.intensities mz_rp,peak_rp = remove_peaks(mz_initial, intens_initial, threshold=0.001, percentage=0.2) mz_enhance,peak_enhance=enhance_peak_intensities(mz_rp, peak_rp, jitter_range=0.4) mz_add,peak_add = peak_addition(mz_enhance, peak_enhance, noise_max=10) indices= np.where(mz_add == 0)[0] mz_f = np.array([mz_add[i] for i in range(len(mz_add)) if i not in indices]) peak_f = np.array([peak_add[i] for i in range(len(mz_add)) if i not in indices]) peak_f = np.array([peak_f[i] for i in mz_f.argsort()]) mz_f.sort() spectrum.set('num_peaks',str(len(mz_f))) spectrum.peaks = Fragments(mz=mz_f,intensities=peak_f) spectrum = msfilters.normalize_intensities(spectrum) return spectrum def graph_spec2vec_calculation(smiles,spectra): print("calculating molecular graphs") df = pd.DataFrame(columns=['Graph','MS2']) for i in tqdm(range(len(smiles))): try: smi = smiles[i] v_d = MolToGraph(smi) spectrum = spectra[i] #spec2 = data_augmentation(spectrum) spectrum = msfilters.reduce_to_number_of_peaks(spectrum,n_required=3, n_max=300) if spectrum is not None: df.loc[len(df.index)] = [v_d,spectrum] except: print("SMILES", smi, "calculation failure") print("Calculated", len(df), "molecular graph-mass spectrometry pairs") return df def graph_spec2vec_valid_calculation(smiles,spectra,formulas): print("calculating molecular graphs") df = pd.DataFrame(columns=['Graph','MS2','formula']) for i in tqdm(range(len(smiles))): try: smi = smiles[i] formula = formulas[i] v_d = MolToGraph(smi) spectrum = spectra[i] #spec2 = data_augmentation(spectrum) df.loc[len(df.index)] = [v_d,spectrum,formula] except: pass print("Calculated", len(df), "molecular graph-mass spectrometry pairs") return df def graph_calculation(smiles,formulas): print("calculating molecular graphs") df = pd.DataFrame(columns=['Graph','formula']) for i in tqdm(range(len(smiles))): try: smi = smiles[i] formula=formulas[i] v_d = MolToGraph(smi) df.loc[len(df.index)] = [v_d,formula] except: pass print("Calculated", len(df), "molecular graphs") return df class DataSetWrapper(object): def __init__(self, world_size, rank, batch_size, num_workers, valid_size, s, ms2_file, smi_file): self.world_size = world_size self.rank = rank self.batch_size = batch_size self.num_workers = num_workers self.valid_size = valid_size self.s = s self.ms2_file = ms2_file self.smi_file = smi_file def get_data_loaders(self): self.smiles = np.load(self.smi_file).tolist() self.ms2 = list(load_from_mgf(self.ms2_file)) # obtain training indices that will be used for validation num_train = len(self.smiles) indices = list(range(num_train)) np.random.shuffle(indices) split = int(np.floor(self.valid_size * num_train)) train_idx, valid_idx = indices[split:], indices[:split] self.train_smiles = [self.smiles[i] for i in train_idx] self.train_ms2 = [self.ms2[i] for i in train_idx] self.valid_smiles = [self.smiles[i] for i in valid_idx] self.valid_ms2 = [self.ms2[i] for i in valid_idx] self.train_graph_file = graph_spec2vec_calculation(self.train_smiles,self.train_ms2) self.valid_graph_file = graph_spec2vec_calculation(self.valid_smiles,self.valid_ms2) train_dataset = ClrDataset(self.train_graph_file,self.train_graph_file.index.values) valid_dataset = ClrDataset(self.valid_graph_file,self.valid_graph_file.index.values) train_loader, valid_loader = self.get_train_validation_data_loaders(train_dataset,valid_dataset) return train_loader, valid_loader def get_train_validation_data_loaders(self, train_dataset,valid_dataset): train_sampler = DistributedSampler(train_dataset, num_replicas = self.world_size, rank=self.rank, shuffle = True) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, sampler=train_sampler,shuffle=False,collate_fn = collate_func) valid_sampler = DistributedSampler(valid_dataset, num_replicas = self.world_size, rank=self.rank, shuffle = False) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=self.batch_size, sampler=valid_sampler,shuffle=False,collate_fn = collate_func) #train_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=train_sampler, # num_workers=self.num_workers, drop_last=True, shuffle=False,collate_fn = collate_func) #valid_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=valid_sampler, # num_workers=self.num_workers, drop_last=True,collate_fn = collate_func) return train_loader, valid_loader class DataSetWrapper_noddp(object): def __init__(self, batch_size, num_workers, valid_size, s, ms2_file, smi_file): self.batch_size = batch_size self.num_workers = num_workers self.valid_size = valid_size self.s = s self.ms2_file = ms2_file self.smi_file = smi_file def get_data_loaders(self): self.smiles = np.load(self.smi_file).tolist() self.ms2 = list(load_from_mgf(self.ms2_file)) # obtain training indices that will be used for validation num_train = len(self.smiles) indices = list(range(num_train)) np.random.shuffle(indices) split = int(np.floor(self.valid_size * num_train)) train_idx, valid_idx = indices[split:], indices[:split] self.train_smiles = [self.smiles[i] for i in train_idx] self.train_ms2 = [self.ms2[i] for i in train_idx] self.valid_smiles = [self.smiles[i] for i in valid_idx] self.valid_ms2 = [self.ms2[i] for i in valid_idx] self.train_graph_file = graph_spec2vec_calculation(self.train_smiles,self.train_ms2) self.valid_graph_file = graph_spec2vec_calculation(self.valid_smiles,self.valid_ms2) train_dataset = ClrDataset(self.train_graph_file,self.train_graph_file.index.values) valid_dataset = ClrDataset(self.valid_graph_file,self.valid_graph_file.index.values) train_loader, valid_loader = self.get_train_validation_data_loaders(train_dataset,valid_dataset) return train_loader, valid_loader def get_train_validation_data_loaders(self, train_dataset,valid_dataset): train_loader =torch.utils.data.DataLoader( train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, collate_fn=collate_func, drop_last=True ) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, collate_fn=collate_func, drop_last=False ) return train_loader, valid_loader