|
|
import torch |
|
|
from diffusers import StableDiffusionPipeline |
|
|
from fastapi import FastAPI, Response |
|
|
from fastapi.responses import JSONResponse |
|
|
from pydantic import BaseModel |
|
|
import io |
|
|
import base64 |
|
|
from typing import Optional |
|
|
import uvicorn |
|
|
import os |
|
|
|
|
|
|
|
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache" |
|
|
os.environ["HF_HOME"] = "/tmp/hf_home" |
|
|
os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_datasets_cache" |
|
|
|
|
|
|
|
|
app = FastAPI(title="Stable Diffusion API") |
|
|
|
|
|
|
|
|
class TextToImageRequest(BaseModel): |
|
|
prompt: str |
|
|
negative_prompt: Optional[str] = None |
|
|
num_inference_steps: Optional[int] = 50 |
|
|
guidance_scale: Optional[float] = 7.5 |
|
|
height: Optional[int] = 512 |
|
|
width: Optional[int] = 512 |
|
|
seed: Optional[int] = None |
|
|
|
|
|
|
|
|
model_id = "CompVis/stable-diffusion-v1-4" |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
|
model_id, |
|
|
cache_dir="/tmp/diffusers_cache", |
|
|
token=os.environ.get("HF_TOKEN") |
|
|
) |
|
|
pipe = pipe.to(device) |
|
|
|
|
|
@app.get("/") |
|
|
def read_root(): |
|
|
return {"message": "Stable Diffusion API is running. Use POST /generate endpoint."} |
|
|
|
|
|
@app.post("/generate") |
|
|
async def generate_image(request: TextToImageRequest): |
|
|
try: |
|
|
|
|
|
if request.seed is not None: |
|
|
generator = torch.Generator(device=device).manual_seed(request.seed) |
|
|
else: |
|
|
generator = None |
|
|
|
|
|
|
|
|
image = pipe( |
|
|
prompt=request.prompt, |
|
|
negative_prompt=request.negative_prompt, |
|
|
num_inference_steps=request.num_inference_steps, |
|
|
guidance_scale=request.guidance_scale, |
|
|
height=request.height, |
|
|
width=request.width, |
|
|
generator=generator |
|
|
).images[0] |
|
|
|
|
|
|
|
|
buffer = io.BytesIO() |
|
|
image.save(buffer, format="PNG") |
|
|
img_str = base64.b64encode(buffer.getvalue()).decode("utf-8") |
|
|
|
|
|
return JSONResponse({ |
|
|
"status": "success", |
|
|
"image": img_str, |
|
|
"parameters": { |
|
|
"prompt": request.prompt, |
|
|
"negative_prompt": request.negative_prompt, |
|
|
"steps": request.num_inference_steps, |
|
|
"guidance_scale": request.guidance_scale, |
|
|
"dimensions": f"{request.width}x{request.height}", |
|
|
"seed": request.seed |
|
|
} |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
return JSONResponse( |
|
|
status_code=500, |
|
|
content={"status": "error", "message": str(e)} |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |