| import gradio as gr |
| import spaces |
| import torch |
| from loadimg import load_img |
| from torchvision import transforms |
| from transformers import AutoModelForImageSegmentation |
| from diffusers import FluxFillPipeline |
| from PIL import Image, ImageOps |
|
|
| torch.set_float32_matmul_precision(["high", "highest"][0]) |
|
|
| birefnet = AutoModelForImageSegmentation.from_pretrained( |
| "ZhengPeng7/BiRefNet", trust_remote_code=True |
| ) |
| birefnet.to("cuda") |
|
|
| transform_image = transforms.Compose( |
| [ |
| transforms.Resize((1024, 1024)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| ] |
| ) |
|
|
| pipe = FluxFillPipeline.from_pretrained( |
| "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16 |
| ).to("cuda") |
|
|
|
|
| def prepare_image_and_mask( |
| image, |
| padding_top=0, |
| padding_bottom=0, |
| padding_left=0, |
| padding_right=0, |
| ): |
| image = load_img(image).convert("RGB") |
| |
| background = ImageOps.expand( |
| image, |
| border=(padding_left, padding_top, padding_right, padding_bottom), |
| fill="white", |
| ) |
| mask = Image.new("RGB", image.size, "black") |
| mask = ImageOps.expand( |
| mask, |
| border=(padding_left, padding_top, padding_right, padding_bottom), |
| fill="white", |
| ) |
| return background, mask |
|
|
|
|
| def inpaint( |
| image, |
| padding_top=0, |
| padding_bottom=0, |
| padding_left=0, |
| padding_right=0, |
| prompt="", |
| num_inference_steps=28, |
| guidance_scale=50, |
| ): |
| background, mask = prepare_image_and_mask( |
| image, padding_top, padding_bottom, padding_left, padding_right |
| ) |
|
|
| result = pipe( |
| prompt=prompt, |
| height=background.height, |
| width=background.width, |
| image=background, |
| mask_image=mask, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| ).images[0] |
|
|
| result = result.convert("RGBA") |
|
|
| return result |
|
|
|
|
| def rmbg(image=None, url=None): |
| if image is None: |
| image = url |
| image = load_img(image).convert("RGB") |
| image_size = image.size |
| input_images = transform_image(image).unsqueeze(0).to("cuda") |
| |
| with torch.no_grad(): |
| preds = birefnet(input_images)[-1].sigmoid().cpu() |
| pred = preds[0].squeeze() |
| pred_pil = transforms.ToPILImage()(pred) |
| mask = pred_pil.resize(image_size) |
| image.putalpha(mask) |
| return image |
|
|
|
|
| @spaces.GPU |
| def main(*args, progress=gr.Progress(track_tqdm=True)): |
| api_num = args[0] |
| args = args[1:] |
| if api_num == 1: |
| return rmbg(*args) |
| elif api_num == 2: |
| return inpaint(*args) |
|
|
|
|
| rmbg_tab = gr.Interface( |
| fn=main, |
| inputs=[gr.Number(1, visible=False), "image", "text"], |
| outputs=["image"], |
| api_name="rmbg", |
| examples=[[1, "./assets/Inpainting mask.png", ""]], |
| cache_examples=False, |
| ) |
|
|
| outpaint_tab = gr.Interface( |
| fn=main, |
| inputs=[ |
| gr.Number(2, visible=False), |
| "image", |
| gr.Number(label="padding top"), |
| gr.Number(label="padding bottom"), |
| gr.Number(label="padding left"), |
| gr.Number(label="padding right"), |
| gr.Text(label="prompt"), |
| gr.Number(value=50, label="num_inference_steps"), |
| gr.Number(value=28, label="guidance_scale"), |
| ], |
| outputs=["image"], |
| api_name="outpainting", |
| ) |
|
|
| demo = gr.TabbedInterface( |
| [rmbg_tab, outpaint_tab], |
| ["remove background", "outpainting"], |
| title="Utilities that require GPU", |
| ) |
|
|
|
|
| demo.launch() |
|
|