""" SDXL Image Generation Engine Handles model loading, generation, and optimization """ import torch from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler from PIL import Image import os from typing import Optional, Tuple import warnings from config import ( MODEL_ID, REFINER_ID, USE_REFINER, DEFAULT_WIDTH, DEFAULT_HEIGHT, DEFAULT_GUIDANCE_SCALE, DEFAULT_NUM_STEPS, DEFAULT_REFINER_STEPS ) # Suppress warnings warnings.filterwarnings("ignore") class ImageGenerator: """ SDXL-based image generation with optional refiner """ def __init__(self, use_refiner: bool = USE_REFINER): """ Initialize the image generator Args: use_refiner: Whether to use the refiner model for better quality """ self.device = "cuda" if torch.cuda.is_available() else "cpu" self.use_refiner = use_refiner self.base_pipe = None self.refiner_pipe = None self._initialized = False print(f"Using device: {self.device}") def load_models(self): """ Load SDXL base and optional refiner models """ if self._initialized: return print("Loading SDXL base model...") print("This may take a few minutes on first load...") try: # Load base model self.base_pipe = DiffusionPipeline.from_pretrained( MODEL_ID, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, use_safetensors=True, variant="fp16" if self.device == "cuda" else None ) # Optimize for memory self.base_pipe.to(self.device) # Use better scheduler self.base_pipe.scheduler = DPMSolverMultistepScheduler.from_config( self.base_pipe.scheduler.config ) # Enable memory optimizations if self.device == "cuda": self.base_pipe.enable_attention_slicing() # Try to enable xformers if available try: self.base_pipe.enable_xformers_memory_efficient_attention() except: pass print("✅ Base model loaded successfully!") # Load refiner if requested if self.use_refiner: print("Loading refiner model...") self.refiner_pipe = DiffusionPipeline.from_pretrained( REFINER_ID, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, use_safetensors=True, variant="fp16" if self.device == "cuda" else None ) self.refiner_pipe.to(self.device) if self.device == "cuda": self.refiner_pipe.enable_attention_slicing() print("✅ Refiner model loaded successfully!") self._initialized = True except Exception as e: print(f"❌ Error loading models: {e}") raise def generate( self, prompt: str, negative_prompt: str = "", width: int = DEFAULT_WIDTH, height: int = DEFAULT_HEIGHT, guidance_scale: float = DEFAULT_GUIDANCE_SCALE, num_inference_steps: int = DEFAULT_NUM_STEPS, seed: int = -1, use_refiner: Optional[bool] = None ) -> Tuple[Image.Image, dict]: """ Generate an image from a prompt Args: prompt: Text prompt for generation negative_prompt: Negative prompt to avoid certain features width: Image width height: Image height guidance_scale: How closely to follow the prompt (7-9 recommended) num_inference_steps: Number of denoising steps (30-50 recommended) seed: Random seed for reproducibility (-1 for random) use_refiner: Override default refiner setting Returns: Tuple of (generated_image, metadata) """ # Ensure models are loaded if not self._initialized: self.load_models() # Handle seed if seed == -1: seed = torch.randint(0, 2**32 - 1, (1,)).item() generator = torch.Generator(device=self.device).manual_seed(seed) # Determine if we should use refiner use_refiner_now = use_refiner if use_refiner is not None else self.use_refiner try: print(f"Generating image with seed: {seed}") # Generate with base model if use_refiner_now and self.refiner_pipe is not None: # Generate latent with base, refine with refiner image = self.base_pipe( prompt=prompt, negative_prompt=negative_prompt, width=width, height=height, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, output_type="latent" ).images[0] # Refine the latent print("Refining image...") image = self.refiner_pipe( prompt=prompt, negative_prompt=negative_prompt, image=image, num_inference_steps=DEFAULT_REFINER_STEPS, generator=generator ).images[0] else: # Generate directly image = self.base_pipe( prompt=prompt, negative_prompt=negative_prompt, width=width, height=height, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator ).images[0] # Metadata metadata = { "prompt": prompt, "negative_prompt": negative_prompt, "seed": seed, "width": width, "height": height, "guidance_scale": guidance_scale, "steps": num_inference_steps, "refiner_used": use_refiner_now } print("✅ Image generated successfully!") return image, metadata except Exception as e: print(f"❌ Generation error: {e}") raise def unload_models(self): """ Unload models to free memory """ if self.base_pipe is not None: del self.base_pipe self.base_pipe = None if self.refiner_pipe is not None: del self.refiner_pipe self.refiner_pipe = None if self.device == "cuda": torch.cuda.empty_cache() self._initialized = False print("Models unloaded") # Test function if __name__ == "__main__": print("=== Image Generator Test ===\n") generator = ImageGenerator(use_refiner=False) generator.load_models() test_prompt = "A beautiful sunset over mountains, highly detailed, photorealistic" test_negative = "blurry, low quality, distorted" print(f"\nGenerating test image...") print(f"Prompt: {test_prompt}") image, metadata = generator.generate( prompt=test_prompt, negative_prompt=test_negative, width=512, # Smaller for testing height=512, num_inference_steps=20, # Fewer steps for testing seed=42 ) # Save test image output_path = "test_output.png" image.save(output_path) print(f"\n✅ Test image saved to: {output_path}") print(f"Metadata: {metadata}")