Spaces:
Sleeping
Sleeping
| import os | |
| from types import SimpleNamespace | |
| import gradio as gr | |
| import json | |
| import pandas as pd | |
| import spaces | |
| import torch | |
| from methods import gdc_api_calls, utilities | |
| from transformers import AutoTokenizer, BertTokenizer, AutoModelForCausalLM, BertForSequenceClassification | |
| from guidance import gen as guidance_gen | |
| from guidance.models import Transformers | |
| from transformers import set_seed | |
| from methods import gdc_api_calls, utilities | |
| # set up various tokens | |
| working_llama_token = os.environ.get("let_this_please_work", False) | |
| hf_TOKEN = os.environ.get("fineTest", False) | |
| intent_token = os.environ.get("query_intent_test", False) | |
| EXAMPLE_INPUTS = [ | |
| "What is the co-occurence frequency of somatic homozygous deletions in CDKN2A and CDKN2B in the mesothelioma project TCGA-MESO in the genomic data commons?", | |
| "What is the co-occurence frequency of somatic heterozygous deletions in BRCA2 and NF1 in the Kidney Chromophobe TCGA-KICH project in the genomic data commons?", | |
| "What percentage of ovarian serous cystadenocarcinoma cases have a somatic heterozygous deletion in BRCA1 and simple somatic mutations in BRCA1 in the genomic data commons?", | |
| "What fraction of cases have simple somatic mutations or copy number variants in ALK in Uterine Carcinosarcoma TCGA-UCS project in the genomic data commons?", | |
| "How often is microsatellite instability observed in Stomach Adenocarcinoma TCGA-STAD project in the genomic data commons?", | |
| "How often is the BRAF V600E mutation found in Skin Cutaneous Melanoma TCGA-SKCM project in the genomic data commons?", | |
| "What is the co-occurence frequency of IDH1 R132H and TP53 R273C simple somatic mutations in the low grade glioma project TCGA-LGG in the genomic data commons?" | |
| ] | |
| EXAMPLE_LABELS = [ | |
| "combination homozygous deletions", | |
| "combination heterozygous deletions", | |
| "heterozygous deletion and somatic mutations", | |
| "copy number variants or somatic mutations", | |
| "microsatellite-instability", | |
| "simple somatic mutation", | |
| "combination somatic mutations" | |
| ] | |
| # set up requirements: models and data | |
| print("getting gdc project information") | |
| project_mappings = gdc_api_calls.get_gdc_project_ids(start=0, stop=86) | |
| print('loading intent model and tokenizer') | |
| model_id = 'uc-ctds/query_intent' | |
| intent_tok = AutoTokenizer.from_pretrained( | |
| model_id, trust_remote_code=True, | |
| token=intent_token | |
| ) | |
| intent_model = BertForSequenceClassification.from_pretrained( | |
| model_id, token=intent_token) | |
| intent_model = intent_model.to('cuda').eval() | |
| print("loading gdc genes and mutations") | |
| gdc_genes_mutations = utilities.load_gdc_genes_mutations_hf(hf_TOKEN) | |
| print("loading llama-3B model and tokenizer") | |
| model_id = "meta-llama/Llama-3.2-3B-Instruct" | |
| tok = AutoTokenizer.from_pretrained( | |
| model_id, trust_remote_code=True, | |
| token=working_llama_token | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True, | |
| token=working_llama_token | |
| ) | |
| model = model.to('cuda').eval() | |
| # execute_api_call | |
| def execute_api_call( | |
| intent, | |
| gene_entities, | |
| mutation_entities, | |
| cancer_entities, | |
| query | |
| ): | |
| if intent == "ssm_frequency": | |
| result, cancer_entities = utilities.get_ssm_frequency( | |
| gene_entities, mutation_entities, cancer_entities, project_mappings | |
| ) | |
| elif intent == "top_mutated_genes_by_project": | |
| result = gdc_api_calls.get_top_mutated_genes_by_project( | |
| cancer_entities, top_k=10 | |
| ) | |
| elif intent == "most_frequently_mutated_gene": | |
| result = gdc_api_calls.get_top_mutated_genes_by_project( | |
| cancer_entities, top_k=1 | |
| ) | |
| elif intent == "freq_cnv_loss_or_gain": | |
| result, cancer_entities = gdc_api_calls.get_freq_cnv_loss_or_gain( | |
| gene_entities, cancer_entities, query, cnv_and_ssm_flag=False | |
| ) | |
| elif intent == "msi_h_frequency": | |
| result, cancer_entities = gdc_api_calls.get_msi_frequency(cancer_entities) | |
| elif intent == "cnv_and_ssm": | |
| result, cancer_entities = utilities.get_freq_of_cnv_and_ssms( | |
| query, cancer_entities, gene_entities, gdc_genes_mutations | |
| ) | |
| elif intent == "top_cases_counts_by_gene": | |
| result, cancer_entities = gdc_api_calls.get_top_cases_counts_by_gene( | |
| gene_entities, cancer_entities | |
| ) | |
| elif intent == "project_summary": | |
| result = gdc_api_calls.get_project_summary(cancer_entities) | |
| else: | |
| result = "user intent not recognized, or use case not covered" | |
| return result, cancer_entities | |
| def infer_user_intent(query): | |
| intent_labels = { | |
| "ssm_frequency": 0.0, | |
| "msi_h_frequency": 1.0, | |
| "freq_cnv_loss_or_gain": 2.0, | |
| "top_cases_counts_by_gene": 3.0, | |
| "cnv_and_ssm": 4.0, | |
| } | |
| inputs = intent_tok(query, return_tensors="pt", truncation=True, padding=True) | |
| inputs = {k: v.to('cuda') for k, v in inputs.items()} | |
| outputs = intent_model(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=1) | |
| predicted_label = torch.argmax(probs, dim=1).item() | |
| for k, v in intent_labels.items(): | |
| if v == predicted_label: | |
| return k | |
| # function to combine entities, intent and API call | |
| def construct_and_execute_api_call(query): | |
| print("query:\n{}\n".format(query)) | |
| # Infer entities | |
| initial_cancer_entities = utilities.return_initial_cancer_entities( | |
| query, model="en_ner_bc5cdr_md" | |
| ) | |
| if not initial_cancer_entities: | |
| try: | |
| initial_cancer_entities = utilities.return_initial_cancer_entities( | |
| query, model="en_core_sci_md" | |
| ) | |
| except Exception as e: | |
| print("unable to guess cancer entities {}".format(str(e))) | |
| initial_cancer_entities = [] | |
| cancer_entities = utilities.postprocess_cancer_entities( | |
| project_mappings, initial_cancer_entities=initial_cancer_entities, query=query | |
| ) | |
| # if cancer entities is empty from above methods return all projects | |
| if not cancer_entities: | |
| cancer_entities = list(project_mappings.keys()) | |
| gene_entities = utilities.infer_gene_entities_from_query(query, gdc_genes_mutations) | |
| mutation_entities = utilities.infer_mutation_entities( | |
| gene_entities=gene_entities, | |
| query=query, | |
| gdc_genes_mutations=gdc_genes_mutations, | |
| ) | |
| print("gene entities {}".format(gene_entities)) | |
| print("mutation entities {}".format(mutation_entities)) | |
| print("cancer entities {}".format(cancer_entities)) | |
| # infer user intent | |
| intent = infer_user_intent(query) | |
| print("user intent:\n{}\n".format(intent)) | |
| try: | |
| api_call_result, cancer_entities = execute_api_call( | |
| intent, | |
| gene_entities, | |
| mutation_entities, | |
| cancer_entities, | |
| query | |
| ) | |
| print("api_call_result {}".format(api_call_result)) | |
| except Exception as e: | |
| print("unable to process query {} {}".format(query, str(e))) | |
| api_call_result = [] | |
| cancer_entities = [] | |
| return SimpleNamespace( | |
| helper_output=api_call_result, | |
| cancer_entities=cancer_entities, | |
| intent=intent, | |
| gene_entities=gene_entities, | |
| mutation_entities=mutation_entities, | |
| ) | |
| # generate llama model response | |
| def generate_response(modified_query): | |
| #set_seed(1042) | |
| regex = "The final answer is: \d*\.\d*%" | |
| lm = Transformers(model=model, tokenizer=tok) | |
| lm += modified_query | |
| lm += guidance_gen( | |
| "gen_response", | |
| n=1, | |
| temperature=0, | |
| max_tokens=1000, | |
| regex=regex | |
| ) | |
| return lm["gen_response"] | |
| def batch_test(query): | |
| modified_query = utilities.construct_modified_query_base_llm(query) | |
| print(f"modified_query is: {modified_query}") | |
| llama_base_output = generate_response(modified_query) | |
| print(f"llama_base_output: {llama_base_output}") | |
| try: | |
| result = construct_and_execute_api_call(query) | |
| except Exception as e: | |
| # unable to compute at this time, recheck | |
| result.helper_output = [] | |
| result.cancer_entities = [] | |
| # if there is not a helper output for each unique cancer entity | |
| # log error to inspect and reprocess query later | |
| try: | |
| len(result.helper_output) == len(result.cancer_entities) | |
| except Exception as e: | |
| msg = "there is not a unique helper output for each unique \ | |
| cancer entity in {}".format( | |
| query | |
| ) | |
| print("exception {}".format(msg)) | |
| result.helper_output = [] | |
| result.cancer_entities = [] | |
| return pd.Series( | |
| [ | |
| llama_base_output, | |
| result.helper_output, | |
| result.cancer_entities, | |
| result.intent, | |
| result.gene_entities, | |
| result.mutation_entities, | |
| ] | |
| ) | |
| def get_prefinal_response(row): | |
| try: | |
| query = row["questions"] | |
| helper_output = row["helper_output"] | |
| except Exception as e: | |
| print(f"unable to retrieve query: {query} or helper_output: {helper_output}") | |
| modified_query = utilities.construct_modified_query(query, helper_output) | |
| prefinal_llama_with_helper_output = generate_response(modified_query) | |
| return pd.Series([modified_query, prefinal_llama_with_helper_output]) | |
| def execute_pipeline(question: str): | |
| df = pd.DataFrame({'questions' : [question]}) | |
| print(f'Question received: {question}') | |
| print("starting pipeline") | |
| print("CUDA available:", torch.cuda.is_available()) | |
| print("CUDA device name:", torch.cuda.get_device_name(0)) | |
| # queries input file | |
| print(f"running test on input {df}") | |
| df[ | |
| [ | |
| "llama_base_output", | |
| "helper_output", | |
| "cancer_entities", | |
| "intent", | |
| "gene_entities", | |
| "mutation_entities", | |
| ] | |
| ] = df["questions"].apply(lambda x: batch_test(x)) | |
| df_exploded = df.explode("helper_output", ignore_index=True) | |
| df_exploded[["modified_prompt", "pre_final_llama_with_helper_output"]] = ( | |
| df_exploded.apply( | |
| lambda x: get_prefinal_response(x), axis=1 | |
| ) | |
| ) | |
| ### postprocess response | |
| print("postprocessing response") | |
| df_exploded[ | |
| [ | |
| "llama_base_stat", | |
| "delta_llama", | |
| "value_changed", | |
| "ground_truth_stat", | |
| "generated_stat_prefinal", | |
| "delta_prefinal", | |
| "generated_stat_final", | |
| "delta_final", | |
| "final_response", | |
| ] | |
| ] = df_exploded.apply( | |
| lambda x: utilities.postprocess_response(x), axis=1 | |
| ) | |
| final_columns = utilities.get_final_columns() | |
| result = df_exploded[final_columns] | |
| result.rename(columns={ | |
| 'llama_base_output': 'llama-3B baseline output', | |
| 'modified_prompt': 'Query augmented prompt', | |
| 'helper_output': 'Processed GDC API result', | |
| 'ground_truth_stat': 'Ground truth frequency from GDC', | |
| 'llama_base_stat': 'llama-3B baseline frequency', | |
| 'delta_llama': 'llama-3B frequency - Ground truth frequency', | |
| 'final_response': 'Query augmented generation', | |
| 'intent': 'Intent', | |
| 'cancer_entities': 'Cancer entities', | |
| 'gene_entities': 'Gene entities', | |
| 'mutation_entities': 'Mutation entities', | |
| 'questions' : 'Question' | |
| }, inplace=True) | |
| result.index = ['QAG pipeline results'] * len(result) | |
| print('completed') | |
| print('writing result string now') | |
| result = result.T.to_dict() | |
| print('result {}'.format(result)) | |
| result_string = "" | |
| result_string += f"Question: {result['QAG pipeline results']['Question']}\n" | |
| result_string += f"llama-3B baseline output: {result['QAG pipeline results']['llama-3B baseline frequency']}%\n" | |
| result_string += f"Query augmented prompt: {result['QAG pipeline results']['Query augmented prompt']}" | |
| result_string += f"Query augmented generation: {result['QAG pipeline results']['Query augmented generation']}" | |
| return result_string | |
| # return json.dumps(result.T.to_dict(), indent=2) | |
| def visible_component(input_text): | |
| return gr.update(value="WHATEVER") | |
| # Create Gradio interface | |
| with gr.Blocks(title="GDC QAG MCP server") as GDC_QAG_QUERY: | |
| gr.Markdown( | |
| """ | |
| # GDC QAG Service | |
| """ | |
| ) | |
| with gr.Row(): | |
| query_input = gr.Textbox( | |
| lines = 3, | |
| label="Search Query", | |
| placeholder='e.g. "What is the co-occurence frequency of somatic homozygous deletions in CDKN2A and CDKN2B in the mesothelioma project TCGA-MESO in the genomic data commons?"', | |
| info="Required: Enter your search query. Click on Examples to execute example queries. Please retry query if API is unavailable or connection aborts.", | |
| ) | |
| gr.Examples( | |
| examples=EXAMPLE_INPUTS, | |
| inputs=query_input, | |
| example_labels = EXAMPLE_LABELS | |
| ) | |
| execute_button = gr.Button("Execute", variant="primary") | |
| output = gr.Textbox( | |
| label="Query Result", | |
| lines=10, | |
| max_lines=25, | |
| info="The Result of the Query will appear here", | |
| ) | |
| execute_button.click( | |
| fn=execute_pipeline, | |
| inputs=[query_input], | |
| outputs=output, | |
| ) | |
| if __name__ == "__main__": | |
| GDC_QAG_QUERY.launch(mcp_server=True, show_api=True) |