Spaces:
Runtime error
Runtime error
File size: 4,206 Bytes
78f372e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | """
REST API: 2 ảnh vào (foreground + background) → ảnh relight ra.
Chạy local / Colab:
uvicorn api_server:app --host 0.0.0.0 --port 8000
POST /relight (multipart/form-data)
- foreground: file ảnh (jpg/png/…)
- background: file ảnh
- num_steps: int 1–4 (optional, default 1)
Trả về: PNG (ảnh đã relight). Query `?include_composite=true` → JSON base64 (composite + relit).
"""
from __future__ import annotations
import base64
import io
import os
os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
from contextlib import asynccontextmanager
from typing import Any, Dict
import torch
import uvicorn
from fastapi import FastAPI, File, Form, HTTPException, Query, UploadFile
from fastapi.responses import JSONResponse, Response
from PIL import Image
from relight_engine import load_lbm_and_segmenter, relight
_model = None
_birefnet = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global _model, _birefnet
if not torch.cuda.is_available():
raise RuntimeError(
"Cần GPU NVIDIA và PyTorch đã build với CUDA (torch.cuda.is_available() phải True). "
f"Hiện tại: torch {torch.__version__}, cuda={torch.version.cuda!s}, "
f"cuda_available={torch.cuda.is_available()}. "
"Gỡ torch CPU-only rồi cài bản CUDA từ https://pytorch.org (chọn đúng phiên bản CUDA với driver); "
"trên Windows chạy `nvidia-smi` để kiểm tra GPU. Không có GPU → chạy API trên Colab."
)
print("Đang tải LBM + BiRefNet…")
_model, _birefnet = load_lbm_and_segmenter()
print("Sẵn sàng:", torch.cuda.get_device_name(0))
yield
app = FastAPI(title="LBM Relighting API", lifespan=lifespan)
async def _read_image(upload: UploadFile) -> Image.Image:
raw = await upload.read()
if not raw:
raise HTTPException(400, f"File rỗng: {upload.filename}")
return Image.open(io.BytesIO(raw)).convert("RGB")
@app.get("/health")
def health() -> Dict[str, Any]:
return {
"ok": True,
"device": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
}
@app.post("/relight")
async def relight_endpoint(
foreground: UploadFile = File(..., description="Ảnh chủ thể (portrait)"),
background: UploadFile = File(..., description="Ảnh nền mục tiêu"),
num_steps: int = Form(1, ge=1, le=4),
include_composite: bool = Query(
False, description="True → JSON base64 (composite + relit); False → chỉ PNG relit"
),
):
if _model is None or _birefnet is None:
raise HTTPException(503, "Model chưa sẵn sàng")
try:
fg = await _read_image(foreground)
bg = await _read_image(background)
except HTTPException:
raise
except Exception as e:
raise HTTPException(400, f"Không đọc được ảnh: {e}") from e
try:
comp, out = relight(
fg,
bg,
model=_model,
birefnet=_birefnet,
num_sampling_steps=num_steps,
)
except torch.OutOfMemoryError as e:
torch.cuda.empty_cache()
raise HTTPException(
507,
"Hết VRAM GPU. Thử đặt LBM_MAX_SIDE=384 (hoặc 512), đóng app khác dùng GPU, "
"hoặc chạy trên máy/Colab có VRAM lớn hơn.",
) from e
if include_composite is True:
buf_c = io.BytesIO()
Image.fromarray(comp).save(buf_c, format="PNG")
buf_o = io.BytesIO()
Image.fromarray(out).save(buf_o, format="PNG")
return JSONResponse(
{
"composite_png_base64": base64.b64encode(buf_c.getvalue()).decode(),
"relit_png_base64": base64.b64encode(buf_o.getvalue()).decode(),
"num_steps": num_steps,
}
)
buf = io.BytesIO()
Image.fromarray(out).save(buf, format="PNG")
return Response(content=buf.getvalue(), media_type="image/png")
if __name__ == "__main__":
uvicorn.run(
"api_server:app",
host="0.0.0.0",
port=int(os.environ.get("PORT", "8000")),
reload=False,
)
|