| import gradio as gr |
| import torch |
| from diffusers import StableDiffusionPipeline |
| import gc |
| import os |
| from PIL import Image |
| import numpy as np |
| from dataclasses import dataclass |
| from typing import Optional, Dict, Any |
| import json |
| import time |
|
|
| @dataclass |
| class GenerationParams: |
| prompt: str |
| style: str = "realistic" |
| steps: int = 20 |
| guidance_scale: float = 7.0 |
| seed: int = -1 |
| quality: str = "balanced" |
| |
| class GenerartSystem: |
| def __init__(self): |
| self.model = None |
| self.styles = { |
| "realistic": { |
| "prompt_prefix": "professional photography, highly detailed, photorealistic quality", |
| "negative_prompt": "cartoon, anime, illustration, painting, drawing, blurry, low quality", |
| "params": {"guidance_scale": 7.5, "steps": 20} |
| }, |
| "artistic": { |
| "prompt_prefix": "artistic painting, impressionist style, vibrant colors", |
| "negative_prompt": "photorealistic, digital art, 3d render, low quality", |
| "params": {"guidance_scale": 6.5, "steps": 25} |
| }, |
| "modern": { |
| "prompt_prefix": "modern art, contemporary style, abstract qualities", |
| "negative_prompt": "traditional, classic, photorealistic, low quality", |
| "params": {"guidance_scale": 8.0, "steps": 15} |
| } |
| } |
| self.quality_presets = { |
| "speed": {"steps_multiplier": 0.8}, |
| "balanced": {"steps_multiplier": 1.0}, |
| "quality": {"steps_multiplier": 1.2} |
| } |
| self.performance_stats = { |
| "total_generations": 0, |
| "average_time": 0, |
| "success_rate": 100, |
| "last_error": None |
| } |
| |
| def initialize_model(self): |
| """Initialize the model with memory optimizations""" |
| if self.model is not None: |
| return |
|
|
| |
| gc.collect() |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| |
| try: |
| self.model = StableDiffusionPipeline.from_pretrained( |
| "CompVis/stable-diffusion-v1-4", |
| torch_dtype=torch.float32, |
| safety_checker=None, |
| requires_safety_checker=False |
| ) |
| |
| |
| self.model.enable_attention_slicing() |
| self.model.enable_vae_slicing() |
| |
| |
| self.model = self.model.to("cpu") |
| |
| except Exception as e: |
| print(f"Error initializing model: {str(e)}") |
| raise |
|
|
| def cleanup(self): |
| """Memory cleanup after generation""" |
| gc.collect() |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
| def update_performance_stats(self, generation_time: float, success: bool = True, error: Optional[str] = None): |
| """Update system performance statistics""" |
| self.performance_stats["total_generations"] += 1 |
| |
| |
| prev_avg = self.performance_stats["average_time"] |
| self.performance_stats["average_time"] = (prev_avg * (self.performance_stats["total_generations"] - 1) + |
| generation_time) / self.performance_stats["total_generations"] |
| |
| |
| if not success: |
| self.performance_stats["success_rate"] = (self.performance_stats["success_rate"] * |
| (self.performance_stats["total_generations"] - 1) + |
| 0) / self.performance_stats["total_generations"] |
| self.performance_stats["last_error"] = error |
| |
| def get_system_stats(self): |
| """Get current system statistics""" |
| return { |
| "total_generations": self.performance_stats["total_generations"], |
| "average_time": round(self.performance_stats["average_time"], 2), |
| "success_rate": round(self.performance_stats["success_rate"], 1), |
| "memory_usage": f"{torch.cuda.memory_allocated()/1024**2:.1f}MB" if torch.cuda.is_available() |
| else "CPU Mode" |
| } |
|
|
| def generate_image(self, params: GenerationParams) -> Image.Image: |
| """Generate image with given parameters""" |
| try: |
| |
| if self.model is None: |
| self.initialize_model() |
| |
| |
| style_config = self.styles[params.style] |
| quality_config = self.quality_presets[params.quality] |
| |
| |
| full_prompt = f"{style_config['prompt_prefix']}, {params.prompt}" |
| |
| |
| final_steps = int(min(25, params.steps * quality_config["steps_multiplier"])) |
| |
| |
| if params.seed == -1: |
| generator = None |
| else: |
| generator = torch.manual_seed(params.seed) |
| |
| start_time = time.time() |
| |
| |
| with torch.no_grad(): |
| image = self.model( |
| prompt=full_prompt, |
| negative_prompt=style_config["negative_prompt"], |
| num_inference_steps=final_steps, |
| guidance_scale=params.guidance_scale, |
| generator=generator, |
| width=512, |
| height=512 |
| ).images[0] |
| |
| generation_time = time.time() - start_time |
| self.update_performance_stats(generation_time, success=True) |
| |
| return image |
| |
| except Exception as e: |
| self.update_performance_stats(0, success=False, error=str(e)) |
| raise RuntimeError(f"Generation error: {str(e)}") |
| |
| finally: |
| self.cleanup() |
|
|
| class GenerartInterface: |
| def __init__(self): |
| self.system = GenerartSystem() |
| |
| def create_interface(self): |
| """Create the Gradio interface""" |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| |
| gr.Markdown("# 🎨 Generart Beta") |
| |
| with gr.Row(): |
| |
| with gr.Column(scale=1): |
| prompt = gr.Textbox(label="Description", placeholder="Décrivez l'image souhaitée...") |
| |
| style = gr.Dropdown( |
| choices=list(self.system.styles.keys()), |
| value="realistic", |
| label="Style Artistique" |
| ) |
| |
| with gr.Group(): |
| steps = gr.Slider( |
| minimum=15, |
| maximum=25, |
| value=20, |
| step=1, |
| label="Nombre d'étapes" |
| ) |
| |
| guidance = gr.Slider( |
| minimum=6.0, |
| maximum=8.0, |
| value=7.0, |
| step=0.1, |
| label="Guide Scale" |
| ) |
| |
| quality = gr.Dropdown( |
| choices=list(self.system.quality_presets.keys()), |
| value="balanced", |
| label="Qualité" |
| ) |
| |
| seed = gr.Number( |
| value=-1, |
| label="Seed (-1 pour aléatoire)", |
| precision=0 |
| ) |
| |
| generate_btn = gr.Button("Générer", variant="primary") |
| |
| |
| with gr.Group(): |
| gr.Markdown("### 📊 Statistiques Système") |
| stats_output = gr.JSON(value=self.system.get_system_stats()) |
| |
| |
| with gr.Column(scale=1): |
| image_output = gr.Image(label="Image Générée", type="pil") |
| |
| |
| def generate(prompt, style, steps, guidance_scale, quality, seed): |
| params = GenerationParams( |
| prompt=prompt, |
| style=style, |
| steps=steps, |
| guidance_scale=guidance_scale, |
| quality=quality, |
| seed=seed |
| ) |
| |
| image = self.system.generate_image(params) |
| return [image, self.system.get_system_stats()] |
| |
| generate_btn.click( |
| fn=generate, |
| inputs=[prompt, style, steps, guidance, quality, seed], |
| outputs=[image_output, stats_output] |
| ) |
| |
| return demo |
|
|
| |
| if __name__ == "__main__": |
| interface = GenerartInterface() |
| demo = interface.create_interface() |
| demo.launch() |