| import gradio as gr |
| import spaces |
| import torch |
| from loadimg import load_img |
| from torchvision import transforms |
| from transformers import AutoModelForImageSegmentation |
| from diffusers import FluxFillPipeline |
| from PIL import Image, ImageOps |
| from sam2.sam2_image_predictor import SAM2ImagePredictor |
| import numpy as np |
|
|
| torch.set_float32_matmul_precision(["high", "highest"][0]) |
|
|
| birefnet = AutoModelForImageSegmentation.from_pretrained( |
| "ZhengPeng7/BiRefNet", trust_remote_code=True |
| ) |
| birefnet.to("cuda") |
|
|
|
|
| transform_image = transforms.Compose( |
| [ |
| transforms.Resize((1024, 1024)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| ] |
| ) |
|
|
| pipe = FluxFillPipeline.from_pretrained( |
| "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16 |
| ).to("cuda") |
|
|
|
|
| def prepare_image_and_mask( |
| image, |
| padding_top=0, |
| padding_bottom=0, |
| padding_left=0, |
| padding_right=0, |
| ): |
| image = load_img(image).convert("RGB") |
| |
| background = ImageOps.expand( |
| image, |
| border=(padding_left, padding_top, padding_right, padding_bottom), |
| fill="white", |
| ) |
| mask = Image.new("RGB", image.size, "black") |
| mask = ImageOps.expand( |
| mask, |
| border=(padding_left, padding_top, padding_right, padding_bottom), |
| fill="white", |
| ) |
| return background, mask |
|
|
|
|
| def outpaint( |
| image, |
| padding_top=0, |
| padding_bottom=0, |
| padding_left=0, |
| padding_right=0, |
| prompt="", |
| num_inference_steps=28, |
| guidance_scale=50, |
| ): |
| background, mask = prepare_image_and_mask( |
| image, padding_top, padding_bottom, padding_left, padding_right |
| ) |
|
|
| result = pipe( |
| prompt=prompt, |
| height=background.height, |
| width=background.width, |
| image=background, |
| mask_image=mask, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| ).images[0] |
|
|
| result = result.convert("RGBA") |
|
|
| return result |
|
|
|
|
| def inpaint( |
| image, |
| mask, |
| prompt="", |
| num_inference_steps=28, |
| guidance_scale=50, |
| ): |
| background = image.convert("RGB") |
| mask = mask.convert("L") |
|
|
| result = pipe( |
| prompt=prompt, |
| height=background.height, |
| width=background.width, |
| image=background, |
| mask_image=mask, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| ).images[0] |
|
|
| result = result.convert("RGBA") |
|
|
| return result |
|
|
|
|
| def rmbg(image=None, url=None): |
| if image is None: |
| image = url |
| image = load_img(image).convert("RGB") |
| image_size = image.size |
| input_images = transform_image(image).unsqueeze(0).to("cuda") |
| |
| with torch.no_grad(): |
| preds = birefnet(input_images)[-1].sigmoid().cpu() |
| pred = preds[0].squeeze() |
| pred_pil = transforms.ToPILImage()(pred) |
| mask = pred_pil.resize(image_size) |
| image.putalpha(mask) |
| return image |
|
|
|
|
| def mask_generation(image=None, d=None): |
| d = eval(d) |
| predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-tiny") |
| predictor.set_image(image) |
| input_point = np.array(d["input_points"]) |
| input_label = np.array(d["input_labels"]) |
| masks, scores, logits = predictor.predict( |
| point_coords=input_point, |
| point_labels=input_label, |
| multimask_output=True, |
| ) |
| sorted_ind = np.argsort(scores)[::-1] |
| masks = masks[sorted_ind] |
| scores = scores[sorted_ind] |
| logits = logits[sorted_ind] |
|
|
| out = [] |
| for i in range(len(masks)): |
| m = Image.fromarray(masks[i] * 255).convert("L") |
| comp = Image.composite(image, m, m) |
| out.append((comp, f"image {i}")) |
|
|
| return out |
|
|
|
|
| @spaces.GPU |
| def main(*args): |
| api_num = args[0] |
| args = args[1:] |
| if api_num == 1: |
| return rmbg(*args) |
| elif api_num == 2: |
| return outpaint(*args) |
| elif api_num == 3: |
| return inpaint(*args) |
| elif api_num == 4: |
| return mask_generation(*args) |
|
|
|
|
| rmbg_tab = gr.Interface( |
| fn=main, |
| inputs=[ |
| gr.Number(1, interactive=False), |
| "image", |
| gr.Text("", label="url"), |
| ], |
| outputs=["image"], |
| api_name="rmbg", |
| examples=[[1, "./assets/Inpainting mask.png", ""]], |
| cache_examples=False, |
| description="pass an image or a url of an image", |
| ) |
|
|
| outpaint_tab = gr.Interface( |
| fn=main, |
| inputs=[ |
| gr.Number(2, interactive=False), |
| gr.Image(label="image", type="pil"), |
| gr.Number(label="padding top"), |
| gr.Number(label="padding bottom"), |
| gr.Number(label="padding left"), |
| gr.Number(label="padding right"), |
| gr.Text(label="prompt"), |
| gr.Number(value=50, label="num_inference_steps"), |
| gr.Number(value=28, label="guidance_scale"), |
| ], |
| outputs=["image"], |
| api_name="outpainting", |
| examples=[[2, "./assets/rocket.png", 100, 0, 0, 0, "", 50, 28]], |
| cache_examples=False, |
| ) |
|
|
|
|
| inpaint_tab = gr.Interface( |
| fn=main, |
| inputs=[ |
| gr.Number(3, interactive=False), |
| gr.Image(label="image", type="pil"), |
| gr.Image(label="mask", type="pil"), |
| gr.Text(label="prompt"), |
| gr.Number(value=50, label="num_inference_steps"), |
| gr.Number(value=28, label="guidance_scale"), |
| ], |
| outputs=["image"], |
| api_name="inpaint", |
| examples=[[3, "./assets/rocket.png", "./assets/Inpainting mask.png"]], |
| cache_examples=False, |
| description="it is recommended that you use https://github.com/la-voliere/react-mask-editor when creating an image mask in JS and then inverse it before sending it to this space", |
| ) |
|
|
|
|
| sam2_tab = gr.Interface( |
| main, |
| inputs=[ |
| gr.Number(4, interactive=False), |
| gr.Image(type="pil"), |
| gr.Text(), |
| ], |
| outputs=gr.Gallery(), |
| examples=[ |
| [ |
| 4, |
| "./assets/truck.jpg", |
| '{"input_points": [[500, 375], [1125, 625]], "input_labels": [1, 0]}', |
| ] |
| ], |
| api_name="sam2", |
| cache_examples=False, |
| ) |
|
|
| demo = gr.TabbedInterface( |
| [rmbg_tab, outpaint_tab, inpaint_tab, sam2_tab], |
| ["remove background", "outpainting", "inpainting", "sam2"], |
| title="Utilities that require GPU", |
| ) |
|
|
|
|
| demo.launch() |
|
|