faceless-void commited on
Commit
9677996
Β·
verified Β·
1 Parent(s): d0b8562

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +92 -98
app.py CHANGED
@@ -1,9 +1,9 @@
1
  #!/usr/bin/env python3
2
  """
3
- UNIStainNet Interactive Demo β€” Hugging Face Spaces (ZeroGPU)
4
 
5
  Virtual staining of H&E histopathology images to IHC (HER2, Ki67, ER, PR).
6
- Uses @spaces.GPU for on-demand GPU allocation on ZeroGPU.
7
  """
8
 
9
  import json
@@ -14,89 +14,82 @@ from pathlib import Path
14
  import gradio as gr
15
  import numpy as np
16
  import torch
 
 
 
 
 
17
 
18
- # ZeroGPU support: use @spaces.GPU if available, otherwise no-op
19
  try:
20
  import spaces
21
- GPU_AVAILABLE = torch.cuda.is_available()
22
  except ImportError:
23
  spaces = None
24
- GPU_AVAILABLE = torch.cuda.is_available()
25
 
 
26
 
27
  def _gpu_decorator(duration=60):
28
- """Apply @spaces.GPU if available, otherwise return identity decorator."""
29
- if spaces is not None and hasattr(spaces, "GPU"):
30
  return spaces.GPU(duration=duration)
31
  return lambda fn: fn
32
- import torch.nn.functional as F
33
- import torchvision.transforms as T
34
- import torchvision.transforms.functional as TF
35
- from PIL import Image
36
- from huggingface_hub import hf_hub_download
37
-
38
- from src.models.trainer import UNIStainNetTrainer
39
- from src.data.mist_dataset import STAIN_TO_LABEL, LABEL_TO_STAIN
40
 
41
  # ── Constants ────────────────────────────────────────────────────────
42
  STAIN_NAMES = ["HER2", "Ki67", "ER", "PR"]
43
  GALLERY_DIR = Path(__file__).parent / "gallery"
44
  TARGET_SIZE = 512
45
-
46
- # Model repo where checkpoint is stored (uploaded separately)
47
  MODEL_REPO = os.environ.get("MODEL_REPO", "faceless-void/UNIStainNet")
48
  CHECKPOINT_FILENAME = "mist_multistain_last.ckpt"
49
 
50
- # ── Global model cache (loaded lazily on GPU request) ────────────────
 
 
 
 
 
 
51
  _model_cache = {"model": None, "uni_model": None, "spatial_pool_size": 32}
52
 
53
 
54
  def _get_checkpoint_path():
55
- """Download checkpoint from HF Hub if not local."""
56
  local_path = Path(__file__).parent / "checkpoints" / CHECKPOINT_FILENAME
57
  if local_path.exists():
58
  return str(local_path)
59
- # Download from HF model repo
60
  return hf_hub_download(repo_id=MODEL_REPO, filename=CHECKPOINT_FILENAME)
61
 
62
 
63
  def _load_models():
64
- """Load UNIStainNet + UNI (called inside @spaces.GPU function)."""
65
- if _model_cache["model"] is None:
66
- import timm
67
 
 
68
  ckpt_path = _get_checkpoint_path()
69
  print(f"Loading UNIStainNet from {ckpt_path} ...")
70
  model = UNIStainNetTrainer.load_from_checkpoint(ckpt_path, strict=False)
71
  model = model.cuda().eval()
72
  _model_cache["model"] = model
73
- _model_cache["spatial_pool_size"] = getattr(
74
- model.hparams, "uni_spatial_size", 32
75
- )
76
- print(" Generator loaded")
77
 
78
  print("Loading UNI ViT-L/16 ...")
79
  uni_model = timm.create_model(
80
- "hf-hub:MahmoodLab/uni",
81
- pretrained=True,
82
- init_values=1e-5,
83
- dynamic_img_size=True,
84
  )
85
  uni_model = uni_model.cuda().eval()
86
  _model_cache["uni_model"] = uni_model
87
- print(" UNI loaded")
88
  else:
89
- # Models already loaded β€” move to current GPU device
90
  _model_cache["model"] = _model_cache["model"].cuda()
91
  _model_cache["uni_model"] = _model_cache["uni_model"].cuda()
92
 
93
  return _model_cache["model"], _model_cache["uni_model"], _model_cache["spatial_pool_size"]
94
 
95
 
96
- # ── Preprocessing helpers ────────────────────────────────────────────
97
 
98
  def preprocess_he(pil_image, target_size=TARGET_SIZE):
99
- """Center-crop and resize H&E to target_size x target_size."""
100
  w, h = pil_image.size
101
  short = min(w, h)
102
  left = (w - short) // 2
@@ -108,52 +101,40 @@ def preprocess_he(pil_image, target_size=TARGET_SIZE):
108
 
109
 
110
  def pil_to_tensor(pil_image):
111
- """PIL β†’ [1, 3, H, W] in [-1, 1]."""
112
  t = TF.to_tensor(pil_image)
113
  t = TF.normalize(t, [0.5] * 3, [0.5] * 3)
114
  return t.unsqueeze(0)
115
 
116
 
117
  def tensor_to_pil(tensor):
118
- """[1, 3, H, W] in [-1, 1] β†’ PIL."""
119
  t = ((tensor[0].cpu() + 1) / 2).clamp(0, 1)
120
  return TF.to_pil_image(t)
121
 
122
 
123
  def extract_uni_features(uni_model, he_tensor_01, spatial_pool_size=32):
124
- """Extract UNI spatial features from H&E crop ([1,3,H,W] in [0,1])."""
125
- uni_transform = T.Normalize(
126
- mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
127
- )
128
  B = he_tensor_01.shape[0]
129
- num_crops = 4
130
- patches_per_side = 14
131
  crop_h = he_tensor_01.shape[2] // num_crops
132
  crop_w = he_tensor_01.shape[3] // num_crops
133
 
134
  sub_crops = []
135
  for i in range(num_crops):
136
  for j in range(num_crops):
137
- sub = he_tensor_01[
138
- :, :, i * crop_h : (i + 1) * crop_h, j * crop_w : (j + 1) * crop_w
139
- ]
140
  sub = F.interpolate(sub, size=(224, 224), mode="bicubic", align_corners=False)
141
  sub = torch.stack([uni_transform(s) for s in sub])
142
  sub_crops.append(sub)
143
 
144
  all_crops = torch.stack(sub_crops, dim=1).reshape(B * 16, 3, 224, 224).cuda()
145
-
146
  with torch.no_grad():
147
  all_feats = uni_model.forward_features(all_crops)
148
  patch_tokens = all_feats[:, 1:, :]
149
 
150
- patch_tokens = patch_tokens.reshape(
151
- B, num_crops, num_crops, patches_per_side, patches_per_side, 1024
152
- )
153
  full_size = num_crops * patches_per_side
154
- full_grid = patch_tokens.permute(0, 1, 3, 2, 4, 5).reshape(
155
- B, full_size, full_size, 1024
156
- )
157
 
158
  S = spatial_pool_size
159
  if S < full_size:
@@ -162,74 +143,82 @@ def extract_uni_features(uni_model, he_tensor_01, spatial_pool_size=32):
162
  result = pooled.permute(0, 2, 3, 1)
163
  else:
164
  result = full_grid
165
-
166
  return result.reshape(B, S * S, 1024)
167
 
168
 
169
- # ── GPU-accelerated inference functions ──────────────────────────────
170
 
171
  @_gpu_decorator(duration=60)
172
- def generate_single_stain(image, stain, guidance_scale):
173
- """Generate a single IHC stain from an H&E image (GPU)."""
174
- if image is None:
175
- return None, "No image uploaded"
176
-
177
- t0 = time.time()
178
  model, uni_model, spatial_pool_size = _load_models()
179
 
180
  he_pil = preprocess_he(image)
181
  he_tensor = pil_to_tensor(he_pil).cuda()
182
  he_01 = ((he_tensor + 1) / 2).clamp(0, 1)
183
-
184
  uni_feats = extract_uni_features(uni_model, he_01, spatial_pool_size).cuda()
185
- label = STAIN_TO_LABEL[stain]
186
- labels = torch.tensor([label], device="cuda", dtype=torch.long)
187
 
188
  with torch.no_grad():
189
  gen = model.generate(he_tensor, uni_feats, labels, guidance_scale=guidance_scale)
190
-
191
- result = tensor_to_pil(gen)
192
- elapsed = time.time() - t0
193
- return result, f"{elapsed:.2f}s"
194
 
195
 
196
  @_gpu_decorator(duration=120)
197
- def generate_all_stains(image, guidance_scale):
198
- """Generate all 4 IHC stains from one H&E image (GPU)."""
199
- if image is None:
200
- return None, None, None, None, None, "No image uploaded"
201
-
202
- t0 = time.time()
203
  model, uni_model, spatial_pool_size = _load_models()
204
 
205
  he_pil = preprocess_he(image)
206
  he_tensor = pil_to_tensor(he_pil).cuda()
207
  he_01 = ((he_tensor + 1) / 2).clamp(0, 1)
208
-
209
  uni_feats = extract_uni_features(uni_model, he_01, spatial_pool_size).cuda()
210
 
211
  results = {}
212
  for stain in STAIN_NAMES:
213
- label = STAIN_TO_LABEL[stain]
214
- labels = torch.tensor([label], device="cuda", dtype=torch.long)
215
  with torch.no_grad():
216
- gen = model.generate(
217
- he_tensor, uni_feats, labels, guidance_scale=guidance_scale
218
- )
219
  results[stain] = tensor_to_pil(gen)
 
220
 
221
- elapsed = time.time() - t0
222
- return (
223
- he_pil,
224
- results["HER2"],
225
- results["Ki67"],
226
- results["ER"],
227
- results["PR"],
228
- f"{elapsed:.2f}s",
229
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
 
232
- # ── Gallery helpers ──────────────────────────────────────────────────
233
 
234
  def load_gallery():
235
  meta_path = GALLERY_DIR / "metadata.json"
@@ -270,18 +259,20 @@ with gr.Blocks(title="UNIStainNet β€” Virtual IHC Staining") as demo:
270
 
271
  # ── Tab 1: Single Stain ──────────────────────────────────────
272
  with gr.Tab("Virtual Staining"):
 
 
 
 
273
  with gr.Row():
274
  with gr.Column(scale=1):
275
  input_image = gr.Image(type="pil", label="Upload H&E Image", height=400)
276
- stain_choice = gr.Radio(
277
- choices=STAIN_NAMES, value="HER2", label="Target IHC Stain"
278
- )
279
  guidance_slider = gr.Slider(
280
  minimum=1.0, maximum=3.0, step=0.1, value=1.0,
281
  label="Guidance Scale (1.0 = no CFG)",
282
  )
283
  generate_btn = gr.Button("Generate", variant="primary")
284
- gen_time = gr.Textbox(label="Time", interactive=False)
285
  with gr.Column(scale=1):
286
  output_image = gr.Image(type="pil", label="Generated IHC", height=400)
287
 
@@ -293,6 +284,10 @@ with gr.Blocks(title="UNIStainNet β€” Virtual IHC Staining") as demo:
293
 
294
  # ── Tab 2: Cross-Stain ───────────────────────────────────────
295
  with gr.Tab("Cross-Stain Comparison"):
 
 
 
 
296
  gr.Markdown(
297
  "Generate **all 4 IHC stains** from a single H&E input. "
298
  "Demonstrates the unified multi-stain capability."
@@ -300,11 +295,10 @@ with gr.Blocks(title="UNIStainNet β€” Virtual IHC Staining") as demo:
300
  with gr.Row():
301
  cross_input = gr.Image(type="pil", label="Upload H&E Image", height=350)
302
  cross_guidance = gr.Slider(
303
- minimum=1.0, maximum=3.0, step=0.1, value=1.0,
304
- label="Guidance Scale",
305
  )
306
  cross_btn = gr.Button("Generate All Stains", variant="primary")
307
- cross_time = gr.Textbox(label="Time", interactive=False)
308
 
309
  with gr.Row():
310
  cross_he_out = gr.Image(type="pil", label="H&E Input", height=300)
@@ -326,7 +320,7 @@ with gr.Blocks(title="UNIStainNet β€” Virtual IHC Staining") as demo:
326
  else:
327
  gr.Markdown(
328
  "Pre-computed examples β€” no GPU required. "
329
- "Select an example to view the H&E input and generated IHC stains."
330
  )
331
  gallery_dropdown = gr.Dropdown(
332
  choices=gallery_names,
 
1
  #!/usr/bin/env python3
2
  """
3
+ UNIStainNet Interactive Demo β€” Hugging Face Spaces
4
 
5
  Virtual staining of H&E histopathology images to IHC (HER2, Ki67, ER, PR).
6
+ Supports ZeroGPU (HF Pro) for live inference, falls back to gallery-only on CPU.
7
  """
8
 
9
  import json
 
14
  import gradio as gr
15
  import numpy as np
16
  import torch
17
+ import torch.nn.functional as F
18
+ import torchvision.transforms as T
19
+ import torchvision.transforms.functional as TF
20
+ from PIL import Image
21
+ from huggingface_hub import hf_hub_download
22
 
23
+ # ── ZeroGPU support ──────────────────────────────────────────────────
24
  try:
25
  import spaces
26
+ HAS_SPACES = True
27
  except ImportError:
28
  spaces = None
29
+ HAS_SPACES = False
30
 
31
+ GPU_AVAILABLE = torch.cuda.is_available()
32
 
33
  def _gpu_decorator(duration=60):
34
+ if HAS_SPACES and hasattr(spaces, "GPU"):
 
35
  return spaces.GPU(duration=duration)
36
  return lambda fn: fn
 
 
 
 
 
 
 
 
37
 
38
  # ── Constants ────────────────────────────────────────────────────────
39
  STAIN_NAMES = ["HER2", "Ki67", "ER", "PR"]
40
  GALLERY_DIR = Path(__file__).parent / "gallery"
41
  TARGET_SIZE = 512
 
 
42
  MODEL_REPO = os.environ.get("MODEL_REPO", "faceless-void/UNIStainNet")
43
  CHECKPOINT_FILENAME = "mist_multistain_last.ckpt"
44
 
45
+ NO_GPU_MSG = (
46
+ "GPU is not available on this Space (requires HF Pro for ZeroGPU). "
47
+ "Please use the **Gallery** tab to browse pre-computed results, "
48
+ "or run the app locally with a GPU: `python app.py`"
49
+ )
50
+
51
+ # ── Lazy imports (avoid crash if no GPU) ─────────────────────────────
52
  _model_cache = {"model": None, "uni_model": None, "spatial_pool_size": 32}
53
 
54
 
55
  def _get_checkpoint_path():
 
56
  local_path = Path(__file__).parent / "checkpoints" / CHECKPOINT_FILENAME
57
  if local_path.exists():
58
  return str(local_path)
 
59
  return hf_hub_download(repo_id=MODEL_REPO, filename=CHECKPOINT_FILENAME)
60
 
61
 
62
  def _load_models():
63
+ """Load models onto GPU. Only called when GPU is confirmed available."""
64
+ from src.models.trainer import UNIStainNetTrainer
65
+ import timm
66
 
67
+ if _model_cache["model"] is None:
68
  ckpt_path = _get_checkpoint_path()
69
  print(f"Loading UNIStainNet from {ckpt_path} ...")
70
  model = UNIStainNetTrainer.load_from_checkpoint(ckpt_path, strict=False)
71
  model = model.cuda().eval()
72
  _model_cache["model"] = model
73
+ _model_cache["spatial_pool_size"] = getattr(model.hparams, "uni_spatial_size", 32)
 
 
 
74
 
75
  print("Loading UNI ViT-L/16 ...")
76
  uni_model = timm.create_model(
77
+ "hf-hub:MahmoodLab/uni", pretrained=True,
78
+ init_values=1e-5, dynamic_img_size=True,
 
 
79
  )
80
  uni_model = uni_model.cuda().eval()
81
  _model_cache["uni_model"] = uni_model
82
+ print(" Models loaded")
83
  else:
 
84
  _model_cache["model"] = _model_cache["model"].cuda()
85
  _model_cache["uni_model"] = _model_cache["uni_model"].cuda()
86
 
87
  return _model_cache["model"], _model_cache["uni_model"], _model_cache["spatial_pool_size"]
88
 
89
 
90
+ # ── Preprocessing ────────────────────────────────────────────────────
91
 
92
  def preprocess_he(pil_image, target_size=TARGET_SIZE):
 
93
  w, h = pil_image.size
94
  short = min(w, h)
95
  left = (w - short) // 2
 
101
 
102
 
103
  def pil_to_tensor(pil_image):
 
104
  t = TF.to_tensor(pil_image)
105
  t = TF.normalize(t, [0.5] * 3, [0.5] * 3)
106
  return t.unsqueeze(0)
107
 
108
 
109
  def tensor_to_pil(tensor):
 
110
  t = ((tensor[0].cpu() + 1) / 2).clamp(0, 1)
111
  return TF.to_pil_image(t)
112
 
113
 
114
  def extract_uni_features(uni_model, he_tensor_01, spatial_pool_size=32):
115
+ from src.data.mist_dataset import STAIN_TO_LABEL
116
+ uni_transform = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 
 
117
  B = he_tensor_01.shape[0]
118
+ num_crops, patches_per_side = 4, 14
 
119
  crop_h = he_tensor_01.shape[2] // num_crops
120
  crop_w = he_tensor_01.shape[3] // num_crops
121
 
122
  sub_crops = []
123
  for i in range(num_crops):
124
  for j in range(num_crops):
125
+ sub = he_tensor_01[:, :, i*crop_h:(i+1)*crop_h, j*crop_w:(j+1)*crop_w]
 
 
126
  sub = F.interpolate(sub, size=(224, 224), mode="bicubic", align_corners=False)
127
  sub = torch.stack([uni_transform(s) for s in sub])
128
  sub_crops.append(sub)
129
 
130
  all_crops = torch.stack(sub_crops, dim=1).reshape(B * 16, 3, 224, 224).cuda()
 
131
  with torch.no_grad():
132
  all_feats = uni_model.forward_features(all_crops)
133
  patch_tokens = all_feats[:, 1:, :]
134
 
135
+ patch_tokens = patch_tokens.reshape(B, num_crops, num_crops, patches_per_side, patches_per_side, 1024)
 
 
136
  full_size = num_crops * patches_per_side
137
+ full_grid = patch_tokens.permute(0, 1, 3, 2, 4, 5).reshape(B, full_size, full_size, 1024)
 
 
138
 
139
  S = spatial_pool_size
140
  if S < full_size:
 
143
  result = pooled.permute(0, 2, 3, 1)
144
  else:
145
  result = full_grid
 
146
  return result.reshape(B, S * S, 1024)
147
 
148
 
149
+ # ── Inference functions ──────────────────────────────────────────────
150
 
151
  @_gpu_decorator(duration=60)
152
+ def _generate_single_gpu(image, stain, guidance_scale):
153
+ """GPU inference for single stain."""
154
+ from src.data.mist_dataset import STAIN_TO_LABEL
 
 
 
155
  model, uni_model, spatial_pool_size = _load_models()
156
 
157
  he_pil = preprocess_he(image)
158
  he_tensor = pil_to_tensor(he_pil).cuda()
159
  he_01 = ((he_tensor + 1) / 2).clamp(0, 1)
 
160
  uni_feats = extract_uni_features(uni_model, he_01, spatial_pool_size).cuda()
161
+ labels = torch.tensor([STAIN_TO_LABEL[stain]], device="cuda", dtype=torch.long)
 
162
 
163
  with torch.no_grad():
164
  gen = model.generate(he_tensor, uni_feats, labels, guidance_scale=guidance_scale)
165
+ return tensor_to_pil(gen)
 
 
 
166
 
167
 
168
  @_gpu_decorator(duration=120)
169
+ def _generate_all_gpu(image, guidance_scale):
170
+ """GPU inference for all 4 stains."""
171
+ from src.data.mist_dataset import STAIN_TO_LABEL
 
 
 
172
  model, uni_model, spatial_pool_size = _load_models()
173
 
174
  he_pil = preprocess_he(image)
175
  he_tensor = pil_to_tensor(he_pil).cuda()
176
  he_01 = ((he_tensor + 1) / 2).clamp(0, 1)
 
177
  uni_feats = extract_uni_features(uni_model, he_01, spatial_pool_size).cuda()
178
 
179
  results = {}
180
  for stain in STAIN_NAMES:
181
+ labels = torch.tensor([STAIN_TO_LABEL[stain]], device="cuda", dtype=torch.long)
 
182
  with torch.no_grad():
183
+ gen = model.generate(he_tensor, uni_feats, labels, guidance_scale=guidance_scale)
 
 
184
  results[stain] = tensor_to_pil(gen)
185
+ return he_pil, results
186
 
187
+
188
+ def generate_single_stain(image, stain, guidance_scale):
189
+ """Wrapper with GPU availability check."""
190
+ if image is None:
191
+ return None, "No image uploaded"
192
+ if not GPU_AVAILABLE and not HAS_SPACES:
193
+ return None, NO_GPU_MSG
194
+ try:
195
+ t0 = time.time()
196
+ result = _generate_single_gpu(image, stain, guidance_scale)
197
+ return result, f"{time.time() - t0:.2f}s"
198
+ except RuntimeError as e:
199
+ if "NVIDIA" in str(e) or "CUDA" in str(e) or "cuda" in str(e):
200
+ return None, NO_GPU_MSG
201
+ raise
202
+
203
+
204
+ def generate_all_stains(image, guidance_scale):
205
+ """Wrapper with GPU availability check."""
206
+ if image is None:
207
+ return None, None, None, None, None, "No image uploaded"
208
+ if not GPU_AVAILABLE and not HAS_SPACES:
209
+ return None, None, None, None, None, NO_GPU_MSG
210
+ try:
211
+ t0 = time.time()
212
+ he_pil, results = _generate_all_gpu(image, guidance_scale)
213
+ elapsed = f"{time.time() - t0:.2f}s"
214
+ return he_pil, results["HER2"], results["Ki67"], results["ER"], results["PR"], elapsed
215
+ except RuntimeError as e:
216
+ if "NVIDIA" in str(e) or "CUDA" in str(e) or "cuda" in str(e):
217
+ return None, None, None, None, None, NO_GPU_MSG
218
+ raise
219
 
220
 
221
+ # ── Gallery ──────────────────────────────────────────────────────────
222
 
223
  def load_gallery():
224
  meta_path = GALLERY_DIR / "metadata.json"
 
259
 
260
  # ── Tab 1: Single Stain ──────────────────────────────────────
261
  with gr.Tab("Virtual Staining"):
262
+ if not GPU_AVAILABLE and not HAS_SPACES:
263
+ gr.Markdown(
264
+ f"> **Note:** {NO_GPU_MSG}"
265
+ )
266
  with gr.Row():
267
  with gr.Column(scale=1):
268
  input_image = gr.Image(type="pil", label="Upload H&E Image", height=400)
269
+ stain_choice = gr.Radio(choices=STAIN_NAMES, value="HER2", label="Target IHC Stain")
 
 
270
  guidance_slider = gr.Slider(
271
  minimum=1.0, maximum=3.0, step=0.1, value=1.0,
272
  label="Guidance Scale (1.0 = no CFG)",
273
  )
274
  generate_btn = gr.Button("Generate", variant="primary")
275
+ gen_time = gr.Textbox(label="Status", interactive=False)
276
  with gr.Column(scale=1):
277
  output_image = gr.Image(type="pil", label="Generated IHC", height=400)
278
 
 
284
 
285
  # ── Tab 2: Cross-Stain ───────────────────────────────────────
286
  with gr.Tab("Cross-Stain Comparison"):
287
+ if not GPU_AVAILABLE and not HAS_SPACES:
288
+ gr.Markdown(
289
+ f"> **Note:** {NO_GPU_MSG}"
290
+ )
291
  gr.Markdown(
292
  "Generate **all 4 IHC stains** from a single H&E input. "
293
  "Demonstrates the unified multi-stain capability."
 
295
  with gr.Row():
296
  cross_input = gr.Image(type="pil", label="Upload H&E Image", height=350)
297
  cross_guidance = gr.Slider(
298
+ minimum=1.0, maximum=3.0, step=0.1, value=1.0, label="Guidance Scale",
 
299
  )
300
  cross_btn = gr.Button("Generate All Stains", variant="primary")
301
+ cross_time = gr.Textbox(label="Status", interactive=False)
302
 
303
  with gr.Row():
304
  cross_he_out = gr.Image(type="pil", label="H&E Input", height=300)
 
320
  else:
321
  gr.Markdown(
322
  "Pre-computed examples β€” no GPU required. "
323
+ "Select an example to view the H&E input, ground truth, and generated IHC stains."
324
  )
325
  gallery_dropdown = gr.Dropdown(
326
  choices=gallery_names,