interior-ai-api / app.py
ebraam1's picture
Update app.py
e5559c2 verified
from fastapi import FastAPI, UploadFile, File
from pydantic import BaseModel
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
from huggingface_hub import hf_hub_download
from fastapi.responses import StreamingResponse
from PIL import Image
import torch
import io
app = FastAPI()
# =========================
# ุชุญู…ูŠู„ ุงู„ู…ูˆุฏูŠู„ุงุช ู…ู† HuggingFace
# =========================
MODEL_PATH = hf_hub_download(
repo_id="ebraam1/interior-sd-models",
filename="Interior.safetensors"
)
LORA_PATH = hf_hub_download(
repo_id="ebraam1/interior-sd-models",
filename="Interior_lora.safetensors"
)
print("Loading base model...")
# =========================
# Load Stable Diffusion
# =========================
txt2img = StableDiffusionPipeline.from_single_file(
MODEL_PATH,
torch_dtype=torch.float16,
safety_checker=None
).to("cpu") # ุบูŠู‘ุฑู‡ุง ู„ู€ "cuda" ู„ูˆ GPU ู…ุชุงุญ
img2img = StableDiffusionImg2ImgPipeline.from_single_file(
MODEL_PATH,
torch_dtype=torch.float16,
safety_checker=None
).to("cpu")
print("Loading LoRA...")
txt2img.load_lora_weights(LORA_PATH)
img2img.load_lora_weights(LORA_PATH)
print("LoRA loaded ๐Ÿ”ฅ")
# =========================
# API Schema
# =========================
class Prompt(BaseModel):
prompt: str
# =========================
# Helper: PIL โ†’ Bytes
# =========================
def to_bytes(img):
buf = io.BytesIO()
img.save(buf, format="PNG")
buf.seek(0)
return buf
# =========================
# TEXT โ†’ IMAGE
# =========================
@app.post("/txt2img")
def generate(data: Prompt):
image = txt2img(
data.prompt,
cross_attention_kwargs={"scale": 0.8}
).images[0]
return StreamingResponse(to_bytes(image), media_type="image/png")
# =========================
# IMAGE โ†’ IMAGE
# =========================
@app.post("/img2img")
async def img2img_api(file: UploadFile = File(...), prompt: str = ""):
img = Image.open(io.BytesIO(await file.read())).convert("RGB").resize((512, 512))
image = img2img(
prompt=prompt,
image=img,
cross_attention_kwargs={"scale": 0.8}
).images[0]
return StreamingResponse(to_bytes(image), media_type="image/png")