Spaces:
Build error
Build error
| import random | |
| import shutil | |
| import uuid | |
| from pathlib import Path | |
| import cv2 | |
| import gradio as gr | |
| import mediapy | |
| import mlflow.pytorch | |
| import numpy as np | |
| import torch | |
| from skimage import img_as_ubyte | |
| from models.ddim import DDIMSampler | |
| import nibabel as nib | |
| ffmpeg_path = shutil.which("ffmpeg") | |
| mediapy.set_ffmpeg(ffmpeg_path) | |
| # Loading model | |
| vqvae = mlflow.pytorch.load_model( | |
| "./trained_models/vae/final_model" | |
| ) | |
| vqvae.eval() | |
| diffusion = mlflow.pytorch.load_model( | |
| "./trained_models/ddpm/final_model" | |
| ) | |
| diffusion.eval() | |
| device = torch.device("cpu") | |
| diffusion = diffusion.to(device) | |
| vqvae = vqvae.to(device) | |
| def sample_fn( | |
| gender_radio, | |
| age_slider, | |
| ventricular_slider, | |
| brain_slider, | |
| ): | |
| print("Sampling brain!") | |
| print(f"Gender: {gender_radio}") | |
| print(f"Age: {age_slider}") | |
| print(f"Ventricular volume: {ventricular_slider}") | |
| print(f"Brain volume: {brain_slider}") | |
| age_slider = (age_slider - 44) / (82 - 44) | |
| cond = torch.Tensor([[gender_radio, age_slider, ventricular_slider, brain_slider]]) | |
| latent_shape = [1, 3, 20, 28, 20] | |
| cond_crossatten = cond.unsqueeze(1) | |
| cond_concat = cond.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
| cond_concat = cond_concat.expand(list(cond.shape[0:2]) + list(latent_shape[2:])) | |
| conditioning = { | |
| "c_concat": [cond_concat.float().to(device)], | |
| "c_crossattn": [cond_crossatten.float().to(device)], | |
| } | |
| ddim = DDIMSampler(diffusion) | |
| num_timesteps = 50 | |
| latent_vectors, _ = ddim.sample( | |
| num_timesteps, | |
| conditioning=conditioning, | |
| batch_size=1, | |
| shape=list(latent_shape[1:]), | |
| eta=1.0, | |
| ) | |
| with torch.no_grad(): | |
| x_hat = vqvae.reconstruct_ldm_outputs(latent_vectors).cpu() | |
| return x_hat.numpy() | |
| def create_videos_and_file( | |
| gender_radio, | |
| age_slider, | |
| ventricular_slider, | |
| brain_slider, | |
| ): | |
| output_dir = Path( | |
| f"/media/walter/Storage/Projects/gradio_medical_ldm/outputs/{str(uuid.uuid4())}" | |
| ) | |
| output_dir.mkdir(exist_ok=True) | |
| image_data = sample_fn( | |
| gender_radio, | |
| age_slider, | |
| ventricular_slider, | |
| brain_slider, | |
| ) | |
| image_data = image_data[0, 0, 5:-5, 5:-5, :-15] | |
| image_data = (image_data - image_data.min()) / (image_data.max() - image_data.min()) | |
| image_data = (image_data * 255).astype(np.uint8) | |
| # Write frames to video | |
| with mediapy.VideoWriter( | |
| f"{str(output_dir)}/brain_axial.mp4", shape=(150, 214), fps=12, crf=18 | |
| ) as w: | |
| for idx in range(image_data.shape[2]): | |
| img = image_data[:, :, idx] | |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) | |
| frame = img_as_ubyte(img) | |
| w.add_image(frame) | |
| with mediapy.VideoWriter( | |
| f"{str(output_dir)}/brain_sagittal.mp4", shape=(145, 214), fps=12, crf=18 | |
| ) as w: | |
| for idx in range(image_data.shape[0]): | |
| img = np.rot90(image_data[idx, :, :]) | |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) | |
| frame = img_as_ubyte(img) | |
| w.add_image(frame) | |
| with mediapy.VideoWriter( | |
| f"{str(output_dir)}/brain_coronal.mp4", shape=(145, 150), fps=12, crf=18 | |
| ) as w: | |
| for idx in range(image_data.shape[1]): | |
| img = np.rot90(np.flip(image_data, axis=1)[:, idx, :]) | |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) | |
| frame = img_as_ubyte(img) | |
| w.add_image(frame) | |
| # Create file | |
| affine = np.array( | |
| [ | |
| [-1.0, 0.0, 0.0, 96.48149872], | |
| [0.0, 1.0, 0.0, -141.47715759], | |
| [0.0, 0.0, 1.0, -156.55375671], | |
| [0.0, 0.0, 0.0, 1.0], | |
| ] | |
| ) | |
| empty_header = nib.Nifti1Header() | |
| sample_nii = nib.Nifti1Image(image_data, affine, empty_header) | |
| nib.save(sample_nii, f"{str(output_dir)}/my_brain.nii.gz") | |
| # time.sleep(2) | |
| return ( | |
| f"{str(output_dir)}/brain_axial.mp4", | |
| f"{str(output_dir)}/brain_sagittal.mp4", | |
| f"{str(output_dir)}/brain_coronal.mp4", | |
| f"{str(output_dir)}/my_brain.nii.gz", | |
| ) | |
| def randomise(): | |
| random_age = round(random.uniform(44.0, 82.0), 2) | |
| return ( | |
| random.choice(["Female", "Male"]), | |
| random_age, | |
| round(random.uniform(0, 1.0), 2), | |
| round(random.uniform(0, 1.0), 2), | |
| ) | |
| def unrest_randomise(): | |
| random_age = round(random.uniform(18.0, 100.0), 2) | |
| return ( | |
| random.choice([1, 0]), | |
| random_age, | |
| round(random.uniform(-1.0, 2.0), 2), | |
| round(random.uniform(-1.0, 2.0), 2), | |
| ) | |
| # TEXT | |
| title = "Generating Brain Imaging with Diffusion Models" | |
| description = """ | |
| <center><b>WORK IN PROGRESS. DO NOT SHARE.</b></center> | |
| <center><a href="https://arxiv.org/">[PAPER]</a> <a href="https://academictorrents.com/details/63aeb864bbe2115ded0aa0d7d36334c026f0660b">[DATASET]</a></center> | |
| <details> | |
| <summary>Instructions</summary> | |
| With this app, you can generate synthetic brain images with one click!<br />You have two ways to set how your generated brain will look like:<br />- Using the "Inputs" tab that creates well-behaved brains using the same value ranges that our models learned as described in paper linked above<br />- Or using the "Unrestricted Inputs" tab to generate the wildest brains!<br />After customisation, just hit "Generate" and wait a few seconds.<br />Note: if are having problems with the videos, try our app using chrome. <b>Enjoy!<b> | |
| </details> | |
| """ | |
| article = """ | |
| Checkout our dataset with [100K synthetic brain](https://academictorrents.com/details/63aeb864bbe2115ded0aa0d7d36334c026f0660b)! 🧠🧠🧠 | |
| App made by [Walter Hugo Lopez Pinaya](https://twitter.com/warvito) from [AMIGO](https://amigos.ai/) | |
| <center><img src="https://amigos.ai/assets/images/logo_dark_rect.png" alt="amigos.ai" width=300px></center> | |
| """ | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown( | |
| "<h1 style='text-align: center; margin-bottom: 1rem'>" + title + "</h1>" | |
| ) | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Box(): | |
| with gr.Tabs(): | |
| with gr.TabItem("Inputs"): | |
| with gr.Row(): | |
| gender_radio = gr.Radio( | |
| choices=["Female", "Male"], | |
| value="Female", | |
| type="index", | |
| label="Gender", | |
| interactive=True, | |
| ) | |
| age_slider = gr.Slider( | |
| minimum=44, | |
| maximum=82, | |
| value=63, | |
| label="Age [years]", | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| ventricular_slider = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.5, | |
| label="Volume of ventricular cerebrospinal fluid", | |
| interactive=True, | |
| ) | |
| brain_slider = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.5, | |
| label="Volume of brain", | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Generate", variant="primary") | |
| randomize_btn = gr.Button("I'm Feeling Lucky") | |
| with gr.TabItem("Unrestricted Inputs"): | |
| with gr.Row(): | |
| unrest_gender_number = gr.Number( | |
| value=1.0, | |
| precision=1, | |
| label="Gender [Female=0, Male=1]", | |
| interactive=True, | |
| ) | |
| unrest_age_number = gr.Number( | |
| value=63, | |
| precision=1, | |
| label="Age [years]", | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| unrest_ventricular_number = gr.Number( | |
| value=0.5, | |
| precision=2, | |
| label="Volume of ventricular cerebrospinal fluid", | |
| interactive=True, | |
| ) | |
| unrest_brain_number = gr.Number( | |
| value=0.5, | |
| precision=2, | |
| label="Volume of brain", | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| unrest_submit_btn = gr.Button("Generate", variant="primary") | |
| unrest_randomize_btn = gr.Button("I'm Feeling Lucky") | |
| gr.Examples( | |
| examples=[ | |
| [1, 63, 1.3, 0.5], | |
| [0, 63, 1.9, 0.5], | |
| [1, 63, -0.5, 0.5], | |
| [0, 63, 0.5, -0.3], | |
| ], | |
| inputs=[ | |
| unrest_gender_number, | |
| unrest_age_number, | |
| unrest_ventricular_number, | |
| unrest_brain_number, | |
| ], | |
| ) | |
| with gr.Column(): | |
| with gr.Box(): | |
| with gr.Tabs(): | |
| with gr.TabItem("Axial View"): | |
| axial_sample_plot = gr.Video(show_label=False) | |
| with gr.TabItem("Sagittal View"): | |
| sagittal_sample_plot = gr.Video(show_label=False) | |
| with gr.TabItem("Coronal View"): | |
| coronal_sample_plot = gr.Video(show_label=False) | |
| sample_file = gr.File(label="My Brain") | |
| gr.Markdown(article) | |
| submit_btn.click( | |
| create_videos_and_file, | |
| [ | |
| gender_radio, | |
| age_slider, | |
| ventricular_slider, | |
| brain_slider, | |
| ], | |
| [axial_sample_plot, sagittal_sample_plot, coronal_sample_plot, sample_file], | |
| # [axial_sample_plot, sagittal_sample_plot, coronal_sample_plot], | |
| ) | |
| unrest_submit_btn.click( | |
| create_videos_and_file, | |
| [ | |
| unrest_gender_number, | |
| unrest_age_number, | |
| unrest_ventricular_number, | |
| unrest_brain_number, | |
| ], | |
| [axial_sample_plot, sagittal_sample_plot, coronal_sample_plot, sample_file], | |
| # [axial_sample_plot, sagittal_sample_plot, coronal_sample_plot], | |
| ) | |
| randomize_btn.click( | |
| fn=randomise, | |
| inputs=[], | |
| queue=False, | |
| outputs=[gender_radio, age_slider, ventricular_slider, brain_slider], | |
| ) | |
| unrest_randomize_btn.click( | |
| fn=unrest_randomise, | |
| inputs=[], | |
| queue=False, | |
| outputs=[ | |
| unrest_gender_number, | |
| unrest_age_number, | |
| unrest_ventricular_number, | |
| unrest_brain_number, | |
| ], | |
| ) | |
| # demo.launch(share=True, enable_queue=True) | |
| demo.launch(enable_queue=True) | |