fastapi / main.py
SAUL19's picture
Update main.py
a1195f3
from fastapi import FastAPI, Response
from fastapi.middleware.cors import CORSMiddleware
import torch
from torch.cuda.amp import autocast
from diffusers import DiffusionPipeline
from diffusers import StableDiffusionPipeline
from io import BytesIO
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"]
)
# MODEL 1.4
model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = StableDiffusionPipeline.from_pretrained(
model_id, torch_dtype=torch.float32, cache_dir="./cache"
)
# MOdel 1.0
# model_id = "stabilityai/stable-diffusion-xl-base-1.0"
# device = "cuda" if torch.cuda.is_available() else "cpu"
# pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32, use_safetensors=True, variant="fp16")
pipe = pipe.to(device)
@app.get("/")
def generate(prompt: str):
with autocast(device):
image = pipe(prompt, guidance_scale=8.5).images[0]
buffer = BytesIO()
image.save(buffer, format="PNG")
buffer.seek(0)
return Response(content=buffer.getvalue(), media_type="image/png")