Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import io | |
| import base64 | |
| from PIL import Image | |
| import gradio as gr | |
| from diffusers import StableDiffusionXLPipeline | |
| import torch | |
| # Check if CUDA is available, otherwise use CPU | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Initialize the SDXL pipeline | |
| model_id = "stabilityai/stable-diffusion-xl-base-1.0" | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
| use_safetensors=True, | |
| variant="fp16" if device == "cuda" else None | |
| ) | |
| pipe = pipe.to(device) | |
| # Enable memory efficient attention if running on CUDA | |
| if device == "cuda": | |
| pipe.enable_attention_slicing() | |
| def generate_image(prompt, negative_prompt="", height=512, width=512, num_inference_steps=30, guidance_scale=7.5): | |
| """Generate an image from a text prompt""" | |
| # Validate inputs | |
| if height % 8 != 0 or width % 8 != 0: | |
| raise ValueError("Height and width must be divisible by 8") | |
| # Generate the image | |
| image = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| ).images[0] | |
| # Convert PIL Image to base64 string | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| return image, f"data:image/png;base64,{img_str}" | |
| # Define the Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Text-to-Image Generator API") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...") | |
| negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Things you don't want in the image...") | |
| with gr.Row(): | |
| height = gr.Slider(minimum=256, maximum=1024, step=8, value=512, label="Height") | |
| width = gr.Slider(minimum=256, maximum=1024, step=8, value=512, label="Width") | |
| with gr.Row(): | |
| steps = gr.Slider(minimum=10, maximum=50, step=1, value=30, label="Inference Steps") | |
| guidance = gr.Slider(minimum=1, maximum=15, step=0.1, value=7.5, label="Guidance Scale") | |
| generate_btn = gr.Button("Generate") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Generated Image") | |
| output_json = gr.Textbox(label="Image Base64", show_copy_button=True) | |
| generate_btn.click( | |
| fn=generate_image, | |
| inputs=[prompt, negative_prompt, height, width, steps, guidance], | |
| outputs=[output_image, output_json] | |
| ) | |
| gr.Markdown(""" | |
| ## API Usage | |
| You can use this as an API with this curl command: | |
| ```bash | |
| curl -X POST "https://your-username-text-to-image-api.hf.space/api/predict" \\ | |
| -H "Content-Type: application/json" \\ | |
| -d '{ | |
| "data": [ | |
| "A beautiful sunset over mountains", | |
| "", | |
| 512, | |
| 512, | |
| 30, | |
| 7.5 | |
| ] | |
| }' | |
| ``` | |
| """) | |
| # Create FastAPI app for direct API usage | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from typing import Optional, List, Union | |
| import nest_asyncio | |
| nest_asyncio.apply() | |
| app = FastAPI() | |
| class ImageRequest(BaseModel): | |
| prompt: str | |
| negative_prompt: Optional[str] = "" | |
| height: Optional[int] = 512 | |
| width: Optional[int] = 512 | |
| num_inference_steps: Optional[int] = 30 | |
| guidance_scale: Optional[float] = 7.5 | |
| class ImageResponse(BaseModel): | |
| image_base64: str | |
| async def generate_image_api(request: ImageRequest): | |
| try: | |
| _, base64_string = generate_image( | |
| prompt=request.prompt, | |
| negative_prompt=request.negative_prompt, | |
| height=request.height, | |
| width=request.width, | |
| num_inference_steps=request.num_inference_steps, | |
| guidance_scale=request.guidance_scale | |
| ) | |
| return ImageResponse(image_base64=base64_string) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Mount the FastAPI app to the Gradio app | |
| demo.queue().launch(share=True) |