""" AI Image Generation Module Uses Stable Diffusion to generate images from text prompts """ import torch from diffusers import StableDiffusionPipeline from PIL import Image import os class ImageGenerator: def __init__(self, model_id="stabilityai/stable-diffusion-2-1-base"): """Initialize the Stable Diffusion pipeline""" self.model_id = model_id self.pipe = None self.device = "cuda" if torch.cuda.is_available() else "cpu" def load_model(self): """Load the Stable Diffusion model""" if self.pipe is None: print(f"Loading model on {self.device}...") self.pipe = StableDiffusionPipeline.from_pretrained( self.model_id, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, safety_checker=None ) self.pipe = self.pipe.to(self.device) # Enable memory optimizations if self.device == "cuda": self.pipe.enable_attention_slicing() print("Model loaded successfully!") def generate( self, prompt: str, negative_prompt: str = "", num_inference_steps: int = 30, guidance_scale: float = 7.5, width: int = 512, height: int = 512, seed: int = None ) -> Image.Image: """ Generate an image from a text prompt Args: prompt: Text description of desired image negative_prompt: What to avoid in the image num_inference_steps: Number of denoising steps (higher = better quality, slower) guidance_scale: How closely to follow the prompt (7-10 recommended) width: Image width (must be multiple of 8) height: Image height (must be multiple of 8) seed: Random seed for reproducibility Returns: PIL Image """ self.load_model() # Set seed for reproducibility generator = None if seed is not None: generator = torch.Generator(device=self.device).manual_seed(seed) # Generate image with torch.inference_mode(): result = self.pipe( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, width=width, height=height, generator=generator ) return result.images[0] def unload_model(self): """Free up memory by unloading the model""" if self.pipe is not None: del self.pipe self.pipe = None if torch.cuda.is_available(): torch.cuda.empty_cache() print("Model unloaded") # Example usage if __name__ == "__main__": generator = ImageGenerator() image = generator.generate( prompt="A fantasy landscape with mountains and a castle at sunset", seed=42 ) image.save("test_generated.png") print("Image saved as test_generated.png")