rayquaza384mega commited on
Commit
d4a259c
·
1 Parent(s): 41e1156

Pseudo-color-3

Browse files
Files changed (1) hide show
  1. app.py +285 -284
app.py CHANGED
@@ -1,16 +1,12 @@
1
  import torch
2
  import numpy as np
3
  import gradio as gr
4
- from PIL import Image, ImageOps
5
  import os
6
  import json
7
- import glob
8
- import random
9
  import tifffile
10
  import re
11
- import imageio
12
  from torchvision import transforms, models
13
- import accelerate
14
  import shutil
15
  import time
16
  import torch.nn as nn
@@ -20,11 +16,12 @@ from collections import OrderedDict
20
  import tempfile
21
  import zipfile
22
  import matplotlib.cm as cm
 
 
23
 
24
  # --- Imports from both scripts ---
25
  from diffusers import DDPMScheduler, DDIMScheduler
26
  from transformers import CLIPTextModel, CLIPTokenizer
27
- from accelerate.state import AcceleratorState
28
  from transformers.utils import ContextManagers
29
 
30
  # --- Custom Model Imports ---
@@ -34,45 +31,37 @@ from models.controlnet import ControlNetModel
34
  from models.unet_2d_condition import UNet2DConditionModel
35
  from models.pipeline_controlnet import DDPMControlnetPipeline
36
 
37
- # --- New Import for Segmentation ---
38
  from cellpose import models as cellpose_models
39
- from cellpose import plot as cellpose_plot
40
  from huggingface_hub import snapshot_download
41
 
42
  # --- 0. Configuration & Constants ---
43
  hf_token = os.environ.get("HF_TOKEN")
44
- # --- General ---
45
  MODEL_TITLE = "🔬 FluoGen: AI-Powered Fluorescence Microscopy Suite"
46
  MODEL_DESCRIPTION = """
47
  **Paper**: *Generative AI empowering fluorescence microscopy imaging and analysis*
48
  <br>
49
- Select a task from the tabs below: generate new images from text, enhance existing images using super-resolution, denoise them, generate training data from segmentation masks, perform cell segmentation, or classify cell images.
 
50
  """
51
  DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
52
  WEIGHT_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
53
  LOGO_PATH = "utils/logo2_transparent.png"
54
-
55
- # --- Global switch to control example saving ---
56
  SAVE_EXAMPLES = False
57
 
58
  # --- Base directory for all models ---
59
  REPO_ID = "FluoGen-Group/FluoGen-demo-test-ckpts"
60
  MODELS_ROOT_DIR = snapshot_download(repo_id=REPO_ID, token=hf_token)
61
 
62
- # --- Tab 1: Mask-to-Image Config ---
63
  M2I_CONTROLNET_PATH = f"{MODELS_ROOT_DIR}/ControlNet_M2I/checkpoint-30000"
64
  M2I_EXAMPLE_IMG_DIR = "example_images_m2i"
65
-
66
- # --- Tab 2: Text-to-Image Config ---
67
  T2I_EXAMPLE_IMG_DIR = "example_images"
68
  T2I_PRETRAINED_MODEL_PATH = f"{MODELS_ROOT_DIR}/stable-diffusion-v1-5"
69
  T2I_UNET_PATH = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000"
70
-
71
- # --- Tab 3, 4: ControlNet-based Tasks Config ---
72
  CONTROLNET_CLIP_PATH = f"{MODELS_ROOT_DIR}/stable-diffusion-v1-5"
73
  CONTROLNET_UNET_PATH = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000"
74
 
75
- # Super-Resolution Models
76
  SR_CONTROLNET_MODELS = {
77
  "Checkpoint ER": f"{MODELS_ROOT_DIR}/ControlNet_SR/ER/checkpoint-30000",
78
  "Checkpoint Microtubules": f"{MODELS_ROOT_DIR}/ControlNet_SR/Microtubules/checkpoint-72500",
@@ -80,13 +69,10 @@ SR_CONTROLNET_MODELS = {
80
  "Checkpoint F-actin": f"{MODELS_ROOT_DIR}/ControlNet_SR/F-actin/checkpoint-35000",
81
  }
82
  SR_EXAMPLE_IMG_DIR = "example_images_sr"
83
-
84
- # Denoising Model
85
  DN_CONTROLNET_PATH = f"{MODELS_ROOT_DIR}/ControlNet_DN/checkpoint-10000"
86
  DN_PROMPT_RULES = {'MICE': 'mouse brain tissues', 'FISH': 'zebrafish embryos', 'BPAE_B': 'nucleus of BPAE', 'BPAE_R': 'mitochondria of BPAE', 'BPAE_G': 'F-actin of BPAE'}
87
  DN_EXAMPLE_IMG_DIR = "example_images_dn"
88
 
89
- # --- Tab 5: Cell Segmentation Config ---
90
  SEG_MODELS = {
91
  "DynamicNet Model": f"{MODELS_ROOT_DIR}/Cellpose/DynamicNet_baseline/CP_dynamic_ten_epoch_0100",
92
  "DynamicNet Model + FluoGen": f"{MODELS_ROOT_DIR}/Cellpose/DynamicNet_FluoGen/CP_dynamic_epoch_0300",
@@ -95,7 +81,6 @@ SEG_MODELS = {
95
  }
96
  SEG_EXAMPLE_IMG_DIR = "example_images_seg"
97
 
98
- # --- Tab 6: Classification Config ---
99
  CLS_MODEL_PATHS = OrderedDict({
100
  "5shot": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot_re",
101
  "5shot+FluoGen": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot_aug_re",
@@ -104,7 +89,7 @@ CLS_CLASS_NAMES = ['Nucleus', 'Endoplasmic Reticulum', 'Giantin', 'GPP130', 'Lys
104
  CLS_EXAMPLE_IMG_DIR = "example_images_cls"
105
 
106
  # --- Constants for Visualization ---
107
- COLOR_MAPS = ["Grayscale", "Green (GFP)", "Red (RFP)", "Blue (DAPI)", "Magenta", "Cyan", "Yellow", "Fire", "Viridis", "Inferno"]
108
 
109
  # --- Helper Functions ---
110
  def sanitize_prompt_for_filename(prompt):
@@ -116,57 +101,73 @@ def min_max_norm(x):
116
  if max_val - min_val < 1e-8: return np.zeros_like(x)
117
  return (x - min_val) / (max_val - min_val)
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  def apply_pseudocolor(image_np, color_name="Grayscale"):
120
  """
121
  Applies a pseudocolor to a single channel numpy image.
122
- image_np: Single channel numpy array (any bit depth).
123
  Returns: PIL Image in RGB.
124
  """
 
 
125
  # Normalize to 0-1 for processing
126
  norm_img = min_max_norm(np.squeeze(image_np))
127
 
128
  if color_name == "Grayscale":
129
- # Just convert to uint8 L
130
  return Image.fromarray((norm_img * 255).astype(np.uint8)).convert("RGB")
131
 
132
- # Create RGB canvas
133
  h, w = norm_img.shape
134
  rgb = np.zeros((h, w, 3), dtype=np.float32)
135
 
136
- if color_name == "Green (GFP)":
137
- rgb[..., 1] = norm_img
138
- elif color_name == "Red (RFP)":
139
- rgb[..., 0] = norm_img
140
- elif color_name == "Blue (DAPI)":
141
- rgb[..., 2] = norm_img
142
- elif color_name == "Magenta":
143
- rgb[..., 0] = norm_img
144
- rgb[..., 2] = norm_img
145
- elif color_name == "Cyan":
146
- rgb[..., 1] = norm_img
147
- rgb[..., 2] = norm_img
148
- elif color_name == "Yellow":
149
- rgb[..., 0] = norm_img
150
- rgb[..., 1] = norm_img
151
- elif color_name in ["Fire", "Viridis", "Inferno"]:
152
- # Use matplotlib colormaps
153
- cmap_map = {"Fire": "inferno", "Viridis": "viridis", "Inferno": "inferno"} # Fire looks like inferno usually
154
- if color_name == "Fire": cmap = cm.get_cmap("gnuplot2") # Better fire approximation
155
- else: cmap = cm.get_cmap(cmap_map[color_name])
156
-
157
- colored = cmap(norm_img) # Returns RGBA 0-1
158
- rgb = colored[..., :3] # Drop Alpha
159
 
160
  return Image.fromarray((rgb * 255).astype(np.uint8))
161
 
162
  def save_temp_tiff(image_np, prefix="output"):
163
- """Saves numpy array to a temp TIFF file and returns path."""
164
  tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".tif", prefix=f"{prefix}_")
165
- # Ensure compatible type for Tiff (float32 or uint16 preferred for science)
166
- if image_np.dtype == np.float16:
167
- save_data = image_np.astype(np.float32)
168
- else:
169
- save_data = image_np
170
  tifffile.imwrite(tfile.name, save_data)
171
  return tfile.name
172
 
@@ -176,8 +177,7 @@ def numpy_to_pil(image_np, target_mode="RGB"):
176
  if target_mode == "L" and image_np.mode != "L": return image_np.convert("L")
177
  return image_np
178
  squeezed_np = np.squeeze(image_np);
179
- if squeezed_np.dtype == np.uint8:
180
- image_8bit = squeezed_np
181
  else:
182
  normalized_np = min_max_norm(squeezed_np)
183
  image_8bit = (normalized_np * 255).astype(np.uint8)
@@ -211,8 +211,7 @@ def load_all_prompts():
211
  if isinstance(data, list):
212
  combined_prompts.extend(data)
213
  for p in data: PROMPT_TO_MODEL_MAP[p] = cat["model"]
214
- print(f"✓ Loaded {len(data)} prompts from {cat['file']}")
215
- except Exception as e: print(f"✗ Error loading {cat['file']}: {e}")
216
  if not combined_prompts: return ["F-actin of COS-7", "ER of COS-7"]
217
  return combined_prompts
218
  T2I_PROMPTS = load_all_prompts()
@@ -231,23 +230,22 @@ try:
231
  current_t2i_unet_path = T2I_UNET_PATH
232
  print("✓ Text-to-Image model loaded successfully!")
233
  except Exception as e:
234
- print(f"!!!!!! FATAL: Text-to-Image Model Loading Failed !!!!!!\nError: {e}")
235
 
236
  try:
237
- print("Loading shared ControlNet pipeline components...")
238
  controlnet_unet = UNet2DConditionModel.from_pretrained(CONTROLNET_UNET_PATH, subfolder="unet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
239
- default_controlnet_path = M2I_CONTROLNET_PATH
240
- controlnet_controlnet = ControlNetModel.from_pretrained(default_controlnet_path, subfolder="controlnet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
241
  controlnet_scheduler = DDIMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="v_prediction", rescale_betas_zero_snr=False, timestep_spacing="trailing")
242
  controlnet_tokenizer = CLIPTokenizer.from_pretrained(CONTROLNET_CLIP_PATH, subfolder="tokenizer")
243
  with ContextManagers([]):
244
  controlnet_text_encoder = CLIPTextModel.from_pretrained(CONTROLNET_CLIP_PATH, subfolder="text_encoder").to(dtype=WEIGHT_DTYPE, device=DEVICE)
245
  controlnet_pipe = DDPMControlnetPipeline(unet=controlnet_unet, controlnet=controlnet_controlnet, scheduler=controlnet_scheduler, text_encoder=controlnet_text_encoder, tokenizer=controlnet_tokenizer)
246
  controlnet_pipe.to(dtype=WEIGHT_DTYPE, device=DEVICE)
247
- controlnet_pipe.current_controlnet_path = default_controlnet_path
248
  print("✓ Shared ControlNet pipeline loaded successfully!")
249
  except Exception as e:
250
- print(f"!!!!!! FATAL: ControlNet Pipeline Loading Failed !!!!!!\nError: {e}")
251
 
252
  # --- 2. Core Logic Functions ---
253
  def swap_controlnet(pipe, target_path):
@@ -257,7 +255,7 @@ def swap_controlnet(pipe, target_path):
257
  pipe.controlnet = ControlNetModel.from_pretrained(target_path, subfolder="controlnet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
258
  pipe.current_controlnet_path = target_path
259
  except Exception as e:
260
- raise gr.Error(f"Failed to load ControlNet model '{target_path}'. Error: {e}")
261
  return pipe
262
 
263
  def swap_t2i_unet(pipe, target_unet_path):
@@ -269,245 +267,210 @@ def swap_t2i_unet(pipe, target_unet_path):
269
  new_unet = UNet2DModel.from_pretrained(target_unet_path, subfolder="unet").to(DEVICE)
270
  pipe.unet = new_unet
271
  current_t2i_unet_path = target_unet_path
272
- print("✅ UNet swapped successfully.")
273
  except Exception as e:
274
- raise gr.Error(f"Failed to load UNet from {target_unet_path}. Error: {e}")
275
  return pipe
276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  @spaces.GPU(duration=120)
278
- def generate_t2i(prompt, num_inference_steps, colormap_choice):
279
  global t2i_pipe
280
  if t2i_pipe is None: raise gr.Error("Text-to-Image model is not loaded.")
281
- target_model_path = PROMPT_TO_MODEL_MAP.get(prompt)
282
- if not target_model_path:
283
- target_model_path = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/FULL-checkpoint-275000"
284
- print(f"ℹ️ Prompt '{prompt}' not found in predefined list. Using Foundation (Full) model.")
285
  t2i_pipe = swap_t2i_unet(t2i_pipe, target_model_path)
286
 
287
- print(f"\n🚀 Task started... | Prompt: '{prompt}' | Model: {current_t2i_unet_path}")
288
  image_np = t2i_pipe(prompt.lower(), generator=None, num_inference_steps=int(num_inference_steps), output_type="np").images
289
 
290
- # 1. Save Raw Data
291
  raw_file_path = save_temp_tiff(image_np, prefix="t2i_raw")
 
 
292
 
293
- # 2. Apply Pseudocolor for Display
294
- display_image = apply_pseudocolor(image_np, colormap_choice)
295
-
296
- print("✓ Image generated")
297
  if SAVE_EXAMPLES:
298
  example_filepath = os.path.join(T2I_EXAMPLE_IMG_DIR, sanitize_prompt_for_filename(prompt))
299
- if not os.path.exists(example_filepath):
300
- display_image.save(example_filepath)
301
 
302
- # Return Display Image AND Path to Raw File
303
- return display_image, raw_file_path
304
 
305
  @spaces.GPU(duration=120)
306
- def run_mask_to_image_generation(mask_file_obj, cell_type, num_images, steps, seed, colormap_choice):
307
  if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
308
- if mask_file_obj is None: raise gr.Error("Please upload a segmentation mask TIF file.")
309
- if not cell_type or not cell_type.strip(): raise gr.Error("Please enter a cell type.")
310
 
311
  pipe = swap_controlnet(controlnet_pipe, M2I_CONTROLNET_PATH)
312
- try:
313
- mask_np = tifffile.imread(mask_file_obj.name)
314
- except Exception as e:
315
- raise gr.Error(f"Failed to read the TIF file. Error: {e}")
316
 
317
- input_display_image = numpy_to_pil(mask_np, "L")
318
  mask_normalized = min_max_norm(mask_np)
319
- image_tensor = torch.from_numpy(mask_normalized.astype(np.float32))
320
- image_tensor = image_tensor.unsqueeze(0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE)
321
  image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor)
322
 
323
  prompt = f"nuclei of {cell_type.strip()}"
324
- print(f"\nTask started... | Task: Mask-to-Image | Prompt: '{prompt}' | Steps: {steps} | Images: {num_images}")
325
 
 
326
  generated_display_images = []
327
  generated_raw_files = []
328
-
329
- # Create a temp dir for the zip file
330
  temp_dir = tempfile.mkdtemp()
331
 
332
  for i in range(int(num_images)):
333
- current_seed = int(seed) + i
334
- generator = torch.Generator(device=DEVICE).manual_seed(current_seed)
335
  with torch.autocast("cuda"):
336
  output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
337
 
338
- # Save individual raw file
 
 
 
339
  raw_name = f"m2i_sample_{i+1}.tif"
340
  raw_path = os.path.join(temp_dir, raw_name)
341
-
342
- # Ensure correct type saving
343
  save_data = output_np.astype(np.float32) if output_np.dtype == np.float16 else output_np
344
  tifffile.imwrite(raw_path, save_data)
345
  generated_raw_files.append(raw_path)
346
 
347
- # Create display image
348
- pil_image = apply_pseudocolor(output_np, colormap_choice)
349
- generated_display_images.append(pil_image)
350
- print(f"✓ Generated image {i+1}/{int(num_images)}")
351
 
352
- # Create ZIP file
353
  zip_filename = os.path.join(temp_dir, "raw_output_images.zip")
354
  with zipfile.ZipFile(zip_filename, 'w') as zipf:
355
- for file in generated_raw_files:
356
- zipf.write(file, os.path.basename(file))
357
 
358
- return input_display_image, generated_display_images, zip_filename
 
359
 
360
  @spaces.GPU(duration=120)
361
- def run_super_resolution(low_res_file_obj, controlnet_model_name, prompt, steps, seed, colormap_choice):
362
  if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
363
- if low_res_file_obj is None: raise gr.Error("Please upload a low-resolution TIF file.")
364
 
365
  target_path = SR_CONTROLNET_MODELS.get(controlnet_model_name)
366
- if not target_path: raise gr.Error(f"ControlNet model '{controlnet_model_name}' not found.")
367
-
368
  pipe = swap_controlnet(controlnet_pipe, target_path)
369
- try:
370
- image_stack_np = tifffile.imread(low_res_file_obj.name)
371
- except Exception as e:
372
- raise gr.Error(f"Failed to read the TIF file. Error: {e}")
373
 
374
  if image_stack_np.ndim != 3 or image_stack_np.shape[-3] != 9:
375
- raise gr.Error(f"Invalid TIF shape. Expected 9 channels (shape 9, H, W), but got {image_stack_np.shape}.")
376
 
377
- average_projection_np = np.mean(image_stack_np, axis=0)
378
- input_display_image = numpy_to_pil(average_projection_np, "L")
379
-
380
  image_tensor = torch.from_numpy(image_stack_np.astype(np.float32) / 65535.0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE)
381
  image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor)
382
 
 
383
  generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
384
  with torch.autocast("cuda"):
385
  output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
386
 
387
- # Save Raw
388
  raw_file_path = save_temp_tiff(output_np, prefix="sr_raw")
389
- # Display Color
390
- output_display = apply_pseudocolor(output_np, colormap_choice)
391
 
392
- return input_display_image, output_display, raw_file_path
393
 
394
  @spaces.GPU(duration=120)
395
- def run_denoising(noisy_image_np, image_type, steps, seed, colormap_choice):
396
  if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
397
- if noisy_image_np is None: raise gr.Error("Please upload a noisy image.")
398
 
399
  pipe = swap_controlnet(controlnet_pipe, DN_CONTROLNET_PATH)
400
  prompt = DN_PROMPT_RULES.get(image_type, 'microscopy image')
401
- print(f"\nTask started... | Task: Denoising | Prompt: '{prompt}' | Steps: {steps}")
402
 
403
- image_tensor = torch.from_numpy(noisy_image_np.astype(np.float32) / 255.0)
404
- image_tensor = image_tensor.unsqueeze(0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE)
405
  image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor)
406
 
407
  generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
408
  with torch.autocast("cuda"):
409
  output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
410
 
411
- # Save Raw
412
  raw_file_path = save_temp_tiff(output_np, prefix="dn_raw")
413
- # Display Color
414
- output_display = apply_pseudocolor(output_np, colormap_choice)
415
 
416
- return numpy_to_pil(noisy_image_np, "L"), output_display, raw_file_path
417
 
 
418
  @spaces.GPU(duration=120)
419
  def run_segmentation(input_image_np, model_name, diameter, flow_threshold, cellprob_threshold):
420
- # Segmentation remains mostly same, usually overlay is RGB anyway
421
- if input_image_np is None: raise gr.Error("Please upload an image to segment.")
422
  model_path = SEG_MODELS.get(model_name)
423
- if not model_path: raise gr.Error(f"Segmentation model '{model_name}' not found.")
424
 
425
- print(f"\nTask started... | Task: Cell Segmentation | Model: '{model_name}'")
426
  try:
427
- use_gpu = torch.cuda.is_available()
428
- model = cellpose_models.CellposeModel(gpu=use_gpu, pretrained_model=model_path)
429
- except Exception as e:
430
- raise gr.Error(f"Failed to load Cellpose model. Error: {e}")
431
-
432
- diameter_to_use = model.diam_labels if diameter == 0 else float(diameter)
433
- try:
434
- masks, _, _ = model.eval([input_image_np], channels=[0, 0], diameter=diameter_to_use, flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold)
435
- mask_output = masks[0]
436
- except Exception as e:
437
- raise gr.Error(f"Cellpose model evaluation failed. Error: {e}")
438
 
439
  original_rgb = numpy_to_pil(input_image_np, "RGB")
440
- original_rgb_np = np.array(original_rgb)
441
- red_mask_layer = np.zeros_like(original_rgb_np)
442
- red_mask_layer[mask_output > 0] = [139, 0, 0]
443
- blended_image_np = ((0.6 * original_rgb_np + 0.4 * red_mask_layer).astype(np.uint8))
444
-
445
- return numpy_to_pil(input_image_np, "L"), numpy_to_pil(blended_image_np, "RGB")
446
 
447
  @spaces.GPU(duration=120)
448
  def run_classification(input_image_np, model_name):
449
- if input_image_np is None: raise gr.Error("Please upload an image to classify.")
450
- model_dir = CLS_MODEL_PATHS.get(model_name)
451
- if not model_dir: raise gr.Error(f"Classification model '{model_name}' not found.")
452
- model_path = os.path.join(model_dir, "best_resnet50.pth")
453
 
454
- print(f"\nTask started... | Task: Classification | Model: '{model_name}'")
455
  try:
456
- model = models.resnet50(weights=None)
457
- model.fc = nn.Linear(model.fc.in_features, len(CLS_CLASS_NAMES))
458
  model.load_state_dict(torch.load(model_path, map_location=DEVICE))
459
  model.to(DEVICE).eval()
460
- except Exception as e:
461
- raise gr.Error(f"Failed to load classification model. Error: {e}")
462
-
463
- input_pil = numpy_to_pil(input_image_np, "RGB")
464
- transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
465
- input_tensor = transform_test(input_pil).unsqueeze(0).to(DEVICE)
466
 
 
467
  with torch.no_grad():
468
- outputs = model(input_tensor)
469
- probabilities = F.softmax(outputs, dim=1).squeeze().cpu().numpy()
470
-
471
- confidences = {name: float(prob) for name, prob in zip(CLS_CLASS_NAMES, probabilities)}
472
- return numpy_to_pil(input_image_np, "L"), confidences
473
-
474
 
475
  # --- 3. Gradio UI Layout ---
476
  print("Building Gradio interface...")
477
- for d in [M2I_EXAMPLE_IMG_DIR, T2I_EXAMPLE_IMG_DIR, SR_EXAMPLE_IMG_DIR, DN_EXAMPLE_IMG_DIR, SEG_EXAMPLE_IMG_DIR, CLS_EXAMPLE_IMG_DIR]:
478
- os.makedirs(d, exist_ok=True)
479
-
480
- # --- Load examples ---
481
- filename_to_prompt_map = { sanitize_prompt_for_filename(prompt): prompt for prompt in T2I_PROMPTS }
482
- t2i_gallery_examples = []
483
- for filename in os.listdir(T2I_EXAMPLE_IMG_DIR):
484
- if filename in filename_to_prompt_map:
485
- t2i_gallery_examples.append((os.path.join(T2I_EXAMPLE_IMG_DIR, filename), filename_to_prompt_map[filename]))
486
-
487
- def load_image_examples(example_dir, is_stack=False):
488
- examples = []
489
- if not os.path.exists(example_dir): return examples
490
- for f in sorted(os.listdir(example_dir)):
491
  if f.lower().endswith(('.tif', '.tiff', '.png', '.jpg')):
492
- filepath = os.path.join(example_dir, f)
493
  try:
494
- if f.lower().endswith(('.tif', '.tiff')): img_np = tifffile.imread(filepath)
495
- else: img_np = np.array(Image.open(filepath).convert("L"))
496
- if is_stack and img_np.ndim == 3: img_np = np.mean(img_np, axis=0)
497
- examples.append((numpy_to_pil(img_np, "L"), filepath))
498
  except: pass
499
- return examples
500
 
501
- m2i_gallery_examples = load_image_examples(M2I_EXAMPLE_IMG_DIR)
502
- sr_gallery_examples = load_image_examples(SR_EXAMPLE_IMG_DIR, is_stack=True)
503
- dn_gallery_examples = load_image_examples(DN_EXAMPLE_IMG_DIR)
504
- seg_gallery_examples = load_image_examples(SEG_EXAMPLE_IMG_DIR)
505
- cls_gallery_examples = load_image_examples(CLS_EXAMPLE_IMG_DIR)
506
-
507
- # --- Universal event handlers ---
508
- def select_example_prompt(evt: gr.SelectData): return evt.value['caption']
509
- def select_example_input_file(evt: gr.SelectData): return evt.value['caption']
510
 
 
511
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
512
  with gr.Row():
513
  gr.Image(value=LOGO_PATH, width=300, height=200, container=False, interactive=False, show_download_button=False, show_fullscreen_button=False)
@@ -516,122 +479,160 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
516
  with gr.Tabs():
517
  # --- TAB 1: Text-to-Image ---
518
  with gr.Tab("Text-to-Image Generation", id="txt2img"):
 
 
 
519
  with gr.Row(variant="panel"):
520
  with gr.Column(scale=1, min_width=350):
521
- t2i_prompt_input = gr.Dropdown(choices=T2I_PROMPTS, value=T2I_PROMPTS[0], label="Search or Type a Prompt", filterable=True, allow_custom_value=True)
522
- t2i_steps_slider = gr.Slider(minimum=10, maximum=200, step=1, value=50, label="Inference Steps")
523
- # New: Color Selector
524
- t2i_color_input = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor (for Display)")
525
- t2i_generate_button = gr.Button("Generate", variant="primary")
526
  with gr.Column(scale=2):
527
- t2i_generated_output = gr.Image(label="Generated Image (Pseudocolor)", type="pil", interactive=False)
528
- # New: Raw Download Button
529
- t2i_raw_download = gr.DownloadButton(label="Download Raw Output (.tif)", visible=True)
530
- t2i_gallery = gr.Gallery(value=t2i_gallery_examples, label="Examples", columns=6, object_fit="contain", height="auto")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531
 
532
  # --- TAB 2: Super-Resolution ---
533
  with gr.Tab("Super-Resolution", id="super_res"):
 
534
  with gr.Row(variant="panel"):
535
  with gr.Column(scale=1, min_width=350):
536
- sr_input_file = gr.File(label="Upload 9-Channel TIF Stack", file_types=['.tif', '.tiff'])
537
- sr_model_selector = gr.Dropdown(choices=list(SR_CONTROLNET_MODELS.keys()), value=list(SR_CONTROLNET_MODELS.keys())[-1], label="Select Super-Resolution Model")
538
- sr_prompt_input = gr.Textbox(label="Prompt", value="F-actin of COS-7", interactive=False)
539
- sr_steps_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Inference Steps")
540
- sr_seed_input = gr.Number(label="Seed", value=42)
541
- sr_color_input = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor (for Display)")
542
- sr_generate_button = gr.Button("Generate Super-Resolution", variant="primary")
543
  with gr.Column(scale=2):
544
  with gr.Row():
545
- sr_input_display = gr.Image(label="Input (Avg)", type="pil", interactive=False)
546
- sr_output_image = gr.Image(label="Super-Resolved Image", type="pil", interactive=False)
547
- sr_raw_download = gr.DownloadButton(label="Download Raw Output (.tif)", visible=True)
548
- sr_gallery = gr.Gallery(value=sr_gallery_examples, label="Examples", columns=6, object_fit="contain", height="auto")
 
 
 
 
 
 
 
 
 
 
 
 
549
 
550
  # --- TAB 3: Denoising ---
551
  with gr.Tab("Denoising", id="denoising"):
 
552
  with gr.Row(variant="panel"):
553
  with gr.Column(scale=1, min_width=350):
554
- dn_input_image = gr.Image(type="numpy", label="Upload Noisy Image", image_mode="L")
555
- dn_image_type_selector = gr.Dropdown(choices=list(DN_PROMPT_RULES.keys()), value='MICE', label="Select Image Type")
556
- dn_steps_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Inference Steps")
557
- dn_seed_input = gr.Number(label="Seed", value=42)
558
- dn_color_input = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor (for Display)")
559
- dn_generate_button = gr.Button("Denoise Image", variant="primary")
560
  with gr.Column(scale=2):
561
  with gr.Row():
562
- dn_original_display = gr.Image(label="Original", type="pil", interactive=False)
563
- dn_output_image = gr.Image(label="Denoised Image", type="pil", interactive=False)
564
- dn_raw_download = gr.DownloadButton(label="Download Raw Output (.tif)", visible=True)
565
- dn_gallery = gr.Gallery(value=dn_gallery_examples, label="Examples", columns=6, object_fit="contain", height="auto")
 
 
 
 
 
 
 
 
 
 
 
566
 
567
  # --- TAB 4: Mask-to-Image ---
568
  with gr.Tab("Mask-to-Image", id="mask2img"):
 
569
  with gr.Row(variant="panel"):
570
  with gr.Column(scale=1, min_width=350):
571
- m2i_input_file = gr.File(label="Upload Segmentation Mask (.tif)", file_types=['.tif', '.tiff'])
572
- m2i_cell_type_input = gr.Textbox(label="Cell Type", placeholder="e.g., CoNSS, HeLa")
573
- m2i_num_images_slider = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Number of Images")
574
- m2i_steps_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Inference Steps")
575
- m2i_seed_input = gr.Number(label="Seed", value=42)
576
- m2i_color_input = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor (for Display)")
577
- m2i_generate_button = gr.Button("Generate Samples", variant="primary")
578
  with gr.Column(scale=2):
579
- m2i_output_gallery = gr.Gallery(label="Generated Samples", columns=5, object_fit="contain", height="auto")
580
- m2i_raw_download = gr.DownloadButton(label="Download All Raw Samples (.zip)", visible=True)
581
- m2i_input_display = gr.Image(label="Input Mask", type="pil", interactive=False)
582
- m2i_gallery = gr.Gallery(value=m2i_gallery_examples, label="Examples", columns=6, object_fit="contain", height="auto")
 
 
 
 
 
 
 
 
 
 
 
583
 
584
  # --- TAB 5: Cell Segmentation ---
585
  with gr.Tab("Cell Segmentation", id="segmentation"):
586
  with gr.Row(variant="panel"):
587
  with gr.Column(scale=1, min_width=350):
588
- seg_input_image = gr.Image(type="numpy", label="Upload Image", image_mode="L")
589
- seg_model_selector = gr.Dropdown(choices=list(SEG_MODELS.keys()), value=list(SEG_MODELS.keys())[0], label="Model")
590
- seg_diameter_input = gr.Number(label="Cell Diameter (0=auto)", value=30)
591
- seg_flow_slider = gr.Slider(minimum=0.0, maximum=3.0, step=0.1, value=0.4, label="Flow Threshold")
592
- seg_cellprob_slider = gr.Slider(minimum=-6.0, maximum=6.0, step=0.5, value=0.0, label="Cell Probability Threshold")
593
- seg_generate_button = gr.Button("Segment Cells", variant="primary")
594
  with gr.Column(scale=2):
595
  with gr.Row():
596
- seg_original_display = gr.Image(label="Original", type="pil", interactive=False)
597
- seg_output_image = gr.Image(label="Segmented Overlay", type="pil", interactive=False)
598
- seg_gallery = gr.Gallery(value=seg_gallery_examples, label="Examples", columns=6, object_fit="contain", height="auto")
599
-
 
 
600
  # --- TAB 6: Classification ---
601
  with gr.Tab("Classification", id="classification"):
602
  with gr.Row(variant="panel"):
603
  with gr.Column(scale=1, min_width=350):
604
- cls_input_image = gr.Image(type="numpy", label="Upload Image", image_mode="L")
605
- cls_model_selector = gr.Dropdown(choices=list(CLS_MODEL_PATHS.keys()), value=list(CLS_MODEL_PATHS.keys())[0], label="Model")
606
- cls_generate_button = gr.Button("Classify Image", variant="primary")
607
  with gr.Column(scale=2):
608
- cls_original_display = gr.Image(label="Input Image", type="pil", interactive=False)
609
- cls_output_label = gr.Label(label="Results", num_top_classes=len(CLS_CLASS_NAMES))
610
- cls_gallery = gr.Gallery(value=cls_gallery_examples, label="Examples", columns=6, object_fit="contain", height="auto")
611
-
612
-
613
- # --- Event Handlers ---
614
- m2i_generate_button.click(fn=run_mask_to_image_generation, inputs=[m2i_input_file, m2i_cell_type_input, m2i_num_images_slider, m2i_steps_slider, m2i_seed_input, m2i_color_input], outputs=[m2i_input_display, m2i_output_gallery, m2i_raw_download])
615
- m2i_gallery.select(fn=select_example_input_file, outputs=m2i_input_file)
616
-
617
- t2i_generate_button.click(fn=generate_t2i, inputs=[t2i_prompt_input, t2i_steps_slider, t2i_color_input], outputs=[t2i_generated_output, t2i_raw_download])
618
- t2i_gallery.select(fn=select_example_prompt, outputs=t2i_prompt_input)
619
-
620
- sr_model_selector.change(fn=update_sr_prompt, inputs=sr_model_selector, outputs=sr_prompt_input)
621
- sr_generate_button.click(fn=run_super_resolution, inputs=[sr_input_file, sr_model_selector, sr_prompt_input, sr_steps_slider, sr_seed_input, sr_color_input], outputs=[sr_input_display, sr_output_image, sr_raw_download])
622
- sr_gallery.select(fn=select_example_input_file, outputs=sr_input_file)
623
-
624
- dn_generate_button.click(fn=run_denoising, inputs=[dn_input_image, dn_image_type_selector, dn_steps_slider, dn_seed_input, dn_color_input], outputs=[dn_original_display, dn_output_image, dn_raw_download])
625
- dn_gallery.select(fn=select_example_input_file, outputs=dn_input_image)
626
-
627
- seg_generate_button.click(fn=run_segmentation, inputs=[seg_input_image, seg_model_selector, seg_diameter_input, seg_flow_slider, seg_cellprob_slider], outputs=[seg_original_display, seg_output_image])
628
- seg_gallery.select(fn=select_example_input_file, outputs=seg_input_image)
629
-
630
- cls_generate_button.click(fn=run_classification, inputs=[cls_input_image, cls_model_selector], outputs=[cls_original_display, cls_output_label])
631
- cls_gallery.select(fn=select_example_input_file, outputs=cls_input_image)
632
-
633
 
634
- # --- 4. Launch Application ---
635
  if __name__ == "__main__":
636
- print("Interface built. Launching server...")
637
  demo.launch()
 
1
  import torch
2
  import numpy as np
3
  import gradio as gr
4
+ from PIL import Image
5
  import os
6
  import json
 
 
7
  import tifffile
8
  import re
 
9
  from torchvision import transforms, models
 
10
  import shutil
11
  import time
12
  import torch.nn as nn
 
16
  import tempfile
17
  import zipfile
18
  import matplotlib.cm as cm
19
+ import matplotlib.pyplot as plt
20
+ import io
21
 
22
  # --- Imports from both scripts ---
23
  from diffusers import DDPMScheduler, DDIMScheduler
24
  from transformers import CLIPTextModel, CLIPTokenizer
 
25
  from transformers.utils import ContextManagers
26
 
27
  # --- Custom Model Imports ---
 
31
  from models.unet_2d_condition import UNet2DConditionModel
32
  from models.pipeline_controlnet import DDPMControlnetPipeline
33
 
34
+ # --- Segmentation Imports ---
35
  from cellpose import models as cellpose_models
 
36
  from huggingface_hub import snapshot_download
37
 
38
  # --- 0. Configuration & Constants ---
39
  hf_token = os.environ.get("HF_TOKEN")
 
40
  MODEL_TITLE = "🔬 FluoGen: AI-Powered Fluorescence Microscopy Suite"
41
  MODEL_DESCRIPTION = """
42
  **Paper**: *Generative AI empowering fluorescence microscopy imaging and analysis*
43
  <br>
44
+ Select a task below.
45
+ **Note**: The "Pseudocolor" option can be adjusted *after* generation for instant visualization. Use the "Download Raw Output" button to get the scientific 16-bit/Float data.
46
  """
47
  DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
48
  WEIGHT_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
49
  LOGO_PATH = "utils/logo2_transparent.png"
 
 
50
  SAVE_EXAMPLES = False
51
 
52
  # --- Base directory for all models ---
53
  REPO_ID = "FluoGen-Group/FluoGen-demo-test-ckpts"
54
  MODELS_ROOT_DIR = snapshot_download(repo_id=REPO_ID, token=hf_token)
55
 
56
+ # --- Paths Config ---
57
  M2I_CONTROLNET_PATH = f"{MODELS_ROOT_DIR}/ControlNet_M2I/checkpoint-30000"
58
  M2I_EXAMPLE_IMG_DIR = "example_images_m2i"
 
 
59
  T2I_EXAMPLE_IMG_DIR = "example_images"
60
  T2I_PRETRAINED_MODEL_PATH = f"{MODELS_ROOT_DIR}/stable-diffusion-v1-5"
61
  T2I_UNET_PATH = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000"
 
 
62
  CONTROLNET_CLIP_PATH = f"{MODELS_ROOT_DIR}/stable-diffusion-v1-5"
63
  CONTROLNET_UNET_PATH = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000"
64
 
 
65
  SR_CONTROLNET_MODELS = {
66
  "Checkpoint ER": f"{MODELS_ROOT_DIR}/ControlNet_SR/ER/checkpoint-30000",
67
  "Checkpoint Microtubules": f"{MODELS_ROOT_DIR}/ControlNet_SR/Microtubules/checkpoint-72500",
 
69
  "Checkpoint F-actin": f"{MODELS_ROOT_DIR}/ControlNet_SR/F-actin/checkpoint-35000",
70
  }
71
  SR_EXAMPLE_IMG_DIR = "example_images_sr"
 
 
72
  DN_CONTROLNET_PATH = f"{MODELS_ROOT_DIR}/ControlNet_DN/checkpoint-10000"
73
  DN_PROMPT_RULES = {'MICE': 'mouse brain tissues', 'FISH': 'zebrafish embryos', 'BPAE_B': 'nucleus of BPAE', 'BPAE_R': 'mitochondria of BPAE', 'BPAE_G': 'F-actin of BPAE'}
74
  DN_EXAMPLE_IMG_DIR = "example_images_dn"
75
 
 
76
  SEG_MODELS = {
77
  "DynamicNet Model": f"{MODELS_ROOT_DIR}/Cellpose/DynamicNet_baseline/CP_dynamic_ten_epoch_0100",
78
  "DynamicNet Model + FluoGen": f"{MODELS_ROOT_DIR}/Cellpose/DynamicNet_FluoGen/CP_dynamic_epoch_0300",
 
81
  }
82
  SEG_EXAMPLE_IMG_DIR = "example_images_seg"
83
 
 
84
  CLS_MODEL_PATHS = OrderedDict({
85
  "5shot": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot_re",
86
  "5shot+FluoGen": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot_aug_re",
 
89
  CLS_EXAMPLE_IMG_DIR = "example_images_cls"
90
 
91
  # --- Constants for Visualization ---
92
+ COLOR_MAPS = ["Grayscale", "Green (GFP)", "Red (RFP)", "Blue (DAPI)", "Magenta", "Cyan", "Yellow", "Fire", "Viridis", "Inferno", "Magma", "Plasma"]
93
 
94
  # --- Helper Functions ---
95
  def sanitize_prompt_for_filename(prompt):
 
101
  if max_val - min_val < 1e-8: return np.zeros_like(x)
102
  return (x - min_val) / (max_val - min_val)
103
 
104
+ def generate_colorbar_preview(color_name):
105
+ """Generates a small PIL image representing the colormap."""
106
+ if color_name == "Grayscale":
107
+ gradient = np.linspace(0, 1, 256).reshape(1, 256)
108
+ return Image.fromarray((gradient * 255).astype(np.uint8)).convert("RGB").resize((256, 30))
109
+
110
+ gradient = np.linspace(0, 1, 256).reshape(1, 256)
111
+ rgb = np.zeros((1, 256, 3))
112
+
113
+ if color_name == "Green (GFP)": rgb[..., 1] = gradient
114
+ elif color_name == "Red (RFP)": rgb[..., 0] = gradient
115
+ elif color_name == "Blue (DAPI)": rgb[..., 2] = gradient
116
+ elif color_name == "Magenta": rgb[..., 0] = gradient; rgb[..., 2] = gradient
117
+ elif color_name == "Cyan": rgb[..., 1] = gradient; rgb[..., 2] = gradient
118
+ elif color_name == "Yellow": rgb[..., 0] = gradient; rgb[..., 1] = gradient
119
+ else:
120
+ # Matplotlib maps
121
+ mpl_map_name = color_name.lower()
122
+ if color_name == "Fire": mpl_map_name = "gnuplot2"
123
+ try:
124
+ cmap = cm.get_cmap(mpl_map_name)
125
+ rgb = cmap(gradient)[..., :3]
126
+ except:
127
+ return generate_colorbar_preview("Grayscale") # Fallback
128
+
129
+ img_np = (rgb * 255).astype(np.uint8)
130
+ return Image.fromarray(img_np).resize((256, 30))
131
+
132
  def apply_pseudocolor(image_np, color_name="Grayscale"):
133
  """
134
  Applies a pseudocolor to a single channel numpy image.
 
135
  Returns: PIL Image in RGB.
136
  """
137
+ if image_np is None: return None
138
+
139
  # Normalize to 0-1 for processing
140
  norm_img = min_max_norm(np.squeeze(image_np))
141
 
142
  if color_name == "Grayscale":
 
143
  return Image.fromarray((norm_img * 255).astype(np.uint8)).convert("RGB")
144
 
 
145
  h, w = norm_img.shape
146
  rgb = np.zeros((h, w, 3), dtype=np.float32)
147
 
148
+ if color_name == "Green (GFP)": rgb[..., 1] = norm_img
149
+ elif color_name == "Red (RFP)": rgb[..., 0] = norm_img
150
+ elif color_name == "Blue (DAPI)": rgb[..., 2] = norm_img
151
+ elif color_name == "Magenta": rgb[..., 0] = norm_img; rgb[..., 2] = norm_img
152
+ elif color_name == "Cyan": rgb[..., 1] = norm_img; rgb[..., 2] = norm_img
153
+ elif color_name == "Yellow": rgb[..., 0] = norm_img; rgb[..., 1] = norm_img
154
+ else:
155
+ # Matplotlib maps
156
+ mpl_map_name = color_name.lower()
157
+ if color_name == "Fire": mpl_map_name = "gnuplot2"
158
+ try:
159
+ cmap = cm.get_cmap(mpl_map_name)
160
+ colored = cmap(norm_img)
161
+ rgb = colored[..., :3]
162
+ except:
163
+ return apply_pseudocolor(image_np, "Grayscale")
 
 
 
 
 
 
 
164
 
165
  return Image.fromarray((rgb * 255).astype(np.uint8))
166
 
167
  def save_temp_tiff(image_np, prefix="output"):
 
168
  tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".tif", prefix=f"{prefix}_")
169
+ if image_np.dtype == np.float16: save_data = image_np.astype(np.float32)
170
+ else: save_data = image_np
 
 
 
171
  tifffile.imwrite(tfile.name, save_data)
172
  return tfile.name
173
 
 
177
  if target_mode == "L" and image_np.mode != "L": return image_np.convert("L")
178
  return image_np
179
  squeezed_np = np.squeeze(image_np);
180
+ if squeezed_np.dtype == np.uint8: image_8bit = squeezed_np
 
181
  else:
182
  normalized_np = min_max_norm(squeezed_np)
183
  image_8bit = (normalized_np * 255).astype(np.uint8)
 
211
  if isinstance(data, list):
212
  combined_prompts.extend(data)
213
  for p in data: PROMPT_TO_MODEL_MAP[p] = cat["model"]
214
+ except Exception: pass
 
215
  if not combined_prompts: return ["F-actin of COS-7", "ER of COS-7"]
216
  return combined_prompts
217
  T2I_PROMPTS = load_all_prompts()
 
230
  current_t2i_unet_path = T2I_UNET_PATH
231
  print("✓ Text-to-Image model loaded successfully!")
232
  except Exception as e:
233
+ print(f"FATAL: Text-to-Image Model Loading Failed: {e}")
234
 
235
  try:
236
+ print("Loading shared ControlNet pipeline...")
237
  controlnet_unet = UNet2DConditionModel.from_pretrained(CONTROLNET_UNET_PATH, subfolder="unet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
238
+ controlnet_controlnet = ControlNetModel.from_pretrained(M2I_CONTROLNET_PATH, subfolder="controlnet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
 
239
  controlnet_scheduler = DDIMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="v_prediction", rescale_betas_zero_snr=False, timestep_spacing="trailing")
240
  controlnet_tokenizer = CLIPTokenizer.from_pretrained(CONTROLNET_CLIP_PATH, subfolder="tokenizer")
241
  with ContextManagers([]):
242
  controlnet_text_encoder = CLIPTextModel.from_pretrained(CONTROLNET_CLIP_PATH, subfolder="text_encoder").to(dtype=WEIGHT_DTYPE, device=DEVICE)
243
  controlnet_pipe = DDPMControlnetPipeline(unet=controlnet_unet, controlnet=controlnet_controlnet, scheduler=controlnet_scheduler, text_encoder=controlnet_text_encoder, tokenizer=controlnet_tokenizer)
244
  controlnet_pipe.to(dtype=WEIGHT_DTYPE, device=DEVICE)
245
+ controlnet_pipe.current_controlnet_path = M2I_CONTROLNET_PATH
246
  print("✓ Shared ControlNet pipeline loaded successfully!")
247
  except Exception as e:
248
+ print(f"FATAL: ControlNet Pipeline Loading Failed: {e}")
249
 
250
  # --- 2. Core Logic Functions ---
251
  def swap_controlnet(pipe, target_path):
 
255
  pipe.controlnet = ControlNetModel.from_pretrained(target_path, subfolder="controlnet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
256
  pipe.current_controlnet_path = target_path
257
  except Exception as e:
258
+ raise gr.Error(f"Failed to load ControlNet model. Error: {e}")
259
  return pipe
260
 
261
  def swap_t2i_unet(pipe, target_unet_path):
 
267
  new_unet = UNet2DModel.from_pretrained(target_unet_path, subfolder="unet").to(DEVICE)
268
  pipe.unet = new_unet
269
  current_t2i_unet_path = target_unet_path
 
270
  except Exception as e:
271
+ raise gr.Error(f"Failed to load UNet. Error: {e}")
272
  return pipe
273
 
274
+ # --- Dynamic Color Update Functions ---
275
+ def update_single_image_color(raw_np_state, color_name):
276
+ if raw_np_state is None: return None, None
277
+ display_img = apply_pseudocolor(raw_np_state, color_name)
278
+ bar_img = generate_colorbar_preview(color_name)
279
+ return display_img, bar_img
280
+
281
+ def update_gallery_color(raw_list_state, color_name):
282
+ if raw_list_state is None: return None, None
283
+ new_gallery = []
284
+ for img_np in raw_list_state:
285
+ new_gallery.append(apply_pseudocolor(img_np, color_name))
286
+ bar_img = generate_colorbar_preview(color_name)
287
+ return new_gallery, bar_img
288
+
289
+ # --- Generation Functions ---
290
  @spaces.GPU(duration=120)
291
+ def generate_t2i(prompt, num_inference_steps, current_color):
292
  global t2i_pipe
293
  if t2i_pipe is None: raise gr.Error("Text-to-Image model is not loaded.")
294
+ target_model_path = PROMPT_TO_MODEL_MAP.get(prompt, f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/FULL-checkpoint-275000")
 
 
 
295
  t2i_pipe = swap_t2i_unet(t2i_pipe, target_model_path)
296
 
297
+ print(f"\n🚀 T2I Task started... | Prompt: '{prompt}'")
298
  image_np = t2i_pipe(prompt.lower(), generator=None, num_inference_steps=int(num_inference_steps), output_type="np").images
299
 
 
300
  raw_file_path = save_temp_tiff(image_np, prefix="t2i_raw")
301
+ display_image = apply_pseudocolor(image_np, current_color)
302
+ colorbar_img = generate_colorbar_preview(current_color)
303
 
 
 
 
 
304
  if SAVE_EXAMPLES:
305
  example_filepath = os.path.join(T2I_EXAMPLE_IMG_DIR, sanitize_prompt_for_filename(prompt))
306
+ if not os.path.exists(example_filepath): display_image.save(example_filepath)
 
307
 
308
+ # Return: Display, Raw Path, Raw State, Colorbar
309
+ return display_image, raw_file_path, image_np, colorbar_img
310
 
311
  @spaces.GPU(duration=120)
312
+ def run_mask_to_image_generation(mask_file_obj, cell_type, num_images, steps, seed, current_color):
313
  if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
314
+ if mask_file_obj is None: raise gr.Error("Please upload a segmentation mask.")
 
315
 
316
  pipe = swap_controlnet(controlnet_pipe, M2I_CONTROLNET_PATH)
317
+ try: mask_np = tifffile.imread(mask_file_obj.name)
318
+ except Exception as e: raise gr.Error(f"Failed to read TIF. Error: {e}")
 
 
319
 
320
+ input_display = numpy_to_pil(mask_np, "L")
321
  mask_normalized = min_max_norm(mask_np)
322
+ image_tensor = torch.from_numpy(mask_normalized.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE)
 
323
  image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor)
324
 
325
  prompt = f"nuclei of {cell_type.strip()}"
326
+ print(f"\nM2I Task started... | Prompt: '{prompt}'")
327
 
328
+ generated_raw_list = []
329
  generated_display_images = []
330
  generated_raw_files = []
 
 
331
  temp_dir = tempfile.mkdtemp()
332
 
333
  for i in range(int(num_images)):
334
+ generator = torch.Generator(device=DEVICE).manual_seed(int(seed) + i)
 
335
  with torch.autocast("cuda"):
336
  output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
337
 
338
+ # Save Raw State
339
+ generated_raw_list.append(output_np)
340
+
341
+ # Save Temp File
342
  raw_name = f"m2i_sample_{i+1}.tif"
343
  raw_path = os.path.join(temp_dir, raw_name)
 
 
344
  save_data = output_np.astype(np.float32) if output_np.dtype == np.float16 else output_np
345
  tifffile.imwrite(raw_path, save_data)
346
  generated_raw_files.append(raw_path)
347
 
348
+ # Display
349
+ generated_display_images.append(apply_pseudocolor(output_np, current_color))
 
 
350
 
351
+ # ZIP
352
  zip_filename = os.path.join(temp_dir, "raw_output_images.zip")
353
  with zipfile.ZipFile(zip_filename, 'w') as zipf:
354
+ for file in generated_raw_files: zipf.write(file, os.path.basename(file))
 
355
 
356
+ colorbar_img = generate_colorbar_preview(current_color)
357
+ return input_display, generated_display_images, zip_filename, generated_raw_list, colorbar_img
358
 
359
  @spaces.GPU(duration=120)
360
+ def run_super_resolution(low_res_file_obj, controlnet_model_name, prompt, steps, seed, current_color):
361
  if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
362
+ if low_res_file_obj is None: raise gr.Error("Please upload a file.")
363
 
364
  target_path = SR_CONTROLNET_MODELS.get(controlnet_model_name)
 
 
365
  pipe = swap_controlnet(controlnet_pipe, target_path)
366
+
367
+ try: image_stack_np = tifffile.imread(low_res_file_obj.name)
368
+ except Exception as e: raise gr.Error(f"Failed to read TIF. Error: {e}")
 
369
 
370
  if image_stack_np.ndim != 3 or image_stack_np.shape[-3] != 9:
371
+ raise gr.Error(f"Invalid TIF shape. Expected 9 channels, got {image_stack_np.shape}.")
372
 
373
+ input_display = numpy_to_pil(np.mean(image_stack_np, axis=0), "L")
 
 
374
  image_tensor = torch.from_numpy(image_stack_np.astype(np.float32) / 65535.0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE)
375
  image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor)
376
 
377
+ print(f"\nSR Task started...")
378
  generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
379
  with torch.autocast("cuda"):
380
  output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
381
 
 
382
  raw_file_path = save_temp_tiff(output_np, prefix="sr_raw")
383
+ output_display = apply_pseudocolor(output_np, current_color)
384
+ colorbar_img = generate_colorbar_preview(current_color)
385
 
386
+ return input_display, output_display, raw_file_path, output_np, colorbar_img
387
 
388
  @spaces.GPU(duration=120)
389
+ def run_denoising(noisy_image_np, image_type, steps, seed, current_color):
390
  if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
391
+ if noisy_image_np is None: raise gr.Error("Please upload an image.")
392
 
393
  pipe = swap_controlnet(controlnet_pipe, DN_CONTROLNET_PATH)
394
  prompt = DN_PROMPT_RULES.get(image_type, 'microscopy image')
 
395
 
396
+ print(f"\nDN Task started...")
397
+ image_tensor = torch.from_numpy(noisy_image_np.astype(np.float32) / 255.0).unsqueeze(0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE)
398
  image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor)
399
 
400
  generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
401
  with torch.autocast("cuda"):
402
  output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
403
 
 
404
  raw_file_path = save_temp_tiff(output_np, prefix="dn_raw")
405
+ output_display = apply_pseudocolor(output_np, current_color)
406
+ colorbar_img = generate_colorbar_preview(current_color)
407
 
408
+ return numpy_to_pil(noisy_image_np, "L"), output_display, raw_file_path, output_np, colorbar_img
409
 
410
+ # --- Segmentation & Classification (Unchanged Logic, just wrapper) ---
411
  @spaces.GPU(duration=120)
412
  def run_segmentation(input_image_np, model_name, diameter, flow_threshold, cellprob_threshold):
413
+ if input_image_np is None: raise gr.Error("Please upload an image.")
 
414
  model_path = SEG_MODELS.get(model_name)
 
415
 
416
+ print(f"\nSeg Task started...")
417
  try:
418
+ model = cellpose_models.CellposeModel(gpu=torch.cuda.is_available(), pretrained_model=model_path)
419
+ masks, _, _ = model.eval([input_image_np], channels=[0, 0], diameter=(model.diam_labels if diameter==0 else diameter), flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold)
420
+ except Exception as e: raise gr.Error(f"Segmentation failed: {e}")
 
 
 
 
 
 
 
 
421
 
422
  original_rgb = numpy_to_pil(input_image_np, "RGB")
423
+ red_mask = np.zeros_like(np.array(original_rgb)); red_mask[masks[0] > 0] = [139, 0, 0]
424
+ blended = ((0.6 * np.array(original_rgb) + 0.4 * red_mask).astype(np.uint8))
425
+ return numpy_to_pil(input_image_np, "L"), numpy_to_pil(blended, "RGB")
 
 
 
426
 
427
  @spaces.GPU(duration=120)
428
  def run_classification(input_image_np, model_name):
429
+ if input_image_np is None: raise gr.Error("Please upload an image.")
430
+ model_path = os.path.join(CLS_MODEL_PATHS.get(model_name), "best_resnet50.pth")
 
 
431
 
432
+ print(f"\nCls Task started...")
433
  try:
434
+ model = models.resnet50(weights=None); model.fc = nn.Linear(model.fc.in_features, len(CLS_CLASS_NAMES))
 
435
  model.load_state_dict(torch.load(model_path, map_location=DEVICE))
436
  model.to(DEVICE).eval()
437
+ except Exception as e: raise gr.Error(f"Classification failed: {e}")
 
 
 
 
 
438
 
439
+ input_tensor = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)])(numpy_to_pil(input_image_np, "RGB")).unsqueeze(0).to(DEVICE)
440
  with torch.no_grad():
441
+ probs = F.softmax(model(input_tensor), dim=1).squeeze().cpu().numpy()
442
+ return numpy_to_pil(input_image_np, "L"), {name: float(p) for name, p in zip(CLS_CLASS_NAMES, probs)}
 
 
 
 
443
 
444
  # --- 3. Gradio UI Layout ---
445
  print("Building Gradio interface...")
446
+ for d in [M2I_EXAMPLE_IMG_DIR, T2I_EXAMPLE_IMG_DIR, SR_EXAMPLE_IMG_DIR, DN_EXAMPLE_IMG_DIR, SEG_EXAMPLE_IMG_DIR, CLS_EXAMPLE_IMG_DIR]: os.makedirs(d, exist_ok=True)
447
+
448
+ # Load Examples
449
+ filename_to_prompt = { sanitize_prompt_for_filename(p): p for p in T2I_PROMPTS }
450
+ t2i_examples = []
451
+ for f in os.listdir(T2I_EXAMPLE_IMG_DIR):
452
+ if f in filename_to_prompt: t2i_examples.append((os.path.join(T2I_EXAMPLE_IMG_DIR, f), filename_to_prompt[f]))
453
+
454
+ def load_examples(d, stack=False):
455
+ ex = []
456
+ if not os.path.exists(d): return ex
457
+ for f in sorted(os.listdir(d)):
 
 
458
  if f.lower().endswith(('.tif', '.tiff', '.png', '.jpg')):
459
+ fp = os.path.join(d, f)
460
  try:
461
+ img = tifffile.imread(fp) if f.endswith(('.tif','.tiff')) else np.array(Image.open(fp).convert("L"))
462
+ if stack and img.ndim == 3: img = np.mean(img, axis=0)
463
+ ex.append((numpy_to_pil(img, "L"), fp))
 
464
  except: pass
465
+ return ex
466
 
467
+ m2i_examples = load_examples(M2I_EXAMPLE_IMG_DIR)
468
+ sr_examples = load_examples(SR_EXAMPLE_IMG_DIR, stack=True)
469
+ dn_examples = load_examples(DN_EXAMPLE_IMG_DIR)
470
+ seg_examples = load_examples(SEG_EXAMPLE_IMG_DIR)
471
+ cls_examples = load_examples(CLS_EXAMPLE_IMG_DIR)
 
 
 
 
472
 
473
+ # --- UI Builders ---
474
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
475
  with gr.Row():
476
  gr.Image(value=LOGO_PATH, width=300, height=200, container=False, interactive=False, show_download_button=False, show_fullscreen_button=False)
 
479
  with gr.Tabs():
480
  # --- TAB 1: Text-to-Image ---
481
  with gr.Tab("Text-to-Image Generation", id="txt2img"):
482
+ # State to hold raw numpy data
483
+ t2i_raw_state = gr.State(None)
484
+
485
  with gr.Row(variant="panel"):
486
  with gr.Column(scale=1, min_width=350):
487
+ t2i_prompt = gr.Dropdown(choices=T2I_PROMPTS, value=T2I_PROMPTS[0], label="Search or Type a Prompt", filterable=True, allow_custom_value=True)
488
+ t2i_steps = gr.Slider(10, 200, 50, step=1, label="Inference Steps")
489
+ t2i_btn = gr.Button("Generate", variant="primary")
 
 
490
  with gr.Column(scale=2):
491
+ # Image with Download Button DISABLED
492
+ t2i_out = gr.Image(label="Generated Image", type="pil", interactive=False, show_download_button=False)
493
+
494
+ with gr.Row(equal_height=True):
495
+ # Color Controls
496
+ with gr.Column(scale=2):
497
+ t2i_color = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor (Adjust after generation)")
498
+ with gr.Column(scale=2):
499
+ t2i_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False)
500
+ with gr.Column(scale=1):
501
+ t2i_dl = gr.DownloadButton(label="Download Raw (.tif)")
502
+
503
+ t2i_gal = gr.Gallery(value=t2i_examples, label="Examples", columns=6, height="auto")
504
+
505
+ # Events
506
+ t2i_btn.click(
507
+ fn=generate_t2i,
508
+ inputs=[t2i_prompt, t2i_steps, t2i_color],
509
+ outputs=[t2i_out, t2i_dl, t2i_raw_state, t2i_colorbar]
510
+ )
511
+ # Real-time color update using State
512
+ t2i_color.change(
513
+ fn=update_single_image_color,
514
+ inputs=[t2i_raw_state, t2i_color],
515
+ outputs=[t2i_out, t2i_colorbar]
516
+ )
517
+ t2i_gal.select(lambda e: e.value['caption'], None, t2i_prompt)
518
 
519
  # --- TAB 2: Super-Resolution ---
520
  with gr.Tab("Super-Resolution", id="super_res"):
521
+ sr_raw_state = gr.State(None)
522
  with gr.Row(variant="panel"):
523
  with gr.Column(scale=1, min_width=350):
524
+ sr_file = gr.File(label="Upload 9-Channel TIF Stack", file_types=['.tif', '.tiff'])
525
+ sr_model = gr.Dropdown(choices=list(SR_CONTROLNET_MODELS.keys()), value=list(SR_CONTROLNET_MODELS.keys())[-1], label="Model")
526
+ sr_prompt = gr.Textbox(label="Prompt", value="F-actin of COS-7", interactive=False)
527
+ sr_steps = gr.Slider(5, 50, 10, step=1, label="Steps")
528
+ sr_seed = gr.Number(label="Seed", value=42)
529
+ sr_btn = gr.Button("Generate", variant="primary")
 
530
  with gr.Column(scale=2):
531
  with gr.Row():
532
+ sr_in_disp = gr.Image(label="Input (Avg)", type="pil", interactive=False, show_download_button=False)
533
+ sr_out_disp = gr.Image(label="Output", type="pil", interactive=False, show_download_button=False)
534
+ with gr.Row(equal_height=True):
535
+ with gr.Column(scale=2):
536
+ sr_color = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor")
537
+ with gr.Column(scale=2):
538
+ sr_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False)
539
+ with gr.Column(scale=1):
540
+ sr_dl = gr.DownloadButton(label="Download Raw (.tif)")
541
+
542
+ sr_gal = gr.Gallery(value=sr_examples, label="Examples", columns=6, height="auto")
543
+
544
+ sr_model.change(update_sr_prompt, sr_model, sr_prompt)
545
+ sr_btn.click(run_super_resolution, [sr_file, sr_model, sr_prompt, sr_steps, sr_seed, sr_color], [sr_in_disp, sr_out_disp, sr_dl, sr_raw_state, sr_colorbar])
546
+ sr_color.change(update_single_image_color, [sr_raw_state, sr_color], [sr_out_disp, sr_colorbar])
547
+ sr_gal.select(lambda e: e.value['caption'], None, sr_file)
548
 
549
  # --- TAB 3: Denoising ---
550
  with gr.Tab("Denoising", id="denoising"):
551
+ dn_raw_state = gr.State(None)
552
  with gr.Row(variant="panel"):
553
  with gr.Column(scale=1, min_width=350):
554
+ dn_img = gr.Image(type="numpy", label="Upload Noisy Image", image_mode="L")
555
+ dn_type = gr.Dropdown(choices=list(DN_PROMPT_RULES.keys()), value='MICE', label="Image Type")
556
+ dn_steps = gr.Slider(5, 50, 10, step=1, label="Steps")
557
+ dn_seed = gr.Number(label="Seed", value=42)
558
+ dn_btn = gr.Button("Denoise", variant="primary")
 
559
  with gr.Column(scale=2):
560
  with gr.Row():
561
+ dn_orig = gr.Image(label="Original", type="pil", interactive=False, show_download_button=False)
562
+ dn_out = gr.Image(label="Denoised", type="pil", interactive=False, show_download_button=False)
563
+ with gr.Row(equal_height=True):
564
+ with gr.Column(scale=2):
565
+ dn_color = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor")
566
+ with gr.Column(scale=2):
567
+ dn_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False)
568
+ with gr.Column(scale=1):
569
+ dn_dl = gr.DownloadButton(label="Download Raw (.tif)")
570
+
571
+ dn_gal = gr.Gallery(value=dn_examples, label="Examples", columns=6, height="auto")
572
+
573
+ dn_btn.click(run_denoising, [dn_img, dn_type, dn_steps, dn_seed, dn_color], [dn_orig, dn_out, dn_dl, dn_raw_state, dn_colorbar])
574
+ dn_color.change(update_single_image_color, [dn_raw_state, dn_color], [dn_out, dn_colorbar])
575
+ dn_gal.select(lambda e: e.value['caption'], None, dn_img)
576
 
577
  # --- TAB 4: Mask-to-Image ---
578
  with gr.Tab("Mask-to-Image", id="mask2img"):
579
+ m2i_raw_state = gr.State(None) # Stores list of numpy arrays
580
  with gr.Row(variant="panel"):
581
  with gr.Column(scale=1, min_width=350):
582
+ m2i_file = gr.File(label="Upload Mask (.tif)", file_types=['.tif', '.tiff'])
583
+ m2i_type = gr.Textbox(label="Cell Type", placeholder="e.g., HeLa")
584
+ m2i_num = gr.Slider(1, 10, 5, step=1, label="Count")
585
+ m2i_steps = gr.Slider(5, 50, 10, step=1, label="Steps")
586
+ m2i_seed = gr.Number(label="Seed", value=42)
587
+ m2i_btn = gr.Button("Generate", variant="primary")
 
588
  with gr.Column(scale=2):
589
+ m2i_gal_out = gr.Gallery(label="Generated Samples", columns=5, height="auto")
590
+ with gr.Row(equal_height=True):
591
+ with gr.Column(scale=2):
592
+ m2i_color = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor")
593
+ with gr.Column(scale=2):
594
+ m2i_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False)
595
+ with gr.Column(scale=1):
596
+ m2i_dl = gr.DownloadButton(label="Download ZIP")
597
+ m2i_in_disp = gr.Image(label="Input Mask", type="pil", interactive=False, show_download_button=False)
598
+
599
+ m2i_gal = gr.Gallery(value=m2i_examples, label="Examples", columns=6, height="auto")
600
+
601
+ m2i_btn.click(run_mask_to_image_generation, [m2i_file, m2i_type, m2i_num, m2i_steps, m2i_seed, m2i_color], [m2i_in_disp, m2i_gal_out, m2i_dl, m2i_raw_state, m2i_colorbar])
602
+ m2i_color.change(update_gallery_color, [m2i_raw_state, m2i_color], [m2i_gal_out, m2i_colorbar])
603
+ m2i_gal.select(lambda e: e.value['caption'], None, m2i_file)
604
 
605
  # --- TAB 5: Cell Segmentation ---
606
  with gr.Tab("Cell Segmentation", id="segmentation"):
607
  with gr.Row(variant="panel"):
608
  with gr.Column(scale=1, min_width=350):
609
+ seg_img = gr.Image(type="numpy", label="Upload Image", image_mode="L")
610
+ seg_model = gr.Dropdown(choices=list(SEG_MODELS.keys()), value=list(SEG_MODELS.keys())[0], label="Model")
611
+ seg_diam = gr.Number(label="Diameter (0=auto)", value=30)
612
+ seg_flow = gr.Slider(0.0, 3.0, 0.4, step=0.1, label="Flow Thresh")
613
+ seg_prob = gr.Slider(-6.0, 6.0, 0.0, step=0.5, label="Prob Thresh")
614
+ seg_btn = gr.Button("Segment", variant="primary")
615
  with gr.Column(scale=2):
616
  with gr.Row():
617
+ seg_orig = gr.Image(label="Original", type="pil", interactive=False, show_download_button=False)
618
+ seg_out = gr.Image(label="Overlay", type="pil", interactive=False, show_download_button=False)
619
+ seg_gal = gr.Gallery(value=seg_examples, label="Examples", columns=6, height="auto")
620
+ seg_btn.click(run_segmentation, [seg_img, seg_model, seg_diam, seg_flow, seg_prob], [seg_orig, seg_out])
621
+ seg_gal.select(lambda e: e.value['caption'], None, seg_img)
622
+
623
  # --- TAB 6: Classification ---
624
  with gr.Tab("Classification", id="classification"):
625
  with gr.Row(variant="panel"):
626
  with gr.Column(scale=1, min_width=350):
627
+ cls_img = gr.Image(type="numpy", label="Upload Image", image_mode="L")
628
+ cls_model = gr.Dropdown(choices=list(CLS_MODEL_PATHS.keys()), value=list(CLS_MODEL_PATHS.keys())[0], label="Model")
629
+ cls_btn = gr.Button("Classify", variant="primary")
630
  with gr.Column(scale=2):
631
+ cls_orig = gr.Image(label="Input", type="pil", interactive=False, show_download_button=False)
632
+ cls_res = gr.Label(label="Results", num_top_classes=10)
633
+ cls_gal = gr.Gallery(value=cls_examples, label="Examples", columns=6, height="auto")
634
+ cls_btn.click(run_classification, [cls_img, cls_model], [cls_orig, cls_res])
635
+ cls_gal.select(lambda e: e.value['caption'], None, cls_img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
636
 
 
637
  if __name__ == "__main__":
 
638
  demo.launch()