Spaces:
Running on Zero
Running on Zero
| #!/usr/bin/env python3 | |
| """Gradio app for SynthCXR: interactive mask scaling and CXR generation.""" | |
| from __future__ import annotations | |
| import os | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
| os.environ["DIFFSYNTH_DOWNLOAD_SOURCE"] = "huggingface" | |
| from pathlib import Path | |
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from synthcxr.constants import KNOWN_CONDITIONS | |
| from synthcxr.mask_utils import resolve_overlaps, scale_mask_channel | |
| from synthcxr.prompt import ConditionConfig, build_condition_prompt | |
| # --------------------------------------------------------------------------- | |
| # Paths | |
| # --------------------------------------------------------------------------- | |
| BASE_DIR = Path(__file__).resolve().parent | |
| SAMPLE_MASKS_DIR = BASE_DIR / "static" / "sample_masks" | |
| LORA_DIR = BASE_DIR / "scripts" / "models" / "qwen_image_edit_chexpert_lora" | |
| # --------------------------------------------------------------------------- | |
| # Condition / severity choices | |
| # --------------------------------------------------------------------------- | |
| CONDITION_CHOICES = [ | |
| "enlarged_cardiomediastinum", | |
| "cardiomegaly", | |
| "atelectasis", | |
| "pneumothorax", | |
| "pleural_effusion", | |
| ] | |
| SEVERITY_CHOICES = ["(none)", "mild", "moderate", "severe"] | |
| # --------------------------------------------------------------------------- | |
| # Pipeline loading (fresh on each @spaces.GPU call; model files cached on disk) | |
| # --------------------------------------------------------------------------- | |
| def load_fresh_pipeline(): | |
| """Load the pipeline + LoRA onto the *currently allocated* GPU. | |
| ZeroGPU deallocates GPU memory after each ``@spaces.GPU`` call, so we | |
| cannot cache tensors between calls. However, diffsynth caches the | |
| model files on disk (HF Hub cache), so only tensor loading happens | |
| here β not a full download. | |
| """ | |
| from synthcxr.pipeline import load_lora_weights, load_pipeline | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.bfloat16 | |
| # VRAM_LIMIT (in GB): enables model offloading for memory-constrained GPUs | |
| vram_limit_str = os.environ.get("VRAM_LIMIT", "") | |
| vram_limit = float(vram_limit_str) if vram_limit_str else None | |
| print(f"[INFO] Loading QwenImagePipeline (device={device}, dtype={dtype}, vram_limit={vram_limit}) β¦") | |
| pipe = load_pipeline(device, dtype, vram_limit=vram_limit) | |
| # LORA_EPOCH env var: which epoch checkpoint to load (default: 2) | |
| lora_epoch = os.environ.get("LORA_EPOCH", "2") | |
| lora = LORA_DIR / f"epoch-{lora_epoch}.safetensors" | |
| if not lora.exists(): | |
| candidates = sorted(LORA_DIR.glob("*.safetensors")) if LORA_DIR.exists() else [] | |
| if candidates: | |
| lora = candidates[-1] | |
| print(f"[WARN] epoch-{lora_epoch} not found, falling back to {lora.name}") | |
| else: | |
| print("[WARN] No LoRA checkpoint found β running base model only.") | |
| return pipe | |
| print(f"[INFO] Loading LoRA from {lora}") | |
| load_lora_weights(pipe, lora) | |
| print("[INFO] Pipeline ready.") | |
| return pipe | |
| # --------------------------------------------------------------------------- | |
| # Sample masks | |
| # --------------------------------------------------------------------------- | |
| def get_sample_masks() -> list[str]: | |
| """Return paths of bundled sample masks.""" | |
| if not SAMPLE_MASKS_DIR.exists(): | |
| return [] | |
| return sorted(str(p) for p in SAMPLE_MASKS_DIR.glob("*.png")) | |
| # --------------------------------------------------------------------------- | |
| # Core functions | |
| # --------------------------------------------------------------------------- | |
| def apply_mask_scaling( | |
| mask_array: np.ndarray, | |
| heart_scale: float, | |
| left_lung_scale: float, | |
| right_lung_scale: float, | |
| ) -> np.ndarray: | |
| """Scale mask channels and resolve overlaps.""" | |
| if heart_scale != 1.0: | |
| mask_array = scale_mask_channel(mask_array, channel=2, scale_factor=heart_scale) | |
| if left_lung_scale != 1.0: | |
| mask_array = scale_mask_channel(mask_array, channel=0, scale_factor=left_lung_scale) | |
| if right_lung_scale != 1.0: | |
| mask_array = scale_mask_channel(mask_array, channel=1, scale_factor=right_lung_scale) | |
| return resolve_overlaps(mask_array, priority=(2, 0, 1)) | |
| def preview_mask( | |
| mask_image: np.ndarray | None, | |
| heart_scale: float, | |
| left_lung_scale: float, | |
| right_lung_scale: float, | |
| ) -> np.ndarray | None: | |
| """Live mask preview callback.""" | |
| if mask_image is None: | |
| return None | |
| mask = np.array(Image.fromarray(mask_image).convert("RGB")) | |
| scaled = apply_mask_scaling(mask, heart_scale, left_lung_scale, right_lung_scale) | |
| return scaled | |
| def build_prompt_preview( | |
| conditions: list[str], | |
| severity: str, | |
| age: int, | |
| sex: str, | |
| view: str, | |
| ) -> str: | |
| """Build the prompt text for preview.""" | |
| cond = ConditionConfig( | |
| name="preview", | |
| conditions=conditions or [], | |
| age=age, | |
| sex=sex, | |
| view=view, | |
| severity=severity if severity != "(none)" else None, | |
| ) | |
| return build_condition_prompt(cond) | |
| def generate_cxr( | |
| mask_image: np.ndarray | None, | |
| heart_scale: float, | |
| left_lung_scale: float, | |
| right_lung_scale: float, | |
| conditions: list[str], | |
| severity: str, | |
| age: int, | |
| sex: str, | |
| view: str, | |
| num_steps: int, | |
| cfg_scale: float, | |
| seed: int, | |
| progress=gr.Progress(), | |
| ): | |
| """Generate a CXR, yielding intermediate previews every N steps.""" | |
| if mask_image is None: | |
| raise gr.Error("Please select or upload a mask first.") | |
| pipe = load_fresh_pipeline() | |
| if pipe is None: | |
| raise gr.Error("Pipeline not loaded. GPU may be unavailable.") | |
| # Prepare mask | |
| mask = np.array(Image.fromarray(mask_image).convert("RGB")) | |
| scaled = apply_mask_scaling(mask, heart_scale, left_lung_scale, right_lung_scale) | |
| edit_image = Image.fromarray(scaled) | |
| # Build prompt | |
| cond = ConditionConfig( | |
| name="web_ui", | |
| conditions=conditions or [], | |
| age=age, | |
| sex=sex, | |
| view=view, | |
| severity=severity if severity != "(none)" else None, | |
| ) | |
| prompt = build_condition_prompt(cond) | |
| # Intermediate preview collector | |
| previews: list[Image.Image] = [] | |
| class StepCallback: | |
| """Custom tqdm-like wrapper that decodes latents every N steps.""" | |
| def __init__(self, iterable): | |
| self._iterable = iterable | |
| self._step = 0 | |
| def __iter__(self): | |
| for item in self._iterable: | |
| progress(self._step / num_steps, desc="Generating CXR...") | |
| yield item | |
| self._step += 1 | |
| def __len__(self): | |
| return len(self._iterable) | |
| # We patch the pipeline's __call__ to capture inputs_shared reference. | |
| # The pipeline stores latents in inputs_shared["latents"] during denoising. | |
| _shared_ref: dict = {} | |
| _orig_unit_runner = pipe.unit_runner.__class__.__call__ | |
| def _patched_runner(self_runner, unit, p, inputs_shared, inputs_posi, inputs_nega): | |
| _shared_ref.update(inputs_shared) | |
| return _orig_unit_runner(self_runner, unit, p, inputs_shared, inputs_posi, inputs_nega) | |
| pipe.unit_runner.__class__.__call__ = _patched_runner | |
| try: | |
| image = pipe( | |
| prompt=prompt, | |
| edit_image=edit_image, | |
| height=512, | |
| width=512, | |
| num_inference_steps=num_steps, | |
| seed=seed, | |
| rand_device=pipe.device, | |
| cfg_scale=cfg_scale, | |
| edit_image_auto_resize=True, | |
| zero_cond_t=True, | |
| progress_bar_cmd=StepCallback, | |
| ) | |
| finally: | |
| # Restore original runner | |
| pipe.unit_runner.__class__.__call__ = _orig_unit_runner | |
| # Yield all collected previews, then the final image | |
| for preview in previews: | |
| yield preview | |
| yield image | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| CUSTOM_CSS = """ | |
| /* ββ Layout ββ */ | |
| .gradio-container { | |
| max-width: 1280px !important; | |
| margin: 0 auto !important; | |
| } | |
| /* ββ Radial gradient background ββ */ | |
| .main { | |
| background: | |
| radial-gradient(ellipse 80% 50% at 10% 20%, rgba(99,102,241,0.07), transparent), | |
| radial-gradient(ellipse 60% 40% at 85% 75%, rgba(59,130,246,0.05), transparent) !important; | |
| } | |
| /* ββ Header ββ */ | |
| #component-0 h1 { | |
| text-align: center; | |
| font-size: 2.2rem !important; | |
| font-weight: 800 !important; | |
| letter-spacing: -0.5px; | |
| background: linear-gradient(135deg, #818cf8, #60a5fa, #818cf8); | |
| background-size: 200% 200%; | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| background-clip: text; | |
| animation: gradientShift 4s ease-in-out infinite; | |
| padding-bottom: 4px !important; | |
| } | |
| #component-0 p { | |
| text-align: center; | |
| color: #94a3b8 !important; | |
| font-size: 0.95rem; | |
| } | |
| @keyframes gradientShift { | |
| 0%, 100% { background-position: 0% 50%; } | |
| 50% { background-position: 100% 50%; } | |
| } | |
| /* ββ Glass panels ββ */ | |
| .block { | |
| border: 1px solid rgba(99,115,146,0.15) !important; | |
| border-radius: 16px !important; | |
| backdrop-filter: blur(12px); | |
| transition: border-color 0.3s ease, box-shadow 0.3s ease !important; | |
| } | |
| .block:hover { | |
| border-color: rgba(99,102,241,0.25) !important; | |
| box-shadow: 0 0 20px rgba(99,102,241,0.06) !important; | |
| } | |
| /* ββ Section headings ββ */ | |
| .markdown h3 { | |
| font-size: 0.78rem !important; | |
| font-weight: 700 !important; | |
| text-transform: uppercase; | |
| letter-spacing: 1.2px; | |
| color: #64748b !important; | |
| border-bottom: 1px solid rgba(99,115,146,0.12); | |
| padding-bottom: 8px !important; | |
| margin-bottom: 12px !important; | |
| } | |
| /* ββ Slider styling ββ */ | |
| input[type="range"] { | |
| height: 6px !important; | |
| border-radius: 3px !important; | |
| background: #1e293b !important; | |
| } | |
| input[type="range"]::-webkit-slider-thumb { | |
| width: 18px !important; | |
| height: 18px !important; | |
| border-radius: 50% !important; | |
| border: 2.5px solid #0a0e17 !important; | |
| transition: transform 0.2s ease, box-shadow 0.2s ease !important; | |
| } | |
| input[type="range"]::-webkit-slider-thumb:hover { | |
| transform: scale(1.2) !important; | |
| } | |
| /* Slider labels */ | |
| .block label span { | |
| font-weight: 500 !important; | |
| font-size: 0.88rem !important; | |
| } | |
| .block .rangeSlider_value { | |
| font-variant-numeric: tabular-nums; | |
| font-weight: 600 !important; | |
| } | |
| /* ββ Image panels ββ */ | |
| .image-frame img, .image-container img { | |
| border-radius: 10px !important; | |
| transition: opacity 0.3s ease !important; | |
| } | |
| .image-container { | |
| background: rgba(0,0,0,0.2) !important; | |
| border-radius: 12px !important; | |
| min-height: 380px; | |
| } | |
| /* ββ Generate button ββ */ | |
| .primary { | |
| background: linear-gradient(135deg, #6366f1, #4f46e5, #6366f1) !important; | |
| background-size: 200% 200% !important; | |
| border: none !important; | |
| border-radius: 12px !important; | |
| padding: 14px 24px !important; | |
| font-weight: 700 !important; | |
| font-size: 1rem !important; | |
| letter-spacing: 0.3px; | |
| transition: all 0.3s cubic-bezier(0.4,0,0.2,1) !important; | |
| position: relative; | |
| overflow: hidden; | |
| } | |
| .primary:hover { | |
| transform: translateY(-2px) !important; | |
| box-shadow: 0 8px 25px rgba(99,102,241,0.4) !important; | |
| animation: btnShimmer 1.5s ease-in-out infinite !important; | |
| } | |
| .primary:active { | |
| transform: translateY(0) !important; | |
| } | |
| @keyframes btnShimmer { | |
| 0%, 100% { background-position: 0% 50%; } | |
| 50% { background-position: 100% 50%; } | |
| } | |
| /* ββ Secondary buttons ββ */ | |
| .secondary { | |
| border: 1px solid rgba(99,115,146,0.2) !important; | |
| border-radius: 10px !important; | |
| background: transparent !important; | |
| color: #94a3b8 !important; | |
| transition: all 0.25s ease !important; | |
| } | |
| .secondary:hover { | |
| border-color: rgba(99,102,241,0.4) !important; | |
| color: #e2e8f0 !important; | |
| background: rgba(99,102,241,0.06) !important; | |
| } | |
| /* ββ Prompt preview ββ */ | |
| textarea[readonly], .prose { | |
| font-family: 'JetBrains Mono', 'Fira Code', monospace !important; | |
| font-size: 0.8rem !important; | |
| line-height: 1.6 !important; | |
| color: #64748b !important; | |
| background: rgba(0,0,0,0.25) !important; | |
| border-radius: 10px !important; | |
| } | |
| /* ββ Checkboxes ββ */ | |
| .checkbox-group label { | |
| border-radius: 20px !important; | |
| padding: 4px 12px !important; | |
| font-size: 0.8rem !important; | |
| transition: all 0.2s ease !important; | |
| border: 1px solid rgba(99,115,146,0.15) !important; | |
| color: #e2e8f0 !important; | |
| background: rgba(17,24,39,0.75) !important; | |
| } | |
| .checkbox-group label span { | |
| color: #e2e8f0 !important; | |
| } | |
| .checkbox-group label:hover { | |
| border-color: rgba(99,102,241,0.35) !important; | |
| background: rgba(30,41,59,0.9) !important; | |
| } | |
| .checkbox-group input:checked + label, | |
| .checkbox-group label.selected { | |
| background: rgba(99,102,241,0.15) !important; | |
| border-color: rgba(99,102,241,0.4) !important; | |
| color: #c7d2fe !important; | |
| } | |
| /* ββ Dropdowns & inputs ββ */ | |
| select, input[type="number"] { | |
| border-radius: 10px !important; | |
| border: 1px solid rgba(99,115,146,0.15) !important; | |
| transition: border-color 0.25s ease !important; | |
| font-size: 0.88rem !important; | |
| } | |
| select:focus, input[type="number"]:focus { | |
| border-color: rgba(99,102,241,0.5) !important; | |
| box-shadow: 0 0 0 2px rgba(99,102,241,0.1) !important; | |
| } | |
| /* ββ Accordion ββ */ | |
| .accordion { | |
| border: 1px solid rgba(99,115,146,0.1) !important; | |
| border-radius: 12px !important; | |
| background: rgba(0,0,0,0.15) !important; | |
| } | |
| .accordion > .label-wrap { | |
| font-size: 0.82rem !important; | |
| color: #64748b !important; | |
| font-weight: 500 !important; | |
| } | |
| /* ββ Examples gallery ββ */ | |
| .gallery-item { | |
| border-radius: 10px !important; | |
| border: 2px solid rgba(99,115,146,0.15) !important; | |
| transition: all 0.25s ease !important; | |
| overflow: hidden; | |
| } | |
| .gallery-item:hover { | |
| border-color: rgba(99,102,241,0.4) !important; | |
| transform: scale(1.04); | |
| box-shadow: 0 4px 16px rgba(99,102,241,0.15) !important; | |
| } | |
| /* ββ Scrollbar ββ */ | |
| ::-webkit-scrollbar { width: 6px; } | |
| ::-webkit-scrollbar-track { background: transparent; } | |
| ::-webkit-scrollbar-thumb { | |
| background: rgba(99,115,146,0.25); | |
| border-radius: 3px; | |
| } | |
| ::-webkit-scrollbar-thumb:hover { background: rgba(99,115,146,0.4); } | |
| /* ββ Footer spacing ββ */ | |
| .gradio-container > .main > .wrap:last-child { padding-bottom: 40px !important; } | |
| """ | |
| sample_paths = get_sample_masks() | |
| THEME = gr.themes.Base( | |
| primary_hue=gr.themes.colors.indigo, | |
| secondary_hue=gr.themes.colors.slate, | |
| neutral_hue=gr.themes.colors.slate, | |
| font=gr.themes.GoogleFont("Inter"), | |
| font_mono=gr.themes.GoogleFont("JetBrains Mono"), | |
| radius_size=gr.themes.sizes.radius_lg, | |
| spacing_size=gr.themes.sizes.spacing_md, | |
| ).set( | |
| # Background | |
| body_background_fill="#0a0e17", | |
| body_background_fill_dark="#0a0e17", | |
| # Panels | |
| block_background_fill="rgba(17,24,39,0.75)", | |
| block_background_fill_dark="rgba(17,24,39,0.75)", | |
| block_border_color="rgba(99,115,146,0.15)", | |
| block_border_color_dark="rgba(99,115,146,0.15)", | |
| block_shadow="0 4px 24px rgba(0,0,0,0.2)", | |
| block_shadow_dark="0 4px 24px rgba(0,0,0,0.2)", | |
| # Inputs | |
| input_background_fill="#131b2e", | |
| input_background_fill_dark="#131b2e", | |
| input_border_color="rgba(99,115,146,0.15)", | |
| input_border_color_dark="rgba(99,115,146,0.15)", | |
| # Buttons | |
| button_primary_background_fill="linear-gradient(135deg, #6366f1, #4f46e5)", | |
| button_primary_background_fill_dark="linear-gradient(135deg, #6366f1, #4f46e5)", | |
| button_primary_text_color="white", | |
| button_primary_text_color_dark="white", | |
| button_primary_shadow="0 4px 14px rgba(99,102,241,0.25)", | |
| button_primary_shadow_dark="0 4px 14px rgba(99,102,241,0.25)", | |
| # Text | |
| body_text_color="#e2e8f0", | |
| body_text_color_dark="#e2e8f0", | |
| body_text_color_subdued="#94a3b8", | |
| body_text_color_subdued_dark="#94a3b8", | |
| # Labels | |
| block_label_text_color="#94a3b8", | |
| block_label_text_color_dark="#94a3b8", | |
| block_title_text_color="#cbd5e1", | |
| block_title_text_color_dark="#cbd5e1", | |
| # Borders | |
| border_color_primary="rgba(99,102,241,0.4)", | |
| border_color_primary_dark="rgba(99,102,241,0.4)", | |
| ) | |
| with gr.Blocks( | |
| title="SynthCXR Β· Chest X-Ray Generator", | |
| ) as demo: | |
| gr.Markdown( | |
| "# π« SynthCXR\n" | |
| "Interactively resize anatomical masks and generate realistic chest X-rays" | |
| ) | |
| with gr.Row(): | |
| # ββ Left column: Controls ββ | |
| with gr.Column(scale=1): | |
| # Mask input | |
| gr.Markdown("### Select Mask") | |
| mask_input = gr.Image( | |
| label="Conditioning Mask", | |
| type="numpy", | |
| sources=["upload"], | |
| height=240, | |
| ) | |
| # Sample mask gallery | |
| if sample_paths: | |
| sample_gallery = gr.Examples( | |
| examples=sample_paths, | |
| inputs=mask_input, | |
| label="Sample Masks", | |
| ) | |
| # Sliders | |
| gr.Markdown("### Mask Scaling") | |
| heart_slider = gr.Slider( | |
| minimum=0.0, maximum=2.0, step=0.05, value=1.0, | |
| label="π Heart Scale", | |
| ) | |
| left_lung_slider = gr.Slider( | |
| minimum=0.0, maximum=2.0, step=0.05, value=1.0, | |
| label="π΄ Left Lung Scale", | |
| ) | |
| right_lung_slider = gr.Slider( | |
| minimum=0.0, maximum=2.0, step=0.05, value=1.0, | |
| label="π’ Right Lung Scale", | |
| ) | |
| reset_btn = gr.Button("βΊ Reset Scales", variant="secondary", size="sm") | |
| # Conditions | |
| gr.Markdown("### Conditions") | |
| conditions_select = gr.CheckboxGroup( | |
| choices=CONDITION_CHOICES, | |
| label="Pathologies", | |
| ) | |
| with gr.Row(): | |
| severity_select = gr.Radio( | |
| choices=SEVERITY_CHOICES, value="(none)", label="Severity", | |
| ) | |
| view_select = gr.Radio( | |
| choices=["AP", "PA"], value="AP", label="View", | |
| ) | |
| with gr.Row(): | |
| age_input = gr.Number(value=45, label="Age", minimum=0, maximum=120, precision=0) | |
| sex_select = gr.Radio( | |
| choices=["male", "female"], value="male", label="Sex", | |
| ) | |
| # Advanced | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| steps_input = gr.Number(value=40, label="Steps", minimum=1, maximum=100, precision=0) | |
| cfg_input = gr.Number(value=8.0, label="CFG Scale", minimum=1.0, maximum=20.0) | |
| with gr.Row(): | |
| seed_input = gr.Number(value=42, label="Seed", minimum=0, precision=0) | |
| # ββ Right column: Outputs ββ | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| mask_preview = gr.Image( | |
| label="Scaled Mask Preview", | |
| type="numpy", | |
| interactive=False, | |
| height=400, | |
| ) | |
| cxr_output = gr.Image( | |
| label="Generated Chest X-Ray", | |
| type="pil", | |
| interactive=False, | |
| height=400, | |
| ) | |
| # Prompt preview | |
| prompt_preview = gr.Textbox( | |
| label="Prompt Preview", | |
| interactive=False, | |
| lines=3, | |
| ) | |
| generate_btn = gr.Button("β‘ Generate CXR", variant="primary", size="lg") | |
| # ββ Event wiring ββ | |
| # Live mask preview on any slider / mask change | |
| slider_inputs = [mask_input, heart_slider, left_lung_slider, right_lung_slider] | |
| mask_input.change(preview_mask, inputs=slider_inputs, outputs=mask_preview) | |
| heart_slider.change(preview_mask, inputs=slider_inputs, outputs=mask_preview) | |
| left_lung_slider.change(preview_mask, inputs=slider_inputs, outputs=mask_preview) | |
| right_lung_slider.change(preview_mask, inputs=slider_inputs, outputs=mask_preview) | |
| # Reset sliders | |
| def reset_scales(): | |
| return 1.0, 1.0, 1.0 | |
| reset_btn.click( | |
| reset_scales, | |
| outputs=[heart_slider, left_lung_slider, right_lung_slider], | |
| ) | |
| # Auto-adjust sliders when conditions change | |
| _CONDITION_SCALE_MAP = { | |
| # condition_key: (heart_delta, lung_delta) | |
| "cardiomegaly": (+0.35, 0.0), | |
| "enlarged_cardiomediastinum": (+0.25, 0.0), | |
| "atelectasis": (0.0, -0.25), | |
| "pneumothorax": (0.0, -0.30), | |
| "pleural_effusion": (0.0, -0.20), | |
| } | |
| _SEVERITY_MULTIPLIER = { | |
| "(none)": 1.0, | |
| "mild": 0.6, | |
| "moderate": 1.0, | |
| "severe": 1.5, | |
| } | |
| def sync_sliders(conditions: list[str], severity: str): | |
| """Set slider values based on selected conditions + severity.""" | |
| heart = 1.0 | |
| lung = 1.0 | |
| mult = _SEVERITY_MULTIPLIER.get(severity, 1.0) | |
| for cond in (conditions or []): | |
| h_delta, l_delta = _CONDITION_SCALE_MAP.get(cond, (0.0, 0.0)) | |
| heart += h_delta * mult | |
| lung += l_delta * mult | |
| # Clamp to slider range [0.0, 2.0] | |
| heart = round(max(0.0, min(2.0, heart)), 2) | |
| lung = round(max(0.0, min(2.0, lung)), 2) | |
| return heart, lung, lung | |
| conditions_select.change( | |
| sync_sliders, | |
| inputs=[conditions_select, severity_select], | |
| outputs=[heart_slider, left_lung_slider, right_lung_slider], | |
| ) | |
| severity_select.change( | |
| sync_sliders, | |
| inputs=[conditions_select, severity_select], | |
| outputs=[heart_slider, left_lung_slider, right_lung_slider], | |
| ) | |
| # Prompt preview on config change | |
| prompt_inputs = [conditions_select, severity_select, age_input, sex_select, view_select] | |
| for inp in prompt_inputs: | |
| inp.change(build_prompt_preview, inputs=prompt_inputs, outputs=prompt_preview) | |
| # Generate | |
| generate_btn.click( | |
| generate_cxr, | |
| inputs=[ | |
| mask_input, | |
| heart_slider, left_lung_slider, right_lung_slider, | |
| conditions_select, severity_select, | |
| age_input, sex_select, view_select, | |
| steps_input, cfg_input, seed_input, | |
| ], | |
| outputs=cxr_output, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Launch (module-level for HuggingFace Spaces compatibility) | |
| # --------------------------------------------------------------------------- | |
| demo.launch(theme=THEME, css=CUSTOM_CSS) | |