| | |
| | |
| | """ |
| | Title : tokenizer.py |
| | project : minimind_RiboUTR |
| | Created by: julse |
| | Created on: 2025/2/12 16:40 |
| | des: TODO |
| | """ |
| | from typing import List |
| |
|
| | import argparse |
| | import os |
| |
|
| | import pickle |
| | import random |
| | import re |
| | import time |
| | from collections import defaultdict |
| | from itertools import chain |
| | from random import shuffle |
| |
|
| | import math |
| | import numpy as np |
| | import pandas as pd |
| | import torch |
| | from torch.utils.data import Dataset, DataLoader |
| | import torch.nn.functional as F |
| |
|
| | import transformers |
| | from copy import copy, deepcopy |
| |
|
| | from model.codon_attr import Codon |
| |
|
| | |
| | os.chdir('../../') |
| | |
| | import sys |
| |
|
| | from utils.ernie_rna.dictionary import Dictionary |
| | from utils.ernie_rna.position_prob_mask import calculate_mask_prob |
| | from transformers import DebertaTokenizerFast |
| | from model.codon_tables import CODON_TO_AA, AA_str, AA_TO_CODONS, reverse_dictionary, create_codon_mask |
| |
|
| | |
| |
|
| | base_range_lst = [1] |
| | lamda_lst = [0.8] |
| |
|
| | import torch |
| | from torch.utils.data import Dataset |
| | import numpy as np |
| | import pandas as pd |
| |
|
| |
|
| |
|
| | class BaseDataset(Dataset): |
| | """公共基类,包含共享属性和方法""" |
| | def __init__( |
| | self, |
| | tokenizer, |
| | region: int = 300, |
| | limit: int = -1, |
| | return_masked_tokens: bool = False, |
| | seed: int = 1, |
| | mask_prob: float = 0.15, |
| | leave_unmasked_prob: float = 0.1, |
| | random_token_prob: float = 0.1, |
| | freq_weighted_replacement: bool = False, |
| | two_dim_score: bool = False, |
| | two_dim_mask: int = -1, |
| | mask_whole_words: torch.Tensor = None, |
| | ): |
| | |
| | assert 0.0 < mask_prob < 1.0 |
| | assert 0.0 <= random_token_prob <= 1.0 |
| | assert 0.0 <= leave_unmasked_prob <= 1.0 |
| | assert random_token_prob + leave_unmasked_prob <= 1.0 |
| |
|
| | |
| | self.tokenizer = tokenizer |
| | self.pad_idx = tokenizer.pad_index |
| | self.mask_idx = tokenizer.mask_index |
| | self.return_masked_tokens = return_masked_tokens |
| | self.seed = seed |
| | self.mask_prob = mask_prob |
| | self.leave_unmasked_prob = leave_unmasked_prob |
| | self.random_token_prob = random_token_prob |
| | self.two_dim_score = two_dim_score |
| | self.two_dim_mask = two_dim_mask |
| | self.mask_whole_words = mask_whole_words |
| | self.region = region |
| | self.limit = limit |
| |
|
| | |
| | if random_token_prob > 0.0: |
| | weights = np.array(tokenizer.count) if freq_weighted_replacement else np.ones(len(tokenizer)) |
| | weights[: tokenizer.nspecial] = 0 |
| | self.weights = weights / weights.sum() |
| |
|
| | self.tokenizer.indices['T']=self.tokenizer.indices['U'] |
| | self.amino_acid_to_codons = {} |
| | for aa, codons in AA_TO_CODONS.items(): |
| | codons_num = [] |
| | for codon in codons: |
| | codon_num = [] |
| | for base in codon: |
| | codon_num.append(self.tokenizer.indices[base]) |
| | codons_num.append(codon_num) |
| | self.amino_acid_to_codons[self.tokenizer.indices[aa.lower()]] = codons_num |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | @staticmethod |
| | def translate(nucleotide_seq,repeate=3): |
| | amino_acid_list = [] |
| | for i in range(0, len(nucleotide_seq), 3): |
| | codon = nucleotide_seq[i:i + 3] |
| | amino_acid_list.append(CODON_TO_AA.get(codon, '-')*repeate) |
| | amino_acid_seq = ''.join(amino_acid_list) |
| | return amino_acid_seq |
| |
|
| | @staticmethod |
| | def prepare_input_for_ernierna(index, seq_len): |
| | if index.ndim == 2: |
| | index = np.squeeze(index) |
| | shorten_index = index[:seq_len] |
| | one_d = torch.from_numpy(shorten_index).long().reshape(1, -1) |
| | two_d = np.zeros((1, seq_len, seq_len)) |
| | two_d[0, :, :] = creatmat(shorten_index.astype(int), base_range=1, lamda=0.8) |
| | |
| | two_d = two_d.transpose(1, 2, 0) |
| | two_d = torch.from_numpy(two_d).reshape(1, seq_len, seq_len, 1) |
| | return one_d, two_d |
| | def generate_inputs(self,x): |
| | region = self.region |
| | |
| | |
| | |
| |
|
| | utr5 = x["UTR5"] |
| | utr3 = x["UTR3"] |
| | cds = x["CDS"] |
| | seq = utr5 + cds + utr3 |
| | cds_start = len(utr5) |
| | cds_stop = len(utr5) + len(cds) |
| |
|
| | |
| | |
| | |
| |
|
| | utr5_limit = 300 if region > 300 else region |
| | |
| | seq = self.process_sequence(seq, cds_start, cds_stop, region, '_', '<', '>', 'N', utr5_limit) |
| | return seq |
| |
|
| | def process_sequence(self,seq, cds_start, cds_stop, region, pad_mark, bos, eos,link,utr5_limit): |
| | utr5 = seq[:cds_start] |
| | cds = seq[cds_start:cds_stop] |
| | utr3 = seq[cds_stop:] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | utr5 = self.process_utr(utr5, utr5_limit, 'pre', pad_mark=pad_mark, bos=bos, eos=eos) |
| | cds_h = self.process_utr(cds, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos) |
| | cds_t = self.process_utr(cds, region, 'pre', pad_mark=pad_mark, bos=bos, eos=eos) |
| | utr3 = self.process_utr(utr3, region, 'behind', pad_mark=pad_mark, bos=bos, eos=eos) |
| | seq = utr5 + cds_h + cds_t + utr3 |
| | seq = seq[:utr5_limit+region+1]+link*3+seq[-region*2-1:] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if isinstance(seq,list): |
| | seq = np.array(seq) |
| | return seq |
| |
|
| |
|
| |
|
| | @staticmethod |
| | def process_utr(utr, input_len, pad_method, pad_mark='_',bos='<',eos='>'): |
| | if len(utr) < input_len: |
| | if pad_method == 'pre': |
| | padded_utr = pad_mark * (input_len - len(utr)) + bos + utr |
| | elif pad_method == 'behind': |
| | padded_utr = utr+eos + pad_mark * (input_len - len(utr)) |
| | else: |
| | if pad_method == 'pre': |
| | padded_utr = bos+utr[-input_len:] |
| | elif pad_method == 'behind': |
| | padded_utr = utr[:input_len]+eos |
| | return padded_utr |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | @staticmethod |
| | def seq_to_rnaindex(seq,pad_idx=1, unk_idx=3): |
| | |
| |
|
| | |
| | |
| | |
| | l = len(seq) |
| | X = np.ones((1, l)) |
| | for j in range(l): |
| | if seq[j] in set('Aa'): |
| | X[0, j] = 5 |
| | elif seq[j] in set('UuTt'): |
| | X[0, j] = 6 |
| | elif seq[j] in set('Cc'): |
| | X[0, j] = 7 |
| | elif seq[j] in set('Gg'): |
| | X[0, j] = 4 |
| | elif seq[j] in set('_'): |
| | X[0,j] = pad_idx |
| | elif seq[j] in set('<'): |
| | X[0,j] = 0 |
| | elif seq[j] in set('>'): |
| | X[0,j] = 2 |
| | else: |
| | X[0,j] = unk_idx |
| |
|
| | |
| | |
| | return X |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | def generate_mask(self, X,seq_len,mask=None,input_mask=True): |
| | """ |
| | |
| | :param X: seuqnce in number index |
| | :param seq_len: |
| | :param mask: 1d mask array, if None, generate by dual center mask |
| | :param input_mask: true: use 2d mask |
| | :return: |
| | """ |
| | one_d, twod_d = self.prepare_input_for_ernierna(X, seq_len) |
| | |
| | '''generate src_data, tgt_data, twod_data ''' |
| | item = one_d.view(-1) |
| |
|
| | assert ( |
| | self.mask_idx not in item |
| | ), "Dataset contains mask_idx (={}), this is not expected!".format( |
| | self.mask_idx, |
| | ) |
| |
|
| | if self.mask_whole_words is not None: |
| | word_begins_mask = self.mask_whole_words.gather(0, item) |
| | word_begins_idx = word_begins_mask.nonzero().view(-1) |
| | sz = len(word_begins_idx) |
| | words = np.split(word_begins_mask, word_begins_idx)[1:] |
| | assert len(words) == sz |
| | word_lens = list(map(len, words)) |
| |
|
| | sz = len(item) |
| | |
| | if mask is None: |
| | mask = np.full(sz, False) |
| |
|
| | |
| | non_pad_indices = np.where( |
| | (item != self.tokenizer.pad_index) & |
| | (item != self.tokenizer.unk_index) |
| | )[0] |
| |
|
| | |
| | num_non_pad = len(non_pad_indices) |
| | num_mask = int( |
| | self.mask_prob * num_non_pad + np.random.rand() |
| | ) |
| |
|
| | |
| | target_positions = [self.region + 1, self.region *3 + 4] |
| | sigma = 90 |
| | probabilities = np.array([calculate_mask_prob(i, target_positions, sigma) for i in range(sz)]) |
| | non_pad_probabilities = probabilities[non_pad_indices] |
| | non_pad_probabilities = non_pad_probabilities/non_pad_probabilities.sum() |
| | if num_non_pad >= 1: |
| | mask[non_pad_indices[np.random.choice(num_non_pad, num_mask, replace=False,p=non_pad_probabilities)]] = True |
| | mask[target_positions[0]:target_positions[0]+3]=False |
| | mask[target_positions[1]-3:target_positions[1]]=False |
| | mask[target_positions[0]+300:target_positions[0]+303]=False |
| |
|
| | |
| | rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob |
| | if rand_or_unmask_prob > 0.0: |
| | rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob) |
| | if self.random_token_prob == 0.0: |
| | unmask = rand_or_unmask |
| | rand_mask = None |
| | elif self.leave_unmasked_prob == 0.0: |
| | unmask = None |
| | rand_mask = rand_or_unmask |
| | else: |
| | unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob |
| | decision = np.random.rand(sz) < unmask_prob |
| | unmask = rand_or_unmask & decision |
| | rand_mask = rand_or_unmask & (~decision) |
| | else: |
| | unmask = rand_mask = None |
| |
|
| | if unmask is not None: |
| | mask = mask ^ unmask |
| | |
| | if input_mask: |
| | twod_data = self.get_twod_data(item,twod_d.detach(),mask) |
| | else: |
| | twod_data = self.get_twod_data(item,twod_d.detach(),np.zeros_like(mask)) |
| |
|
| | if self.mask_whole_words is not None: |
| | mask = np.repeat(mask, word_lens) |
| |
|
| | |
| | |
| | |
| | tgt_data = item |
| |
|
| |
|
| | if self.mask_whole_words is not None: |
| | mask = np.repeat(mask, word_lens) |
| |
|
| | new_item = np.copy(item) |
| | new_item[mask] = self.mask_idx |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | src_data = torch.from_numpy(new_item) |
| | |
| | return src_data,tgt_data,twod_data,mask |
| | def get_twod_data(self,item,twod_d,mask): |
| | two_dim_matrix =torch.squeeze(twod_d, dim=-1).numpy() |
| | |
| | |
| | padding_dim = 0 |
| | for base_range in base_range_lst: |
| | for lamda in lamda_lst: |
| | new_matrix = creatmat(item.numpy(), base_range, lamda) |
| | new_matrix[mask==1, :] = -1 |
| | new_matrix[:, mask==1] = -1 |
| | two_dim_matrix[padding_dim, :, :] = new_matrix |
| | padding_dim += 1 |
| | |
| | |
| | |
| | |
| | twod_data = torch.from_numpy(two_dim_matrix) |
| | return twod_data |
| | @staticmethod |
| | def read_text_file(file_path): |
| | try: |
| | with open(file_path, 'r') as file: |
| | return [line.strip() for line in file] |
| | except FileNotFoundError: |
| | print(f"Error: File '{file_path}' not found.") |
| | return [] |
| |
|
| | @staticmethod |
| | def create_base_prob(target_protein,ith_nn_prob,rna_alphabet,tokenizer): |
| | mask_nn_logits = torch.full(size=(len(target_protein)*3,len(tokenizer)),fill_value=float("-inf")) |
| | for i,a in enumerate(target_protein): |
| | if a not in ith_nn_prob[0]: continue |
| | for j in range(3): |
| | for n in rna_alphabet: |
| | mask_nn_logits[i*3+j,tokenizer.index(n)] = ith_nn_prob[j][a][n] |
| | return mask_nn_logits |
| | @staticmethod |
| | def create_codon_mask(target_protein, backbone_cds, amino_acid_to_codons, |
| | tokenizer): |
| | |
| | |
| | seq_length = len(backbone_cds) |
| | vocab_size = len(tokenizer) |
| | |
| | mask = torch.full(size=(seq_length,vocab_size),fill_value=float("-inf")) |
| | for i, amino_acid in enumerate(target_protein): |
| | codon_start = i * 3 |
| | codon_end = codon_start + 3 |
| |
|
| | if codon_end > seq_length: |
| | continue |
| |
|
| | possible_codons = amino_acid_to_codons.get(amino_acid.item(), []) |
| | |
| | for pos in range(codon_start, codon_end): |
| | base_pos = pos % 3 |
| | for codon in possible_codons: |
| | flag = True |
| | for j, nt in enumerate(backbone_cds[codon_start:codon_end]): |
| | nt = nt.item() |
| | if tokenizer.mask_index == nt: continue |
| | if codon[j] != nt: |
| | flag = False |
| | |
| | if flag: |
| | base_idx = codon[base_pos] |
| | mask[pos, base_idx] = 0 |
| | |
| | return mask |
| | |
| |
|
| | |
| | def load_data(self, path, **kwargs): |
| | raise NotImplementedError("Subclasses must implement load_data") |
| |
|
| | def __getitem__(self, idx): |
| | raise NotImplementedError("Subclasses must implement __getitem__") |
| | |
| |
|
| | class RNADataset(BaseDataset): |
| | """处理RNA序列的Dataset""" |
| | def __init__( |
| | self, |
| | path, |
| | tokenizer, |
| | region: int = 300, |
| | limit: int = -1, |
| | return_masked_tokens: bool = False, |
| | seed: int = 1, |
| | mask_prob: float = 0.15, |
| | leave_unmasked_prob: float = 0.1, |
| | random_token_prob: float = 0.1, |
| | freq_weighted_replacement: bool = False, |
| | two_dim_score: bool = False, |
| | two_dim_mask: int = -1, |
| | mask_whole_words: torch.Tensor = None, |
| | ): |
| | |
| | super().__init__( |
| | tokenizer=tokenizer, |
| | region=region, |
| | limit=limit, |
| | return_masked_tokens=return_masked_tokens, |
| | seed=seed, |
| | mask_prob=mask_prob, |
| | leave_unmasked_prob=leave_unmasked_prob, |
| | random_token_prob=random_token_prob, |
| | freq_weighted_replacement=freq_weighted_replacement, |
| | two_dim_score=two_dim_score, |
| | two_dim_mask=two_dim_mask, |
| | mask_whole_words=mask_whole_words, |
| | ) |
| |
|
| | |
| | self.samples = self.load_data(path, region=self.region, limit=limit) |
| |
|
| | def load_data(self, path, region=300, limit=-1): |
| | return self.read_fasta_file(path, region=region, limit=limit) |
| |
|
| | @staticmethod |
| | def read_fasta_file(file_path, region=300, cds_min=100, limit=-1): |
| | ''' |
| | input: |
| | file_path: str, fasta file path of input seqs |
| | |
| | return: |
| | seqs_dict: dict[str], dict of seqs |
| | |
| | { |
| | 'ENST00000231420.11': { # 转录本的标识符 |
| | 'cds_start': 57, # CDS的起始位置(基于0的索引) |
| | 'cds_stop': 1599, # CDS的终止位置(不包括该位置,基于0的索引) |
| | 'full': 'AGTTAGAGCCCGGCCTCCAATCTGCTTCCATGGGGTTGGCTTTCTGAGTGGGAGAAATGACTCTAATCTGGAGACA...', # 完整的mRNA序列 |
| | 'start_context': '___GAAATGTCT', # CDS起始位置前的序列上下文, padding left _,essential |
| | 'stop_context': 'AAGTAAGGG___' # CDS终止位置后的序列上下文, padding right _, essential |
| | } |
| | } |
| | |
| | ''' |
| | |
| | |
| | try: |
| | with open(file_path) as fa: |
| | seqs_dicts = [] |
| | cds_start = 0 |
| | cds_stop = 0 |
| | count = 0 |
| | seq_name = '' |
| | |
| | for line in fa: |
| | line = line.replace('\n', '') |
| | if line.startswith('>'): |
| | transcript_id, gene_id, cds_start, cds_stop = line[1:].split( |
| | ' ') |
| | cds_start = int(cds_start) |
| | cds_stop = int(cds_stop) |
| | if cds_stop - cds_start < cds_min: continue |
| | seq_name = transcript_id |
| | |
| | |
| | |
| |
|
| | else: |
| | expand_mRNA = '_' * region + line + '_' * region |
| | cds_start += region |
| | cds_stop += region |
| | |
| | start_context = expand_mRNA[cds_start - region:cds_start + region] |
| | stop_context = expand_mRNA[cds_stop - region:cds_stop + region] |
| | seqs_dicts.append( |
| | {'_id': seq_name, 'start_context': start_context, 'stop_context': stop_context}) |
| | count += 1 |
| | if count > limit and limit != -1: break |
| | return seqs_dicts |
| | except FileNotFoundError: |
| | print(f"Error: File '{file_path}' not found.") |
| | return [] |
| | def __len__(self): |
| | return len(self.samples) |
| |
|
| | def __getitem__(self, idx): |
| | ''' |
| | GAUC 4567 |
| | unk 3 |
| | :param idx: |
| | :return: |
| | ''' |
| | sample = self.samples[idx] |
| | seq = sample['start_context'] + 'NNN' + sample['stop_context'] |
| | X = self.seq_to_rnaindex(seq, pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index) |
| | if '_' in sample['start_context']: |
| | X[:, sample['start_context'].count('_')] = self.tokenizer.bos_index |
| | if '_' in sample['stop_context']: |
| | X[:, -sample['stop_context'].count('_')-1] = self.tokenizer.eos_index |
| |
|
| | '''generate src_data, tgt_data, twod_data ''' |
| | src_data,tgt_data,twod_data,loss_mask = self.generate_mask(X,len(seq)) |
| | return src_data,tgt_data,twod_data,loss_mask |
| |
|
| | class RiboDataPipeline(): |
| | """ |
| | 处理预训练任务的Dataset,生成mRNA.fa,加载ribosome_density, ribo_counts, rna_counts,划分TR,VL,TS |
| | Loading from origin bw |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | path, |
| | ribo_experiment,rna_experiment, |
| | seq_only=False, |
| | region: int = 300, |
| | cds_min: int = -1, |
| | limit: int = -1, |
| | env : int = 0, |
| | norm = True |
| | ): |
| |
|
| | self.seq_only = seq_only |
| | self.cds_min = cds_min |
| | self.env = env |
| | |
| | |
| | |
| | |
| | self.reference_transcript_dict = {} |
| | |
| | self.samples = self.load_data(path, ribo_experiment=ribo_experiment,rna_experiment=rna_experiment,region=region, limit=limit,norm=norm) |
| | |
| |
|
| | def load_data(self, path, ribo_experiment=None,rna_experiment=None,region=300, limit=-1,norm=True): |
| | ''' |
| | 读取数据 |
| | 1. 根据ribo_experiment,从meta中查询species,avg_len,total counts 等 |
| | 2. 查询mRNA.fa文件是否存在 |
| | 不存在: |
| | 查询mRNA.tsv是否存在 |
| | 不存在: |
| | genome.gtf 生成 mRNA.gtf文件 (只含有mRNA相关的行和列,scale the size of gtf) |
| | 根据genome.fa 和mRNA.gtf 生成mRNA.fa文件 (包括start or stop codon positions) |
| | 3. 读取track文件,生成ribosome_density, ribo_counts, rna_counts |
| | |
| | {'ENST00000303577.7': 'PCBP1', # IRES |
| | 'ENST00000309311.7': 'EEF2'} # cap dependent |
| | |
| | :param path: |
| | :param reference_path: |
| | :param region: |
| | :param limit: |
| | :return: samples |
| | |
| | path = ./dataset/pretraining/ |
| | ''' |
| | """1. input ribo_experiment, meta""" |
| | seq_only = self.seq_only |
| | cds_min = self.cds_min |
| | |
| |
|
| | reference_path = os.path.join(path,'reference') |
| | meta = self.read_meta_file(os.path.join(reference_path, 'experiment_meta.tsv'),ribo_experiment,rna_experiment,seq_only = seq_only) |
| | totalNumReads_RNA, totalNumReads_RPF, readsLength_RNA,readsLength_RPF,species = meta |
| |
|
| | fribo_track = os.path.join(path, 'track', f'{ribo_experiment}.bw') |
| | frna_track = os.path.join(path, 'track', f'{rna_experiment}.bw') |
| | if not seq_only: |
| | if os.access(fribo_track,os.F_OK) and os.access(frna_track,os.F_OK): |
| | print(f'load {ribo_experiment} and {rna_experiment} tracks') |
| | else: |
| | print(f'Error: {fribo_track} or {frna_track} not found.') |
| | return None |
| |
|
| | """2. check mRNA.fa, .pkl""" |
| | |
| | |
| | mrna_fa_path = os.path.join(reference_path, species, f'mRNA.fa') |
| | if region != -1: mrna_fa_path = mrna_fa_path.replace('.fa', f'_{region}.fa') |
| | mrna_fa_path = mrna_fa_path.replace('.fa', f'.pkl') |
| | self.mrna_region_pkl_path = mrna_fa_path |
| | |
| | if seq_only and os.access(mrna_fa_path, os.F_OK): |
| | with open(mrna_fa_path, 'rb') as f: |
| | sample_dict = pickle.load(f) |
| | limited_sample_dict = {} |
| | for key in sample_dict.keys(): |
| | |
| | |
| | |
| |
|
| | if limit != -1: |
| | limited_sample_dict[key] = sample_dict[:limit] |
| | else: |
| | limited_sample_dict[key] = sample_dict[key] |
| |
|
| | return limited_sample_dict |
| | |
| | mrna_tsv_path = os.path.join(reference_path,species, 'mRNA.tsv') |
| | if not os.access(mrna_tsv_path, os.F_OK): |
| | genome_gtf_path = os.path.join(reference_path,species, 'genome.gtf') |
| | genome_fa_path = os.path.join(reference_path,species, 'genome.fa') |
| | mrna_tsv = self.generate_mRNA_tsv(genome_gtf_path,genome_fa_path,mrna_tsv_path) |
| | else: |
| | mrna_tsv = pd.read_table(mrna_tsv_path) |
| |
|
| | |
| | """3. read track files""" |
| | |
| | |
| | |
| | |
| | print(f'filter limit={limit},region={region}') |
| | print('load_data in Pipeline,before filter',mrna_tsv.shape) |
| | reference_transcript_ids = list(self.reference_transcript_dict.keys()) |
| | keeping_transcript_ids = mrna_tsv[mrna_tsv['seqname'].isin(['chr10','chr15'])].transcript_id.unique().tolist() |
| |
|
| | print('keeping transcript_ids',len(reference_transcript_ids+keeping_transcript_ids)) |
| | |
| | |
| | |
| | if limit!=-1: |
| | other_transcript_ids = mrna_tsv[~mrna_tsv['transcript_id'].isin(reference_transcript_ids+keeping_transcript_ids)].transcript_id.unique().tolist() |
| | shuffle(other_transcript_ids) |
| | shuffle(keeping_transcript_ids) |
| |
|
| | mrna_tsv = mrna_tsv[mrna_tsv.transcript_id.isin(reference_transcript_ids+keeping_transcript_ids[:limit]+other_transcript_ids[:limit])] |
| |
|
| | print('load_data in Pipeline, after filter',mrna_tsv.shape) |
| |
|
| | if not seq_only: |
| | import pyBigWig |
| | ribo_bw,rna_bw = [pyBigWig.open(fribo_track), pyBigWig.open(frna_track)] |
| | print(f'meta of {ribo_experiment} and {rna_experiment} tracks loaded\n{ribo_bw.header()}\n{rna_bw.header()}') |
| | def iterfunc(x,bw): |
| | chrom, start, end = x['seqname'], x['start'], x['end'] |
| | if chrom in bw.chroms(): |
| | return np.array(bw.values(chrom, start - 1, end)) |
| | else: |
| | return np.zeros(end - start) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | mrna_tsv['ribo_counts'] = mrna_tsv.apply( lambda x: iterfunc(x,ribo_bw), axis=1) |
| | mrna_tsv['rna_counts'] = mrna_tsv.apply(lambda x: iterfunc(x,rna_bw), axis=1) |
| |
|
| | |
| | |
| |
|
| | |
| | ribo_bw.close() |
| | rna_bw.close() |
| | del ribo_bw,rna_bw |
| |
|
| |
|
| | """split dataset todo""" |
| | if species == 'GRCh38.p14': |
| | mappingdict = {'chr10': 'VL', 'chr15': 'TS'} |
| | mrna_tsv['dataset'] = mrna_tsv['seqname'].apply(lambda x: mappingdict[x] if x in mappingdict else 'TR') |
| | else:mrna_tsv['dataset'] = 'TR' |
| | print(f'{int(mrna_tsv.transcript_id.nunique())} transcripts in mRNA.tsv') |
| | """4. filter mRNA.tsv by transcript_id""" |
| | sample_dict = defaultdict(list) |
| | count = 0 |
| | total_counts_info = [totalNumReads_RNA, totalNumReads_RPF, readsLength_RNA, readsLength_RPF] |
| |
|
| | '''ref_norm''' |
| | ref_norm = None |
| | if not seq_only: |
| | ref_norm = [] |
| | for transcript_id, data in mrna_tsv[mrna_tsv.transcript_id.isin(reference_transcript_ids)].groupby('transcript_id'): |
| | ans = self.merge_transcript_level(data,total_counts_info=total_counts_info,seq_only=seq_only,cds_min=cds_min,region=region) |
| | if ans is None:continue |
| | (seq, |
| | cds_start, cds_stop, |
| | |
| | ribo_counts, |
| | rna_counts, |
| | ribosome_density, |
| | te, self.env, cds_len, mRNA_len,junction_counts) = ans |
| | ref_norm.append((sum(ribo_counts)/cds_len/readsLength_RPF,sum(rna_counts)/mRNA_len/readsLength_RNA)) |
| | if len(ref_norm)==0 and norm: |
| | print(f'Error: no qualified reference transcript (housekeeping when norm=True)') |
| | return None |
| | ref_norm = np.mean(ref_norm,axis=0) if norm and len(ref_norm)>0 else None |
| | print('ref_norm',ref_norm,'sum(ribo_counts)/cds_len/readsLength_RPF,sum(rna_counts)/mRNA_len/readsLength_RNA)') |
| | '''generate by norm''' |
| | for transcript_id, data in mrna_tsv.groupby('transcript_id'): |
| | tag = data['dataset'].iloc[0] |
| | ans = self.merge_transcript_level(data,total_counts_info=total_counts_info,seq_only=seq_only,cds_min=cds_min,region=region,ref_norm=ref_norm) |
| | if ans is None:continue |
| | sample_dict[tag].append([transcript_id] + ans) |
| | count += 1 |
| | if limit == count: break |
| | |
| | |
| | |
| | self.samples = sample_dict |
| | if seq_only: |
| | mrna_fa_path = os.path.join(reference_path, species, f'mRNA.fa') |
| | if region !=-1: |
| | mrna_fa_path = mrna_fa_path.replace('.fa',f'_{region}.fa') |
| | if not (os.access(mrna_fa_path,os.F_OK) and os.path.getsize(mrna_fa_path)>0): |
| | print(f'generate {sum([len(a) for a in sample_dict.values()])} sequences to {mrna_fa_path} {os.path.abspath(mrna_fa_path)}') |
| | self.generate_mRNA_fa(mrna_fa_path,sample_dict,force_regenerate=True) |
| | mrna_fa_path = mrna_fa_path.replace('.fa',f'.pkl') |
| | if not os.access(mrna_fa_path, os.F_OK): |
| | with open(mrna_fa_path, 'wb') as f: |
| | pickle.dump(sample_dict, f) |
| | self.mrna_region_pkl_path = mrna_fa_path |
| | return sample_dict |
| | def utr5_limit(self,args,x,region): |
| | utr5_limit = 300 if args.region>300 else args.region |
| | seq = list( x[region - utr5_limit:region + 1 + args.region] \ |
| | + 'NNN' + x[3 * region + 4 - args.region:3 * region + 4 + args.region+1]) |
| | if seq[-1] not in {'_','>'}:seq[-1]='>' |
| | if seq[0] not in {'_','<'}:seq[0]='<' |
| | return seq |
| | def merge_transcript_level(self,data,total_counts_info=None,seq_only=False,cds_min=-1,region=300,ref_norm=None): |
| | |
| | ans = self.qualified_samples(data, seq_only=seq_only, cds_min=cds_min) |
| | junction_counts = len(data[data['feature'] == 'CDS']) |
| | if ans is not None: |
| | seq, cds_start, cds_stop, ribo_counts, rna_counts, anno, metadict = ans |
| | cds_len = cds_stop - cds_start |
| | mRNA_len = len(seq) |
| | if region!=-1: |
| | utr5_limit = 300 if region > 300 else region |
| | seq = self.process_sequence(seq, cds_start, cds_stop, region, '_', '<', '>', 'N',utr5_limit) |
| | |
| | if metadict is not None: |
| | totalNumReads_RNA, totalNumReads_RPF, readsLength_RNA, readsLength_RPF = total_counts_info |
| | if metadict['ribo_recovery'] > 0.9 and metadict['rna_recovery'] > 0.9: |
| | te = self.calculate_ribosome_density(metadict['ribo_avg_count'], metadict['rna_avg_count'], |
| | totalNumReads_RNA, totalNumReads_RPF, |
| | readsLength_RNA, readsLength_RPF) |
| | te = float(te) |
| | else: |
| | te = -1 |
| | |
| |
|
| | |
| | anno = self.process_sequence(anno, cds_start, cds_stop, region, '_', '<', '>', 'N',utr5_limit) |
| | ribo_counts\ |
| | = self.process_sequence(ribo_counts, cds_start, cds_stop, region, [-1], [-1], [-1], [-1],utr5_limit) |
| | rna_counts = self.process_sequence(rna_counts, cds_start, cds_stop, region, [-1], [-1], [-1], [-1],utr5_limit) |
| |
|
| | if sum(ribo_counts[ribo_counts != -1]) <= 100 or sum( |
| | rna_counts[rna_counts != -1]) <= 100: |
| | |
| | return None |
| |
|
| | ''' |
| | normalized by total counts |
| | https://rcxqhxlmkf.feishu.cn/docx/MdEvd008poMIaexhX9Xc7EAEnth#share-SNGtdmaQ2oATE0xax1Nc6m3jnbp |
| | ''' |
| |
|
| | ribo_counts += 1 |
| | rna_counts += 1 |
| | ribosome_density = deepcopy(ribo_counts) |
| | ribosome_density[ribosome_density != 0] = self.calculate_ribosome_density( |
| | ribo_counts[ribo_counts != 0], rna_counts[rna_counts != 0], totalNumReads_RNA, totalNumReads_RPF, |
| | readsLength_RNA, readsLength_RPF) |
| | if ref_norm is not None: |
| | ribo_counts, rna_counts = ribo_counts / (ref_norm[0] * readsLength_RPF), rna_counts / ( |
| | ref_norm[1] * readsLength_RNA) |
| | |
| | cds_start, cds_stop = anno.index('|'), anno.rindex('|', 1) + 4 |
| | return [seq, |
| | cds_start, cds_stop, |
| | |
| | ribo_counts, |
| | rna_counts, |
| | ribosome_density, |
| | te, self.env,cds_len,mRNA_len,junction_counts] |
| | return [seq, cds_start, cds_stop,cds_len,mRNA_len,junction_counts] |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def generate_mRNA_fa(self,mrna_fa_path,sample_dict,force_regenerate=False): |
| | '''for pretrain''' |
| | if force_regenerate: |
| | '''generate mRNA.fa''' |
| | print('generate mRNA.fa to',mrna_fa_path) |
| | with open(mrna_fa_path, 'w') as f: |
| | |
| | for tag, data in sample_dict.items(): |
| | for transcript_id, seq, cds_start, cds_stop, cds_len,mRNA_len,*_ in data: |
| | |
| | f.write(f">{transcript_id}|cds_start={cds_start}|cds_stop={cds_stop}|cds_len={cds_len}|mRNA_len={mRNA_len}|dataset={tag}\n{re.sub(r'[^ACGT]', 'N', seq.replace('U','T'))}\n") |
| | |
| | |
| |
|
| | @staticmethod |
| | def qualified_samples(data,seq_only=False,cds_min=-1): |
| | """ |
| | 过滤掉不合格的样本 |
| | :param df_total_counts: |
| | :return: |
| | """ |
| | """load elements""" |
| | strand = data['strand'].iloc[0] |
| | num_start = data[data.feature == 'start_codon'].shape[0] |
| | num_stop = data[data.feature == 'stop_codon'].shape[0] |
| | if num_start == 0 or num_stop == 0: |
| | |
| | return None |
| |
|
| | data = data[(data.feature!='start_codon') & (data.feature!='stop_codon')] |
| | seq = ''.join(list(chain(*data['seq']))) |
| | anno = ''.join(list(chain(*data['anno']))) |
| |
|
| | if not seq_only: |
| | ribo_counts = list(chain(*data['ribo_counts'])) |
| | rna_counts = list(chain(*data['rna_counts'])) |
| | if sum(ribo_counts) == 0 or sum(rna_counts) == 0: |
| | |
| | return None |
| | |
| | if strand == '-': |
| | from pyfaidx import complement |
| | seq = complement(seq[::-1]) |
| | anno = anno[::-1] |
| | if not seq_only: |
| | ribo_counts = ribo_counts[::-1] |
| | rna_counts = rna_counts[::-1] |
| | |
| | cds_start = anno.index('|') |
| | cds_stop = anno.rindex('|') + 4 |
| | if cds_min!=-1: |
| | if cds_stop - cds_start < cds_min: |
| | |
| | return None |
| | trible = anno.count('|') % 3 |
| | if trible != 0: return None |
| |
|
| | if not seq_only: |
| | |
| | metadict = dict() |
| | counts = np.array([ribo_counts, rna_counts]) |
| | t = counts[:, cds_start:cds_stop] > 0 |
| | metadict['cds_len'] = cds_stop - cds_start |
| | metadict['ribo_recovery'], metadict['rna_recovery'] = t.sum(axis=1) / metadict['cds_len'] |
| | metadict['ribo_avg_count'], metadict['rna_avg_count'] = counts.sum(axis=1) / metadict['cds_len'] |
| | if seq_only: |
| | return seq, cds_start, cds_stop,None,None,anno,None |
| | |
| | |
| | return seq,cds_start,cds_stop,ribo_counts,rna_counts,anno,metadict |
| |
|
| | @staticmethod |
| | def generate_mRNA_tsv(genome_gtf_path,genome_fa_path,mrna_tsv_path): |
| | |
| | |
| | from gtfparse import read_gtf |
| | from pyfaidx import Fasta |
| | import polars as pl |
| | gtf = read_gtf(genome_gtf_path) |
| | |
| | features_to_keep = 'CDS,UTR,start_codon,stop_codon,five_prime_utr,three_prime_utr'.split(',') |
| | columns_to_keep = ['seqname','gene_id','transcript_id','protein_id','transcript_type','start', 'end', 'feature','strand'] |
| | gtf = gtf.filter(pl.col("feature").is_in(features_to_keep)) |
| | gtf = gtf.select(columns_to_keep) |
| | gtf = gtf.to_pandas() |
| | gtf = gtf.sort_values(by=['seqname', 'start','end']) |
| | |
| | genome_fa = Fasta(genome_fa_path) |
| | gtf['seq'] = gtf.apply(lambda x: genome_fa[x['seqname']][x['start']-1:x['end']].seq, axis=1) |
| | gtf['anno'] = gtf.apply(lambda x: '-'* (x['end'] - x['start']+1) if x['feature'] in ['UTR','five_prime_utr','three_prime_utr'] else '|'*(x['end']-x['start']+1) , axis=1) |
| | gtf.to_csv(mrna_tsv_path,index=None,sep='\t') |
| | del genome_fa |
| |
|
| | print(f"generate mRNA.tsv file: {mrna_tsv_path}\n{gtf.shape}\t{gtf[['seq','anno']].head()}") |
| | return gtf |
| | @staticmethod |
| | def calculate_ribosome_density(numReads_RPF, numReads_RNA, totalNumReads_RNA, totalNumReads_RPF, readsLength_RNA, |
| | readsLength_RPF): |
| | ''' |
| | 计算ribosome_density |
| | :param numReads_RNA: |
| | :param totalNumReads_RNA: |
| | :param totalNumReads_RPF: |
| | :param readsLength_RNA: |
| | :param readsLength_RPF: |
| | :return: |
| | |
| | example: |
| | # 示例值 |
| | numReads_RPF = 1000 |
| | numReads_RNA = 2000 |
| | totalNumReads_RNA = 5000000 |
| | totalNumReads_RPF = 3000000 |
| | readsLength_RNA = 150 |
| | readsLength_RPF = 100 |
| | |
| | result = calculate_ribosome_density(numReads_RPF, numReads_RNA, totalNumReads_RNA, totalNumReads_RPF, |
| | readsLength_RNA, readsLength_RPF) |
| | print("Ribosome Density:", result) |
| | ''' |
| | |
| | readsLength_RPF = np.where(readsLength_RPF > 40, 30, readsLength_RPF) |
| | ratio_numReads = numReads_RPF / numReads_RNA |
| | ratio_totalNumReads = totalNumReads_RNA / totalNumReads_RPF |
| | ratio_readsLength = readsLength_RNA / readsLength_RPF |
| |
|
| | ribosome_density = np.log2(ratio_numReads * ratio_totalNumReads * ratio_readsLength + 1) |
| | ribosome_density = np.where(numReads_RNA==-1, -1, ribosome_density) |
| | return ribosome_density |
| | def read_meta_file(self, file_path, ribo_experiment, rna_experiment, seq_only=False): |
| | df = pd.read_table(file_path) |
| | if seq_only: |
| | if ribo_experiment: |
| | species = df[df['ribo_experiment'] == ribo_experiment]['Ref'].iloc[0] |
| | elif rna_experiment: |
| | species = df[df['rna_experiment'] == ribo_experiment]['Ref'].iloc[0] |
| | else: |
| | raise ValueError("ribo_experiment or rna_experiment should be provided") |
| | return None,None,None,None,species |
| | row = df[(df['ribo_experiment'] == ribo_experiment) & (df['rna_experiment'] == rna_experiment)].iloc[0] |
| | totalNumReads_RNA, totalNumReads_RPF, readsLength_RNA, readsLength_RPF,species = row['totalNumReads_RNA'], row['totalNumReads_RPF'], row['readsLength_RNA'], row['readsLength_RPF'],row['Ref'] |
| | return totalNumReads_RNA, totalNumReads_RPF, readsLength_RNA,readsLength_RPF,species |
| |
|
| | class RiboBwDataPipeline(RiboDataPipeline): |
| | def __init__(self, |
| | data_path, |
| | ribo_experiment, |
| | rna_experiment, |
| | seq_only=False, |
| | limit=-1, |
| | ): |
| | super().__init__(data_path, ribo_experiment, rna_experiment, seq_only, limit) |
| | def load_data(self, path, ribo_experiment=None,rna_experiment=None,region=300, limit=-1,norm=True): |
| | ''' |
| | 读取数据 |
| | 1. 根据ribo_experiment,从meta中查询species,avg_len,total counts 等 |
| | 2. 查询mRNA.fa文件是否存在 |
| | 不存在: |
| | 查询mRNA.tsv是否存在 |
| | 不存在: |
| | genome.gtf 生成 mRNA.gtf文件 (只含有mRNA相关的行和列,scale the size of gtf) |
| | 根据genome.fa 和mRNA.gtf 生成mRNA.fa文件 (包括start or stop codon positions) |
| | 3. 读取track文件,生成ribosome_density, ribo_counts, rna_counts |
| | |
| | {'ENST00000303577.7': 'PCBP1', # IRES |
| | 'ENST00000309311.7': 'EEF2'} # cap dependent |
| | |
| | :param path: |
| | :param reference_path: |
| | :param region: |
| | :param limit: |
| | :return: samples |
| | |
| | path = ./dataset/pretraining/ |
| | ''' |
| | """1. input ribo_experiment, meta""" |
| | seq_only = self.seq_only |
| | cds_min = self.cds_min |
| | |
| |
|
| | reference_path = os.path.join(path,'reference') |
| | meta = self.read_meta_file(os.path.join(reference_path, 'experiment_meta.tsv'),ribo_experiment,rna_experiment,seq_only = seq_only) |
| | totalNumReads_RNA, totalNumReads_RPF, readsLength_RNA,readsLength_RPF,species = meta |
| | """2. check mRNA.fa""" |
| | |
| | |
| | mrna_fa_path = os.path.join(reference_path, species, f'mRNA.fa') |
| | if region != -1: mrna_fa_path = mrna_fa_path.replace('.fa', f'_{region}.fa') |
| | mrna_fa_path = mrna_fa_path.replace('.fa', f'.pkl') |
| | self.mrna_region_pkl_path = mrna_fa_path |
| | if seq_only and os.access(mrna_fa_path, os.F_OK): |
| | with open(mrna_fa_path, 'rb') as f: |
| | sample_dict = pickle.load(f) |
| | if limit!=-1: |
| | limited_sample_dict = {} |
| | for key in sample_dict.keys(): |
| | limited_sample_dict[key] = {transcript_id:sample_dict[key][transcript_id] for transcript_id in list(sample_dict[key].keys())[:limit]} |
| | return limited_sample_dict |
| | mrna_tsv_path = os.path.join(reference_path,species, 'mRNA.tsv') |
| | if not os.access(mrna_tsv_path, os.F_OK): |
| | genome_gtf_path = os.path.join(reference_path,species, 'genome.gtf') |
| | genome_fa_path = os.path.join(reference_path,species, 'genome.fa') |
| | mrna_tsv = self.generate_mRNA_tsv(genome_gtf_path,genome_fa_path,mrna_tsv_path) |
| | else: |
| | mrna_tsv = pd.read_table(mrna_tsv_path) |
| |
|
| | |
| | """3. read track files""" |
| | |
| | |
| | |
| | |
| | print(f'filter limit={limit},region={region}') |
| | print('load_data in Pipeline,before filter',mrna_tsv.shape) |
| | reference_transcript_ids = list(self.reference_transcript_dict.keys()) |
| | |
| | |
| | |
| | if limit!=-1: |
| | mrna_tsv = mrna_tsv[mrna_tsv.transcript_id.isin(reference_transcript_ids+list(mrna_tsv.transcript_id.unique()[:limit]))] |
| | print('load_data in Pipeline, after filter',mrna_tsv.shape) |
| |
|
| | if not seq_only: |
| | import pyBigWig |
| | fribo_track = os.path.join(path,'track', f'{ribo_experiment}.bw') |
| | frna_track = os.path.join(path,'track', f'{rna_experiment}.bw') |
| | if os.access(fribo_track,os.F_OK) and os.access(frna_track,os.F_OK): |
| | print(f'load {ribo_experiment} and {rna_experiment} tracks') |
| | else: |
| | print(f'Error: {fribo_track} or {frna_track} not found.') |
| | return None |
| | ribo_bw,rna_bw = [pyBigWig.open(fribo_track), pyBigWig.open(frna_track)] |
| | print(f'meta of {ribo_experiment} and {rna_experiment} tracks loaded\n{ribo_bw.header()}\n{rna_bw.header()}') |
| | def iterfunc(x,bw): |
| | chrom, start, end = x['seqname'], x['start'], x['end'] |
| | if chrom in bw.chroms(): |
| | return np.array(bw.values(chrom, start - 1, end)) |
| | else: |
| | return np.zeros(end - start) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | mrna_tsv['ribo_counts'] = mrna_tsv.apply( lambda x: iterfunc(x,ribo_bw), axis=1) |
| | mrna_tsv['rna_counts'] = mrna_tsv.apply(lambda x: iterfunc(x,rna_bw), axis=1) |
| |
|
| | |
| | |
| |
|
| | |
| | ribo_bw.close() |
| | rna_bw.close() |
| | del ribo_bw,rna_bw |
| |
|
| |
|
| | """split dataset todo""" |
| | if species == 'GRCh38.p14': |
| | mappingdict = {'chr10': 'VL', 'chr15': 'TS'} |
| | mrna_tsv['dataset'] = mrna_tsv['seqname'].apply(lambda x: mappingdict[x] if x in mappingdict else 'TR') |
| | else:mrna_tsv['dataset'] = 'TR' |
| | print(f'{int(mrna_tsv.transcript_id.nunique())} transcripts in mRNA.tsv') |
| | """4. filter mRNA.tsv by transcript_id""" |
| | sample_dict = defaultdict(list) |
| | count = 0 |
| | total_counts_info = [totalNumReads_RNA, totalNumReads_RPF, readsLength_RNA, readsLength_RPF] |
| |
|
| | '''ref_norm''' |
| | ref_norm = None |
| | if not seq_only: |
| | ref_norm = [] |
| | for transcript_id, data in mrna_tsv[mrna_tsv.transcript_id.isin(reference_transcript_ids)].groupby('transcript_id'): |
| | ans = self.merge_transcript_level(data,total_counts_info=total_counts_info,seq_only=seq_only,cds_min=cds_min,region=region) |
| | if ans is None:continue |
| | (seq, |
| | cds_start, cds_stop, |
| | |
| | ribo_counts, |
| | rna_counts, |
| | ribosome_density, |
| | te, self.env, cds_len, mRNA_len,junction_counts) = ans |
| | ref_norm.append((sum(ribo_counts)/cds_len/readsLength_RPF,sum(rna_counts)/mRNA_len/readsLength_RNA)) |
| | if len(ref_norm)==0 and norm: |
| | print(f'Error: no qualified reference transcript (housekeeping when norm=True)') |
| | return None |
| | ref_norm = np.mean(ref_norm,axis=0) if norm and len(ref_norm)>0 else None |
| | print('ref_norm',ref_norm,'sum(ribo_counts)/cds_len/readsLength_RPF,sum(rna_counts)/mRNA_len/readsLength_RNA)') |
| | '''generate by norm''' |
| | for transcript_id, data in mrna_tsv.groupby('transcript_id'): |
| | tag = data['dataset'].iloc[0] |
| | ans = self.merge_transcript_level(data,total_counts_info=total_counts_info,seq_only=seq_only,cds_min=cds_min,region=region,ref_norm=ref_norm) |
| | if ans is None:continue |
| | sample_dict[tag].append([transcript_id] + ans) |
| | count += 1 |
| | if limit == count: break |
| | |
| | |
| | |
| | self.samples = sample_dict |
| | if seq_only: |
| | mrna_fa_path = os.path.join(reference_path, species, f'mRNA.fa') |
| | if region !=-1: |
| | mrna_fa_path = mrna_fa_path.replace('.fa',f'_{region}.fa') |
| | if not os.access(mrna_fa_path,os.F_OK) or os.path.getsize(mrna_fa_path)==0: |
| | print(f'generate {sum([len(a.keys()) for a in sample_dict.values()])} sequences to {mrna_fa_path} {os.path.abspath(mrna_fa_path)}') |
| | self.generate_mRNA_fa(mrna_fa_path,force_regenerate=True) |
| | mrna_fa_path = mrna_fa_path.replace('.fa',f'.pkl') |
| | if not os.access(mrna_fa_path, os.F_OK) or os.path.getsize(mrna_fa_path)==0: |
| | with open(mrna_fa_path, 'wb') as f: |
| | pickle.dump(sample_dict, f) |
| | self.mrna_region_pkl_path = mrna_fa_path |
| | return sample_dict |
| |
|
| | class RegionDataset(BaseDataset): |
| | """DST""" |
| | def __init__( |
| | self, |
| | samples, |
| | tokenizer, |
| | args, |
| | region: int = 300, |
| | limit: int = -1, |
| | return_masked_tokens: bool = False, |
| | seed: int = 1, |
| | mask_prob: float = 0.15, |
| | leave_unmasked_prob: float = 0.1, |
| | random_token_prob: float = 0.1, |
| | freq_weighted_replacement: bool = False, |
| | two_dim_score: bool = False, |
| | two_dim_mask: int = -1, |
| | mask_whole_words: torch.Tensor = None, |
| | ): |
| | |
| | super().__init__( |
| | tokenizer=tokenizer, |
| | region=region, |
| | limit=limit, |
| | return_masked_tokens=return_masked_tokens, |
| | seed=seed, |
| | mask_prob=mask_prob, |
| | leave_unmasked_prob=leave_unmasked_prob, |
| | random_token_prob=random_token_prob, |
| | freq_weighted_replacement=freq_weighted_replacement, |
| | two_dim_score=two_dim_score, |
| | two_dim_mask=two_dim_mask, |
| | mask_whole_words=mask_whole_words, |
| | ) |
| |
|
| | |
| | self.args = args |
| | self.samples = samples |
| | if limit!=-1: |
| | self.samples = self.samples[:limit] |
| | self.teacher_tokenizer = DebertaTokenizerFast.from_pretrained("./src/mRNA2vec/tokenizer", use_fast=True) |
| |
|
| | self.teacher_tokenizer.padding_side = "left" |
| |
|
| | def __len__(self): |
| | return len(self.samples) |
| |
|
| | def __getitem__(self, idx): |
| | _id, seq, cds_start, cds_stop,*_ = self.samples[idx] |
| | |
| | |
| | |
| | |
| |
|
| | aa_seq = '-'+self.translate(re.sub(r'[^ACGT]', 'N', seq[1:-1].replace('U','T'))).lower()+'-' |
| | aa_idx = torch.tensor(np.array([self.tokenizer.indices.get(aa) for aa in aa_seq]),dtype=torch.long) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | X = self.seq_to_rnaindex(seq, pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index) |
| | '''generate src_data, tgt_data, twod_data ''' |
| | src_data, tgt_data, twod_data, loss_mask = self.generate_mask(X, len(seq)) |
| | if "ernierna" in self.args.mlm_pretrained_model_path or 'teacher' in self.args.mlm_pretrained_model_path: |
| | teacher_input_ids = src_data |
| | elif "mrna2vec" in self.args.mlm_pretrained_model_path: |
| | teacher_encoder = self.teacher_tokenizer(seq[1:-1], |
| | padding='max_length', |
| | max_length=403, |
| | truncation=True, |
| | add_special_tokens=True, |
| | return_tensors="pt", |
| | ) |
| | teacher_input_ids = teacher_encoder['input_ids'].squeeze(0) |
| |
|
| | |
| | return (src_data,teacher_input_ids, tgt_data, twod_data,aa_idx, loss_mask) |
| |
|
| | @staticmethod |
| | def seq_to_rnaindex(seq,pad_idx=1, unk_idx=3): |
| | seq = seq.upper() |
| | if seq.count('<') > 1 or seq.count('/') > 0: |
| | seq = seq.replace('<PAD>', '_').replace('<BOS>', 'V').replace('<EOS>', '^').replace('/','NNN') |
| | l = len(seq) |
| | X = np.ones((1, l)) |
| |
|
| | for j in range(l): |
| | if seq[j] in set('Aa'): |
| | X[0, j] = 5 |
| | elif seq[j] in set('UuTt'): |
| | X[0, j] = 6 |
| | elif seq[j] in set('Cc'): |
| | X[0, j] = 7 |
| | elif seq[j] in set('Gg'): |
| | X[0, j] = 4 |
| | elif seq[j] in set('_'): |
| | X[0,j] = pad_idx |
| | elif seq[j] in set('<V'): |
| | X[0,j] = 0 |
| | elif seq[j] in set('>^'): |
| | X[0,j] = 2 |
| | else: |
| | X[0,j] = unk_idx |
| | return X |
| |
|
| | '''generate''' |
| | class BackBoneDataset(RegionDataset): |
| | '''for distillation using ribo dataset''' |
| | def __init__( |
| | self, |
| | samples, |
| | tokenizer, |
| | args, |
| | region: int = 300, |
| | limit: int = -1, |
| | return_masked_tokens: bool = False, |
| | seed: int = 1, |
| | mask_prob: float = 0.15, |
| | leave_unmasked_prob: float = 0.1, |
| | random_token_prob: float = 0.1, |
| | freq_weighted_replacement: bool = False, |
| | two_dim_score: bool = False, |
| | two_dim_mask: int = -1, |
| | mask_whole_words: torch.Tensor = None, |
| | input_mask = True,Kozak_GS6H_Stop3='GCCACC,GGGAGCCACCACCACCATCACCAC,TGATAATAG' |
| | ): |
| | |
| | super().__init__( |
| | samples=samples, |
| | tokenizer=tokenizer, |
| | args=args, |
| | region=region, |
| | limit=limit, |
| | return_masked_tokens=return_masked_tokens, |
| | seed=seed, |
| | mask_prob=mask_prob, |
| | leave_unmasked_prob=leave_unmasked_prob, |
| | random_token_prob=random_token_prob, |
| | freq_weighted_replacement=freq_weighted_replacement, |
| | two_dim_score=two_dim_score, |
| | two_dim_mask=two_dim_mask, |
| | mask_whole_words=mask_whole_words, |
| | ) |
| | self.input_mask = input_mask |
| | self.Kozak_GS6H_Stop3 = Kozak_GS6H_Stop3.upper() |
| |
|
| | def __getitem__(self, idx): |
| | |
| | |
| | |
| |
|
| | data = self.samples.iloc[idx] |
| | _id = data['_id'] |
| | seq = data['sequence'] |
| | seq = seq.replace('U','T') |
| | start,stop = self.region + 1, self.region * 3 + 4 |
| | Kozak, GS6H, Stop3 = self.Kozak_GS6H_Stop3.split(',') if ',' in self.Kozak_GS6H_Stop3 else '','','' |
| | |
| | '''fix nt, not opt''' |
| | seq = seq[:start-len(Kozak)].replace('ATG','ATC') + Kozak + seq[start:stop-len(GS6H)-len(Stop3)] + GS6H+ Stop3 +seq[stop:] |
| | |
| | |
| | |
| | |
| |
|
| | '''whole mask''' |
| | aa_seq = '-'+self.translate(re.sub(r'[^ACGTU]', 'N', seq[1:-1])).lower()+'-' |
| | aa_idx = torch.tensor(np.array([self.tokenizer.indices.get(aa) for aa in aa_seq]),dtype=torch.long) |
| |
|
| | |
| | |
| | X = self.seq_to_rnaindex(seq, pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index) |
| | _,_,_,mask = self.generate_mask(X,len(seq)) |
| |
|
| | seq_length = len(seq) |
| | vocab_size = len(self.tokenizer) |
| | |
| | masked_logits = torch.full(size=(seq_length,vocab_size),fill_value=float("-inf")) |
| | masked_logits[np.arange(X.shape[1]),X.reshape(-1)]=0 |
| |
|
| | '''CDS mask''' |
| | X_CDS,masked_logits_CDS = self.CDS_mask(seq, start, stop- len(GS6H) - len(Stop3)) |
| | mask[start:stop- len(GS6H) - len(Stop3)] = X_CDS==self.tokenizer.mask_index |
| | special = self.seq_to_rnaindex('ACGT').reshape(-1) |
| | masked_logits[:start-len(Kozak),special]=0 |
| | masked_logits[stop:,special]=0 |
| | masked_logits[start:stop- len(GS6H) - len(Stop3)] = masked_logits_CDS |
| |
|
| | for token in ['<s>', '<pad>', '</s>', '<unk>']: |
| | masked_logits[X.reshape(-1) == self.tokenizer.indices.get(token), :] = float("-inf") |
| | masked_logits[X.reshape(-1)==self.tokenizer.indices.get(token),self.tokenizer.indices.get(token)] = 0 |
| |
|
| | |
| | mask[start-len(Kozak):start+3] = False |
| | mask[stop-len(GS6H)-len(Stop3):stop] = False |
| | src_data, tgt_data, twod_data, loss_mask = self.generate_mask(X, len(seq),mask=mask,input_mask=self.input_mask) |
| |
|
| | loss_mask = torch.tensor(mask, dtype=torch.bool) |
| | src_data = torch.where(loss_mask,aa_idx,src_data) |
| | src_env = torch.tensor(self.args.env_id, dtype=torch.long) |
| | |
| | return (_id,src_data,tgt_data,twod_data,loss_mask,masked_logits,src_env) |
| |
|
| | def CDS_mask(self,seq,start,stop): |
| | |
| | |
| | backbone_cds = re.sub(r'[^ACGT]', 'N', seq[start:stop]) |
| | |
| | |
| | target_protein = self.translate(backbone_cds,repeate=1).upper() |
| | target_protein_idx = torch.tensor(np.array([self.tokenizer.indices.get(aa) for aa in target_protein.lower()]),dtype=torch.long) |
| |
|
| | X = self.seq_to_rnaindex(backbone_cds, pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index).reshape(-1) |
| | num_rows, num_cols = len(target_protein),3 |
| | cds_mask = np.zeros([num_rows, num_cols],dtype=int) |
| |
|
| | |
| | rows_to_mask = int(num_rows * self.mask_prob *2) |
| | |
| | masked_rows = random.sample(range(num_rows), rows_to_mask) |
| | |
| | masked_cols = np.random.randint(0, num_cols, size=rows_to_mask) |
| | |
| | cds_mask[masked_rows, masked_cols] = 1 |
| | cds_mask = cds_mask.reshape(-1) |
| | X[cds_mask==1]=self.tokenizer.mask_index |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | masked_logits = self.create_codon_mask(target_protein_idx.numpy(), X, self.amino_acid_to_codons,self.tokenizer) |
| | |
| |
|
| | |
| | |
| |
|
| | return X,masked_logits |
| |
|
| |
|
| | class RiboDataset(RegionDataset): |
| | '''for distillation using ribo dataset''' |
| | def __init__( |
| | self, |
| | samples, |
| | tokenizer, |
| | args, |
| | region: int = 300, |
| | limit: int = -1, |
| | return_masked_tokens: bool = False, |
| | seed: int = 1, |
| | mask_prob: float = 0.15, |
| | leave_unmasked_prob: float = 0.1, |
| | random_token_prob: float = 0.1, |
| | freq_weighted_replacement: bool = False, |
| | two_dim_score: bool = False, |
| | two_dim_mask: int = -1, |
| | mask_whole_words: torch.Tensor = None, |
| | ): |
| | |
| | super().__init__( |
| | samples=samples, |
| | tokenizer=tokenizer, |
| | args=args, |
| | region=region, |
| | limit=limit, |
| | return_masked_tokens=return_masked_tokens, |
| | seed=seed, |
| | mask_prob=mask_prob, |
| | leave_unmasked_prob=leave_unmasked_prob, |
| | random_token_prob=random_token_prob, |
| | freq_weighted_replacement=freq_weighted_replacement, |
| | two_dim_score=two_dim_score, |
| | two_dim_mask=two_dim_mask, |
| | mask_whole_words=mask_whole_words, |
| | ) |
| |
|
| | def __getitem__(self, idx): |
| | |
| | |
| | |
| | (_id,seq,cds_start, cds_stop, |
| | ribo_counts,rna_counts, |
| | ribosome_density,te,env,cds_len,mRNA_len,junction_counts) = self.samples[idx] |
| | aa_seq = '-'+self.translate(re.sub(r'[^ACGT]', 'N', seq[1:-1])).lower()+'-' |
| | aa_idx = torch.tensor(np.array([self.tokenizer.indices.get(aa) for aa in aa_seq]),dtype=torch.long) |
| | |
| |
|
| | |
| | X = self.seq_to_rnaindex(seq, pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index) |
| | src_data,tgt_data,twod_data,mask = self.generate_mask(X,len(seq)) |
| | loss_mask = torch.tensor(mask, dtype=torch.bool) |
| | src_data = torch.where(loss_mask,aa_idx,src_data) |
| |
|
| | window = 31 |
| | exp_one_d = np.stack([ribo_counts,rna_counts,ribosome_density],axis=1) |
| | tgt_exp_data = torch.from_numpy(exp_one_d).float() |
| | tgt_exp_data = tgt_exp_data.permute(1,0) |
| | tgt_exp_data = F.avg_pool1d(tgt_exp_data,kernel_size=window,padding=window//2,stride=1) |
| | tgt_exp_data = tgt_exp_data.permute(1,0) |
| | tgt_exp_data[~loss_mask,:] = -1 |
| |
|
| | src_exp_data = torch.from_numpy(exp_one_d).float() |
| | src_exp_mask = F.max_pool1d(loss_mask.unsqueeze(0).repeat(3,1).float(),kernel_size=window,padding=window//2,stride=1).permute(1,0) |
| | src_exp_data = torch.where(src_exp_mask.bool(),torch.zeros_like(src_exp_mask),src_exp_data) |
| |
|
| | |
| | |
| | |
| | src_data = torch.where(loss_mask,aa_idx,src_data) |
| | src_env = torch.tensor(env, dtype=torch.long) |
| | src_feature = np.array([cds_len,mRNA_len,junction_counts]) |
| | src_feature = torch.from_numpy(src_feature).float() |
| | src_feature = torch.log(src_feature+1) |
| | tgt_te = torch.tensor(te, dtype=torch.float32) |
| |
|
| | |
| |
|
| | return src_data,src_exp_data,src_env,src_feature,tgt_data,tgt_exp_data,tgt_te,twod_data,loss_mask |
| |
|
| | class DownstreamDataset(RegionDataset): |
| | def __init__( |
| | self, |
| | samples, |
| | tokenizer, |
| | args, |
| | region: int = 300, |
| | limit: int = -1, |
| | return_masked_tokens: bool = False, |
| | seed: int = 1, |
| | mask_prob: float = 0.15, |
| | leave_unmasked_prob: float = 0.1, |
| | random_token_prob: float = 0.1, |
| | freq_weighted_replacement: bool = False, |
| | two_dim_score: bool = False, |
| | two_dim_mask: int = -1, |
| | mask_whole_words: torch.Tensor = None, |
| | |
| | seq_len: int = 174, |
| | pad_method: str = "pre", |
| | column: str = "sequence", |
| | cds_len:str='cds_len', |
| | mRNA_len:str='mRNA_len', |
| | label: str = "IRES_Activity", |
| | ): |
| | |
| | super().__init__( |
| | samples=samples, |
| | tokenizer=tokenizer, |
| | args=args, |
| | region=region, |
| | limit=limit, |
| | return_masked_tokens=return_masked_tokens, |
| | seed=seed, |
| | mask_prob=mask_prob, |
| | leave_unmasked_prob=leave_unmasked_prob, |
| | random_token_prob=random_token_prob, |
| | freq_weighted_replacement=freq_weighted_replacement, |
| | two_dim_score=two_dim_score, |
| | two_dim_mask=two_dim_mask, |
| | mask_whole_words=mask_whole_words, |
| | ) |
| |
|
| | |
| | self.label = label |
| | self.column = column |
| | self.seq_len = seq_len |
| | self.cds_len = cds_len |
| | self.mRNA_len = mRNA_len |
| | self.pad_method = pad_method |
| | if limit!=-1: |
| | self.samples = self.samples.iloc[:limit] |
| |
|
| |
|
| | ''' |
| | eGFP |
| | https://www.ncbi.nlm.nih.gov/nuccore/L29345.1 |
| | >L29345.1 Aequorea victoria green-fluorescent protein (GFP) mRNA, complete cds| 26..742 |
| | TACACACGAATAAAAGATAACAAAGATGAGTAAAGGAGAAGAACTTTTCACTGGAGTTGTCCCAATTCTTGTTGAATTAGATGGCGATGTTAATGGGCAAAAATTCTCTGTCAGTGGAGAGGGTGAAGGTGATGCAACATACGGAAAACTTACCCTTAAATTTATTTGCACTACTGGGAAGCTACCTGTTCCATGGCCAACACTTGTCACTACTTTCTCTTATGGTGTTCAATGCTTTTCAAGATACCCAGATCATATGAAACAGCATGACTTTTTCAAGAGTGCCATGCCCGAAGGTTATGTACAGGAAAGAACTATATTTTACAAAGATGACGGGAACTACAAGACACGTGCTGAAGTCAAGTTTGAAGGTGATACCCTTGTTAATAGAATCGAGTTAAAAGGTATTGATTTTAAAGAAGATGGAAACATTCTTGGACACAAAATGGAATACAACTATAACTCACATAATGTATACATCATGGCAGACAAACCAAAGAATGGAATCAAAGTTAACTTCAAAATTAGACACAACATTAAAGATGGAAGCGTTCAATTAGCAGACCATTATCAACAAAATACTCCAATTGGCGATGGCCCTGTCCTTTTACCAGACAACCATTACCTGTCCACACAATCTGCCCTTTCCAAAGATCCCAACGAAAAGAGAGATCACATGATCCTTCTTGAGTTTGTAACAGCTGCTGGGATTACACATGGCATGGATGAACTATACAAATAAATGTCCAGACTTCCAATTGACACTAAAGTGTCCGAACAATTACTAAATTCTCAGGGTTCCTGGTTAAATTCAGGCTGAGACTTTATTTATATATTTATAGATTCATTAAAATTTTATGAATAATTTATTGATGTTATTAATAGGGGCTATTTTCTTATTAAATAGGCTACTGGAGTGTAT |
| | ''' |
| | def __len__(self): |
| | return len(self.samples) |
| | def __getitem__(self, idx): |
| | data = self.samples.iloc[idx] |
| | seq = data[self.column] |
| | target = data[self.label] |
| | X = self.seq_to_rnaindex(seq, pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index) |
| | one_d, twod_d = self.prepare_input_for_ernierna(X, len(seq)) |
| | one_d = one_d.view(-1) |
| | |
| | if not torch.is_tensor(one_d): |
| | src_data = torch.from_numpy(one_d) |
| | else: |
| | src_data = one_d |
| | if not torch.is_tensor(twod_d): |
| | twod_data = torch.from_numpy(twod_d.squeeze(dim=-1)) |
| | else: |
| | twod_data = twod_d.squeeze(dim=-1) |
| |
|
| |
|
| | src_env = torch.tensor(self.args.env_id, dtype=torch.long) |
| |
|
| | cds_len = data[self.cds_len] if self.cds_len in data else 742-26+1 |
| | mRNA_len = data[self.mRNA_len] if self.mRNA_len in data else 922 |
| | |
| | |
| | junction_counts= 0 |
| | src_feature = np.array([cds_len,mRNA_len,junction_counts]) |
| | src_feature = torch.from_numpy(src_feature).float() |
| | |
| |
|
| | |
| | target = torch.tensor(target, dtype=torch.float32) |
| | return src_data, twod_data,src_env, src_feature, target |
| |
|
| |
|
| | class RegressionDataset(BaseDataset): |
| | """处理回归任务的Dataset""" |
| | def __init__( |
| | self, |
| | path, |
| | tokenizer, |
| | args, |
| | region: int = 300, |
| | limit: int = -1, |
| | return_masked_tokens: bool = False, |
| | seed: int = 1, |
| | mask_prob: float = 0.15, |
| | leave_unmasked_prob: float = 0.1, |
| | random_token_prob: float = 0.1, |
| | freq_weighted_replacement: bool = False, |
| | two_dim_score: bool = False, |
| | two_dim_mask: int = -1, |
| | mask_whole_words: torch.Tensor = None, |
| | |
| | seq_len: int = 174, |
| | pad_method: str = "pre", |
| | column: str = "sequence", |
| | label: str = "IRES_Activity", |
| | returnid=None |
| | ): |
| | |
| | super().__init__( |
| | tokenizer=tokenizer, |
| | region=region, |
| | limit=limit, |
| | return_masked_tokens=return_masked_tokens, |
| | seed=seed, |
| | mask_prob=mask_prob, |
| | leave_unmasked_prob=leave_unmasked_prob, |
| | random_token_prob=random_token_prob, |
| | freq_weighted_replacement=freq_weighted_replacement, |
| | two_dim_score=two_dim_score, |
| | two_dim_mask=two_dim_mask, |
| | mask_whole_words=mask_whole_words, |
| | ) |
| |
|
| | |
| | self.label = label |
| | self.column = column |
| | self.seq_len = seq_len |
| | self.pad_method = pad_method |
| | self.args = args |
| | self.returnid = returnid |
| | |
| | self.samples = self.load_data( |
| | path, |
| | seq_len=seq_len, |
| | column=column, |
| | pad_method=pad_method |
| | ) |
| | if limit!=-1: |
| | self.samples = self.samples.iloc[:limit] |
| |
|
| | def load_data(self, path, **kwargs): |
| |
|
| | return self.read_csv_file( |
| | path, |
| | seq_len=kwargs['seq_len'], |
| | column=kwargs['column'], |
| | pad_method=kwargs['pad_method'] |
| | ) |
| |
|
| | def read_csv_file(self,file_path, **kwargs): |
| | |
| | try: |
| | column = kwargs['column'] |
| | data = pd.read_csv(file_path) |
| | if column not in data.columns: |
| | data[column] = data.apply(self.generate_inputs, axis=1) |
| | return pad_or_truncate_utr( |
| | data, |
| | pad_method=kwargs['pad_method'], |
| | column=kwargs['column'], |
| | input_len=kwargs['seq_len'] |
| | ) |
| | except FileNotFoundError: |
| | print(f"Error: File '{file_path}' not found.") |
| | return [] |
| |
|
| | def __len__(self): |
| | return len(self.samples) |
| |
|
| | def __getitem__(self,idx): |
| | data = self.samples.iloc[idx] |
| | seq = data[self.column] |
| | target = data[self.label] |
| | |
| | X = self.seq_to_rnaindex(seq, pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index) |
| |
|
| | |
| | one_d, twod_d = self.prepare_input_for_ernierna(X, len(seq)) |
| | one_d = one_d.view(-1) |
| | |
| | if not torch.is_tensor(one_d): |
| | src_data = torch.from_numpy(one_d) |
| | else: |
| | src_data = one_d |
| | if not torch.is_tensor(twod_d): |
| | twod_data = torch.from_numpy(twod_d.squeeze(dim=-1)) |
| | else: |
| | twod_data = twod_d.squeeze(dim=-1) |
| | src_env = torch.tensor(self.args.env_id, dtype=torch.long) |
| | cds_len = 742-26+1 |
| | mRNA_len = 922 |
| | junction_counts= 0 |
| | src_feature = np.array([cds_len,mRNA_len,junction_counts]) |
| | src_feature = torch.from_numpy(src_feature).float() |
| | src_feature = torch.log(src_feature+1) |
| |
|
| | |
| | target = torch.tensor(target, dtype=torch.float32) |
| |
|
| | if self.returnid is None:return src_data, twod_data,src_env, src_feature, target |
| | else: return data[self.returnid],src_data, twod_data,src_env, src_feature, target |
| |
|
| |
|
| | class MaotaoDataset(BaseDataset): |
| | """处理回归任务的Dataset""" |
| | def __init__( |
| | self, |
| | path, |
| | tokenizer, |
| | args, |
| | region: int = 300, |
| | limit: int = -1, |
| | return_masked_tokens: bool = False, |
| | seed: int = 1, |
| | mask_prob: float = 0.15, |
| | leave_unmasked_prob: float = 0.1, |
| | random_token_prob: float = 0.1, |
| | freq_weighted_replacement: bool = False, |
| | two_dim_score: bool = False, |
| | two_dim_mask: int = -1, |
| | mask_whole_words: torch.Tensor = None, |
| | |
| | seq_len: int = 1200, |
| | column: str = 'off_start,off_end,full_len,type,_id,species,maotao_id,truncated_aa,cai_best_nn', |
| | label: str = "truncated_nn,cai_nature", |
| | |
| | codon_table_path: str='maotao_file/codon_table/codon_usage_{species}.csv', |
| | species_list:str="""mouse,Ec,Sac,Pic,Human""", |
| | type_list:str="""full,head,tail,boundary,middle""", |
| | |
| | rna_alphabet_list:str="""GAUC""", |
| | returnid = None |
| | ): |
| | |
| | super().__init__( |
| | tokenizer=tokenizer, |
| | region=region, |
| | limit=limit, |
| | return_masked_tokens=return_masked_tokens, |
| | seed=seed, |
| | mask_prob=mask_prob, |
| | leave_unmasked_prob=leave_unmasked_prob, |
| | random_token_prob=random_token_prob, |
| | freq_weighted_replacement=freq_weighted_replacement, |
| | two_dim_score=two_dim_score, |
| | two_dim_mask=two_dim_mask, |
| | mask_whole_words=mask_whole_words, |
| | ) |
| |
|
| | |
| | self.species = {k:v for v,k in enumerate(species_list.split(','))} |
| | self.species.update({v:v for v,k in enumerate(species_list.split(','))}) |
| | self.seq_types = {k:v for v,k in enumerate(type_list.split(','))} |
| | self.seq_types.update({v:v for v,k in enumerate(type_list.split(','))}) |
| | |
| | self.rna_alphabet = {k:v+4 for v,k in enumerate(rna_alphabet_list)} |
| | self.label = label.split(',') |
| | self.column = column.split(',') |
| | self.seq_len = seq_len |
| | self.args = args |
| | |
| | self.samples = self.load_data(path) |
| | |
| | self.codon_instance_rna = {self.species[species]: Codon(codon_table_path.format(species=species), rna=True) for species in |
| | species_list.split(',')} |
| | if limit!=-1: |
| | self.samples = self.samples.iloc[:limit] |
| |
|
| | def load_data(self, path, **kwargs): |
| | if os.access(path.replace('.csv','_processed.pickle'), os.R_OK): |
| | df = pd.read_pickle(path.replace('.csv','_processed.pickle')) |
| | else: |
| | df = pd.read_csv(path) |
| | df['truncated_aa'] = df['truncated_aa'].apply(lambda x: re.sub(r'[^acdefghiklmnpqrstvwy*_]', '_', x.lower())) |
| | df['cai_best_nn'] = df['cai_best_nn'].apply(lambda x: x.upper().replace('T','U')) |
| | df['species'] = df['species'].apply(lambda x: self.species[x]) |
| | df['type'] = df['type'].apply(lambda x: self.seq_types[x]) |
| | df.to_csv(path.replace('.csv','_processed.csv'),index=False) |
| | with open(path.replace('.csv','_processed.pickle'), 'wb') as f: |
| | pickle.dump(df,f) |
| | return df |
| |
|
| | def __len__(self): |
| | return len(self.samples) |
| |
|
| | def __getitem__(self,idx): |
| | data = self.samples.iloc[idx] |
| | maotao_id = data['maotao_id'] |
| | aa_index = np.array([self.tokenizer.index(x) for x in data['truncated_aa']]) |
| | |
| | aa_idx = torch.from_numpy(aa_index).long() |
| | seq = data['cai_best_nn'] |
| |
|
| | '''prepare 1D and 2D input data''' |
| | |
| | X = self.seq_to_rnaindex(seq, pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index) |
| | |
| | one_d, twod_d = self.prepare_input_for_ernierna(X, len(seq)) |
| | one_d = one_d.view(-1) |
| | |
| | if not torch.is_tensor(one_d): |
| | src_data = torch.from_numpy(one_d) |
| | else: |
| | src_data = one_d |
| | if not torch.is_tensor(twod_d): |
| | twod_data = torch.from_numpy(twod_d.squeeze(dim=-1)) |
| | else: |
| | twod_data = twod_d.squeeze(dim=-1) |
| |
|
| | continuous_features = np.array([data['off_start'],data['off_end'],data['full_len']]) |
| | continuous_features = np.log(np.maximum(continuous_features+3,0)+1) |
| | continuous_features = torch.from_numpy(continuous_features).float() |
| | |
| | species_features = torch.tensor(data['species'],dtype=torch.long) |
| | truncated_features = torch.tensor(data['type'],dtype=torch.long) |
| | ith_nn_prob = self.codon_instance_rna[data['species']].frame_ith_aa_base_fraction |
| | nn_prob = self.create_base_prob(data['truncated_aa'],ith_nn_prob,self.rna_alphabet,self.tokenizer) |
| | '''output''' |
| | if 'truncated_nn' in data: |
| | target_nn = self.seq_to_rnaindex(data['truncated_nn'], pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index).reshape(-1) |
| | |
| | target = torch.tensor(data['cai_nature'], dtype=torch.float32) |
| | else: |
| | target_nn = self.seq_to_rnaindex(data['cai_best_nn'], pad_idx=self.tokenizer.pad_index, unk_idx=self.tokenizer.unk_index).reshape(-1) |
| | target = torch.tensor(0, dtype=torch.float32) |
| | target_nn = torch.from_numpy(target_nn).long() |
| | frames = [1, 2, 3] |
| | backbone_cds_list = self.modify_codon_by_frames(target_nn, frames=frames, |
| | masked_token=self.tokenizer.mask_index) |
| | |
| | masked_logits_list = [] |
| | for backbone_cds, frame in zip(backbone_cds_list, frames): |
| | masked_logits = self.create_codon_mask(aa_idx, backbone_cds, self.amino_acid_to_codons, |
| | self.tokenizer) |
| | masked_logits_list.append(masked_logits.unsqueeze(0)) |
| | |
| | |
| | masked_logits_list = torch.cat(masked_logits_list, dim=0) |
| |
|
| | return src_data, twod_data, aa_idx,continuous_features, species_features, truncated_features, target_nn, target,masked_logits_list[...,:10],nn_prob[...,:10], maotao_id |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | @staticmethod |
| | def modify_codon_by_frames(sequence, frames=[1, 2, 3], masked_token='_'): |
| | """ |
| | 高级版本:支持自定义修改函数 |
| | |
| | 参数: |
| | sequence (str): 输入序列 |
| | frame (int): 要修改的密码子位置 (1, 2, 3) |
| | modify_func (callable): 修改函数,接收原帧字符串,返回修改后的字符串 |
| | |
| | 返回: |
| | str: 修改后的重建序列 |
| | """ |
| | |
| | |
| | seq = sequence |
| | |
| | |
| | frames_seq = [seq[0::3], seq[1::3], seq[2::3]] |
| |
|
| | reconstructed_list = [] |
| | |
| | for ith,frame in enumerate(frames_seq): |
| | if ith+1 in frames: |
| | tmp_seq = deepcopy(frames_seq) |
| | tmp_seq[ith] = [masked_token] * len(frames_seq[ith]) |
| | |
| | |
| | if isinstance(seq,str): |
| | reconstructed = ''.join( |
| | tmp_seq[0][i] + tmp_seq[1][i] + tmp_seq[2][i] |
| | for i in range(len(tmp_seq[0])) |
| | ) |
| | elif isinstance(seq,torch.Tensor): |
| | tmp_seq[ith] = torch.from_numpy(np.array(tmp_seq[ith])) |
| | reconstructed = torch.stack(tmp_seq, dim=1).reshape(-1) |
| |
|
| | elif isinstance(seq,np.ndarray): |
| | tmp_seq[ith] = np.array(tmp_seq[ith]) |
| | reconstructed = np.stack(tmp_seq, axis=1).reshape(-1) |
| | else: |
| | raise ValueError(type(seq)) |
| | |
| | |
| | |
| | |
| | reconstructed_list.append(deepcopy(reconstructed)) |
| |
|
| | return reconstructed_list |
| |
|
| | def gaussian(x): |
| | return math.exp(-0.5*(x*x)) |
| | def paired(x,y,lamda=0.8): |
| | if x == 5 and y == 6: |
| | return 2 |
| | elif x == 4 and y == 7: |
| | return 3 |
| | elif x == 4 and y == 6: |
| | return lamda |
| | elif x == 6 and y == 5: |
| | return 2 |
| | elif x == 7 and y == 4: |
| | return 3 |
| | elif x == 6 and y == 4: |
| | return lamda |
| | else: |
| | return 0 |
| |
|
| | def pad_or_truncate_utr(data, input_len, pad_method,column='utr',pad_mark='_'): |
| | def process_utr(utr): |
| | if len(utr) < input_len: |
| | if pad_method == 'pre': |
| | padded_utr = pad_mark * (input_len - len(utr)) + utr |
| | elif pad_method == 'behind': |
| | padded_utr = utr + pad_mark * (input_len - len(utr)) |
| | else: |
| | padded_utr = utr[-input_len:] |
| | return padded_utr |
| | data[column] = data[column].apply(process_utr) |
| | return data |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | def do_createmat(data, base_range=30, lamda=0.8): |
| | paird_map = np.array([[paired(i, j, lamda) for i in range(30)] for j in range(30)]) |
| | data_index = np.arange(0, len(data)) |
| | |
| | coefficient = np.zeros([len(data), len(data)]) |
| | |
| | score_mask = np.full((len(data), len(data)), True) |
| | for add in [0,300]: |
| | data_index_x = data_index - add |
| | data_index_y = data_index + add |
| | score_mask = ((data_index_x >= 0)[:, None] & (data_index_y < len(data))[None, :]) & score_mask |
| | data_index_x, data_index_y = np.meshgrid(data_index_x.clip(0, len(data) - 1), |
| | data_index_y.clip(0, len(data) - 1), indexing='ij') |
| | score = paird_map[data[data_index_x], data[data_index_y]] |
| | score_mask = score_mask & (score != 0) |
| |
|
| | coefficient = coefficient + score * score_mask * gaussian(add) |
| | if ~(score_mask.any()): |
| | break |
| | return coefficient |
| |
|
| | def creatmat(data, base_range=30, lamda=0.8): |
| | return do_createmat(data, base_range=base_range, lamda=lamda) |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| |
|
| |
|
| |
|
| |
|
| |
|
| | if __name__ == '__main__': |
| | print('start generating') |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |