| import logging |
| import random |
| from contextlib import redirect_stdout |
| from pathlib import Path |
| from random import randint |
| from typing import Dict, List, Tuple |
|
|
| import numpy as np |
| import pandas as pd |
|
|
| from ..novelty_prediction.id_vs_ood_nld_clf import get_closest_ngbr_per_split |
| from ..utils.energy import fold_sequences |
| from ..utils.file import load |
| from ..utils.tcga_post_analysis_utils import Results_Handler |
| from ..utils.utils import (get_model, infer_from_pd, |
| prepare_inference_results_tcga, |
| update_config_with_inference_params) |
| from .seq_tokenizer import SeqTokenizer |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class IDModelAugmenter: |
| ''' |
| This class is used to augment the dataset with the predictions of the ID models |
| It will first predict the subclasses of the NA set using the ID models |
| Then it will compute the levenstein distance between the sequences of the NA set and the closest neighbor in the training set |
| If the levenstein distance is less than a threshold, the sequence is considered familiar |
| ''' |
| def __init__(self,df:pd.DataFrame,config:Dict): |
| self.df = df |
| self.config = config |
| self.mapping_dict = load(config['train_config']['mapping_dict_path']) |
|
|
|
|
| def predict_transforna_na(self) -> Tuple: |
| infer_pd = pd.DataFrame(columns=['Sequence','Net-Label','Is Familiar?']) |
|
|
| mc_or_sc = 'major_class' if 'major_class' in self.config['model_config']['clf_target'] else 'sub_class' |
| inference_config = update_config_with_inference_params(self.config,mc_or_sc=mc_or_sc,path_to_models=self.config['path_to_models']) |
| model_path = inference_config['inference_settings']["model_path"] |
| logger.info(f"Augmenting hico sequences based on predictions from model at: {model_path}") |
|
|
| |
| embedds_path = '/'.join(inference_config['inference_settings']["model_path"].split('/')[:-2])+'/embedds' |
| |
| results:Results_Handler = Results_Handler(embedds_path=embedds_path,splits=['train','no_annotation']) |
| results.get_knn_model() |
| threshold = load(results.analysis_path+"/novelty_model_coef")["Threshold"] |
| sequences = results.splits_df_dict['no_annotation_df'][results.seq_col].values[:,0] |
| with redirect_stdout(None): |
| root_dir = Path(__file__).parents[3].absolute() |
| inference_config, net = get_model(inference_config, root_dir) |
| |
| original_infer_pd = pd.Series(sequences, name="Sequences").to_frame() |
| logger.info(f'predicting sub classes for the NA set by the ID models') |
| predicted_labels, logits,_, _,all_data, max_len, net, infer_pd = infer_from_pd(inference_config, net, original_infer_pd, SeqTokenizer) |
|
|
|
|
| prepare_inference_results_tcga(inference_config,predicted_labels, logits, all_data, max_len) |
| infer_pd = all_data["infere_rna_seq"] |
| |
| |
| logger.info('computing levenstein distance for the NA set by the ID models') |
| _,_,_,_,_,lev_dist = get_closest_ngbr_per_split(results,'no_annotation') |
| |
| logger.info(f'num of hico based on entropy novelty prediction is {sum(infer_pd["Is Familiar?"])}') |
| infer_pd['Is Familiar?'] = [True if lv<threshold else False for lv in lev_dist] |
| infer_pd['Threshold'] = threshold |
| logger.info(f'num of new hico based on levenstein distance is {np.sum(infer_pd["Is Familiar?"])}') |
| return infer_pd.rename_axis("Sequence").reset_index() |
| |
| def include_id_model_predictions(self): |
| pred_df = self.predict_transforna_na() |
| set1 = set(self.df.Labels.cat.categories) |
| set2 = set(pred_df['Net-Label'].unique()) |
| self.df['Labels'] = self.df['Labels'].cat.add_categories(set2.difference(set1)) |
|
|
| |
| familiar_seqs = pred_df[pred_df['Is Familiar?'] == True].Sequence.values |
| |
| tcga_familiar_seqs = set(familiar_seqs).intersection(set(self.df.index)) |
| |
| pred_df = pred_df[pred_df['Sequence'].isin(tcga_familiar_seqs)] |
| |
| self.df.loc[tcga_familiar_seqs,'Labels'] = pred_df[pred_df['Is Familiar?'] == True]['Net-Label'].values |
|
|
| def get_augmented_df(self): |
| self.include_id_model_predictions() |
| return self.df |
| |
| class RecombinedSeqAugmenter: |
| ''' |
| This class is used to augment the dataset with recombined sequences |
| recombinations are done by fusing two sequences from the same subclass |
| ''' |
| def __init__(self,df:pd.DataFrame,config:Dict): |
| self.df = df |
| self.config = config |
| |
| def create_recombined_seqs(self,class_label:str='recombined'): |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| unique_labels = self.df.Labels.value_counts()[self.df.Labels.value_counts() >= 1].index.tolist() |
| |
| samples = [self.df[self.df['Labels'] == label].sample(1).index[0] for label in unique_labels] |
| |
| if len(samples) % 2 != 0: |
| samples = samples[:-1] |
| np.random.shuffle(samples) |
| |
| samples_set1 = samples[:len(samples)//2] |
| samples_set2 = samples[len(samples)//2:] |
| |
| recombined_set = [] |
| for i in range(len(samples_set1)): |
| recombined_seq = samples_set1[i]+samples_set2[i] |
| |
| recombined_index = len(samples_set1[i]) |
| |
| offset = randint(-5,5) |
| recombined_index += offset |
| |
| random_half_len = int(randint(18,30)/2) |
| |
| random_seq = recombined_seq[max(0,recombined_index - random_half_len):recombined_index + random_half_len] |
| recombined_set.append(random_seq) |
|
|
| recombined_df = pd.DataFrame(index=recombined_set, data=[f'{class_label}']*len(recombined_set)\ |
| , columns =['Labels']) |
|
|
| return recombined_df |
| |
| def get_augmented_df(self): |
| recombined_df = self.create_recombined_seqs() |
| return recombined_df |
| |
| class RandomSeqAugmenter: |
| ''' |
| This class is used to augment the dataset with random sequences within the same length range as the tcga sequences |
| ''' |
| def __init__(self,df:pd.DataFrame,config:Dict): |
| self.df = df |
| self.config = config |
| self.num_seqs = 500 |
| self.min_len = 18 |
| self.max_len = 30 |
|
|
| def get_random_seq(self): |
| |
| random_seqs = [] |
| while len(random_seqs) < self.num_seqs: |
| random_seq = ''.join(random.choices(['A','C','G','T'], k=randint(self.min_len,self.max_len))) |
| if random_seq not in random_seqs and random_seq not in self.df.index: |
| random_seqs.append(random_seq) |
|
|
| return pd.DataFrame(index=random_seqs, data=['random']*len(random_seqs)\ |
| , columns =['Labels']) |
| def get_augmented_df(self): |
| random_df = self.get_random_seq() |
| return random_df |
|
|
| class PrecursorAugmenter: |
| def __init__(self,df:pd.DataFrame, config:Dict): |
| self.df = df |
| self.config = config |
| self.mapping_dict = load(config['train_config'].mapping_dict_path) |
| self.precursor_df = self.load_precursor_file() |
| self.trained_on = config.trained_on |
| |
| self.min_num_samples_per_sc:int=1 |
| if self.trained_on == 'id': |
| self.min_num_samples_per_sc = 8 |
|
|
| self.min_bin_size = 20 |
| self.max_bin_size = 30 |
| self.min_seq_len = 18 |
| self.max_seq_len = 30 |
|
|
| def load_precursor_file(self): |
| try: |
| precursor_df = pd.read_csv(self.config['train_config'].precursor_file_path, index_col=0) |
| return precursor_df |
| except: |
| logger.info('Could not load precursor file') |
| return None |
| |
| def compute_dynamic_bin_size(self,precursor_len:int, name:str=None) -> List[int]: |
| ''' |
| This function splits precursor to bins of size max_bin_size |
| if the last bin is smaller than min_bin_size, it will split the precursor to bins of size max_bin_size-1 |
| This process will continue until the last bin is larger than min_bin_size. |
| if the min bin size is reached and still the last bin is smaller than min_bin_size, the last two bins will be merged. |
| so the maximimum bin size possible would be min_bin_size+(min_bin_size-1) = 39 |
| ''' |
| def split_precursor_to_bins(precursor_len,max_bin_size): |
| ''' |
| This function splits precursor to bins of size max_bin_size |
| ''' |
| precursor_bin_lens = [] |
| for i in range(0, precursor_len, max_bin_size): |
| if i+max_bin_size < precursor_len: |
| precursor_bin_lens.append(max_bin_size) |
| else: |
| precursor_bin_lens.append(precursor_len-i) |
| return precursor_bin_lens |
|
|
| if precursor_len < self.min_bin_size: |
| return [precursor_len] |
| else: |
| precursor_bin_lens = split_precursor_to_bins(precursor_len,self.max_bin_size) |
| reduced_len = self.max_bin_size-1 |
| while precursor_bin_lens[-1] < self.min_bin_size: |
| precursor_bin_lens = split_precursor_to_bins(precursor_len,reduced_len) |
| reduced_len -= 1 |
| if reduced_len < self.min_bin_size: |
| |
| precursor_bin_lens[-2] += precursor_bin_lens[-1] |
| precursor_bin_lens = precursor_bin_lens[:-1] |
| break |
|
|
| return precursor_bin_lens |
| |
| def get_bin_with_max_overlap(self,precursor_len:int,start_frag_pos:int,frag_len:int,name) -> int: |
| ''' |
| This function returns the bin number of a fragment that overlaps the most with the fragment |
| ''' |
| precursor_bin_lens = self.compute_dynamic_bin_size(precursor_len=precursor_len,name=name) |
| bin_no = 0 |
| for i,bin_len in enumerate(precursor_bin_lens): |
| if start_frag_pos < bin_len: |
| |
| overlap = min(bin_len-start_frag_pos,frag_len) |
|
|
| if overlap > frag_len/2: |
| bin_no = i |
| else: |
| bin_no = i+1 |
| break |
|
|
| else: |
| start_frag_pos -= bin_len |
| return bin_no+1 |
|
|
| def get_precursor_info(self,mc:str,sc:str): |
| |
| xRNA_df = self.precursor_df.loc[self.precursor_df.small_RNA_class_annotation == mc] |
| xRNA_df.index = xRNA_df.index.str.replace('|','-', regex=False) |
| prec_name = sc.split('_bin-')[0] |
| |
| if mc in ['snoRNA','lncRNA','protein_coding','miscRNA']: |
| prec_name = mc+'-'+prec_name |
| prec_row_df = xRNA_df.iloc[xRNA_df.index.str.contains(prec_name)] |
| |
| if prec_row_df.empty: |
| xRNA_df = self.precursor_df.loc[self.precursor_df.small_RNA_class_annotation == 'pseudo_'+mc] |
| xRNA_df.index = xRNA_df.index.str.replace('|','-', regex=False) |
| prec_row_df = xRNA_df.iloc[xRNA_df.index.str.contains(prec_name)] |
| if prec_row_df.empty: |
| logger.info(f'precursor {prec_name} not found in HBDxBase') |
| return pd.DataFrame() |
| |
| prec_row_df = prec_row_df.iloc[0] |
| else: |
| prec_row_df = xRNA_df.loc[f'{mc}-{prec_name}'] |
| |
| precursor = prec_row_df.sequence |
| return precursor,prec_name |
| |
| def populate_from_bin(self,sc:str,precursor:str,prec_name:str,existing_seqs:List[str]): |
| ''' |
| This function will first get the bin no from the sc. |
| Then it will do three types of sampling: |
| 1. sample from the previous bin, insuring that the overlap with the middle bin is the highest |
| 2. sample from the next bin, insuring that the overlap with the middle bin is the highest |
| 3. sample from the middle bin, insuring that the overlap with the middle bin is the highest |
| The staet idx should be the middle position of the previous bin, then start position is incremented until the end of the current bin |
| ''' |
| bin_no = int(sc.split('_bin-')[1]) |
| bins = self.compute_dynamic_bin_size(len(precursor), prec_name) |
| if len(bins) == 1: |
| return pd.DataFrame() |
| |
| |
| bin_no -= 1 |
|
|
| |
| try: |
| previous_bin_start = sum(bins[:bin_no-1]) |
| except: |
| previous_bin_start = 0 |
| middle_bin_start = sum(bins[:bin_no]) |
| next_bin_start = sum(bins[:bin_no+1]) |
|
|
|
|
| try: |
| previous_bin_size = bins[bin_no-1] |
| except: |
| previous_bin_size = 0 |
|
|
| middle_bin_size = bins[bin_no] |
| try: |
| next_bin_size = bins[bin_no+1] |
| except: |
| next_bin_size = 0 |
|
|
|
|
| start_idx = previous_bin_start + previous_bin_size//2 + 1 |
| sampled_seqs = [] |
| |
| while start_idx < middle_bin_start+middle_bin_size: |
| |
| if start_idx < middle_bin_start: |
| max_overlap_prev = middle_bin_start - start_idx |
| end_idx = start_idx + randint(max(self.min_seq_len,max_overlap_prev*2+1),self.max_seq_len) |
| else: |
| max_overlap_curr = next_bin_start - start_idx |
| max_overlap_next = (start_idx + self.max_seq_len) - next_bin_start |
| max_overlap_next = min(max_overlap_next,next_bin_size) |
| if max_overlap_curr <= 9 or (max_overlap_next==0 and max_overlap_curr < self.min_seq_len): |
| end_idx = -1 |
| else: |
| end_idx = start_idx + randint(self.min_seq_len,min(self.max_seq_len,self.max_seq_len - max_overlap_next + max_overlap_curr - 1)) |
| |
| |
| if end_idx == -1: |
| break |
|
|
| tmp_seq = precursor[start_idx:end_idx] |
| |
| assert len(tmp_seq) >= self.min_seq_len and len(tmp_seq) <= self.max_seq_len, f'length of tmp_seq is {len(tmp_seq)}' |
| if tmp_seq not in existing_seqs: |
| sampled_seqs.append(tmp_seq) |
| start_idx += 1 |
| |
| |
| for frag in sampled_seqs: |
| all_occ = precursor.find(frag) |
| if not isinstance(all_occ,list): |
| all_occ = [all_occ] |
| |
| for occ in all_occ: |
| curr_bin_no = self.get_bin_with_max_overlap(len(precursor),occ,len(frag),' ') |
| |
| if abs(curr_bin_no - (bin_no+1)) > 1: |
| continue |
| assert curr_bin_no == bin_no+1, f'curr_bin_no is {curr_bin_no} and bin_no is {bin_no+1}' |
| |
| return pd.DataFrame(index=sampled_seqs, data=[sc]*len(sampled_seqs)\ |
| , columns =['Labels']) |
| |
| def populate_scs_with_bins(self): |
| augmented_df = pd.DataFrame(columns=['Labels']) |
|
|
| |
| unique_labels = self.df.Labels.value_counts()[self.df.Labels.value_counts() >= self.min_num_samples_per_sc].index.tolist() |
| scs_list = [] |
| scs_before = [] |
| sc_after = [] |
| for sc in unique_labels: |
| |
| if type(sc) == str and '_bin-' in sc: |
| |
| try: |
| mc = self.mapping_dict[sc] |
| except: |
| sc_mc_mapper = lambda x: 'miRNA' if 'miR' in x else 'tRNA' if 'tRNA' in x else 'rRNA' if 'rRNA' in x else 'snRNA' if 'snRNA' in x else 'snoRNA' if 'snoRNA' in x else 'snoRNA' if 'SNO' in x else 'protein_coding' if 'RPL37A' in x else 'lncRNA' if 'SNHG1' in x else None |
| mc = sc_mc_mapper(sc) |
| if mc is None: |
| logger.info(f'No mapping for {sc}') |
| continue |
| existing_seqs = self.df[self.df['Labels'] == sc].index |
| scs_list.append(sc) |
| scs_before.append(len(existing_seqs)) |
| |
| precursor,prec_name = self.get_precursor_info(mc,sc) |
| sc2_df = self.populate_from_bin(sc,precursor,prec_name,existing_seqs) |
| augmented_df = augmented_df.append(sc2_df) |
| sc_after.append(len(sc2_df)) |
| |
| scs_dict = {'sub_class':scs_list,'Number of samples before':scs_before,'Number of samples afrer':sc_after} |
| scs_df = pd.DataFrame(scs_dict) |
| scs_df.to_csv(f'frequency_per_sub_class_df.csv') |
|
|
|
|
| return augmented_df |
| |
| def get_augmented_df(self): |
| return self.populate_scs_with_bins() |
| |
| class DataAugmenter: |
| ''' |
| This class sets the labels of the dataset to major class or sub class labels based on the clf_target |
| major class: miRNA, tRNA ... |
| sub class: mir-192-3p, rRNA-bin-30 ... |
| Then if the models should be tained on ID models, it will augment the dataset with sequences sampled from the precursor file |
| If the models should be trained on full, it will augment the dataset based on the following: |
| 1. Random sequences |
| 2. Recombined sequences |
| 3. Sequences sampled from the precursor file |
| 4. predictions of the sequences that previously had no annotation of low confidence but were predicted to be familiar by the ID models |
| ''' |
| def __init__(self,df:pd.DataFrame, config:Dict): |
| self.df = df |
| self.config = config |
| self.mapping_dict = load(config['train_config'].mapping_dict_path) |
| self.trained_on = config.trained_on |
| self.clf_target = config['model_config'].clf_target |
| logger.info(f'Augmenting the dataset for {self.clf_target}') |
| self.set_labels() |
|
|
| self.precursor_augmenter = PrecursorAugmenter(self.df,self.config) |
| self.random_augmenter = RandomSeqAugmenter(self.df,self.config) |
| self.recombined_augmenter = RecombinedSeqAugmenter(self.df,self.config) |
| self.id_model_augmenter = IDModelAugmenter(self.df,self.config) |
|
|
|
|
|
|
| def set_labels(self): |
| if 'hico' not in self.clf_target: |
| self.df['Labels'] = self.df['subclass_name'].str.split(';', expand=True)[0] |
| else: |
| self.df['Labels'] = self.df['subclass_name'][self.df['hico'] == True] |
| |
| self.df['Labels'] = self.df['Labels'].astype('category') |
|
|
| |
| def convert_to_major_class_labels(self): |
| if 'major_class' in self.clf_target: |
| self.df['Labels'] = self.df['Labels'].map(self.mapping_dict).astype('category') |
| |
| self.df = self.df[~self.df['Labels'].str.contains(';').fillna(False)] |
|
|
| |
| def combine_df(self,new_var_df:pd.DataFrame): |
| |
| duplicated_df = new_var_df[new_var_df.index.isin(self.df.index)] |
| |
| if len(duplicated_df): |
| logger.info(f'Number of duplicated sequences to be removed augmented data: {duplicated_df.shape[0]}') |
|
|
| new_var_df = new_var_df[~new_var_df.index.isin(self.df.index)].sample(frac=1) |
|
|
| for col in self.df.columns: |
| if col not in new_var_df.columns: |
| new_var_df[col] = np.nan |
|
|
| self.df = new_var_df.append(self.df) |
| self.df.index = self.df.index.str.upper() |
| self.df.Labels = self.df.Labels.astype('category') |
| return self.df |
|
|
| |
| def annotate_artificial_affix_seqs(self): |
| |
| aa_seqs = self.df[self.df['five_prime_adapter_filter'] == 0].index.tolist() |
| self.df['Labels'] = self.df['Labels'].cat.add_categories('artificial_affix') |
| self.df.loc[aa_seqs,'Labels'] = 'artificial_affix' |
|
|
|
|
| |
| def full_pipeline(self): |
| self.df = self.id_model_augmenter.get_augmented_df() |
|
|
| |
| def post_augmentation(self): |
| random_df = self.random_augmenter.get_augmented_df() |
| |
| if 'sub_class' in self.clf_target: |
| df = self.precursor_augmenter.get_augmented_df() |
| else: |
| df = pd.DataFrame() |
| recombined_df = self.recombined_augmenter.get_augmented_df() |
| df = df.append(recombined_df).append(random_df) |
| self.df['Labels'] = self.df['Labels'].cat.add_categories({'random','recombined'}) |
| self.combine_df(df) |
|
|
| self.convert_to_major_class_labels() |
| self.annotate_artificial_affix_seqs() |
| self.df['Labels'] = self.df['Labels'].cat.remove_unused_categories() |
| self.df['Sequences'] = self.df.index.tolist() |
|
|
| if 'struct' in self.config['model_config'].model_input: |
| self.df['Secondary'] = fold_sequences(self.df.index.tolist(),temperature=37)[f'structure_37'].values |
|
|
| return self.df |
|
|
| def get_augmented_df(self): |
| if self.trained_on == 'full': |
| self.full_pipeline() |
| return self.post_augmentation() |