#!/usr/bin/env python3 # various utility functions employed by the pipeline import json import re import time from functools import reduce, wraps import numpy as np import pandas as pd import spacy import torch import warnings from guidance.models import Transformers from guidance import gen as guidance_gen from huggingface_hub import HfFolder, hf_hub_download from datasets import load_dataset from transformers import AutoTokenizer, BertTokenizer, AutoModelForCausalLM, BertForSequenceClassification from methods import gdc_api_calls # spacy warning, upgrade to latest spacy # in the next release warnings.filterwarnings( "ignore", message="Possible set union at position", category=FutureWarning ) def load_llama_llm(AUTH_TOKEN): # hugging face model # https://huggingface.co/blog/llama32 model_id = "meta-llama/Llama-3.2-3B-Instruct" tok = AutoTokenizer.from_pretrained( model_id, trust_remote_code=True, token=AUTH_TOKEN ) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16, trust_remote_code=True, # device_map="auto", token=AUTH_TOKEN ) model = model.to("cuda" if torch.cuda.is_available() else "cpu") model = model.eval() return model, tok def load_gdc_genes_mutations_hf(AUTH_TOKEN): dataset_id = 'uc-ctds/GDC-QAG-genes-mutations' filename = 'gdc_genes_mutations.json' json_path = hf_hub_download( repo_id=dataset_id, filename=filename, repo_type="dataset", token=AUTH_TOKEN ) # json_path = load_dataset(dataset_id, token=AUTH_TOKEN) with open(json_path, 'r') as f: gdc_genes_mutations = json.load(f) return gdc_genes_mutations def load_intent_model_hf(AUTH_TOKEN): model_id = 'uc-ctds/query_intent' tok = AutoTokenizer.from_pretrained( model_id, trust_remote_code=True, token=AUTH_TOKEN ) model = BertForSequenceClassification.from_pretrained( model_id, token=AUTH_TOKEN) return model, tok def calculate_ssm_frequency(ssm_statistics, mutation_name, total_case_count, cancer_entities, project_mappings): ssm_frequency = {} for project in ssm_statistics.keys(): freq = ( ssm_statistics[project]["ssm_counts"] / total_case_count[project] ) ssm_frequency[project] = {"frequency": round(freq * 100, 2)} # if there are no ssms, set to 0 counts for c in cancer_entities: if c not in ssm_frequency: ssm_frequency[c] = {'frequency': 0.0} return ssm_frequency def calculate_joint_ssm_frequency(ssm_statistics, total_case_count, mutation_list, cancer_entities): # stores the result for all cancers joint_ssm_frequency = {} # initialize joint_freq by cancer entities joint_ssm_frequency_for_cancer = {} for c in cancer_entities: joint_ssm_frequency_for_cancer[c] = {} joint_ssm_frequency_for_cancer[c] = {"joint_frequency": 0.0} projects_with_mutation = [ set(ssm_statistics[mutation].keys()) for mutation in mutation_list ] overlapping_projects_with_mutation = list( reduce(lambda x, y: x & y, projects_with_mutation) ) for project in overlapping_projects_with_mutation: cases_with_mutation = [ set(ssm_statistics[mutation][project]["case_id_list"]) for mutation in mutation_list ] print('getting shared cases...') shared_cases = list(reduce(lambda x, y: x & y, cases_with_mutation)) print('number of shared cases: {}'.format(len(shared_cases))) if shared_cases: if project not in joint_ssm_frequency: joint_ssm_frequency[project] = {} joint_frequency = len(shared_cases) / total_case_count[project] joint_ssm_frequency[project]["joint_frequency"] = round( joint_frequency * 100, 2 ) # filter for specific cancer type and return for c in cancer_entities: if c in joint_ssm_frequency: joint_ssm_frequency_for_cancer[c]["joint_frequency"] = joint_ssm_frequency[ c ]["joint_frequency"] return joint_ssm_frequency_for_cancer def flatten_ssm_results_to_text(result, result_type): result_text = [] print('preparing a GDC Result for query augmentation...') if result_type == "joint_frequency": for k, v in result.items(): if k == "joint_frequency": for k2, v2 in v.items(): gdc_result = "joint frequency in {} is {}%".format(k2, v2["joint_frequency"]) result_text.append(gdc_result) else: for k, v in result.items(): if k != "joint_frequency": for k2, v2 in v.items(): gdc_result = "The frequency of {} in {} is {}%".format(k, k2, v2["frequency"]) result_text.append(gdc_result) print('prepared GDC Result: {}'.format(gdc_result)) return result_text def get_ssm_frequency( gene_entities, mutation_entities, cancer_entities, project_mappings ): ssm_statistics = {} mutation_list = [] result = {} total_case_count = {} for ce in cancer_entities: total_case_count[ce] = gdc_api_calls.get_available_ssm_data_for_project(ce) # to match the genes with mutations if len(mutation_entities) > len(gene_entities): gene_entities = gene_entities * len(mutation_entities) for gene, mutation in zip(gene_entities, mutation_entities): mutation_name = "_".join([gene, mutation]) mutation_list.append(mutation_name) ssm_id = gdc_api_calls.get_ssm_id(gene, mutation) ssm_counts_by_project = gdc_api_calls.get_ssm_counts(ssm_id, cancer_entities) ssm_statistics[mutation_name] = ssm_counts_by_project result[mutation_name] = calculate_ssm_frequency( ssm_statistics[mutation_name], mutation_name, total_case_count, cancer_entities, project_mappings ) print('\nStep 5: Query GDC and process results\n') for mut in mutation_list: print('mutation: {}'.format(mut)) for ce in cancer_entities: try: print('number of cases with mutation: {}'.format( ssm_statistics[mut][ce]["ssm_counts"])) except Exception as e: # if no ssms, this will be empty print('number of cases with mutation: {}'.format( ssm_statistics )) print('total case count: {}'.format( total_case_count[ce])) # only supporting for two mutations atm if len(mutation_list) > 1: result["joint_frequency"] = calculate_joint_ssm_frequency( ssm_statistics, total_case_count, mutation_list, cancer_entities ) result_text = flatten_ssm_results_to_text(result, result_type="joint_frequency") else: result["joint_frequency"] = 0 result_text = flatten_ssm_results_to_text( result, result_type="single_frequency" ) return result_text, cancer_entities def decompose_mutation_and_cnv(query, match_term, gdc_genes_mutations): decompose_result = {} genes = [g for g in query.split(" ") if g in gdc_genes_mutations.keys()] # query must have cnv first, followed by mutation cnv_gene_name, mut_gene_name = genes[0], genes[1] # print('cnv_gene_name, mut_gene_name {} {}'.format( # cnv_gene_name, mut_gene_name)) decompose_result["cnv_and_ssm"] = True decompose_result["cnv_gene"] = cnv_gene_name decompose_result["mut_gene"] = mut_gene_name decompose_result["cnv_change_type"] = match_term return decompose_result def get_freq_of_cnv_and_ssms( query, cancer_entities, gene_entities, gdc_genes_mutations ): lc_query = query.lower() match_term = "" cnv_terms = [ "amplification", "deletion", "loss", "gain", "homozygous deletion", "heterozygous deletion", ] for term in cnv_terms: if term in lc_query: match_term = term # print('match_term {}'.format(match_term)) if match_term: decompose_result = decompose_mutation_and_cnv( query, match_term, gdc_genes_mutations ) # print('decompose result {}'.format(decompose_result)) result, cancer_entities = gdc_api_calls.run_cnv_ssm_api( decompose_result, cancer_entities, query ) # print('result {}'.format(result)) else: # no specific match terms, return freq of cnvs + ssm result, cancer_entities = gdc_api_calls.get_top_cases_counts_by_gene( gene_entities, cancer_entities ) return result, cancer_entities def return_initial_cancer_entities(query, model): nlp = spacy.load(model) doc = nlp(query) result = doc.ents initial_cancer_entities = [e.text for e in result if e.label_ == "DISEASE"] return initial_cancer_entities def infer_gene_entities_from_query(query, gdc_genes_mutations): entities = [] # gene recognition with simple dict-based method for g in gdc_genes_mutations.keys(): if (g in query) and (g in query.split(" ")): entities.append(g) return entities def check_if_project_id_in_query(project_list, query): # check if mention of project keys # e.g. TCGA-BRCA in query final_entities = [ potential_ce for potential_ce in query.split(" ") if potential_ce in project_list ] return final_entities def proj_id_and_partial_match(query, project_mappings, initial_cancer_entities): final_entities = [] if initial_cancer_entities: # print('checking for full match between initial cancer entities and GDC project descriptions') # check for match with project_mapping values # e.g. match "ovarian serous cystadenocarcinoma" to TCGA-OV project for ic in initial_cancer_entities: for k, v in project_mappings.items(): for c in v: if ic in c.lower(): # print('found!!! {} {}'.format(ic, c.lower())) final_entities.append(k) else: # print('no initial cancer entities, check for full match between query terms and GDC project descriptions') for term in query.lower().split(" "): for k, v in project_mappings.items(): for c in v: if term in c.lower(): # print('found!!! {} {}'.format(ic, c.lower())) final_entities.append(k) return list(set(final_entities)) def postprocess_cancer_entities(project_mappings, initial_cancer_entities, query): # print('initial cancer entities {}'.format(initial_cancer_entities)) project_list = project_mappings.keys() # print('check if GDC project-id mentioned in query') final_entities = check_if_project_id_in_query(project_list, query) if final_entities: return final_entities else: if initial_cancer_entities: # first query GDC projects endpt # print('test 1 (w/ initial entities): querying GDC projects endpt for project_id') gdc_project_match = gdc_api_calls.map_cancer_entities_to_project( initial_cancer_entities, project_mappings ) # print('mapped projects to ids {}'.format(gdc_project_match)) if gdc_project_match.values(): final_entities = list(gdc_project_match.values()) if not final_entities: # print('test 2 (w/ initial entities): no result from GDC projects endpt, check for matches ' # 'between query terms and gdc project_mappings') final_entities = proj_id_and_partial_match( query, project_mappings, initial_cancer_entities ) else: # no initial_cancer_entities # check project_mappings keys/values for matches with query terms # print('test 3 (w/o initial entities): no result from GDC projects endpt, check for matches ' # 'between query terms and gdc project_mappings') final_entities = proj_id_and_partial_match( query, project_mappings, initial_cancer_entities ) return final_entities def infer_mutation_entities(gene_entities, query, gdc_genes_mutations): mutation_entities = [] for g in gene_entities: for m in gdc_genes_mutations[g]: if m in query: mutation_entities.append(m) return mutation_entities def set_hf_token(token_path): # hugging face token with open(token_path, "r") as hf_token_file: HF_TOKEN = hf_token_file.read().strip() HfFolder.save_token(HF_TOKEN) def get_final_columns(): # colnames for final output CSV final_columns = [ "questions", "gene_entities", "mutation_entities", "cancer_entities", "intent", "llama_base_output", "llama_base_stat", "gdc_result", "gdc_qag_base_stat", "descriptive_prompt", "percentage_prompt", "final_gdc_qag_desc_response", "final_gdc_qag_percentage_response", "final_gdc_qag_response", ] return final_columns def timeit(fn): @wraps(fn) def wrapper(*args, **kwargs): start = time.perf_counter() result = fn(*args, **kwargs) end = time.perf_counter() print(f"{fn.__name__} took {end - start:.4f} seconds") return result return wrapper