Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # various utility functions employed by the pipeline | |
| import json | |
| import re | |
| import time | |
| from functools import reduce, wraps | |
| import numpy as np | |
| import pandas as pd | |
| import spacy | |
| import torch | |
| import warnings | |
| from guidance.models import Transformers | |
| from guidance import gen as guidance_gen | |
| from huggingface_hub import HfFolder, hf_hub_download | |
| from datasets import load_dataset | |
| from transformers import AutoTokenizer, BertTokenizer, AutoModelForCausalLM, BertForSequenceClassification | |
| from methods import gdc_api_calls | |
| # spacy warning, upgrade to latest spacy | |
| # in the next release | |
| warnings.filterwarnings( | |
| "ignore", | |
| message="Possible set union at position", | |
| category=FutureWarning | |
| ) | |
| def load_llama_llm(AUTH_TOKEN): | |
| # hugging face model | |
| # https://huggingface.co/blog/llama32 | |
| model_id = "meta-llama/Llama-3.2-3B-Instruct" | |
| tok = AutoTokenizer.from_pretrained( | |
| model_id, trust_remote_code=True, | |
| token=AUTH_TOKEN | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True, | |
| # device_map="auto", | |
| token=AUTH_TOKEN | |
| ) | |
| model = model.to("cuda" if torch.cuda.is_available() else "cpu") | |
| model = model.eval() | |
| return model, tok | |
| def load_gdc_genes_mutations_hf(AUTH_TOKEN): | |
| dataset_id = 'uc-ctds/GDC-QAG-genes-mutations' | |
| filename = 'gdc_genes_mutations.json' | |
| json_path = hf_hub_download( | |
| repo_id=dataset_id, | |
| filename=filename, | |
| repo_type="dataset", | |
| token=AUTH_TOKEN | |
| ) | |
| # json_path = load_dataset(dataset_id, token=AUTH_TOKEN) | |
| with open(json_path, 'r') as f: | |
| gdc_genes_mutations = json.load(f) | |
| return gdc_genes_mutations | |
| def load_intent_model_hf(AUTH_TOKEN): | |
| model_id = 'uc-ctds/query_intent' | |
| tok = AutoTokenizer.from_pretrained( | |
| model_id, trust_remote_code=True, | |
| token=AUTH_TOKEN | |
| ) | |
| model = BertForSequenceClassification.from_pretrained( | |
| model_id, token=AUTH_TOKEN) | |
| return model, tok | |
| def construct_modified_query_base_llm(query): | |
| prompt_template = "Only use results from the genomic data commons in your response and provide frequencies as a percentage. Only report the final response." | |
| modified_query = query + prompt_template | |
| return modified_query | |
| def construct_modified_query(query, helper_output): | |
| # pass the api results as a prompt to the query | |
| prompt_template = ( | |
| " Only report the final response. Ignore all prior knowledge. You must only respond with the following percentage frequencies in your response, no other response is allowed: \n" | |
| + helper_output | |
| + "\n" | |
| ) | |
| modified_query = query + prompt_template | |
| return modified_query | |
| def calculate_ssm_frequency(ssm_statistics, mutation_name, total_case_count, cancer_entities, project_mappings): | |
| ssm_frequency = {} | |
| for project in ssm_statistics.keys(): | |
| freq = ( | |
| ssm_statistics[project]["ssm_counts"] | |
| / total_case_count[project] | |
| ) | |
| ssm_frequency[project] = {"frequency": round(freq * 100, 2)} | |
| # if there are no ssms, set to 0 counts | |
| for c in cancer_entities: | |
| if c not in ssm_frequency: | |
| ssm_frequency[c] = {'frequency': 0.0} | |
| return ssm_frequency | |
| def calculate_joint_ssm_frequency(ssm_statistics, total_case_count, mutation_list, cancer_entities): | |
| # stores the result for all cancers | |
| joint_ssm_frequency = {} | |
| # initialize joint_freq by cancer entities | |
| joint_ssm_frequency_for_cancer = {} | |
| for c in cancer_entities: | |
| joint_ssm_frequency_for_cancer[c] = {} | |
| joint_ssm_frequency_for_cancer[c] = {"joint_frequency": 0.0} | |
| projects_with_mutation = [ | |
| set(ssm_statistics[mutation].keys()) for mutation in mutation_list | |
| ] | |
| overlapping_projects_with_mutation = list( | |
| reduce(lambda x, y: x & y, projects_with_mutation) | |
| ) | |
| for project in overlapping_projects_with_mutation: | |
| cases_with_mutation = [ | |
| set(ssm_statistics[mutation][project]["case_id_list"]) | |
| for mutation in mutation_list | |
| ] | |
| print('getting shared cases...') | |
| shared_cases = list(reduce(lambda x, y: x & y, cases_with_mutation)) | |
| print('number of shared cases: {}'.format(len(shared_cases))) | |
| if shared_cases: | |
| if project not in joint_ssm_frequency: | |
| joint_ssm_frequency[project] = {} | |
| joint_frequency = len(shared_cases) / total_case_count[project] | |
| joint_ssm_frequency[project]["joint_frequency"] = round( | |
| joint_frequency * 100, 2 | |
| ) | |
| # filter for specific cancer type and return | |
| for c in cancer_entities: | |
| if c in joint_ssm_frequency: | |
| joint_ssm_frequency_for_cancer[c]["joint_frequency"] = joint_ssm_frequency[ | |
| c | |
| ]["joint_frequency"] | |
| return joint_ssm_frequency_for_cancer | |
| def flatten_ssm_results_to_text(result, result_type): | |
| result_text = [] | |
| print('preparing a GDC Result for query augmentation...') | |
| if result_type == "joint_frequency": | |
| for k, v in result.items(): | |
| if k == "joint_frequency": | |
| for k2, v2 in v.items(): | |
| gdc_result = "joint frequency in {} is {}%".format(k2, v2["joint_frequency"]) | |
| result_text.append(gdc_result) | |
| else: | |
| for k, v in result.items(): | |
| if k != "joint_frequency": | |
| for k2, v2 in v.items(): | |
| gdc_result = "The frequency of {} in {} is {}%".format(k, k2, v2["frequency"]) | |
| result_text.append(gdc_result) | |
| print('prepared GDC Result: {}'.format(gdc_result)) | |
| return result_text | |
| def get_ssm_frequency( | |
| gene_entities, mutation_entities, cancer_entities, project_mappings | |
| ): | |
| ssm_statistics = {} | |
| mutation_list = [] | |
| result = {} | |
| total_case_count = {} | |
| for ce in cancer_entities: | |
| total_case_count[ce] = gdc_api_calls.get_available_ssm_data_for_project(ce) | |
| # to match the genes with mutations | |
| if len(mutation_entities) > len(gene_entities): | |
| gene_entities = gene_entities * len(mutation_entities) | |
| for gene, mutation in zip(gene_entities, mutation_entities): | |
| mutation_name = "_".join([gene, mutation]) | |
| mutation_list.append(mutation_name) | |
| ssm_id = gdc_api_calls.get_ssm_id(gene, mutation) | |
| ssm_counts_by_project = gdc_api_calls.get_ssm_counts(ssm_id, cancer_entities) | |
| ssm_statistics[mutation_name] = ssm_counts_by_project | |
| result[mutation_name] = calculate_ssm_frequency( | |
| ssm_statistics[mutation_name], mutation_name, total_case_count, cancer_entities, project_mappings | |
| ) | |
| print('\nStep 5: Query GDC and process results\n') | |
| for mut in mutation_list: | |
| print('mutation: {}'.format(mut)) | |
| for ce in cancer_entities: | |
| print('number of cases with mutation: {}'.format( | |
| ssm_statistics[mut][ce]["ssm_counts"])) | |
| print('total case count: {}'.format( | |
| total_case_count[ce])) | |
| # only supporting for two mutations atm | |
| if len(mutation_list) > 1: | |
| result["joint_frequency"] = calculate_joint_ssm_frequency( | |
| ssm_statistics, total_case_count, mutation_list, cancer_entities | |
| ) | |
| result_text = flatten_ssm_results_to_text(result, result_type="joint_frequency") | |
| else: | |
| result["joint_frequency"] = 0 | |
| result_text = flatten_ssm_results_to_text( | |
| result, result_type="single_frequency" | |
| ) | |
| return result_text, cancer_entities | |
| def decompose_mutation_and_cnv(query, match_term, gdc_genes_mutations): | |
| decompose_result = {} | |
| genes = [g for g in query.split(" ") if g in gdc_genes_mutations.keys()] | |
| # query must have cnv first, followed by mutation | |
| cnv_gene_name, mut_gene_name = genes[0], genes[1] | |
| # print('cnv_gene_name, mut_gene_name {} {}'.format( | |
| # cnv_gene_name, mut_gene_name)) | |
| decompose_result["cnv_and_ssm"] = True | |
| decompose_result["cnv_gene"] = cnv_gene_name | |
| decompose_result["mut_gene"] = mut_gene_name | |
| decompose_result["cnv_change_type"] = match_term | |
| return decompose_result | |
| def get_freq_of_cnv_and_ssms( | |
| query, cancer_entities, gene_entities, gdc_genes_mutations | |
| ): | |
| lc_query = query.lower() | |
| match_term = "" | |
| cnv_terms = [ | |
| "amplification", | |
| "deletion", | |
| "loss", | |
| "gain", | |
| "homozygous deletion", | |
| "heterozygous deletion", | |
| ] | |
| for term in cnv_terms: | |
| if term in lc_query: | |
| match_term = term | |
| # print('match_term {}'.format(match_term)) | |
| if match_term: | |
| decompose_result = decompose_mutation_and_cnv( | |
| query, match_term, gdc_genes_mutations | |
| ) | |
| # print('decompose result {}'.format(decompose_result)) | |
| result, cancer_entities = gdc_api_calls.run_cnv_ssm_api( | |
| decompose_result, cancer_entities, query | |
| ) | |
| # print('result {}'.format(result)) | |
| else: | |
| # no specific match terms, return freq of cnvs + ssm | |
| result, cancer_entities = gdc_api_calls.get_top_cases_counts_by_gene( | |
| gene_entities, cancer_entities | |
| ) | |
| return result, cancer_entities | |
| def return_initial_cancer_entities(query, model): | |
| nlp = spacy.load(model) | |
| doc = nlp(query) | |
| result = doc.ents | |
| initial_cancer_entities = [e.text for e in result if e.label_ == "DISEASE"] | |
| return initial_cancer_entities | |
| def infer_gene_entities_from_query(query, gdc_genes_mutations): | |
| entities = [] | |
| # gene recognition with simple dict-based method | |
| for g in gdc_genes_mutations.keys(): | |
| if (g in query) and (g in query.split(" ")): | |
| entities.append(g) | |
| return entities | |
| def check_if_project_id_in_query(project_list, query): | |
| # check if mention of project keys | |
| # e.g. TCGA-BRCA in query | |
| final_entities = [ | |
| potential_ce | |
| for potential_ce in query.split(" ") | |
| if potential_ce in project_list | |
| ] | |
| return final_entities | |
| def proj_id_and_partial_match(query, project_mappings, initial_cancer_entities): | |
| final_entities = [] | |
| if initial_cancer_entities: | |
| # print('checking for full match between initial cancer entities and GDC project descriptions') | |
| # check for match with project_mapping values | |
| # e.g. match "ovarian serous cystadenocarcinoma" to TCGA-OV project | |
| for ic in initial_cancer_entities: | |
| for k, v in project_mappings.items(): | |
| for c in v: | |
| if ic in c.lower(): | |
| # print('found!!! {} {}'.format(ic, c.lower())) | |
| final_entities.append(k) | |
| else: | |
| # print('no initial cancer entities, check for full match between query terms and GDC project descriptions') | |
| for term in query.lower().split(" "): | |
| for k, v in project_mappings.items(): | |
| for c in v: | |
| if term in c.lower(): | |
| # print('found!!! {} {}'.format(ic, c.lower())) | |
| final_entities.append(k) | |
| return list(set(final_entities)) | |
| def postprocess_cancer_entities(project_mappings, initial_cancer_entities, query): | |
| # print('initial cancer entities {}'.format(initial_cancer_entities)) | |
| project_list = project_mappings.keys() | |
| # print('check if GDC project-id mentioned in query') | |
| final_entities = check_if_project_id_in_query(project_list, query) | |
| if final_entities: | |
| return final_entities | |
| else: | |
| if initial_cancer_entities: | |
| # first query GDC projects endpt | |
| # print('test 1 (w/ initial entities): querying GDC projects endpt for project_id') | |
| gdc_project_match = gdc_api_calls.map_cancer_entities_to_project( | |
| initial_cancer_entities, project_mappings | |
| ) | |
| # print('mapped projects to ids {}'.format(gdc_project_match)) | |
| if gdc_project_match.values(): | |
| final_entities = list(gdc_project_match.values()) | |
| if not final_entities: | |
| # print('test 2 (w/ initial entities): no result from GDC projects endpt, check for matches ' | |
| # 'between query terms and gdc project_mappings') | |
| final_entities = proj_id_and_partial_match( | |
| query, project_mappings, initial_cancer_entities | |
| ) | |
| else: | |
| # no initial_cancer_entities | |
| # check project_mappings keys/values for matches with query terms | |
| # print('test 3 (w/o initial entities): no result from GDC projects endpt, check for matches ' | |
| # 'between query terms and gdc project_mappings') | |
| final_entities = proj_id_and_partial_match( | |
| query, project_mappings, initial_cancer_entities | |
| ) | |
| return final_entities | |
| def infer_mutation_entities(gene_entities, query, gdc_genes_mutations): | |
| mutation_entities = [] | |
| for g in gene_entities: | |
| for m in gdc_genes_mutations[g]: | |
| if m in query: | |
| mutation_entities.append(m) | |
| return mutation_entities | |
| def postprocess_response(row): | |
| value_changed = "no" | |
| pattern = r".*?(\d*\.\d*)%.*?" | |
| delta_final = np.nan | |
| delta_prefinal = np.nan | |
| generated_stat_final = np.nan | |
| try: | |
| helper_output = row["helper_output"] | |
| except Exception as e: | |
| # print('unable to generate helper output, returning nan') | |
| return pd.Series(["np.nan"] * 8) | |
| pre_final_response = row["pre_final_llama_with_helper_output"] | |
| llama_base_output = row["llama_base_output"] | |
| try: | |
| llama_base_stat = float(re.search(pattern, llama_base_output).group(1)) | |
| except Exception as e: | |
| # print('unable to extract llama base stat {}'.format(str(e))) | |
| llama_base_stat = np.nan | |
| try: | |
| generated_stat_prefinal = float(re.search(pattern, pre_final_response).group(1)) | |
| except Exception as e: | |
| # print('unable to extract generated stat {}'.format(str(e))) | |
| generated_stat_prefinal = np.nan | |
| try: | |
| ground_truth_stat = float(re.search(pattern, helper_output).group(1)) | |
| except Exception as e: | |
| # print('unable to extract ground truth stat {}'.format(str(e))) | |
| ground_truth_stat = np.nan | |
| try: | |
| delta_llama = llama_base_stat - ground_truth_stat | |
| except Exception as e: | |
| # print('unable to calculate delta_llama {}'.format(str(e))) | |
| delta_llama = np.nan | |
| if not np.isnan(generated_stat_prefinal) and not np.isnan(ground_truth_stat): | |
| delta_prefinal = generated_stat_prefinal - ground_truth_stat | |
| if delta_prefinal != 0.0: | |
| final_response = "The final answer is: {}%".format(ground_truth_stat) | |
| value_changed = "yes" | |
| else: | |
| final_response = pre_final_response | |
| generated_stat_final = float(re.search(pattern, final_response).group(1)) | |
| delta_final = generated_stat_final - ground_truth_stat | |
| else: | |
| final_response = "unable to postprocess, check generated or truth stat" | |
| value_changed = "na" | |
| """ | |
| print('check if all values are populated:\n') | |
| print('delta_llama {}'.format(delta_llama)) | |
| print('value_changed {}'.format(value_changed)) | |
| print('ground_truth_stat {}'.format(ground_truth_stat)) | |
| print('generated_stat_prefinal {}'.format(generated_stat_prefinal)) | |
| print('delta_prefinal {}'.format(delta_prefinal)) | |
| print('generated_stat_final {}'.format(generated_stat_final)) | |
| print('delta_final {}'.format(delta_final)) | |
| print('final_response {}'.format(final_response)) | |
| """ | |
| return pd.Series( | |
| [ | |
| llama_base_stat, | |
| delta_llama, | |
| value_changed, | |
| ground_truth_stat, | |
| generated_stat_prefinal, | |
| delta_prefinal, | |
| generated_stat_final, | |
| delta_final, | |
| final_response, | |
| ] | |
| ) | |
| def set_hf_token(token_path): | |
| # hugging face token | |
| with open(token_path, "r") as hf_token_file: | |
| HF_TOKEN = hf_token_file.read().strip() | |
| HfFolder.save_token(HF_TOKEN) | |
| def get_final_columns(): | |
| # colnames for final output CSV | |
| final_columns = [ | |
| "questions", | |
| "gene_entities", | |
| "mutation_entities", | |
| "cancer_entities", | |
| "intent", | |
| "llama_base_output", | |
| "llama_base_stat", | |
| "helper_output", | |
| "ground_truth_stat", | |
| "modified_prompt", | |
| "final_response", | |
| "delta_llama" | |
| ] | |
| return final_columns | |
| def timeit(fn): | |
| def wrapper(*args, **kwargs): | |
| start = time.perf_counter() | |
| result = fn(*args, **kwargs) | |
| end = time.perf_counter() | |
| print(f"{fn.__name__} took {end - start:.4f} seconds") | |
| return result | |
| return wrapper |