#!/usr/bin/env python3 """Gradio app for SynthCXR: interactive mask scaling and CXR generation.""" from __future__ import annotations import os os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" os.environ["DIFFSYNTH_DOWNLOAD_SOURCE"] = "huggingface" from pathlib import Path import spaces import gradio as gr import numpy as np import torch from PIL import Image from synthcxr.constants import KNOWN_CONDITIONS from synthcxr.mask_utils import resolve_overlaps, scale_mask_channel from synthcxr.prompt import ConditionConfig, build_condition_prompt # --------------------------------------------------------------------------- # Paths # --------------------------------------------------------------------------- BASE_DIR = Path(__file__).resolve().parent SAMPLE_MASKS_DIR = BASE_DIR / "static" / "sample_masks" LORA_DIR = BASE_DIR / "scripts" / "models" / "qwen_image_edit_chexpert_lora" # --------------------------------------------------------------------------- # Condition / severity choices # --------------------------------------------------------------------------- CONDITION_CHOICES = [ "enlarged_cardiomediastinum", "cardiomegaly", "atelectasis", "pneumothorax", "pleural_effusion", ] SEVERITY_CHOICES = ["(none)", "mild", "moderate", "severe"] # --------------------------------------------------------------------------- # Pipeline loading (fresh on each @spaces.GPU call; model files cached on disk) # --------------------------------------------------------------------------- def load_fresh_pipeline(): """Load the pipeline + LoRA onto the *currently allocated* GPU. ZeroGPU deallocates GPU memory after each ``@spaces.GPU`` call, so we cannot cache tensors between calls. However, diffsynth caches the model files on disk (HF Hub cache), so only tensor loading happens here — not a full download. """ from synthcxr.pipeline import load_lora_weights, load_pipeline device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 # VRAM_LIMIT (in GB): enables model offloading for memory-constrained GPUs vram_limit_str = os.environ.get("VRAM_LIMIT", "") vram_limit = float(vram_limit_str) if vram_limit_str else None print(f"[INFO] Loading QwenImagePipeline (device={device}, dtype={dtype}, vram_limit={vram_limit}) …") pipe = load_pipeline(device, dtype, vram_limit=vram_limit) # LORA_EPOCH env var: which epoch checkpoint to load (default: 2) lora_epoch = os.environ.get("LORA_EPOCH", "2") lora = LORA_DIR / f"epoch-{lora_epoch}.safetensors" if not lora.exists(): candidates = sorted(LORA_DIR.glob("*.safetensors")) if LORA_DIR.exists() else [] if candidates: lora = candidates[-1] print(f"[WARN] epoch-{lora_epoch} not found, falling back to {lora.name}") else: print("[WARN] No LoRA checkpoint found – running base model only.") return pipe print(f"[INFO] Loading LoRA from {lora}") load_lora_weights(pipe, lora) print("[INFO] Pipeline ready.") return pipe # --------------------------------------------------------------------------- # Sample masks # --------------------------------------------------------------------------- def get_sample_masks() -> list[str]: """Return paths of bundled sample masks.""" if not SAMPLE_MASKS_DIR.exists(): return [] return sorted(str(p) for p in SAMPLE_MASKS_DIR.glob("*.png")) # --------------------------------------------------------------------------- # Core functions # --------------------------------------------------------------------------- def apply_mask_scaling( mask_array: np.ndarray, heart_scale: float, left_lung_scale: float, right_lung_scale: float, ) -> np.ndarray: """Scale mask channels and resolve overlaps.""" if heart_scale != 1.0: mask_array = scale_mask_channel(mask_array, channel=2, scale_factor=heart_scale) if left_lung_scale != 1.0: mask_array = scale_mask_channel(mask_array, channel=0, scale_factor=left_lung_scale) if right_lung_scale != 1.0: mask_array = scale_mask_channel(mask_array, channel=1, scale_factor=right_lung_scale) return resolve_overlaps(mask_array, priority=(2, 0, 1)) def preview_mask( mask_image: np.ndarray | None, heart_scale: float, left_lung_scale: float, right_lung_scale: float, ) -> np.ndarray | None: """Live mask preview callback.""" if mask_image is None: return None mask = np.array(Image.fromarray(mask_image).convert("RGB")) scaled = apply_mask_scaling(mask, heart_scale, left_lung_scale, right_lung_scale) return scaled def build_prompt_preview( conditions: list[str], severity: str, age: int, sex: str, view: str, ) -> str: """Build the prompt text for preview.""" cond = ConditionConfig( name="preview", conditions=conditions or [], age=age, sex=sex, view=view, severity=severity if severity != "(none)" else None, ) return build_condition_prompt(cond) @spaces.GPU(duration=120) def generate_cxr( mask_image: np.ndarray | None, heart_scale: float, left_lung_scale: float, right_lung_scale: float, conditions: list[str], severity: str, age: int, sex: str, view: str, num_steps: int, cfg_scale: float, seed: int, progress=gr.Progress(), ): """Generate a CXR, yielding intermediate previews every N steps.""" if mask_image is None: raise gr.Error("Please select or upload a mask first.") pipe = load_fresh_pipeline() if pipe is None: raise gr.Error("Pipeline not loaded. GPU may be unavailable.") # Prepare mask mask = np.array(Image.fromarray(mask_image).convert("RGB")) scaled = apply_mask_scaling(mask, heart_scale, left_lung_scale, right_lung_scale) edit_image = Image.fromarray(scaled) # Build prompt cond = ConditionConfig( name="web_ui", conditions=conditions or [], age=age, sex=sex, view=view, severity=severity if severity != "(none)" else None, ) prompt = build_condition_prompt(cond) # Intermediate preview collector previews: list[Image.Image] = [] class StepCallback: """Custom tqdm-like wrapper that decodes latents every N steps.""" def __init__(self, iterable): self._iterable = iterable self._step = 0 def __iter__(self): for item in self._iterable: progress(self._step / num_steps, desc="Generating CXR...") yield item self._step += 1 def __len__(self): return len(self._iterable) # We patch the pipeline's __call__ to capture inputs_shared reference. # The pipeline stores latents in inputs_shared["latents"] during denoising. _shared_ref: dict = {} _orig_unit_runner = pipe.unit_runner.__class__.__call__ def _patched_runner(self_runner, unit, p, inputs_shared, inputs_posi, inputs_nega): _shared_ref.update(inputs_shared) return _orig_unit_runner(self_runner, unit, p, inputs_shared, inputs_posi, inputs_nega) pipe.unit_runner.__class__.__call__ = _patched_runner try: image = pipe( prompt=prompt, edit_image=edit_image, height=512, width=512, num_inference_steps=num_steps, seed=seed, rand_device=pipe.device, cfg_scale=cfg_scale, edit_image_auto_resize=True, zero_cond_t=True, progress_bar_cmd=StepCallback, ) finally: # Restore original runner pipe.unit_runner.__class__.__call__ = _orig_unit_runner # Yield all collected previews, then the final image for preview in previews: yield preview yield image # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- CUSTOM_CSS = """ /* ── Layout ── */ .gradio-container { max-width: 1280px !important; margin: 0 auto !important; } /* ── Radial gradient background ── */ .main { background: radial-gradient(ellipse 80% 50% at 10% 20%, rgba(99,102,241,0.07), transparent), radial-gradient(ellipse 60% 40% at 85% 75%, rgba(59,130,246,0.05), transparent) !important; } /* ── Header ── */ #component-0 h1 { text-align: center; font-size: 2.2rem !important; font-weight: 800 !important; letter-spacing: -0.5px; background: linear-gradient(135deg, #818cf8, #60a5fa, #818cf8); background-size: 200% 200%; -webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text; animation: gradientShift 4s ease-in-out infinite; padding-bottom: 4px !important; } #component-0 p { text-align: center; color: #94a3b8 !important; font-size: 0.95rem; } @keyframes gradientShift { 0%, 100% { background-position: 0% 50%; } 50% { background-position: 100% 50%; } } /* ── Glass panels ── */ .block { border: 1px solid rgba(99,115,146,0.15) !important; border-radius: 16px !important; backdrop-filter: blur(12px); transition: border-color 0.3s ease, box-shadow 0.3s ease !important; } .block:hover { border-color: rgba(99,102,241,0.25) !important; box-shadow: 0 0 20px rgba(99,102,241,0.06) !important; } /* ── Section headings ── */ .markdown h3 { font-size: 0.78rem !important; font-weight: 700 !important; text-transform: uppercase; letter-spacing: 1.2px; color: #64748b !important; border-bottom: 1px solid rgba(99,115,146,0.12); padding-bottom: 8px !important; margin-bottom: 12px !important; } /* ── Slider styling ── */ input[type="range"] { height: 6px !important; border-radius: 3px !important; background: #1e293b !important; } input[type="range"]::-webkit-slider-thumb { width: 18px !important; height: 18px !important; border-radius: 50% !important; border: 2.5px solid #0a0e17 !important; transition: transform 0.2s ease, box-shadow 0.2s ease !important; } input[type="range"]::-webkit-slider-thumb:hover { transform: scale(1.2) !important; } /* Slider labels */ .block label span { font-weight: 500 !important; font-size: 0.88rem !important; } .block .rangeSlider_value { font-variant-numeric: tabular-nums; font-weight: 600 !important; } /* ── Image panels ── */ .image-frame img, .image-container img { border-radius: 10px !important; transition: opacity 0.3s ease !important; } .image-container { background: rgba(0,0,0,0.2) !important; border-radius: 12px !important; min-height: 380px; } /* ── Generate button ── */ .primary { background: linear-gradient(135deg, #6366f1, #4f46e5, #6366f1) !important; background-size: 200% 200% !important; border: none !important; border-radius: 12px !important; padding: 14px 24px !important; font-weight: 700 !important; font-size: 1rem !important; letter-spacing: 0.3px; transition: all 0.3s cubic-bezier(0.4,0,0.2,1) !important; position: relative; overflow: hidden; } .primary:hover { transform: translateY(-2px) !important; box-shadow: 0 8px 25px rgba(99,102,241,0.4) !important; animation: btnShimmer 1.5s ease-in-out infinite !important; } .primary:active { transform: translateY(0) !important; } @keyframes btnShimmer { 0%, 100% { background-position: 0% 50%; } 50% { background-position: 100% 50%; } } /* ── Secondary buttons ── */ .secondary { border: 1px solid rgba(99,115,146,0.2) !important; border-radius: 10px !important; background: transparent !important; color: #94a3b8 !important; transition: all 0.25s ease !important; } .secondary:hover { border-color: rgba(99,102,241,0.4) !important; color: #e2e8f0 !important; background: rgba(99,102,241,0.06) !important; } /* ── Prompt preview ── */ textarea[readonly], .prose { font-family: 'JetBrains Mono', 'Fira Code', monospace !important; font-size: 0.8rem !important; line-height: 1.6 !important; color: #64748b !important; background: rgba(0,0,0,0.25) !important; border-radius: 10px !important; } /* ── Checkboxes ── */ .checkbox-group label { border-radius: 20px !important; padding: 4px 12px !important; font-size: 0.8rem !important; transition: all 0.2s ease !important; border: 1px solid rgba(99,115,146,0.15) !important; color: #e2e8f0 !important; background: rgba(17,24,39,0.75) !important; } .checkbox-group label span { color: #e2e8f0 !important; } .checkbox-group label:hover { border-color: rgba(99,102,241,0.35) !important; background: rgba(30,41,59,0.9) !important; } .checkbox-group input:checked + label, .checkbox-group label.selected { background: rgba(99,102,241,0.15) !important; border-color: rgba(99,102,241,0.4) !important; color: #c7d2fe !important; } /* ── Dropdowns & inputs ── */ select, input[type="number"] { border-radius: 10px !important; border: 1px solid rgba(99,115,146,0.15) !important; transition: border-color 0.25s ease !important; font-size: 0.88rem !important; } select:focus, input[type="number"]:focus { border-color: rgba(99,102,241,0.5) !important; box-shadow: 0 0 0 2px rgba(99,102,241,0.1) !important; } /* ── Accordion ── */ .accordion { border: 1px solid rgba(99,115,146,0.1) !important; border-radius: 12px !important; background: rgba(0,0,0,0.15) !important; } .accordion > .label-wrap { font-size: 0.82rem !important; color: #64748b !important; font-weight: 500 !important; } /* ── Examples gallery ── */ .gallery-item { border-radius: 10px !important; border: 2px solid rgba(99,115,146,0.15) !important; transition: all 0.25s ease !important; overflow: hidden; } .gallery-item:hover { border-color: rgba(99,102,241,0.4) !important; transform: scale(1.04); box-shadow: 0 4px 16px rgba(99,102,241,0.15) !important; } /* ── Scrollbar ── */ ::-webkit-scrollbar { width: 6px; } ::-webkit-scrollbar-track { background: transparent; } ::-webkit-scrollbar-thumb { background: rgba(99,115,146,0.25); border-radius: 3px; } ::-webkit-scrollbar-thumb:hover { background: rgba(99,115,146,0.4); } /* ── Footer spacing ── */ .gradio-container > .main > .wrap:last-child { padding-bottom: 40px !important; } """ sample_paths = get_sample_masks() THEME = gr.themes.Base( primary_hue=gr.themes.colors.indigo, secondary_hue=gr.themes.colors.slate, neutral_hue=gr.themes.colors.slate, font=gr.themes.GoogleFont("Inter"), font_mono=gr.themes.GoogleFont("JetBrains Mono"), radius_size=gr.themes.sizes.radius_lg, spacing_size=gr.themes.sizes.spacing_md, ).set( # Background body_background_fill="#0a0e17", body_background_fill_dark="#0a0e17", # Panels block_background_fill="rgba(17,24,39,0.75)", block_background_fill_dark="rgba(17,24,39,0.75)", block_border_color="rgba(99,115,146,0.15)", block_border_color_dark="rgba(99,115,146,0.15)", block_shadow="0 4px 24px rgba(0,0,0,0.2)", block_shadow_dark="0 4px 24px rgba(0,0,0,0.2)", # Inputs input_background_fill="#131b2e", input_background_fill_dark="#131b2e", input_border_color="rgba(99,115,146,0.15)", input_border_color_dark="rgba(99,115,146,0.15)", # Buttons button_primary_background_fill="linear-gradient(135deg, #6366f1, #4f46e5)", button_primary_background_fill_dark="linear-gradient(135deg, #6366f1, #4f46e5)", button_primary_text_color="white", button_primary_text_color_dark="white", button_primary_shadow="0 4px 14px rgba(99,102,241,0.25)", button_primary_shadow_dark="0 4px 14px rgba(99,102,241,0.25)", # Text body_text_color="#e2e8f0", body_text_color_dark="#e2e8f0", body_text_color_subdued="#94a3b8", body_text_color_subdued_dark="#94a3b8", # Labels block_label_text_color="#94a3b8", block_label_text_color_dark="#94a3b8", block_title_text_color="#cbd5e1", block_title_text_color_dark="#cbd5e1", # Borders border_color_primary="rgba(99,102,241,0.4)", border_color_primary_dark="rgba(99,102,241,0.4)", ) with gr.Blocks( title="SynthCXR · Chest X-Ray Generator", ) as demo: gr.Markdown( "# 🫁 SynthCXR\n" "Interactively resize anatomical masks and generate realistic chest X-rays" ) with gr.Row(): # ── Left column: Controls ── with gr.Column(scale=1): # Mask input gr.Markdown("### Select Mask") mask_input = gr.Image( label="Conditioning Mask", type="numpy", sources=["upload"], height=240, ) # Sample mask gallery if sample_paths: sample_gallery = gr.Examples( examples=sample_paths, inputs=mask_input, label="Sample Masks", ) # Sliders gr.Markdown("### Mask Scaling") heart_slider = gr.Slider( minimum=0.0, maximum=2.0, step=0.05, value=1.0, label="💙 Heart Scale", ) left_lung_slider = gr.Slider( minimum=0.0, maximum=2.0, step=0.05, value=1.0, label="🔴 Left Lung Scale", ) right_lung_slider = gr.Slider( minimum=0.0, maximum=2.0, step=0.05, value=1.0, label="🟢 Right Lung Scale", ) reset_btn = gr.Button("↺ Reset Scales", variant="secondary", size="sm") # Conditions gr.Markdown("### Conditions") conditions_select = gr.CheckboxGroup( choices=CONDITION_CHOICES, label="Pathologies", ) with gr.Row(): severity_select = gr.Radio( choices=SEVERITY_CHOICES, value="(none)", label="Severity", ) view_select = gr.Radio( choices=["AP", "PA"], value="AP", label="View", ) with gr.Row(): age_input = gr.Number(value=45, label="Age", minimum=0, maximum=120, precision=0) sex_select = gr.Radio( choices=["male", "female"], value="male", label="Sex", ) # Advanced with gr.Accordion("Advanced Settings", open=False): with gr.Row(): steps_input = gr.Number(value=40, label="Steps", minimum=1, maximum=100, precision=0) cfg_input = gr.Number(value=8.0, label="CFG Scale", minimum=1.0, maximum=20.0) with gr.Row(): seed_input = gr.Number(value=42, label="Seed", minimum=0, precision=0) # ── Right column: Outputs ── with gr.Column(scale=2): with gr.Row(): mask_preview = gr.Image( label="Scaled Mask Preview", type="numpy", interactive=False, height=400, ) cxr_output = gr.Image( label="Generated Chest X-Ray", type="pil", interactive=False, height=400, ) # Prompt preview prompt_preview = gr.Textbox( label="Prompt Preview", interactive=False, lines=3, ) generate_btn = gr.Button("⚡ Generate CXR", variant="primary", size="lg") # ── Event wiring ── # Live mask preview on any slider / mask change slider_inputs = [mask_input, heart_slider, left_lung_slider, right_lung_slider] mask_input.change(preview_mask, inputs=slider_inputs, outputs=mask_preview) heart_slider.change(preview_mask, inputs=slider_inputs, outputs=mask_preview) left_lung_slider.change(preview_mask, inputs=slider_inputs, outputs=mask_preview) right_lung_slider.change(preview_mask, inputs=slider_inputs, outputs=mask_preview) # Reset sliders def reset_scales(): return 1.0, 1.0, 1.0 reset_btn.click( reset_scales, outputs=[heart_slider, left_lung_slider, right_lung_slider], ) # Auto-adjust sliders when conditions change _CONDITION_SCALE_MAP = { # condition_key: (heart_delta, lung_delta) "cardiomegaly": (+0.35, 0.0), "enlarged_cardiomediastinum": (+0.25, 0.0), "atelectasis": (0.0, -0.25), "pneumothorax": (0.0, -0.30), "pleural_effusion": (0.0, -0.20), } _SEVERITY_MULTIPLIER = { "(none)": 1.0, "mild": 0.6, "moderate": 1.0, "severe": 1.5, } def sync_sliders(conditions: list[str], severity: str): """Set slider values based on selected conditions + severity.""" heart = 1.0 lung = 1.0 mult = _SEVERITY_MULTIPLIER.get(severity, 1.0) for cond in (conditions or []): h_delta, l_delta = _CONDITION_SCALE_MAP.get(cond, (0.0, 0.0)) heart += h_delta * mult lung += l_delta * mult # Clamp to slider range [0.0, 2.0] heart = round(max(0.0, min(2.0, heart)), 2) lung = round(max(0.0, min(2.0, lung)), 2) return heart, lung, lung conditions_select.change( sync_sliders, inputs=[conditions_select, severity_select], outputs=[heart_slider, left_lung_slider, right_lung_slider], ) severity_select.change( sync_sliders, inputs=[conditions_select, severity_select], outputs=[heart_slider, left_lung_slider, right_lung_slider], ) # Prompt preview on config change prompt_inputs = [conditions_select, severity_select, age_input, sex_select, view_select] for inp in prompt_inputs: inp.change(build_prompt_preview, inputs=prompt_inputs, outputs=prompt_preview) # Generate generate_btn.click( generate_cxr, inputs=[ mask_input, heart_slider, left_lung_slider, right_lung_slider, conditions_select, severity_select, age_input, sex_select, view_select, steps_input, cfg_input, seed_input, ], outputs=cxr_output, ) # --------------------------------------------------------------------------- # Launch (module-level for HuggingFace Spaces compatibility) # --------------------------------------------------------------------------- demo.launch(theme=THEME, css=CUSTOM_CSS)