|
|
import os |
|
|
os.environ["HF_HOME"] = "/tmp/hf_cache" |
|
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" |
|
|
os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_cache" |
|
|
|
|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel |
|
|
from diffusers import StableDiffusionPipeline |
|
|
import torch |
|
|
|
|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
import base64 |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") |
|
|
pipe = pipe.to("cpu") |
|
|
|
|
|
class Prompt(BaseModel): |
|
|
text: str |
|
|
|
|
|
@app.get("/") |
|
|
def greet(): |
|
|
return {"message": "Model ready"} |
|
|
|
|
|
@app.post("/generate") |
|
|
def generate(prompt: Prompt): |
|
|
image = pipe(prompt.text).images[0] |
|
|
buffer = BytesIO() |
|
|
image.save(buffer, format="PNG") |
|
|
img_str = base64.b64encode(buffer.getvalue()).decode("utf-8") |
|
|
return {"image_base64": img_str, "prompt": prompt.text} |
|
|
|