| | import os |
| |
|
| | os.environ["KERAS_BACKEND"] = "jax" |
| |
|
| | import gradio as gr |
| | import jax |
| | import keras |
| | import numpy as np |
| | import spaces |
| | from PIL import Image |
| | from zea import init_device |
| |
|
| | from main import Config, init, run |
| | from utils import load_image |
| | import torch |
| | import subprocess |
| |
|
| | CONFIG_PATH = "configs/semantic_dps.yaml" |
| | SLIDER_CONFIG_PATH = "configs/slider_params.yaml" |
| | ASSETS_DIR = "assets" |
| | DEVICE = None |
| |
|
| | STATUS_STYLE_LOAD = "display:flex;align-items:center;justify-content:center;padding:40px 10px 18px 10px;border-radius:8px;font-weight:bold;font-size:1.15em;line-height:1.5;align-items:center;" |
| |
|
| | STATUS_STYLE = "display:flex;align-items:center;justify-content:center;padding:18px 18px 18px 10px;border-radius:8px;font-weight:bold;font-size:1.15em;line-height:1.1;align-items:center;" |
| |
|
| | description = """ |
| | # Cardiac Ultrasound Dehazing with Semantic Diffusion |
| | |
| | Select an example image below to see the dehazing algorithm in action. The algorithm was tuned for the DehazingEcho2025 challenge dataset, so be wary of using it on other datasets. |
| | |
| | Tip: Adjust "Omega (Ventricle)" and "Eta (haze prior)" to control the dehazing effect. |
| | """ |
| |
|
| |
|
| | |
| | config, diffusion_model = None, None |
| | model_loaded = False |
| |
|
| |
|
| | def initialize_model(): |
| | global config, diffusion_model, model_loaded |
| | if config is None or diffusion_model is None: |
| | config = Config.from_yaml(CONFIG_PATH) |
| | diffusion_model = init(config) |
| | |
| | h, w = diffusion_model.input_shape[:2] |
| | dummy_img = np.zeros((1, h, w), dtype=np.float32) |
| | params = config.params |
| | guidance_kwargs = { |
| | "omega": params["guidance_kwargs"]["omega"], |
| | "omega_vent": params["guidance_kwargs"].get("omega_vent", 1.0), |
| | "omega_sept": params["guidance_kwargs"].get("omega_sept", 1.0), |
| | "eta": params["guidance_kwargs"].get("eta", 1.0), |
| | "smooth_l1_beta": params["guidance_kwargs"]["smooth_l1_beta"], |
| | } |
| | seed = jax.random.PRNGKey(config.seed) |
| | run( |
| | hazy_images=dummy_img, |
| | diffusion_model=diffusion_model, |
| | seed=seed, |
| | guidance_kwargs=guidance_kwargs, |
| | mask_params=params["mask_params"], |
| | fixed_mask_params=params["fixed_mask_params"], |
| | skeleton_params=params["skeleton_params"], |
| | batch_size=1, |
| | diffusion_steps=1, |
| | verbose=False, |
| | ) |
| |
|
| | model_loaded = True |
| | return config, diffusion_model |
| |
|
| |
|
| | @spaces.GPU(duration=10) |
| | def process_image(input_img, diffusion_steps, omega, omega_vent, omega_sept, eta): |
| | global config, diffusion_model, model_loaded |
| | if not model_loaded: |
| | yield ( |
| | gr.update( |
| | value=f'<div style="background:#ffeeba;{STATUS_STYLE}color:#856404;">⏳ Model is still loading. Please wait...</div>' |
| | ), |
| | None, |
| | ) |
| | return |
| |
|
| | if input_img is None: |
| | yield ( |
| | gr.update( |
| | value=f'<div style="background:#ffeeba;{STATUS_STYLE}color:#856404;">⚠️ No input image was provided. Please select or upload an image before running.</div>' |
| | ), |
| | None, |
| | ) |
| | return |
| | params = config.params |
| |
|
| | def _prepare_image(image): |
| | resized = False |
| | if image.mode != "L": |
| | image = image.convert("L") |
| | orig_shape = image.size[::-1] |
| | h, w = diffusion_model.input_shape[:2] |
| | if image.size != (w, h): |
| | image = image.resize((w, h), Image.BILINEAR) |
| | resized = True |
| | image = np.array(image) |
| | image = image.astype(np.float32) |
| | image = image[None, ...] |
| | return image, resized, orig_shape |
| |
|
| | try: |
| | image, resized, orig_shape = _prepare_image(input_img) |
| | except Exception as e: |
| | yield ( |
| | gr.update( |
| | value=f'<div style="background:#f8d7da;{STATUS_STYLE}color:#721c24;">❌ Error preparing input image: {e}</div>' |
| | ), |
| | None, |
| | ) |
| | return |
| |
|
| | guidance_kwargs = { |
| | "omega": omega, |
| | "omega_vent": omega_vent, |
| | "omega_sept": omega_sept, |
| | "eta": eta, |
| | "smooth_l1_beta": params["guidance_kwargs"]["smooth_l1_beta"], |
| | } |
| |
|
| | seed = jax.random.PRNGKey(config.seed) |
| |
|
| | try: |
| | yield ( |
| | gr.update( |
| | value=f'<div style="background:#cce5ff;{STATUS_STYLE}color:#004085;">🌀 Running dehazing algorithm...</div>' |
| | ), |
| | None, |
| | ) |
| | _, pred_tissue_images, *_ = run( |
| | hazy_images=image, |
| | diffusion_model=diffusion_model, |
| | seed=seed, |
| | guidance_kwargs=guidance_kwargs, |
| | mask_params=params["mask_params"], |
| | fixed_mask_params=params["fixed_mask_params"], |
| | skeleton_params=params["skeleton_params"], |
| | batch_size=1, |
| | diffusion_steps=diffusion_steps, |
| | threshold_output_quantile=params.get("threshold_output_quantile", None), |
| | preserve_bottom_percent=params.get("preserve_bottom_percent", 30.0), |
| | bottom_transition_width=params.get("bottom_transition_width", 10.0), |
| | verbose=False, |
| | ) |
| | except Exception as e: |
| | yield ( |
| | gr.update( |
| | value=f'<div style="background:#f8d7da;{STATUS_STYLE}color:#721c24;">❌ The algorithm failed to process the image: {e}</div>' |
| | ), |
| | None, |
| | ) |
| | return |
| |
|
| | out_img = np.squeeze(pred_tissue_images[0]) |
| | out_img = np.clip(out_img, 0, 255).astype(np.uint8) |
| | out_pil = Image.fromarray(out_img) |
| | if resized and out_pil.size != (orig_shape[1], orig_shape[0]): |
| | out_pil = out_pil.resize((orig_shape[1], orig_shape[0]), Image.BILINEAR) |
| | yield ( |
| | gr.update( |
| | value=f'<div style="background:#d4edda;{STATUS_STYLE}color:#155724;">✅ Done!</div>' |
| | ), |
| | (input_img, out_pil), |
| | ) |
| |
|
| |
|
| | slider_params = Config.from_yaml(SLIDER_CONFIG_PATH) |
| |
|
| | diffusion_steps_default = slider_params["diffusion_steps"]["default"] |
| | diffusion_steps_min = slider_params["diffusion_steps"]["min"] |
| | diffusion_steps_max = slider_params["diffusion_steps"]["max"] |
| | diffusion_steps_step = slider_params["diffusion_steps"]["step"] |
| |
|
| | omega_default = slider_params["omega"]["default"] |
| | omega_min = slider_params["omega"]["min"] |
| | omega_max = slider_params["omega"]["max"] |
| | omega_step = slider_params["omega"]["step"] |
| |
|
| | omega_vent_default = slider_params["omega_vent"]["default"] |
| | omega_vent_min = slider_params["omega_vent"]["min"] |
| | omega_vent_max = slider_params["omega_vent"]["max"] |
| | omega_vent_step = slider_params["omega_vent"]["step"] |
| |
|
| | omega_sept_default = slider_params["omega_sept"]["default"] |
| | omega_sept_min = slider_params["omega_sept"]["min"] |
| | omega_sept_max = slider_params["omega_sept"]["max"] |
| | omega_sept_step = slider_params["omega_sept"]["step"] |
| |
|
| | eta_default = slider_params["eta"]["default"] |
| | eta_min = slider_params["eta"]["min"] |
| | eta_max = slider_params["eta"]["max"] |
| | eta_step = slider_params["eta"]["step"] |
| |
|
| |
|
| | example_image_paths = [ |
| | os.path.join(ASSETS_DIR, f) |
| | for f in os.listdir(ASSETS_DIR) |
| | if f.lower().endswith(".png") |
| | ] |
| | example_images = [load_image(p) for p in example_image_paths] |
| | examples = [[img] for img in example_images] |
| |
|
| |
|
| | with gr.Blocks() as demo: |
| | gr.Markdown(description) |
| | status = gr.Markdown( |
| | f'<div style="background:#ffeeba;{STATUS_STYLE_LOAD}color:#856404;">⏳ Loading model...</div>', |
| | visible=True, |
| | ) |
| | with gr.Row(): |
| | with gr.Column(): |
| | img1 = gr.Image( |
| | label="Input Image", |
| | type="pil", |
| | webcam_options=False, |
| | value=example_images[0] if example_images else None, |
| | ) |
| | gr.Examples(examples=examples, inputs=[img1]) |
| | with gr.Column(): |
| | img2 = gr.ImageSlider(label="Dehazed Image", type="pil") |
| | with gr.Row(): |
| | diffusion_steps_slider = gr.Slider( |
| | minimum=diffusion_steps_min, |
| | maximum=diffusion_steps_max, |
| | step=diffusion_steps_step, |
| | value=diffusion_steps_default, |
| | label="Diffusion Steps", |
| | ) |
| | omega_slider = gr.Slider( |
| | minimum=omega_min, |
| | maximum=omega_max, |
| | step=omega_step, |
| | value=omega_default, |
| | label="Omega (background)", |
| | ) |
| | omega_vent_slider = gr.Slider( |
| | minimum=omega_vent_min, |
| | maximum=omega_vent_max, |
| | step=omega_vent_step, |
| | value=omega_vent_default, |
| | label="Omega Ventricle", |
| | ) |
| | omega_sept_slider = gr.Slider( |
| | minimum=omega_sept_min, |
| | maximum=omega_sept_max, |
| | step=omega_sept_step, |
| | value=omega_sept_default, |
| | label="Omega Septum", |
| | ) |
| | eta_slider = gr.Slider( |
| | minimum=eta_min, |
| | maximum=eta_max, |
| | step=eta_step, |
| | value=eta_default, |
| | label="Eta (haze prior)", |
| | ) |
| | run_btn = gr.Button("Run", interactive=False) |
| |
|
| | run_btn.click( |
| | process_image, |
| | inputs=[ |
| | img1, |
| | diffusion_steps_slider, |
| | omega_slider, |
| | omega_vent_slider, |
| | omega_sept_slider, |
| | eta_slider, |
| | ], |
| | outputs=[status, img2], |
| | queue=True, |
| | ) |
| |
|
| | def load_model_event(): |
| | global config, diffusion_model, model_loaded, DEVICE |
| | try: |
| | if DEVICE is None: |
| | try: |
| | DEVICE = init_device() |
| | except: |
| | print("Could not initialize device using `zea.init_device()`") |
| | print(f"KERAS version: {keras.__version__}") |
| | try: |
| | print(f"JAX version: {jax.__version__}") |
| | print(f"JAX devices: {jax.devices()}") |
| | except Exception as e: |
| | print(f"Could not get JAX info: {e}") |
| |
|
| | try: |
| | print(f"PyTorch version: {torch.__version__}") |
| | print(f"PyTorch CUDA available: {torch.cuda.is_available()}") |
| | print(f"PyTorch CUDA device count: {torch.cuda.device_count()}") |
| | print(f"PyTorch devices: {[torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]}") |
| | print(f"PyTorch CUDA version: {torch.version.cuda}") |
| | print(f"PyTorch cuDNN version: {torch.backends.cudnn.version()}") |
| | except Exception as e: |
| | print(f"Could not get PyTorch info: {e}") |
| |
|
| | try: |
| | cuda_version = subprocess.getoutput("nvcc --version") |
| | print(f"nvcc version:\n{cuda_version}") |
| | nvidia_smi = subprocess.getoutput("nvidia-smi") |
| | print(f"nvidia-smi output:\n{nvidia_smi}") |
| | except Exception as e: |
| | print(f"Could not get CUDA/nvidia-smi info: {e}") |
| |
|
| | config, diffusion_model = initialize_model() |
| | ready_msg = gr.update( |
| | value=f'<div style="background:#d4edda;{STATUS_STYLE}color:#155724;">✅ Model loaded! You can now press Run.</div>' |
| | ) |
| | return ready_msg, gr.update(interactive=True) |
| | except Exception as e: |
| | return gr.update( |
| | value=f'<div style="background:#f8d7da;{STATUS_STYLE}color:#721c24;">❌ Error loading model: {e}</div>' |
| | ), gr.update(interactive=False) |
| |
|
| | demo.load( |
| | load_model_event, |
| | inputs=None, |
| | outputs=[status, run_btn], |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |
| |
|