esab's picture
Upload folder using huggingface_hub
55c0756 verified
"""
PB Cell Generator - Synthetic Blood Cell Image Generation
With ZeroGPU support for fast inference.
"""
import os
import spaces
import torch
import gradio as gr
# Cell type configurations with V2 prompts
CELL_TYPES = {
"Neutrophil": "A Neutrophil cell with intermediate size, low nucleocytoplasmic ratio, segmented nucleus, condensed chromatin, nucleoli absent, wide azurophilic cytoplasm, azurophil granulation.",
"Lymphocyte": "A Lymphocyte cell with small size, high nucleocytoplasmic ratio, round nucleus, condensed chromatin, nucleoli absent, scant basophilic cytoplasm.",
"Monocyte": "A Monocyte cell with large size, moderate nucleocytoplasmic ratio, irregular kidney-shaped nucleus, open chromatin, nucleoli absent, wide grayish cytoplasm, fine granulation, vacuoles.",
"Eosinophil": "A Eosinophil cell with intermediate size, low nucleocytoplasmic ratio, bilobed nucleus, condensed chromatin, nucleoli absent, wide eosinophilic cytoplasm, eosinophilic granulation.",
"Basophil": "A Basophil cell with intermediate size, low nucleocytoplasmic ratio, segmented nucleus, condensed chromatin, nucleoli absent, wide basophilic cytoplasm, coarse basophilic granulation.",
"Platelet": "A Platelet cell with small size, anucleate, light basophilic cytoplasm, fine azurophilic granulation.",
"Erythroblast": "A single Erythroblast cell with small size, high nucleocytoplasmic ratio, round nucleus, condensed chromatin, nucleoli absent, scant basophilic cytoplasm. One cell only.",
"Immature Granulocyte (IG)": "A Immature Granulocyte cell with large size, low nucleocytoplasmic ratio, round to oval nucleus, fine open chromatin, nucleoli present, wide basophilic cytoplasm, azurophil granulation.",
}
CELL_TYPE_LIST = list(CELL_TYPES.keys())
# Custom CSS for soft red theme
custom_css = """
.primary-btn {
background: linear-gradient(135deg, #e57373 0%, #d32f2f 100%) !important;
border: none !important;
}
.primary-btn:hover {
background: linear-gradient(135deg, #ef5350 0%, #c62828 100%) !important;
}
.gradio-container {
max-width: 900px !important;
margin: auto !important;
}
"""
# Global pipeline
pipe = None
def load_pipeline():
"""Load pipeline (moved to GPU by @spaces.GPU decorator)."""
global pipe
if pipe is not None:
return pipe
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler, UNet2DConditionModel
hf_token = os.environ.get("HF_TOKEN")
print("Loading base model...")
pipe = StableDiffusionPipeline.from_pretrained(
"sd2-community/stable-diffusion-2-1",
torch_dtype=torch.float16,
token=hf_token,
)
print("Loading fine-tuned UNet...")
unet = UNet2DConditionModel.from_pretrained(
"esab/pbcell-sd21-v2",
subfolder="unet",
torch_dtype=torch.float16,
token=hf_token,
)
pipe.unet = unet
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
print("Pipeline ready!")
return pipe
# Default negative prompt to avoid common issues
# Note: Training data from CellaVision DM96/9600 has palid yellow background (desired)
DEFAULT_NEGATIVE_PROMPT = (
"white background, washed out, overexposed, "
"low contrast, blurry, out of focus, multiple cells, overlapping cells, "
"artifacts, noise, low quality, deformed"
)
@spaces.GPU(duration=30)
def generate(cell_type, custom_prompt, cfg, steps, seed, negative_prompt, use_negative_prompt):
"""Generate a blood cell image. GPU is allocated for this function."""
import random
pipeline = load_pipeline()
prompt = custom_prompt.strip() if custom_prompt and custom_prompt.strip() else CELL_TYPES.get(cell_type, CELL_TYPES["Neutrophil"])
# Handle seed: -1 means truly random each time
if int(seed) < 0:
actual_seed = random.randint(0, 2**32 - 1)
else:
actual_seed = int(seed)
generator = torch.Generator(device="cuda").manual_seed(actual_seed)
# Prepare negative prompt
neg_prompt = None
if use_negative_prompt:
neg_prompt = negative_prompt.strip() if negative_prompt and negative_prompt.strip() else DEFAULT_NEGATIVE_PROMPT
result = pipeline(
prompt=prompt,
negative_prompt=neg_prompt,
height=512,
width=512,
num_inference_steps=int(steps),
guidance_scale=float(cfg),
generator=generator,
)
return result.images[0]
# Build interface
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="red")) as demo:
gr.Markdown("""
# PB Cell Generator
Generate synthetic peripheral blood cell images using a fine-tuned Stable Diffusion 2.1 model
trained on the PBC dataset with detailed morphological captions.
**Model:** [esab/pbcell-sd21-v2](https://huggingface.co/esab/pbcell-sd21-v2) | **FID Score:** 79.39 | **Powered by ZeroGPU**
""")
with gr.Row():
with gr.Column(scale=1):
cell_dropdown = gr.Dropdown(
choices=CELL_TYPE_LIST,
value="Neutrophil",
label="Cell Type",
info="Select the type of blood cell to generate"
)
custom_box = gr.Textbox(
label="Custom Prompt (optional)",
placeholder="Leave empty to use the default morphological prompt for the selected cell type...",
lines=2,
info="Override the default prompt with your own description"
)
seed_box = gr.Number(
value=-1,
label="Seed",
info="Random seed for reproducibility. Use -1 for random generation each time.",
precision=0
)
with gr.Accordion("Advanced Settings", open=False):
cfg_slider = gr.Slider(
minimum=1,
maximum=20,
value=8.5,
step=0.5,
label="Guidance Scale (CFG)",
info="Controls how closely the image follows the prompt. Higher values = stronger adherence to prompt but may reduce quality. Recommended: 7-9."
)
steps_slider = gr.Slider(
minimum=10,
maximum=50,
value=20,
step=5,
label="Inference Steps",
info="Number of denoising steps. More steps = higher quality but slower generation. Recommended: 20-30."
)
gr.Markdown("---")
use_negative_checkbox = gr.Checkbox(
value=False,
label="Use Negative Prompt",
info="Enable to steer generation away from unwanted characteristics (e.g., white backgrounds, blur)"
)
negative_prompt_box = gr.Textbox(
value=DEFAULT_NEGATIVE_PROMPT,
label="Negative Prompt",
placeholder="Describe what you DON'T want in the image...",
lines=2,
info="The model will avoid these characteristics. Helps prevent pale/washed out backgrounds and blurry images."
)
btn = gr.Button("Generate Cell Image", variant="primary", elem_classes=["primary-btn"])
with gr.Column(scale=1):
output_img = gr.Image(label="Generated Cell", show_label=True)
gr.Markdown("""
---
### Supported Cell Types
| Cell Type | Description |
|-----------|-------------|
| **Neutrophil** | Segmented nucleus, azurophilic granules |
| **Lymphocyte** | Small cell, high N/C ratio, round nucleus |
| **Monocyte** | Large cell, kidney-shaped nucleus, vacuoles |
| **Eosinophil** | Bilobed nucleus, eosinophilic granules |
| **Basophil** | Segmented nucleus, basophilic granules |
| **Platelet** | Small anucleate cell fragments |
| **Erythroblast** | Nucleated red blood cell precursor |
| **Immature Granulocyte** | Large cell, fine chromatin, nucleoli present |
*Note: Erythroblast images may occasionally show multiple cells due to training data characteristics.*
""")
btn.click(
fn=generate,
inputs=[cell_dropdown, custom_box, cfg_slider, steps_slider, seed_box, negative_prompt_box, use_negative_checkbox],
outputs=output_img
)
demo.launch()