| """ |
| CutoutAI - Enhanced Background Removal for Perfect T-Shirt Mockups |
| |
| Built on BiRefNet for flawless background removal with: |
| - Multi-element capture (bubbles, decorations, small details) |
| - Edge refinement for clean cutouts |
| - Optimized for Printify mockup preparation |
| """ |
|
|
| import io |
| import base64 |
| import time |
| import logging |
| from typing import Optional, Literal, Union |
| from pathlib import Path |
|
|
| import torch |
| import numpy as np |
| from PIL import Image, ImageFilter |
| from torchvision import transforms |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| ) |
| logger = logging.getLogger("CutoutAI") |
|
|
| |
| MODEL_VARIANTS = { |
| "general": "ZhengPeng7/BiRefNet", |
| "matting": "ZhengPeng7/BiRefNet-matting", |
| "portrait": "ZhengPeng7/BiRefNet-portrait", |
| "lite": "ZhengPeng7/BiRefNet_lite", |
| "hr": "ZhengPeng7/BiRefNet_HR", |
| "dynamic": "ZhengPeng7/BiRefNet_dynamic", |
| } |
|
|
| |
| def get_transforms(size: int = 1024): |
| """Get preprocessing transforms for BiRefNet.""" |
| return transforms.Compose([ |
| transforms.Resize((size, size)), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225] |
| ) |
| ]) |
|
|
|
|
| def refine_foreground(image: Image.Image, mask: Image.Image) -> Image.Image: |
| """ |
| Apply mask to image with refined edges for flawless cutouts. |
| |
| This is critical for t-shirt mockups - ensures: |
| - No patchy faces or artifacts |
| - Clean edges on hair and fine details |
| - All small elements (bubbles, decorations) captured |
| """ |
| |
| image = image.convert("RGBA") |
| mask = mask.convert("L") |
| |
| |
| if mask.size != image.size: |
| mask = mask.resize(image.size, Image.LANCZOS) |
| |
| |
| result = Image.new("RGBA", image.size, (0, 0, 0, 0)) |
| result.paste(image, mask=mask) |
| |
| return result |
|
|
|
|
| def edge_smooth(mask: Image.Image, radius: int = 2, preserve_details: bool = True) -> Image.Image: |
| """ |
| Apply edge smoothing while preserving fine details. |
| |
| Args: |
| mask: Binary or grayscale mask |
| radius: Smoothing intensity (1-5 recommended) |
| preserve_details: If True, use morphological ops instead of blur |
| """ |
| if radius <= 0: |
| return mask |
|
|
| if preserve_details: |
| |
| |
| |
| size = 2 * radius + 1 |
| eroded = mask.filter(ImageFilter.MinFilter(size)) |
| |
| smoothed = eroded.filter(ImageFilter.MaxFilter(size)) |
|
|
| |
| if radius > 1: |
| smoothed = smoothed.filter(ImageFilter.MedianFilter(3)) |
| else: |
| |
| smoothed = mask.filter(ImageFilter.GaussianBlur(radius=radius)) |
|
|
| return smoothed |
|
|
|
|
| def remove_small_artifacts(mask: Image.Image, min_size: int = 100) -> Image.Image: |
| """ |
| Remove small isolated 'islands' from the mask that are likely artifacts. |
| |
| Args: |
| mask: Grayscale mask (PIL Image) |
| min_size: Minimum pixel area to keep |
| """ |
| import numpy as np |
| from scipy import ndimage |
|
|
| |
| mask_np = np.array(mask) > 128 |
|
|
| |
| label_im, nb_labels = ndimage.label(mask_np) |
|
|
| |
| sizes = ndimage.sum(mask_np, label_im, range(nb_labels + 1)) |
|
|
| |
| mask_size = sizes < min_size |
| remove_pixel = mask_size[label_im] |
|
|
| |
| mask_np[remove_pixel] = 0 |
|
|
| return Image.fromarray((mask_np * 255).astype(np.uint8)) |
|
|
|
|
| def calculate_adaptive_threshold(pred: np.ndarray, base_threshold: float = 0.2) -> float: |
| """ |
| Calculate an adaptive threshold based on the prediction distribution. |
| Useful for capturing small design elements without introducing too much noise. |
| """ |
| |
| |
| |
|
|
| |
| |
| p95 = np.percentile(pred, 95) |
| if p95 < 0.5: |
| return max(0.05, base_threshold * 0.5) |
|
|
| return base_threshold |
|
|
|
|
| def apply_threshold(pred: np.ndarray, threshold: float = 0.4, soft: bool = False) -> np.ndarray: |
| """ |
| Apply threshold to mask for cleaner binary edges. |
| |
| Args: |
| pred: Prediction array (0-1 range) |
| threshold: Cutoff value (pixels below become 0, above become 1) |
| soft: If True, use a soft threshold (keep low confidence regions as semi-transparent) |
| |
| Returns: |
| Thresholded array |
| """ |
| if soft: |
| |
| |
| |
| return 1.0 / (1.0 + np.exp(-15 * (pred - threshold))) |
|
|
| return np.where(pred > threshold, 1.0, 0.0) |
|
|
|
|
| class CutoutAI: |
| """ |
| Enhanced background removal optimized for t-shirt mockup preparation. |
| |
| Key features: |
| - Captures ALL elements including bubbles, small decorations |
| - Flawless edge quality with no artifacts |
| - Multiple model options for different use cases |
| """ |
| |
| def __init__( |
| self, |
| model_variant: Literal["general", "matting", "portrait", "lite", "hr", "dynamic"] = "matting", |
| device: Optional[str] = None, |
| resolution: int = 1024 |
| ): |
| """ |
| Initialize CutoutAI. |
| |
| Args: |
| model_variant: Which BiRefNet model to use |
| - "matting": Best for complex edges, hair, fine details (RECOMMENDED) |
| - "general": Standard background removal |
| - "portrait": Optimized for faces/people |
| - "lite": Faster processing, lower quality |
| - "hr": High resolution up to 2K |
| - "dynamic": Variable resolution support |
| device: "cuda", "cpu", or None for auto-detect |
| resolution: Processing resolution (1024 or 2048 for hr model) |
| """ |
| self.model_variant = model_variant |
| self.model_name = MODEL_VARIANTS[model_variant] |
| self.resolution = resolution |
| |
| |
| if device is None: |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| else: |
| self.device = device |
| |
| self.model = None |
| self.transforms = get_transforms(resolution) |
| |
| def load_model(self): |
| """Load the BiRefNet model from HuggingFace.""" |
| if self.model is not None: |
| return |
| |
| from transformers import AutoModelForImageSegmentation |
| |
| print(f"Loading {self.model_name}...") |
| self.model = AutoModelForImageSegmentation.from_pretrained( |
| self.model_name, |
| trust_remote_code=True |
| ) |
| self.model.to(self.device) |
| self.model.eval() |
| print(f"Model loaded on {self.device}") |
| |
| def process( |
| self, |
| image: Union[str, Path, Image.Image, bytes], |
| capture_all_elements: bool = True, |
| edge_refinement: bool = True, |
| edge_radius: int = 2, |
| threshold: Optional[float] = None, |
| soft_threshold: bool = False, |
| preserve_details: bool = True, |
| remove_artifacts: bool = True, |
| min_artifact_size: int = 40, |
| adaptive_threshold: bool = True, |
| return_mask: bool = False, |
| output_format: Literal["pil", "bytes", "base64"] = "pil" |
| ) -> Union[Image.Image, bytes, str, dict]: |
| """ |
| Remove background from image with enhanced quality. |
| |
| Args: |
| image: Input image (path, PIL Image, or bytes) |
| capture_all_elements: Use lower threshold to capture bubbles/small elements |
| edge_refinement: Apply edge smoothing for cleaner cutouts |
| edge_radius: Smoothing intensity (1-5, default 2) |
| threshold: Override mask threshold (0.0-1.0, None for auto) |
| soft_threshold: Use soft thresholding for smoother transitions |
| preserve_details: Use morphological ops instead of blur |
| remove_artifacts: Remove small isolated islands from mask |
| min_artifact_size: Minimum pixel area for islands to keep |
| adaptive_threshold: Calculate threshold based on image confidence |
| return_mask: If True, return a dict containing both result and mask |
| output_format: Return format ("pil", "bytes", "base64") |
| |
| Returns: |
| Processed image with transparent background (or dict if return_mask=True) |
| """ |
| start_time = time.time() |
| logger.info(f"Processing image with variant: {self.model_variant}") |
| self.load_model() |
|
|
| |
| try: |
| if isinstance(image, (str, Path)): |
| pil_image = Image.open(image).convert("RGB") |
| elif isinstance(image, bytes): |
| pil_image = Image.open(io.BytesIO(image)).convert("RGB") |
| else: |
| pil_image = image.convert("RGB") |
| except Exception as e: |
| logger.error(f"Failed to load image: {e}") |
| raise ValueError(f"Invalid image input: {e}") |
|
|
| original_size = pil_image.size |
| logger.info(f"Image size: {original_size}") |
|
|
| |
| input_tensor = self.transforms(pil_image).unsqueeze(0).to(self.device) |
|
|
| |
| with torch.no_grad(): |
| outputs = self.model(input_tensor) |
|
|
| |
| if isinstance(outputs, (list, tuple)): |
| pred = outputs[0] |
| else: |
| pred = outputs |
|
|
| |
| pred = pred.squeeze().cpu().numpy() |
|
|
| |
| |
| |
| if threshold is not None: |
| mask_threshold = threshold |
| elif capture_all_elements: |
| mask_threshold = 0.2 |
| if adaptive_threshold: |
| mask_threshold = calculate_adaptive_threshold(pred, mask_threshold) |
| else: |
| mask_threshold = 0.4 |
|
|
| logger.info(f"Using threshold: {mask_threshold:.4f} (soft: {soft_threshold})") |
| pred = apply_threshold(pred, mask_threshold, soft=soft_threshold) |
|
|
| |
| pred = (pred * 255).astype(np.uint8) |
| mask = Image.fromarray(pred).resize(original_size, Image.LANCZOS) |
|
|
| |
| if remove_artifacts: |
| logger.info(f"Removing small artifacts (min_size: {min_artifact_size})") |
| try: |
| mask = remove_small_artifacts(mask, min_size=min_artifact_size) |
| except ImportError: |
| logger.warning("Scipy not installed, skipping artifact removal") |
|
|
| |
| if edge_refinement: |
| logger.info(f"Applying edge refinement (radius: {edge_radius})") |
| mask = edge_smooth(mask, radius=edge_radius, preserve_details=preserve_details) |
|
|
| |
| result = refine_foreground(pil_image, mask) |
|
|
| |
| self._last_processing_time = time.time() - start_time |
| logger.info(f"Processing completed in {self._last_processing_time:.2f}s") |
|
|
| |
| if return_mask: |
| return { |
| "result": self._format_output(result, output_format), |
| "mask": self._format_output(mask, output_format), |
| "threshold_used": mask_threshold, |
| "processing_time": self._last_processing_time |
| } |
|
|
| return self._format_output(result, output_format) |
|
|
| def _format_output(self, image: Image.Image, output_format: str) -> Union[Image.Image, bytes, str]: |
| """Format PIL Image to requested output format.""" |
| if output_format == "pil": |
| return image |
| elif output_format == "bytes": |
| buffer = io.BytesIO() |
| image.save(buffer, format="PNG") |
| return buffer.getvalue() |
| elif output_format == "base64": |
| buffer = io.BytesIO() |
| image.save(buffer, format="PNG") |
| return base64.b64encode(buffer.getvalue()).decode() |
| return image |
|
|
| @property |
| def last_processing_time(self) -> float: |
| """Get the processing time of the last operation in seconds.""" |
| return getattr(self, '_last_processing_time', 0.0) |
| |
| def process_batch( |
| self, |
| images: list, |
| **kwargs |
| ) -> list: |
| """Process multiple images.""" |
| return [self.process(img, **kwargs) for img in images] |
|
|
|
|
| |
| def remove_background( |
| image: Union[str, Path, Image.Image, bytes], |
| model: str = "matting", |
| capture_all_elements: bool = True, |
| edge_refinement: bool = True, |
| **kwargs |
| ) -> Image.Image: |
| """ |
| Quick function to remove background from an image. |
| |
| Args: |
| image: Input image |
| model: Model variant ("matting" recommended for t-shirt designs) |
| capture_all_elements: Capture bubbles, small elements (uses lower threshold) |
| edge_refinement: Smooth edges for clean mockups |
| |
| Returns: |
| PIL Image with transparent background |
| """ |
| processor = CutoutAI(model_variant=model) |
| return processor.process( |
| image, |
| capture_all_elements=capture_all_elements, |
| edge_refinement=edge_refinement, |
| **kwargs |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| |
| parser = argparse.ArgumentParser(description="CutoutAI Background Remover") |
| parser.add_argument("input", help="Input image path") |
| parser.add_argument("-o", "--output", help="Output path", default=None) |
| parser.add_argument("-m", "--model", choices=list(MODEL_VARIANTS.keys()), |
| default="matting", help="Model variant") |
| parser.add_argument("--no-edge-refinement", action="store_true", |
| help="Disable edge refinement") |
| parser.add_argument("--threshold", type=float, default=None, |
| help="Mask threshold (0.0-1.0)") |
| parser.add_argument("--capture-all", action="store_true", default=True, |
| help="Use lower threshold to capture small elements") |
| |
| args = parser.parse_args() |
| |
| |
| result = remove_background( |
| args.input, |
| model=args.model, |
| edge_refinement=not args.no_edge_refinement, |
| capture_all_elements=args.capture_all, |
| threshold=args.threshold |
| ) |
| |
| |
| output_path = args.output or args.input.rsplit(".", 1)[0] + "_cutout.png" |
| result.save(output_path) |
| print(f"Saved to: {output_path}") |
|
|