import torch import numpy as np import gradio as gr from PIL import Image import os import json import tifffile import re from torchvision import transforms, models import shutil import time import torch.nn as nn import torch.nn.functional as F import spaces from collections import OrderedDict import tempfile import zipfile import matplotlib.cm as cm import matplotlib.pyplot as plt import io # --- Imports from both scripts --- from diffusers import DDPMScheduler, DDIMScheduler from transformers import CLIPTextModel, CLIPTokenizer from transformers.utils import ContextManagers # --- Custom Model Imports --- from models.pipeline_ddpm_text_encoder import DDPMPipeline from models.unet_2d import UNet2DModel from models.controlnet import ControlNetModel from models.unet_2d_condition import UNet2DConditionModel from models.pipeline_controlnet import DDPMControlnetPipeline # --- Segmentation Imports --- from cellpose import models as cellpose_models from huggingface_hub import snapshot_download # --- 0. Configuration & Constants --- hf_token = os.environ.get("HF_TOKEN") MODEL_TITLE = "šŸ”¬ FluoGen: AI-Powered Fluorescence Microscopy Suite" MODEL_DESCRIPTION = """ **Paper**: [*FluoGen: An Open-Source Generative Foundation Model for Fluorescence Microscopy Image Enhancement and Analysis*](https://doi.org/10.21203/rs.3.rs-8334792/v1)
**Homepage**: [Homepage Website](https://fluogen-group.github.io/FluoGen-HomePage/)
**Code**: [GitHub Repository](https://github.com/FluoGen-Group/FluoGen)
Select a task below. **Note**: The "Pseudocolor" option instantly applies to both **Input** and **Output** images for better comparison. Use the "Download Raw Output" button to get the scientific 16-bit/Float data. """ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" WEIGHT_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 LOGO_PATH = "utils/logo2_transparent.png" SAVE_EXAMPLES = False # --- CSS for Times New Roman --- CUSTOM_CSS = """ .gradio-container, .gradio-container * { font-family: 'Arial', 'Helvetica', 'Microsoft YaHei', '微软雅黑', sans-serif !important; } button[aria-label="Download"], a[download] { display: none !important; } """ # --- Base directory for all models --- REPO_ID = "FluoGen-Group/FluoGen-demo-test-ckpts" MODELS_ROOT_DIR = snapshot_download(repo_id=REPO_ID, token=hf_token) # --- Paths Config --- M2I_CONTROLNET_PATH = f"{MODELS_ROOT_DIR}/ControlNet_M2I/checkpoint-30000" M2I_EXAMPLE_IMG_DIR = "example_images_m2i" T2I_EXAMPLE_IMG_DIR = "example_images" T2I_PRETRAINED_MODEL_PATH = f"{MODELS_ROOT_DIR}/stable-diffusion-v1-5" T2I_UNET_PATH = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000" CONTROLNET_CLIP_PATH = f"{MODELS_ROOT_DIR}/stable-diffusion-v1-5" CONTROLNET_UNET_PATH = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000" SR_CONTROLNET_MODELS = { "Checkpoint ER": f"{MODELS_ROOT_DIR}/ControlNet_SR/ER/checkpoint-30000", "Checkpoint Microtubules": f"{MODELS_ROOT_DIR}/ControlNet_SR/Microtubules/checkpoint-72500", "Checkpoint CCPs": f"{MODELS_ROOT_DIR}/ControlNet_SR/CCPs/checkpoint-100000", "Checkpoint F-actin": f"{MODELS_ROOT_DIR}/ControlNet_SR/F-actin/checkpoint-35000", } SR_EXAMPLE_IMG_DIR = "example_images_sr" DN_CONTROLNET_PATH = f"{MODELS_ROOT_DIR}/ControlNet_DN/checkpoint-10000" DN_PROMPT_RULES = {'MICE': 'mouse brain tissues', 'FISH': 'zebrafish embryos', 'BPAE_B': 'nucleus of BPAE', 'BPAE_R': 'mitochondria of BPAE', 'BPAE_G': 'F-actin of BPAE'} DN_EXAMPLE_IMG_DIR = "example_images_dn" SEG_MODELS = { #"DynamicNet Model": f"{MODELS_ROOT_DIR}/Cellpose/DynamicNet_baseline/CP_dynamic_ten_epoch_0100", #"DynamicNet Model + FluoGen": f"{MODELS_ROOT_DIR}/Cellpose/DynamicNet_FluoGen/CP_dynamic_epoch_0300", #"DSB Model": f"{MODELS_ROOT_DIR}/Cellpose/DSB_baseline/CP_dsb_baseline_ratio_1_epoch_0135", "Cellpose Augmented by FluoGen": f"{MODELS_ROOT_DIR}/Cellpose/DSB_FluoGen/CP_dsb_ten_epoch_0135", } SEG_EXAMPLE_IMG_DIR = "example_images_seg" CLS_MODEL_PATHS = OrderedDict({ #"5shot": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot", "ResNet-50 Augmented by FluoGen": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot_aug", }) CLS_CLASS_NAMES = ['Nucleus', 'Endoplasmic Reticulum', 'Giantin', 'GPP130', 'Lysosomes', 'Mitochondria', 'Nucleolus', 'Actin', 'Endosomes', 'Microtubules'] CLS_EXAMPLE_IMG_DIR = "example_images_cls" # --- Constants for Visualization --- COLOR_MAPS = [ "Grayscale", "Green (GFP)", "Red (RFP)", "Blue (DAPI)", "Magenta", "Cyan", "Yellow", "Fire", "Viridis", "Inferno", "Magma", "Plasma", "Red Hot", "Cyan Hot", "Magenta Hot" ] CYAN_HOT_POINTS = [ (0.00, (0, 0, 0)), (0.20, (0, 69, 143)), (0.50, (0, 183, 255)), (0.80, (89, 255, 255)), (1.00, (255, 255, 255)), ] def get_rgb_interpolation(val_norm, rgb_points): x_pts = [p[0] for p in rgb_points] r_pts = [p[1][0] / 255.0 for p in rgb_points] # å½’äø€åŒ–åˆ° 0-1 g_pts = [p[1][1] / 255.0 for p in rgb_points] b_pts = [p[1][2] / 255.0 for p in rgb_points] r = np.interp(val_norm, x_pts, r_pts) g = np.interp(val_norm, x_pts, g_pts) b = np.interp(val_norm, x_pts, b_pts) return r, g, b # --- Helper Functions --- def sanitize_prompt_for_filename(prompt): prompt = prompt.lower(); prompt = re.sub(r'\s+of\s+', '_', prompt); prompt = re.sub(r'[^a-z0-9-_]+', '', prompt) return f"{prompt}.png" def min_max_norm(x): x = x.astype(np.float32); min_val, max_val = np.min(x), np.max(x) if max_val - min_val < 1e-8: return np.zeros_like(x) return (x - min_val) / (max_val - min_val) def generate_colorbar_preview(color_name): """Generates a small PIL image representing the colormap.""" if color_name == "Grayscale": gradient = np.linspace(0, 1, 256).reshape(1, 256) return Image.fromarray((gradient * 255).astype(np.uint8)).convert("RGB").resize((256, 30)) gradient = np.linspace(0, 1, 256).reshape(1, 256) rgb = np.zeros((1, 256, 3)) if "Hot" in color_name: low_half = np.clip(gradient * 2, 0, 1) high_half = np.clip((gradient - 0.5) * 2, 0, 1) if color_name == "Magenta Hot": rgb[..., 0] = low_half rgb[..., 1] = high_half rgb[..., 2] = low_half else: step_1_red = np.clip(gradient * 3, 0, 1) step_2_red = np.clip((gradient - 0.333) * 3, 0, 1) step_3_red = np.clip((gradient - 0.666) * 3, 0, 1) if color_name == "Red Hot": rgb[..., 0] = step_1_red rgb[..., 1] = step_2_red rgb[..., 2] = step_3_red elif color_name == "Cyan Hot": r, g, b = get_rgb_interpolation(gradient, CYAN_HOT_POINTS) rgb = np.stack([r, g, b], axis=-1) elif color_name == "Green (GFP)": rgb[..., 1] = gradient elif color_name == "Red (RFP)": rgb[..., 0] = gradient elif color_name == "Blue (DAPI)": rgb[..., 2] = gradient elif color_name == "Magenta": rgb[..., 0] = gradient; rgb[..., 2] = gradient elif color_name == "Cyan": rgb[..., 1] = gradient; rgb[..., 2] = gradient elif color_name == "Yellow": rgb[..., 0] = gradient; rgb[..., 1] = gradient else: # Matplotlib maps mpl_map_name = color_name.lower() if color_name == "Fire": mpl_map_name = "gnuplot2" try: cmap = cm.get_cmap(mpl_map_name) rgb = cmap(gradient)[..., :3] except: return generate_colorbar_preview("Grayscale") # Fallback img_np = (rgb * 255).astype(np.uint8) return Image.fromarray(img_np).resize((256, 30)) def apply_pseudocolor(image_np, color_name="Grayscale"): """ Applies a pseudocolor to a single channel numpy image. Returns: PIL Image in RGB. """ if image_np is None: return None # Normalize to 0-1 for processing norm_img = min_max_norm(np.squeeze(image_np)) if color_name == "Grayscale": return Image.fromarray((norm_img * 255).astype(np.uint8)).convert("RGB") h, w = norm_img.shape rgb = np.zeros((h, w, 3), dtype=np.float32) if "Hot" in color_name: low_half = np.clip(norm_img * 2, 0, 1) high_half = np.clip((norm_img - 0.5) * 2, 0, 1) if color_name == "Magenta Hot": rgb[..., 0] = low_half rgb[..., 1] = high_half rgb[..., 2] = low_half else: step_1_red = np.clip(norm_img * 3, 0, 1) step_2_red = np.clip((norm_img - 0.333) * 3, 0, 1) step_3_red = np.clip((norm_img - 0.666) * 3, 0, 1) if color_name == "Red Hot": rgb[..., 0] = step_1_red rgb[..., 1] = step_2_red rgb[..., 2] = step_3_red elif color_name == "Cyan Hot": r, g, b = get_rgb_interpolation(norm_img, CYAN_HOT_POINTS) rgb = np.stack([r, g, b], axis=-1) elif color_name == "Green (GFP)": rgb[..., 1] = norm_img elif color_name == "Red (RFP)": rgb[..., 0] = norm_img elif color_name == "Blue (DAPI)": rgb[..., 2] = norm_img elif color_name == "Magenta": rgb[..., 0] = norm_img; rgb[..., 2] = norm_img elif color_name == "Cyan": rgb[..., 1] = norm_img; rgb[..., 2] = norm_img elif color_name == "Yellow": rgb[..., 0] = norm_img; rgb[..., 1] = norm_img else: # Matplotlib maps mpl_map_name = color_name.lower() if color_name == "Fire": mpl_map_name = "gnuplot2" try: cmap = cm.get_cmap(mpl_map_name) colored = cmap(norm_img) rgb = colored[..., :3] except: return apply_pseudocolor(image_np, "Grayscale") return Image.fromarray((rgb * 255).astype(np.uint8)) def update_sr_settings(model_name): """ - Microtubules -> Cyan Hot - F-actin -> Red Hot - CCPs -> Green (GFP) - ER -> Magenta Hot """ if model_name == "Checkpoint ER": return "ER of COS-7", "Magenta Hot" if model_name == "Checkpoint Microtubules": return "Microtubules of COS-7", "Cyan Hot" if model_name == "Checkpoint CCPs": return "CCPs of COS-7", "Green (GFP)" elif model_name == "Checkpoint F-actin": return "F-actin of COS-7", "Red Hot" return "", "Grayscale" def save_temp_tiff(image_np, prefix="output"): tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".tif", prefix=f"{prefix}_") if image_np.dtype == np.float16: save_data = image_np.astype(np.float32) else: save_data = image_np tifffile.imwrite(tfile.name, save_data) return tfile.name def numpy_to_pil(image_np, target_mode="RGB"): if isinstance(image_np, Image.Image): if target_mode == "RGB" and image_np.mode != "RGB": return image_np.convert("RGB") if target_mode == "L" and image_np.mode != "L": return image_np.convert("L") return image_np squeezed_np = np.squeeze(image_np); if squeezed_np.dtype == np.uint8: image_8bit = squeezed_np else: normalized_np = min_max_norm(squeezed_np) image_8bit = (normalized_np * 255).astype(np.uint8) pil_image = Image.fromarray(image_8bit) if target_mode == "RGB" and pil_image.mode != "RGB": pil_image = pil_image.convert("RGB") elif target_mode == "L" and pil_image.mode != "L": pil_image = pil_image.convert("L") return pil_image def update_sr_prompt(model_name): if model_name == "Checkpoint ER": return "ER of COS-7" if model_name == "Checkpoint Microtubules": return "Microtubules of COS-7" if model_name == "Checkpoint CCPs": return "CCPs of COS-7" elif model_name == "Checkpoint F-actin": return "F-actin of COS-7" return "" PROMPT_TO_MODEL_MAP = {} current_t2i_unet_path = None def load_all_prompts(): global PROMPT_TO_MODEL_MAP categories = [ {"file": "prompts/basic_prompts.json", "model": f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000"}, {"file": "prompts/others_prompts.json", "model": f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/FULL-checkpoint-275000"}, {"file": "prompts/hpa_prompts.json", "model": f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/HPA-checkpoint-40000"} ] combined_prompts = [] for cat in categories: try: if os.path.exists(cat["file"]): with open(cat["file"], "r", encoding="utf-8") as f: data = json.load(f) if isinstance(data, list): combined_prompts.extend(data) for p in data: PROMPT_TO_MODEL_MAP[p] = cat["model"] except Exception: pass if not combined_prompts: return ["F-actin of COS-7", "ER of COS-7"] return combined_prompts T2I_PROMPTS = load_all_prompts() # --- 1. Model Loading --- print("--- Initializing FluoGen Application ---") t2i_pipe, controlnet_pipe = None, None try: print("Loading Text-to-Image model...") t2i_noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="v_prediction", rescale_betas_zero_snr=True, timestep_spacing="trailing") t2i_unet = UNet2DModel.from_pretrained(T2I_UNET_PATH, subfolder="unet") t2i_text_encoder = CLIPTextModel.from_pretrained(T2I_PRETRAINED_MODEL_PATH, subfolder="text_encoder").to(DEVICE) t2i_tokenizer = CLIPTokenizer.from_pretrained(T2I_PRETRAINED_MODEL_PATH, subfolder="tokenizer") t2i_pipe = DDPMPipeline(unet=t2i_unet, scheduler=t2i_noise_scheduler, text_encoder=t2i_text_encoder, tokenizer=t2i_tokenizer) t2i_pipe.to(DEVICE) current_t2i_unet_path = T2I_UNET_PATH print("āœ“ Text-to-Image model loaded successfully!") except Exception as e: print(f"FATAL: Text-to-Image Model Loading Failed: {e}") try: print("Loading shared ControlNet pipeline...") controlnet_unet = UNet2DConditionModel.from_pretrained(CONTROLNET_UNET_PATH, subfolder="unet").to(dtype=WEIGHT_DTYPE, device=DEVICE) controlnet_controlnet = ControlNetModel.from_pretrained(M2I_CONTROLNET_PATH, subfolder="controlnet").to(dtype=WEIGHT_DTYPE, device=DEVICE) controlnet_scheduler = DDIMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="v_prediction", rescale_betas_zero_snr=False, timestep_spacing="trailing") controlnet_tokenizer = CLIPTokenizer.from_pretrained(CONTROLNET_CLIP_PATH, subfolder="tokenizer") with ContextManagers([]): controlnet_text_encoder = CLIPTextModel.from_pretrained(CONTROLNET_CLIP_PATH, subfolder="text_encoder").to(dtype=WEIGHT_DTYPE, device=DEVICE) controlnet_pipe = DDPMControlnetPipeline(unet=controlnet_unet, controlnet=controlnet_controlnet, scheduler=controlnet_scheduler, text_encoder=controlnet_text_encoder, tokenizer=controlnet_tokenizer) controlnet_pipe.to(dtype=WEIGHT_DTYPE, device=DEVICE) controlnet_pipe.current_controlnet_path = M2I_CONTROLNET_PATH print("āœ“ Shared ControlNet pipeline loaded successfully!") except Exception as e: print(f"FATAL: ControlNet Pipeline Loading Failed: {e}") # --- 2. Core Logic Functions --- def swap_controlnet(pipe, target_path): if os.path.normpath(getattr(pipe, 'current_controlnet_path', '')) != os.path.normpath(target_path): print(f"Swapping ControlNet model to: {target_path}") try: pipe.controlnet = ControlNetModel.from_pretrained(target_path, subfolder="controlnet").to(dtype=WEIGHT_DTYPE, device=DEVICE) pipe.current_controlnet_path = target_path except Exception as e: raise gr.Error(f"Failed to load ControlNet model. Error: {e}") return pipe def swap_t2i_unet(pipe, target_unet_path): global current_t2i_unet_path target_unet_path = os.path.normpath(target_unet_path) if current_t2i_unet_path is None or os.path.normpath(current_t2i_unet_path) != target_unet_path: print(f"šŸ”„ Swapping T2I UNet to: {target_unet_path}") try: new_unet = UNet2DModel.from_pretrained(target_unet_path, subfolder="unet").to(DEVICE) pipe.unet = new_unet current_t2i_unet_path = target_unet_path except Exception as e: raise gr.Error(f"Failed to load UNet. Error: {e}") return pipe # --- Dynamic Color Update Functions --- def update_single_image_color(raw_np_state, color_name): if raw_np_state is None: return None, None display_img = apply_pseudocolor(raw_np_state, color_name) bar_img = generate_colorbar_preview(color_name) return display_img, bar_img def update_pair_color(input_np_state, output_np_state, color_name): """Updates both input and output images with the selected pseudocolor.""" if input_np_state is None: in_img = None else: in_img = apply_pseudocolor(input_np_state, color_name) if output_np_state is None: out_img = None else: out_img = apply_pseudocolor(output_np_state, color_name) bar_img = generate_colorbar_preview(color_name) return in_img, out_img, bar_img def update_gallery_color(raw_list_state, color_name): if raw_list_state is None: return None, None new_gallery = [] for img_np in raw_list_state: new_gallery.append(apply_pseudocolor(img_np, color_name)) bar_img = generate_colorbar_preview(color_name) return new_gallery, bar_img # --- Event Handler Helper --- def get_gallery_selection(evt: gr.SelectData): return evt.value['caption'] # --- Generation Functions --- @spaces.GPU(duration=120) def generate_t2i(prompt, num_inference_steps, num_images, current_color, height=512, width=512): """ Generates multiple images for Text-to-Image and returns a gallery. """ global t2i_pipe if t2i_pipe is None: raise gr.Error("Text-to-Image model is not loaded.") target_model_path = PROMPT_TO_MODEL_MAP.get(prompt, f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/FULL-checkpoint-275000") t2i_pipe = swap_t2i_unet(t2i_pipe, target_model_path) print(f"\nšŸš€ T2I Task started... | Prompt: '{prompt}' | Count: {num_images} | Size: {height}x{width}") generated_raw_list = [] generated_display_images = [] generated_raw_files = [] temp_dir = tempfile.mkdtemp() # Generate Batch for i in range(int(num_images)): # Generate single image image_np = t2i_pipe( prompt.lower(), generator=None, num_inference_steps=int(num_inference_steps), output_type="np", height=int(height), width=int(width) ).images generated_raw_list.append(image_np) # Save raw to temp raw_name = f"t2i_sample_{i+1}.tif" raw_path = os.path.join(temp_dir, raw_name) save_data = image_np.astype(np.float32) if image_np.dtype == np.float16 else image_np tifffile.imwrite(raw_path, save_data) generated_raw_files.append(raw_path) # Create display version generated_display_images.append(apply_pseudocolor(image_np, current_color)) # Save first image to examples if needed if SAVE_EXAMPLES and i == 0: example_filepath = os.path.join(T2I_EXAMPLE_IMG_DIR, sanitize_prompt_for_filename(prompt)) if not os.path.exists(example_filepath): generated_display_images[0].save(example_filepath) # Zip raw files zip_filename = os.path.join(temp_dir, "raw_output_images.zip") with zipfile.ZipFile(zip_filename, 'w') as zipf: for file in generated_raw_files: zipf.write(file, os.path.basename(file)) colorbar_img = generate_colorbar_preview(current_color) # Return: Gallery List, Zip Path, Raw State List, Colorbar return generated_display_images, zip_filename, generated_raw_list, colorbar_img @spaces.GPU(duration=120) def run_mask_to_image_generation(mask_file_obj, cell_type, num_images, steps, seed, current_color): if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.") if mask_file_obj is None: raise gr.Error("Please upload a segmentation mask.") pipe = swap_controlnet(controlnet_pipe, M2I_CONTROLNET_PATH) try: mask_np = tifffile.imread(mask_file_obj.name) except Exception as e: raise gr.Error(f"Failed to read TIF. Error: {e}") input_display = numpy_to_pil(mask_np, "L") mask_normalized = min_max_norm(mask_np) image_tensor = torch.from_numpy(mask_normalized.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE) image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor) prompt = f"nuclei of {cell_type.strip()}" print(f"\nM2I Task started... | Prompt: '{prompt}'") generated_raw_list = [] generated_display_images = [] generated_raw_files = [] temp_dir = tempfile.mkdtemp() for i in range(int(num_images)): generator = torch.Generator(device=DEVICE).manual_seed(int(seed) + i) with torch.autocast("cuda"): output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images generated_raw_list.append(output_np) raw_name = f"m2i_sample_{i+1}.tif" raw_path = os.path.join(temp_dir, raw_name) save_data = output_np.astype(np.float32) if output_np.dtype == np.float16 else output_np tifffile.imwrite(raw_path, save_data) generated_raw_files.append(raw_path) generated_display_images.append(apply_pseudocolor(output_np, current_color)) zip_filename = os.path.join(temp_dir, "raw_output_images.zip") with zipfile.ZipFile(zip_filename, 'w') as zipf: for file in generated_raw_files: zipf.write(file, os.path.basename(file)) colorbar_img = generate_colorbar_preview(current_color) return input_display, generated_display_images, zip_filename, generated_raw_list, colorbar_img @spaces.GPU(duration=120) def run_super_resolution(low_res_file_obj, controlnet_model_name, prompt, steps, seed, current_color): if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.") if low_res_file_obj is None: raise gr.Error("Please upload a file.") target_path = SR_CONTROLNET_MODELS.get(controlnet_model_name) pipe = swap_controlnet(controlnet_pipe, target_path) try: image_stack_np = tifffile.imread(low_res_file_obj.name) except Exception as e: raise gr.Error(f"Failed to read TIF. Error: {e}") if image_stack_np.ndim != 3 or image_stack_np.shape[-3] != 9: raise gr.Error(f"Invalid TIF shape. Expected 9 channels, got {image_stack_np.shape}.") # Calculate Average Projection for Input avg_projection_np = np.mean(image_stack_np, axis=0) # Preprocess for model image_tensor = torch.from_numpy(image_stack_np.astype(np.float32) / 65535.0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE) image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor) print(f"\nSR Task started...") generator = torch.Generator(device=DEVICE).manual_seed(int(seed)) with torch.autocast("cuda"): output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images raw_file_path = save_temp_tiff(output_np, prefix="sr_raw") # Generate displays with current color input_display = apply_pseudocolor(avg_projection_np, current_color) output_display = apply_pseudocolor(output_np, current_color) colorbar_img = generate_colorbar_preview(current_color) # Return: Input Disp, Output Disp, Download Path, Input State, Output State, Colorbar return input_display, output_display, raw_file_path, avg_projection_np, output_np, colorbar_img @spaces.GPU(duration=120) def run_denoising(noisy_image_np, image_type, steps, seed, current_color): if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.") if noisy_image_np is None: raise gr.Error("Please upload an image.") pipe = swap_controlnet(controlnet_pipe, DN_CONTROLNET_PATH) prompt = DN_PROMPT_RULES.get(image_type, 'microscopy image') print(f"\nDN Task started...") image_tensor = torch.from_numpy(noisy_image_np.astype(np.float32) / 255.0).unsqueeze(0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE) image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor) generator = torch.Generator(device=DEVICE).manual_seed(int(seed)) with torch.autocast("cuda"): output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images raw_file_path = save_temp_tiff(output_np, prefix="dn_raw") # Generate displays with current color input_display = apply_pseudocolor(noisy_image_np, current_color) output_display = apply_pseudocolor(output_np, current_color) colorbar_img = generate_colorbar_preview(current_color) # Return: Input Disp, Output Disp, Download Path, Input State, Output State, Colorbar return input_display, output_display, raw_file_path, noisy_image_np, output_np, colorbar_img # --- Segmentation & Classification --- @spaces.GPU(duration=120) def run_segmentation(input_image_np, model_name, diameter, flow_threshold, cellprob_threshold): if input_image_np is None: raise gr.Error("Please upload an image.") model_path = SEG_MODELS.get(model_name) print(f"\nSeg Task started...") try: model = cellpose_models.CellposeModel(gpu=torch.cuda.is_available(), pretrained_model=model_path) masks, _, _ = model.eval([input_image_np], channels=[0, 0], diameter=(model.diam_labels if diameter==0 else diameter), flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold) except Exception as e: raise gr.Error(f"Segmentation failed: {e}") original_rgb = numpy_to_pil(input_image_np, "RGB") red_mask = np.zeros_like(np.array(original_rgb)); red_mask[masks[0] > 0] = [139, 0, 0] blended = ((0.6 * np.array(original_rgb) + 0.4 * red_mask).astype(np.uint8)) return numpy_to_pil(input_image_np, "L"), numpy_to_pil(blended, "RGB") @spaces.GPU(duration=120) def run_classification(input_image_np, model_name): if input_image_np is None: raise gr.Error("Please upload an image.") model_path = os.path.join(CLS_MODEL_PATHS.get(model_name), "best_resnet50.pth") print(f"\nCls Task started...") try: model = models.resnet50(weights=None); model.fc = nn.Linear(model.fc.in_features, len(CLS_CLASS_NAMES)) model.load_state_dict(torch.load(model_path, map_location=DEVICE)) model.to(DEVICE).eval() except Exception as e: raise gr.Error(f"Classification failed: {e}") input_tensor = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)])(numpy_to_pil(input_image_np, "RGB")).unsqueeze(0).to(DEVICE) with torch.no_grad(): probs = F.softmax(model(input_tensor), dim=1).squeeze().cpu().numpy() return numpy_to_pil(input_image_np, "L"), {name: float(p) for name, p in zip(CLS_CLASS_NAMES, probs)} # --- 3. Gradio UI Layout --- print("Building Gradio interface...") for d in [M2I_EXAMPLE_IMG_DIR, T2I_EXAMPLE_IMG_DIR, SR_EXAMPLE_IMG_DIR, DN_EXAMPLE_IMG_DIR, SEG_EXAMPLE_IMG_DIR, CLS_EXAMPLE_IMG_DIR]: os.makedirs(d, exist_ok=True) # Load Examples filename_to_prompt = { sanitize_prompt_for_filename(p): p for p in T2I_PROMPTS } t2i_examples = [] for f in os.listdir(T2I_EXAMPLE_IMG_DIR): if f in filename_to_prompt: t2i_examples.append((os.path.join(T2I_EXAMPLE_IMG_DIR, f), filename_to_prompt[f])) def load_examples(d, stack=False): ex = [] if not os.path.exists(d): return ex for f in sorted(os.listdir(d)): if f.lower().endswith(('.tif', '.tiff', '.png', '.jpg')): fp = os.path.join(d, f) try: img = tifffile.imread(fp) if f.endswith(('.tif','.tiff')) else np.array(Image.open(fp).convert("L")) if stack and img.ndim == 3: img = np.mean(img, axis=0) ex.append((numpy_to_pil(img, "L"), f)) except: pass return ex def create_path_selector(base_dir): def handler(evt: gr.SelectData): # å°†ē›®å½•č·Æå¾„å’Œē‚¹å‡»ēš„ę–‡ä»¶å(caption)ę‹¼ęŽ„ return os.path.join(base_dir, evt.value['caption']) return handler m2i_examples = load_examples(M2I_EXAMPLE_IMG_DIR) sr_examples = load_examples(SR_EXAMPLE_IMG_DIR, stack=True) dn_examples = load_examples(DN_EXAMPLE_IMG_DIR) seg_examples = load_examples(SEG_EXAMPLE_IMG_DIR) cls_examples = load_examples(CLS_EXAMPLE_IMG_DIR) # --- UI Builders --- with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo: with gr.Row(): gr.Image(value=LOGO_PATH, width=300, height=200, container=False, interactive=False, show_download_button=False, show_fullscreen_button=False) gr.Markdown(f"# {MODEL_TITLE}\n{MODEL_DESCRIPTION}") with gr.Tabs(): # --- TAB 1: Text-to-Image --- with gr.Tab("Text-to-Image Generation", id="txt2img"): t2i_raw_state = gr.State(None) # Stores list of arrays gr.Markdown(""" ### Instructions 1. Select a desired prompt from the dropdown menu. 2. Adjust the 'Inference Steps' slider to control generation quality. 3. Click the 'Generate' button to create a new image. 4. Explore the 'Examples' gallery; clicking an image will load its prompt. **Notice:** This model currently supports 3566 prompt categories. However, data for many cell structures and lines is still lacking. We welcome data source contributions to improve the model. """) with gr.Row(variant="panel"): with gr.Column(scale=1, min_width=350): t2i_prompt = gr.Dropdown(choices=T2I_PROMPTS, value=T2I_PROMPTS[0], label="Search or Type a Prompt", filterable=True, allow_custom_value=True) t2i_steps = gr.Slider(10, 200, 50, step=1, label="Inference Steps") # Added: Number of Images Slider t2i_num_images = gr.Slider(1, 9, 3, step=1, label="Number of Images") # with gr.Row(): # t2i_height = gr.Slider(512, 1024, value=512, step=64, label="Height") # t2i_width = gr.Slider(512, 1024, value=512, step=64, label="Width") t2i_btn = gr.Button("Generate", variant="primary") with gr.Column(scale=2): # Changed: Image to Gallery t2i_gallery_out = gr.Gallery(label="Generated Images", columns=3, height="auto", show_download_button=False, show_share_button=False) with gr.Row(equal_height=True): with gr.Column(scale=2): t2i_color = gr.Dropdown(choices=COLOR_MAPS, value="Green (GFP)", label="Pseudocolor (Adjust after generation)") with gr.Column(scale=2): t2i_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False) with gr.Column(scale=1): t2i_dl = gr.DownloadButton(label="Download All (.zip)") t2i_gal = gr.Gallery(value=t2i_examples, label="Examples", columns=6, height="auto") t2i_btn.click(generate_t2i, [t2i_prompt, t2i_steps, t2i_num_images, t2i_color], [t2i_gallery_out, t2i_dl, t2i_raw_state, t2i_colorbar]) # t2i_btn.click( # generate_t2i, # inputs=[t2i_prompt, t2i_steps, t2i_num_images, t2i_color, t2i_height, t2i_width], # outputs=[t2i_gallery_out, t2i_dl, t2i_raw_state, t2i_colorbar] # ) # Reuse update_gallery_color since state is now a list t2i_color.change(update_gallery_color, [t2i_raw_state, t2i_color], [t2i_gallery_out, t2i_colorbar]) t2i_gal.select(fn=get_gallery_selection, inputs=None, outputs=t2i_prompt) # --- TAB 2: Super-Resolution --- with gr.Tab("Super-Resolution", id="super_res"): # Stores: Input (Average Projection) and Output sr_input_state = gr.State(None) sr_raw_state = gr.State(None) gr.Markdown(""" ### Instructions 1. Upload a low-resolution 9-channel TIF stack, or select one from the examples. 2. Select a 'Super-Resolution Model' from the dropdown. 3. Enter a descriptive 'Prompt' related to the image content (e.g., 'CCPs of COS-7'). 4. Adjust 'Inference Steps' and 'Seed' as needed. 5. Click 'Generate Super-Resolution' to process the image. **Notice:** This model was trained on the **BioSR** dataset. If your data's characteristics differ significantly, please consider fine-tuning the model using our project on GitHub for optimal results. """) with gr.Row(variant="panel"): with gr.Column(scale=1, min_width=350): sr_file = gr.File(label="Upload 9-Channel TIF Stack", file_types=['.tif', '.tiff']) sr_model = gr.Dropdown(choices=list(SR_CONTROLNET_MODELS.keys()), value=list(SR_CONTROLNET_MODELS.keys())[-1], label="Model") sr_prompt = gr.Textbox(label="Prompt", value="F-actin of COS-7", interactive=False) sr_steps = gr.Slider(5, 50, 10, step=1, label="Steps") sr_seed = gr.Number(label="Seed", value=42) sr_btn = gr.Button("Generate", variant="primary") with gr.Column(scale=2): with gr.Row(): # Input Display now shows Pseudocolor sr_in_disp = gr.Image(label="Input (Avg Projection)", type="pil", interactive=False, show_download_button=False) sr_out_disp = gr.Image(label="Output", type="pil", interactive=False, show_download_button=False) with gr.Row(equal_height=True): with gr.Column(scale=2): sr_color = gr.Dropdown(choices=COLOR_MAPS, value="Red Hot", label="Pseudocolor") with gr.Column(scale=2): sr_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False) with gr.Column(scale=1): sr_dl = gr.DownloadButton(label="Download Raw Output (.tif)") sr_gal = gr.Gallery(value=sr_examples, label="Examples", columns=6, height="auto") # sr_model.change(update_sr_prompt, sr_model, sr_prompt) sr_model.change(update_sr_settings, inputs=sr_model, outputs=[sr_prompt, sr_color]) # Run returns both input/output states and displays sr_btn.click(run_super_resolution, [sr_file, sr_model, sr_prompt, sr_steps, sr_seed, sr_color], [sr_in_disp, sr_out_disp, sr_dl, sr_input_state, sr_raw_state, sr_colorbar]) # Change color updates both displays sr_color.change(update_pair_color, [sr_input_state, sr_raw_state, sr_color], [sr_in_disp, sr_out_disp, sr_colorbar]) sr_gal.select(fn=create_path_selector(SR_EXAMPLE_IMG_DIR), inputs=None, outputs=sr_file) # --- TAB 3: Denoising --- with gr.Tab("Denoising", id="denoising"): # Stores: Input (Noisy) and Output dn_input_state = gr.State(None) dn_raw_state = gr.State(None) gr.Markdown(""" ### Instructions 1. Upload a noisy single-channel image, or select one from the examples. 2. Select the 'Image Type' from the dropdown to provide context for the model. 3. Adjust 'Inference Steps' and 'Seed' as needed. 4. Click 'Denoise Image' to reduce the noise. **Notice:** This model was trained on the **FMD** dataset. If your data's characteristics differ significantly, please consider fine-tuning the model using our project on GitHub for optimal results. """) with gr.Row(variant="panel"): with gr.Column(scale=1, min_width=350): dn_img = gr.Image(type="numpy", label="Upload Noisy Image", image_mode="L") dn_type = gr.Dropdown(choices=list(DN_PROMPT_RULES.keys()), value='MICE', label="Image Type") dn_steps = gr.Slider(5, 50, 10, step=1, label="Steps") dn_seed = gr.Number(label="Seed", value=42) dn_btn = gr.Button("Denoise", variant="primary") with gr.Column(scale=2): with gr.Row(): # Input Display now shows Pseudocolor dn_orig = gr.Image(label="Original", type="pil", interactive=False, show_download_button=False) dn_out = gr.Image(label="Denoised", type="pil", interactive=False, show_download_button=False) with gr.Row(equal_height=True): with gr.Column(scale=2): dn_color = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor") with gr.Column(scale=2): dn_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False) with gr.Column(scale=1): dn_dl = gr.DownloadButton(label="Download Raw Output (.tif)") dn_gal = gr.Gallery(value=dn_examples, label="Examples", columns=6, height="auto") dn_btn.click(run_denoising, [dn_img, dn_type, dn_steps, dn_seed, dn_color], [dn_orig, dn_out, dn_dl, dn_input_state, dn_raw_state, dn_colorbar]) dn_color.change(update_pair_color, [dn_input_state, dn_raw_state, dn_color], [dn_orig, dn_out, dn_colorbar]) dn_gal.select(fn=create_path_selector(DN_EXAMPLE_IMG_DIR), inputs=None, outputs=dn_img) # --- TAB 4: Mask-to-Image --- with gr.Tab("Mask-to-Image", id="mask2img"): m2i_raw_state = gr.State(None) gr.Markdown(""" ### Instructions 1. Upload a single-channel segmentation mask (`.tif` file), or select one from the examples gallery below. 2. Enter the corresponding 'Cell Type' (e.g., 'CoNSS', 'HeLa') to create the prompt. 3. Select how many sample images you want to generate. 4. Adjust 'Inference Steps' and 'Seed' as needed. 5. Click 'Generate Training Samples' to start the process. 6. The 'Generated Samples' will appear in the main gallery, with the 'Input Mask' shown below for reference. **Notice:** This model was trained on the **2018 Data Science Bowl** dataset. If your data's characteristics differ significantly, please consider fine-tuning the model using our project on GitHub for optimal results. """) with gr.Row(variant="panel"): with gr.Column(scale=1, min_width=350): m2i_file = gr.File(label="Upload Mask (.tif)", file_types=['.tif', '.tiff']) # Changed: Default value to HeLa m2i_type = gr.Textbox(label="Cell Type", value="HeLa", placeholder="e.g., HeLa") m2i_num = gr.Slider(1, 10, 5, step=1, label="Count") m2i_steps = gr.Slider(5, 50, 10, step=1, label="Steps") m2i_seed = gr.Number(label="Seed", value=42) m2i_btn = gr.Button("Generate", variant="primary") with gr.Column(scale=2): m2i_gal_out = gr.Gallery(label="Generated Samples", columns=5, height="auto") with gr.Row(equal_height=True): with gr.Column(scale=2): m2i_color = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor") with gr.Column(scale=2): m2i_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False) with gr.Column(scale=1): m2i_dl = gr.DownloadButton(label="Download ZIP") m2i_in_disp = gr.Image(label="Input Mask", type="pil", interactive=False, show_download_button=False) m2i_gal = gr.Gallery(value=m2i_examples, label="Examples", columns=6, height="auto") m2i_btn.click(run_mask_to_image_generation, [m2i_file, m2i_type, m2i_num, m2i_steps, m2i_seed, m2i_color], [m2i_in_disp, m2i_gal_out, m2i_dl, m2i_raw_state, m2i_colorbar]) m2i_color.change(update_gallery_color, [m2i_raw_state, m2i_color], [m2i_gal_out, m2i_colorbar]) m2i_gal.select(fn=create_path_selector(M2I_EXAMPLE_IMG_DIR), inputs=None, outputs=m2i_file) # --- TAB 5: Cell Segmentation --- with gr.Tab("Cell Segmentation", id="segmentation"): gr.Markdown(""" ### Instructions 1. Upload a single-channel image for segmentation, or select one from the examples. 2. Select a 'Segmentation Model' from the dropdown menu. 3. Set the expected 'Diameter' of the cells in pixels. Set to 0 to let the model automatically estimate it. 4. Adjust 'Flow Threshold' and 'Cell Probability Threshold' for finer control. 5. Click 'Segment Cells'. The result will be shown as a dark red overlay on the original image. **Notice:** This model was trained on the **2018 Data Science Bowl** dataset. If your data's characteristics differ significantly, please consider fine-tuning the model using our project on GitHub for optimal results. """) with gr.Row(variant="panel"): with gr.Column(scale=1, min_width=350): seg_img = gr.Image(type="numpy", label="Upload Image", image_mode="L") seg_model = gr.Dropdown(choices=list(SEG_MODELS.keys()), value=list(SEG_MODELS.keys())[0], label="Model") seg_diam = gr.Number(label="Diameter (0=auto)", value=30) seg_flow = gr.Slider(0.0, 3.0, 0.4, step=0.1, label="Flow Thresh") seg_prob = gr.Slider(-6.0, 6.0, 0.0, step=0.5, label="Prob Thresh") seg_btn = gr.Button("Segment", variant="primary") with gr.Column(scale=2): with gr.Row(): seg_orig = gr.Image(label="Original", type="pil", interactive=False, show_download_button=False) seg_out = gr.Image(label="Overlay", type="pil", interactive=False, show_download_button=False) seg_gal = gr.Gallery(value=seg_examples, label="Examples", columns=6, height="auto") seg_btn.click(run_segmentation, [seg_img, seg_model, seg_diam, seg_flow, seg_prob], [seg_orig, seg_out]) seg_gal.select(fn=create_path_selector(SEG_EXAMPLE_IMG_DIR), inputs=None, outputs=seg_img) # --- TAB 6: Classification --- with gr.Tab("Classification", id="classification"): gr.Markdown(""" ### Instructions 1. Upload a single-channel image for classification, or select an example. 2. Select a pre-trained 'Classification Model' from the dropdown menu. 3. Click 'Classify Image' to view the prediction probabilities for each class. **Note:** The models provided are ResNet50 trained on the 2D HeLa dataset. """) with gr.Row(variant="panel"): with gr.Column(scale=1, min_width=350): cls_img = gr.Image(type="numpy", label="Upload Image", image_mode="L") cls_model = gr.Dropdown(choices=list(CLS_MODEL_PATHS.keys()), value=list(CLS_MODEL_PATHS.keys())[0], label="Model") cls_btn = gr.Button("Classify", variant="primary") with gr.Column(scale=2): cls_orig = gr.Image(label="Input", type="pil", interactive=False, show_download_button=False) cls_res = gr.Label(label="Results", num_top_classes=10) cls_gal = gr.Gallery(value=cls_examples, label="Examples", columns=6, height="auto") cls_btn.click(run_classification, [cls_img, cls_model], [cls_orig, cls_res]) cls_gal.select(fn=create_path_selector(CLS_EXAMPLE_IMG_DIR), inputs=None, outputs=cls_img) if __name__ == "__main__": demo.launch()