|
|
from fastapi import FastAPI, File, UploadFile |
|
|
from fastapi.responses import StreamingResponse |
|
|
from PIL import Image |
|
|
from io import BytesIO |
|
|
import torch |
|
|
import math |
|
|
from diffusers import AutoPipelineForImage2Image |
|
|
import uvicorn |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
pipe = AutoPipelineForImage2Image.from_pretrained( |
|
|
"stabilityai/sdxl-turbo", |
|
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32 |
|
|
) |
|
|
pipe = pipe.to(device) |
|
|
|
|
|
PROMPT = "Transform the subject or image into cartoon style high quality" |
|
|
STEPS = 2 |
|
|
STRENGTH = 0.65 |
|
|
GUIDANCE = 1.0 |
|
|
OUTPUT_SIZE = 1024 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resize_image(img: Image.Image, size=512): |
|
|
return img.resize((size, size), Image.BICUBIC) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def root(): |
|
|
return {"status": "SDXL Turbo Cartoon API is running"} |
|
|
|
|
|
@app.post("/cartoonize") |
|
|
async def cartoonize_image(file: UploadFile = File(...)): |
|
|
image_bytes = await file.read() |
|
|
input_image = Image.open(BytesIO(image_bytes)).convert("RGB") |
|
|
input_image = resize_image(input_image, OUTPUT_SIZE) |
|
|
|
|
|
|
|
|
steps = STEPS |
|
|
if int(steps * STRENGTH) < 1: |
|
|
steps = math.ceil(1 / STRENGTH) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
result = pipe( |
|
|
PROMPT, |
|
|
image=input_image, |
|
|
strength=STRENGTH, |
|
|
guidance_scale=GUIDANCE, |
|
|
num_inference_steps=steps |
|
|
).images[0] |
|
|
|
|
|
buffer = BytesIO() |
|
|
result.save(buffer, format="PNG") |
|
|
buffer.seek(0) |
|
|
|
|
|
return StreamingResponse( |
|
|
buffer, |
|
|
media_type="image/png", |
|
|
headers={"Content-Disposition": "inline; filename=cartoon.png"} |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |