background-remover-api / tests /test_remove_background.py
Keramo's picture
Upload 13 files
4ba81a9 verified
Raw
History Blame Contribute Delete
5.76 kB
"""
Tests for background removal endpoint.
"""
import pytest
import io
from unittest.mock import patch
from PIL import Image
from fastapi.testclient import TestClient
# Patch model loading before importing app
with patch("app.load_model"), \
patch("app.warm_model"), \
patch("app.is_model_ready", return_value=True):
from app import app
client = TestClient(app)
def create_test_image(
width: int = 100,
height: int = 100,
format: str = "PNG",
mode: str = "RGB",
) -> bytes:
"""Create a valid test image and return its bytes."""
img = Image.new(mode, (width, height))
buf = io.BytesIO()
img.save(buf, format=format)
buf.seek(0)
return buf.getvalue()
class TestRemoveBackgroundEndpoint:
"""Tests for the background removal endpoint."""
def test_remove_background_valid_png(self):
"""Successfully process a valid PNG image."""
image_data = create_test_image(format="PNG")
response = client.post(
"/api/v1/remove-background",
files={"file": ("test.png", io.BytesIO(image_data), "image/png")},
)
assert response.status_code == 200
assert response.headers["content-type"] == "image/png"
assert len(response.content) > 0
def test_remove_background_valid_jpeg(self):
"""Successfully process a valid JPEG image."""
image_data = create_test_image(format="JPEG", mode="RGB")
response = client.post(
"/api/v1/remove-background",
files={"file": ("test.jpg", io.BytesIO(image_data), "image/jpeg")},
)
assert response.status_code == 200
assert response.headers["content-type"] == "image/png"
def test_remove_background_valid_webp(self):
"""Successfully process a valid WEBP image."""
image_data = create_test_image(format="WEBP")
response = client.post(
"/api/v1/remove-background",
files={"file": ("test.webp", io.BytesIO(image_data), "image/webp")},
)
assert response.status_code == 200
assert response.headers["content-type"] == "image/png"
def test_remove_background_unsupported_mime_type(self):
"""Reject files with unsupported MIME types."""
response = client.post(
"/api/v1/remove-background",
files={"file": ("test.txt", io.BytesIO(b"not an image"), "text/plain")},
)
assert response.status_code == 415
data = response.json()
assert "detail" in data
assert "Unsupported" in data["detail"]
def test_remove_background_pdf_rejected(self):
"""Reject PDF files even with correct extension."""
response = client.post(
"/api/v1/remove-background",
files={"file": ("test.pdf", io.BytesIO(b"%PDF-fake"), "application/pdf")},
)
assert response.status_code == 415
def test_remove_background_corrupted_image(self):
"""Reject corrupted image data."""
response = client.post(
"/api/v1/remove-background",
files={"file": ("test.png", io.BytesIO(b"corrupted data"), "image/png")},
)
assert response.status_code == 400
data = response.json()
assert "detail" in data
def test_remove_background_oversized_file(self):
"""Reject files exceeding size limit (5 MB)."""
oversized_data = b"x" * (6 * 1024 * 1024) # 6 MB
response = client.post(
"/api/v1/remove-background",
files={"file": ("test.png", io.BytesIO(oversized_data), "image/png")},
)
assert response.status_code == 413
data = response.json()
assert "5 MB" in data["detail"]
def test_remove_background_invalid_dimensions(self):
"""Reject images exceeding dimension limits (6000x6000)."""
oversized_image = create_test_image(width=7000, height=100)
response = client.post(
"/api/v1/remove-background",
files={"file": ("test.png", io.BytesIO(oversized_image), "image/png")},
)
assert response.status_code == 400
data = response.json()
assert "dimensions" in data["detail"].lower() or "exceed" in data["detail"].lower()
def test_remove_background_missing_file(self):
"""Reject requests without file."""
response = client.post("/api/v1/remove-background")
assert response.status_code == 422
@patch("app.remove_background")
def test_remove_background_model_processing_error(self, mock_remove):
"""Handle model inference errors gracefully."""
from fastapi import HTTPException
mock_remove.side_effect = HTTPException(status_code=500, detail="Processing failed")
image_data = create_test_image(format="PNG")
response = client.post(
"/api/v1/remove-background",
files={"file": ("test.png", io.BytesIO(image_data), "image/png")},
)
assert response.status_code == 500
data = response.json()
assert "detail" in data
def test_response_headers_correct(self):
"""Verify response headers are correct."""
image_data = create_test_image(format="PNG")
response = client.post(
"/api/v1/remove-background",
files={"file": ("test.png", io.BytesIO(image_data), "image/png")},
)
assert response.status_code == 200
assert "attachment" in response.headers["content-disposition"]
assert "result.png" in response.headers["content-disposition"]
assert "X-Engine" in response.headers