File size: 11,610 Bytes
a8fc815 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 |
"""
Stable Diffusion Generator with Safetensors Support
Production-grade image generation with security and performance optimizations
"""
import torch
import logging
from typing import List, Optional, Dict, Any
from diffusers import (
StableDiffusionXLPipeline,
DiffusionPipeline,
LCMScheduler
)
from diffusers.models import AutoencoderKL
from safetensors import safe_open
import os
from pathlib import Path
logger = logging.getLogger(__name__)
class SafeStableDiffusionGenerator:
"""
Production-grade Stable Diffusion generator with safetensors support.
Implements security, performance, and memory optimizations.
"""
def __init__(
self,
model_id: str = "stabilityai/stable-diffusion-xl-base-1.0",
lora_path: Optional[str] = None,
use_lcm: bool = False,
device: str = "auto"
):
"""
Initialize the generator with proper security and performance settings.
Args:
model_id: Base model identifier
lora_path: Path to LoRA weights (safetensors only)
use_lcm: Use LCM scheduler for faster inference
device: Device to use ('auto', 'cuda', 'cpu')
"""
self.model_id = model_id
self.lora_path = lora_path
self.use_lcm = use_lcm
self.device = device
self.pipe = None
self.vae = None
logger.info(f"Initializing SafeStableDiffusionGenerator")
logger.info(f"Model: {model_id}")
logger.info(f"LoRA path: {lora_path}")
logger.info(f"LCM enabled: {use_lcm}")
self._setup_device()
self._load_model()
def _setup_device(self):
"""Setup device configuration."""
if self.device == "auto":
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
# Set memory optimization settings
if self.device == "cuda":
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
def _load_model(self):
"""Load model with safetensors and optimizations."""
try:
# Configure pipeline loading
load_kwargs = {
"torch_dtype": torch.float16 if self.device == "cuda" else torch.float32,
"variant": "fp16" if self.device == "cuda" else None,
"use_safetensors": True, # MANDATORY for security
"safety_checker": None, # Disable for faster inference
"requires_safety_checker": False
}
# Add device mapping for CUDA
if self.device == "cuda":
load_kwargs["device_map"] = "auto"
logger.info("Loading Stable Diffusion model with safetensors...")
# Load the main pipeline
self.pipe = StableDiffusionXLPipeline.from_pretrained(
self.model_id,
**load_kwargs
)
# Apply memory optimizations
if self.device == "cuda":
self._apply_memory_optimizations()
# Load LoRA weights if provided
if self.lora_path:
self._load_lora_weights()
# Load LCM scheduler if enabled
if self.use_lcm:
self._setup_lcm_scheduler()
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
def _apply_memory_optimizations(self):
"""Apply memory and performance optimizations."""
try:
# Enable memory efficient attention
self.pipe.enable_xformers_memory_efficient_attention()
logger.info("Enabled xFormers memory efficient attention")
# Enable attention slicing
self.pipe.enable_attention_slicing()
logger.info("Enabled attention slicing")
# Enable VAE slicing
self.pipe.enable_vae_slicing()
logger.info("Enabled VAE slicing")
# Enable CPU offload for memory optimization
self.pipe.enable_model_cpu_offload()
logger.info("Enabled model CPU offload")
except Exception as e:
logger.warning(f"Some memory optimizations failed: {e}")
def _load_lora_weights(self):
"""Load LoRA weights from safetensors files."""
if not self.lora_path or not os.path.exists(self.lora_path):
logger.warning(f"LoRA path not found: {self.lora_path}")
return
try:
# Find safetensors files in the directory
safetensors_files = []
if os.path.isdir(self.lora_path):
safetensors_files = list(Path(self.lora_path).glob("*.safetensors"))
elif self.lora_path.endswith(".safetensors"):
safetensors_files = [self.lora_path]
if not safetensors_files:
logger.warning(f"No safetensors files found in {self.lora_path}")
return
logger.info(f"Loading LoRA weights from {len(safetensors_files)} files")
# Load each safetensors file
for lora_file in safetensors_files:
try:
self.pipe.load_lora_weights(
str(lora_file.parent),
weight_name=lora_file.name
)
logger.info(f"Loaded LoRA: {lora_file.name}")
except Exception as e:
logger.warning(f"Failed to load LoRA {lora_file.name}: {e}")
except Exception as e:
logger.error(f"Failed to load LoRA weights: {e}")
def _setup_lcm_scheduler(self):
"""Setup LCM scheduler for faster inference."""
try:
# This would require the LCM LoRA to be loaded first
# For now, we'll use a faster scheduler configuration
self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
logger.info("LCM scheduler configured")
except Exception as e:
logger.warning(f"Failed to setup LCM scheduler: {e}")
def generate_frames(
self,
prompt: str,
frames: int = 5,
negative_prompt: Optional[str] = None,
width: int = 1024,
height: int = 1024,
num_inference_steps: int = 25,
guidance_scale: float = 7.5,
seed: Optional[int] = None
) -> List[Any]:
"""
Generate image frames using the transformer pipeline.
Args:
prompt: Text prompt for generation
frames: Number of frames to generate
negative_prompt: Negative prompt for better results
width: Image width
height: Image height
num_inference_steps: Number of diffusion steps
guidance_scale: Classifier-free guidance scale
seed: Random seed for reproducibility
Returns:
List of generated images
"""
if not prompt.strip():
logger.warning("Empty prompt provided to generator")
return []
try:
logger.info(f"Generating {frames} frames for prompt: {prompt[:50]}...")
images = []
for i in range(frames):
logger.debug(f"Generating frame {i+1}/{frames}")
# Set seed for reproducibility if provided
generator = None
if seed is not None:
generator = torch.Generator(device=self.device).manual_seed(seed + i)
# Generate image
with torch.inference_mode():
result = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt or self._get_default_negative_prompt(),
width=width,
height=height,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
num_images_per_prompt=1
)
images.append(result.images[0])
logger.info(f"Successfully generated {len(images)} frames")
return images
except Exception as e:
logger.error(f"Frame generation failed: {e}")
return []
def _get_default_negative_prompt(self) -> str:
"""Get default negative prompt for better quality."""
return "blurry, bad quality, worst quality, low quality, ugly, duplicate, watermark, signature"
def save_model_info(self, output_path: str):
"""Save model information to file."""
info = {
"model_id": self.model_id,
"device": self.device,
"lora_path": self.lora_path,
"use_lcm": self.use_lcm,
"model_parameters": sum(p.numel() for p in self.pipe.unet.parameters()),
"vae_parameters": sum(p.numel() for p in self.pipe.vae.parameters()),
"text_encoder_parameters": sum(p.numel() for p in self.pipe.text_encoder.parameters())
}
with open(output_path, 'w') as f:
import json
json.dump(info, f, indent=2)
logger.info(f"Model info saved to {output_path}")
def get_model_stats(self) -> Dict[str, Any]:
"""Get current model statistics."""
if not self.pipe:
return {"error": "Model not loaded"}
return {
"model_id": self.model_id,
"device": self.device,
"dtype": str(next(self.pipe.unet.parameters()).dtype),
"memory_usage": self._get_memory_usage(),
"lcm_enabled": self.use_lcm,
"lora_loaded": self.lora_path is not None
}
def _get_memory_usage(self) -> Dict[str, float]:
"""Get current memory usage."""
if self.device != "cuda":
return {"cuda_memory": 0.0, "system_memory": 0.0}
try:
return {
"cuda_memory": torch.cuda.memory_allocated() / 1024**3, # GB
"cuda_memory_reserved": torch.cuda.memory_reserved() / 1024**3 # GB
}
except:
return {"cuda_memory": 0.0, "cuda_memory_reserved": 0.0}
# Global generator instance
_generator_instance = None
def get_generator(
model_id: str = "stabilityai/stable-diffusion-xl-base-1.0",
lora_path: Optional[str] = None,
use_lcm: bool = False
) -> SafeStableDiffusionGenerator:
"""Get or create a global generator instance."""
global _generator_instance
if _generator_instance is None or _generator_instance.model_id != model_id:
_generator_instance = SafeStableDiffusionGenerator(
model_id=model_id,
lora_path=lora_path,
use_lcm=use_lcm
)
return _generator_instance
def generate_frames(
prompt: str,
frames: int = 5,
**kwargs
) -> List[Any]:
"""Convenience function for frame generation."""
generator = get_generator()
return generator.generate_frames(prompt, frames, **kwargs) |