Commit
·
7bf41e6
0
Parent(s):
feat: enhance background removal quality and API robustness
Browse files- Added adaptive thresholding and artifact removal
- Optimized Dockerfile for HuggingFace Spaces
- Enhanced API with debug mode and extra parameters
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
- .gitignore +15 -0
- @fix_plan.md +35 -0
- ASSETS.md +67 -0
- Dockerfile +46 -0
- PROMPT.md +79 -0
- README.md +68 -0
- api.py +388 -0
- cutoutai.py +441 -0
- requirements.txt +22 -0
- run-claude-analysis.bat +23 -0
- specs/requirements.md +134 -0
- test_cutout.py +70 -0
.gitignore
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
venv/
|
| 6 |
+
.env
|
| 7 |
+
|
| 8 |
+
# Images/Outputs
|
| 9 |
+
*_output.png
|
| 10 |
+
test_input.png
|
| 11 |
+
cache/
|
| 12 |
+
|
| 13 |
+
# IDEs
|
| 14 |
+
.vscode/
|
| 15 |
+
.idea/
|
@fix_plan.md
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CutoutAI - Task Priority List
|
| 2 |
+
|
| 3 |
+
## Completed
|
| 4 |
+
- [x] Create basic cutoutai.py with BiRefNet integration
|
| 5 |
+
- [x] Create api.py with webhook endpoint
|
| 6 |
+
- [x] Add edge smoothing function
|
| 7 |
+
- [x] Add requirements.txt
|
| 8 |
+
- [x] Create project documentation
|
| 9 |
+
- [x] Add mask thresholding (0.2 for capture_all, 0.4 standard)
|
| 10 |
+
- [x] Implement capture_all_elements with lower threshold
|
| 11 |
+
- [x] Replace blur with morphological edge processing (preserves details)
|
| 12 |
+
- [x] Add startup model preloading
|
| 13 |
+
- [x] Add processing time to responses
|
| 14 |
+
- [x] Use model parameter in webhook
|
| 15 |
+
- [x] Add input validation (10MB limit)
|
| 16 |
+
- [x] Add Dockerfile for HuggingFace Spaces deployment
|
| 17 |
+
- [x] Add processing time logging
|
| 18 |
+
- [x] Add optional debug mode with intermediate outputs (return_mask=True)
|
| 19 |
+
- [x] Add artifact removal (scipy ndimage)
|
| 20 |
+
- [x] Add adaptive thresholding
|
| 21 |
+
|
| 22 |
+
## In Progress
|
| 23 |
+
- [ ] Test with Gemini-generated images
|
| 24 |
+
|
| 25 |
+
## High Priority
|
| 26 |
+
- [ ] Add Dockerfile for HuggingFace Spaces deployment
|
| 27 |
+
- [ ] Test with various Gemini-generated design types
|
| 28 |
+
|
| 29 |
+
## Medium Priority
|
| 30 |
+
- [ ] Add batch processing optimizations
|
| 31 |
+
- [ ] Add processing time logging
|
| 32 |
+
- [ ] Add optional debug mode with intermediate outputs
|
| 33 |
+
|
| 34 |
+
## Low Priority
|
| 35 |
+
- [ ] Support for BiRefNet_HR (2K resolution)
|
ASSETS.md
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Background Removal Tool - Project Assets
|
| 2 |
+
|
| 3 |
+
> **Purpose**: Self-hosted background removal API for Etsy t-shirt workflow
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## Core Files
|
| 8 |
+
|
| 9 |
+
| File | Description |
|
| 10 |
+
|------|-------------|
|
| 11 |
+
| `cutoutai.py` | Core BiRefNet processing (365 lines) |
|
| 12 |
+
| `api.py` | FastAPI server with webhooks (351 lines) |
|
| 13 |
+
| `Dockerfile` | Production container |
|
| 14 |
+
| `requirements.txt` | Python dependencies |
|
| 15 |
+
| `test_cutout.py` | Automated test script |
|
| 16 |
+
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
## Configuration
|
| 20 |
+
|
| 21 |
+
| File | Description |
|
| 22 |
+
|------|-------------|
|
| 23 |
+
| `PROMPT.md` | Ralph development instructions |
|
| 24 |
+
| `@fix_plan.md` | Task priority tracking |
|
| 25 |
+
| `specs/requirements.md` | Technical specifications |
|
| 26 |
+
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
## Test Outputs
|
| 30 |
+
|
| 31 |
+
| File | Description |
|
| 32 |
+
|------|-------------|
|
| 33 |
+
| `test_output.png` | Synthetic test result |
|
| 34 |
+
| `real_test_output.png` | cosmic_bloom.png result |
|
| 35 |
+
| `hard_test_output.png` | ChatGPT image result (3.6MB input) |
|
| 36 |
+
|
| 37 |
+
---
|
| 38 |
+
|
| 39 |
+
## Key Features
|
| 40 |
+
|
| 41 |
+
- **Models**: matting, general, portrait, lite, hr, dynamic
|
| 42 |
+
- **API**: REST + Webhook (n8n compatible)
|
| 43 |
+
- **Output**: PNG, base64
|
| 44 |
+
- **Thresholding**: 0.2 (capture_all) / 0.4 (standard)
|
| 45 |
+
|
| 46 |
+
---
|
| 47 |
+
|
| 48 |
+
## Deployment Status
|
| 49 |
+
|
| 50 |
+
| Target | Status |
|
| 51 |
+
|--------|--------|
|
| 52 |
+
| Local | ✅ Ready |
|
| 53 |
+
| Railway | ⬜ Not deployed |
|
| 54 |
+
| HuggingFace | ⬜ Not deployed |
|
| 55 |
+
|
| 56 |
+
---
|
| 57 |
+
|
| 58 |
+
## Related Projects
|
| 59 |
+
|
| 60 |
+
| Project | Relationship |
|
| 61 |
+
|---------|--------------|
|
| 62 |
+
| `etsy tshirt project` | Primary consumer of this API |
|
| 63 |
+
| `system-instructions` | CCR/Ralph configuration |
|
| 64 |
+
|
| 65 |
+
---
|
| 66 |
+
|
| 67 |
+
*Last Updated: Dec 28, 2025*
|
Dockerfile
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use NVIDIA CUDA base image if possible, otherwise standard python
|
| 2 |
+
FROM python:3.10-slim
|
| 3 |
+
|
| 4 |
+
# Set environment variables
|
| 5 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 6 |
+
ENV PYTHONUNBUFFERED=1
|
| 7 |
+
ENV TRANSFORMERS_CACHE=/app/cache
|
| 8 |
+
ENV MPLCONFIGDIR=/app/cache
|
| 9 |
+
ENV HOME=/home/user
|
| 10 |
+
|
| 11 |
+
# Install system dependencies
|
| 12 |
+
RUN apt-get update && apt-get install -y \
|
| 13 |
+
build-essential \
|
| 14 |
+
libgl1-mesa-glx \
|
| 15 |
+
libglib2.0-0 \
|
| 16 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 17 |
+
|
| 18 |
+
# Set up a new user named "user" with UID 1000
|
| 19 |
+
RUN useradd -m -u 1000 user
|
| 20 |
+
|
| 21 |
+
# Create app directory
|
| 22 |
+
WORKDIR /app
|
| 23 |
+
|
| 24 |
+
# Install dependencies
|
| 25 |
+
COPY requirements.txt .
|
| 26 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 27 |
+
|
| 28 |
+
# Create cache directory with right permissions
|
| 29 |
+
RUN mkdir -p /app/cache && chmod 777 /app/cache
|
| 30 |
+
|
| 31 |
+
# Switch to the "user" user
|
| 32 |
+
USER user
|
| 33 |
+
|
| 34 |
+
# Set home to /home/user
|
| 35 |
+
ENV HOME=/home/user
|
| 36 |
+
ENV PATH=/home/user/.local/bin:$PATH
|
| 37 |
+
|
| 38 |
+
# Copy app code (owned by user)
|
| 39 |
+
COPY --chown=user . .
|
| 40 |
+
|
| 41 |
+
# Expose port (HuggingFace default is 7860)
|
| 42 |
+
EXPOSE 7860
|
| 43 |
+
|
| 44 |
+
# Start command
|
| 45 |
+
# We use uvicorn to run the FastAPI app on port 7860
|
| 46 |
+
CMD ["python", "api.py", "--port", "7860"]
|
PROMPT.md
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CutoutAI Background Remover - Ralph Development Instructions
|
| 2 |
+
|
| 3 |
+
## Project Goal
|
| 4 |
+
Create a flawless background removal tool for the Etsy t-shirt workflow. This tool must produce perfect cutouts suitable for Printify mockups.
|
| 5 |
+
|
| 6 |
+
## Current Workflow
|
| 7 |
+
```
|
| 8 |
+
Gemini Image Gen → Slack Approval → BACKGROUND REMOVAL → Printify Mockup → SEO → Etsy/Shopify
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
## Critical Requirements
|
| 12 |
+
|
| 13 |
+
### 1. FLAWLESS Quality (Non-Negotiable)
|
| 14 |
+
- NO patchy faces or artifacts
|
| 15 |
+
- NO edge bleeding or halos
|
| 16 |
+
- CLEAN edges on hair and fine details
|
| 17 |
+
- Must look perfect on t-shirt mockups
|
| 18 |
+
|
| 19 |
+
### 2. Multi-Element Capture
|
| 20 |
+
The tool MUST capture ALL design elements including:
|
| 21 |
+
- Main subject
|
| 22 |
+
- Bubbles and floating decorations
|
| 23 |
+
- Small text or symbols
|
| 24 |
+
- Scattered elements (stars, sparkles, etc.)
|
| 25 |
+
|
| 26 |
+
### 3. API Integration
|
| 27 |
+
Must provide:
|
| 28 |
+
- Webhook endpoint for n8n (POST /webhook)
|
| 29 |
+
- REST API (POST /api/v1/remove)
|
| 30 |
+
- Base64 input/output support
|
| 31 |
+
- Health check endpoint
|
| 32 |
+
|
| 33 |
+
## Files to Review and Improve
|
| 34 |
+
|
| 35 |
+
1. **cutoutai.py** - Core processing logic
|
| 36 |
+
- Uses BiRefNet-matting model (correct choice)
|
| 37 |
+
- Has edge_smooth function (may need enhancement)
|
| 38 |
+
- Check if multi-element capture is working properly
|
| 39 |
+
|
| 40 |
+
2. **api.py** - FastAPI server
|
| 41 |
+
- Webhook endpoint exists
|
| 42 |
+
- Verify n8n compatibility
|
| 43 |
+
- Add any missing error handling
|
| 44 |
+
|
| 45 |
+
3. **requirements.txt** - Dependencies
|
| 46 |
+
- Verify all needed packages are listed
|
| 47 |
+
|
| 48 |
+
## Improvement Tasks
|
| 49 |
+
|
| 50 |
+
### Priority 1: Quality Enhancement
|
| 51 |
+
- [ ] Verify BiRefNet output quality
|
| 52 |
+
- [ ] Test edge refinement settings
|
| 53 |
+
- [ ] Add adaptive thresholding for multi-element capture
|
| 54 |
+
- [ ] Consider adding post-processing for artifact removal
|
| 55 |
+
|
| 56 |
+
### Priority 2: API Robustness
|
| 57 |
+
- [ ] Add proper error responses with details
|
| 58 |
+
- [ ] Add request validation
|
| 59 |
+
- [ ] Add timeout handling for large images
|
| 60 |
+
- [ ] Verify callback_url functionality
|
| 61 |
+
|
| 62 |
+
### Priority 3: Deployment Ready
|
| 63 |
+
- [ ] Add Dockerfile for HuggingFace Spaces
|
| 64 |
+
- [ ] Add startup preloading (reduce first-request latency)
|
| 65 |
+
- [ ] Add logging for debugging
|
| 66 |
+
|
| 67 |
+
## Success Criteria
|
| 68 |
+
- Process Gemini-generated images with ZERO visible artifacts
|
| 69 |
+
- Capture ALL design elements (test with bubble/sparkle designs)
|
| 70 |
+
- Return base64 that works in n8n HTTP Request node
|
| 71 |
+
- Health endpoint returns proper status
|
| 72 |
+
|
| 73 |
+
## Reference Documents
|
| 74 |
+
See specs/requirements.md for detailed technical specifications.
|
| 75 |
+
|
| 76 |
+
## Notes
|
| 77 |
+
- This will replace the current HuggingFace BiRefNet API in the Etsy workflow
|
| 78 |
+
- Priority is QUALITY over speed (mockups need to be perfect)
|
| 79 |
+
- Test with white AND non-white backgrounds (Gemini may vary)
|
README.md
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CutoutAI - Background Remover
|
| 2 |
+
|
| 3 |
+
An enhanced, flawless background removal tool built on BiRefNet for perfect t-shirt mockup preparation.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
- **Flawless Removal**: No patchy faces, artifacts, or edge issues
|
| 8 |
+
- **Multi-Element Capture**: Captures bubbles, decorations, and all design elements
|
| 9 |
+
- **API Ready**: Webhook, HTTP API, terminal commands
|
| 10 |
+
- **Cloud Hosted**: Designed for n8n, Make, and cloud automation
|
| 11 |
+
- **Mockup Quality**: Optimized for Printify t-shirt mockups
|
| 12 |
+
|
| 13 |
+
## Quick Start
|
| 14 |
+
|
| 15 |
+
```python
|
| 16 |
+
from cutoutai import remove_background
|
| 17 |
+
|
| 18 |
+
# Basic usage
|
| 19 |
+
result = remove_background("design.png")
|
| 20 |
+
result.save("design_cutout.png")
|
| 21 |
+
|
| 22 |
+
# With enhanced settings for complex designs
|
| 23 |
+
result = remove_background(
|
| 24 |
+
"design.png",
|
| 25 |
+
capture_all_elements=True, # Get bubbles, small elements
|
| 26 |
+
edge_refinement=True, # Smooth edges
|
| 27 |
+
matting_mode="general" # or "portrait" for faces
|
| 28 |
+
)
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## API Endpoints
|
| 32 |
+
|
| 33 |
+
| Endpoint | Method | Description |
|
| 34 |
+
|----------|--------|-------------|
|
| 35 |
+
| `/api/v1/remove` | POST | Remove background from image |
|
| 36 |
+
| `/api/v1/batch` | POST | Process multiple images |
|
| 37 |
+
| `/api/v1/health` | GET | Health check |
|
| 38 |
+
| `/webhook` | POST | n8n/Make webhook endpoint |
|
| 39 |
+
|
| 40 |
+
## Workflow Integration
|
| 41 |
+
|
| 42 |
+
### n8n Webhook
|
| 43 |
+
```
|
| 44 |
+
POST https://your-host/webhook
|
| 45 |
+
Content-Type: multipart/form-data
|
| 46 |
+
|
| 47 |
+
image: <file>
|
| 48 |
+
options: {"capture_all_elements": true}
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### CLI
|
| 52 |
+
```bash
|
| 53 |
+
cutoutai process design.png --output cutout.png
|
| 54 |
+
cutoutai batch ./designs/ --output ./cutouts/
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
## Quality Settings
|
| 58 |
+
|
| 59 |
+
| Setting | Description | Use Case |
|
| 60 |
+
|---------|-------------|----------|
|
| 61 |
+
| `capture_all_elements` | Detect and preserve small elements (bubbles, decorations) | Complex designs |
|
| 62 |
+
| `edge_refinement` | Smooth and feather edges | All mockups |
|
| 63 |
+
| `matting_mode` | `general`, `portrait`, or `heavy` | Match content type |
|
| 64 |
+
| `output_resolution` | Preserve or scale output | Printify requirements |
|
| 65 |
+
|
| 66 |
+
## License
|
| 67 |
+
|
| 68 |
+
MIT License - Built on BiRefNet
|
api.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CutoutAI API Server
|
| 3 |
+
|
| 4 |
+
FastAPI server providing:
|
| 5 |
+
- REST API endpoints for background removal
|
| 6 |
+
- Webhook endpoint for n8n/Make integration
|
| 7 |
+
- Health check for monitoring
|
| 8 |
+
- Startup model preloading
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import io
|
| 12 |
+
import base64
|
| 13 |
+
import time
|
| 14 |
+
import logging
|
| 15 |
+
from typing import Optional, Literal, Union
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from contextlib import asynccontextmanager
|
| 18 |
+
|
| 19 |
+
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request
|
| 20 |
+
from fastapi.responses import Response, JSONResponse
|
| 21 |
+
from pydantic import BaseModel, Field
|
| 22 |
+
|
| 23 |
+
from cutoutai import CutoutAI, MODEL_VARIANTS, logger as cutout_logger
|
| 24 |
+
|
| 25 |
+
# Configure logging
|
| 26 |
+
logging.basicConfig(
|
| 27 |
+
level=logging.INFO,
|
| 28 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 29 |
+
)
|
| 30 |
+
logger = logging.getLogger("CutoutAI-API")
|
| 31 |
+
|
| 32 |
+
# Global model instances (by variant)
|
| 33 |
+
_models: dict[str, CutoutAI] = {}
|
| 34 |
+
|
| 35 |
+
def get_model(variant: str = "matting") -> CutoutAI:
|
| 36 |
+
"""Get or create a model instance for the specified variant."""
|
| 37 |
+
global _models
|
| 38 |
+
if variant not in _models:
|
| 39 |
+
_models[variant] = CutoutAI(model_variant=variant)
|
| 40 |
+
_models[variant].load_model()
|
| 41 |
+
return _models[variant]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Lifespan context for startup/shutdown
|
| 45 |
+
@asynccontextmanager
|
| 46 |
+
async def lifespan(app: FastAPI):
|
| 47 |
+
# Startup: preload the default model
|
| 48 |
+
print("Preloading matting model...")
|
| 49 |
+
get_model("matting")
|
| 50 |
+
print("Model preloaded and ready!")
|
| 51 |
+
yield
|
| 52 |
+
# Shutdown: cleanup
|
| 53 |
+
_models.clear()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# Initialize FastAPI with lifespan
|
| 57 |
+
app = FastAPI(
|
| 58 |
+
title="CutoutAI - Background Remover",
|
| 59 |
+
description="Flawless background removal for t-shirt mockups and design workflows",
|
| 60 |
+
version="1.1.0",
|
| 61 |
+
lifespan=lifespan
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# Request/Response models
|
| 66 |
+
class ProcessOptions(BaseModel):
|
| 67 |
+
model: Literal["general", "matting", "portrait", "lite", "hr", "dynamic"] = "matting"
|
| 68 |
+
capture_all_elements: bool = True
|
| 69 |
+
edge_refinement: bool = True
|
| 70 |
+
edge_radius: int = 2
|
| 71 |
+
threshold: Optional[float] = None
|
| 72 |
+
soft_threshold: bool = False
|
| 73 |
+
remove_artifacts: bool = True
|
| 74 |
+
min_artifact_size: int = 40
|
| 75 |
+
adaptive_threshold: bool = True
|
| 76 |
+
return_mask: bool = False
|
| 77 |
+
output_format: Literal["png", "base64"] = "png"
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class WebhookRequest(BaseModel):
|
| 81 |
+
image_base64: Optional[str] = None
|
| 82 |
+
image_url: Optional[str] = None
|
| 83 |
+
options: Optional[ProcessOptions] = None
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class HealthResponse(BaseModel):
|
| 87 |
+
status: str
|
| 88 |
+
version: str
|
| 89 |
+
model_loaded: bool
|
| 90 |
+
models_loaded: list[str]
|
| 91 |
+
device: str
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# Endpoints
|
| 95 |
+
@app.get("/health", response_model=HealthResponse)
|
| 96 |
+
async def health_check():
|
| 97 |
+
"""Health check endpoint for monitoring."""
|
| 98 |
+
global _models
|
| 99 |
+
loaded_models = list(_models.keys())
|
| 100 |
+
device = _models["matting"].device if "matting" in _models else "not loaded"
|
| 101 |
+
return HealthResponse(
|
| 102 |
+
status="healthy",
|
| 103 |
+
version="1.1.0",
|
| 104 |
+
model_loaded=len(_models) > 0,
|
| 105 |
+
models_loaded=loaded_models,
|
| 106 |
+
device=device
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@app.get("/")
|
| 111 |
+
async def root():
|
| 112 |
+
"""Root endpoint with API info."""
|
| 113 |
+
return {
|
| 114 |
+
"name": "CutoutAI - Background Remover",
|
| 115 |
+
"version": "1.1.0",
|
| 116 |
+
"docs": "/docs",
|
| 117 |
+
"health": "/health"
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@app.post("/api/v1/remove")
|
| 122 |
+
async def remove_bg(
|
| 123 |
+
image: UploadFile = File(...),
|
| 124 |
+
model: str = Form("matting"),
|
| 125 |
+
edge_refinement: bool = Form(True),
|
| 126 |
+
capture_all_elements: bool = Form(True),
|
| 127 |
+
threshold: Optional[float] = Form(None),
|
| 128 |
+
soft_threshold: bool = Form(False),
|
| 129 |
+
remove_artifacts: bool = Form(True),
|
| 130 |
+
adaptive_threshold: bool = Form(True),
|
| 131 |
+
return_mask: bool = Form(False),
|
| 132 |
+
output_format: str = Form("png")
|
| 133 |
+
):
|
| 134 |
+
"""
|
| 135 |
+
Remove background from uploaded image.
|
| 136 |
+
|
| 137 |
+
- **image**: Image file to process
|
| 138 |
+
- **model**: Model variant (matting recommended for designs)
|
| 139 |
+
- **edge_refinement**: Smooth edges for cleaner cutouts
|
| 140 |
+
- **capture_all_elements**: Lower threshold to capture bubbles/small elements
|
| 141 |
+
- **threshold**: Override mask threshold (0.0-1.0)
|
| 142 |
+
- **soft_threshold**: Use soft thresholding
|
| 143 |
+
- **remove_artifacts**: Remove small isolated islands from mask
|
| 144 |
+
- **adaptive_threshold**: Calculate threshold based on image confidence
|
| 145 |
+
- **return_mask**: Return a JSON object with both result and mask
|
| 146 |
+
- **output_format**: "png" for file download, "base64" for JSON response
|
| 147 |
+
"""
|
| 148 |
+
start_time = time.time()
|
| 149 |
+
|
| 150 |
+
try:
|
| 151 |
+
# Validate model
|
| 152 |
+
if model not in MODEL_VARIANTS:
|
| 153 |
+
raise HTTPException(status_code=400, detail=f"Invalid model: {model}. Available variants: {list(MODEL_VARIANTS.keys())}")
|
| 154 |
+
|
| 155 |
+
# Read image
|
| 156 |
+
contents = await image.read()
|
| 157 |
+
|
| 158 |
+
# Validate file size (max 10MB)
|
| 159 |
+
if len(contents) > 10 * 1024 * 1024:
|
| 160 |
+
raise HTTPException(status_code=413, detail="Image too large (max 10MB)")
|
| 161 |
+
|
| 162 |
+
# Process
|
| 163 |
+
processor = get_model(model)
|
| 164 |
+
result = processor.process(
|
| 165 |
+
contents,
|
| 166 |
+
edge_refinement=edge_refinement,
|
| 167 |
+
capture_all_elements=capture_all_elements,
|
| 168 |
+
threshold=threshold,
|
| 169 |
+
soft_threshold=soft_threshold,
|
| 170 |
+
remove_artifacts=remove_artifacts,
|
| 171 |
+
adaptive_threshold=adaptive_threshold,
|
| 172 |
+
return_mask=return_mask,
|
| 173 |
+
output_format="bytes" if output_format == "png" and not return_mask else "base64"
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
processing_time = time.time() - start_time
|
| 177 |
+
|
| 178 |
+
if return_mask:
|
| 179 |
+
# result is a dict here
|
| 180 |
+
return JSONResponse({
|
| 181 |
+
"success": True,
|
| 182 |
+
"result_base64": result["result"],
|
| 183 |
+
"mask_base64": result["mask"],
|
| 184 |
+
"threshold_used": round(result["threshold_used"], 4),
|
| 185 |
+
"processing_time_seconds": round(processing_time, 2)
|
| 186 |
+
})
|
| 187 |
+
|
| 188 |
+
if output_format == "png":
|
| 189 |
+
return Response(
|
| 190 |
+
content=result,
|
| 191 |
+
media_type="image/png",
|
| 192 |
+
headers={
|
| 193 |
+
"Content-Disposition": f'attachment; filename="{image.filename}_cutout.png"',
|
| 194 |
+
"X-Processing-Time": f"{processing_time:.2f}s"
|
| 195 |
+
}
|
| 196 |
+
)
|
| 197 |
+
else:
|
| 198 |
+
return JSONResponse({
|
| 199 |
+
"success": True,
|
| 200 |
+
"image_base64": result,
|
| 201 |
+
"processing_time_seconds": round(processing_time, 2)
|
| 202 |
+
})
|
| 203 |
+
|
| 204 |
+
except HTTPException:
|
| 205 |
+
raise
|
| 206 |
+
except ValueError as e:
|
| 207 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 208 |
+
except Exception as e:
|
| 209 |
+
logger.exception("Error processing request")
|
| 210 |
+
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
@app.post("/api/v1/batch")
|
| 214 |
+
async def batch_remove(
|
| 215 |
+
images: list[UploadFile] = File(...),
|
| 216 |
+
model: str = Form("matting"),
|
| 217 |
+
capture_all_elements: bool = Form(True)
|
| 218 |
+
):
|
| 219 |
+
"""Process multiple images in batch."""
|
| 220 |
+
start_time = time.time()
|
| 221 |
+
results = []
|
| 222 |
+
processor = get_model(model)
|
| 223 |
+
|
| 224 |
+
for img in images:
|
| 225 |
+
contents = await img.read()
|
| 226 |
+
result = processor.process(
|
| 227 |
+
contents,
|
| 228 |
+
capture_all_elements=capture_all_elements,
|
| 229 |
+
output_format="base64"
|
| 230 |
+
)
|
| 231 |
+
results.append({
|
| 232 |
+
"filename": img.filename,
|
| 233 |
+
"image_base64": result
|
| 234 |
+
})
|
| 235 |
+
|
| 236 |
+
total_time = time.time() - start_time
|
| 237 |
+
|
| 238 |
+
return JSONResponse({
|
| 239 |
+
"success": True,
|
| 240 |
+
"count": len(results),
|
| 241 |
+
"results": results,
|
| 242 |
+
"total_processing_time_seconds": round(total_time, 2)
|
| 243 |
+
})
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
@app.post("/webhook")
|
| 247 |
+
async def webhook_handler(
|
| 248 |
+
request: Request,
|
| 249 |
+
image: Optional[UploadFile] = File(None),
|
| 250 |
+
image_base64: Optional[str] = Form(None),
|
| 251 |
+
image_url: Optional[str] = Form(None),
|
| 252 |
+
model: str = Form("matting"),
|
| 253 |
+
edge_refinement: bool = Form(True),
|
| 254 |
+
capture_all_elements: bool = Form(True),
|
| 255 |
+
edge_radius: int = Form(2),
|
| 256 |
+
threshold: Optional[float] = Form(None),
|
| 257 |
+
soft_threshold: bool = Form(False),
|
| 258 |
+
callback_url: Optional[str] = Form(None)
|
| 259 |
+
):
|
| 260 |
+
"""
|
| 261 |
+
Webhook endpoint for n8n/Make integration.
|
| 262 |
+
|
| 263 |
+
Accepts image via:
|
| 264 |
+
- File upload (image)
|
| 265 |
+
- Base64 encoded string (image_base64)
|
| 266 |
+
- URL to fetch (image_url)
|
| 267 |
+
|
| 268 |
+
Returns base64 encoded result for easy workflow integration.
|
| 269 |
+
"""
|
| 270 |
+
start_time = time.time()
|
| 271 |
+
logger.info(f"Webhook request received from {request.client.host}")
|
| 272 |
+
|
| 273 |
+
try:
|
| 274 |
+
# Check if JSON body instead of form
|
| 275 |
+
if request.headers.get("content-type") == "application/json":
|
| 276 |
+
try:
|
| 277 |
+
body = await request.json()
|
| 278 |
+
image_base64 = body.get("image_base64", image_base64)
|
| 279 |
+
image_url = body.get("image_url", image_url)
|
| 280 |
+
model = body.get("model", model)
|
| 281 |
+
edge_refinement = body.get("edge_refinement", edge_refinement)
|
| 282 |
+
capture_all_elements = body.get("capture_all_elements", capture_all_elements)
|
| 283 |
+
edge_radius = body.get("edge_radius", edge_radius)
|
| 284 |
+
threshold = body.get("threshold", threshold)
|
| 285 |
+
soft_threshold = body.get("soft_threshold", soft_threshold)
|
| 286 |
+
callback_url = body.get("callback_url", callback_url)
|
| 287 |
+
except Exception as e:
|
| 288 |
+
logger.warning(f"Failed to parse JSON body: {e}")
|
| 289 |
+
|
| 290 |
+
# Validate model
|
| 291 |
+
if model not in MODEL_VARIANTS:
|
| 292 |
+
logger.error(f"Invalid model requested: {model}")
|
| 293 |
+
return JSONResponse(
|
| 294 |
+
{"success": False, "error": f"Invalid model: {model}. Available: {list(MODEL_VARIANTS.keys())}"},
|
| 295 |
+
status_code=400
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
processor = get_model(model)
|
| 299 |
+
|
| 300 |
+
# Get image from one of the sources
|
| 301 |
+
img_data = None
|
| 302 |
+
if image:
|
| 303 |
+
img_data = await image.read()
|
| 304 |
+
logger.info(f"Using uploaded file: {image.filename}")
|
| 305 |
+
elif image_base64:
|
| 306 |
+
try:
|
| 307 |
+
# Handle potential header in base64
|
| 308 |
+
if "," in image_base64:
|
| 309 |
+
image_base64 = image_base64.split(",")[1]
|
| 310 |
+
img_data = base64.b64decode(image_base64)
|
| 311 |
+
logger.info("Using base64 image data")
|
| 312 |
+
except Exception as e:
|
| 313 |
+
return JSONResponse({"success": False, "error": f"Invalid base64 data: {e}"}, status_code=400)
|
| 314 |
+
elif image_url:
|
| 315 |
+
import httpx
|
| 316 |
+
logger.info(f"Fetching image from URL: {image_url}")
|
| 317 |
+
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
| 318 |
+
try:
|
| 319 |
+
response = await client.get(image_url)
|
| 320 |
+
response.raise_for_status()
|
| 321 |
+
img_data = response.content
|
| 322 |
+
except httpx.HTTPStatusError as e:
|
| 323 |
+
return JSONResponse({"success": False, "error": f"Failed to fetch image: {e.response.status_code}"}, status_code=400)
|
| 324 |
+
except Exception as e:
|
| 325 |
+
return JSONResponse({"success": False, "error": f"Network error: {e}"}, status_code=500)
|
| 326 |
+
else:
|
| 327 |
+
return JSONResponse(
|
| 328 |
+
{"success": False, "error": "No image provided. Use 'image', 'image_base64', or 'image_url'"},
|
| 329 |
+
status_code=400
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
# Validate data
|
| 333 |
+
if not img_data:
|
| 334 |
+
return JSONResponse({"success": False, "error": "Empty image data"}, status_code=400)
|
| 335 |
+
|
| 336 |
+
# Process
|
| 337 |
+
result = processor.process(
|
| 338 |
+
img_data,
|
| 339 |
+
edge_refinement=edge_refinement,
|
| 340 |
+
capture_all_elements=capture_all_elements,
|
| 341 |
+
edge_radius=edge_radius,
|
| 342 |
+
threshold=threshold,
|
| 343 |
+
soft_threshold=soft_threshold,
|
| 344 |
+
output_format="base64"
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
processing_time = time.time() - start_time
|
| 348 |
+
|
| 349 |
+
response_data = {
|
| 350 |
+
"success": True,
|
| 351 |
+
"image_base64": result,
|
| 352 |
+
"model_used": model,
|
| 353 |
+
"processing_time_seconds": round(processing_time, 2)
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
# If callback URL provided, send result there too
|
| 357 |
+
if callback_url:
|
| 358 |
+
import httpx
|
| 359 |
+
logger.info(f"Sending callback to: {callback_url}")
|
| 360 |
+
async with httpx.AsyncClient(timeout=10.0) as client:
|
| 361 |
+
try:
|
| 362 |
+
await client.post(callback_url, json=response_data)
|
| 363 |
+
except Exception as e:
|
| 364 |
+
logger.error(f"Callback failed: {e}")
|
| 365 |
+
response_data["callback_error"] = str(e)
|
| 366 |
+
|
| 367 |
+
return JSONResponse(response_data)
|
| 368 |
+
|
| 369 |
+
except Exception as e:
|
| 370 |
+
logger.exception("Unexpected error in webhook handler")
|
| 371 |
+
return JSONResponse(
|
| 372 |
+
{"success": False, "error": str(e)},
|
| 373 |
+
status_code=500
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
# CLI entry point
|
| 378 |
+
if __name__ == "__main__":
|
| 379 |
+
import uvicorn
|
| 380 |
+
import argparse
|
| 381 |
+
import os
|
| 382 |
+
|
| 383 |
+
parser = argparse.ArgumentParser(description="CutoutAI API Server")
|
| 384 |
+
parser.add_argument("--host", default="0.0.0.0", help="Host address")
|
| 385 |
+
parser.add_argument("--port", type=int, default=int(os.environ.get("PORT", 8000)), help="Port number")
|
| 386 |
+
args = parser.parse_args()
|
| 387 |
+
|
| 388 |
+
uvicorn.run(app, host=args.host, port=args.port)
|
cutoutai.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CutoutAI - Enhanced Background Removal for Perfect T-Shirt Mockups
|
| 3 |
+
|
| 4 |
+
Built on BiRefNet for flawless background removal with:
|
| 5 |
+
- Multi-element capture (bubbles, decorations, small details)
|
| 6 |
+
- Edge refinement for clean cutouts
|
| 7 |
+
- Optimized for Printify mockup preparation
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import io
|
| 11 |
+
import base64
|
| 12 |
+
import time
|
| 13 |
+
import logging
|
| 14 |
+
from typing import Optional, Literal, Union
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import numpy as np
|
| 19 |
+
from PIL import Image, ImageFilter
|
| 20 |
+
from torchvision import transforms
|
| 21 |
+
|
| 22 |
+
# Configure logging
|
| 23 |
+
logging.basicConfig(
|
| 24 |
+
level=logging.INFO,
|
| 25 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 26 |
+
)
|
| 27 |
+
logger = logging.getLogger("CutoutAI")
|
| 28 |
+
|
| 29 |
+
# Model variants available
|
| 30 |
+
MODEL_VARIANTS = {
|
| 31 |
+
"general": "ZhengPeng7/BiRefNet", # General use
|
| 32 |
+
"matting": "ZhengPeng7/BiRefNet-matting", # Best for complex edges
|
| 33 |
+
"portrait": "ZhengPeng7/BiRefNet-portrait", # Faces/people
|
| 34 |
+
"lite": "ZhengPeng7/BiRefNet_lite", # Faster, smaller
|
| 35 |
+
"hr": "ZhengPeng7/BiRefNet_HR", # High resolution (2K)
|
| 36 |
+
"dynamic": "ZhengPeng7/BiRefNet_dynamic", # Variable resolution
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
# Default image transforms
|
| 40 |
+
def get_transforms(size: int = 1024):
|
| 41 |
+
"""Get preprocessing transforms for BiRefNet."""
|
| 42 |
+
return transforms.Compose([
|
| 43 |
+
transforms.Resize((size, size)),
|
| 44 |
+
transforms.ToTensor(),
|
| 45 |
+
transforms.Normalize(
|
| 46 |
+
mean=[0.485, 0.456, 0.406],
|
| 47 |
+
std=[0.229, 0.224, 0.225]
|
| 48 |
+
)
|
| 49 |
+
])
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def refine_foreground(image: Image.Image, mask: Image.Image) -> Image.Image:
|
| 53 |
+
"""
|
| 54 |
+
Apply mask to image with refined edges for flawless cutouts.
|
| 55 |
+
|
| 56 |
+
This is critical for t-shirt mockups - ensures:
|
| 57 |
+
- No patchy faces or artifacts
|
| 58 |
+
- Clean edges on hair and fine details
|
| 59 |
+
- All small elements (bubbles, decorations) captured
|
| 60 |
+
"""
|
| 61 |
+
# Convert to RGBA
|
| 62 |
+
image = image.convert("RGBA")
|
| 63 |
+
mask = mask.convert("L")
|
| 64 |
+
|
| 65 |
+
# Resize mask to match image if needed
|
| 66 |
+
if mask.size != image.size:
|
| 67 |
+
mask = mask.resize(image.size, Image.LANCZOS)
|
| 68 |
+
|
| 69 |
+
# Apply mask as alpha channel
|
| 70 |
+
result = Image.new("RGBA", image.size, (0, 0, 0, 0))
|
| 71 |
+
result.paste(image, mask=mask)
|
| 72 |
+
|
| 73 |
+
return result
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def edge_smooth(mask: Image.Image, radius: int = 2, preserve_details: bool = True) -> Image.Image:
|
| 77 |
+
"""
|
| 78 |
+
Apply edge smoothing while preserving fine details.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
mask: Binary or grayscale mask
|
| 82 |
+
radius: Smoothing intensity (1-5 recommended)
|
| 83 |
+
preserve_details: If True, use morphological ops instead of blur
|
| 84 |
+
"""
|
| 85 |
+
if radius <= 0:
|
| 86 |
+
return mask
|
| 87 |
+
|
| 88 |
+
if preserve_details:
|
| 89 |
+
# Use morphological operations to clean edges without losing detail
|
| 90 |
+
# Erosion removes thin protrusions (noise)
|
| 91 |
+
# size must be odd
|
| 92 |
+
size = 2 * radius + 1
|
| 93 |
+
eroded = mask.filter(ImageFilter.MinFilter(size))
|
| 94 |
+
# Dilation restores the shape
|
| 95 |
+
smoothed = eroded.filter(ImageFilter.MaxFilter(size))
|
| 96 |
+
|
| 97 |
+
# Optional: slight median filter to remove salt-and-pepper noise
|
| 98 |
+
if radius > 1:
|
| 99 |
+
smoothed = smoothed.filter(ImageFilter.MedianFilter(3))
|
| 100 |
+
else:
|
| 101 |
+
# Fall back to gaussian blur for softer edges
|
| 102 |
+
smoothed = mask.filter(ImageFilter.GaussianBlur(radius=radius))
|
| 103 |
+
|
| 104 |
+
return smoothed
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def remove_small_artifacts(mask: Image.Image, min_size: int = 100) -> Image.Image:
|
| 108 |
+
"""
|
| 109 |
+
Remove small isolated 'islands' from the mask that are likely artifacts.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
mask: Grayscale mask (PIL Image)
|
| 113 |
+
min_size: Minimum pixel area to keep
|
| 114 |
+
"""
|
| 115 |
+
import numpy as np
|
| 116 |
+
from scipy import ndimage
|
| 117 |
+
|
| 118 |
+
# Convert to binary
|
| 119 |
+
mask_np = np.array(mask) > 128
|
| 120 |
+
|
| 121 |
+
# Label connected components
|
| 122 |
+
label_im, nb_labels = ndimage.label(mask_np)
|
| 123 |
+
|
| 124 |
+
# Calculate sizes of components
|
| 125 |
+
sizes = ndimage.sum(mask_np, label_im, range(nb_labels + 1))
|
| 126 |
+
|
| 127 |
+
# Identify components that are too small
|
| 128 |
+
mask_size = sizes < min_size
|
| 129 |
+
remove_pixel = mask_size[label_im]
|
| 130 |
+
|
| 131 |
+
# Remove small components
|
| 132 |
+
mask_np[remove_pixel] = 0
|
| 133 |
+
|
| 134 |
+
return Image.fromarray((mask_np * 255).astype(np.uint8))
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def calculate_adaptive_threshold(pred: np.ndarray, base_threshold: float = 0.2) -> float:
|
| 138 |
+
"""
|
| 139 |
+
Calculate an adaptive threshold based on the prediction distribution.
|
| 140 |
+
Useful for capturing small design elements without introducing too much noise.
|
| 141 |
+
"""
|
| 142 |
+
# Simple adaptive approach: if there are many low-confidence pixels,
|
| 143 |
+
# we might be looking at a design with many small elements (bubbles, etc.)
|
| 144 |
+
# We can use a percentile-based approach or Otsu's method if appropriate
|
| 145 |
+
|
| 146 |
+
# For now, let's use a simple heuristic:
|
| 147 |
+
# If the 95th percentile is low, it's a very faint design, lower the threshold further
|
| 148 |
+
p95 = np.percentile(pred, 95)
|
| 149 |
+
if p95 < 0.5:
|
| 150 |
+
return max(0.05, base_threshold * 0.5)
|
| 151 |
+
|
| 152 |
+
return base_threshold
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def apply_threshold(pred: np.ndarray, threshold: float = 0.4, soft: bool = False) -> np.ndarray:
|
| 156 |
+
"""
|
| 157 |
+
Apply threshold to mask for cleaner binary edges.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
pred: Prediction array (0-1 range)
|
| 161 |
+
threshold: Cutoff value (pixels below become 0, above become 1)
|
| 162 |
+
soft: If True, use a soft threshold (keep low confidence regions as semi-transparent)
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
Thresholded array
|
| 166 |
+
"""
|
| 167 |
+
if soft:
|
| 168 |
+
# Sigmoid-like soft thresholding
|
| 169 |
+
# Regions near threshold are preserved but dimmed
|
| 170 |
+
# Steepness of 15 provides a good balance between sharp and soft
|
| 171 |
+
return 1.0 / (1.0 + np.exp(-15 * (pred - threshold)))
|
| 172 |
+
|
| 173 |
+
return np.where(pred > threshold, 1.0, 0.0)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class CutoutAI:
|
| 177 |
+
"""
|
| 178 |
+
Enhanced background removal optimized for t-shirt mockup preparation.
|
| 179 |
+
|
| 180 |
+
Key features:
|
| 181 |
+
- Captures ALL elements including bubbles, small decorations
|
| 182 |
+
- Flawless edge quality with no artifacts
|
| 183 |
+
- Multiple model options for different use cases
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
def __init__(
|
| 187 |
+
self,
|
| 188 |
+
model_variant: Literal["general", "matting", "portrait", "lite", "hr", "dynamic"] = "matting",
|
| 189 |
+
device: Optional[str] = None,
|
| 190 |
+
resolution: int = 1024
|
| 191 |
+
):
|
| 192 |
+
"""
|
| 193 |
+
Initialize CutoutAI.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
model_variant: Which BiRefNet model to use
|
| 197 |
+
- "matting": Best for complex edges, hair, fine details (RECOMMENDED)
|
| 198 |
+
- "general": Standard background removal
|
| 199 |
+
- "portrait": Optimized for faces/people
|
| 200 |
+
- "lite": Faster processing, lower quality
|
| 201 |
+
- "hr": High resolution up to 2K
|
| 202 |
+
- "dynamic": Variable resolution support
|
| 203 |
+
device: "cuda", "cpu", or None for auto-detect
|
| 204 |
+
resolution: Processing resolution (1024 or 2048 for hr model)
|
| 205 |
+
"""
|
| 206 |
+
self.model_variant = model_variant
|
| 207 |
+
self.model_name = MODEL_VARIANTS[model_variant]
|
| 208 |
+
self.resolution = resolution
|
| 209 |
+
|
| 210 |
+
# Auto-detect device
|
| 211 |
+
if device is None:
|
| 212 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 213 |
+
else:
|
| 214 |
+
self.device = device
|
| 215 |
+
|
| 216 |
+
self.model = None
|
| 217 |
+
self.transforms = get_transforms(resolution)
|
| 218 |
+
|
| 219 |
+
def load_model(self):
|
| 220 |
+
"""Load the BiRefNet model from HuggingFace."""
|
| 221 |
+
if self.model is not None:
|
| 222 |
+
return
|
| 223 |
+
|
| 224 |
+
from transformers import AutoModelForImageSegmentation
|
| 225 |
+
|
| 226 |
+
print(f"Loading {self.model_name}...")
|
| 227 |
+
self.model = AutoModelForImageSegmentation.from_pretrained(
|
| 228 |
+
self.model_name,
|
| 229 |
+
trust_remote_code=True
|
| 230 |
+
)
|
| 231 |
+
self.model.to(self.device)
|
| 232 |
+
self.model.eval()
|
| 233 |
+
print(f"Model loaded on {self.device}")
|
| 234 |
+
|
| 235 |
+
def process(
|
| 236 |
+
self,
|
| 237 |
+
image: Union[str, Path, Image.Image, bytes],
|
| 238 |
+
capture_all_elements: bool = True,
|
| 239 |
+
edge_refinement: bool = True,
|
| 240 |
+
edge_radius: int = 2,
|
| 241 |
+
threshold: Optional[float] = None,
|
| 242 |
+
soft_threshold: bool = False,
|
| 243 |
+
preserve_details: bool = True,
|
| 244 |
+
remove_artifacts: bool = True,
|
| 245 |
+
min_artifact_size: int = 40,
|
| 246 |
+
adaptive_threshold: bool = True,
|
| 247 |
+
return_mask: bool = False,
|
| 248 |
+
output_format: Literal["pil", "bytes", "base64"] = "pil"
|
| 249 |
+
) -> Union[Image.Image, bytes, str, dict]:
|
| 250 |
+
"""
|
| 251 |
+
Remove background from image with enhanced quality.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
image: Input image (path, PIL Image, or bytes)
|
| 255 |
+
capture_all_elements: Use lower threshold to capture bubbles/small elements
|
| 256 |
+
edge_refinement: Apply edge smoothing for cleaner cutouts
|
| 257 |
+
edge_radius: Smoothing intensity (1-5, default 2)
|
| 258 |
+
threshold: Override mask threshold (0.0-1.0, None for auto)
|
| 259 |
+
soft_threshold: Use soft thresholding for smoother transitions
|
| 260 |
+
preserve_details: Use morphological ops instead of blur
|
| 261 |
+
remove_artifacts: Remove small isolated islands from mask
|
| 262 |
+
min_artifact_size: Minimum pixel area for islands to keep
|
| 263 |
+
adaptive_threshold: Calculate threshold based on image confidence
|
| 264 |
+
return_mask: If True, return a dict containing both result and mask
|
| 265 |
+
output_format: Return format ("pil", "bytes", "base64")
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
Processed image with transparent background (or dict if return_mask=True)
|
| 269 |
+
"""
|
| 270 |
+
start_time = time.time()
|
| 271 |
+
logger.info(f"Processing image with variant: {self.model_variant}")
|
| 272 |
+
self.load_model()
|
| 273 |
+
|
| 274 |
+
# Load image
|
| 275 |
+
try:
|
| 276 |
+
if isinstance(image, (str, Path)):
|
| 277 |
+
pil_image = Image.open(image).convert("RGB")
|
| 278 |
+
elif isinstance(image, bytes):
|
| 279 |
+
pil_image = Image.open(io.BytesIO(image)).convert("RGB")
|
| 280 |
+
else:
|
| 281 |
+
pil_image = image.convert("RGB")
|
| 282 |
+
except Exception as e:
|
| 283 |
+
logger.error(f"Failed to load image: {e}")
|
| 284 |
+
raise ValueError(f"Invalid image input: {e}")
|
| 285 |
+
|
| 286 |
+
original_size = pil_image.size
|
| 287 |
+
logger.info(f"Image size: {original_size}")
|
| 288 |
+
|
| 289 |
+
# Preprocess
|
| 290 |
+
input_tensor = self.transforms(pil_image).unsqueeze(0).to(self.device)
|
| 291 |
+
|
| 292 |
+
# Inference
|
| 293 |
+
with torch.no_grad():
|
| 294 |
+
outputs = self.model(input_tensor)
|
| 295 |
+
|
| 296 |
+
# Get prediction mask
|
| 297 |
+
if isinstance(outputs, (list, tuple)):
|
| 298 |
+
pred = outputs[0]
|
| 299 |
+
else:
|
| 300 |
+
pred = outputs
|
| 301 |
+
|
| 302 |
+
# Convert to numpy
|
| 303 |
+
pred = pred.squeeze().cpu().numpy()
|
| 304 |
+
|
| 305 |
+
# Apply thresholding for cleaner edges
|
| 306 |
+
# Lower threshold captures more (bubbles, small elements)
|
| 307 |
+
# Higher threshold is more selective
|
| 308 |
+
if threshold is not None:
|
| 309 |
+
mask_threshold = threshold
|
| 310 |
+
elif capture_all_elements:
|
| 311 |
+
mask_threshold = 0.2 # Base low threshold
|
| 312 |
+
if adaptive_threshold:
|
| 313 |
+
mask_threshold = calculate_adaptive_threshold(pred, mask_threshold)
|
| 314 |
+
else:
|
| 315 |
+
mask_threshold = 0.4 # Standard threshold
|
| 316 |
+
|
| 317 |
+
logger.info(f"Using threshold: {mask_threshold:.4f} (soft: {soft_threshold})")
|
| 318 |
+
pred = apply_threshold(pred, mask_threshold, soft=soft_threshold)
|
| 319 |
+
|
| 320 |
+
# Convert to PIL mask
|
| 321 |
+
pred = (pred * 255).astype(np.uint8)
|
| 322 |
+
mask = Image.fromarray(pred).resize(original_size, Image.LANCZOS)
|
| 323 |
+
|
| 324 |
+
# Remove small artifacts if requested
|
| 325 |
+
if remove_artifacts:
|
| 326 |
+
logger.info(f"Removing small artifacts (min_size: {min_artifact_size})")
|
| 327 |
+
try:
|
| 328 |
+
mask = remove_small_artifacts(mask, min_size=min_artifact_size)
|
| 329 |
+
except ImportError:
|
| 330 |
+
logger.warning("Scipy not installed, skipping artifact removal")
|
| 331 |
+
|
| 332 |
+
# Edge refinement for cleaner cutouts
|
| 333 |
+
if edge_refinement:
|
| 334 |
+
logger.info(f"Applying edge refinement (radius: {edge_radius})")
|
| 335 |
+
mask = edge_smooth(mask, radius=edge_radius, preserve_details=preserve_details)
|
| 336 |
+
|
| 337 |
+
# Apply mask to get final result
|
| 338 |
+
result = refine_foreground(pil_image, mask)
|
| 339 |
+
|
| 340 |
+
# Record processing time
|
| 341 |
+
self._last_processing_time = time.time() - start_time
|
| 342 |
+
logger.info(f"Processing completed in {self._last_processing_time:.2f}s")
|
| 343 |
+
|
| 344 |
+
# Prepare outputs
|
| 345 |
+
if return_mask:
|
| 346 |
+
return {
|
| 347 |
+
"result": self._format_output(result, output_format),
|
| 348 |
+
"mask": self._format_output(mask, output_format),
|
| 349 |
+
"threshold_used": mask_threshold,
|
| 350 |
+
"processing_time": self._last_processing_time
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
return self._format_output(result, output_format)
|
| 354 |
+
|
| 355 |
+
def _format_output(self, image: Image.Image, output_format: str) -> Union[Image.Image, bytes, str]:
|
| 356 |
+
"""Format PIL Image to requested output format."""
|
| 357 |
+
if output_format == "pil":
|
| 358 |
+
return image
|
| 359 |
+
elif output_format == "bytes":
|
| 360 |
+
buffer = io.BytesIO()
|
| 361 |
+
image.save(buffer, format="PNG")
|
| 362 |
+
return buffer.getvalue()
|
| 363 |
+
elif output_format == "base64":
|
| 364 |
+
buffer = io.BytesIO()
|
| 365 |
+
image.save(buffer, format="PNG")
|
| 366 |
+
return base64.b64encode(buffer.getvalue()).decode()
|
| 367 |
+
return image
|
| 368 |
+
|
| 369 |
+
@property
|
| 370 |
+
def last_processing_time(self) -> float:
|
| 371 |
+
"""Get the processing time of the last operation in seconds."""
|
| 372 |
+
return getattr(self, '_last_processing_time', 0.0)
|
| 373 |
+
|
| 374 |
+
def process_batch(
|
| 375 |
+
self,
|
| 376 |
+
images: list,
|
| 377 |
+
**kwargs
|
| 378 |
+
) -> list:
|
| 379 |
+
"""Process multiple images."""
|
| 380 |
+
return [self.process(img, **kwargs) for img in images]
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
# Convenience function
|
| 384 |
+
def remove_background(
|
| 385 |
+
image: Union[str, Path, Image.Image, bytes],
|
| 386 |
+
model: str = "matting",
|
| 387 |
+
capture_all_elements: bool = True,
|
| 388 |
+
edge_refinement: bool = True,
|
| 389 |
+
**kwargs
|
| 390 |
+
) -> Image.Image:
|
| 391 |
+
"""
|
| 392 |
+
Quick function to remove background from an image.
|
| 393 |
+
|
| 394 |
+
Args:
|
| 395 |
+
image: Input image
|
| 396 |
+
model: Model variant ("matting" recommended for t-shirt designs)
|
| 397 |
+
capture_all_elements: Capture bubbles, small elements (uses lower threshold)
|
| 398 |
+
edge_refinement: Smooth edges for clean mockups
|
| 399 |
+
|
| 400 |
+
Returns:
|
| 401 |
+
PIL Image with transparent background
|
| 402 |
+
"""
|
| 403 |
+
processor = CutoutAI(model_variant=model)
|
| 404 |
+
return processor.process(
|
| 405 |
+
image,
|
| 406 |
+
capture_all_elements=capture_all_elements,
|
| 407 |
+
edge_refinement=edge_refinement,
|
| 408 |
+
**kwargs
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
if __name__ == "__main__":
|
| 413 |
+
import argparse
|
| 414 |
+
|
| 415 |
+
parser = argparse.ArgumentParser(description="CutoutAI Background Remover")
|
| 416 |
+
parser.add_argument("input", help="Input image path")
|
| 417 |
+
parser.add_argument("-o", "--output", help="Output path", default=None)
|
| 418 |
+
parser.add_argument("-m", "--model", choices=list(MODEL_VARIANTS.keys()),
|
| 419 |
+
default="matting", help="Model variant")
|
| 420 |
+
parser.add_argument("--no-edge-refinement", action="store_true",
|
| 421 |
+
help="Disable edge refinement")
|
| 422 |
+
parser.add_argument("--threshold", type=float, default=None,
|
| 423 |
+
help="Mask threshold (0.0-1.0)")
|
| 424 |
+
parser.add_argument("--capture-all", action="store_true", default=True,
|
| 425 |
+
help="Use lower threshold to capture small elements")
|
| 426 |
+
|
| 427 |
+
args = parser.parse_args()
|
| 428 |
+
|
| 429 |
+
# Process
|
| 430 |
+
result = remove_background(
|
| 431 |
+
args.input,
|
| 432 |
+
model=args.model,
|
| 433 |
+
edge_refinement=not args.no_edge_refinement,
|
| 434 |
+
capture_all_elements=args.capture_all,
|
| 435 |
+
threshold=args.threshold
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
# Save
|
| 439 |
+
output_path = args.output or args.input.rsplit(".", 1)[0] + "_cutout.png"
|
| 440 |
+
result.save(output_path)
|
| 441 |
+
print(f"Saved to: {output_path}")
|
requirements.txt
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CutoutAI Dependencies
|
| 2 |
+
|
| 3 |
+
# Core ML
|
| 4 |
+
torch>=2.0.0
|
| 5 |
+
torchvision>=0.15.0
|
| 6 |
+
transformers>=4.35.0
|
| 7 |
+
timm>=0.9.0
|
| 8 |
+
kornia>=0.7.0
|
| 9 |
+
|
| 10 |
+
# Image processing
|
| 11 |
+
Pillow>=10.0.0
|
| 12 |
+
numpy>=1.24.0
|
| 13 |
+
scipy>=1.10.0
|
| 14 |
+
|
| 15 |
+
# API server
|
| 16 |
+
fastapi>=0.104.0
|
| 17 |
+
uvicorn[standard]>=0.24.0
|
| 18 |
+
python-multipart>=0.0.6
|
| 19 |
+
httpx>=0.25.0
|
| 20 |
+
|
| 21 |
+
# Optional: for HuggingFace model loading
|
| 22 |
+
huggingface-hub>=0.19.0
|
run-claude-analysis.bat
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
echo ============================================
|
| 3 |
+
echo CutoutAI Analysis via Claude Code + CCR
|
| 4 |
+
echo ============================================
|
| 5 |
+
echo.
|
| 6 |
+
|
| 7 |
+
:: Set environment for CCR
|
| 8 |
+
set CLAUDE_BASE_URL=http://127.0.0.1:3456
|
| 9 |
+
set ANTHROPIC_BASE_URL=http://127.0.0.1:3456
|
| 10 |
+
|
| 11 |
+
:: Change to project directory
|
| 12 |
+
cd /d "C:\Users\jonat_cau4\.gemini\antigravity\scratch\background removal tool"
|
| 13 |
+
|
| 14 |
+
echo Current directory: %CD%
|
| 15 |
+
echo CCR URL: %CLAUDE_BASE_URL%
|
| 16 |
+
echo.
|
| 17 |
+
echo Starting Claude Code with analysis prompt...
|
| 18 |
+
echo.
|
| 19 |
+
|
| 20 |
+
:: Run Claude Code with the analysis task
|
| 21 |
+
claude -p "Read and analyze all files in this project: PROMPT.md, cutoutai.py, api.py, @fix_plan.md, specs/requirements.md, and README.md. Provide a COMPREHENSIVE ANALYSIS including: 1) Code quality assessment, 2) Edge handling for t-shirt mockups - is thresholding correct?, 3) Multi-element capture (bubbles) - is the threshold low enough?, 4) API robustness for n8n/Make integration, 5) Startup preloading implementation, 6) Any bugs or issues found. Then provide SPECIFIC IMPROVEMENT RECOMMENDATIONS with code examples. After analysis, ask if I want you to implement the improvements."
|
| 22 |
+
|
| 23 |
+
pause
|
specs/requirements.md
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CutoutAI Background Remover - Project Specifications
|
| 2 |
+
|
| 3 |
+
## Project Overview
|
| 4 |
+
|
| 5 |
+
**Name**: CutoutAI - Background Remover
|
| 6 |
+
**Purpose**: Flawless background removal for t-shirt mockup preparation in Etsy workflow
|
| 7 |
+
**Core Tech**: BiRefNet (BiRefNet-matting model)
|
| 8 |
+
**Deployment**: Cloud-hosted (n8n/Make integration), webhook API, terminal CLI
|
| 9 |
+
|
| 10 |
+
## Current Workflow (Etsy T-Shirt Pipeline)
|
| 11 |
+
|
| 12 |
+
```
|
| 13 |
+
Gemini Image Gen → Slack Approval → [BACKGROUND REMOVAL] → Printify Mockup → SEO → Etsy/Shopify
|
| 14 |
+
↓
|
| 15 |
+
Feedback loop (re-prompt if needed)
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
## Requirements
|
| 19 |
+
|
| 20 |
+
### Functional Requirements
|
| 21 |
+
|
| 22 |
+
1. **Flawless Quality**
|
| 23 |
+
- NO patchy faces or artifacts
|
| 24 |
+
- Clean edges on hair and fine details
|
| 25 |
+
- Capture ALL elements including:
|
| 26 |
+
- Bubbles
|
| 27 |
+
- Small decorations
|
| 28 |
+
- Floating elements
|
| 29 |
+
- Text overlays
|
| 30 |
+
|
| 31 |
+
2. **Input Handling**
|
| 32 |
+
- Accept various image qualities from Gemini
|
| 33 |
+
- Handle non-white backgrounds (prepare for anything)
|
| 34 |
+
- Process images WITH multiple small elements
|
| 35 |
+
|
| 36 |
+
3. **API/Integration**
|
| 37 |
+
- Webhook endpoint for n8n
|
| 38 |
+
- Base64 input/output for easy workflow integration
|
| 39 |
+
- REST API for batch processing
|
| 40 |
+
- Terminal CLI for manual use
|
| 41 |
+
|
| 42 |
+
4. **Cloud Deployment**
|
| 43 |
+
- Host on HuggingFace Spaces or Google Cloud Run
|
| 44 |
+
- Zero cold-start penalty (or minimal)
|
| 45 |
+
- Handle concurrent requests
|
| 46 |
+
|
| 47 |
+
### Non-Functional Requirements
|
| 48 |
+
|
| 49 |
+
1. **Performance**
|
| 50 |
+
- Sub-10 second processing for standard images
|
| 51 |
+
- Batch processing capability
|
| 52 |
+
|
| 53 |
+
2. **Reliability**
|
| 54 |
+
- Health check endpoint
|
| 55 |
+
- Error reporting to callback URLs
|
| 56 |
+
|
| 57 |
+
## Technical Specifications
|
| 58 |
+
|
| 59 |
+
### Recommended Model Settings
|
| 60 |
+
|
| 61 |
+
```python
|
| 62 |
+
# BiRefNet-matting is CRITICAL for edge quality
|
| 63 |
+
model_variant = "matting" # NOT "general"
|
| 64 |
+
|
| 65 |
+
# Resolution considerations
|
| 66 |
+
# - 1024x1024 for standard processing
|
| 67 |
+
# - 2048x2048 for high-res (BiRefNet_HR)
|
| 68 |
+
|
| 69 |
+
# Edge refinement is REQUIRED for mockups
|
| 70 |
+
edge_refinement = True
|
| 71 |
+
edge_radius = 2 # Subtle smoothing
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
### Known Issues to Address
|
| 75 |
+
|
| 76 |
+
1. **Artifact Prevention**
|
| 77 |
+
- Downsampling large images can cause artifacts
|
| 78 |
+
- Solution: Use appropriate input resolution matching model
|
| 79 |
+
- Consider super-resolution post-processing if needed
|
| 80 |
+
|
| 81 |
+
2. **Multi-Element Capture**
|
| 82 |
+
- BiRefNet's bilateral reference should capture small elements
|
| 83 |
+
- May need to adjust detection thresholds for bubbles/decorations
|
| 84 |
+
|
| 85 |
+
3. **Edge Quality**
|
| 86 |
+
- `refine_foreground` function is essential
|
| 87 |
+
- Edge smoothing radius should be configurable
|
| 88 |
+
|
| 89 |
+
## API Specification
|
| 90 |
+
|
| 91 |
+
### Endpoints Required
|
| 92 |
+
|
| 93 |
+
```yaml
|
| 94 |
+
POST /api/v1/remove:
|
| 95 |
+
input: multipart/form-data OR JSON with base64
|
| 96 |
+
params:
|
| 97 |
+
- model: string (matting|general|portrait|hr)
|
| 98 |
+
- edge_refinement: boolean
|
| 99 |
+
- edge_radius: int (1-5)
|
| 100 |
+
- output_format: string (png|base64)
|
| 101 |
+
output: PNG file OR JSON with base64
|
| 102 |
+
|
| 103 |
+
POST /webhook:
|
| 104 |
+
input:
|
| 105 |
+
- image: file upload OR
|
| 106 |
+
- image_base64: string OR
|
| 107 |
+
- image_url: string
|
| 108 |
+
output: JSON with base64 image
|
| 109 |
+
|
| 110 |
+
GET /health:
|
| 111 |
+
output: JSON status
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
### n8n Integration
|
| 115 |
+
|
| 116 |
+
The webhook must be compatible with n8n HTTP Request node:
|
| 117 |
+
- Accept multipart/form-data
|
| 118 |
+
- Return JSON with `image_base64` field
|
| 119 |
+
- Support `callback_url` parameter for async notifications
|
| 120 |
+
|
| 121 |
+
## Files to Review
|
| 122 |
+
|
| 123 |
+
1. `cutoutai.py` - Core background removal logic
|
| 124 |
+
2. `api.py` - FastAPI server and endpoints
|
| 125 |
+
3. `requirements.txt` - Dependencies
|
| 126 |
+
|
| 127 |
+
## Success Criteria
|
| 128 |
+
|
| 129 |
+
- [ ] Process Gemini-generated designs without artifacts
|
| 130 |
+
- [ ] Capture bubbles and small decorative elements
|
| 131 |
+
- [ ] Clean edges suitable for Printify mockups
|
| 132 |
+
- [ ] Working webhook for n8n integration
|
| 133 |
+
- [ ] Base64 input/output for workflow compatibility
|
| 134 |
+
- [ ] Health check endpoint for monitoring
|
test_cutout.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import io
|
| 3 |
+
import base64
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image, ImageDraw
|
| 6 |
+
import cutoutai
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
# Setup logging
|
| 10 |
+
logging.basicConfig(level=logging.INFO)
|
| 11 |
+
logger = logging.getLogger("TestCutoutAI")
|
| 12 |
+
|
| 13 |
+
def create_test_image(path="test_input.png"):
|
| 14 |
+
"""Create a synthetic test image with bubbles and a central object."""
|
| 15 |
+
# 512x512 white background
|
| 16 |
+
img = Image.new("RGB", (512, 512), (240, 240, 240))
|
| 17 |
+
draw = ImageDraw.Draw(img)
|
| 18 |
+
|
| 19 |
+
# Draw a "subject" (blue circle)
|
| 20 |
+
draw.ellipse([150, 150, 362, 362], fill=(0, 0, 255), outline=(0, 0, 0))
|
| 21 |
+
|
| 22 |
+
# Draw "bubbles" (small circles)
|
| 23 |
+
draw.ellipse([50, 50, 80, 80], fill=(200, 200, 255, 128), outline=(100, 100, 100))
|
| 24 |
+
draw.ellipse([400, 100, 430, 130], fill=(200, 200, 255, 128), outline=(100, 100, 100))
|
| 25 |
+
draw.ellipse([100, 400, 140, 440], fill=(255, 200, 200, 128), outline=(100, 100, 100))
|
| 26 |
+
|
| 27 |
+
# Draw some "fine detail" (thin lines)
|
| 28 |
+
draw.line([256, 0, 256, 150], fill=(0, 0, 0), width=1)
|
| 29 |
+
|
| 30 |
+
img.save(path)
|
| 31 |
+
logger.info(f"Created test image: {path}")
|
| 32 |
+
return path
|
| 33 |
+
|
| 34 |
+
def test_processing():
|
| 35 |
+
"""Test the core processing logic."""
|
| 36 |
+
input_path = create_test_image()
|
| 37 |
+
|
| 38 |
+
# Use 'lite' variant for faster testing if possible,
|
| 39 |
+
# but the prompt asks for BiRefNet quality analysis.
|
| 40 |
+
# Note: Loading the model will take time and requires internet + torch.
|
| 41 |
+
# If we are in a restricted environment, this might fail.
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
processor = cutoutai.CutoutAI(model_variant="lite") # Using lite for faster test
|
| 45 |
+
|
| 46 |
+
logger.info("Running process()...")
|
| 47 |
+
result = processor.process(
|
| 48 |
+
input_path,
|
| 49 |
+
capture_all_elements=True,
|
| 50 |
+
edge_refinement=True,
|
| 51 |
+
edge_radius=2,
|
| 52 |
+
output_format="pil"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
output_path = "test_output.png"
|
| 56 |
+
result.save(output_path)
|
| 57 |
+
logger.info(f"Saved result to: {output_path}")
|
| 58 |
+
|
| 59 |
+
# Check if output is RGBA
|
| 60 |
+
if result.mode == "RGBA":
|
| 61 |
+
logger.info("SUCCESS: Output is in RGBA mode.")
|
| 62 |
+
else:
|
| 63 |
+
logger.error(f"FAILURE: Output mode is {result.mode}, expected RGBA.")
|
| 64 |
+
|
| 65 |
+
except Exception as e:
|
| 66 |
+
logger.error(f"Error during processing: {e}")
|
| 67 |
+
logger.info("Note: This test requires torch and transformers to be installed and working.")
|
| 68 |
+
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
test_processing()
|