|
|
| from typing import Dict |
|
|
| from anndata import AnnData |
|
|
| from ..processing.seq_tokenizer import SeqTokenizer |
| from ..utils.file import load |
| from ..utils.utils import * |
|
|
|
|
| def infer_tcga(cfg:Dict= None,path:str = None): |
| if cfg['tensorboard']: |
| from ..callbacks.tbWriter import writer |
| cfg,net = get_model(cfg,path) |
| inference_path = cfg['inference_settings']['sequences_path'] |
| original_infer_df = load(inference_path, index_col=0) |
| if isinstance(original_infer_df,AnnData): |
| original_infer_df = original_infer_df.var |
| predicted_labels,logits,_,_,all_data,max_len,net,infer_df = infer_from_pd(cfg,net,original_infer_df,SeqTokenizer) |
| |
| |
| if not os.path.exists(f"inference_output"): |
| os.makedirs(f"inference_output") |
| if cfg['log_embedds']: |
| embedds_pd = log_embedds(cfg,net,all_data['infere_rna_seq']) |
| embedds_pd.to_csv(f"inference_output/{cfg['model_name']}_embedds.csv") |
| |
| prepare_inference_results_tcga(cfg,predicted_labels,logits,all_data,max_len) |
| |
| |
| if original_infer_df.shape[0] != infer_df.shape[0]: |
| all_data["infere_rna_seq"] = add_original_seqs_to_predictions(infer_df,all_data['infere_rna_seq']) |
| |
| all_data["infere_rna_seq"].to_csv(f"inference_output/{cfg['model_name']}_inference_results.csv") |
|
|
| if cfg['tensorboard']: |
| writer.close() |
| return predicted_labels |