Spaces:
Build error
Build error
| import glob | |
| import os | |
| from copy import deepcopy | |
| import gradio as gr | |
| import numpy as np | |
| import PIL | |
| import spaces | |
| import torch | |
| import yaml | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| from safetensors.torch import load_file | |
| from torchvision.transforms import ToPILImage, ToTensor | |
| from transformers import AutoModelForImageSegmentation | |
| from utils import extract_object, get_model_from_config, resize_and_center_crop | |
| huggingface_token = os.getenv("HUGGINGFACE_TOKEN") | |
| ASPECT_RATIOS = { | |
| str(512 / 2048): (512, 2048), | |
| str(1024 / 1024): (1024, 1024), | |
| str(2048 / 512): (2048, 512), | |
| str(896 / 1152): (896, 1152), | |
| str(1152 / 896): (1152, 896), | |
| str(512 / 1920): (512, 1920), | |
| str(640 / 1536): (640, 1536), | |
| str(768 / 1280): (768, 1280), | |
| str(1280 / 768): (1280, 768), | |
| str(1536 / 640): (1536, 640), | |
| str(1920 / 512): (1920, 512), | |
| } | |
| MODEL_PATH = hf_hub_download("jasperai/LBM_relighting", "model.safetensors", token=huggingface_token) | |
| CONFIG_PATH = hf_hub_download("jasperai/LBM_relighting", "config.yaml", token=huggingface_token) | |
| with open(CONFIG_PATH, "r") as f: | |
| config = yaml.safe_load(f) | |
| model = get_model_from_config(**config) | |
| sd = load_file(MODEL_PATH) | |
| model.load_state_dict(sd, strict=True) | |
| model.to("cuda").to(torch.bfloat16) | |
| birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True).cuda() | |
| image_size = (1024, 1024) | |
| def evaluate(fg_image: PIL.Image.Image, bg_image: PIL.Image.Image, num_sampling_steps: int = 4): | |
| ori_h_bg, ori_w_bg = fg_image.size | |
| ar_bg = ori_h_bg / ori_w_bg | |
| closest_ar_bg = min(ASPECT_RATIOS, key=lambda x: abs(float(x) - ar_bg)) | |
| dimensions_bg = ASPECT_RATIOS[closest_ar_bg] | |
| _, fg_mask = extract_object(birefnet, deepcopy(fg_image)) | |
| fg_image = resize_and_center_crop(fg_image, dimensions_bg[0], dimensions_bg[1]) | |
| fg_mask = resize_and_center_crop(fg_mask, dimensions_bg[0], dimensions_bg[1]) | |
| bg_image = resize_and_center_crop(bg_image, dimensions_bg[0], dimensions_bg[1]) | |
| img_pasted = Image.composite(fg_image, bg_image, fg_mask) | |
| img_pasted_tensor = ToTensor()(img_pasted).unsqueeze(0) * 2 - 1 | |
| batch = {"source_image": img_pasted_tensor.cuda().to(torch.bfloat16)} | |
| z_source = model.vae.encode(batch[model.source_key]) | |
| output_image = model.sample(z=z_source, num_steps=num_sampling_steps, conditioner_inputs=batch, max_samples=1).clamp(-1, 1) | |
| output_image = (output_image[0].float().cpu() + 1) / 2 | |
| output_image = ToPILImage()(output_image) | |
| output_image = Image.composite(output_image, bg_image, fg_mask) | |
| output_image.resize((ori_h_bg, ori_w_bg)) | |
| return (np.array(img_pasted), np.array(output_image)) | |
| with gr.Blocks() as app: | |
| gr.HTML(""" | |
| <style> | |
| body::before { | |
| content: ""; | |
| display: block; | |
| height: 320px; | |
| background-color: var(--body-background-fill); | |
| } | |
| button[aria-label="Fullscreen"], button[aria-label="Fullscreen"]:hover { | |
| display: none !important; | |
| visibility: hidden !important; | |
| opacity: 0 !important; | |
| pointer-events: none !important; | |
| } | |
| button[aria-label="Share"], button[aria-label="Share"]:hover { | |
| display: none !important; | |
| } | |
| button[aria-label="Download"] { | |
| transform: scale(3); | |
| transform-origin: top right; | |
| margin: 0 !important; | |
| padding: 6px !important; | |
| } | |
| </style> | |
| """) | |
| gr.Markdown("# Ndrysho Sfondin") | |
| gr.Markdown("Zëvendëso sfondin e fotove me rindriçim të avancuar nga inteligjenca artificiale.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| fg_image = gr.Image(type="pil", label="Imazhi Kryesor", image_mode="RGB", height=360) | |
| bg_image = gr.Image(type="pil", label="Sfondi i Ri", image_mode="RGB", height=360) | |
| with gr.Row(): | |
| submit_button = gr.Button("Rindriço") | |
| with gr.Row(): | |
| num_inference_steps = gr.Slider(minimum=1, maximum=4, value=4, step=1, visible=False) | |
| bg_gallery = gr.Gallery(object_fit="contain", visible=False) | |
| with gr.Column(): | |
| output_slider = gr.ImageSlider(label="Para / Pas", type="numpy") | |
| output_slider.upload(fn=evaluate, inputs=[fg_image, bg_image, num_inference_steps], outputs=[output_slider]) | |
| submit_button.click(evaluate, inputs=[fg_image, bg_image, num_inference_steps], outputs=[output_slider], show_progress="full", show_api=False) | |
| def bg_gallery_selected(gal, evt: gr.SelectData): | |
| return gal[evt.index][0] | |
| bg_gallery.select(bg_gallery_selected, inputs=bg_gallery, outputs=bg_image) | |
| if __name__ == "__main__": | |
| app.launch(share=True) |