gradientguild commited on
Commit
a4aa5c5
·
verified ·
1 Parent(s): 06770e6

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base model weights (downloaded at runtime from HF Hub)
2
+ scripts/models/Qwen/
3
+
4
+ # LoRA checkpoints — keep only epoch-2 for the demo
5
+ scripts/models/qwen_image_edit_chexpert_lora/epoch-0.safetensors
6
+ scripts/models/qwen_image_edit_chexpert_lora/epoch-1.safetensors
7
+ scripts/models/qwen_image_edit_chexpert_lora/epoch-3.safetensors
8
+ scripts/models/qwen_image_edit_chexpert_lora/epoch-4.safetensors
9
+
10
+ # Python
11
+ __pycache__/
12
+ *.py[cod]
13
+ *.egg-info/
14
+ dist/
15
+ build/
16
+ *.egg
17
+
18
+ # Environment
19
+ .env
20
+ .venv/
21
+ venv/
22
+ .cache/
23
+
24
+ # IDE
25
+ .vscode/
26
+ .idea/
27
+
28
+ # OS
29
+ .DS_Store
30
+ Thumbs.db
31
+
32
+ # Misc
33
+ *.log
README.md CHANGED
@@ -1,14 +1,31 @@
1
  ---
2
  title: SynthCXR
3
- emoji: 🌍
4
- colorFrom: yellow
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 6.9.0
8
- python_version: '3.12'
9
  app_file: app.py
 
10
  pinned: false
11
- license: mit
 
 
 
 
 
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: SynthCXR
3
+ emoji: 🫁
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: "6.9.0"
 
8
  app_file: app.py
9
+ hardware: zero-a10g
10
  pinned: false
11
+ tags:
12
+ - medical-imaging
13
+ - chest-x-ray
14
+ - diffusion
15
+ - lora
16
+ short_description: Controllable chest X-ray generation with anatomical masks
17
  ---
18
 
19
+ # 🫁 SynthCXR · Chest X-Ray Generator
20
+
21
+ Interactively resize anatomical mask components (heart, left lung, right lung) with sliders and generate realistic chest X-rays using a Qwen-Image-Edit model with LoRA fine-tuning on CheXpert.
22
+
23
+ > **Zero GPU** — This Space uses HuggingFace ZeroGPU for dynamic GPU allocation. A GPU is acquired only during image generation and released immediately after.
24
+
25
+ ## Features
26
+
27
+ - **Mask Scaling Sliders** — Real-time preview of organ masks scaled from 0× to 2×
28
+ - **Condition Picker** — Select from 13 CheXpert pathologies with severity modifiers
29
+ - **Demographics** — Configure patient age, sex, and radiograph view (AP/PA)
30
+ - **CXR Generation** — Generate 512×512 chest X-rays conditioned on the modified mask
31
+ - **Progress Bar** — Real-time step-by-step progress during generation
app.py ADDED
@@ -0,0 +1,721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Gradio app for SynthCXR: interactive mask scaling and CXR generation."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import os
7
+ from pathlib import Path
8
+
9
+ import spaces
10
+
11
+ import gradio as gr
12
+ import numpy as np
13
+ import torch
14
+ from PIL import Image
15
+
16
+ from synthcxr.constants import KNOWN_CONDITIONS
17
+ from synthcxr.mask_utils import resolve_overlaps, scale_mask_channel
18
+ from synthcxr.prompt import ConditionConfig, build_condition_prompt
19
+
20
+ # ---------------------------------------------------------------------------
21
+ # Paths
22
+ # ---------------------------------------------------------------------------
23
+ BASE_DIR = Path(__file__).resolve().parent
24
+ SAMPLE_MASKS_DIR = BASE_DIR / "static" / "sample_masks"
25
+ LORA_DIR = BASE_DIR / "scripts" / "models" / "qwen_image_edit_chexpert_lora"
26
+
27
+ # ---------------------------------------------------------------------------
28
+ # Condition / severity choices
29
+ # ---------------------------------------------------------------------------
30
+ CONDITION_CHOICES = [
31
+ "enlarged_cardiomediastinum",
32
+ "cardiomegaly",
33
+ "atelectasis",
34
+ "pneumothorax",
35
+ "pleural_effusion",
36
+ ]
37
+ SEVERITY_CHOICES = ["(none)", "mild", "moderate", "severe"]
38
+
39
+ # ---------------------------------------------------------------------------
40
+ # Pipeline (lazy-loaded once)
41
+ # ---------------------------------------------------------------------------
42
+ _pipe = None
43
+
44
+
45
+ def get_pipeline():
46
+ """Load the diffusion pipeline + LoRA weights into GPU memory (once)."""
47
+ global _pipe
48
+ if _pipe is not None:
49
+ return _pipe
50
+
51
+ from synthcxr.pipeline import load_lora_weights, load_pipeline
52
+
53
+ device = "cuda" if torch.cuda.is_available() else "cpu"
54
+ dtype = torch.bfloat16
55
+
56
+ # VRAM_LIMIT (in GB): enables model offloading for memory-constrained GPUs
57
+ vram_limit_str = os.environ.get("VRAM_LIMIT", "")
58
+ vram_limit = float(vram_limit_str) if vram_limit_str else None
59
+
60
+ print(f"[INFO] Loading QwenImagePipeline (device={device}, dtype={dtype}, vram_limit={vram_limit}) …")
61
+ _pipe = load_pipeline(device, dtype, vram_limit=vram_limit)
62
+
63
+ # LORA_EPOCH env var: which epoch checkpoint to load (default: 2)
64
+ lora_epoch = os.environ.get("LORA_EPOCH", "2")
65
+ lora = LORA_DIR / f"epoch-{lora_epoch}.safetensors"
66
+
67
+ if not lora.exists():
68
+ # Try step-based checkpoints or any available .safetensors
69
+ candidates = sorted(LORA_DIR.glob("*.safetensors")) if LORA_DIR.exists() else []
70
+ if candidates:
71
+ lora = candidates[-1]
72
+ print(f"[WARN] epoch-{lora_epoch} not found, falling back to {lora.name}")
73
+ else:
74
+ print("[WARN] No LoRA checkpoint found – running base model only.")
75
+ return _pipe
76
+
77
+ print(f"[INFO] Loading LoRA from {lora}")
78
+ load_lora_weights(_pipe, lora)
79
+ print("[INFO] Pipeline ready.")
80
+ return _pipe
81
+
82
+
83
+ # ---------------------------------------------------------------------------
84
+ # Sample masks
85
+ # ---------------------------------------------------------------------------
86
+ def get_sample_masks() -> list[str]:
87
+ """Return paths of bundled sample masks."""
88
+ if not SAMPLE_MASKS_DIR.exists():
89
+ return []
90
+ return sorted(str(p) for p in SAMPLE_MASKS_DIR.glob("*.png"))
91
+
92
+
93
+ # ---------------------------------------------------------------------------
94
+ # Core functions
95
+ # ---------------------------------------------------------------------------
96
+
97
+ def apply_mask_scaling(
98
+ mask_array: np.ndarray,
99
+ heart_scale: float,
100
+ left_lung_scale: float,
101
+ right_lung_scale: float,
102
+ ) -> np.ndarray:
103
+ """Scale mask channels and resolve overlaps."""
104
+ if heart_scale != 1.0:
105
+ mask_array = scale_mask_channel(mask_array, channel=2, scale_factor=heart_scale)
106
+ if left_lung_scale != 1.0:
107
+ mask_array = scale_mask_channel(mask_array, channel=0, scale_factor=left_lung_scale)
108
+ if right_lung_scale != 1.0:
109
+ mask_array = scale_mask_channel(mask_array, channel=1, scale_factor=right_lung_scale)
110
+ return resolve_overlaps(mask_array, priority=(2, 0, 1))
111
+
112
+
113
+ def preview_mask(
114
+ mask_image: np.ndarray | None,
115
+ heart_scale: float,
116
+ left_lung_scale: float,
117
+ right_lung_scale: float,
118
+ ) -> np.ndarray | None:
119
+ """Live mask preview callback."""
120
+ if mask_image is None:
121
+ return None
122
+ mask = np.array(Image.fromarray(mask_image).convert("RGB"))
123
+ scaled = apply_mask_scaling(mask, heart_scale, left_lung_scale, right_lung_scale)
124
+ return scaled
125
+
126
+
127
+ def build_prompt_preview(
128
+ conditions: list[str],
129
+ severity: str,
130
+ age: int,
131
+ sex: str,
132
+ view: str,
133
+ ) -> str:
134
+ """Build the prompt text for preview."""
135
+ cond = ConditionConfig(
136
+ name="preview",
137
+ conditions=conditions or [],
138
+ age=age,
139
+ sex=sex,
140
+ view=view,
141
+ severity=severity if severity != "(none)" else None,
142
+ )
143
+ return build_condition_prompt(cond)
144
+
145
+
146
+ @spaces.GPU(duration=120)
147
+ def generate_cxr(
148
+ mask_image: np.ndarray | None,
149
+ heart_scale: float,
150
+ left_lung_scale: float,
151
+ right_lung_scale: float,
152
+ conditions: list[str],
153
+ severity: str,
154
+ age: int,
155
+ sex: str,
156
+ view: str,
157
+ num_steps: int,
158
+ cfg_scale: float,
159
+ seed: int,
160
+ preview_every: int = 10,
161
+ progress=gr.Progress(),
162
+ ):
163
+ """Generate a CXR, yielding intermediate previews every N steps."""
164
+ if mask_image is None:
165
+ raise gr.Error("Please select or upload a mask first.")
166
+
167
+ pipe = get_pipeline()
168
+ if pipe is None:
169
+ raise gr.Error("Pipeline not loaded. GPU may be unavailable.")
170
+
171
+ # Prepare mask
172
+ mask = np.array(Image.fromarray(mask_image).convert("RGB"))
173
+ scaled = apply_mask_scaling(mask, heart_scale, left_lung_scale, right_lung_scale)
174
+ edit_image = Image.fromarray(scaled)
175
+
176
+ # Build prompt
177
+ cond = ConditionConfig(
178
+ name="web_ui",
179
+ conditions=conditions or [],
180
+ age=age,
181
+ sex=sex,
182
+ view=view,
183
+ severity=severity if severity != "(none)" else None,
184
+ )
185
+ prompt = build_condition_prompt(cond)
186
+
187
+ # Intermediate preview collector
188
+ previews: list[Image.Image] = []
189
+
190
+ class StepCallback:
191
+ """Custom tqdm-like wrapper that decodes latents every N steps."""
192
+ def __init__(self, iterable):
193
+ self._iterable = iterable
194
+ self._step = 0
195
+
196
+ def __iter__(self):
197
+ for item in self._iterable:
198
+ progress(self._step / num_steps, desc="Generating CXR...")
199
+ yield item
200
+ self._step += 1
201
+ if (
202
+ preview_every > 0
203
+ and self._step % preview_every == 0
204
+ and self._step < num_steps
205
+ and "latents" in _shared_ref
206
+ ):
207
+ try:
208
+ with torch.no_grad():
209
+ latents = _shared_ref["latents"]
210
+ decoded = pipe.vae.decode(
211
+ latents,
212
+ device=pipe.device,
213
+ tiled=False,
214
+ )
215
+ img = pipe.vae_output_to_image(decoded)
216
+ previews.append(img)
217
+ except Exception:
218
+ pass # skip preview on error
219
+
220
+ def __len__(self):
221
+ return len(self._iterable)
222
+
223
+ # We patch the pipeline's __call__ to capture inputs_shared reference.
224
+ # The pipeline stores latents in inputs_shared["latents"] during denoising.
225
+ _shared_ref: dict = {}
226
+ _orig_unit_runner = pipe.unit_runner.__class__.__call__
227
+
228
+ def _patched_runner(self_runner, unit, p, inputs_shared, inputs_posi, inputs_nega):
229
+ _shared_ref.update(inputs_shared)
230
+ return _orig_unit_runner(self_runner, unit, p, inputs_shared, inputs_posi, inputs_nega)
231
+
232
+ pipe.unit_runner.__class__.__call__ = _patched_runner
233
+
234
+ try:
235
+ image = pipe(
236
+ prompt=prompt,
237
+ edit_image=edit_image,
238
+ height=512,
239
+ width=512,
240
+ num_inference_steps=num_steps,
241
+ seed=seed,
242
+ rand_device=pipe.device,
243
+ cfg_scale=cfg_scale,
244
+ edit_image_auto_resize=True,
245
+ zero_cond_t=True,
246
+ progress_bar_cmd=StepCallback,
247
+ )
248
+ finally:
249
+ # Restore original runner
250
+ pipe.unit_runner.__class__.__call__ = _orig_unit_runner
251
+
252
+ # Yield all collected previews, then the final image
253
+ for preview in previews:
254
+ yield preview
255
+ yield image
256
+
257
+
258
+ # ---------------------------------------------------------------------------
259
+ # Gradio UI
260
+ # ---------------------------------------------------------------------------
261
+
262
+ CUSTOM_CSS = """
263
+ /* ── Layout ── */
264
+ .gradio-container {
265
+ max-width: 1280px !important;
266
+ margin: 0 auto !important;
267
+ }
268
+
269
+ /* ── Radial gradient background ── */
270
+ .main {
271
+ background:
272
+ radial-gradient(ellipse 80% 50% at 10% 20%, rgba(99,102,241,0.07), transparent),
273
+ radial-gradient(ellipse 60% 40% at 85% 75%, rgba(59,130,246,0.05), transparent) !important;
274
+ }
275
+
276
+ /* ── Header ── */
277
+ #component-0 h1 {
278
+ text-align: center;
279
+ font-size: 2.2rem !important;
280
+ font-weight: 800 !important;
281
+ letter-spacing: -0.5px;
282
+ background: linear-gradient(135deg, #818cf8, #60a5fa, #818cf8);
283
+ background-size: 200% 200%;
284
+ -webkit-background-clip: text;
285
+ -webkit-text-fill-color: transparent;
286
+ background-clip: text;
287
+ animation: gradientShift 4s ease-in-out infinite;
288
+ padding-bottom: 4px !important;
289
+ }
290
+ #component-0 p {
291
+ text-align: center;
292
+ color: #94a3b8 !important;
293
+ font-size: 0.95rem;
294
+ }
295
+
296
+ @keyframes gradientShift {
297
+ 0%, 100% { background-position: 0% 50%; }
298
+ 50% { background-position: 100% 50%; }
299
+ }
300
+
301
+ /* ── Glass panels ── */
302
+ .block {
303
+ border: 1px solid rgba(99,115,146,0.15) !important;
304
+ border-radius: 16px !important;
305
+ backdrop-filter: blur(12px);
306
+ transition: border-color 0.3s ease, box-shadow 0.3s ease !important;
307
+ }
308
+ .block:hover {
309
+ border-color: rgba(99,102,241,0.25) !important;
310
+ box-shadow: 0 0 20px rgba(99,102,241,0.06) !important;
311
+ }
312
+
313
+ /* ── Section headings ── */
314
+ .markdown h3 {
315
+ font-size: 0.78rem !important;
316
+ font-weight: 700 !important;
317
+ text-transform: uppercase;
318
+ letter-spacing: 1.2px;
319
+ color: #64748b !important;
320
+ border-bottom: 1px solid rgba(99,115,146,0.12);
321
+ padding-bottom: 8px !important;
322
+ margin-bottom: 12px !important;
323
+ }
324
+
325
+ /* ── Slider styling ── */
326
+ input[type="range"] {
327
+ height: 6px !important;
328
+ border-radius: 3px !important;
329
+ background: #1e293b !important;
330
+ }
331
+ input[type="range"]::-webkit-slider-thumb {
332
+ width: 18px !important;
333
+ height: 18px !important;
334
+ border-radius: 50% !important;
335
+ border: 2.5px solid #0a0e17 !important;
336
+ transition: transform 0.2s ease, box-shadow 0.2s ease !important;
337
+ }
338
+ input[type="range"]::-webkit-slider-thumb:hover {
339
+ transform: scale(1.2) !important;
340
+ }
341
+
342
+ /* Slider labels */
343
+ .block label span {
344
+ font-weight: 500 !important;
345
+ font-size: 0.88rem !important;
346
+ }
347
+ .block .rangeSlider_value {
348
+ font-variant-numeric: tabular-nums;
349
+ font-weight: 600 !important;
350
+ }
351
+
352
+ /* ── Image panels ── */
353
+ .image-frame img, .image-container img {
354
+ border-radius: 10px !important;
355
+ transition: opacity 0.3s ease !important;
356
+ }
357
+ .image-container {
358
+ background: rgba(0,0,0,0.2) !important;
359
+ border-radius: 12px !important;
360
+ min-height: 380px;
361
+ }
362
+
363
+ /* ── Generate button ── */
364
+ .primary {
365
+ background: linear-gradient(135deg, #6366f1, #4f46e5, #6366f1) !important;
366
+ background-size: 200% 200% !important;
367
+ border: none !important;
368
+ border-radius: 12px !important;
369
+ padding: 14px 24px !important;
370
+ font-weight: 700 !important;
371
+ font-size: 1rem !important;
372
+ letter-spacing: 0.3px;
373
+ transition: all 0.3s cubic-bezier(0.4,0,0.2,1) !important;
374
+ position: relative;
375
+ overflow: hidden;
376
+ }
377
+ .primary:hover {
378
+ transform: translateY(-2px) !important;
379
+ box-shadow: 0 8px 25px rgba(99,102,241,0.4) !important;
380
+ animation: btnShimmer 1.5s ease-in-out infinite !important;
381
+ }
382
+ .primary:active {
383
+ transform: translateY(0) !important;
384
+ }
385
+ @keyframes btnShimmer {
386
+ 0%, 100% { background-position: 0% 50%; }
387
+ 50% { background-position: 100% 50%; }
388
+ }
389
+
390
+ /* ── Secondary buttons ── */
391
+ .secondary {
392
+ border: 1px solid rgba(99,115,146,0.2) !important;
393
+ border-radius: 10px !important;
394
+ background: transparent !important;
395
+ color: #94a3b8 !important;
396
+ transition: all 0.25s ease !important;
397
+ }
398
+ .secondary:hover {
399
+ border-color: rgba(99,102,241,0.4) !important;
400
+ color: #e2e8f0 !important;
401
+ background: rgba(99,102,241,0.06) !important;
402
+ }
403
+
404
+ /* ── Prompt preview ── */
405
+ textarea[readonly], .prose {
406
+ font-family: 'JetBrains Mono', 'Fira Code', monospace !important;
407
+ font-size: 0.8rem !important;
408
+ line-height: 1.6 !important;
409
+ color: #64748b !important;
410
+ background: rgba(0,0,0,0.25) !important;
411
+ border-radius: 10px !important;
412
+ }
413
+
414
+ /* ── Checkboxes ── */
415
+ .checkbox-group label {
416
+ border-radius: 20px !important;
417
+ padding: 4px 12px !important;
418
+ font-size: 0.8rem !important;
419
+ transition: all 0.2s ease !important;
420
+ border: 1px solid rgba(99,115,146,0.15) !important;
421
+ color: #e2e8f0 !important;
422
+ background: rgba(17,24,39,0.75) !important;
423
+ }
424
+ .checkbox-group label span {
425
+ color: #e2e8f0 !important;
426
+ }
427
+ .checkbox-group label:hover {
428
+ border-color: rgba(99,102,241,0.35) !important;
429
+ background: rgba(30,41,59,0.9) !important;
430
+ }
431
+ .checkbox-group input:checked + label,
432
+ .checkbox-group label.selected {
433
+ background: rgba(99,102,241,0.15) !important;
434
+ border-color: rgba(99,102,241,0.4) !important;
435
+ color: #c7d2fe !important;
436
+ }
437
+
438
+ /* ── Dropdowns & inputs ── */
439
+ select, input[type="number"] {
440
+ border-radius: 10px !important;
441
+ border: 1px solid rgba(99,115,146,0.15) !important;
442
+ transition: border-color 0.25s ease !important;
443
+ font-size: 0.88rem !important;
444
+ }
445
+ select:focus, input[type="number"]:focus {
446
+ border-color: rgba(99,102,241,0.5) !important;
447
+ box-shadow: 0 0 0 2px rgba(99,102,241,0.1) !important;
448
+ }
449
+
450
+ /* ── Accordion ── */
451
+ .accordion {
452
+ border: 1px solid rgba(99,115,146,0.1) !important;
453
+ border-radius: 12px !important;
454
+ background: rgba(0,0,0,0.15) !important;
455
+ }
456
+ .accordion > .label-wrap {
457
+ font-size: 0.82rem !important;
458
+ color: #64748b !important;
459
+ font-weight: 500 !important;
460
+ }
461
+
462
+ /* ── Examples gallery ── */
463
+ .gallery-item {
464
+ border-radius: 10px !important;
465
+ border: 2px solid rgba(99,115,146,0.15) !important;
466
+ transition: all 0.25s ease !important;
467
+ overflow: hidden;
468
+ }
469
+ .gallery-item:hover {
470
+ border-color: rgba(99,102,241,0.4) !important;
471
+ transform: scale(1.04);
472
+ box-shadow: 0 4px 16px rgba(99,102,241,0.15) !important;
473
+ }
474
+
475
+ /* ── Scrollbar ── */
476
+ ::-webkit-scrollbar { width: 6px; }
477
+ ::-webkit-scrollbar-track { background: transparent; }
478
+ ::-webkit-scrollbar-thumb {
479
+ background: rgba(99,115,146,0.25);
480
+ border-radius: 3px;
481
+ }
482
+ ::-webkit-scrollbar-thumb:hover { background: rgba(99,115,146,0.4); }
483
+
484
+ /* ── Footer spacing ── */
485
+ .gradio-container > .main > .wrap:last-child { padding-bottom: 40px !important; }
486
+ """
487
+
488
+ sample_paths = get_sample_masks()
489
+
490
+ THEME = gr.themes.Base(
491
+ primary_hue=gr.themes.colors.indigo,
492
+ secondary_hue=gr.themes.colors.slate,
493
+ neutral_hue=gr.themes.colors.slate,
494
+ font=gr.themes.GoogleFont("Inter"),
495
+ font_mono=gr.themes.GoogleFont("JetBrains Mono"),
496
+ radius_size=gr.themes.sizes.radius_lg,
497
+ spacing_size=gr.themes.sizes.spacing_md,
498
+ ).set(
499
+ # Background
500
+ body_background_fill="#0a0e17",
501
+ body_background_fill_dark="#0a0e17",
502
+ # Panels
503
+ block_background_fill="rgba(17,24,39,0.75)",
504
+ block_background_fill_dark="rgba(17,24,39,0.75)",
505
+ block_border_color="rgba(99,115,146,0.15)",
506
+ block_border_color_dark="rgba(99,115,146,0.15)",
507
+ block_shadow="0 4px 24px rgba(0,0,0,0.2)",
508
+ block_shadow_dark="0 4px 24px rgba(0,0,0,0.2)",
509
+ # Inputs
510
+ input_background_fill="#131b2e",
511
+ input_background_fill_dark="#131b2e",
512
+ input_border_color="rgba(99,115,146,0.15)",
513
+ input_border_color_dark="rgba(99,115,146,0.15)",
514
+ # Buttons
515
+ button_primary_background_fill="linear-gradient(135deg, #6366f1, #4f46e5)",
516
+ button_primary_background_fill_dark="linear-gradient(135deg, #6366f1, #4f46e5)",
517
+ button_primary_text_color="white",
518
+ button_primary_text_color_dark="white",
519
+ button_primary_shadow="0 4px 14px rgba(99,102,241,0.25)",
520
+ button_primary_shadow_dark="0 4px 14px rgba(99,102,241,0.25)",
521
+ # Text
522
+ body_text_color="#e2e8f0",
523
+ body_text_color_dark="#e2e8f0",
524
+ body_text_color_subdued="#94a3b8",
525
+ body_text_color_subdued_dark="#94a3b8",
526
+ # Labels
527
+ block_label_text_color="#94a3b8",
528
+ block_label_text_color_dark="#94a3b8",
529
+ block_title_text_color="#cbd5e1",
530
+ block_title_text_color_dark="#cbd5e1",
531
+ # Borders
532
+ border_color_primary="rgba(99,102,241,0.4)",
533
+ border_color_primary_dark="rgba(99,102,241,0.4)",
534
+ )
535
+
536
+ with gr.Blocks(
537
+ title="SynthCXR · Chest X-Ray Generator",
538
+ ) as demo:
539
+
540
+ gr.Markdown(
541
+ "# 🫁 SynthCXR\n"
542
+ "Interactively resize anatomical masks and generate realistic chest X-rays"
543
+ )
544
+
545
+ with gr.Row():
546
+
547
+ # ── Left column: Controls ──
548
+ with gr.Column(scale=1):
549
+
550
+ # Mask input
551
+ gr.Markdown("### Select Mask")
552
+ mask_input = gr.Image(
553
+ label="Conditioning Mask",
554
+ type="numpy",
555
+ sources=["upload"],
556
+ height=240,
557
+ )
558
+
559
+ # Sample mask gallery
560
+ if sample_paths:
561
+ sample_gallery = gr.Examples(
562
+ examples=sample_paths,
563
+ inputs=mask_input,
564
+ label="Sample Masks",
565
+ )
566
+
567
+ # Sliders
568
+ gr.Markdown("### Mask Scaling")
569
+ heart_slider = gr.Slider(
570
+ minimum=0.0, maximum=2.0, step=0.05, value=1.0,
571
+ label="💙 Heart Scale",
572
+ )
573
+ left_lung_slider = gr.Slider(
574
+ minimum=0.0, maximum=2.0, step=0.05, value=1.0,
575
+ label="🔴 Left Lung Scale",
576
+ )
577
+ right_lung_slider = gr.Slider(
578
+ minimum=0.0, maximum=2.0, step=0.05, value=1.0,
579
+ label="🟢 Right Lung Scale",
580
+ )
581
+ reset_btn = gr.Button("↺ Reset Scales", variant="secondary", size="sm")
582
+
583
+ # Conditions
584
+ gr.Markdown("### Conditions")
585
+ conditions_select = gr.CheckboxGroup(
586
+ choices=CONDITION_CHOICES,
587
+ label="Pathologies",
588
+ )
589
+ with gr.Row():
590
+ severity_select = gr.Radio(
591
+ choices=SEVERITY_CHOICES, value="(none)", label="Severity",
592
+ )
593
+ view_select = gr.Radio(
594
+ choices=["AP", "PA"], value="AP", label="View",
595
+ )
596
+ with gr.Row():
597
+ age_input = gr.Number(value=45, label="Age", minimum=0, maximum=120, precision=0)
598
+ sex_select = gr.Radio(
599
+ choices=["male", "female"], value="male", label="Sex",
600
+ )
601
+
602
+ # Advanced
603
+ with gr.Accordion("Advanced Settings", open=False):
604
+ with gr.Row():
605
+ steps_input = gr.Number(value=50, label="Steps", minimum=1, maximum=100, precision=0)
606
+ cfg_input = gr.Number(value=4.0, label="CFG Scale", minimum=1.0, maximum=20.0)
607
+ with gr.Row():
608
+ seed_input = gr.Number(value=42, label="Seed", minimum=0, precision=0)
609
+ preview_every_input = gr.Number(value=10, label="Preview Every N Steps", minimum=0, maximum=50, precision=0)
610
+
611
+ # ── Right column: Outputs ──
612
+ with gr.Column(scale=2):
613
+
614
+ with gr.Row():
615
+ mask_preview = gr.Image(
616
+ label="Scaled Mask Preview",
617
+ type="numpy",
618
+ interactive=False,
619
+ height=400,
620
+ )
621
+ cxr_output = gr.Image(
622
+ label="Generated Chest X-Ray",
623
+ type="pil",
624
+ interactive=False,
625
+ height=400,
626
+ )
627
+
628
+ # Prompt preview
629
+ prompt_preview = gr.Textbox(
630
+ label="Prompt Preview",
631
+ interactive=False,
632
+ lines=3,
633
+ )
634
+
635
+ generate_btn = gr.Button("⚡ Generate CXR", variant="primary", size="lg")
636
+
637
+ # ── Event wiring ──
638
+
639
+ # Live mask preview on any slider / mask change
640
+ slider_inputs = [mask_input, heart_slider, left_lung_slider, right_lung_slider]
641
+
642
+ mask_input.change(preview_mask, inputs=slider_inputs, outputs=mask_preview)
643
+ heart_slider.change(preview_mask, inputs=slider_inputs, outputs=mask_preview)
644
+ left_lung_slider.change(preview_mask, inputs=slider_inputs, outputs=mask_preview)
645
+ right_lung_slider.change(preview_mask, inputs=slider_inputs, outputs=mask_preview)
646
+
647
+ # Reset sliders
648
+ def reset_scales():
649
+ return 1.0, 1.0, 1.0
650
+
651
+ reset_btn.click(
652
+ reset_scales,
653
+ outputs=[heart_slider, left_lung_slider, right_lung_slider],
654
+ )
655
+
656
+ # Auto-adjust sliders when conditions change
657
+ _CONDITION_SCALE_MAP = {
658
+ # condition_key: (heart_delta, lung_delta)
659
+ "cardiomegaly": (+0.35, 0.0),
660
+ "enlarged_cardiomediastinum": (+0.25, 0.0),
661
+ "atelectasis": (0.0, -0.25),
662
+ "pneumothorax": (0.0, -0.30),
663
+ "pleural_effusion": (0.0, -0.20),
664
+ }
665
+ _SEVERITY_MULTIPLIER = {
666
+ "(none)": 1.0,
667
+ "mild": 0.6,
668
+ "moderate": 1.0,
669
+ "severe": 1.5,
670
+ }
671
+
672
+ def sync_sliders(conditions: list[str], severity: str):
673
+ """Set slider values based on selected conditions + severity."""
674
+ heart = 1.0
675
+ lung = 1.0
676
+ mult = _SEVERITY_MULTIPLIER.get(severity, 1.0)
677
+ for cond in (conditions or []):
678
+ h_delta, l_delta = _CONDITION_SCALE_MAP.get(cond, (0.0, 0.0))
679
+ heart += h_delta * mult
680
+ lung += l_delta * mult
681
+ # Clamp to slider range [0.0, 2.0]
682
+ heart = round(max(0.0, min(2.0, heart)), 2)
683
+ lung = round(max(0.0, min(2.0, lung)), 2)
684
+ return heart, lung, lung
685
+
686
+ conditions_select.change(
687
+ sync_sliders,
688
+ inputs=[conditions_select, severity_select],
689
+ outputs=[heart_slider, left_lung_slider, right_lung_slider],
690
+ )
691
+ severity_select.change(
692
+ sync_sliders,
693
+ inputs=[conditions_select, severity_select],
694
+ outputs=[heart_slider, left_lung_slider, right_lung_slider],
695
+ )
696
+
697
+ # Prompt preview on config change
698
+ prompt_inputs = [conditions_select, severity_select, age_input, sex_select, view_select]
699
+
700
+ for inp in prompt_inputs:
701
+ inp.change(build_prompt_preview, inputs=prompt_inputs, outputs=prompt_preview)
702
+
703
+ # Generate
704
+ generate_btn.click(
705
+ generate_cxr,
706
+ inputs=[
707
+ mask_input,
708
+ heart_slider, left_lung_slider, right_lung_slider,
709
+ conditions_select, severity_select,
710
+ age_input, sex_select, view_select,
711
+ steps_input, cfg_input, seed_input,
712
+ preview_every_input,
713
+ ],
714
+ outputs=cxr_output,
715
+ )
716
+
717
+
718
+ # ---------------------------------------------------------------------------
719
+ # Launch (module-level for HuggingFace Spaces compatibility)
720
+ # ---------------------------------------------------------------------------
721
+ demo.launch(theme=THEME, css=CUSTOM_CSS)
pyproject.toml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "synthcxr"
3
+ version = "0.1.0"
4
+ description = "Chest X-ray generation via Qwen-Image-Edit LoRA fine-tuning"
5
+ requires-python = ">=3.10.1"
6
+ dependencies = [
7
+ "diffsynth>=2.0.4",
8
+ "fastapi[standard]>=0.135.1",
9
+ "gradio>=6.8.0",
10
+ "python-multipart>=0.0.22",
11
+ "scipy",
12
+ "uvicorn[standard]>=0.41.0",
13
+ ]
14
+
15
+ [build-system]
16
+ requires = ["setuptools>=68"]
17
+ build-backend = "setuptools.build_meta"
18
+
19
+ [tool.setuptools.packages.find]
20
+ where = ["src"]
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=4.0
2
+ diffsynth>=2.0.4
3
+ spaces
4
+ scipy
5
+ Pillow
6
+ numpy
7
+ torch
scripts/models/qwen_image_edit_chexpert_lora/epoch-2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fef90b53ae95c9628efe14b0919f7be7e291ec9f80677a3f2ed509ebccca1c05
3
+ size 472047184
scripts/models/qwen_image_edit_chexpert_lora/latest_checkpoint.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"path": "./models/qwen_image_edit_chexpert_lora/checkpoint-step233240", "epoch_id": 4, "global_step": 233240}
static/sample_masks/sample_1.png ADDED
static/sample_masks/sample_2.png ADDED
static/sample_masks/sample_3.png ADDED
synthcxr/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """SynthCXR: Chest X-ray generation via Qwen-Image-Edit LoRA fine-tuning."""
synthcxr/constants.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared constants for SynthCXR: disease labels, condition maps, severity modifiers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ # CheXpert label column names -> natural-language descriptions used in prompts.
6
+ # Used by both dataset preparation and inference scripts.
7
+ LABEL_TEXT: dict[str, str] = {
8
+ "Enlarged Cardiomediastinum": "enlarged cardiomediastinum",
9
+ "Cardiomegaly": "cardiomegaly",
10
+ "Lung Opacity": "diffuse lung opacity",
11
+ "Lung Lesion": "discrete lung lesion",
12
+ "Edema": "pulmonary edema",
13
+ "Consolidation": "parenchymal consolidation",
14
+ "Pneumonia": "findings compatible with pneumonia",
15
+ "Atelectasis": "atelectasis",
16
+ "Pneumothorax": "pneumothorax",
17
+ "Pleural Effusion": "pleural effusion",
18
+ "Pleural Other": "other pleural abnormality",
19
+ "Fracture": "possible fracture",
20
+ "Support Devices": "support devices in place",
21
+ }
22
+
23
+ # Snake_case keys for config files -> natural-language descriptions.
24
+ KNOWN_CONDITIONS: dict[str, str] = {
25
+ "enlarged_cardiomediastinum": "enlarged cardiomediastinum",
26
+ "cardiomegaly": "cardiomegaly",
27
+ "lung_opacity": "diffuse lung opacity",
28
+ "lung_lesion": "discrete lung lesion",
29
+ "edema": "pulmonary edema",
30
+ "consolidation": "parenchymal consolidation",
31
+ "pneumonia": "findings compatible with pneumonia",
32
+ "atelectasis": "atelectasis",
33
+ "pneumothorax": "pneumothorax",
34
+ "pleural_effusion": "pleural effusion",
35
+ "pleural_other": "other pleural abnormality",
36
+ "fracture": "possible fracture",
37
+ "support_devices": "support devices in place",
38
+ }
39
+
40
+ SEVERITY_MODIFIERS: dict[str, str] = {
41
+ "mild": "mild",
42
+ "moderate": "moderate",
43
+ "severe": "severe",
44
+ "small": "small",
45
+ "large": "large",
46
+ "very_small": "very small",
47
+ "very_large": "very large",
48
+ "minimal": "minimal",
49
+ "significant": "significant",
50
+ }
51
+
52
+ DEFAULT_MODEL_ID = "Qwen/Qwen-Image-Edit-2511"
53
+ TEXT_ENCODER_MODEL_ID = "Qwen/Qwen-Image"
54
+ PROCESSOR_MODEL_ID = "Qwen/Qwen-Image-Edit"
synthcxr/mask_utils.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Mask manipulation: scaling organ regions and resolving overlaps."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ from PIL import Image
9
+ from scipy import ndimage
10
+ from scipy.ndimage import map_coordinates
11
+
12
+
13
+ def resolve_overlaps(
14
+ mask: np.ndarray,
15
+ priority: tuple[int, int, int] = (2, 0, 1),
16
+ threshold: int = 10,
17
+ ) -> np.ndarray:
18
+ """Assign overlapping pixels to the highest-priority channel.
19
+
20
+ Default priority: heart (2) > left lung (0) > right lung (1).
21
+ """
22
+ result = mask.copy()
23
+ active = mask > threshold
24
+ overlap_mask = active.sum(axis=2) > 1
25
+ if not overlap_mask.any():
26
+ return result
27
+
28
+ for y, x in zip(*np.where(overlap_mask)):
29
+ active_channels = [ch for ch in range(3) if mask[y, x, ch] > threshold]
30
+ best = min(active_channels, key=lambda ch: priority.index(ch))
31
+ for ch in active_channels:
32
+ if ch != best:
33
+ result[y, x, ch] = 0
34
+ return result
35
+
36
+
37
+ def scale_mask_channel(
38
+ mask: np.ndarray,
39
+ channel: int,
40
+ scale_factor: float,
41
+ threshold: int = 10,
42
+ ) -> np.ndarray:
43
+ """Scale a single channel's region around its centroid.
44
+
45
+ ``channel``: 0 = left lung (red), 1 = right lung (green), 2 = heart (blue).
46
+ """
47
+ result = mask.copy()
48
+ channel_data = mask[:, :, channel]
49
+ binary = channel_data > threshold
50
+ if not binary.any():
51
+ return result
52
+
53
+ cy, cx = ndimage.center_of_mass(binary)
54
+ h, w = mask.shape[:2]
55
+ y_coords, x_coords = np.mgrid[0:h, 0:w]
56
+ y_t = ((y_coords - cy) / scale_factor + cy).astype(np.float32)
57
+ x_t = ((x_coords - cx) / scale_factor + cx).astype(np.float32)
58
+
59
+ result[:, :, channel] = 0
60
+ scaled = map_coordinates(
61
+ channel_data.astype(np.float32),
62
+ [y_t, x_t],
63
+ order=1,
64
+ mode="constant",
65
+ cval=0,
66
+ )
67
+ result[:, :, channel] = np.clip(scaled, 0, 255).astype(np.uint8)
68
+ return result
69
+
70
+
71
+ def modify_mask(
72
+ input_path: Path,
73
+ output_path: Path,
74
+ heart_scale: float = 1.0,
75
+ left_lung_scale: float = 1.0,
76
+ right_lung_scale: float = 1.0,
77
+ ) -> None:
78
+ """Load a conditioning mask, apply scale factors, and save."""
79
+ with Image.open(input_path) as img:
80
+ mask = np.array(img.convert("RGB"))
81
+
82
+ if left_lung_scale != 1.0:
83
+ mask = scale_mask_channel(mask, channel=0, scale_factor=left_lung_scale)
84
+ if right_lung_scale != 1.0:
85
+ mask = scale_mask_channel(mask, channel=1, scale_factor=right_lung_scale)
86
+ if heart_scale != 1.0:
87
+ mask = scale_mask_channel(mask, channel=2, scale_factor=heart_scale)
88
+
89
+ mask = resolve_overlaps(mask, priority=(2, 0, 1))
90
+
91
+ output_path.parent.mkdir(parents=True, exist_ok=True)
92
+ Image.fromarray(mask).save(output_path)
93
+ print(f"[INFO] Saved modified mask to {output_path}")
synthcxr/pipeline.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pipeline loading, LoRA weight management, and image I/O helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Sequence
8
+
9
+ import numpy as np
10
+ import torch
11
+ from PIL import Image
12
+
13
+ from diffsynth.core import ModelConfig
14
+ from diffsynth.pipelines.qwen_image import QwenImagePipeline
15
+
16
+ from .constants import DEFAULT_MODEL_ID, PROCESSOR_MODEL_ID, TEXT_ENCODER_MODEL_ID
17
+ from .mask_utils import resolve_overlaps, scale_mask_channel
18
+
19
+
20
+ @dataclass
21
+ class SampleSpec:
22
+ """A single validation/inference sample."""
23
+
24
+ prompt: str
25
+ mask_paths: list[Path]
26
+ identifier: str
27
+ image_path: Path | None
28
+ original_prompt: str = ""
29
+
30
+
31
+ def load_pipeline(
32
+ device: str,
33
+ torch_dtype: torch.dtype,
34
+ model_id: str = DEFAULT_MODEL_ID,
35
+ vram_limit: float | None = None,
36
+ ) -> QwenImagePipeline:
37
+ """Instantiate a ``QwenImagePipeline``, downloading weights from HF Hub."""
38
+ model_configs = [
39
+ ModelConfig(
40
+ model_id=model_id,
41
+ origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors",
42
+ ),
43
+ ModelConfig(
44
+ model_id=TEXT_ENCODER_MODEL_ID,
45
+ origin_file_pattern="text_encoder/model*.safetensors",
46
+ ),
47
+ ModelConfig(
48
+ model_id=TEXT_ENCODER_MODEL_ID,
49
+ origin_file_pattern="vae/diffusion_pytorch_model.safetensors",
50
+ ),
51
+ ]
52
+ tokenizer_config = ModelConfig(
53
+ model_id=TEXT_ENCODER_MODEL_ID,
54
+ origin_file_pattern="tokenizer/",
55
+ )
56
+ processor_config = ModelConfig(
57
+ model_id=PROCESSOR_MODEL_ID,
58
+ origin_file_pattern="processor/",
59
+ )
60
+ pipe = QwenImagePipeline.from_pretrained(
61
+ torch_dtype=torch_dtype,
62
+ device=device,
63
+ model_configs=model_configs,
64
+ tokenizer_config=tokenizer_config,
65
+ processor_config=processor_config,
66
+ vram_limit=vram_limit,
67
+ )
68
+ return pipe
69
+
70
+
71
+ def load_lora_weights(pipe: QwenImagePipeline, checkpoint: Path) -> None:
72
+ """Load a LoRA checkpoint into an existing pipeline."""
73
+ pipe.clear_lora()
74
+ pipe.load_lora(pipe.dit, lora_config=str(checkpoint))
75
+
76
+
77
+ def load_edit_images(
78
+ paths: Sequence[Path],
79
+ *,
80
+ heart_scale: float = 1.0,
81
+ left_lung_scale: float = 1.0,
82
+ right_lung_scale: float = 1.0,
83
+ ) -> Image.Image | list[Image.Image]:
84
+ """Load conditioning mask image(s), optionally rescaling organ regions."""
85
+ images: list[Image.Image] = []
86
+ needs_modification = (
87
+ heart_scale != 1.0 or left_lung_scale != 1.0 or right_lung_scale != 1.0
88
+ )
89
+ for path in paths:
90
+ with Image.open(path) as img:
91
+ if needs_modification:
92
+ mask = np.array(img.convert("RGB"))
93
+ if heart_scale != 1.0:
94
+ mask = scale_mask_channel(mask, channel=2, scale_factor=heart_scale)
95
+ if left_lung_scale != 1.0:
96
+ mask = scale_mask_channel(mask, channel=0, scale_factor=left_lung_scale)
97
+ if right_lung_scale != 1.0:
98
+ mask = scale_mask_channel(mask, channel=1, scale_factor=right_lung_scale)
99
+ mask = resolve_overlaps(mask, priority=(2, 0, 1))
100
+ images.append(Image.fromarray(mask))
101
+ else:
102
+ images.append(img.convert("RGB"))
103
+ return images[0] if len(images) == 1 else images
104
+
105
+
106
+ def export_original_images(
107
+ samples: Sequence[SampleSpec], output_dir: Path
108
+ ) -> None:
109
+ """Copy original CXR images into *output_dir*/original/ for comparison."""
110
+ original_dir = output_dir / "original"
111
+ original_dir.mkdir(parents=True, exist_ok=True)
112
+ for sample in samples:
113
+ if sample.image_path is None:
114
+ continue
115
+ dest = original_dir / f"{sample.identifier}.png"
116
+ if dest.exists():
117
+ continue
118
+ with Image.open(sample.image_path) as img:
119
+ img.convert("RGB").save(dest)
120
+ print(f"[INFO] Saved original to {dest}")
synthcxr/prompt.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Prompt builders for conditional inference."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+
7
+ from .constants import KNOWN_CONDITIONS, SEVERITY_MODIFIERS
8
+
9
+
10
+ @dataclass
11
+ class ConditionConfig:
12
+ """Configuration for a single inference run with specific conditions."""
13
+
14
+ name: str
15
+ conditions: list[str] = field(default_factory=list)
16
+ age: int | None = None
17
+ sex: str | None = None
18
+ view: str = "AP"
19
+ custom_prompt: str | None = None
20
+ severity: str | None = None
21
+ heart_scale: float = 1.0
22
+ left_lung_scale: float = 1.0
23
+ right_lung_scale: float = 1.0
24
+
25
+
26
+ @dataclass
27
+ class InferenceConfig:
28
+ """Top-level configuration for the condition-inference script."""
29
+
30
+ num_samples: int = 10
31
+ num_steps: int = 50
32
+ height: int = 512
33
+ width: int = 512
34
+ cfg_scale: float = 4.0
35
+ seed: int = 0
36
+ conditions: list[ConditionConfig] = field(default_factory=list)
37
+
38
+
39
+ def build_condition_prompt(condition: ConditionConfig) -> str:
40
+ """Build a CheXpert-style prompt from a ``ConditionConfig``."""
41
+ if condition.custom_prompt:
42
+ return condition.custom_prompt
43
+
44
+ view = condition.view.upper() if condition.view else "AP"
45
+ age_str = f"{condition.age}-year-old" if condition.age else ""
46
+ sex_str = condition.sex.lower() if condition.sex else ""
47
+
48
+ if age_str and sex_str:
49
+ demographics = f"a {age_str} {sex_str} patient"
50
+ elif age_str:
51
+ demographics = f"a {age_str} patient"
52
+ elif sex_str:
53
+ demographics = f"a {sex_str} patient"
54
+ else:
55
+ demographics = "a patient"
56
+
57
+ pathologies: list[str] = []
58
+ severity_prefix = ""
59
+ if condition.severity and condition.severity in SEVERITY_MODIFIERS:
60
+ severity_prefix = SEVERITY_MODIFIERS[condition.severity] + " "
61
+
62
+ for cond_key in condition.conditions:
63
+ cond_text = KNOWN_CONDITIONS.get(cond_key.lower(), cond_key)
64
+ if severity_prefix and not pathologies:
65
+ pathologies.append(severity_prefix + cond_text)
66
+ severity_prefix = ""
67
+ else:
68
+ pathologies.append(cond_text)
69
+
70
+ pathology_str = (
71
+ f"with {', '.join(pathologies)}" if pathologies else "with no significant abnormality"
72
+ )
73
+
74
+ return (
75
+ f"frontal {view} chest radiograph of {demographics} {pathology_str}. "
76
+ "The conditioning mask image provides three channels "
77
+ "(red=left lung, green=right lung, blue=heart). "
78
+ "Reconstruct a CheXpert-style chest X-ray that aligns "
79
+ "with the segmentation and follows the described pathology."
80
+ )
synthcxr/utils.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Small shared helpers used across scripts."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Sequence
7
+
8
+
9
+ def resolve_path(base: Path, maybe_relative: str) -> Path:
10
+ """Return *maybe_relative* as an absolute path, resolved against *base*."""
11
+ path = Path(maybe_relative).expanduser()
12
+ if path.is_absolute():
13
+ return path
14
+ return (base / path).resolve()
15
+
16
+
17
+ def build_identifier(
18
+ record_path: str | None,
19
+ fallback_paths: Sequence[str],
20
+ sample_idx: int,
21
+ ) -> str:
22
+ """Build a filesystem-safe identifier from a metadata record."""
23
+ source = record_path or (fallback_paths[0] if fallback_paths else f"sample_{sample_idx}")
24
+ candidate = Path(source)
25
+ tail_parts = [part.replace(".", "-") for part in candidate.parts[-4:]]
26
+ slug = "_".join(tail_parts) if tail_parts else candidate.stem
27
+ return f"{sample_idx:03d}_{slug}"