File size: 15,168 Bytes
df4d2da
 
 
 
 
 
 
 
18a9cc5
 
 
 
 
 
 
 
df4d2da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18a9cc5
 
 
df4d2da
 
 
 
 
 
 
18a9cc5
 
 
df4d2da
 
 
 
18a9cc5
df4d2da
 
18a9cc5
df4d2da
 
 
 
18a9cc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df4d2da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18a9cc5
 
 
df4d2da
18a9cc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df4d2da
 
18a9cc5
 
 
 
 
 
 
 
 
 
df4d2da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18a9cc5
df4d2da
18a9cc5
 
 
 
 
 
 
 
 
 
df4d2da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
"""
Stencil Image Generator using Stable Diffusion

This module provides a simple interface to generate drawing stencil images
using pretrained Stable Diffusion models with prompt engineering.
"""

import torch
from diffusers import (
    StableDiffusionPipeline,
    DPMSolverMultistepScheduler,
    UNet2DConditionModel,
    AutoencoderKL,
    PNDMScheduler
)
from transformers import CLIPTextModel, CLIPTokenizer
from PIL import Image, ImageOps, ImageEnhance, ImageFilter
from typing import Optional, List, Union
import os
import numpy as np
from scipy import ndimage


def _patch_clip_init():
    """
    Monkey-patch CLIPTextModel.__init__ to ignore offload_state_dict parameter.
    This fixes compatibility issues between mismatched transformers versions.
    """
    try:
        from transformers import CLIPTextModel
        original_init = CLIPTextModel.__init__

        def patched_init(self, config, *args, **kwargs):
            # Remove the offload_state_dict parameter if it exists
            kwargs.pop('offload_state_dict', None)
            return original_init(self, config, *args, **kwargs)

        CLIPTextModel.__init__ = patched_init
    except ImportError:
        pass  # transformers not installed yet


class StencilGenerator:
    """
    A class to generate drawing stencil images using Stable Diffusion.

    This generator automatically appends stencil-specific prompt decorations
    to guide the model toward producing black and white stencil-style images.
    """

    def __init__(
        self,
        model_id: str = "Manojb/stable-diffusion-2-1-base",
        # model_id: str = "runwayml/stable-diffusion-v1-5",
        checkpoint_path: Optional[str] = None,
        device: Optional[str] = None,
        use_fp16: bool = True
    ):
        """
        Initialize the Stencil Generator.

        Args:
            model_id: HuggingFace model ID for Stable Diffusion model (used if checkpoint_path is None)
            checkpoint_path: Path to fine-tuned checkpoint directory (e.g., "./checkpoint-1000")
                           If provided, loads fine-tuned model instead of pretrained model
            device: Device to run on ('cuda', 'cpu', or None for auto-detect)
            use_fp16: Whether to use half precision (FP16) for faster inference
        """
        self.model_id = model_id
        self.checkpoint_path = checkpoint_path
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.use_fp16 = use_fp16 and self.device == "cuda"
        self.is_checkpoint_model = checkpoint_path is not None

        # Apply monkey-patch to fix transformers version compatibility
        _patch_clip_init()

        # Load model based on whether checkpoint is provided
        if self.is_checkpoint_model:
            self._load_from_checkpoint(checkpoint_path)
        else:
            self._load_from_pretrained(model_id)

        print("Model loaded successfully!")

        # Set prompt decoration based on model type
        if self.is_checkpoint_model:
            # Fine-tuned models use simple "sketch of" prefix
            self.stencil_suffix = "Sketch of"
            self.default_negative_prompt = None
        else:
            # Standard SD 2.1 models use detailed stencil suffix
            self.stencil_suffix = (
                "black silhouette, high contrast, simple stencil design, "
                "centered in frame, complete object visible, isolated subject"
            )
            self.default_negative_prompt = (
                "color, colorful, photograph, realistic, detailed, complex, "
            )

    def _load_from_pretrained(self, model_id: str):
        """
        Load a pretrained model from HuggingFace.

        Args:
            model_id: HuggingFace model ID
        """
        print(f"Loading pretrained model {model_id} on {self.device}...")

        # Load the pipeline with version-compatible parameters
        dtype = torch.float16 if self.use_fp16 else torch.float32

        self.pipe = StableDiffusionPipeline.from_pretrained(
            model_id,
            torch_dtype=dtype,
            safety_checker=None,  # Disable for faster loading
        )

        # Use DPM-Solver for faster generation
        self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
            self.pipe.scheduler.config
        )

        self.pipe = self.pipe.to(self.device)

        # Enable memory optimizations
        if self.device == "cuda":
            self.pipe.enable_attention_slicing()
            # Uncomment if you have limited VRAM
            # self.pipe.enable_vae_slicing()

    def _load_from_checkpoint(self, checkpoint_path: str):
        """
        Load a fine-tuned model from checkpoint directory or HuggingFace Hub.

        Args:
            checkpoint_path: Path to checkpoint directory containing UNet,
                           or HuggingFace Hub model ID (e.g., "username/model-name")
        """
        print(f"Loading fine-tuned checkpoint from {checkpoint_path} on {self.device}...")

        # Base model for standard components
        base_model = "runwayml/stable-diffusion-v1-5"

        print("Loading tokenizer...")
        tokenizer = CLIPTokenizer.from_pretrained(base_model, subfolder="tokenizer")

        print("Loading text encoder...")
        text_encoder = CLIPTextModel.from_pretrained(base_model, subfolder="text_encoder")

        print("Loading VAE...")
        vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae")

        print("Loading scheduler...")
        scheduler = PNDMScheduler.from_pretrained(base_model, subfolder="scheduler")

        # Load fine-tuned UNet from checkpoint
        # Handles both local paths and HuggingFace Hub model IDs
        if os.path.exists(checkpoint_path):
            # Local path - append /unet subdirectory
            unet_path = f"{checkpoint_path}/unet"
        else:
            # Assume it's a HuggingFace Hub model ID
            unet_path = checkpoint_path

        print(f"Loading fine-tuned UNet from {unet_path}...")
        unet = UNet2DConditionModel.from_pretrained(unet_path, subfolder="unet" if not os.path.exists(checkpoint_path) else None)

        # Assemble pipeline
        print("Assembling pipeline...")
        self.pipe = StableDiffusionPipeline(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            safety_checker=None,
            feature_extractor=None,
            requires_safety_checker=False
        )

        # Move to device with FP16 if enabled
        if self.device == "cuda":
            if self.use_fp16:
                self.pipe.vae = self.pipe.vae.to(self.device, dtype=torch.float16)
                self.pipe.text_encoder = self.pipe.text_encoder.to(self.device, dtype=torch.float16)
                self.pipe.unet = self.pipe.unet.to(self.device, dtype=torch.float16)
            else:
                self.pipe = self.pipe.to(self.device)
        else:
            self.pipe = self.pipe.to(self.device)

    def _clean_stencil_image(
        self,
        image: Image.Image,
        binary_threshold: int = 128,
        invert_if_needed: bool = True,
        remove_small_objects: bool = True,
        min_object_size: int = 100
    ) -> Image.Image:
        """
        Aggressively convert any image to a clean binary stencil.
        This uses Otsu's method and morphological operations to force
        a clean black silhouette on pure white background, regardless
        of what the model generated.

        Args:
            image: Input PIL Image
            binary_threshold: Threshold for binarization (0-255), 128 = middle
            invert_if_needed: Auto-detect if we need to invert (black on white vs white on black)
            remove_small_objects: Remove small noise/artifacts
            min_object_size: Minimum pixel area to keep (removes noise)

        Returns:
            Pure black and white stencil image
        """
        # Convert to grayscale first
        if image.mode != 'L':
            image = image.convert('L')

        # Convert to numpy array
        img_array = np.array(image)

        # Apply Otsu's method for automatic threshold detection
        # This finds the optimal threshold to separate foreground/background
        try:
            from skimage.filters import threshold_otsu
            binary_threshold = threshold_otsu(img_array)
        except ImportError:
            # Fall back to simple threshold if skimage not available
            binary_threshold = 128

        # Apply binary threshold - create stark black and white
        binary = img_array > binary_threshold

        # Decide if we need to invert (we want black subject on white background)
        if invert_if_needed:
            # Count pixels - if more white than black, we likely have black subject on white (correct)
            # If more black than white, we have white subject on black (need to invert)
            white_pixels = np.sum(binary)
            total_pixels = binary.size
            if white_pixels < total_pixels / 2:
                # More black than white - invert
                binary = ~binary

        # Remove small objects (noise/artifacts)
        if remove_small_objects:
            try:
                from scipy.ndimage import label, sum as ndi_sum
                # Label connected components
                labeled_array, num_features = label(~binary)  # Invert for labeling dark regions

                # Calculate size of each component
                component_sizes = ndi_sum(~binary, labeled_array, range(num_features + 1))

                # Remove small components
                mask_size = component_sizes < min_object_size
                remove_pixel = mask_size[labeled_array]
                binary[remove_pixel] = True  # Set to white (background)
            except ImportError:
                pass  # Skip if scipy not available

        # Apply slight morphological closing to fill small holes in the subject
        try:
            from scipy.ndimage import binary_closing
            binary = binary_closing(binary, structure=np.ones((3, 3)))
        except ImportError:
            pass

        # Convert boolean array to uint8 (True->255, False->0)
        result = (binary * 255).astype(np.uint8)

        # Convert back to PIL Image
        cleaned_image = Image.fromarray(result, mode='L').convert('RGB')

        return cleaned_image

    

    def generate(
        self,
        prompt: str,
        num_images: int = 1,
        negative_prompt: Optional[str] = None,
        num_inference_steps: int = 25,
        guidance_scale: float = 7.5,
        width: int = 512,
        height: int = 512,
        seed: Optional[int] = None,
        add_stencil_suffix: bool = True,
        clean_background: bool = True,
    ) -> Union[Image.Image, List[Image.Image]]:
        """
        Generate stencil images based on the prompt.

        Args:
            prompt: Base text prompt describing what to draw
            negative_prompt: Things to avoid in the generation
            num_images: Number of images to generate
            num_inference_steps: Number of denoising steps (higher = better quality, slower)
            guidance_scale: How strongly to follow the prompt (7-8 recommended)
            width: Image width in pixels (must be divisible by 8)
            height: Image height in pixels (must be divisible by 8)
            seed: Random seed for reproducibility (None for random)
            add_stencil_suffix: Whether to automatically add stencil styling to prompt
            clean_background: Whether to post-process into pure binary stencil (highly recommended)

        Returns:
            Single PIL Image if num_images=1, otherwise list of PIL Images
        """


        # Construct full prompt based on model type
        full_prompt = prompt
        if self.is_checkpoint_model:
            # For fine-tuned checkpoints, add "sketch of" prefix
            if add_stencil_suffix and not prompt.lower().startswith("sketch of"):
                full_prompt = f"sketch of {prompt}"
        else:
            # For standard models, use stencil suffix
            if add_stencil_suffix:
                full_prompt = f"{prompt}, {self.stencil_suffix}"

        # Use default negative prompt if none provided (None for checkpoint models)
        full_negative_prompt = negative_prompt or self.default_negative_prompt

        # Set seed if provided
        generator = None
        if seed is not None:
            generator = torch.Generator(device=self.device).manual_seed(seed)

        print(f"Generating {num_images} stencil image(s)...")
        print(f"Prompt: {full_prompt}")

        # Generate images
        with torch.autocast(self.device) if self.use_fp16 else torch.no_grad():
            result = self.pipe(
                prompt=full_prompt,
                num_images_per_prompt=num_images,
                negative_prompt=full_negative_prompt,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale,
                width=width,
                height=height,
                generator=generator,
            )

        images = result.images

        # Apply post-processing to clean background if enabled
        if clean_background:
            print("Cleaning background...")
            images = [self._clean_stencil_image(img) for img in images]

        print("Generation complete!")

        # Return single image or list
        return images[0] if num_images == 1 else images

    def save_image(
        self,
        image: Image.Image,
        output_path: str,
        create_dirs: bool = True
    ):
        """
        Save a generated image to disk.

        Args:
            image: PIL Image to save
            output_path: Path where to save the image
            create_dirs: Whether to create parent directories if they don't exist
        """
        if create_dirs:
            os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)

        image.save(output_path)
        print(f"Image saved to: {output_path}")

    def generate_and_save(
        self,
        prompt: str,
        output_path: str,
        num_images: int = 1,
        **kwargs
    ) -> Image.Image:
        """
        Generate a stencil image and save it to disk in one call.

        Args:
            prompt: Base text prompt describing what to draw
            output_path: Path where to save the image
            **kwargs: Additional arguments passed to generate()

        Returns:
            The generated PIL Image
        """
        image = self.generate(prompt, num_images, **kwargs)
        # Save single or multiple images
        # if numb images is 1, save directly, else save with index suffix
        if num_images == 1:
            self.save_image(image, output_path)
        else:
            for idx, img in enumerate(image):
                path = output_path.replace(".png", f"_{idx+1}.png")
                self.save_image(img, path)
        return image