|
|
""" |
|
|
Stable Diffusion Generator with Safetensors Support |
|
|
Production-grade image generation with security and performance optimizations |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import logging |
|
|
from typing import List, Optional, Dict, Any |
|
|
from diffusers import ( |
|
|
StableDiffusionXLPipeline, |
|
|
DiffusionPipeline, |
|
|
LCMScheduler |
|
|
) |
|
|
from diffusers.models import AutoencoderKL |
|
|
from safetensors import safe_open |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class SafeStableDiffusionGenerator: |
|
|
""" |
|
|
Production-grade Stable Diffusion generator with safetensors support. |
|
|
Implements security, performance, and memory optimizations. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_id: str = "stabilityai/stable-diffusion-xl-base-1.0", |
|
|
lora_path: Optional[str] = None, |
|
|
use_lcm: bool = False, |
|
|
device: str = "auto" |
|
|
): |
|
|
""" |
|
|
Initialize the generator with proper security and performance settings. |
|
|
|
|
|
Args: |
|
|
model_id: Base model identifier |
|
|
lora_path: Path to LoRA weights (safetensors only) |
|
|
use_lcm: Use LCM scheduler for faster inference |
|
|
device: Device to use ('auto', 'cuda', 'cpu') |
|
|
""" |
|
|
self.model_id = model_id |
|
|
self.lora_path = lora_path |
|
|
self.use_lcm = use_lcm |
|
|
self.device = device |
|
|
self.pipe = None |
|
|
self.vae = None |
|
|
|
|
|
logger.info(f"Initializing SafeStableDiffusionGenerator") |
|
|
logger.info(f"Model: {model_id}") |
|
|
logger.info(f"LoRA path: {lora_path}") |
|
|
logger.info(f"LCM enabled: {use_lcm}") |
|
|
|
|
|
self._setup_device() |
|
|
self._load_model() |
|
|
|
|
|
def _setup_device(self): |
|
|
"""Setup device configuration.""" |
|
|
if self.device == "auto": |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
logger.info(f"Using device: {self.device}") |
|
|
|
|
|
|
|
|
if self.device == "cuda": |
|
|
torch.backends.cudnn.benchmark = True |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
|
|
def _load_model(self): |
|
|
"""Load model with safetensors and optimizations.""" |
|
|
try: |
|
|
|
|
|
load_kwargs = { |
|
|
"torch_dtype": torch.float16 if self.device == "cuda" else torch.float32, |
|
|
"variant": "fp16" if self.device == "cuda" else None, |
|
|
"use_safetensors": True, |
|
|
"safety_checker": None, |
|
|
"requires_safety_checker": False |
|
|
} |
|
|
|
|
|
|
|
|
if self.device == "cuda": |
|
|
load_kwargs["device_map"] = "auto" |
|
|
|
|
|
logger.info("Loading Stable Diffusion model with safetensors...") |
|
|
|
|
|
|
|
|
self.pipe = StableDiffusionXLPipeline.from_pretrained( |
|
|
self.model_id, |
|
|
**load_kwargs |
|
|
) |
|
|
|
|
|
|
|
|
if self.device == "cuda": |
|
|
self._apply_memory_optimizations() |
|
|
|
|
|
|
|
|
if self.lora_path: |
|
|
self._load_lora_weights() |
|
|
|
|
|
|
|
|
if self.use_lcm: |
|
|
self._setup_lcm_scheduler() |
|
|
|
|
|
logger.info("Model loaded successfully") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load model: {e}") |
|
|
raise |
|
|
|
|
|
def _apply_memory_optimizations(self): |
|
|
"""Apply memory and performance optimizations.""" |
|
|
try: |
|
|
|
|
|
self.pipe.enable_xformers_memory_efficient_attention() |
|
|
logger.info("Enabled xFormers memory efficient attention") |
|
|
|
|
|
|
|
|
self.pipe.enable_attention_slicing() |
|
|
logger.info("Enabled attention slicing") |
|
|
|
|
|
|
|
|
self.pipe.enable_vae_slicing() |
|
|
logger.info("Enabled VAE slicing") |
|
|
|
|
|
|
|
|
self.pipe.enable_model_cpu_offload() |
|
|
logger.info("Enabled model CPU offload") |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Some memory optimizations failed: {e}") |
|
|
|
|
|
def _load_lora_weights(self): |
|
|
"""Load LoRA weights from safetensors files.""" |
|
|
if not self.lora_path or not os.path.exists(self.lora_path): |
|
|
logger.warning(f"LoRA path not found: {self.lora_path}") |
|
|
return |
|
|
|
|
|
try: |
|
|
|
|
|
safetensors_files = [] |
|
|
if os.path.isdir(self.lora_path): |
|
|
safetensors_files = list(Path(self.lora_path).glob("*.safetensors")) |
|
|
elif self.lora_path.endswith(".safetensors"): |
|
|
safetensors_files = [self.lora_path] |
|
|
|
|
|
if not safetensors_files: |
|
|
logger.warning(f"No safetensors files found in {self.lora_path}") |
|
|
return |
|
|
|
|
|
logger.info(f"Loading LoRA weights from {len(safetensors_files)} files") |
|
|
|
|
|
|
|
|
for lora_file in safetensors_files: |
|
|
try: |
|
|
self.pipe.load_lora_weights( |
|
|
str(lora_file.parent), |
|
|
weight_name=lora_file.name |
|
|
) |
|
|
logger.info(f"Loaded LoRA: {lora_file.name}") |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to load LoRA {lora_file.name}: {e}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load LoRA weights: {e}") |
|
|
|
|
|
def _setup_lcm_scheduler(self): |
|
|
"""Setup LCM scheduler for faster inference.""" |
|
|
try: |
|
|
|
|
|
|
|
|
self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config) |
|
|
logger.info("LCM scheduler configured") |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to setup LCM scheduler: {e}") |
|
|
|
|
|
def generate_frames( |
|
|
self, |
|
|
prompt: str, |
|
|
frames: int = 5, |
|
|
negative_prompt: Optional[str] = None, |
|
|
width: int = 1024, |
|
|
height: int = 1024, |
|
|
num_inference_steps: int = 25, |
|
|
guidance_scale: float = 7.5, |
|
|
seed: Optional[int] = None |
|
|
) -> List[Any]: |
|
|
""" |
|
|
Generate image frames using the transformer pipeline. |
|
|
|
|
|
Args: |
|
|
prompt: Text prompt for generation |
|
|
frames: Number of frames to generate |
|
|
negative_prompt: Negative prompt for better results |
|
|
width: Image width |
|
|
height: Image height |
|
|
num_inference_steps: Number of diffusion steps |
|
|
guidance_scale: Classifier-free guidance scale |
|
|
seed: Random seed for reproducibility |
|
|
|
|
|
Returns: |
|
|
List of generated images |
|
|
""" |
|
|
if not prompt.strip(): |
|
|
logger.warning("Empty prompt provided to generator") |
|
|
return [] |
|
|
|
|
|
try: |
|
|
logger.info(f"Generating {frames} frames for prompt: {prompt[:50]}...") |
|
|
|
|
|
images = [] |
|
|
for i in range(frames): |
|
|
logger.debug(f"Generating frame {i+1}/{frames}") |
|
|
|
|
|
|
|
|
generator = None |
|
|
if seed is not None: |
|
|
generator = torch.Generator(device=self.device).manual_seed(seed + i) |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
result = self.pipe( |
|
|
prompt=prompt, |
|
|
negative_prompt=negative_prompt or self._get_default_negative_prompt(), |
|
|
width=width, |
|
|
height=height, |
|
|
num_inference_steps=num_inference_steps, |
|
|
guidance_scale=guidance_scale, |
|
|
generator=generator, |
|
|
num_images_per_prompt=1 |
|
|
) |
|
|
|
|
|
images.append(result.images[0]) |
|
|
|
|
|
logger.info(f"Successfully generated {len(images)} frames") |
|
|
return images |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Frame generation failed: {e}") |
|
|
return [] |
|
|
|
|
|
def _get_default_negative_prompt(self) -> str: |
|
|
"""Get default negative prompt for better quality.""" |
|
|
return "blurry, bad quality, worst quality, low quality, ugly, duplicate, watermark, signature" |
|
|
|
|
|
def save_model_info(self, output_path: str): |
|
|
"""Save model information to file.""" |
|
|
info = { |
|
|
"model_id": self.model_id, |
|
|
"device": self.device, |
|
|
"lora_path": self.lora_path, |
|
|
"use_lcm": self.use_lcm, |
|
|
"model_parameters": sum(p.numel() for p in self.pipe.unet.parameters()), |
|
|
"vae_parameters": sum(p.numel() for p in self.pipe.vae.parameters()), |
|
|
"text_encoder_parameters": sum(p.numel() for p in self.pipe.text_encoder.parameters()) |
|
|
} |
|
|
|
|
|
with open(output_path, 'w') as f: |
|
|
import json |
|
|
json.dump(info, f, indent=2) |
|
|
|
|
|
logger.info(f"Model info saved to {output_path}") |
|
|
|
|
|
def get_model_stats(self) -> Dict[str, Any]: |
|
|
"""Get current model statistics.""" |
|
|
if not self.pipe: |
|
|
return {"error": "Model not loaded"} |
|
|
|
|
|
return { |
|
|
"model_id": self.model_id, |
|
|
"device": self.device, |
|
|
"dtype": str(next(self.pipe.unet.parameters()).dtype), |
|
|
"memory_usage": self._get_memory_usage(), |
|
|
"lcm_enabled": self.use_lcm, |
|
|
"lora_loaded": self.lora_path is not None |
|
|
} |
|
|
|
|
|
def _get_memory_usage(self) -> Dict[str, float]: |
|
|
"""Get current memory usage.""" |
|
|
if self.device != "cuda": |
|
|
return {"cuda_memory": 0.0, "system_memory": 0.0} |
|
|
|
|
|
try: |
|
|
return { |
|
|
"cuda_memory": torch.cuda.memory_allocated() / 1024**3, |
|
|
"cuda_memory_reserved": torch.cuda.memory_reserved() / 1024**3 |
|
|
} |
|
|
except: |
|
|
return {"cuda_memory": 0.0, "cuda_memory_reserved": 0.0} |
|
|
|
|
|
|
|
|
_generator_instance = None |
|
|
|
|
|
def get_generator( |
|
|
model_id: str = "stabilityai/stable-diffusion-xl-base-1.0", |
|
|
lora_path: Optional[str] = None, |
|
|
use_lcm: bool = False |
|
|
) -> SafeStableDiffusionGenerator: |
|
|
"""Get or create a global generator instance.""" |
|
|
global _generator_instance |
|
|
|
|
|
if _generator_instance is None or _generator_instance.model_id != model_id: |
|
|
_generator_instance = SafeStableDiffusionGenerator( |
|
|
model_id=model_id, |
|
|
lora_path=lora_path, |
|
|
use_lcm=use_lcm |
|
|
) |
|
|
return _generator_instance |
|
|
|
|
|
def generate_frames( |
|
|
prompt: str, |
|
|
frames: int = 5, |
|
|
**kwargs |
|
|
) -> List[Any]: |
|
|
"""Convenience function for frame generation.""" |
|
|
generator = get_generator() |
|
|
return generator.generate_frames(prompt, frames, **kwargs) |