File size: 861 Bytes
2844088
 
 
 
 
cff8cd2
356ca95
 
 
cff8cd2
814e862
 
2844088
814e862
cff8cd2
 
2844088
 
 
356ca95
 
 
 
814e862
2844088
 
814e862
356ca95
2844088
356ca95
2844088
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import os
os.environ["HF_HOME"] = "/tmp/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_cache"

from fastapi import FastAPI
from pydantic import BaseModel
from diffusers import StableDiffusionPipeline
import torch

from io import BytesIO
from PIL import Image
import base64

app = FastAPI()

# Load the model safely
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe = pipe.to("cpu")

class Prompt(BaseModel):
    text: str

@app.get("/")
def greet():
    return {"message": "Model ready"}

@app.post("/generate")
def generate(prompt: Prompt):
    image = pipe(prompt.text).images[0]
    buffer = BytesIO()
    image.save(buffer, format="PNG")
    img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
    return {"image_base64": img_str, "prompt": prompt.text}