Spaces:
Sleeping
Sleeping
| """ | |
| SDXL Image Generation Engine | |
| Handles model loading, generation, and optimization | |
| """ | |
| import torch | |
| from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler | |
| from PIL import Image | |
| import os | |
| from typing import Optional, Tuple | |
| import warnings | |
| from config import ( | |
| MODEL_ID, | |
| REFINER_ID, | |
| USE_REFINER, | |
| DEFAULT_WIDTH, | |
| DEFAULT_HEIGHT, | |
| DEFAULT_GUIDANCE_SCALE, | |
| DEFAULT_NUM_STEPS, | |
| DEFAULT_REFINER_STEPS | |
| ) | |
| # Suppress warnings | |
| warnings.filterwarnings("ignore") | |
| class ImageGenerator: | |
| """ | |
| SDXL-based image generation with optional refiner | |
| """ | |
| def __init__(self, use_refiner: bool = USE_REFINER): | |
| """ | |
| Initialize the image generator | |
| Args: | |
| use_refiner: Whether to use the refiner model for better quality | |
| """ | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.use_refiner = use_refiner | |
| self.base_pipe = None | |
| self.refiner_pipe = None | |
| self._initialized = False | |
| print(f"Using device: {self.device}") | |
| def load_models(self): | |
| """ | |
| Load SDXL base and optional refiner models | |
| """ | |
| if self._initialized: | |
| return | |
| print("Loading SDXL base model...") | |
| print("This may take a few minutes on first load...") | |
| try: | |
| # Load base model | |
| self.base_pipe = DiffusionPipeline.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, | |
| use_safetensors=True, | |
| variant="fp16" if self.device == "cuda" else None | |
| ) | |
| # Optimize for memory | |
| self.base_pipe.to(self.device) | |
| # Use better scheduler | |
| self.base_pipe.scheduler = DPMSolverMultistepScheduler.from_config( | |
| self.base_pipe.scheduler.config | |
| ) | |
| # Enable memory optimizations | |
| if self.device == "cuda": | |
| self.base_pipe.enable_attention_slicing() | |
| # Try to enable xformers if available | |
| try: | |
| self.base_pipe.enable_xformers_memory_efficient_attention() | |
| except: | |
| pass | |
| print("β Base model loaded successfully!") | |
| # Load refiner if requested | |
| if self.use_refiner: | |
| print("Loading refiner model...") | |
| self.refiner_pipe = DiffusionPipeline.from_pretrained( | |
| REFINER_ID, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, | |
| use_safetensors=True, | |
| variant="fp16" if self.device == "cuda" else None | |
| ) | |
| self.refiner_pipe.to(self.device) | |
| if self.device == "cuda": | |
| self.refiner_pipe.enable_attention_slicing() | |
| print("β Refiner model loaded successfully!") | |
| self._initialized = True | |
| except Exception as e: | |
| print(f"β Error loading models: {e}") | |
| raise | |
| def generate( | |
| self, | |
| prompt: str, | |
| negative_prompt: str = "", | |
| width: int = DEFAULT_WIDTH, | |
| height: int = DEFAULT_HEIGHT, | |
| guidance_scale: float = DEFAULT_GUIDANCE_SCALE, | |
| num_inference_steps: int = DEFAULT_NUM_STEPS, | |
| seed: int = -1, | |
| use_refiner: Optional[bool] = None | |
| ) -> Tuple[Image.Image, dict]: | |
| """ | |
| Generate an image from a prompt | |
| Args: | |
| prompt: Text prompt for generation | |
| negative_prompt: Negative prompt to avoid certain features | |
| width: Image width | |
| height: Image height | |
| guidance_scale: How closely to follow the prompt (7-9 recommended) | |
| num_inference_steps: Number of denoising steps (30-50 recommended) | |
| seed: Random seed for reproducibility (-1 for random) | |
| use_refiner: Override default refiner setting | |
| Returns: | |
| Tuple of (generated_image, metadata) | |
| """ | |
| # Ensure models are loaded | |
| if not self._initialized: | |
| self.load_models() | |
| # Handle seed | |
| if seed == -1: | |
| seed = torch.randint(0, 2**32 - 1, (1,)).item() | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| # Determine if we should use refiner | |
| use_refiner_now = use_refiner if use_refiner is not None else self.use_refiner | |
| try: | |
| print(f"Generating image with seed: {seed}") | |
| # Generate with base model | |
| if use_refiner_now and self.refiner_pipe is not None: | |
| # Generate latent with base, refine with refiner | |
| image = self.base_pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| width=width, | |
| height=height, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| output_type="latent" | |
| ).images[0] | |
| # Refine the latent | |
| print("Refining image...") | |
| image = self.refiner_pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| image=image, | |
| num_inference_steps=DEFAULT_REFINER_STEPS, | |
| generator=generator | |
| ).images[0] | |
| else: | |
| # Generate directly | |
| image = self.base_pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| width=width, | |
| height=height, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator | |
| ).images[0] | |
| # Metadata | |
| metadata = { | |
| "prompt": prompt, | |
| "negative_prompt": negative_prompt, | |
| "seed": seed, | |
| "width": width, | |
| "height": height, | |
| "guidance_scale": guidance_scale, | |
| "steps": num_inference_steps, | |
| "refiner_used": use_refiner_now | |
| } | |
| print("β Image generated successfully!") | |
| return image, metadata | |
| except Exception as e: | |
| print(f"β Generation error: {e}") | |
| raise | |
| def unload_models(self): | |
| """ | |
| Unload models to free memory | |
| """ | |
| if self.base_pipe is not None: | |
| del self.base_pipe | |
| self.base_pipe = None | |
| if self.refiner_pipe is not None: | |
| del self.refiner_pipe | |
| self.refiner_pipe = None | |
| if self.device == "cuda": | |
| torch.cuda.empty_cache() | |
| self._initialized = False | |
| print("Models unloaded") | |
| # Test function | |
| if __name__ == "__main__": | |
| print("=== Image Generator Test ===\n") | |
| generator = ImageGenerator(use_refiner=False) | |
| generator.load_models() | |
| test_prompt = "A beautiful sunset over mountains, highly detailed, photorealistic" | |
| test_negative = "blurry, low quality, distorted" | |
| print(f"\nGenerating test image...") | |
| print(f"Prompt: {test_prompt}") | |
| image, metadata = generator.generate( | |
| prompt=test_prompt, | |
| negative_prompt=test_negative, | |
| width=512, # Smaller for testing | |
| height=512, | |
| num_inference_steps=20, # Fewer steps for testing | |
| seed=42 | |
| ) | |
| # Save test image | |
| output_path = "test_output.png" | |
| image.save(output_path) | |
| print(f"\nβ Test image saved to: {output_path}") | |
| print(f"Metadata: {metadata}") | |