Robin7339's picture
Upload 6 files
cce2b06 verified
"""
AI Image Generator - Main Gradio Application
Professional interface for SDXL-based image generation with quality validation
"""
import gradio as gr
import torch
from PIL import Image
import os
from datetime import datetime
from generator import ImageGenerator
from prompt_optimizer import PromptOptimizer
from quality_validator import QualityValidator
from config import (
STYLE_PRESETS,
ASPECT_RATIOS,
DEFAULT_GUIDANCE_SCALE,
DEFAULT_NUM_STEPS,
MIN_QUALITY_SCORE,
MAX_RETRIES
)
class ImageGeneratorApp:
"""
Main application class combining all components
"""
def __init__(self):
self.generator = ImageGenerator(use_refiner=False)
self.optimizer = PromptOptimizer()
self.validator = QualityValidator()
self.output_dir = "outputs"
# Create output directory
os.makedirs(self.output_dir, exist_ok=True)
print("πŸš€ AI Image Generator initialized!")
print(f"Device: {'CUDA (GPU)' if torch.cuda.is_available() else 'CPU'}")
def generate_image(
self,
prompt: str,
style: str,
aspect_ratio: str,
guidance_scale: float,
num_steps: int,
seed: int,
enable_quality_check: bool,
progress=gr.Progress()
):
"""
Main generation pipeline with quality validation
"""
try:
# Update progress
progress(0, desc="Optimizing prompt...")
# Optimize prompt
enhanced_prompt, negative_prompt = self.optimizer.enhance_prompt(
prompt,
style=style
)
# Get dimensions
width, height = ASPECT_RATIOS[aspect_ratio]
# Load models if needed
progress(0.1, desc="Loading models...")
if not self.generator._initialized:
self.generator.load_models()
# Generate image (with potential retries for quality)
best_image = None
best_score = 0
attempt = 0
max_attempts = MAX_RETRIES + 1 if enable_quality_check else 1
while attempt < max_attempts:
progress(
0.2 + (attempt * 0.6 / max_attempts),
desc=f"Generating image (attempt {attempt + 1}/{max_attempts})..."
)
# Generate
current_seed = seed if seed != -1 else -1
image, metadata = self.generator.generate(
prompt=enhanced_prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_steps,
seed=current_seed
)
# Validate quality if enabled
if enable_quality_check:
progress(0.8, desc="Validating quality...")
score = self.validator.validate(image, enhanced_prompt)
if score > best_score:
best_image = image
best_score = score
# Check if quality is acceptable
if score >= MIN_QUALITY_SCORE:
best_image = image
best_score = score
break
else:
best_image = image
best_score = 0.5 # Neutral score
break
attempt += 1
if attempt < max_attempts:
seed = -1 # Use random seed for retry
# Save image
progress(0.9, desc="Saving image...")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"generated_{timestamp}.png"
filepath = os.path.join(self.output_dir, filename)
best_image.save(filepath)
# Prepare info
quality_feedback = self.validator.get_quality_feedback(best_score) if enable_quality_check else "Quality check disabled"
info = f"""
### Generation Info
**Prompt:** {prompt}
**Enhanced Prompt:** {enhanced_prompt}
**Negative Prompt:** {negative_prompt}
**Settings:**
- Style: {style}
- Aspect Ratio: {aspect_ratio} ({width}x{height})
- Guidance Scale: {guidance_scale}
- Steps: {num_steps}
- Seed: {metadata['seed']}
**Quality Score:** {best_score:.4f} - {quality_feedback}
**Attempts:** {attempt + 1}/{max_attempts}
**Saved to:** `{filepath}`
"""
progress(1.0, desc="Complete!")
return best_image, info
except Exception as e:
error_msg = f"❌ Error during generation: {str(e)}\n\nPlease check your settings and try again."
return None, error_msg
def create_ui():
"""
Create the Gradio interface
"""
app = ImageGeneratorApp()
# Custom CSS for better aesthetics
custom_css = """
.gradio-container {
font-family: 'Inter', sans-serif;
}
.main-header {
text-align: center;
margin-bottom: 2rem;
}
.generate-btn {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border: none;
color: white;
font-weight: 600;
}
"""
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🎨 AI Image Generator
### High-Accuracy SDXL with Intelligent Prompt Optimization
Generate stunning images with advanced prompt enhancement and quality validation.
""",
elem_classes="main-header"
)
with gr.Row():
# Left column - Inputs
with gr.Column(scale=1):
prompt_input = gr.Textbox(
label="Your Prompt",
placeholder="Describe the image you want to create...",
lines=4,
value="A serene mountain landscape at sunset with a lake reflection"
)
with gr.Row():
style_dropdown = gr.Dropdown(
choices=list(STYLE_PRESETS.keys()),
value="Photorealistic",
label="Style Preset",
info="Select a style for automatic optimization"
)
aspect_ratio_dropdown = gr.Dropdown(
choices=list(ASPECT_RATIOS.keys()),
value="Square (1:1)",
label="Aspect Ratio",
info="Choose your desired dimensions"
)
with gr.Accordion("βš™οΈ Advanced Settings", open=False):
guidance_scale_slider = gr.Slider(
minimum=1,
maximum=15,
value=DEFAULT_GUIDANCE_SCALE,
step=0.5,
label="Guidance Scale",
info="How closely to follow the prompt (7-9 recommended)"
)
num_steps_slider = gr.Slider(
minimum=20,
maximum=50,
value=DEFAULT_NUM_STEPS,
step=5,
label="Inference Steps",
info="More steps = better quality but slower (30-35 recommended)"
)
seed_input = gr.Number(
label="Seed",
value=-1,
precision=0,
info="Set to -1 for random, or use specific number for reproducibility"
)
quality_check = gr.Checkbox(
label="Enable Quality Validation",
value=True,
info="Use CLIP to validate output and retry if needed"
)
generate_btn = gr.Button(
"🎨 Generate Image",
variant="primary",
size="lg",
elem_classes="generate-btn"
)
# Right column - Outputs
with gr.Column(scale=1):
output_image = gr.Image(
label="Generated Image",
type="pil",
show_label=True
)
output_info = gr.Markdown(label="Generation Details")
# Examples
gr.Examples(
examples=[
["A futuristic cyberpunk city at night with neon lights", "Cinematic", "Landscape (4:3)"],
["Portrait of a wise old wizard with a long beard", "Digital Art", "Portrait (3:4)"],
["Cute anime girl with pink hair in a cherry blossom garden", "Anime", "Square (1:1)"],
["Photorealistic macro photograph of a dewdrop on a leaf", "Photorealistic", "Square (1:1)"],
["Epic dragon flying over ancient castle ruins", "Oil Painting", "Wide (16:9)"],
],
inputs=[prompt_input, style_dropdown, aspect_ratio_dropdown],
label="πŸ’‘ Example Prompts"
)
# Event handler
generate_btn.click(
fn=app.generate_image,
inputs=[
prompt_input,
style_dropdown,
aspect_ratio_dropdown,
guidance_scale_slider,
num_steps_slider,
seed_input,
quality_check
],
outputs=[output_image, output_info]
)
gr.Markdown(
"""
---
### πŸ“š Tips for Best Results
- **Be specific**: Include details about subject, setting, lighting, and style
- **Use style presets**: They automatically add professional quality enhancers
- **Adjust guidance scale**: Higher values (8-10) follow prompts more strictly
- **Quality validation**: Helps ensure good results but takes slightly longer
- **Seed control**: Use the same seed to reproduce results with variations
### πŸ› οΈ Technical Stack
- **Model**: Stable Diffusion XL (SDXL)
- **Scheduler**: DPM++ Solver Multistep
- **Validation**: CLIP-based quality scoring
- **Optimization**: Intelligent prompt enhancement
"""
)
return demo
if __name__ == "__main__":
print("=" * 60)
print("πŸš€ Starting AI Image Generator...")
print("=" * 60)
demo = create_ui()
demo.launch(
share=False, # Set to True to create a public link
server_name="0.0.0.0",
server_port=7860,
show_error=True
)