Spaces:
Sleeping
Sleeping
| """ | |
| AI Image Generator - Main Gradio Application | |
| Professional interface for SDXL-based image generation with quality validation | |
| """ | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import os | |
| from datetime import datetime | |
| from generator import ImageGenerator | |
| from prompt_optimizer import PromptOptimizer | |
| from quality_validator import QualityValidator | |
| from config import ( | |
| STYLE_PRESETS, | |
| ASPECT_RATIOS, | |
| DEFAULT_GUIDANCE_SCALE, | |
| DEFAULT_NUM_STEPS, | |
| MIN_QUALITY_SCORE, | |
| MAX_RETRIES | |
| ) | |
| class ImageGeneratorApp: | |
| """ | |
| Main application class combining all components | |
| """ | |
| def __init__(self): | |
| self.generator = ImageGenerator(use_refiner=False) | |
| self.optimizer = PromptOptimizer() | |
| self.validator = QualityValidator() | |
| self.output_dir = "outputs" | |
| # Create output directory | |
| os.makedirs(self.output_dir, exist_ok=True) | |
| print("π AI Image Generator initialized!") | |
| print(f"Device: {'CUDA (GPU)' if torch.cuda.is_available() else 'CPU'}") | |
| def generate_image( | |
| self, | |
| prompt: str, | |
| style: str, | |
| aspect_ratio: str, | |
| guidance_scale: float, | |
| num_steps: int, | |
| seed: int, | |
| enable_quality_check: bool, | |
| progress=gr.Progress() | |
| ): | |
| """ | |
| Main generation pipeline with quality validation | |
| """ | |
| try: | |
| # Update progress | |
| progress(0, desc="Optimizing prompt...") | |
| # Optimize prompt | |
| enhanced_prompt, negative_prompt = self.optimizer.enhance_prompt( | |
| prompt, | |
| style=style | |
| ) | |
| # Get dimensions | |
| width, height = ASPECT_RATIOS[aspect_ratio] | |
| # Load models if needed | |
| progress(0.1, desc="Loading models...") | |
| if not self.generator._initialized: | |
| self.generator.load_models() | |
| # Generate image (with potential retries for quality) | |
| best_image = None | |
| best_score = 0 | |
| attempt = 0 | |
| max_attempts = MAX_RETRIES + 1 if enable_quality_check else 1 | |
| while attempt < max_attempts: | |
| progress( | |
| 0.2 + (attempt * 0.6 / max_attempts), | |
| desc=f"Generating image (attempt {attempt + 1}/{max_attempts})..." | |
| ) | |
| # Generate | |
| current_seed = seed if seed != -1 else -1 | |
| image, metadata = self.generator.generate( | |
| prompt=enhanced_prompt, | |
| negative_prompt=negative_prompt, | |
| width=width, | |
| height=height, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_steps, | |
| seed=current_seed | |
| ) | |
| # Validate quality if enabled | |
| if enable_quality_check: | |
| progress(0.8, desc="Validating quality...") | |
| score = self.validator.validate(image, enhanced_prompt) | |
| if score > best_score: | |
| best_image = image | |
| best_score = score | |
| # Check if quality is acceptable | |
| if score >= MIN_QUALITY_SCORE: | |
| best_image = image | |
| best_score = score | |
| break | |
| else: | |
| best_image = image | |
| best_score = 0.5 # Neutral score | |
| break | |
| attempt += 1 | |
| if attempt < max_attempts: | |
| seed = -1 # Use random seed for retry | |
| # Save image | |
| progress(0.9, desc="Saving image...") | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"generated_{timestamp}.png" | |
| filepath = os.path.join(self.output_dir, filename) | |
| best_image.save(filepath) | |
| # Prepare info | |
| quality_feedback = self.validator.get_quality_feedback(best_score) if enable_quality_check else "Quality check disabled" | |
| info = f""" | |
| ### Generation Info | |
| **Prompt:** {prompt} | |
| **Enhanced Prompt:** {enhanced_prompt} | |
| **Negative Prompt:** {negative_prompt} | |
| **Settings:** | |
| - Style: {style} | |
| - Aspect Ratio: {aspect_ratio} ({width}x{height}) | |
| - Guidance Scale: {guidance_scale} | |
| - Steps: {num_steps} | |
| - Seed: {metadata['seed']} | |
| **Quality Score:** {best_score:.4f} - {quality_feedback} | |
| **Attempts:** {attempt + 1}/{max_attempts} | |
| **Saved to:** `{filepath}` | |
| """ | |
| progress(1.0, desc="Complete!") | |
| return best_image, info | |
| except Exception as e: | |
| error_msg = f"β Error during generation: {str(e)}\n\nPlease check your settings and try again." | |
| return None, error_msg | |
| def create_ui(): | |
| """ | |
| Create the Gradio interface | |
| """ | |
| app = ImageGeneratorApp() | |
| # Custom CSS for better aesthetics | |
| custom_css = """ | |
| .gradio-container { | |
| font-family: 'Inter', sans-serif; | |
| } | |
| .main-header { | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| } | |
| .generate-btn { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| border: none; | |
| color: white; | |
| font-weight: 600; | |
| } | |
| """ | |
| with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # π¨ AI Image Generator | |
| ### High-Accuracy SDXL with Intelligent Prompt Optimization | |
| Generate stunning images with advanced prompt enhancement and quality validation. | |
| """, | |
| elem_classes="main-header" | |
| ) | |
| with gr.Row(): | |
| # Left column - Inputs | |
| with gr.Column(scale=1): | |
| prompt_input = gr.Textbox( | |
| label="Your Prompt", | |
| placeholder="Describe the image you want to create...", | |
| lines=4, | |
| value="A serene mountain landscape at sunset with a lake reflection" | |
| ) | |
| with gr.Row(): | |
| style_dropdown = gr.Dropdown( | |
| choices=list(STYLE_PRESETS.keys()), | |
| value="Photorealistic", | |
| label="Style Preset", | |
| info="Select a style for automatic optimization" | |
| ) | |
| aspect_ratio_dropdown = gr.Dropdown( | |
| choices=list(ASPECT_RATIOS.keys()), | |
| value="Square (1:1)", | |
| label="Aspect Ratio", | |
| info="Choose your desired dimensions" | |
| ) | |
| with gr.Accordion("βοΈ Advanced Settings", open=False): | |
| guidance_scale_slider = gr.Slider( | |
| minimum=1, | |
| maximum=15, | |
| value=DEFAULT_GUIDANCE_SCALE, | |
| step=0.5, | |
| label="Guidance Scale", | |
| info="How closely to follow the prompt (7-9 recommended)" | |
| ) | |
| num_steps_slider = gr.Slider( | |
| minimum=20, | |
| maximum=50, | |
| value=DEFAULT_NUM_STEPS, | |
| step=5, | |
| label="Inference Steps", | |
| info="More steps = better quality but slower (30-35 recommended)" | |
| ) | |
| seed_input = gr.Number( | |
| label="Seed", | |
| value=-1, | |
| precision=0, | |
| info="Set to -1 for random, or use specific number for reproducibility" | |
| ) | |
| quality_check = gr.Checkbox( | |
| label="Enable Quality Validation", | |
| value=True, | |
| info="Use CLIP to validate output and retry if needed" | |
| ) | |
| generate_btn = gr.Button( | |
| "π¨ Generate Image", | |
| variant="primary", | |
| size="lg", | |
| elem_classes="generate-btn" | |
| ) | |
| # Right column - Outputs | |
| with gr.Column(scale=1): | |
| output_image = gr.Image( | |
| label="Generated Image", | |
| type="pil", | |
| show_label=True | |
| ) | |
| output_info = gr.Markdown(label="Generation Details") | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["A futuristic cyberpunk city at night with neon lights", "Cinematic", "Landscape (4:3)"], | |
| ["Portrait of a wise old wizard with a long beard", "Digital Art", "Portrait (3:4)"], | |
| ["Cute anime girl with pink hair in a cherry blossom garden", "Anime", "Square (1:1)"], | |
| ["Photorealistic macro photograph of a dewdrop on a leaf", "Photorealistic", "Square (1:1)"], | |
| ["Epic dragon flying over ancient castle ruins", "Oil Painting", "Wide (16:9)"], | |
| ], | |
| inputs=[prompt_input, style_dropdown, aspect_ratio_dropdown], | |
| label="π‘ Example Prompts" | |
| ) | |
| # Event handler | |
| generate_btn.click( | |
| fn=app.generate_image, | |
| inputs=[ | |
| prompt_input, | |
| style_dropdown, | |
| aspect_ratio_dropdown, | |
| guidance_scale_slider, | |
| num_steps_slider, | |
| seed_input, | |
| quality_check | |
| ], | |
| outputs=[output_image, output_info] | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### π Tips for Best Results | |
| - **Be specific**: Include details about subject, setting, lighting, and style | |
| - **Use style presets**: They automatically add professional quality enhancers | |
| - **Adjust guidance scale**: Higher values (8-10) follow prompts more strictly | |
| - **Quality validation**: Helps ensure good results but takes slightly longer | |
| - **Seed control**: Use the same seed to reproduce results with variations | |
| ### π οΈ Technical Stack | |
| - **Model**: Stable Diffusion XL (SDXL) | |
| - **Scheduler**: DPM++ Solver Multistep | |
| - **Validation**: CLIP-based quality scoring | |
| - **Optimization**: Intelligent prompt enhancement | |
| """ | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| print("=" * 60) | |
| print("π Starting AI Image Generator...") | |
| print("=" * 60) | |
| demo = create_ui() | |
| demo.launch( | |
| share=False, # Set to True to create a public link | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True | |
| ) | |