Spaces:
Running
Running
Professional UI overhaul: custom CSS, badges, human-readable gallery, auto-load, section headings
Browse files
app.py
CHANGED
|
@@ -48,6 +48,90 @@ NO_GPU_MSG = (
|
|
| 48 |
"or run the app locally with a GPU: `python app.py`"
|
| 49 |
)
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
# ββ Lazy imports (avoid crash if no GPU) βββββββββββββββββββββββββββββ
|
| 52 |
_model_cache = {"model": None, "uni_model": None, "spatial_pool_size": 32}
|
| 53 |
|
|
@@ -188,13 +272,13 @@ def _generate_all_gpu(image, guidance_scale):
|
|
| 188 |
def generate_single_stain(image, stain, guidance_scale):
|
| 189 |
"""Wrapper with GPU availability check."""
|
| 190 |
if image is None:
|
| 191 |
-
return None, "
|
| 192 |
if not GPU_AVAILABLE and not HAS_SPACES:
|
| 193 |
return None, NO_GPU_MSG
|
| 194 |
try:
|
| 195 |
t0 = time.time()
|
| 196 |
result = _generate_single_gpu(image, stain, guidance_scale)
|
| 197 |
-
return result, f"{time.time() - t0:.2f}s"
|
| 198 |
except RuntimeError as e:
|
| 199 |
if "NVIDIA" in str(e) or "CUDA" in str(e) or "cuda" in str(e):
|
| 200 |
return None, NO_GPU_MSG
|
|
@@ -204,13 +288,13 @@ def generate_single_stain(image, stain, guidance_scale):
|
|
| 204 |
def generate_all_stains(image, guidance_scale):
|
| 205 |
"""Wrapper with GPU availability check."""
|
| 206 |
if image is None:
|
| 207 |
-
return None, None, None, None, None, "
|
| 208 |
if not GPU_AVAILABLE and not HAS_SPACES:
|
| 209 |
return None, None, None, None, None, NO_GPU_MSG
|
| 210 |
try:
|
| 211 |
t0 = time.time()
|
| 212 |
he_pil, results = _generate_all_gpu(image, guidance_scale)
|
| 213 |
-
elapsed = f"{time.time() - t0:.2f}s"
|
| 214 |
return he_pil, results["HER2"], results["Ki67"], results["ER"], results["PR"], elapsed
|
| 215 |
except RuntimeError as e:
|
| 216 |
if "NVIDIA" in str(e) or "CUDA" in str(e) or "cuda" in str(e):
|
|
@@ -228,10 +312,30 @@ def load_gallery():
|
|
| 228 |
return json.load(f)
|
| 229 |
|
| 230 |
|
| 231 |
-
def
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
base = GALLERY_DIR / "images"
|
| 236 |
he = Image.open(base / entry["he"]).convert("RGB") if "he" in entry else None
|
| 237 |
gt = Image.open(base / entry["gt"]).convert("RGB") if "gt" in entry else None
|
|
@@ -239,42 +343,121 @@ def show_gallery(name, gallery):
|
|
| 239 |
gen_ki67 = Image.open(base / entry["gen_ki67"]).convert("RGB") if "gen_ki67" in entry else None
|
| 240 |
gen_er = Image.open(base / entry["gen_er"]).convert("RGB") if "gen_er" in entry else None
|
| 241 |
gen_pr = Image.open(base / entry["gen_pr"]).convert("RGB") if "gen_pr" in entry else None
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
|
| 245 |
# ββ Build Gradio App βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 246 |
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
with gr.Row():
|
| 267 |
with gr.Column(scale=1):
|
| 268 |
-
input_image = gr.Image(type="pil", label="Upload H&E Image", height=
|
| 269 |
-
stain_choice = gr.Radio(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
guidance_slider = gr.Slider(
|
| 271 |
minimum=1.0, maximum=3.0, step=0.1, value=1.0,
|
| 272 |
-
label="Guidance Scale
|
|
|
|
| 273 |
)
|
| 274 |
-
generate_btn = gr.Button("Generate", variant="primary")
|
| 275 |
gen_time = gr.Textbox(label="Status", interactive=False)
|
| 276 |
with gr.Column(scale=1):
|
| 277 |
-
output_image = gr.Image(type="pil", label="Generated IHC", height=
|
| 278 |
|
| 279 |
generate_btn.click(
|
| 280 |
fn=generate_single_stain,
|
|
@@ -282,30 +465,36 @@ with gr.Blocks(title="UNIStainNet β Virtual IHC Staining") as demo:
|
|
| 282 |
outputs=[output_image, gen_time],
|
| 283 |
)
|
| 284 |
|
| 285 |
-
# ββ Tab
|
| 286 |
-
with gr.Tab("Cross-Stain Comparison"):
|
| 287 |
if not GPU_AVAILABLE and not HAS_SPACES:
|
| 288 |
-
gr.
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
|
|
|
| 295 |
with gr.Row():
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
|
|
|
|
| 303 |
with gr.Row():
|
| 304 |
-
cross_he_out = gr.Image(type="pil", label="H&E Input", height=
|
| 305 |
-
cross_her2 = gr.Image(type="pil", label="HER2", height=
|
| 306 |
-
cross_ki67 = gr.Image(type="pil", label="Ki67", height=
|
| 307 |
-
cross_er = gr.Image(type="pil", label="ER", height=
|
| 308 |
-
cross_pr = gr.Image(type="pil", label="PR", height=
|
| 309 |
|
| 310 |
cross_btn.click(
|
| 311 |
fn=generate_all_stains,
|
|
@@ -313,69 +502,76 @@ with gr.Blocks(title="UNIStainNet β Virtual IHC Staining") as demo:
|
|
| 313 |
outputs=[cross_he_out, cross_her2, cross_ki67, cross_er, cross_pr, cross_time],
|
| 314 |
)
|
| 315 |
|
| 316 |
-
# ββ Tab 3: Gallery βββββββββββββββββββββββββββββββββββββββββββ
|
| 317 |
-
with gr.Tab("Gallery"):
|
| 318 |
-
if not gallery_names:
|
| 319 |
-
gr.Markdown("No pre-computed gallery available.")
|
| 320 |
-
else:
|
| 321 |
-
gr.Markdown(
|
| 322 |
-
"Pre-computed examples β no GPU required. "
|
| 323 |
-
"Select an example to view the H&E input, ground truth, and generated IHC stains."
|
| 324 |
-
)
|
| 325 |
-
gallery_dropdown = gr.Dropdown(
|
| 326 |
-
choices=gallery_names,
|
| 327 |
-
value=gallery_names[0] if gallery_names else None,
|
| 328 |
-
label="Select Example",
|
| 329 |
-
)
|
| 330 |
-
with gr.Row():
|
| 331 |
-
gal_he = gr.Image(type="pil", label="H&E Input", height=300)
|
| 332 |
-
gal_gt = gr.Image(type="pil", label="Ground Truth IHC", height=300)
|
| 333 |
-
with gr.Row():
|
| 334 |
-
gal_her2 = gr.Image(type="pil", label="Generated HER2", height=300)
|
| 335 |
-
gal_ki67 = gr.Image(type="pil", label="Generated Ki67", height=300)
|
| 336 |
-
gal_er = gr.Image(type="pil", label="Generated ER", height=300)
|
| 337 |
-
gal_pr = gr.Image(type="pil", label="Generated PR", height=300)
|
| 338 |
-
|
| 339 |
-
gallery_dropdown.change(
|
| 340 |
-
fn=lambda name: show_gallery(name, gallery),
|
| 341 |
-
inputs=[gallery_dropdown],
|
| 342 |
-
outputs=[gal_he, gal_gt, gal_her2, gal_ki67, gal_er, gal_pr],
|
| 343 |
-
)
|
| 344 |
-
|
| 345 |
# ββ Tab 4: About βββββββββββββββββββββββββββββββββββββββββββββ
|
| 346 |
-
with gr.Tab("About"):
|
| 347 |
gr.Markdown(
|
| 348 |
"""
|
| 349 |
-
|
|
|
|
|
|
|
| 350 |
|
| 351 |
-
|
| 352 |
-
|
|
|
|
|
|
|
| 353 |
|
| 354 |
-
|
| 355 |
-
- Dense UNI spatial conditioning (32x32 = 1,024 tokens)
|
| 356 |
-
- Misalignment-aware loss suite for consecutive-section training pairs
|
| 357 |
-
- Single unified model serves 4 IHC markers (HER2, Ki67, ER, PR)
|
| 358 |
-
- 42M generator parameters, single forward pass inference
|
| 359 |
|
| 360 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
|
| 362 |
| Component | Details |
|
| 363 |
|-----------|---------|
|
| 364 |
-
| Generator | SPADE-UNet with UNI spatial conditioning + FiLM stain embeddings |
|
| 365 |
-
|
|
| 366 |
-
|
|
| 367 |
-
| Parameters | 42M
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
|
| 373 |
-
|
|
| 374 |
-
|
|
| 375 |
-
|
|
| 376 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
"""
|
| 378 |
)
|
| 379 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
if __name__ == "__main__":
|
| 381 |
demo.launch(theme=gr.themes.Soft())
|
|
|
|
| 48 |
"or run the app locally with a GPU: `python app.py`"
|
| 49 |
)
|
| 50 |
|
| 51 |
+
# ββ Custom CSS βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 52 |
+
CUSTOM_CSS = """
|
| 53 |
+
.header-container {
|
| 54 |
+
text-align: center;
|
| 55 |
+
padding: 1.5rem 1rem 0.5rem 1rem;
|
| 56 |
+
}
|
| 57 |
+
.header-container h1 {
|
| 58 |
+
font-size: 1.8rem;
|
| 59 |
+
font-weight: 700;
|
| 60 |
+
margin-bottom: 0.3rem;
|
| 61 |
+
}
|
| 62 |
+
.header-subtitle {
|
| 63 |
+
text-align: center;
|
| 64 |
+
color: #555;
|
| 65 |
+
font-size: 0.95rem;
|
| 66 |
+
margin-bottom: 0.8rem;
|
| 67 |
+
}
|
| 68 |
+
.badge-row {
|
| 69 |
+
display: flex;
|
| 70 |
+
justify-content: center;
|
| 71 |
+
gap: 0.6rem;
|
| 72 |
+
flex-wrap: wrap;
|
| 73 |
+
margin-bottom: 1rem;
|
| 74 |
+
}
|
| 75 |
+
.badge {
|
| 76 |
+
display: inline-block;
|
| 77 |
+
padding: 0.25rem 0.75rem;
|
| 78 |
+
border-radius: 999px;
|
| 79 |
+
font-size: 0.8rem;
|
| 80 |
+
font-weight: 600;
|
| 81 |
+
background: #e8eaf6;
|
| 82 |
+
color: #3949ab;
|
| 83 |
+
}
|
| 84 |
+
.badge-green {
|
| 85 |
+
background: #e8f5e9;
|
| 86 |
+
color: #2e7d32;
|
| 87 |
+
}
|
| 88 |
+
.badge-purple {
|
| 89 |
+
background: #f3e5f5;
|
| 90 |
+
color: #7b1fa2;
|
| 91 |
+
}
|
| 92 |
+
.badge-orange {
|
| 93 |
+
background: #fff3e0;
|
| 94 |
+
color: #e65100;
|
| 95 |
+
}
|
| 96 |
+
.gpu-notice {
|
| 97 |
+
background: #fff8e1;
|
| 98 |
+
border: 1px solid #ffe082;
|
| 99 |
+
border-radius: 8px;
|
| 100 |
+
padding: 0.75rem 1rem;
|
| 101 |
+
margin-bottom: 1rem;
|
| 102 |
+
font-size: 0.9rem;
|
| 103 |
+
color: #6d4c00;
|
| 104 |
+
}
|
| 105 |
+
.section-heading {
|
| 106 |
+
font-size: 1.05rem;
|
| 107 |
+
font-weight: 600;
|
| 108 |
+
color: #333;
|
| 109 |
+
margin-bottom: 0.5rem;
|
| 110 |
+
border-bottom: 2px solid #e0e0e0;
|
| 111 |
+
padding-bottom: 0.3rem;
|
| 112 |
+
}
|
| 113 |
+
.gallery-info {
|
| 114 |
+
background: #f5f5f5;
|
| 115 |
+
border-radius: 8px;
|
| 116 |
+
padding: 0.6rem 1rem;
|
| 117 |
+
margin-bottom: 0.8rem;
|
| 118 |
+
font-size: 0.88rem;
|
| 119 |
+
color: #555;
|
| 120 |
+
}
|
| 121 |
+
.about-section {
|
| 122 |
+
max-width: 800px;
|
| 123 |
+
margin: 0 auto;
|
| 124 |
+
padding: 1rem;
|
| 125 |
+
}
|
| 126 |
+
footer {
|
| 127 |
+
text-align: center;
|
| 128 |
+
padding: 1rem;
|
| 129 |
+
color: #999;
|
| 130 |
+
font-size: 0.8rem;
|
| 131 |
+
}
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
|
| 135 |
# ββ Lazy imports (avoid crash if no GPU) βββββββββββββββββββββββββββββ
|
| 136 |
_model_cache = {"model": None, "uni_model": None, "spatial_pool_size": 32}
|
| 137 |
|
|
|
|
| 272 |
def generate_single_stain(image, stain, guidance_scale):
|
| 273 |
"""Wrapper with GPU availability check."""
|
| 274 |
if image is None:
|
| 275 |
+
return None, "Please upload an H&E image first."
|
| 276 |
if not GPU_AVAILABLE and not HAS_SPACES:
|
| 277 |
return None, NO_GPU_MSG
|
| 278 |
try:
|
| 279 |
t0 = time.time()
|
| 280 |
result = _generate_single_gpu(image, stain, guidance_scale)
|
| 281 |
+
return result, f"Generated in {time.time() - t0:.2f}s"
|
| 282 |
except RuntimeError as e:
|
| 283 |
if "NVIDIA" in str(e) or "CUDA" in str(e) or "cuda" in str(e):
|
| 284 |
return None, NO_GPU_MSG
|
|
|
|
| 288 |
def generate_all_stains(image, guidance_scale):
|
| 289 |
"""Wrapper with GPU availability check."""
|
| 290 |
if image is None:
|
| 291 |
+
return None, None, None, None, None, "Please upload an H&E image first."
|
| 292 |
if not GPU_AVAILABLE and not HAS_SPACES:
|
| 293 |
return None, None, None, None, None, NO_GPU_MSG
|
| 294 |
try:
|
| 295 |
t0 = time.time()
|
| 296 |
he_pil, results = _generate_all_gpu(image, guidance_scale)
|
| 297 |
+
elapsed = f"Generated all 4 stains in {time.time() - t0:.2f}s"
|
| 298 |
return he_pil, results["HER2"], results["Ki67"], results["ER"], results["PR"], elapsed
|
| 299 |
except RuntimeError as e:
|
| 300 |
if "NVIDIA" in str(e) or "CUDA" in str(e) or "cuda" in str(e):
|
|
|
|
| 312 |
return json.load(f)
|
| 313 |
|
| 314 |
|
| 315 |
+
def _make_gallery_label(key, entry):
|
| 316 |
+
"""Create a human-readable label for a gallery entry."""
|
| 317 |
+
source = entry.get("source", "")
|
| 318 |
+
gt_stain = entry.get("gt_stain", "")
|
| 319 |
+
# Extract a short sample ID from the key
|
| 320 |
+
parts = key.split("_")
|
| 321 |
+
if source == "BCI":
|
| 322 |
+
# e.g. BCI_HER2_3+_00277_test_3+ -> "BCI - HER2 3+ (#00277)"
|
| 323 |
+
her2_class = parts[2] if len(parts) > 2 else ""
|
| 324 |
+
sample_id = parts[3] if len(parts) > 3 else ""
|
| 325 |
+
return f"BCI - HER2 {her2_class} (#{sample_id})"
|
| 326 |
+
else:
|
| 327 |
+
# e.g. MIST_Ki67_10M2102916_10_20 -> "MIST - Ki67 (10M2102916)"
|
| 328 |
+
stain = parts[1] if len(parts) > 1 else ""
|
| 329 |
+
sample_id = parts[2] if len(parts) > 2 else ""
|
| 330 |
+
return f"MIST - {stain} ({sample_id})"
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def show_gallery(display_name, gallery, name_map):
|
| 334 |
+
"""Show a gallery example by its display name."""
|
| 335 |
+
key = name_map.get(display_name)
|
| 336 |
+
if not key or not gallery or key not in gallery:
|
| 337 |
+
return None, None, None, None, None, None, ""
|
| 338 |
+
entry = gallery[key]
|
| 339 |
base = GALLERY_DIR / "images"
|
| 340 |
he = Image.open(base / entry["he"]).convert("RGB") if "he" in entry else None
|
| 341 |
gt = Image.open(base / entry["gt"]).convert("RGB") if "gt" in entry else None
|
|
|
|
| 343 |
gen_ki67 = Image.open(base / entry["gen_ki67"]).convert("RGB") if "gen_ki67" in entry else None
|
| 344 |
gen_er = Image.open(base / entry["gen_er"]).convert("RGB") if "gen_er" in entry else None
|
| 345 |
gen_pr = Image.open(base / entry["gen_pr"]).convert("RGB") if "gen_pr" in entry else None
|
| 346 |
+
|
| 347 |
+
source = entry.get("source", "Unknown")
|
| 348 |
+
gt_stain = entry.get("gt_stain", "Unknown")
|
| 349 |
+
info = f"**Dataset:** {source} | **Ground truth stain:** {gt_stain}"
|
| 350 |
+
return he, gt, gen_her2, gen_ki67, gen_er, gen_pr, info
|
| 351 |
|
| 352 |
|
| 353 |
# ββ Build Gradio App βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 354 |
|
| 355 |
+
gallery_data = load_gallery()
|
| 356 |
+
gallery_name_map = {} # display_name -> key
|
| 357 |
+
gallery_display_names = []
|
| 358 |
+
if gallery_data:
|
| 359 |
+
for key, entry in gallery_data.items():
|
| 360 |
+
label = _make_gallery_label(key, entry)
|
| 361 |
+
gallery_name_map[label] = key
|
| 362 |
+
gallery_display_names.append(label)
|
| 363 |
+
|
| 364 |
+
with gr.Blocks(title="UNIStainNet β Virtual IHC Staining", css=CUSTOM_CSS) as demo:
|
| 365 |
+
|
| 366 |
+
# ββ Header ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 367 |
+
gr.HTML("""
|
| 368 |
+
<div class="header-container">
|
| 369 |
+
<h1>UNIStainNet</h1>
|
| 370 |
+
<p style="font-size:1.1rem; color:#555; margin-top:0.2rem;">
|
| 371 |
+
Foundation-Model-Guided Virtual Staining of H&E to IHC
|
| 372 |
+
</p>
|
| 373 |
+
</div>
|
| 374 |
+
<div class="header-subtitle">
|
| 375 |
+
Translate H&E histopathology images into immunohistochemistry (IHC) stains
|
| 376 |
+
for breast cancer biomarkers using a single unified deep learning model.
|
| 377 |
+
</div>
|
| 378 |
+
<div class="badge-row">
|
| 379 |
+
<span class="badge">42M Parameters</span>
|
| 380 |
+
<span class="badge-green badge">4 IHC Stains</span>
|
| 381 |
+
<span class="badge-purple badge">UNI Foundation Model</span>
|
| 382 |
+
<span class="badge-orange badge">Single Forward Pass</span>
|
| 383 |
+
</div>
|
| 384 |
+
""")
|
| 385 |
+
|
| 386 |
+
# ββ Tab 1: Gallery (default β works without GPU) ββββββββββββ
|
| 387 |
+
with gr.Tab("Gallery", id="gallery"):
|
| 388 |
+
if not gallery_display_names:
|
| 389 |
+
gr.Markdown("No pre-computed gallery available.")
|
| 390 |
+
else:
|
| 391 |
+
gr.HTML("""
|
| 392 |
+
<div class="gallery-info">
|
| 393 |
+
Browse pre-computed virtual staining results β <strong>no GPU required</strong>.
|
| 394 |
+
Each example shows the H&E input, ground truth IHC, and all 4 generated stains from our unified model.
|
| 395 |
+
</div>
|
| 396 |
+
""")
|
| 397 |
+
with gr.Row():
|
| 398 |
+
gallery_dropdown = gr.Dropdown(
|
| 399 |
+
choices=gallery_display_names,
|
| 400 |
+
value=gallery_display_names[0] if gallery_display_names else None,
|
| 401 |
+
label="Select Example",
|
| 402 |
+
scale=3,
|
| 403 |
+
)
|
| 404 |
+
gallery_info_box = gr.Markdown(value="", scale=2)
|
| 405 |
+
|
| 406 |
+
gr.HTML('<p class="section-heading">Input & Ground Truth</p>')
|
| 407 |
+
with gr.Row():
|
| 408 |
+
gal_he = gr.Image(type="pil", label="H&E Input", height=280, show_download_button=False)
|
| 409 |
+
gal_gt = gr.Image(type="pil", label="Ground Truth IHC", height=280, show_download_button=False)
|
| 410 |
+
|
| 411 |
+
gr.HTML('<p class="section-heading">Generated IHC Stains (all from the same H&E)</p>')
|
| 412 |
+
with gr.Row():
|
| 413 |
+
gal_her2 = gr.Image(type="pil", label="Generated HER2", height=280, show_download_button=False)
|
| 414 |
+
gal_ki67 = gr.Image(type="pil", label="Generated Ki67", height=280, show_download_button=False)
|
| 415 |
+
gal_er = gr.Image(type="pil", label="Generated ER", height=280, show_download_button=False)
|
| 416 |
+
gal_pr = gr.Image(type="pil", label="Generated PR", height=280, show_download_button=False)
|
| 417 |
+
|
| 418 |
+
def _show_gallery_wrapper(display_name):
|
| 419 |
+
return show_gallery(display_name, gallery_data, gallery_name_map)
|
| 420 |
+
|
| 421 |
+
gallery_dropdown.change(
|
| 422 |
+
fn=_show_gallery_wrapper,
|
| 423 |
+
inputs=[gallery_dropdown],
|
| 424 |
+
outputs=[gal_he, gal_gt, gal_her2, gal_ki67, gal_er, gal_pr, gallery_info_box],
|
| 425 |
)
|
| 426 |
+
|
| 427 |
+
# Auto-load first example
|
| 428 |
+
demo.load(
|
| 429 |
+
fn=lambda: _show_gallery_wrapper(gallery_display_names[0]) if gallery_display_names else (None,) * 7,
|
| 430 |
+
outputs=[gal_he, gal_gt, gal_her2, gal_ki67, gal_er, gal_pr, gallery_info_box],
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
# ββ Tab 2: Single Stain ββββββββββββββββββββββββββββββββββββββ
|
| 434 |
+
with gr.Tab("Virtual Staining", id="inference"):
|
| 435 |
+
if not GPU_AVAILABLE and not HAS_SPACES:
|
| 436 |
+
gr.HTML(f'<div class="gpu-notice">{NO_GPU_MSG}</div>')
|
| 437 |
+
else:
|
| 438 |
+
gr.HTML("""
|
| 439 |
+
<div class="gallery-info">
|
| 440 |
+
Upload an H&E image and select a target IHC stain. The model generates
|
| 441 |
+
the virtual stain in a single forward pass (~1 second on GPU).
|
| 442 |
+
</div>
|
| 443 |
+
""")
|
| 444 |
with gr.Row():
|
| 445 |
with gr.Column(scale=1):
|
| 446 |
+
input_image = gr.Image(type="pil", label="Upload H&E Image", height=380)
|
| 447 |
+
stain_choice = gr.Radio(
|
| 448 |
+
choices=STAIN_NAMES, value="HER2",
|
| 449 |
+
label="Target IHC Stain",
|
| 450 |
+
info="Select which immunohistochemistry marker to generate",
|
| 451 |
+
)
|
| 452 |
guidance_slider = gr.Slider(
|
| 453 |
minimum=1.0, maximum=3.0, step=0.1, value=1.0,
|
| 454 |
+
label="Guidance Scale",
|
| 455 |
+
info="1.0 = standard generation, higher = stronger stain signal (CFG)",
|
| 456 |
)
|
| 457 |
+
generate_btn = gr.Button("Generate", variant="primary", size="lg")
|
| 458 |
gen_time = gr.Textbox(label="Status", interactive=False)
|
| 459 |
with gr.Column(scale=1):
|
| 460 |
+
output_image = gr.Image(type="pil", label="Generated IHC", height=380)
|
| 461 |
|
| 462 |
generate_btn.click(
|
| 463 |
fn=generate_single_stain,
|
|
|
|
| 465 |
outputs=[output_image, gen_time],
|
| 466 |
)
|
| 467 |
|
| 468 |
+
# ββ Tab 3: Cross-Stain βββββββββββββββββββββββββββββββββββββββ
|
| 469 |
+
with gr.Tab("Cross-Stain Comparison", id="cross-stain"):
|
| 470 |
if not GPU_AVAILABLE and not HAS_SPACES:
|
| 471 |
+
gr.HTML(f'<div class="gpu-notice">{NO_GPU_MSG}</div>')
|
| 472 |
+
else:
|
| 473 |
+
gr.HTML("""
|
| 474 |
+
<div class="gallery-info">
|
| 475 |
+
Generate <strong>all 4 IHC stains</strong> from a single H&E input.
|
| 476 |
+
This demonstrates the unified multi-stain capability of UNIStainNet.
|
| 477 |
+
</div>
|
| 478 |
+
""")
|
| 479 |
with gr.Row():
|
| 480 |
+
with gr.Column(scale=1):
|
| 481 |
+
cross_input = gr.Image(type="pil", label="Upload H&E Image", height=300)
|
| 482 |
+
with gr.Column(scale=1):
|
| 483 |
+
cross_guidance = gr.Slider(
|
| 484 |
+
minimum=1.0, maximum=3.0, step=0.1, value=1.0,
|
| 485 |
+
label="Guidance Scale",
|
| 486 |
+
info="1.0 = standard generation, higher = stronger stain signal",
|
| 487 |
+
)
|
| 488 |
+
cross_btn = gr.Button("Generate All 4 Stains", variant="primary", size="lg")
|
| 489 |
+
cross_time = gr.Textbox(label="Status", interactive=False)
|
| 490 |
|
| 491 |
+
gr.HTML('<p class="section-heading">Results</p>')
|
| 492 |
with gr.Row():
|
| 493 |
+
cross_he_out = gr.Image(type="pil", label="H&E Input", height=250)
|
| 494 |
+
cross_her2 = gr.Image(type="pil", label="HER2", height=250)
|
| 495 |
+
cross_ki67 = gr.Image(type="pil", label="Ki67", height=250)
|
| 496 |
+
cross_er = gr.Image(type="pil", label="ER", height=250)
|
| 497 |
+
cross_pr = gr.Image(type="pil", label="PR", height=250)
|
| 498 |
|
| 499 |
cross_btn.click(
|
| 500 |
fn=generate_all_stains,
|
|
|
|
| 502 |
outputs=[cross_he_out, cross_her2, cross_ki67, cross_er, cross_pr, cross_time],
|
| 503 |
)
|
| 504 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 505 |
# ββ Tab 4: About βββββββββββββββββββββββββββββββββββββββββββββ
|
| 506 |
+
with gr.Tab("About", id="about"):
|
| 507 |
gr.Markdown(
|
| 508 |
"""
|
| 509 |
+
<div class="about-section">
|
| 510 |
+
|
| 511 |
+
## UNIStainNet: Foundation-Model-Guided Virtual Staining
|
| 512 |
|
| 513 |
+
UNIStainNet is a deep learning model for **virtual immunohistochemistry (IHC) staining**
|
| 514 |
+
from standard hematoxylin & eosin (H&E) histopathology images. It translates routine H&E
|
| 515 |
+
slides into IHC stains for four clinically important breast cancer biomarkers:
|
| 516 |
+
**HER2**, **Ki67**, **ER**, and **PR**.
|
| 517 |
|
| 518 |
+
### Why Virtual Staining?
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
|
| 520 |
+
- **Tissue conservation** β eliminates the need for additional serial sections
|
| 521 |
+
- **Faster turnaround** β results in seconds instead of hours/days
|
| 522 |
+
- **Cost reduction** β one H&E slide replaces multiple IHC tests for screening
|
| 523 |
+
- **Consistency** β no batch-to-batch staining variability
|
| 524 |
+
|
| 525 |
+
### How It Works
|
| 526 |
+
|
| 527 |
+
The model uses a **SPADE-UNet generator** conditioned on dense spatial features from a
|
| 528 |
+
frozen [UNI](https://github.com/mahmoodlab/UNI) pathology foundation model (ViT-L/16,
|
| 529 |
+
pretrained on 100M+ histopathology patches). A FiLM-based stain embedding allows a
|
| 530 |
+
**single unified model** to generate all 4 IHC stains.
|
| 531 |
|
| 532 |
| Component | Details |
|
| 533 |
|-----------|---------|
|
| 534 |
+
| **Generator** | SPADE-UNet with UNI spatial conditioning + FiLM stain embeddings |
|
| 535 |
+
| **Foundation Model** | UNI ViT-L/16 (frozen, 303M parameters) |
|
| 536 |
+
| **Spatial Tokens** | 4x4 sub-crop tiling of H&E input, yielding 32x32 = 1,024 tokens |
|
| 537 |
+
| **Generator Parameters** | 42M |
|
| 538 |
+
| **Inference** | Single forward pass (~1 second on GPU) |
|
| 539 |
+
|
| 540 |
+
### Quantitative Results (MIST Dataset, Unified Model)
|
| 541 |
+
|
| 542 |
+
| Stain | FID | KID x1k | Pearson-R | DAB KL |
|
| 543 |
+
|-------|-----|---------|-----------|--------|
|
| 544 |
+
| HER2 | 34.5 | 2.2 | 0.929 | 0.166 |
|
| 545 |
+
| Ki67 | 27.2 | 1.8 | 0.927 | 0.119 |
|
| 546 |
+
| ER | 29.2 | 1.8 | 0.949 | 0.182 |
|
| 547 |
+
| PR | 29.0 | 1.1 | 0.943 | 0.171 |
|
| 548 |
+
|
| 549 |
+
### Key Innovations
|
| 550 |
+
|
| 551 |
+
- **Dense UNI spatial conditioning**: Unlike prior methods that use global image features,
|
| 552 |
+
UNIStainNet extracts spatially-resolved features at 32x32 resolution, enabling the generator
|
| 553 |
+
to leverage fine-grained morphological context from the pathology foundation model.
|
| 554 |
+
- **Misalignment-aware training**: Because H&E and IHC are cut from consecutive tissue sections
|
| 555 |
+
(not the same section), there are inherent spatial shifts. Our loss suite (perceptual loss,
|
| 556 |
+
DAB intensity supervision, unconditional discriminator) is designed to handle this misalignment.
|
| 557 |
+
- **Classifier-free guidance (CFG)**: 10% class dropout and 10% UNI dropout during training
|
| 558 |
+
enables tunable generation strength at inference time.
|
| 559 |
+
|
| 560 |
+
### Disclaimer
|
| 561 |
+
|
| 562 |
+
This is a **research tool** for exploratory analysis. It is not intended for clinical diagnosis
|
| 563 |
+
and has not undergone regulatory validation. Generated stains should not be used for treatment decisions.
|
| 564 |
+
|
| 565 |
+
</div>
|
| 566 |
"""
|
| 567 |
)
|
| 568 |
|
| 569 |
+
# ββ Footer βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 570 |
+
gr.HTML("""
|
| 571 |
+
<footer>
|
| 572 |
+
UNIStainNet | Built with Gradio
|
| 573 |
+
</footer>
|
| 574 |
+
""")
|
| 575 |
+
|
| 576 |
if __name__ == "__main__":
|
| 577 |
demo.launch(theme=gr.themes.Soft())
|