Spaces:
Sleeping
Sleeping
Commit
Β·
7ca35bb
1
Parent(s):
515dbe1
Display gifs
Browse files- app.py +101 -57
- requirements.txt +1 -1
app.py
CHANGED
|
@@ -213,15 +213,18 @@ def mae(image_id, mask_ratio, progress=gr.Progress()):
|
|
| 213 |
batch, reconstructed_dict, masks_dict = mae_inference(
|
| 214 |
batch, transform, model, mask_ratio
|
| 215 |
)
|
| 216 |
-
progress(1, desc="Plotting
|
| 217 |
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
|
| 227 |
def mae_tab():
|
|
@@ -229,14 +232,16 @@ def mae_tab():
|
|
| 229 |
gr.Markdown(
|
| 230 |
"""
|
| 231 |
This page demonstrates the masking and reconstruction process of the masked autoencoder. The model was trained with a mask ratio of 0.75 over 74,000 studies.
|
| 232 |
-
|
| 233 |
-
Visualisation may take a few seconds as we download the model weights, process the data, and render the plots.
|
| 234 |
"""
|
| 235 |
)
|
| 236 |
with gr.Row():
|
| 237 |
with gr.Column(scale=5):
|
| 238 |
gr.Markdown("## Reconstruction")
|
| 239 |
-
plot = gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
with gr.Column(scale=3):
|
| 241 |
gr.Markdown("## Data Settings")
|
| 242 |
image_id = gr.Slider(
|
|
@@ -254,6 +259,7 @@ def mae_tab():
|
|
| 254 |
value=0.75,
|
| 255 |
)
|
| 256 |
run_button = gr.Button("Run masked autoencoder", variant="primary")
|
|
|
|
| 257 |
run_button.click(
|
| 258 |
fn=mae,
|
| 259 |
inputs=[image_id, mask_ratio],
|
|
@@ -331,13 +337,27 @@ def segmentation_sax(trained_dataset, seed, image_id, t_step, progress=gr.Progre
|
|
| 331 |
images = images[..., ::t_step]
|
| 332 |
labels = segmentation_sax_inference(images, view, transform, model, progress)
|
| 333 |
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
|
| 342 |
|
| 343 |
def segmentation_sax_tab():
|
|
@@ -345,8 +365,6 @@ def segmentation_sax_tab():
|
|
| 345 |
gr.Markdown(
|
| 346 |
"""
|
| 347 |
This page demonstrates the segmentation of cardiac structures in the short-axis (SAX) view.
|
| 348 |
-
|
| 349 |
-
Visualisation may take dozens of seconds to update as we download model checkpoints, process multiple time frames sequentially, and generate the final plots.
|
| 350 |
"""
|
| 351 |
)
|
| 352 |
|
|
@@ -381,7 +399,7 @@ def segmentation_sax_tab():
|
|
| 381 |
maximum=10,
|
| 382 |
step=1,
|
| 383 |
label="Choose the gap between time frames",
|
| 384 |
-
value=
|
| 385 |
)
|
| 386 |
with gr.Column(scale=3):
|
| 387 |
gr.Markdown("## Model Settings")
|
|
@@ -401,16 +419,22 @@ def segmentation_sax_tab():
|
|
| 401 |
|
| 402 |
with gr.Row():
|
| 403 |
with gr.Column():
|
| 404 |
-
gr.
|
| 405 |
-
|
|
|
|
|
|
|
|
|
|
| 406 |
with gr.Column():
|
| 407 |
-
gr.
|
| 408 |
-
|
|
|
|
|
|
|
|
|
|
| 409 |
|
| 410 |
run_button.click(
|
| 411 |
fn=segmentation_sax,
|
| 412 |
inputs=[trained_dataset, seed, image_id, t_step],
|
| 413 |
-
outputs=[
|
| 414 |
)
|
| 415 |
return sax_interface
|
| 416 |
|
|
@@ -475,13 +499,17 @@ def segmentation_lax(seed, image_id, progress=gr.Progress()):
|
|
| 475 |
)
|
| 476 |
labels = segmentation_lax_inference(images, view, transform, model, progress)
|
| 477 |
|
| 478 |
-
progress(1, desc="Plotting
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
|
| 486 |
|
| 487 |
def segmentation_lax_tab():
|
|
@@ -489,8 +517,6 @@ def segmentation_lax_tab():
|
|
| 489 |
gr.Markdown(
|
| 490 |
"""
|
| 491 |
This page demonstrates the segmentation of cardiac structures in the long-axis (LAX) four-chamber (4C) view.
|
| 492 |
-
|
| 493 |
-
Visualisation may take a few seconds to update as we download model checkpoints, process multiple time frames, and generate the final plots.
|
| 494 |
"""
|
| 495 |
)
|
| 496 |
|
|
@@ -533,16 +559,22 @@ def segmentation_lax_tab():
|
|
| 533 |
|
| 534 |
with gr.Row():
|
| 535 |
with gr.Column():
|
| 536 |
-
gr.
|
| 537 |
-
|
|
|
|
|
|
|
|
|
|
| 538 |
with gr.Column():
|
| 539 |
-
gr.
|
| 540 |
-
|
|
|
|
|
|
|
|
|
|
| 541 |
|
| 542 |
run_button.click(
|
| 543 |
fn=segmentation_lax,
|
| 544 |
inputs=[seed, image_id],
|
| 545 |
-
outputs=[
|
| 546 |
)
|
| 547 |
return lax_interface
|
| 548 |
|
|
@@ -651,12 +683,19 @@ def landmark(image_id, view, method, seed, progress=gr.Progress()):
|
|
| 651 |
else:
|
| 652 |
raise ValueError(f"Invalid method: {method}")
|
| 653 |
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
|
| 661 |
|
| 662 |
def landmark_tab():
|
|
@@ -664,8 +703,6 @@ def landmark_tab():
|
|
| 664 |
gr.Markdown(
|
| 665 |
"""
|
| 666 |
This page demonstrates landmark localisation in the long-axis (LAX) two-chamber (2C) and four-chamber (4C) views.
|
| 667 |
-
|
| 668 |
-
Visualisation may take a few seconds to update as we download model checkpoints, process multiple time frames, and generate the final plots.
|
| 669 |
"""
|
| 670 |
)
|
| 671 |
|
|
@@ -679,7 +716,7 @@ def landmark_tab():
|
|
| 679 |
|
| 680 |
### Model
|
| 681 |
|
| 682 |
-
The available models are finetuned on data from [Xue et al.](https://pubs.rsna.org/doi/10.1148/ryai.2021200197)
|
| 683 |
There are two types of landmark localisation models:
|
| 684 |
|
| 685 |
- **Heatmap**: predicts dense probability maps of landmarks
|
|
@@ -726,16 +763,22 @@ def landmark_tab():
|
|
| 726 |
|
| 727 |
with gr.Row():
|
| 728 |
with gr.Column():
|
| 729 |
-
gr.
|
| 730 |
-
|
|
|
|
|
|
|
|
|
|
| 731 |
with gr.Column():
|
| 732 |
-
gr.
|
| 733 |
-
|
|
|
|
|
|
|
|
|
|
| 734 |
|
| 735 |
run_button.click(
|
| 736 |
fn=landmark,
|
| 737 |
inputs=[image_id, view, method, seed],
|
| 738 |
-
outputs=[
|
| 739 |
)
|
| 740 |
return landmark_interface
|
| 741 |
|
|
@@ -747,8 +790,9 @@ with gr.Blocks(
|
|
| 747 |
"""
|
| 748 |
# CineMA: A Foundation Model for Cine Cardiac MRI π₯π«
|
| 749 |
|
| 750 |
-
The following demos showcase the capabilities of CineMA in multiple tasks
|
| 751 |
-
|
|
|
|
| 752 |
"""
|
| 753 |
)
|
| 754 |
|
|
@@ -763,4 +807,4 @@ with gr.Blocks(
|
|
| 763 |
segmentation_lax_tab()
|
| 764 |
with gr.TabItem("π Landmark Localisation in LAX 2C/4C View", id="landmark"):
|
| 765 |
landmark_tab()
|
| 766 |
-
demo.launch()
|
|
|
|
| 213 |
batch, reconstructed_dict, masks_dict = mae_inference(
|
| 214 |
batch, transform, model, mask_ratio
|
| 215 |
)
|
| 216 |
+
progress(1, desc="Inference finished. Plotting ...")
|
| 217 |
|
| 218 |
+
# (y, x, z) -> (x, y, z)
|
| 219 |
+
batch["sax"] = np.transpose(batch["sax"], (1, 0, 2))
|
| 220 |
+
reconstructed_dict["sax"] = np.transpose(reconstructed_dict["sax"], (1, 0, 2))
|
| 221 |
+
masks_dict["sax"] = np.transpose(masks_dict["sax"], (1, 0, 2))
|
| 222 |
+
|
| 223 |
+
# Plot MAE reconstruction and save to file
|
| 224 |
+
mae_path = cache_dir / f"mae_image{image_id}_mask{mask_ratio:.2f}.png"
|
| 225 |
+
plot_mae_reconstruction(batch, reconstructed_dict, masks_dict, mae_path)
|
| 226 |
+
|
| 227 |
+
return str(mae_path)
|
| 228 |
|
| 229 |
|
| 230 |
def mae_tab():
|
|
|
|
| 232 |
gr.Markdown(
|
| 233 |
"""
|
| 234 |
This page demonstrates the masking and reconstruction process of the masked autoencoder. The model was trained with a mask ratio of 0.75 over 74,000 studies.
|
|
|
|
|
|
|
| 235 |
"""
|
| 236 |
)
|
| 237 |
with gr.Row():
|
| 238 |
with gr.Column(scale=5):
|
| 239 |
gr.Markdown("## Reconstruction")
|
| 240 |
+
plot = gr.Image(
|
| 241 |
+
show_label=False,
|
| 242 |
+
type="filepath",
|
| 243 |
+
label="Masked Autoencoder Reconstruction",
|
| 244 |
+
)
|
| 245 |
with gr.Column(scale=3):
|
| 246 |
gr.Markdown("## Data Settings")
|
| 247 |
image_id = gr.Slider(
|
|
|
|
| 259 |
value=0.75,
|
| 260 |
)
|
| 261 |
run_button = gr.Button("Run masked autoencoder", variant="primary")
|
| 262 |
+
|
| 263 |
run_button.click(
|
| 264 |
fn=mae,
|
| 265 |
inputs=[image_id, mask_ratio],
|
|
|
|
| 337 |
images = images[..., ::t_step]
|
| 338 |
labels = segmentation_sax_inference(images, view, transform, model, progress)
|
| 339 |
|
| 340 |
+
# (y, x, z, t) -> (x, y, z, t)
|
| 341 |
+
images = np.transpose(images, (1, 0, 2, 3))
|
| 342 |
+
labels = np.transpose(labels, (1, 0, 2, 3))
|
| 343 |
+
|
| 344 |
+
progress(1, desc="Inference finished. Plotting ...")
|
| 345 |
+
|
| 346 |
+
# Create file paths for saving plots
|
| 347 |
+
seg_path = (
|
| 348 |
+
cache_dir
|
| 349 |
+
/ f"sax_segmentation_{trained_dataset}_image{image_id}_seed{seed}_tstep{t_step}.gif"
|
| 350 |
+
)
|
| 351 |
+
vol_path = (
|
| 352 |
+
cache_dir
|
| 353 |
+
/ f"sax_volume_{trained_dataset}_image{image_id}_seed{seed}_tstep{t_step}.png"
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
# Plot segmentations and volume changes with file paths
|
| 357 |
+
plot_segmentations_sax(images, labels, seg_path)
|
| 358 |
+
plot_volume_changes_sax(labels, t_step, vol_path)
|
| 359 |
+
|
| 360 |
+
return (str(seg_path), str(vol_path))
|
| 361 |
|
| 362 |
|
| 363 |
def segmentation_sax_tab():
|
|
|
|
| 365 |
gr.Markdown(
|
| 366 |
"""
|
| 367 |
This page demonstrates the segmentation of cardiac structures in the short-axis (SAX) view.
|
|
|
|
|
|
|
| 368 |
"""
|
| 369 |
)
|
| 370 |
|
|
|
|
| 399 |
maximum=10,
|
| 400 |
step=1,
|
| 401 |
label="Choose the gap between time frames",
|
| 402 |
+
value=3,
|
| 403 |
)
|
| 404 |
with gr.Column(scale=3):
|
| 405 |
gr.Markdown("## Model Settings")
|
|
|
|
| 419 |
|
| 420 |
with gr.Row():
|
| 421 |
with gr.Column():
|
| 422 |
+
segmentation_gif = gr.Image(
|
| 423 |
+
show_label=True,
|
| 424 |
+
type="filepath",
|
| 425 |
+
label="Ventricle and Myocardium Segmentation",
|
| 426 |
+
)
|
| 427 |
with gr.Column():
|
| 428 |
+
volume_plot = gr.Image(
|
| 429 |
+
show_label=True,
|
| 430 |
+
type="filepath",
|
| 431 |
+
label="Ejection Fraction Estimation",
|
| 432 |
+
)
|
| 433 |
|
| 434 |
run_button.click(
|
| 435 |
fn=segmentation_sax,
|
| 436 |
inputs=[trained_dataset, seed, image_id, t_step],
|
| 437 |
+
outputs=[segmentation_gif, volume_plot],
|
| 438 |
)
|
| 439 |
return sax_interface
|
| 440 |
|
|
|
|
| 499 |
)
|
| 500 |
labels = segmentation_lax_inference(images, view, transform, model, progress)
|
| 501 |
|
| 502 |
+
progress(1, desc="Inference finished. Plotting ...")
|
| 503 |
+
|
| 504 |
+
# Plot segmentations and save as GIF
|
| 505 |
+
seg_path = cache_dir / f"lax_segmentation_image{image_id}_seed{seed}.gif"
|
| 506 |
+
plot_segmentations_lax(images, labels, seg_path)
|
| 507 |
+
|
| 508 |
+
# Plot volume changes and save as figure
|
| 509 |
+
vol_path = cache_dir / f"lax_volume_image{image_id}_seed{seed}.png"
|
| 510 |
+
plot_volume_changes_lax(labels, vol_path)
|
| 511 |
+
|
| 512 |
+
return (str(seg_path), str(vol_path))
|
| 513 |
|
| 514 |
|
| 515 |
def segmentation_lax_tab():
|
|
|
|
| 517 |
gr.Markdown(
|
| 518 |
"""
|
| 519 |
This page demonstrates the segmentation of cardiac structures in the long-axis (LAX) four-chamber (4C) view.
|
|
|
|
|
|
|
| 520 |
"""
|
| 521 |
)
|
| 522 |
|
|
|
|
| 559 |
|
| 560 |
with gr.Row():
|
| 561 |
with gr.Column():
|
| 562 |
+
segmentation_gif = gr.Image(
|
| 563 |
+
show_label=True,
|
| 564 |
+
type="filepath",
|
| 565 |
+
label="Ventricle and Myocardium Segmentation",
|
| 566 |
+
)
|
| 567 |
with gr.Column():
|
| 568 |
+
volume_plot = gr.Image(
|
| 569 |
+
show_label=True,
|
| 570 |
+
type="filepath",
|
| 571 |
+
label="Ejection Fraction Prediction",
|
| 572 |
+
)
|
| 573 |
|
| 574 |
run_button.click(
|
| 575 |
fn=segmentation_lax,
|
| 576 |
inputs=[seed, image_id],
|
| 577 |
+
outputs=[segmentation_gif, volume_plot],
|
| 578 |
)
|
| 579 |
return lax_interface
|
| 580 |
|
|
|
|
| 683 |
else:
|
| 684 |
raise ValueError(f"Invalid method: {method}")
|
| 685 |
|
| 686 |
+
progress(1, desc="Inference finished. Plotting ...")
|
| 687 |
+
|
| 688 |
+
# Plot landmarks in GIF
|
| 689 |
+
landmark_path = (
|
| 690 |
+
cache_dir / f"landmark_{view}_image{image_id}_{method}_seed{seed}.gif"
|
| 691 |
+
)
|
| 692 |
+
plot_landmarks(images, coords, landmark_path)
|
| 693 |
+
|
| 694 |
+
# Plot LV change in PNG
|
| 695 |
+
lv_path = cache_dir / f"lv_{view}_image{image_id}_{method}_seed{seed}.png"
|
| 696 |
+
plot_lv(coords, lv_path)
|
| 697 |
+
|
| 698 |
+
return (str(landmark_path), str(lv_path))
|
| 699 |
|
| 700 |
|
| 701 |
def landmark_tab():
|
|
|
|
| 703 |
gr.Markdown(
|
| 704 |
"""
|
| 705 |
This page demonstrates landmark localisation in the long-axis (LAX) two-chamber (2C) and four-chamber (4C) views.
|
|
|
|
|
|
|
| 706 |
"""
|
| 707 |
)
|
| 708 |
|
|
|
|
| 716 |
|
| 717 |
### Model
|
| 718 |
|
| 719 |
+
The available models are finetuned on data from [Xue et al.](https://pubs.rsna.org/doi/10.1148/ryai.2021200197)
|
| 720 |
There are two types of landmark localisation models:
|
| 721 |
|
| 722 |
- **Heatmap**: predicts dense probability maps of landmarks
|
|
|
|
| 763 |
|
| 764 |
with gr.Row():
|
| 765 |
with gr.Column():
|
| 766 |
+
landmark_gif = gr.Image(
|
| 767 |
+
show_label=True,
|
| 768 |
+
type="filepath",
|
| 769 |
+
label="Landmark Localisation",
|
| 770 |
+
)
|
| 771 |
with gr.Column():
|
| 772 |
+
lv_plot = gr.Image(
|
| 773 |
+
show_label=True,
|
| 774 |
+
type="filepath",
|
| 775 |
+
label="Left Ventricle Length Estimation",
|
| 776 |
+
)
|
| 777 |
|
| 778 |
run_button.click(
|
| 779 |
fn=landmark,
|
| 780 |
inputs=[image_id, view, method, seed],
|
| 781 |
+
outputs=[landmark_gif, lv_plot],
|
| 782 |
)
|
| 783 |
return landmark_interface
|
| 784 |
|
|
|
|
| 790 |
"""
|
| 791 |
# CineMA: A Foundation Model for Cine Cardiac MRI π₯π«
|
| 792 |
|
| 793 |
+
π The following demos showcase the capabilities of CineMA in multiple tasks.<br>
|
| 794 |
+
β±οΈ The examples may take 10-60 seconds to download data and model, perform inference, and render plots.<br>
|
| 795 |
+
π For more details, check out our [GitHub](https://github.com/mathpluscode/CineMA).
|
| 796 |
"""
|
| 797 |
)
|
| 798 |
|
|
|
|
| 807 |
segmentation_lax_tab()
|
| 808 |
with gr.TabItem("π Landmark Localisation in LAX 2C/4C View", id="landmark"):
|
| 809 |
landmark_tab()
|
| 810 |
+
demo.launch(allowed_paths=[cache_dir])
|
requirements.txt
CHANGED
|
@@ -17,6 +17,6 @@ scikit-learn==1.6.1
|
|
| 17 |
scipy==1.15.2
|
| 18 |
spaces==0.36.0
|
| 19 |
timm==1.0.15
|
| 20 |
-
git+https://github.com/mathpluscode/CineMA@
|
| 21 |
--extra-index-url https://download.pytorch.org/whl/cu113
|
| 22 |
torch==2.5.1
|
|
|
|
| 17 |
scipy==1.15.2
|
| 18 |
spaces==0.36.0
|
| 19 |
timm==1.0.15
|
| 20 |
+
git+https://github.com/mathpluscode/CineMA@7e86ffc7ddf06ad7283915ee143ed808c0f59576#egg=cinema
|
| 21 |
--extra-index-url https://download.pytorch.org/whl/cu113
|
| 22 |
torch==2.5.1
|