""" Stable Diffusion Generator with Safetensors Support Production-grade image generation with security and performance optimizations """ import torch import logging from typing import List, Optional, Dict, Any from diffusers import ( StableDiffusionXLPipeline, DiffusionPipeline, LCMScheduler ) from diffusers.models import AutoencoderKL from safetensors import safe_open import os from pathlib import Path logger = logging.getLogger(__name__) class SafeStableDiffusionGenerator: """ Production-grade Stable Diffusion generator with safetensors support. Implements security, performance, and memory optimizations. """ def __init__( self, model_id: str = "stabilityai/stable-diffusion-xl-base-1.0", lora_path: Optional[str] = None, use_lcm: bool = False, device: str = "auto" ): """ Initialize the generator with proper security and performance settings. Args: model_id: Base model identifier lora_path: Path to LoRA weights (safetensors only) use_lcm: Use LCM scheduler for faster inference device: Device to use ('auto', 'cuda', 'cpu') """ self.model_id = model_id self.lora_path = lora_path self.use_lcm = use_lcm self.device = device self.pipe = None self.vae = None logger.info(f"Initializing SafeStableDiffusionGenerator") logger.info(f"Model: {model_id}") logger.info(f"LoRA path: {lora_path}") logger.info(f"LCM enabled: {use_lcm}") self._setup_device() self._load_model() def _setup_device(self): """Setup device configuration.""" if self.device == "auto": self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {self.device}") # Set memory optimization settings if self.device == "cuda": torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True def _load_model(self): """Load model with safetensors and optimizations.""" try: # Configure pipeline loading load_kwargs = { "torch_dtype": torch.float16 if self.device == "cuda" else torch.float32, "variant": "fp16" if self.device == "cuda" else None, "use_safetensors": True, # MANDATORY for security "safety_checker": None, # Disable for faster inference "requires_safety_checker": False } # Add device mapping for CUDA if self.device == "cuda": load_kwargs["device_map"] = "auto" logger.info("Loading Stable Diffusion model with safetensors...") # Load the main pipeline self.pipe = StableDiffusionXLPipeline.from_pretrained( self.model_id, **load_kwargs ) # Apply memory optimizations if self.device == "cuda": self._apply_memory_optimizations() # Load LoRA weights if provided if self.lora_path: self._load_lora_weights() # Load LCM scheduler if enabled if self.use_lcm: self._setup_lcm_scheduler() logger.info("Model loaded successfully") except Exception as e: logger.error(f"Failed to load model: {e}") raise def _apply_memory_optimizations(self): """Apply memory and performance optimizations.""" try: # Enable memory efficient attention self.pipe.enable_xformers_memory_efficient_attention() logger.info("Enabled xFormers memory efficient attention") # Enable attention slicing self.pipe.enable_attention_slicing() logger.info("Enabled attention slicing") # Enable VAE slicing self.pipe.enable_vae_slicing() logger.info("Enabled VAE slicing") # Enable CPU offload for memory optimization self.pipe.enable_model_cpu_offload() logger.info("Enabled model CPU offload") except Exception as e: logger.warning(f"Some memory optimizations failed: {e}") def _load_lora_weights(self): """Load LoRA weights from safetensors files.""" if not self.lora_path or not os.path.exists(self.lora_path): logger.warning(f"LoRA path not found: {self.lora_path}") return try: # Find safetensors files in the directory safetensors_files = [] if os.path.isdir(self.lora_path): safetensors_files = list(Path(self.lora_path).glob("*.safetensors")) elif self.lora_path.endswith(".safetensors"): safetensors_files = [self.lora_path] if not safetensors_files: logger.warning(f"No safetensors files found in {self.lora_path}") return logger.info(f"Loading LoRA weights from {len(safetensors_files)} files") # Load each safetensors file for lora_file in safetensors_files: try: self.pipe.load_lora_weights( str(lora_file.parent), weight_name=lora_file.name ) logger.info(f"Loaded LoRA: {lora_file.name}") except Exception as e: logger.warning(f"Failed to load LoRA {lora_file.name}: {e}") except Exception as e: logger.error(f"Failed to load LoRA weights: {e}") def _setup_lcm_scheduler(self): """Setup LCM scheduler for faster inference.""" try: # This would require the LCM LoRA to be loaded first # For now, we'll use a faster scheduler configuration self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config) logger.info("LCM scheduler configured") except Exception as e: logger.warning(f"Failed to setup LCM scheduler: {e}") def generate_frames( self, prompt: str, frames: int = 5, negative_prompt: Optional[str] = None, width: int = 1024, height: int = 1024, num_inference_steps: int = 25, guidance_scale: float = 7.5, seed: Optional[int] = None ) -> List[Any]: """ Generate image frames using the transformer pipeline. Args: prompt: Text prompt for generation frames: Number of frames to generate negative_prompt: Negative prompt for better results width: Image width height: Image height num_inference_steps: Number of diffusion steps guidance_scale: Classifier-free guidance scale seed: Random seed for reproducibility Returns: List of generated images """ if not prompt.strip(): logger.warning("Empty prompt provided to generator") return [] try: logger.info(f"Generating {frames} frames for prompt: {prompt[:50]}...") images = [] for i in range(frames): logger.debug(f"Generating frame {i+1}/{frames}") # Set seed for reproducibility if provided generator = None if seed is not None: generator = torch.Generator(device=self.device).manual_seed(seed + i) # Generate image with torch.inference_mode(): result = self.pipe( prompt=prompt, negative_prompt=negative_prompt or self._get_default_negative_prompt(), width=width, height=height, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator, num_images_per_prompt=1 ) images.append(result.images[0]) logger.info(f"Successfully generated {len(images)} frames") return images except Exception as e: logger.error(f"Frame generation failed: {e}") return [] def _get_default_negative_prompt(self) -> str: """Get default negative prompt for better quality.""" return "blurry, bad quality, worst quality, low quality, ugly, duplicate, watermark, signature" def save_model_info(self, output_path: str): """Save model information to file.""" info = { "model_id": self.model_id, "device": self.device, "lora_path": self.lora_path, "use_lcm": self.use_lcm, "model_parameters": sum(p.numel() for p in self.pipe.unet.parameters()), "vae_parameters": sum(p.numel() for p in self.pipe.vae.parameters()), "text_encoder_parameters": sum(p.numel() for p in self.pipe.text_encoder.parameters()) } with open(output_path, 'w') as f: import json json.dump(info, f, indent=2) logger.info(f"Model info saved to {output_path}") def get_model_stats(self) -> Dict[str, Any]: """Get current model statistics.""" if not self.pipe: return {"error": "Model not loaded"} return { "model_id": self.model_id, "device": self.device, "dtype": str(next(self.pipe.unet.parameters()).dtype), "memory_usage": self._get_memory_usage(), "lcm_enabled": self.use_lcm, "lora_loaded": self.lora_path is not None } def _get_memory_usage(self) -> Dict[str, float]: """Get current memory usage.""" if self.device != "cuda": return {"cuda_memory": 0.0, "system_memory": 0.0} try: return { "cuda_memory": torch.cuda.memory_allocated() / 1024**3, # GB "cuda_memory_reserved": torch.cuda.memory_reserved() / 1024**3 # GB } except: return {"cuda_memory": 0.0, "cuda_memory_reserved": 0.0} # Global generator instance _generator_instance = None def get_generator( model_id: str = "stabilityai/stable-diffusion-xl-base-1.0", lora_path: Optional[str] = None, use_lcm: bool = False ) -> SafeStableDiffusionGenerator: """Get or create a global generator instance.""" global _generator_instance if _generator_instance is None or _generator_instance.model_id != model_id: _generator_instance = SafeStableDiffusionGenerator( model_id=model_id, lora_path=lora_path, use_lcm=use_lcm ) return _generator_instance def generate_frames( prompt: str, frames: int = 5, **kwargs ) -> List[Any]: """Convenience function for frame generation.""" generator = get_generator() return generator.generate_frames(prompt, frames, **kwargs)