GenMake-Crystal-Engine / genmake_bg_remove.py
mhtbhatia's picture
Upload 12 files
cb789b0 verified
# genmake_bg_remove.py
import cv2
import numpy as np
from rembg import remove # BiRefNet for background removal
import torch
from ultralytics import YOLO # For SAM2 (Ultralytics version)
from PIL import Image
import os
# 1. BiRefNet Background Removal (default option)
def remove_bg_birefnet(input_image_path):
input_image = Image.open(input_image_path)
output_image = remove(input_image) # Using BiRefNet for background removal
output_image_path = input_image_path.replace(".jpg", "_no_bg.png").replace(".png", "_no_bg.png")
output_image.save(output_image_path)
return output_image_path
# 2. SAM2 Background Removal (pro option)
def remove_bg_sam2(input_image_path):
# Load SAM2 model (Ultralytics)
model = YOLO("sam2.pt") # Load pre-trained SAM2 model
# Load input image
image = cv2.imread(input_image_path)
results = model(image) # Perform segmentation
# Post-process results (segment)
segmentation_mask = results[0].masks # Get segmentation mask (person segmentation)
mask = segmentation_mask.numpy()
# Create a masked image
masked_image = cv2.bitwise_and(image, image, mask=mask.astype(np.uint8))
# Save the output image
output_image_path = input_image_path.replace(".jpg", "_masked.png").replace(".png", "_masked.png")
cv2.imwrite(output_image_path, masked_image)
return output_image_path
# 3. Hybrid Approach (BiRefNet first, then SAM2 for refinement)
def remove_bg_hybrid(input_image_path):
temp_image_path = remove_bg_birefnet(input_image_path)
return remove_bg_sam2(temp_image_path)
# Entry function to call based on the user's selection
def process_background(input_image_path, method="auto"):
if method == "auto":
return remove_bg_birefnet(input_image_path) # BiRefNet default
elif method == "pro":
return remove_bg_sam2(input_image_path) # SAM2 pro
elif method == "hybrid":
return remove_bg_hybrid(input_image_path) # Hybrid method
else:
raise ValueError("Invalid method: Choose from 'auto', 'pro', 'hybrid'")
# Example usage:
# img_path = 'path_to_your_image.jpg'
# result_path = process_background(img_path, 'auto') # 'auto', 'pro', or 'hybrid'
# print(f"Processed image saved at: {result_path}")