File size: 8,480 Bytes
fd82f1a
 
ec6e9f7
fd82f1a
 
 
ec6e9f7
 
fd82f1a
 
 
 
 
 
 
55c0756
fd82f1a
 
 
 
 
 
0cfc091
 
1b1ca95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd82f1a
 
 
 
1b1ca95
fd82f1a
 
 
 
0cfc091
 
fd82f1a
 
ec6e9f7
fd82f1a
b5f32bf
ec6e9f7
fd82f1a
 
 
ec6e9f7
fd82f1a
b5f32bf
fd82f1a
ec6e9f7
fd82f1a
 
 
ec6e9f7
fd82f1a
ec6e9f7
fd82f1a
ec6e9f7
fd82f1a
 
 
e9bf074
c4d8798
e9bf074
c4d8798
e9bf074
 
 
 
 
ec6e9f7
e9bf074
ec6e9f7
bd18676
 
b5f32bf
 
 
 
bd18676
 
 
 
 
 
 
b5f32bf
e9bf074
 
 
 
 
b5f32bf
fd82f1a
e9bf074
fd82f1a
 
b5f32bf
 
ec6e9f7
fd82f1a
 
0cfc091
 
 
ec6e9f7
1b1ca95
 
 
 
 
 
 
 
 
b5f32bf
 
1b1ca95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd18676
1b1ca95
bd18676
1b1ca95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9bf074
 
 
c4d8798
e9bf074
c4d8798
e9bf074
 
 
 
 
 
 
 
 
1b1ca95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9bf074
1b1ca95
 
b5f32bf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
"""
PB Cell Generator - Synthetic Blood Cell Image Generation
With ZeroGPU support for fast inference.
"""

import os
import spaces
import torch
import gradio as gr

# Cell type configurations with V2 prompts
CELL_TYPES = {
    "Neutrophil": "A Neutrophil cell with intermediate size, low nucleocytoplasmic ratio, segmented nucleus, condensed chromatin, nucleoli absent, wide azurophilic cytoplasm, azurophil granulation.",
    "Lymphocyte": "A Lymphocyte cell with small size, high nucleocytoplasmic ratio, round nucleus, condensed chromatin, nucleoli absent, scant basophilic cytoplasm.",
    "Monocyte": "A Monocyte cell with large size, moderate nucleocytoplasmic ratio, irregular kidney-shaped nucleus, open chromatin, nucleoli absent, wide grayish cytoplasm, fine granulation, vacuoles.",
    "Eosinophil": "A Eosinophil cell with intermediate size, low nucleocytoplasmic ratio, bilobed nucleus, condensed chromatin, nucleoli absent, wide eosinophilic cytoplasm, eosinophilic granulation.",
    "Basophil": "A Basophil cell with intermediate size, low nucleocytoplasmic ratio, segmented nucleus, condensed chromatin, nucleoli absent, wide basophilic cytoplasm, coarse basophilic granulation.",
    "Platelet": "A Platelet cell with small size, anucleate, light basophilic cytoplasm, fine azurophilic granulation.",
    "Erythroblast": "A single Erythroblast cell with small size, high nucleocytoplasmic ratio, round nucleus, condensed chromatin, nucleoli absent, scant basophilic cytoplasm. One cell only.",
    "Immature Granulocyte (IG)": "A Immature Granulocyte cell with large size, low nucleocytoplasmic ratio, round to oval nucleus, fine open chromatin, nucleoli present, wide basophilic cytoplasm, azurophil granulation.",
}

CELL_TYPE_LIST = list(CELL_TYPES.keys())

# Custom CSS for soft red theme
custom_css = """
.primary-btn {
    background: linear-gradient(135deg, #e57373 0%, #d32f2f 100%) !important;
    border: none !important;
}
.primary-btn:hover {
    background: linear-gradient(135deg, #ef5350 0%, #c62828 100%) !important;
}
.gradio-container {
    max-width: 900px !important;
    margin: auto !important;
}
"""

# Global pipeline
pipe = None


def load_pipeline():
    """Load pipeline (moved to GPU by @spaces.GPU decorator)."""
    global pipe
    if pipe is not None:
        return pipe

    from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler, UNet2DConditionModel

    hf_token = os.environ.get("HF_TOKEN")

    print("Loading base model...")
    pipe = StableDiffusionPipeline.from_pretrained(
        "sd2-community/stable-diffusion-2-1",
        torch_dtype=torch.float16,
        token=hf_token,
    )

    print("Loading fine-tuned UNet...")
    unet = UNet2DConditionModel.from_pretrained(
        "esab/pbcell-sd21-v2",
        subfolder="unet",
        torch_dtype=torch.float16,
        token=hf_token,
    )
    pipe.unet = unet

    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    pipe = pipe.to("cuda")

    print("Pipeline ready!")
    return pipe


# Default negative prompt to avoid common issues
# Note: Training data from CellaVision DM96/9600 has palid yellow background (desired)
DEFAULT_NEGATIVE_PROMPT = (
    "white background, washed out, overexposed, "
    "low contrast, blurry, out of focus, multiple cells, overlapping cells, "
    "artifacts, noise, low quality, deformed"
)


@spaces.GPU(duration=30)
def generate(cell_type, custom_prompt, cfg, steps, seed, negative_prompt, use_negative_prompt):
    """Generate a blood cell image. GPU is allocated for this function."""
    import random

    pipeline = load_pipeline()

    prompt = custom_prompt.strip() if custom_prompt and custom_prompt.strip() else CELL_TYPES.get(cell_type, CELL_TYPES["Neutrophil"])

    # Handle seed: -1 means truly random each time
    if int(seed) < 0:
        actual_seed = random.randint(0, 2**32 - 1)
    else:
        actual_seed = int(seed)

    generator = torch.Generator(device="cuda").manual_seed(actual_seed)

    # Prepare negative prompt
    neg_prompt = None
    if use_negative_prompt:
        neg_prompt = negative_prompt.strip() if negative_prompt and negative_prompt.strip() else DEFAULT_NEGATIVE_PROMPT

    result = pipeline(
        prompt=prompt,
        negative_prompt=neg_prompt,
        height=512,
        width=512,
        num_inference_steps=int(steps),
        guidance_scale=float(cfg),
        generator=generator,
    )

    return result.images[0]


# Build interface
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="red")) as demo:
    gr.Markdown("""
    # PB Cell Generator

    Generate synthetic peripheral blood cell images using a fine-tuned Stable Diffusion 2.1 model
    trained on the PBC dataset with detailed morphological captions.

    **Model:** [esab/pbcell-sd21-v2](https://huggingface.co/esab/pbcell-sd21-v2) | **FID Score:** 79.39 | **Powered by ZeroGPU**
    """)

    with gr.Row():
        with gr.Column(scale=1):
            cell_dropdown = gr.Dropdown(
                choices=CELL_TYPE_LIST,
                value="Neutrophil",
                label="Cell Type",
                info="Select the type of blood cell to generate"
            )

            custom_box = gr.Textbox(
                label="Custom Prompt (optional)",
                placeholder="Leave empty to use the default morphological prompt for the selected cell type...",
                lines=2,
                info="Override the default prompt with your own description"
            )

            seed_box = gr.Number(
                value=-1,
                label="Seed",
                info="Random seed for reproducibility. Use -1 for random generation each time.",
                precision=0
            )

            with gr.Accordion("Advanced Settings", open=False):
                cfg_slider = gr.Slider(
                    minimum=1,
                    maximum=20,
                    value=8.5,
                    step=0.5,
                    label="Guidance Scale (CFG)",
                    info="Controls how closely the image follows the prompt. Higher values = stronger adherence to prompt but may reduce quality. Recommended: 7-9."
                )
                steps_slider = gr.Slider(
                    minimum=10,
                    maximum=50,
                    value=20,
                    step=5,
                    label="Inference Steps",
                    info="Number of denoising steps. More steps = higher quality but slower generation. Recommended: 20-30."
                )

                gr.Markdown("---")

                use_negative_checkbox = gr.Checkbox(
                    value=False,
                    label="Use Negative Prompt",
                    info="Enable to steer generation away from unwanted characteristics (e.g., white backgrounds, blur)"
                )
                negative_prompt_box = gr.Textbox(
                    value=DEFAULT_NEGATIVE_PROMPT,
                    label="Negative Prompt",
                    placeholder="Describe what you DON'T want in the image...",
                    lines=2,
                    info="The model will avoid these characteristics. Helps prevent pale/washed out backgrounds and blurry images."
                )

            btn = gr.Button("Generate Cell Image", variant="primary", elem_classes=["primary-btn"])

        with gr.Column(scale=1):
            output_img = gr.Image(label="Generated Cell", show_label=True)

    gr.Markdown("""
    ---
    ### Supported Cell Types

    | Cell Type | Description |
    |-----------|-------------|
    | **Neutrophil** | Segmented nucleus, azurophilic granules |
    | **Lymphocyte** | Small cell, high N/C ratio, round nucleus |
    | **Monocyte** | Large cell, kidney-shaped nucleus, vacuoles |
    | **Eosinophil** | Bilobed nucleus, eosinophilic granules |
    | **Basophil** | Segmented nucleus, basophilic granules |
    | **Platelet** | Small anucleate cell fragments |
    | **Erythroblast** | Nucleated red blood cell precursor |
    | **Immature Granulocyte** | Large cell, fine chromatin, nucleoli present |

    *Note: Erythroblast images may occasionally show multiple cells due to training data characteristics.*
    """)

    btn.click(
        fn=generate,
        inputs=[cell_dropdown, custom_box, cfg_slider, steps_slider, seed_box, negative_prompt_box, use_negative_checkbox],
        outputs=output_img
    )

demo.launch()