File size: 11,610 Bytes
a8fc815
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
"""
Stable Diffusion Generator with Safetensors Support
Production-grade image generation with security and performance optimizations
"""

import torch
import logging
from typing import List, Optional, Dict, Any
from diffusers import (
    StableDiffusionXLPipeline, 
    DiffusionPipeline,
    LCMScheduler
)
from diffusers.models import AutoencoderKL
from safetensors import safe_open
import os
from pathlib import Path

logger = logging.getLogger(__name__)

class SafeStableDiffusionGenerator:
    """
    Production-grade Stable Diffusion generator with safetensors support.
    Implements security, performance, and memory optimizations.
    """
    
    def __init__(
        self,
        model_id: str = "stabilityai/stable-diffusion-xl-base-1.0",
        lora_path: Optional[str] = None,
        use_lcm: bool = False,
        device: str = "auto"
    ):
        """
        Initialize the generator with proper security and performance settings.
        
        Args:
            model_id: Base model identifier
            lora_path: Path to LoRA weights (safetensors only)
            use_lcm: Use LCM scheduler for faster inference
            device: Device to use ('auto', 'cuda', 'cpu')
        """
        self.model_id = model_id
        self.lora_path = lora_path
        self.use_lcm = use_lcm
        self.device = device
        self.pipe = None
        self.vae = None
        
        logger.info(f"Initializing SafeStableDiffusionGenerator")
        logger.info(f"Model: {model_id}")
        logger.info(f"LoRA path: {lora_path}")
        logger.info(f"LCM enabled: {use_lcm}")
        
        self._setup_device()
        self._load_model()
    
    def _setup_device(self):
        """Setup device configuration."""
        if self.device == "auto":
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        logger.info(f"Using device: {self.device}")
        
        # Set memory optimization settings
        if self.device == "cuda":
            torch.backends.cudnn.benchmark = True
            torch.backends.cuda.matmul.allow_tf32 = True
    
    def _load_model(self):
        """Load model with safetensors and optimizations."""
        try:
            # Configure pipeline loading
            load_kwargs = {
                "torch_dtype": torch.float16 if self.device == "cuda" else torch.float32,
                "variant": "fp16" if self.device == "cuda" else None,
                "use_safetensors": True,  # MANDATORY for security
                "safety_checker": None,  # Disable for faster inference
                "requires_safety_checker": False
            }
            
            # Add device mapping for CUDA
            if self.device == "cuda":
                load_kwargs["device_map"] = "auto"
            
            logger.info("Loading Stable Diffusion model with safetensors...")
            
            # Load the main pipeline
            self.pipe = StableDiffusionXLPipeline.from_pretrained(
                self.model_id,
                **load_kwargs
            )
            
            # Apply memory optimizations
            if self.device == "cuda":
                self._apply_memory_optimizations()
            
            # Load LoRA weights if provided
            if self.lora_path:
                self._load_lora_weights()
            
            # Load LCM scheduler if enabled
            if self.use_lcm:
                self._setup_lcm_scheduler()
            
            logger.info("Model loaded successfully")
            
        except Exception as e:
            logger.error(f"Failed to load model: {e}")
            raise
    
    def _apply_memory_optimizations(self):
        """Apply memory and performance optimizations."""
        try:
            # Enable memory efficient attention
            self.pipe.enable_xformers_memory_efficient_attention()
            logger.info("Enabled xFormers memory efficient attention")
            
            # Enable attention slicing
            self.pipe.enable_attention_slicing()
            logger.info("Enabled attention slicing")
            
            # Enable VAE slicing
            self.pipe.enable_vae_slicing()
            logger.info("Enabled VAE slicing")
            
            # Enable CPU offload for memory optimization
            self.pipe.enable_model_cpu_offload()
            logger.info("Enabled model CPU offload")
            
        except Exception as e:
            logger.warning(f"Some memory optimizations failed: {e}")
    
    def _load_lora_weights(self):
        """Load LoRA weights from safetensors files."""
        if not self.lora_path or not os.path.exists(self.lora_path):
            logger.warning(f"LoRA path not found: {self.lora_path}")
            return
        
        try:
            # Find safetensors files in the directory
            safetensors_files = []
            if os.path.isdir(self.lora_path):
                safetensors_files = list(Path(self.lora_path).glob("*.safetensors"))
            elif self.lora_path.endswith(".safetensors"):
                safetensors_files = [self.lora_path]
            
            if not safetensors_files:
                logger.warning(f"No safetensors files found in {self.lora_path}")
                return
            
            logger.info(f"Loading LoRA weights from {len(safetensors_files)} files")
            
            # Load each safetensors file
            for lora_file in safetensors_files:
                try:
                    self.pipe.load_lora_weights(
                        str(lora_file.parent),
                        weight_name=lora_file.name
                    )
                    logger.info(f"Loaded LoRA: {lora_file.name}")
                except Exception as e:
                    logger.warning(f"Failed to load LoRA {lora_file.name}: {e}")
            
        except Exception as e:
            logger.error(f"Failed to load LoRA weights: {e}")
    
    def _setup_lcm_scheduler(self):
        """Setup LCM scheduler for faster inference."""
        try:
            # This would require the LCM LoRA to be loaded first
            # For now, we'll use a faster scheduler configuration
            self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
            logger.info("LCM scheduler configured")
        except Exception as e:
            logger.warning(f"Failed to setup LCM scheduler: {e}")
    
    def generate_frames(
        self,
        prompt: str,
        frames: int = 5,
        negative_prompt: Optional[str] = None,
        width: int = 1024,
        height: int = 1024,
        num_inference_steps: int = 25,
        guidance_scale: float = 7.5,
        seed: Optional[int] = None
    ) -> List[Any]:
        """
        Generate image frames using the transformer pipeline.
        
        Args:
            prompt: Text prompt for generation
            frames: Number of frames to generate
            negative_prompt: Negative prompt for better results
            width: Image width
            height: Image height
            num_inference_steps: Number of diffusion steps
            guidance_scale: Classifier-free guidance scale
            seed: Random seed for reproducibility
            
        Returns:
            List of generated images
        """
        if not prompt.strip():
            logger.warning("Empty prompt provided to generator")
            return []
        
        try:
            logger.info(f"Generating {frames} frames for prompt: {prompt[:50]}...")
            
            images = []
            for i in range(frames):
                logger.debug(f"Generating frame {i+1}/{frames}")
                
                # Set seed for reproducibility if provided
                generator = None
                if seed is not None:
                    generator = torch.Generator(device=self.device).manual_seed(seed + i)
                
                # Generate image
                with torch.inference_mode():
                    result = self.pipe(
                        prompt=prompt,
                        negative_prompt=negative_prompt or self._get_default_negative_prompt(),
                        width=width,
                        height=height,
                        num_inference_steps=num_inference_steps,
                        guidance_scale=guidance_scale,
                        generator=generator,
                        num_images_per_prompt=1
                    )
                
                images.append(result.images[0])
            
            logger.info(f"Successfully generated {len(images)} frames")
            return images
            
        except Exception as e:
            logger.error(f"Frame generation failed: {e}")
            return []
    
    def _get_default_negative_prompt(self) -> str:
        """Get default negative prompt for better quality."""
        return "blurry, bad quality, worst quality, low quality, ugly, duplicate, watermark, signature"
    
    def save_model_info(self, output_path: str):
        """Save model information to file."""
        info = {
            "model_id": self.model_id,
            "device": self.device,
            "lora_path": self.lora_path,
            "use_lcm": self.use_lcm,
            "model_parameters": sum(p.numel() for p in self.pipe.unet.parameters()),
            "vae_parameters": sum(p.numel() for p in self.pipe.vae.parameters()),
            "text_encoder_parameters": sum(p.numel() for p in self.pipe.text_encoder.parameters())
        }
        
        with open(output_path, 'w') as f:
            import json
            json.dump(info, f, indent=2)
        
        logger.info(f"Model info saved to {output_path}")
    
    def get_model_stats(self) -> Dict[str, Any]:
        """Get current model statistics."""
        if not self.pipe:
            return {"error": "Model not loaded"}
        
        return {
            "model_id": self.model_id,
            "device": self.device,
            "dtype": str(next(self.pipe.unet.parameters()).dtype),
            "memory_usage": self._get_memory_usage(),
            "lcm_enabled": self.use_lcm,
            "lora_loaded": self.lora_path is not None
        }
    
    def _get_memory_usage(self) -> Dict[str, float]:
        """Get current memory usage."""
        if self.device != "cuda":
            return {"cuda_memory": 0.0, "system_memory": 0.0}
        
        try:
            return {
                "cuda_memory": torch.cuda.memory_allocated() / 1024**3,  # GB
                "cuda_memory_reserved": torch.cuda.memory_reserved() / 1024**3  # GB
            }
        except:
            return {"cuda_memory": 0.0, "cuda_memory_reserved": 0.0}

# Global generator instance
_generator_instance = None

def get_generator(
    model_id: str = "stabilityai/stable-diffusion-xl-base-1.0",
    lora_path: Optional[str] = None,
    use_lcm: bool = False
) -> SafeStableDiffusionGenerator:
    """Get or create a global generator instance."""
    global _generator_instance
    
    if _generator_instance is None or _generator_instance.model_id != model_id:
        _generator_instance = SafeStableDiffusionGenerator(
            model_id=model_id,
            lora_path=lora_path,
            use_lcm=use_lcm
        )
    return _generator_instance

def generate_frames(
    prompt: str,
    frames: int = 5,
    **kwargs
) -> List[Any]:
    """Convenience function for frame generation."""
    generator = get_generator()
    return generator.generate_frames(prompt, frames, **kwargs)