import locale import os import sys os.environ.setdefault("LANG", "C.UTF-8") os.environ.setdefault("LC_ALL", "C.UTF-8") os.environ.setdefault("PYTHONIOENCODING", "utf-8") try: locale.setlocale(locale.LC_ALL, "C.UTF-8") except locale.Error: pass for stream in (sys.stdout, sys.stderr): if hasattr(stream, "reconfigure"): stream.reconfigure(encoding="utf-8", errors="replace") import gradio as gr from gradio_imageslider import ImageSlider from loadimg import load_img import spaces from transformers import AutoModelForImageSegmentation import torch from torchvision import transforms try: gr.Blocks.transpile_to_js = lambda self, quiet=False: None except Exception: pass torch.set_float32_matmul_precision(["high", "highest"][0]) _birefnet = None def get_birefnet(): global _birefnet if _birefnet is None: _birefnet = AutoModelForImageSegmentation.from_pretrained( "briaai/RMBG-2.0", trust_remote_code=True ).eval() return _birefnet transform_image = transforms.Compose( [ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) output_folder = 'output_images' if not os.path.exists(output_folder): os.makedirs(output_folder) def fn(image): im = load_img(image, output_type="pil") im = im.convert("RGB") origin = im.copy() image = process(im) image_path = os.path.join(output_folder, "no_bg_image.png") image.save(image_path) return (image, origin), image_path @spaces.GPU(duration=120) def process(image): image_size = image.size device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = get_birefnet().to(device) input_images = transform_image(image).unsqueeze(0).to(device) # Prediction with torch.no_grad(): preds = model(input_images)[-1].sigmoid().cpu() if device.type == "cuda": model.to("cpu") torch.cuda.empty_cache() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) image.putalpha(mask) return image def process_file(f): name_path = f.rsplit(".",1)[0]+".png" im = load_img(f, output_type="pil") im = im.convert("RGB") transparent = process(im) transparent.save(name_path) return name_path slider1 = ImageSlider(label="RMBG-2.0", type="pil") slider2 = ImageSlider(label="RMBG-2.0", type="pil") image = gr.Image(label="Upload an image") image2 = gr.Image(label="Upload an image",type="filepath") text = gr.Textbox(label="Paste an image URL") png_file = gr.File(label="output png file") chameleon = load_img("giraffe.jpg", output_type="pil") url = "http://farm9.staticflickr.com/8488/8228323072_76eeddfea3_z.jpg" tab1 = gr.Interface( fn, inputs=image, outputs=[slider1, gr.File(label="output png file")], examples=[chameleon], api_name="image" ) tab2 = gr.Interface(fn, inputs=text, outputs=[slider2, gr.File(label="output png file")], examples=[url], api_name="text") tab3 = gr.Interface(process_file, inputs=image2, outputs=png_file, examples=["giraffe.jpg"], api_name="png") demo = gr.TabbedInterface( [tab1, tab2], ["input image", "input url"], title = ( "RMBG-2.0 for background removal
" "" "Background removal model developed by " "BRIA.AI, trained on a carefully selected dataset,
" "and is available as an open-source model for non-commercial use.

" " For testing upload your image and wait.
" "Model card | " "Blog" "

" "" "API Endpoint available on: " "Bria.ai, " "fal.ai
" "ComfyUI node is available here: " "ComfyUI Node
" "Purchase weigths for commercial use: " "here" "
" ) ) if __name__ == "__main__": demo.launch(ssr_mode=False, show_error=True)