cutoutai / test_cutout.py
Camelhamcanaan's picture
feat: enhance background removal quality and API robustness
7bf41e6
import os
import io
import base64
import numpy as np
from PIL import Image, ImageDraw
import cutoutai
import logging
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("TestCutoutAI")
def create_test_image(path="test_input.png"):
"""Create a synthetic test image with bubbles and a central object."""
# 512x512 white background
img = Image.new("RGB", (512, 512), (240, 240, 240))
draw = ImageDraw.Draw(img)
# Draw a "subject" (blue circle)
draw.ellipse([150, 150, 362, 362], fill=(0, 0, 255), outline=(0, 0, 0))
# Draw "bubbles" (small circles)
draw.ellipse([50, 50, 80, 80], fill=(200, 200, 255, 128), outline=(100, 100, 100))
draw.ellipse([400, 100, 430, 130], fill=(200, 200, 255, 128), outline=(100, 100, 100))
draw.ellipse([100, 400, 140, 440], fill=(255, 200, 200, 128), outline=(100, 100, 100))
# Draw some "fine detail" (thin lines)
draw.line([256, 0, 256, 150], fill=(0, 0, 0), width=1)
img.save(path)
logger.info(f"Created test image: {path}")
return path
def test_processing():
"""Test the core processing logic."""
input_path = create_test_image()
# Use 'lite' variant for faster testing if possible,
# but the prompt asks for BiRefNet quality analysis.
# Note: Loading the model will take time and requires internet + torch.
# If we are in a restricted environment, this might fail.
try:
processor = cutoutai.CutoutAI(model_variant="lite") # Using lite for faster test
logger.info("Running process()...")
result = processor.process(
input_path,
capture_all_elements=True,
edge_refinement=True,
edge_radius=2,
output_format="pil"
)
output_path = "test_output.png"
result.save(output_path)
logger.info(f"Saved result to: {output_path}")
# Check if output is RGBA
if result.mode == "RGBA":
logger.info("SUCCESS: Output is in RGBA mode.")
else:
logger.error(f"FAILURE: Output mode is {result.mode}, expected RGBA.")
except Exception as e:
logger.error(f"Error during processing: {e}")
logger.info("Note: This test requires torch and transformers to be installed and working.")
if __name__ == "__main__":
test_processing()