| | import os |
| | import cv2 |
| | import numpy as np |
| | import torch |
| | import gradio as gr |
| | import argparse |
| | from pathlib import Path |
| | from glob import glob |
| | from typing import Optional, Tuple, List |
| | from PIL import Image |
| | from transformers import AutoModelForImageSegmentation |
| | from torchvision import transforms |
| | import time |
| | import os |
| | import platform |
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser(description="Run the image segmentation app") |
| | parser.add_argument("--share", action="store_true", help="Enable sharing of the Gradio interface") |
| | return parser.parse_args() |
| |
|
| | torch.set_float32_matmul_precision('high') |
| | torch.jit.script = lambda f: f |
| |
|
| | os.environ['HOME'] = os.path.expanduser('~') |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | def open_folder(): |
| | open_folder_path = os.path.abspath("results") |
| | if platform.system() == "Windows": |
| | os.startfile(open_folder_path) |
| | elif platform.system() == "Linux": |
| | os.system(f'xdg-open "{open_folder_path}"') |
| |
|
| | class ImagePreprocessor(): |
| | def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None: |
| | self.transform_image = transforms.Compose([ |
| | transforms.ToTensor(), |
| | ]) |
| | self.normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| |
|
| | def proc(self, image: Image.Image) -> torch.Tensor: |
| | image = image.convert('RGB') |
| | image = self.transform_image(image) |
| | return self.normalize(image) |
| |
|
| | usage_to_weights_file = { |
| | 'General': 'BiRefNet', |
| | 'General-Lite': 'BiRefNet_T', |
| | 'Portrait': 'BiRefNet-portrait', |
| | 'DIS': 'BiRefNet-DIS5K', |
| | 'HRSOD': 'BiRefNet-HRSOD', |
| | 'COD': 'BiRefNet-COD', |
| | 'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs' |
| | } |
| |
|
| | birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True) |
| | birefnet.to(device) |
| | birefnet.eval() |
| |
|
| | def process_single_image(image_path: str, resolution: str, output_folder: str) -> Tuple[str, str, float]: |
| | start_time = time.time() |
| | |
| | image = Image.open(image_path).convert('RGBA') |
| | |
| | if resolution == '': |
| | resolution = f"{image.width}x{image.height}" |
| | resolution = [int(int(reso)//32*32) for reso in resolution.strip().split('x')] |
| | |
| | image_shape = image.size[::-1] |
| | image_pil = image.resize(tuple(resolution)) |
| |
|
| | image_preprocessor = ImagePreprocessor(resolution=tuple(resolution)) |
| | image_proc = image_preprocessor.proc(image_pil) |
| | image_proc = image_proc.unsqueeze(0) |
| |
|
| | with torch.no_grad(): |
| | scaled_pred_tensor = birefnet(image_proc.to(device))[-1].sigmoid() |
| |
|
| | if device == 'cuda': |
| | scaled_pred_tensor = scaled_pred_tensor.cpu() |
| | |
| | pred = torch.nn.functional.interpolate(scaled_pred_tensor, size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy() |
| |
|
| | pred_rgba = np.zeros((*pred.shape, 4), dtype=np.uint8) |
| | pred_rgba[..., :3] = (pred[..., np.newaxis] * 255).astype(np.uint8) |
| | pred_rgba[..., 3] = (pred * 255).astype(np.uint8) |
| |
|
| | image_array = np.array(image) |
| | image_pred = image_array * (pred_rgba / 255.0) |
| | |
| | output_image = Image.fromarray(image_pred.astype(np.uint8), 'RGBA') |
| | |
| | base_filename = os.path.splitext(os.path.basename(image_path))[0] |
| | output_path = os.path.join(output_folder, f"{base_filename}.png") |
| | |
| | counter = 1 |
| | while os.path.exists(output_path): |
| | output_path = os.path.join(output_folder, f"{base_filename}_{counter:04d}.png") |
| | counter += 1 |
| |
|
| | output_image.save(output_path) |
| | |
| | processing_time = time.time() - start_time |
| | print(f"Processed {image_path} in {processing_time:.4f} seconds") |
| | return image_path, output_path, processing_time |
| |
|
| | def predict( |
| | image: str, |
| | resolution: str, |
| | weights_file: Optional[str], |
| | batch_folder: Optional[str] = None, |
| | output_folder: Optional[str] = None, |
| | is_batch: bool = False |
| | ) -> Tuple[str, List[Tuple[str, str]]]: |
| | global birefnet |
| | _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General'])) |
| | print('Using weights:', _weights_file) |
| | birefnet = AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True) |
| | birefnet.to(device) |
| | birefnet.eval() |
| |
|
| | if not output_folder: |
| | output_folder = 'results' |
| | os.makedirs(output_folder, exist_ok=True) |
| |
|
| | results = [] |
| |
|
| | if is_batch and batch_folder: |
| | image_files = glob(os.path.join(batch_folder, '*')) |
| | total_images = len(image_files) |
| | processed_images = 0 |
| | start_time = time.time() |
| |
|
| | for img_path in image_files: |
| | try: |
| | input_path, output_path, proc_time = process_single_image(img_path, resolution, output_folder) |
| | results.append((output_path, f"{proc_time:.4f} seconds")) |
| | processed_images += 1 |
| | elapsed_time = time.time() - start_time |
| | avg_time_per_image = elapsed_time / processed_images |
| | estimated_time_left = avg_time_per_image * (total_images - processed_images) |
| |
|
| | status = f"Processed {processed_images}/{total_images} images. Estimated time left: {estimated_time_left:.2f} seconds" |
| | print(status) |
| | except Exception as e: |
| | print(f"Error processing {img_path}: {str(e)}") |
| | continue |
| |
|
| | return f"Batch processing complete. Processed {processed_images}/{total_images} images.", results |
| | else: |
| | input_path, output_path, proc_time = process_single_image(image, resolution, output_folder) |
| | results.append((output_path, f"{proc_time:.4f} seconds")) |
| | return "Single image processing complete.", results |
| |
|
| | def create_interface(): |
| | with gr.Blocks() as demo: |
| | gr.Markdown("## SECourses Improved BiRefNet V2 'Bilateral Reference for High-Resolution Dichotomous Image Segmentation' APP - SOTA Background Remover") |
| | gr.Markdown("## Most Advanced Latest Version On : https://www.patreon.com/posts/109913645") |
| | |
| | with gr.Row(): |
| | input_image = gr.Image(type="filepath", label="Input Image",height=512) |
| | output_image = gr.Gallery(label="Output Image", elem_id="gallery",height=512) |
| |
|
| |
|
| | with gr.Row(): |
| | resolution = gr.Textbox(label="Resolution", placeholder="1024x1024 - Optional - Don't enter to use original image resolution - Higher res uses more VRAM but still works perfect with shared VRAM so fast") |
| | weights_file = gr.Dropdown(choices=list(usage_to_weights_file.keys()), value="General", label="Weights File") |
| | btn_open_outputs = gr.Button("Open Results Folder") |
| | btn_open_outputs.click(fn=open_folder) |
| |
|
| | with gr.Row(): |
| | batch_folder = gr.Textbox(label="Batch Folder Path") |
| | output_folder = gr.Textbox(label="Output Folder Path", value="results") |
| |
|
| | with gr.Row(): |
| | submit_button = gr.Button("Single Image Process") |
| | batch_button = gr.Button("Batch Process Images in Given Folder") |
| |
|
| | output_text = gr.Textbox(label="Processing Status") |
| |
|
| | submit_button.click( |
| | predict, |
| | inputs=[input_image, resolution, weights_file, batch_folder, output_folder, gr.Checkbox(value=False, visible=False)], |
| | outputs=[output_text, output_image] |
| | ) |
| |
|
| | batch_button.click( |
| | predict, |
| | inputs=[input_image, resolution, weights_file, batch_folder, output_folder, gr.Checkbox(value=True, visible=False)], |
| | outputs=[output_text, output_image] |
| | ) |
| |
|
| | return demo |
| |
|
| | if __name__ == "__main__": |
| | args = parse_args() |
| | demo = create_interface() |
| | demo.launch(inbrowser=True, share=args.share) |