Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """ | |
| Created on Wed Apr 27 10:43:40 2022 | |
| @author: ZNDX002 | |
| """ | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import Dataset | |
| import ast | |
| class ClrDataset(Dataset): | |
| """Contrastive Learning Representations Dataset.""" | |
| def __init__(self, | |
| file, | |
| list_IDs, | |
| transform=None): | |
| self.clr_frame = file | |
| self.list_IDs = list_IDs | |
| def __len__(self): | |
| return len(self.clr_frame) | |
| def __getitem__(self, idx): | |
| index = self.list_IDs[idx] | |
| v_d = self.clr_frame.loc[index,'Graph'] | |
| spec = self.clr_frame.loc[idx,'MS2'] | |
| spec_mz = spec.mz | |
| spec_intens = spec.intensities | |
| spec_mz = np.around(spec_mz, decimals=4) | |
| #spec_mz = torch.from_numpy(spec_mz).float() | |
| #spec_intens = torch.from_numpy(spec_intens).float() | |
| num_peak = len(spec_mz) | |
| #spec_mz = np.pad(spec_mz, (0, 300 - len(spec_mz)), mode='constant', constant_values=0) | |
| #spec_intens = np.pad(spec_intens, (0, 300 - len(spec_intens)), mode='constant', constant_values=0) | |
| return v_d,spec_mz,spec_intens,num_peak | |
| #return {'graph':v_d,'mz':spec_mz,'inten':spec_intens} | |
| class re_train_dataset(Dataset): | |
| def __init__(self, | |
| file, | |
| list_IDs, | |
| transform=None): | |
| self.clr_frame = file | |
| self.list_IDs = list_IDs | |
| def __len__(self): | |
| return len(self.clr_frame) | |
| def __getitem__(self, idx): | |
| index = self.list_IDs[idx] | |
| v_d = self.clr_frame.loc[index,'Graph'] | |
| spec = self.clr_frame.loc[index,'MS2'] | |
| #spec = np.array(ast.literal_eval(spec)) | |
| spec = torch.from_numpy(spec).to(torch.float32) | |
| return v_d,spec | |
| class re_eval_dataset(Dataset): | |
| def __init__(self, | |
| file, | |
| list_IDs, | |
| smiles_reference, | |
| transform=None): | |
| self.clr_frame = file | |
| self.list_IDs = list_IDs | |
| self.valid_formulas = list(self.clr_frame['formula']) | |
| self.smiles_reference = smiles_reference | |
| self.structures = list(self.clr_frame['Graph']) + list(self.smiles_reference['Graph']) | |
| self.spectra = list(self.clr_frame['MS2']) | |
| self.spec2smi = {} | |
| smi_id = 0 | |
| for spec_id, ann in enumerate(self.spectra): | |
| self.spec2smi[spec_id] = [] | |
| self.spec2smi[spec_id].append(smi_id) | |
| smi_id += 1 | |
| def __len__(self): | |
| return len(self.clr_frame) | |
| def __getitem__(self, idx): | |
| index = self.list_IDs[idx] | |
| spec = self.clr_frame.loc[index,'MS2'] | |
| spec = torch.from_numpy(spec).to(torch.float32) | |
| formula = self.clr_frame.loc[index,'formula'] | |
| return spec |