""" Stencil Image Generator using Stable Diffusion This module provides a simple interface to generate drawing stencil images using pretrained Stable Diffusion models with prompt engineering. """ import torch from diffusers import ( StableDiffusionPipeline, DPMSolverMultistepScheduler, UNet2DConditionModel, AutoencoderKL, PNDMScheduler ) from transformers import CLIPTextModel, CLIPTokenizer from PIL import Image, ImageOps, ImageEnhance, ImageFilter from typing import Optional, List, Union import os import numpy as np from scipy import ndimage def _patch_clip_init(): """ Monkey-patch CLIPTextModel.__init__ to ignore offload_state_dict parameter. This fixes compatibility issues between mismatched transformers versions. """ try: from transformers import CLIPTextModel original_init = CLIPTextModel.__init__ def patched_init(self, config, *args, **kwargs): # Remove the offload_state_dict parameter if it exists kwargs.pop('offload_state_dict', None) return original_init(self, config, *args, **kwargs) CLIPTextModel.__init__ = patched_init except ImportError: pass # transformers not installed yet class StencilGenerator: """ A class to generate drawing stencil images using Stable Diffusion. This generator automatically appends stencil-specific prompt decorations to guide the model toward producing black and white stencil-style images. """ def __init__( self, model_id: str = "Manojb/stable-diffusion-2-1-base", # model_id: str = "runwayml/stable-diffusion-v1-5", checkpoint_path: Optional[str] = None, device: Optional[str] = None, use_fp16: bool = True ): """ Initialize the Stencil Generator. Args: model_id: HuggingFace model ID for Stable Diffusion model (used if checkpoint_path is None) checkpoint_path: Path to fine-tuned checkpoint directory (e.g., "./checkpoint-1000") If provided, loads fine-tuned model instead of pretrained model device: Device to run on ('cuda', 'cpu', or None for auto-detect) use_fp16: Whether to use half precision (FP16) for faster inference """ self.model_id = model_id self.checkpoint_path = checkpoint_path self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.use_fp16 = use_fp16 and self.device == "cuda" self.is_checkpoint_model = checkpoint_path is not None # Apply monkey-patch to fix transformers version compatibility _patch_clip_init() # Load model based on whether checkpoint is provided if self.is_checkpoint_model: self._load_from_checkpoint(checkpoint_path) else: self._load_from_pretrained(model_id) print("Model loaded successfully!") # Set prompt decoration based on model type if self.is_checkpoint_model: # Fine-tuned models use simple "sketch of" prefix self.stencil_suffix = "Sketch of" self.default_negative_prompt = None else: # Standard SD 2.1 models use detailed stencil suffix self.stencil_suffix = ( "black silhouette, high contrast, simple stencil design, " "centered in frame, complete object visible, isolated subject" ) self.default_negative_prompt = ( "color, colorful, photograph, realistic, detailed, complex, " ) def _load_from_pretrained(self, model_id: str): """ Load a pretrained model from HuggingFace. Args: model_id: HuggingFace model ID """ print(f"Loading pretrained model {model_id} on {self.device}...") # Load the pipeline with version-compatible parameters dtype = torch.float16 if self.use_fp16 else torch.float32 self.pipe = StableDiffusionPipeline.from_pretrained( model_id, torch_dtype=dtype, safety_checker=None, # Disable for faster loading ) # Use DPM-Solver for faster generation self.pipe.scheduler = DPMSolverMultistepScheduler.from_config( self.pipe.scheduler.config ) self.pipe = self.pipe.to(self.device) # Enable memory optimizations if self.device == "cuda": self.pipe.enable_attention_slicing() # Uncomment if you have limited VRAM # self.pipe.enable_vae_slicing() def _load_from_checkpoint(self, checkpoint_path: str): """ Load a fine-tuned model from checkpoint directory or HuggingFace Hub. Args: checkpoint_path: Path to checkpoint directory containing UNet, or HuggingFace Hub model ID (e.g., "username/model-name") """ print(f"Loading fine-tuned checkpoint from {checkpoint_path} on {self.device}...") # Base model for standard components base_model = "runwayml/stable-diffusion-v1-5" print("Loading tokenizer...") tokenizer = CLIPTokenizer.from_pretrained(base_model, subfolder="tokenizer") print("Loading text encoder...") text_encoder = CLIPTextModel.from_pretrained(base_model, subfolder="text_encoder") print("Loading VAE...") vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae") print("Loading scheduler...") scheduler = PNDMScheduler.from_pretrained(base_model, subfolder="scheduler") # Load fine-tuned UNet from checkpoint # Handles both local paths and HuggingFace Hub model IDs if os.path.exists(checkpoint_path): # Local path - append /unet subdirectory unet_path = f"{checkpoint_path}/unet" else: # Assume it's a HuggingFace Hub model ID unet_path = checkpoint_path print(f"Loading fine-tuned UNet from {unet_path}...") unet = UNet2DConditionModel.from_pretrained(unet_path, subfolder="unet" if not os.path.exists(checkpoint_path) else None) # Assemble pipeline print("Assembling pipeline...") self.pipe = StableDiffusionPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False ) # Move to device with FP16 if enabled if self.device == "cuda": if self.use_fp16: self.pipe.vae = self.pipe.vae.to(self.device, dtype=torch.float16) self.pipe.text_encoder = self.pipe.text_encoder.to(self.device, dtype=torch.float16) self.pipe.unet = self.pipe.unet.to(self.device, dtype=torch.float16) else: self.pipe = self.pipe.to(self.device) else: self.pipe = self.pipe.to(self.device) def _clean_stencil_image( self, image: Image.Image, binary_threshold: int = 128, invert_if_needed: bool = True, remove_small_objects: bool = True, min_object_size: int = 100 ) -> Image.Image: """ Aggressively convert any image to a clean binary stencil. This uses Otsu's method and morphological operations to force a clean black silhouette on pure white background, regardless of what the model generated. Args: image: Input PIL Image binary_threshold: Threshold for binarization (0-255), 128 = middle invert_if_needed: Auto-detect if we need to invert (black on white vs white on black) remove_small_objects: Remove small noise/artifacts min_object_size: Minimum pixel area to keep (removes noise) Returns: Pure black and white stencil image """ # Convert to grayscale first if image.mode != 'L': image = image.convert('L') # Convert to numpy array img_array = np.array(image) # Apply Otsu's method for automatic threshold detection # This finds the optimal threshold to separate foreground/background try: from skimage.filters import threshold_otsu binary_threshold = threshold_otsu(img_array) except ImportError: # Fall back to simple threshold if skimage not available binary_threshold = 128 # Apply binary threshold - create stark black and white binary = img_array > binary_threshold # Decide if we need to invert (we want black subject on white background) if invert_if_needed: # Count pixels - if more white than black, we likely have black subject on white (correct) # If more black than white, we have white subject on black (need to invert) white_pixels = np.sum(binary) total_pixels = binary.size if white_pixels < total_pixels / 2: # More black than white - invert binary = ~binary # Remove small objects (noise/artifacts) if remove_small_objects: try: from scipy.ndimage import label, sum as ndi_sum # Label connected components labeled_array, num_features = label(~binary) # Invert for labeling dark regions # Calculate size of each component component_sizes = ndi_sum(~binary, labeled_array, range(num_features + 1)) # Remove small components mask_size = component_sizes < min_object_size remove_pixel = mask_size[labeled_array] binary[remove_pixel] = True # Set to white (background) except ImportError: pass # Skip if scipy not available # Apply slight morphological closing to fill small holes in the subject try: from scipy.ndimage import binary_closing binary = binary_closing(binary, structure=np.ones((3, 3))) except ImportError: pass # Convert boolean array to uint8 (True->255, False->0) result = (binary * 255).astype(np.uint8) # Convert back to PIL Image cleaned_image = Image.fromarray(result, mode='L').convert('RGB') return cleaned_image def generate( self, prompt: str, num_images: int = 1, negative_prompt: Optional[str] = None, num_inference_steps: int = 25, guidance_scale: float = 7.5, width: int = 512, height: int = 512, seed: Optional[int] = None, add_stencil_suffix: bool = True, clean_background: bool = True, ) -> Union[Image.Image, List[Image.Image]]: """ Generate stencil images based on the prompt. Args: prompt: Base text prompt describing what to draw negative_prompt: Things to avoid in the generation num_images: Number of images to generate num_inference_steps: Number of denoising steps (higher = better quality, slower) guidance_scale: How strongly to follow the prompt (7-8 recommended) width: Image width in pixels (must be divisible by 8) height: Image height in pixels (must be divisible by 8) seed: Random seed for reproducibility (None for random) add_stencil_suffix: Whether to automatically add stencil styling to prompt clean_background: Whether to post-process into pure binary stencil (highly recommended) Returns: Single PIL Image if num_images=1, otherwise list of PIL Images """ # Construct full prompt based on model type full_prompt = prompt if self.is_checkpoint_model: # For fine-tuned checkpoints, add "sketch of" prefix if add_stencil_suffix and not prompt.lower().startswith("sketch of"): full_prompt = f"sketch of {prompt}" else: # For standard models, use stencil suffix if add_stencil_suffix: full_prompt = f"{prompt}, {self.stencil_suffix}" # Use default negative prompt if none provided (None for checkpoint models) full_negative_prompt = negative_prompt or self.default_negative_prompt # Set seed if provided generator = None if seed is not None: generator = torch.Generator(device=self.device).manual_seed(seed) print(f"Generating {num_images} stencil image(s)...") print(f"Prompt: {full_prompt}") # Generate images with torch.autocast(self.device) if self.use_fp16 else torch.no_grad(): result = self.pipe( prompt=full_prompt, num_images_per_prompt=num_images, negative_prompt=full_negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, width=width, height=height, generator=generator, ) images = result.images # Apply post-processing to clean background if enabled if clean_background: print("Cleaning background...") images = [self._clean_stencil_image(img) for img in images] print("Generation complete!") # Return single image or list return images[0] if num_images == 1 else images def save_image( self, image: Image.Image, output_path: str, create_dirs: bool = True ): """ Save a generated image to disk. Args: image: PIL Image to save output_path: Path where to save the image create_dirs: Whether to create parent directories if they don't exist """ if create_dirs: os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) image.save(output_path) print(f"Image saved to: {output_path}") def generate_and_save( self, prompt: str, output_path: str, num_images: int = 1, **kwargs ) -> Image.Image: """ Generate a stencil image and save it to disk in one call. Args: prompt: Base text prompt describing what to draw output_path: Path where to save the image **kwargs: Additional arguments passed to generate() Returns: The generated PIL Image """ image = self.generate(prompt, num_images, **kwargs) # Save single or multiple images # if numb images is 1, save directly, else save with index suffix if num_images == 1: self.save_image(image, output_path) else: for idx, img in enumerate(image): path = output_path.replace(".png", f"_{idx+1}.png") self.save_image(img, path) return image