import torch
import numpy as np
import gradio as gr
from PIL import Image
import os
import json
import tifffile
import re
from torchvision import transforms, models
import shutil
import time
import torch.nn as nn
import torch.nn.functional as F
import spaces
from collections import OrderedDict
import tempfile
import zipfile
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import io
# --- Imports from both scripts ---
from diffusers import DDPMScheduler, DDIMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from transformers.utils import ContextManagers
# --- Custom Model Imports ---
from models.pipeline_ddpm_text_encoder import DDPMPipeline
from models.unet_2d import UNet2DModel
from models.controlnet import ControlNetModel
from models.unet_2d_condition import UNet2DConditionModel
from models.pipeline_controlnet import DDPMControlnetPipeline
# --- Segmentation Imports ---
from cellpose import models as cellpose_models
from huggingface_hub import snapshot_download
# --- 0. Configuration & Constants ---
hf_token = os.environ.get("HF_TOKEN")
MODEL_TITLE = "š¬ FluoGen: AI-Powered Fluorescence Microscopy Suite"
MODEL_DESCRIPTION = """
**Paper**: [*FluoGen: An Open-Source Generative Foundation Model for Fluorescence Microscopy Image Enhancement and Analysis*](https://doi.org/10.21203/rs.3.rs-8334792/v1)
**Homepage**: [Homepage Website](https://fluogen-group.github.io/FluoGen-HomePage/)
**Code**: [GitHub Repository](https://github.com/FluoGen-Group/FluoGen)
Select a task below.
**Note**: The "Pseudocolor" option instantly applies to both **Input** and **Output** images for better comparison. Use the "Download Raw Output" button to get the scientific 16-bit/Float data.
"""
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
WEIGHT_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
LOGO_PATH = "utils/logo2_transparent.png"
SAVE_EXAMPLES = False
# --- CSS for Times New Roman ---
CUSTOM_CSS = """
.gradio-container, .gradio-container * {
font-family: 'Arial', 'Helvetica', 'Microsoft YaHei', '微软é
é»', sans-serif !important;
}
button[aria-label="Download"],
a[download] {
display: none !important;
}
"""
# --- Base directory for all models ---
REPO_ID = "FluoGen-Group/FluoGen-demo-test-ckpts"
MODELS_ROOT_DIR = snapshot_download(repo_id=REPO_ID, token=hf_token)
# --- Paths Config ---
M2I_CONTROLNET_PATH = f"{MODELS_ROOT_DIR}/ControlNet_M2I/checkpoint-30000"
M2I_EXAMPLE_IMG_DIR = "example_images_m2i"
T2I_EXAMPLE_IMG_DIR = "example_images"
T2I_PRETRAINED_MODEL_PATH = f"{MODELS_ROOT_DIR}/stable-diffusion-v1-5"
T2I_UNET_PATH = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000"
CONTROLNET_CLIP_PATH = f"{MODELS_ROOT_DIR}/stable-diffusion-v1-5"
CONTROLNET_UNET_PATH = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000"
SR_CONTROLNET_MODELS = {
"Checkpoint ER": f"{MODELS_ROOT_DIR}/ControlNet_SR/ER/checkpoint-30000",
"Checkpoint Microtubules": f"{MODELS_ROOT_DIR}/ControlNet_SR/Microtubules/checkpoint-72500",
"Checkpoint CCPs": f"{MODELS_ROOT_DIR}/ControlNet_SR/CCPs/checkpoint-100000",
"Checkpoint F-actin": f"{MODELS_ROOT_DIR}/ControlNet_SR/F-actin/checkpoint-35000",
}
SR_EXAMPLE_IMG_DIR = "example_images_sr"
DN_CONTROLNET_PATH = f"{MODELS_ROOT_DIR}/ControlNet_DN/checkpoint-10000"
DN_PROMPT_RULES = {'MICE': 'mouse brain tissues', 'FISH': 'zebrafish embryos', 'BPAE_B': 'nucleus of BPAE', 'BPAE_R': 'mitochondria of BPAE', 'BPAE_G': 'F-actin of BPAE'}
DN_EXAMPLE_IMG_DIR = "example_images_dn"
SEG_MODELS = {
#"DynamicNet Model": f"{MODELS_ROOT_DIR}/Cellpose/DynamicNet_baseline/CP_dynamic_ten_epoch_0100",
#"DynamicNet Model + FluoGen": f"{MODELS_ROOT_DIR}/Cellpose/DynamicNet_FluoGen/CP_dynamic_epoch_0300",
#"DSB Model": f"{MODELS_ROOT_DIR}/Cellpose/DSB_baseline/CP_dsb_baseline_ratio_1_epoch_0135",
"Cellpose Augmented by FluoGen": f"{MODELS_ROOT_DIR}/Cellpose/DSB_FluoGen/CP_dsb_ten_epoch_0135",
}
SEG_EXAMPLE_IMG_DIR = "example_images_seg"
CLS_MODEL_PATHS = OrderedDict({
#"5shot": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot",
"ResNet-50 Augmented by FluoGen": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot_aug",
})
CLS_CLASS_NAMES = ['Nucleus', 'Endoplasmic Reticulum', 'Giantin', 'GPP130', 'Lysosomes', 'Mitochondria', 'Nucleolus', 'Actin', 'Endosomes', 'Microtubules']
CLS_EXAMPLE_IMG_DIR = "example_images_cls"
# --- Constants for Visualization ---
COLOR_MAPS = [
"Grayscale",
"Green (GFP)",
"Red (RFP)",
"Blue (DAPI)",
"Magenta",
"Cyan",
"Yellow",
"Fire",
"Viridis",
"Inferno",
"Magma",
"Plasma",
"Red Hot",
"Cyan Hot",
"Magenta Hot"
]
CYAN_HOT_POINTS = [
(0.00, (0, 0, 0)),
(0.20, (0, 69, 143)),
(0.50, (0, 183, 255)),
(0.80, (89, 255, 255)),
(1.00, (255, 255, 255)),
]
def get_rgb_interpolation(val_norm, rgb_points):
x_pts = [p[0] for p in rgb_points]
r_pts = [p[1][0] / 255.0 for p in rgb_points] # å½äøåå° 0-1
g_pts = [p[1][1] / 255.0 for p in rgb_points]
b_pts = [p[1][2] / 255.0 for p in rgb_points]
r = np.interp(val_norm, x_pts, r_pts)
g = np.interp(val_norm, x_pts, g_pts)
b = np.interp(val_norm, x_pts, b_pts)
return r, g, b
# --- Helper Functions ---
def sanitize_prompt_for_filename(prompt):
prompt = prompt.lower(); prompt = re.sub(r'\s+of\s+', '_', prompt); prompt = re.sub(r'[^a-z0-9-_]+', '', prompt)
return f"{prompt}.png"
def min_max_norm(x):
x = x.astype(np.float32); min_val, max_val = np.min(x), np.max(x)
if max_val - min_val < 1e-8: return np.zeros_like(x)
return (x - min_val) / (max_val - min_val)
def generate_colorbar_preview(color_name):
"""Generates a small PIL image representing the colormap."""
if color_name == "Grayscale":
gradient = np.linspace(0, 1, 256).reshape(1, 256)
return Image.fromarray((gradient * 255).astype(np.uint8)).convert("RGB").resize((256, 30))
gradient = np.linspace(0, 1, 256).reshape(1, 256)
rgb = np.zeros((1, 256, 3))
if "Hot" in color_name:
low_half = np.clip(gradient * 2, 0, 1)
high_half = np.clip((gradient - 0.5) * 2, 0, 1)
if color_name == "Magenta Hot":
rgb[..., 0] = low_half
rgb[..., 1] = high_half
rgb[..., 2] = low_half
else:
step_1_red = np.clip(gradient * 3, 0, 1)
step_2_red = np.clip((gradient - 0.333) * 3, 0, 1)
step_3_red = np.clip((gradient - 0.666) * 3, 0, 1)
if color_name == "Red Hot":
rgb[..., 0] = step_1_red
rgb[..., 1] = step_2_red
rgb[..., 2] = step_3_red
elif color_name == "Cyan Hot":
r, g, b = get_rgb_interpolation(gradient, CYAN_HOT_POINTS)
rgb = np.stack([r, g, b], axis=-1)
elif color_name == "Green (GFP)": rgb[..., 1] = gradient
elif color_name == "Red (RFP)": rgb[..., 0] = gradient
elif color_name == "Blue (DAPI)": rgb[..., 2] = gradient
elif color_name == "Magenta": rgb[..., 0] = gradient; rgb[..., 2] = gradient
elif color_name == "Cyan": rgb[..., 1] = gradient; rgb[..., 2] = gradient
elif color_name == "Yellow": rgb[..., 0] = gradient; rgb[..., 1] = gradient
else:
# Matplotlib maps
mpl_map_name = color_name.lower()
if color_name == "Fire": mpl_map_name = "gnuplot2"
try:
cmap = cm.get_cmap(mpl_map_name)
rgb = cmap(gradient)[..., :3]
except:
return generate_colorbar_preview("Grayscale") # Fallback
img_np = (rgb * 255).astype(np.uint8)
return Image.fromarray(img_np).resize((256, 30))
def apply_pseudocolor(image_np, color_name="Grayscale"):
"""
Applies a pseudocolor to a single channel numpy image.
Returns: PIL Image in RGB.
"""
if image_np is None: return None
# Normalize to 0-1 for processing
norm_img = min_max_norm(np.squeeze(image_np))
if color_name == "Grayscale":
return Image.fromarray((norm_img * 255).astype(np.uint8)).convert("RGB")
h, w = norm_img.shape
rgb = np.zeros((h, w, 3), dtype=np.float32)
if "Hot" in color_name:
low_half = np.clip(norm_img * 2, 0, 1)
high_half = np.clip((norm_img - 0.5) * 2, 0, 1)
if color_name == "Magenta Hot":
rgb[..., 0] = low_half
rgb[..., 1] = high_half
rgb[..., 2] = low_half
else:
step_1_red = np.clip(norm_img * 3, 0, 1)
step_2_red = np.clip((norm_img - 0.333) * 3, 0, 1)
step_3_red = np.clip((norm_img - 0.666) * 3, 0, 1)
if color_name == "Red Hot":
rgb[..., 0] = step_1_red
rgb[..., 1] = step_2_red
rgb[..., 2] = step_3_red
elif color_name == "Cyan Hot":
r, g, b = get_rgb_interpolation(norm_img, CYAN_HOT_POINTS)
rgb = np.stack([r, g, b], axis=-1)
elif color_name == "Green (GFP)": rgb[..., 1] = norm_img
elif color_name == "Red (RFP)": rgb[..., 0] = norm_img
elif color_name == "Blue (DAPI)": rgb[..., 2] = norm_img
elif color_name == "Magenta": rgb[..., 0] = norm_img; rgb[..., 2] = norm_img
elif color_name == "Cyan": rgb[..., 1] = norm_img; rgb[..., 2] = norm_img
elif color_name == "Yellow": rgb[..., 0] = norm_img; rgb[..., 1] = norm_img
else:
# Matplotlib maps
mpl_map_name = color_name.lower()
if color_name == "Fire": mpl_map_name = "gnuplot2"
try:
cmap = cm.get_cmap(mpl_map_name)
colored = cmap(norm_img)
rgb = colored[..., :3]
except:
return apply_pseudocolor(image_np, "Grayscale")
return Image.fromarray((rgb * 255).astype(np.uint8))
def update_sr_settings(model_name):
"""
- Microtubules -> Cyan Hot
- F-actin -> Red Hot
- CCPs -> Green (GFP)
- ER -> Magenta Hot
"""
if model_name == "Checkpoint ER":
return "ER of COS-7", "Magenta Hot"
if model_name == "Checkpoint Microtubules":
return "Microtubules of COS-7", "Cyan Hot"
if model_name == "Checkpoint CCPs":
return "CCPs of COS-7", "Green (GFP)"
elif model_name == "Checkpoint F-actin":
return "F-actin of COS-7", "Red Hot"
return "", "Grayscale"
def save_temp_tiff(image_np, prefix="output"):
tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".tif", prefix=f"{prefix}_")
if image_np.dtype == np.float16: save_data = image_np.astype(np.float32)
else: save_data = image_np
tifffile.imwrite(tfile.name, save_data)
return tfile.name
def numpy_to_pil(image_np, target_mode="RGB"):
if isinstance(image_np, Image.Image):
if target_mode == "RGB" and image_np.mode != "RGB": return image_np.convert("RGB")
if target_mode == "L" and image_np.mode != "L": return image_np.convert("L")
return image_np
squeezed_np = np.squeeze(image_np);
if squeezed_np.dtype == np.uint8: image_8bit = squeezed_np
else:
normalized_np = min_max_norm(squeezed_np)
image_8bit = (normalized_np * 255).astype(np.uint8)
pil_image = Image.fromarray(image_8bit)
if target_mode == "RGB" and pil_image.mode != "RGB": pil_image = pil_image.convert("RGB")
elif target_mode == "L" and pil_image.mode != "L": pil_image = pil_image.convert("L")
return pil_image
def update_sr_prompt(model_name):
if model_name == "Checkpoint ER": return "ER of COS-7"
if model_name == "Checkpoint Microtubules": return "Microtubules of COS-7"
if model_name == "Checkpoint CCPs": return "CCPs of COS-7"
elif model_name == "Checkpoint F-actin": return "F-actin of COS-7"
return ""
PROMPT_TO_MODEL_MAP = {}
current_t2i_unet_path = None
def load_all_prompts():
global PROMPT_TO_MODEL_MAP
categories = [
{"file": "prompts/basic_prompts.json", "model": f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000"},
{"file": "prompts/others_prompts.json", "model": f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/FULL-checkpoint-275000"},
{"file": "prompts/hpa_prompts.json", "model": f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/HPA-checkpoint-40000"}
]
combined_prompts = []
for cat in categories:
try:
if os.path.exists(cat["file"]):
with open(cat["file"], "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, list):
combined_prompts.extend(data)
for p in data: PROMPT_TO_MODEL_MAP[p] = cat["model"]
except Exception: pass
if not combined_prompts: return ["F-actin of COS-7", "ER of COS-7"]
return combined_prompts
T2I_PROMPTS = load_all_prompts()
# --- 1. Model Loading ---
print("--- Initializing FluoGen Application ---")
t2i_pipe, controlnet_pipe = None, None
try:
print("Loading Text-to-Image model...")
t2i_noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="v_prediction", rescale_betas_zero_snr=True, timestep_spacing="trailing")
t2i_unet = UNet2DModel.from_pretrained(T2I_UNET_PATH, subfolder="unet")
t2i_text_encoder = CLIPTextModel.from_pretrained(T2I_PRETRAINED_MODEL_PATH, subfolder="text_encoder").to(DEVICE)
t2i_tokenizer = CLIPTokenizer.from_pretrained(T2I_PRETRAINED_MODEL_PATH, subfolder="tokenizer")
t2i_pipe = DDPMPipeline(unet=t2i_unet, scheduler=t2i_noise_scheduler, text_encoder=t2i_text_encoder, tokenizer=t2i_tokenizer)
t2i_pipe.to(DEVICE)
current_t2i_unet_path = T2I_UNET_PATH
print("ā Text-to-Image model loaded successfully!")
except Exception as e:
print(f"FATAL: Text-to-Image Model Loading Failed: {e}")
try:
print("Loading shared ControlNet pipeline...")
controlnet_unet = UNet2DConditionModel.from_pretrained(CONTROLNET_UNET_PATH, subfolder="unet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
controlnet_controlnet = ControlNetModel.from_pretrained(M2I_CONTROLNET_PATH, subfolder="controlnet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
controlnet_scheduler = DDIMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="v_prediction", rescale_betas_zero_snr=False, timestep_spacing="trailing")
controlnet_tokenizer = CLIPTokenizer.from_pretrained(CONTROLNET_CLIP_PATH, subfolder="tokenizer")
with ContextManagers([]):
controlnet_text_encoder = CLIPTextModel.from_pretrained(CONTROLNET_CLIP_PATH, subfolder="text_encoder").to(dtype=WEIGHT_DTYPE, device=DEVICE)
controlnet_pipe = DDPMControlnetPipeline(unet=controlnet_unet, controlnet=controlnet_controlnet, scheduler=controlnet_scheduler, text_encoder=controlnet_text_encoder, tokenizer=controlnet_tokenizer)
controlnet_pipe.to(dtype=WEIGHT_DTYPE, device=DEVICE)
controlnet_pipe.current_controlnet_path = M2I_CONTROLNET_PATH
print("ā Shared ControlNet pipeline loaded successfully!")
except Exception as e:
print(f"FATAL: ControlNet Pipeline Loading Failed: {e}")
# --- 2. Core Logic Functions ---
def swap_controlnet(pipe, target_path):
if os.path.normpath(getattr(pipe, 'current_controlnet_path', '')) != os.path.normpath(target_path):
print(f"Swapping ControlNet model to: {target_path}")
try:
pipe.controlnet = ControlNetModel.from_pretrained(target_path, subfolder="controlnet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
pipe.current_controlnet_path = target_path
except Exception as e:
raise gr.Error(f"Failed to load ControlNet model. Error: {e}")
return pipe
def swap_t2i_unet(pipe, target_unet_path):
global current_t2i_unet_path
target_unet_path = os.path.normpath(target_unet_path)
if current_t2i_unet_path is None or os.path.normpath(current_t2i_unet_path) != target_unet_path:
print(f"š Swapping T2I UNet to: {target_unet_path}")
try:
new_unet = UNet2DModel.from_pretrained(target_unet_path, subfolder="unet").to(DEVICE)
pipe.unet = new_unet
current_t2i_unet_path = target_unet_path
except Exception as e:
raise gr.Error(f"Failed to load UNet. Error: {e}")
return pipe
# --- Dynamic Color Update Functions ---
def update_single_image_color(raw_np_state, color_name):
if raw_np_state is None: return None, None
display_img = apply_pseudocolor(raw_np_state, color_name)
bar_img = generate_colorbar_preview(color_name)
return display_img, bar_img
def update_pair_color(input_np_state, output_np_state, color_name):
"""Updates both input and output images with the selected pseudocolor."""
if input_np_state is None: in_img = None
else: in_img = apply_pseudocolor(input_np_state, color_name)
if output_np_state is None: out_img = None
else: out_img = apply_pseudocolor(output_np_state, color_name)
bar_img = generate_colorbar_preview(color_name)
return in_img, out_img, bar_img
def update_gallery_color(raw_list_state, color_name):
if raw_list_state is None: return None, None
new_gallery = []
for img_np in raw_list_state:
new_gallery.append(apply_pseudocolor(img_np, color_name))
bar_img = generate_colorbar_preview(color_name)
return new_gallery, bar_img
# --- Event Handler Helper ---
def get_gallery_selection(evt: gr.SelectData):
return evt.value['caption']
# --- Generation Functions ---
@spaces.GPU(duration=120)
def generate_t2i(prompt, num_inference_steps, num_images, current_color, height=512, width=512):
"""
Generates multiple images for Text-to-Image and returns a gallery.
"""
global t2i_pipe
if t2i_pipe is None: raise gr.Error("Text-to-Image model is not loaded.")
target_model_path = PROMPT_TO_MODEL_MAP.get(prompt, f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/FULL-checkpoint-275000")
t2i_pipe = swap_t2i_unet(t2i_pipe, target_model_path)
print(f"\nš T2I Task started... | Prompt: '{prompt}' | Count: {num_images} | Size: {height}x{width}")
generated_raw_list = []
generated_display_images = []
generated_raw_files = []
temp_dir = tempfile.mkdtemp()
# Generate Batch
for i in range(int(num_images)):
# Generate single image
image_np = t2i_pipe(
prompt.lower(),
generator=None,
num_inference_steps=int(num_inference_steps),
output_type="np",
height=int(height),
width=int(width)
).images
generated_raw_list.append(image_np)
# Save raw to temp
raw_name = f"t2i_sample_{i+1}.tif"
raw_path = os.path.join(temp_dir, raw_name)
save_data = image_np.astype(np.float32) if image_np.dtype == np.float16 else image_np
tifffile.imwrite(raw_path, save_data)
generated_raw_files.append(raw_path)
# Create display version
generated_display_images.append(apply_pseudocolor(image_np, current_color))
# Save first image to examples if needed
if SAVE_EXAMPLES and i == 0:
example_filepath = os.path.join(T2I_EXAMPLE_IMG_DIR, sanitize_prompt_for_filename(prompt))
if not os.path.exists(example_filepath): generated_display_images[0].save(example_filepath)
# Zip raw files
zip_filename = os.path.join(temp_dir, "raw_output_images.zip")
with zipfile.ZipFile(zip_filename, 'w') as zipf:
for file in generated_raw_files: zipf.write(file, os.path.basename(file))
colorbar_img = generate_colorbar_preview(current_color)
# Return: Gallery List, Zip Path, Raw State List, Colorbar
return generated_display_images, zip_filename, generated_raw_list, colorbar_img
@spaces.GPU(duration=120)
def run_mask_to_image_generation(mask_file_obj, cell_type, num_images, steps, seed, current_color):
if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
if mask_file_obj is None: raise gr.Error("Please upload a segmentation mask.")
pipe = swap_controlnet(controlnet_pipe, M2I_CONTROLNET_PATH)
try: mask_np = tifffile.imread(mask_file_obj.name)
except Exception as e: raise gr.Error(f"Failed to read TIF. Error: {e}")
input_display = numpy_to_pil(mask_np, "L")
mask_normalized = min_max_norm(mask_np)
image_tensor = torch.from_numpy(mask_normalized.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE)
image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor)
prompt = f"nuclei of {cell_type.strip()}"
print(f"\nM2I Task started... | Prompt: '{prompt}'")
generated_raw_list = []
generated_display_images = []
generated_raw_files = []
temp_dir = tempfile.mkdtemp()
for i in range(int(num_images)):
generator = torch.Generator(device=DEVICE).manual_seed(int(seed) + i)
with torch.autocast("cuda"):
output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
generated_raw_list.append(output_np)
raw_name = f"m2i_sample_{i+1}.tif"
raw_path = os.path.join(temp_dir, raw_name)
save_data = output_np.astype(np.float32) if output_np.dtype == np.float16 else output_np
tifffile.imwrite(raw_path, save_data)
generated_raw_files.append(raw_path)
generated_display_images.append(apply_pseudocolor(output_np, current_color))
zip_filename = os.path.join(temp_dir, "raw_output_images.zip")
with zipfile.ZipFile(zip_filename, 'w') as zipf:
for file in generated_raw_files: zipf.write(file, os.path.basename(file))
colorbar_img = generate_colorbar_preview(current_color)
return input_display, generated_display_images, zip_filename, generated_raw_list, colorbar_img
@spaces.GPU(duration=120)
def run_super_resolution(low_res_file_obj, controlnet_model_name, prompt, steps, seed, current_color):
if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
if low_res_file_obj is None: raise gr.Error("Please upload a file.")
target_path = SR_CONTROLNET_MODELS.get(controlnet_model_name)
pipe = swap_controlnet(controlnet_pipe, target_path)
try: image_stack_np = tifffile.imread(low_res_file_obj.name)
except Exception as e: raise gr.Error(f"Failed to read TIF. Error: {e}")
if image_stack_np.ndim != 3 or image_stack_np.shape[-3] != 9:
raise gr.Error(f"Invalid TIF shape. Expected 9 channels, got {image_stack_np.shape}.")
# Calculate Average Projection for Input
avg_projection_np = np.mean(image_stack_np, axis=0)
# Preprocess for model
image_tensor = torch.from_numpy(image_stack_np.astype(np.float32) / 65535.0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE)
image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor)
print(f"\nSR Task started...")
generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
with torch.autocast("cuda"):
output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
raw_file_path = save_temp_tiff(output_np, prefix="sr_raw")
# Generate displays with current color
input_display = apply_pseudocolor(avg_projection_np, current_color)
output_display = apply_pseudocolor(output_np, current_color)
colorbar_img = generate_colorbar_preview(current_color)
# Return: Input Disp, Output Disp, Download Path, Input State, Output State, Colorbar
return input_display, output_display, raw_file_path, avg_projection_np, output_np, colorbar_img
@spaces.GPU(duration=120)
def run_denoising(noisy_image_np, image_type, steps, seed, current_color):
if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
if noisy_image_np is None: raise gr.Error("Please upload an image.")
pipe = swap_controlnet(controlnet_pipe, DN_CONTROLNET_PATH)
prompt = DN_PROMPT_RULES.get(image_type, 'microscopy image')
print(f"\nDN Task started...")
image_tensor = torch.from_numpy(noisy_image_np.astype(np.float32) / 255.0).unsqueeze(0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE)
image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor)
generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
with torch.autocast("cuda"):
output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
raw_file_path = save_temp_tiff(output_np, prefix="dn_raw")
# Generate displays with current color
input_display = apply_pseudocolor(noisy_image_np, current_color)
output_display = apply_pseudocolor(output_np, current_color)
colorbar_img = generate_colorbar_preview(current_color)
# Return: Input Disp, Output Disp, Download Path, Input State, Output State, Colorbar
return input_display, output_display, raw_file_path, noisy_image_np, output_np, colorbar_img
# --- Segmentation & Classification ---
@spaces.GPU(duration=120)
def run_segmentation(input_image_np, model_name, diameter, flow_threshold, cellprob_threshold):
if input_image_np is None: raise gr.Error("Please upload an image.")
model_path = SEG_MODELS.get(model_name)
print(f"\nSeg Task started...")
try:
model = cellpose_models.CellposeModel(gpu=torch.cuda.is_available(), pretrained_model=model_path)
masks, _, _ = model.eval([input_image_np], channels=[0, 0], diameter=(model.diam_labels if diameter==0 else diameter), flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold)
except Exception as e: raise gr.Error(f"Segmentation failed: {e}")
original_rgb = numpy_to_pil(input_image_np, "RGB")
red_mask = np.zeros_like(np.array(original_rgb)); red_mask[masks[0] > 0] = [139, 0, 0]
blended = ((0.6 * np.array(original_rgb) + 0.4 * red_mask).astype(np.uint8))
return numpy_to_pil(input_image_np, "L"), numpy_to_pil(blended, "RGB")
@spaces.GPU(duration=120)
def run_classification(input_image_np, model_name):
if input_image_np is None: raise gr.Error("Please upload an image.")
model_path = os.path.join(CLS_MODEL_PATHS.get(model_name), "best_resnet50.pth")
print(f"\nCls Task started...")
try:
model = models.resnet50(weights=None); model.fc = nn.Linear(model.fc.in_features, len(CLS_CLASS_NAMES))
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.to(DEVICE).eval()
except Exception as e: raise gr.Error(f"Classification failed: {e}")
input_tensor = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)])(numpy_to_pil(input_image_np, "RGB")).unsqueeze(0).to(DEVICE)
with torch.no_grad():
probs = F.softmax(model(input_tensor), dim=1).squeeze().cpu().numpy()
return numpy_to_pil(input_image_np, "L"), {name: float(p) for name, p in zip(CLS_CLASS_NAMES, probs)}
# --- 3. Gradio UI Layout ---
print("Building Gradio interface...")
for d in [M2I_EXAMPLE_IMG_DIR, T2I_EXAMPLE_IMG_DIR, SR_EXAMPLE_IMG_DIR, DN_EXAMPLE_IMG_DIR, SEG_EXAMPLE_IMG_DIR, CLS_EXAMPLE_IMG_DIR]: os.makedirs(d, exist_ok=True)
# Load Examples
filename_to_prompt = { sanitize_prompt_for_filename(p): p for p in T2I_PROMPTS }
t2i_examples = []
for f in os.listdir(T2I_EXAMPLE_IMG_DIR):
if f in filename_to_prompt: t2i_examples.append((os.path.join(T2I_EXAMPLE_IMG_DIR, f), filename_to_prompt[f]))
def load_examples(d, stack=False):
ex = []
if not os.path.exists(d): return ex
for f in sorted(os.listdir(d)):
if f.lower().endswith(('.tif', '.tiff', '.png', '.jpg')):
fp = os.path.join(d, f)
try:
img = tifffile.imread(fp) if f.endswith(('.tif','.tiff')) else np.array(Image.open(fp).convert("L"))
if stack and img.ndim == 3: img = np.mean(img, axis=0)
ex.append((numpy_to_pil(img, "L"), f))
except: pass
return ex
def create_path_selector(base_dir):
def handler(evt: gr.SelectData):
# å°ē®å½č·Æå¾åē¹å»ēęä»¶å(caption)ę¼ę„
return os.path.join(base_dir, evt.value['caption'])
return handler
m2i_examples = load_examples(M2I_EXAMPLE_IMG_DIR)
sr_examples = load_examples(SR_EXAMPLE_IMG_DIR, stack=True)
dn_examples = load_examples(DN_EXAMPLE_IMG_DIR)
seg_examples = load_examples(SEG_EXAMPLE_IMG_DIR)
cls_examples = load_examples(CLS_EXAMPLE_IMG_DIR)
# --- UI Builders ---
with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
with gr.Row():
gr.Image(value=LOGO_PATH, width=300, height=200, container=False, interactive=False, show_download_button=False, show_fullscreen_button=False)
gr.Markdown(f"# {MODEL_TITLE}\n{MODEL_DESCRIPTION}")
with gr.Tabs():
# --- TAB 1: Text-to-Image ---
with gr.Tab("Text-to-Image Generation", id="txt2img"):
t2i_raw_state = gr.State(None) # Stores list of arrays
gr.Markdown("""
### Instructions
1. Select a desired prompt from the dropdown menu.
2. Adjust the 'Inference Steps' slider to control generation quality.
3. Click the 'Generate' button to create a new image.
4. Explore the 'Examples' gallery; clicking an image will load its prompt.
**Notice:** This model currently supports 3566 prompt categories. However, data for many cell structures and lines is still lacking. We welcome data source contributions to improve the model.
""")
with gr.Row(variant="panel"):
with gr.Column(scale=1, min_width=350):
t2i_prompt = gr.Dropdown(choices=T2I_PROMPTS, value=T2I_PROMPTS[0], label="Search or Type a Prompt", filterable=True, allow_custom_value=True)
t2i_steps = gr.Slider(10, 200, 50, step=1, label="Inference Steps")
# Added: Number of Images Slider
t2i_num_images = gr.Slider(1, 9, 3, step=1, label="Number of Images")
# with gr.Row():
# t2i_height = gr.Slider(512, 1024, value=512, step=64, label="Height")
# t2i_width = gr.Slider(512, 1024, value=512, step=64, label="Width")
t2i_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=2):
# Changed: Image to Gallery
t2i_gallery_out = gr.Gallery(label="Generated Images", columns=3, height="auto", show_download_button=False, show_share_button=False)
with gr.Row(equal_height=True):
with gr.Column(scale=2):
t2i_color = gr.Dropdown(choices=COLOR_MAPS, value="Green (GFP)", label="Pseudocolor (Adjust after generation)")
with gr.Column(scale=2):
t2i_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False)
with gr.Column(scale=1):
t2i_dl = gr.DownloadButton(label="Download All (.zip)")
t2i_gal = gr.Gallery(value=t2i_examples, label="Examples", columns=6, height="auto")
t2i_btn.click(generate_t2i, [t2i_prompt, t2i_steps, t2i_num_images, t2i_color], [t2i_gallery_out, t2i_dl, t2i_raw_state, t2i_colorbar])
# t2i_btn.click(
# generate_t2i,
# inputs=[t2i_prompt, t2i_steps, t2i_num_images, t2i_color, t2i_height, t2i_width],
# outputs=[t2i_gallery_out, t2i_dl, t2i_raw_state, t2i_colorbar]
# )
# Reuse update_gallery_color since state is now a list
t2i_color.change(update_gallery_color, [t2i_raw_state, t2i_color], [t2i_gallery_out, t2i_colorbar])
t2i_gal.select(fn=get_gallery_selection, inputs=None, outputs=t2i_prompt)
# --- TAB 2: Super-Resolution ---
with gr.Tab("Super-Resolution", id="super_res"):
# Stores: Input (Average Projection) and Output
sr_input_state = gr.State(None)
sr_raw_state = gr.State(None)
gr.Markdown("""
### Instructions
1. Upload a low-resolution 9-channel TIF stack, or select one from the examples.
2. Select a 'Super-Resolution Model' from the dropdown.
3. Enter a descriptive 'Prompt' related to the image content (e.g., 'CCPs of COS-7').
4. Adjust 'Inference Steps' and 'Seed' as needed.
5. Click 'Generate Super-Resolution' to process the image.
**Notice:** This model was trained on the **BioSR** dataset. If your data's characteristics differ significantly, please consider fine-tuning the model using our project on GitHub for optimal results.
""")
with gr.Row(variant="panel"):
with gr.Column(scale=1, min_width=350):
sr_file = gr.File(label="Upload 9-Channel TIF Stack", file_types=['.tif', '.tiff'])
sr_model = gr.Dropdown(choices=list(SR_CONTROLNET_MODELS.keys()), value=list(SR_CONTROLNET_MODELS.keys())[-1], label="Model")
sr_prompt = gr.Textbox(label="Prompt", value="F-actin of COS-7", interactive=False)
sr_steps = gr.Slider(5, 50, 10, step=1, label="Steps")
sr_seed = gr.Number(label="Seed", value=42)
sr_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=2):
with gr.Row():
# Input Display now shows Pseudocolor
sr_in_disp = gr.Image(label="Input (Avg Projection)", type="pil", interactive=False, show_download_button=False)
sr_out_disp = gr.Image(label="Output", type="pil", interactive=False, show_download_button=False)
with gr.Row(equal_height=True):
with gr.Column(scale=2):
sr_color = gr.Dropdown(choices=COLOR_MAPS, value="Red Hot", label="Pseudocolor")
with gr.Column(scale=2):
sr_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False)
with gr.Column(scale=1):
sr_dl = gr.DownloadButton(label="Download Raw Output (.tif)")
sr_gal = gr.Gallery(value=sr_examples, label="Examples", columns=6, height="auto")
# sr_model.change(update_sr_prompt, sr_model, sr_prompt)
sr_model.change(update_sr_settings, inputs=sr_model, outputs=[sr_prompt, sr_color])
# Run returns both input/output states and displays
sr_btn.click(run_super_resolution, [sr_file, sr_model, sr_prompt, sr_steps, sr_seed, sr_color], [sr_in_disp, sr_out_disp, sr_dl, sr_input_state, sr_raw_state, sr_colorbar])
# Change color updates both displays
sr_color.change(update_pair_color, [sr_input_state, sr_raw_state, sr_color], [sr_in_disp, sr_out_disp, sr_colorbar])
sr_gal.select(fn=create_path_selector(SR_EXAMPLE_IMG_DIR), inputs=None, outputs=sr_file)
# --- TAB 3: Denoising ---
with gr.Tab("Denoising", id="denoising"):
# Stores: Input (Noisy) and Output
dn_input_state = gr.State(None)
dn_raw_state = gr.State(None)
gr.Markdown("""
### Instructions
1. Upload a noisy single-channel image, or select one from the examples.
2. Select the 'Image Type' from the dropdown to provide context for the model.
3. Adjust 'Inference Steps' and 'Seed' as needed.
4. Click 'Denoise Image' to reduce the noise.
**Notice:** This model was trained on the **FMD** dataset. If your data's characteristics differ significantly, please consider fine-tuning the model using our project on GitHub for optimal results.
""")
with gr.Row(variant="panel"):
with gr.Column(scale=1, min_width=350):
dn_img = gr.Image(type="numpy", label="Upload Noisy Image", image_mode="L")
dn_type = gr.Dropdown(choices=list(DN_PROMPT_RULES.keys()), value='MICE', label="Image Type")
dn_steps = gr.Slider(5, 50, 10, step=1, label="Steps")
dn_seed = gr.Number(label="Seed", value=42)
dn_btn = gr.Button("Denoise", variant="primary")
with gr.Column(scale=2):
with gr.Row():
# Input Display now shows Pseudocolor
dn_orig = gr.Image(label="Original", type="pil", interactive=False, show_download_button=False)
dn_out = gr.Image(label="Denoised", type="pil", interactive=False, show_download_button=False)
with gr.Row(equal_height=True):
with gr.Column(scale=2):
dn_color = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor")
with gr.Column(scale=2):
dn_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False)
with gr.Column(scale=1):
dn_dl = gr.DownloadButton(label="Download Raw Output (.tif)")
dn_gal = gr.Gallery(value=dn_examples, label="Examples", columns=6, height="auto")
dn_btn.click(run_denoising, [dn_img, dn_type, dn_steps, dn_seed, dn_color], [dn_orig, dn_out, dn_dl, dn_input_state, dn_raw_state, dn_colorbar])
dn_color.change(update_pair_color, [dn_input_state, dn_raw_state, dn_color], [dn_orig, dn_out, dn_colorbar])
dn_gal.select(fn=create_path_selector(DN_EXAMPLE_IMG_DIR), inputs=None, outputs=dn_img)
# --- TAB 4: Mask-to-Image ---
with gr.Tab("Mask-to-Image", id="mask2img"):
m2i_raw_state = gr.State(None)
gr.Markdown("""
### Instructions
1. Upload a single-channel segmentation mask (`.tif` file), or select one from the examples gallery below.
2. Enter the corresponding 'Cell Type' (e.g., 'CoNSS', 'HeLa') to create the prompt.
3. Select how many sample images you want to generate.
4. Adjust 'Inference Steps' and 'Seed' as needed.
5. Click 'Generate Training Samples' to start the process.
6. The 'Generated Samples' will appear in the main gallery, with the 'Input Mask' shown below for reference.
**Notice:** This model was trained on the **2018 Data Science Bowl** dataset. If your data's characteristics differ significantly, please consider fine-tuning the model using our project on GitHub for optimal results.
""")
with gr.Row(variant="panel"):
with gr.Column(scale=1, min_width=350):
m2i_file = gr.File(label="Upload Mask (.tif)", file_types=['.tif', '.tiff'])
# Changed: Default value to HeLa
m2i_type = gr.Textbox(label="Cell Type", value="HeLa", placeholder="e.g., HeLa")
m2i_num = gr.Slider(1, 10, 5, step=1, label="Count")
m2i_steps = gr.Slider(5, 50, 10, step=1, label="Steps")
m2i_seed = gr.Number(label="Seed", value=42)
m2i_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=2):
m2i_gal_out = gr.Gallery(label="Generated Samples", columns=5, height="auto")
with gr.Row(equal_height=True):
with gr.Column(scale=2):
m2i_color = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor")
with gr.Column(scale=2):
m2i_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False)
with gr.Column(scale=1):
m2i_dl = gr.DownloadButton(label="Download ZIP")
m2i_in_disp = gr.Image(label="Input Mask", type="pil", interactive=False, show_download_button=False)
m2i_gal = gr.Gallery(value=m2i_examples, label="Examples", columns=6, height="auto")
m2i_btn.click(run_mask_to_image_generation, [m2i_file, m2i_type, m2i_num, m2i_steps, m2i_seed, m2i_color], [m2i_in_disp, m2i_gal_out, m2i_dl, m2i_raw_state, m2i_colorbar])
m2i_color.change(update_gallery_color, [m2i_raw_state, m2i_color], [m2i_gal_out, m2i_colorbar])
m2i_gal.select(fn=create_path_selector(M2I_EXAMPLE_IMG_DIR), inputs=None, outputs=m2i_file)
# --- TAB 5: Cell Segmentation ---
with gr.Tab("Cell Segmentation", id="segmentation"):
gr.Markdown("""
### Instructions
1. Upload a single-channel image for segmentation, or select one from the examples.
2. Select a 'Segmentation Model' from the dropdown menu.
3. Set the expected 'Diameter' of the cells in pixels. Set to 0 to let the model automatically estimate it.
4. Adjust 'Flow Threshold' and 'Cell Probability Threshold' for finer control.
5. Click 'Segment Cells'. The result will be shown as a dark red overlay on the original image.
**Notice:** This model was trained on the **2018 Data Science Bowl** dataset. If your data's characteristics differ significantly, please consider fine-tuning the model using our project on GitHub for optimal results.
""")
with gr.Row(variant="panel"):
with gr.Column(scale=1, min_width=350):
seg_img = gr.Image(type="numpy", label="Upload Image", image_mode="L")
seg_model = gr.Dropdown(choices=list(SEG_MODELS.keys()), value=list(SEG_MODELS.keys())[0], label="Model")
seg_diam = gr.Number(label="Diameter (0=auto)", value=30)
seg_flow = gr.Slider(0.0, 3.0, 0.4, step=0.1, label="Flow Thresh")
seg_prob = gr.Slider(-6.0, 6.0, 0.0, step=0.5, label="Prob Thresh")
seg_btn = gr.Button("Segment", variant="primary")
with gr.Column(scale=2):
with gr.Row():
seg_orig = gr.Image(label="Original", type="pil", interactive=False, show_download_button=False)
seg_out = gr.Image(label="Overlay", type="pil", interactive=False, show_download_button=False)
seg_gal = gr.Gallery(value=seg_examples, label="Examples", columns=6, height="auto")
seg_btn.click(run_segmentation, [seg_img, seg_model, seg_diam, seg_flow, seg_prob], [seg_orig, seg_out])
seg_gal.select(fn=create_path_selector(SEG_EXAMPLE_IMG_DIR), inputs=None, outputs=seg_img)
# --- TAB 6: Classification ---
with gr.Tab("Classification", id="classification"):
gr.Markdown("""
### Instructions
1. Upload a single-channel image for classification, or select an example.
2. Select a pre-trained 'Classification Model' from the dropdown menu.
3. Click 'Classify Image' to view the prediction probabilities for each class.
**Note:** The models provided are ResNet50 trained on the 2D HeLa dataset.
""")
with gr.Row(variant="panel"):
with gr.Column(scale=1, min_width=350):
cls_img = gr.Image(type="numpy", label="Upload Image", image_mode="L")
cls_model = gr.Dropdown(choices=list(CLS_MODEL_PATHS.keys()), value=list(CLS_MODEL_PATHS.keys())[0], label="Model")
cls_btn = gr.Button("Classify", variant="primary")
with gr.Column(scale=2):
cls_orig = gr.Image(label="Input", type="pil", interactive=False, show_download_button=False)
cls_res = gr.Label(label="Results", num_top_classes=10)
cls_gal = gr.Gallery(value=cls_examples, label="Examples", columns=6, height="auto")
cls_btn.click(run_classification, [cls_img, cls_model], [cls_orig, cls_res])
cls_gal.select(fn=create_path_selector(CLS_EXAMPLE_IMG_DIR), inputs=None, outputs=cls_img)
if __name__ == "__main__":
demo.launch()