pipeline1copy / test_removal.py
Janeka's picture
Create test_removal.py
84f004b verified
import numpy as np
from PIL import Image
from transformers import pipeline
import matplotlib.pyplot as plt
# List of models to test
MODELS = {
"BRIA": "BRIA-AI/bria-rmbg",
"INSPyReNet": "mattmdjaga/INSPyReNet",
"U2Net": "silks-road/u2net",
"U2Net-Human": "mattmdjaga/u2net-human-seg",
"ISNet-General": "xuebinqin/ISNet-general-use",
"ISNet-Anime": "skytnt/anime-seg"
}
def test_removal(image_path):
# Load image
original = Image.open(image_path).convert("RGB")
img_np = np.array(original)
# Create figure to display results
plt.figure(figsize=(15, 8))
plt.subplot(2, 4, 1)
plt.imshow(original)
plt.title("Original")
plt.axis('off')
# Test each model
for i, (name, repo) in enumerate(MODELS.items(), start=2):
try:
print(f"Testing {name}...")
pipe = pipeline("image-segmentation", repo)
result = pipe(img_np)
mask = result[0]['mask'] if isinstance(result, list) else result['mask']
# Apply mask
background = Image.new('RGB', original.size, (255, 0, 0)) # Red background for visibility
result_img = Image.composite(original, background, Image.fromarray(mask))
plt.subplot(2, 4, i)
plt.imshow(result_img)
plt.title(name)
plt.axis('off')
print(f"{name} worked!")
except Exception as e:
print(f"{name} failed: {str(e)}")
plt.subplot(2, 4, i)
plt.text(0.5, 0.5, f"{name}\nFailed", ha='center')
plt.axis('off')
plt.tight_layout()
plt.show()
# Test with your image
test_removal("your_image.jpg") # Change this to your image path