mipooo / app.py
Vgjkmhf's picture
Update app.py
e070f17 verified
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from PIL import Image
import os
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
pipe = None
def load_model():
global pipe
if pipe is None:
print(f"Loading model: {model_id}")
try:
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
safety_checker=None,
requires_safety_checker=False
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
pipe.scheduler.config
)
pipe = pipe.to(device)
if device == "cpu":
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
else:
pipe.enable_attention_slicing(1)
print("Model loaded successfully!")
return pipe
except Exception as e:
print(f"Error loading model: {str(e)}")
print("Trying alternative model...")
try:
model_id_alt = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(
model_id_alt,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
safety_checker=None,
requires_safety_checker=False
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
pipe.scheduler.config
)
pipe = pipe.to(device)
if device == "cpu":
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
print(f"Alternative model loaded successfully!")
return pipe
except Exception as e2:
print(f"Error loading alternative model: {str(e2)}")
raise Exception("Cannot load model. Please check internet connection.")
return pipe
def generate_image(
prompt: str,
negative_prompt: str = "",
num_inference_steps: int = 25,
guidance_scale: float = 7.5,
width: int = 512,
height: int = 512,
seed: int = -1
):
if not prompt or len(prompt.strip()) == 0:
return None, "Please enter a prompt!"
try:
pipeline = load_model()
generator = None
if seed != -1:
generator = torch.Generator(device=device).manual_seed(int(seed))
print(f"Generating: {prompt[:50]}...")
with torch.inference_mode():
result = pipeline(
prompt=prompt,
negative_prompt=negative_prompt if negative_prompt else None,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
width=width,
height=height,
generator=generator
)
image = result.images[0]
return image, "Image generated successfully!"
except Exception as e:
error_msg = f"Error: {str(e)}"
print(error_msg)
return None, error_msg
with gr.Blocks(title="Text to Image", theme=gr.themes.Soft()) as demo:
gr.Markdown("# Text to Image Generator")
gr.Markdown("Generate images from text using Stable Diffusion")
with gr.Row():
with gr.Column(scale=1):
prompt_input = gr.Textbox(
label="Prompt",
placeholder="A beautiful sunset over mountains, digital art",
lines=4
)
negative_prompt_input = gr.Textbox(
label="Negative Prompt (optional)",
placeholder="blurry, low quality, distorted",
lines=2
)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
width_slider = gr.Slider(
minimum=256,
maximum=768,
step=64,
value=512,
label="Width"
)
height_slider = gr.Slider(
minimum=256,
maximum=768,
step=64,
value=512,
label="Height"
)
steps_slider = gr.Slider(
minimum=15,
maximum=50,
step=5,
value=25,
label="Steps"
)
guidance_slider = gr.Slider(
minimum=1.0,
maximum=15.0,
step=0.5,
value=7.5,
label="Guidance Scale"
)
seed_input = gr.Number(
label="Seed (-1 for random)",
value=-1,
precision=0
)
generate_btn = gr.Button("Generate Image", variant="primary", size="lg")
with gr.Column(scale=1):
output_image = gr.Image(
label="Generated Image",
type="pil",
height=512
)
output_message = gr.Textbox(
label="Status",
interactive=False,
lines=2
)
gr.Examples(
examples=[
["A serene landscape with mountains and a lake at sunset, digital art", "blurry, low quality", 25, 7.5, 512, 512, 42],
["A futuristic city with flying cars, cyberpunk style, neon lights", "ugly, distorted", 25, 7.5, 512, 512, 123],
["A cute cat wearing sunglasses, cartoon style", "", 25, 7.5, 512, 512, 456],
],
inputs=[
prompt_input,
negative_prompt_input,
steps_slider,
guidance_slider,
width_slider,
height_slider,
seed_input
]
)
generate_btn.click(
fn=generate_image,
inputs=[
prompt_input,
negative_prompt_input,
steps_slider,
guidance_slider,
width_slider,
height_slider,
seed_input
],
outputs=[output_image, output_message]
)
if __name__ == "__main__":
demo.queue()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)