Spaces:
Sleeping
Sleeping
File size: 2,871 Bytes
603d88b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 27 10:43:40 2022
@author: ZNDX002
"""
import numpy as np
import torch
from torch.utils.data import Dataset
import ast
class ClrDataset(Dataset):
"""Contrastive Learning Representations Dataset."""
def __init__(self,
file,
list_IDs,
transform=None):
self.clr_frame = file
self.list_IDs = list_IDs
def __len__(self):
return len(self.clr_frame)
def __getitem__(self, idx):
index = self.list_IDs[idx]
v_d = self.clr_frame.loc[index,'Graph']
spec = self.clr_frame.loc[idx,'MS2']
spec_mz = spec.mz
spec_intens = spec.intensities
spec_mz = np.around(spec_mz, decimals=4)
#spec_mz = torch.from_numpy(spec_mz).float()
#spec_intens = torch.from_numpy(spec_intens).float()
num_peak = len(spec_mz)
#spec_mz = np.pad(spec_mz, (0, 300 - len(spec_mz)), mode='constant', constant_values=0)
#spec_intens = np.pad(spec_intens, (0, 300 - len(spec_intens)), mode='constant', constant_values=0)
return v_d,spec_mz,spec_intens,num_peak
#return {'graph':v_d,'mz':spec_mz,'inten':spec_intens}
class re_train_dataset(Dataset):
def __init__(self,
file,
list_IDs,
transform=None):
self.clr_frame = file
self.list_IDs = list_IDs
def __len__(self):
return len(self.clr_frame)
def __getitem__(self, idx):
index = self.list_IDs[idx]
v_d = self.clr_frame.loc[index,'Graph']
spec = self.clr_frame.loc[index,'MS2']
#spec = np.array(ast.literal_eval(spec))
spec = torch.from_numpy(spec).to(torch.float32)
return v_d,spec
class re_eval_dataset(Dataset):
def __init__(self,
file,
list_IDs,
smiles_reference,
transform=None):
self.clr_frame = file
self.list_IDs = list_IDs
self.valid_formulas = list(self.clr_frame['formula'])
self.smiles_reference = smiles_reference
self.structures = list(self.clr_frame['Graph']) + list(self.smiles_reference['Graph'])
self.spectra = list(self.clr_frame['MS2'])
self.spec2smi = {}
smi_id = 0
for spec_id, ann in enumerate(self.spectra):
self.spec2smi[spec_id] = []
self.spec2smi[spec_id].append(smi_id)
smi_id += 1
def __len__(self):
return len(self.clr_frame)
def __getitem__(self, idx):
index = self.list_IDs[idx]
spec = self.clr_frame.loc[index,'MS2']
spec = torch.from_numpy(spec).to(torch.float32)
formula = self.clr_frame.loc[index,'formula']
return spec |