GDC-QAG / app.py
aatu18's picture
update instructions
e365b79 verified
raw
history blame
13.4 kB
import os
from types import SimpleNamespace
import gradio as gr
import json
import pandas as pd
import spaces
import torch
from methods import gdc_api_calls, utilities
from transformers import AutoTokenizer, BertTokenizer, AutoModelForCausalLM, BertForSequenceClassification
from guidance import gen as guidance_gen
from guidance.models import Transformers
from transformers import set_seed
from methods import gdc_api_calls, utilities
# set up various tokens
working_llama_token = os.environ.get("let_this_please_work", False)
hf_TOKEN = os.environ.get("fineTest", False)
intent_token = os.environ.get("query_intent_test", False)
EXAMPLE_INPUTS = [
"What is the co-occurence frequency of somatic homozygous deletions in CDKN2A and CDKN2B in the mesothelioma project TCGA-MESO in the genomic data commons?",
"What is the co-occurence frequency of somatic heterozygous deletions in BRCA2 and NF1 in the Kidney Chromophobe TCGA-KICH project in the genomic data commons?",
"What percentage of ovarian serous cystadenocarcinoma cases have a somatic heterozygous deletion in BRCA1 and simple somatic mutations in BRCA1 in the genomic data commons?",
"What fraction of cases have simple somatic mutations or copy number variants in ALK in Uterine Carcinosarcoma TCGA-UCS project in the genomic data commons?",
"How often is microsatellite instability observed in Stomach Adenocarcinoma TCGA-STAD project in the genomic data commons?",
"How often is the BRAF V600E mutation found in Skin Cutaneous Melanoma TCGA-SKCM project in the genomic data commons?",
"What is the co-occurence frequency of IDH1 R132H and TP53 R273C simple somatic mutations in the low grade glioma project TCGA-LGG in the genomic data commons?"
]
EXAMPLE_LABELS = [
"combination homozygous deletions",
"combination heterozygous deletions",
"heterozygous deletion and somatic mutations",
"copy number variants or somatic mutations",
"microsatellite-instability",
"simple somatic mutation",
"combination somatic mutations"
]
# set up requirements: models and data
print("getting gdc project information")
project_mappings = gdc_api_calls.get_gdc_project_ids(start=0, stop=86)
print('loading intent model and tokenizer')
model_id = 'uc-ctds/query_intent'
intent_tok = AutoTokenizer.from_pretrained(
model_id, trust_remote_code=True,
token=intent_token
)
intent_model = BertForSequenceClassification.from_pretrained(
model_id, token=intent_token)
intent_model = intent_model.to('cuda').eval()
print("loading gdc genes and mutations")
gdc_genes_mutations = utilities.load_gdc_genes_mutations_hf(hf_TOKEN)
print("loading llama-3B model and tokenizer")
model_id = "meta-llama/Llama-3.2-3B-Instruct"
tok = AutoTokenizer.from_pretrained(
model_id, trust_remote_code=True,
token=working_llama_token
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
trust_remote_code=True,
token=working_llama_token
)
model = model.to('cuda').eval()
# execute_api_call
def execute_api_call(
intent,
gene_entities,
mutation_entities,
cancer_entities,
query
):
if intent == "ssm_frequency":
result, cancer_entities = utilities.get_ssm_frequency(
gene_entities, mutation_entities, cancer_entities, project_mappings
)
elif intent == "top_mutated_genes_by_project":
result = gdc_api_calls.get_top_mutated_genes_by_project(
cancer_entities, top_k=10
)
elif intent == "most_frequently_mutated_gene":
result = gdc_api_calls.get_top_mutated_genes_by_project(
cancer_entities, top_k=1
)
elif intent == "freq_cnv_loss_or_gain":
result, cancer_entities = gdc_api_calls.get_freq_cnv_loss_or_gain(
gene_entities, cancer_entities, query, cnv_and_ssm_flag=False
)
elif intent == "msi_h_frequency":
result, cancer_entities = gdc_api_calls.get_msi_frequency(cancer_entities)
elif intent == "cnv_and_ssm":
result, cancer_entities = utilities.get_freq_of_cnv_and_ssms(
query, cancer_entities, gene_entities, gdc_genes_mutations
)
elif intent == "top_cases_counts_by_gene":
result, cancer_entities = gdc_api_calls.get_top_cases_counts_by_gene(
gene_entities, cancer_entities
)
elif intent == "project_summary":
result = gdc_api_calls.get_project_summary(cancer_entities)
else:
result = "user intent not recognized, or use case not covered"
return result, cancer_entities
@spaces.GPU(duration=10)
def infer_user_intent(query):
intent_labels = {
"ssm_frequency": 0.0,
"msi_h_frequency": 1.0,
"freq_cnv_loss_or_gain": 2.0,
"top_cases_counts_by_gene": 3.0,
"cnv_and_ssm": 4.0,
}
inputs = intent_tok(query, return_tensors="pt", truncation=True, padding=True)
inputs = {k: v.to('cuda') for k, v in inputs.items()}
outputs = intent_model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
predicted_label = torch.argmax(probs, dim=1).item()
for k, v in intent_labels.items():
if v == predicted_label:
return k
# function to combine entities, intent and API call
def construct_and_execute_api_call(query):
print("query:\n{}\n".format(query))
# Infer entities
initial_cancer_entities = utilities.return_initial_cancer_entities(
query, model="en_ner_bc5cdr_md"
)
if not initial_cancer_entities:
try:
initial_cancer_entities = utilities.return_initial_cancer_entities(
query, model="en_core_sci_md"
)
except Exception as e:
print("unable to guess cancer entities {}".format(str(e)))
initial_cancer_entities = []
cancer_entities = utilities.postprocess_cancer_entities(
project_mappings, initial_cancer_entities=initial_cancer_entities, query=query
)
# if cancer entities is empty from above methods return all projects
if not cancer_entities:
cancer_entities = list(project_mappings.keys())
gene_entities = utilities.infer_gene_entities_from_query(query, gdc_genes_mutations)
mutation_entities = utilities.infer_mutation_entities(
gene_entities=gene_entities,
query=query,
gdc_genes_mutations=gdc_genes_mutations,
)
print("gene entities {}".format(gene_entities))
print("mutation entities {}".format(mutation_entities))
print("cancer entities {}".format(cancer_entities))
# infer user intent
intent = infer_user_intent(query)
print("user intent:\n{}\n".format(intent))
try:
api_call_result, cancer_entities = execute_api_call(
intent,
gene_entities,
mutation_entities,
cancer_entities,
query
)
print("api_call_result {}".format(api_call_result))
except Exception as e:
print("unable to process query {} {}".format(query, str(e)))
api_call_result = []
cancer_entities = []
return SimpleNamespace(
helper_output=api_call_result,
cancer_entities=cancer_entities,
intent=intent,
gene_entities=gene_entities,
mutation_entities=mutation_entities,
)
# generate llama model response
@spaces.GPU(duration=30)
def generate_response(modified_query):
#set_seed(1042)
regex = "The final answer is: \d*\.\d*%"
lm = Transformers(model=model, tokenizer=tok)
lm += modified_query
lm += guidance_gen(
"gen_response",
n=1,
temperature=0,
max_tokens=1000,
regex=regex
)
return lm["gen_response"]
def batch_test(query):
modified_query = utilities.construct_modified_query_base_llm(query)
print(f"modified_query is: {modified_query}")
llama_base_output = generate_response(modified_query)
print(f"llama_base_output: {llama_base_output}")
try:
result = construct_and_execute_api_call(query)
except Exception as e:
# unable to compute at this time, recheck
result.helper_output = []
result.cancer_entities = []
# if there is not a helper output for each unique cancer entity
# log error to inspect and reprocess query later
try:
len(result.helper_output) == len(result.cancer_entities)
except Exception as e:
msg = "there is not a unique helper output for each unique \
cancer entity in {}".format(
query
)
print("exception {}".format(msg))
result.helper_output = []
result.cancer_entities = []
return pd.Series(
[
llama_base_output,
result.helper_output,
result.cancer_entities,
result.intent,
result.gene_entities,
result.mutation_entities,
]
)
def get_prefinal_response(row):
try:
query = row["questions"]
helper_output = row["helper_output"]
except Exception as e:
print(f"unable to retrieve query: {query} or helper_output: {helper_output}")
modified_query = utilities.construct_modified_query(query, helper_output)
prefinal_llama_with_helper_output = generate_response(modified_query)
return pd.Series([modified_query, prefinal_llama_with_helper_output])
def execute_pipeline(question: str):
df = pd.DataFrame({'questions' : [question]})
print(f'Question received: {question}')
print("starting pipeline")
print("CUDA available:", torch.cuda.is_available())
print("CUDA device name:", torch.cuda.get_device_name(0))
# queries input file
print(f"running test on input {df}")
df[
[
"llama_base_output",
"helper_output",
"cancer_entities",
"intent",
"gene_entities",
"mutation_entities",
]
] = df["questions"].apply(lambda x: batch_test(x))
df_exploded = df.explode("helper_output", ignore_index=True)
df_exploded[["modified_prompt", "pre_final_llama_with_helper_output"]] = (
df_exploded.apply(
lambda x: get_prefinal_response(x), axis=1
)
)
### postprocess response
print("postprocessing response")
df_exploded[
[
"llama_base_stat",
"delta_llama",
"value_changed",
"ground_truth_stat",
"generated_stat_prefinal",
"delta_prefinal",
"generated_stat_final",
"delta_final",
"final_response",
]
] = df_exploded.apply(
lambda x: utilities.postprocess_response(x), axis=1
)
final_columns = utilities.get_final_columns()
result = df_exploded[final_columns]
result.rename(columns={
'llama_base_output': 'llama-3B baseline output',
'modified_prompt': 'Query augmented prompt',
'helper_output': 'Processed GDC API result',
'ground_truth_stat': 'Ground truth frequency from GDC',
'llama_base_stat': 'llama-3B baseline frequency',
'delta_llama': 'llama-3B frequency - Ground truth frequency',
'final_response': 'Query augmented generation',
'intent': 'Intent',
'cancer_entities': 'Cancer entities',
'gene_entities': 'Gene entities',
'mutation_entities': 'Mutation entities',
'questions' : 'Question'
}, inplace=True)
result.index = ['QAG pipeline results'] * len(result)
print('completed')
print('writing result string now')
result = result.T.to_dict()
print('result {}'.format(result))
result_string = ""
result_string += f"Question: {result['QAG pipeline results']['Question']}\n"
result_string += f"llama-3B baseline output: {result['QAG pipeline results']['llama-3B baseline frequency']}%\n"
result_string += f"Query augmented prompt: {result['QAG pipeline results']['Query augmented prompt']}"
result_string += f"Query augmented generation: {result['QAG pipeline results']['Query augmented generation']}"
return result_string
# return json.dumps(result.T.to_dict(), indent=2)
def visible_component(input_text):
return gr.update(value="WHATEVER")
# Create Gradio interface
with gr.Blocks(title="GDC QAG MCP server") as GDC_QAG_QUERY:
gr.Markdown(
"""
# GDC QAG Service
"""
)
with gr.Row():
query_input = gr.Textbox(
lines = 3,
label="Search Query",
placeholder='e.g. "What is the co-occurence frequency of somatic homozygous deletions in CDKN2A and CDKN2B in the mesothelioma project TCGA-MESO in the genomic data commons?"',
info="Required: Enter your search query. Click on Examples to execute example queries. Please retry query if API is unavailable or connection aborts.",
)
gr.Examples(
examples=EXAMPLE_INPUTS,
inputs=query_input,
example_labels = EXAMPLE_LABELS
)
execute_button = gr.Button("Execute", variant="primary")
output = gr.Textbox(
label="Query Result",
lines=10,
max_lines=25,
info="The Result of the Query will appear here",
)
execute_button.click(
fn=execute_pipeline,
inputs=[query_input],
outputs=output,
)
if __name__ == "__main__":
GDC_QAG_QUERY.launch(mcp_server=True, show_api=True)