Spaces:
Runtime error
Runtime error
| """LBM relighting — dùng chung cho Gradio (`app.py`) và REST API (`api_server.py`).""" | |
| from __future__ import annotations | |
| import os | |
| from copy import deepcopy | |
| from typing import TYPE_CHECKING, Tuple | |
| import numpy as np | |
| import torch | |
| import yaml | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| from safetensors.torch import load_file | |
| from torchvision.transforms import ToPILImage, ToTensor | |
| from transformers import AutoModelForImageSegmentation | |
| from utils import extract_object, get_model_from_config, resize_and_center_crop | |
| if TYPE_CHECKING: | |
| import PIL.Image | |
| def _vram_max_side() -> int | None: | |
| """Giới hạn cạnh dài nhất (px) theo VRAM hoặc LBM_MAX_SIDE để tránh OOM.""" | |
| raw = os.environ.get("LBM_MAX_SIDE", "").strip() | |
| if raw.isdigit(): | |
| return max(64, int(raw)) | |
| if not torch.cuda.is_available(): | |
| return None | |
| gb = torch.cuda.get_device_properties(0).total_memory / (1024.0**3) | |
| if gb < 4.0: | |
| return 512 | |
| if gb < 8.0: | |
| return 768 | |
| return None | |
| def _cap_wh(width: int, height: int, max_side: int) -> tuple[int, int]: | |
| if max(width, height) <= max_side: | |
| return width, height | |
| if width >= height: | |
| nw = max_side | |
| nh = max(8, round(height * max_side / width)) | |
| else: | |
| nh = max_side | |
| nw = max(8, round(width * max_side / height)) | |
| nw = max(8, (nw // 8) * 8) | |
| nh = max(8, (nh // 8) * 8) | |
| return nw, nh | |
| ASPECT_RATIOS = { | |
| str(512 / 2048): (512, 2048), | |
| str(1024 / 1024): (1024, 1024), | |
| str(2048 / 512): (2048, 512), | |
| str(896 / 1152): (896, 1152), | |
| str(1152 / 896): (1152, 896), | |
| str(512 / 1920): (512, 1920), | |
| str(640 / 1536): (640, 1536), | |
| str(768 / 1280): (768, 1280), | |
| str(1280 / 768): (1280, 768), | |
| str(1536 / 640): (1536, 640), | |
| str(1920 / 512): (1920, 512), | |
| } | |
| def load_lbm_and_segmenter(hf_token: str | None = None): | |
| """Tải LBM + BiRefNet một lần; trả về (model, birefnet) trên CUDA bf16.""" | |
| token = hf_token if hf_token is not None else os.getenv("HUGGINGFACE_TOKEN") | |
| model_path = hf_hub_download( | |
| "jasperai/LBM_relighting", "model.safetensors", token=token | |
| ) | |
| config_path = hf_hub_download( | |
| "jasperai/LBM_relighting", "config.yaml", token=token | |
| ) | |
| with open(config_path, "r", encoding="utf-8") as f: | |
| config = yaml.safe_load(f) | |
| model = get_model_from_config(**config) | |
| sd = load_file(model_path) | |
| model.load_state_dict(sd, strict=True) | |
| model.to("cuda").to(torch.bfloat16) | |
| if hasattr(model.vae, "vae_model"): | |
| vae = model.vae.vae_model | |
| try: | |
| vae.enable_slicing() | |
| except Exception: | |
| pass | |
| try: | |
| vae.enable_tiling() | |
| except Exception: | |
| pass | |
| birefnet = AutoModelForImageSegmentation.from_pretrained( | |
| "ZhengPeng7/BiRefNet", trust_remote_code=True | |
| ).cuda() | |
| model.eval() | |
| birefnet.eval() | |
| return model, birefnet | |
| def relight( | |
| fg_image: "PIL.Image.Image", | |
| bg_image: "PIL.Image.Image", | |
| *, | |
| model, | |
| birefnet, | |
| num_sampling_steps: int = 1, | |
| ) -> Tuple[np.ndarray, np.ndarray]: | |
| """ | |
| Trả về (composite_rgb, relit_rgb), numpy uint8 HxWx3 RGB. | |
| """ | |
| ori_h_bg, ori_w_bg = fg_image.size | |
| ar_bg = ori_h_bg / ori_w_bg | |
| closest_ar_bg = min(ASPECT_RATIOS, key=lambda x: abs(float(x) - ar_bg)) | |
| w_bg, h_bg = ASPECT_RATIOS[closest_ar_bg] | |
| max_side = _vram_max_side() | |
| if max_side is not None: | |
| w_bg, h_bg = _cap_wh(w_bg, h_bg, max_side) | |
| dimensions_bg = (w_bg, h_bg) | |
| _, fg_mask = extract_object(birefnet, deepcopy(fg_image)) | |
| fg_image = resize_and_center_crop(fg_image, dimensions_bg[0], dimensions_bg[1]) | |
| fg_mask = resize_and_center_crop(fg_mask, dimensions_bg[0], dimensions_bg[1]) | |
| bg_image = resize_and_center_crop(bg_image, dimensions_bg[0], dimensions_bg[1]) | |
| img_pasted = Image.composite(fg_image, bg_image, fg_mask) | |
| img_pasted_tensor = ToTensor()(img_pasted).unsqueeze(0) * 2 - 1 | |
| batch = {"source_image": img_pasted_tensor.cuda().to(torch.bfloat16)} | |
| torch.cuda.empty_cache() | |
| z_source = model.vae.encode(batch[model.source_key]) | |
| torch.cuda.empty_cache() | |
| output_image = model.sample( | |
| z=z_source, | |
| num_steps=num_sampling_steps, | |
| conditioner_inputs=batch, | |
| max_samples=1, | |
| ).clamp(-1, 1) | |
| output_image = (output_image[0].float().cpu() + 1) / 2 | |
| output_image = ToPILImage()(output_image) | |
| output_image = Image.composite(output_image, bg_image, fg_mask) | |
| output_image = output_image.resize((ori_h_bg, ori_w_bg), Image.LANCZOS) | |
| return np.array(img_pasted), np.array(output_image) | |