Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,9 +14,6 @@ from inference import (
|
|
| 14 |
)
|
| 15 |
|
| 16 |
|
| 17 |
-
# ---------------------------------------------------------------------------
|
| 18 |
-
# Choice lists
|
| 19 |
-
# ---------------------------------------------------------------------------
|
| 20 |
|
| 21 |
AGEGPFF_CHOICES = ["<=10", "11-17", "18-29", "30-49", ">=50"]
|
| 22 |
SEX_CHOICES = ["Male", "Female"]
|
|
@@ -56,9 +53,7 @@ SCATXRSN_CHOICES = [
|
|
| 56 |
]
|
| 57 |
|
| 58 |
|
| 59 |
-
|
| 60 |
-
# Grouped published-regimen dropdown
|
| 61 |
-
# ---------------------------------------------------------------------------
|
| 62 |
|
| 63 |
GROUPED_REGIMEN_CHOICES = [
|
| 64 |
("ββ HLA IDENTICAL ββ", "__header_hla_identical__"),
|
|
@@ -137,9 +132,7 @@ PUBLISHED_PRESETS = {
|
|
| 137 |
}
|
| 138 |
|
| 139 |
|
| 140 |
-
|
| 141 |
-
# Feature groupings
|
| 142 |
-
# ---------------------------------------------------------------------------
|
| 143 |
|
| 144 |
PATIENT_FEATURES = ["AGE", "AGEGPFF", "SEX", "KPS", "RCMVPR"]
|
| 145 |
DONOR_FEATURES = ["DONORF", "GRAFTYPE", "HLA_FINAL",
|
|
@@ -148,9 +141,7 @@ DISEASE_FEATURES = ["NACS2YR", "EXCHTFPR", "VOC2YPR", "VOCFRQPR", "SCATXRSN"]
|
|
| 148 |
ALL_FEATURES = PATIENT_FEATURES + DONOR_FEATURES + DISEASE_FEATURES
|
| 149 |
|
| 150 |
|
| 151 |
-
|
| 152 |
-
# Utility callbacks
|
| 153 |
-
# ---------------------------------------------------------------------------
|
| 154 |
|
| 155 |
def get_age_group(age):
|
| 156 |
if age is None or age == "":
|
|
@@ -188,7 +179,7 @@ def apply_grouped_preset(selected_value):
|
|
| 188 |
return [gr.update()] * 7
|
| 189 |
|
| 190 |
return [
|
| 191 |
-
gr.update(),
|
| 192 |
gr.update(value=preset["DONORF"]),
|
| 193 |
gr.update(value=preset["CONDGRPF"]),
|
| 194 |
gr.update(value=preset["CONDGRP_FINAL"]),
|
|
@@ -198,9 +189,7 @@ def apply_grouped_preset(selected_value):
|
|
| 198 |
]
|
| 199 |
|
| 200 |
|
| 201 |
-
|
| 202 |
-
# Component factory
|
| 203 |
-
# ---------------------------------------------------------------------------
|
| 204 |
|
| 205 |
def make_component(name: str):
|
| 206 |
if name == "AGE":
|
|
@@ -247,9 +236,7 @@ def make_component(name: str):
|
|
| 247 |
return gr.Textbox(label=name)
|
| 248 |
|
| 249 |
|
| 250 |
-
|
| 251 |
-
# Prediction callback
|
| 252 |
-
# ---------------------------------------------------------------------------
|
| 253 |
|
| 254 |
def predict_gradio(*values):
|
| 255 |
try:
|
|
@@ -284,7 +271,7 @@ def predict_gradio(*values):
|
|
| 284 |
|
| 285 |
return (
|
| 286 |
df,
|
| 287 |
-
icon_arrays["__grid__"],
|
| 288 |
shap_plots["DEAD"],
|
| 289 |
shap_plots["GF"],
|
| 290 |
shap_plots["AGVHD"],
|
|
@@ -304,9 +291,7 @@ def predict_gradio(*values):
|
|
| 304 |
raise gr.Error(f"{type(e).__name__}: {str(e)}\n\nSee terminal for full traceback.")
|
| 305 |
|
| 306 |
|
| 307 |
-
|
| 308 |
-
# CSS
|
| 309 |
-
# ---------------------------------------------------------------------------
|
| 310 |
|
| 311 |
custom_css = """
|
| 312 |
.predict-button {
|
|
@@ -322,9 +307,6 @@ custom_css = """
|
|
| 322 |
}
|
| 323 |
"""
|
| 324 |
|
| 325 |
-
# ---------------------------------------------------------------------------
|
| 326 |
-
# Gradio UI
|
| 327 |
-
# ---------------------------------------------------------------------------
|
| 328 |
|
| 329 |
with gr.Blocks(title="HCT Outcome Prediction Model", css=custom_css) as demo:
|
| 330 |
gr.Markdown(
|
|
@@ -338,13 +320,13 @@ with gr.Blocks(title="HCT Outcome Prediction Model", css=custom_css) as demo:
|
|
| 338 |
inputs_dict = {}
|
| 339 |
|
| 340 |
with gr.Row():
|
| 341 |
-
|
| 342 |
with gr.Column(scale=1):
|
| 343 |
gr.Markdown("### Patient Characteristics")
|
| 344 |
for f in PATIENT_FEATURES:
|
| 345 |
inputs_dict[f] = make_component(f)
|
| 346 |
|
| 347 |
-
|
| 348 |
with gr.Column(scale=1):
|
| 349 |
gr.Markdown("### Transplant Characteristics")
|
| 350 |
|
|
@@ -364,13 +346,13 @@ with gr.Blocks(title="HCT Outcome Prediction Model", css=custom_css) as demo:
|
|
| 364 |
gvhd_final = inputs_dict["GVHD_FINAL"] = make_component("GVHD_FINAL")
|
| 365 |
hla_final = inputs_dict["HLA_FINAL"] = make_component("HLA_FINAL")
|
| 366 |
|
| 367 |
-
|
| 368 |
with gr.Column(scale=1):
|
| 369 |
gr.Markdown("### Disease Characteristics")
|
| 370 |
for f in DISEASE_FEATURES:
|
| 371 |
inputs_dict[f] = make_component(f)
|
| 372 |
|
| 373 |
-
|
| 374 |
inputs_dict["AGE"].change(
|
| 375 |
fn=get_age_group,
|
| 376 |
inputs=inputs_dict["AGE"],
|
|
@@ -383,7 +365,7 @@ with gr.Blocks(title="HCT Outcome Prediction Model", css=custom_css) as demo:
|
|
| 383 |
outputs=inputs_dict["VOCFRQPR"],
|
| 384 |
)
|
| 385 |
|
| 386 |
-
|
| 387 |
grouped_regimen_dropdown.change(
|
| 388 |
fn=apply_grouped_preset,
|
| 389 |
inputs=grouped_regimen_dropdown,
|
|
@@ -412,7 +394,7 @@ with gr.Blocks(title="HCT Outcome Prediction Model", css=custom_css) as demo:
|
|
| 412 |
gr.Markdown("---")
|
| 413 |
gr.Markdown("## Outcome Probability β Icon Arrays")
|
| 414 |
|
| 415 |
-
|
| 416 |
icon_array_grid = gr.HTML(label="")
|
| 417 |
|
| 418 |
gr.Markdown("---")
|
|
|
|
| 14 |
)
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
AGEGPFF_CHOICES = ["<=10", "11-17", "18-29", "30-49", ">=50"]
|
| 19 |
SEX_CHOICES = ["Male", "Female"]
|
|
|
|
| 53 |
]
|
| 54 |
|
| 55 |
|
| 56 |
+
|
|
|
|
|
|
|
| 57 |
|
| 58 |
GROUPED_REGIMEN_CHOICES = [
|
| 59 |
("ββ HLA IDENTICAL ββ", "__header_hla_identical__"),
|
|
|
|
| 132 |
}
|
| 133 |
|
| 134 |
|
| 135 |
+
|
|
|
|
|
|
|
| 136 |
|
| 137 |
PATIENT_FEATURES = ["AGE", "AGEGPFF", "SEX", "KPS", "RCMVPR"]
|
| 138 |
DONOR_FEATURES = ["DONORF", "GRAFTYPE", "HLA_FINAL",
|
|
|
|
| 141 |
ALL_FEATURES = PATIENT_FEATURES + DONOR_FEATURES + DISEASE_FEATURES
|
| 142 |
|
| 143 |
|
| 144 |
+
|
|
|
|
|
|
|
| 145 |
|
| 146 |
def get_age_group(age):
|
| 147 |
if age is None or age == "":
|
|
|
|
| 179 |
return [gr.update()] * 7
|
| 180 |
|
| 181 |
return [
|
| 182 |
+
gr.update(),
|
| 183 |
gr.update(value=preset["DONORF"]),
|
| 184 |
gr.update(value=preset["CONDGRPF"]),
|
| 185 |
gr.update(value=preset["CONDGRP_FINAL"]),
|
|
|
|
| 189 |
]
|
| 190 |
|
| 191 |
|
| 192 |
+
|
|
|
|
|
|
|
| 193 |
|
| 194 |
def make_component(name: str):
|
| 195 |
if name == "AGE":
|
|
|
|
| 236 |
return gr.Textbox(label=name)
|
| 237 |
|
| 238 |
|
| 239 |
+
|
|
|
|
|
|
|
| 240 |
|
| 241 |
def predict_gradio(*values):
|
| 242 |
try:
|
|
|
|
| 271 |
|
| 272 |
return (
|
| 273 |
df,
|
| 274 |
+
icon_arrays["__grid__"],
|
| 275 |
shap_plots["DEAD"],
|
| 276 |
shap_plots["GF"],
|
| 277 |
shap_plots["AGVHD"],
|
|
|
|
| 291 |
raise gr.Error(f"{type(e).__name__}: {str(e)}\n\nSee terminal for full traceback.")
|
| 292 |
|
| 293 |
|
| 294 |
+
|
|
|
|
|
|
|
| 295 |
|
| 296 |
custom_css = """
|
| 297 |
.predict-button {
|
|
|
|
| 307 |
}
|
| 308 |
"""
|
| 309 |
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
with gr.Blocks(title="HCT Outcome Prediction Model", css=custom_css) as demo:
|
| 312 |
gr.Markdown(
|
|
|
|
| 320 |
inputs_dict = {}
|
| 321 |
|
| 322 |
with gr.Row():
|
| 323 |
+
|
| 324 |
with gr.Column(scale=1):
|
| 325 |
gr.Markdown("### Patient Characteristics")
|
| 326 |
for f in PATIENT_FEATURES:
|
| 327 |
inputs_dict[f] = make_component(f)
|
| 328 |
|
| 329 |
+
|
| 330 |
with gr.Column(scale=1):
|
| 331 |
gr.Markdown("### Transplant Characteristics")
|
| 332 |
|
|
|
|
| 346 |
gvhd_final = inputs_dict["GVHD_FINAL"] = make_component("GVHD_FINAL")
|
| 347 |
hla_final = inputs_dict["HLA_FINAL"] = make_component("HLA_FINAL")
|
| 348 |
|
| 349 |
+
|
| 350 |
with gr.Column(scale=1):
|
| 351 |
gr.Markdown("### Disease Characteristics")
|
| 352 |
for f in DISEASE_FEATURES:
|
| 353 |
inputs_dict[f] = make_component(f)
|
| 354 |
|
| 355 |
+
|
| 356 |
inputs_dict["AGE"].change(
|
| 357 |
fn=get_age_group,
|
| 358 |
inputs=inputs_dict["AGE"],
|
|
|
|
| 365 |
outputs=inputs_dict["VOCFRQPR"],
|
| 366 |
)
|
| 367 |
|
| 368 |
+
|
| 369 |
grouped_regimen_dropdown.change(
|
| 370 |
fn=apply_grouped_preset,
|
| 371 |
inputs=grouped_regimen_dropdown,
|
|
|
|
| 394 |
gr.Markdown("---")
|
| 395 |
gr.Markdown("## Outcome Probability β Icon Arrays")
|
| 396 |
|
| 397 |
+
|
| 398 |
icon_array_grid = gr.HTML(label="")
|
| 399 |
|
| 400 |
gr.Markdown("---")
|