| import gradio as gr |
| import torch |
| from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler |
| from PIL import Image |
| import os |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Using device: {device}") |
|
|
| model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5" |
| pipe = None |
|
|
| def load_model(): |
| global pipe |
| if pipe is None: |
| print(f"Loading model: {model_id}") |
| try: |
| pipe = StableDiffusionPipeline.from_pretrained( |
| model_id, |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
| safety_checker=None, |
| requires_safety_checker=False |
| ) |
| |
| pipe.scheduler = DPMSolverMultistepScheduler.from_config( |
| pipe.scheduler.config |
| ) |
| |
| pipe = pipe.to(device) |
| |
| if device == "cpu": |
| pipe.enable_attention_slicing() |
| pipe.enable_vae_slicing() |
| else: |
| pipe.enable_attention_slicing(1) |
| |
| print("Model loaded successfully!") |
| return pipe |
| |
| except Exception as e: |
| print(f"Error loading model: {str(e)}") |
| print("Trying alternative model...") |
| |
| try: |
| model_id_alt = "CompVis/stable-diffusion-v1-4" |
| pipe = StableDiffusionPipeline.from_pretrained( |
| model_id_alt, |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
| safety_checker=None, |
| requires_safety_checker=False |
| ) |
| pipe.scheduler = DPMSolverMultistepScheduler.from_config( |
| pipe.scheduler.config |
| ) |
| pipe = pipe.to(device) |
| |
| if device == "cpu": |
| pipe.enable_attention_slicing() |
| pipe.enable_vae_slicing() |
| |
| print(f"Alternative model loaded successfully!") |
| return pipe |
| |
| except Exception as e2: |
| print(f"Error loading alternative model: {str(e2)}") |
| raise Exception("Cannot load model. Please check internet connection.") |
| |
| return pipe |
|
|
| def generate_image( |
| prompt: str, |
| negative_prompt: str = "", |
| num_inference_steps: int = 25, |
| guidance_scale: float = 7.5, |
| width: int = 512, |
| height: int = 512, |
| seed: int = -1 |
| ): |
| if not prompt or len(prompt.strip()) == 0: |
| return None, "Please enter a prompt!" |
| |
| try: |
| pipeline = load_model() |
| |
| generator = None |
| if seed != -1: |
| generator = torch.Generator(device=device).manual_seed(int(seed)) |
| |
| print(f"Generating: {prompt[:50]}...") |
| |
| with torch.inference_mode(): |
| result = pipeline( |
| prompt=prompt, |
| negative_prompt=negative_prompt if negative_prompt else None, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| width=width, |
| height=height, |
| generator=generator |
| ) |
| |
| image = result.images[0] |
| |
| return image, "Image generated successfully!" |
| |
| except Exception as e: |
| error_msg = f"Error: {str(e)}" |
| print(error_msg) |
| return None, error_msg |
|
|
| with gr.Blocks(title="Text to Image", theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# Text to Image Generator") |
| gr.Markdown("Generate images from text using Stable Diffusion") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| prompt_input = gr.Textbox( |
| label="Prompt", |
| placeholder="A beautiful sunset over mountains, digital art", |
| lines=4 |
| ) |
| |
| negative_prompt_input = gr.Textbox( |
| label="Negative Prompt (optional)", |
| placeholder="blurry, low quality, distorted", |
| lines=2 |
| ) |
| |
| with gr.Accordion("Advanced Settings", open=False): |
| with gr.Row(): |
| width_slider = gr.Slider( |
| minimum=256, |
| maximum=768, |
| step=64, |
| value=512, |
| label="Width" |
| ) |
| |
| height_slider = gr.Slider( |
| minimum=256, |
| maximum=768, |
| step=64, |
| value=512, |
| label="Height" |
| ) |
| |
| steps_slider = gr.Slider( |
| minimum=15, |
| maximum=50, |
| step=5, |
| value=25, |
| label="Steps" |
| ) |
| |
| guidance_slider = gr.Slider( |
| minimum=1.0, |
| maximum=15.0, |
| step=0.5, |
| value=7.5, |
| label="Guidance Scale" |
| ) |
| |
| seed_input = gr.Number( |
| label="Seed (-1 for random)", |
| value=-1, |
| precision=0 |
| ) |
| |
| generate_btn = gr.Button("Generate Image", variant="primary", size="lg") |
| |
| with gr.Column(scale=1): |
| output_image = gr.Image( |
| label="Generated Image", |
| type="pil", |
| height=512 |
| ) |
| |
| output_message = gr.Textbox( |
| label="Status", |
| interactive=False, |
| lines=2 |
| ) |
| |
| gr.Examples( |
| examples=[ |
| ["A serene landscape with mountains and a lake at sunset, digital art", "blurry, low quality", 25, 7.5, 512, 512, 42], |
| ["A futuristic city with flying cars, cyberpunk style, neon lights", "ugly, distorted", 25, 7.5, 512, 512, 123], |
| ["A cute cat wearing sunglasses, cartoon style", "", 25, 7.5, 512, 512, 456], |
| ], |
| inputs=[ |
| prompt_input, |
| negative_prompt_input, |
| steps_slider, |
| guidance_slider, |
| width_slider, |
| height_slider, |
| seed_input |
| ] |
| ) |
| |
| generate_btn.click( |
| fn=generate_image, |
| inputs=[ |
| prompt_input, |
| negative_prompt_input, |
| steps_slider, |
| guidance_slider, |
| width_slider, |
| height_slider, |
| seed_input |
| ], |
| outputs=[output_image, output_message] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.queue() |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=False |
| ) |