|
|
| import logging |
| import math |
| import os |
| import warnings |
| from random import randint |
|
|
| import numpy as np |
| import pandas as pd |
| from numpy.lib.stride_tricks import as_strided |
| from omegaconf import DictConfig, open_dict |
|
|
| from ..utils import energy |
| from ..utils.file import save |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class SeqTokenizer: |
| ''' |
| This class should contain functions that other data specific classes should inherit from. |
| ''' |
| def __init__(self,seqs_dot_bracket_labels: pd.DataFrame, config: DictConfig): |
|
|
| self.seqs_dot_bracket_labels = seqs_dot_bracket_labels.reset_index(drop=True) |
| |
| if not config["inference"]: |
| self.seqs_dot_bracket_labels = self.seqs_dot_bracket_labels\ |
| .sample(frac=1)\ |
| .reset_index(drop=True) |
| |
| |
| self.model_input = config["model_config"].model_input |
|
|
|
|
| |
| if config["train_config"].filter_seq_length: |
| self.get_outlier_length_threshold() |
| self.limit_seqs_to_range() |
|
|
| else: |
| self.max_length = self.seqs_dot_bracket_labels['Sequences'].str.len().max() |
| self.min_length = 0 |
|
|
| with open_dict(config): |
| config["model_config"]["max_length"] = np.int64(self.max_length).item() |
| config["model_config"]["min_length"] = np.int64(self.min_length).item() |
| |
| self.window = config["model_config"].window |
| self.tokens_len = math.ceil(self.max_length / self.window) |
| if config["model_config"].tokenizer in ["overlap", "overlap_multi_window"]: |
| self.tokens_len = int(self.max_length - (config["model_config"].window - 1)) |
| self.tokenizer = config["model_config"].tokenizer |
| |
|
|
| self.seq_len_dist = self.seqs_dot_bracket_labels['Sequences'].str.len().value_counts() |
| |
| self.seq_tokens_ids_dict = {} |
| self.second_input_tokens_ids_dict = {} |
|
|
| |
| config["model_config"].num_classes = len(self.seqs_dot_bracket_labels['Labels'].unique()) |
|
|
| self.set_class_attr() |
|
|
|
|
| def get_outlier_length_threshold(self): |
| lengths_arr = self.seqs_dot_bracket_labels['Sequences'].str.len() |
| mean = np.mean(lengths_arr) |
| standard_deviation = np.std(lengths_arr) |
| distance_from_mean = abs(lengths_arr - mean) |
| in_distribution = distance_from_mean < 2 * standard_deviation |
|
|
| inlier_lengths = np.sort(lengths_arr[in_distribution].unique()) |
| self.max_length = int(np.max(inlier_lengths)) |
| self.min_length = int(np.min(inlier_lengths)) |
| logger.info(f'maximum and minimum sequence length is set to: {self.max_length} and {self.min_length}') |
| return |
| |
|
|
| def limit_seqs_to_range(self): |
| ''' |
| Trimms seqs longer than maximum len and deletes seqs shorter than min length |
| ''' |
| df = self.seqs_dot_bracket_labels |
| min_to_be_deleted = [] |
|
|
| num_longer_seqs = sum(df['Sequences'].str.len()>self.max_length) |
| if num_longer_seqs: |
| logger.info(f"Number of sequences to be trimmed: {num_longer_seqs}") |
|
|
|
|
| for idx,seq in enumerate(df['Sequences']): |
| if len(seq) > self.max_length: |
| df['Sequences'].iloc[idx] = \ |
| df['Sequences'].iloc[idx][:self.max_length] |
| |
| elif len(seq) < self.min_length: |
| |
| min_to_be_deleted.append(str(idx)) |
| |
| if len(min_to_be_deleted): |
| df = df.drop(min_to_be_deleted).reset_index(drop=True) |
| logger.info(f"Number of sequences shroter sequences to be removed: {len(min_to_be_deleted)}") |
| self.seqs_dot_bracket_labels = df |
| |
| def get_secondary_structure(self,sequences): |
| secondary = energy.fold_sequences(sequences.tolist()) |
| return secondary['structure_37'].values |
| |
| |
| def chunkstring_overlap(self, string, window): |
| return ( |
| string[0 + i : window + i] for i in range(0, len(string) - window + 1, 1) |
| ) |
| |
| def chunkstring_no_overlap(self, string, window): |
| return (string[0 + i : window + i] for i in range(0, len(string), window)) |
| |
|
|
| def tokenize_samples(self, window:int,sequences_to_be_tokenized:pd.DataFrame,inference:bool=False,tokenizer:str="overlap") -> np.ndarray: |
| """ |
| This function tokenizes rnas based on window(window) |
| with or without overlap according to the current tokenizer option. |
| In case of overlap: |
| example: Token :AACTAGA, window: 3 |
| output: AAC,ACT,CTA,TAG,AGA |
| |
| In case no_overlap: |
| example: Token :AACTAGA, window: 3 |
| output: AAC,TAG,A |
| """ |
| |
| if "overlap" in tokenizer: |
| feature_tokens_gen = list( |
| self.chunkstring_overlap(feature, window) |
| for feature in sequences_to_be_tokenized |
| ) |
| elif tokenizer == "no_overlap": |
| feature_tokens_gen = list( |
| self.chunkstring_no_overlap(feature, window) for feature in sequences_to_be_tokenized |
| ) |
| |
| samples_tokenized = [] |
| sample_token_ids = [] |
| if not self.seq_tokens_ids_dict: |
| self.seq_tokens_ids_dict = {"pad": 0} |
| |
| for gen in feature_tokens_gen: |
| sample_token_id = [] |
| sample_token = list(gen) |
| sample_len = len(sample_token) |
| |
| sample_token.extend( |
| ["pad" for _ in range(int(self.tokens_len - sample_len))] |
| ) |
| |
| for token in sample_token: |
| |
| if token not in self.seq_tokens_ids_dict: |
| if not inference: |
| id = len(self.seq_tokens_ids_dict.keys()) |
| self.seq_tokens_ids_dict[token] = id |
| else: |
| |
| logger.warning(f"The sequence token: {token} was not seen previously by the model. Token will be replaced by a random token") |
| id = randint(1,len(self.seq_tokens_ids_dict.keys()) - 1) |
| token = self.seq_tokens_ids_dict[id] |
| |
| sample_token_id.append(self.seq_tokens_ids_dict[token]) |
|
|
| |
| sample_token_ids.append(np.array(sample_token_id)) |
|
|
| sample_token = np.array(sample_token) |
| samples_tokenized.append(sample_token) |
|
|
| return (np.array(samples_tokenized), np.array(sample_token_ids)) |
| |
| def tokenize_secondary_structure(self, window,sequences_to_be_tokenized,inference:bool=False,tokenizer= "overlap") -> np.ndarray: |
| """ |
| This function tokenizes rnas based on window(window) |
| with or without overlap according to the current tokenizer option. |
| In case of overlap: |
| example: Token :...()..., window: 3 |
| output: ...,..(,.(),().,)..,... |
| |
| In case no_overlap: |
| example: Token :...()..., window: 3 |
| output: ...,().,.. |
| """ |
| samples_tokenized = [] |
| sample_token_ids = [] |
| if not self.second_input_tokens_ids_dict: |
| self.second_input_tokens_ids_dict = {"pad": 0} |
|
|
| |
| if "overlap" in tokenizer: |
| feature_tokens_gen = list( |
| self.chunkstring_overlap(feature, window) |
| for feature in sequences_to_be_tokenized |
| ) |
| elif "no_overlap" == tokenizer: |
| feature_tokens_gen = list( |
| self.chunkstring_no_overlap(feature, window) for feature in sequences_to_be_tokenized |
| ) |
| |
| for seq_idx, gen in enumerate(feature_tokens_gen): |
| sample_token_id = [] |
| sample_token = list(gen) |
| |
| |
| for token in sample_token: |
| |
| if token not in self.second_input_tokens_ids_dict: |
| if not inference: |
| id = len(self.second_input_tokens_ids_dict.keys()) |
| self.second_input_tokens_ids_dict[token] = id |
| else: |
| |
| warnings.warn(f"The secondary structure token: {token} was not seen previously by the model. Token will be replaced by a random token") |
| id = randint(1,len(self.second_input_tokens_ids_dict.keys()) - 1) |
| token = self.second_input_tokens_ids_dict[id] |
| |
| sample_token_id.append(self.second_input_tokens_ids_dict[token]) |
| |
| sample_token_ids.append(sample_token_id) |
| samples_tokenized.append(sample_token) |
| |
| |
| |
| |
| self.second_input_token_len = self.tokens_len |
| for seq_idx, token in enumerate(sample_token_ids): |
| sample_len = len(token) |
| sample_token_ids[seq_idx].extend( |
| [self.second_input_tokens_ids_dict["pad"] for _ in range(int(self.second_input_token_len - sample_len))] |
| ) |
| samples_tokenized[seq_idx].extend( |
| ["pad" for _ in range(int(self.second_input_token_len - sample_len))] |
| ) |
| sample_token_ids[seq_idx] = np.array(sample_token_ids[seq_idx]) |
| samples_tokenized[seq_idx] = np.array(samples_tokenized[seq_idx]) |
| |
| return (np.array(samples_tokenized), np.array(sample_token_ids)) |
| |
| def set_class_attr(self): |
| |
| self.seq = self.seqs_dot_bracket_labels["Sequences"] |
| if 'struct' in self.model_input: |
| self.struct = self.seqs_dot_bracket_labels["Secondary"] |
|
|
| self.labels = self.seqs_dot_bracket_labels['Labels'] |
|
|
| def prepare_multi_idx_pd(self,num_coln,pd_name,pd_value): |
| iterables = [[pd_name], np.arange(num_coln)] |
| index = pd.MultiIndex.from_product(iterables, names=["type of data", "indices"]) |
| return pd.DataFrame(columns=index, data=pd_value) |
|
|
| def phase_sequence(self,sample_token_ids): |
| phase0 = sample_token_ids[:,::2] |
| phase1 = sample_token_ids[:,1::2] |
| |
| if phase0.shape!= phase1.shape: |
| phase1 = np.concatenate([phase1,np.zeros(phase1.shape[0])[...,np.newaxis]],axis=1) |
| sample_token_ids = phase0 |
| |
| return sample_token_ids,phase1 |
|
|
| def custom_roll(self,arr, n_shifts_per_row): |
| ''' |
| shifts each row of a numpy array according to n_shifts_per_row |
| ''' |
| m = np.asarray(n_shifts_per_row) |
| arr_roll = arr[:, [*range(arr.shape[1]),*range(arr.shape[1]-1)]].copy() |
| strd_0, strd_1 = arr_roll.strides |
| n = arr.shape[1] |
| result = as_strided(arr_roll, (*arr.shape, n), (strd_0 ,strd_1, strd_1)) |
|
|
| return result[np.arange(arr.shape[0]), (n-m)%n] |
|
|
| def save_token_dicts(self): |
| |
| save(data = self.second_input_tokens_ids_dict,path = os.getcwd()+'/second_input_tokens_ids_dict') |
| save(data = self.seq_tokens_ids_dict,path = os.getcwd()+'/seq_tokens_ids_dict') |
|
|
|
|
| def get_tokenized_data(self,inference:bool=False): |
| |
| samples_tokenized,sample_token_ids = self.tokenize_samples(self.window,self.seq,inference) |
|
|
| logger.info(f'Vocab size for primary sequences: {len(self.seq_tokens_ids_dict.keys())}') |
| logger.info(f'Vocab size for secondary structure: {len(self.second_input_tokens_ids_dict.keys())}') |
| logger.info(f'Number of sequences used for tokenization: {samples_tokenized.shape[0]}') |
|
|
| |
| if "comp" in self.model_input: |
| |
| self.seq_comp = [] |
| for feature in self.seq: |
| feature = feature.replace('A','%temp%').replace('T','A')\ |
| .replace('C','%temp2%').replace('G','C')\ |
| .replace('%temp%','T').replace('%temp2%','G') |
| self.seq_comp.append(feature) |
| |
| self.seq_tokens_ids_dict_temp = self.seq_tokens_ids_dict |
| self.seq_tokens_ids_dict = None |
| |
| _,seq_comp_str_token_ids = self.tokenize_samples(self.window,self.seq_comp,inference) |
| sec_input_value = seq_comp_str_token_ids |
| |
| self.second_input_tokens_ids_dict = self.seq_tokens_ids_dict |
| self.seq_tokens_ids_dict = self.seq_tokens_ids_dict_temp |
|
|
|
|
| |
| if "struct" in self.model_input: |
| _,sec_str_token_ids = self.tokenize_secondary_structure(self.window,self.struct,inference) |
| sec_input_value = sec_str_token_ids |
|
|
|
|
| |
| if "seq-seq" in self.model_input: |
| sample_token_ids,sec_input_value = self.phase_sequence(sample_token_ids) |
| self.second_input_tokens_ids_dict = self.seq_tokens_ids_dict |
|
|
| |
| |
| |
| if "seq-rev" in self.model_input or "baseline" in self.model_input or self.model_input == 'seq': |
| sample_token_ids_rev = sample_token_ids[:,::-1] |
| n_zeros = np.count_nonzero(sample_token_ids_rev==0, axis=1) |
| sec_input_value = self.custom_roll(sample_token_ids_rev, -n_zeros) |
| self.second_input_tokens_ids_dict = self.seq_tokens_ids_dict |
|
|
|
|
|
|
|
|
| seqs_length = list(sum(sample_token_ids.T !=0)) |
|
|
| labels_df = self.prepare_multi_idx_pd(1,"Labels",self.labels.values) |
| tokens_id_df = self.prepare_multi_idx_pd(sample_token_ids.shape[1],"tokens_id",sample_token_ids) |
| tokens_df = self.prepare_multi_idx_pd(samples_tokenized.shape[1],"tokens",samples_tokenized) |
| sec_input_df = self.prepare_multi_idx_pd(sec_input_value.shape[1],'second_input',sec_input_value) |
| seqs_length_df = self.prepare_multi_idx_pd(1,"seqs_length",seqs_length) |
|
|
| all_df = labels_df.join(tokens_df).join(tokens_id_df).join(sec_input_df).join(seqs_length_df) |
|
|
| |
| self.save_token_dicts() |
| return all_df |