lbm-lightning / relight_engine.py
tranduy17023's picture
Upload 7 files
2bbfba4 verified
"""LBM relighting — dùng chung cho Gradio (`app.py`) và REST API (`api_server.py`)."""
from __future__ import annotations
import os
from copy import deepcopy
from typing import TYPE_CHECKING, Tuple
import numpy as np
import torch
import yaml
from huggingface_hub import hf_hub_download
from PIL import Image
from safetensors.torch import load_file
from torchvision.transforms import ToPILImage, ToTensor
from transformers import AutoModelForImageSegmentation
from utils import extract_object, get_model_from_config, resize_and_center_crop
if TYPE_CHECKING:
import PIL.Image
def _vram_max_side() -> int | None:
"""Giới hạn cạnh dài nhất (px) theo VRAM hoặc LBM_MAX_SIDE để tránh OOM."""
raw = os.environ.get("LBM_MAX_SIDE", "").strip()
if raw.isdigit():
return max(64, int(raw))
if not torch.cuda.is_available():
return None
gb = torch.cuda.get_device_properties(0).total_memory / (1024.0**3)
if gb < 4.0:
return 512
if gb < 8.0:
return 768
return None
def _cap_wh(width: int, height: int, max_side: int) -> tuple[int, int]:
if max(width, height) <= max_side:
return width, height
if width >= height:
nw = max_side
nh = max(8, round(height * max_side / width))
else:
nh = max_side
nw = max(8, round(width * max_side / height))
nw = max(8, (nw // 8) * 8)
nh = max(8, (nh // 8) * 8)
return nw, nh
ASPECT_RATIOS = {
str(512 / 2048): (512, 2048),
str(1024 / 1024): (1024, 1024),
str(2048 / 512): (2048, 512),
str(896 / 1152): (896, 1152),
str(1152 / 896): (1152, 896),
str(512 / 1920): (512, 1920),
str(640 / 1536): (640, 1536),
str(768 / 1280): (768, 1280),
str(1280 / 768): (1280, 768),
str(1536 / 640): (1536, 640),
str(1920 / 512): (1920, 512),
}
def load_lbm_and_segmenter(hf_token: str | None = None):
"""Tải LBM + BiRefNet một lần; trả về (model, birefnet) trên CUDA bf16."""
token = hf_token if hf_token is not None else os.getenv("HUGGINGFACE_TOKEN")
model_path = hf_hub_download(
"jasperai/LBM_relighting", "model.safetensors", token=token
)
config_path = hf_hub_download(
"jasperai/LBM_relighting", "config.yaml", token=token
)
with open(config_path, "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
model = get_model_from_config(**config)
sd = load_file(model_path)
model.load_state_dict(sd, strict=True)
model.to("cuda").to(torch.bfloat16)
if hasattr(model.vae, "vae_model"):
vae = model.vae.vae_model
try:
vae.enable_slicing()
except Exception:
pass
try:
vae.enable_tiling()
except Exception:
pass
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
).cuda()
model.eval()
birefnet.eval()
return model, birefnet
@torch.inference_mode()
def relight(
fg_image: "PIL.Image.Image",
bg_image: "PIL.Image.Image",
*,
model,
birefnet,
num_sampling_steps: int = 1,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Trả về (composite_rgb, relit_rgb), numpy uint8 HxWx3 RGB.
"""
ori_h_bg, ori_w_bg = fg_image.size
ar_bg = ori_h_bg / ori_w_bg
closest_ar_bg = min(ASPECT_RATIOS, key=lambda x: abs(float(x) - ar_bg))
w_bg, h_bg = ASPECT_RATIOS[closest_ar_bg]
max_side = _vram_max_side()
if max_side is not None:
w_bg, h_bg = _cap_wh(w_bg, h_bg, max_side)
dimensions_bg = (w_bg, h_bg)
_, fg_mask = extract_object(birefnet, deepcopy(fg_image))
fg_image = resize_and_center_crop(fg_image, dimensions_bg[0], dimensions_bg[1])
fg_mask = resize_and_center_crop(fg_mask, dimensions_bg[0], dimensions_bg[1])
bg_image = resize_and_center_crop(bg_image, dimensions_bg[0], dimensions_bg[1])
img_pasted = Image.composite(fg_image, bg_image, fg_mask)
img_pasted_tensor = ToTensor()(img_pasted).unsqueeze(0) * 2 - 1
batch = {"source_image": img_pasted_tensor.cuda().to(torch.bfloat16)}
torch.cuda.empty_cache()
z_source = model.vae.encode(batch[model.source_key])
torch.cuda.empty_cache()
output_image = model.sample(
z=z_source,
num_steps=num_sampling_steps,
conditioner_inputs=batch,
max_samples=1,
).clamp(-1, 1)
output_image = (output_image[0].float().cpu() + 1) / 2
output_image = ToPILImage()(output_image)
output_image = Image.composite(output_image, bg_image, fg_mask)
output_image = output_image.resize((ori_h_bg, ori_w_bg), Image.LANCZOS)
return np.array(img_pasted), np.array(output_image)