CSU-MS2-T2 / dataloader /dataset.py
Tingxie's picture
Upload 2 files
603d88b
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 27 10:43:40 2022
@author: ZNDX002
"""
import numpy as np
import torch
from torch.utils.data import Dataset
import ast
class ClrDataset(Dataset):
"""Contrastive Learning Representations Dataset."""
def __init__(self,
file,
list_IDs,
transform=None):
self.clr_frame = file
self.list_IDs = list_IDs
def __len__(self):
return len(self.clr_frame)
def __getitem__(self, idx):
index = self.list_IDs[idx]
v_d = self.clr_frame.loc[index,'Graph']
spec = self.clr_frame.loc[idx,'MS2']
spec_mz = spec.mz
spec_intens = spec.intensities
spec_mz = np.around(spec_mz, decimals=4)
#spec_mz = torch.from_numpy(spec_mz).float()
#spec_intens = torch.from_numpy(spec_intens).float()
num_peak = len(spec_mz)
#spec_mz = np.pad(spec_mz, (0, 300 - len(spec_mz)), mode='constant', constant_values=0)
#spec_intens = np.pad(spec_intens, (0, 300 - len(spec_intens)), mode='constant', constant_values=0)
return v_d,spec_mz,spec_intens,num_peak
#return {'graph':v_d,'mz':spec_mz,'inten':spec_intens}
class re_train_dataset(Dataset):
def __init__(self,
file,
list_IDs,
transform=None):
self.clr_frame = file
self.list_IDs = list_IDs
def __len__(self):
return len(self.clr_frame)
def __getitem__(self, idx):
index = self.list_IDs[idx]
v_d = self.clr_frame.loc[index,'Graph']
spec = self.clr_frame.loc[index,'MS2']
#spec = np.array(ast.literal_eval(spec))
spec = torch.from_numpy(spec).to(torch.float32)
return v_d,spec
class re_eval_dataset(Dataset):
def __init__(self,
file,
list_IDs,
smiles_reference,
transform=None):
self.clr_frame = file
self.list_IDs = list_IDs
self.valid_formulas = list(self.clr_frame['formula'])
self.smiles_reference = smiles_reference
self.structures = list(self.clr_frame['Graph']) + list(self.smiles_reference['Graph'])
self.spectra = list(self.clr_frame['MS2'])
self.spec2smi = {}
smi_id = 0
for spec_id, ann in enumerate(self.spectra):
self.spec2smi[spec_id] = []
self.spec2smi[spec_id].append(smi_id)
smi_id += 1
def __len__(self):
return len(self.clr_frame)
def __getitem__(self, idx):
index = self.list_IDs[idx]
spec = self.clr_frame.loc[index,'MS2']
spec = torch.from_numpy(spec).to(torch.float32)
formula = self.clr_frame.loc[index,'formula']
return spec