|
|
--- |
|
|
license: mit |
|
|
tags: |
|
|
- tattoo-segmentation |
|
|
- image-segmentation |
|
|
- pytorch |
|
|
- unet |
|
|
- edge-detection |
|
|
--- |
|
|
|
|
|
# Deep Tattoo Segmentation v5/v7 |
|
|
|
|
|
Edge-Aware Attention U-Net for precise tattoo extraction with transparent background. |
|
|
|
|
|
## Models |
|
|
|
|
|
- `edge_aware_v3_clahe_best.pth`: **v5 Model** - Best performance (Val Dice: **90.50%**) |
|
|
- Edge-Aware Attention U-Net architecture |
|
|
- CLAHE preprocessing for lighting invariance |
|
|
- Trained on 24 manual labels + 165 auto-generated masks |
|
|
- Test-Time Augmentation (TTA) support |
|
|
|
|
|
- `edge_aware_v7_samrefiner_best.pth`: **v7 Model with SAMRefiner** (Val Dice: 74.25%) |
|
|
- Same architecture as v5 |
|
|
- Trained with 122 SAMRefiner-refined masks + 25 v2 fallback |
|
|
- SAMRefiner (ICLR 2025) for mask refinement |
|
|
- Better edge quality but lower validation score (due to v2-only validation set) |
|
|
|
|
|
- `edge_aware_improved_best.pth`: Base Model (Val Dice: 79.17%) |
|
|
- Foundation model before hybrid training |
|
|
|
|
|
## Key Features |
|
|
|
|
|
- **High Accuracy**: 90.50% Dice coefficient on validation set |
|
|
- **Edge Detection**: Specialized edge-aware attention mechanism |
|
|
- **Lighting Invariant**: CLAHE preprocessing handles various lighting conditions |
|
|
- **Transparent Output**: Extracts tattoos with alpha channel for transparent background |
|
|
- **Production Ready**: Optimized for inference with TTA |
|
|
|
|
|
## Usage |
|
|
|
|
|
```python |
|
|
import torch |
|
|
import cv2 |
|
|
import numpy as np |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
# Download v5 model |
|
|
model_path = hf_hub_download( |
|
|
repo_id="jun710/deep-tattoo", |
|
|
filename="edge_aware_v3_clahe_best.pth" |
|
|
) |
|
|
|
|
|
# Load model |
|
|
checkpoint = torch.load(model_path, map_location='cpu') |
|
|
# Use with EdgeAwareAttentionUNet from the repository |
|
|
|
|
|
# Preprocess with CLAHE |
|
|
image = cv2.imread("tattoo.jpg") |
|
|
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) |
|
|
# ... apply model and extract tattoo |
|
|
``` |
|
|
|
|
|
## Model Architecture |
|
|
|
|
|
- **Encoder**: EfficientNet-B3 backbone with edge detection branch |
|
|
- **Decoder**: Attention-based skip connections |
|
|
- **Output**: Binary segmentation mask (tattoo vs background) |
|
|
|
|
|
## Training Details |
|
|
|
|
|
- Image Size: 256x256 |
|
|
- Batch Size: 8 |
|
|
- Optimizer: Adam (lr=1e-4) |
|
|
- Loss: Boundary-Aware Loss (Dice + BCE + Edge) |
|
|
- Augmentation: Strong geometric + color transformations |
|
|
- Training Data: 189 images (24 manual + 165 auto-generated) |
|
|
|
|
|
## Performance |
|
|
|
|
|
- Validation Dice: 90.50% |
|
|
- Test Coverage: 52 diverse images |
|
|
- Success Rate: ~97% on typical tattoos |
|
|
- Limitation: Fine details on very delicate outlines (~3% cases) |
|
|
|
|
|
## Repository |
|
|
|
|
|
https://github.com/enjius/deep-tattoo |
|
|
|