rajux75's picture
Create app.py
43e487b verified
# app.py
import os
import io
import base64
from PIL import Image
import gradio as gr
from diffusers import StableDiffusionXLPipeline
import torch
# Check if CUDA is available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize the SDXL pipeline
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = StableDiffusionXLPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
use_safetensors=True,
variant="fp16" if device == "cuda" else None
)
pipe = pipe.to(device)
# Enable memory efficient attention if running on CUDA
if device == "cuda":
pipe.enable_attention_slicing()
def generate_image(prompt, negative_prompt="", height=512, width=512, num_inference_steps=30, guidance_scale=7.5):
"""Generate an image from a text prompt"""
# Validate inputs
if height % 8 != 0 or width % 8 != 0:
raise ValueError("Height and width must be divisible by 8")
# Generate the image
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).images[0]
# Convert PIL Image to base64 string
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return image, f"data:image/png;base64,{img_str}"
# Define the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Text-to-Image Generator API")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Things you don't want in the image...")
with gr.Row():
height = gr.Slider(minimum=256, maximum=1024, step=8, value=512, label="Height")
width = gr.Slider(minimum=256, maximum=1024, step=8, value=512, label="Width")
with gr.Row():
steps = gr.Slider(minimum=10, maximum=50, step=1, value=30, label="Inference Steps")
guidance = gr.Slider(minimum=1, maximum=15, step=0.1, value=7.5, label="Guidance Scale")
generate_btn = gr.Button("Generate")
with gr.Column():
output_image = gr.Image(label="Generated Image")
output_json = gr.Textbox(label="Image Base64", show_copy_button=True)
generate_btn.click(
fn=generate_image,
inputs=[prompt, negative_prompt, height, width, steps, guidance],
outputs=[output_image, output_json]
)
gr.Markdown("""
## API Usage
You can use this as an API with this curl command:
```bash
curl -X POST "https://your-username-text-to-image-api.hf.space/api/predict" \\
-H "Content-Type: application/json" \\
-d '{
"data": [
"A beautiful sunset over mountains",
"",
512,
512,
30,
7.5
]
}'
```
""")
# Create FastAPI app for direct API usage
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, List, Union
import nest_asyncio
nest_asyncio.apply()
app = FastAPI()
class ImageRequest(BaseModel):
prompt: str
negative_prompt: Optional[str] = ""
height: Optional[int] = 512
width: Optional[int] = 512
num_inference_steps: Optional[int] = 30
guidance_scale: Optional[float] = 7.5
class ImageResponse(BaseModel):
image_base64: str
@app.post("/generate", response_model=ImageResponse)
async def generate_image_api(request: ImageRequest):
try:
_, base64_string = generate_image(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
height=request.height,
width=request.width,
num_inference_steps=request.num_inference_steps,
guidance_scale=request.guidance_scale
)
return ImageResponse(image_base64=base64_string)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Mount the FastAPI app to the Gradio app
demo.queue().launch(share=True)