Spaces:
Running on Zero
Running on Zero
| #!/usr/bin/env python3 | |
| # various utility functions employed by the pipeline | |
| import json | |
| import re | |
| import time | |
| from functools import reduce, wraps | |
| import numpy as np | |
| import pandas as pd | |
| import spacy | |
| import torch | |
| import warnings | |
| from guidance.models import Transformers | |
| from guidance import gen as guidance_gen | |
| from huggingface_hub import HfFolder, hf_hub_download | |
| from datasets import load_dataset | |
| from transformers import AutoTokenizer, BertTokenizer, AutoModelForCausalLM, BertForSequenceClassification | |
| from methods import gdc_api_calls | |
| # spacy warning, upgrade to latest spacy | |
| # in the next release | |
| warnings.filterwarnings( | |
| "ignore", | |
| message="Possible set union at position", | |
| category=FutureWarning | |
| ) | |
| def load_llama_llm(AUTH_TOKEN): | |
| # hugging face model | |
| # https://huggingface.co/blog/llama32 | |
| model_id = "meta-llama/Llama-3.2-3B-Instruct" | |
| tok = AutoTokenizer.from_pretrained( | |
| model_id, trust_remote_code=True, | |
| token=AUTH_TOKEN | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True, | |
| # device_map="auto", | |
| token=AUTH_TOKEN | |
| ) | |
| model = model.to("cuda" if torch.cuda.is_available() else "cpu") | |
| model = model.eval() | |
| return model, tok | |
| def load_gdc_genes_mutations_hf(AUTH_TOKEN): | |
| dataset_id = 'uc-ctds/GDC-QAG-genes-mutations' | |
| filename = 'gdc_genes_mutations.json' | |
| json_path = hf_hub_download( | |
| repo_id=dataset_id, | |
| filename=filename, | |
| repo_type="dataset", | |
| token=AUTH_TOKEN | |
| ) | |
| # json_path = load_dataset(dataset_id, token=AUTH_TOKEN) | |
| with open(json_path, 'r') as f: | |
| gdc_genes_mutations = json.load(f) | |
| return gdc_genes_mutations | |
| def load_intent_model_hf(AUTH_TOKEN): | |
| model_id = 'uc-ctds/query_intent' | |
| tok = AutoTokenizer.from_pretrained( | |
| model_id, trust_remote_code=True, | |
| token=AUTH_TOKEN | |
| ) | |
| model = BertForSequenceClassification.from_pretrained( | |
| model_id, token=AUTH_TOKEN) | |
| return model, tok | |
| def calculate_ssm_frequency(ssm_statistics, mutation_name, total_case_count, cancer_entities, project_mappings): | |
| ssm_frequency = {} | |
| for project in ssm_statistics.keys(): | |
| freq = ( | |
| ssm_statistics[project]["ssm_counts"] | |
| / total_case_count[project] | |
| ) | |
| ssm_frequency[project] = {"frequency": round(freq * 100, 2)} | |
| # if there are no ssms, set to 0 counts | |
| for c in cancer_entities: | |
| if c not in ssm_frequency: | |
| ssm_frequency[c] = {'frequency': 0.0} | |
| return ssm_frequency | |
| def calculate_joint_ssm_frequency(ssm_statistics, total_case_count, mutation_list, cancer_entities): | |
| # stores the result for all cancers | |
| joint_ssm_frequency = {} | |
| # initialize joint_freq by cancer entities | |
| joint_ssm_frequency_for_cancer = {} | |
| for c in cancer_entities: | |
| joint_ssm_frequency_for_cancer[c] = {} | |
| joint_ssm_frequency_for_cancer[c] = {"joint_frequency": 0.0} | |
| projects_with_mutation = [ | |
| set(ssm_statistics[mutation].keys()) for mutation in mutation_list | |
| ] | |
| overlapping_projects_with_mutation = list( | |
| reduce(lambda x, y: x & y, projects_with_mutation) | |
| ) | |
| for project in overlapping_projects_with_mutation: | |
| cases_with_mutation = [ | |
| set(ssm_statistics[mutation][project]["case_id_list"]) | |
| for mutation in mutation_list | |
| ] | |
| print('getting shared cases...') | |
| shared_cases = list(reduce(lambda x, y: x & y, cases_with_mutation)) | |
| print('number of shared cases: {}'.format(len(shared_cases))) | |
| if shared_cases: | |
| if project not in joint_ssm_frequency: | |
| joint_ssm_frequency[project] = {} | |
| joint_frequency = len(shared_cases) / total_case_count[project] | |
| joint_ssm_frequency[project]["joint_frequency"] = round( | |
| joint_frequency * 100, 2 | |
| ) | |
| # filter for specific cancer type and return | |
| for c in cancer_entities: | |
| if c in joint_ssm_frequency: | |
| joint_ssm_frequency_for_cancer[c]["joint_frequency"] = joint_ssm_frequency[ | |
| c | |
| ]["joint_frequency"] | |
| return joint_ssm_frequency_for_cancer | |
| def flatten_ssm_results_to_text(result, result_type): | |
| result_text = [] | |
| print('preparing a GDC Result for query augmentation...') | |
| if result_type == "joint_frequency": | |
| for k, v in result.items(): | |
| if k == "joint_frequency": | |
| for k2, v2 in v.items(): | |
| gdc_result = "joint frequency in {} is {}%".format(k2, v2["joint_frequency"]) | |
| result_text.append(gdc_result) | |
| else: | |
| for k, v in result.items(): | |
| if k != "joint_frequency": | |
| for k2, v2 in v.items(): | |
| gdc_result = "The frequency of {} in {} is {}%".format(k, k2, v2["frequency"]) | |
| result_text.append(gdc_result) | |
| print('prepared GDC Result: {}'.format(gdc_result)) | |
| return result_text | |
| def get_ssm_frequency( | |
| gene_entities, mutation_entities, cancer_entities, project_mappings | |
| ): | |
| ssm_statistics = {} | |
| mutation_list = [] | |
| result = {} | |
| total_case_count = {} | |
| for ce in cancer_entities: | |
| total_case_count[ce] = gdc_api_calls.get_available_ssm_data_for_project(ce) | |
| # to match the genes with mutations | |
| if len(mutation_entities) > len(gene_entities): | |
| gene_entities = gene_entities * len(mutation_entities) | |
| for gene, mutation in zip(gene_entities, mutation_entities): | |
| mutation_name = "_".join([gene, mutation]) | |
| mutation_list.append(mutation_name) | |
| ssm_id = gdc_api_calls.get_ssm_id(gene, mutation) | |
| ssm_counts_by_project = gdc_api_calls.get_ssm_counts(ssm_id, cancer_entities) | |
| ssm_statistics[mutation_name] = ssm_counts_by_project | |
| result[mutation_name] = calculate_ssm_frequency( | |
| ssm_statistics[mutation_name], mutation_name, total_case_count, cancer_entities, project_mappings | |
| ) | |
| print('\nStep 5: Query GDC and process results\n') | |
| for mut in mutation_list: | |
| print('mutation: {}'.format(mut)) | |
| for ce in cancer_entities: | |
| try: | |
| print('number of cases with mutation: {}'.format( | |
| ssm_statistics[mut][ce]["ssm_counts"])) | |
| except Exception as e: | |
| # if no ssms, this will be empty | |
| print('number of cases with mutation: {}'.format( | |
| ssm_statistics | |
| )) | |
| print('total case count: {}'.format( | |
| total_case_count[ce])) | |
| # only supporting for two mutations atm | |
| if len(mutation_list) > 1: | |
| result["joint_frequency"] = calculate_joint_ssm_frequency( | |
| ssm_statistics, total_case_count, mutation_list, cancer_entities | |
| ) | |
| result_text = flatten_ssm_results_to_text(result, result_type="joint_frequency") | |
| else: | |
| result["joint_frequency"] = 0 | |
| result_text = flatten_ssm_results_to_text( | |
| result, result_type="single_frequency" | |
| ) | |
| return result_text, cancer_entities | |
| def decompose_mutation_and_cnv(query, match_term, gdc_genes_mutations): | |
| decompose_result = {} | |
| genes = [g for g in query.split(" ") if g in gdc_genes_mutations.keys()] | |
| # query must have cnv first, followed by mutation | |
| cnv_gene_name, mut_gene_name = genes[0], genes[1] | |
| # print('cnv_gene_name, mut_gene_name {} {}'.format( | |
| # cnv_gene_name, mut_gene_name)) | |
| decompose_result["cnv_and_ssm"] = True | |
| decompose_result["cnv_gene"] = cnv_gene_name | |
| decompose_result["mut_gene"] = mut_gene_name | |
| decompose_result["cnv_change_type"] = match_term | |
| return decompose_result | |
| def get_freq_of_cnv_and_ssms( | |
| query, cancer_entities, gene_entities, gdc_genes_mutations | |
| ): | |
| lc_query = query.lower() | |
| match_term = "" | |
| cnv_terms = [ | |
| "amplification", | |
| "deletion", | |
| "loss", | |
| "gain", | |
| "homozygous deletion", | |
| "heterozygous deletion", | |
| ] | |
| for term in cnv_terms: | |
| if term in lc_query: | |
| match_term = term | |
| # print('match_term {}'.format(match_term)) | |
| if match_term: | |
| decompose_result = decompose_mutation_and_cnv( | |
| query, match_term, gdc_genes_mutations | |
| ) | |
| # print('decompose result {}'.format(decompose_result)) | |
| result, cancer_entities = gdc_api_calls.run_cnv_ssm_api( | |
| decompose_result, cancer_entities, query | |
| ) | |
| # print('result {}'.format(result)) | |
| else: | |
| # no specific match terms, return freq of cnvs + ssm | |
| result, cancer_entities = gdc_api_calls.get_top_cases_counts_by_gene( | |
| gene_entities, cancer_entities | |
| ) | |
| return result, cancer_entities | |
| def return_initial_cancer_entities(query, model): | |
| nlp = spacy.load(model) | |
| doc = nlp(query) | |
| result = doc.ents | |
| initial_cancer_entities = [e.text for e in result if e.label_ == "DISEASE"] | |
| return initial_cancer_entities | |
| def infer_gene_entities_from_query(query, gdc_genes_mutations): | |
| entities = [] | |
| # gene recognition with simple dict-based method | |
| for g in gdc_genes_mutations.keys(): | |
| if (g in query) and (g in query.split(" ")): | |
| entities.append(g) | |
| return entities | |
| def check_if_project_id_in_query(project_list, query): | |
| # check if mention of project keys | |
| # e.g. TCGA-BRCA in query | |
| final_entities = [ | |
| potential_ce | |
| for potential_ce in query.split(" ") | |
| if potential_ce in project_list | |
| ] | |
| return final_entities | |
| def proj_id_and_partial_match(query, project_mappings, initial_cancer_entities): | |
| final_entities = [] | |
| if initial_cancer_entities: | |
| # print('checking for full match between initial cancer entities and GDC project descriptions') | |
| # check for match with project_mapping values | |
| # e.g. match "ovarian serous cystadenocarcinoma" to TCGA-OV project | |
| for ic in initial_cancer_entities: | |
| for k, v in project_mappings.items(): | |
| for c in v: | |
| if ic in c.lower(): | |
| # print('found!!! {} {}'.format(ic, c.lower())) | |
| final_entities.append(k) | |
| else: | |
| # print('no initial cancer entities, check for full match between query terms and GDC project descriptions') | |
| for term in query.lower().split(" "): | |
| for k, v in project_mappings.items(): | |
| for c in v: | |
| if term in c.lower(): | |
| # print('found!!! {} {}'.format(ic, c.lower())) | |
| final_entities.append(k) | |
| return list(set(final_entities)) | |
| def postprocess_cancer_entities(project_mappings, initial_cancer_entities, query): | |
| # print('initial cancer entities {}'.format(initial_cancer_entities)) | |
| project_list = project_mappings.keys() | |
| # print('check if GDC project-id mentioned in query') | |
| final_entities = check_if_project_id_in_query(project_list, query) | |
| if final_entities: | |
| return final_entities | |
| else: | |
| if initial_cancer_entities: | |
| # first query GDC projects endpt | |
| # print('test 1 (w/ initial entities): querying GDC projects endpt for project_id') | |
| gdc_project_match = gdc_api_calls.map_cancer_entities_to_project( | |
| initial_cancer_entities, project_mappings | |
| ) | |
| # print('mapped projects to ids {}'.format(gdc_project_match)) | |
| if gdc_project_match.values(): | |
| final_entities = list(gdc_project_match.values()) | |
| if not final_entities: | |
| # print('test 2 (w/ initial entities): no result from GDC projects endpt, check for matches ' | |
| # 'between query terms and gdc project_mappings') | |
| final_entities = proj_id_and_partial_match( | |
| query, project_mappings, initial_cancer_entities | |
| ) | |
| else: | |
| # no initial_cancer_entities | |
| # check project_mappings keys/values for matches with query terms | |
| # print('test 3 (w/o initial entities): no result from GDC projects endpt, check for matches ' | |
| # 'between query terms and gdc project_mappings') | |
| final_entities = proj_id_and_partial_match( | |
| query, project_mappings, initial_cancer_entities | |
| ) | |
| return final_entities | |
| def infer_mutation_entities(gene_entities, query, gdc_genes_mutations): | |
| mutation_entities = [] | |
| for g in gene_entities: | |
| for m in gdc_genes_mutations[g]: | |
| if m in query: | |
| mutation_entities.append(m) | |
| return mutation_entities | |
| def set_hf_token(token_path): | |
| # hugging face token | |
| with open(token_path, "r") as hf_token_file: | |
| HF_TOKEN = hf_token_file.read().strip() | |
| HfFolder.save_token(HF_TOKEN) | |
| def get_final_columns(): | |
| # colnames for final output CSV | |
| final_columns = [ | |
| "questions", | |
| "gene_entities", | |
| "mutation_entities", | |
| "cancer_entities", | |
| "intent", | |
| "llama_base_output", | |
| "llama_base_stat", | |
| "gdc_result", | |
| "gdc_qag_base_stat", | |
| "descriptive_prompt", | |
| "percentage_prompt", | |
| "final_gdc_qag_desc_response", | |
| "final_gdc_qag_percentage_response", | |
| "final_gdc_qag_response", | |
| ] | |
| return final_columns | |
| def timeit(fn): | |
| def wrapper(*args, **kwargs): | |
| start = time.perf_counter() | |
| result = fn(*args, **kwargs) | |
| end = time.perf_counter() | |
| print(f"{fn.__name__} took {end - start:.4f} seconds") | |
| return result | |
| return wrapper |