| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from torchvision.transforms.functional import normalize |
| | from huggingface_hub import hf_hub_download |
| | import gradio as gr |
| | from gradio_imageslider import ImageSlider |
| | from briarmbg import BriaRMBG |
| | import PIL |
| | from PIL import Image |
| | from typing import Tuple |
| |
|
| | import os |
| | import requests |
| | from moviepy.editor import VideoFileClip |
| | from moviepy.audio.AudioClip import AudioClip |
| |
|
| | def search_pexels_images(query): |
| | API_KEY = os.getenv("API_KEY") |
| | url = f"https://api.pexels.com/v1/search?query={query}&per_page=80" |
| | headers = {"Authorization": API_KEY} |
| | response = requests.get(url, headers=headers) |
| | data = response.json() |
| | |
| | |
| | images_urls = [] |
| | for photo in data.get('photos', []): |
| | |
| | if 'src' in photo and 'large2x' in photo['src']: |
| | images_urls.append(photo['src']['large2x']) |
| | |
| | elif 'large' in photo['src']: |
| | images_urls.append(photo['src']['large']) |
| | elif 'original' in photo['src']: |
| | images_urls.append(photo['src']['original']) |
| |
|
| | return images_urls |
| |
|
| |
|
| | def show_search_results(query): |
| | images_urls = search_pexels_images(query) |
| | return images_urls |
| | |
| |
|
| | net=BriaRMBG() |
| | |
| | model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth') |
| | if torch.cuda.is_available(): |
| | net.load_state_dict(torch.load(model_path)) |
| | net=net.cuda() |
| | else: |
| | net.load_state_dict(torch.load(model_path,map_location="cpu")) |
| | net.eval() |
| |
|
| | |
| | def resize_image(image): |
| | image = image.convert('RGB') |
| | model_input_size = (1024, 1024) |
| | image = image.resize(model_input_size, Image.BILINEAR) |
| | return image |
| |
|
| |
|
| | def process(image): |
| | |
| | if isinstance(image, np.ndarray): |
| | orig_image = Image.fromarray(image) |
| | else: |
| | |
| | orig_image = image |
| |
|
| | w, h = orig_im_size = orig_image.size |
| | image = resize_image(orig_image) |
| | im_np = np.array(image) |
| | im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1) |
| | im_tensor = torch.unsqueeze(im_tensor, 0) |
| | im_tensor = torch.divide(im_tensor, 255.0) |
| | im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) |
| | if torch.cuda.is_available(): |
| | im_tensor = im_tensor.cuda() |
| |
|
| | |
| | result = net(im_tensor) |
| | |
| | result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0) |
| | ma = torch.max(result) |
| | mi = torch.min(result) |
| | result = (result - mi) / (ma - mi) |
| | |
| | im_array = (result * 255).cpu().data.numpy().astype(np.uint8) |
| | pil_im = Image.fromarray(np.squeeze(im_array)) |
| | |
| | new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0)) |
| | new_im.paste(orig_image, mask=pil_im) |
| |
|
| | return new_im |
| |
|
| | def calculate_position(org_size, add_size, position): |
| | if position == "์๋จ ์ข์ธก": |
| | return (0, 0) |
| | elif position == "์๋จ ๊ฐ์ด๋ฐ": |
| | return ((org_size[0] - add_size[0]) // 2, 0) |
| | elif position == "์๋จ ์ฐ์ธก": |
| | return (org_size[0] - add_size[0], 0) |
| | elif position == "์ค์ ์ข์ธก": |
| | return (0, (org_size[1] - add_size[1]) // 2) |
| | elif position == "์ค์ ๊ฐ์ด๋ฐ": |
| | return ((org_size[0] - add_size[0]) // 2, (org_size[1] - add_size[1]) // 2) |
| | elif position == "์ค์ ์ฐ์ธก": |
| | return (org_size[0] - add_size[0], (org_size[1] - add_size[1]) // 2) |
| | elif position == "ํ๋จ ์ข์ธก": |
| | return (0, org_size[1] - add_size[1]) |
| | elif position == "ํ๋จ ๊ฐ์ด๋ฐ": |
| | return ((org_size[0] - add_size[0]) // 2, org_size[1] - add_size[1]) |
| | elif position == "ํ๋จ ์ฐ์ธก": |
| | return (org_size[0] - add_size[0], org_size[1] - add_size[1]) |
| |
|
| |
|
| | def merge(org_image, add_image, scale, position, display_size): |
| | |
| | display_width, display_height = map(int, display_size.split('x')) |
| | |
| | |
| | scale_percentage = scale / 100.0 |
| | new_size = (int(add_image.width * scale_percentage), int(add_image.height * scale_percentage)) |
| | add_image = add_image.resize(new_size, Image.Resampling.LANCZOS) |
| | |
| | position = calculate_position(org_image.size, add_image.size, position) |
| | merged_image = Image.new("RGBA", org_image.size) |
| | merged_image.paste(org_image, (0, 0)) |
| | merged_image.paste(add_image, position, add_image) |
| | |
| | |
| | final_image = merged_image.resize((display_width, display_height), Image.Resampling.LANCZOS) |
| | |
| | return final_image |
| |
|
| |
|
| | with gr.Blocks() as demo: |
| | with gr.Tab("Background Removal"): |
| | with gr.Column(): |
| | gr.Markdown("๋๋ผ๋ฐ๊ธฐ์ ์ '๋ํน'(Nuking)") |
| | gr.HTML(''' |
| | <p style="margin-bottom: 10px; font-size: 94%"> |
| | This is a demo for BRIA RMBG 1.4 that using |
| | <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone. |
| | </p> |
| | ''') |
| | input_image = gr.Image(type="pil") |
| | output_image = gr.Image() |
| | process_button = gr.Button("Remove Background") |
| | process_button.click(fn=process, inputs=input_image, outputs=output_image) |
| |
|
| | with gr.Tab("Merge"): |
| | with gr.Column(): |
| | org_image = gr.Image(label="Background", type='pil', image_mode='RGBA', height=400) |
| | add_image = gr.Image(label="Foreground", type='pil', image_mode='RGBA', height=400) |
| | scale = gr.Slider(minimum=10, maximum=200, step=1, value=100, label="Scale of Foreground Image (%)") |
| | position = gr.Radio(choices=["์ค์ ๊ฐ์ด๋ฐ", "์๋จ ์ข์ธก", "์๋จ ๊ฐ์ด๋ฐ", "์๋จ ์ฐ์ธก", "์ค์ ์ข์ธก", "์ค์ ์ฐ์ธก", "ํ๋จ ์ข์ธก", "ํ๋จ ๊ฐ์ด๋ฐ", "ํ๋จ ์ฐ์ธก"], value="์ค์ ๊ฐ์ด๋ฐ", label="Position of Foreground Image") |
| | display_size = gr.Textbox(value="1024x768", label="Display Size (Width x Height)") |
| | btn_merge = gr.Button("Merge Images") |
| | result_merge = gr.Image() |
| | |
| | btn_merge.click( |
| | fn=merge, |
| | inputs=[org_image, add_image, scale, position, display_size], |
| | outputs=result_merge, |
| | ) |
| |
|
| | |
| | with gr.TabItem("Image Search"): |
| | with gr.Column(): |
| | gr.Markdown("### FREE Image Search") |
| | search_query = gr.Textbox(label="์ฌ์ง ๊ฒ์") |
| | search_btn = gr.Button("๊ฒ์") |
| | images_output = gr.Gallery(label="๊ฒ์ ๊ฒฐ๊ณผ ์ด๋ฏธ์ง") |
| | search_btn.click( |
| | fn=show_search_results, |
| | inputs=search_query, |
| | outputs=images_output |
| | ) |
| |
|
| | demo.launch() |