Spaces:
Running
on
Zero
Running
on
Zero
File size: 45,228 Bytes
9060565 d4a259c 9060565 d7b90f6 9060565 517a4aa d4a259c 9060565 d4a259c 9060565 31b3106 9060565 068458e 9060565 3511099 5aafe11 11ac284 5aafe11 3511099 9060565 d4a259c 9778ac9 9060565 68b0f40 9060565 4ab5cd7 3511099 4ab5cd7 43a0eda 4ab5cd7 9060565 4cfafe1 517a4aa 9060565 d4a259c 9060565 f0616d4 9060565 cb29951 d7f65f5 9060565 3310a64 9060565 3310a64 9060565 ad236df 9060565 517a4aa 75be3d6 61add18 3fe0e76 61add18 9060565 d4a259c 88153a5 d309237 88153a5 d309237 0210b98 d309237 0210b98 d309237 61add18 d4a259c 9355932 d4a259c 517a4aa d4a259c 517a4aa 88153a5 d309237 88153a5 d309237 0210b98 d309237 0210b98 d309237 61add18 517a4aa 9355932 d4a259c 517a4aa 88153a5 517a4aa d4a259c 517a4aa 9060565 d4a259c 9060565 f0616d4 517a4aa f0616d4 bc82a4c 4610725 bc82a4c 517a4aa f0616d4 bc82a4c f0616d4 517a4aa f0616d4 517a4aa d4a259c 517a4aa f0616d4 7f3f28e 4610725 9060565 bc82a4c 9060565 d4a259c 9060565 d4a259c 9060565 d4a259c 9060565 d4a259c 9060565 d4a259c 9060565 d4a259c 9060565 bc82a4c d4a259c bc82a4c d4a259c 9778ac9 d4a259c 9778ac9 be57e23 d4a259c d7b90f6 1834127 4ab5cd7 bc82a4c 9060565 d4a259c 6e0ed52 517a4aa fd7a328 517a4aa 4ab5cd7 517a4aa 4ab5cd7 fd7a328 4ab5cd7 517a4aa 4ab5cd7 9060565 d7b90f6 d4a259c 9060565 d4a259c 9060565 d4a259c 9060565 d4a259c 9060565 d4a259c 9060565 d4a259c 9060565 d4a259c 517a4aa 9060565 d4a259c 9060565 d4a259c 517a4aa d4a259c 517a4aa d4a259c 517a4aa d4a259c 9060565 d7b90f6 d4a259c 9060565 d4a259c 9060565 d4a259c 9060565 d4a259c 9060565 9778ac9 9060565 d4a259c 9060565 517a4aa 9778ac9 d4a259c 9778ac9 9060565 d7b90f6 d4a259c 9060565 d4a259c 9060565 d4a259c 9060565 517a4aa 9778ac9 d4a259c 9778ac9 9060565 9778ac9 d7b90f6 9060565 d4a259c 9060565 d4a259c 9060565 d4a259c 9060565 d4a259c 9060565 d7b90f6 9060565 d4a259c 517a4aa d4a259c 9060565 d4a259c 9060565 517a4aa d4a259c 9060565 d4a259c 9060565 d4a259c 9060565 d4a259c 517a4aa d4a259c 9060565 d4a259c d277748 517a4aa d4a259c 9060565 7c564f4 d4a259c 9060565 d4a259c 4ab5cd7 9060565 a711ad5 9060565 4ab5cd7 75be3d6 9060565 d4a259c 4ab5cd7 09ec9be 1834127 d4a259c 9060565 4ab5cd7 0210b98 d4a259c 2cbc1e6 d4a259c 4ab5cd7 d4a259c 1834127 4ab5cd7 be57e23 9060565 a711ad5 9060565 9778ac9 d4a259c 75be3d6 9778ac9 9060565 d4a259c 9060565 9778ac9 d4a259c d7f65f5 d4a259c 9778ac9 d4a259c 88153a5 9778ac9 7c564f4 9060565 a711ad5 9060565 9778ac9 d4a259c 75be3d6 9778ac9 9060565 d4a259c 9060565 9778ac9 d4a259c 9778ac9 d4a259c 9778ac9 7c564f4 9060565 a711ad5 9778ac9 75be3d6 f6455a5 75be3d6 a711ad5 d4a259c 4ab5cd7 d4a259c a711ad5 2cbc1e6 d4a259c 7c564f4 a711ad5 9060565 75be3d6 f6455a5 75be3d6 9060565 d4a259c 9060565 d4a259c 7c564f4 d4a259c 517a4aa 9060565 75be3d6 9060565 d4a259c 9060565 d4a259c 7c564f4 9060565 517a4aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 |
import torch
import numpy as np
import gradio as gr
from PIL import Image
import os
import json
import tifffile
import re
from torchvision import transforms, models
import shutil
import time
import torch.nn as nn
import torch.nn.functional as F
import spaces
from collections import OrderedDict
import tempfile
import zipfile
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import io
# --- Imports from both scripts ---
from diffusers import DDPMScheduler, DDIMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from transformers.utils import ContextManagers
# --- Custom Model Imports ---
from models.pipeline_ddpm_text_encoder import DDPMPipeline
from models.unet_2d import UNet2DModel
from models.controlnet import ControlNetModel
from models.unet_2d_condition import UNet2DConditionModel
from models.pipeline_controlnet import DDPMControlnetPipeline
# --- Segmentation Imports ---
from cellpose import models as cellpose_models
from huggingface_hub import snapshot_download
# --- 0. Configuration & Constants ---
hf_token = os.environ.get("HF_TOKEN")
MODEL_TITLE = "🔬 FluoGen: AI-Powered Fluorescence Microscopy Suite"
MODEL_DESCRIPTION = """
**Paper**: [*FluoGen: An Open-Source Generative Foundation Model for Fluorescence Microscopy Image Enhancement and Analysis*](https://doi.org/10.21203/rs.3.rs-8334792/v1)
<br>
**Homepage**: [Homepage Website](https://fluogen-group.github.io/FluoGen-HomePage/)
<br>
**Code**: [GitHub Repository](https://github.com/FluoGen-Group/FluoGen)
<br>
Select a task below.
**Note**: The "Pseudocolor" option instantly applies to both **Input** and **Output** images for better comparison. Use the "Download Raw Output" button to get the scientific 16-bit/Float data.
"""
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
WEIGHT_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
LOGO_PATH = "utils/logo2_transparent.png"
SAVE_EXAMPLES = False
# --- CSS for Times New Roman ---
CUSTOM_CSS = """
.gradio-container, .gradio-container * {
font-family: 'Arial', 'Helvetica', 'Microsoft YaHei', '微软雅黑', sans-serif !important;
}
button[aria-label="Download"],
a[download] {
display: none !important;
}
"""
# --- Base directory for all models ---
REPO_ID = "FluoGen-Group/FluoGen-demo-test-ckpts"
MODELS_ROOT_DIR = snapshot_download(repo_id=REPO_ID, token=hf_token)
# --- Paths Config ---
M2I_CONTROLNET_PATH = f"{MODELS_ROOT_DIR}/ControlNet_M2I/checkpoint-30000"
M2I_EXAMPLE_IMG_DIR = "example_images_m2i"
T2I_EXAMPLE_IMG_DIR = "example_images"
T2I_PRETRAINED_MODEL_PATH = f"{MODELS_ROOT_DIR}/stable-diffusion-v1-5"
T2I_UNET_PATH = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000"
CONTROLNET_CLIP_PATH = f"{MODELS_ROOT_DIR}/stable-diffusion-v1-5"
CONTROLNET_UNET_PATH = f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000"
SR_CONTROLNET_MODELS = {
"Checkpoint ER": f"{MODELS_ROOT_DIR}/ControlNet_SR/ER/checkpoint-30000",
"Checkpoint Microtubules": f"{MODELS_ROOT_DIR}/ControlNet_SR/Microtubules/checkpoint-72500",
"Checkpoint CCPs": f"{MODELS_ROOT_DIR}/ControlNet_SR/CCPs/checkpoint-100000",
"Checkpoint F-actin": f"{MODELS_ROOT_DIR}/ControlNet_SR/F-actin/checkpoint-35000",
}
SR_EXAMPLE_IMG_DIR = "example_images_sr"
DN_CONTROLNET_PATH = f"{MODELS_ROOT_DIR}/ControlNet_DN/checkpoint-10000"
DN_PROMPT_RULES = {'MICE': 'mouse brain tissues', 'FISH': 'zebrafish embryos', 'BPAE_B': 'nucleus of BPAE', 'BPAE_R': 'mitochondria of BPAE', 'BPAE_G': 'F-actin of BPAE'}
DN_EXAMPLE_IMG_DIR = "example_images_dn"
SEG_MODELS = {
#"DynamicNet Model": f"{MODELS_ROOT_DIR}/Cellpose/DynamicNet_baseline/CP_dynamic_ten_epoch_0100",
#"DynamicNet Model + FluoGen": f"{MODELS_ROOT_DIR}/Cellpose/DynamicNet_FluoGen/CP_dynamic_epoch_0300",
#"DSB Model": f"{MODELS_ROOT_DIR}/Cellpose/DSB_baseline/CP_dsb_baseline_ratio_1_epoch_0135",
"Cellpose Augmented by FluoGen": f"{MODELS_ROOT_DIR}/Cellpose/DSB_FluoGen/CP_dsb_ten_epoch_0135",
}
SEG_EXAMPLE_IMG_DIR = "example_images_seg"
CLS_MODEL_PATHS = OrderedDict({
#"5shot": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot",
"ResNet-50 Augmented by FluoGen": f"{MODELS_ROOT_DIR}/Classification/resnet50_hela_5_shot_aug",
})
CLS_CLASS_NAMES = ['Nucleus', 'Endoplasmic Reticulum', 'Giantin', 'GPP130', 'Lysosomes', 'Mitochondria', 'Nucleolus', 'Actin', 'Endosomes', 'Microtubules']
CLS_EXAMPLE_IMG_DIR = "example_images_cls"
# --- Constants for Visualization ---
COLOR_MAPS = [
"Grayscale",
"Green (GFP)",
"Red (RFP)",
"Blue (DAPI)",
"Magenta",
"Cyan",
"Yellow",
"Fire",
"Viridis",
"Inferno",
"Magma",
"Plasma",
"Red Hot",
"Cyan Hot",
"Magenta Hot"
]
CYAN_HOT_POINTS = [
(0.00, (0, 0, 0)),
(0.20, (0, 69, 143)),
(0.50, (0, 183, 255)),
(0.80, (89, 255, 255)),
(1.00, (255, 255, 255)),
]
def get_rgb_interpolation(val_norm, rgb_points):
x_pts = [p[0] for p in rgb_points]
r_pts = [p[1][0] / 255.0 for p in rgb_points] # 归一化到 0-1
g_pts = [p[1][1] / 255.0 for p in rgb_points]
b_pts = [p[1][2] / 255.0 for p in rgb_points]
r = np.interp(val_norm, x_pts, r_pts)
g = np.interp(val_norm, x_pts, g_pts)
b = np.interp(val_norm, x_pts, b_pts)
return r, g, b
# --- Helper Functions ---
def sanitize_prompt_for_filename(prompt):
prompt = prompt.lower(); prompt = re.sub(r'\s+of\s+', '_', prompt); prompt = re.sub(r'[^a-z0-9-_]+', '', prompt)
return f"{prompt}.png"
def min_max_norm(x):
x = x.astype(np.float32); min_val, max_val = np.min(x), np.max(x)
if max_val - min_val < 1e-8: return np.zeros_like(x)
return (x - min_val) / (max_val - min_val)
def generate_colorbar_preview(color_name):
"""Generates a small PIL image representing the colormap."""
if color_name == "Grayscale":
gradient = np.linspace(0, 1, 256).reshape(1, 256)
return Image.fromarray((gradient * 255).astype(np.uint8)).convert("RGB").resize((256, 30))
gradient = np.linspace(0, 1, 256).reshape(1, 256)
rgb = np.zeros((1, 256, 3))
if "Hot" in color_name:
low_half = np.clip(gradient * 2, 0, 1)
high_half = np.clip((gradient - 0.5) * 2, 0, 1)
if color_name == "Magenta Hot":
rgb[..., 0] = low_half
rgb[..., 1] = high_half
rgb[..., 2] = low_half
else:
step_1_red = np.clip(gradient * 3, 0, 1)
step_2_red = np.clip((gradient - 0.333) * 3, 0, 1)
step_3_red = np.clip((gradient - 0.666) * 3, 0, 1)
if color_name == "Red Hot":
rgb[..., 0] = step_1_red
rgb[..., 1] = step_2_red
rgb[..., 2] = step_3_red
elif color_name == "Cyan Hot":
r, g, b = get_rgb_interpolation(gradient, CYAN_HOT_POINTS)
rgb = np.stack([r, g, b], axis=-1)
elif color_name == "Green (GFP)": rgb[..., 1] = gradient
elif color_name == "Red (RFP)": rgb[..., 0] = gradient
elif color_name == "Blue (DAPI)": rgb[..., 2] = gradient
elif color_name == "Magenta": rgb[..., 0] = gradient; rgb[..., 2] = gradient
elif color_name == "Cyan": rgb[..., 1] = gradient; rgb[..., 2] = gradient
elif color_name == "Yellow": rgb[..., 0] = gradient; rgb[..., 1] = gradient
else:
# Matplotlib maps
mpl_map_name = color_name.lower()
if color_name == "Fire": mpl_map_name = "gnuplot2"
try:
cmap = cm.get_cmap(mpl_map_name)
rgb = cmap(gradient)[..., :3]
except:
return generate_colorbar_preview("Grayscale") # Fallback
img_np = (rgb * 255).astype(np.uint8)
return Image.fromarray(img_np).resize((256, 30))
def apply_pseudocolor(image_np, color_name="Grayscale"):
"""
Applies a pseudocolor to a single channel numpy image.
Returns: PIL Image in RGB.
"""
if image_np is None: return None
# Normalize to 0-1 for processing
norm_img = min_max_norm(np.squeeze(image_np))
if color_name == "Grayscale":
return Image.fromarray((norm_img * 255).astype(np.uint8)).convert("RGB")
h, w = norm_img.shape
rgb = np.zeros((h, w, 3), dtype=np.float32)
if "Hot" in color_name:
low_half = np.clip(norm_img * 2, 0, 1)
high_half = np.clip((norm_img - 0.5) * 2, 0, 1)
if color_name == "Magenta Hot":
rgb[..., 0] = low_half
rgb[..., 1] = high_half
rgb[..., 2] = low_half
else:
step_1_red = np.clip(norm_img * 3, 0, 1)
step_2_red = np.clip((norm_img - 0.333) * 3, 0, 1)
step_3_red = np.clip((norm_img - 0.666) * 3, 0, 1)
if color_name == "Red Hot":
rgb[..., 0] = step_1_red
rgb[..., 1] = step_2_red
rgb[..., 2] = step_3_red
elif color_name == "Cyan Hot":
r, g, b = get_rgb_interpolation(norm_img, CYAN_HOT_POINTS)
rgb = np.stack([r, g, b], axis=-1)
elif color_name == "Green (GFP)": rgb[..., 1] = norm_img
elif color_name == "Red (RFP)": rgb[..., 0] = norm_img
elif color_name == "Blue (DAPI)": rgb[..., 2] = norm_img
elif color_name == "Magenta": rgb[..., 0] = norm_img; rgb[..., 2] = norm_img
elif color_name == "Cyan": rgb[..., 1] = norm_img; rgb[..., 2] = norm_img
elif color_name == "Yellow": rgb[..., 0] = norm_img; rgb[..., 1] = norm_img
else:
# Matplotlib maps
mpl_map_name = color_name.lower()
if color_name == "Fire": mpl_map_name = "gnuplot2"
try:
cmap = cm.get_cmap(mpl_map_name)
colored = cmap(norm_img)
rgb = colored[..., :3]
except:
return apply_pseudocolor(image_np, "Grayscale")
return Image.fromarray((rgb * 255).astype(np.uint8))
def update_sr_settings(model_name):
"""
- Microtubules -> Cyan Hot
- F-actin -> Red Hot
- CCPs -> Green (GFP)
- ER -> Magenta Hot
"""
if model_name == "Checkpoint ER":
return "ER of COS-7", "Magenta Hot"
if model_name == "Checkpoint Microtubules":
return "Microtubules of COS-7", "Cyan Hot"
if model_name == "Checkpoint CCPs":
return "CCPs of COS-7", "Green (GFP)"
elif model_name == "Checkpoint F-actin":
return "F-actin of COS-7", "Red Hot"
return "", "Grayscale"
def save_temp_tiff(image_np, prefix="output"):
tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".tif", prefix=f"{prefix}_")
if image_np.dtype == np.float16: save_data = image_np.astype(np.float32)
else: save_data = image_np
tifffile.imwrite(tfile.name, save_data)
return tfile.name
def numpy_to_pil(image_np, target_mode="RGB"):
if isinstance(image_np, Image.Image):
if target_mode == "RGB" and image_np.mode != "RGB": return image_np.convert("RGB")
if target_mode == "L" and image_np.mode != "L": return image_np.convert("L")
return image_np
squeezed_np = np.squeeze(image_np);
if squeezed_np.dtype == np.uint8: image_8bit = squeezed_np
else:
normalized_np = min_max_norm(squeezed_np)
image_8bit = (normalized_np * 255).astype(np.uint8)
pil_image = Image.fromarray(image_8bit)
if target_mode == "RGB" and pil_image.mode != "RGB": pil_image = pil_image.convert("RGB")
elif target_mode == "L" and pil_image.mode != "L": pil_image = pil_image.convert("L")
return pil_image
def update_sr_prompt(model_name):
if model_name == "Checkpoint ER": return "ER of COS-7"
if model_name == "Checkpoint Microtubules": return "Microtubules of COS-7"
if model_name == "Checkpoint CCPs": return "CCPs of COS-7"
elif model_name == "Checkpoint F-actin": return "F-actin of COS-7"
return ""
PROMPT_TO_MODEL_MAP = {}
current_t2i_unet_path = None
def load_all_prompts():
global PROMPT_TO_MODEL_MAP
categories = [
{"file": "prompts/basic_prompts.json", "model": f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000"},
{"file": "prompts/others_prompts.json", "model": f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/FULL-checkpoint-275000"},
{"file": "prompts/hpa_prompts.json", "model": f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/HPA-checkpoint-40000"}
]
combined_prompts = []
for cat in categories:
try:
if os.path.exists(cat["file"]):
with open(cat["file"], "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, list):
combined_prompts.extend(data)
for p in data: PROMPT_TO_MODEL_MAP[p] = cat["model"]
except Exception: pass
if not combined_prompts: return ["F-actin of COS-7", "ER of COS-7"]
return combined_prompts
T2I_PROMPTS = load_all_prompts()
# --- 1. Model Loading ---
print("--- Initializing FluoGen Application ---")
t2i_pipe, controlnet_pipe = None, None
try:
print("Loading Text-to-Image model...")
t2i_noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="v_prediction", rescale_betas_zero_snr=True, timestep_spacing="trailing")
t2i_unet = UNet2DModel.from_pretrained(T2I_UNET_PATH, subfolder="unet")
t2i_text_encoder = CLIPTextModel.from_pretrained(T2I_PRETRAINED_MODEL_PATH, subfolder="text_encoder").to(DEVICE)
t2i_tokenizer = CLIPTokenizer.from_pretrained(T2I_PRETRAINED_MODEL_PATH, subfolder="tokenizer")
t2i_pipe = DDPMPipeline(unet=t2i_unet, scheduler=t2i_noise_scheduler, text_encoder=t2i_text_encoder, tokenizer=t2i_tokenizer)
t2i_pipe.to(DEVICE)
current_t2i_unet_path = T2I_UNET_PATH
print("✓ Text-to-Image model loaded successfully!")
except Exception as e:
print(f"FATAL: Text-to-Image Model Loading Failed: {e}")
try:
print("Loading shared ControlNet pipeline...")
controlnet_unet = UNet2DConditionModel.from_pretrained(CONTROLNET_UNET_PATH, subfolder="unet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
controlnet_controlnet = ControlNetModel.from_pretrained(M2I_CONTROLNET_PATH, subfolder="controlnet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
controlnet_scheduler = DDIMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="v_prediction", rescale_betas_zero_snr=False, timestep_spacing="trailing")
controlnet_tokenizer = CLIPTokenizer.from_pretrained(CONTROLNET_CLIP_PATH, subfolder="tokenizer")
with ContextManagers([]):
controlnet_text_encoder = CLIPTextModel.from_pretrained(CONTROLNET_CLIP_PATH, subfolder="text_encoder").to(dtype=WEIGHT_DTYPE, device=DEVICE)
controlnet_pipe = DDPMControlnetPipeline(unet=controlnet_unet, controlnet=controlnet_controlnet, scheduler=controlnet_scheduler, text_encoder=controlnet_text_encoder, tokenizer=controlnet_tokenizer)
controlnet_pipe.to(dtype=WEIGHT_DTYPE, device=DEVICE)
controlnet_pipe.current_controlnet_path = M2I_CONTROLNET_PATH
print("✓ Shared ControlNet pipeline loaded successfully!")
except Exception as e:
print(f"FATAL: ControlNet Pipeline Loading Failed: {e}")
# --- 2. Core Logic Functions ---
def swap_controlnet(pipe, target_path):
if os.path.normpath(getattr(pipe, 'current_controlnet_path', '')) != os.path.normpath(target_path):
print(f"Swapping ControlNet model to: {target_path}")
try:
pipe.controlnet = ControlNetModel.from_pretrained(target_path, subfolder="controlnet").to(dtype=WEIGHT_DTYPE, device=DEVICE)
pipe.current_controlnet_path = target_path
except Exception as e:
raise gr.Error(f"Failed to load ControlNet model. Error: {e}")
return pipe
def swap_t2i_unet(pipe, target_unet_path):
global current_t2i_unet_path
target_unet_path = os.path.normpath(target_unet_path)
if current_t2i_unet_path is None or os.path.normpath(current_t2i_unet_path) != target_unet_path:
print(f"🔄 Swapping T2I UNet to: {target_unet_path}")
try:
new_unet = UNet2DModel.from_pretrained(target_unet_path, subfolder="unet").to(DEVICE)
pipe.unet = new_unet
current_t2i_unet_path = target_unet_path
except Exception as e:
raise gr.Error(f"Failed to load UNet. Error: {e}")
return pipe
# --- Dynamic Color Update Functions ---
def update_single_image_color(raw_np_state, color_name):
if raw_np_state is None: return None, None
display_img = apply_pseudocolor(raw_np_state, color_name)
bar_img = generate_colorbar_preview(color_name)
return display_img, bar_img
def update_pair_color(input_np_state, output_np_state, color_name):
"""Updates both input and output images with the selected pseudocolor."""
if input_np_state is None: in_img = None
else: in_img = apply_pseudocolor(input_np_state, color_name)
if output_np_state is None: out_img = None
else: out_img = apply_pseudocolor(output_np_state, color_name)
bar_img = generate_colorbar_preview(color_name)
return in_img, out_img, bar_img
def update_gallery_color(raw_list_state, color_name):
if raw_list_state is None: return None, None
new_gallery = []
for img_np in raw_list_state:
new_gallery.append(apply_pseudocolor(img_np, color_name))
bar_img = generate_colorbar_preview(color_name)
return new_gallery, bar_img
# --- Event Handler Helper ---
def get_gallery_selection(evt: gr.SelectData):
return evt.value['caption']
# --- Generation Functions ---
@spaces.GPU(duration=120)
def generate_t2i(prompt, num_inference_steps, num_images, current_color, height=512, width=512):
"""
Generates multiple images for Text-to-Image and returns a gallery.
"""
global t2i_pipe
if t2i_pipe is None: raise gr.Error("Text-to-Image model is not loaded.")
target_model_path = PROMPT_TO_MODEL_MAP.get(prompt, f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/FULL-checkpoint-275000")
t2i_pipe = swap_t2i_unet(t2i_pipe, target_model_path)
print(f"\n🚀 T2I Task started... | Prompt: '{prompt}' | Count: {num_images} | Size: {height}x{width}")
generated_raw_list = []
generated_display_images = []
generated_raw_files = []
temp_dir = tempfile.mkdtemp()
# Generate Batch
for i in range(int(num_images)):
# Generate single image
image_np = t2i_pipe(
prompt.lower(),
generator=None,
num_inference_steps=int(num_inference_steps),
output_type="np",
height=int(height),
width=int(width)
).images
generated_raw_list.append(image_np)
# Save raw to temp
raw_name = f"t2i_sample_{i+1}.tif"
raw_path = os.path.join(temp_dir, raw_name)
save_data = image_np.astype(np.float32) if image_np.dtype == np.float16 else image_np
tifffile.imwrite(raw_path, save_data)
generated_raw_files.append(raw_path)
# Create display version
generated_display_images.append(apply_pseudocolor(image_np, current_color))
# Save first image to examples if needed
if SAVE_EXAMPLES and i == 0:
example_filepath = os.path.join(T2I_EXAMPLE_IMG_DIR, sanitize_prompt_for_filename(prompt))
if not os.path.exists(example_filepath): generated_display_images[0].save(example_filepath)
# Zip raw files
zip_filename = os.path.join(temp_dir, "raw_output_images.zip")
with zipfile.ZipFile(zip_filename, 'w') as zipf:
for file in generated_raw_files: zipf.write(file, os.path.basename(file))
colorbar_img = generate_colorbar_preview(current_color)
# Return: Gallery List, Zip Path, Raw State List, Colorbar
return generated_display_images, zip_filename, generated_raw_list, colorbar_img
@spaces.GPU(duration=120)
def run_mask_to_image_generation(mask_file_obj, cell_type, num_images, steps, seed, current_color):
if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
if mask_file_obj is None: raise gr.Error("Please upload a segmentation mask.")
pipe = swap_controlnet(controlnet_pipe, M2I_CONTROLNET_PATH)
try: mask_np = tifffile.imread(mask_file_obj.name)
except Exception as e: raise gr.Error(f"Failed to read TIF. Error: {e}")
input_display = numpy_to_pil(mask_np, "L")
mask_normalized = min_max_norm(mask_np)
image_tensor = torch.from_numpy(mask_normalized.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE)
image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor)
prompt = f"nuclei of {cell_type.strip()}"
print(f"\nM2I Task started... | Prompt: '{prompt}'")
generated_raw_list = []
generated_display_images = []
generated_raw_files = []
temp_dir = tempfile.mkdtemp()
for i in range(int(num_images)):
generator = torch.Generator(device=DEVICE).manual_seed(int(seed) + i)
with torch.autocast("cuda"):
output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
generated_raw_list.append(output_np)
raw_name = f"m2i_sample_{i+1}.tif"
raw_path = os.path.join(temp_dir, raw_name)
save_data = output_np.astype(np.float32) if output_np.dtype == np.float16 else output_np
tifffile.imwrite(raw_path, save_data)
generated_raw_files.append(raw_path)
generated_display_images.append(apply_pseudocolor(output_np, current_color))
zip_filename = os.path.join(temp_dir, "raw_output_images.zip")
with zipfile.ZipFile(zip_filename, 'w') as zipf:
for file in generated_raw_files: zipf.write(file, os.path.basename(file))
colorbar_img = generate_colorbar_preview(current_color)
return input_display, generated_display_images, zip_filename, generated_raw_list, colorbar_img
@spaces.GPU(duration=120)
def run_super_resolution(low_res_file_obj, controlnet_model_name, prompt, steps, seed, current_color):
if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
if low_res_file_obj is None: raise gr.Error("Please upload a file.")
target_path = SR_CONTROLNET_MODELS.get(controlnet_model_name)
pipe = swap_controlnet(controlnet_pipe, target_path)
try: image_stack_np = tifffile.imread(low_res_file_obj.name)
except Exception as e: raise gr.Error(f"Failed to read TIF. Error: {e}")
if image_stack_np.ndim != 3 or image_stack_np.shape[-3] != 9:
raise gr.Error(f"Invalid TIF shape. Expected 9 channels, got {image_stack_np.shape}.")
# Calculate Average Projection for Input
avg_projection_np = np.mean(image_stack_np, axis=0)
# Preprocess for model
image_tensor = torch.from_numpy(image_stack_np.astype(np.float32) / 65535.0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE)
image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor)
print(f"\nSR Task started...")
generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
with torch.autocast("cuda"):
output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
raw_file_path = save_temp_tiff(output_np, prefix="sr_raw")
# Generate displays with current color
input_display = apply_pseudocolor(avg_projection_np, current_color)
output_display = apply_pseudocolor(output_np, current_color)
colorbar_img = generate_colorbar_preview(current_color)
# Return: Input Disp, Output Disp, Download Path, Input State, Output State, Colorbar
return input_display, output_display, raw_file_path, avg_projection_np, output_np, colorbar_img
@spaces.GPU(duration=120)
def run_denoising(noisy_image_np, image_type, steps, seed, current_color):
if controlnet_pipe is None: raise gr.Error("ControlNet pipeline is not loaded.")
if noisy_image_np is None: raise gr.Error("Please upload an image.")
pipe = swap_controlnet(controlnet_pipe, DN_CONTROLNET_PATH)
prompt = DN_PROMPT_RULES.get(image_type, 'microscopy image')
print(f"\nDN Task started...")
image_tensor = torch.from_numpy(noisy_image_np.astype(np.float32) / 255.0).unsqueeze(0).unsqueeze(0).to(dtype=WEIGHT_DTYPE, device=DEVICE)
image_tensor = transforms.Resize((512, 512), antialias=True)(image_tensor)
generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
with torch.autocast("cuda"):
output_np = pipe(prompt=prompt, image_cond=image_tensor, num_inference_steps=int(steps), generator=generator, output_type="np").images
raw_file_path = save_temp_tiff(output_np, prefix="dn_raw")
# Generate displays with current color
input_display = apply_pseudocolor(noisy_image_np, current_color)
output_display = apply_pseudocolor(output_np, current_color)
colorbar_img = generate_colorbar_preview(current_color)
# Return: Input Disp, Output Disp, Download Path, Input State, Output State, Colorbar
return input_display, output_display, raw_file_path, noisy_image_np, output_np, colorbar_img
# --- Segmentation & Classification ---
@spaces.GPU(duration=120)
def run_segmentation(input_image_np, model_name, diameter, flow_threshold, cellprob_threshold):
if input_image_np is None: raise gr.Error("Please upload an image.")
model_path = SEG_MODELS.get(model_name)
print(f"\nSeg Task started...")
try:
model = cellpose_models.CellposeModel(gpu=torch.cuda.is_available(), pretrained_model=model_path)
masks, _, _ = model.eval([input_image_np], channels=[0, 0], diameter=(model.diam_labels if diameter==0 else diameter), flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold)
except Exception as e: raise gr.Error(f"Segmentation failed: {e}")
original_rgb = numpy_to_pil(input_image_np, "RGB")
red_mask = np.zeros_like(np.array(original_rgb)); red_mask[masks[0] > 0] = [139, 0, 0]
blended = ((0.6 * np.array(original_rgb) + 0.4 * red_mask).astype(np.uint8))
return numpy_to_pil(input_image_np, "L"), numpy_to_pil(blended, "RGB")
@spaces.GPU(duration=120)
def run_classification(input_image_np, model_name):
if input_image_np is None: raise gr.Error("Please upload an image.")
model_path = os.path.join(CLS_MODEL_PATHS.get(model_name), "best_resnet50.pth")
print(f"\nCls Task started...")
try:
model = models.resnet50(weights=None); model.fc = nn.Linear(model.fc.in_features, len(CLS_CLASS_NAMES))
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.to(DEVICE).eval()
except Exception as e: raise gr.Error(f"Classification failed: {e}")
input_tensor = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)])(numpy_to_pil(input_image_np, "RGB")).unsqueeze(0).to(DEVICE)
with torch.no_grad():
probs = F.softmax(model(input_tensor), dim=1).squeeze().cpu().numpy()
return numpy_to_pil(input_image_np, "L"), {name: float(p) for name, p in zip(CLS_CLASS_NAMES, probs)}
# --- 3. Gradio UI Layout ---
print("Building Gradio interface...")
for d in [M2I_EXAMPLE_IMG_DIR, T2I_EXAMPLE_IMG_DIR, SR_EXAMPLE_IMG_DIR, DN_EXAMPLE_IMG_DIR, SEG_EXAMPLE_IMG_DIR, CLS_EXAMPLE_IMG_DIR]: os.makedirs(d, exist_ok=True)
# Load Examples
filename_to_prompt = { sanitize_prompt_for_filename(p): p for p in T2I_PROMPTS }
t2i_examples = []
for f in os.listdir(T2I_EXAMPLE_IMG_DIR):
if f in filename_to_prompt: t2i_examples.append((os.path.join(T2I_EXAMPLE_IMG_DIR, f), filename_to_prompt[f]))
def load_examples(d, stack=False):
ex = []
if not os.path.exists(d): return ex
for f in sorted(os.listdir(d)):
if f.lower().endswith(('.tif', '.tiff', '.png', '.jpg')):
fp = os.path.join(d, f)
try:
img = tifffile.imread(fp) if f.endswith(('.tif','.tiff')) else np.array(Image.open(fp).convert("L"))
if stack and img.ndim == 3: img = np.mean(img, axis=0)
ex.append((numpy_to_pil(img, "L"), f))
except: pass
return ex
def create_path_selector(base_dir):
def handler(evt: gr.SelectData):
# 将目录路径和点击的文件名(caption)拼接
return os.path.join(base_dir, evt.value['caption'])
return handler
m2i_examples = load_examples(M2I_EXAMPLE_IMG_DIR)
sr_examples = load_examples(SR_EXAMPLE_IMG_DIR, stack=True)
dn_examples = load_examples(DN_EXAMPLE_IMG_DIR)
seg_examples = load_examples(SEG_EXAMPLE_IMG_DIR)
cls_examples = load_examples(CLS_EXAMPLE_IMG_DIR)
# --- UI Builders ---
with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
with gr.Row():
gr.Image(value=LOGO_PATH, width=300, height=200, container=False, interactive=False, show_download_button=False, show_fullscreen_button=False)
gr.Markdown(f"# {MODEL_TITLE}\n{MODEL_DESCRIPTION}")
with gr.Tabs():
# --- TAB 1: Text-to-Image ---
with gr.Tab("Text-to-Image Generation", id="txt2img"):
t2i_raw_state = gr.State(None) # Stores list of arrays
gr.Markdown("""
### Instructions
1. Select a desired prompt from the dropdown menu.
2. Adjust the 'Inference Steps' slider to control generation quality.
3. Click the 'Generate' button to create a new image.
4. Explore the 'Examples' gallery; clicking an image will load its prompt.
**Notice:** This model currently supports 3566 prompt categories. However, data for many cell structures and lines is still lacking. We welcome data source contributions to improve the model.
""")
with gr.Row(variant="panel"):
with gr.Column(scale=1, min_width=350):
t2i_prompt = gr.Dropdown(choices=T2I_PROMPTS, value=T2I_PROMPTS[0], label="Search or Type a Prompt", filterable=True, allow_custom_value=True)
t2i_steps = gr.Slider(10, 200, 50, step=1, label="Inference Steps")
# Added: Number of Images Slider
t2i_num_images = gr.Slider(1, 9, 3, step=1, label="Number of Images")
# with gr.Row():
# t2i_height = gr.Slider(512, 1024, value=512, step=64, label="Height")
# t2i_width = gr.Slider(512, 1024, value=512, step=64, label="Width")
t2i_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=2):
# Changed: Image to Gallery
t2i_gallery_out = gr.Gallery(label="Generated Images", columns=3, height="auto", show_download_button=False, show_share_button=False)
with gr.Row(equal_height=True):
with gr.Column(scale=2):
t2i_color = gr.Dropdown(choices=COLOR_MAPS, value="Green (GFP)", label="Pseudocolor (Adjust after generation)")
with gr.Column(scale=2):
t2i_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False)
with gr.Column(scale=1):
t2i_dl = gr.DownloadButton(label="Download All (.zip)")
t2i_gal = gr.Gallery(value=t2i_examples, label="Examples", columns=6, height="auto")
t2i_btn.click(generate_t2i, [t2i_prompt, t2i_steps, t2i_num_images, t2i_color], [t2i_gallery_out, t2i_dl, t2i_raw_state, t2i_colorbar])
# t2i_btn.click(
# generate_t2i,
# inputs=[t2i_prompt, t2i_steps, t2i_num_images, t2i_color, t2i_height, t2i_width],
# outputs=[t2i_gallery_out, t2i_dl, t2i_raw_state, t2i_colorbar]
# )
# Reuse update_gallery_color since state is now a list
t2i_color.change(update_gallery_color, [t2i_raw_state, t2i_color], [t2i_gallery_out, t2i_colorbar])
t2i_gal.select(fn=get_gallery_selection, inputs=None, outputs=t2i_prompt)
# --- TAB 2: Super-Resolution ---
with gr.Tab("Super-Resolution", id="super_res"):
# Stores: Input (Average Projection) and Output
sr_input_state = gr.State(None)
sr_raw_state = gr.State(None)
gr.Markdown("""
### Instructions
1. Upload a low-resolution 9-channel TIF stack, or select one from the examples.
2. Select a 'Super-Resolution Model' from the dropdown.
3. Enter a descriptive 'Prompt' related to the image content (e.g., 'CCPs of COS-7').
4. Adjust 'Inference Steps' and 'Seed' as needed.
5. Click 'Generate Super-Resolution' to process the image.
**Notice:** This model was trained on the **BioSR** dataset. If your data's characteristics differ significantly, please consider fine-tuning the model using our project on GitHub for optimal results.
""")
with gr.Row(variant="panel"):
with gr.Column(scale=1, min_width=350):
sr_file = gr.File(label="Upload 9-Channel TIF Stack", file_types=['.tif', '.tiff'])
sr_model = gr.Dropdown(choices=list(SR_CONTROLNET_MODELS.keys()), value=list(SR_CONTROLNET_MODELS.keys())[-1], label="Model")
sr_prompt = gr.Textbox(label="Prompt", value="F-actin of COS-7", interactive=False)
sr_steps = gr.Slider(5, 50, 10, step=1, label="Steps")
sr_seed = gr.Number(label="Seed", value=42)
sr_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=2):
with gr.Row():
# Input Display now shows Pseudocolor
sr_in_disp = gr.Image(label="Input (Avg Projection)", type="pil", interactive=False, show_download_button=False)
sr_out_disp = gr.Image(label="Output", type="pil", interactive=False, show_download_button=False)
with gr.Row(equal_height=True):
with gr.Column(scale=2):
sr_color = gr.Dropdown(choices=COLOR_MAPS, value="Red Hot", label="Pseudocolor")
with gr.Column(scale=2):
sr_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False)
with gr.Column(scale=1):
sr_dl = gr.DownloadButton(label="Download Raw Output (.tif)")
sr_gal = gr.Gallery(value=sr_examples, label="Examples", columns=6, height="auto")
# sr_model.change(update_sr_prompt, sr_model, sr_prompt)
sr_model.change(update_sr_settings, inputs=sr_model, outputs=[sr_prompt, sr_color])
# Run returns both input/output states and displays
sr_btn.click(run_super_resolution, [sr_file, sr_model, sr_prompt, sr_steps, sr_seed, sr_color], [sr_in_disp, sr_out_disp, sr_dl, sr_input_state, sr_raw_state, sr_colorbar])
# Change color updates both displays
sr_color.change(update_pair_color, [sr_input_state, sr_raw_state, sr_color], [sr_in_disp, sr_out_disp, sr_colorbar])
sr_gal.select(fn=create_path_selector(SR_EXAMPLE_IMG_DIR), inputs=None, outputs=sr_file)
# --- TAB 3: Denoising ---
with gr.Tab("Denoising", id="denoising"):
# Stores: Input (Noisy) and Output
dn_input_state = gr.State(None)
dn_raw_state = gr.State(None)
gr.Markdown("""
### Instructions
1. Upload a noisy single-channel image, or select one from the examples.
2. Select the 'Image Type' from the dropdown to provide context for the model.
3. Adjust 'Inference Steps' and 'Seed' as needed.
4. Click 'Denoise Image' to reduce the noise.
**Notice:** This model was trained on the **FMD** dataset. If your data's characteristics differ significantly, please consider fine-tuning the model using our project on GitHub for optimal results.
""")
with gr.Row(variant="panel"):
with gr.Column(scale=1, min_width=350):
dn_img = gr.Image(type="numpy", label="Upload Noisy Image", image_mode="L")
dn_type = gr.Dropdown(choices=list(DN_PROMPT_RULES.keys()), value='MICE', label="Image Type")
dn_steps = gr.Slider(5, 50, 10, step=1, label="Steps")
dn_seed = gr.Number(label="Seed", value=42)
dn_btn = gr.Button("Denoise", variant="primary")
with gr.Column(scale=2):
with gr.Row():
# Input Display now shows Pseudocolor
dn_orig = gr.Image(label="Original", type="pil", interactive=False, show_download_button=False)
dn_out = gr.Image(label="Denoised", type="pil", interactive=False, show_download_button=False)
with gr.Row(equal_height=True):
with gr.Column(scale=2):
dn_color = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor")
with gr.Column(scale=2):
dn_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False)
with gr.Column(scale=1):
dn_dl = gr.DownloadButton(label="Download Raw Output (.tif)")
dn_gal = gr.Gallery(value=dn_examples, label="Examples", columns=6, height="auto")
dn_btn.click(run_denoising, [dn_img, dn_type, dn_steps, dn_seed, dn_color], [dn_orig, dn_out, dn_dl, dn_input_state, dn_raw_state, dn_colorbar])
dn_color.change(update_pair_color, [dn_input_state, dn_raw_state, dn_color], [dn_orig, dn_out, dn_colorbar])
dn_gal.select(fn=create_path_selector(DN_EXAMPLE_IMG_DIR), inputs=None, outputs=dn_img)
# --- TAB 4: Mask-to-Image ---
with gr.Tab("Mask-to-Image", id="mask2img"):
m2i_raw_state = gr.State(None)
gr.Markdown("""
### Instructions
1. Upload a single-channel segmentation mask (`.tif` file), or select one from the examples gallery below.
2. Enter the corresponding 'Cell Type' (e.g., 'CoNSS', 'HeLa') to create the prompt.
3. Select how many sample images you want to generate.
4. Adjust 'Inference Steps' and 'Seed' as needed.
5. Click 'Generate Training Samples' to start the process.
6. The 'Generated Samples' will appear in the main gallery, with the 'Input Mask' shown below for reference.
**Notice:** This model was trained on the **2018 Data Science Bowl** dataset. If your data's characteristics differ significantly, please consider fine-tuning the model using our project on GitHub for optimal results.
""")
with gr.Row(variant="panel"):
with gr.Column(scale=1, min_width=350):
m2i_file = gr.File(label="Upload Mask (.tif)", file_types=['.tif', '.tiff'])
# Changed: Default value to HeLa
m2i_type = gr.Textbox(label="Cell Type", value="HeLa", placeholder="e.g., HeLa")
m2i_num = gr.Slider(1, 10, 5, step=1, label="Count")
m2i_steps = gr.Slider(5, 50, 10, step=1, label="Steps")
m2i_seed = gr.Number(label="Seed", value=42)
m2i_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=2):
m2i_gal_out = gr.Gallery(label="Generated Samples", columns=5, height="auto")
with gr.Row(equal_height=True):
with gr.Column(scale=2):
m2i_color = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor")
with gr.Column(scale=2):
m2i_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False)
with gr.Column(scale=1):
m2i_dl = gr.DownloadButton(label="Download ZIP")
m2i_in_disp = gr.Image(label="Input Mask", type="pil", interactive=False, show_download_button=False)
m2i_gal = gr.Gallery(value=m2i_examples, label="Examples", columns=6, height="auto")
m2i_btn.click(run_mask_to_image_generation, [m2i_file, m2i_type, m2i_num, m2i_steps, m2i_seed, m2i_color], [m2i_in_disp, m2i_gal_out, m2i_dl, m2i_raw_state, m2i_colorbar])
m2i_color.change(update_gallery_color, [m2i_raw_state, m2i_color], [m2i_gal_out, m2i_colorbar])
m2i_gal.select(fn=create_path_selector(M2I_EXAMPLE_IMG_DIR), inputs=None, outputs=m2i_file)
# --- TAB 5: Cell Segmentation ---
with gr.Tab("Cell Segmentation", id="segmentation"):
gr.Markdown("""
### Instructions
1. Upload a single-channel image for segmentation, or select one from the examples.
2. Select a 'Segmentation Model' from the dropdown menu.
3. Set the expected 'Diameter' of the cells in pixels. Set to 0 to let the model automatically estimate it.
4. Adjust 'Flow Threshold' and 'Cell Probability Threshold' for finer control.
5. Click 'Segment Cells'. The result will be shown as a dark red overlay on the original image.
**Notice:** This model was trained on the **2018 Data Science Bowl** dataset. If your data's characteristics differ significantly, please consider fine-tuning the model using our project on GitHub for optimal results.
""")
with gr.Row(variant="panel"):
with gr.Column(scale=1, min_width=350):
seg_img = gr.Image(type="numpy", label="Upload Image", image_mode="L")
seg_model = gr.Dropdown(choices=list(SEG_MODELS.keys()), value=list(SEG_MODELS.keys())[0], label="Model")
seg_diam = gr.Number(label="Diameter (0=auto)", value=30)
seg_flow = gr.Slider(0.0, 3.0, 0.4, step=0.1, label="Flow Thresh")
seg_prob = gr.Slider(-6.0, 6.0, 0.0, step=0.5, label="Prob Thresh")
seg_btn = gr.Button("Segment", variant="primary")
with gr.Column(scale=2):
with gr.Row():
seg_orig = gr.Image(label="Original", type="pil", interactive=False, show_download_button=False)
seg_out = gr.Image(label="Overlay", type="pil", interactive=False, show_download_button=False)
seg_gal = gr.Gallery(value=seg_examples, label="Examples", columns=6, height="auto")
seg_btn.click(run_segmentation, [seg_img, seg_model, seg_diam, seg_flow, seg_prob], [seg_orig, seg_out])
seg_gal.select(fn=create_path_selector(SEG_EXAMPLE_IMG_DIR), inputs=None, outputs=seg_img)
# --- TAB 6: Classification ---
with gr.Tab("Classification", id="classification"):
gr.Markdown("""
### Instructions
1. Upload a single-channel image for classification, or select an example.
2. Select a pre-trained 'Classification Model' from the dropdown menu.
3. Click 'Classify Image' to view the prediction probabilities for each class.
**Note:** The models provided are ResNet50 trained on the 2D HeLa dataset.
""")
with gr.Row(variant="panel"):
with gr.Column(scale=1, min_width=350):
cls_img = gr.Image(type="numpy", label="Upload Image", image_mode="L")
cls_model = gr.Dropdown(choices=list(CLS_MODEL_PATHS.keys()), value=list(CLS_MODEL_PATHS.keys())[0], label="Model")
cls_btn = gr.Button("Classify", variant="primary")
with gr.Column(scale=2):
cls_orig = gr.Image(label="Input", type="pil", interactive=False, show_download_button=False)
cls_res = gr.Label(label="Results", num_top_classes=10)
cls_gal = gr.Gallery(value=cls_examples, label="Examples", columns=6, height="auto")
cls_btn.click(run_classification, [cls_img, cls_model], [cls_orig, cls_res])
cls_gal.select(fn=create_path_selector(CLS_EXAMPLE_IMG_DIR), inputs=None, outputs=cls_img)
if __name__ == "__main__":
demo.launch() |