File size: 4,206 Bytes
78f372e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""
REST API: 2 ảnh vào (foreground + background) → ảnh relight ra.

Chạy local / Colab:
  uvicorn api_server:app --host 0.0.0.0 --port 8000

POST /relight  (multipart/form-data)
  - foreground: file ảnh (jpg/png/…)
  - background: file ảnh
  - num_steps: int 1–4 (optional, default 1)

Trả về: PNG (ảnh đã relight). Query `?include_composite=true` → JSON base64 (composite + relit).
"""

from __future__ import annotations

import base64
import io
import os

os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")

from contextlib import asynccontextmanager
from typing import Any, Dict

import torch
import uvicorn
from fastapi import FastAPI, File, Form, HTTPException, Query, UploadFile
from fastapi.responses import JSONResponse, Response
from PIL import Image

from relight_engine import load_lbm_and_segmenter, relight

_model = None
_birefnet = None


@asynccontextmanager
async def lifespan(app: FastAPI):
    global _model, _birefnet
    if not torch.cuda.is_available():
        raise RuntimeError(
            "Cần GPU NVIDIA và PyTorch đã build với CUDA (torch.cuda.is_available() phải True). "
            f"Hiện tại: torch {torch.__version__}, cuda={torch.version.cuda!s}, "
            f"cuda_available={torch.cuda.is_available()}. "
            "Gỡ torch CPU-only rồi cài bản CUDA từ https://pytorch.org (chọn đúng phiên bản CUDA với driver); "
            "trên Windows chạy `nvidia-smi` để kiểm tra GPU. Không có GPU → chạy API trên Colab."
        )
    print("Đang tải LBM + BiRefNet…")
    _model, _birefnet = load_lbm_and_segmenter()
    print("Sẵn sàng:", torch.cuda.get_device_name(0))
    yield


app = FastAPI(title="LBM Relighting API", lifespan=lifespan)


async def _read_image(upload: UploadFile) -> Image.Image:
    raw = await upload.read()
    if not raw:
        raise HTTPException(400, f"File rỗng: {upload.filename}")
    return Image.open(io.BytesIO(raw)).convert("RGB")


@app.get("/health")
def health() -> Dict[str, Any]:
    return {
        "ok": True,
        "device": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
    }


@app.post("/relight")
async def relight_endpoint(
    foreground: UploadFile = File(..., description="Ảnh chủ thể (portrait)"),
    background: UploadFile = File(..., description="Ảnh nền mục tiêu"),
    num_steps: int = Form(1, ge=1, le=4),
    include_composite: bool = Query(
        False, description="True → JSON base64 (composite + relit); False → chỉ PNG relit"
    ),
):
    if _model is None or _birefnet is None:
        raise HTTPException(503, "Model chưa sẵn sàng")

    try:
        fg = await _read_image(foreground)
        bg = await _read_image(background)
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(400, f"Không đọc được ảnh: {e}") from e

    try:
        comp, out = relight(
            fg,
            bg,
            model=_model,
            birefnet=_birefnet,
            num_sampling_steps=num_steps,
        )
    except torch.OutOfMemoryError as e:
        torch.cuda.empty_cache()
        raise HTTPException(
            507,
            "Hết VRAM GPU. Thử đặt LBM_MAX_SIDE=384 (hoặc 512), đóng app khác dùng GPU, "
            "hoặc chạy trên máy/Colab có VRAM lớn hơn.",
        ) from e

    if include_composite is True:
        buf_c = io.BytesIO()
        Image.fromarray(comp).save(buf_c, format="PNG")
        buf_o = io.BytesIO()
        Image.fromarray(out).save(buf_o, format="PNG")
        return JSONResponse(
            {
                "composite_png_base64": base64.b64encode(buf_c.getvalue()).decode(),
                "relit_png_base64": base64.b64encode(buf_o.getvalue()).decode(),
                "num_steps": num_steps,
            }
        )

    buf = io.BytesIO()
    Image.fromarray(out).save(buf, format="PNG")
    return Response(content=buf.getvalue(), media_type="image/png")


if __name__ == "__main__":
    uvicorn.run(
        "api_server:app",
        host="0.0.0.0",
        port=int(os.environ.get("PORT", "8000")),
        reload=False,
    )