songs1's picture
add active filter selection box and other minor frontend tweaks
064444a
raw
history blame
26.5 kB
import json
import os
from collections import defaultdict
import gradio as gr
import requests
import spaces
import torch
import yaml
from gradio_rangeslider import RangeSlider
from guidance import json as gen_json
from guidance.models import Transformers
from transformers import AutoTokenizer, GPT2LMHeadModel, set_seed
from schema import GDCCohortSchema # isort: skip
DEBUG = "DEBUG" in os.environ
EXAMPLE_INPUTS = [
"bam files for TCGA-BRCA",
"kidney or adrenal gland cancers with alcohol history",
"tumor samples from male patients with acute myeloid lymphoma",
]
GDC_CASES_API_ENDPOINT = "https://api.gdc.cancer.gov/cases"
MODEL_NAME = "uc-ctds/gdc-cohort-llm-gpt2-s1M"
TOKENIZER_NAME = MODEL_NAME
AUTH_TOKEN = os.environ.get("HF_TOKEN", False) # HF_TOKEN must be set to use auth
with open("config.yaml", "r") as f:
CONFIG = yaml.safe_load(f)
TAB_NAMES = [tab["name"] for tab in CONFIG["tabs"]]
CARD_NAMES = [card["name"] for tab in CONFIG["tabs"] for card in tab["cards"]]
CARD_FIELDS = [card["field"] for tab in CONFIG["tabs"] for card in tab["cards"]]
CARD_2_FIELD = dict(list(zip(CARD_NAMES, CARD_FIELDS)))
CARD_2_VALUES = {
card["name"]: card["values"] for tab in CONFIG["tabs"] for card in tab["cards"]
}
FACETS_STR = ",".join(
[
f.replace("cases.", "")
for f, n in zip(CARD_FIELDS, CARD_NAMES)
if not isinstance(CARD_2_VALUES[n], dict)
# ^ skip range facets in bin counts
]
)
if not DEBUG:
tok = AutoTokenizer.from_pretrained(TOKENIZER_NAME, token=AUTH_TOKEN)
# for some reason, pre-invoking tokenizer prevents endless generation when using guidance
# opened ticket here: https://github.com/guidance-ai/guidance/issues/1322
tok("foobar")
model = GPT2LMHeadModel.from_pretrained(MODEL_NAME, token=AUTH_TOKEN)
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
model = model.eval()
DUMMY_FILTER = json.dumps(
{
"op": "and",
"content": [
{
"op": "in",
"content": {
"field": "cases.project.project_id",
"value": ["TCGA-BRCA"],
},
},
{
"op": "in",
"content": {
"field": "cases.project.program.name",
"value": ["TCGA"],
},
},
{
"op": "and",
"content": [
{
"op": ">=",
"content": {
"field": "cases.diagnoses.age_at_diagnosis",
"value": 7305,
},
},
{
"op": "<=",
"content": {
"field": "cases.diagnoses.age_at_diagnosis",
"value": 14610,
},
},
],
},
],
},
indent=4,
)
# Generate cohort filter JSON from free text
@spaces.GPU(duration=15)
def generate_filter(query):
if DEBUG:
return DUMMY_FILTER
set_seed(42)
lm = Transformers(
model=model,
tokenizer=tok,
# sampling_params=SamplingParams,
)
lm += query
lm += gen_json(
name="cohort", schema=GDCCohortSchema, temperature=0, max_tokens=1024
)
cohort_filter = lm["cohort"]
cohort_filter = json.dumps(json.loads(cohort_filter), indent=4)
return cohort_filter
# Transform query to filter to checkbox selections (and update json box)
def process_query(query):
# Generate filter
cohort_filter_str = generate_filter(query)
cohort_filter = json.loads(cohort_filter_str)
# Pre-flatten nested ops for easier mapping in next step
flattened_ops = []
for op in cohort_filter["content"]:
# nested `and` can only be 1 deep based on schema
if op["op"] == "and":
flattened_ops.extend(op["content"])
else:
flattened_ops.append(op)
# Prepare and validate generated filters
generated_field_2_values = dict()
for op in flattened_ops:
assert op["op"] in [
"in",
"=",
"<",
">",
"<=",
">=",
], f"Unknown handling for op: {op}"
content = op["content"]
field, value = content["field"], content["value"]
# comparators are ints so can convert to g/lte by add/sub 1
if op["op"] == "<":
op["op"] = "<="
value -= 1
elif op["op"] == ">":
op["op"] = ">="
value += 1
elif op["op"] == "=":
# convert = to <=,>= ops so it can be filled into card
flattened_ops.append(
{
"op": "<=",
"content": content,
}
)
flattened_ops.append(
{
"op": ">=",
"content": content,
}
)
continue
if op["op"] != "in":
# comp ops will duplicate name, disambiguate by appending comp
field += "_" + op["op"]
if field in generated_field_2_values:
raise ValueError(f"{field} is ambiguously duplicated")
generated_field_2_values[field] = value
# Map filter selections to cards
card_updates = []
for card_name, card_field in zip(CARD_NAMES, CARD_FIELDS):
# Need to update all cards so use all possible cards as ref
default_values = CARD_2_VALUES[card_name]
if isinstance(default_values, list):
updated_values = []
updated_choices = default_values # reset value
possible_values = set(updated_choices)
if card_field in generated_field_2_values:
# check ref against generated
selected_values = generated_field_2_values.pop(card_field)
unmatched_values = []
for selected_value in selected_values:
if selected_value in possible_values:
updated_values.append(selected_value)
else:
# model hallucination?
unmatched_values.append(selected_value)
if len(unmatched_values) > 0:
generated_field_2_values[card_field] = unmatched_values
update_obj = gr.update(
choices=updated_choices,
value=updated_values, # will override existing selections
)
elif isinstance(default_values, dict):
# range-slider, maybe other options in the future?
assert (
default_values["type"] == "range"
), f"Expected range slider for card {card_name}"
# Need to handle if model outputs flat range or nested range
card_field_gte = card_field + "_>="
card_field_lte = card_field + "_<="
_min = default_values["min"]
_max = default_values["max"]
lo = generated_field_2_values.pop(card_field_gte, _min)
hi = generated_field_2_values.pop(card_field_lte, _max)
assert (
lo >= _min
), f"Generated lower bound ({lo}) less than minimum allowable value ({_min})"
assert (
hi <= _max
), f"Generated upper bound ({hi}) greater than maximum allowable value ({_max})"
update_obj = gr.update(value=(lo, hi))
else:
raise ValueError(f"Unknown values for card {card_name}")
card_updates.append(update_obj)
# generated_field_2_values will have remaining, unmatched values
# edit: updated json schema with enumerated fields prevents unmatched fields
print(f"Unmatched values in model generation: {generated_field_2_values}")
return card_updates + [gr.update(value=cohort_filter_str)]
# Update JSON based on checkbox selections
def update_json_from_cards(*selected_filters_per_card):
ops = []
for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card):
# use the default values to determine card type (checkbox, range, etc)
default_values = CARD_2_VALUES[card_name]
if isinstance(default_values, list):
# checkbox
if len(selected_filters) > 0:
base_values = []
for selected_value in selected_filters:
base_value = get_base_value(selected_value)
base_values.append(base_value)
content = {
"field": CARD_2_FIELD[card_name],
"value": base_values,
}
op = {
"op": "in",
"content": content,
}
ops.append(op)
elif isinstance(default_values, dict):
# range-slider, maybe other options in the future?
assert (
default_values["type"] == "range"
), f"Expected range slider for card {card_name}"
lo, hi = selected_filters
subops = []
for val, limit, comp in [
(lo, default_values["min"], ">="),
(hi, default_values["max"], "<="),
]:
# only add range filter if not default
if val == limit:
continue
subop = {
"op": comp,
"content": {
"field": CARD_2_FIELD[card_name],
"value": int(val),
},
}
subops.append(subop)
if len(subops) > 0:
ops.append({"op": "and", "content": subops})
else:
raise ValueError(f"Unknown values for card {card_name}")
cohort_filter = {
"op": "and",
"content": ops,
}
filter_json = json.dumps(cohort_filter, indent=4)
return gr.update(value=filter_json)
# Execute GDC API query and prepare checkbox + case counter updates
# Preserve prior selections
def update_cards_with_counts(cohort_filter: str, *selected_filters_per_card):
card_2_selections = dict(list(zip(CARD_NAMES, selected_filters_per_card)))
# Execute GDC API query
params = {
"facets": FACETS_STR,
"pretty": "false",
"format": "JSON",
"size": 0,
}
if cohort_filter:
# patch for range selectors which use nested `and`
# seems `facets` and nested `and` don't play well together
# so flatten direct nested `and` for query execution only
# this is equivalent since our top-level is always `and`
# keeping nested `and` for presentation and model generations though
temp = json.loads(cohort_filter)
ops = temp["content"]
new_ops = []
for op in ops:
# assumes no deeper than single level nesting
if op["op"] == "and":
for subop in op["content"]:
new_ops.append(subop)
else:
new_ops.append(op)
temp["content"] = new_ops
cohort_filter = json.dumps(temp)
params["filters"] = cohort_filter
response = requests.get(GDC_CASES_API_ENDPOINT, params=params)
if not response.ok:
raise Exception(f"API error: {response.status_code}\n{response.json()}")
temp = response.json()
# Update checkboxes with bin counts
card_updates = []
all_counts = temp["data"]["aggregations"]
for card_name in CARD_NAMES:
card_field = CARD_2_FIELD[card_name]
card_field = card_field.replace("cases.", "")
card_values = CARD_2_VALUES[card_name]
if isinstance(card_values, list):
# value checkboxes
choice_mapping = {}
updated_choices = []
card_counts = {
x["key"]: x["doc_count"] for x in all_counts[card_field]["buckets"]
}
for value_name in card_values:
if value_name in card_counts:
value_str = prepare_value_count(
value_name,
card_counts[value_name],
)
# track possible choices to use as values
choice_mapping[value_name] = value_str
updated_choices.append(value_str)
# Align prior selections with new choices
updated_values = []
for selected_value in card_2_selections[card_name]:
base_value = get_base_value(selected_value)
if base_value not in choice_mapping:
# Re-add choices which now presumably have 0 counts
choice_mapping[base_value] = prepare_value_count(base_value, 0)
updated_values.append(choice_mapping[base_value])
update_obj = gr.update(
choices=updated_choices,
value=updated_values,
)
elif isinstance(card_values, dict):
# range-slider, maybe other options in the future?
assert (
card_values["type"] == "range"
), f"Expected range slider for card {card_name}"
# for range slider, nothing to actually do!
update_obj = gr.update()
else:
raise ValueError(f"Unknown values for card {card_name}")
card_updates.append(update_obj)
case_count = temp["data"]["pagination"]["total"]
return card_updates + [gr.update(value=f"{case_count} Cases")]
def update_active_selections(*selected_filters_per_card):
choices = []
for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card):
# use the default values to determine card type (checkbox, range, etc)
default_values = CARD_2_VALUES[card_name]
if isinstance(default_values, list):
# checkbox
for selected_value in selected_filters:
base_value = get_base_value(selected_value)
choices.append(f"{card_name.upper()}: {base_value}")
elif isinstance(default_values, dict):
# range-slider, maybe other options in the future?
assert (
default_values["type"] == "range"
), f"Expected range slider for card {card_name}"
lo, hi = selected_filters
if lo != default_values["min"] or hi != default_values["max"]:
# only add range filter if not default
lo, hi = int(lo), int(hi)
choices.append(f"{card_name.upper()}: {lo}-{hi}")
else:
raise ValueError(f"Unknown values for card {card_name}")
return gr.update(choices=choices, value=choices)
def update_cards_from_active(current_selections, *selected_filters_per_card):
# active selector uses a flattened list so re-agg values under card groups
grouped_selections = defaultdict(set)
for k_v in current_selections:
idx = k_v.find(": ")
k, v = k_v[:idx], k_v[idx + 2 :]
grouped_selections[k].add(v)
card_updates = []
for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card):
# use the default values to determine card type (checkbox, range, etc)
default_values = CARD_2_VALUES[card_name]
if isinstance(default_values, list):
# checkbox
updated_values = []
for selected_value in selected_filters:
base_value = get_base_value(selected_value)
if base_value in grouped_selections[card_name.upper()]:
updated_values.append(selected_value)
update_obj = gr.update(value=updated_values)
elif isinstance(default_values, dict):
# range-slider, maybe other options in the future?
assert (
default_values["type"] == "range"
), f"Expected range slider for card {card_name}"
# the active selector cannot change range values
# so if present as an active selection, no action is needed
# otherwise, reset entire range selector
if card_name.upper() in grouped_selections:
update_obj = gr.update()
else:
update_obj = gr.update(
value=(
default_values["min"],
default_values["max"],
)
)
else:
raise ValueError(f"Unknown values for card {card_name}")
card_updates.append(update_obj)
# also remove unselected value as possible choice
active_selection_update = gr.update(choices=current_selections)
return [active_selection_update] + card_updates
def prepare_value_count(value, count):
return f"{value} [{count}]"
def get_base_value(value):
if " [" in value:
value = value[: value.rfind(" [")]
return value
# Tab selection helper
def set_active_tab(selected_tab):
visibles = [gr.update(visible=(tab == selected_tab)) for tab in TAB_NAMES]
elem_classes = [
gr.update(variant="primary" if tab == selected_tab else "secondary")
for tab in TAB_NAMES
]
return visibles + elem_classes
DOWNLOAD_CASES_JS = f"""
function download_cases(filter_str) {{
const params = new URLSearchParams();
params.set('fields', 'case_id');
params.set('format', 'JSON');
params.set('size', 100000);
params.set('filters', filter_str);
const url = "{GDC_CASES_API_ENDPOINT}?" + params.toString();
const button = document.getElementById("download-btn");
button.innerHTML = '<div class="spinner"><\div>';
button.disabled = true;
fetch(url).then(resp => {{
if (!resp.ok) throw new Error("Failed to fetch TSV.");
return resp.json();
}})
.then(data => {{
const ids = data.data.hits.map(item => item.id);
const text = ids.join("\\n");
const blob = new Blob([text], {{type: "text/plain"}});
return blob;
}})
.then(blob => {{
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = "gdc_cohort_case_ids.tsv";
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
button.innerHTML = 'Export to GDC';
button.disabled = false;
}})
.catch(error => {{
alert("Download failed: " + error.message);
}});
}}
"""
with gr.Blocks(css_paths="style.css") as demo:
gr.Markdown("# GDC Cohort Copilot")
with gr.Row(equal_height=True):
with gr.Column(scale=7):
text_input = gr.Textbox(
label="Describe the cohort you're looking for:",
info=(
"Only provide the cohort characteristics. "
"Do not include extraneous text. "
"For example, write 'patients with X' "
"instead of 'I would like patients with X':"
),
submit_btn="Generate Cohort",
elem_id="description-input",
placeholder="Enter a cohort description to begin...",
)
with gr.Column(scale=1, min_width=150):
case_counter = gr.Text(
show_label=False,
interactive=False,
container=False,
elem_id="case-counter",
min_width=150,
)
case_download = gr.Button(
value="Export to GDC",
min_width=150,
elem_id="download-btn",
)
with gr.Row(equal_height=True):
with gr.Column(scale=1, min_width=250):
gr.Examples(
examples=EXAMPLE_INPUTS,
inputs=text_input,
)
with gr.Column(scale=4):
json_output = gr.Code(
label="Cohort Filter JSON",
value=json.dumps({"op": "and", "content": []}, indent=4),
language="json",
interactive=False,
show_label=True,
container=True,
elem_id="json-output",
)
with gr.Row(equal_height=True):
with gr.Column(scale=1, min_width=250):
gr.Markdown("## Currently Selected Filters")
with gr.Column(scale=4):
active_selections = gr.CheckboxGroup(
choices=[],
show_label=False,
interactive=True,
elem_id="active-selections",
)
with gr.Row():
gr.Markdown(
"The generated cohort filter will autopopulate into the filter cards below. "
"**GDC Cohort Copilot can make mistakes!** "
"Refine your search using the interactive checkboxes. "
"Note that many other options can be found by selecting the different tabs on the left."
)
with gr.Row():
# Tab selectors
tab_buttons = []
with gr.Column(scale=1, min_width=250):
for name in TAB_NAMES:
tab_button = gr.Button(
value=name,
variant="primary" if name == TAB_NAMES[0] else "secondary",
)
tab_buttons.append(tab_button)
# Filter cards
tab_containers = []
filter_cards = []
for tab in CONFIG["tabs"]:
visible = tab["name"] == TAB_NAMES[0] # default first card
with gr.Column(scale=4, visible=visible) as tab_container:
tab_containers.append(tab_container)
with gr.Row(elem_classes=["card-group"]):
for card in tab["cards"]:
if isinstance(card["values"], list):
filter_card = gr.CheckboxGroup(
choices=[],
label=card["name"],
interactive=True,
elem_classes=["filter-card"],
)
else:
# values is a dictionary and defines some meta options
metaopts = card["values"]
assert (
"type" in metaopts
and metaopts["type"] == "range"
and all(
k in metaopts
for k in [
"min",
"max",
]
)
), f"Unknown meta options for {card['name']}"
info = "Inclusive range"
if "unit" in metaopts:
info += f", units in {metaopts['unit']}"
filter_card = RangeSlider(
label=card["name"],
info=info,
minimum=metaopts["min"],
maximum=metaopts["max"],
step=1, # assume integer
elem_classes=["filter-card", "filter-range"],
)
filter_cards.append(filter_card)
# Assign tab buttons to toggle visibility
for tab_button, name in zip(tab_buttons, TAB_NAMES):
tab_button.click(
fn=set_active_tab,
inputs=gr.State(name),
outputs=tab_containers + tab_buttons,
)
# Enable case download
case_download.click(
fn=None, # apparently this isn't the same as not specifying it
js=DOWNLOAD_CASES_JS,
inputs=json_output,
)
# Load initial counts on startup
demo.load(
fn=update_cards_with_counts,
inputs=[gr.State("")] + filter_cards,
outputs=filter_cards + [case_counter],
)
# Update checkboxes on filter generation
# Also update JSON based on checkboxes
# - relying on checkbox update to do this fires multiple times
# - also propagates new model selections after json is updated
# Also this way it shows the model generated JSON
text_input.submit(
fn=process_query,
inputs=text_input,
outputs=filter_cards + [json_output],
).success(
fn=update_active_selections,
inputs=filter_cards,
outputs=[active_selections],
)
# Update JSON based on cards
# Keep user `input` event listener (vs `change`) otherwise will fire multiple times
# Seems like otherwise it should be cyclical, Gradio must have some logic to prevent infinite loops
for filter_card in filter_cards:
if isinstance(filter_card, RangeSlider):
filter_card.release(
fn=update_json_from_cards,
inputs=filter_cards,
outputs=json_output,
).success(
fn=update_active_selections,
inputs=filter_cards,
outputs=[active_selections],
)
else:
filter_card.input(
fn=update_json_from_cards,
inputs=filter_cards,
outputs=json_output,
).success(
fn=update_active_selections,
inputs=filter_cards,
outputs=[active_selections],
)
# Enable functionality of the active filter selectors
active_selections.input(
fn=update_cards_from_active,
inputs=[active_selections] + filter_cards,
outputs=[active_selections] + filter_cards,
).success(
fn=update_json_from_cards,
inputs=filter_cards,
outputs=json_output,
)
# Update checkboxes after executing filter query
json_output.change(
fn=update_cards_with_counts,
inputs=[json_output] + filter_cards,
outputs=filter_cards + [case_counter],
)
if __name__ == "__main__":
demo.launch()