"""LBM relighting — dùng chung cho Gradio (`app.py`) và REST API (`api_server.py`).""" from __future__ import annotations import os from copy import deepcopy from typing import TYPE_CHECKING, Tuple import numpy as np import torch import yaml from huggingface_hub import hf_hub_download from PIL import Image from safetensors.torch import load_file from torchvision.transforms import ToPILImage, ToTensor from transformers import AutoModelForImageSegmentation from utils import extract_object, get_model_from_config, resize_and_center_crop if TYPE_CHECKING: import PIL.Image def _vram_max_side() -> int | None: """Giới hạn cạnh dài nhất (px) theo VRAM hoặc LBM_MAX_SIDE để tránh OOM.""" raw = os.environ.get("LBM_MAX_SIDE", "").strip() if raw.isdigit(): return max(64, int(raw)) if not torch.cuda.is_available(): return None gb = torch.cuda.get_device_properties(0).total_memory / (1024.0**3) if gb < 4.0: return 512 if gb < 8.0: return 768 return None def _cap_wh(width: int, height: int, max_side: int) -> tuple[int, int]: if max(width, height) <= max_side: return width, height if width >= height: nw = max_side nh = max(8, round(height * max_side / width)) else: nh = max_side nw = max(8, round(width * max_side / height)) nw = max(8, (nw // 8) * 8) nh = max(8, (nh // 8) * 8) return nw, nh ASPECT_RATIOS = { str(512 / 2048): (512, 2048), str(1024 / 1024): (1024, 1024), str(2048 / 512): (2048, 512), str(896 / 1152): (896, 1152), str(1152 / 896): (1152, 896), str(512 / 1920): (512, 1920), str(640 / 1536): (640, 1536), str(768 / 1280): (768, 1280), str(1280 / 768): (1280, 768), str(1536 / 640): (1536, 640), str(1920 / 512): (1920, 512), } def load_lbm_and_segmenter(hf_token: str | None = None): """Tải LBM + BiRefNet một lần; trả về (model, birefnet) trên CUDA bf16.""" token = hf_token if hf_token is not None else os.getenv("HUGGINGFACE_TOKEN") model_path = hf_hub_download( "jasperai/LBM_relighting", "model.safetensors", token=token ) config_path = hf_hub_download( "jasperai/LBM_relighting", "config.yaml", token=token ) with open(config_path, "r", encoding="utf-8") as f: config = yaml.safe_load(f) model = get_model_from_config(**config) sd = load_file(model_path) model.load_state_dict(sd, strict=True) model.to("cuda").to(torch.bfloat16) if hasattr(model.vae, "vae_model"): vae = model.vae.vae_model try: vae.enable_slicing() except Exception: pass try: vae.enable_tiling() except Exception: pass birefnet = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True ).cuda() model.eval() birefnet.eval() return model, birefnet @torch.inference_mode() def relight( fg_image: "PIL.Image.Image", bg_image: "PIL.Image.Image", *, model, birefnet, num_sampling_steps: int = 1, ) -> Tuple[np.ndarray, np.ndarray]: """ Trả về (composite_rgb, relit_rgb), numpy uint8 HxWx3 RGB. """ ori_h_bg, ori_w_bg = fg_image.size ar_bg = ori_h_bg / ori_w_bg closest_ar_bg = min(ASPECT_RATIOS, key=lambda x: abs(float(x) - ar_bg)) w_bg, h_bg = ASPECT_RATIOS[closest_ar_bg] max_side = _vram_max_side() if max_side is not None: w_bg, h_bg = _cap_wh(w_bg, h_bg, max_side) dimensions_bg = (w_bg, h_bg) _, fg_mask = extract_object(birefnet, deepcopy(fg_image)) fg_image = resize_and_center_crop(fg_image, dimensions_bg[0], dimensions_bg[1]) fg_mask = resize_and_center_crop(fg_mask, dimensions_bg[0], dimensions_bg[1]) bg_image = resize_and_center_crop(bg_image, dimensions_bg[0], dimensions_bg[1]) img_pasted = Image.composite(fg_image, bg_image, fg_mask) img_pasted_tensor = ToTensor()(img_pasted).unsqueeze(0) * 2 - 1 batch = {"source_image": img_pasted_tensor.cuda().to(torch.bfloat16)} torch.cuda.empty_cache() z_source = model.vae.encode(batch[model.source_key]) torch.cuda.empty_cache() output_image = model.sample( z=z_source, num_steps=num_sampling_steps, conditioner_inputs=batch, max_samples=1, ).clamp(-1, 1) output_image = (output_image[0].float().cpu() + 1) / 2 output_image = ToPILImage()(output_image) output_image = Image.composite(output_image, bg_image, fg_mask) output_image = output_image.resize((ori_h_bg, ori_w_bg), Image.LANCZOS) return np.array(img_pasted), np.array(output_image)