Spaces:
Running on Zero
Running on Zero
Upload folder using huggingface_hub
Browse files- .gitignore +33 -0
- README.md +24 -7
- app.py +721 -0
- pyproject.toml +20 -0
- requirements.txt +7 -0
- scripts/models/qwen_image_edit_chexpert_lora/epoch-2.safetensors +3 -0
- scripts/models/qwen_image_edit_chexpert_lora/latest_checkpoint.json +1 -0
- static/sample_masks/sample_1.png +0 -0
- static/sample_masks/sample_2.png +0 -0
- static/sample_masks/sample_3.png +0 -0
- synthcxr/__init__.py +1 -0
- synthcxr/constants.py +54 -0
- synthcxr/mask_utils.py +93 -0
- synthcxr/pipeline.py +120 -0
- synthcxr/prompt.py +80 -0
- synthcxr/utils.py +27 -0
.gitignore
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Base model weights (downloaded at runtime from HF Hub)
|
| 2 |
+
scripts/models/Qwen/
|
| 3 |
+
|
| 4 |
+
# LoRA checkpoints — keep only epoch-2 for the demo
|
| 5 |
+
scripts/models/qwen_image_edit_chexpert_lora/epoch-0.safetensors
|
| 6 |
+
scripts/models/qwen_image_edit_chexpert_lora/epoch-1.safetensors
|
| 7 |
+
scripts/models/qwen_image_edit_chexpert_lora/epoch-3.safetensors
|
| 8 |
+
scripts/models/qwen_image_edit_chexpert_lora/epoch-4.safetensors
|
| 9 |
+
|
| 10 |
+
# Python
|
| 11 |
+
__pycache__/
|
| 12 |
+
*.py[cod]
|
| 13 |
+
*.egg-info/
|
| 14 |
+
dist/
|
| 15 |
+
build/
|
| 16 |
+
*.egg
|
| 17 |
+
|
| 18 |
+
# Environment
|
| 19 |
+
.env
|
| 20 |
+
.venv/
|
| 21 |
+
venv/
|
| 22 |
+
.cache/
|
| 23 |
+
|
| 24 |
+
# IDE
|
| 25 |
+
.vscode/
|
| 26 |
+
.idea/
|
| 27 |
+
|
| 28 |
+
# OS
|
| 29 |
+
.DS_Store
|
| 30 |
+
Thumbs.db
|
| 31 |
+
|
| 32 |
+
# Misc
|
| 33 |
+
*.log
|
README.md
CHANGED
|
@@ -1,14 +1,31 @@
|
|
| 1 |
---
|
| 2 |
title: SynthCXR
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 6.9.0
|
| 8 |
-
python_version: '3.12'
|
| 9 |
app_file: app.py
|
|
|
|
| 10 |
pinned: false
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: SynthCXR
|
| 3 |
+
emoji: 🫁
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: "6.9.0"
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
+
hardware: zero-a10g
|
| 10 |
pinned: false
|
| 11 |
+
tags:
|
| 12 |
+
- medical-imaging
|
| 13 |
+
- chest-x-ray
|
| 14 |
+
- diffusion
|
| 15 |
+
- lora
|
| 16 |
+
short_description: Controllable chest X-ray generation with anatomical masks
|
| 17 |
---
|
| 18 |
|
| 19 |
+
# 🫁 SynthCXR · Chest X-Ray Generator
|
| 20 |
+
|
| 21 |
+
Interactively resize anatomical mask components (heart, left lung, right lung) with sliders and generate realistic chest X-rays using a Qwen-Image-Edit model with LoRA fine-tuning on CheXpert.
|
| 22 |
+
|
| 23 |
+
> **Zero GPU** — This Space uses HuggingFace ZeroGPU for dynamic GPU allocation. A GPU is acquired only during image generation and released immediately after.
|
| 24 |
+
|
| 25 |
+
## Features
|
| 26 |
+
|
| 27 |
+
- **Mask Scaling Sliders** — Real-time preview of organ masks scaled from 0× to 2×
|
| 28 |
+
- **Condition Picker** — Select from 13 CheXpert pathologies with severity modifiers
|
| 29 |
+
- **Demographics** — Configure patient age, sex, and radiograph view (AP/PA)
|
| 30 |
+
- **CXR Generation** — Generate 512×512 chest X-rays conditioned on the modified mask
|
| 31 |
+
- **Progress Bar** — Real-time step-by-step progress during generation
|
app.py
ADDED
|
@@ -0,0 +1,721 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Gradio app for SynthCXR: interactive mask scaling and CXR generation."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import spaces
|
| 10 |
+
|
| 11 |
+
import gradio as gr
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from PIL import Image
|
| 15 |
+
|
| 16 |
+
from synthcxr.constants import KNOWN_CONDITIONS
|
| 17 |
+
from synthcxr.mask_utils import resolve_overlaps, scale_mask_channel
|
| 18 |
+
from synthcxr.prompt import ConditionConfig, build_condition_prompt
|
| 19 |
+
|
| 20 |
+
# ---------------------------------------------------------------------------
|
| 21 |
+
# Paths
|
| 22 |
+
# ---------------------------------------------------------------------------
|
| 23 |
+
BASE_DIR = Path(__file__).resolve().parent
|
| 24 |
+
SAMPLE_MASKS_DIR = BASE_DIR / "static" / "sample_masks"
|
| 25 |
+
LORA_DIR = BASE_DIR / "scripts" / "models" / "qwen_image_edit_chexpert_lora"
|
| 26 |
+
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
# Condition / severity choices
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
CONDITION_CHOICES = [
|
| 31 |
+
"enlarged_cardiomediastinum",
|
| 32 |
+
"cardiomegaly",
|
| 33 |
+
"atelectasis",
|
| 34 |
+
"pneumothorax",
|
| 35 |
+
"pleural_effusion",
|
| 36 |
+
]
|
| 37 |
+
SEVERITY_CHOICES = ["(none)", "mild", "moderate", "severe"]
|
| 38 |
+
|
| 39 |
+
# ---------------------------------------------------------------------------
|
| 40 |
+
# Pipeline (lazy-loaded once)
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
+
_pipe = None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_pipeline():
|
| 46 |
+
"""Load the diffusion pipeline + LoRA weights into GPU memory (once)."""
|
| 47 |
+
global _pipe
|
| 48 |
+
if _pipe is not None:
|
| 49 |
+
return _pipe
|
| 50 |
+
|
| 51 |
+
from synthcxr.pipeline import load_lora_weights, load_pipeline
|
| 52 |
+
|
| 53 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 54 |
+
dtype = torch.bfloat16
|
| 55 |
+
|
| 56 |
+
# VRAM_LIMIT (in GB): enables model offloading for memory-constrained GPUs
|
| 57 |
+
vram_limit_str = os.environ.get("VRAM_LIMIT", "")
|
| 58 |
+
vram_limit = float(vram_limit_str) if vram_limit_str else None
|
| 59 |
+
|
| 60 |
+
print(f"[INFO] Loading QwenImagePipeline (device={device}, dtype={dtype}, vram_limit={vram_limit}) …")
|
| 61 |
+
_pipe = load_pipeline(device, dtype, vram_limit=vram_limit)
|
| 62 |
+
|
| 63 |
+
# LORA_EPOCH env var: which epoch checkpoint to load (default: 2)
|
| 64 |
+
lora_epoch = os.environ.get("LORA_EPOCH", "2")
|
| 65 |
+
lora = LORA_DIR / f"epoch-{lora_epoch}.safetensors"
|
| 66 |
+
|
| 67 |
+
if not lora.exists():
|
| 68 |
+
# Try step-based checkpoints or any available .safetensors
|
| 69 |
+
candidates = sorted(LORA_DIR.glob("*.safetensors")) if LORA_DIR.exists() else []
|
| 70 |
+
if candidates:
|
| 71 |
+
lora = candidates[-1]
|
| 72 |
+
print(f"[WARN] epoch-{lora_epoch} not found, falling back to {lora.name}")
|
| 73 |
+
else:
|
| 74 |
+
print("[WARN] No LoRA checkpoint found – running base model only.")
|
| 75 |
+
return _pipe
|
| 76 |
+
|
| 77 |
+
print(f"[INFO] Loading LoRA from {lora}")
|
| 78 |
+
load_lora_weights(_pipe, lora)
|
| 79 |
+
print("[INFO] Pipeline ready.")
|
| 80 |
+
return _pipe
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# ---------------------------------------------------------------------------
|
| 84 |
+
# Sample masks
|
| 85 |
+
# ---------------------------------------------------------------------------
|
| 86 |
+
def get_sample_masks() -> list[str]:
|
| 87 |
+
"""Return paths of bundled sample masks."""
|
| 88 |
+
if not SAMPLE_MASKS_DIR.exists():
|
| 89 |
+
return []
|
| 90 |
+
return sorted(str(p) for p in SAMPLE_MASKS_DIR.glob("*.png"))
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# ---------------------------------------------------------------------------
|
| 94 |
+
# Core functions
|
| 95 |
+
# ---------------------------------------------------------------------------
|
| 96 |
+
|
| 97 |
+
def apply_mask_scaling(
|
| 98 |
+
mask_array: np.ndarray,
|
| 99 |
+
heart_scale: float,
|
| 100 |
+
left_lung_scale: float,
|
| 101 |
+
right_lung_scale: float,
|
| 102 |
+
) -> np.ndarray:
|
| 103 |
+
"""Scale mask channels and resolve overlaps."""
|
| 104 |
+
if heart_scale != 1.0:
|
| 105 |
+
mask_array = scale_mask_channel(mask_array, channel=2, scale_factor=heart_scale)
|
| 106 |
+
if left_lung_scale != 1.0:
|
| 107 |
+
mask_array = scale_mask_channel(mask_array, channel=0, scale_factor=left_lung_scale)
|
| 108 |
+
if right_lung_scale != 1.0:
|
| 109 |
+
mask_array = scale_mask_channel(mask_array, channel=1, scale_factor=right_lung_scale)
|
| 110 |
+
return resolve_overlaps(mask_array, priority=(2, 0, 1))
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def preview_mask(
|
| 114 |
+
mask_image: np.ndarray | None,
|
| 115 |
+
heart_scale: float,
|
| 116 |
+
left_lung_scale: float,
|
| 117 |
+
right_lung_scale: float,
|
| 118 |
+
) -> np.ndarray | None:
|
| 119 |
+
"""Live mask preview callback."""
|
| 120 |
+
if mask_image is None:
|
| 121 |
+
return None
|
| 122 |
+
mask = np.array(Image.fromarray(mask_image).convert("RGB"))
|
| 123 |
+
scaled = apply_mask_scaling(mask, heart_scale, left_lung_scale, right_lung_scale)
|
| 124 |
+
return scaled
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def build_prompt_preview(
|
| 128 |
+
conditions: list[str],
|
| 129 |
+
severity: str,
|
| 130 |
+
age: int,
|
| 131 |
+
sex: str,
|
| 132 |
+
view: str,
|
| 133 |
+
) -> str:
|
| 134 |
+
"""Build the prompt text for preview."""
|
| 135 |
+
cond = ConditionConfig(
|
| 136 |
+
name="preview",
|
| 137 |
+
conditions=conditions or [],
|
| 138 |
+
age=age,
|
| 139 |
+
sex=sex,
|
| 140 |
+
view=view,
|
| 141 |
+
severity=severity if severity != "(none)" else None,
|
| 142 |
+
)
|
| 143 |
+
return build_condition_prompt(cond)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@spaces.GPU(duration=120)
|
| 147 |
+
def generate_cxr(
|
| 148 |
+
mask_image: np.ndarray | None,
|
| 149 |
+
heart_scale: float,
|
| 150 |
+
left_lung_scale: float,
|
| 151 |
+
right_lung_scale: float,
|
| 152 |
+
conditions: list[str],
|
| 153 |
+
severity: str,
|
| 154 |
+
age: int,
|
| 155 |
+
sex: str,
|
| 156 |
+
view: str,
|
| 157 |
+
num_steps: int,
|
| 158 |
+
cfg_scale: float,
|
| 159 |
+
seed: int,
|
| 160 |
+
preview_every: int = 10,
|
| 161 |
+
progress=gr.Progress(),
|
| 162 |
+
):
|
| 163 |
+
"""Generate a CXR, yielding intermediate previews every N steps."""
|
| 164 |
+
if mask_image is None:
|
| 165 |
+
raise gr.Error("Please select or upload a mask first.")
|
| 166 |
+
|
| 167 |
+
pipe = get_pipeline()
|
| 168 |
+
if pipe is None:
|
| 169 |
+
raise gr.Error("Pipeline not loaded. GPU may be unavailable.")
|
| 170 |
+
|
| 171 |
+
# Prepare mask
|
| 172 |
+
mask = np.array(Image.fromarray(mask_image).convert("RGB"))
|
| 173 |
+
scaled = apply_mask_scaling(mask, heart_scale, left_lung_scale, right_lung_scale)
|
| 174 |
+
edit_image = Image.fromarray(scaled)
|
| 175 |
+
|
| 176 |
+
# Build prompt
|
| 177 |
+
cond = ConditionConfig(
|
| 178 |
+
name="web_ui",
|
| 179 |
+
conditions=conditions or [],
|
| 180 |
+
age=age,
|
| 181 |
+
sex=sex,
|
| 182 |
+
view=view,
|
| 183 |
+
severity=severity if severity != "(none)" else None,
|
| 184 |
+
)
|
| 185 |
+
prompt = build_condition_prompt(cond)
|
| 186 |
+
|
| 187 |
+
# Intermediate preview collector
|
| 188 |
+
previews: list[Image.Image] = []
|
| 189 |
+
|
| 190 |
+
class StepCallback:
|
| 191 |
+
"""Custom tqdm-like wrapper that decodes latents every N steps."""
|
| 192 |
+
def __init__(self, iterable):
|
| 193 |
+
self._iterable = iterable
|
| 194 |
+
self._step = 0
|
| 195 |
+
|
| 196 |
+
def __iter__(self):
|
| 197 |
+
for item in self._iterable:
|
| 198 |
+
progress(self._step / num_steps, desc="Generating CXR...")
|
| 199 |
+
yield item
|
| 200 |
+
self._step += 1
|
| 201 |
+
if (
|
| 202 |
+
preview_every > 0
|
| 203 |
+
and self._step % preview_every == 0
|
| 204 |
+
and self._step < num_steps
|
| 205 |
+
and "latents" in _shared_ref
|
| 206 |
+
):
|
| 207 |
+
try:
|
| 208 |
+
with torch.no_grad():
|
| 209 |
+
latents = _shared_ref["latents"]
|
| 210 |
+
decoded = pipe.vae.decode(
|
| 211 |
+
latents,
|
| 212 |
+
device=pipe.device,
|
| 213 |
+
tiled=False,
|
| 214 |
+
)
|
| 215 |
+
img = pipe.vae_output_to_image(decoded)
|
| 216 |
+
previews.append(img)
|
| 217 |
+
except Exception:
|
| 218 |
+
pass # skip preview on error
|
| 219 |
+
|
| 220 |
+
def __len__(self):
|
| 221 |
+
return len(self._iterable)
|
| 222 |
+
|
| 223 |
+
# We patch the pipeline's __call__ to capture inputs_shared reference.
|
| 224 |
+
# The pipeline stores latents in inputs_shared["latents"] during denoising.
|
| 225 |
+
_shared_ref: dict = {}
|
| 226 |
+
_orig_unit_runner = pipe.unit_runner.__class__.__call__
|
| 227 |
+
|
| 228 |
+
def _patched_runner(self_runner, unit, p, inputs_shared, inputs_posi, inputs_nega):
|
| 229 |
+
_shared_ref.update(inputs_shared)
|
| 230 |
+
return _orig_unit_runner(self_runner, unit, p, inputs_shared, inputs_posi, inputs_nega)
|
| 231 |
+
|
| 232 |
+
pipe.unit_runner.__class__.__call__ = _patched_runner
|
| 233 |
+
|
| 234 |
+
try:
|
| 235 |
+
image = pipe(
|
| 236 |
+
prompt=prompt,
|
| 237 |
+
edit_image=edit_image,
|
| 238 |
+
height=512,
|
| 239 |
+
width=512,
|
| 240 |
+
num_inference_steps=num_steps,
|
| 241 |
+
seed=seed,
|
| 242 |
+
rand_device=pipe.device,
|
| 243 |
+
cfg_scale=cfg_scale,
|
| 244 |
+
edit_image_auto_resize=True,
|
| 245 |
+
zero_cond_t=True,
|
| 246 |
+
progress_bar_cmd=StepCallback,
|
| 247 |
+
)
|
| 248 |
+
finally:
|
| 249 |
+
# Restore original runner
|
| 250 |
+
pipe.unit_runner.__class__.__call__ = _orig_unit_runner
|
| 251 |
+
|
| 252 |
+
# Yield all collected previews, then the final image
|
| 253 |
+
for preview in previews:
|
| 254 |
+
yield preview
|
| 255 |
+
yield image
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
# ---------------------------------------------------------------------------
|
| 259 |
+
# Gradio UI
|
| 260 |
+
# ---------------------------------------------------------------------------
|
| 261 |
+
|
| 262 |
+
CUSTOM_CSS = """
|
| 263 |
+
/* ── Layout ── */
|
| 264 |
+
.gradio-container {
|
| 265 |
+
max-width: 1280px !important;
|
| 266 |
+
margin: 0 auto !important;
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
/* ── Radial gradient background ── */
|
| 270 |
+
.main {
|
| 271 |
+
background:
|
| 272 |
+
radial-gradient(ellipse 80% 50% at 10% 20%, rgba(99,102,241,0.07), transparent),
|
| 273 |
+
radial-gradient(ellipse 60% 40% at 85% 75%, rgba(59,130,246,0.05), transparent) !important;
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
/* ── Header ── */
|
| 277 |
+
#component-0 h1 {
|
| 278 |
+
text-align: center;
|
| 279 |
+
font-size: 2.2rem !important;
|
| 280 |
+
font-weight: 800 !important;
|
| 281 |
+
letter-spacing: -0.5px;
|
| 282 |
+
background: linear-gradient(135deg, #818cf8, #60a5fa, #818cf8);
|
| 283 |
+
background-size: 200% 200%;
|
| 284 |
+
-webkit-background-clip: text;
|
| 285 |
+
-webkit-text-fill-color: transparent;
|
| 286 |
+
background-clip: text;
|
| 287 |
+
animation: gradientShift 4s ease-in-out infinite;
|
| 288 |
+
padding-bottom: 4px !important;
|
| 289 |
+
}
|
| 290 |
+
#component-0 p {
|
| 291 |
+
text-align: center;
|
| 292 |
+
color: #94a3b8 !important;
|
| 293 |
+
font-size: 0.95rem;
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
@keyframes gradientShift {
|
| 297 |
+
0%, 100% { background-position: 0% 50%; }
|
| 298 |
+
50% { background-position: 100% 50%; }
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
/* ── Glass panels ── */
|
| 302 |
+
.block {
|
| 303 |
+
border: 1px solid rgba(99,115,146,0.15) !important;
|
| 304 |
+
border-radius: 16px !important;
|
| 305 |
+
backdrop-filter: blur(12px);
|
| 306 |
+
transition: border-color 0.3s ease, box-shadow 0.3s ease !important;
|
| 307 |
+
}
|
| 308 |
+
.block:hover {
|
| 309 |
+
border-color: rgba(99,102,241,0.25) !important;
|
| 310 |
+
box-shadow: 0 0 20px rgba(99,102,241,0.06) !important;
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
/* ── Section headings ── */
|
| 314 |
+
.markdown h3 {
|
| 315 |
+
font-size: 0.78rem !important;
|
| 316 |
+
font-weight: 700 !important;
|
| 317 |
+
text-transform: uppercase;
|
| 318 |
+
letter-spacing: 1.2px;
|
| 319 |
+
color: #64748b !important;
|
| 320 |
+
border-bottom: 1px solid rgba(99,115,146,0.12);
|
| 321 |
+
padding-bottom: 8px !important;
|
| 322 |
+
margin-bottom: 12px !important;
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
/* ── Slider styling ── */
|
| 326 |
+
input[type="range"] {
|
| 327 |
+
height: 6px !important;
|
| 328 |
+
border-radius: 3px !important;
|
| 329 |
+
background: #1e293b !important;
|
| 330 |
+
}
|
| 331 |
+
input[type="range"]::-webkit-slider-thumb {
|
| 332 |
+
width: 18px !important;
|
| 333 |
+
height: 18px !important;
|
| 334 |
+
border-radius: 50% !important;
|
| 335 |
+
border: 2.5px solid #0a0e17 !important;
|
| 336 |
+
transition: transform 0.2s ease, box-shadow 0.2s ease !important;
|
| 337 |
+
}
|
| 338 |
+
input[type="range"]::-webkit-slider-thumb:hover {
|
| 339 |
+
transform: scale(1.2) !important;
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
/* Slider labels */
|
| 343 |
+
.block label span {
|
| 344 |
+
font-weight: 500 !important;
|
| 345 |
+
font-size: 0.88rem !important;
|
| 346 |
+
}
|
| 347 |
+
.block .rangeSlider_value {
|
| 348 |
+
font-variant-numeric: tabular-nums;
|
| 349 |
+
font-weight: 600 !important;
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
/* ── Image panels ── */
|
| 353 |
+
.image-frame img, .image-container img {
|
| 354 |
+
border-radius: 10px !important;
|
| 355 |
+
transition: opacity 0.3s ease !important;
|
| 356 |
+
}
|
| 357 |
+
.image-container {
|
| 358 |
+
background: rgba(0,0,0,0.2) !important;
|
| 359 |
+
border-radius: 12px !important;
|
| 360 |
+
min-height: 380px;
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
/* ── Generate button ── */
|
| 364 |
+
.primary {
|
| 365 |
+
background: linear-gradient(135deg, #6366f1, #4f46e5, #6366f1) !important;
|
| 366 |
+
background-size: 200% 200% !important;
|
| 367 |
+
border: none !important;
|
| 368 |
+
border-radius: 12px !important;
|
| 369 |
+
padding: 14px 24px !important;
|
| 370 |
+
font-weight: 700 !important;
|
| 371 |
+
font-size: 1rem !important;
|
| 372 |
+
letter-spacing: 0.3px;
|
| 373 |
+
transition: all 0.3s cubic-bezier(0.4,0,0.2,1) !important;
|
| 374 |
+
position: relative;
|
| 375 |
+
overflow: hidden;
|
| 376 |
+
}
|
| 377 |
+
.primary:hover {
|
| 378 |
+
transform: translateY(-2px) !important;
|
| 379 |
+
box-shadow: 0 8px 25px rgba(99,102,241,0.4) !important;
|
| 380 |
+
animation: btnShimmer 1.5s ease-in-out infinite !important;
|
| 381 |
+
}
|
| 382 |
+
.primary:active {
|
| 383 |
+
transform: translateY(0) !important;
|
| 384 |
+
}
|
| 385 |
+
@keyframes btnShimmer {
|
| 386 |
+
0%, 100% { background-position: 0% 50%; }
|
| 387 |
+
50% { background-position: 100% 50%; }
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
/* ── Secondary buttons ── */
|
| 391 |
+
.secondary {
|
| 392 |
+
border: 1px solid rgba(99,115,146,0.2) !important;
|
| 393 |
+
border-radius: 10px !important;
|
| 394 |
+
background: transparent !important;
|
| 395 |
+
color: #94a3b8 !important;
|
| 396 |
+
transition: all 0.25s ease !important;
|
| 397 |
+
}
|
| 398 |
+
.secondary:hover {
|
| 399 |
+
border-color: rgba(99,102,241,0.4) !important;
|
| 400 |
+
color: #e2e8f0 !important;
|
| 401 |
+
background: rgba(99,102,241,0.06) !important;
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
/* ── Prompt preview ── */
|
| 405 |
+
textarea[readonly], .prose {
|
| 406 |
+
font-family: 'JetBrains Mono', 'Fira Code', monospace !important;
|
| 407 |
+
font-size: 0.8rem !important;
|
| 408 |
+
line-height: 1.6 !important;
|
| 409 |
+
color: #64748b !important;
|
| 410 |
+
background: rgba(0,0,0,0.25) !important;
|
| 411 |
+
border-radius: 10px !important;
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
/* ── Checkboxes ── */
|
| 415 |
+
.checkbox-group label {
|
| 416 |
+
border-radius: 20px !important;
|
| 417 |
+
padding: 4px 12px !important;
|
| 418 |
+
font-size: 0.8rem !important;
|
| 419 |
+
transition: all 0.2s ease !important;
|
| 420 |
+
border: 1px solid rgba(99,115,146,0.15) !important;
|
| 421 |
+
color: #e2e8f0 !important;
|
| 422 |
+
background: rgba(17,24,39,0.75) !important;
|
| 423 |
+
}
|
| 424 |
+
.checkbox-group label span {
|
| 425 |
+
color: #e2e8f0 !important;
|
| 426 |
+
}
|
| 427 |
+
.checkbox-group label:hover {
|
| 428 |
+
border-color: rgba(99,102,241,0.35) !important;
|
| 429 |
+
background: rgba(30,41,59,0.9) !important;
|
| 430 |
+
}
|
| 431 |
+
.checkbox-group input:checked + label,
|
| 432 |
+
.checkbox-group label.selected {
|
| 433 |
+
background: rgba(99,102,241,0.15) !important;
|
| 434 |
+
border-color: rgba(99,102,241,0.4) !important;
|
| 435 |
+
color: #c7d2fe !important;
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
/* ── Dropdowns & inputs ── */
|
| 439 |
+
select, input[type="number"] {
|
| 440 |
+
border-radius: 10px !important;
|
| 441 |
+
border: 1px solid rgba(99,115,146,0.15) !important;
|
| 442 |
+
transition: border-color 0.25s ease !important;
|
| 443 |
+
font-size: 0.88rem !important;
|
| 444 |
+
}
|
| 445 |
+
select:focus, input[type="number"]:focus {
|
| 446 |
+
border-color: rgba(99,102,241,0.5) !important;
|
| 447 |
+
box-shadow: 0 0 0 2px rgba(99,102,241,0.1) !important;
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
/* ── Accordion ── */
|
| 451 |
+
.accordion {
|
| 452 |
+
border: 1px solid rgba(99,115,146,0.1) !important;
|
| 453 |
+
border-radius: 12px !important;
|
| 454 |
+
background: rgba(0,0,0,0.15) !important;
|
| 455 |
+
}
|
| 456 |
+
.accordion > .label-wrap {
|
| 457 |
+
font-size: 0.82rem !important;
|
| 458 |
+
color: #64748b !important;
|
| 459 |
+
font-weight: 500 !important;
|
| 460 |
+
}
|
| 461 |
+
|
| 462 |
+
/* ── Examples gallery ── */
|
| 463 |
+
.gallery-item {
|
| 464 |
+
border-radius: 10px !important;
|
| 465 |
+
border: 2px solid rgba(99,115,146,0.15) !important;
|
| 466 |
+
transition: all 0.25s ease !important;
|
| 467 |
+
overflow: hidden;
|
| 468 |
+
}
|
| 469 |
+
.gallery-item:hover {
|
| 470 |
+
border-color: rgba(99,102,241,0.4) !important;
|
| 471 |
+
transform: scale(1.04);
|
| 472 |
+
box-shadow: 0 4px 16px rgba(99,102,241,0.15) !important;
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
/* ── Scrollbar ── */
|
| 476 |
+
::-webkit-scrollbar { width: 6px; }
|
| 477 |
+
::-webkit-scrollbar-track { background: transparent; }
|
| 478 |
+
::-webkit-scrollbar-thumb {
|
| 479 |
+
background: rgba(99,115,146,0.25);
|
| 480 |
+
border-radius: 3px;
|
| 481 |
+
}
|
| 482 |
+
::-webkit-scrollbar-thumb:hover { background: rgba(99,115,146,0.4); }
|
| 483 |
+
|
| 484 |
+
/* ── Footer spacing ── */
|
| 485 |
+
.gradio-container > .main > .wrap:last-child { padding-bottom: 40px !important; }
|
| 486 |
+
"""
|
| 487 |
+
|
| 488 |
+
sample_paths = get_sample_masks()
|
| 489 |
+
|
| 490 |
+
THEME = gr.themes.Base(
|
| 491 |
+
primary_hue=gr.themes.colors.indigo,
|
| 492 |
+
secondary_hue=gr.themes.colors.slate,
|
| 493 |
+
neutral_hue=gr.themes.colors.slate,
|
| 494 |
+
font=gr.themes.GoogleFont("Inter"),
|
| 495 |
+
font_mono=gr.themes.GoogleFont("JetBrains Mono"),
|
| 496 |
+
radius_size=gr.themes.sizes.radius_lg,
|
| 497 |
+
spacing_size=gr.themes.sizes.spacing_md,
|
| 498 |
+
).set(
|
| 499 |
+
# Background
|
| 500 |
+
body_background_fill="#0a0e17",
|
| 501 |
+
body_background_fill_dark="#0a0e17",
|
| 502 |
+
# Panels
|
| 503 |
+
block_background_fill="rgba(17,24,39,0.75)",
|
| 504 |
+
block_background_fill_dark="rgba(17,24,39,0.75)",
|
| 505 |
+
block_border_color="rgba(99,115,146,0.15)",
|
| 506 |
+
block_border_color_dark="rgba(99,115,146,0.15)",
|
| 507 |
+
block_shadow="0 4px 24px rgba(0,0,0,0.2)",
|
| 508 |
+
block_shadow_dark="0 4px 24px rgba(0,0,0,0.2)",
|
| 509 |
+
# Inputs
|
| 510 |
+
input_background_fill="#131b2e",
|
| 511 |
+
input_background_fill_dark="#131b2e",
|
| 512 |
+
input_border_color="rgba(99,115,146,0.15)",
|
| 513 |
+
input_border_color_dark="rgba(99,115,146,0.15)",
|
| 514 |
+
# Buttons
|
| 515 |
+
button_primary_background_fill="linear-gradient(135deg, #6366f1, #4f46e5)",
|
| 516 |
+
button_primary_background_fill_dark="linear-gradient(135deg, #6366f1, #4f46e5)",
|
| 517 |
+
button_primary_text_color="white",
|
| 518 |
+
button_primary_text_color_dark="white",
|
| 519 |
+
button_primary_shadow="0 4px 14px rgba(99,102,241,0.25)",
|
| 520 |
+
button_primary_shadow_dark="0 4px 14px rgba(99,102,241,0.25)",
|
| 521 |
+
# Text
|
| 522 |
+
body_text_color="#e2e8f0",
|
| 523 |
+
body_text_color_dark="#e2e8f0",
|
| 524 |
+
body_text_color_subdued="#94a3b8",
|
| 525 |
+
body_text_color_subdued_dark="#94a3b8",
|
| 526 |
+
# Labels
|
| 527 |
+
block_label_text_color="#94a3b8",
|
| 528 |
+
block_label_text_color_dark="#94a3b8",
|
| 529 |
+
block_title_text_color="#cbd5e1",
|
| 530 |
+
block_title_text_color_dark="#cbd5e1",
|
| 531 |
+
# Borders
|
| 532 |
+
border_color_primary="rgba(99,102,241,0.4)",
|
| 533 |
+
border_color_primary_dark="rgba(99,102,241,0.4)",
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
with gr.Blocks(
|
| 537 |
+
title="SynthCXR · Chest X-Ray Generator",
|
| 538 |
+
) as demo:
|
| 539 |
+
|
| 540 |
+
gr.Markdown(
|
| 541 |
+
"# 🫁 SynthCXR\n"
|
| 542 |
+
"Interactively resize anatomical masks and generate realistic chest X-rays"
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
with gr.Row():
|
| 546 |
+
|
| 547 |
+
# ── Left column: Controls ──
|
| 548 |
+
with gr.Column(scale=1):
|
| 549 |
+
|
| 550 |
+
# Mask input
|
| 551 |
+
gr.Markdown("### Select Mask")
|
| 552 |
+
mask_input = gr.Image(
|
| 553 |
+
label="Conditioning Mask",
|
| 554 |
+
type="numpy",
|
| 555 |
+
sources=["upload"],
|
| 556 |
+
height=240,
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
# Sample mask gallery
|
| 560 |
+
if sample_paths:
|
| 561 |
+
sample_gallery = gr.Examples(
|
| 562 |
+
examples=sample_paths,
|
| 563 |
+
inputs=mask_input,
|
| 564 |
+
label="Sample Masks",
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
# Sliders
|
| 568 |
+
gr.Markdown("### Mask Scaling")
|
| 569 |
+
heart_slider = gr.Slider(
|
| 570 |
+
minimum=0.0, maximum=2.0, step=0.05, value=1.0,
|
| 571 |
+
label="💙 Heart Scale",
|
| 572 |
+
)
|
| 573 |
+
left_lung_slider = gr.Slider(
|
| 574 |
+
minimum=0.0, maximum=2.0, step=0.05, value=1.0,
|
| 575 |
+
label="🔴 Left Lung Scale",
|
| 576 |
+
)
|
| 577 |
+
right_lung_slider = gr.Slider(
|
| 578 |
+
minimum=0.0, maximum=2.0, step=0.05, value=1.0,
|
| 579 |
+
label="🟢 Right Lung Scale",
|
| 580 |
+
)
|
| 581 |
+
reset_btn = gr.Button("↺ Reset Scales", variant="secondary", size="sm")
|
| 582 |
+
|
| 583 |
+
# Conditions
|
| 584 |
+
gr.Markdown("### Conditions")
|
| 585 |
+
conditions_select = gr.CheckboxGroup(
|
| 586 |
+
choices=CONDITION_CHOICES,
|
| 587 |
+
label="Pathologies",
|
| 588 |
+
)
|
| 589 |
+
with gr.Row():
|
| 590 |
+
severity_select = gr.Radio(
|
| 591 |
+
choices=SEVERITY_CHOICES, value="(none)", label="Severity",
|
| 592 |
+
)
|
| 593 |
+
view_select = gr.Radio(
|
| 594 |
+
choices=["AP", "PA"], value="AP", label="View",
|
| 595 |
+
)
|
| 596 |
+
with gr.Row():
|
| 597 |
+
age_input = gr.Number(value=45, label="Age", minimum=0, maximum=120, precision=0)
|
| 598 |
+
sex_select = gr.Radio(
|
| 599 |
+
choices=["male", "female"], value="male", label="Sex",
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
# Advanced
|
| 603 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 604 |
+
with gr.Row():
|
| 605 |
+
steps_input = gr.Number(value=50, label="Steps", minimum=1, maximum=100, precision=0)
|
| 606 |
+
cfg_input = gr.Number(value=4.0, label="CFG Scale", minimum=1.0, maximum=20.0)
|
| 607 |
+
with gr.Row():
|
| 608 |
+
seed_input = gr.Number(value=42, label="Seed", minimum=0, precision=0)
|
| 609 |
+
preview_every_input = gr.Number(value=10, label="Preview Every N Steps", minimum=0, maximum=50, precision=0)
|
| 610 |
+
|
| 611 |
+
# ── Right column: Outputs ──
|
| 612 |
+
with gr.Column(scale=2):
|
| 613 |
+
|
| 614 |
+
with gr.Row():
|
| 615 |
+
mask_preview = gr.Image(
|
| 616 |
+
label="Scaled Mask Preview",
|
| 617 |
+
type="numpy",
|
| 618 |
+
interactive=False,
|
| 619 |
+
height=400,
|
| 620 |
+
)
|
| 621 |
+
cxr_output = gr.Image(
|
| 622 |
+
label="Generated Chest X-Ray",
|
| 623 |
+
type="pil",
|
| 624 |
+
interactive=False,
|
| 625 |
+
height=400,
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
# Prompt preview
|
| 629 |
+
prompt_preview = gr.Textbox(
|
| 630 |
+
label="Prompt Preview",
|
| 631 |
+
interactive=False,
|
| 632 |
+
lines=3,
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
generate_btn = gr.Button("⚡ Generate CXR", variant="primary", size="lg")
|
| 636 |
+
|
| 637 |
+
# ── Event wiring ──
|
| 638 |
+
|
| 639 |
+
# Live mask preview on any slider / mask change
|
| 640 |
+
slider_inputs = [mask_input, heart_slider, left_lung_slider, right_lung_slider]
|
| 641 |
+
|
| 642 |
+
mask_input.change(preview_mask, inputs=slider_inputs, outputs=mask_preview)
|
| 643 |
+
heart_slider.change(preview_mask, inputs=slider_inputs, outputs=mask_preview)
|
| 644 |
+
left_lung_slider.change(preview_mask, inputs=slider_inputs, outputs=mask_preview)
|
| 645 |
+
right_lung_slider.change(preview_mask, inputs=slider_inputs, outputs=mask_preview)
|
| 646 |
+
|
| 647 |
+
# Reset sliders
|
| 648 |
+
def reset_scales():
|
| 649 |
+
return 1.0, 1.0, 1.0
|
| 650 |
+
|
| 651 |
+
reset_btn.click(
|
| 652 |
+
reset_scales,
|
| 653 |
+
outputs=[heart_slider, left_lung_slider, right_lung_slider],
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
# Auto-adjust sliders when conditions change
|
| 657 |
+
_CONDITION_SCALE_MAP = {
|
| 658 |
+
# condition_key: (heart_delta, lung_delta)
|
| 659 |
+
"cardiomegaly": (+0.35, 0.0),
|
| 660 |
+
"enlarged_cardiomediastinum": (+0.25, 0.0),
|
| 661 |
+
"atelectasis": (0.0, -0.25),
|
| 662 |
+
"pneumothorax": (0.0, -0.30),
|
| 663 |
+
"pleural_effusion": (0.0, -0.20),
|
| 664 |
+
}
|
| 665 |
+
_SEVERITY_MULTIPLIER = {
|
| 666 |
+
"(none)": 1.0,
|
| 667 |
+
"mild": 0.6,
|
| 668 |
+
"moderate": 1.0,
|
| 669 |
+
"severe": 1.5,
|
| 670 |
+
}
|
| 671 |
+
|
| 672 |
+
def sync_sliders(conditions: list[str], severity: str):
|
| 673 |
+
"""Set slider values based on selected conditions + severity."""
|
| 674 |
+
heart = 1.0
|
| 675 |
+
lung = 1.0
|
| 676 |
+
mult = _SEVERITY_MULTIPLIER.get(severity, 1.0)
|
| 677 |
+
for cond in (conditions or []):
|
| 678 |
+
h_delta, l_delta = _CONDITION_SCALE_MAP.get(cond, (0.0, 0.0))
|
| 679 |
+
heart += h_delta * mult
|
| 680 |
+
lung += l_delta * mult
|
| 681 |
+
# Clamp to slider range [0.0, 2.0]
|
| 682 |
+
heart = round(max(0.0, min(2.0, heart)), 2)
|
| 683 |
+
lung = round(max(0.0, min(2.0, lung)), 2)
|
| 684 |
+
return heart, lung, lung
|
| 685 |
+
|
| 686 |
+
conditions_select.change(
|
| 687 |
+
sync_sliders,
|
| 688 |
+
inputs=[conditions_select, severity_select],
|
| 689 |
+
outputs=[heart_slider, left_lung_slider, right_lung_slider],
|
| 690 |
+
)
|
| 691 |
+
severity_select.change(
|
| 692 |
+
sync_sliders,
|
| 693 |
+
inputs=[conditions_select, severity_select],
|
| 694 |
+
outputs=[heart_slider, left_lung_slider, right_lung_slider],
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
# Prompt preview on config change
|
| 698 |
+
prompt_inputs = [conditions_select, severity_select, age_input, sex_select, view_select]
|
| 699 |
+
|
| 700 |
+
for inp in prompt_inputs:
|
| 701 |
+
inp.change(build_prompt_preview, inputs=prompt_inputs, outputs=prompt_preview)
|
| 702 |
+
|
| 703 |
+
# Generate
|
| 704 |
+
generate_btn.click(
|
| 705 |
+
generate_cxr,
|
| 706 |
+
inputs=[
|
| 707 |
+
mask_input,
|
| 708 |
+
heart_slider, left_lung_slider, right_lung_slider,
|
| 709 |
+
conditions_select, severity_select,
|
| 710 |
+
age_input, sex_select, view_select,
|
| 711 |
+
steps_input, cfg_input, seed_input,
|
| 712 |
+
preview_every_input,
|
| 713 |
+
],
|
| 714 |
+
outputs=cxr_output,
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
# ---------------------------------------------------------------------------
|
| 719 |
+
# Launch (module-level for HuggingFace Spaces compatibility)
|
| 720 |
+
# ---------------------------------------------------------------------------
|
| 721 |
+
demo.launch(theme=THEME, css=CUSTOM_CSS)
|
pyproject.toml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "synthcxr"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Chest X-ray generation via Qwen-Image-Edit LoRA fine-tuning"
|
| 5 |
+
requires-python = ">=3.10.1"
|
| 6 |
+
dependencies = [
|
| 7 |
+
"diffsynth>=2.0.4",
|
| 8 |
+
"fastapi[standard]>=0.135.1",
|
| 9 |
+
"gradio>=6.8.0",
|
| 10 |
+
"python-multipart>=0.0.22",
|
| 11 |
+
"scipy",
|
| 12 |
+
"uvicorn[standard]>=0.41.0",
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
[build-system]
|
| 16 |
+
requires = ["setuptools>=68"]
|
| 17 |
+
build-backend = "setuptools.build_meta"
|
| 18 |
+
|
| 19 |
+
[tool.setuptools.packages.find]
|
| 20 |
+
where = ["src"]
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=4.0
|
| 2 |
+
diffsynth>=2.0.4
|
| 3 |
+
spaces
|
| 4 |
+
scipy
|
| 5 |
+
Pillow
|
| 6 |
+
numpy
|
| 7 |
+
torch
|
scripts/models/qwen_image_edit_chexpert_lora/epoch-2.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fef90b53ae95c9628efe14b0919f7be7e291ec9f80677a3f2ed509ebccca1c05
|
| 3 |
+
size 472047184
|
scripts/models/qwen_image_edit_chexpert_lora/latest_checkpoint.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"path": "./models/qwen_image_edit_chexpert_lora/checkpoint-step233240", "epoch_id": 4, "global_step": 233240}
|
static/sample_masks/sample_1.png
ADDED
|
static/sample_masks/sample_2.png
ADDED
|
static/sample_masks/sample_3.png
ADDED
|
synthcxr/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""SynthCXR: Chest X-ray generation via Qwen-Image-Edit LoRA fine-tuning."""
|
synthcxr/constants.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared constants for SynthCXR: disease labels, condition maps, severity modifiers."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
# CheXpert label column names -> natural-language descriptions used in prompts.
|
| 6 |
+
# Used by both dataset preparation and inference scripts.
|
| 7 |
+
LABEL_TEXT: dict[str, str] = {
|
| 8 |
+
"Enlarged Cardiomediastinum": "enlarged cardiomediastinum",
|
| 9 |
+
"Cardiomegaly": "cardiomegaly",
|
| 10 |
+
"Lung Opacity": "diffuse lung opacity",
|
| 11 |
+
"Lung Lesion": "discrete lung lesion",
|
| 12 |
+
"Edema": "pulmonary edema",
|
| 13 |
+
"Consolidation": "parenchymal consolidation",
|
| 14 |
+
"Pneumonia": "findings compatible with pneumonia",
|
| 15 |
+
"Atelectasis": "atelectasis",
|
| 16 |
+
"Pneumothorax": "pneumothorax",
|
| 17 |
+
"Pleural Effusion": "pleural effusion",
|
| 18 |
+
"Pleural Other": "other pleural abnormality",
|
| 19 |
+
"Fracture": "possible fracture",
|
| 20 |
+
"Support Devices": "support devices in place",
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
# Snake_case keys for config files -> natural-language descriptions.
|
| 24 |
+
KNOWN_CONDITIONS: dict[str, str] = {
|
| 25 |
+
"enlarged_cardiomediastinum": "enlarged cardiomediastinum",
|
| 26 |
+
"cardiomegaly": "cardiomegaly",
|
| 27 |
+
"lung_opacity": "diffuse lung opacity",
|
| 28 |
+
"lung_lesion": "discrete lung lesion",
|
| 29 |
+
"edema": "pulmonary edema",
|
| 30 |
+
"consolidation": "parenchymal consolidation",
|
| 31 |
+
"pneumonia": "findings compatible with pneumonia",
|
| 32 |
+
"atelectasis": "atelectasis",
|
| 33 |
+
"pneumothorax": "pneumothorax",
|
| 34 |
+
"pleural_effusion": "pleural effusion",
|
| 35 |
+
"pleural_other": "other pleural abnormality",
|
| 36 |
+
"fracture": "possible fracture",
|
| 37 |
+
"support_devices": "support devices in place",
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
SEVERITY_MODIFIERS: dict[str, str] = {
|
| 41 |
+
"mild": "mild",
|
| 42 |
+
"moderate": "moderate",
|
| 43 |
+
"severe": "severe",
|
| 44 |
+
"small": "small",
|
| 45 |
+
"large": "large",
|
| 46 |
+
"very_small": "very small",
|
| 47 |
+
"very_large": "very large",
|
| 48 |
+
"minimal": "minimal",
|
| 49 |
+
"significant": "significant",
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
DEFAULT_MODEL_ID = "Qwen/Qwen-Image-Edit-2511"
|
| 53 |
+
TEXT_ENCODER_MODEL_ID = "Qwen/Qwen-Image"
|
| 54 |
+
PROCESSOR_MODEL_ID = "Qwen/Qwen-Image-Edit"
|
synthcxr/mask_utils.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Mask manipulation: scaling organ regions and resolving overlaps."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from scipy import ndimage
|
| 10 |
+
from scipy.ndimage import map_coordinates
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def resolve_overlaps(
|
| 14 |
+
mask: np.ndarray,
|
| 15 |
+
priority: tuple[int, int, int] = (2, 0, 1),
|
| 16 |
+
threshold: int = 10,
|
| 17 |
+
) -> np.ndarray:
|
| 18 |
+
"""Assign overlapping pixels to the highest-priority channel.
|
| 19 |
+
|
| 20 |
+
Default priority: heart (2) > left lung (0) > right lung (1).
|
| 21 |
+
"""
|
| 22 |
+
result = mask.copy()
|
| 23 |
+
active = mask > threshold
|
| 24 |
+
overlap_mask = active.sum(axis=2) > 1
|
| 25 |
+
if not overlap_mask.any():
|
| 26 |
+
return result
|
| 27 |
+
|
| 28 |
+
for y, x in zip(*np.where(overlap_mask)):
|
| 29 |
+
active_channels = [ch for ch in range(3) if mask[y, x, ch] > threshold]
|
| 30 |
+
best = min(active_channels, key=lambda ch: priority.index(ch))
|
| 31 |
+
for ch in active_channels:
|
| 32 |
+
if ch != best:
|
| 33 |
+
result[y, x, ch] = 0
|
| 34 |
+
return result
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def scale_mask_channel(
|
| 38 |
+
mask: np.ndarray,
|
| 39 |
+
channel: int,
|
| 40 |
+
scale_factor: float,
|
| 41 |
+
threshold: int = 10,
|
| 42 |
+
) -> np.ndarray:
|
| 43 |
+
"""Scale a single channel's region around its centroid.
|
| 44 |
+
|
| 45 |
+
``channel``: 0 = left lung (red), 1 = right lung (green), 2 = heart (blue).
|
| 46 |
+
"""
|
| 47 |
+
result = mask.copy()
|
| 48 |
+
channel_data = mask[:, :, channel]
|
| 49 |
+
binary = channel_data > threshold
|
| 50 |
+
if not binary.any():
|
| 51 |
+
return result
|
| 52 |
+
|
| 53 |
+
cy, cx = ndimage.center_of_mass(binary)
|
| 54 |
+
h, w = mask.shape[:2]
|
| 55 |
+
y_coords, x_coords = np.mgrid[0:h, 0:w]
|
| 56 |
+
y_t = ((y_coords - cy) / scale_factor + cy).astype(np.float32)
|
| 57 |
+
x_t = ((x_coords - cx) / scale_factor + cx).astype(np.float32)
|
| 58 |
+
|
| 59 |
+
result[:, :, channel] = 0
|
| 60 |
+
scaled = map_coordinates(
|
| 61 |
+
channel_data.astype(np.float32),
|
| 62 |
+
[y_t, x_t],
|
| 63 |
+
order=1,
|
| 64 |
+
mode="constant",
|
| 65 |
+
cval=0,
|
| 66 |
+
)
|
| 67 |
+
result[:, :, channel] = np.clip(scaled, 0, 255).astype(np.uint8)
|
| 68 |
+
return result
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def modify_mask(
|
| 72 |
+
input_path: Path,
|
| 73 |
+
output_path: Path,
|
| 74 |
+
heart_scale: float = 1.0,
|
| 75 |
+
left_lung_scale: float = 1.0,
|
| 76 |
+
right_lung_scale: float = 1.0,
|
| 77 |
+
) -> None:
|
| 78 |
+
"""Load a conditioning mask, apply scale factors, and save."""
|
| 79 |
+
with Image.open(input_path) as img:
|
| 80 |
+
mask = np.array(img.convert("RGB"))
|
| 81 |
+
|
| 82 |
+
if left_lung_scale != 1.0:
|
| 83 |
+
mask = scale_mask_channel(mask, channel=0, scale_factor=left_lung_scale)
|
| 84 |
+
if right_lung_scale != 1.0:
|
| 85 |
+
mask = scale_mask_channel(mask, channel=1, scale_factor=right_lung_scale)
|
| 86 |
+
if heart_scale != 1.0:
|
| 87 |
+
mask = scale_mask_channel(mask, channel=2, scale_factor=heart_scale)
|
| 88 |
+
|
| 89 |
+
mask = resolve_overlaps(mask, priority=(2, 0, 1))
|
| 90 |
+
|
| 91 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 92 |
+
Image.fromarray(mask).save(output_path)
|
| 93 |
+
print(f"[INFO] Saved modified mask to {output_path}")
|
synthcxr/pipeline.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pipeline loading, LoRA weight management, and image I/O helpers."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Sequence
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
from diffsynth.core import ModelConfig
|
| 14 |
+
from diffsynth.pipelines.qwen_image import QwenImagePipeline
|
| 15 |
+
|
| 16 |
+
from .constants import DEFAULT_MODEL_ID, PROCESSOR_MODEL_ID, TEXT_ENCODER_MODEL_ID
|
| 17 |
+
from .mask_utils import resolve_overlaps, scale_mask_channel
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class SampleSpec:
|
| 22 |
+
"""A single validation/inference sample."""
|
| 23 |
+
|
| 24 |
+
prompt: str
|
| 25 |
+
mask_paths: list[Path]
|
| 26 |
+
identifier: str
|
| 27 |
+
image_path: Path | None
|
| 28 |
+
original_prompt: str = ""
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_pipeline(
|
| 32 |
+
device: str,
|
| 33 |
+
torch_dtype: torch.dtype,
|
| 34 |
+
model_id: str = DEFAULT_MODEL_ID,
|
| 35 |
+
vram_limit: float | None = None,
|
| 36 |
+
) -> QwenImagePipeline:
|
| 37 |
+
"""Instantiate a ``QwenImagePipeline``, downloading weights from HF Hub."""
|
| 38 |
+
model_configs = [
|
| 39 |
+
ModelConfig(
|
| 40 |
+
model_id=model_id,
|
| 41 |
+
origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors",
|
| 42 |
+
),
|
| 43 |
+
ModelConfig(
|
| 44 |
+
model_id=TEXT_ENCODER_MODEL_ID,
|
| 45 |
+
origin_file_pattern="text_encoder/model*.safetensors",
|
| 46 |
+
),
|
| 47 |
+
ModelConfig(
|
| 48 |
+
model_id=TEXT_ENCODER_MODEL_ID,
|
| 49 |
+
origin_file_pattern="vae/diffusion_pytorch_model.safetensors",
|
| 50 |
+
),
|
| 51 |
+
]
|
| 52 |
+
tokenizer_config = ModelConfig(
|
| 53 |
+
model_id=TEXT_ENCODER_MODEL_ID,
|
| 54 |
+
origin_file_pattern="tokenizer/",
|
| 55 |
+
)
|
| 56 |
+
processor_config = ModelConfig(
|
| 57 |
+
model_id=PROCESSOR_MODEL_ID,
|
| 58 |
+
origin_file_pattern="processor/",
|
| 59 |
+
)
|
| 60 |
+
pipe = QwenImagePipeline.from_pretrained(
|
| 61 |
+
torch_dtype=torch_dtype,
|
| 62 |
+
device=device,
|
| 63 |
+
model_configs=model_configs,
|
| 64 |
+
tokenizer_config=tokenizer_config,
|
| 65 |
+
processor_config=processor_config,
|
| 66 |
+
vram_limit=vram_limit,
|
| 67 |
+
)
|
| 68 |
+
return pipe
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def load_lora_weights(pipe: QwenImagePipeline, checkpoint: Path) -> None:
|
| 72 |
+
"""Load a LoRA checkpoint into an existing pipeline."""
|
| 73 |
+
pipe.clear_lora()
|
| 74 |
+
pipe.load_lora(pipe.dit, lora_config=str(checkpoint))
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def load_edit_images(
|
| 78 |
+
paths: Sequence[Path],
|
| 79 |
+
*,
|
| 80 |
+
heart_scale: float = 1.0,
|
| 81 |
+
left_lung_scale: float = 1.0,
|
| 82 |
+
right_lung_scale: float = 1.0,
|
| 83 |
+
) -> Image.Image | list[Image.Image]:
|
| 84 |
+
"""Load conditioning mask image(s), optionally rescaling organ regions."""
|
| 85 |
+
images: list[Image.Image] = []
|
| 86 |
+
needs_modification = (
|
| 87 |
+
heart_scale != 1.0 or left_lung_scale != 1.0 or right_lung_scale != 1.0
|
| 88 |
+
)
|
| 89 |
+
for path in paths:
|
| 90 |
+
with Image.open(path) as img:
|
| 91 |
+
if needs_modification:
|
| 92 |
+
mask = np.array(img.convert("RGB"))
|
| 93 |
+
if heart_scale != 1.0:
|
| 94 |
+
mask = scale_mask_channel(mask, channel=2, scale_factor=heart_scale)
|
| 95 |
+
if left_lung_scale != 1.0:
|
| 96 |
+
mask = scale_mask_channel(mask, channel=0, scale_factor=left_lung_scale)
|
| 97 |
+
if right_lung_scale != 1.0:
|
| 98 |
+
mask = scale_mask_channel(mask, channel=1, scale_factor=right_lung_scale)
|
| 99 |
+
mask = resolve_overlaps(mask, priority=(2, 0, 1))
|
| 100 |
+
images.append(Image.fromarray(mask))
|
| 101 |
+
else:
|
| 102 |
+
images.append(img.convert("RGB"))
|
| 103 |
+
return images[0] if len(images) == 1 else images
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def export_original_images(
|
| 107 |
+
samples: Sequence[SampleSpec], output_dir: Path
|
| 108 |
+
) -> None:
|
| 109 |
+
"""Copy original CXR images into *output_dir*/original/ for comparison."""
|
| 110 |
+
original_dir = output_dir / "original"
|
| 111 |
+
original_dir.mkdir(parents=True, exist_ok=True)
|
| 112 |
+
for sample in samples:
|
| 113 |
+
if sample.image_path is None:
|
| 114 |
+
continue
|
| 115 |
+
dest = original_dir / f"{sample.identifier}.png"
|
| 116 |
+
if dest.exists():
|
| 117 |
+
continue
|
| 118 |
+
with Image.open(sample.image_path) as img:
|
| 119 |
+
img.convert("RGB").save(dest)
|
| 120 |
+
print(f"[INFO] Saved original to {dest}")
|
synthcxr/prompt.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Prompt builders for conditional inference."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
|
| 7 |
+
from .constants import KNOWN_CONDITIONS, SEVERITY_MODIFIERS
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class ConditionConfig:
|
| 12 |
+
"""Configuration for a single inference run with specific conditions."""
|
| 13 |
+
|
| 14 |
+
name: str
|
| 15 |
+
conditions: list[str] = field(default_factory=list)
|
| 16 |
+
age: int | None = None
|
| 17 |
+
sex: str | None = None
|
| 18 |
+
view: str = "AP"
|
| 19 |
+
custom_prompt: str | None = None
|
| 20 |
+
severity: str | None = None
|
| 21 |
+
heart_scale: float = 1.0
|
| 22 |
+
left_lung_scale: float = 1.0
|
| 23 |
+
right_lung_scale: float = 1.0
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class InferenceConfig:
|
| 28 |
+
"""Top-level configuration for the condition-inference script."""
|
| 29 |
+
|
| 30 |
+
num_samples: int = 10
|
| 31 |
+
num_steps: int = 50
|
| 32 |
+
height: int = 512
|
| 33 |
+
width: int = 512
|
| 34 |
+
cfg_scale: float = 4.0
|
| 35 |
+
seed: int = 0
|
| 36 |
+
conditions: list[ConditionConfig] = field(default_factory=list)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def build_condition_prompt(condition: ConditionConfig) -> str:
|
| 40 |
+
"""Build a CheXpert-style prompt from a ``ConditionConfig``."""
|
| 41 |
+
if condition.custom_prompt:
|
| 42 |
+
return condition.custom_prompt
|
| 43 |
+
|
| 44 |
+
view = condition.view.upper() if condition.view else "AP"
|
| 45 |
+
age_str = f"{condition.age}-year-old" if condition.age else ""
|
| 46 |
+
sex_str = condition.sex.lower() if condition.sex else ""
|
| 47 |
+
|
| 48 |
+
if age_str and sex_str:
|
| 49 |
+
demographics = f"a {age_str} {sex_str} patient"
|
| 50 |
+
elif age_str:
|
| 51 |
+
demographics = f"a {age_str} patient"
|
| 52 |
+
elif sex_str:
|
| 53 |
+
demographics = f"a {sex_str} patient"
|
| 54 |
+
else:
|
| 55 |
+
demographics = "a patient"
|
| 56 |
+
|
| 57 |
+
pathologies: list[str] = []
|
| 58 |
+
severity_prefix = ""
|
| 59 |
+
if condition.severity and condition.severity in SEVERITY_MODIFIERS:
|
| 60 |
+
severity_prefix = SEVERITY_MODIFIERS[condition.severity] + " "
|
| 61 |
+
|
| 62 |
+
for cond_key in condition.conditions:
|
| 63 |
+
cond_text = KNOWN_CONDITIONS.get(cond_key.lower(), cond_key)
|
| 64 |
+
if severity_prefix and not pathologies:
|
| 65 |
+
pathologies.append(severity_prefix + cond_text)
|
| 66 |
+
severity_prefix = ""
|
| 67 |
+
else:
|
| 68 |
+
pathologies.append(cond_text)
|
| 69 |
+
|
| 70 |
+
pathology_str = (
|
| 71 |
+
f"with {', '.join(pathologies)}" if pathologies else "with no significant abnormality"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
return (
|
| 75 |
+
f"frontal {view} chest radiograph of {demographics} {pathology_str}. "
|
| 76 |
+
"The conditioning mask image provides three channels "
|
| 77 |
+
"(red=left lung, green=right lung, blue=heart). "
|
| 78 |
+
"Reconstruct a CheXpert-style chest X-ray that aligns "
|
| 79 |
+
"with the segmentation and follows the described pathology."
|
| 80 |
+
)
|
synthcxr/utils.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Small shared helpers used across scripts."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Sequence
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def resolve_path(base: Path, maybe_relative: str) -> Path:
|
| 10 |
+
"""Return *maybe_relative* as an absolute path, resolved against *base*."""
|
| 11 |
+
path = Path(maybe_relative).expanduser()
|
| 12 |
+
if path.is_absolute():
|
| 13 |
+
return path
|
| 14 |
+
return (base / path).resolve()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def build_identifier(
|
| 18 |
+
record_path: str | None,
|
| 19 |
+
fallback_paths: Sequence[str],
|
| 20 |
+
sample_idx: int,
|
| 21 |
+
) -> str:
|
| 22 |
+
"""Build a filesystem-safe identifier from a metadata record."""
|
| 23 |
+
source = record_path or (fallback_paths[0] if fallback_paths else f"sample_{sample_idx}")
|
| 24 |
+
candidate = Path(source)
|
| 25 |
+
tail_parts = [part.replace(".", "-") for part in candidate.parts[-4:]]
|
| 26 |
+
slug = "_".join(tail_parts) if tail_parts else candidate.stem
|
| 27 |
+
return f"{sample_idx:03d}_{slug}"
|