Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| import os | |
| import json | |
| import tifffile | |
| import re | |
| from torchvision import transforms, models | |
| import shutil | |
| import time | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import spaces | |
| from collections import OrderedDict | |
| import tempfile | |
| import zipfile | |
| import matplotlib.cm as cm | |
| import matplotlib.pyplot as plt | |
| import io | |
| # --- Imports from both scripts --- | |
| from diffusers import DDPMScheduler, DDIMScheduler | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from transformers.utils import ContextManagers | |
| # --- Custom Model Imports --- | |
| from models.pipeline_ddpm_text_encoder import DDPMPipeline | |
| from models.unet_2d import UNet2DModel | |
| from models.controlnet import ControlNetModel | |
| from models.unet_2d_condition import UNet2DConditionModel | |
| from models.pipeline_controlnet import DDPMControlnetPipeline | |
| # --- Segmentation Imports --- | |
| from cellpose import models as cellpose_models | |
| from huggingface_hub import snapshot_download | |
| # --- 0. Configuration & Constants --- | |
| hf_token = os.environ.get("HF_TOKEN") | |
| MODEL_TITLE = "🔬 FluoGen: AI-Powered Fluorescence Microscopy Suite" | |
| MODEL_DESCRIPTION = """ | |
| **Paper**: [*FluoGen: An Open-Source Generative Foundation Model for Fluorescence Microscopy Image Enhancement and Analysis*](https://doi.org/10.21203/rs.3.rs-8334792/v1) | |
| <br> | |
| **Homepage**: [Homepage Website](https://fluogen-group.github.io/FluoGen-HomePage/) | |
| <br> | |
| **Code**: [GitHub Repository](https://github.com/FluoGen-Group/FluoGen) | |
| <br> | |
| Select a task below. | |
| **Note**: The "Pseudocolor" option instantly applies to both **Input** and **Output** images for better comparison. Use the "Download Raw Output" button to get the scientific 16-bit/Float data. | |
| """ | |
| DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| WEIGHT_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| LOGO_PATH = "utils/logo2_transparent.png" | |
| SAVE_EXAMPLES = False | |
| # --- CSS for Times New Roman --- | |
| CUSTOM_CSS = """ | |
| .gradio-container, .gradio-container * { | |
| font-family: 'Arial', 'Helvetica', 'Microsoft YaHei', '微软雅黑', sans-serif !important; | |
| } | |
| button[aria-label="Download"], | |
| a[download] { | |
| display: none !important; | |
| } | |
| """ | |
| # --- Base directory for all models --- | |
| REPO_ID = "FluoGen-Group/FluoGen-demo-test-ckpts" | |
| MODELS_ROOT_DIR = snapshot_download(repo_id=REPO_ID, token=hf_token) | |
| # --- Paths Config --- | |
| M2I_CONTROLNET_PATH = f"{MODELS_ROOT_DIR}/ControlNet_M2I/checkpoint-30000" | |
| M2I_EXAMPLE_IMG_DIR = "example_images_m2i" | |
| T2I_EXAMPLE_IMG_DIR = "example_images" | |
| T2I_PRETRAINED_MODEL_PATH = f"{MODELS_ROOT_DIR}/stable-diffusion-v1-5" | |
| T2I_UNET_PATH = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000" | |
| CONTROLNET_CLIP_PATH = f"{MODELS_ROOT_DIR}/stable-diffusion-v1-5" | |
| CONTROLNET_UNET_PATH = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000" | |
| SR_CONTROLNET_MODELS = { | |
| "Checkpoint ER": f"{MODELS_ROOT_DIR}/ControlNet_SR/ER/checkpoint-30000", | |
| "Checkpoint Microtubules": f"{MODELS_ROOT_DIR}/ControlNet_SR/Microtubules/checkpoint-72500", | |
| "Checkpoint CCPs": f"{MODELS_ROOT_DIR}/ControlNet_SR/CCPs/checkpoint-100000", | |
| "Checkpoint F-actin": f"{MODELS_ROOT_DIR}/ControlNet_SR/F-actin/checkpoint-35000", | |
| } | |
| SR_EXAMPLE_IMG_DIR = "example_images_sr" | |
| DN_CONTROLNET_PATH = f"{MODELS_ROOT_DIR}/ControlNet_DN/checkpoint-10000" | |
| DN_PROMPT_RULES = {'MICE': 'mouse brain tissues', 'FISH': 'zebrafish embryos', 'BPAE_B': 'nucleus of BPAE', 'BPAE_R': 'mitochondria of BPAE', 'BPAE_G': 'F-actin of BPAE'} | |
| DN_EXAMPLE_IMG_DIR = "example_images_dn" | |
| SEG_MODELS = { | |
| #"DynamicNet Model": f"{MODELS_ROOT_DIR}/Cellpose/DynamicNet_baseline/CP_dynamic_ten_epoch_0100", | |
| #"DynamicNet Model + FluoGen": f"{MODELS_ROOT_DIR}/Cellpose/DynamicNet_FluoGen/CP_dynamic_epoch_0300", | |
| #"DSB Model": f"{MODELS_ROOT_DIR}/Cellpose/DSB_baseline/CP_dsb_baseline_ratio_1_epoch_0135", | |
| "Cellpose Augmented by FluoGen": f"{MODELS_ROOT_DIR}/Cellpose/DSB_FluoGen/CP_dsb_ten_epoch_0135", | |
| } | |
| SEG_EXAMPLE_IMG_DIR = "example_images_seg" | |
| CLS_MODEL_PATHS = OrderedDict({ | |
| #"5shot": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot", | |
| "ResNet-50 Augmented by FluoGen": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot_aug", | |
| }) | |
| CLS_CLASS_NAMES = ['Nucleus', 'Endoplasmic Reticulum', 'Giantin', 'GPP130', 'Lysosomes', 'Mitochondria', 'Nucleolus', 'Actin', 'Endosomes', 'Microtubules'] | |
| CLS_EXAMPLE_IMG_DIR = "example_images_cls" | |
| # --- Constants for Visualization --- | |
| COLOR_MAPS = [ | |
| "Grayscale", | |
| "Green (GFP)", | |
| "Red (RFP)", | |
| "Blue (DAPI)", | |
| "Magenta", | |
| "Cyan", | |
| "Yellow", | |
| "Fire", | |
| "Viridis", | |
| "Inferno", | |
| "Magma", | |
| "Plasma", | |
| "Red Hot", | |
| "Cyan Hot", | |
| "Magenta Hot" | |
| ] | |
| CYAN_HOT_POINTS = [ | |
| (0.00, (0, 0, 0)), | |
| (0.20, (0, 69, 143)), | |
| (0.50, (0, 183, 255)), | |
| (0.80, (89, 255, 255)), | |
| (1.00, (255, 255, 255)), | |
| ] | |
| def get_rgb_interpolation(val_norm, rgb_points): | |
| x_pts = [p[0] for p in rgb_points] | |
| r_pts = [p[1][0] / 255.0 for p in rgb_points] # 归一化到 0-1 | |
| g_pts = [p[1][1] / 255.0 for p in rgb_points] | |
| b_pts = [p[1][2] / 255.0 for p in rgb_points] | |
| r = np.interp(val_norm, x_pts, r_pts) | |
| g = np.interp(val_norm, x_pts, g_pts) | |
| b = np.interp(val_norm, x_pts, b_pts) | |
| return r, g, b | |
| # --- Helper Functions --- | |
| def sanitize_prompt_for_filename(prompt): | |
| prompt = prompt.lower(); prompt = re.sub(r'\s+of\s+', '_', prompt); prompt = re.sub(r'[^a-z0-9-_]+', '', prompt) | |
| return f"{prompt}.png" | |
| def min_max_norm(x): | |
| x = x.astype(np.float32); min_val, max_val = np.min(x), np.max(x) | |
| if max_val - min_val < 1e-8: return np.zeros_like(x) | |
| return (x - min_val) / (max_val - min_val) | |
| def generate_colorbar_preview(color_name): | |
| """Generates a small PIL image representing the colormap.""" | |
| if color_name == "Grayscale": | |
| gradient = np.linspace(0, 1, 256).reshape(1, 256) | |
| return Image.fromarray((gradient * 255).astype(np.uint8)).convert("RGB").resize((256, 30)) | |
| gradient = np.linspace(0, 1, 256).reshape(1, 256) | |
| rgb = np.zeros((1, 256, 3)) | |
| if "Hot" in color_name: | |
| low_half = np.clip(gradient * 2, 0, 1) | |
| high_half = np.clip((gradient - 0.5) * 2, 0, 1) | |
| if color_name == "Magenta Hot": | |
| rgb[..., 0] = low_half | |
| rgb[..., 1] = high_half | |
| rgb[..., 2] = low_half | |
| else: | |
| step_1_red = np.clip(gradient * 3, 0, 1) | |
| step_2_red = np.clip((gradient - 0.333) * 3, 0, 1) | |
| step_3_red = np.clip((gradient - 0.666) * 3, 0, 1) | |
| if color_name == "Red Hot": | |
| rgb[..., 0] = step_1_red | |
| rgb[..., 1] = step_2_red | |
| rgb[..., 2] = step_3_red | |
| elif color_name == "Cyan Hot": | |
| r, g, b = get_rgb_interpolation(gradient, CYAN_HOT_POINTS) | |
| rgb = np.stack([r, g, b], axis=-1) | |
| elif color_name == "Green (GFP)": rgb[..., 1] = gradient | |
| elif color_name == "Red (RFP)": rgb[..., 0] = gradient | |
| elif color_name == "Blue (DAPI)": rgb[..., 2] = gradient | |
| elif color_name == "Magenta": rgb[..., 0] = gradient; rgb[..., 2] = gradient | |
| elif color_name == "Cyan": rgb[..., 1] = gradient; rgb[..., 2] = gradient | |
| elif color_name == "Yellow": rgb[..., 0] = gradient; rgb[..., 1] = gradient | |
| else: | |
| # Matplotlib maps | |
| mpl_map_name = color_name.lower() | |
| if color_name == "Fire": mpl_map_name = "gnuplot2" | |
| try: | |
| cmap = cm.get_cmap(mpl_map_name) | |
| rgb = cmap(gradient)[..., :3] | |
| except: | |
| return generate_colorbar_preview("Grayscale") # Fallback | |
| img_np = (rgb * 255).astype(np.uint8) | |
| return Image.fromarray(img_np).resize((256, 30)) | |
| def apply_pseudocolor(image_np, color_name="Grayscale"): | |
| """ | |
| Applies a pseudocolor to a single channel numpy image. | |
| Returns: PIL Image in RGB. | |
| """ | |
| if image_np is None: return None | |
| # Normalize to 0-1 for processing | |
| norm_img = min_max_norm(np.squeeze(image_np)) | |
| if color_name == "Grayscale": | |
| return Image.fromarray((norm_img * 255).astype(np.uint8)).convert("RGB") | |
| h, w = norm_img.shape | |
| rgb = np.zeros((h, w, 3), dtype=np.float32) | |
| if "Hot" in color_name: | |
| low_half = np.clip(norm_img * 2, 0, 1) | |
| high_half = np.clip((norm_img - 0.5) * 2, 0, 1) | |
| if color_name == "Magenta Hot": | |
| rgb[..., 0] = low_half | |
| rgb[..., 1] = high_half | |
| rgb[..., 2] = low_half | |
| else: | |
| step_1_red = np.clip(norm_img * 3, 0, 1) | |
| step_2_red = np.clip((norm_img - 0.333) * 3, 0, 1) | |
| step_3_red = np.clip((norm_img - 0.666) * 3, 0, 1) | |
| if color_name == "Red Hot": | |
| rgb[..., 0] = step_1_red | |
| rgb[..., 1] = step_2_red | |
| rgb[..., 2] = step_3_red | |
| elif color_name == "Cyan Hot": | |
| r, g, b = get_rgb_interpolation(norm_img, CYAN_HOT_POINTS) | |
| rgb = np.stack([r, g, b], axis=-1) | |
| elif color_name == "Green (GFP)": rgb[..., 1] = norm_img | |
| elif color_name == "Red (RFP)": rgb[..., 0] = norm_img | |
| elif color_name == "Blue (DAPI)": rgb[..., 2] = norm_img | |
| elif color_name == "Magenta": rgb[..., 0] = norm_img; rgb[..., 2] = norm_img | |
| elif color_name == "Cyan": rgb[..., 1] = norm_img; rgb[..., 2] = norm_img | |
| elif color_name == "Yellow": rgb[..., 0] = norm_img; rgb[..., 1] = norm_img | |
| else: | |
| # Matplotlib maps | |
| mpl_map_name = color_name.lower() | |
| if color_name == "Fire": mpl_map_name = "gnuplot2" | |
| try: | |
| cmap = cm.get_cmap(mpl_map_name) | |
| colored = cmap(norm_img) | |
| rgb = colored[..., :3] | |
| except: | |
| return apply_pseudocolor(image_np, "Grayscale") | |
| return Image.fromarray((rgb * 255).astype(np.uint8)) | |
| def update_sr_settings(model_name): | |
| """ | |
| - Microtubules -> Cyan Hot | |
| - F-actin -> Red Hot | |
| - CCPs -> Green (GFP) | |
| - ER -> Magenta Hot | |
| """ | |
| if model_name == "Checkpoint ER": | |
| return "ER of COS-7", "Magenta Hot" | |
| if model_name == "Checkpoint Microtubules": | |
| return "Microtubules of COS-7", "Cyan Hot" | |
| if model_name == "Checkpoint CCPs": | |
| return "CCPs of COS-7", "Green (GFP)" | |
| elif model_name == "Checkpoint F-actin": | |
| return "F-actin of COS-7", "Red Hot" | |
| return "", "Grayscale" | |
| def save_temp_tiff(image_np, prefix="output"): | |
| tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".tif", prefix=f"{prefix}_") | |
| if image_np.dtype == np.float16: save_data = image_np.astype(np.float32) | |
| else: save_data = image_np | |
| tifffile.imwrite(tfile.name, save_data) | |
| return tfile.name | |
| def numpy_to_pil(image_np, target_mode="RGB"): | |
| if isinstance(image_np, Image.Image): | |
| if target_mode == "RGB" and image_np.mode != "RGB": return image_np.convert("RGB") | |
| if target_mode == "L" and image_np.mode != "L": return image_np.convert("L") | |
| return image_np | |
| squeezed_np = np.squeeze(image_np); | |
| if squeezed_np.dtype == np.uint8: image_8bit = squeezed_np | |
| else: | |
| normalized_np = min_max_norm(squeezed_np) | |
| image_8bit = (normalized_np * 255).astype(np.uint8) | |
| pil_image = Image.fromarray(image_8bit) | |
| if target_mode == "RGB" and pil_image.mode != "RGB": pil_image = pil_image.convert("RGB") | |
| elif target_mode == "L" and pil_image.mode != "L": pil_image = pil_image.convert("L") | |
| return pil_image | |
| def update_sr_prompt(model_name): | |
| if model_name == "Checkpoint ER": return "ER of COS-7" | |
| if model_name == "Checkpoint Microtubules": return "Microtubules of COS-7" | |
| if model_name == "Checkpoint CCPs": return "CCPs of COS-7" | |
| elif model_name == "Checkpoint F-actin": return "F-actin of COS-7" | |
| return "" | |
| PROMPT_TO_MODEL_MAP = {} | |
| current_t2i_unet_path = None | |
| def load_all_prompts(): | |
| global PROMPT_TO_MODEL_MAP | |
| categories = [ | |
| {"file": "prompts/basic_prompts.json", "model": f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000"}, | |
| {"file": "prompts/others_prompts.json", "model": f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/FULL-checkpoint-275000"}, | |
| {"file": "prompts/hpa_prompts.json", "model": f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/HPA-checkpoint-40000"} | |
| ] | |
| combined_prompts = [] | |
| for cat in categories: | |
| try: | |
| if os.path.exists(cat["file"]): | |
| with open(cat["file"], "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| if isinstance(data, list): | |
| combined_prompts.extend(data) | |
| for p in data: PROMPT_TO_MODEL_MAP[p] = cat["model"] | |
| except Exception: pass | |
| if not combined_prompts: return ["F-actin of COS-7", "ER of COS-7"] | |
| return combined_prompts | |
| T2I_PROMPTS = load_all_prompts() | |
| # --- 1. Model Loading --- | |
| print("--- Initializing FluoGen Application ---") | |
| t2i_pipe, controlnet_pipe = None, None | |
| try: | |
| print("Loading Text-to-Image model...") | |
| t2i_noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="v_prediction", rescale_betas_zero_snr=True, timestep_spacing="trailing") | |
| t2i_unet = UNet2DModel.from_pretrained(T2I_UNET_PATH, subfolder="unet") | |
| t2i_text_encoder = CLIPTextModel.from_pretrained(T2I_PRETRAINED_MODEL_PATH, subfolder="text_encoder").to(DEVICE) | |
| t2i_tokenizer = CLIPTokenizer.from_pretrained(T2I_PRETRAINED_MODEL_PATH, subfolder="tokenizer") | |
| t2i_pipe = DDPMPipeline(unet=t2i_unet, scheduler=t2i_noise_scheduler, text_encoder=t2i_text_encoder, tokenizer=t2i_tokenizer) | |
| t2i_pipe.to(DEVICE) | |
| current_t2i_unet_path = T2I_UNET_PATH | |
| print("✓ Text-to-Image model loaded successfully!") | |
| except Exception as e: | |
| print(f"FATAL: Text-to-Image Model Loading Failed: {e}") | |
| try: | |
| print("Loading shared ControlNet pipeline...") | |
| controlnet_unet = UNet2DConditionModel.from_pretrained(CONTROLNET_UNET_PATH, subfolder="unet").to(dtype=WEIGHT_DTYPE, device=DEVICE) | |
| controlnet_controlnet = ControlNetModel.from_pretrained(M2I_CONTROLNET_PATH, subfolder="controlnet").to(dtype=WEIGHT_DTYPE, device=DEVICE) | |
| controlnet_scheduler = DDIMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="v_prediction", rescale_betas_zero_snr=False, timestep_spacing="trailing") | |
| controlnet_tokenizer = CLIPTokenizer.from_pretrained(CONTROLNET_CLIP_PATH, subfolder="tokenizer") | |
| with ContextManagers([]): | |
| controlnet_text_encoder = CLIPTextModel.from_pretrained(CONTROLNET_CLIP_PATH, subfolder="text_encoder").to(dtype=WEIGHT_DTYPE, device=DEVICE) | |
| controlnet_pipe = DDPMControlnetPipeline(unet=controlnet_unet, controlnet=controlnet_controlnet, scheduler=controlnet_scheduler, text_encoder=controlnet_text_encoder, tokenizer=controlnet_tokenizer) | |
| controlnet_pipe.to(dtype=WEIGHT_DTYPE, device=DEVICE) | |
| controlnet_pipe.current_controlnet_path = M2I_CONTROLNET_PATH | |
| print("✓ Shared ControlNet pipeline loaded successfully!") | |
| except Exception as e: | |
| print(f"FATAL: ControlNet Pipeline Loading Failed: {e}") | |
| # --- 2. Core Logic Functions --- | |
| def swap_controlnet(pipe, target_path): | |
| if os.path.normpath(getattr(pipe, 'current_controlnet_path', '')) != os.path.normpath(target_path): | |
| print(f"Swapping ControlNet model to: {target_path}") | |
| try: | |
| pipe.controlnet = ControlNetModel.from_pretrained(target_path, subfolder="controlnet").to(dtype=WEIGHT_DTYPE, device=DEVICE) | |
| pipe.current_controlnet_path = target_path | |
| except Exception as e: | |
| raise gr.Error(f"Failed to load ControlNet model. Error: {e}") | |
| return pipe | |
| def swap_t2i_unet(pipe, target_unet_path): | |
| global current_t2i_unet_path | |
| target_unet_path = os.path.normpath(target_unet_path) | |
| if current_t2i_unet_path is None or os.path.normpath(current_t2i_unet_path) != target_unet_path: | |
| print(f"🔄 Swapping T2I UNet to: {target_unet_path}") | |
| try: | |
| new_unet = UNet2DModel.from_pretrained(target_unet_path, subfolder="unet").to(DEVICE) | |
| pipe.unet = new_unet | |
| current_t2i_unet_path = target_unet_path | |
| except Exception as e: | |
| raise gr.Error(f"Failed to load UNet. Error: {e}") | |
| return pipe | |
| # --- Dynamic Color Update Functions --- | |
| def update_single_image_color(raw_np_state, color_name): | |
| if raw_np_state is None: return None, None | |
| display_img = apply_pseudocolor(raw_np_state, color_name) | |
| bar_img = generate_colorbar_preview(color_name) | |
| return display_img, bar_img | |
| def update_pair_color(input_np_state, output_np_state, color_name): | |
| """Updates both input and output images with the selected pseudocolor.""" | |
| if input_np_state is None: in_img = None | |
| else: in_img = apply_pseudocolor(input_np_state, color_name) | |
| if output_np_state is None: out_img = None | |
| else: out_img = apply_pseudocolor(output_np_state, color_name) | |
| bar_img = generate_colorbar_preview(color_name) | |
| return in_img, out_img, bar_img | |
| def update_gallery_color(raw_list_state, color_name): | |
| if raw_list_state is None: return None, None | |
| new_gallery = [] | |
| for img_np in raw_list_state: | |
| new_gallery.append(apply_pseudocolor(img_np, color_name)) | |
| bar_img = generate_colorbar_preview(color_name) | |
| return new_gallery, bar_img | |
| # --- Event Handler Helper --- | |
| def get_gallery_selection(evt: gr.SelectData): | |
| return evt.value['caption'] | |
| # --- Generation Functions --- | |
| def generate_t2i(prompt, num_inference_steps, num_images, current_color, height=512, width=512): | |
| """ | |
| Generates multiple images for Text-to-Image and returns a gallery. | |
| """ | |
| global t2i_pipe | |
| if t2i_pipe is None: raise gr.Error("Text-to-Image model is not loaded.") | |
| target_model_path = PROMPT_TO_MODEL_MAP.get(prompt, f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/FULL-checkpoint-275000") | |
| t2i_pipe = swap_t2i_unet(t2i_pipe, target_model_path) | |
| print(f"\n🚀 T2I Task started... | Prompt: '{prompt}' | Count: {num_images} | Size: {height}x{width}") | |
| generated_raw_list = [] | |
| generated_display_images = [] | |
| generated_raw_files = [] | |
| temp_dir = tempfile.mkdtemp() | |
| # Generate Batch | |
| for i in range(int(num_images)): | |
| # Generate single image | |
| image_np = t2i_pipe( | |
| prompt.lower(), | |
| generator=None, | |
| num_inference_steps=int(num_inference_steps), | |
| output_type="np", | |
| height=int(height), | |
| width=int(width) | |
| ).images | |
| generated_raw_list.append(image_np) | |
| # Save raw to temp | |
| raw_name = f"t2i_sample_{i+1}.tif" | |
| raw_path = os.path.join(temp_dir, raw_name) | |
| save_data = image_np.astype(np.float32) if image_np.dtype == np.float16 else image_np | |
| tifffile.imwrite(raw_path, save_data) | |
| generated_raw_files.append(raw_path) | |
| # Create display version | |
| generated_display_images.append(apply_pseudocolor(image_np, current_color)) | |
| # Save first image to examples if needed | |
| if SAVE_EXAMPLES and i == 0: | |
| example_filepath = os.path.join(T2I_EXAMPLE_IMG_DIR, sanitize_prompt_for_filename(prompt)) | |
| if not os.path.exists(example_filepath): generated_display_images[0].save(example_filepath) | |
| # Zip raw files | |
| zip_filename = os.path.join(temp_dir, "raw_output_images.zip") | |
| with zipfile.ZipFile(zip_filename, 'w') as zipf: | |
| for file in generated_raw_files: zipf.write(file, os.path.basename(file)) | |
| colorbar_img = generate_colorbar_preview(current_color) | |
| # Return: Gallery List, Zip Path, Raw State List, Colorbar | |
| return generated_display_images, zip_filename, generated_raw_list, colorbar_img | |
| def run_mask_to_image_generation(mask_file_obj, cell_type, num_images, steps, seed, current_color): | |
| if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.") | |
| if mask_file_obj is None: raise gr.Error("Please upload a segmentation mask.") | |
| pipe = swap_controlnet(controlnet_pipe, M2I_CONTROLNET_PATH) | |
| try: mask_np = tifffile.imread(mask_file_obj.name) | |
| except Exception as e: raise gr.Error(f"Failed to read TIF. Error: {e}") | |
| input_display = numpy_to_pil(mask_np, "L") | |
| mask_normalized = min_max_norm(mask_np) | |
| image_tensor = torch.from_numpy(mask_normalized.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE) | |
| image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor) | |
| prompt = f"nuclei of {cell_type.strip()}" | |
| print(f"\nM2I Task started... | Prompt: '{prompt}'") | |
| generated_raw_list = [] | |
| generated_display_images = [] | |
| generated_raw_files = [] | |
| temp_dir = tempfile.mkdtemp() | |
| for i in range(int(num_images)): | |
| generator = torch.Generator(device=DEVICE).manual_seed(int(seed) + i) | |
| with torch.autocast("cuda"): | |
| output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images | |
| generated_raw_list.append(output_np) | |
| raw_name = f"m2i_sample_{i+1}.tif" | |
| raw_path = os.path.join(temp_dir, raw_name) | |
| save_data = output_np.astype(np.float32) if output_np.dtype == np.float16 else output_np | |
| tifffile.imwrite(raw_path, save_data) | |
| generated_raw_files.append(raw_path) | |
| generated_display_images.append(apply_pseudocolor(output_np, current_color)) | |
| zip_filename = os.path.join(temp_dir, "raw_output_images.zip") | |
| with zipfile.ZipFile(zip_filename, 'w') as zipf: | |
| for file in generated_raw_files: zipf.write(file, os.path.basename(file)) | |
| colorbar_img = generate_colorbar_preview(current_color) | |
| return input_display, generated_display_images, zip_filename, generated_raw_list, colorbar_img | |
| def run_super_resolution(low_res_file_obj, controlnet_model_name, prompt, steps, seed, current_color): | |
| if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.") | |
| if low_res_file_obj is None: raise gr.Error("Please upload a file.") | |
| target_path = SR_CONTROLNET_MODELS.get(controlnet_model_name) | |
| pipe = swap_controlnet(controlnet_pipe, target_path) | |
| try: image_stack_np = tifffile.imread(low_res_file_obj.name) | |
| except Exception as e: raise gr.Error(f"Failed to read TIF. Error: {e}") | |
| if image_stack_np.ndim != 3 or image_stack_np.shape[-3] != 9: | |
| raise gr.Error(f"Invalid TIF shape. Expected 9 channels, got {image_stack_np.shape}.") | |
| # Calculate Average Projection for Input | |
| avg_projection_np = np.mean(image_stack_np, axis=0) | |
| # Preprocess for model | |
| image_tensor = torch.from_numpy(image_stack_np.astype(np.float32) / 65535.0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE) | |
| image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor) | |
| print(f"\nSR Task started...") | |
| generator = torch.Generator(device=DEVICE).manual_seed(int(seed)) | |
| with torch.autocast("cuda"): | |
| output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images | |
| raw_file_path = save_temp_tiff(output_np, prefix="sr_raw") | |
| # Generate displays with current color | |
| input_display = apply_pseudocolor(avg_projection_np, current_color) | |
| output_display = apply_pseudocolor(output_np, current_color) | |
| colorbar_img = generate_colorbar_preview(current_color) | |
| # Return: Input Disp, Output Disp, Download Path, Input State, Output State, Colorbar | |
| return input_display, output_display, raw_file_path, avg_projection_np, output_np, colorbar_img | |
| def run_denoising(noisy_image_np, image_type, steps, seed, current_color): | |
| if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.") | |
| if noisy_image_np is None: raise gr.Error("Please upload an image.") | |
| pipe = swap_controlnet(controlnet_pipe, DN_CONTROLNET_PATH) | |
| prompt = DN_PROMPT_RULES.get(image_type, 'microscopy image') | |
| print(f"\nDN Task started...") | |
| image_tensor = torch.from_numpy(noisy_image_np.astype(np.float32) / 255.0).unsqueeze(0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE) | |
| image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor) | |
| generator = torch.Generator(device=DEVICE).manual_seed(int(seed)) | |
| with torch.autocast("cuda"): | |
| output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images | |
| raw_file_path = save_temp_tiff(output_np, prefix="dn_raw") | |
| # Generate displays with current color | |
| input_display = apply_pseudocolor(noisy_image_np, current_color) | |
| output_display = apply_pseudocolor(output_np, current_color) | |
| colorbar_img = generate_colorbar_preview(current_color) | |
| # Return: Input Disp, Output Disp, Download Path, Input State, Output State, Colorbar | |
| return input_display, output_display, raw_file_path, noisy_image_np, output_np, colorbar_img | |
| # --- Segmentation & Classification --- | |
| def run_segmentation(input_image_np, model_name, diameter, flow_threshold, cellprob_threshold): | |
| if input_image_np is None: raise gr.Error("Please upload an image.") | |
| model_path = SEG_MODELS.get(model_name) | |
| print(f"\nSeg Task started...") | |
| try: | |
| model = cellpose_models.CellposeModel(gpu=torch.cuda.is_available(), pretrained_model=model_path) | |
| masks, _, _ = model.eval([input_image_np], channels=[0, 0], diameter=(model.diam_labels if diameter==0 else diameter), flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold) | |
| except Exception as e: raise gr.Error(f"Segmentation failed: {e}") | |
| original_rgb = numpy_to_pil(input_image_np, "RGB") | |
| red_mask = np.zeros_like(np.array(original_rgb)); red_mask[masks[0] > 0] = [139, 0, 0] | |
| blended = ((0.6 * np.array(original_rgb) + 0.4 * red_mask).astype(np.uint8)) | |
| return numpy_to_pil(input_image_np, "L"), numpy_to_pil(blended, "RGB") | |
| def run_classification(input_image_np, model_name): | |
| if input_image_np is None: raise gr.Error("Please upload an image.") | |
| model_path = os.path.join(CLS_MODEL_PATHS.get(model_name), "best_resnet50.pth") | |
| print(f"\nCls Task started...") | |
| try: | |
| model = models.resnet50(weights=None); model.fc = nn.Linear(model.fc.in_features, len(CLS_CLASS_NAMES)) | |
| model.load_state_dict(torch.load(model_path, map_location=DEVICE)) | |
| model.to(DEVICE).eval() | |
| except Exception as e: raise gr.Error(f"Classification failed: {e}") | |
| input_tensor = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)])(numpy_to_pil(input_image_np, "RGB")).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| probs = F.softmax(model(input_tensor), dim=1).squeeze().cpu().numpy() | |
| return numpy_to_pil(input_image_np, "L"), {name: float(p) for name, p in zip(CLS_CLASS_NAMES, probs)} | |
| # --- 3. Gradio UI Layout --- | |
| print("Building Gradio interface...") | |
| for d in [M2I_EXAMPLE_IMG_DIR, T2I_EXAMPLE_IMG_DIR, SR_EXAMPLE_IMG_DIR, DN_EXAMPLE_IMG_DIR, SEG_EXAMPLE_IMG_DIR, CLS_EXAMPLE_IMG_DIR]: os.makedirs(d, exist_ok=True) | |
| # Load Examples | |
| filename_to_prompt = { sanitize_prompt_for_filename(p): p for p in T2I_PROMPTS } | |
| t2i_examples = [] | |
| for f in os.listdir(T2I_EXAMPLE_IMG_DIR): | |
| if f in filename_to_prompt: t2i_examples.append((os.path.join(T2I_EXAMPLE_IMG_DIR, f), filename_to_prompt[f])) | |
| def load_examples(d, stack=False): | |
| ex = [] | |
| if not os.path.exists(d): return ex | |
| for f in sorted(os.listdir(d)): | |
| if f.lower().endswith(('.tif', '.tiff', '.png', '.jpg')): | |
| fp = os.path.join(d, f) | |
| try: | |
| img = tifffile.imread(fp) if f.endswith(('.tif','.tiff')) else np.array(Image.open(fp).convert("L")) | |
| if stack and img.ndim == 3: img = np.mean(img, axis=0) | |
| ex.append((numpy_to_pil(img, "L"), f)) | |
| except: pass | |
| return ex | |
| def create_path_selector(base_dir): | |
| def handler(evt: gr.SelectData): | |
| # 将目录路径和点击的文件名(caption)拼接 | |
| return os.path.join(base_dir, evt.value['caption']) | |
| return handler | |
| m2i_examples = load_examples(M2I_EXAMPLE_IMG_DIR) | |
| sr_examples = load_examples(SR_EXAMPLE_IMG_DIR, stack=True) | |
| dn_examples = load_examples(DN_EXAMPLE_IMG_DIR) | |
| seg_examples = load_examples(SEG_EXAMPLE_IMG_DIR) | |
| cls_examples = load_examples(CLS_EXAMPLE_IMG_DIR) | |
| # --- UI Builders --- | |
| with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo: | |
| with gr.Row(): | |
| gr.Image(value=LOGO_PATH, width=300, height=200, container=False, interactive=False, show_download_button=False, show_fullscreen_button=False) | |
| gr.Markdown(f"# {MODEL_TITLE}\n{MODEL_DESCRIPTION}") | |
| with gr.Tabs(): | |
| # --- TAB 1: Text-to-Image --- | |
| with gr.Tab("Text-to-Image Generation", id="txt2img"): | |
| t2i_raw_state = gr.State(None) # Stores list of arrays | |
| gr.Markdown(""" | |
| ### Instructions | |
| 1. Select a desired prompt from the dropdown menu. | |
| 2. Adjust the 'Inference Steps' slider to control generation quality. | |
| 3. Click the 'Generate' button to create a new image. | |
| 4. Explore the 'Examples' gallery; clicking an image will load its prompt. | |
| **Notice:** This model currently supports 3566 prompt categories. However, data for many cell structures and lines is still lacking. We welcome data source contributions to improve the model. | |
| """) | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=1, min_width=350): | |
| t2i_prompt = gr.Dropdown(choices=T2I_PROMPTS, value=T2I_PROMPTS[0], label="Search or Type a Prompt", filterable=True, allow_custom_value=True) | |
| t2i_steps = gr.Slider(10, 200, 50, step=1, label="Inference Steps") | |
| # Added: Number of Images Slider | |
| t2i_num_images = gr.Slider(1, 9, 3, step=1, label="Number of Images") | |
| # with gr.Row(): | |
| # t2i_height = gr.Slider(512, 1024, value=512, step=64, label="Height") | |
| # t2i_width = gr.Slider(512, 1024, value=512, step=64, label="Width") | |
| t2i_btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=2): | |
| # Changed: Image to Gallery | |
| t2i_gallery_out = gr.Gallery(label="Generated Images", columns=3, height="auto", show_download_button=False, show_share_button=False) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=2): | |
| t2i_color = gr.Dropdown(choices=COLOR_MAPS, value="Green (GFP)", label="Pseudocolor (Adjust after generation)") | |
| with gr.Column(scale=2): | |
| t2i_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False) | |
| with gr.Column(scale=1): | |
| t2i_dl = gr.DownloadButton(label="Download All (.zip)") | |
| t2i_gal = gr.Gallery(value=t2i_examples, label="Examples", columns=6, height="auto") | |
| t2i_btn.click(generate_t2i, [t2i_prompt, t2i_steps, t2i_num_images, t2i_color], [t2i_gallery_out, t2i_dl, t2i_raw_state, t2i_colorbar]) | |
| # t2i_btn.click( | |
| # generate_t2i, | |
| # inputs=[t2i_prompt, t2i_steps, t2i_num_images, t2i_color, t2i_height, t2i_width], | |
| # outputs=[t2i_gallery_out, t2i_dl, t2i_raw_state, t2i_colorbar] | |
| # ) | |
| # Reuse update_gallery_color since state is now a list | |
| t2i_color.change(update_gallery_color, [t2i_raw_state, t2i_color], [t2i_gallery_out, t2i_colorbar]) | |
| t2i_gal.select(fn=get_gallery_selection, inputs=None, outputs=t2i_prompt) | |
| # --- TAB 2: Super-Resolution --- | |
| with gr.Tab("Super-Resolution", id="super_res"): | |
| # Stores: Input (Average Projection) and Output | |
| sr_input_state = gr.State(None) | |
| sr_raw_state = gr.State(None) | |
| gr.Markdown(""" | |
| ### Instructions | |
| 1. Upload a low-resolution 9-channel TIF stack, or select one from the examples. | |
| 2. Select a 'Super-Resolution Model' from the dropdown. | |
| 3. Enter a descriptive 'Prompt' related to the image content (e.g., 'CCPs of COS-7'). | |
| 4. Adjust 'Inference Steps' and 'Seed' as needed. | |
| 5. Click 'Generate Super-Resolution' to process the image. | |
| **Notice:** This model was trained on the **BioSR** dataset. If your data's characteristics differ significantly, please consider fine-tuning the model using our project on GitHub for optimal results. | |
| """) | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=1, min_width=350): | |
| sr_file = gr.File(label="Upload 9-Channel TIF Stack", file_types=['.tif', '.tiff']) | |
| sr_model = gr.Dropdown(choices=list(SR_CONTROLNET_MODELS.keys()), value=list(SR_CONTROLNET_MODELS.keys())[-1], label="Model") | |
| sr_prompt = gr.Textbox(label="Prompt", value="F-actin of COS-7", interactive=False) | |
| sr_steps = gr.Slider(5, 50, 10, step=1, label="Steps") | |
| sr_seed = gr.Number(label="Seed", value=42) | |
| sr_btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| # Input Display now shows Pseudocolor | |
| sr_in_disp = gr.Image(label="Input (Avg Projection)", type="pil", interactive=False, show_download_button=False) | |
| sr_out_disp = gr.Image(label="Output", type="pil", interactive=False, show_download_button=False) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=2): | |
| sr_color = gr.Dropdown(choices=COLOR_MAPS, value="Red Hot", label="Pseudocolor") | |
| with gr.Column(scale=2): | |
| sr_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False) | |
| with gr.Column(scale=1): | |
| sr_dl = gr.DownloadButton(label="Download Raw Output (.tif)") | |
| sr_gal = gr.Gallery(value=sr_examples, label="Examples", columns=6, height="auto") | |
| # sr_model.change(update_sr_prompt, sr_model, sr_prompt) | |
| sr_model.change(update_sr_settings, inputs=sr_model, outputs=[sr_prompt, sr_color]) | |
| # Run returns both input/output states and displays | |
| sr_btn.click(run_super_resolution, [sr_file, sr_model, sr_prompt, sr_steps, sr_seed, sr_color], [sr_in_disp, sr_out_disp, sr_dl, sr_input_state, sr_raw_state, sr_colorbar]) | |
| # Change color updates both displays | |
| sr_color.change(update_pair_color, [sr_input_state, sr_raw_state, sr_color], [sr_in_disp, sr_out_disp, sr_colorbar]) | |
| sr_gal.select(fn=create_path_selector(SR_EXAMPLE_IMG_DIR), inputs=None, outputs=sr_file) | |
| # --- TAB 3: Denoising --- | |
| with gr.Tab("Denoising", id="denoising"): | |
| # Stores: Input (Noisy) and Output | |
| dn_input_state = gr.State(None) | |
| dn_raw_state = gr.State(None) | |
| gr.Markdown(""" | |
| ### Instructions | |
| 1. Upload a noisy single-channel image, or select one from the examples. | |
| 2. Select the 'Image Type' from the dropdown to provide context for the model. | |
| 3. Adjust 'Inference Steps' and 'Seed' as needed. | |
| 4. Click 'Denoise Image' to reduce the noise. | |
| **Notice:** This model was trained on the **FMD** dataset. If your data's characteristics differ significantly, please consider fine-tuning the model using our project on GitHub for optimal results. | |
| """) | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=1, min_width=350): | |
| dn_img = gr.Image(type="numpy", label="Upload Noisy Image", image_mode="L") | |
| dn_type = gr.Dropdown(choices=list(DN_PROMPT_RULES.keys()), value='MICE', label="Image Type") | |
| dn_steps = gr.Slider(5, 50, 10, step=1, label="Steps") | |
| dn_seed = gr.Number(label="Seed", value=42) | |
| dn_btn = gr.Button("Denoise", variant="primary") | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| # Input Display now shows Pseudocolor | |
| dn_orig = gr.Image(label="Original", type="pil", interactive=False, show_download_button=False) | |
| dn_out = gr.Image(label="Denoised", type="pil", interactive=False, show_download_button=False) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=2): | |
| dn_color = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor") | |
| with gr.Column(scale=2): | |
| dn_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False) | |
| with gr.Column(scale=1): | |
| dn_dl = gr.DownloadButton(label="Download Raw Output (.tif)") | |
| dn_gal = gr.Gallery(value=dn_examples, label="Examples", columns=6, height="auto") | |
| dn_btn.click(run_denoising, [dn_img, dn_type, dn_steps, dn_seed, dn_color], [dn_orig, dn_out, dn_dl, dn_input_state, dn_raw_state, dn_colorbar]) | |
| dn_color.change(update_pair_color, [dn_input_state, dn_raw_state, dn_color], [dn_orig, dn_out, dn_colorbar]) | |
| dn_gal.select(fn=create_path_selector(DN_EXAMPLE_IMG_DIR), inputs=None, outputs=dn_img) | |
| # --- TAB 4: Mask-to-Image --- | |
| with gr.Tab("Mask-to-Image", id="mask2img"): | |
| m2i_raw_state = gr.State(None) | |
| gr.Markdown(""" | |
| ### Instructions | |
| 1. Upload a single-channel segmentation mask (`.tif` file), or select one from the examples gallery below. | |
| 2. Enter the corresponding 'Cell Type' (e.g., 'CoNSS', 'HeLa') to create the prompt. | |
| 3. Select how many sample images you want to generate. | |
| 4. Adjust 'Inference Steps' and 'Seed' as needed. | |
| 5. Click 'Generate Training Samples' to start the process. | |
| 6. The 'Generated Samples' will appear in the main gallery, with the 'Input Mask' shown below for reference. | |
| **Notice:** This model was trained on the **2018 Data Science Bowl** dataset. If your data's characteristics differ significantly, please consider fine-tuning the model using our project on GitHub for optimal results. | |
| """) | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=1, min_width=350): | |
| m2i_file = gr.File(label="Upload Mask (.tif)", file_types=['.tif', '.tiff']) | |
| # Changed: Default value to HeLa | |
| m2i_type = gr.Textbox(label="Cell Type", value="HeLa", placeholder="e.g., HeLa") | |
| m2i_num = gr.Slider(1, 10, 5, step=1, label="Count") | |
| m2i_steps = gr.Slider(5, 50, 10, step=1, label="Steps") | |
| m2i_seed = gr.Number(label="Seed", value=42) | |
| m2i_btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=2): | |
| m2i_gal_out = gr.Gallery(label="Generated Samples", columns=5, height="auto") | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=2): | |
| m2i_color = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor") | |
| with gr.Column(scale=2): | |
| m2i_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False) | |
| with gr.Column(scale=1): | |
| m2i_dl = gr.DownloadButton(label="Download ZIP") | |
| m2i_in_disp = gr.Image(label="Input Mask", type="pil", interactive=False, show_download_button=False) | |
| m2i_gal = gr.Gallery(value=m2i_examples, label="Examples", columns=6, height="auto") | |
| m2i_btn.click(run_mask_to_image_generation, [m2i_file, m2i_type, m2i_num, m2i_steps, m2i_seed, m2i_color], [m2i_in_disp, m2i_gal_out, m2i_dl, m2i_raw_state, m2i_colorbar]) | |
| m2i_color.change(update_gallery_color, [m2i_raw_state, m2i_color], [m2i_gal_out, m2i_colorbar]) | |
| m2i_gal.select(fn=create_path_selector(M2I_EXAMPLE_IMG_DIR), inputs=None, outputs=m2i_file) | |
| # --- TAB 5: Cell Segmentation --- | |
| with gr.Tab("Cell Segmentation", id="segmentation"): | |
| gr.Markdown(""" | |
| ### Instructions | |
| 1. Upload a single-channel image for segmentation, or select one from the examples. | |
| 2. Select a 'Segmentation Model' from the dropdown menu. | |
| 3. Set the expected 'Diameter' of the cells in pixels. Set to 0 to let the model automatically estimate it. | |
| 4. Adjust 'Flow Threshold' and 'Cell Probability Threshold' for finer control. | |
| 5. Click 'Segment Cells'. The result will be shown as a dark red overlay on the original image. | |
| **Notice:** This model was trained on the **2018 Data Science Bowl** dataset. If your data's characteristics differ significantly, please consider fine-tuning the model using our project on GitHub for optimal results. | |
| """) | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=1, min_width=350): | |
| seg_img = gr.Image(type="numpy", label="Upload Image", image_mode="L") | |
| seg_model = gr.Dropdown(choices=list(SEG_MODELS.keys()), value=list(SEG_MODELS.keys())[0], label="Model") | |
| seg_diam = gr.Number(label="Diameter (0=auto)", value=30) | |
| seg_flow = gr.Slider(0.0, 3.0, 0.4, step=0.1, label="Flow Thresh") | |
| seg_prob = gr.Slider(-6.0, 6.0, 0.0, step=0.5, label="Prob Thresh") | |
| seg_btn = gr.Button("Segment", variant="primary") | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| seg_orig = gr.Image(label="Original", type="pil", interactive=False, show_download_button=False) | |
| seg_out = gr.Image(label="Overlay", type="pil", interactive=False, show_download_button=False) | |
| seg_gal = gr.Gallery(value=seg_examples, label="Examples", columns=6, height="auto") | |
| seg_btn.click(run_segmentation, [seg_img, seg_model, seg_diam, seg_flow, seg_prob], [seg_orig, seg_out]) | |
| seg_gal.select(fn=create_path_selector(SEG_EXAMPLE_IMG_DIR), inputs=None, outputs=seg_img) | |
| # --- TAB 6: Classification --- | |
| with gr.Tab("Classification", id="classification"): | |
| gr.Markdown(""" | |
| ### Instructions | |
| 1. Upload a single-channel image for classification, or select an example. | |
| 2. Select a pre-trained 'Classification Model' from the dropdown menu. | |
| 3. Click 'Classify Image' to view the prediction probabilities for each class. | |
| **Note:** The models provided are ResNet50 trained on the 2D HeLa dataset. | |
| """) | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=1, min_width=350): | |
| cls_img = gr.Image(type="numpy", label="Upload Image", image_mode="L") | |
| cls_model = gr.Dropdown(choices=list(CLS_MODEL_PATHS.keys()), value=list(CLS_MODEL_PATHS.keys())[0], label="Model") | |
| cls_btn = gr.Button("Classify", variant="primary") | |
| with gr.Column(scale=2): | |
| cls_orig = gr.Image(label="Input", type="pil", interactive=False, show_download_button=False) | |
| cls_res = gr.Label(label="Results", num_top_classes=10) | |
| cls_gal = gr.Gallery(value=cls_examples, label="Examples", columns=6, height="auto") | |
| cls_btn.click(run_classification, [cls_img, cls_model], [cls_orig, cls_res]) | |
| cls_gal.select(fn=create_path_selector(CLS_EXAMPLE_IMG_DIR), inputs=None, outputs=cls_img) | |
| if __name__ == "__main__": | |
| demo.launch() |