ai-image-editor / image_generator.py
official.ghost.logic
Initial commit: Add project structure and core modules
3aa90e6
"""
AI Image Generation Module
Uses Stable Diffusion to generate images from text prompts
"""
import torch
from diffusers import StableDiffusionPipeline
from PIL import Image
import os
class ImageGenerator:
def __init__(self, model_id="stabilityai/stable-diffusion-2-1-base"):
"""Initialize the Stable Diffusion pipeline"""
self.model_id = model_id
self.pipe = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def load_model(self):
"""Load the Stable Diffusion model"""
if self.pipe is None:
print(f"Loading model on {self.device}...")
self.pipe = StableDiffusionPipeline.from_pretrained(
self.model_id,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
safety_checker=None
)
self.pipe = self.pipe.to(self.device)
# Enable memory optimizations
if self.device == "cuda":
self.pipe.enable_attention_slicing()
print("Model loaded successfully!")
def generate(
self,
prompt: str,
negative_prompt: str = "",
num_inference_steps: int = 30,
guidance_scale: float = 7.5,
width: int = 512,
height: int = 512,
seed: int = None
) -> Image.Image:
"""
Generate an image from a text prompt
Args:
prompt: Text description of desired image
negative_prompt: What to avoid in the image
num_inference_steps: Number of denoising steps (higher = better quality, slower)
guidance_scale: How closely to follow the prompt (7-10 recommended)
width: Image width (must be multiple of 8)
height: Image height (must be multiple of 8)
seed: Random seed for reproducibility
Returns:
PIL Image
"""
self.load_model()
# Set seed for reproducibility
generator = None
if seed is not None:
generator = torch.Generator(device=self.device).manual_seed(seed)
# Generate image
with torch.inference_mode():
result = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
width=width,
height=height,
generator=generator
)
return result.images[0]
def unload_model(self):
"""Free up memory by unloading the model"""
if self.pipe is not None:
del self.pipe
self.pipe = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("Model unloaded")
# Example usage
if __name__ == "__main__":
generator = ImageGenerator()
image = generator.generate(
prompt="A fantasy landscape with mountains and a castle at sunset",
seed=42
)
image.save("test_generated.png")
print("Image saved as test_generated.png")