tmblack's picture
Update app.py
f96b2c0 verified
# ================================================================
# AI IMAGE GENERATOR - FINAL FIXED SOLUTION
# ================================================================
import logging
import torch
import gradio as gr
import numpy as np
import random
from diffusers import AutoPipelineForText2Image
from PIL import Image
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Constants
MODEL_ID = "stabilityai/sdxl-turbo"
MAX_SEED = np.iinfo(np.int32).max
DEFAULT_SIZE = 512
DEFAULT_STEPS = 1
DEFAULT_GUIDANCE = 0.0
NUM_IMAGES = 4
# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
logger.info(f"Using device: {device}, dtype: {dtype}")
# Load model
try:
logger.info("⚙️ Loading model...")
pipe = AutoPipelineForText2Image.from_pretrained(
MODEL_ID,
torch_dtype=dtype,
use_safetensors=True
)
pipe = pipe.to(device)
logger.info("✅ Model loaded successfully!")
except Exception as e:
logger.error(f"Model loading failed: {str(e)}")
gr.Error(f"Model loading failed: {str(e)}")
def generate_images(
prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
num_images
):
"""Generate images with robust error handling"""
# Input validation
if not prompt.strip():
return [Image.new('RGB', (512, 512), (255, 0, 0))] * num_images, "ERROR: Empty prompt"
try:
images = []
seeds_used = []
negative_prompt = "" # Fixed: Negative prompt as empty string
for i in range(num_images):
# Seed handling
if randomize_seed:
current_seed = random.randint(0, MAX_SEED)
else:
current_seed = seed + i
generator = torch.Generator(device=device).manual_seed(current_seed)
# Generate image
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=int(num_inference_steps),
guidance_scale=guidance_scale,
width=int(width),
height=int(height),
generator=generator,
).images[0]
images.append(image)
seeds_used.append(current_seed)
return images, ", ".join(map(str, seeds_used))
except torch.cuda.OutOfMemoryError:
return [Image.new('RGB', (512, 512), (255, 165, 0))] * num_images, "ERROR: Out of GPU memory"
except Exception as e:
logger.error(f"Generation error: {str(e)}")
return [Image.new('RGB', (512, 512), (255, 0, 0))] * num_images, f"ERROR: {str(e)}"
# UI Components
css = """
/* ... আপনার CSS কোড অপরিবর্তিত থাকবে ... */
"""
# Create interface
with gr.Blocks(css=css, title="AI Image Generator") as demo:
with gr.Column(elem_id="main-container"):
# Header
gr.Markdown("# 🎨 AI Image Generator", elem_classes="header")
gr.Markdown("Create stunning visuals from text prompts", elem_classes="header")
# Inputs
with gr.Row():
prompt = gr.Textbox(label="Prompt", placeholder="Describe what you want to see...", lines=2)
# Generate button
generate_btn = gr.Button("✨ Generate Images", elem_classes="btn-generate")
# Gallery output
gallery = gr.Gallery(label="Generated Images", columns=2, height="auto", elem_classes="gallery-grid")
seed_info = gr.Textbox(label="Seeds Used", interactive=False)
# Advanced settings
with gr.Accordion("⚙️ Advanced Settings", open=False):
with gr.Row():
num_images = gr.Slider(1, 4, value=2, step=1, label="Number of Images")
width = gr.Slider(256, 768, value=DEFAULT_SIZE, step=32, label="Width")
height = gr.Slider(256, 768, value=DEFAULT_SIZE, step=32, label="Height")
with gr.Row():
guidance_scale = gr.Slider(0.0, 10.0, value=DEFAULT_GUIDANCE, step=0.1, label="Guidance Scale")
num_inference_steps = gr.Slider(1, 4, value=DEFAULT_STEPS, step=1, label="Steps")
with gr.Row():
seed = gr.Slider(0, MAX_SEED, value=0, step=1, label="Seed")
randomize_seed = gr.Checkbox(value=True, label="Randomize Seed")
# Examples
gr.Markdown("### ✨ Example Prompts:")
gr.Examples(
examples=[
"A red apple on a wooden table",
"Mountain landscape at sunset",
"Cyberpunk city street at night"
],
inputs=prompt
)
# Event handling - FINAL FIX
generate_btn.click(
fn=generate_images,
inputs=[
prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
num_images
],
outputs=[gallery, seed_info]
)
# Launch the app
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)