|
|
| import logging |
| import warnings |
| from argparse import ArgumentParser |
| from contextlib import redirect_stdout |
| from datetime import datetime |
| from pathlib import Path |
| from typing import List |
|
|
| import numpy as np |
| import pandas as pd |
| from hydra.utils import instantiate |
| from omegaconf import OmegaConf |
| from sklearn.preprocessing import StandardScaler |
| from umap import UMAP |
|
|
| from ..novelty_prediction.id_vs_ood_nld_clf import get_closest_ngbr_per_split |
| from ..processing.seq_tokenizer import SeqTokenizer |
| from ..utils.file import load |
| from ..utils.tcga_post_analysis_utils import Results_Handler |
| from ..utils.utils import (get_model, infer_from_pd, |
| prepare_inference_results_tcga, |
| update_config_with_inference_params) |
|
|
| logger = logging.getLogger(__name__) |
|
|
| warnings.filterwarnings("ignore") |
|
|
| def aggregate_ensemble_model(lev_dist_df:pd.DataFrame): |
| ''' |
| This function aggregates the predictions of the ensemble model by choosing the model with the lowest and the highest NLD per query sequence. |
| If the lowest NLD is lower than Novelty Threshold, then the model with the lowest NLD is chosen as the ensemble prediction. |
| Otherwise, the model with the highest NLD is chosen as the ensemble prediction. |
| ''' |
| |
| |
| |
| |
| baseline_df = lev_dist_df[lev_dist_df['Model'] == 'Baseline'].reset_index(drop=True) |
| lev_dist_df = lev_dist_df[lev_dist_df['Model'] != 'Baseline'].reset_index(drop=True) |
| min_lev_dist_df = lev_dist_df.iloc[lev_dist_df.groupby('Sequence')['NLD'].idxmin().values] |
| |
| max_lev_dist_df = lev_dist_df.iloc[lev_dist_df.groupby('Sequence')['NLD'].idxmax().values] |
| |
| novel_mask_df = min_lev_dist_df['NLD'] > min_lev_dist_df['Novelty Threshold'] |
| |
| min_lev_dist_df = min_lev_dist_df[~novel_mask_df.values] |
| |
| max_lev_dist_df = max_lev_dist_df[novel_mask_df.values] |
| |
| ensemble_lev_dist_df = pd.concat([min_lev_dist_df,max_lev_dist_df]) |
| |
| ensemble_lev_dist_df['Model'] = 'Ensemble' |
| |
| lev_dist_df = pd.concat([lev_dist_df,ensemble_lev_dist_df,baseline_df]) |
| return lev_dist_df.reset_index(drop=True) |
|
|
|
|
| def read_inference_model_config(model:str,mc_or_sc,trained_on:str,path_to_models:str): |
| transforna_folder = "TransfoRNA_ID" |
| if trained_on == "full": |
| transforna_folder = "TransfoRNA_FULL" |
|
|
| model_path = f"{path_to_models}/{transforna_folder}/{mc_or_sc}/{model}/meta/hp_settings.yaml" |
| cfg = OmegaConf.load(model_path) |
| return cfg |
|
|
| def predict_transforna(sequences: List[str], model: str = "Seq-Rev", mc_or_sc:str='sub_class',\ |
| logits_flag:bool = False,attention_flag:bool = False,\ |
| similarity_flag:bool=False,n_sim:int=3,embedds_flag:bool = False, \ |
| umap_flag:bool = False,trained_on:str='full',path_to_models:str='') -> pd.DataFrame: |
| ''' |
| This function predicts the major class or sub class of a list of sequences using the TransfoRNA model. |
| Additionaly, it can return logits, attention scores, similarity scores, gene embeddings or umap embeddings. |
| |
| Input: |
| sequences: list of sequences to predict |
| model: model to use for prediction |
| mc_or_sc: models trained on major class or sub class |
| logits_flag: whether to return logits |
| attention_flag: whether to return attention scores (obtained from the self-attention layer) |
| similarity_flag: whether to return explanatory/similar sequences in the training set |
| n_sim: number of similar sequences to return |
| embedds_flag: whether to return embeddings of the sequences |
| umap_flag: whether to return umap embeddings |
| trained_on: whether to use the model trained on the full dataset or the ID dataset |
| Output: |
| pd.DataFrame with the predictions |
| ''' |
| |
| assert sum([logits_flag,attention_flag,similarity_flag,embedds_flag,umap_flag]) <= 1, 'One option at most can be True' |
| |
| model = "-".join([word.capitalize() for word in model.split("-")]) |
| cfg = read_inference_model_config(model,mc_or_sc,trained_on,path_to_models) |
| cfg = update_config_with_inference_params(cfg,mc_or_sc,trained_on,path_to_models) |
| root_dir = Path(__file__).parents[1].absolute() |
|
|
| with redirect_stdout(None): |
| cfg, net = get_model(cfg, root_dir) |
| |
| infer_pd = pd.Series(sequences, name="Sequences").to_frame() |
| predicted_labels, logits, gene_embedds_df,attn_scores_pd,all_data, max_len, net,_ = infer_from_pd(cfg, net, infer_pd, SeqTokenizer,attention_flag) |
|
|
| if model == 'Seq': |
| gene_embedds_df = gene_embedds_df.iloc[:,:int(gene_embedds_df.shape[1]/2)] |
| if logits_flag: |
| cfg['log_logits'] = True |
| prepare_inference_results_tcga(cfg, predicted_labels, logits, all_data, max_len) |
| infer_pd = all_data["infere_rna_seq"] |
|
|
| if logits_flag: |
| logits_df = infer_pd.rename_axis("Sequence").reset_index() |
| logits_cols = [col for col in infer_pd.columns if "Logits" in col] |
| logits_df = infer_pd[logits_cols] |
| logits_df.columns = pd.MultiIndex.from_tuples(logits_df.columns, names=["Logits", "Sub Class"]) |
| logits_df.columns = logits_df.columns.droplevel(0) |
| return logits_df |
| |
| elif attention_flag: |
| return attn_scores_pd |
|
|
| elif embedds_flag: |
| return gene_embedds_df |
|
|
| else: |
| |
| embedds_path = '/'.join(cfg['inference_settings']["model_path"].split('/')[:-2])+'/embedds' |
| results:Results_Handler = Results_Handler(embedds_path=embedds_path,splits=['train']) |
| results.get_knn_model() |
| lv_threshold = load(results.analysis_path+"/novelty_model_coef")["Threshold"] |
| logger.info(f'computing levenstein distance for the inference set') |
| |
| gene_embedds_df.columns = results.embedds_cols[:len(gene_embedds_df.columns)] |
| |
| gene_embedds_df[results.seq_col] = gene_embedds_df.index |
| |
| results.splits_df_dict['infer_df'] = gene_embedds_df |
|
|
|
|
| _,_,top_n_seqs,top_n_labels,distances,lev_dist = get_closest_ngbr_per_split(results,'infer',num_neighbors=n_sim) |
|
|
| if similarity_flag: |
| |
| sim_df = pd.DataFrame() |
| |
| sequences = gene_embedds_df.index.tolist() |
| |
| sequences_duplicated = [seq for seq in sequences for _ in range(n_sim)] |
| sim_df['Sequence'] = sequences_duplicated |
| |
| sim_df[f'Explanatory Sequence'] = top_n_seqs |
| sim_df['NLD'] = lev_dist |
| sim_df['Explanatory Label'] = top_n_labels |
| sim_df['Novelty Threshold'] = lv_threshold |
| |
| sim_df = sim_df.sort_values(by=['Sequence','NLD'],ascending=[False,True]) |
| return sim_df |
| |
| logger.info(f'num of hico based on entropy novelty prediction is {sum(infer_pd["Is Familiar?"])}') |
| |
| lv_dist_closest = [min(lev_dist[i:i+n_sim]) for i in range(0,len(lev_dist),n_sim)] |
| top_n_labels_closest = [top_n_labels[i:i+n_sim][np.argmin(lev_dist[i:i+n_sim])] for i in range(0,len(lev_dist),n_sim)] |
| top_n_seqs_closest = [top_n_seqs[i:i+n_sim][np.argmin(lev_dist[i:i+n_sim])] for i in range(0,len(lev_dist),n_sim)] |
| infer_pd['Is Familiar?'] = [True if lv<lv_threshold else False for lv in lv_dist_closest] |
|
|
| if umap_flag: |
| |
| logger.info(f'computing umap for the inference set') |
| gene_embedds_df = gene_embedds_df.drop(results.seq_col,axis=1) |
| umap = UMAP(n_components=2,random_state=42) |
| scaled_embedds = StandardScaler().fit_transform(gene_embedds_df.values) |
| gene_embedds_df = pd.DataFrame(umap.fit_transform(scaled_embedds),columns=['UMAP1','UMAP2']) |
| gene_embedds_df['Net-Label'] = infer_pd['Net-Label'].values |
| gene_embedds_df['Is Familiar?'] = infer_pd['Is Familiar?'].values |
| gene_embedds_df['Explanatory Label'] = top_n_labels_closest |
| gene_embedds_df['Explanatory Sequence'] = top_n_seqs_closest |
| gene_embedds_df['Sequence'] = infer_pd.index |
| return gene_embedds_df |
|
|
| |
| infer_pd['Novelty Threshold'] = lv_threshold |
| infer_pd['NLD'] = lv_dist_closest |
| infer_pd['Explanatory Label'] = top_n_labels_closest |
| infer_pd['Explanatory Sequence'] = top_n_seqs_closest |
| infer_pd = infer_pd.round({"NLD": 2, "Novelty Threshold": 2}) |
| logger.info(f'num of new hico based on levenstein distance is {np.sum(infer_pd["Is Familiar?"])}') |
| return infer_pd.rename_axis("Sequence").reset_index() |
|
|
| def predict_transforna_all_models(sequences: List[str], mc_or_sc:str = 'sub_class',logits_flag: bool = False, attention_flag: bool = False,\ |
| similarity_flag: bool = False, n_sim:int = 3, |
| embedds_flag:bool=False, umap_flag:bool = False, trained_on:str="full",path_to_models:str='') -> pd.DataFrame: |
| """ |
| Predicts the labels of the sequences using all the models available in the transforna package. |
| If non of the flags are true, it constructs and aggrgates the output of the ensemble model. |
| |
| Input: |
| sequences: list of sequences to predict |
| mc_or_sc: models trained on major class or sub class |
| logits_flag: whether to return logits |
| attention_flag: whether to return attention scores (obtained from the self-attention layer) |
| similarity_flag: whether to return explanatory/similar sequences in the training set |
| n_sim: number of similar sequences to return |
| embedds_flag: whether to return embeddings of the sequences |
| umap_flag: whether to return umap embeddings |
| trained_on: whether to use the model trained on the full dataset or the ID dataset |
| Output: |
| df: dataframe with the predictions |
| """ |
| now = datetime.now() |
| before_time = now.strftime("%H:%M:%S") |
| models = ["Baseline","Seq", "Seq-Seq", "Seq-Struct", "Seq-Rev"] |
| if similarity_flag or embedds_flag: |
| models = ["Baseline","Seq", "Seq-Seq", "Seq-Struct", "Seq-Rev"] |
| if attention_flag: |
| models = ["Seq", "Seq-Struct", "Seq-Rev"] |
| df = None |
| for model in models: |
| logger.info(model) |
| df_ = predict_transforna(sequences, model, mc_or_sc,logits_flag,attention_flag,similarity_flag,n_sim,embedds_flag,umap_flag,trained_on=trained_on,path_to_models = path_to_models) |
| df_["Model"] = model |
| df = pd.concat([df, df_], axis=0) |
| |
| if not logits_flag and not attention_flag and not similarity_flag and not embedds_flag and not umap_flag: |
| df = aggregate_ensemble_model(df) |
|
|
| now = datetime.now() |
| after_time = now.strftime("%H:%M:%S") |
| delta_time = datetime.strptime(after_time, "%H:%M:%S") - datetime.strptime(before_time, "%H:%M:%S") |
| logger.info(f"Time taken: {delta_time}") |
| |
| return df |
|
|
|
|
| if __name__ == "__main__": |
| parser = ArgumentParser() |
| parser.add_argument("sequences", nargs="+") |
| parser.add_argument("--logits_flag", nargs="?", const = True,default=False) |
| parser.add_argument("--attention_flag", nargs="?", const = True,default=False) |
| parser.add_argument("--similarity_flag", nargs="?", const = True,default=False) |
| parser.add_argument("--n_sim", nargs="?", const = 3,default=3) |
| parser.add_argument("--embedds_flag", nargs="?", const = True,default=False) |
| parser.add_argument("--trained_on", nargs="?", const = True,default="full") |
| predict_transforna_all_models(**vars(parser.parse_args())) |
|
|