""" PB Cell Generator - Synthetic Blood Cell Image Generation With ZeroGPU support for fast inference. """ import os import spaces import torch import gradio as gr # Cell type configurations with V2 prompts CELL_TYPES = { "Neutrophil": "A Neutrophil cell with intermediate size, low nucleocytoplasmic ratio, segmented nucleus, condensed chromatin, nucleoli absent, wide azurophilic cytoplasm, azurophil granulation.", "Lymphocyte": "A Lymphocyte cell with small size, high nucleocytoplasmic ratio, round nucleus, condensed chromatin, nucleoli absent, scant basophilic cytoplasm.", "Monocyte": "A Monocyte cell with large size, moderate nucleocytoplasmic ratio, irregular kidney-shaped nucleus, open chromatin, nucleoli absent, wide grayish cytoplasm, fine granulation, vacuoles.", "Eosinophil": "A Eosinophil cell with intermediate size, low nucleocytoplasmic ratio, bilobed nucleus, condensed chromatin, nucleoli absent, wide eosinophilic cytoplasm, eosinophilic granulation.", "Basophil": "A Basophil cell with intermediate size, low nucleocytoplasmic ratio, segmented nucleus, condensed chromatin, nucleoli absent, wide basophilic cytoplasm, coarse basophilic granulation.", "Platelet": "A Platelet cell with small size, anucleate, light basophilic cytoplasm, fine azurophilic granulation.", "Erythroblast": "A single Erythroblast cell with small size, high nucleocytoplasmic ratio, round nucleus, condensed chromatin, nucleoli absent, scant basophilic cytoplasm. One cell only.", "Immature Granulocyte (IG)": "A Immature Granulocyte cell with large size, low nucleocytoplasmic ratio, round to oval nucleus, fine open chromatin, nucleoli present, wide basophilic cytoplasm, azurophil granulation.", } CELL_TYPE_LIST = list(CELL_TYPES.keys()) # Custom CSS for soft red theme custom_css = """ .primary-btn { background: linear-gradient(135deg, #e57373 0%, #d32f2f 100%) !important; border: none !important; } .primary-btn:hover { background: linear-gradient(135deg, #ef5350 0%, #c62828 100%) !important; } .gradio-container { max-width: 900px !important; margin: auto !important; } """ # Global pipeline pipe = None def load_pipeline(): """Load pipeline (moved to GPU by @spaces.GPU decorator).""" global pipe if pipe is not None: return pipe from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler, UNet2DConditionModel hf_token = os.environ.get("HF_TOKEN") print("Loading base model...") pipe = StableDiffusionPipeline.from_pretrained( "sd2-community/stable-diffusion-2-1", torch_dtype=torch.float16, token=hf_token, ) print("Loading fine-tuned UNet...") unet = UNet2DConditionModel.from_pretrained( "esab/pbcell-sd21-v2", subfolder="unet", torch_dtype=torch.float16, token=hf_token, ) pipe.unet = unet pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) pipe = pipe.to("cuda") print("Pipeline ready!") return pipe # Default negative prompt to avoid common issues # Note: Training data from CellaVision DM96/9600 has palid yellow background (desired) DEFAULT_NEGATIVE_PROMPT = ( "white background, washed out, overexposed, " "low contrast, blurry, out of focus, multiple cells, overlapping cells, " "artifacts, noise, low quality, deformed" ) @spaces.GPU(duration=30) def generate(cell_type, custom_prompt, cfg, steps, seed, negative_prompt, use_negative_prompt): """Generate a blood cell image. GPU is allocated for this function.""" import random pipeline = load_pipeline() prompt = custom_prompt.strip() if custom_prompt and custom_prompt.strip() else CELL_TYPES.get(cell_type, CELL_TYPES["Neutrophil"]) # Handle seed: -1 means truly random each time if int(seed) < 0: actual_seed = random.randint(0, 2**32 - 1) else: actual_seed = int(seed) generator = torch.Generator(device="cuda").manual_seed(actual_seed) # Prepare negative prompt neg_prompt = None if use_negative_prompt: neg_prompt = negative_prompt.strip() if negative_prompt and negative_prompt.strip() else DEFAULT_NEGATIVE_PROMPT result = pipeline( prompt=prompt, negative_prompt=neg_prompt, height=512, width=512, num_inference_steps=int(steps), guidance_scale=float(cfg), generator=generator, ) return result.images[0] # Build interface with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="red")) as demo: gr.Markdown(""" # PB Cell Generator Generate synthetic peripheral blood cell images using a fine-tuned Stable Diffusion 2.1 model trained on the PBC dataset with detailed morphological captions. **Model:** [esab/pbcell-sd21-v2](https://huggingface.co/esab/pbcell-sd21-v2) | **FID Score:** 79.39 | **Powered by ZeroGPU** """) with gr.Row(): with gr.Column(scale=1): cell_dropdown = gr.Dropdown( choices=CELL_TYPE_LIST, value="Neutrophil", label="Cell Type", info="Select the type of blood cell to generate" ) custom_box = gr.Textbox( label="Custom Prompt (optional)", placeholder="Leave empty to use the default morphological prompt for the selected cell type...", lines=2, info="Override the default prompt with your own description" ) seed_box = gr.Number( value=-1, label="Seed", info="Random seed for reproducibility. Use -1 for random generation each time.", precision=0 ) with gr.Accordion("Advanced Settings", open=False): cfg_slider = gr.Slider( minimum=1, maximum=20, value=8.5, step=0.5, label="Guidance Scale (CFG)", info="Controls how closely the image follows the prompt. Higher values = stronger adherence to prompt but may reduce quality. Recommended: 7-9." ) steps_slider = gr.Slider( minimum=10, maximum=50, value=20, step=5, label="Inference Steps", info="Number of denoising steps. More steps = higher quality but slower generation. Recommended: 20-30." ) gr.Markdown("---") use_negative_checkbox = gr.Checkbox( value=False, label="Use Negative Prompt", info="Enable to steer generation away from unwanted characteristics (e.g., white backgrounds, blur)" ) negative_prompt_box = gr.Textbox( value=DEFAULT_NEGATIVE_PROMPT, label="Negative Prompt", placeholder="Describe what you DON'T want in the image...", lines=2, info="The model will avoid these characteristics. Helps prevent pale/washed out backgrounds and blurry images." ) btn = gr.Button("Generate Cell Image", variant="primary", elem_classes=["primary-btn"]) with gr.Column(scale=1): output_img = gr.Image(label="Generated Cell", show_label=True) gr.Markdown(""" --- ### Supported Cell Types | Cell Type | Description | |-----------|-------------| | **Neutrophil** | Segmented nucleus, azurophilic granules | | **Lymphocyte** | Small cell, high N/C ratio, round nucleus | | **Monocyte** | Large cell, kidney-shaped nucleus, vacuoles | | **Eosinophil** | Bilobed nucleus, eosinophilic granules | | **Basophil** | Segmented nucleus, basophilic granules | | **Platelet** | Small anucleate cell fragments | | **Erythroblast** | Nucleated red blood cell precursor | | **Immature Granulocyte** | Large cell, fine chromatin, nucleoli present | *Note: Erythroblast images may occasionally show multiple cells due to training data characteristics.* """) btn.click( fn=generate, inputs=[cell_dropdown, custom_box, cfg_slider, steps_slider, seed_box, negative_prompt_box, use_negative_checkbox], outputs=output_img ) demo.launch()