Spaces:
Sleeping
Sleeping
Commit ·
d4a259c
1
Parent(s): 41e1156
Pseudo-color-3
Browse files
app.py
CHANGED
|
@@ -1,16 +1,12 @@
|
|
| 1 |
import torch
|
| 2 |
import numpy as np
|
| 3 |
import gradio as gr
|
| 4 |
-
from PIL import Image
|
| 5 |
import os
|
| 6 |
import json
|
| 7 |
-
import glob
|
| 8 |
-
import random
|
| 9 |
import tifffile
|
| 10 |
import re
|
| 11 |
-
import imageio
|
| 12 |
from torchvision import transforms, models
|
| 13 |
-
import accelerate
|
| 14 |
import shutil
|
| 15 |
import time
|
| 16 |
import torch.nn as nn
|
|
@@ -20,11 +16,12 @@ from collections import OrderedDict
|
|
| 20 |
import tempfile
|
| 21 |
import zipfile
|
| 22 |
import matplotlib.cm as cm
|
|
|
|
|
|
|
| 23 |
|
| 24 |
# --- Imports from both scripts ---
|
| 25 |
from diffusers import DDPMScheduler, DDIMScheduler
|
| 26 |
from transformers import CLIPTextModel, CLIPTokenizer
|
| 27 |
-
from accelerate.state import AcceleratorState
|
| 28 |
from transformers.utils import ContextManagers
|
| 29 |
|
| 30 |
# --- Custom Model Imports ---
|
|
@@ -34,45 +31,37 @@ from models.controlnet import ControlNetModel
|
|
| 34 |
from models.unet_2d_condition import UNet2DConditionModel
|
| 35 |
from models.pipeline_controlnet import DDPMControlnetPipeline
|
| 36 |
|
| 37 |
-
# ---
|
| 38 |
from cellpose import models as cellpose_models
|
| 39 |
-
from cellpose import plot as cellpose_plot
|
| 40 |
from huggingface_hub import snapshot_download
|
| 41 |
|
| 42 |
# --- 0. Configuration & Constants ---
|
| 43 |
hf_token = os.environ.get("HF_TOKEN")
|
| 44 |
-
# --- General ---
|
| 45 |
MODEL_TITLE = "🔬 FluoGen: AI-Powered Fluorescence Microscopy Suite"
|
| 46 |
MODEL_DESCRIPTION = """
|
| 47 |
**Paper**: *Generative AI empowering fluorescence microscopy imaging and analysis*
|
| 48 |
<br>
|
| 49 |
-
Select a task
|
|
|
|
| 50 |
"""
|
| 51 |
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 52 |
WEIGHT_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 53 |
LOGO_PATH = "utils/logo2_transparent.png"
|
| 54 |
-
|
| 55 |
-
# --- Global switch to control example saving ---
|
| 56 |
SAVE_EXAMPLES = False
|
| 57 |
|
| 58 |
# --- Base directory for all models ---
|
| 59 |
REPO_ID = "FluoGen-Group/FluoGen-demo-test-ckpts"
|
| 60 |
MODELS_ROOT_DIR = snapshot_download(repo_id=REPO_ID, token=hf_token)
|
| 61 |
|
| 62 |
-
# ---
|
| 63 |
M2I_CONTROLNET_PATH = f"{MODELS_ROOT_DIR}/ControlNet_M2I/checkpoint-30000"
|
| 64 |
M2I_EXAMPLE_IMG_DIR = "example_images_m2i"
|
| 65 |
-
|
| 66 |
-
# --- Tab 2: Text-to-Image Config ---
|
| 67 |
T2I_EXAMPLE_IMG_DIR = "example_images"
|
| 68 |
T2I_PRETRAINED_MODEL_PATH = f"{MODELS_ROOT_DIR}/stable-diffusion-v1-5"
|
| 69 |
T2I_UNET_PATH = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000"
|
| 70 |
-
|
| 71 |
-
# --- Tab 3, 4: ControlNet-based Tasks Config ---
|
| 72 |
CONTROLNET_CLIP_PATH = f"{MODELS_ROOT_DIR}/stable-diffusion-v1-5"
|
| 73 |
CONTROLNET_UNET_PATH = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000"
|
| 74 |
|
| 75 |
-
# Super-Resolution Models
|
| 76 |
SR_CONTROLNET_MODELS = {
|
| 77 |
"Checkpoint ER": f"{MODELS_ROOT_DIR}/ControlNet_SR/ER/checkpoint-30000",
|
| 78 |
"Checkpoint Microtubules": f"{MODELS_ROOT_DIR}/ControlNet_SR/Microtubules/checkpoint-72500",
|
|
@@ -80,13 +69,10 @@ SR_CONTROLNET_MODELS = {
|
|
| 80 |
"Checkpoint F-actin": f"{MODELS_ROOT_DIR}/ControlNet_SR/F-actin/checkpoint-35000",
|
| 81 |
}
|
| 82 |
SR_EXAMPLE_IMG_DIR = "example_images_sr"
|
| 83 |
-
|
| 84 |
-
# Denoising Model
|
| 85 |
DN_CONTROLNET_PATH = f"{MODELS_ROOT_DIR}/ControlNet_DN/checkpoint-10000"
|
| 86 |
DN_PROMPT_RULES = {'MICE': 'mouse brain tissues', 'FISH': 'zebrafish embryos', 'BPAE_B': 'nucleus of BPAE', 'BPAE_R': 'mitochondria of BPAE', 'BPAE_G': 'F-actin of BPAE'}
|
| 87 |
DN_EXAMPLE_IMG_DIR = "example_images_dn"
|
| 88 |
|
| 89 |
-
# --- Tab 5: Cell Segmentation Config ---
|
| 90 |
SEG_MODELS = {
|
| 91 |
"DynamicNet Model": f"{MODELS_ROOT_DIR}/Cellpose/DynamicNet_baseline/CP_dynamic_ten_epoch_0100",
|
| 92 |
"DynamicNet Model + FluoGen": f"{MODELS_ROOT_DIR}/Cellpose/DynamicNet_FluoGen/CP_dynamic_epoch_0300",
|
|
@@ -95,7 +81,6 @@ SEG_MODELS = {
|
|
| 95 |
}
|
| 96 |
SEG_EXAMPLE_IMG_DIR = "example_images_seg"
|
| 97 |
|
| 98 |
-
# --- Tab 6: Classification Config ---
|
| 99 |
CLS_MODEL_PATHS = OrderedDict({
|
| 100 |
"5shot": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot_re",
|
| 101 |
"5shot+FluoGen": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot_aug_re",
|
|
@@ -104,7 +89,7 @@ CLS_CLASS_NAMES = ['Nucleus', 'Endoplasmic Reticulum', 'Giantin', 'GPP130', 'Lys
|
|
| 104 |
CLS_EXAMPLE_IMG_DIR = "example_images_cls"
|
| 105 |
|
| 106 |
# --- Constants for Visualization ---
|
| 107 |
-
COLOR_MAPS = ["Grayscale", "Green (GFP)", "Red (RFP)", "Blue (DAPI)", "Magenta", "Cyan", "Yellow", "Fire", "Viridis", "Inferno"]
|
| 108 |
|
| 109 |
# --- Helper Functions ---
|
| 110 |
def sanitize_prompt_for_filename(prompt):
|
|
@@ -116,57 +101,73 @@ def min_max_norm(x):
|
|
| 116 |
if max_val - min_val < 1e-8: return np.zeros_like(x)
|
| 117 |
return (x - min_val) / (max_val - min_val)
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
def apply_pseudocolor(image_np, color_name="Grayscale"):
|
| 120 |
"""
|
| 121 |
Applies a pseudocolor to a single channel numpy image.
|
| 122 |
-
image_np: Single channel numpy array (any bit depth).
|
| 123 |
Returns: PIL Image in RGB.
|
| 124 |
"""
|
|
|
|
|
|
|
| 125 |
# Normalize to 0-1 for processing
|
| 126 |
norm_img = min_max_norm(np.squeeze(image_np))
|
| 127 |
|
| 128 |
if color_name == "Grayscale":
|
| 129 |
-
# Just convert to uint8 L
|
| 130 |
return Image.fromarray((norm_img * 255).astype(np.uint8)).convert("RGB")
|
| 131 |
|
| 132 |
-
# Create RGB canvas
|
| 133 |
h, w = norm_img.shape
|
| 134 |
rgb = np.zeros((h, w, 3), dtype=np.float32)
|
| 135 |
|
| 136 |
-
if color_name == "Green (GFP)":
|
| 137 |
-
|
| 138 |
-
elif color_name == "
|
| 139 |
-
|
| 140 |
-
elif color_name == "
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
# Use matplotlib colormaps
|
| 153 |
-
cmap_map = {"Fire": "inferno", "Viridis": "viridis", "Inferno": "inferno"} # Fire looks like inferno usually
|
| 154 |
-
if color_name == "Fire": cmap = cm.get_cmap("gnuplot2") # Better fire approximation
|
| 155 |
-
else: cmap = cm.get_cmap(cmap_map[color_name])
|
| 156 |
-
|
| 157 |
-
colored = cmap(norm_img) # Returns RGBA 0-1
|
| 158 |
-
rgb = colored[..., :3] # Drop Alpha
|
| 159 |
|
| 160 |
return Image.fromarray((rgb * 255).astype(np.uint8))
|
| 161 |
|
| 162 |
def save_temp_tiff(image_np, prefix="output"):
|
| 163 |
-
"""Saves numpy array to a temp TIFF file and returns path."""
|
| 164 |
tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".tif", prefix=f"{prefix}_")
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
save_data = image_np.astype(np.float32)
|
| 168 |
-
else:
|
| 169 |
-
save_data = image_np
|
| 170 |
tifffile.imwrite(tfile.name, save_data)
|
| 171 |
return tfile.name
|
| 172 |
|
|
@@ -176,8 +177,7 @@ def numpy_to_pil(image_np, target_mode="RGB"):
|
|
| 176 |
if target_mode == "L" and image_np.mode != "L": return image_np.convert("L")
|
| 177 |
return image_np
|
| 178 |
squeezed_np = np.squeeze(image_np);
|
| 179 |
-
if squeezed_np.dtype == np.uint8:
|
| 180 |
-
image_8bit = squeezed_np
|
| 181 |
else:
|
| 182 |
normalized_np = min_max_norm(squeezed_np)
|
| 183 |
image_8bit = (normalized_np * 255).astype(np.uint8)
|
|
@@ -211,8 +211,7 @@ def load_all_prompts():
|
|
| 211 |
if isinstance(data, list):
|
| 212 |
combined_prompts.extend(data)
|
| 213 |
for p in data: PROMPT_TO_MODEL_MAP[p] = cat["model"]
|
| 214 |
-
|
| 215 |
-
except Exception as e: print(f"✗ Error loading {cat['file']}: {e}")
|
| 216 |
if not combined_prompts: return ["F-actin of COS-7", "ER of COS-7"]
|
| 217 |
return combined_prompts
|
| 218 |
T2I_PROMPTS = load_all_prompts()
|
|
@@ -231,23 +230,22 @@ try:
|
|
| 231 |
current_t2i_unet_path = T2I_UNET_PATH
|
| 232 |
print("✓ Text-to-Image model loaded successfully!")
|
| 233 |
except Exception as e:
|
| 234 |
-
print(f"
|
| 235 |
|
| 236 |
try:
|
| 237 |
-
print("Loading shared ControlNet pipeline
|
| 238 |
controlnet_unet = UNet2DConditionModel.from_pretrained(CONTROLNET_UNET_PATH, subfolder="unet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
|
| 239 |
-
|
| 240 |
-
controlnet_controlnet = ControlNetModel.from_pretrained(default_controlnet_path, subfolder="controlnet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
|
| 241 |
controlnet_scheduler = DDIMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="v_prediction", rescale_betas_zero_snr=False, timestep_spacing="trailing")
|
| 242 |
controlnet_tokenizer = CLIPTokenizer.from_pretrained(CONTROLNET_CLIP_PATH, subfolder="tokenizer")
|
| 243 |
with ContextManagers([]):
|
| 244 |
controlnet_text_encoder = CLIPTextModel.from_pretrained(CONTROLNET_CLIP_PATH, subfolder="text_encoder").to(dtype=WEIGHT_DTYPE, device=DEVICE)
|
| 245 |
controlnet_pipe = DDPMControlnetPipeline(unet=controlnet_unet, controlnet=controlnet_controlnet, scheduler=controlnet_scheduler, text_encoder=controlnet_text_encoder, tokenizer=controlnet_tokenizer)
|
| 246 |
controlnet_pipe.to(dtype=WEIGHT_DTYPE, device=DEVICE)
|
| 247 |
-
controlnet_pipe.current_controlnet_path =
|
| 248 |
print("✓ Shared ControlNet pipeline loaded successfully!")
|
| 249 |
except Exception as e:
|
| 250 |
-
print(f"
|
| 251 |
|
| 252 |
# --- 2. Core Logic Functions ---
|
| 253 |
def swap_controlnet(pipe, target_path):
|
|
@@ -257,7 +255,7 @@ def swap_controlnet(pipe, target_path):
|
|
| 257 |
pipe.controlnet = ControlNetModel.from_pretrained(target_path, subfolder="controlnet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
|
| 258 |
pipe.current_controlnet_path = target_path
|
| 259 |
except Exception as e:
|
| 260 |
-
raise gr.Error(f"Failed to load ControlNet model
|
| 261 |
return pipe
|
| 262 |
|
| 263 |
def swap_t2i_unet(pipe, target_unet_path):
|
|
@@ -269,245 +267,210 @@ def swap_t2i_unet(pipe, target_unet_path):
|
|
| 269 |
new_unet = UNet2DModel.from_pretrained(target_unet_path, subfolder="unet").to(DEVICE)
|
| 270 |
pipe.unet = new_unet
|
| 271 |
current_t2i_unet_path = target_unet_path
|
| 272 |
-
print("✅ UNet swapped successfully.")
|
| 273 |
except Exception as e:
|
| 274 |
-
raise gr.Error(f"Failed to load UNet
|
| 275 |
return pipe
|
| 276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
@spaces.GPU(duration=120)
|
| 278 |
-
def generate_t2i(prompt, num_inference_steps,
|
| 279 |
global t2i_pipe
|
| 280 |
if t2i_pipe is None: raise gr.Error("Text-to-Image model is not loaded.")
|
| 281 |
-
target_model_path = PROMPT_TO_MODEL_MAP.get(prompt)
|
| 282 |
-
if not target_model_path:
|
| 283 |
-
target_model_path = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/FULL-checkpoint-275000"
|
| 284 |
-
print(f"ℹ️ Prompt '{prompt}' not found in predefined list. Using Foundation (Full) model.")
|
| 285 |
t2i_pipe = swap_t2i_unet(t2i_pipe, target_model_path)
|
| 286 |
|
| 287 |
-
print(f"\n🚀 Task started... | Prompt: '{prompt}'
|
| 288 |
image_np = t2i_pipe(prompt.lower(), generator=None, num_inference_steps=int(num_inference_steps), output_type="np").images
|
| 289 |
|
| 290 |
-
# 1. Save Raw Data
|
| 291 |
raw_file_path = save_temp_tiff(image_np, prefix="t2i_raw")
|
|
|
|
|
|
|
| 292 |
|
| 293 |
-
# 2. Apply Pseudocolor for Display
|
| 294 |
-
display_image = apply_pseudocolor(image_np, colormap_choice)
|
| 295 |
-
|
| 296 |
-
print("✓ Image generated")
|
| 297 |
if SAVE_EXAMPLES:
|
| 298 |
example_filepath = os.path.join(T2I_EXAMPLE_IMG_DIR, sanitize_prompt_for_filename(prompt))
|
| 299 |
-
if not os.path.exists(example_filepath):
|
| 300 |
-
display_image.save(example_filepath)
|
| 301 |
|
| 302 |
-
# Return Display
|
| 303 |
-
return display_image, raw_file_path
|
| 304 |
|
| 305 |
@spaces.GPU(duration=120)
|
| 306 |
-
def run_mask_to_image_generation(mask_file_obj, cell_type, num_images, steps, seed,
|
| 307 |
if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
|
| 308 |
-
if mask_file_obj is None: raise gr.Error("Please upload a segmentation mask
|
| 309 |
-
if not cell_type or not cell_type.strip(): raise gr.Error("Please enter a cell type.")
|
| 310 |
|
| 311 |
pipe = swap_controlnet(controlnet_pipe, M2I_CONTROLNET_PATH)
|
| 312 |
-
try:
|
| 313 |
-
|
| 314 |
-
except Exception as e:
|
| 315 |
-
raise gr.Error(f"Failed to read the TIF file. Error: {e}")
|
| 316 |
|
| 317 |
-
|
| 318 |
mask_normalized = min_max_norm(mask_np)
|
| 319 |
-
image_tensor = torch.from_numpy(mask_normalized.astype(np.float32))
|
| 320 |
-
image_tensor = image_tensor.unsqueeze(0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE)
|
| 321 |
image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor)
|
| 322 |
|
| 323 |
prompt = f"nuclei of {cell_type.strip()}"
|
| 324 |
-
print(f"\
|
| 325 |
|
|
|
|
| 326 |
generated_display_images = []
|
| 327 |
generated_raw_files = []
|
| 328 |
-
|
| 329 |
-
# Create a temp dir for the zip file
|
| 330 |
temp_dir = tempfile.mkdtemp()
|
| 331 |
|
| 332 |
for i in range(int(num_images)):
|
| 333 |
-
|
| 334 |
-
generator = torch.Generator(device=DEVICE).manual_seed(current_seed)
|
| 335 |
with torch.autocast("cuda"):
|
| 336 |
output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
|
| 337 |
|
| 338 |
-
# Save
|
|
|
|
|
|
|
|
|
|
| 339 |
raw_name = f"m2i_sample_{i+1}.tif"
|
| 340 |
raw_path = os.path.join(temp_dir, raw_name)
|
| 341 |
-
|
| 342 |
-
# Ensure correct type saving
|
| 343 |
save_data = output_np.astype(np.float32) if output_np.dtype == np.float16 else output_np
|
| 344 |
tifffile.imwrite(raw_path, save_data)
|
| 345 |
generated_raw_files.append(raw_path)
|
| 346 |
|
| 347 |
-
#
|
| 348 |
-
|
| 349 |
-
generated_display_images.append(pil_image)
|
| 350 |
-
print(f"✓ Generated image {i+1}/{int(num_images)}")
|
| 351 |
|
| 352 |
-
#
|
| 353 |
zip_filename = os.path.join(temp_dir, "raw_output_images.zip")
|
| 354 |
with zipfile.ZipFile(zip_filename, 'w') as zipf:
|
| 355 |
-
for file in generated_raw_files:
|
| 356 |
-
zipf.write(file, os.path.basename(file))
|
| 357 |
|
| 358 |
-
|
|
|
|
| 359 |
|
| 360 |
@spaces.GPU(duration=120)
|
| 361 |
-
def run_super_resolution(low_res_file_obj, controlnet_model_name, prompt, steps, seed,
|
| 362 |
if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
|
| 363 |
-
if low_res_file_obj is None: raise gr.Error("Please upload a
|
| 364 |
|
| 365 |
target_path = SR_CONTROLNET_MODELS.get(controlnet_model_name)
|
| 366 |
-
if not target_path: raise gr.Error(f"ControlNet model '{controlnet_model_name}' not found.")
|
| 367 |
-
|
| 368 |
pipe = swap_controlnet(controlnet_pipe, target_path)
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
except Exception as e:
|
| 372 |
-
raise gr.Error(f"Failed to read the TIF file. Error: {e}")
|
| 373 |
|
| 374 |
if image_stack_np.ndim != 3 or image_stack_np.shape[-3] != 9:
|
| 375 |
-
raise gr.Error(f"Invalid TIF shape. Expected 9 channels
|
| 376 |
|
| 377 |
-
|
| 378 |
-
input_display_image = numpy_to_pil(average_projection_np, "L")
|
| 379 |
-
|
| 380 |
image_tensor = torch.from_numpy(image_stack_np.astype(np.float32) / 65535.0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE)
|
| 381 |
image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor)
|
| 382 |
|
|
|
|
| 383 |
generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
|
| 384 |
with torch.autocast("cuda"):
|
| 385 |
output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
|
| 386 |
|
| 387 |
-
# Save Raw
|
| 388 |
raw_file_path = save_temp_tiff(output_np, prefix="sr_raw")
|
| 389 |
-
|
| 390 |
-
|
| 391 |
|
| 392 |
-
return
|
| 393 |
|
| 394 |
@spaces.GPU(duration=120)
|
| 395 |
-
def run_denoising(noisy_image_np, image_type, steps, seed,
|
| 396 |
if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
|
| 397 |
-
if noisy_image_np is None: raise gr.Error("Please upload
|
| 398 |
|
| 399 |
pipe = swap_controlnet(controlnet_pipe, DN_CONTROLNET_PATH)
|
| 400 |
prompt = DN_PROMPT_RULES.get(image_type, 'microscopy image')
|
| 401 |
-
print(f"\nTask started... | Task: Denoising | Prompt: '{prompt}' | Steps: {steps}")
|
| 402 |
|
| 403 |
-
|
| 404 |
-
image_tensor =
|
| 405 |
image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor)
|
| 406 |
|
| 407 |
generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
|
| 408 |
with torch.autocast("cuda"):
|
| 409 |
output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
|
| 410 |
|
| 411 |
-
# Save Raw
|
| 412 |
raw_file_path = save_temp_tiff(output_np, prefix="dn_raw")
|
| 413 |
-
|
| 414 |
-
|
| 415 |
|
| 416 |
-
return numpy_to_pil(noisy_image_np, "L"), output_display, raw_file_path
|
| 417 |
|
|
|
|
| 418 |
@spaces.GPU(duration=120)
|
| 419 |
def run_segmentation(input_image_np, model_name, diameter, flow_threshold, cellprob_threshold):
|
| 420 |
-
|
| 421 |
-
if input_image_np is None: raise gr.Error("Please upload an image to segment.")
|
| 422 |
model_path = SEG_MODELS.get(model_name)
|
| 423 |
-
if not model_path: raise gr.Error(f"Segmentation model '{model_name}' not found.")
|
| 424 |
|
| 425 |
-
print(f"\
|
| 426 |
try:
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
except Exception as e:
|
| 430 |
-
raise gr.Error(f"Failed to load Cellpose model. Error: {e}")
|
| 431 |
-
|
| 432 |
-
diameter_to_use = model.diam_labels if diameter == 0 else float(diameter)
|
| 433 |
-
try:
|
| 434 |
-
masks, _, _ = model.eval([input_image_np], channels=[0, 0], diameter=diameter_to_use, flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold)
|
| 435 |
-
mask_output = masks[0]
|
| 436 |
-
except Exception as e:
|
| 437 |
-
raise gr.Error(f"Cellpose model evaluation failed. Error: {e}")
|
| 438 |
|
| 439 |
original_rgb = numpy_to_pil(input_image_np, "RGB")
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
blended_image_np = ((0.6 * original_rgb_np + 0.4 * red_mask_layer).astype(np.uint8))
|
| 444 |
-
|
| 445 |
-
return numpy_to_pil(input_image_np, "L"), numpy_to_pil(blended_image_np, "RGB")
|
| 446 |
|
| 447 |
@spaces.GPU(duration=120)
|
| 448 |
def run_classification(input_image_np, model_name):
|
| 449 |
-
if input_image_np is None: raise gr.Error("Please upload an image
|
| 450 |
-
|
| 451 |
-
if not model_dir: raise gr.Error(f"Classification model '{model_name}' not found.")
|
| 452 |
-
model_path = os.path.join(model_dir, "best_resnet50.pth")
|
| 453 |
|
| 454 |
-
print(f"\
|
| 455 |
try:
|
| 456 |
-
model = models.resnet50(weights=None)
|
| 457 |
-
model.fc = nn.Linear(model.fc.in_features, len(CLS_CLASS_NAMES))
|
| 458 |
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
|
| 459 |
model.to(DEVICE).eval()
|
| 460 |
-
except Exception as e:
|
| 461 |
-
raise gr.Error(f"Failed to load classification model. Error: {e}")
|
| 462 |
-
|
| 463 |
-
input_pil = numpy_to_pil(input_image_np, "RGB")
|
| 464 |
-
transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
|
| 465 |
-
input_tensor = transform_test(input_pil).unsqueeze(0).to(DEVICE)
|
| 466 |
|
|
|
|
| 467 |
with torch.no_grad():
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
confidences = {name: float(prob) for name, prob in zip(CLS_CLASS_NAMES, probabilities)}
|
| 472 |
-
return numpy_to_pil(input_image_np, "L"), confidences
|
| 473 |
-
|
| 474 |
|
| 475 |
# --- 3. Gradio UI Layout ---
|
| 476 |
print("Building Gradio interface...")
|
| 477 |
-
for d in [M2I_EXAMPLE_IMG_DIR, T2I_EXAMPLE_IMG_DIR, SR_EXAMPLE_IMG_DIR, DN_EXAMPLE_IMG_DIR, SEG_EXAMPLE_IMG_DIR, CLS_EXAMPLE_IMG_DIR]:
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
if not os.path.exists(example_dir): return examples
|
| 490 |
-
for f in sorted(os.listdir(example_dir)):
|
| 491 |
if f.lower().endswith(('.tif', '.tiff', '.png', '.jpg')):
|
| 492 |
-
|
| 493 |
try:
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
examples.append((numpy_to_pil(img_np, "L"), filepath))
|
| 498 |
except: pass
|
| 499 |
-
return
|
| 500 |
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
# --- Universal event handlers ---
|
| 508 |
-
def select_example_prompt(evt: gr.SelectData): return evt.value['caption']
|
| 509 |
-
def select_example_input_file(evt: gr.SelectData): return evt.value['caption']
|
| 510 |
|
|
|
|
| 511 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 512 |
with gr.Row():
|
| 513 |
gr.Image(value=LOGO_PATH, width=300, height=200, container=False, interactive=False, show_download_button=False, show_fullscreen_button=False)
|
|
@@ -516,122 +479,160 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 516 |
with gr.Tabs():
|
| 517 |
# --- TAB 1: Text-to-Image ---
|
| 518 |
with gr.Tab("Text-to-Image Generation", id="txt2img"):
|
|
|
|
|
|
|
|
|
|
| 519 |
with gr.Row(variant="panel"):
|
| 520 |
with gr.Column(scale=1, min_width=350):
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
t2i_color_input = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor (for Display)")
|
| 525 |
-
t2i_generate_button = gr.Button("Generate", variant="primary")
|
| 526 |
with gr.Column(scale=2):
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
|
| 532 |
# --- TAB 2: Super-Resolution ---
|
| 533 |
with gr.Tab("Super-Resolution", id="super_res"):
|
|
|
|
| 534 |
with gr.Row(variant="panel"):
|
| 535 |
with gr.Column(scale=1, min_width=350):
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
sr_generate_button = gr.Button("Generate Super-Resolution", variant="primary")
|
| 543 |
with gr.Column(scale=2):
|
| 544 |
with gr.Row():
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 549 |
|
| 550 |
# --- TAB 3: Denoising ---
|
| 551 |
with gr.Tab("Denoising", id="denoising"):
|
|
|
|
| 552 |
with gr.Row(variant="panel"):
|
| 553 |
with gr.Column(scale=1, min_width=350):
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
dn_generate_button = gr.Button("Denoise Image", variant="primary")
|
| 560 |
with gr.Column(scale=2):
|
| 561 |
with gr.Row():
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
|
| 567 |
# --- TAB 4: Mask-to-Image ---
|
| 568 |
with gr.Tab("Mask-to-Image", id="mask2img"):
|
|
|
|
| 569 |
with gr.Row(variant="panel"):
|
| 570 |
with gr.Column(scale=1, min_width=350):
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
m2i_generate_button = gr.Button("Generate Samples", variant="primary")
|
| 578 |
with gr.Column(scale=2):
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 583 |
|
| 584 |
# --- TAB 5: Cell Segmentation ---
|
| 585 |
with gr.Tab("Cell Segmentation", id="segmentation"):
|
| 586 |
with gr.Row(variant="panel"):
|
| 587 |
with gr.Column(scale=1, min_width=350):
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
with gr.Column(scale=2):
|
| 595 |
with gr.Row():
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
|
|
|
|
|
|
| 600 |
# --- TAB 6: Classification ---
|
| 601 |
with gr.Tab("Classification", id="classification"):
|
| 602 |
with gr.Row(variant="panel"):
|
| 603 |
with gr.Column(scale=1, min_width=350):
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
with gr.Column(scale=2):
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
# --- Event Handlers ---
|
| 614 |
-
m2i_generate_button.click(fn=run_mask_to_image_generation, inputs=[m2i_input_file, m2i_cell_type_input, m2i_num_images_slider, m2i_steps_slider, m2i_seed_input, m2i_color_input], outputs=[m2i_input_display, m2i_output_gallery, m2i_raw_download])
|
| 615 |
-
m2i_gallery.select(fn=select_example_input_file, outputs=m2i_input_file)
|
| 616 |
-
|
| 617 |
-
t2i_generate_button.click(fn=generate_t2i, inputs=[t2i_prompt_input, t2i_steps_slider, t2i_color_input], outputs=[t2i_generated_output, t2i_raw_download])
|
| 618 |
-
t2i_gallery.select(fn=select_example_prompt, outputs=t2i_prompt_input)
|
| 619 |
-
|
| 620 |
-
sr_model_selector.change(fn=update_sr_prompt, inputs=sr_model_selector, outputs=sr_prompt_input)
|
| 621 |
-
sr_generate_button.click(fn=run_super_resolution, inputs=[sr_input_file, sr_model_selector, sr_prompt_input, sr_steps_slider, sr_seed_input, sr_color_input], outputs=[sr_input_display, sr_output_image, sr_raw_download])
|
| 622 |
-
sr_gallery.select(fn=select_example_input_file, outputs=sr_input_file)
|
| 623 |
-
|
| 624 |
-
dn_generate_button.click(fn=run_denoising, inputs=[dn_input_image, dn_image_type_selector, dn_steps_slider, dn_seed_input, dn_color_input], outputs=[dn_original_display, dn_output_image, dn_raw_download])
|
| 625 |
-
dn_gallery.select(fn=select_example_input_file, outputs=dn_input_image)
|
| 626 |
-
|
| 627 |
-
seg_generate_button.click(fn=run_segmentation, inputs=[seg_input_image, seg_model_selector, seg_diameter_input, seg_flow_slider, seg_cellprob_slider], outputs=[seg_original_display, seg_output_image])
|
| 628 |
-
seg_gallery.select(fn=select_example_input_file, outputs=seg_input_image)
|
| 629 |
-
|
| 630 |
-
cls_generate_button.click(fn=run_classification, inputs=[cls_input_image, cls_model_selector], outputs=[cls_original_display, cls_output_label])
|
| 631 |
-
cls_gallery.select(fn=select_example_input_file, outputs=cls_input_image)
|
| 632 |
-
|
| 633 |
|
| 634 |
-
# --- 4. Launch Application ---
|
| 635 |
if __name__ == "__main__":
|
| 636 |
-
print("Interface built. Launching server...")
|
| 637 |
demo.launch()
|
|
|
|
| 1 |
import torch
|
| 2 |
import numpy as np
|
| 3 |
import gradio as gr
|
| 4 |
+
from PIL import Image
|
| 5 |
import os
|
| 6 |
import json
|
|
|
|
|
|
|
| 7 |
import tifffile
|
| 8 |
import re
|
|
|
|
| 9 |
from torchvision import transforms, models
|
|
|
|
| 10 |
import shutil
|
| 11 |
import time
|
| 12 |
import torch.nn as nn
|
|
|
|
| 16 |
import tempfile
|
| 17 |
import zipfile
|
| 18 |
import matplotlib.cm as cm
|
| 19 |
+
import matplotlib.pyplot as plt
|
| 20 |
+
import io
|
| 21 |
|
| 22 |
# --- Imports from both scripts ---
|
| 23 |
from diffusers import DDPMScheduler, DDIMScheduler
|
| 24 |
from transformers import CLIPTextModel, CLIPTokenizer
|
|
|
|
| 25 |
from transformers.utils import ContextManagers
|
| 26 |
|
| 27 |
# --- Custom Model Imports ---
|
|
|
|
| 31 |
from models.unet_2d_condition import UNet2DConditionModel
|
| 32 |
from models.pipeline_controlnet import DDPMControlnetPipeline
|
| 33 |
|
| 34 |
+
# --- Segmentation Imports ---
|
| 35 |
from cellpose import models as cellpose_models
|
|
|
|
| 36 |
from huggingface_hub import snapshot_download
|
| 37 |
|
| 38 |
# --- 0. Configuration & Constants ---
|
| 39 |
hf_token = os.environ.get("HF_TOKEN")
|
|
|
|
| 40 |
MODEL_TITLE = "🔬 FluoGen: AI-Powered Fluorescence Microscopy Suite"
|
| 41 |
MODEL_DESCRIPTION = """
|
| 42 |
**Paper**: *Generative AI empowering fluorescence microscopy imaging and analysis*
|
| 43 |
<br>
|
| 44 |
+
Select a task below.
|
| 45 |
+
**Note**: The "Pseudocolor" option can be adjusted *after* generation for instant visualization. Use the "Download Raw Output" button to get the scientific 16-bit/Float data.
|
| 46 |
"""
|
| 47 |
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 48 |
WEIGHT_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 49 |
LOGO_PATH = "utils/logo2_transparent.png"
|
|
|
|
|
|
|
| 50 |
SAVE_EXAMPLES = False
|
| 51 |
|
| 52 |
# --- Base directory for all models ---
|
| 53 |
REPO_ID = "FluoGen-Group/FluoGen-demo-test-ckpts"
|
| 54 |
MODELS_ROOT_DIR = snapshot_download(repo_id=REPO_ID, token=hf_token)
|
| 55 |
|
| 56 |
+
# --- Paths Config ---
|
| 57 |
M2I_CONTROLNET_PATH = f"{MODELS_ROOT_DIR}/ControlNet_M2I/checkpoint-30000"
|
| 58 |
M2I_EXAMPLE_IMG_DIR = "example_images_m2i"
|
|
|
|
|
|
|
| 59 |
T2I_EXAMPLE_IMG_DIR = "example_images"
|
| 60 |
T2I_PRETRAINED_MODEL_PATH = f"{MODELS_ROOT_DIR}/stable-diffusion-v1-5"
|
| 61 |
T2I_UNET_PATH = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000"
|
|
|
|
|
|
|
| 62 |
CONTROLNET_CLIP_PATH = f"{MODELS_ROOT_DIR}/stable-diffusion-v1-5"
|
| 63 |
CONTROLNET_UNET_PATH = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000"
|
| 64 |
|
|
|
|
| 65 |
SR_CONTROLNET_MODELS = {
|
| 66 |
"Checkpoint ER": f"{MODELS_ROOT_DIR}/ControlNet_SR/ER/checkpoint-30000",
|
| 67 |
"Checkpoint Microtubules": f"{MODELS_ROOT_DIR}/ControlNet_SR/Microtubules/checkpoint-72500",
|
|
|
|
| 69 |
"Checkpoint F-actin": f"{MODELS_ROOT_DIR}/ControlNet_SR/F-actin/checkpoint-35000",
|
| 70 |
}
|
| 71 |
SR_EXAMPLE_IMG_DIR = "example_images_sr"
|
|
|
|
|
|
|
| 72 |
DN_CONTROLNET_PATH = f"{MODELS_ROOT_DIR}/ControlNet_DN/checkpoint-10000"
|
| 73 |
DN_PROMPT_RULES = {'MICE': 'mouse brain tissues', 'FISH': 'zebrafish embryos', 'BPAE_B': 'nucleus of BPAE', 'BPAE_R': 'mitochondria of BPAE', 'BPAE_G': 'F-actin of BPAE'}
|
| 74 |
DN_EXAMPLE_IMG_DIR = "example_images_dn"
|
| 75 |
|
|
|
|
| 76 |
SEG_MODELS = {
|
| 77 |
"DynamicNet Model": f"{MODELS_ROOT_DIR}/Cellpose/DynamicNet_baseline/CP_dynamic_ten_epoch_0100",
|
| 78 |
"DynamicNet Model + FluoGen": f"{MODELS_ROOT_DIR}/Cellpose/DynamicNet_FluoGen/CP_dynamic_epoch_0300",
|
|
|
|
| 81 |
}
|
| 82 |
SEG_EXAMPLE_IMG_DIR = "example_images_seg"
|
| 83 |
|
|
|
|
| 84 |
CLS_MODEL_PATHS = OrderedDict({
|
| 85 |
"5shot": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot_re",
|
| 86 |
"5shot+FluoGen": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot_aug_re",
|
|
|
|
| 89 |
CLS_EXAMPLE_IMG_DIR = "example_images_cls"
|
| 90 |
|
| 91 |
# --- Constants for Visualization ---
|
| 92 |
+
COLOR_MAPS = ["Grayscale", "Green (GFP)", "Red (RFP)", "Blue (DAPI)", "Magenta", "Cyan", "Yellow", "Fire", "Viridis", "Inferno", "Magma", "Plasma"]
|
| 93 |
|
| 94 |
# --- Helper Functions ---
|
| 95 |
def sanitize_prompt_for_filename(prompt):
|
|
|
|
| 101 |
if max_val - min_val < 1e-8: return np.zeros_like(x)
|
| 102 |
return (x - min_val) / (max_val - min_val)
|
| 103 |
|
| 104 |
+
def generate_colorbar_preview(color_name):
|
| 105 |
+
"""Generates a small PIL image representing the colormap."""
|
| 106 |
+
if color_name == "Grayscale":
|
| 107 |
+
gradient = np.linspace(0, 1, 256).reshape(1, 256)
|
| 108 |
+
return Image.fromarray((gradient * 255).astype(np.uint8)).convert("RGB").resize((256, 30))
|
| 109 |
+
|
| 110 |
+
gradient = np.linspace(0, 1, 256).reshape(1, 256)
|
| 111 |
+
rgb = np.zeros((1, 256, 3))
|
| 112 |
+
|
| 113 |
+
if color_name == "Green (GFP)": rgb[..., 1] = gradient
|
| 114 |
+
elif color_name == "Red (RFP)": rgb[..., 0] = gradient
|
| 115 |
+
elif color_name == "Blue (DAPI)": rgb[..., 2] = gradient
|
| 116 |
+
elif color_name == "Magenta": rgb[..., 0] = gradient; rgb[..., 2] = gradient
|
| 117 |
+
elif color_name == "Cyan": rgb[..., 1] = gradient; rgb[..., 2] = gradient
|
| 118 |
+
elif color_name == "Yellow": rgb[..., 0] = gradient; rgb[..., 1] = gradient
|
| 119 |
+
else:
|
| 120 |
+
# Matplotlib maps
|
| 121 |
+
mpl_map_name = color_name.lower()
|
| 122 |
+
if color_name == "Fire": mpl_map_name = "gnuplot2"
|
| 123 |
+
try:
|
| 124 |
+
cmap = cm.get_cmap(mpl_map_name)
|
| 125 |
+
rgb = cmap(gradient)[..., :3]
|
| 126 |
+
except:
|
| 127 |
+
return generate_colorbar_preview("Grayscale") # Fallback
|
| 128 |
+
|
| 129 |
+
img_np = (rgb * 255).astype(np.uint8)
|
| 130 |
+
return Image.fromarray(img_np).resize((256, 30))
|
| 131 |
+
|
| 132 |
def apply_pseudocolor(image_np, color_name="Grayscale"):
|
| 133 |
"""
|
| 134 |
Applies a pseudocolor to a single channel numpy image.
|
|
|
|
| 135 |
Returns: PIL Image in RGB.
|
| 136 |
"""
|
| 137 |
+
if image_np is None: return None
|
| 138 |
+
|
| 139 |
# Normalize to 0-1 for processing
|
| 140 |
norm_img = min_max_norm(np.squeeze(image_np))
|
| 141 |
|
| 142 |
if color_name == "Grayscale":
|
|
|
|
| 143 |
return Image.fromarray((norm_img * 255).astype(np.uint8)).convert("RGB")
|
| 144 |
|
|
|
|
| 145 |
h, w = norm_img.shape
|
| 146 |
rgb = np.zeros((h, w, 3), dtype=np.float32)
|
| 147 |
|
| 148 |
+
if color_name == "Green (GFP)": rgb[..., 1] = norm_img
|
| 149 |
+
elif color_name == "Red (RFP)": rgb[..., 0] = norm_img
|
| 150 |
+
elif color_name == "Blue (DAPI)": rgb[..., 2] = norm_img
|
| 151 |
+
elif color_name == "Magenta": rgb[..., 0] = norm_img; rgb[..., 2] = norm_img
|
| 152 |
+
elif color_name == "Cyan": rgb[..., 1] = norm_img; rgb[..., 2] = norm_img
|
| 153 |
+
elif color_name == "Yellow": rgb[..., 0] = norm_img; rgb[..., 1] = norm_img
|
| 154 |
+
else:
|
| 155 |
+
# Matplotlib maps
|
| 156 |
+
mpl_map_name = color_name.lower()
|
| 157 |
+
if color_name == "Fire": mpl_map_name = "gnuplot2"
|
| 158 |
+
try:
|
| 159 |
+
cmap = cm.get_cmap(mpl_map_name)
|
| 160 |
+
colored = cmap(norm_img)
|
| 161 |
+
rgb = colored[..., :3]
|
| 162 |
+
except:
|
| 163 |
+
return apply_pseudocolor(image_np, "Grayscale")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
return Image.fromarray((rgb * 255).astype(np.uint8))
|
| 166 |
|
| 167 |
def save_temp_tiff(image_np, prefix="output"):
|
|
|
|
| 168 |
tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".tif", prefix=f"{prefix}_")
|
| 169 |
+
if image_np.dtype == np.float16: save_data = image_np.astype(np.float32)
|
| 170 |
+
else: save_data = image_np
|
|
|
|
|
|
|
|
|
|
| 171 |
tifffile.imwrite(tfile.name, save_data)
|
| 172 |
return tfile.name
|
| 173 |
|
|
|
|
| 177 |
if target_mode == "L" and image_np.mode != "L": return image_np.convert("L")
|
| 178 |
return image_np
|
| 179 |
squeezed_np = np.squeeze(image_np);
|
| 180 |
+
if squeezed_np.dtype == np.uint8: image_8bit = squeezed_np
|
|
|
|
| 181 |
else:
|
| 182 |
normalized_np = min_max_norm(squeezed_np)
|
| 183 |
image_8bit = (normalized_np * 255).astype(np.uint8)
|
|
|
|
| 211 |
if isinstance(data, list):
|
| 212 |
combined_prompts.extend(data)
|
| 213 |
for p in data: PROMPT_TO_MODEL_MAP[p] = cat["model"]
|
| 214 |
+
except Exception: pass
|
|
|
|
| 215 |
if not combined_prompts: return ["F-actin of COS-7", "ER of COS-7"]
|
| 216 |
return combined_prompts
|
| 217 |
T2I_PROMPTS = load_all_prompts()
|
|
|
|
| 230 |
current_t2i_unet_path = T2I_UNET_PATH
|
| 231 |
print("✓ Text-to-Image model loaded successfully!")
|
| 232 |
except Exception as e:
|
| 233 |
+
print(f"FATAL: Text-to-Image Model Loading Failed: {e}")
|
| 234 |
|
| 235 |
try:
|
| 236 |
+
print("Loading shared ControlNet pipeline...")
|
| 237 |
controlnet_unet = UNet2DConditionModel.from_pretrained(CONTROLNET_UNET_PATH, subfolder="unet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
|
| 238 |
+
controlnet_controlnet = ControlNetModel.from_pretrained(M2I_CONTROLNET_PATH, subfolder="controlnet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
|
|
|
|
| 239 |
controlnet_scheduler = DDIMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="v_prediction", rescale_betas_zero_snr=False, timestep_spacing="trailing")
|
| 240 |
controlnet_tokenizer = CLIPTokenizer.from_pretrained(CONTROLNET_CLIP_PATH, subfolder="tokenizer")
|
| 241 |
with ContextManagers([]):
|
| 242 |
controlnet_text_encoder = CLIPTextModel.from_pretrained(CONTROLNET_CLIP_PATH, subfolder="text_encoder").to(dtype=WEIGHT_DTYPE, device=DEVICE)
|
| 243 |
controlnet_pipe = DDPMControlnetPipeline(unet=controlnet_unet, controlnet=controlnet_controlnet, scheduler=controlnet_scheduler, text_encoder=controlnet_text_encoder, tokenizer=controlnet_tokenizer)
|
| 244 |
controlnet_pipe.to(dtype=WEIGHT_DTYPE, device=DEVICE)
|
| 245 |
+
controlnet_pipe.current_controlnet_path = M2I_CONTROLNET_PATH
|
| 246 |
print("✓ Shared ControlNet pipeline loaded successfully!")
|
| 247 |
except Exception as e:
|
| 248 |
+
print(f"FATAL: ControlNet Pipeline Loading Failed: {e}")
|
| 249 |
|
| 250 |
# --- 2. Core Logic Functions ---
|
| 251 |
def swap_controlnet(pipe, target_path):
|
|
|
|
| 255 |
pipe.controlnet = ControlNetModel.from_pretrained(target_path, subfolder="controlnet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
|
| 256 |
pipe.current_controlnet_path = target_path
|
| 257 |
except Exception as e:
|
| 258 |
+
raise gr.Error(f"Failed to load ControlNet model. Error: {e}")
|
| 259 |
return pipe
|
| 260 |
|
| 261 |
def swap_t2i_unet(pipe, target_unet_path):
|
|
|
|
| 267 |
new_unet = UNet2DModel.from_pretrained(target_unet_path, subfolder="unet").to(DEVICE)
|
| 268 |
pipe.unet = new_unet
|
| 269 |
current_t2i_unet_path = target_unet_path
|
|
|
|
| 270 |
except Exception as e:
|
| 271 |
+
raise gr.Error(f"Failed to load UNet. Error: {e}")
|
| 272 |
return pipe
|
| 273 |
|
| 274 |
+
# --- Dynamic Color Update Functions ---
|
| 275 |
+
def update_single_image_color(raw_np_state, color_name):
|
| 276 |
+
if raw_np_state is None: return None, None
|
| 277 |
+
display_img = apply_pseudocolor(raw_np_state, color_name)
|
| 278 |
+
bar_img = generate_colorbar_preview(color_name)
|
| 279 |
+
return display_img, bar_img
|
| 280 |
+
|
| 281 |
+
def update_gallery_color(raw_list_state, color_name):
|
| 282 |
+
if raw_list_state is None: return None, None
|
| 283 |
+
new_gallery = []
|
| 284 |
+
for img_np in raw_list_state:
|
| 285 |
+
new_gallery.append(apply_pseudocolor(img_np, color_name))
|
| 286 |
+
bar_img = generate_colorbar_preview(color_name)
|
| 287 |
+
return new_gallery, bar_img
|
| 288 |
+
|
| 289 |
+
# --- Generation Functions ---
|
| 290 |
@spaces.GPU(duration=120)
|
| 291 |
+
def generate_t2i(prompt, num_inference_steps, current_color):
|
| 292 |
global t2i_pipe
|
| 293 |
if t2i_pipe is None: raise gr.Error("Text-to-Image model is not loaded.")
|
| 294 |
+
target_model_path = PROMPT_TO_MODEL_MAP.get(prompt, f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/FULL-checkpoint-275000")
|
|
|
|
|
|
|
|
|
|
| 295 |
t2i_pipe = swap_t2i_unet(t2i_pipe, target_model_path)
|
| 296 |
|
| 297 |
+
print(f"\n🚀 T2I Task started... | Prompt: '{prompt}'")
|
| 298 |
image_np = t2i_pipe(prompt.lower(), generator=None, num_inference_steps=int(num_inference_steps), output_type="np").images
|
| 299 |
|
|
|
|
| 300 |
raw_file_path = save_temp_tiff(image_np, prefix="t2i_raw")
|
| 301 |
+
display_image = apply_pseudocolor(image_np, current_color)
|
| 302 |
+
colorbar_img = generate_colorbar_preview(current_color)
|
| 303 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
if SAVE_EXAMPLES:
|
| 305 |
example_filepath = os.path.join(T2I_EXAMPLE_IMG_DIR, sanitize_prompt_for_filename(prompt))
|
| 306 |
+
if not os.path.exists(example_filepath): display_image.save(example_filepath)
|
|
|
|
| 307 |
|
| 308 |
+
# Return: Display, Raw Path, Raw State, Colorbar
|
| 309 |
+
return display_image, raw_file_path, image_np, colorbar_img
|
| 310 |
|
| 311 |
@spaces.GPU(duration=120)
|
| 312 |
+
def run_mask_to_image_generation(mask_file_obj, cell_type, num_images, steps, seed, current_color):
|
| 313 |
if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
|
| 314 |
+
if mask_file_obj is None: raise gr.Error("Please upload a segmentation mask.")
|
|
|
|
| 315 |
|
| 316 |
pipe = swap_controlnet(controlnet_pipe, M2I_CONTROLNET_PATH)
|
| 317 |
+
try: mask_np = tifffile.imread(mask_file_obj.name)
|
| 318 |
+
except Exception as e: raise gr.Error(f"Failed to read TIF. Error: {e}")
|
|
|
|
|
|
|
| 319 |
|
| 320 |
+
input_display = numpy_to_pil(mask_np, "L")
|
| 321 |
mask_normalized = min_max_norm(mask_np)
|
| 322 |
+
image_tensor = torch.from_numpy(mask_normalized.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE)
|
|
|
|
| 323 |
image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor)
|
| 324 |
|
| 325 |
prompt = f"nuclei of {cell_type.strip()}"
|
| 326 |
+
print(f"\nM2I Task started... | Prompt: '{prompt}'")
|
| 327 |
|
| 328 |
+
generated_raw_list = []
|
| 329 |
generated_display_images = []
|
| 330 |
generated_raw_files = []
|
|
|
|
|
|
|
| 331 |
temp_dir = tempfile.mkdtemp()
|
| 332 |
|
| 333 |
for i in range(int(num_images)):
|
| 334 |
+
generator = torch.Generator(device=DEVICE).manual_seed(int(seed) + i)
|
|
|
|
| 335 |
with torch.autocast("cuda"):
|
| 336 |
output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
|
| 337 |
|
| 338 |
+
# Save Raw State
|
| 339 |
+
generated_raw_list.append(output_np)
|
| 340 |
+
|
| 341 |
+
# Save Temp File
|
| 342 |
raw_name = f"m2i_sample_{i+1}.tif"
|
| 343 |
raw_path = os.path.join(temp_dir, raw_name)
|
|
|
|
|
|
|
| 344 |
save_data = output_np.astype(np.float32) if output_np.dtype == np.float16 else output_np
|
| 345 |
tifffile.imwrite(raw_path, save_data)
|
| 346 |
generated_raw_files.append(raw_path)
|
| 347 |
|
| 348 |
+
# Display
|
| 349 |
+
generated_display_images.append(apply_pseudocolor(output_np, current_color))
|
|
|
|
|
|
|
| 350 |
|
| 351 |
+
# ZIP
|
| 352 |
zip_filename = os.path.join(temp_dir, "raw_output_images.zip")
|
| 353 |
with zipfile.ZipFile(zip_filename, 'w') as zipf:
|
| 354 |
+
for file in generated_raw_files: zipf.write(file, os.path.basename(file))
|
|
|
|
| 355 |
|
| 356 |
+
colorbar_img = generate_colorbar_preview(current_color)
|
| 357 |
+
return input_display, generated_display_images, zip_filename, generated_raw_list, colorbar_img
|
| 358 |
|
| 359 |
@spaces.GPU(duration=120)
|
| 360 |
+
def run_super_resolution(low_res_file_obj, controlnet_model_name, prompt, steps, seed, current_color):
|
| 361 |
if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
|
| 362 |
+
if low_res_file_obj is None: raise gr.Error("Please upload a file.")
|
| 363 |
|
| 364 |
target_path = SR_CONTROLNET_MODELS.get(controlnet_model_name)
|
|
|
|
|
|
|
| 365 |
pipe = swap_controlnet(controlnet_pipe, target_path)
|
| 366 |
+
|
| 367 |
+
try: image_stack_np = tifffile.imread(low_res_file_obj.name)
|
| 368 |
+
except Exception as e: raise gr.Error(f"Failed to read TIF. Error: {e}")
|
|
|
|
| 369 |
|
| 370 |
if image_stack_np.ndim != 3 or image_stack_np.shape[-3] != 9:
|
| 371 |
+
raise gr.Error(f"Invalid TIF shape. Expected 9 channels, got {image_stack_np.shape}.")
|
| 372 |
|
| 373 |
+
input_display = numpy_to_pil(np.mean(image_stack_np, axis=0), "L")
|
|
|
|
|
|
|
| 374 |
image_tensor = torch.from_numpy(image_stack_np.astype(np.float32) / 65535.0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE)
|
| 375 |
image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor)
|
| 376 |
|
| 377 |
+
print(f"\nSR Task started...")
|
| 378 |
generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
|
| 379 |
with torch.autocast("cuda"):
|
| 380 |
output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
|
| 381 |
|
|
|
|
| 382 |
raw_file_path = save_temp_tiff(output_np, prefix="sr_raw")
|
| 383 |
+
output_display = apply_pseudocolor(output_np, current_color)
|
| 384 |
+
colorbar_img = generate_colorbar_preview(current_color)
|
| 385 |
|
| 386 |
+
return input_display, output_display, raw_file_path, output_np, colorbar_img
|
| 387 |
|
| 388 |
@spaces.GPU(duration=120)
|
| 389 |
+
def run_denoising(noisy_image_np, image_type, steps, seed, current_color):
|
| 390 |
if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
|
| 391 |
+
if noisy_image_np is None: raise gr.Error("Please upload an image.")
|
| 392 |
|
| 393 |
pipe = swap_controlnet(controlnet_pipe, DN_CONTROLNET_PATH)
|
| 394 |
prompt = DN_PROMPT_RULES.get(image_type, 'microscopy image')
|
|
|
|
| 395 |
|
| 396 |
+
print(f"\nDN Task started...")
|
| 397 |
+
image_tensor = torch.from_numpy(noisy_image_np.astype(np.float32) / 255.0).unsqueeze(0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE)
|
| 398 |
image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor)
|
| 399 |
|
| 400 |
generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
|
| 401 |
with torch.autocast("cuda"):
|
| 402 |
output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
|
| 403 |
|
|
|
|
| 404 |
raw_file_path = save_temp_tiff(output_np, prefix="dn_raw")
|
| 405 |
+
output_display = apply_pseudocolor(output_np, current_color)
|
| 406 |
+
colorbar_img = generate_colorbar_preview(current_color)
|
| 407 |
|
| 408 |
+
return numpy_to_pil(noisy_image_np, "L"), output_display, raw_file_path, output_np, colorbar_img
|
| 409 |
|
| 410 |
+
# --- Segmentation & Classification (Unchanged Logic, just wrapper) ---
|
| 411 |
@spaces.GPU(duration=120)
|
| 412 |
def run_segmentation(input_image_np, model_name, diameter, flow_threshold, cellprob_threshold):
|
| 413 |
+
if input_image_np is None: raise gr.Error("Please upload an image.")
|
|
|
|
| 414 |
model_path = SEG_MODELS.get(model_name)
|
|
|
|
| 415 |
|
| 416 |
+
print(f"\nSeg Task started...")
|
| 417 |
try:
|
| 418 |
+
model = cellpose_models.CellposeModel(gpu=torch.cuda.is_available(), pretrained_model=model_path)
|
| 419 |
+
masks, _, _ = model.eval([input_image_np], channels=[0, 0], diameter=(model.diam_labels if diameter==0 else diameter), flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold)
|
| 420 |
+
except Exception as e: raise gr.Error(f"Segmentation failed: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
|
| 422 |
original_rgb = numpy_to_pil(input_image_np, "RGB")
|
| 423 |
+
red_mask = np.zeros_like(np.array(original_rgb)); red_mask[masks[0] > 0] = [139, 0, 0]
|
| 424 |
+
blended = ((0.6 * np.array(original_rgb) + 0.4 * red_mask).astype(np.uint8))
|
| 425 |
+
return numpy_to_pil(input_image_np, "L"), numpy_to_pil(blended, "RGB")
|
|
|
|
|
|
|
|
|
|
| 426 |
|
| 427 |
@spaces.GPU(duration=120)
|
| 428 |
def run_classification(input_image_np, model_name):
|
| 429 |
+
if input_image_np is None: raise gr.Error("Please upload an image.")
|
| 430 |
+
model_path = os.path.join(CLS_MODEL_PATHS.get(model_name), "best_resnet50.pth")
|
|
|
|
|
|
|
| 431 |
|
| 432 |
+
print(f"\nCls Task started...")
|
| 433 |
try:
|
| 434 |
+
model = models.resnet50(weights=None); model.fc = nn.Linear(model.fc.in_features, len(CLS_CLASS_NAMES))
|
|
|
|
| 435 |
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
|
| 436 |
model.to(DEVICE).eval()
|
| 437 |
+
except Exception as e: raise gr.Error(f"Classification failed: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
|
| 439 |
+
input_tensor = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)])(numpy_to_pil(input_image_np, "RGB")).unsqueeze(0).to(DEVICE)
|
| 440 |
with torch.no_grad():
|
| 441 |
+
probs = F.softmax(model(input_tensor), dim=1).squeeze().cpu().numpy()
|
| 442 |
+
return numpy_to_pil(input_image_np, "L"), {name: float(p) for name, p in zip(CLS_CLASS_NAMES, probs)}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
|
| 444 |
# --- 3. Gradio UI Layout ---
|
| 445 |
print("Building Gradio interface...")
|
| 446 |
+
for d in [M2I_EXAMPLE_IMG_DIR, T2I_EXAMPLE_IMG_DIR, SR_EXAMPLE_IMG_DIR, DN_EXAMPLE_IMG_DIR, SEG_EXAMPLE_IMG_DIR, CLS_EXAMPLE_IMG_DIR]: os.makedirs(d, exist_ok=True)
|
| 447 |
+
|
| 448 |
+
# Load Examples
|
| 449 |
+
filename_to_prompt = { sanitize_prompt_for_filename(p): p for p in T2I_PROMPTS }
|
| 450 |
+
t2i_examples = []
|
| 451 |
+
for f in os.listdir(T2I_EXAMPLE_IMG_DIR):
|
| 452 |
+
if f in filename_to_prompt: t2i_examples.append((os.path.join(T2I_EXAMPLE_IMG_DIR, f), filename_to_prompt[f]))
|
| 453 |
+
|
| 454 |
+
def load_examples(d, stack=False):
|
| 455 |
+
ex = []
|
| 456 |
+
if not os.path.exists(d): return ex
|
| 457 |
+
for f in sorted(os.listdir(d)):
|
|
|
|
|
|
|
| 458 |
if f.lower().endswith(('.tif', '.tiff', '.png', '.jpg')):
|
| 459 |
+
fp = os.path.join(d, f)
|
| 460 |
try:
|
| 461 |
+
img = tifffile.imread(fp) if f.endswith(('.tif','.tiff')) else np.array(Image.open(fp).convert("L"))
|
| 462 |
+
if stack and img.ndim == 3: img = np.mean(img, axis=0)
|
| 463 |
+
ex.append((numpy_to_pil(img, "L"), fp))
|
|
|
|
| 464 |
except: pass
|
| 465 |
+
return ex
|
| 466 |
|
| 467 |
+
m2i_examples = load_examples(M2I_EXAMPLE_IMG_DIR)
|
| 468 |
+
sr_examples = load_examples(SR_EXAMPLE_IMG_DIR, stack=True)
|
| 469 |
+
dn_examples = load_examples(DN_EXAMPLE_IMG_DIR)
|
| 470 |
+
seg_examples = load_examples(SEG_EXAMPLE_IMG_DIR)
|
| 471 |
+
cls_examples = load_examples(CLS_EXAMPLE_IMG_DIR)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
|
| 473 |
+
# --- UI Builders ---
|
| 474 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 475 |
with gr.Row():
|
| 476 |
gr.Image(value=LOGO_PATH, width=300, height=200, container=False, interactive=False, show_download_button=False, show_fullscreen_button=False)
|
|
|
|
| 479 |
with gr.Tabs():
|
| 480 |
# --- TAB 1: Text-to-Image ---
|
| 481 |
with gr.Tab("Text-to-Image Generation", id="txt2img"):
|
| 482 |
+
# State to hold raw numpy data
|
| 483 |
+
t2i_raw_state = gr.State(None)
|
| 484 |
+
|
| 485 |
with gr.Row(variant="panel"):
|
| 486 |
with gr.Column(scale=1, min_width=350):
|
| 487 |
+
t2i_prompt = gr.Dropdown(choices=T2I_PROMPTS, value=T2I_PROMPTS[0], label="Search or Type a Prompt", filterable=True, allow_custom_value=True)
|
| 488 |
+
t2i_steps = gr.Slider(10, 200, 50, step=1, label="Inference Steps")
|
| 489 |
+
t2i_btn = gr.Button("Generate", variant="primary")
|
|
|
|
|
|
|
| 490 |
with gr.Column(scale=2):
|
| 491 |
+
# Image with Download Button DISABLED
|
| 492 |
+
t2i_out = gr.Image(label="Generated Image", type="pil", interactive=False, show_download_button=False)
|
| 493 |
+
|
| 494 |
+
with gr.Row(equal_height=True):
|
| 495 |
+
# Color Controls
|
| 496 |
+
with gr.Column(scale=2):
|
| 497 |
+
t2i_color = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor (Adjust after generation)")
|
| 498 |
+
with gr.Column(scale=2):
|
| 499 |
+
t2i_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False)
|
| 500 |
+
with gr.Column(scale=1):
|
| 501 |
+
t2i_dl = gr.DownloadButton(label="Download Raw (.tif)")
|
| 502 |
+
|
| 503 |
+
t2i_gal = gr.Gallery(value=t2i_examples, label="Examples", columns=6, height="auto")
|
| 504 |
+
|
| 505 |
+
# Events
|
| 506 |
+
t2i_btn.click(
|
| 507 |
+
fn=generate_t2i,
|
| 508 |
+
inputs=[t2i_prompt, t2i_steps, t2i_color],
|
| 509 |
+
outputs=[t2i_out, t2i_dl, t2i_raw_state, t2i_colorbar]
|
| 510 |
+
)
|
| 511 |
+
# Real-time color update using State
|
| 512 |
+
t2i_color.change(
|
| 513 |
+
fn=update_single_image_color,
|
| 514 |
+
inputs=[t2i_raw_state, t2i_color],
|
| 515 |
+
outputs=[t2i_out, t2i_colorbar]
|
| 516 |
+
)
|
| 517 |
+
t2i_gal.select(lambda e: e.value['caption'], None, t2i_prompt)
|
| 518 |
|
| 519 |
# --- TAB 2: Super-Resolution ---
|
| 520 |
with gr.Tab("Super-Resolution", id="super_res"):
|
| 521 |
+
sr_raw_state = gr.State(None)
|
| 522 |
with gr.Row(variant="panel"):
|
| 523 |
with gr.Column(scale=1, min_width=350):
|
| 524 |
+
sr_file = gr.File(label="Upload 9-Channel TIF Stack", file_types=['.tif', '.tiff'])
|
| 525 |
+
sr_model = gr.Dropdown(choices=list(SR_CONTROLNET_MODELS.keys()), value=list(SR_CONTROLNET_MODELS.keys())[-1], label="Model")
|
| 526 |
+
sr_prompt = gr.Textbox(label="Prompt", value="F-actin of COS-7", interactive=False)
|
| 527 |
+
sr_steps = gr.Slider(5, 50, 10, step=1, label="Steps")
|
| 528 |
+
sr_seed = gr.Number(label="Seed", value=42)
|
| 529 |
+
sr_btn = gr.Button("Generate", variant="primary")
|
|
|
|
| 530 |
with gr.Column(scale=2):
|
| 531 |
with gr.Row():
|
| 532 |
+
sr_in_disp = gr.Image(label="Input (Avg)", type="pil", interactive=False, show_download_button=False)
|
| 533 |
+
sr_out_disp = gr.Image(label="Output", type="pil", interactive=False, show_download_button=False)
|
| 534 |
+
with gr.Row(equal_height=True):
|
| 535 |
+
with gr.Column(scale=2):
|
| 536 |
+
sr_color = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor")
|
| 537 |
+
with gr.Column(scale=2):
|
| 538 |
+
sr_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False)
|
| 539 |
+
with gr.Column(scale=1):
|
| 540 |
+
sr_dl = gr.DownloadButton(label="Download Raw (.tif)")
|
| 541 |
+
|
| 542 |
+
sr_gal = gr.Gallery(value=sr_examples, label="Examples", columns=6, height="auto")
|
| 543 |
+
|
| 544 |
+
sr_model.change(update_sr_prompt, sr_model, sr_prompt)
|
| 545 |
+
sr_btn.click(run_super_resolution, [sr_file, sr_model, sr_prompt, sr_steps, sr_seed, sr_color], [sr_in_disp, sr_out_disp, sr_dl, sr_raw_state, sr_colorbar])
|
| 546 |
+
sr_color.change(update_single_image_color, [sr_raw_state, sr_color], [sr_out_disp, sr_colorbar])
|
| 547 |
+
sr_gal.select(lambda e: e.value['caption'], None, sr_file)
|
| 548 |
|
| 549 |
# --- TAB 3: Denoising ---
|
| 550 |
with gr.Tab("Denoising", id="denoising"):
|
| 551 |
+
dn_raw_state = gr.State(None)
|
| 552 |
with gr.Row(variant="panel"):
|
| 553 |
with gr.Column(scale=1, min_width=350):
|
| 554 |
+
dn_img = gr.Image(type="numpy", label="Upload Noisy Image", image_mode="L")
|
| 555 |
+
dn_type = gr.Dropdown(choices=list(DN_PROMPT_RULES.keys()), value='MICE', label="Image Type")
|
| 556 |
+
dn_steps = gr.Slider(5, 50, 10, step=1, label="Steps")
|
| 557 |
+
dn_seed = gr.Number(label="Seed", value=42)
|
| 558 |
+
dn_btn = gr.Button("Denoise", variant="primary")
|
|
|
|
| 559 |
with gr.Column(scale=2):
|
| 560 |
with gr.Row():
|
| 561 |
+
dn_orig = gr.Image(label="Original", type="pil", interactive=False, show_download_button=False)
|
| 562 |
+
dn_out = gr.Image(label="Denoised", type="pil", interactive=False, show_download_button=False)
|
| 563 |
+
with gr.Row(equal_height=True):
|
| 564 |
+
with gr.Column(scale=2):
|
| 565 |
+
dn_color = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor")
|
| 566 |
+
with gr.Column(scale=2):
|
| 567 |
+
dn_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False)
|
| 568 |
+
with gr.Column(scale=1):
|
| 569 |
+
dn_dl = gr.DownloadButton(label="Download Raw (.tif)")
|
| 570 |
+
|
| 571 |
+
dn_gal = gr.Gallery(value=dn_examples, label="Examples", columns=6, height="auto")
|
| 572 |
+
|
| 573 |
+
dn_btn.click(run_denoising, [dn_img, dn_type, dn_steps, dn_seed, dn_color], [dn_orig, dn_out, dn_dl, dn_raw_state, dn_colorbar])
|
| 574 |
+
dn_color.change(update_single_image_color, [dn_raw_state, dn_color], [dn_out, dn_colorbar])
|
| 575 |
+
dn_gal.select(lambda e: e.value['caption'], None, dn_img)
|
| 576 |
|
| 577 |
# --- TAB 4: Mask-to-Image ---
|
| 578 |
with gr.Tab("Mask-to-Image", id="mask2img"):
|
| 579 |
+
m2i_raw_state = gr.State(None) # Stores list of numpy arrays
|
| 580 |
with gr.Row(variant="panel"):
|
| 581 |
with gr.Column(scale=1, min_width=350):
|
| 582 |
+
m2i_file = gr.File(label="Upload Mask (.tif)", file_types=['.tif', '.tiff'])
|
| 583 |
+
m2i_type = gr.Textbox(label="Cell Type", placeholder="e.g., HeLa")
|
| 584 |
+
m2i_num = gr.Slider(1, 10, 5, step=1, label="Count")
|
| 585 |
+
m2i_steps = gr.Slider(5, 50, 10, step=1, label="Steps")
|
| 586 |
+
m2i_seed = gr.Number(label="Seed", value=42)
|
| 587 |
+
m2i_btn = gr.Button("Generate", variant="primary")
|
|
|
|
| 588 |
with gr.Column(scale=2):
|
| 589 |
+
m2i_gal_out = gr.Gallery(label="Generated Samples", columns=5, height="auto")
|
| 590 |
+
with gr.Row(equal_height=True):
|
| 591 |
+
with gr.Column(scale=2):
|
| 592 |
+
m2i_color = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor")
|
| 593 |
+
with gr.Column(scale=2):
|
| 594 |
+
m2i_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False)
|
| 595 |
+
with gr.Column(scale=1):
|
| 596 |
+
m2i_dl = gr.DownloadButton(label="Download ZIP")
|
| 597 |
+
m2i_in_disp = gr.Image(label="Input Mask", type="pil", interactive=False, show_download_button=False)
|
| 598 |
+
|
| 599 |
+
m2i_gal = gr.Gallery(value=m2i_examples, label="Examples", columns=6, height="auto")
|
| 600 |
+
|
| 601 |
+
m2i_btn.click(run_mask_to_image_generation, [m2i_file, m2i_type, m2i_num, m2i_steps, m2i_seed, m2i_color], [m2i_in_disp, m2i_gal_out, m2i_dl, m2i_raw_state, m2i_colorbar])
|
| 602 |
+
m2i_color.change(update_gallery_color, [m2i_raw_state, m2i_color], [m2i_gal_out, m2i_colorbar])
|
| 603 |
+
m2i_gal.select(lambda e: e.value['caption'], None, m2i_file)
|
| 604 |
|
| 605 |
# --- TAB 5: Cell Segmentation ---
|
| 606 |
with gr.Tab("Cell Segmentation", id="segmentation"):
|
| 607 |
with gr.Row(variant="panel"):
|
| 608 |
with gr.Column(scale=1, min_width=350):
|
| 609 |
+
seg_img = gr.Image(type="numpy", label="Upload Image", image_mode="L")
|
| 610 |
+
seg_model = gr.Dropdown(choices=list(SEG_MODELS.keys()), value=list(SEG_MODELS.keys())[0], label="Model")
|
| 611 |
+
seg_diam = gr.Number(label="Diameter (0=auto)", value=30)
|
| 612 |
+
seg_flow = gr.Slider(0.0, 3.0, 0.4, step=0.1, label="Flow Thresh")
|
| 613 |
+
seg_prob = gr.Slider(-6.0, 6.0, 0.0, step=0.5, label="Prob Thresh")
|
| 614 |
+
seg_btn = gr.Button("Segment", variant="primary")
|
| 615 |
with gr.Column(scale=2):
|
| 616 |
with gr.Row():
|
| 617 |
+
seg_orig = gr.Image(label="Original", type="pil", interactive=False, show_download_button=False)
|
| 618 |
+
seg_out = gr.Image(label="Overlay", type="pil", interactive=False, show_download_button=False)
|
| 619 |
+
seg_gal = gr.Gallery(value=seg_examples, label="Examples", columns=6, height="auto")
|
| 620 |
+
seg_btn.click(run_segmentation, [seg_img, seg_model, seg_diam, seg_flow, seg_prob], [seg_orig, seg_out])
|
| 621 |
+
seg_gal.select(lambda e: e.value['caption'], None, seg_img)
|
| 622 |
+
|
| 623 |
# --- TAB 6: Classification ---
|
| 624 |
with gr.Tab("Classification", id="classification"):
|
| 625 |
with gr.Row(variant="panel"):
|
| 626 |
with gr.Column(scale=1, min_width=350):
|
| 627 |
+
cls_img = gr.Image(type="numpy", label="Upload Image", image_mode="L")
|
| 628 |
+
cls_model = gr.Dropdown(choices=list(CLS_MODEL_PATHS.keys()), value=list(CLS_MODEL_PATHS.keys())[0], label="Model")
|
| 629 |
+
cls_btn = gr.Button("Classify", variant="primary")
|
| 630 |
with gr.Column(scale=2):
|
| 631 |
+
cls_orig = gr.Image(label="Input", type="pil", interactive=False, show_download_button=False)
|
| 632 |
+
cls_res = gr.Label(label="Results", num_top_classes=10)
|
| 633 |
+
cls_gal = gr.Gallery(value=cls_examples, label="Examples", columns=6, height="auto")
|
| 634 |
+
cls_btn.click(run_classification, [cls_img, cls_model], [cls_orig, cls_res])
|
| 635 |
+
cls_gal.select(lambda e: e.value['caption'], None, cls_img)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 636 |
|
|
|
|
| 637 |
if __name__ == "__main__":
|
|
|
|
| 638 |
demo.launch()
|