ai-image-generator / generator.py
Robin7339's picture
Upload 6 files
cce2b06 verified
"""
SDXL Image Generation Engine
Handles model loading, generation, and optimization
"""
import torch
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
from PIL import Image
import os
from typing import Optional, Tuple
import warnings
from config import (
MODEL_ID,
REFINER_ID,
USE_REFINER,
DEFAULT_WIDTH,
DEFAULT_HEIGHT,
DEFAULT_GUIDANCE_SCALE,
DEFAULT_NUM_STEPS,
DEFAULT_REFINER_STEPS
)
# Suppress warnings
warnings.filterwarnings("ignore")
class ImageGenerator:
"""
SDXL-based image generation with optional refiner
"""
def __init__(self, use_refiner: bool = USE_REFINER):
"""
Initialize the image generator
Args:
use_refiner: Whether to use the refiner model for better quality
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.use_refiner = use_refiner
self.base_pipe = None
self.refiner_pipe = None
self._initialized = False
print(f"Using device: {self.device}")
def load_models(self):
"""
Load SDXL base and optional refiner models
"""
if self._initialized:
return
print("Loading SDXL base model...")
print("This may take a few minutes on first load...")
try:
# Load base model
self.base_pipe = DiffusionPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
use_safetensors=True,
variant="fp16" if self.device == "cuda" else None
)
# Optimize for memory
self.base_pipe.to(self.device)
# Use better scheduler
self.base_pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.base_pipe.scheduler.config
)
# Enable memory optimizations
if self.device == "cuda":
self.base_pipe.enable_attention_slicing()
# Try to enable xformers if available
try:
self.base_pipe.enable_xformers_memory_efficient_attention()
except:
pass
print("βœ… Base model loaded successfully!")
# Load refiner if requested
if self.use_refiner:
print("Loading refiner model...")
self.refiner_pipe = DiffusionPipeline.from_pretrained(
REFINER_ID,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
use_safetensors=True,
variant="fp16" if self.device == "cuda" else None
)
self.refiner_pipe.to(self.device)
if self.device == "cuda":
self.refiner_pipe.enable_attention_slicing()
print("βœ… Refiner model loaded successfully!")
self._initialized = True
except Exception as e:
print(f"❌ Error loading models: {e}")
raise
def generate(
self,
prompt: str,
negative_prompt: str = "",
width: int = DEFAULT_WIDTH,
height: int = DEFAULT_HEIGHT,
guidance_scale: float = DEFAULT_GUIDANCE_SCALE,
num_inference_steps: int = DEFAULT_NUM_STEPS,
seed: int = -1,
use_refiner: Optional[bool] = None
) -> Tuple[Image.Image, dict]:
"""
Generate an image from a prompt
Args:
prompt: Text prompt for generation
negative_prompt: Negative prompt to avoid certain features
width: Image width
height: Image height
guidance_scale: How closely to follow the prompt (7-9 recommended)
num_inference_steps: Number of denoising steps (30-50 recommended)
seed: Random seed for reproducibility (-1 for random)
use_refiner: Override default refiner setting
Returns:
Tuple of (generated_image, metadata)
"""
# Ensure models are loaded
if not self._initialized:
self.load_models()
# Handle seed
if seed == -1:
seed = torch.randint(0, 2**32 - 1, (1,)).item()
generator = torch.Generator(device=self.device).manual_seed(seed)
# Determine if we should use refiner
use_refiner_now = use_refiner if use_refiner is not None else self.use_refiner
try:
print(f"Generating image with seed: {seed}")
# Generate with base model
if use_refiner_now and self.refiner_pipe is not None:
# Generate latent with base, refine with refiner
image = self.base_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
output_type="latent"
).images[0]
# Refine the latent
print("Refining image...")
image = self.refiner_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=image,
num_inference_steps=DEFAULT_REFINER_STEPS,
generator=generator
).images[0]
else:
# Generate directly
image = self.base_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator
).images[0]
# Metadata
metadata = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"seed": seed,
"width": width,
"height": height,
"guidance_scale": guidance_scale,
"steps": num_inference_steps,
"refiner_used": use_refiner_now
}
print("βœ… Image generated successfully!")
return image, metadata
except Exception as e:
print(f"❌ Generation error: {e}")
raise
def unload_models(self):
"""
Unload models to free memory
"""
if self.base_pipe is not None:
del self.base_pipe
self.base_pipe = None
if self.refiner_pipe is not None:
del self.refiner_pipe
self.refiner_pipe = None
if self.device == "cuda":
torch.cuda.empty_cache()
self._initialized = False
print("Models unloaded")
# Test function
if __name__ == "__main__":
print("=== Image Generator Test ===\n")
generator = ImageGenerator(use_refiner=False)
generator.load_models()
test_prompt = "A beautiful sunset over mountains, highly detailed, photorealistic"
test_negative = "blurry, low quality, distorted"
print(f"\nGenerating test image...")
print(f"Prompt: {test_prompt}")
image, metadata = generator.generate(
prompt=test_prompt,
negative_prompt=test_negative,
width=512, # Smaller for testing
height=512,
num_inference_steps=20, # Fewer steps for testing
seed=42
)
# Save test image
output_path = "test_output.png"
image.save(output_path)
print(f"\nβœ… Test image saved to: {output_path}")
print(f"Metadata: {metadata}")