|
|
import transformers |
|
|
from transformers import DebertaTokenizerFast, T5EncoderModel, T5Config |
|
|
from typing import Callable, Optional, Union, Tuple, List |
|
|
from torch.utils import data |
|
|
import pickle |
|
|
import torch |
|
|
import numpy as np |
|
|
import random |
|
|
import pandas as pd |
|
|
def pkl_load(one_path): |
|
|
with open(one_path, 'rb') as f: |
|
|
return pickle.load(f) |
|
|
|
|
|
class DataLoad(data.Dataset): |
|
|
def __init__(self,data, |
|
|
tokenizer_path='./tokenizer/', |
|
|
mode='train', |
|
|
max_length=128, |
|
|
mask_ratio=0.15, |
|
|
): |
|
|
self.mode = mode |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.data = data |
|
|
self.data_x = [] |
|
|
self.data_y = [] |
|
|
self.data_ss = [] |
|
|
|
|
|
self.samples = data |
|
|
for key,values in self.samples.items(): |
|
|
seq, cds_start, cds_stop, cds_len, mRNA_len = values |
|
|
seq = seq.replace("T", "U") |
|
|
seq = seq.replace("_", "") |
|
|
seq = seq.replace("<","") |
|
|
self.data_x.append(seq) |
|
|
self.data_y.append("") |
|
|
self.data_ss.append("") |
|
|
self.tokenizer = DebertaTokenizerFast.from_pretrained(tokenizer_path, use_fast=True) |
|
|
self.keys = list(self.samples.keys()) |
|
|
self.mask_token_id = self.tokenizer.mask_token_id |
|
|
self.pad_token_id = self.tokenizer.pad_token_id |
|
|
self.mask_ratio = mask_ratio |
|
|
|
|
|
self.mask_fn = self.normal_mask |
|
|
self.max_length = max_length |
|
|
self.tokenizer.padding_side = "left" |
|
|
self.tokenize_ss = self._tokenize_ss |
|
|
|
|
|
def __getitem__(self, index): |
|
|
seq_rna = self.data_x[index] |
|
|
mfe_rna = "" |
|
|
ss_category = "" |
|
|
seq_rna = self.tokenizer(seq_rna, |
|
|
padding='max_length', |
|
|
max_length=self.max_length, |
|
|
truncation=True, |
|
|
add_special_tokens=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
y = seq_rna.input_ids[0].numpy().copy() |
|
|
mask_rna, label_mask = self.mask_fn(seq_rna) |
|
|
x = mask_rna.input_ids[0].numpy() |
|
|
attention_mask = mask_rna.attention_mask[0].numpy() |
|
|
return x, mfe_rna,ss_category, attention_mask, y,label_mask |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data_x) |
|
|
|
|
|
def _tokenize_ss(self,ss_rna): |
|
|
num_tokens = int(len(ss_rna)/3) |
|
|
ss_categories = [-1] |
|
|
for i in range(0, 3*126,3): |
|
|
if i+3>len(ss_rna): |
|
|
ss_categories.insert(0, 0) |
|
|
else: |
|
|
a1 = ss_rna[-1-i] |
|
|
a2 = ss_rna[-2-i] |
|
|
a3 = ss_rna[-3-i] |
|
|
category = self.dic_voc[f'{a3}{a2}{a1}'] |
|
|
ss_categories.insert(0,category) |
|
|
ss_categories.insert(0, -2) |
|
|
return torch.tensor(ss_categories) |
|
|
|
|
|
def normal_mask(self, one_ab): |
|
|
input_ids = one_ab.input_ids.numpy() |
|
|
attention_mask_ids = one_ab.attention_mask[0].numpy() |
|
|
index_lst = np.where(attention_mask_ids == 1)[0].tolist() |
|
|
index_lst.pop(0) |
|
|
mask_index = random.sample(index_lst[15:45], int(self.mask_ratio*0.8 * len(index_lst))) |
|
|
input_ids[0][mask_index] = self.mask_token_id |
|
|
mask_index = random.sample(index_lst, int(self.mask_ratio*0.2 * len(index_lst))) |
|
|
input_ids[0][mask_index] = self.mask_token_id |
|
|
mask_lst = np.zeros(input_ids[0].shape) |
|
|
mask_lst[mask_index] = 1 |
|
|
one_ab.input_ids = torch.tensor(input_ids) |
|
|
return one_ab, torch.tensor(mask_lst) |
|
|
|
|
|
class DataLoad_downstream(data.Dataset): |
|
|
def __init__(self, data_path=str, |
|
|
tokenizer_path=str, |
|
|
mode='train', |
|
|
max_length=32, |
|
|
): |
|
|
self.mode = mode |
|
|
self.max_length = max_length |
|
|
te_data = pd.read_csv(data_path) |
|
|
|
|
|
te_data = te_data[te_data['Sequence'].str.len() <= 1024] |
|
|
seqs = te_data['Sequence'].values.tolist() |
|
|
seqs = [x.replace('<pad>', '') for x in seqs] |
|
|
seqs = [x.replace('T', 'U') for x in seqs] |
|
|
value = te_data['Value'].values.tolist() |
|
|
states = te_data['Split'].values.tolist() |
|
|
states = [x if x=='test' else 'train' for x in states ] |
|
|
self.data = [[x, y] for x, y, z in zip(seqs, value, states) if z == mode] |
|
|
print(len(self.data)) |
|
|
self.tokenizer = DebertaTokenizerFast.from_pretrained(tokenizer_path, use_fast=True) |
|
|
self.mask_token_id = self.tokenizer.mask_token_id |
|
|
self.pad_token_id = self.tokenizer.pad_token_id |
|
|
self.tokenizer.padding_side = "left" |
|
|
|
|
|
def __getitem__(self, index): |
|
|
no2 = random.choice(list(range(len(self.data)))) |
|
|
seq, label = self.data[index] |
|
|
label = np.array(label, dtype=np.float32) |
|
|
seq = self.tokenizer(seq, |
|
|
padding='max_length', |
|
|
max_length=self.max_length, |
|
|
truncation=True, |
|
|
add_special_tokens=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
x = seq.input_ids[0].numpy() |
|
|
attention_mask = seq.attention_mask[0].numpy() |
|
|
|
|
|
|
|
|
return x,attention_mask,label |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|