maotao / src /mRNA2vec /LoadData.py
julse's picture
upload AA2CDS
4707555 verified
import transformers
from transformers import DebertaTokenizerFast, T5EncoderModel, T5Config
from typing import Callable, Optional, Union, Tuple, List
from torch.utils import data
import pickle
import torch
import numpy as np
import random
import pandas as pd
def pkl_load(one_path):
with open(one_path, 'rb') as f:
return pickle.load(f)
class DataLoad(data.Dataset):
def __init__(self,data,
tokenizer_path='./tokenizer/',
mode='train',
max_length=128,
mask_ratio=0.15,
):
self.mode = mode
# symbolic = ['(','.',')']
# lst_voc = ['PAD']
# for a1 in symbolic:
# for a2 in symbolic:
# for a3 in symbolic:
# lst_voc.extend([f'{a1}{a2}{a3}'])
# self.dic_voc = dict(zip(lst_voc, range(len(lst_voc))))
self.data = data
self.data_x = []
self.data_y = []
self.data_ss = []
self.samples = data
for key,values in self.samples.items():
seq, cds_start, cds_stop, cds_len, mRNA_len = values
seq = seq.replace("T", "U")
seq = seq.replace("_", "")
seq = seq.replace("<","")
self.data_x.append(seq)
self.data_y.append("")
self.data_ss.append("")
self.tokenizer = DebertaTokenizerFast.from_pretrained(tokenizer_path, use_fast=True)
self.keys = list(self.samples.keys())
self.mask_token_id = self.tokenizer.mask_token_id
self.pad_token_id = self.tokenizer.pad_token_id
self.mask_ratio = mask_ratio
self.mask_fn = self.normal_mask
self.max_length = max_length
self.tokenizer.padding_side = "left"
self.tokenize_ss = self._tokenize_ss
def __getitem__(self, index):
seq_rna = self.data_x[index]
mfe_rna = ""
ss_category = ""
seq_rna = self.tokenizer(seq_rna,
padding='max_length',
max_length=self.max_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
y = seq_rna.input_ids[0].numpy().copy()
mask_rna, label_mask = self.mask_fn(seq_rna)
x = mask_rna.input_ids[0].numpy()
attention_mask = mask_rna.attention_mask[0].numpy()
return x, mfe_rna,ss_category, attention_mask, y,label_mask
def __len__(self):
return len(self.data_x)
def _tokenize_ss(self,ss_rna):
num_tokens = int(len(ss_rna)/3)
ss_categories = [-1]
for i in range(0, 3*126,3):
if i+3>len(ss_rna):
ss_categories.insert(0, 0)
else:
a1 = ss_rna[-1-i]
a2 = ss_rna[-2-i]
a3 = ss_rna[-3-i]
category = self.dic_voc[f'{a3}{a2}{a1}']
ss_categories.insert(0,category)
ss_categories.insert(0, -2)
return torch.tensor(ss_categories)
def normal_mask(self, one_ab):
input_ids = one_ab.input_ids.numpy()
attention_mask_ids = one_ab.attention_mask[0].numpy()
index_lst = np.where(attention_mask_ids == 1)[0].tolist()
index_lst.pop(0)
mask_index = random.sample(index_lst[15:45], int(self.mask_ratio*0.8 * len(index_lst)))
input_ids[0][mask_index] = self.mask_token_id
mask_index = random.sample(index_lst, int(self.mask_ratio*0.2 * len(index_lst)))
input_ids[0][mask_index] = self.mask_token_id
mask_lst = np.zeros(input_ids[0].shape)
mask_lst[mask_index] = 1
one_ab.input_ids = torch.tensor(input_ids)
return one_ab, torch.tensor(mask_lst)
class DataLoad_downstream(data.Dataset):
def __init__(self, data_path=str,
tokenizer_path=str,
mode='train',
max_length=32,
):
self.mode = mode
self.max_length = max_length
te_data = pd.read_csv(data_path)
#if mode == 'train':
te_data = te_data[te_data['Sequence'].str.len() <= 1024]
seqs = te_data['Sequence'].values.tolist()
seqs = [x.replace('<pad>', '') for x in seqs]
seqs = [x.replace('T', 'U') for x in seqs]
value = te_data['Value'].values.tolist()
states = te_data['Split'].values.tolist()
states = [x if x=='test' else 'train' for x in states ]
self.data = [[x, y] for x, y, z in zip(seqs, value, states) if z == mode]
print(len(self.data))
self.tokenizer = DebertaTokenizerFast.from_pretrained(tokenizer_path, use_fast=True)
self.mask_token_id = self.tokenizer.mask_token_id
self.pad_token_id = self.tokenizer.pad_token_id
self.tokenizer.padding_side = "left"
def __getitem__(self, index):
no2 = random.choice(list(range(len(self.data))))
seq, label = self.data[index]
label = np.array(label, dtype=np.float32)
seq = self.tokenizer(seq,
padding='max_length',
max_length=self.max_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
x = seq.input_ids[0].numpy()
attention_mask = seq.attention_mask[0].numpy()
return x,attention_mask,label
def __len__(self):
return len(self.data)