SynthCXR / app.py
gradientguild's picture
Upload folder using huggingface_hub
c106be7 verified
#!/usr/bin/env python3
"""Gradio app for SynthCXR: interactive mask scaling and CXR generation."""
from __future__ import annotations
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["DIFFSYNTH_DOWNLOAD_SOURCE"] = "huggingface"
from pathlib import Path
import spaces
import gradio as gr
import numpy as np
import torch
from PIL import Image
from synthcxr.constants import KNOWN_CONDITIONS
from synthcxr.mask_utils import resolve_overlaps, scale_mask_channel
from synthcxr.prompt import ConditionConfig, build_condition_prompt
# ---------------------------------------------------------------------------
# Paths
# ---------------------------------------------------------------------------
BASE_DIR = Path(__file__).resolve().parent
SAMPLE_MASKS_DIR = BASE_DIR / "static" / "sample_masks"
LORA_DIR = BASE_DIR / "scripts" / "models" / "qwen_image_edit_chexpert_lora"
# ---------------------------------------------------------------------------
# Condition / severity choices
# ---------------------------------------------------------------------------
CONDITION_CHOICES = [
"enlarged_cardiomediastinum",
"cardiomegaly",
"atelectasis",
"pneumothorax",
"pleural_effusion",
]
SEVERITY_CHOICES = ["(none)", "mild", "moderate", "severe"]
# ---------------------------------------------------------------------------
# Pipeline loading (fresh on each @spaces.GPU call; model files cached on disk)
# ---------------------------------------------------------------------------
def load_fresh_pipeline():
"""Load the pipeline + LoRA onto the *currently allocated* GPU.
ZeroGPU deallocates GPU memory after each ``@spaces.GPU`` call, so we
cannot cache tensors between calls. However, diffsynth caches the
model files on disk (HF Hub cache), so only tensor loading happens
here β€” not a full download.
"""
from synthcxr.pipeline import load_lora_weights, load_pipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16
# VRAM_LIMIT (in GB): enables model offloading for memory-constrained GPUs
vram_limit_str = os.environ.get("VRAM_LIMIT", "")
vram_limit = float(vram_limit_str) if vram_limit_str else None
print(f"[INFO] Loading QwenImagePipeline (device={device}, dtype={dtype}, vram_limit={vram_limit}) …")
pipe = load_pipeline(device, dtype, vram_limit=vram_limit)
# LORA_EPOCH env var: which epoch checkpoint to load (default: 2)
lora_epoch = os.environ.get("LORA_EPOCH", "2")
lora = LORA_DIR / f"epoch-{lora_epoch}.safetensors"
if not lora.exists():
candidates = sorted(LORA_DIR.glob("*.safetensors")) if LORA_DIR.exists() else []
if candidates:
lora = candidates[-1]
print(f"[WARN] epoch-{lora_epoch} not found, falling back to {lora.name}")
else:
print("[WARN] No LoRA checkpoint found – running base model only.")
return pipe
print(f"[INFO] Loading LoRA from {lora}")
load_lora_weights(pipe, lora)
print("[INFO] Pipeline ready.")
return pipe
# ---------------------------------------------------------------------------
# Sample masks
# ---------------------------------------------------------------------------
def get_sample_masks() -> list[str]:
"""Return paths of bundled sample masks."""
if not SAMPLE_MASKS_DIR.exists():
return []
return sorted(str(p) for p in SAMPLE_MASKS_DIR.glob("*.png"))
# ---------------------------------------------------------------------------
# Core functions
# ---------------------------------------------------------------------------
def apply_mask_scaling(
mask_array: np.ndarray,
heart_scale: float,
left_lung_scale: float,
right_lung_scale: float,
) -> np.ndarray:
"""Scale mask channels and resolve overlaps."""
if heart_scale != 1.0:
mask_array = scale_mask_channel(mask_array, channel=2, scale_factor=heart_scale)
if left_lung_scale != 1.0:
mask_array = scale_mask_channel(mask_array, channel=0, scale_factor=left_lung_scale)
if right_lung_scale != 1.0:
mask_array = scale_mask_channel(mask_array, channel=1, scale_factor=right_lung_scale)
return resolve_overlaps(mask_array, priority=(2, 0, 1))
def preview_mask(
mask_image: np.ndarray | None,
heart_scale: float,
left_lung_scale: float,
right_lung_scale: float,
) -> np.ndarray | None:
"""Live mask preview callback."""
if mask_image is None:
return None
mask = np.array(Image.fromarray(mask_image).convert("RGB"))
scaled = apply_mask_scaling(mask, heart_scale, left_lung_scale, right_lung_scale)
return scaled
def build_prompt_preview(
conditions: list[str],
severity: str,
age: int,
sex: str,
view: str,
) -> str:
"""Build the prompt text for preview."""
cond = ConditionConfig(
name="preview",
conditions=conditions or [],
age=age,
sex=sex,
view=view,
severity=severity if severity != "(none)" else None,
)
return build_condition_prompt(cond)
@spaces.GPU(duration=120)
def generate_cxr(
mask_image: np.ndarray | None,
heart_scale: float,
left_lung_scale: float,
right_lung_scale: float,
conditions: list[str],
severity: str,
age: int,
sex: str,
view: str,
num_steps: int,
cfg_scale: float,
seed: int,
progress=gr.Progress(),
):
"""Generate a CXR, yielding intermediate previews every N steps."""
if mask_image is None:
raise gr.Error("Please select or upload a mask first.")
pipe = load_fresh_pipeline()
if pipe is None:
raise gr.Error("Pipeline not loaded. GPU may be unavailable.")
# Prepare mask
mask = np.array(Image.fromarray(mask_image).convert("RGB"))
scaled = apply_mask_scaling(mask, heart_scale, left_lung_scale, right_lung_scale)
edit_image = Image.fromarray(scaled)
# Build prompt
cond = ConditionConfig(
name="web_ui",
conditions=conditions or [],
age=age,
sex=sex,
view=view,
severity=severity if severity != "(none)" else None,
)
prompt = build_condition_prompt(cond)
# Intermediate preview collector
previews: list[Image.Image] = []
class StepCallback:
"""Custom tqdm-like wrapper that decodes latents every N steps."""
def __init__(self, iterable):
self._iterable = iterable
self._step = 0
def __iter__(self):
for item in self._iterable:
progress(self._step / num_steps, desc="Generating CXR...")
yield item
self._step += 1
def __len__(self):
return len(self._iterable)
# We patch the pipeline's __call__ to capture inputs_shared reference.
# The pipeline stores latents in inputs_shared["latents"] during denoising.
_shared_ref: dict = {}
_orig_unit_runner = pipe.unit_runner.__class__.__call__
def _patched_runner(self_runner, unit, p, inputs_shared, inputs_posi, inputs_nega):
_shared_ref.update(inputs_shared)
return _orig_unit_runner(self_runner, unit, p, inputs_shared, inputs_posi, inputs_nega)
pipe.unit_runner.__class__.__call__ = _patched_runner
try:
image = pipe(
prompt=prompt,
edit_image=edit_image,
height=512,
width=512,
num_inference_steps=num_steps,
seed=seed,
rand_device=pipe.device,
cfg_scale=cfg_scale,
edit_image_auto_resize=True,
zero_cond_t=True,
progress_bar_cmd=StepCallback,
)
finally:
# Restore original runner
pipe.unit_runner.__class__.__call__ = _orig_unit_runner
# Yield all collected previews, then the final image
for preview in previews:
yield preview
yield image
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
CUSTOM_CSS = """
/* ── Layout ── */
.gradio-container {
max-width: 1280px !important;
margin: 0 auto !important;
}
/* ── Radial gradient background ── */
.main {
background:
radial-gradient(ellipse 80% 50% at 10% 20%, rgba(99,102,241,0.07), transparent),
radial-gradient(ellipse 60% 40% at 85% 75%, rgba(59,130,246,0.05), transparent) !important;
}
/* ── Header ── */
#component-0 h1 {
text-align: center;
font-size: 2.2rem !important;
font-weight: 800 !important;
letter-spacing: -0.5px;
background: linear-gradient(135deg, #818cf8, #60a5fa, #818cf8);
background-size: 200% 200%;
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
animation: gradientShift 4s ease-in-out infinite;
padding-bottom: 4px !important;
}
#component-0 p {
text-align: center;
color: #94a3b8 !important;
font-size: 0.95rem;
}
@keyframes gradientShift {
0%, 100% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
}
/* ── Glass panels ── */
.block {
border: 1px solid rgba(99,115,146,0.15) !important;
border-radius: 16px !important;
backdrop-filter: blur(12px);
transition: border-color 0.3s ease, box-shadow 0.3s ease !important;
}
.block:hover {
border-color: rgba(99,102,241,0.25) !important;
box-shadow: 0 0 20px rgba(99,102,241,0.06) !important;
}
/* ── Section headings ── */
.markdown h3 {
font-size: 0.78rem !important;
font-weight: 700 !important;
text-transform: uppercase;
letter-spacing: 1.2px;
color: #64748b !important;
border-bottom: 1px solid rgba(99,115,146,0.12);
padding-bottom: 8px !important;
margin-bottom: 12px !important;
}
/* ── Slider styling ── */
input[type="range"] {
height: 6px !important;
border-radius: 3px !important;
background: #1e293b !important;
}
input[type="range"]::-webkit-slider-thumb {
width: 18px !important;
height: 18px !important;
border-radius: 50% !important;
border: 2.5px solid #0a0e17 !important;
transition: transform 0.2s ease, box-shadow 0.2s ease !important;
}
input[type="range"]::-webkit-slider-thumb:hover {
transform: scale(1.2) !important;
}
/* Slider labels */
.block label span {
font-weight: 500 !important;
font-size: 0.88rem !important;
}
.block .rangeSlider_value {
font-variant-numeric: tabular-nums;
font-weight: 600 !important;
}
/* ── Image panels ── */
.image-frame img, .image-container img {
border-radius: 10px !important;
transition: opacity 0.3s ease !important;
}
.image-container {
background: rgba(0,0,0,0.2) !important;
border-radius: 12px !important;
min-height: 380px;
}
/* ── Generate button ── */
.primary {
background: linear-gradient(135deg, #6366f1, #4f46e5, #6366f1) !important;
background-size: 200% 200% !important;
border: none !important;
border-radius: 12px !important;
padding: 14px 24px !important;
font-weight: 700 !important;
font-size: 1rem !important;
letter-spacing: 0.3px;
transition: all 0.3s cubic-bezier(0.4,0,0.2,1) !important;
position: relative;
overflow: hidden;
}
.primary:hover {
transform: translateY(-2px) !important;
box-shadow: 0 8px 25px rgba(99,102,241,0.4) !important;
animation: btnShimmer 1.5s ease-in-out infinite !important;
}
.primary:active {
transform: translateY(0) !important;
}
@keyframes btnShimmer {
0%, 100% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
}
/* ── Secondary buttons ── */
.secondary {
border: 1px solid rgba(99,115,146,0.2) !important;
border-radius: 10px !important;
background: transparent !important;
color: #94a3b8 !important;
transition: all 0.25s ease !important;
}
.secondary:hover {
border-color: rgba(99,102,241,0.4) !important;
color: #e2e8f0 !important;
background: rgba(99,102,241,0.06) !important;
}
/* ── Prompt preview ── */
textarea[readonly], .prose {
font-family: 'JetBrains Mono', 'Fira Code', monospace !important;
font-size: 0.8rem !important;
line-height: 1.6 !important;
color: #64748b !important;
background: rgba(0,0,0,0.25) !important;
border-radius: 10px !important;
}
/* ── Checkboxes ── */
.checkbox-group label {
border-radius: 20px !important;
padding: 4px 12px !important;
font-size: 0.8rem !important;
transition: all 0.2s ease !important;
border: 1px solid rgba(99,115,146,0.15) !important;
color: #e2e8f0 !important;
background: rgba(17,24,39,0.75) !important;
}
.checkbox-group label span {
color: #e2e8f0 !important;
}
.checkbox-group label:hover {
border-color: rgba(99,102,241,0.35) !important;
background: rgba(30,41,59,0.9) !important;
}
.checkbox-group input:checked + label,
.checkbox-group label.selected {
background: rgba(99,102,241,0.15) !important;
border-color: rgba(99,102,241,0.4) !important;
color: #c7d2fe !important;
}
/* ── Dropdowns & inputs ── */
select, input[type="number"] {
border-radius: 10px !important;
border: 1px solid rgba(99,115,146,0.15) !important;
transition: border-color 0.25s ease !important;
font-size: 0.88rem !important;
}
select:focus, input[type="number"]:focus {
border-color: rgba(99,102,241,0.5) !important;
box-shadow: 0 0 0 2px rgba(99,102,241,0.1) !important;
}
/* ── Accordion ── */
.accordion {
border: 1px solid rgba(99,115,146,0.1) !important;
border-radius: 12px !important;
background: rgba(0,0,0,0.15) !important;
}
.accordion > .label-wrap {
font-size: 0.82rem !important;
color: #64748b !important;
font-weight: 500 !important;
}
/* ── Examples gallery ── */
.gallery-item {
border-radius: 10px !important;
border: 2px solid rgba(99,115,146,0.15) !important;
transition: all 0.25s ease !important;
overflow: hidden;
}
.gallery-item:hover {
border-color: rgba(99,102,241,0.4) !important;
transform: scale(1.04);
box-shadow: 0 4px 16px rgba(99,102,241,0.15) !important;
}
/* ── Scrollbar ── */
::-webkit-scrollbar { width: 6px; }
::-webkit-scrollbar-track { background: transparent; }
::-webkit-scrollbar-thumb {
background: rgba(99,115,146,0.25);
border-radius: 3px;
}
::-webkit-scrollbar-thumb:hover { background: rgba(99,115,146,0.4); }
/* ── Footer spacing ── */
.gradio-container > .main > .wrap:last-child { padding-bottom: 40px !important; }
"""
sample_paths = get_sample_masks()
THEME = gr.themes.Base(
primary_hue=gr.themes.colors.indigo,
secondary_hue=gr.themes.colors.slate,
neutral_hue=gr.themes.colors.slate,
font=gr.themes.GoogleFont("Inter"),
font_mono=gr.themes.GoogleFont("JetBrains Mono"),
radius_size=gr.themes.sizes.radius_lg,
spacing_size=gr.themes.sizes.spacing_md,
).set(
# Background
body_background_fill="#0a0e17",
body_background_fill_dark="#0a0e17",
# Panels
block_background_fill="rgba(17,24,39,0.75)",
block_background_fill_dark="rgba(17,24,39,0.75)",
block_border_color="rgba(99,115,146,0.15)",
block_border_color_dark="rgba(99,115,146,0.15)",
block_shadow="0 4px 24px rgba(0,0,0,0.2)",
block_shadow_dark="0 4px 24px rgba(0,0,0,0.2)",
# Inputs
input_background_fill="#131b2e",
input_background_fill_dark="#131b2e",
input_border_color="rgba(99,115,146,0.15)",
input_border_color_dark="rgba(99,115,146,0.15)",
# Buttons
button_primary_background_fill="linear-gradient(135deg, #6366f1, #4f46e5)",
button_primary_background_fill_dark="linear-gradient(135deg, #6366f1, #4f46e5)",
button_primary_text_color="white",
button_primary_text_color_dark="white",
button_primary_shadow="0 4px 14px rgba(99,102,241,0.25)",
button_primary_shadow_dark="0 4px 14px rgba(99,102,241,0.25)",
# Text
body_text_color="#e2e8f0",
body_text_color_dark="#e2e8f0",
body_text_color_subdued="#94a3b8",
body_text_color_subdued_dark="#94a3b8",
# Labels
block_label_text_color="#94a3b8",
block_label_text_color_dark="#94a3b8",
block_title_text_color="#cbd5e1",
block_title_text_color_dark="#cbd5e1",
# Borders
border_color_primary="rgba(99,102,241,0.4)",
border_color_primary_dark="rgba(99,102,241,0.4)",
)
with gr.Blocks(
title="SynthCXR Β· Chest X-Ray Generator",
) as demo:
gr.Markdown(
"# 🫁 SynthCXR\n"
"Interactively resize anatomical masks and generate realistic chest X-rays"
)
with gr.Row():
# ── Left column: Controls ──
with gr.Column(scale=1):
# Mask input
gr.Markdown("### Select Mask")
mask_input = gr.Image(
label="Conditioning Mask",
type="numpy",
sources=["upload"],
height=240,
)
# Sample mask gallery
if sample_paths:
sample_gallery = gr.Examples(
examples=sample_paths,
inputs=mask_input,
label="Sample Masks",
)
# Sliders
gr.Markdown("### Mask Scaling")
heart_slider = gr.Slider(
minimum=0.0, maximum=2.0, step=0.05, value=1.0,
label="πŸ’™ Heart Scale",
)
left_lung_slider = gr.Slider(
minimum=0.0, maximum=2.0, step=0.05, value=1.0,
label="πŸ”΄ Left Lung Scale",
)
right_lung_slider = gr.Slider(
minimum=0.0, maximum=2.0, step=0.05, value=1.0,
label="🟒 Right Lung Scale",
)
reset_btn = gr.Button("β†Ί Reset Scales", variant="secondary", size="sm")
# Conditions
gr.Markdown("### Conditions")
conditions_select = gr.CheckboxGroup(
choices=CONDITION_CHOICES,
label="Pathologies",
)
with gr.Row():
severity_select = gr.Radio(
choices=SEVERITY_CHOICES, value="(none)", label="Severity",
)
view_select = gr.Radio(
choices=["AP", "PA"], value="AP", label="View",
)
with gr.Row():
age_input = gr.Number(value=45, label="Age", minimum=0, maximum=120, precision=0)
sex_select = gr.Radio(
choices=["male", "female"], value="male", label="Sex",
)
# Advanced
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
steps_input = gr.Number(value=40, label="Steps", minimum=1, maximum=100, precision=0)
cfg_input = gr.Number(value=8.0, label="CFG Scale", minimum=1.0, maximum=20.0)
with gr.Row():
seed_input = gr.Number(value=42, label="Seed", minimum=0, precision=0)
# ── Right column: Outputs ──
with gr.Column(scale=2):
with gr.Row():
mask_preview = gr.Image(
label="Scaled Mask Preview",
type="numpy",
interactive=False,
height=400,
)
cxr_output = gr.Image(
label="Generated Chest X-Ray",
type="pil",
interactive=False,
height=400,
)
# Prompt preview
prompt_preview = gr.Textbox(
label="Prompt Preview",
interactive=False,
lines=3,
)
generate_btn = gr.Button("⚑ Generate CXR", variant="primary", size="lg")
# ── Event wiring ──
# Live mask preview on any slider / mask change
slider_inputs = [mask_input, heart_slider, left_lung_slider, right_lung_slider]
mask_input.change(preview_mask, inputs=slider_inputs, outputs=mask_preview)
heart_slider.change(preview_mask, inputs=slider_inputs, outputs=mask_preview)
left_lung_slider.change(preview_mask, inputs=slider_inputs, outputs=mask_preview)
right_lung_slider.change(preview_mask, inputs=slider_inputs, outputs=mask_preview)
# Reset sliders
def reset_scales():
return 1.0, 1.0, 1.0
reset_btn.click(
reset_scales,
outputs=[heart_slider, left_lung_slider, right_lung_slider],
)
# Auto-adjust sliders when conditions change
_CONDITION_SCALE_MAP = {
# condition_key: (heart_delta, lung_delta)
"cardiomegaly": (+0.35, 0.0),
"enlarged_cardiomediastinum": (+0.25, 0.0),
"atelectasis": (0.0, -0.25),
"pneumothorax": (0.0, -0.30),
"pleural_effusion": (0.0, -0.20),
}
_SEVERITY_MULTIPLIER = {
"(none)": 1.0,
"mild": 0.6,
"moderate": 1.0,
"severe": 1.5,
}
def sync_sliders(conditions: list[str], severity: str):
"""Set slider values based on selected conditions + severity."""
heart = 1.0
lung = 1.0
mult = _SEVERITY_MULTIPLIER.get(severity, 1.0)
for cond in (conditions or []):
h_delta, l_delta = _CONDITION_SCALE_MAP.get(cond, (0.0, 0.0))
heart += h_delta * mult
lung += l_delta * mult
# Clamp to slider range [0.0, 2.0]
heart = round(max(0.0, min(2.0, heart)), 2)
lung = round(max(0.0, min(2.0, lung)), 2)
return heart, lung, lung
conditions_select.change(
sync_sliders,
inputs=[conditions_select, severity_select],
outputs=[heart_slider, left_lung_slider, right_lung_slider],
)
severity_select.change(
sync_sliders,
inputs=[conditions_select, severity_select],
outputs=[heart_slider, left_lung_slider, right_lung_slider],
)
# Prompt preview on config change
prompt_inputs = [conditions_select, severity_select, age_input, sex_select, view_select]
for inp in prompt_inputs:
inp.change(build_prompt_preview, inputs=prompt_inputs, outputs=prompt_preview)
# Generate
generate_btn.click(
generate_cxr,
inputs=[
mask_input,
heart_slider, left_lung_slider, right_lung_slider,
conditions_select, severity_select,
age_input, sex_select, view_select,
steps_input, cfg_input, seed_input,
],
outputs=cxr_output,
)
# ---------------------------------------------------------------------------
# Launch (module-level for HuggingFace Spaces compatibility)
# ---------------------------------------------------------------------------
demo.launch(theme=THEME, css=CUSTOM_CSS)