rayquaza384mega commited on
Commit
517a4aa
·
1 Parent(s): ad236df

Pseudo-color

Browse files
Files changed (1) hide show
  1. app.py +195 -333
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
  import numpy as np
3
  import gradio as gr
4
- from PIL import Image
5
  import os
6
  import json
7
  import glob
@@ -17,7 +17,9 @@ import torch.nn as nn
17
  import torch.nn.functional as F
18
  import spaces
19
  from collections import OrderedDict
20
-
 
 
21
 
22
  # --- Imports from both scripts ---
23
  from diffusers import DDPMScheduler, DDIMScheduler
@@ -54,17 +56,14 @@ LOGO_PATH = "utils/logo2_transparent.png"
54
  SAVE_EXAMPLES = False
55
 
56
  # --- Base directory for all models ---
57
- # NOTE: All model paths are now relative.
58
- # Run the `copy_weights.py` script once to copy all necessary model files into this local directory.
59
  REPO_ID = "FluoGen-Group/FluoGen-demo-test-ckpts"
60
- MODELS_ROOT_DIR = snapshot_download(repo_id=REPO_ID, token=hf_token) #"models_collection"
61
 
62
- # --- Tab 1: Mask-to-Image Config (Formerly Segmentation-to-Image) ---
63
  M2I_CONTROLNET_PATH = f"{MODELS_ROOT_DIR}/ControlNet_M2I/checkpoint-30000"
64
  M2I_EXAMPLE_IMG_DIR = "example_images_m2i"
65
 
66
  # --- Tab 2: Text-to-Image Config ---
67
- # T2I_PROMPTS = ["F-actin of COS-7", "ER of COS-7", "Mitochondria of BPAE", "Nucleus of BPAE", "ER of HeLa", "Microtubules of HeLa"]
68
  T2I_EXAMPLE_IMG_DIR = "example_images"
69
  T2I_PRETRAINED_MODEL_PATH = f"{MODELS_ROOT_DIR}/stable-diffusion-v1-5"
70
  T2I_UNET_PATH = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000"
@@ -96,22 +95,16 @@ SEG_MODELS = {
96
  }
97
  SEG_EXAMPLE_IMG_DIR = "example_images_seg"
98
 
99
-
100
  # --- Tab 6: Classification Config ---
101
  CLS_MODEL_PATHS = OrderedDict({
102
  "5shot": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot_re",
103
- #"10shot": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_10_shot_re",
104
- #"15shot": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_15_shot_re",
105
- #"20shot": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_20_shot_re",
106
  "5shot+FluoGen": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot_aug_re",
107
- #"10shot_aug": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_10_shot_aug_re",
108
- #"15shot_aug": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_15_shot_aug_re",
109
- #"20shot_aug": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_20_shot_aug_re",
110
  })
111
- #CLS_CLASS_NAMES = ['dap', 'erdak', 'giant', 'gpp130', 'h4b4', 'mc151', 'nucle', 'phal', 'tfr', 'tubul']
112
  CLS_CLASS_NAMES = ['Nucleus', 'Endoplasmic Reticulum', 'Giantin', 'GPP130', 'Lysosomes', 'Mitochondria', 'Nucleolus', 'Actin', 'Endosomes', 'Microtubules']
113
  CLS_EXAMPLE_IMG_DIR = "example_images_cls"
114
 
 
 
115
 
116
  # --- Helper Functions ---
117
  def sanitize_prompt_for_filename(prompt):
@@ -123,77 +116,104 @@ def min_max_norm(x):
123
  if max_val - min_val < 1e-8: return np.zeros_like(x)
124
  return (x - min_val) / (max_val - min_val)
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  def numpy_to_pil(image_np, target_mode="RGB"):
127
- # If the input is already a PIL image, just ensure mode and return
128
  if isinstance(image_np, Image.Image):
129
  if target_mode == "RGB" and image_np.mode != "RGB": return image_np.convert("RGB")
130
  if target_mode == "L" and image_np.mode != "L": return image_np.convert("L")
131
  return image_np
132
-
133
- # Handle numpy array conversion
134
  squeezed_np = np.squeeze(image_np);
135
  if squeezed_np.dtype == np.uint8:
136
- # If it's already uint8, it's likely in the 0-255 range.
137
  image_8bit = squeezed_np
138
  else:
139
- # Normalize and scale for other types
140
  normalized_np = min_max_norm(squeezed_np)
141
  image_8bit = (normalized_np * 255).astype(np.uint8)
142
-
143
  pil_image = Image.fromarray(image_8bit)
144
-
145
  if target_mode == "RGB" and pil_image.mode != "RGB": pil_image = pil_image.convert("RGB")
146
  elif target_mode == "L" and pil_image.mode != "L": pil_image = pil_image.convert("L")
147
  return pil_image
148
 
149
  def update_sr_prompt(model_name):
150
- if model_name == "Checkpoint ER":
151
- return "ER of COS-7"
152
- if model_name == "Checkpoint Microtubules":
153
- return "Microtubules of COS-7"
154
- if model_name == "Checkpoint CCPs":
155
- return "CCPs of COS-7"
156
- elif model_name == "Checkpoint F-actin":
157
- return "F-actin of COS-7"
158
- return "" # 或者返回一个默认值
159
 
160
  PROMPT_TO_MODEL_MAP = {}
161
  current_t2i_unet_path = None
162
  def load_all_prompts():
163
  global PROMPT_TO_MODEL_MAP
164
  categories = [
165
- {
166
- "file": "prompts/basic_prompts.json",
167
- "model": f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000"
168
- },
169
- {
170
- "file": "prompts/others_prompts.json",
171
- "model": f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/FULL-checkpoint-275000"
172
- },
173
- {
174
- "file": "prompts/hpa_prompts.json",
175
- "model": f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/HPA-checkpoint-40000"
176
- }
177
  ]
178
-
179
  combined_prompts = []
180
  for cat in categories:
181
- file_path = cat["file"]
182
- model_path = cat["model"]
183
  try:
184
- if os.path.exists(file_path):
185
- with open(file_path, "r", encoding="utf-8") as f:
186
  data = json.load(f)
187
  if isinstance(data, list):
188
  combined_prompts.extend(data)
189
- for p in data:
190
- PROMPT_TO_MODEL_MAP[p] = model_path
191
- print(f" Loaded {len(data)} prompts from {file_path}")
192
- except Exception as e:
193
- print(f"✗ Error loading {file_path}: {e}")
194
-
195
- if not combined_prompts:
196
- return ["F-actin of COS-7", "ER of COS-7"]
197
  return combined_prompts
198
  T2I_PROMPTS = load_all_prompts()
199
 
@@ -216,7 +236,7 @@ except Exception as e:
216
  try:
217
  print("Loading shared ControlNet pipeline components...")
218
  controlnet_unet = UNet2DConditionModel.from_pretrained(CONTROLNET_UNET_PATH, subfolder="unet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
219
- default_controlnet_path = M2I_CONTROLNET_PATH # Start with the first tab's model
220
  controlnet_controlnet = ControlNetModel.from_pretrained(default_controlnet_path, subfolder="controlnet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
221
  controlnet_scheduler = DDIMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="v_prediction", rescale_betas_zero_snr=False, timestep_spacing="trailing")
222
  controlnet_tokenizer = CLIPTokenizer.from_pretrained(CONTROLNET_CLIP_PATH, subfolder="tokenizer")
@@ -254,19 +274,8 @@ def swap_t2i_unet(pipe, target_unet_path):
254
  raise gr.Error(f"Failed to load UNet from {target_unet_path}. Error: {e}")
255
  return pipe
256
 
257
- # def generate_t2i(prompt, num_inference_steps):
258
- # if t2i_pipe is None: raise gr.Error("Text-to-Image model is not loaded.")
259
- # print(f"\nTask started... | Prompt: '{prompt}' | Steps: {num_inference_steps}")
260
- # image_np = t2i_pipe(prompt.lower(), generator=None, num_inference_steps=int(num_inference_steps), output_type="np").images
261
- # generated_image = numpy_to_pil(image_np)
262
- # print("✓ Image generated")
263
- # if SAVE_EXAMPLES:
264
- # example_filepath = os.path.join(T2I_EXAMPLE_IMG_DIR, sanitize_prompt_for_filename(prompt))0
265
- # if not os.path.exists(example_filepath):
266
- # generated_image.save(example_filepath); print(f"✓ New T2I example saved: {example_filepath}")
267
- # return generated_image
268
  @spaces.GPU(duration=120)
269
- def generate_t2i(prompt, num_inference_steps):
270
  global t2i_pipe
271
  if t2i_pipe is None: raise gr.Error("Text-to-Image model is not loaded.")
272
  target_model_path = PROMPT_TO_MODEL_MAP.get(prompt)
@@ -274,31 +283,31 @@ def generate_t2i(prompt, num_inference_steps):
274
  target_model_path = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/FULL-checkpoint-275000"
275
  print(f"ℹ️ Prompt '{prompt}' not found in predefined list. Using Foundation (Full) model.")
276
  t2i_pipe = swap_t2i_unet(t2i_pipe, target_model_path)
 
277
  print(f"\n🚀 Task started... | Prompt: '{prompt}' | Model: {current_t2i_unet_path}")
278
  image_np = t2i_pipe(prompt.lower(), generator=None, num_inference_steps=int(num_inference_steps), output_type="np").images
279
- generated_image = numpy_to_pil(image_np)
 
 
 
 
 
 
280
  print("✓ Image generated")
281
  if SAVE_EXAMPLES:
282
  example_filepath = os.path.join(T2I_EXAMPLE_IMG_DIR, sanitize_prompt_for_filename(prompt))
283
  if not os.path.exists(example_filepath):
284
- generated_image.save(example_filepath)
285
- print(f"✓ New T2I example saved: {example_filepath}")
286
- return generated_image
 
287
 
288
  @spaces.GPU(duration=120)
289
- def run_mask_to_image_generation(mask_file_obj, cell_type, num_images, steps, seed):
290
  if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
291
  if mask_file_obj is None: raise gr.Error("Please upload a segmentation mask TIF file.")
292
  if not cell_type or not cell_type.strip(): raise gr.Error("Please enter a cell type.")
293
 
294
- if SAVE_EXAMPLES:
295
- input_path = mask_file_obj.name
296
- filename = os.path.basename(input_path)
297
- dest_path = os.path.join(M2I_EXAMPLE_IMG_DIR, filename)
298
- if not os.path.exists(dest_path):
299
- shutil.copy(input_path, dest_path)
300
- print(f"✓ New Mask-to-Image example saved: {dest_path}")
301
-
302
  pipe = swap_controlnet(controlnet_pipe, M2I_CONTROLNET_PATH)
303
  try:
304
  mask_np = tifffile.imread(mask_file_obj.name)
@@ -314,30 +323,44 @@ def run_mask_to_image_generation(mask_file_obj, cell_type, num_images, steps, se
314
  prompt = f"nuclei of {cell_type.strip()}"
315
  print(f"\nTask started... | Task: Mask-to-Image | Prompt: '{prompt}' | Steps: {steps} | Images: {num_images}")
316
 
317
- generated_images_pil = []
 
 
 
 
 
318
  for i in range(int(num_images)):
319
  current_seed = int(seed) + i
320
  generator = torch.Generator(device=DEVICE).manual_seed(current_seed)
321
  with torch.autocast("cuda"):
322
  output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
323
- pil_image = numpy_to_pil(output_np)
324
- generated_images_pil.append(pil_image)
325
- print(f"✓ Generated image {i+1}/{int(num_images)}")
326
 
327
- return input_display_image, generated_images_pil
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
 
329
  @spaces.GPU(duration=120)
330
- def run_super_resolution(low_res_file_obj, controlnet_model_name, prompt, steps, seed):
331
  if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
332
  if low_res_file_obj is None: raise gr.Error("Please upload a low-resolution TIF file.")
333
-
334
- if SAVE_EXAMPLES:
335
- input_path = low_res_file_obj.name
336
- filename = os.path.basename(input_path)
337
- dest_path = os.path.join(SR_EXAMPLE_IMG_DIR, filename)
338
- if not os.path.exists(dest_path):
339
- shutil.copy(input_path, dest_path)
340
- print(f"✓ New SR example saved: {dest_path}")
341
 
342
  target_path = SR_CONTROLNET_MODELS.get(controlnet_model_name)
343
  if not target_path: raise gr.Error(f"ControlNet model '{controlnet_model_name}' not found.")
@@ -360,25 +383,19 @@ def run_super_resolution(low_res_file_obj, controlnet_model_name, prompt, steps,
360
  generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
361
  with torch.autocast("cuda"):
362
  output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
 
 
 
 
 
363
 
364
- return input_display_image, numpy_to_pil(output_np)
365
 
366
  @spaces.GPU(duration=120)
367
- def run_denoising(noisy_image_np, image_type, steps, seed):
368
  if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
369
  if noisy_image_np is None: raise gr.Error("Please upload a noisy image.")
370
 
371
- if SAVE_EXAMPLES:
372
- timestamp = int(time.time() * 1000)
373
- filename = f"dn_input_{image_type}_{timestamp}.tif"
374
- dest_path = os.path.join(DN_EXAMPLE_IMG_DIR, filename)
375
- try:
376
- img_to_save = noisy_image_np.astype(np.uint8) if noisy_image_np.dtype != np.uint8 else noisy_image_np
377
- tifffile.imwrite(dest_path, img_to_save)
378
- print(f"✓ New Denoising example saved: {dest_path}")
379
- except Exception as e:
380
- print(f"✗ Failed to save denoising example: {e}")
381
-
382
  pipe = swap_controlnet(controlnet_pipe, DN_CONTROLNET_PATH)
383
  prompt = DN_PROMPT_RULES.get(image_type, 'microscopy image')
384
  print(f"\nTask started... | Task: Denoising | Prompt: '{prompt}' | Steps: {steps}")
@@ -390,27 +407,22 @@ def run_denoising(noisy_image_np, image_type, steps, seed):
390
  generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
391
  with torch.autocast("cuda"):
392
  output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
 
 
 
 
 
393
 
394
- return numpy_to_pil(noisy_image_np, "L"), numpy_to_pil(output_np)
395
 
396
  @spaces.GPU(duration=120)
397
  def run_segmentation(input_image_np, model_name, diameter, flow_threshold, cellprob_threshold):
398
- """
399
- Runs cell segmentation and creates a dark red overlay.
400
- """
401
- if input_image_np is None:
402
- raise gr.Error("Please upload an image to segment.")
403
-
404
  model_path = SEG_MODELS.get(model_name)
405
- if not model_path:
406
- raise gr.Error(f"Segmentation model '{model_name}' not found.")
407
 
408
- if not os.path.exists(model_path):
409
- raise gr.Error(f"Model file not found at path: {model_path}. Please check the configuration.")
410
-
411
  print(f"\nTask started... | Task: Cell Segmentation | Model: '{model_name}'")
412
-
413
- # 1. Load Cellpose Model
414
  try:
415
  use_gpu = torch.cuda.is_available()
416
  model = cellpose_models.CellposeModel(gpu=use_gpu, pretrained_model=model_path)
@@ -418,155 +430,72 @@ def run_segmentation(input_image_np, model_name, diameter, flow_threshold, cellp
418
  raise gr.Error(f"Failed to load Cellpose model. Error: {e}")
419
 
420
  diameter_to_use = model.diam_labels if diameter == 0 else float(diameter)
421
- print(f"Using Diameter: {diameter_to_use}")
422
-
423
- # 2. Run model evaluation
424
  try:
425
- masks, _, _ = model.eval(
426
- [input_image_np],
427
- channels=[0, 0],
428
- diameter=diameter_to_use,
429
- flow_threshold=flow_threshold,
430
- cellprob_threshold=cellprob_threshold
431
- )
432
  mask_output = masks[0]
433
  except Exception as e:
434
  raise gr.Error(f"Cellpose model evaluation failed. Error: {e}")
435
 
436
- # 3. Create custom dark red overlay
437
- # Ensure input image is uint8 and 3-channel for blending
438
  original_rgb = numpy_to_pil(input_image_np, "RGB")
439
  original_rgb_np = np.array(original_rgb)
440
-
441
- # Create a blank layer for the red mask
442
  red_mask_layer = np.zeros_like(original_rgb_np)
443
- dark_red_color = [139, 0, 0]
444
-
445
- # Apply the red color where the mask is present
446
- is_mask_pixels = mask_output > 0
447
- red_mask_layer[is_mask_pixels] = dark_red_color
448
-
449
- # Blend the original image with the red mask layer
450
- alpha = 0.4 # Opacity of the mask
451
- blended_image_np = ((1 - alpha) * original_rgb_np + alpha * red_mask_layer).astype(np.uint8)
452
-
453
- # 4. Save example if enabled
454
- if SAVE_EXAMPLES:
455
- timestamp = int(time.time() * 1000)
456
- filename = f"seg_input_{timestamp}.tif"
457
- dest_path = os.path.join(SEG_EXAMPLE_IMG_DIR, filename)
458
- try:
459
- img_to_save = input_image_np.astype(np.uint8) if input_image_np.dtype != np.uint8 else input_image_np
460
- tifffile.imwrite(dest_path, img_to_save)
461
- print(f"✓ New Segmentation example saved: {dest_path}")
462
- except Exception as e:
463
- print(f"✗ Failed to save segmentation example: {e}")
464
-
465
- print("✓ Segmentation complete")
466
 
467
  return numpy_to_pil(input_image_np, "L"), numpy_to_pil(blended_image_np, "RGB")
468
 
469
  @spaces.GPU(duration=120)
470
  def run_classification(input_image_np, model_name):
471
- """
472
- Runs classification on a single image using a pre-trained ResNet50 model.
473
- """
474
- if input_image_np is None:
475
- raise gr.Error("Please upload an image to classify.")
476
-
477
  model_dir = CLS_MODEL_PATHS.get(model_name)
478
- if not model_dir:
479
- raise gr.Error(f"Classification model '{model_name}' not found.")
480
-
481
  model_path = os.path.join(model_dir, "best_resnet50.pth")
482
- if not os.path.exists(model_path):
483
- raise gr.Error(f"Model file not found at {model_path}. Please check the configuration.")
484
-
485
  print(f"\nTask started... | Task: Classification | Model: '{model_name}'")
486
-
487
- # 1. Load Model
488
  try:
489
  model = models.resnet50(weights=None)
490
- num_features = model.fc.in_features
491
- model.fc = nn.Linear(num_features, len(CLS_CLASS_NAMES))
492
  model.load_state_dict(torch.load(model_path, map_location=DEVICE))
493
- model.to(DEVICE)
494
- model.eval()
495
  except Exception as e:
496
  raise gr.Error(f"Failed to load classification model. Error: {e}")
497
 
498
- # 2. Preprocess Image
499
- # Grayscale numpy -> RGB PIL -> transform -> tensor
500
  input_pil = numpy_to_pil(input_image_np, "RGB")
501
-
502
- transform_test = transforms.Compose([
503
- transforms.ToTensor(),
504
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # ResNet needs 3-channel norm
505
- ])
506
  input_tensor = transform_test(input_pil).unsqueeze(0).to(DEVICE)
507
 
508
- # 3. Perform Inference
509
  with torch.no_grad():
510
  outputs = model(input_tensor)
511
  probabilities = F.softmax(outputs, dim=1).squeeze().cpu().numpy()
512
 
513
- # 4. Format output for Gradio Label component
514
  confidences = {name: float(prob) for name, prob in zip(CLS_CLASS_NAMES, probabilities)}
515
-
516
- # 5. Save example
517
- if SAVE_EXAMPLES:
518
- timestamp = int(time.time() * 1000)
519
- filename = f"cls_input_{timestamp}.png" # Save as png for compatibility
520
- dest_path = os.path.join(CLS_EXAMPLE_IMG_DIR, filename)
521
- try:
522
- input_pil.save(dest_path)
523
- print(f"✓ New Classification example saved: {dest_path}")
524
- except Exception as e:
525
- print(f"✗ Failed to save classification example: {e}")
526
-
527
- print("✓ Classification complete")
528
-
529
  return numpy_to_pil(input_image_np, "L"), confidences
530
 
531
 
532
  # --- 3. Gradio UI Layout ---
533
  print("Building Gradio interface...")
534
- # Create directories for all example types
535
- os.makedirs(M2I_EXAMPLE_IMG_DIR, exist_ok=True)
536
- os.makedirs(T2I_EXAMPLE_IMG_DIR, exist_ok=True)
537
- os.makedirs(SR_EXAMPLE_IMG_DIR, exist_ok=True)
538
- os.makedirs(DN_EXAMPLE_IMG_DIR, exist_ok=True)
539
- os.makedirs(SEG_EXAMPLE_IMG_DIR, exist_ok=True)
540
- os.makedirs(CLS_EXAMPLE_IMG_DIR, exist_ok=True)
541
 
542
  # --- Load examples ---
543
  filename_to_prompt_map = { sanitize_prompt_for_filename(prompt): prompt for prompt in T2I_PROMPTS }
544
  t2i_gallery_examples = []
545
  for filename in os.listdir(T2I_EXAMPLE_IMG_DIR):
546
  if filename in filename_to_prompt_map:
547
- filepath = os.path.join(T2I_EXAMPLE_IMG_DIR, filename)
548
- prompt = filename_to_prompt_map[filename]
549
- t2i_gallery_examples.append((filepath, prompt))
550
 
551
  def load_image_examples(example_dir, is_stack=False):
552
  examples = []
553
  if not os.path.exists(example_dir): return examples
554
  for f in sorted(os.listdir(example_dir)):
555
- if f.lower().endswith(('.tif', '.tiff', '.png', '.jpg', '.jpeg')):
556
  filepath = os.path.join(example_dir, f)
557
  try:
558
- if f.lower().endswith(('.tif', '.tiff')):
559
- img_np = tifffile.imread(filepath)
560
- else:
561
- img_np = np.array(Image.open(filepath).convert("L"))
562
-
563
- if is_stack and img_np.ndim == 3:
564
- img_np = np.mean(img_np, axis=0)
565
-
566
- display_img = numpy_to_pil(img_np, "L")
567
- examples.append((display_img, filepath))
568
- except Exception as e:
569
- print(f"Warning: Could not load gallery image {filepath}. Error: {e}")
570
  return examples
571
 
572
  m2i_gallery_examples = load_image_examples(M2I_EXAMPLE_IMG_DIR)
@@ -576,11 +505,8 @@ seg_gallery_examples = load_image_examples(SEG_EXAMPLE_IMG_DIR)
576
  cls_gallery_examples = load_image_examples(CLS_EXAMPLE_IMG_DIR)
577
 
578
  # --- Universal event handlers ---
579
- def select_example_prompt(evt: gr.SelectData):
580
- return evt.value['caption']
581
-
582
- def select_example_input_file(evt: gr.SelectData):
583
- return evt.value['caption']
584
 
585
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
586
  with gr.Row():
@@ -590,175 +516,112 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
590
  with gr.Tabs():
591
  # --- TAB 1: Text-to-Image ---
592
  with gr.Tab("Text-to-Image Generation", id="txt2img"):
593
- gr.Markdown("""
594
- ### Instructions
595
- 1. Select a desired prompt from the dropdown menu.
596
- 2. Adjust the 'Inference Steps' slider to control generation quality.
597
- 3. Click the 'Generate' button to create a new image.
598
- 4. Explore the 'Examples' gallery; clicking an image will load its prompt.
599
-
600
- **Notice:** This model currently supports 3566 prompt categories. However, data for many cell structures and lines is still lacking. **We welcome data source contributions to improve the model.**
601
- """) # Content hidden for brevity
602
  with gr.Row(variant="panel"):
603
  with gr.Column(scale=1, min_width=350):
604
- # t2i_prompt_input = gr.Dropdown(choices=T2I_PROMPTS, value=T2I_PROMPTS[0], label="Select a Prompt")
605
- t2i_prompt_input = gr.Dropdown(
606
- choices=T2I_PROMPTS,
607
- value=T2I_PROMPTS[0],
608
- label="Search or Type a Prompt",
609
- filterable=True,
610
- allow_custom_value=True
611
- )
612
  t2i_steps_slider = gr.Slider(minimum=10, maximum=200, step=1, value=50, label="Inference Steps")
 
 
613
  t2i_generate_button = gr.Button("Generate", variant="primary")
614
  with gr.Column(scale=2):
615
- t2i_generated_output = gr.Image(label="Generated Image", type="pil", interactive=False)
616
- t2i_gallery = gr.Gallery(value=t2i_gallery_examples, label="Examples (Click an image to use its prompt)", columns=6, object_fit="contain", height="auto")
 
 
617
 
618
  # --- TAB 2: Super-Resolution ---
619
  with gr.Tab("Super-Resolution", id="super_res"):
620
- gr.Markdown("""
621
- ### Instructions
622
- 1. Upload a low-resolution 9-channel TIF stack, or select one from the examples.
623
- 2. Select a 'Super-Resolution Model' from the dropdown.
624
- 3. Enter a descriptive 'Prompt' related to the image content (e.g., 'CCPs of COS-7').
625
- 4. Adjust 'Inference Steps' and 'Seed' as needed.
626
- 5. Click 'Generate Super-Resolution' to process the image.
627
-
628
- **Notice:** This model was trained on the **BioSR** dataset. If your data's characteristics differ significantly, please consider fine-tuning the model using our project on GitHub for optimal results.
629
- """) # Content hidden for brevity
630
  with gr.Row(variant="panel"):
631
  with gr.Column(scale=1, min_width=350):
632
  sr_input_file = gr.File(label="Upload 9-Channel TIF Stack", file_types=['.tif', '.tiff'])
633
- sr_model_selector = gr.Dropdown(
634
- choices=list(SR_CONTROLNET_MODELS.keys()),
635
- value=list(SR_CONTROLNET_MODELS.keys())[-1],
636
- label="Select Super-Resolution Model"
637
- )
638
- # sr_prompt_input = gr.Textbox(label="Prompt (e.g., structure name)", value="CCPs of COS-7")
639
- sr_prompt_input = gr.Textbox(
640
- label="Prompt",
641
- value="F-actin of COS-7", # 初始值根据你的默认选择设定
642
- interactive=False
643
- )
644
  sr_steps_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Inference Steps")
645
  sr_seed_input = gr.Number(label="Seed", value=42)
 
646
  sr_generate_button = gr.Button("Generate Super-Resolution", variant="primary")
647
  with gr.Column(scale=2):
648
  with gr.Row():
649
- sr_input_display = gr.Image(label="Input (Average Projection)", type="pil", interactive=False)
650
  sr_output_image = gr.Image(label="Super-Resolved Image", type="pil", interactive=False)
651
- sr_gallery = gr.Gallery(value=sr_gallery_examples, label="Input Examples (Click an image to use it as input)", columns=6, object_fit="contain", height="auto")
 
652
 
653
  # --- TAB 3: Denoising ---
654
  with gr.Tab("Denoising", id="denoising"):
655
- gr.Markdown("""
656
- ### Instructions
657
- 1. Upload a noisy single-channel image, or select one from the examples.
658
- 2. Select the 'Image Type' from the dropdown to provide context for the model.
659
- 3. Adjust 'Inference Steps' and 'Seed' as needed.
660
- 4. Click 'Denoise Image' to reduce the noise.
661
-
662
- **Notice:** This model was trained on the **FMD** dataset. If your data's characteristics differ significantly, please consider fine-tuning the model using our project on GitHub for optimal results.
663
- """) # Content hidden for brevity
664
  with gr.Row(variant="panel"):
665
  with gr.Column(scale=1, min_width=350):
666
  dn_input_image = gr.Image(type="numpy", label="Upload Noisy Image", image_mode="L")
667
- dn_image_type_selector = gr.Dropdown(choices=list(DN_PROMPT_RULES.keys()), value='MICE', label="Select Image Type (for Prompt)")
668
  dn_steps_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Inference Steps")
669
  dn_seed_input = gr.Number(label="Seed", value=42)
 
670
  dn_generate_button = gr.Button("Denoise Image", variant="primary")
671
  with gr.Column(scale=2):
672
  with gr.Row():
673
- dn_original_display = gr.Image(label="Original Noisy Image", type="pil", interactive=False)
674
  dn_output_image = gr.Image(label="Denoised Image", type="pil", interactive=False)
675
- dn_gallery = gr.Gallery(value=dn_gallery_examples, label="Input Examples (Click an image to use it as input)", columns=6, object_fit="contain", height="auto")
 
676
 
677
  # --- TAB 4: Mask-to-Image ---
678
  with gr.Tab("Mask-to-Image", id="mask2img"):
679
- gr.Markdown("""
680
- ### Instructions
681
- 1. Upload a single-channel segmentation mask (`.tif` file), or select one from the examples gallery below.
682
- 2. Enter the corresponding 'Cell Type' (e.g., 'CoNSS', 'HeLa') to create the prompt.
683
- 3. Select how many sample images you want to generate.
684
- 4. Adjust 'Inference Steps' and 'Seed' as needed.
685
- 5. Click 'Generate Training Samples' to start the process.
686
- 6. The 'Generated Samples' will appear in the main gallery, with the 'Input Mask' shown below for reference.
687
- """) # Content hidden for brevity
688
  with gr.Row(variant="panel"):
689
  with gr.Column(scale=1, min_width=350):
690
  m2i_input_file = gr.File(label="Upload Segmentation Mask (.tif)", file_types=['.tif', '.tiff'])
691
- m2i_cell_type_input = gr.Textbox(label="Cell Type (for prompt)", placeholder="e.g., CoNSS, HeLa, MCF-7")
692
- m2i_num_images_slider = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Number of Images to Generate")
693
  m2i_steps_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Inference Steps")
694
  m2i_seed_input = gr.Number(label="Seed", value=42)
695
- m2i_generate_button = gr.Button("Generate Training Samples", variant="primary")
 
696
  with gr.Column(scale=2):
697
  m2i_output_gallery = gr.Gallery(label="Generated Samples", columns=5, object_fit="contain", height="auto")
 
698
  m2i_input_display = gr.Image(label="Input Mask", type="pil", interactive=False)
699
- m2i_gallery = gr.Gallery(value=m2i_gallery_examples, label="Input Examples (Click an image to use it as input)", columns=6, object_fit="contain", height="auto")
700
 
701
  # --- TAB 5: Cell Segmentation ---
702
  with gr.Tab("Cell Segmentation", id="segmentation"):
703
- gr.Markdown("""
704
- ### Instructions
705
- 1. Upload a single-channel image for segmentation, or select one from the examples.
706
- 2. Select a 'Segmentation Model' from the dropdown menu.
707
- 3. Set the expected 'Diameter' of the cells in pixels. Set to 0 to let the model automatically estimate it.
708
- 4. Adjust 'Flow Threshold' and 'Cell Probability Threshold' for finer control.
709
- 5. Click 'Segment Cells'. The result will be shown as a dark red overlay on the original image.
710
- """)
711
  with gr.Row(variant="panel"):
712
  with gr.Column(scale=1, min_width=350):
713
- gr.Markdown("### 1. Inputs & Controls")
714
- seg_input_image = gr.Image(type="numpy", label="Upload Image for Segmentation", image_mode="L")
715
- seg_model_selector = gr.Dropdown(choices=list(SEG_MODELS.keys()), value=list(SEG_MODELS.keys())[0], label="Select Segmentation Model")
716
- seg_diameter_input = gr.Number(label="Cell Diameter (pixels, 0=auto)", value=30)
717
  seg_flow_slider = gr.Slider(minimum=0.0, maximum=3.0, step=0.1, value=0.4, label="Flow Threshold")
718
  seg_cellprob_slider = gr.Slider(minimum=-6.0, maximum=6.0, step=0.5, value=0.0, label="Cell Probability Threshold")
719
  seg_generate_button = gr.Button("Segment Cells", variant="primary")
720
  with gr.Column(scale=2):
721
- gr.Markdown("### 2. Results")
722
  with gr.Row():
723
- seg_original_display = gr.Image(label="Original Image", type="pil", interactive=False)
724
- seg_output_image = gr.Image(label="Segmented Image (Overlay)", type="pil", interactive=False)
725
- seg_gallery = gr.Gallery(value=seg_gallery_examples, label="Input Examples (Click an image to use it as input)", columns=6, object_fit="contain", height="auto")
726
 
727
- # --- NEW TAB 6: Classification ---
728
  with gr.Tab("Classification", id="classification"):
729
- gr.Markdown("""
730
- ### Instructions
731
- 1. Upload a single-channel image for classification, or select an example.
732
- 2. Select a pre-trained 'Classification Model' from the dropdown menu.
733
- 3. Click 'Classify Image' to view the prediction probabilities for each class.
734
-
735
- **Note:** The models provided are ResNet50 trained on the 2D HeLa dataset.
736
- """)
737
  with gr.Row(variant="panel"):
738
  with gr.Column(scale=1, min_width=350):
739
- gr.Markdown("### 1. Inputs & Controls")
740
- cls_input_image = gr.Image(type="numpy", label="Upload Image for Classification", image_mode="L")
741
- cls_model_selector = gr.Dropdown(choices=list(CLS_MODEL_PATHS.keys()), value=list(CLS_MODEL_PATHS.keys())[0], label="Select Classification Model")
742
  cls_generate_button = gr.Button("Classify Image", variant="primary")
743
  with gr.Column(scale=2):
744
- gr.Markdown("### 2. Results")
745
  cls_original_display = gr.Image(label="Input Image", type="pil", interactive=False)
746
- cls_output_label = gr.Label(label="Classification Results", num_top_classes=len(CLS_CLASS_NAMES))
747
- cls_gallery = gr.Gallery(value=cls_gallery_examples, label="Input Examples (Click an image to use it as input)", columns=6, object_fit="contain", height="auto")
748
 
749
 
750
  # --- Event Handlers ---
751
- m2i_generate_button.click(fn=run_mask_to_image_generation, inputs=[m2i_input_file, m2i_cell_type_input, m2i_num_images_slider, m2i_steps_slider, m2i_seed_input], outputs=[m2i_input_display, m2i_output_gallery])
752
  m2i_gallery.select(fn=select_example_input_file, outputs=m2i_input_file)
753
 
754
- t2i_generate_button.click(fn=generate_t2i, inputs=[t2i_prompt_input, t2i_steps_slider], outputs=[t2i_generated_output])
755
  t2i_gallery.select(fn=select_example_prompt, outputs=t2i_prompt_input)
756
 
757
  sr_model_selector.change(fn=update_sr_prompt, inputs=sr_model_selector, outputs=sr_prompt_input)
758
- sr_generate_button.click(fn=run_super_resolution, inputs=[sr_input_file, sr_model_selector, sr_prompt_input, sr_steps_slider, sr_seed_input], outputs=[sr_input_display, sr_output_image])
759
  sr_gallery.select(fn=select_example_input_file, outputs=sr_input_file)
760
 
761
- dn_generate_button.click(fn=run_denoising, inputs=[dn_input_image, dn_image_type_selector, dn_steps_slider, dn_seed_input], outputs=[dn_original_display, dn_output_image])
762
  dn_gallery.select(fn=select_example_input_file, outputs=dn_input_image)
763
 
764
  seg_generate_button.click(fn=run_segmentation, inputs=[seg_input_image, seg_model_selector, seg_diameter_input, seg_flow_slider, seg_cellprob_slider], outputs=[seg_original_display, seg_output_image])
@@ -771,5 +634,4 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
771
  # --- 4. Launch Application ---
772
  if __name__ == "__main__":
773
  print("Interface built. Launching server...")
774
- demo.launch()
775
-
 
1
  import torch
2
  import numpy as np
3
  import gradio as gr
4
+ from PIL import Image, ImageOps
5
  import os
6
  import json
7
  import glob
 
17
  import torch.nn.functional as F
18
  import spaces
19
  from collections import OrderedDict
20
+ import tempfile
21
+ import zipfile
22
+ import matplotlib.cm as cm
23
 
24
  # --- Imports from both scripts ---
25
  from diffusers import DDPMScheduler, DDIMScheduler
 
56
  SAVE_EXAMPLES = False
57
 
58
  # --- Base directory for all models ---
 
 
59
  REPO_ID = "FluoGen-Group/FluoGen-demo-test-ckpts"
60
+ MODELS_ROOT_DIR = snapshot_download(repo_id=REPO_ID, token=hf_token)
61
 
62
+ # --- Tab 1: Mask-to-Image Config ---
63
  M2I_CONTROLNET_PATH = f"{MODELS_ROOT_DIR}/ControlNet_M2I/checkpoint-30000"
64
  M2I_EXAMPLE_IMG_DIR = "example_images_m2i"
65
 
66
  # --- Tab 2: Text-to-Image Config ---
 
67
  T2I_EXAMPLE_IMG_DIR = "example_images"
68
  T2I_PRETRAINED_MODEL_PATH = f"{MODELS_ROOT_DIR}/stable-diffusion-v1-5"
69
  T2I_UNET_PATH = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000"
 
95
  }
96
  SEG_EXAMPLE_IMG_DIR = "example_images_seg"
97
 
 
98
  # --- Tab 6: Classification Config ---
99
  CLS_MODEL_PATHS = OrderedDict({
100
  "5shot": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot_re",
 
 
 
101
  "5shot+FluoGen": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot_aug_re",
 
 
 
102
  })
 
103
  CLS_CLASS_NAMES = ['Nucleus', 'Endoplasmic Reticulum', 'Giantin', 'GPP130', 'Lysosomes', 'Mitochondria', 'Nucleolus', 'Actin', 'Endosomes', 'Microtubules']
104
  CLS_EXAMPLE_IMG_DIR = "example_images_cls"
105
 
106
+ # --- Constants for Visualization ---
107
+ COLOR_MAPS = ["Grayscale", "Green (GFP)", "Red (RFP)", "Blue (DAPI)", "Magenta", "Cyan", "Yellow", "Fire", "Viridis", "Inferno"]
108
 
109
  # --- Helper Functions ---
110
  def sanitize_prompt_for_filename(prompt):
 
116
  if max_val - min_val < 1e-8: return np.zeros_like(x)
117
  return (x - min_val) / (max_val - min_val)
118
 
119
+ def apply_pseudocolor(image_np, color_name="Grayscale"):
120
+ """
121
+ Applies a pseudocolor to a single channel numpy image.
122
+ image_np: Single channel numpy array (any bit depth).
123
+ Returns: PIL Image in RGB.
124
+ """
125
+ # Normalize to 0-1 for processing
126
+ norm_img = min_max_norm(np.squeeze(image_np))
127
+
128
+ if color_name == "Grayscale":
129
+ # Just convert to uint8 L
130
+ return Image.fromarray((norm_img * 255).astype(np.uint8)).convert("RGB")
131
+
132
+ # Create RGB canvas
133
+ h, w = norm_img.shape
134
+ rgb = np.zeros((h, w, 3), dtype=np.float32)
135
+
136
+ if color_name == "Green (GFP)":
137
+ rgb[..., 1] = norm_img
138
+ elif color_name == "Red (RFP)":
139
+ rgb[..., 0] = norm_img
140
+ elif color_name == "Blue (DAPI)":
141
+ rgb[..., 2] = norm_img
142
+ elif color_name == "Magenta":
143
+ rgb[..., 0] = norm_img
144
+ rgb[..., 2] = norm_img
145
+ elif color_name == "Cyan":
146
+ rgb[..., 1] = norm_img
147
+ rgb[..., 2] = norm_img
148
+ elif color_name == "Yellow":
149
+ rgb[..., 0] = norm_img
150
+ rgb[..., 1] = norm_img
151
+ elif color_name in ["Fire", "Viridis", "Inferno"]:
152
+ # Use matplotlib colormaps
153
+ cmap_map = {"Fire": "inferno", "Viridis": "viridis", "Inferno": "inferno"} # Fire looks like inferno usually
154
+ if color_name == "Fire": cmap = cm.get_cmap("gnuplot2") # Better fire approximation
155
+ else: cmap = cm.get_cmap(cmap_map[color_name])
156
+
157
+ colored = cmap(norm_img) # Returns RGBA 0-1
158
+ rgb = colored[..., :3] # Drop Alpha
159
+
160
+ return Image.fromarray((rgb * 255).astype(np.uint8))
161
+
162
+ def save_temp_tiff(image_np, prefix="output"):
163
+ """Saves numpy array to a temp TIFF file and returns path."""
164
+ tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".tif", prefix=f"{prefix}_")
165
+ # Ensure compatible type for Tiff (float32 or uint16 preferred for science)
166
+ if image_np.dtype == np.float16:
167
+ save_data = image_np.astype(np.float32)
168
+ else:
169
+ save_data = image_np
170
+ tifffile.imwrite(tfile.name, save_data)
171
+ return tfile.name
172
+
173
  def numpy_to_pil(image_np, target_mode="RGB"):
 
174
  if isinstance(image_np, Image.Image):
175
  if target_mode == "RGB" and image_np.mode != "RGB": return image_np.convert("RGB")
176
  if target_mode == "L" and image_np.mode != "L": return image_np.convert("L")
177
  return image_np
 
 
178
  squeezed_np = np.squeeze(image_np);
179
  if squeezed_np.dtype == np.uint8:
 
180
  image_8bit = squeezed_np
181
  else:
 
182
  normalized_np = min_max_norm(squeezed_np)
183
  image_8bit = (normalized_np * 255).astype(np.uint8)
 
184
  pil_image = Image.fromarray(image_8bit)
 
185
  if target_mode == "RGB" and pil_image.mode != "RGB": pil_image = pil_image.convert("RGB")
186
  elif target_mode == "L" and pil_image.mode != "L": pil_image = pil_image.convert("L")
187
  return pil_image
188
 
189
  def update_sr_prompt(model_name):
190
+ if model_name == "Checkpoint ER": return "ER of COS-7"
191
+ if model_name == "Checkpoint Microtubules": return "Microtubules of COS-7"
192
+ if model_name == "Checkpoint CCPs": return "CCPs of COS-7"
193
+ elif model_name == "Checkpoint F-actin": return "F-actin of COS-7"
194
+ return ""
 
 
 
 
195
 
196
  PROMPT_TO_MODEL_MAP = {}
197
  current_t2i_unet_path = None
198
  def load_all_prompts():
199
  global PROMPT_TO_MODEL_MAP
200
  categories = [
201
+ {"file": "prompts/basic_prompts.json", "model": f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000"},
202
+ {"file": "prompts/others_prompts.json", "model": f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/FULL-checkpoint-275000"},
203
+ {"file": "prompts/hpa_prompts.json", "model": f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/HPA-checkpoint-40000"}
 
 
 
 
 
 
 
 
 
204
  ]
 
205
  combined_prompts = []
206
  for cat in categories:
 
 
207
  try:
208
+ if os.path.exists(cat["file"]):
209
+ with open(cat["file"], "r", encoding="utf-8") as f:
210
  data = json.load(f)
211
  if isinstance(data, list):
212
  combined_prompts.extend(data)
213
+ for p in data: PROMPT_TO_MODEL_MAP[p] = cat["model"]
214
+ print(f"✓ Loaded {len(data)} prompts from {cat['file']}")
215
+ except Exception as e: print(f" Error loading {cat['file']}: {e}")
216
+ if not combined_prompts: return ["F-actin of COS-7", "ER of COS-7"]
 
 
 
 
217
  return combined_prompts
218
  T2I_PROMPTS = load_all_prompts()
219
 
 
236
  try:
237
  print("Loading shared ControlNet pipeline components...")
238
  controlnet_unet = UNet2DConditionModel.from_pretrained(CONTROLNET_UNET_PATH, subfolder="unet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
239
+ default_controlnet_path = M2I_CONTROLNET_PATH
240
  controlnet_controlnet = ControlNetModel.from_pretrained(default_controlnet_path, subfolder="controlnet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
241
  controlnet_scheduler = DDIMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="v_prediction", rescale_betas_zero_snr=False, timestep_spacing="trailing")
242
  controlnet_tokenizer = CLIPTokenizer.from_pretrained(CONTROLNET_CLIP_PATH, subfolder="tokenizer")
 
274
  raise gr.Error(f"Failed to load UNet from {target_unet_path}. Error: {e}")
275
  return pipe
276
 
 
 
 
 
 
 
 
 
 
 
 
277
  @spaces.GPU(duration=120)
278
+ def generate_t2i(prompt, num_inference_steps, colormap_choice):
279
  global t2i_pipe
280
  if t2i_pipe is None: raise gr.Error("Text-to-Image model is not loaded.")
281
  target_model_path = PROMPT_TO_MODEL_MAP.get(prompt)
 
283
  target_model_path = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/FULL-checkpoint-275000"
284
  print(f"ℹ️ Prompt '{prompt}' not found in predefined list. Using Foundation (Full) model.")
285
  t2i_pipe = swap_t2i_unet(t2i_pipe, target_model_path)
286
+
287
  print(f"\n🚀 Task started... | Prompt: '{prompt}' | Model: {current_t2i_unet_path}")
288
  image_np = t2i_pipe(prompt.lower(), generator=None, num_inference_steps=int(num_inference_steps), output_type="np").images
289
+
290
+ # 1. Save Raw Data
291
+ raw_file_path = save_temp_tiff(image_np, prefix="t2i_raw")
292
+
293
+ # 2. Apply Pseudocolor for Display
294
+ display_image = apply_pseudocolor(image_np, colormap_choice)
295
+
296
  print("✓ Image generated")
297
  if SAVE_EXAMPLES:
298
  example_filepath = os.path.join(T2I_EXAMPLE_IMG_DIR, sanitize_prompt_for_filename(prompt))
299
  if not os.path.exists(example_filepath):
300
+ display_image.save(example_filepath)
301
+
302
+ # Return Display Image AND Path to Raw File
303
+ return display_image, raw_file_path
304
 
305
  @spaces.GPU(duration=120)
306
+ def run_mask_to_image_generation(mask_file_obj, cell_type, num_images, steps, seed, colormap_choice):
307
  if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
308
  if mask_file_obj is None: raise gr.Error("Please upload a segmentation mask TIF file.")
309
  if not cell_type or not cell_type.strip(): raise gr.Error("Please enter a cell type.")
310
 
 
 
 
 
 
 
 
 
311
  pipe = swap_controlnet(controlnet_pipe, M2I_CONTROLNET_PATH)
312
  try:
313
  mask_np = tifffile.imread(mask_file_obj.name)
 
323
  prompt = f"nuclei of {cell_type.strip()}"
324
  print(f"\nTask started... | Task: Mask-to-Image | Prompt: '{prompt}' | Steps: {steps} | Images: {num_images}")
325
 
326
+ generated_display_images = []
327
+ generated_raw_files = []
328
+
329
+ # Create a temp dir for the zip file
330
+ temp_dir = tempfile.mkdtemp()
331
+
332
  for i in range(int(num_images)):
333
  current_seed = int(seed) + i
334
  generator = torch.Generator(device=DEVICE).manual_seed(current_seed)
335
  with torch.autocast("cuda"):
336
  output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
 
 
 
337
 
338
+ # Save individual raw file
339
+ raw_name = f"m2i_sample_{i+1}.tif"
340
+ raw_path = os.path.join(temp_dir, raw_name)
341
+
342
+ # Ensure correct type saving
343
+ save_data = output_np.astype(np.float32) if output_np.dtype == np.float16 else output_np
344
+ tifffile.imwrite(raw_path, save_data)
345
+ generated_raw_files.append(raw_path)
346
+
347
+ # Create display image
348
+ pil_image = apply_pseudocolor(output_np, colormap_choice)
349
+ generated_display_images.append(pil_image)
350
+ print(f"✓ Generated image {i+1}/{int(num_images)}")
351
+
352
+ # Create ZIP file
353
+ zip_filename = os.path.join(temp_dir, "raw_output_images.zip")
354
+ with zipfile.ZipFile(zip_filename, 'w') as zipf:
355
+ for file in generated_raw_files:
356
+ zipf.write(file, os.path.basename(file))
357
+
358
+ return input_display_image, generated_display_images, zip_filename
359
 
360
  @spaces.GPU(duration=120)
361
+ def run_super_resolution(low_res_file_obj, controlnet_model_name, prompt, steps, seed, colormap_choice):
362
  if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
363
  if low_res_file_obj is None: raise gr.Error("Please upload a low-resolution TIF file.")
 
 
 
 
 
 
 
 
364
 
365
  target_path = SR_CONTROLNET_MODELS.get(controlnet_model_name)
366
  if not target_path: raise gr.Error(f"ControlNet model '{controlnet_model_name}' not found.")
 
383
  generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
384
  with torch.autocast("cuda"):
385
  output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
386
+
387
+ # Save Raw
388
+ raw_file_path = save_temp_tiff(output_np, prefix="sr_raw")
389
+ # Display Color
390
+ output_display = apply_pseudocolor(output_np, colormap_choice)
391
 
392
+ return input_display_image, output_display, raw_file_path
393
 
394
  @spaces.GPU(duration=120)
395
+ def run_denoising(noisy_image_np, image_type, steps, seed, colormap_choice):
396
  if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
397
  if noisy_image_np is None: raise gr.Error("Please upload a noisy image.")
398
 
 
 
 
 
 
 
 
 
 
 
 
399
  pipe = swap_controlnet(controlnet_pipe, DN_CONTROLNET_PATH)
400
  prompt = DN_PROMPT_RULES.get(image_type, 'microscopy image')
401
  print(f"\nTask started... | Task: Denoising | Prompt: '{prompt}' | Steps: {steps}")
 
407
  generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
408
  with torch.autocast("cuda"):
409
  output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
410
+
411
+ # Save Raw
412
+ raw_file_path = save_temp_tiff(output_np, prefix="dn_raw")
413
+ # Display Color
414
+ output_display = apply_pseudocolor(output_np, colormap_choice)
415
 
416
+ return numpy_to_pil(noisy_image_np, "L"), output_display, raw_file_path
417
 
418
  @spaces.GPU(duration=120)
419
  def run_segmentation(input_image_np, model_name, diameter, flow_threshold, cellprob_threshold):
420
+ # Segmentation remains mostly same, usually overlay is RGB anyway
421
+ if input_image_np is None: raise gr.Error("Please upload an image to segment.")
 
 
 
 
422
  model_path = SEG_MODELS.get(model_name)
423
+ if not model_path: raise gr.Error(f"Segmentation model '{model_name}' not found.")
 
424
 
 
 
 
425
  print(f"\nTask started... | Task: Cell Segmentation | Model: '{model_name}'")
 
 
426
  try:
427
  use_gpu = torch.cuda.is_available()
428
  model = cellpose_models.CellposeModel(gpu=use_gpu, pretrained_model=model_path)
 
430
  raise gr.Error(f"Failed to load Cellpose model. Error: {e}")
431
 
432
  diameter_to_use = model.diam_labels if diameter == 0 else float(diameter)
 
 
 
433
  try:
434
+ masks, _, _ = model.eval([input_image_np], channels=[0, 0], diameter=diameter_to_use, flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold)
 
 
 
 
 
 
435
  mask_output = masks[0]
436
  except Exception as e:
437
  raise gr.Error(f"Cellpose model evaluation failed. Error: {e}")
438
 
 
 
439
  original_rgb = numpy_to_pil(input_image_np, "RGB")
440
  original_rgb_np = np.array(original_rgb)
 
 
441
  red_mask_layer = np.zeros_like(original_rgb_np)
442
+ red_mask_layer[mask_output > 0] = [139, 0, 0]
443
+ blended_image_np = ((0.6 * original_rgb_np + 0.4 * red_mask_layer).astype(np.uint8))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
 
445
  return numpy_to_pil(input_image_np, "L"), numpy_to_pil(blended_image_np, "RGB")
446
 
447
  @spaces.GPU(duration=120)
448
  def run_classification(input_image_np, model_name):
449
+ if input_image_np is None: raise gr.Error("Please upload an image to classify.")
 
 
 
 
 
450
  model_dir = CLS_MODEL_PATHS.get(model_name)
451
+ if not model_dir: raise gr.Error(f"Classification model '{model_name}' not found.")
 
 
452
  model_path = os.path.join(model_dir, "best_resnet50.pth")
453
+
 
 
454
  print(f"\nTask started... | Task: Classification | Model: '{model_name}'")
 
 
455
  try:
456
  model = models.resnet50(weights=None)
457
+ model.fc = nn.Linear(model.fc.in_features, len(CLS_CLASS_NAMES))
 
458
  model.load_state_dict(torch.load(model_path, map_location=DEVICE))
459
+ model.to(DEVICE).eval()
 
460
  except Exception as e:
461
  raise gr.Error(f"Failed to load classification model. Error: {e}")
462
 
 
 
463
  input_pil = numpy_to_pil(input_image_np, "RGB")
464
+ transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
 
 
 
 
465
  input_tensor = transform_test(input_pil).unsqueeze(0).to(DEVICE)
466
 
 
467
  with torch.no_grad():
468
  outputs = model(input_tensor)
469
  probabilities = F.softmax(outputs, dim=1).squeeze().cpu().numpy()
470
 
 
471
  confidences = {name: float(prob) for name, prob in zip(CLS_CLASS_NAMES, probabilities)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  return numpy_to_pil(input_image_np, "L"), confidences
473
 
474
 
475
  # --- 3. Gradio UI Layout ---
476
  print("Building Gradio interface...")
477
+ for d in [M2I_EXAMPLE_IMG_DIR, T2I_EXAMPLE_IMG_DIR, SR_EXAMPLE_IMG_DIR, DN_EXAMPLE_IMG_DIR, SEG_EXAMPLE_IMG_DIR, CLS_EXAMPLE_IMG_DIR]:
478
+ os.makedirs(d, exist_ok=True)
 
 
 
 
 
479
 
480
  # --- Load examples ---
481
  filename_to_prompt_map = { sanitize_prompt_for_filename(prompt): prompt for prompt in T2I_PROMPTS }
482
  t2i_gallery_examples = []
483
  for filename in os.listdir(T2I_EXAMPLE_IMG_DIR):
484
  if filename in filename_to_prompt_map:
485
+ t2i_gallery_examples.append((os.path.join(T2I_EXAMPLE_IMG_DIR, filename), filename_to_prompt_map[filename]))
 
 
486
 
487
  def load_image_examples(example_dir, is_stack=False):
488
  examples = []
489
  if not os.path.exists(example_dir): return examples
490
  for f in sorted(os.listdir(example_dir)):
491
+ if f.lower().endswith(('.tif', '.tiff', '.png', '.jpg')):
492
  filepath = os.path.join(example_dir, f)
493
  try:
494
+ if f.lower().endswith(('.tif', '.tiff')): img_np = tifffile.imread(filepath)
495
+ else: img_np = np.array(Image.open(filepath).convert("L"))
496
+ if is_stack and img_np.ndim == 3: img_np = np.mean(img_np, axis=0)
497
+ examples.append((numpy_to_pil(img_np, "L"), filepath))
498
+ except: pass
 
 
 
 
 
 
 
499
  return examples
500
 
501
  m2i_gallery_examples = load_image_examples(M2I_EXAMPLE_IMG_DIR)
 
505
  cls_gallery_examples = load_image_examples(CLS_EXAMPLE_IMG_DIR)
506
 
507
  # --- Universal event handlers ---
508
+ def select_example_prompt(evt: gr.SelectData): return evt.value['caption']
509
+ def select_example_input_file(evt: gr.SelectData): return evt.value['caption']
 
 
 
510
 
511
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
512
  with gr.Row():
 
516
  with gr.Tabs():
517
  # --- TAB 1: Text-to-Image ---
518
  with gr.Tab("Text-to-Image Generation", id="txt2img"):
 
 
 
 
 
 
 
 
 
519
  with gr.Row(variant="panel"):
520
  with gr.Column(scale=1, min_width=350):
521
+ t2i_prompt_input = gr.Dropdown(choices=T2I_PROMPTS, value=T2I_PROMPTS[0], label="Search or Type a Prompt", filterable=True, allow_custom_value=True)
 
 
 
 
 
 
 
522
  t2i_steps_slider = gr.Slider(minimum=10, maximum=200, step=1, value=50, label="Inference Steps")
523
+ # New: Color Selector
524
+ t2i_color_input = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor (for Display)")
525
  t2i_generate_button = gr.Button("Generate", variant="primary")
526
  with gr.Column(scale=2):
527
+ t2i_generated_output = gr.Image(label="Generated Image (Pseudocolor)", type="pil", interactive=False)
528
+ # New: Raw Download Button
529
+ t2i_raw_download = gr.DownloadButton(label="Download Raw Output (.tif)", visible=True)
530
+ t2i_gallery = gr.Gallery(value=t2i_gallery_examples, label="Examples", columns=6, object_fit="contain", height="auto")
531
 
532
  # --- TAB 2: Super-Resolution ---
533
  with gr.Tab("Super-Resolution", id="super_res"):
 
 
 
 
 
 
 
 
 
 
534
  with gr.Row(variant="panel"):
535
  with gr.Column(scale=1, min_width=350):
536
  sr_input_file = gr.File(label="Upload 9-Channel TIF Stack", file_types=['.tif', '.tiff'])
537
+ sr_model_selector = gr.Dropdown(choices=list(SR_CONTROLNET_MODELS.keys()), value=list(SR_CONTROLNET_MODELS.keys())[-1], label="Select Super-Resolution Model")
538
+ sr_prompt_input = gr.Textbox(label="Prompt", value="F-actin of COS-7", interactive=False)
 
 
 
 
 
 
 
 
 
539
  sr_steps_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Inference Steps")
540
  sr_seed_input = gr.Number(label="Seed", value=42)
541
+ sr_color_input = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor (for Display)")
542
  sr_generate_button = gr.Button("Generate Super-Resolution", variant="primary")
543
  with gr.Column(scale=2):
544
  with gr.Row():
545
+ sr_input_display = gr.Image(label="Input (Avg)", type="pil", interactive=False)
546
  sr_output_image = gr.Image(label="Super-Resolved Image", type="pil", interactive=False)
547
+ sr_raw_download = gr.DownloadButton(label="Download Raw Output (.tif)", visible=True)
548
+ sr_gallery = gr.Gallery(value=sr_gallery_examples, label="Examples", columns=6, object_fit="contain", height="auto")
549
 
550
  # --- TAB 3: Denoising ---
551
  with gr.Tab("Denoising", id="denoising"):
 
 
 
 
 
 
 
 
 
552
  with gr.Row(variant="panel"):
553
  with gr.Column(scale=1, min_width=350):
554
  dn_input_image = gr.Image(type="numpy", label="Upload Noisy Image", image_mode="L")
555
+ dn_image_type_selector = gr.Dropdown(choices=list(DN_PROMPT_RULES.keys()), value='MICE', label="Select Image Type")
556
  dn_steps_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Inference Steps")
557
  dn_seed_input = gr.Number(label="Seed", value=42)
558
+ dn_color_input = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor (for Display)")
559
  dn_generate_button = gr.Button("Denoise Image", variant="primary")
560
  with gr.Column(scale=2):
561
  with gr.Row():
562
+ dn_original_display = gr.Image(label="Original", type="pil", interactive=False)
563
  dn_output_image = gr.Image(label="Denoised Image", type="pil", interactive=False)
564
+ dn_raw_download = gr.DownloadButton(label="Download Raw Output (.tif)", visible=True)
565
+ dn_gallery = gr.Gallery(value=dn_gallery_examples, label="Examples", columns=6, object_fit="contain", height="auto")
566
 
567
  # --- TAB 4: Mask-to-Image ---
568
  with gr.Tab("Mask-to-Image", id="mask2img"):
 
 
 
 
 
 
 
 
 
569
  with gr.Row(variant="panel"):
570
  with gr.Column(scale=1, min_width=350):
571
  m2i_input_file = gr.File(label="Upload Segmentation Mask (.tif)", file_types=['.tif', '.tiff'])
572
+ m2i_cell_type_input = gr.Textbox(label="Cell Type", placeholder="e.g., CoNSS, HeLa")
573
+ m2i_num_images_slider = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Number of Images")
574
  m2i_steps_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Inference Steps")
575
  m2i_seed_input = gr.Number(label="Seed", value=42)
576
+ m2i_color_input = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor (for Display)")
577
+ m2i_generate_button = gr.Button("Generate Samples", variant="primary")
578
  with gr.Column(scale=2):
579
  m2i_output_gallery = gr.Gallery(label="Generated Samples", columns=5, object_fit="contain", height="auto")
580
+ m2i_raw_download = gr.DownloadButton(label="Download All Raw Samples (.zip)", visible=True)
581
  m2i_input_display = gr.Image(label="Input Mask", type="pil", interactive=False)
582
+ m2i_gallery = gr.Gallery(value=m2i_gallery_examples, label="Examples", columns=6, object_fit="contain", height="auto")
583
 
584
  # --- TAB 5: Cell Segmentation ---
585
  with gr.Tab("Cell Segmentation", id="segmentation"):
 
 
 
 
 
 
 
 
586
  with gr.Row(variant="panel"):
587
  with gr.Column(scale=1, min_width=350):
588
+ seg_input_image = gr.Image(type="numpy", label="Upload Image", image_mode="L")
589
+ seg_model_selector = gr.Dropdown(choices=list(SEG_MODELS.keys()), value=list(SEG_MODELS.keys())[0], label="Model")
590
+ seg_diameter_input = gr.Number(label="Cell Diameter (0=auto)", value=30)
 
591
  seg_flow_slider = gr.Slider(minimum=0.0, maximum=3.0, step=0.1, value=0.4, label="Flow Threshold")
592
  seg_cellprob_slider = gr.Slider(minimum=-6.0, maximum=6.0, step=0.5, value=0.0, label="Cell Probability Threshold")
593
  seg_generate_button = gr.Button("Segment Cells", variant="primary")
594
  with gr.Column(scale=2):
 
595
  with gr.Row():
596
+ seg_original_display = gr.Image(label="Original", type="pil", interactive=False)
597
+ seg_output_image = gr.Image(label="Segmented Overlay", type="pil", interactive=False)
598
+ seg_gallery = gr.Gallery(value=seg_gallery_examples, label="Examples", columns=6, object_fit="contain", height="auto")
599
 
600
+ # --- TAB 6: Classification ---
601
  with gr.Tab("Classification", id="classification"):
 
 
 
 
 
 
 
 
602
  with gr.Row(variant="panel"):
603
  with gr.Column(scale=1, min_width=350):
604
+ cls_input_image = gr.Image(type="numpy", label="Upload Image", image_mode="L")
605
+ cls_model_selector = gr.Dropdown(choices=list(CLS_MODEL_PATHS.keys()), value=list(CLS_MODEL_PATHS.keys())[0], label="Model")
 
606
  cls_generate_button = gr.Button("Classify Image", variant="primary")
607
  with gr.Column(scale=2):
 
608
  cls_original_display = gr.Image(label="Input Image", type="pil", interactive=False)
609
+ cls_output_label = gr.Label(label="Results", num_top_classes=len(CLS_CLASS_NAMES))
610
+ cls_gallery = gr.Gallery(value=cls_gallery_examples, label="Examples", columns=6, object_fit="contain", height="auto")
611
 
612
 
613
  # --- Event Handlers ---
614
+ m2i_generate_button.click(fn=run_mask_to_image_generation, inputs=[m2i_input_file, m2i_cell_type_input, m2i_num_images_slider, m2i_steps_slider, m2i_seed_input, m2i_color_input], outputs=[m2i_input_display, m2i_output_gallery, m2i_raw_download])
615
  m2i_gallery.select(fn=select_example_input_file, outputs=m2i_input_file)
616
 
617
+ t2i_generate_button.click(fn=generate_t2i, inputs=[t2i_prompt_input, t2i_steps_slider, t2i_color_input], outputs=[t2i_generated_output, t2i_raw_download])
618
  t2i_gallery.select(fn=select_example_prompt, outputs=t2i_prompt_input)
619
 
620
  sr_model_selector.change(fn=update_sr_prompt, inputs=sr_model_selector, outputs=sr_prompt_input)
621
+ sr_generate_button.click(fn=run_super_resolution, inputs=[sr_input_file, sr_model_selector, sr_prompt_input, sr_steps_slider, sr_seed_input, sr_color_input], outputs=[sr_input_display, sr_output_image, sr_raw_download])
622
  sr_gallery.select(fn=select_example_input_file, outputs=sr_input_file)
623
 
624
+ dn_generate_button.click(fn=run_denoising, inputs=[dn_input_image, dn_image_type_selector, dn_steps_slider, dn_seed_input, dn_color_input], outputs=[dn_original_display, dn_output_image, dn_raw_download])
625
  dn_gallery.select(fn=select_example_input_file, outputs=dn_input_image)
626
 
627
  seg_generate_button.click(fn=run_segmentation, inputs=[seg_input_image, seg_model_selector, seg_diameter_input, seg_flow_slider, seg_cellprob_slider], outputs=[seg_original_display, seg_output_image])
 
634
  # --- 4. Launch Application ---
635
  if __name__ == "__main__":
636
  print("Interface built. Launching server...")
637
+ demo.launch()