Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| import requests | |
| import SimpleITK as sitk # noqa: N813 | |
| import spaces | |
| import torch | |
| from cinema import CineMA, ConvUNetR | |
| from cinema.examples.cine_cmr import plot_cmr_views | |
| from cinema.examples.inference.mae import plot_mae_reconstruction, reconstruct_images | |
| from cinema.examples.inference.segmentation_lax_4c import ( | |
| plot_segmentations as plot_segmentations_lax, | |
| ) | |
| from cinema.examples.inference.segmentation_lax_4c import ( | |
| plot_volume_changes as plot_volume_changes_lax, | |
| ) | |
| from cinema.examples.inference.segmentation_lax_4c import ( | |
| post_process as post_process_lax_segmentation, | |
| ) | |
| from cinema.examples.inference.segmentation_sax import ( | |
| plot_segmentations as plot_segmentations_sax, | |
| ) | |
| from cinema.examples.inference.segmentation_sax import ( | |
| plot_volume_changes as plot_volume_changes_sax, | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| from monai.transforms import Compose, ScaleIntensityd, SpatialPadd | |
| from tqdm import tqdm | |
| # cache directories | |
| cache_dir = Path("/tmp/.cinema") | |
| cache_dir.mkdir(parents=True, exist_ok=True) | |
| # set device and dtype | |
| dtype, device = torch.float32, torch.device("cpu") | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| if torch.cuda.is_bf16_supported(): | |
| dtype = torch.bfloat16 | |
| # Create the Gradio interface | |
| theme = gr.themes.Ocean( | |
| primary_hue="red", | |
| secondary_hue="purple", | |
| ) | |
| def load_nifti_from_github(name: str) -> sitk.Image: | |
| path = cache_dir / name | |
| if not path.exists(): | |
| image_url = f"https://raw.githubusercontent.com/mathpluscode/CineMA/main/cinema/examples/data/{name}" | |
| response = requests.get(image_url) | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(path, "wb") as f: | |
| f.write(response.content) | |
| return sitk.ReadImage(path) | |
| def cmr_tab(): | |
| with gr.Blocks() as cmr_interface: | |
| gr.Markdown( | |
| """ | |
| This page illustrates the spatial orientation of short-axis (SAX) and long-axis (LAX) views in 3D. Use the control panels on the right to select specific images and slices. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| gr.Markdown("## Views") | |
| cmr_plot = gr.Plot(show_label=False) | |
| with gr.Column(scale=3): | |
| gr.Markdown("## Data Settings") | |
| image_id = gr.Slider( | |
| minimum=1, | |
| maximum=4, | |
| step=1, | |
| label="Choose an image, ID is between 1 and 4", | |
| value=1, | |
| ) | |
| # Placeholder for slice slider, will update dynamically | |
| slice_idx = gr.Slider( | |
| minimum=0, | |
| maximum=8, | |
| step=1, | |
| label="SAX slice to visualize", | |
| value=0, | |
| ) | |
| def get_num_slices(image_id): | |
| sax_image = load_nifti_from_github(f"ukb/{image_id}/{image_id}_sax.nii.gz") | |
| return sax_image.GetSize()[2] | |
| def update_slice_slider(image_id): | |
| num_slices = get_num_slices(image_id) | |
| return gr.update(maximum=num_slices - 1, value=0, visible=True) | |
| def fn(image_id, slice_idx): | |
| lax_2c_image = load_nifti_from_github( | |
| f"ukb/{image_id}/{image_id}_lax_2c.nii.gz" | |
| ) | |
| lax_3c_image = load_nifti_from_github( | |
| f"ukb/{image_id}/{image_id}_lax_3c.nii.gz" | |
| ) | |
| lax_4c_image = load_nifti_from_github( | |
| f"ukb/{image_id}/{image_id}_lax_4c.nii.gz" | |
| ) | |
| sax_image = load_nifti_from_github(f"ukb/{image_id}/{image_id}_sax.nii.gz") | |
| fig = plot_cmr_views( | |
| lax_2c_image, | |
| lax_3c_image, | |
| lax_4c_image, | |
| sax_image, | |
| t_to_show=4, | |
| depth_to_show=slice_idx, | |
| ) | |
| fig.update_layout(height=600) | |
| return fig | |
| # When image changes, update the slice slider and plot | |
| gr.on( | |
| fn=lambda image_id: [update_slice_slider(image_id), fn(image_id, 0)], | |
| inputs=[image_id], | |
| outputs=[slice_idx, cmr_plot], | |
| ) | |
| # When slice changes, update the plot | |
| slice_idx.change( | |
| fn=fn, | |
| inputs=[image_id, slice_idx], | |
| outputs=[cmr_plot], | |
| ) | |
| return cmr_interface | |
| def mae_inference( | |
| batch: dict[str, torch.Tensor], | |
| transform: Compose, | |
| model: CineMA, | |
| mask_ratio: float, | |
| ) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray], dict[str, np.ndarray]]: | |
| model.to(device) | |
| sax_slices = batch["sax"].shape[-1] | |
| batch = transform(batch) | |
| batch = {k: v[None, ...].to(device=device, dtype=dtype) for k, v in batch.items()} | |
| with ( | |
| torch.no_grad(), | |
| torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()), | |
| ): | |
| _, pred_dict, enc_mask_dict, _ = model(batch, enc_mask_ratio=mask_ratio) | |
| grid_size_dict = { | |
| k: v.patch_embed.grid_size for k, v in model.enc_down_dict.items() | |
| } | |
| reconstructed_dict, masks_dict = reconstruct_images( | |
| batch, | |
| pred_dict, | |
| enc_mask_dict, | |
| model.dec_patch_size_dict, | |
| grid_size_dict, | |
| sax_slices, | |
| ) | |
| batch = { | |
| k: v.detach().to(torch.float32).cpu().numpy()[0, 0] | |
| for k, v in batch.items() | |
| } | |
| batch["sax"] = batch["sax"][..., :sax_slices] | |
| return batch, reconstructed_dict, masks_dict | |
| def mae(image_id, mask_ratio, progress=gr.Progress()): | |
| t = 4 # which time frame to use | |
| progress(0, desc="Downloading model...") | |
| model = CineMA.from_pretrained() | |
| model.eval() | |
| progress(0, desc="Downloading data...") | |
| lax_2c_image = load_nifti_from_github(f"ukb/{image_id}/{image_id}_lax_2c.nii.gz") | |
| lax_3c_image = load_nifti_from_github(f"ukb/{image_id}/{image_id}_lax_3c.nii.gz") | |
| lax_4c_image = load_nifti_from_github(f"ukb/{image_id}/{image_id}_lax_4c.nii.gz") | |
| sax_image = load_nifti_from_github(f"ukb/{image_id}/{image_id}_sax.nii.gz") | |
| transform = Compose( | |
| [ | |
| ScaleIntensityd(keys=("sax", "lax_2c", "lax_3c", "lax_4c")), | |
| SpatialPadd(keys="sax", spatial_size=(192, 192, 16), method="end"), | |
| SpatialPadd( | |
| keys=("lax_2c", "lax_3c", "lax_4c"), | |
| spatial_size=(256, 256), | |
| method="end", | |
| ), | |
| ] | |
| ) | |
| lax_2c_image_np = np.transpose(sitk.GetArrayFromImage(lax_2c_image)) | |
| lax_3c_image_np = np.transpose(sitk.GetArrayFromImage(lax_3c_image)) | |
| lax_4c_image_np = np.transpose(sitk.GetArrayFromImage(lax_4c_image)) | |
| sax_image_np = np.transpose(sitk.GetArrayFromImage(sax_image)) | |
| image_dict = { | |
| "sax": sax_image_np[None, ..., t], | |
| "lax_2c": lax_2c_image_np[None, ..., 0, t], | |
| "lax_3c": lax_3c_image_np[None, ..., 0, t], | |
| "lax_4c": lax_4c_image_np[None, ..., 0, t], | |
| } | |
| batch = {k: torch.from_numpy(v) for k, v in image_dict.items()} | |
| progress(0.5, desc="Running inference...") | |
| batch, reconstructed_dict, masks_dict = mae_inference( | |
| batch, transform, model, mask_ratio | |
| ) | |
| progress(1, desc="Plotting results...") | |
| fig = plot_mae_reconstruction( | |
| batch, | |
| reconstructed_dict, | |
| masks_dict, | |
| ) | |
| return fig | |
| def mae_tab(): | |
| with gr.Blocks() as mae_interface: | |
| gr.Markdown( | |
| """ | |
| This page illustrates the masking and reconstruction process of the masked autoencoder. The model was trained with mask ratio 0.75 over 74,000 studies. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| gr.Markdown("## Reconstruction") | |
| plot = gr.Plot(show_label=False) | |
| with gr.Column(scale=3): | |
| gr.Markdown("## Data Settings") | |
| image_id = gr.Slider( | |
| minimum=1, | |
| maximum=4, | |
| step=1, | |
| label="Choose an image, ID is between 1 and 4", | |
| value=1, | |
| ) | |
| mask_ratio = gr.Slider( | |
| minimum=0.05, | |
| maximum=1, | |
| step=0.05, | |
| label="Mask ratio", | |
| value=0.75, | |
| ) | |
| run_button = gr.Button("Run Masked Autoencoder", variant="primary") | |
| run_button.click( | |
| fn=mae, | |
| inputs=[image_id, mask_ratio], | |
| outputs=[plot], | |
| ) | |
| return mae_interface | |
| def segmentation_sax_inference( | |
| images: torch.Tensor, | |
| view: str, | |
| transform: Compose, | |
| model: ConvUNetR, | |
| progress=gr.Progress(), | |
| ) -> np.ndarray: | |
| model.to(device) | |
| n_slices, n_frames = images.shape[-2:] | |
| labels_list = [] | |
| for t in tqdm(range(0, n_frames), total=n_frames): | |
| progress((t + 1) / n_frames, desc=f"Processing frame {t + 1} / {n_frames}...") | |
| batch = transform({view: torch.from_numpy(images[None, ..., t])}) | |
| batch = { | |
| k: v[None, ...].to(device=device, dtype=torch.float32) | |
| for k, v in batch.items() | |
| } | |
| with ( | |
| torch.no_grad(), | |
| torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()), | |
| ): | |
| logits = model(batch)[view] | |
| labels_list.append(torch.argmax(logits, dim=1)[0, ..., :n_slices]) | |
| labels = torch.stack(labels_list, dim=-1).detach().to(torch.float32).cpu().numpy() | |
| return labels | |
| def segmentation_sax(trained_dataset, seed, image_id, t_step, progress=gr.Progress()): | |
| # Fixed parameters | |
| view = "sax" | |
| split = "train" if image_id <= 100 else "test" | |
| trained_dataset = { | |
| "ACDC": "acdc", | |
| "M&MS": "mnms", | |
| "M&MS2": "mnms2", | |
| }[str(trained_dataset)] | |
| # Download and load model | |
| progress(0, desc="Downloading model...") | |
| image_path = hf_hub_download( | |
| repo_id="mathpluscode/ACDC", | |
| repo_type="dataset", | |
| filename=f"{split}/patient{image_id:03d}/patient{image_id:03d}_sax_t.nii.gz", | |
| cache_dir=cache_dir, | |
| ) | |
| model = ConvUNetR.from_finetuned( | |
| repo_id="mathpluscode/CineMA", | |
| model_filename=f"finetuned/segmentation/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors", | |
| config_filename=f"finetuned/segmentation/{trained_dataset}_{view}/config.yaml", | |
| cache_dir=cache_dir, | |
| ) | |
| model.eval() | |
| # Inference | |
| progress(0, desc="Downloading data...") | |
| transform = Compose( | |
| [ | |
| ScaleIntensityd(keys=view), | |
| SpatialPadd(keys=view, spatial_size=(192, 192, 16), method="end"), | |
| ] | |
| ) | |
| images = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(image_path))) | |
| images = images[..., ::t_step] | |
| labels = segmentation_sax_inference(images, view, transform, model, progress) | |
| progress(1, desc="Plotting results...") | |
| fig1 = plot_segmentations_sax(images, labels, t_step) | |
| fig2 = plot_volume_changes_sax(labels, t_step) | |
| return fig1, fig2 | |
| def segmentation_sax_tab(): | |
| with gr.Blocks() as sax_interface: | |
| gr.Markdown( | |
| """ | |
| This page demonstrates the segmentation of cardiac structures in the Short-Axis (SAX) view. | |
| Please adjust the settings on the right panels and click the button to run the inference. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| gr.Markdown(""" | |
| ## Description | |
| ### Data | |
| The available data is from ACDC. All images have been resampled to 1 mm Γ 1 mm Γ 10 mm and centre-cropped to 192 mm Γ 192 mm for each SAX slice. | |
| Image 101 - 150 are from the test set. | |
| ### Model | |
| The available models are finetuned on different datasets ([ACDC](https://www.creatis.insa-lyon.fr/Challenge/acdc/), [M&Ms](https://www.ub.edu/mnms/), and [M&Ms2](https://www.ub.edu/mnms-2/)). For each dataset, there are 3 models finetuned on different seeds: 0, 1, 2. | |
| ### Visualization | |
| The left figure shows the segmentation of ventricles and myocardium every n time steps across all SAX slices. | |
| The right figure plots the ventricle and mycoardium volumes across all inference time frames. | |
| """) | |
| with gr.Column(scale=3): | |
| gr.Markdown("## Data Settings") | |
| image_id = gr.Slider( | |
| minimum=101, | |
| maximum=150, | |
| step=1, | |
| label="Choose an ACDC image, ID is between 101 and 150", | |
| value=101, | |
| ) | |
| t_step = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| step=1, | |
| label="Choose the gap between time frames", | |
| value=2, | |
| ) | |
| with gr.Column(scale=3): | |
| gr.Markdown("## Model Setting") | |
| trained_dataset = gr.Dropdown( | |
| choices=["ACDC", "M&MS", "M&MS2"], | |
| label="Choose which dataset the segmentation model was finetuned on", | |
| value="ACDC", | |
| ) | |
| seed = gr.Slider( | |
| minimum=0, | |
| maximum=2, | |
| step=1, | |
| label="Choose which seed the finetuning used", | |
| value=0, | |
| ) | |
| run_button = gr.Button("Run SAX segmentation inference", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Ventricle and Myocardium Segmentation") | |
| segmentation_plot = gr.Plot(show_label=False) | |
| with gr.Column(): | |
| gr.Markdown("## Ejection Fraction Prediction") | |
| volume_plot = gr.Plot(show_label=False) | |
| run_button.click( | |
| fn=segmentation_sax, | |
| inputs=[trained_dataset, seed, image_id, t_step], | |
| outputs=[segmentation_plot, volume_plot], | |
| ) | |
| return sax_interface | |
| def segmentation_lax_inference( | |
| images: torch.Tensor, | |
| view: str, | |
| transform: Compose, | |
| model: ConvUNetR, | |
| progress=gr.Progress(), | |
| ) -> np.ndarray: | |
| model.to(device) | |
| n_frames = images.shape[-1] | |
| labels_list = [] | |
| for t in tqdm(range(n_frames), total=n_frames): | |
| progress((t + 1) / n_frames, desc=f"Processing frame {t + 1} / {n_frames}...") | |
| batch = transform({view: torch.from_numpy(images[None, ..., 0, t])}) | |
| batch = { | |
| k: v[None, ...].to(device=device, dtype=dtype) for k, v in batch.items() | |
| } | |
| with ( | |
| torch.no_grad(), | |
| torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()), | |
| ): | |
| logits = model(batch)[view] # (1, 4, x, y) | |
| labels = ( | |
| torch.argmax(logits, dim=1)[0].detach().to(torch.float32).cpu().numpy() | |
| ) # (x, y) | |
| # the model seems to hallucinate an additional right ventricle and myocardium sometimes | |
| # find the connected component that is closest to left ventricle | |
| labels = post_process_lax_segmentation(labels) | |
| labels_list.append(labels) | |
| labels = np.stack(labels_list, axis=-1) # (x, y, t) | |
| return labels | |
| def segmentation_lax(seed, image_id, progress=gr.Progress()): | |
| # Fixed parameters | |
| trained_dataset = "mnms2" | |
| view = "lax_4c" | |
| # Download and load model | |
| progress(0, desc="Downloading model...") | |
| model = ConvUNetR.from_finetuned( | |
| repo_id="mathpluscode/CineMA", | |
| model_filename=f"finetuned/segmentation/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors", | |
| config_filename=f"finetuned/segmentation/{trained_dataset}_{view}/config.yaml", | |
| cache_dir=cache_dir, | |
| ) | |
| model.eval() | |
| # Inference | |
| progress(0, desc="Downloading data...") | |
| transform = ScaleIntensityd(keys=view) | |
| images = np.transpose( | |
| sitk.GetArrayFromImage( | |
| load_nifti_from_github(f"ukb/{image_id}/{image_id}_{view}.nii.gz") | |
| ) | |
| ) | |
| labels = segmentation_lax_inference(images, view, transform, model, progress) | |
| progress(1, desc="Plotting results...") | |
| fig1 = plot_segmentations_lax(images, labels) | |
| fig2 = plot_volume_changes_lax(labels) | |
| return fig1, fig2 | |
| def segmentation_lax_tab(): | |
| with gr.Blocks() as lax_interface: | |
| gr.Markdown( | |
| """ | |
| This page demonstrates the segmentation of cardiac structures in the Long-Axis (LAX) view. | |
| Please adjust the settings on the right panels and click the button to run the inference. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| gr.Markdown(""" | |
| ## Description | |
| ### Data | |
| There are four example samples. All images have been resampled to 1 mm Γ 1 mm and centre-cropped. | |
| ### Model | |
| The available models are finetuned on [M&Ms2](https://www.ub.edu/mnms-2/). For each dataset, there are 3 models finetuned on different seeds: 0, 1, 2. | |
| ### Visualization | |
| The left figure shows the segmentation of ventricles and myocardium across all time frames. | |
| The right figure plots the ventricle and mycoardium volumes across all inference time frames. | |
| """) | |
| with gr.Column(scale=3): | |
| gr.Markdown("## Data Settings") | |
| image_id = gr.Slider( | |
| minimum=1, | |
| maximum=4, | |
| step=1, | |
| label="Choose an image, ID is between 1 and 4", | |
| value=4, | |
| ) | |
| with gr.Column(scale=3): | |
| gr.Markdown("## Model Setting") | |
| seed = gr.Slider( | |
| minimum=0, | |
| maximum=2, | |
| step=1, | |
| label="Choose which seed the finetuning used", | |
| value=0, | |
| ) | |
| run_button = gr.Button("Run LAX segmentation inference", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Ventricle and Myocardium Segmentation") | |
| segmentation_plot = gr.Plot(show_label=False) | |
| with gr.Column(): | |
| gr.Markdown("## Ejection Fraction Prediction") | |
| volume_plot = gr.Plot(show_label=False) | |
| run_button.click( | |
| fn=segmentation_lax, | |
| inputs=[seed, image_id], | |
| outputs=[segmentation_plot, volume_plot], | |
| ) | |
| return lax_interface | |
| with gr.Blocks( | |
| theme=theme, title="CineMA: A Foundation Model for Cine Cardiac MRI" | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # CineMA: A Foundation Model for Cine Cardiac MRI π₯π« | |
| This demo showcases the capabilities of CineMA in multiple tasks. | |
| For more details, checkout our [GitHub](https://github.com/mathpluscode/CineMA). | |
| """ | |
| ) | |
| with gr.Tabs() as tabs: | |
| with gr.TabItem("Cine CMR Views"): | |
| cmr_tab() | |
| with gr.TabItem("Masked Autoencoder"): | |
| mae_tab() | |
| with gr.TabItem("Segmentation in SAX View"): | |
| segmentation_sax_tab() | |
| with gr.TabItem("Segmentation in LAX View"): | |
| segmentation_lax_tab() | |
| demo.launch() | |