# genmake_bg_remove.py import cv2 import numpy as np from rembg import remove # BiRefNet for background removal import torch from ultralytics import YOLO # For SAM2 (Ultralytics version) from PIL import Image import os # 1. BiRefNet Background Removal (default option) def remove_bg_birefnet(input_image_path): input_image = Image.open(input_image_path) output_image = remove(input_image) # Using BiRefNet for background removal output_image_path = input_image_path.replace(".jpg", "_no_bg.png").replace(".png", "_no_bg.png") output_image.save(output_image_path) return output_image_path # 2. SAM2 Background Removal (pro option) def remove_bg_sam2(input_image_path): # Load SAM2 model (Ultralytics) model = YOLO("sam2.pt") # Load pre-trained SAM2 model # Load input image image = cv2.imread(input_image_path) results = model(image) # Perform segmentation # Post-process results (segment) segmentation_mask = results[0].masks # Get segmentation mask (person segmentation) mask = segmentation_mask.numpy() # Create a masked image masked_image = cv2.bitwise_and(image, image, mask=mask.astype(np.uint8)) # Save the output image output_image_path = input_image_path.replace(".jpg", "_masked.png").replace(".png", "_masked.png") cv2.imwrite(output_image_path, masked_image) return output_image_path # 3. Hybrid Approach (BiRefNet first, then SAM2 for refinement) def remove_bg_hybrid(input_image_path): temp_image_path = remove_bg_birefnet(input_image_path) return remove_bg_sam2(temp_image_path) # Entry function to call based on the user's selection def process_background(input_image_path, method="auto"): if method == "auto": return remove_bg_birefnet(input_image_path) # BiRefNet default elif method == "pro": return remove_bg_sam2(input_image_path) # SAM2 pro elif method == "hybrid": return remove_bg_hybrid(input_image_path) # Hybrid method else: raise ValueError("Invalid method: Choose from 'auto', 'pro', 'hybrid'") # Example usage: # img_path = 'path_to_your_image.jpg' # result_path = process_background(img_path, 'auto') # 'auto', 'pro', or 'hybrid' # print(f"Processed image saved at: {result_path}")