""" REST API: 2 ảnh vào (foreground + background) → ảnh relight ra. Chạy local / Colab: uvicorn api_server:app --host 0.0.0.0 --port 8000 POST /relight (multipart/form-data) - foreground: file ảnh (jpg/png/…) - background: file ảnh - num_steps: int 1–4 (optional, default 1) Trả về: PNG (ảnh đã relight). Query `?include_composite=true` → JSON base64 (composite + relit). """ from __future__ import annotations import base64 import io import os os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1") from contextlib import asynccontextmanager from typing import Any, Dict import torch import uvicorn from fastapi import FastAPI, File, Form, HTTPException, Query, UploadFile from fastapi.responses import JSONResponse, Response from PIL import Image from relight_engine import load_lbm_and_segmenter, relight _model = None _birefnet = None @asynccontextmanager async def lifespan(app: FastAPI): global _model, _birefnet if not torch.cuda.is_available(): raise RuntimeError( "Cần GPU NVIDIA và PyTorch đã build với CUDA (torch.cuda.is_available() phải True). " f"Hiện tại: torch {torch.__version__}, cuda={torch.version.cuda!s}, " f"cuda_available={torch.cuda.is_available()}. " "Gỡ torch CPU-only rồi cài bản CUDA từ https://pytorch.org (chọn đúng phiên bản CUDA với driver); " "trên Windows chạy `nvidia-smi` để kiểm tra GPU. Không có GPU → chạy API trên Colab." ) print("Đang tải LBM + BiRefNet…") _model, _birefnet = load_lbm_and_segmenter() print("Sẵn sàng:", torch.cuda.get_device_name(0)) yield app = FastAPI(title="LBM Relighting API", lifespan=lifespan) async def _read_image(upload: UploadFile) -> Image.Image: raw = await upload.read() if not raw: raise HTTPException(400, f"File rỗng: {upload.filename}") return Image.open(io.BytesIO(raw)).convert("RGB") @app.get("/health") def health() -> Dict[str, Any]: return { "ok": True, "device": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None, } @app.post("/relight") async def relight_endpoint( foreground: UploadFile = File(..., description="Ảnh chủ thể (portrait)"), background: UploadFile = File(..., description="Ảnh nền mục tiêu"), num_steps: int = Form(1, ge=1, le=4), include_composite: bool = Query( False, description="True → JSON base64 (composite + relit); False → chỉ PNG relit" ), ): if _model is None or _birefnet is None: raise HTTPException(503, "Model chưa sẵn sàng") try: fg = await _read_image(foreground) bg = await _read_image(background) except HTTPException: raise except Exception as e: raise HTTPException(400, f"Không đọc được ảnh: {e}") from e try: comp, out = relight( fg, bg, model=_model, birefnet=_birefnet, num_sampling_steps=num_steps, ) except torch.OutOfMemoryError as e: torch.cuda.empty_cache() raise HTTPException( 507, "Hết VRAM GPU. Thử đặt LBM_MAX_SIDE=384 (hoặc 512), đóng app khác dùng GPU, " "hoặc chạy trên máy/Colab có VRAM lớn hơn.", ) from e if include_composite is True: buf_c = io.BytesIO() Image.fromarray(comp).save(buf_c, format="PNG") buf_o = io.BytesIO() Image.fromarray(out).save(buf_o, format="PNG") return JSONResponse( { "composite_png_base64": base64.b64encode(buf_c.getvalue()).decode(), "relit_png_base64": base64.b64encode(buf_o.getvalue()).decode(), "num_steps": num_steps, } ) buf = io.BytesIO() Image.fromarray(out).save(buf, format="PNG") return Response(content=buf.getvalue(), media_type="image/png") if __name__ == "__main__": uvicorn.run( "api_server:app", host="0.0.0.0", port=int(os.environ.get("PORT", "8000")), reload=False, )