import glob import os from copy import deepcopy import gradio as gr import numpy as np import PIL import spaces import torch import yaml from huggingface_hub import hf_hub_download from PIL import Image from safetensors.torch import load_file from torchvision.transforms import ToPILImage, ToTensor from transformers import AutoModelForImageSegmentation from utils import extract_object, get_model_from_config, resize_and_center_crop huggingface_token = os.getenv("HUGGINGFACE_TOKEN") ASPECT_RATIOS = { str(512 / 2048): (512, 2048), str(1024 / 1024): (1024, 1024), str(2048 / 512): (2048, 512), str(896 / 1152): (896, 1152), str(1152 / 896): (1152, 896), str(512 / 1920): (512, 1920), str(640 / 1536): (640, 1536), str(768 / 1280): (768, 1280), str(1280 / 768): (1280, 768), str(1536 / 640): (1536, 640), str(1920 / 512): (1920, 512), } MODEL_PATH = hf_hub_download("jasperai/LBM_relighting", "model.safetensors", token=huggingface_token) CONFIG_PATH = hf_hub_download("jasperai/LBM_relighting", "config.yaml", token=huggingface_token) with open(CONFIG_PATH, "r") as f: config = yaml.safe_load(f) model = get_model_from_config(**config) sd = load_file(MODEL_PATH) model.load_state_dict(sd, strict=True) model.to("cuda").to(torch.bfloat16) birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True).cuda() image_size = (1024, 1024) @spaces.GPU def evaluate(fg_image: PIL.Image.Image, bg_image: PIL.Image.Image, num_sampling_steps: int = 4): ori_h_bg, ori_w_bg = fg_image.size ar_bg = ori_h_bg / ori_w_bg closest_ar_bg = min(ASPECT_RATIOS, key=lambda x: abs(float(x) - ar_bg)) dimensions_bg = ASPECT_RATIOS[closest_ar_bg] _, fg_mask = extract_object(birefnet, deepcopy(fg_image)) fg_image = resize_and_center_crop(fg_image, dimensions_bg[0], dimensions_bg[1]) fg_mask = resize_and_center_crop(fg_mask, dimensions_bg[0], dimensions_bg[1]) bg_image = resize_and_center_crop(bg_image, dimensions_bg[0], dimensions_bg[1]) img_pasted = Image.composite(fg_image, bg_image, fg_mask) img_pasted_tensor = ToTensor()(img_pasted).unsqueeze(0) * 2 - 1 batch = {"source_image": img_pasted_tensor.cuda().to(torch.bfloat16)} z_source = model.vae.encode(batch[model.source_key]) output_image = model.sample(z=z_source, num_steps=num_sampling_steps, conditioner_inputs=batch, max_samples=1).clamp(-1, 1) output_image = (output_image[0].float().cpu() + 1) / 2 output_image = ToPILImage()(output_image) output_image = Image.composite(output_image, bg_image, fg_mask) output_image.resize((ori_h_bg, ori_w_bg)) return (np.array(img_pasted), np.array(output_image)) with gr.Blocks() as app: gr.HTML(""" """) gr.Markdown("# Ndrysho Sfondin") gr.Markdown("Zëvendëso sfondin e fotove me rindriçim të avancuar nga inteligjenca artificiale.") with gr.Row(): with gr.Column(): with gr.Row(): fg_image = gr.Image(type="pil", label="Imazhi Kryesor", image_mode="RGB", height=360) bg_image = gr.Image(type="pil", label="Sfondi i Ri", image_mode="RGB", height=360) with gr.Row(): submit_button = gr.Button("Rindriço") with gr.Row(): num_inference_steps = gr.Slider(minimum=1, maximum=4, value=4, step=1, visible=False) bg_gallery = gr.Gallery(object_fit="contain", visible=False) with gr.Column(): output_slider = gr.ImageSlider(label="Para / Pas", type="numpy") output_slider.upload(fn=evaluate, inputs=[fg_image, bg_image, num_inference_steps], outputs=[output_slider]) submit_button.click(evaluate, inputs=[fg_image, bg_image, num_inference_steps], outputs=[output_slider], show_progress="full", show_api=False) def bg_gallery_selected(gal, evt: gr.SelectData): return gal[evt.index][0] bg_gallery.select(bg_gallery_selected, inputs=bg_gallery, outputs=bg_image) if __name__ == "__main__": app.launch(share=True)