Upload 25 files
Browse files- README.md +49 -4
- app.py +956 -0
- cdl_smoothing.py +497 -0
- clothes_segmentation.py +292 -0
- color_matching.py +698 -0
- core.py +356 -0
- examples/beach.jpg +0 -0
- examples/field.jpg +0 -0
- examples/sky.jpg +0 -0
- face_comparison.py +246 -0
- folder_paths.py +22 -0
- human_parts_segmentation.py +322 -0
- models/RMBG/segformer_clothes/.cache/huggingface/.gitignore +1 -0
- models/RMBG/segformer_clothes/.cache/huggingface/download/config.json.lock +0 -0
- models/RMBG/segformer_clothes/.cache/huggingface/download/config.json.metadata +3 -0
- models/RMBG/segformer_clothes/.cache/huggingface/download/model.safetensors.lock +0 -0
- models/RMBG/segformer_clothes/.cache/huggingface/download/model.safetensors.metadata +3 -0
- models/RMBG/segformer_clothes/.cache/huggingface/download/preprocessor_config.json.lock +0 -0
- models/RMBG/segformer_clothes/.cache/huggingface/download/preprocessor_config.json.metadata +3 -0
- models/RMBG/segformer_clothes/config.json +110 -0
- models/RMBG/segformer_clothes/model.safetensors +3 -0
- models/RMBG/segformer_clothes/preprocessor_config.json +23 -0
- models/onnx/human-parts/deeplabv3p-resnet50-human.onnx +3 -0
- requirements.txt +30 -0
- spaces.py +12 -0
README.md
CHANGED
|
@@ -1,12 +1,57 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: gray
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 5.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: LaDeco
|
| 3 |
+
emoji: 👀
|
| 4 |
colorFrom: gray
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.31.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
short_description: 'LaDeco: A tool to analyze visual landscape elements'
|
| 11 |
---
|
| 12 |
|
| 13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 14 |
+
|
| 15 |
+
# LaDeco - Landscape Environment Semantic Analysis Model
|
| 16 |
+
|
| 17 |
+
LaDeco is a tool that analyzes landscape images, performs semantic segmentation to identify different elements in the scene (sky, vegetation, buildings, etc.), and enables region-based color matching between images.
|
| 18 |
+
|
| 19 |
+
## Features
|
| 20 |
+
|
| 21 |
+
### Semantic Segmentation
|
| 22 |
+
- Analyzes landscape images and segments them into different semantic regions
|
| 23 |
+
- Provides area ratio analysis for each landscape element
|
| 24 |
+
|
| 25 |
+
### Region-Based Color Matching
|
| 26 |
+
- Matches colors between corresponding semantic regions of two images
|
| 27 |
+
- Shows visualization of which regions are being matched between images
|
| 28 |
+
- Offers multiple color matching algorithms:
|
| 29 |
+
- **adain**: Adaptive Instance Normalization - Matches mean and standard deviation of colors
|
| 30 |
+
- **mkl**: Monge-Kantorovich Linearization - Linear transformation of color statistics
|
| 31 |
+
- **reinhard**: Reinhard color transfer - Simple statistical approach that matches mean and standard deviation
|
| 32 |
+
- **mvgd**: Multi-Variate Gaussian Distribution - Uses color covariance matrices for more accurate matching
|
| 33 |
+
- **hm**: Histogram Matching - Matches the full color distribution histograms
|
| 34 |
+
- **hm-mvgd-hm**: Histogram + MVGD + Histogram compound method
|
| 35 |
+
- **hm-mkl-hm**: Histogram + MKL + Histogram compound method
|
| 36 |
+
|
| 37 |
+
## Installation
|
| 38 |
+
|
| 39 |
+
1. Clone this repository
|
| 40 |
+
2. Create a virtual environment: `python3 -m venv .venv`
|
| 41 |
+
3. Activate the virtual environment: `source .venv/bin/activate`
|
| 42 |
+
4. Install requirements: `pip install -r requirements.txt`
|
| 43 |
+
5. Run the application: `python app.py`
|
| 44 |
+
|
| 45 |
+
## Usage
|
| 46 |
+
|
| 47 |
+
1. Upload two landscape images - the first will be the color reference, the second will be color-matched to the first
|
| 48 |
+
2. Choose a color matching method from the dropdown menu
|
| 49 |
+
3. Click "Start Analysis" to process the images
|
| 50 |
+
4. View the results in the Segmentation and Color Matching tabs
|
| 51 |
+
- Segmentation tab shows the semantic segmentation and area ratios for both images
|
| 52 |
+
- Color Matching tab shows the matched regions visualization and the color matching result
|
| 53 |
+
|
| 54 |
+
## Reference
|
| 55 |
+
|
| 56 |
+
Li-Chih Ho (2023), LaDeco: A Tool to Analyze Visual Landscape Elements, Ecological Informatics, vol. 78.
|
| 57 |
+
https://www.sciencedirect.com/science/article/pii/S1574954123003187
|
app.py
ADDED
|
@@ -0,0 +1,956 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from core import Ladeco
|
| 3 |
+
from matplotlib.figure import Figure
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import matplotlib as mpl
|
| 6 |
+
import spaces
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import numpy as np
|
| 9 |
+
from color_matching import RegionColorMatcher, create_comparison_figure
|
| 10 |
+
from face_comparison import FaceComparison
|
| 11 |
+
from cdl_smoothing import cdl_edge_smoothing, get_smoothing_stats, cdl_edge_smoothing_apply_to_source
|
| 12 |
+
import tempfile
|
| 13 |
+
import os
|
| 14 |
+
import cv2
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
plt.rcParams['figure.facecolor'] = '#0b0f19'
|
| 18 |
+
plt.rcParams['text.color'] = '#aab6cc'
|
| 19 |
+
ladeco = Ladeco()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@spaces.GPU
|
| 23 |
+
def infer_two_images(img1: str, img2: str, method: str, enable_face_matching: bool, enable_edge_smoothing: bool) -> tuple[Figure, Figure, Figure, Figure, Figure, Figure, str, str, str]:
|
| 24 |
+
"""
|
| 25 |
+
Clean 4-step approach:
|
| 26 |
+
1. Segment both images identically
|
| 27 |
+
2. Determine segment correspondences
|
| 28 |
+
3. Match each segment pair in isolation
|
| 29 |
+
4. Composite all matched segments
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
cdl_display = "" # Initialize CDL display string
|
| 33 |
+
|
| 34 |
+
# STEP 1: SEGMENT BOTH IMAGES IDENTICALLY
|
| 35 |
+
# This step is always identical regardless of face matching
|
| 36 |
+
print("Step 1: Segmenting both images...")
|
| 37 |
+
out1 = ladeco.predict(img1)
|
| 38 |
+
out2 = ladeco.predict(img2)
|
| 39 |
+
|
| 40 |
+
# Extract visualization and stats (unchanged)
|
| 41 |
+
seg1 = out1.visualize(level=2)[0].image
|
| 42 |
+
colormap1 = out1.color_map(level=2)
|
| 43 |
+
area1 = out1.area()[0]
|
| 44 |
+
|
| 45 |
+
seg2 = out2.visualize(level=2)[0].image
|
| 46 |
+
colormap2 = out2.color_map(level=2)
|
| 47 |
+
area2 = out2.area()[0]
|
| 48 |
+
|
| 49 |
+
# Process areas for pie charts
|
| 50 |
+
colors1, l2_area1 = [], {}
|
| 51 |
+
for labelname, area_ratio in area1.items():
|
| 52 |
+
if labelname.startswith("l2") and area_ratio > 0:
|
| 53 |
+
colors1.append(colormap1[labelname])
|
| 54 |
+
labelname = labelname.replace("l2_", "").capitalize()
|
| 55 |
+
l2_area1[labelname] = area_ratio
|
| 56 |
+
|
| 57 |
+
colors2, l2_area2 = [], {}
|
| 58 |
+
for labelname, area_ratio in area2.items():
|
| 59 |
+
if labelname.startswith("l2") and area_ratio > 0:
|
| 60 |
+
colors2.append(colormap2[labelname])
|
| 61 |
+
labelname = labelname.replace("l2_", "").capitalize()
|
| 62 |
+
l2_area2[labelname] = area_ratio
|
| 63 |
+
|
| 64 |
+
pie1 = plot_pie(l2_area1, colors=colors1)
|
| 65 |
+
pie2 = plot_pie(l2_area2, colors=colors2)
|
| 66 |
+
|
| 67 |
+
# Set plot sizes
|
| 68 |
+
for fig in [seg1, seg2, pie1, pie2]:
|
| 69 |
+
fig.set_dpi(96)
|
| 70 |
+
fig.set_size_inches(256/96, 256/96)
|
| 71 |
+
|
| 72 |
+
# Extract semantic masks - IDENTICAL for both images regardless of face matching
|
| 73 |
+
masks1 = extract_semantic_masks(out1)
|
| 74 |
+
masks2 = extract_semantic_masks(out2)
|
| 75 |
+
|
| 76 |
+
print(f"Extracted {len(masks1)} masks from img1, {len(masks2)} masks from img2")
|
| 77 |
+
|
| 78 |
+
# STEP 2: DETERMINE SEGMENT CORRESPONDENCES
|
| 79 |
+
print("Step 2: Determining segment correspondences...")
|
| 80 |
+
face_log = ["Step 2: Determining segment correspondences"]
|
| 81 |
+
|
| 82 |
+
# Find common segments between both images
|
| 83 |
+
common_segments = set(masks1.keys()).intersection(set(masks2.keys()))
|
| 84 |
+
face_log.append(f"Found {len(common_segments)} common segments: {sorted(common_segments)}")
|
| 85 |
+
|
| 86 |
+
# Determine which segments to match based on face matching logic
|
| 87 |
+
segments_to_match = determine_segments_to_match(img1, img2, common_segments, enable_face_matching, face_log)
|
| 88 |
+
|
| 89 |
+
face_log.append(f"Final segments to match: {sorted(segments_to_match)}")
|
| 90 |
+
|
| 91 |
+
# STEP 3: MATCH EACH SEGMENT PAIR IN ISOLATION
|
| 92 |
+
print("Step 3: Matching each segment pair in isolation...")
|
| 93 |
+
face_log.append("\nStep 3: Color matching each segment independently")
|
| 94 |
+
|
| 95 |
+
matched_regions = {}
|
| 96 |
+
segment_masks = {} # Store masks for all segments being matched
|
| 97 |
+
|
| 98 |
+
for segment_name in segments_to_match:
|
| 99 |
+
if segment_name in masks1 and segment_name in masks2:
|
| 100 |
+
face_log.append(f" Processing {segment_name}...")
|
| 101 |
+
|
| 102 |
+
# Match this segment in complete isolation
|
| 103 |
+
matched_region, final_mask1, final_mask2 = match_single_segment(
|
| 104 |
+
img1, img2,
|
| 105 |
+
masks1[segment_name], masks2[segment_name],
|
| 106 |
+
segment_name, method, face_log
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
if matched_region is not None:
|
| 110 |
+
matched_regions[segment_name] = matched_region
|
| 111 |
+
segment_masks[segment_name] = final_mask2 # Use mask from target image for compositing
|
| 112 |
+
face_log.append(f" ✅ {segment_name} matched successfully")
|
| 113 |
+
else:
|
| 114 |
+
face_log.append(f" ❌ {segment_name} matching failed")
|
| 115 |
+
elif segment_name.startswith('l4_'):
|
| 116 |
+
# Handle fine-grained segments that need to be generated
|
| 117 |
+
face_log.append(f" Processing fine-grained {segment_name}...")
|
| 118 |
+
|
| 119 |
+
matched_region, final_mask1, final_mask2 = match_single_segment(
|
| 120 |
+
img1, img2, None, None, segment_name, method, face_log
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
if matched_region is not None:
|
| 124 |
+
matched_regions[segment_name] = matched_region
|
| 125 |
+
segment_masks[segment_name] = final_mask2 # Store the generated mask
|
| 126 |
+
face_log.append(f" ✅ {segment_name} matched successfully")
|
| 127 |
+
else:
|
| 128 |
+
face_log.append(f" ❌ {segment_name} matching failed")
|
| 129 |
+
|
| 130 |
+
# STEP 4: COMPOSITE ALL MATCHED SEGMENTS
|
| 131 |
+
print("Step 4: Compositing all matched segments...")
|
| 132 |
+
face_log.append(f"\nStep 4: Compositing {len(matched_regions)} matched segments")
|
| 133 |
+
|
| 134 |
+
final_image = composite_matched_segments(img2, matched_regions, segment_masks, face_log)
|
| 135 |
+
|
| 136 |
+
# STEP 5: OPTIONAL CDL-BASED EDGE SMOOTHING
|
| 137 |
+
if enable_edge_smoothing:
|
| 138 |
+
print("Step 5: Applying CDL-based edge smoothing...")
|
| 139 |
+
face_log.append("\nStep 5: CDL edge smoothing - applying CDL transform to image 2 based on composited result")
|
| 140 |
+
|
| 141 |
+
try:
|
| 142 |
+
# Save the composited result temporarily for CDL calculation
|
| 143 |
+
temp_dir = tempfile.gettempdir()
|
| 144 |
+
temp_composite_path = os.path.join(temp_dir, "temp_composite_for_cdl.png")
|
| 145 |
+
final_image.save(temp_composite_path, "PNG")
|
| 146 |
+
|
| 147 |
+
# Calculate CDL parameters to transform image 2 → composited result
|
| 148 |
+
cdl_stats = get_smoothing_stats(img2, temp_composite_path)
|
| 149 |
+
|
| 150 |
+
# Log the CDL values
|
| 151 |
+
slope = cdl_stats['cdl_slope']
|
| 152 |
+
offset = cdl_stats['cdl_offset']
|
| 153 |
+
power = cdl_stats['cdl_power']
|
| 154 |
+
|
| 155 |
+
# Format CDL values for display
|
| 156 |
+
cdl_display = f"""📊 CDL Parameters (Image 2 → Composited Result):
|
| 157 |
+
|
| 158 |
+
🔧 Method: Simple Mean/Std Matching (basic statistical approach)
|
| 159 |
+
|
| 160 |
+
🔸 Slope (Gain):
|
| 161 |
+
Red: {slope[0]:.6f}
|
| 162 |
+
Green: {slope[1]:.6f}
|
| 163 |
+
Blue: {slope[2]:.6f}
|
| 164 |
+
|
| 165 |
+
🔸 Offset:
|
| 166 |
+
Red: {offset[0]:.6f}
|
| 167 |
+
Green: {offset[1]:.6f}
|
| 168 |
+
Blue: {offset[2]:.6f}
|
| 169 |
+
|
| 170 |
+
🔸 Power (Gamma):
|
| 171 |
+
Red: {power[0]:.6f}
|
| 172 |
+
Green: {power[1]:.6f}
|
| 173 |
+
Blue: {power[2]:.6f}
|
| 174 |
+
|
| 175 |
+
These CDL values represent the color transformation needed to convert Image 2 into the composited result.
|
| 176 |
+
|
| 177 |
+
The CDL calculation uses the simplest possible approach: matches the mean and standard deviation
|
| 178 |
+
of each color channel between the original and composited images, with simple gamma calculation
|
| 179 |
+
based on brightness relationships.
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
face_log.append(f"📊 CDL Parameters (image 2 → composited result):")
|
| 183 |
+
face_log.append(f" Method: Simple mean/std matching")
|
| 184 |
+
face_log.append(f" Slope (R,G,B): [{slope[0]:.4f}, {slope[1]:.4f}, {slope[2]:.4f}]")
|
| 185 |
+
face_log.append(f" Offset (R,G,B): [{offset[0]:.4f}, {offset[1]:.4f}, {offset[2]:.4f}]")
|
| 186 |
+
face_log.append(f" Power (R,G,B): [{power[0]:.4f}, {power[1]:.4f}, {power[2]:.4f}]")
|
| 187 |
+
|
| 188 |
+
# Apply CDL transformation to image 2 to approximate the composited result
|
| 189 |
+
final_image = cdl_edge_smoothing_apply_to_source(img2, temp_composite_path, factor=1.0)
|
| 190 |
+
|
| 191 |
+
# Clean up temp file
|
| 192 |
+
if os.path.exists(temp_composite_path):
|
| 193 |
+
os.remove(temp_composite_path)
|
| 194 |
+
|
| 195 |
+
face_log.append("✅ CDL edge smoothing completed - transformed image 2 using calculated CDL parameters")
|
| 196 |
+
|
| 197 |
+
except Exception as e:
|
| 198 |
+
face_log.append(f"❌ CDL edge smoothing failed: {e}")
|
| 199 |
+
cdl_display = f"❌ CDL calculation failed: {e}"
|
| 200 |
+
else:
|
| 201 |
+
face_log.append("\nStep 5: CDL edge smoothing disabled")
|
| 202 |
+
cdl_display = "CDL edge smoothing is disabled. Enable it to see CDL parameters."
|
| 203 |
+
|
| 204 |
+
# Save result
|
| 205 |
+
temp_dir = tempfile.gettempdir()
|
| 206 |
+
filename = os.path.basename(img2).split('.')[0]
|
| 207 |
+
temp_filename = f"color_matched_{method}_{filename}.png"
|
| 208 |
+
temp_path = os.path.join(temp_dir, temp_filename)
|
| 209 |
+
final_image.save(temp_path, "PNG")
|
| 210 |
+
|
| 211 |
+
# Create visualizations
|
| 212 |
+
# For visualization, we need to collect the masks that were actually used
|
| 213 |
+
vis_masks1 = {}
|
| 214 |
+
vis_masks2 = {}
|
| 215 |
+
|
| 216 |
+
for segment_name in segments_to_match:
|
| 217 |
+
if segment_name in segment_masks:
|
| 218 |
+
if segment_name.startswith('l4_'):
|
| 219 |
+
# Fine-grained segments - we'll regenerate for visualization
|
| 220 |
+
part_name = segment_name.replace('l4_', '')
|
| 221 |
+
if part_name in ['face', 'hair']:
|
| 222 |
+
from human_parts_segmentation import HumanPartsSegmentation
|
| 223 |
+
segmenter = HumanPartsSegmentation()
|
| 224 |
+
masks_dict1 = segmenter.segment_parts(img1, [part_name])
|
| 225 |
+
masks_dict2 = segmenter.segment_parts(img2, [part_name])
|
| 226 |
+
if part_name in masks_dict1 and part_name in masks_dict2:
|
| 227 |
+
vis_masks1[segment_name] = masks_dict1[part_name]
|
| 228 |
+
vis_masks2[segment_name] = masks_dict2[part_name]
|
| 229 |
+
elif part_name == 'upper_clothes':
|
| 230 |
+
from clothes_segmentation import ClothesSegmentation
|
| 231 |
+
segmenter = ClothesSegmentation()
|
| 232 |
+
mask1 = segmenter.segment_clothes(img1, ["Upper-clothes"])
|
| 233 |
+
mask2 = segmenter.segment_clothes(img2, ["Upper-clothes"])
|
| 234 |
+
if mask1 is not None and mask2 is not None:
|
| 235 |
+
vis_masks1[segment_name] = mask1
|
| 236 |
+
vis_masks2[segment_name] = mask2
|
| 237 |
+
else:
|
| 238 |
+
# Regular segments - use original masks
|
| 239 |
+
if segment_name in masks1 and segment_name in masks2:
|
| 240 |
+
vis_masks1[segment_name] = masks1[segment_name]
|
| 241 |
+
vis_masks2[segment_name] = masks2[segment_name]
|
| 242 |
+
|
| 243 |
+
mask_vis = visualize_matching_masks(img1, img2, vis_masks1, vis_masks2)
|
| 244 |
+
|
| 245 |
+
comparison = create_comparison_figure(Image.open(img2), final_image, f"Color Matching Result ({method})")
|
| 246 |
+
|
| 247 |
+
face_log_text = "\n".join(face_log)
|
| 248 |
+
|
| 249 |
+
return seg1, pie1, seg2, pie2, comparison, mask_vis, temp_path, face_log_text, cdl_display
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def determine_segments_to_match(img1: str, img2: str, common_segments: set, enable_face_matching: bool, log: list) -> set:
|
| 253 |
+
"""
|
| 254 |
+
Determine which segments should be matched based on face matching logic.
|
| 255 |
+
Returns the set of segment names to process.
|
| 256 |
+
"""
|
| 257 |
+
if not enable_face_matching:
|
| 258 |
+
log.append("Face matching disabled - matching all common segments")
|
| 259 |
+
return common_segments
|
| 260 |
+
|
| 261 |
+
log.append("Face matching enabled - checking faces...")
|
| 262 |
+
|
| 263 |
+
# Run face comparison
|
| 264 |
+
face_comparator = FaceComparison()
|
| 265 |
+
faces_match, face_log = face_comparator.run_face_comparison(img1, img2)
|
| 266 |
+
log.extend(face_log)
|
| 267 |
+
|
| 268 |
+
if not faces_match:
|
| 269 |
+
# Remove human/bio segments from matching
|
| 270 |
+
log.append("No face match - excluding human/bio segments")
|
| 271 |
+
non_human_segments = set()
|
| 272 |
+
for segment in common_segments:
|
| 273 |
+
if not any(term in segment.lower() for term in ['l3_human', 'l2_bio']):
|
| 274 |
+
non_human_segments.add(segment)
|
| 275 |
+
else:
|
| 276 |
+
log.append(f" Excluding human segment: {segment}")
|
| 277 |
+
|
| 278 |
+
log.append(f"Matching {len(non_human_segments)} non-human segments")
|
| 279 |
+
return non_human_segments
|
| 280 |
+
|
| 281 |
+
else:
|
| 282 |
+
# Faces match - include all segments + add fine-grained if possible
|
| 283 |
+
log.append("Faces match - including all segments + fine-grained")
|
| 284 |
+
|
| 285 |
+
segments_to_match = common_segments.copy()
|
| 286 |
+
|
| 287 |
+
# Add fine-grained human parts if bio regions exist
|
| 288 |
+
bio_segments = [s for s in common_segments if 'l2_bio' in s.lower()]
|
| 289 |
+
if bio_segments:
|
| 290 |
+
fine_grained_segments = add_fine_grained_segments(img1, img2, common_segments, log)
|
| 291 |
+
segments_to_match.update(fine_grained_segments)
|
| 292 |
+
|
| 293 |
+
return segments_to_match
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def add_fine_grained_segments(img1: str, img2: str, common_segments: set, log: list) -> set:
|
| 297 |
+
"""
|
| 298 |
+
Add fine-grained human parts segments when faces match.
|
| 299 |
+
Returns set of fine-grained segment names that were successfully added.
|
| 300 |
+
"""
|
| 301 |
+
fine_grained_segments = set()
|
| 302 |
+
|
| 303 |
+
try:
|
| 304 |
+
from human_parts_segmentation import HumanPartsSegmentation
|
| 305 |
+
from clothes_segmentation import ClothesSegmentation
|
| 306 |
+
|
| 307 |
+
log.append(" Adding fine-grained human parts...")
|
| 308 |
+
|
| 309 |
+
# Get face and hair masks
|
| 310 |
+
human_segmenter = HumanPartsSegmentation()
|
| 311 |
+
face_hair_masks1 = human_segmenter.segment_parts(img1, ['face', 'hair'])
|
| 312 |
+
face_hair_masks2 = human_segmenter.segment_parts(img2, ['face', 'hair'])
|
| 313 |
+
|
| 314 |
+
# Get clothes masks
|
| 315 |
+
clothes_segmenter = ClothesSegmentation()
|
| 316 |
+
clothes_mask1 = clothes_segmenter.segment_clothes(img1, ["Upper-clothes"])
|
| 317 |
+
clothes_mask2 = clothes_segmenter.segment_clothes(img2, ["Upper-clothes"])
|
| 318 |
+
|
| 319 |
+
# Process face/hair
|
| 320 |
+
for part_name, mask1 in face_hair_masks1.items():
|
| 321 |
+
if (mask1 is not None and part_name in face_hair_masks2 and
|
| 322 |
+
face_hair_masks2[part_name] is not None):
|
| 323 |
+
|
| 324 |
+
if np.sum(mask1 > 0) > 0 and np.sum(face_hair_masks2[part_name] > 0) > 0:
|
| 325 |
+
fine_grained_segments.add(f'l4_{part_name}')
|
| 326 |
+
log.append(f" Added fine-grained: {part_name}")
|
| 327 |
+
|
| 328 |
+
# Process clothes
|
| 329 |
+
if (clothes_mask1 is not None and clothes_mask2 is not None and
|
| 330 |
+
np.sum(clothes_mask1 > 0) > 0 and np.sum(clothes_mask2 > 0) > 0):
|
| 331 |
+
fine_grained_segments.add('l4_upper_clothes')
|
| 332 |
+
log.append(f" Added fine-grained: upper_clothes")
|
| 333 |
+
|
| 334 |
+
except Exception as e:
|
| 335 |
+
log.append(f" Error adding fine-grained segments: {e}")
|
| 336 |
+
|
| 337 |
+
return fine_grained_segments
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def match_single_segment(img1_path: str, img2_path: str, mask1: np.ndarray, mask2: np.ndarray,
|
| 341 |
+
segment_name: str, method: str, log: list) -> tuple[Image.Image, np.ndarray, np.ndarray]:
|
| 342 |
+
"""
|
| 343 |
+
Match colors of a single segment in complete isolation from other segments.
|
| 344 |
+
Each segment is processed independently with no knowledge of other segments.
|
| 345 |
+
Returns: (matched_image, final_mask1, final_mask2)
|
| 346 |
+
"""
|
| 347 |
+
try:
|
| 348 |
+
# Load images
|
| 349 |
+
img1 = Image.open(img1_path).convert("RGB")
|
| 350 |
+
img2 = Image.open(img2_path).convert("RGB")
|
| 351 |
+
|
| 352 |
+
# Convert to numpy
|
| 353 |
+
img1_np = np.array(img1)
|
| 354 |
+
img2_np = np.array(img2)
|
| 355 |
+
|
| 356 |
+
# Handle fine-grained segments
|
| 357 |
+
if segment_name.startswith('l4_'):
|
| 358 |
+
part_name = segment_name.replace('l4_', '')
|
| 359 |
+
if part_name in ['face', 'hair']:
|
| 360 |
+
from human_parts_segmentation import HumanPartsSegmentation
|
| 361 |
+
segmenter = HumanPartsSegmentation()
|
| 362 |
+
masks_dict1 = segmenter.segment_parts(img1_path, [part_name])
|
| 363 |
+
masks_dict2 = segmenter.segment_parts(img2_path, [part_name])
|
| 364 |
+
|
| 365 |
+
if part_name in masks_dict1 and part_name in masks_dict2:
|
| 366 |
+
mask1 = masks_dict1[part_name]
|
| 367 |
+
mask2 = masks_dict2[part_name]
|
| 368 |
+
else:
|
| 369 |
+
return None, None, None
|
| 370 |
+
|
| 371 |
+
elif part_name == 'upper_clothes':
|
| 372 |
+
from clothes_segmentation import ClothesSegmentation
|
| 373 |
+
segmenter = ClothesSegmentation()
|
| 374 |
+
mask1 = segmenter.segment_clothes(img1_path, ["Upper-clothes"])
|
| 375 |
+
mask2 = segmenter.segment_clothes(img2_path, ["Upper-clothes"])
|
| 376 |
+
|
| 377 |
+
if mask1 is None or mask2 is None:
|
| 378 |
+
return None, None, None
|
| 379 |
+
|
| 380 |
+
# Ensure masks are same size as images
|
| 381 |
+
if mask1.shape != img1_np.shape[:2]:
|
| 382 |
+
mask1 = cv2.resize(mask1.astype(np.float32), (img1_np.shape[1], img1_np.shape[0]),
|
| 383 |
+
interpolation=cv2.INTER_NEAREST)
|
| 384 |
+
if mask2.shape != img2_np.shape[:2]:
|
| 385 |
+
mask2 = cv2.resize(mask2.astype(np.float32), (img2_np.shape[1], img2_np.shape[0]),
|
| 386 |
+
interpolation=cv2.INTER_NEAREST)
|
| 387 |
+
|
| 388 |
+
# Convert to binary masks
|
| 389 |
+
mask1_binary = (mask1 > 0.5).astype(np.float32)
|
| 390 |
+
mask2_binary = (mask2 > 0.5).astype(np.float32)
|
| 391 |
+
|
| 392 |
+
# Check if masks have content
|
| 393 |
+
pixels1 = np.sum(mask1_binary > 0)
|
| 394 |
+
pixels2 = np.sum(mask2_binary > 0)
|
| 395 |
+
|
| 396 |
+
if pixels1 == 0 or pixels2 == 0:
|
| 397 |
+
log.append(f" No pixels in {segment_name}: img1={pixels1}, img2={pixels2}")
|
| 398 |
+
return None, None, None
|
| 399 |
+
|
| 400 |
+
log.append(f" {segment_name}: img1={pixels1} pixels, img2={pixels2} pixels")
|
| 401 |
+
|
| 402 |
+
# Create single-segment masks dictionary for color matcher
|
| 403 |
+
masks1_dict = {segment_name: mask1_binary}
|
| 404 |
+
masks2_dict = {segment_name: mask2_binary}
|
| 405 |
+
|
| 406 |
+
# Apply color matching to this segment only
|
| 407 |
+
color_matcher = RegionColorMatcher(factor=0.8, preserve_colors=True,
|
| 408 |
+
preserve_luminance=True, method=method)
|
| 409 |
+
|
| 410 |
+
matched_img = color_matcher.match_regions(img1_path, img2_path, masks1_dict, masks2_dict)
|
| 411 |
+
|
| 412 |
+
return matched_img, mask1_binary, mask2_binary
|
| 413 |
+
|
| 414 |
+
except Exception as e:
|
| 415 |
+
log.append(f" Error matching {segment_name}: {e}")
|
| 416 |
+
return None, None, None
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def composite_matched_segments(base_img_path: str, matched_regions: dict, segment_masks: dict, log: list) -> Image.Image:
|
| 420 |
+
"""
|
| 421 |
+
Composite all matched segments back together using simple alpha compositing.
|
| 422 |
+
Each matched segment is completely independent and overlaid on the base image.
|
| 423 |
+
"""
|
| 424 |
+
# Start with base image
|
| 425 |
+
result = Image.open(base_img_path).convert("RGBA")
|
| 426 |
+
result_np = np.array(result)
|
| 427 |
+
|
| 428 |
+
log.append(f"Compositing {len(matched_regions)} segments onto base image")
|
| 429 |
+
|
| 430 |
+
for segment_name, matched_img in matched_regions.items():
|
| 431 |
+
if segment_name in segment_masks:
|
| 432 |
+
mask = segment_masks[segment_name]
|
| 433 |
+
|
| 434 |
+
# Ensure mask is right size
|
| 435 |
+
if mask.shape != result_np.shape[:2]:
|
| 436 |
+
mask = cv2.resize(mask.astype(np.float32),
|
| 437 |
+
(result_np.shape[1], result_np.shape[0]),
|
| 438 |
+
interpolation=cv2.INTER_NEAREST)
|
| 439 |
+
|
| 440 |
+
# Convert matched image to numpy
|
| 441 |
+
matched_np = np.array(matched_img.convert("RGB"))
|
| 442 |
+
|
| 443 |
+
# Ensure matched image is right size
|
| 444 |
+
if matched_np.shape[:2] != result_np.shape[:2]:
|
| 445 |
+
matched_pil = Image.fromarray(matched_np)
|
| 446 |
+
matched_pil = matched_pil.resize((result_np.shape[1], result_np.shape[0]), Image.LANCZOS)
|
| 447 |
+
matched_np = np.array(matched_pil)
|
| 448 |
+
|
| 449 |
+
# Apply mask with alpha blending
|
| 450 |
+
mask_binary = (mask > 0.5).astype(np.float32)
|
| 451 |
+
alpha = np.expand_dims(mask_binary, axis=2)
|
| 452 |
+
|
| 453 |
+
# Blend: result = result * (1 - alpha) + matched * alpha
|
| 454 |
+
result_np[:, :, :3] = (result_np[:, :, :3] * (1 - alpha) +
|
| 455 |
+
matched_np * alpha).astype(np.uint8)
|
| 456 |
+
|
| 457 |
+
pixels = np.sum(mask_binary > 0)
|
| 458 |
+
log.append(f" Composited {segment_name}: {pixels} pixels")
|
| 459 |
+
|
| 460 |
+
return Image.fromarray(result_np).convert("RGB")
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
def visualize_matching_masks(img1_path, img2_path, masks1, masks2):
|
| 464 |
+
"""
|
| 465 |
+
Create a visualization of the masks being matched between two images.
|
| 466 |
+
|
| 467 |
+
Args:
|
| 468 |
+
img1_path: Path to first image
|
| 469 |
+
img2_path: Path to second image
|
| 470 |
+
masks1: Dictionary of masks for first image {label: binary_mask}
|
| 471 |
+
masks2: Dictionary of masks for second image {label: binary_mask}
|
| 472 |
+
|
| 473 |
+
Returns:
|
| 474 |
+
A matplotlib Figure showing the matched masks
|
| 475 |
+
"""
|
| 476 |
+
# Load images
|
| 477 |
+
img1 = Image.open(img1_path).convert("RGB")
|
| 478 |
+
img2 = Image.open(img2_path).convert("RGB")
|
| 479 |
+
|
| 480 |
+
# Convert to numpy arrays
|
| 481 |
+
img1_np = np.array(img1)
|
| 482 |
+
img2_np = np.array(img2)
|
| 483 |
+
|
| 484 |
+
# Separate fine-grained human parts from regular masks
|
| 485 |
+
fine_grained_masks = {}
|
| 486 |
+
regular_masks = {}
|
| 487 |
+
|
| 488 |
+
for label, mask in masks1.items():
|
| 489 |
+
if label.startswith('l4_'): # Fine-grained human parts
|
| 490 |
+
fine_grained_masks[label] = mask
|
| 491 |
+
else:
|
| 492 |
+
regular_masks[label] = mask
|
| 493 |
+
|
| 494 |
+
# Find common labels in both regular and fine-grained masks
|
| 495 |
+
common_regular = set(regular_masks.keys()).intersection(set(masks2.keys()))
|
| 496 |
+
|
| 497 |
+
# Count fine-grained masks that are in both masks1 and masks2
|
| 498 |
+
common_fine_grained = set()
|
| 499 |
+
for label in fine_grained_masks.keys():
|
| 500 |
+
if label.startswith('l4_') and label in masks2:
|
| 501 |
+
part_name = label.replace('l4_', '')
|
| 502 |
+
common_fine_grained.add(part_name)
|
| 503 |
+
|
| 504 |
+
# Count total rows needed
|
| 505 |
+
n_regular_rows = len(common_regular)
|
| 506 |
+
n_fine_rows = len(common_fine_grained)
|
| 507 |
+
n_rows = n_regular_rows + n_fine_rows
|
| 508 |
+
|
| 509 |
+
if n_rows == 0:
|
| 510 |
+
# No common regions found
|
| 511 |
+
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
|
| 512 |
+
ax.text(0.5, 0.5, "No matching regions found between images",
|
| 513 |
+
ha='center', va='center', fontsize=14, color='white')
|
| 514 |
+
ax.axis('off')
|
| 515 |
+
return fig
|
| 516 |
+
|
| 517 |
+
fig, axes = plt.subplots(n_rows, 2, figsize=(12, 3 * n_rows))
|
| 518 |
+
|
| 519 |
+
# If only one row, reshape axes
|
| 520 |
+
if n_rows == 1:
|
| 521 |
+
axes = np.array([axes])
|
| 522 |
+
|
| 523 |
+
row_idx = 0
|
| 524 |
+
|
| 525 |
+
# Visualize regular semantic regions
|
| 526 |
+
for label in sorted(common_regular):
|
| 527 |
+
# Get label display name
|
| 528 |
+
display_name = label.replace("l2_", "").capitalize()
|
| 529 |
+
|
| 530 |
+
# Get masks and resize them to match the image dimensions
|
| 531 |
+
mask1 = regular_masks[label]
|
| 532 |
+
mask2 = masks2[label]
|
| 533 |
+
|
| 534 |
+
# Create visualizations
|
| 535 |
+
masked_img1, masked_img2 = create_mask_overlay(img1_np, img2_np, mask1, mask2, [255, 0, 0]) # Red
|
| 536 |
+
|
| 537 |
+
# Plot the masked images
|
| 538 |
+
axes[row_idx, 0].imshow(masked_img1)
|
| 539 |
+
axes[row_idx, 0].set_title(f"Image 1: {display_name}")
|
| 540 |
+
axes[row_idx, 0].axis('off')
|
| 541 |
+
|
| 542 |
+
axes[row_idx, 1].imshow(masked_img2)
|
| 543 |
+
axes[row_idx, 1].set_title(f"Image 2: {display_name}")
|
| 544 |
+
axes[row_idx, 1].axis('off')
|
| 545 |
+
|
| 546 |
+
row_idx += 1
|
| 547 |
+
|
| 548 |
+
# Visualize fine-grained human parts
|
| 549 |
+
part_colors = {
|
| 550 |
+
'face': [255, 0, 0], # Red (like other masks)
|
| 551 |
+
'hair': [255, 0, 0], # Red (like other masks)
|
| 552 |
+
'upper_clothes': [255, 0, 0] # Red (like other masks)
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
for part_name in sorted(common_fine_grained):
|
| 556 |
+
label = f'l4_{part_name}'
|
| 557 |
+
|
| 558 |
+
if label in fine_grained_masks and label in masks2:
|
| 559 |
+
mask1 = fine_grained_masks[label]
|
| 560 |
+
mask2 = masks2[label]
|
| 561 |
+
|
| 562 |
+
color = part_colors.get(part_name, [255, 0, 0]) # Default to red
|
| 563 |
+
|
| 564 |
+
# Create visualizations
|
| 565 |
+
masked_img1, masked_img2 = create_mask_overlay(img1_np, img2_np, mask1, mask2, color)
|
| 566 |
+
|
| 567 |
+
# Plot the masked images
|
| 568 |
+
display_name = part_name.replace('_', ' ').title()
|
| 569 |
+
axes[row_idx, 0].imshow(masked_img1)
|
| 570 |
+
axes[row_idx, 0].set_title(f"Image 1: {display_name} (Fine-grained)")
|
| 571 |
+
axes[row_idx, 0].axis('off')
|
| 572 |
+
|
| 573 |
+
axes[row_idx, 1].imshow(masked_img2)
|
| 574 |
+
axes[row_idx, 1].set_title(f"Image 2: {display_name} (Fine-grained)")
|
| 575 |
+
axes[row_idx, 1].axis('off')
|
| 576 |
+
|
| 577 |
+
row_idx += 1
|
| 578 |
+
|
| 579 |
+
plt.suptitle("Matched Regions (highlighted with different colors)", fontsize=16, color='white')
|
| 580 |
+
plt.tight_layout()
|
| 581 |
+
|
| 582 |
+
return fig
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def create_mask_overlay(img1_np, img2_np, mask1, mask2, overlay_color):
|
| 586 |
+
"""
|
| 587 |
+
Create mask overlays on images with the specified color.
|
| 588 |
+
|
| 589 |
+
Args:
|
| 590 |
+
img1_np: First image as numpy array
|
| 591 |
+
img2_np: Second image as numpy array
|
| 592 |
+
mask1: Mask for first image
|
| 593 |
+
mask2: Mask for second image
|
| 594 |
+
overlay_color: RGB color for overlay [R, G, B]
|
| 595 |
+
|
| 596 |
+
Returns:
|
| 597 |
+
Tuple of (masked_img1, masked_img2)
|
| 598 |
+
"""
|
| 599 |
+
# Resize masks to match image dimensions if needed
|
| 600 |
+
if mask1.shape != img1_np.shape[:2]:
|
| 601 |
+
mask1_img = Image.fromarray((mask1 * 255).astype(np.uint8))
|
| 602 |
+
mask1_img = mask1_img.resize((img1_np.shape[1], img1_np.shape[0]), Image.NEAREST)
|
| 603 |
+
mask1 = np.array(mask1_img).astype(np.float32) / 255.0
|
| 604 |
+
|
| 605 |
+
if mask2.shape != img2_np.shape[:2]:
|
| 606 |
+
mask2_img = Image.fromarray((mask2 * 255).astype(np.uint8))
|
| 607 |
+
mask2_img = mask2_img.resize((img2_np.shape[1], img2_np.shape[0]), Image.NEAREST)
|
| 608 |
+
mask2 = np.array(mask2_img).astype(np.float32) / 255.0
|
| 609 |
+
|
| 610 |
+
# Create masked versions of the images
|
| 611 |
+
masked_img1 = img1_np.copy()
|
| 612 |
+
masked_img2 = img2_np.copy()
|
| 613 |
+
|
| 614 |
+
# Apply a semi-transparent colored overlay to show the masked region
|
| 615 |
+
overlay_color = np.array(overlay_color, dtype=np.uint8)
|
| 616 |
+
|
| 617 |
+
# Create alpha channel based on the mask (with transparency)
|
| 618 |
+
alpha1 = mask1 * 0.6 # Increased opacity for better visibility
|
| 619 |
+
alpha2 = mask2 * 0.6
|
| 620 |
+
|
| 621 |
+
# Apply the colored overlay to masked regions
|
| 622 |
+
for c in range(3):
|
| 623 |
+
masked_img1[:, :, c] = masked_img1[:, :, c] * (1 - alpha1) + overlay_color[c] * alpha1
|
| 624 |
+
masked_img2[:, :, c] = masked_img2[:, :, c] * (1 - alpha2) + overlay_color[c] * alpha2
|
| 625 |
+
|
| 626 |
+
return masked_img1, masked_img2
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
def extract_semantic_masks(output):
|
| 630 |
+
"""
|
| 631 |
+
Extract binary masks for each semantic region from the LadecoOutput.
|
| 632 |
+
|
| 633 |
+
Args:
|
| 634 |
+
output: LadecoOutput from Ladeco.predict()
|
| 635 |
+
|
| 636 |
+
Returns:
|
| 637 |
+
Dictionary mapping label names to binary masks
|
| 638 |
+
"""
|
| 639 |
+
masks = {}
|
| 640 |
+
|
| 641 |
+
# Get the segmentation mask
|
| 642 |
+
seg_mask = output.masks[0].cpu().numpy()
|
| 643 |
+
|
| 644 |
+
# Process each label in level 2 (as we're visualizing at level 2)
|
| 645 |
+
for label, indices in output.ladeco2ade.items():
|
| 646 |
+
if label.startswith("l2_"):
|
| 647 |
+
# Create a binary mask for this label
|
| 648 |
+
binary_mask = np.zeros_like(seg_mask, dtype=np.float32)
|
| 649 |
+
|
| 650 |
+
# Set 1 for pixels matching this label
|
| 651 |
+
for idx in indices:
|
| 652 |
+
binary_mask[seg_mask == idx] = 1.0
|
| 653 |
+
|
| 654 |
+
# Only include labels that have some pixels in the image
|
| 655 |
+
if np.any(binary_mask):
|
| 656 |
+
masks[label] = binary_mask
|
| 657 |
+
|
| 658 |
+
return masks
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
def plot_pie(data: dict[str, float], colors=None) -> Figure:
|
| 662 |
+
fig, ax = plt.subplots()
|
| 663 |
+
|
| 664 |
+
labels = list(data.keys())
|
| 665 |
+
sizes = list(data.values())
|
| 666 |
+
|
| 667 |
+
*_, autotexts = ax.pie(sizes, labels=labels, autopct="%1.1f%%", colors=colors)
|
| 668 |
+
|
| 669 |
+
for percent_text in autotexts:
|
| 670 |
+
percent_text.set_color("k")
|
| 671 |
+
|
| 672 |
+
ax.axis("equal")
|
| 673 |
+
|
| 674 |
+
return fig
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
def choose_example(imgpath: str, target_component) -> gr.Image:
|
| 678 |
+
img = Image.open(imgpath)
|
| 679 |
+
width, height = img.size
|
| 680 |
+
ratio = 512 / max(width, height)
|
| 681 |
+
img = img.resize((int(width * ratio), int(height * ratio)))
|
| 682 |
+
return gr.Image(value=img, label="Input Image (SVG format not supported)", type="filepath")
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
css = """
|
| 686 |
+
.reference {
|
| 687 |
+
text-align: center;
|
| 688 |
+
font-size: 1.2em;
|
| 689 |
+
color: #d1d5db;
|
| 690 |
+
margin-bottom: 20px;
|
| 691 |
+
}
|
| 692 |
+
.reference a {
|
| 693 |
+
color: #FB923C;
|
| 694 |
+
text-decoration: none;
|
| 695 |
+
}
|
| 696 |
+
.reference a:hover {
|
| 697 |
+
text-decoration: underline;
|
| 698 |
+
color: #FB923C;
|
| 699 |
+
}
|
| 700 |
+
.description {
|
| 701 |
+
text-align: center;
|
| 702 |
+
font-size: 1.1em;
|
| 703 |
+
color: #d1d5db;
|
| 704 |
+
margin-bottom: 25px;
|
| 705 |
+
}
|
| 706 |
+
.footer {
|
| 707 |
+
text-align: center;
|
| 708 |
+
margin-top: 30px;
|
| 709 |
+
padding-top: 20px;
|
| 710 |
+
border-top: 1px solid #ddd;
|
| 711 |
+
color: #d1d5db;
|
| 712 |
+
font-size: 14px;
|
| 713 |
+
}
|
| 714 |
+
.main-title {
|
| 715 |
+
font-size: 24px;
|
| 716 |
+
font-weight: bold;
|
| 717 |
+
text-align: center;
|
| 718 |
+
margin-bottom: 20px;
|
| 719 |
+
}
|
| 720 |
+
.selected-image {
|
| 721 |
+
height: 756px;
|
| 722 |
+
}
|
| 723 |
+
.example-image {
|
| 724 |
+
height: 220px;
|
| 725 |
+
padding: 25px;
|
| 726 |
+
}
|
| 727 |
+
""".strip()
|
| 728 |
+
theme = gr.themes.Base(
|
| 729 |
+
primary_hue="orange",
|
| 730 |
+
secondary_hue="cyan",
|
| 731 |
+
neutral_hue="gray",
|
| 732 |
+
).set(
|
| 733 |
+
body_text_color='*neutral_100',
|
| 734 |
+
body_text_color_subdued='*neutral_600',
|
| 735 |
+
background_fill_primary='*neutral_950',
|
| 736 |
+
background_fill_secondary='*neutral_600',
|
| 737 |
+
border_color_accent='*secondary_800',
|
| 738 |
+
color_accent='*primary_50',
|
| 739 |
+
color_accent_soft='*secondary_800',
|
| 740 |
+
code_background_fill='*neutral_700',
|
| 741 |
+
block_background_fill_dark='*body_background_fill',
|
| 742 |
+
block_info_text_color='#6b7280',
|
| 743 |
+
block_label_text_color='*neutral_300',
|
| 744 |
+
block_label_text_weight='700',
|
| 745 |
+
block_title_text_color='*block_label_text_color',
|
| 746 |
+
block_title_text_weight='300',
|
| 747 |
+
panel_background_fill='*neutral_800',
|
| 748 |
+
table_text_color_dark='*secondary_800',
|
| 749 |
+
checkbox_background_color_selected='*primary_500',
|
| 750 |
+
checkbox_label_background_fill='*neutral_500',
|
| 751 |
+
checkbox_label_background_fill_hover='*neutral_700',
|
| 752 |
+
checkbox_label_text_color='*neutral_200',
|
| 753 |
+
input_background_fill='*neutral_700',
|
| 754 |
+
input_background_fill_focus='*neutral_600',
|
| 755 |
+
slider_color='*primary_500',
|
| 756 |
+
table_even_background_fill='*neutral_700',
|
| 757 |
+
table_odd_background_fill='*neutral_600',
|
| 758 |
+
table_row_focus='*neutral_800'
|
| 759 |
+
)
|
| 760 |
+
with gr.Blocks(css=css, theme=theme) as demo:
|
| 761 |
+
gr.HTML(
|
| 762 |
+
"""
|
| 763 |
+
<div class="main-title">SegMatch – Zero Shot Segmentation-based color matching</div>
|
| 764 |
+
<div class="description">
|
| 765 |
+
Advanced region-based color matching using semantic segmentation and fine-grained human parts detection for precise, contextually-aware color transfer between images.
|
| 766 |
+
</div>
|
| 767 |
+
""".strip()
|
| 768 |
+
)
|
| 769 |
+
|
| 770 |
+
with gr.Row():
|
| 771 |
+
# First image inputs
|
| 772 |
+
with gr.Column():
|
| 773 |
+
img1 = gr.Image(
|
| 774 |
+
label="First Input Image - Color Reference (SVG not supported)",
|
| 775 |
+
type="filepath",
|
| 776 |
+
height="256px",
|
| 777 |
+
)
|
| 778 |
+
gr.Label("Example Images for First Input", show_label=False)
|
| 779 |
+
with gr.Row():
|
| 780 |
+
ex1_1 = gr.Image(
|
| 781 |
+
value="examples/beach.jpg",
|
| 782 |
+
show_label=False,
|
| 783 |
+
type="filepath",
|
| 784 |
+
elem_classes="example-image",
|
| 785 |
+
interactive=False,
|
| 786 |
+
show_download_button=False,
|
| 787 |
+
show_fullscreen_button=False,
|
| 788 |
+
show_share_button=False,
|
| 789 |
+
)
|
| 790 |
+
ex1_2 = gr.Image(
|
| 791 |
+
value="examples/field.jpg",
|
| 792 |
+
show_label=False,
|
| 793 |
+
type="filepath",
|
| 794 |
+
elem_classes="example-image",
|
| 795 |
+
interactive=False,
|
| 796 |
+
show_download_button=False,
|
| 797 |
+
show_fullscreen_button=False,
|
| 798 |
+
show_share_button=False,
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
# Second image inputs
|
| 802 |
+
with gr.Column():
|
| 803 |
+
img2 = gr.Image(
|
| 804 |
+
label="Second Input Image - To Be Color Matched (SVG not supported)",
|
| 805 |
+
type="filepath",
|
| 806 |
+
height="256px",
|
| 807 |
+
)
|
| 808 |
+
gr.Label("Example Images for Second Input", show_label=False)
|
| 809 |
+
with gr.Row():
|
| 810 |
+
ex2_1 = gr.Image(
|
| 811 |
+
value="examples/field.jpg",
|
| 812 |
+
show_label=False,
|
| 813 |
+
type="filepath",
|
| 814 |
+
elem_classes="example-image",
|
| 815 |
+
interactive=False,
|
| 816 |
+
show_download_button=False,
|
| 817 |
+
show_fullscreen_button=False,
|
| 818 |
+
show_share_button=False,
|
| 819 |
+
)
|
| 820 |
+
ex2_2 = gr.Image(
|
| 821 |
+
value="examples/sky.jpg",
|
| 822 |
+
show_label=False,
|
| 823 |
+
type="filepath",
|
| 824 |
+
elem_classes="example-image",
|
| 825 |
+
interactive=False,
|
| 826 |
+
show_download_button=False,
|
| 827 |
+
show_fullscreen_button=False,
|
| 828 |
+
show_share_button=False,
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
with gr.Row():
|
| 832 |
+
with gr.Column():
|
| 833 |
+
method = gr.Dropdown(
|
| 834 |
+
label="Color Matching Method",
|
| 835 |
+
choices=["adain", "mkl", "hm", "reinhard", "mvgd", "hm-mvgd-hm", "hm-mkl-hm", "coral"],
|
| 836 |
+
value="adain",
|
| 837 |
+
info="Choose the algorithm for color matching between regions"
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
with gr.Column():
|
| 841 |
+
enable_face_matching = gr.Checkbox(
|
| 842 |
+
label="Enable Face Matching for Human Regions",
|
| 843 |
+
value=True,
|
| 844 |
+
info="Only match human regions if faces are similar (requires DeepFace)"
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
with gr.Row():
|
| 848 |
+
with gr.Column():
|
| 849 |
+
enable_edge_smoothing = gr.Checkbox(
|
| 850 |
+
label="Enable CDL Edge Smoothing",
|
| 851 |
+
value=False,
|
| 852 |
+
info="Apply CDL transform to original image using calculated parameters (see log for values)"
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
start = gr.Button("Start Analysis", variant="primary")
|
| 856 |
+
|
| 857 |
+
# Download button positioned right after the start button
|
| 858 |
+
download_btn = gr.File(
|
| 859 |
+
label="📥 Download Color-Matched Image",
|
| 860 |
+
visible=True,
|
| 861 |
+
interactive=False
|
| 862 |
+
)
|
| 863 |
+
|
| 864 |
+
with gr.Tabs():
|
| 865 |
+
with gr.TabItem("Segmentation Results"):
|
| 866 |
+
with gr.Row():
|
| 867 |
+
# First image results
|
| 868 |
+
with gr.Column():
|
| 869 |
+
gr.Label("Results for First Image", show_label=True)
|
| 870 |
+
seg1 = gr.Plot(label="Semantic Segmentation")
|
| 871 |
+
pie1 = gr.Plot(label="Element Area Ratio")
|
| 872 |
+
|
| 873 |
+
# Second image results
|
| 874 |
+
with gr.Column():
|
| 875 |
+
gr.Label("Results for Second Image", show_label=True)
|
| 876 |
+
seg2 = gr.Plot(label="Semantic Segmentation")
|
| 877 |
+
pie2 = gr.Plot(label="Element Area Ratio")
|
| 878 |
+
|
| 879 |
+
with gr.TabItem("Color Matching"):
|
| 880 |
+
gr.Markdown("""
|
| 881 |
+
### Region-Based Color Matching
|
| 882 |
+
|
| 883 |
+
This tab shows the result of matching the colors of the second image to the first image's colors,
|
| 884 |
+
but only within corresponding semantic regions. For example, sky areas in the second image are
|
| 885 |
+
matched to sky areas in the first image, while vegetation areas are matched separately.
|
| 886 |
+
|
| 887 |
+
#### Face Matching Feature:
|
| 888 |
+
When enabled, the system will detect faces within human/bio regions and only apply color matching
|
| 889 |
+
to human regions where similar faces are found in both images. This ensures that color transfer
|
| 890 |
+
only occurs between images of the same person.
|
| 891 |
+
|
| 892 |
+
#### CDL Edge Smoothing Feature:
|
| 893 |
+
When enabled, calculates Color Decision List (CDL) parameters to transform the original target image
|
| 894 |
+
towards the segment-matched result, then applies those CDL parameters to the original image. This creates
|
| 895 |
+
a "smoothed" version that maintains the original image's overall characteristics while incorporating the
|
| 896 |
+
color relationships found through segment matching.
|
| 897 |
+
|
| 898 |
+
The CDL calculation uses the simplest possible approach: matches the mean and standard deviation
|
| 899 |
+
of each color channel between the original and composited images, with simple gamma calculation
|
| 900 |
+
based on brightness relationships.
|
| 901 |
+
|
| 902 |
+
#### Available Methods:
|
| 903 |
+
- **adain**: Adaptive Instance Normalization - Matches mean and standard deviation of colors
|
| 904 |
+
- **mkl**: Monge-Kantorovich Linearization - Linear transformation of color statistics
|
| 905 |
+
- **reinhard**: Reinhard color transfer - Simple statistical approach that matches mean and standard deviation
|
| 906 |
+
- **mvgd**: Multi-Variate Gaussian Distribution - Uses color covariance matrices for more accurate matching
|
| 907 |
+
- **hm**: Histogram Matching - Matches the full color distribution histograms
|
| 908 |
+
- **hm-mvgd-hm**: Histogram + MVGD + Histogram compound method
|
| 909 |
+
- **hm-mkl-hm**: Histogram + MKL + Histogram compound method
|
| 910 |
+
- **coral**: CORAL (Color Transfer using Correlated Color Temperature) - Advanced covariance-based method for natural color transfer
|
| 911 |
+
""")
|
| 912 |
+
|
| 913 |
+
# CDL Parameters Display
|
| 914 |
+
cdl_display = gr.Textbox(
|
| 915 |
+
label="📊 CDL Parameters",
|
| 916 |
+
lines=15,
|
| 917 |
+
max_lines=20,
|
| 918 |
+
interactive=False,
|
| 919 |
+
info="Color Decision List parameters calculated when CDL edge smoothing is enabled"
|
| 920 |
+
)
|
| 921 |
+
|
| 922 |
+
face_log = gr.Textbox(
|
| 923 |
+
label="Face Matching Log",
|
| 924 |
+
lines=8,
|
| 925 |
+
max_lines=15,
|
| 926 |
+
interactive=False,
|
| 927 |
+
info="Shows details of face detection and matching process"
|
| 928 |
+
)
|
| 929 |
+
|
| 930 |
+
mask_vis = gr.Plot(label="Matched Regions Visualization")
|
| 931 |
+
comparison = gr.Plot(label="Region-Based Color Matching Result")
|
| 932 |
+
|
| 933 |
+
gr.HTML(
|
| 934 |
+
"""
|
| 935 |
+
<div class="footer">
|
| 936 |
+
© 2024 SegMatch All Rights Reserved<br>
|
| 937 |
+
Developer: Stefan Allen
|
| 938 |
+
</div>
|
| 939 |
+
""".strip()
|
| 940 |
+
)
|
| 941 |
+
|
| 942 |
+
# Connect the inference function
|
| 943 |
+
start.click(
|
| 944 |
+
fn=infer_two_images,
|
| 945 |
+
inputs=[img1, img2, method, enable_face_matching, enable_edge_smoothing],
|
| 946 |
+
outputs=[seg1, pie1, seg2, pie2, comparison, mask_vis, download_btn, face_log, cdl_display]
|
| 947 |
+
)
|
| 948 |
+
|
| 949 |
+
# Example image selection handlers
|
| 950 |
+
ex1_1.select(fn=lambda x: choose_example(x, img1), inputs=ex1_1, outputs=img1)
|
| 951 |
+
ex1_2.select(fn=lambda x: choose_example(x, img1), inputs=ex1_2, outputs=img1)
|
| 952 |
+
ex2_1.select(fn=lambda x: choose_example(x, img2), inputs=ex2_1, outputs=img2)
|
| 953 |
+
ex2_2.select(fn=lambda x: choose_example(x, img2), inputs=ex2_2, outputs=img2)
|
| 954 |
+
|
| 955 |
+
if __name__ == "__main__":
|
| 956 |
+
demo.launch()
|
cdl_smoothing.py
ADDED
|
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
CDL (Color Decision List) based edge smoothing for SegMatch
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from typing import Tuple, Optional
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import cv2
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def calculate_cdl_params_face_only(source: np.ndarray, target: np.ndarray,
|
| 13 |
+
source_face_mask: np.ndarray, target_face_mask: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 14 |
+
"""Calculate CDL parameters using only face pixels for focused accuracy.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
source (np.ndarray): Source image as numpy array (0-1 range)
|
| 18 |
+
target (np.ndarray): Target image as numpy array (0-1 range)
|
| 19 |
+
source_face_mask (np.ndarray): Binary mask of face in source image
|
| 20 |
+
target_face_mask (np.ndarray): Binary mask of face in target image
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
Tuple[np.ndarray, np.ndarray, np.ndarray]: (slope, offset, power)
|
| 24 |
+
"""
|
| 25 |
+
epsilon = 1e-6
|
| 26 |
+
|
| 27 |
+
# Extract face pixels only
|
| 28 |
+
source_face_pixels = source[source_face_mask > 0.5]
|
| 29 |
+
target_face_pixels = target[target_face_mask > 0.5]
|
| 30 |
+
|
| 31 |
+
# Ensure we have enough face pixels
|
| 32 |
+
if len(source_face_pixels) < 100 or len(target_face_pixels) < 100:
|
| 33 |
+
# Fallback to simple calculation if not enough face pixels
|
| 34 |
+
return calculate_cdl_params_simple(source, target)
|
| 35 |
+
|
| 36 |
+
slopes = []
|
| 37 |
+
offsets = []
|
| 38 |
+
powers = []
|
| 39 |
+
|
| 40 |
+
for channel in range(3):
|
| 41 |
+
src_channel = source_face_pixels[:, channel]
|
| 42 |
+
tgt_channel = target_face_pixels[:, channel]
|
| 43 |
+
|
| 44 |
+
# Use robust percentiles for face pixels
|
| 45 |
+
percentiles = [10, 25, 50, 75, 90]
|
| 46 |
+
src_percentiles = np.percentile(src_channel, percentiles)
|
| 47 |
+
tgt_percentiles = np.percentile(tgt_channel, percentiles)
|
| 48 |
+
|
| 49 |
+
# Calculate slope from face pixel range
|
| 50 |
+
src_range = src_percentiles[4] - src_percentiles[0] # 90th - 10th
|
| 51 |
+
tgt_range = tgt_percentiles[4] - tgt_percentiles[0]
|
| 52 |
+
slope = tgt_range / (src_range + epsilon)
|
| 53 |
+
|
| 54 |
+
# Calculate offset using face median
|
| 55 |
+
src_median = src_percentiles[2]
|
| 56 |
+
tgt_median = tgt_percentiles[2]
|
| 57 |
+
offset = tgt_median - (src_median * slope)
|
| 58 |
+
|
| 59 |
+
# Calculate gamma from face brightness relationship
|
| 60 |
+
src_mean = np.mean(src_channel)
|
| 61 |
+
tgt_mean = np.mean(tgt_channel)
|
| 62 |
+
|
| 63 |
+
if src_mean > epsilon:
|
| 64 |
+
power = np.log(tgt_mean + epsilon) / np.log(src_mean + epsilon)
|
| 65 |
+
power = np.clip(power, 0.3, 3.0)
|
| 66 |
+
else:
|
| 67 |
+
power = 1.0
|
| 68 |
+
|
| 69 |
+
slopes.append(slope)
|
| 70 |
+
offsets.append(offset)
|
| 71 |
+
powers.append(power)
|
| 72 |
+
|
| 73 |
+
return np.array(slopes), np.array(offsets), np.array(powers)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def calculate_cdl_params_simple(source: np.ndarray, target: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 77 |
+
"""Simple CDL calculation as fallback method.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
source (np.ndarray): Source image as numpy array (0-1 range)
|
| 81 |
+
target (np.ndarray): Target image as numpy array (0-1 range)
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
Tuple[np.ndarray, np.ndarray, np.ndarray]: (slope, offset, power)
|
| 85 |
+
"""
|
| 86 |
+
epsilon = 1e-6
|
| 87 |
+
|
| 88 |
+
# Calculate mean and standard deviation for each RGB channel
|
| 89 |
+
source_mean = np.mean(source, axis=(0, 1))
|
| 90 |
+
source_std = np.std(source, axis=(0, 1))
|
| 91 |
+
target_mean = np.mean(target, axis=(0, 1))
|
| 92 |
+
target_std = np.std(target, axis=(0, 1))
|
| 93 |
+
|
| 94 |
+
# Calculate slope (gain)
|
| 95 |
+
slope = target_std / (source_std + epsilon)
|
| 96 |
+
|
| 97 |
+
# Calculate offset
|
| 98 |
+
offset = target_mean - (source_mean * slope)
|
| 99 |
+
|
| 100 |
+
# Set power to neutral
|
| 101 |
+
power = np.ones(3)
|
| 102 |
+
|
| 103 |
+
return slope, offset, power
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def calculate_cdl_params_histogram(source: np.ndarray, target: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 107 |
+
"""Calculate CDL parameters using histogram matching approach.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
source (np.ndarray): Source image as numpy array (0-1 range)
|
| 111 |
+
target (np.ndarray): Target image as numpy array (0-1 range)
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Tuple[np.ndarray, np.ndarray, np.ndarray]: (slope, offset, power)
|
| 115 |
+
"""
|
| 116 |
+
epsilon = 1e-6
|
| 117 |
+
|
| 118 |
+
# Convert to 0-255 range for histogram calculation
|
| 119 |
+
source_255 = (source * 255).astype(np.uint8)
|
| 120 |
+
target_255 = (target * 255).astype(np.uint8)
|
| 121 |
+
|
| 122 |
+
slopes = []
|
| 123 |
+
offsets = []
|
| 124 |
+
powers = []
|
| 125 |
+
|
| 126 |
+
for channel in range(3):
|
| 127 |
+
# Calculate histograms
|
| 128 |
+
hist_source = cv2.calcHist([source_255], [channel], None, [256], [0, 256])
|
| 129 |
+
hist_target = cv2.calcHist([target_255], [channel], None, [256], [0, 256])
|
| 130 |
+
|
| 131 |
+
# Calculate cumulative distributions
|
| 132 |
+
cdf_source = np.cumsum(hist_source) / np.sum(hist_source)
|
| 133 |
+
cdf_target = np.cumsum(hist_target) / np.sum(hist_target)
|
| 134 |
+
|
| 135 |
+
# Find percentile mappings
|
| 136 |
+
p25_src = np.percentile(source[:, :, channel], 25)
|
| 137 |
+
p75_src = np.percentile(source[:, :, channel], 75)
|
| 138 |
+
p25_tgt = np.percentile(target[:, :, channel], 25)
|
| 139 |
+
p75_tgt = np.percentile(target[:, :, channel], 75)
|
| 140 |
+
|
| 141 |
+
# Calculate slope from percentile mapping
|
| 142 |
+
slope = (p75_tgt - p25_tgt) / (p75_src - p25_src + epsilon)
|
| 143 |
+
|
| 144 |
+
# Calculate offset
|
| 145 |
+
median_src = np.percentile(source[:, :, channel], 50)
|
| 146 |
+
median_tgt = np.percentile(target[:, :, channel], 50)
|
| 147 |
+
offset = median_tgt - (median_src * slope)
|
| 148 |
+
|
| 149 |
+
# Estimate power/gamma from the histogram shape
|
| 150 |
+
mean_src = np.mean(source[:, :, channel])
|
| 151 |
+
mean_tgt = np.mean(target[:, :, channel])
|
| 152 |
+
if mean_src > epsilon:
|
| 153 |
+
power = np.log(mean_tgt + epsilon) / np.log(mean_src + epsilon)
|
| 154 |
+
power = np.clip(power, 0.1, 10.0) # Reasonable gamma range
|
| 155 |
+
else:
|
| 156 |
+
power = 1.0
|
| 157 |
+
|
| 158 |
+
slopes.append(slope)
|
| 159 |
+
offsets.append(offset)
|
| 160 |
+
powers.append(power)
|
| 161 |
+
|
| 162 |
+
return np.array(slopes), np.array(offsets), np.array(powers)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def calculate_cdl_params_mask_aware(source: np.ndarray, target: np.ndarray,
|
| 166 |
+
changed_mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 167 |
+
"""Calculate CDL parameters focusing only on changed regions.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
source (np.ndarray): Source image as numpy array (0-1 range)
|
| 171 |
+
target (np.ndarray): Target image as numpy array (0-1 range)
|
| 172 |
+
changed_mask (np.ndarray, optional): Binary mask of changed regions
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
Tuple[np.ndarray, np.ndarray, np.ndarray]: (slope, offset, power)
|
| 176 |
+
"""
|
| 177 |
+
if changed_mask is not None:
|
| 178 |
+
# Only use pixels where changes occurred
|
| 179 |
+
mask_bool = changed_mask > 0.5
|
| 180 |
+
if np.sum(mask_bool) > 100: # Ensure enough pixels
|
| 181 |
+
source_masked = source[mask_bool]
|
| 182 |
+
target_masked = target[mask_bool]
|
| 183 |
+
|
| 184 |
+
# Reshape back to have channel dimension
|
| 185 |
+
source_masked = source_masked.reshape(-1, 3)
|
| 186 |
+
target_masked = target_masked.reshape(-1, 3)
|
| 187 |
+
|
| 188 |
+
# Calculate statistics on masked regions
|
| 189 |
+
epsilon = 1e-6
|
| 190 |
+
source_mean = np.mean(source_masked, axis=0)
|
| 191 |
+
source_std = np.std(source_masked, axis=0)
|
| 192 |
+
target_mean = np.mean(target_masked, axis=0)
|
| 193 |
+
target_std = np.std(target_masked, axis=0)
|
| 194 |
+
|
| 195 |
+
slope = target_std / (source_std + epsilon)
|
| 196 |
+
offset = target_mean - (source_mean * slope)
|
| 197 |
+
power = np.ones(3)
|
| 198 |
+
|
| 199 |
+
return slope, offset, power
|
| 200 |
+
|
| 201 |
+
# Fallback to simple method if mask is not useful
|
| 202 |
+
return calculate_cdl_params_simple(source, target)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def calculate_cdl_params_lab(source: np.ndarray, target: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 206 |
+
"""Calculate CDL parameters in LAB color space for better perceptual matching.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
source (np.ndarray): Source image as numpy array (0-1 range)
|
| 210 |
+
target (np.ndarray): Target image as numpy array (0-1 range)
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
Tuple[np.ndarray, np.ndarray, np.ndarray]: (slope, offset, power)
|
| 214 |
+
"""
|
| 215 |
+
# Convert to LAB color space
|
| 216 |
+
source_lab = cv2.cvtColor((source * 255).astype(np.uint8), cv2.COLOR_RGB2LAB).astype(np.float32)
|
| 217 |
+
target_lab = cv2.cvtColor((target * 255).astype(np.uint8), cv2.COLOR_RGB2LAB).astype(np.float32)
|
| 218 |
+
|
| 219 |
+
# Normalize LAB values
|
| 220 |
+
source_lab[:, :, 0] /= 100.0 # L: 0-100 -> 0-1
|
| 221 |
+
source_lab[:, :, 1] = (source_lab[:, :, 1] + 128) / 255.0 # A: -128-127 -> 0-1
|
| 222 |
+
source_lab[:, :, 2] = (source_lab[:, :, 2] + 128) / 255.0 # B: -128-127 -> 0-1
|
| 223 |
+
|
| 224 |
+
target_lab[:, :, 0] /= 100.0
|
| 225 |
+
target_lab[:, :, 1] = (target_lab[:, :, 1] + 128) / 255.0
|
| 226 |
+
target_lab[:, :, 2] = (target_lab[:, :, 2] + 128) / 255.0
|
| 227 |
+
|
| 228 |
+
# Calculate CDL in LAB space
|
| 229 |
+
epsilon = 1e-6
|
| 230 |
+
source_mean = np.mean(source_lab, axis=(0, 1))
|
| 231 |
+
source_std = np.std(source_lab, axis=(0, 1))
|
| 232 |
+
target_mean = np.mean(target_lab, axis=(0, 1))
|
| 233 |
+
target_std = np.std(target_lab, axis=(0, 1))
|
| 234 |
+
|
| 235 |
+
slope_lab = target_std / (source_std + epsilon)
|
| 236 |
+
offset_lab = target_mean - (source_mean * slope_lab)
|
| 237 |
+
|
| 238 |
+
# Convert back to RGB approximation
|
| 239 |
+
# This is a simplified conversion - for full accuracy we'd need to convert each pixel
|
| 240 |
+
slope = np.array([slope_lab[0], slope_lab[1], slope_lab[2]]) # Rough mapping
|
| 241 |
+
offset = np.array([offset_lab[0], offset_lab[1], offset_lab[2]])
|
| 242 |
+
power = np.ones(3)
|
| 243 |
+
|
| 244 |
+
return slope, offset, power
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def calculate_cdl_params(source: np.ndarray, target: np.ndarray,
|
| 248 |
+
source_path: str = None, target_path: str = None) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 249 |
+
"""Calculate CDL parameters using simple mean/std matching - the most basic approach.
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
source (np.ndarray): Source image as numpy array (0-1 range)
|
| 253 |
+
target (np.ndarray): Target image as numpy array (0-1 range)
|
| 254 |
+
source_path (str, optional): Ignored - kept for compatibility
|
| 255 |
+
target_path (str, optional): Ignored - kept for compatibility
|
| 256 |
+
|
| 257 |
+
Returns:
|
| 258 |
+
Tuple[np.ndarray, np.ndarray, np.ndarray]: (slope, offset, power)
|
| 259 |
+
"""
|
| 260 |
+
epsilon = 1e-6
|
| 261 |
+
|
| 262 |
+
# Calculate simple mean and standard deviation for each RGB channel
|
| 263 |
+
source_mean = np.mean(source, axis=(0, 1))
|
| 264 |
+
source_std = np.std(source, axis=(0, 1))
|
| 265 |
+
target_mean = np.mean(target, axis=(0, 1))
|
| 266 |
+
target_std = np.std(target, axis=(0, 1))
|
| 267 |
+
|
| 268 |
+
# Calculate slope (gain) from std ratio
|
| 269 |
+
slope = target_std / (source_std + epsilon)
|
| 270 |
+
|
| 271 |
+
# Calculate offset from mean difference
|
| 272 |
+
offset = target_mean - (source_mean * slope)
|
| 273 |
+
|
| 274 |
+
# Calculate simple gamma from brightness relationship
|
| 275 |
+
power = []
|
| 276 |
+
for channel in range(3):
|
| 277 |
+
if source_mean[channel] > epsilon:
|
| 278 |
+
gamma = np.log(target_mean[channel] + epsilon) / np.log(source_mean[channel] + epsilon)
|
| 279 |
+
gamma = np.clip(gamma, 0.1, 10.0) # Keep within reasonable bounds
|
| 280 |
+
else:
|
| 281 |
+
gamma = 1.0
|
| 282 |
+
power.append(gamma)
|
| 283 |
+
|
| 284 |
+
power = np.array(power)
|
| 285 |
+
|
| 286 |
+
return slope, offset, power
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def calculate_change_mask(original: np.ndarray, composited: np.ndarray, threshold: float = 0.05) -> np.ndarray:
|
| 290 |
+
"""Calculate a mask of significantly changed regions between original and composited images.
|
| 291 |
+
|
| 292 |
+
Args:
|
| 293 |
+
original (np.ndarray): Original image (0-1 range)
|
| 294 |
+
composited (np.ndarray): Composited result (0-1 range)
|
| 295 |
+
threshold (float): Threshold for detecting significant changes
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
np.ndarray: Binary mask of changed regions
|
| 299 |
+
"""
|
| 300 |
+
# Calculate per-pixel difference
|
| 301 |
+
diff = np.sqrt(np.sum((composited - original) ** 2, axis=2))
|
| 302 |
+
|
| 303 |
+
# Create binary mask where changes exceed threshold
|
| 304 |
+
change_mask = (diff > threshold).astype(np.float32)
|
| 305 |
+
|
| 306 |
+
# Apply morphological operations to clean up the mask
|
| 307 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
| 308 |
+
change_mask = cv2.morphologyEx(change_mask, cv2.MORPH_CLOSE, kernel)
|
| 309 |
+
|
| 310 |
+
return change_mask
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def calculate_channel_stats(array: np.ndarray) -> dict:
|
| 314 |
+
"""Calculate per-channel statistics for an image array.
|
| 315 |
+
|
| 316 |
+
Args:
|
| 317 |
+
array: Image array of shape (H, W, 3)
|
| 318 |
+
|
| 319 |
+
Returns:
|
| 320 |
+
dict: Dictionary containing mean, std, min, max for each channel
|
| 321 |
+
"""
|
| 322 |
+
stats = {
|
| 323 |
+
'mean': np.mean(array, axis=(0, 1)),
|
| 324 |
+
'std': np.std(array, axis=(0, 1)),
|
| 325 |
+
'min': np.min(array, axis=(0, 1)),
|
| 326 |
+
'max': np.max(array, axis=(0, 1))
|
| 327 |
+
}
|
| 328 |
+
return stats
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def apply_cdl_transform(image: np.ndarray, slope: np.ndarray, offset: np.ndarray, power: np.ndarray,
|
| 332 |
+
factor: float = 0.3) -> np.ndarray:
|
| 333 |
+
"""Apply CDL transformation to an image.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
image (np.ndarray): Input image (0-1 range)
|
| 337 |
+
slope (np.ndarray): CDL slope parameters for each channel
|
| 338 |
+
offset (np.ndarray): CDL offset parameters for each channel
|
| 339 |
+
power (np.ndarray): CDL power parameters for each channel
|
| 340 |
+
factor (float): Blending factor (0.0 = no change, 1.0 = full transform)
|
| 341 |
+
|
| 342 |
+
Returns:
|
| 343 |
+
np.ndarray: Transformed image
|
| 344 |
+
"""
|
| 345 |
+
# Apply CDL transform: out = ((in * slope) + offset) ** power
|
| 346 |
+
transformed = np.power(np.maximum(image * slope + offset, 0), power)
|
| 347 |
+
|
| 348 |
+
# Clamp to valid range
|
| 349 |
+
transformed = np.clip(transformed, 0.0, 1.0)
|
| 350 |
+
|
| 351 |
+
# Blend with original based on factor
|
| 352 |
+
result = (1 - factor) * image + factor * transformed
|
| 353 |
+
|
| 354 |
+
return result
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def cdl_edge_smoothing(composited_image_path: str, original_image_path: str, factor: float = 0.3) -> Image.Image:
|
| 358 |
+
"""Apply CDL-based edge smoothing between composited result and original image.
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
composited_image_path (str): Path to the composited result image
|
| 362 |
+
original_image_path (str): Path to the original target image
|
| 363 |
+
factor (float): Smoothing strength (0.0 = no smoothing, 1.0 = full smoothing)
|
| 364 |
+
|
| 365 |
+
Returns:
|
| 366 |
+
Image.Image: Smoothed result image
|
| 367 |
+
"""
|
| 368 |
+
# Load images
|
| 369 |
+
composited_img = Image.open(composited_image_path).convert("RGB")
|
| 370 |
+
original_img = Image.open(original_image_path).convert("RGB")
|
| 371 |
+
|
| 372 |
+
# Ensure same dimensions
|
| 373 |
+
if composited_img.size != original_img.size:
|
| 374 |
+
composited_img = composited_img.resize(original_img.size, Image.LANCZOS)
|
| 375 |
+
|
| 376 |
+
# Convert to numpy arrays (0-1 range)
|
| 377 |
+
composited_np = np.array(composited_img).astype(np.float32) / 255.0
|
| 378 |
+
original_np = np.array(original_img).astype(np.float32) / 255.0
|
| 379 |
+
|
| 380 |
+
# Calculate CDL parameters to transform composited to match original
|
| 381 |
+
slope, offset, power = calculate_cdl_params(composited_np, original_np)
|
| 382 |
+
|
| 383 |
+
# Apply CDL transformation with blending
|
| 384 |
+
smoothed_np = apply_cdl_transform(composited_np, slope, offset, power, factor)
|
| 385 |
+
|
| 386 |
+
# Convert back to PIL Image
|
| 387 |
+
smoothed_img = Image.fromarray((smoothed_np * 255).astype(np.uint8))
|
| 388 |
+
|
| 389 |
+
return smoothed_img
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def get_smoothing_stats(original_image_path: str, composited_image_path: str) -> dict:
|
| 393 |
+
"""Get statistics about the CDL transformation for debugging.
|
| 394 |
+
|
| 395 |
+
Args:
|
| 396 |
+
original_image_path (str): Path to the original target image
|
| 397 |
+
composited_image_path (str): Path to the composited result image
|
| 398 |
+
|
| 399 |
+
Returns:
|
| 400 |
+
dict: Statistics about the transformation
|
| 401 |
+
"""
|
| 402 |
+
# Load images
|
| 403 |
+
composited_img = Image.open(composited_image_path).convert("RGB")
|
| 404 |
+
original_img = Image.open(original_image_path).convert("RGB")
|
| 405 |
+
|
| 406 |
+
# Ensure same dimensions
|
| 407 |
+
if composited_img.size != original_img.size:
|
| 408 |
+
composited_img = composited_img.resize(original_img.size, Image.LANCZOS)
|
| 409 |
+
|
| 410 |
+
# Convert to numpy arrays (0-1 range)
|
| 411 |
+
composited_np = np.array(composited_img).astype(np.float32) / 255.0
|
| 412 |
+
original_np = np.array(original_img).astype(np.float32) / 255.0
|
| 413 |
+
|
| 414 |
+
# Calculate statistics
|
| 415 |
+
composited_stats = calculate_channel_stats(composited_np)
|
| 416 |
+
original_stats = calculate_channel_stats(original_np)
|
| 417 |
+
|
| 418 |
+
# Calculate CDL parameters using face-based method when possible
|
| 419 |
+
slope, offset, power = calculate_cdl_params(original_np, composited_np,
|
| 420 |
+
original_image_path, composited_image_path)
|
| 421 |
+
|
| 422 |
+
return {
|
| 423 |
+
'composited_stats': composited_stats,
|
| 424 |
+
'original_stats': original_stats,
|
| 425 |
+
'cdl_slope': slope,
|
| 426 |
+
'cdl_offset': offset,
|
| 427 |
+
'cdl_power': power
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def cdl_edge_smoothing_apply_to_source(source_image_path: str, target_image_path: str, factor: float = 1.0) -> Image.Image:
|
| 432 |
+
"""Apply CDL transformation to source image using face-based parameters when possible.
|
| 433 |
+
|
| 434 |
+
This function:
|
| 435 |
+
1. Calculates CDL parameters to transform source to match target (using face pixels when available)
|
| 436 |
+
2. Applies those CDL parameters to the entire source image
|
| 437 |
+
3. Returns the transformed source image
|
| 438 |
+
|
| 439 |
+
Args:
|
| 440 |
+
source_image_path (str): Path to the source image (to be transformed)
|
| 441 |
+
target_image_path (str): Path to the target image (reference for CDL calculation)
|
| 442 |
+
factor (float): Transform strength (0.0 = no change, 1.0 = full transform)
|
| 443 |
+
|
| 444 |
+
Returns:
|
| 445 |
+
Image.Image: Source image with CDL transformation applied
|
| 446 |
+
"""
|
| 447 |
+
# Load images
|
| 448 |
+
source_img = Image.open(source_image_path).convert("RGB")
|
| 449 |
+
target_img = Image.open(target_image_path).convert("RGB")
|
| 450 |
+
|
| 451 |
+
# Ensure same dimensions
|
| 452 |
+
if source_img.size != target_img.size:
|
| 453 |
+
target_img = target_img.resize(source_img.size, Image.LANCZOS)
|
| 454 |
+
|
| 455 |
+
# Convert to numpy arrays (0-1 range)
|
| 456 |
+
source_np = np.array(source_img).astype(np.float32) / 255.0
|
| 457 |
+
target_np = np.array(target_img).astype(np.float32) / 255.0
|
| 458 |
+
|
| 459 |
+
# Calculate CDL parameters using face-based method when possible
|
| 460 |
+
slope, offset, power = calculate_cdl_params(source_np, target_np,
|
| 461 |
+
source_image_path, target_image_path)
|
| 462 |
+
|
| 463 |
+
# Apply CDL transformation to the entire source image
|
| 464 |
+
transformed_np = apply_cdl_transform(source_np, slope, offset, power, factor)
|
| 465 |
+
|
| 466 |
+
# Convert back to PIL Image
|
| 467 |
+
transformed_img = Image.fromarray((transformed_np * 255).astype(np.uint8))
|
| 468 |
+
|
| 469 |
+
return transformed_img
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def extract_face_mask(image_path: str) -> Optional[np.ndarray]:
|
| 473 |
+
"""Extract face mask from an image using human parts segmentation.
|
| 474 |
+
|
| 475 |
+
Args:
|
| 476 |
+
image_path (str): Path to the image
|
| 477 |
+
|
| 478 |
+
Returns:
|
| 479 |
+
np.ndarray or None: Binary face mask, or None if no face found
|
| 480 |
+
"""
|
| 481 |
+
try:
|
| 482 |
+
from human_parts_segmentation import HumanPartsSegmentation
|
| 483 |
+
|
| 484 |
+
segmenter = HumanPartsSegmentation()
|
| 485 |
+
masks_dict = segmenter.segment_parts(image_path, ['face'])
|
| 486 |
+
|
| 487 |
+
if 'face' in masks_dict and masks_dict['face'] is not None:
|
| 488 |
+
face_mask = masks_dict['face']
|
| 489 |
+
# Ensure it's a proper binary mask
|
| 490 |
+
if np.sum(face_mask > 0.5) > 100: # At least 100 face pixels
|
| 491 |
+
return face_mask
|
| 492 |
+
|
| 493 |
+
return None
|
| 494 |
+
|
| 495 |
+
except Exception as e:
|
| 496 |
+
print(f"Face extraction failed: {e}")
|
| 497 |
+
return None
|
clothes_segmentation.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import numpy as np
|
| 5 |
+
from typing import Union, Tuple
|
| 6 |
+
from PIL import Image, ImageFilter
|
| 7 |
+
import cv2
|
| 8 |
+
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
|
| 9 |
+
from huggingface_hub import hf_hub_download
|
| 10 |
+
import shutil
|
| 11 |
+
|
| 12 |
+
# Device configuration
|
| 13 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 14 |
+
|
| 15 |
+
# Model configuration
|
| 16 |
+
AVAILABLE_MODELS = {
|
| 17 |
+
"segformer_b2_clothes": "1038lab/segformer_clothes"
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
# Model paths
|
| 21 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 22 |
+
models_dir = os.path.join(current_dir, "models")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def pil2tensor(image: Image.Image) -> torch.Tensor:
|
| 26 |
+
"""Convert PIL Image to tensor."""
|
| 27 |
+
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0)[None,]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def tensor2pil(image: torch.Tensor) -> Image.Image:
|
| 31 |
+
"""Convert tensor to PIL Image."""
|
| 32 |
+
return Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def image2mask(image: Image.Image) -> torch.Tensor:
|
| 36 |
+
"""Convert image to mask tensor."""
|
| 37 |
+
if isinstance(image, Image.Image):
|
| 38 |
+
image = pil2tensor(image)
|
| 39 |
+
return image.squeeze()[..., 0]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def mask2image(mask: torch.Tensor) -> Image.Image:
|
| 43 |
+
"""Convert mask tensor to PIL Image."""
|
| 44 |
+
if len(mask.shape) == 2:
|
| 45 |
+
mask = mask.unsqueeze(0)
|
| 46 |
+
return tensor2pil(mask)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class ClothesSegmentation:
|
| 50 |
+
"""
|
| 51 |
+
Standalone clothing segmentation using Segformer model.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self):
|
| 55 |
+
self.processor = None
|
| 56 |
+
self.model = None
|
| 57 |
+
self.cache_dir = os.path.join(models_dir, "RMBG", "segformer_clothes")
|
| 58 |
+
|
| 59 |
+
# Class mapping for segmentation - consistent with latest repo
|
| 60 |
+
self.class_map = {
|
| 61 |
+
"Background": 0, "Hat": 1, "Hair": 2, "Sunglasses": 3,
|
| 62 |
+
"Upper-clothes": 4, "Skirt": 5, "Pants": 6, "Dress": 7,
|
| 63 |
+
"Belt": 8, "Left-shoe": 9, "Right-shoe": 10, "Face": 11,
|
| 64 |
+
"Left-leg": 12, "Right-leg": 13, "Left-arm": 14, "Right-arm": 15,
|
| 65 |
+
"Bag": 16, "Scarf": 17
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
def check_model_cache(self):
|
| 69 |
+
"""Check if model files exist in cache."""
|
| 70 |
+
if not os.path.exists(self.cache_dir):
|
| 71 |
+
return False, "Model directory not found"
|
| 72 |
+
|
| 73 |
+
required_files = [
|
| 74 |
+
'config.json',
|
| 75 |
+
'model.safetensors',
|
| 76 |
+
'preprocessor_config.json'
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
missing_files = [f for f in required_files if not os.path.exists(os.path.join(self.cache_dir, f))]
|
| 80 |
+
if missing_files:
|
| 81 |
+
return False, f"Required model files missing: {', '.join(missing_files)}"
|
| 82 |
+
return True, "Model cache verified"
|
| 83 |
+
|
| 84 |
+
def clear_model(self):
|
| 85 |
+
"""Clear model from memory - improved version."""
|
| 86 |
+
if self.model is not None:
|
| 87 |
+
self.model.cpu()
|
| 88 |
+
del self.model
|
| 89 |
+
self.model = None
|
| 90 |
+
self.processor = None
|
| 91 |
+
if torch.cuda.is_available():
|
| 92 |
+
torch.cuda.empty_cache()
|
| 93 |
+
|
| 94 |
+
def download_model_files(self):
|
| 95 |
+
"""Download model files from Hugging Face - improved version."""
|
| 96 |
+
model_id = AVAILABLE_MODELS["segformer_b2_clothes"]
|
| 97 |
+
model_files = {
|
| 98 |
+
'config.json': 'config.json',
|
| 99 |
+
'model.safetensors': 'model.safetensors',
|
| 100 |
+
'preprocessor_config.json': 'preprocessor_config.json'
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
os.makedirs(self.cache_dir, exist_ok=True)
|
| 104 |
+
print(f"Downloading Clothes Segformer model files...")
|
| 105 |
+
|
| 106 |
+
try:
|
| 107 |
+
for save_name, repo_path in model_files.items():
|
| 108 |
+
print(f"Downloading {save_name}...")
|
| 109 |
+
downloaded_path = hf_hub_download(
|
| 110 |
+
repo_id=model_id,
|
| 111 |
+
filename=repo_path,
|
| 112 |
+
local_dir=self.cache_dir,
|
| 113 |
+
local_dir_use_symlinks=False
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
if os.path.dirname(downloaded_path) != self.cache_dir:
|
| 117 |
+
target_path = os.path.join(self.cache_dir, save_name)
|
| 118 |
+
shutil.move(downloaded_path, target_path)
|
| 119 |
+
return True, "Model files downloaded successfully"
|
| 120 |
+
except Exception as e:
|
| 121 |
+
return False, f"Error downloading model files: {str(e)}"
|
| 122 |
+
|
| 123 |
+
def load_model(self):
|
| 124 |
+
"""Load the clothing segmentation model - improved version."""
|
| 125 |
+
try:
|
| 126 |
+
# Check and download model if needed
|
| 127 |
+
cache_status, message = self.check_model_cache()
|
| 128 |
+
if not cache_status:
|
| 129 |
+
print(f"Cache check: {message}")
|
| 130 |
+
download_status, download_message = self.download_model_files()
|
| 131 |
+
if not download_status:
|
| 132 |
+
print(f"❌ {download_message}")
|
| 133 |
+
return False
|
| 134 |
+
|
| 135 |
+
# Load model if needed
|
| 136 |
+
if self.processor is None:
|
| 137 |
+
print("Loading clothes segmentation model...")
|
| 138 |
+
self.processor = SegformerImageProcessor.from_pretrained(self.cache_dir)
|
| 139 |
+
self.model = AutoModelForSemanticSegmentation.from_pretrained(self.cache_dir)
|
| 140 |
+
self.model.eval()
|
| 141 |
+
for param in self.model.parameters():
|
| 142 |
+
param.requires_grad = False
|
| 143 |
+
self.model.to(device)
|
| 144 |
+
print("✅ Clothes segmentation model loaded successfully")
|
| 145 |
+
|
| 146 |
+
return True
|
| 147 |
+
|
| 148 |
+
except Exception as e:
|
| 149 |
+
print(f"❌ Error loading clothes model: {e}")
|
| 150 |
+
self.clear_model() # Cleanup on error
|
| 151 |
+
return False
|
| 152 |
+
|
| 153 |
+
def segment_clothes(self, image_path: str, target_classes: list = None, process_res: int = 512) -> np.ndarray:
|
| 154 |
+
"""
|
| 155 |
+
Segment clothing from an image - improved version with process_res parameter.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
image_path: Path to the image
|
| 159 |
+
target_classes: List of clothing classes to segment (default: ["Upper-clothes"])
|
| 160 |
+
process_res: Processing resolution (default: 512)
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
Binary mask as numpy array
|
| 164 |
+
"""
|
| 165 |
+
if target_classes is None:
|
| 166 |
+
target_classes = ["Upper-clothes"]
|
| 167 |
+
|
| 168 |
+
if not self.load_model():
|
| 169 |
+
print("❌ Cannot load clothes segmentation model")
|
| 170 |
+
return None
|
| 171 |
+
|
| 172 |
+
try:
|
| 173 |
+
# Load and preprocess image
|
| 174 |
+
image = cv2.imread(image_path)
|
| 175 |
+
if image is None:
|
| 176 |
+
print(f"❌ Could not load image: {image_path}")
|
| 177 |
+
return None
|
| 178 |
+
|
| 179 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 180 |
+
original_size = image_rgb.shape[:2]
|
| 181 |
+
|
| 182 |
+
# Preprocess image with custom resolution
|
| 183 |
+
pil_image = Image.fromarray(image_rgb)
|
| 184 |
+
|
| 185 |
+
# Resize for processing if needed
|
| 186 |
+
if process_res != 512:
|
| 187 |
+
pil_image = pil_image.resize((process_res, process_res), Image.Resampling.LANCZOS)
|
| 188 |
+
|
| 189 |
+
inputs = self.processor(images=pil_image, return_tensors="pt")
|
| 190 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 191 |
+
|
| 192 |
+
# Run inference
|
| 193 |
+
with torch.no_grad():
|
| 194 |
+
outputs = self.model(**inputs)
|
| 195 |
+
logits = outputs.logits.cpu()
|
| 196 |
+
|
| 197 |
+
# Resize logits to original image size
|
| 198 |
+
upsampled_logits = nn.functional.interpolate(
|
| 199 |
+
logits,
|
| 200 |
+
size=original_size,
|
| 201 |
+
mode="bilinear",
|
| 202 |
+
align_corners=False,
|
| 203 |
+
)
|
| 204 |
+
pred_seg = upsampled_logits.argmax(dim=1)[0]
|
| 205 |
+
|
| 206 |
+
# Combine selected class masks
|
| 207 |
+
combined_mask = None
|
| 208 |
+
for class_name in target_classes:
|
| 209 |
+
if class_name in self.class_map:
|
| 210 |
+
mask = (pred_seg == self.class_map[class_name]).float()
|
| 211 |
+
if combined_mask is None:
|
| 212 |
+
combined_mask = mask
|
| 213 |
+
else:
|
| 214 |
+
combined_mask = torch.clamp(combined_mask + mask, 0, 1)
|
| 215 |
+
else:
|
| 216 |
+
print(f"⚠️ Unknown class: {class_name}")
|
| 217 |
+
|
| 218 |
+
if combined_mask is None:
|
| 219 |
+
print(f"❌ No valid classes found in: {target_classes}")
|
| 220 |
+
return None
|
| 221 |
+
|
| 222 |
+
# Convert to numpy
|
| 223 |
+
mask_np = combined_mask.numpy().astype(np.float32)
|
| 224 |
+
|
| 225 |
+
return mask_np
|
| 226 |
+
|
| 227 |
+
except Exception as e:
|
| 228 |
+
print(f"❌ Error in clothes segmentation: {e}")
|
| 229 |
+
return None
|
| 230 |
+
finally:
|
| 231 |
+
# Clean up model if not training (consistent with updated repo)
|
| 232 |
+
if self.model is not None and not self.model.training:
|
| 233 |
+
self.clear_model()
|
| 234 |
+
|
| 235 |
+
def segment_clothes_with_filters(self, image_path: str, target_classes: list = None,
|
| 236 |
+
mask_blur: int = 0, mask_offset: int = 0,
|
| 237 |
+
process_res: int = 512) -> np.ndarray:
|
| 238 |
+
"""
|
| 239 |
+
Segment clothing with additional filtering options - new method from updated repo.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
image_path: Path to the image
|
| 243 |
+
target_classes: List of clothing classes to segment
|
| 244 |
+
mask_blur: Blur amount for mask edges
|
| 245 |
+
mask_offset: Expand/Shrink mask boundary
|
| 246 |
+
process_res: Processing resolution
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
Filtered binary mask as numpy array
|
| 250 |
+
"""
|
| 251 |
+
# Get initial mask
|
| 252 |
+
mask = self.segment_clothes(image_path, target_classes, process_res)
|
| 253 |
+
if mask is None:
|
| 254 |
+
return None
|
| 255 |
+
|
| 256 |
+
try:
|
| 257 |
+
# Convert to PIL for filtering
|
| 258 |
+
mask_image = Image.fromarray((mask * 255).astype(np.uint8))
|
| 259 |
+
|
| 260 |
+
# Apply blur if specified
|
| 261 |
+
if mask_blur > 0:
|
| 262 |
+
mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=mask_blur))
|
| 263 |
+
|
| 264 |
+
# Apply offset if specified
|
| 265 |
+
if mask_offset != 0:
|
| 266 |
+
if mask_offset > 0:
|
| 267 |
+
mask_image = mask_image.filter(ImageFilter.MaxFilter(size=mask_offset * 2 + 1))
|
| 268 |
+
else:
|
| 269 |
+
mask_image = mask_image.filter(ImageFilter.MinFilter(size=-mask_offset * 2 + 1))
|
| 270 |
+
|
| 271 |
+
# Convert back to numpy
|
| 272 |
+
filtered_mask = np.array(mask_image).astype(np.float32) / 255.0
|
| 273 |
+
return filtered_mask
|
| 274 |
+
|
| 275 |
+
except Exception as e:
|
| 276 |
+
print(f"❌ Error applying filters: {e}")
|
| 277 |
+
return mask
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
# Standalone function for easy usage
|
| 281 |
+
def segment_upper_clothes(image_path: str) -> np.ndarray:
|
| 282 |
+
"""
|
| 283 |
+
Convenience function to segment upper clothes from an image.
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
image_path: Path to the image
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
Binary mask as numpy array or None if failed
|
| 290 |
+
"""
|
| 291 |
+
segmenter = ClothesSegmentation()
|
| 292 |
+
return segmenter.segment_clothes(image_path, ["Upper-clothes"])
|
color_matching.py
ADDED
|
@@ -0,0 +1,698 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import matplotlib.figure as figure
|
| 6 |
+
from matplotlib.figure import Figure
|
| 7 |
+
import numpy.typing as npt
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import tempfile
|
| 11 |
+
import time
|
| 12 |
+
|
| 13 |
+
class RegionColorMatcher:
|
| 14 |
+
def __init__(self, factor=1.0, preserve_colors=True, preserve_luminance=True, method="adain"):
|
| 15 |
+
"""
|
| 16 |
+
Initialize the RegionColorMatcher.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
factor: Strength of the color matching (0.0 to 1.0)
|
| 20 |
+
preserve_colors: If True, convert to YUV and preserve color relationships
|
| 21 |
+
preserve_luminance: If True, preserve the luminance when in YUV mode
|
| 22 |
+
method: The color matching method to use (adain, mkl, hm, reinhard, mvgd, hm-mvgd-hm, hm-mkl-hm)
|
| 23 |
+
"""
|
| 24 |
+
self.factor = factor
|
| 25 |
+
self.preserve_colors = preserve_colors
|
| 26 |
+
self.preserve_luminance = preserve_luminance
|
| 27 |
+
self.method = method
|
| 28 |
+
|
| 29 |
+
def match_regions(self, img1_path, img2_path, masks1, masks2):
|
| 30 |
+
"""
|
| 31 |
+
Match colors between corresponding masked regions of two images.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
img1_path: Path to first image
|
| 35 |
+
img2_path: Path to second image
|
| 36 |
+
masks1: Dictionary of masks for first image {label: binary_mask}
|
| 37 |
+
masks2: Dictionary of masks for second image {label: binary_mask}
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
A PIL Image with the color-matched result
|
| 41 |
+
"""
|
| 42 |
+
print(f"🎨 Color matching with method: {self.method}")
|
| 43 |
+
print(f"📊 Processing {len(masks1)} regions from img1 and {len(masks2)} regions from img2")
|
| 44 |
+
|
| 45 |
+
# Load images
|
| 46 |
+
img1 = Image.open(img1_path).convert("RGB")
|
| 47 |
+
img2 = Image.open(img2_path).convert("RGB")
|
| 48 |
+
|
| 49 |
+
# Convert to numpy arrays and normalize to [0,1]
|
| 50 |
+
img1_np = np.array(img1).astype(np.float32) / 255.0
|
| 51 |
+
img2_np = np.array(img2).astype(np.float32) / 255.0
|
| 52 |
+
|
| 53 |
+
# Create a copy of the second image as our base for color matching
|
| 54 |
+
# We want to make img2 look like img1's colors
|
| 55 |
+
result_np = np.copy(img2_np)
|
| 56 |
+
|
| 57 |
+
# Convert images to PyTorch tensors
|
| 58 |
+
img1_tensor = torch.from_numpy(img1_np)
|
| 59 |
+
img2_tensor = torch.from_numpy(img2_np)
|
| 60 |
+
result_tensor = torch.from_numpy(result_np)
|
| 61 |
+
|
| 62 |
+
# Track coverage to ensure all regions are processed
|
| 63 |
+
total_coverage = np.zeros(img2_np.shape[:2], dtype=np.float32)
|
| 64 |
+
processed_regions = 0
|
| 65 |
+
|
| 66 |
+
# Process each mask region
|
| 67 |
+
for label, mask1 in masks1.items():
|
| 68 |
+
if label not in masks2:
|
| 69 |
+
print(f"⚠️ Skipping {label} - not found in masks2")
|
| 70 |
+
continue
|
| 71 |
+
|
| 72 |
+
mask2 = masks2[label]
|
| 73 |
+
|
| 74 |
+
# Resize masks to match image dimensions if needed
|
| 75 |
+
if mask1.shape != img1_np.shape[:2]:
|
| 76 |
+
mask1 = self._resize_mask(mask1, img1_np.shape[:2])
|
| 77 |
+
|
| 78 |
+
if mask2.shape != img2_np.shape[:2]:
|
| 79 |
+
mask2 = self._resize_mask(mask2, img2_np.shape[:2])
|
| 80 |
+
|
| 81 |
+
# Check mask coverage
|
| 82 |
+
mask1_pixels = np.sum(mask1 > 0)
|
| 83 |
+
mask2_pixels = np.sum(mask2 > 0)
|
| 84 |
+
print(f"🔍 Processing {label}: {mask1_pixels} pixels (img1) → {mask2_pixels} pixels (img2)")
|
| 85 |
+
|
| 86 |
+
if mask1_pixels == 0 or mask2_pixels == 0:
|
| 87 |
+
print(f"⚠️ Skipping {label} - no pixels in mask")
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
# Track coverage
|
| 91 |
+
total_coverage += (mask2 > 0).astype(np.float32)
|
| 92 |
+
processed_regions += 1
|
| 93 |
+
|
| 94 |
+
# Convert masks to torch tensors
|
| 95 |
+
mask1_tensor = torch.from_numpy(mask1.astype(np.float32))
|
| 96 |
+
mask2_tensor = torch.from_numpy(mask2.astype(np.float32))
|
| 97 |
+
|
| 98 |
+
# Apply color matching for this region based on selected method
|
| 99 |
+
if self.method == "adain":
|
| 100 |
+
result_tensor = self._apply_adain_to_region(
|
| 101 |
+
img1_tensor,
|
| 102 |
+
img2_tensor,
|
| 103 |
+
result_tensor,
|
| 104 |
+
mask1_tensor,
|
| 105 |
+
mask2_tensor
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
result_tensor = self._apply_color_matcher_to_region(
|
| 109 |
+
img1_tensor,
|
| 110 |
+
img2_tensor,
|
| 111 |
+
result_tensor,
|
| 112 |
+
mask1_tensor,
|
| 113 |
+
mask2_tensor,
|
| 114 |
+
self.method
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
print(f"✅ Completed color matching for {label}")
|
| 118 |
+
|
| 119 |
+
# Debug coverage
|
| 120 |
+
total_pixels = img2_np.shape[0] * img2_np.shape[1]
|
| 121 |
+
covered_pixels = np.sum(total_coverage > 0)
|
| 122 |
+
overlap_pixels = np.sum(total_coverage > 1)
|
| 123 |
+
|
| 124 |
+
print(f"📊 Coverage summary:")
|
| 125 |
+
print(f" Total image pixels: {total_pixels}")
|
| 126 |
+
print(f" Covered pixels: {covered_pixels} ({100*covered_pixels/total_pixels:.1f}%)")
|
| 127 |
+
print(f" Overlapping pixels: {overlap_pixels} ({100*overlap_pixels/total_pixels:.1f}%)")
|
| 128 |
+
print(f" Processed regions: {processed_regions}")
|
| 129 |
+
|
| 130 |
+
# Convert back to numpy, scale to [0,255] and convert to uint8
|
| 131 |
+
result_np = (result_tensor.numpy() * 255.0).astype(np.uint8)
|
| 132 |
+
|
| 133 |
+
# Convert to PIL Image
|
| 134 |
+
result_img = Image.fromarray(result_np)
|
| 135 |
+
|
| 136 |
+
return result_img
|
| 137 |
+
|
| 138 |
+
def _resize_mask(self, mask, target_shape):
|
| 139 |
+
"""
|
| 140 |
+
Resize a mask to match the target shape.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
mask: Binary mask array
|
| 144 |
+
target_shape: Target shape (height, width)
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
Resized mask array
|
| 148 |
+
"""
|
| 149 |
+
# Convert to PIL Image for resizing
|
| 150 |
+
mask_img = Image.fromarray((mask * 255).astype(np.uint8))
|
| 151 |
+
|
| 152 |
+
# Resize to target shape
|
| 153 |
+
mask_img = mask_img.resize((target_shape[1], target_shape[0]), Image.NEAREST)
|
| 154 |
+
|
| 155 |
+
# Convert back to numpy array and normalize to [0,1]
|
| 156 |
+
resized_mask = np.array(mask_img).astype(np.float32) / 255.0
|
| 157 |
+
|
| 158 |
+
return resized_mask
|
| 159 |
+
|
| 160 |
+
def _apply_adain_to_region(self, source_img, target_img, result_img, source_mask, target_mask):
|
| 161 |
+
"""
|
| 162 |
+
Apply AdaIN to match the statistics of the masked region in source to the target.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
source_img: Source image tensor [H,W,3] (reference for color matching)
|
| 166 |
+
target_img: Target image tensor [H,W,3] (to be color matched)
|
| 167 |
+
result_img: Result image tensor to modify [H,W,3]
|
| 168 |
+
source_mask: Binary mask for source image [H,W]
|
| 169 |
+
target_mask: Binary mask for target image [H,W]
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
Modified result tensor
|
| 173 |
+
"""
|
| 174 |
+
# Ensure masks are binary
|
| 175 |
+
source_mask_binary = (source_mask > 0.5).float()
|
| 176 |
+
target_mask_binary = (target_mask > 0.5).float()
|
| 177 |
+
|
| 178 |
+
# If preserving colors, convert to YUV
|
| 179 |
+
if self.preserve_colors:
|
| 180 |
+
# RGB to YUV conversion matrix
|
| 181 |
+
rgb_to_yuv = torch.tensor([
|
| 182 |
+
[0.299, 0.587, 0.114],
|
| 183 |
+
[-0.14713, -0.28886, 0.436],
|
| 184 |
+
[0.615, -0.51499, -0.10001]
|
| 185 |
+
])
|
| 186 |
+
|
| 187 |
+
# Convert to YUV
|
| 188 |
+
source_yuv = torch.matmul(source_img, rgb_to_yuv.T)
|
| 189 |
+
target_yuv = torch.matmul(target_img, rgb_to_yuv.T)
|
| 190 |
+
result_yuv = torch.matmul(result_img, rgb_to_yuv.T)
|
| 191 |
+
|
| 192 |
+
# Only normalize Y channel if preserving luminance is False
|
| 193 |
+
channels_to_process = [0] if not self.preserve_luminance else []
|
| 194 |
+
|
| 195 |
+
# Always process U and V channels (chroma)
|
| 196 |
+
channels_to_process.extend([1, 2])
|
| 197 |
+
|
| 198 |
+
# Process selected channels
|
| 199 |
+
for c in channels_to_process:
|
| 200 |
+
# Apply the color matching only to the masked region in the result
|
| 201 |
+
result_channel = result_yuv[:,:,c]
|
| 202 |
+
matched_channel = self._match_channel_statistics(
|
| 203 |
+
source_yuv[:,:,c],
|
| 204 |
+
target_yuv[:,:,c],
|
| 205 |
+
result_channel,
|
| 206 |
+
source_mask_binary,
|
| 207 |
+
target_mask_binary
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# Only update the masked region in the result
|
| 211 |
+
mask_expanded = target_mask_binary.unsqueeze(-1).expand_as(result_yuv)[:,:,c]
|
| 212 |
+
result_yuv[:,:,c] = torch.where(
|
| 213 |
+
mask_expanded > 0.5,
|
| 214 |
+
matched_channel,
|
| 215 |
+
result_channel
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Convert back to RGB
|
| 219 |
+
yuv_to_rgb = torch.tensor([
|
| 220 |
+
[1.0, 0.0, 1.13983],
|
| 221 |
+
[1.0, -0.39465, -0.58060],
|
| 222 |
+
[1.0, 2.03211, 0.0]
|
| 223 |
+
])
|
| 224 |
+
|
| 225 |
+
result_rgb = torch.matmul(result_yuv, yuv_to_rgb.T)
|
| 226 |
+
|
| 227 |
+
# Only update the masked region in the result
|
| 228 |
+
mask_expanded = target_mask_binary.unsqueeze(-1).expand_as(result_img)
|
| 229 |
+
result_img = torch.where(
|
| 230 |
+
mask_expanded > 0.5,
|
| 231 |
+
result_rgb,
|
| 232 |
+
result_img
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
else:
|
| 236 |
+
# Process each RGB channel separately
|
| 237 |
+
for c in range(3):
|
| 238 |
+
# Apply the color matching only to the masked region in the result
|
| 239 |
+
result_channel = result_img[:,:,c]
|
| 240 |
+
matched_channel = self._match_channel_statistics(
|
| 241 |
+
source_img[:,:,c],
|
| 242 |
+
target_img[:,:,c],
|
| 243 |
+
result_channel,
|
| 244 |
+
source_mask_binary,
|
| 245 |
+
target_mask_binary
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# Only update the masked region in the result
|
| 249 |
+
mask_expanded = target_mask_binary.unsqueeze(-1).expand_as(result_img)[:,:,c]
|
| 250 |
+
result_img[:,:,c] = torch.where(
|
| 251 |
+
mask_expanded > 0.5,
|
| 252 |
+
matched_channel,
|
| 253 |
+
result_channel
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# Ensure values are in valid range [0, 1]
|
| 257 |
+
return torch.clamp(result_img, 0.0, 1.0)
|
| 258 |
+
|
| 259 |
+
def _apply_color_matcher_to_region(self, source_img, target_img, result_img, source_mask, target_mask, method):
|
| 260 |
+
"""
|
| 261 |
+
Apply color-matcher library methods to match the statistics of the masked region in source to the target.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
source_img: Source image tensor [H,W,3] (reference for color matching)
|
| 265 |
+
target_img: Target image tensor [H,W,3] (to be color matched)
|
| 266 |
+
result_img: Result image tensor to modify [H,W,3]
|
| 267 |
+
source_mask: Binary mask for source image [H,W]
|
| 268 |
+
target_mask: Binary mask for target image [H,W]
|
| 269 |
+
method: The color matching method to use (mkl, hm, reinhard, mvgd, hm-mvgd-hm, hm-mkl-hm)
|
| 270 |
+
|
| 271 |
+
Returns:
|
| 272 |
+
Modified result tensor
|
| 273 |
+
"""
|
| 274 |
+
# Ensure masks are binary
|
| 275 |
+
source_mask_binary = (source_mask > 0.5).float()
|
| 276 |
+
target_mask_binary = (target_mask > 0.5).float()
|
| 277 |
+
|
| 278 |
+
# Convert tensors to numpy arrays
|
| 279 |
+
source_np = source_img.detach().cpu().numpy()
|
| 280 |
+
target_np = target_img.detach().cpu().numpy()
|
| 281 |
+
source_mask_np = source_mask_binary.detach().cpu().numpy()
|
| 282 |
+
target_mask_np = target_mask_binary.detach().cpu().numpy()
|
| 283 |
+
|
| 284 |
+
try:
|
| 285 |
+
# Try to import the color_matcher library
|
| 286 |
+
try:
|
| 287 |
+
from color_matcher import ColorMatcher
|
| 288 |
+
from color_matcher.normalizer import Normalizer
|
| 289 |
+
except ImportError:
|
| 290 |
+
self._install_package("color-matcher")
|
| 291 |
+
from color_matcher import ColorMatcher
|
| 292 |
+
from color_matcher.normalizer import Normalizer
|
| 293 |
+
|
| 294 |
+
# Extract only the masked pixels from both images
|
| 295 |
+
source_coords = np.where(source_mask_np > 0.5)
|
| 296 |
+
target_coords = np.where(target_mask_np > 0.5)
|
| 297 |
+
|
| 298 |
+
if len(source_coords[0]) == 0 or len(target_coords[0]) == 0:
|
| 299 |
+
return result_img
|
| 300 |
+
|
| 301 |
+
# Extract pixel values from masked regions
|
| 302 |
+
source_pixels = source_np[source_coords]
|
| 303 |
+
target_pixels = target_np[target_coords]
|
| 304 |
+
|
| 305 |
+
# Initialize color matcher
|
| 306 |
+
cm = ColorMatcher()
|
| 307 |
+
|
| 308 |
+
if method == "mkl":
|
| 309 |
+
# For MKL, calculate mean and standard deviation from masked regions
|
| 310 |
+
source_mean = np.mean(source_pixels, axis=0)
|
| 311 |
+
source_std = np.std(source_pixels, axis=0)
|
| 312 |
+
target_mean = np.mean(target_pixels, axis=0)
|
| 313 |
+
target_std = np.std(target_pixels, axis=0)
|
| 314 |
+
|
| 315 |
+
# Apply the transformation
|
| 316 |
+
result_np = np.copy(target_np)
|
| 317 |
+
for c in range(3):
|
| 318 |
+
# Normalize the target channel and scale by source statistics
|
| 319 |
+
normalized = (target_np[:,:,c] - target_mean[c]) / (target_std[c] + 1e-8) * source_std[c] + source_mean[c]
|
| 320 |
+
|
| 321 |
+
# Only apply to masked region
|
| 322 |
+
result_np[:,:,c] = np.where(target_mask_np > 0.5, normalized, target_np[:,:,c])
|
| 323 |
+
|
| 324 |
+
# Convert back to tensor
|
| 325 |
+
result_tensor = torch.from_numpy(result_np).to(result_img.device)
|
| 326 |
+
|
| 327 |
+
# Blend with original based on factor
|
| 328 |
+
result_img = torch.lerp(result_img, result_tensor, self.factor)
|
| 329 |
+
|
| 330 |
+
elif method == "reinhard":
|
| 331 |
+
# Similar to MKL but with a different normalization approach
|
| 332 |
+
source_mean = np.mean(source_pixels, axis=0)
|
| 333 |
+
source_std = np.std(source_pixels, axis=0)
|
| 334 |
+
target_mean = np.mean(target_pixels, axis=0)
|
| 335 |
+
target_std = np.std(target_pixels, axis=0)
|
| 336 |
+
|
| 337 |
+
# Apply the transformation
|
| 338 |
+
result_np = np.copy(target_np)
|
| 339 |
+
for c in range(3):
|
| 340 |
+
# Normalize the target channel and scale by source statistics
|
| 341 |
+
normalized = (target_np[:,:,c] - target_mean[c]) / (target_std[c] + 1e-8) * source_std[c] + source_mean[c]
|
| 342 |
+
|
| 343 |
+
# Only apply to masked region
|
| 344 |
+
result_np[:,:,c] = np.where(target_mask_np > 0.5, normalized, target_np[:,:,c])
|
| 345 |
+
|
| 346 |
+
# Convert back to tensor
|
| 347 |
+
result_tensor = torch.from_numpy(result_np).to(result_img.device)
|
| 348 |
+
|
| 349 |
+
# Blend with original based on factor
|
| 350 |
+
result_img = torch.lerp(result_img, result_tensor, self.factor)
|
| 351 |
+
|
| 352 |
+
elif method == "mvgd":
|
| 353 |
+
# For MVGD, we need mean and covariance matrices
|
| 354 |
+
source_mean = np.mean(source_pixels, axis=0)
|
| 355 |
+
source_cov = np.cov(source_pixels, rowvar=False)
|
| 356 |
+
target_mean = np.mean(target_pixels, axis=0)
|
| 357 |
+
target_cov = np.cov(target_pixels, rowvar=False)
|
| 358 |
+
|
| 359 |
+
# Check if covariance matrices are valid
|
| 360 |
+
if np.isnan(source_cov).any() or np.isnan(target_cov).any():
|
| 361 |
+
# Fallback to simple statistics matching
|
| 362 |
+
source_std = np.std(source_pixels, axis=0)
|
| 363 |
+
target_std = np.std(target_pixels, axis=0)
|
| 364 |
+
|
| 365 |
+
result_np = np.copy(target_np)
|
| 366 |
+
for c in range(3):
|
| 367 |
+
normalized = (target_np[:,:,c] - target_mean[c]) / (target_std[c] + 1e-8) * source_std[c] + source_mean[c]
|
| 368 |
+
result_np[:,:,c] = np.where(target_mask_np > 0.5, normalized, target_np[:,:,c])
|
| 369 |
+
else:
|
| 370 |
+
# Apply full MVGD transformation to masked pixels
|
| 371 |
+
# Reshape the masked pixels for matrix operations
|
| 372 |
+
target_flat = target_np.reshape(-1, 3)
|
| 373 |
+
result_np = np.copy(target_np)
|
| 374 |
+
|
| 375 |
+
try:
|
| 376 |
+
# Try to compute the full MVGD transformation
|
| 377 |
+
source_cov_sqrt = np.linalg.cholesky(source_cov)
|
| 378 |
+
target_cov_sqrt = np.linalg.cholesky(target_cov)
|
| 379 |
+
target_cov_sqrt_inv = np.linalg.inv(target_cov_sqrt)
|
| 380 |
+
|
| 381 |
+
# Compute the transformation matrix
|
| 382 |
+
temp = target_cov_sqrt_inv @ source_cov @ target_cov_sqrt_inv.T
|
| 383 |
+
temp_sqrt_inv = np.linalg.inv(np.linalg.cholesky(temp))
|
| 384 |
+
A = target_cov_sqrt @ temp_sqrt_inv @ target_cov_sqrt_inv
|
| 385 |
+
|
| 386 |
+
# Apply the transformation to all pixels
|
| 387 |
+
for i in range(target_np.shape[0]):
|
| 388 |
+
for j in range(target_np.shape[1]):
|
| 389 |
+
if target_mask_np[i, j] > 0.5:
|
| 390 |
+
# Only apply to masked pixels
|
| 391 |
+
pixel = target_np[i, j]
|
| 392 |
+
centered = pixel - target_mean
|
| 393 |
+
transformed = centered @ A.T + source_mean
|
| 394 |
+
result_np[i, j] = transformed
|
| 395 |
+
except np.linalg.LinAlgError:
|
| 396 |
+
# Fallback to simple statistics matching
|
| 397 |
+
source_std = np.std(source_pixels, axis=0)
|
| 398 |
+
target_std = np.std(target_pixels, axis=0)
|
| 399 |
+
|
| 400 |
+
for c in range(3):
|
| 401 |
+
normalized = (target_np[:,:,c] - target_mean[c]) / (target_std[c] + 1e-8) * source_std[c] + source_mean[c]
|
| 402 |
+
result_np[:,:,c] = np.where(target_mask_np > 0.5, normalized, target_np[:,:,c])
|
| 403 |
+
|
| 404 |
+
# Convert back to tensor
|
| 405 |
+
result_tensor = torch.from_numpy(result_np).to(result_img.device)
|
| 406 |
+
|
| 407 |
+
# Blend with original based on factor
|
| 408 |
+
result_img = torch.lerp(result_img, result_tensor, self.factor)
|
| 409 |
+
|
| 410 |
+
elif method in ["hm", "hm-mvgd-hm", "hm-mkl-hm"]:
|
| 411 |
+
# For histogram-based methods, we'll create temporary cropped images with just the masked regions
|
| 412 |
+
|
| 413 |
+
# Get the bounding box of the masked regions
|
| 414 |
+
source_min_y, source_min_x = np.min(source_coords[0]), np.min(source_coords[1])
|
| 415 |
+
source_max_y, source_max_x = np.max(source_coords[0]), np.max(source_coords[1])
|
| 416 |
+
target_min_y, target_min_x = np.min(target_coords[0]), np.min(target_coords[1])
|
| 417 |
+
target_max_y, target_max_x = np.max(target_coords[0]), np.max(target_coords[1])
|
| 418 |
+
|
| 419 |
+
# Create cropped images with just the masked regions
|
| 420 |
+
source_crop = source_np[source_min_y:source_max_y+1, source_min_x:source_max_x+1].copy()
|
| 421 |
+
target_crop = target_np[target_min_y:target_max_y+1, target_min_x:target_max_x+1].copy()
|
| 422 |
+
|
| 423 |
+
# Create cropped masks
|
| 424 |
+
source_mask_crop = source_mask_np[source_min_y:source_max_y+1, source_min_x:source_max_x+1]
|
| 425 |
+
target_mask_crop = target_mask_np[target_min_y:target_max_y+1, target_min_x:target_max_x+1]
|
| 426 |
+
|
| 427 |
+
# Apply the mask to the cropped images
|
| 428 |
+
# For non-masked areas, use the average color
|
| 429 |
+
source_avg_color = np.mean(source_pixels, axis=0)
|
| 430 |
+
target_avg_color = np.mean(target_pixels, axis=0)
|
| 431 |
+
|
| 432 |
+
for c in range(3):
|
| 433 |
+
source_crop[:, :, c] = np.where(source_mask_crop > 0.5, source_crop[:, :, c], source_avg_color[c])
|
| 434 |
+
target_crop[:, :, c] = np.where(target_mask_crop > 0.5, target_crop[:, :, c], target_avg_color[c])
|
| 435 |
+
|
| 436 |
+
try:
|
| 437 |
+
# Use the color matcher directly on the masked regions
|
| 438 |
+
matched_crop = cm.transfer(src=target_crop, ref=source_crop, method=method)
|
| 439 |
+
|
| 440 |
+
# Apply the matched colors back to the original image, only in the masked region
|
| 441 |
+
result_np = np.copy(target_np)
|
| 442 |
+
|
| 443 |
+
# Create a mapping from crop coordinates to original image coordinates
|
| 444 |
+
for i in range(target_crop.shape[0]):
|
| 445 |
+
for j in range(target_crop.shape[1]):
|
| 446 |
+
orig_i = target_min_y + i
|
| 447 |
+
orig_j = target_min_x + j
|
| 448 |
+
if orig_i < target_np.shape[0] and orig_j < target_np.shape[1] and target_mask_np[orig_i, orig_j] > 0.5:
|
| 449 |
+
result_np[orig_i, orig_j] = matched_crop[i, j]
|
| 450 |
+
|
| 451 |
+
# Convert back to tensor
|
| 452 |
+
result_tensor = torch.from_numpy(result_np).to(result_img.device)
|
| 453 |
+
|
| 454 |
+
# Blend with original based on factor
|
| 455 |
+
result_img = torch.lerp(result_img, result_tensor, self.factor)
|
| 456 |
+
|
| 457 |
+
except Exception as e:
|
| 458 |
+
# Fallback to AdaIN if color matcher fails
|
| 459 |
+
print(f"Color matcher failed for {method}, using fallback: {str(e)}")
|
| 460 |
+
result_img = self._apply_adain_to_region(
|
| 461 |
+
source_img,
|
| 462 |
+
target_img,
|
| 463 |
+
result_img,
|
| 464 |
+
source_mask_binary,
|
| 465 |
+
target_mask_binary
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
elif method == "coral":
|
| 469 |
+
# For CORAL method, extract masked regions and apply CORAL color transfer
|
| 470 |
+
try:
|
| 471 |
+
# Create masked versions of the images
|
| 472 |
+
source_masked = source_np.copy()
|
| 473 |
+
target_masked = target_np.copy()
|
| 474 |
+
|
| 475 |
+
# Apply masks - set non-masked areas to average color
|
| 476 |
+
source_avg_color = np.mean(source_pixels, axis=0)
|
| 477 |
+
target_avg_color = np.mean(target_pixels, axis=0)
|
| 478 |
+
|
| 479 |
+
for c in range(3):
|
| 480 |
+
source_masked[:, :, c] = np.where(source_mask_np > 0.5, source_masked[:, :, c], source_avg_color[c])
|
| 481 |
+
target_masked[:, :, c] = np.where(target_mask_np > 0.5, target_masked[:, :, c], target_avg_color[c])
|
| 482 |
+
|
| 483 |
+
# Convert to torch tensors and rearrange to [C, H, W]
|
| 484 |
+
source_tensor = torch.from_numpy(source_masked).permute(2, 0, 1).float()
|
| 485 |
+
target_tensor = torch.from_numpy(target_masked).permute(2, 0, 1).float()
|
| 486 |
+
|
| 487 |
+
# Apply CORAL color transfer
|
| 488 |
+
matched_tensor = coral(target_tensor, source_tensor) # target gets matched to source
|
| 489 |
+
|
| 490 |
+
# Convert back to [H, W, C] format
|
| 491 |
+
matched_np = matched_tensor.permute(1, 2, 0).numpy()
|
| 492 |
+
|
| 493 |
+
# Apply the matched colors back to the original image, only in the masked region
|
| 494 |
+
result_np = np.copy(target_np)
|
| 495 |
+
for c in range(3):
|
| 496 |
+
result_np[:, :, c] = np.where(target_mask_np > 0.5, matched_np[:, :, c], target_np[:, :, c])
|
| 497 |
+
|
| 498 |
+
# Convert back to tensor
|
| 499 |
+
result_tensor = torch.from_numpy(result_np).to(result_img.device)
|
| 500 |
+
|
| 501 |
+
# Blend with original based on factor
|
| 502 |
+
result_img = torch.lerp(result_img, result_tensor, self.factor)
|
| 503 |
+
|
| 504 |
+
except Exception as e:
|
| 505 |
+
# Fallback to AdaIN if CORAL fails
|
| 506 |
+
print(f"CORAL failed for {method}, using fallback: {str(e)}")
|
| 507 |
+
result_img = self._apply_adain_to_region(
|
| 508 |
+
source_img,
|
| 509 |
+
target_img,
|
| 510 |
+
result_img,
|
| 511 |
+
source_mask_binary,
|
| 512 |
+
target_mask_binary
|
| 513 |
+
)
|
| 514 |
+
else:
|
| 515 |
+
# Default to AdaIN for unsupported methods
|
| 516 |
+
result_img = self._apply_adain_to_region(
|
| 517 |
+
source_img,
|
| 518 |
+
target_img,
|
| 519 |
+
result_img,
|
| 520 |
+
source_mask_binary,
|
| 521 |
+
target_mask_binary
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
except Exception as e:
|
| 525 |
+
# If all fails, fallback to AdaIN
|
| 526 |
+
print(f"Error in color matching: {str(e)}, using AdaIN as fallback")
|
| 527 |
+
result_img = self._apply_adain_to_region(
|
| 528 |
+
source_img,
|
| 529 |
+
target_img,
|
| 530 |
+
result_img,
|
| 531 |
+
source_mask_binary,
|
| 532 |
+
target_mask_binary
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
return torch.clamp(result_img, 0.0, 1.0)
|
| 536 |
+
|
| 537 |
+
def _match_channel_statistics(self, source_channel, target_channel, result_channel, source_mask, target_mask):
|
| 538 |
+
"""
|
| 539 |
+
Match the statistics of a single channel.
|
| 540 |
+
|
| 541 |
+
Args:
|
| 542 |
+
source_channel: Source channel [H,W] (reference for color matching)
|
| 543 |
+
target_channel: Target channel [H,W] (to be color matched)
|
| 544 |
+
result_channel: Result channel to modify [H,W]
|
| 545 |
+
source_mask: Binary mask for source [H,W]
|
| 546 |
+
target_mask: Binary mask for target [H,W]
|
| 547 |
+
|
| 548 |
+
Returns:
|
| 549 |
+
Modified result channel
|
| 550 |
+
"""
|
| 551 |
+
# Count non-zero elements in masks
|
| 552 |
+
source_count = torch.sum(source_mask)
|
| 553 |
+
target_count = torch.sum(target_mask)
|
| 554 |
+
|
| 555 |
+
if source_count > 0 and target_count > 0:
|
| 556 |
+
# Calculate statistics only from masked regions
|
| 557 |
+
source_masked = source_channel * source_mask
|
| 558 |
+
target_masked = target_channel * target_mask
|
| 559 |
+
|
| 560 |
+
# Calculate mean
|
| 561 |
+
source_mean = torch.sum(source_masked) / source_count
|
| 562 |
+
target_mean = torch.sum(target_masked) / target_count
|
| 563 |
+
|
| 564 |
+
# Calculate variance
|
| 565 |
+
source_var = torch.sum(((source_channel - source_mean) * source_mask) ** 2) / source_count
|
| 566 |
+
target_var = torch.sum(((target_channel - target_mean) * target_mask) ** 2) / target_count
|
| 567 |
+
|
| 568 |
+
# Calculate std (add small epsilon to avoid division by zero)
|
| 569 |
+
source_std = torch.sqrt(source_var + 1e-8)
|
| 570 |
+
target_std = torch.sqrt(target_var + 1e-8)
|
| 571 |
+
|
| 572 |
+
# Apply AdaIN to the masked region
|
| 573 |
+
normalized = ((target_channel - target_mean) / target_std) * source_std + source_mean
|
| 574 |
+
|
| 575 |
+
# Blend with original based on factor
|
| 576 |
+
result = torch.lerp(target_channel, normalized, self.factor)
|
| 577 |
+
|
| 578 |
+
return result
|
| 579 |
+
|
| 580 |
+
return result_channel
|
| 581 |
+
|
| 582 |
+
def _install_package(self, package_name):
|
| 583 |
+
"""Install a package using pip."""
|
| 584 |
+
import subprocess
|
| 585 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", package_name])
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
def create_comparison_figure(original_img, matched_img, title="Color Matching Comparison"):
|
| 589 |
+
"""
|
| 590 |
+
Create a matplotlib figure with the original and color-matched images.
|
| 591 |
+
|
| 592 |
+
Args:
|
| 593 |
+
original_img: Original PIL Image
|
| 594 |
+
matched_img: Color-matched PIL Image
|
| 595 |
+
title: Title for the figure
|
| 596 |
+
|
| 597 |
+
Returns:
|
| 598 |
+
matplotlib Figure
|
| 599 |
+
"""
|
| 600 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
|
| 601 |
+
|
| 602 |
+
ax1.imshow(original_img)
|
| 603 |
+
ax1.set_title("Original")
|
| 604 |
+
ax1.axis('off')
|
| 605 |
+
|
| 606 |
+
ax2.imshow(matched_img)
|
| 607 |
+
ax2.set_title("Color Matched")
|
| 608 |
+
ax2.axis('off')
|
| 609 |
+
|
| 610 |
+
plt.suptitle(title)
|
| 611 |
+
plt.tight_layout()
|
| 612 |
+
|
| 613 |
+
return fig
|
| 614 |
+
|
| 615 |
+
def coral(source, target):
|
| 616 |
+
"""
|
| 617 |
+
CORAL (Color Transfer using Correlated Color Temperature) implementation.
|
| 618 |
+
Based on the original ColorMatchImage approach.
|
| 619 |
+
|
| 620 |
+
Args:
|
| 621 |
+
source: Source image tensor [C, H, W] (to be color matched)
|
| 622 |
+
target: Target image tensor [C, H, W] (reference for color matching)
|
| 623 |
+
|
| 624 |
+
Returns:
|
| 625 |
+
Color-matched source image tensor [C, H, W]
|
| 626 |
+
"""
|
| 627 |
+
# Ensure tensors are float
|
| 628 |
+
source = source.float()
|
| 629 |
+
target = target.float()
|
| 630 |
+
|
| 631 |
+
# Reshape to [C, N] where N is number of pixels
|
| 632 |
+
C, H, W = source.shape
|
| 633 |
+
source_flat = source.view(C, -1) # [C, H*W]
|
| 634 |
+
target_flat = target.view(C, -1) # [C, H*W]
|
| 635 |
+
|
| 636 |
+
# Compute means
|
| 637 |
+
source_mean = torch.mean(source_flat, dim=1, keepdim=True) # [C, 1]
|
| 638 |
+
target_mean = torch.mean(target_flat, dim=1, keepdim=True) # [C, 1]
|
| 639 |
+
|
| 640 |
+
# Center the data
|
| 641 |
+
source_centered = source_flat - source_mean # [C, H*W]
|
| 642 |
+
target_centered = target_flat - target_mean # [C, H*W]
|
| 643 |
+
|
| 644 |
+
# Compute covariance matrices
|
| 645 |
+
N = source_centered.shape[1]
|
| 646 |
+
source_cov = torch.mm(source_centered, source_centered.t()) / (N - 1) # [C, C]
|
| 647 |
+
target_cov = torch.mm(target_centered, target_centered.t()) / (N - 1) # [C, C]
|
| 648 |
+
|
| 649 |
+
# Add small epsilon to diagonal for numerical stability
|
| 650 |
+
eps = 1e-5
|
| 651 |
+
source_cov += eps * torch.eye(C, device=source.device)
|
| 652 |
+
target_cov += eps * torch.eye(C, device=source.device)
|
| 653 |
+
|
| 654 |
+
try:
|
| 655 |
+
# Compute the transformation matrix using Cholesky decomposition
|
| 656 |
+
# This is more stable than eigendecomposition for positive definite matrices
|
| 657 |
+
|
| 658 |
+
# Cholesky decomposition: A = L * L^T
|
| 659 |
+
source_chol = torch.linalg.cholesky(source_cov) # Lower triangular
|
| 660 |
+
target_chol = torch.linalg.cholesky(target_cov) # Lower triangular
|
| 661 |
+
|
| 662 |
+
# Compute the transformation matrix
|
| 663 |
+
# We want to transform source covariance to target covariance
|
| 664 |
+
# Transform = target_chol * source_chol^(-1)
|
| 665 |
+
source_chol_inv = torch.linalg.inv(source_chol)
|
| 666 |
+
transform_matrix = torch.mm(target_chol, source_chol_inv)
|
| 667 |
+
|
| 668 |
+
# Apply transformation: result = transform_matrix * (source - source_mean) + target_mean
|
| 669 |
+
result_centered = torch.mm(transform_matrix, source_centered)
|
| 670 |
+
result_flat = result_centered + target_mean
|
| 671 |
+
|
| 672 |
+
# Reshape back to original shape
|
| 673 |
+
result = result_flat.view(C, H, W)
|
| 674 |
+
|
| 675 |
+
# Clamp to valid range
|
| 676 |
+
result = torch.clamp(result, 0.0, 1.0)
|
| 677 |
+
|
| 678 |
+
return result
|
| 679 |
+
|
| 680 |
+
except Exception as e:
|
| 681 |
+
# Fallback to simple mean/std matching if Cholesky fails
|
| 682 |
+
print(f"CORAL Cholesky failed, using simple statistics matching: {e}")
|
| 683 |
+
|
| 684 |
+
# Simple per-channel statistics matching
|
| 685 |
+
source_std = torch.std(source_centered, dim=1, keepdim=True)
|
| 686 |
+
target_std = torch.std(target_centered, dim=1, keepdim=True)
|
| 687 |
+
|
| 688 |
+
# Avoid division by zero
|
| 689 |
+
source_std = torch.clamp(source_std, min=eps)
|
| 690 |
+
|
| 691 |
+
# Apply simple transformation: (source - source_mean) / source_std * target_std + target_mean
|
| 692 |
+
result_flat = (source_centered / source_std) * target_std + target_mean
|
| 693 |
+
result = result_flat.view(C, H, W)
|
| 694 |
+
|
| 695 |
+
# Clamp to valid range
|
| 696 |
+
result = torch.clamp(result, 0.0, 1.0)
|
| 697 |
+
|
| 698 |
+
return result
|
core.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Core part of LaDeco v2
|
| 2 |
+
|
| 3 |
+
Example usage:
|
| 4 |
+
>>> from core import Ladeco
|
| 5 |
+
>>> from PIL import Image
|
| 6 |
+
>>> from pathlib import Path
|
| 7 |
+
>>>
|
| 8 |
+
>>> # predict
|
| 9 |
+
>>> ldc = Ladeco()
|
| 10 |
+
>>> imgs = (thing for thing in Path("example").glob("*.jpg"))
|
| 11 |
+
>>> out = ldc.predict(imgs)
|
| 12 |
+
>>>
|
| 13 |
+
>>> # output - visualization
|
| 14 |
+
>>> segs = out.visualize(level=2)
|
| 15 |
+
>>> segs[0].image.show()
|
| 16 |
+
>>>
|
| 17 |
+
>>> # output - element area
|
| 18 |
+
>>> area = out.area()
|
| 19 |
+
>>> area[0]
|
| 20 |
+
{"fid": "example/.jpg", "l1_nature": 0.673, "l1_man_made": 0.241, ...}
|
| 21 |
+
"""
|
| 22 |
+
from matplotlib.patches import Rectangle
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from PIL import Image
|
| 25 |
+
from transformers import AutoModelForUniversalSegmentation, AutoProcessor
|
| 26 |
+
import math
|
| 27 |
+
import matplotlib as mpl
|
| 28 |
+
import matplotlib.pyplot as plt
|
| 29 |
+
import numpy as np
|
| 30 |
+
import torch
|
| 31 |
+
from functools import lru_cache
|
| 32 |
+
from matplotlib.figure import Figure
|
| 33 |
+
import numpy.typing as npt
|
| 34 |
+
from typing import Iterable, NamedTuple, Generator
|
| 35 |
+
from tqdm import tqdm
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class LadecoVisualization(NamedTuple):
|
| 39 |
+
filename: str
|
| 40 |
+
image: Figure
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class Ladeco:
|
| 44 |
+
|
| 45 |
+
def __init__(self,
|
| 46 |
+
model_name: str = "shi-labs/oneformer_ade20k_swin_large",
|
| 47 |
+
area_threshold: float = 0.01,
|
| 48 |
+
device: str | None = None,
|
| 49 |
+
):
|
| 50 |
+
if device is None:
|
| 51 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 52 |
+
else:
|
| 53 |
+
self.device = device
|
| 54 |
+
|
| 55 |
+
self.processor = AutoProcessor.from_pretrained(model_name)
|
| 56 |
+
self.model = AutoModelForUniversalSegmentation.from_pretrained(model_name).to(self.device)
|
| 57 |
+
|
| 58 |
+
self.area_threshold = area_threshold
|
| 59 |
+
|
| 60 |
+
self.ade20k_labels = {
|
| 61 |
+
name.strip(): int(idx)
|
| 62 |
+
for name, idx in self.model.config.label2id.items()
|
| 63 |
+
}
|
| 64 |
+
self.ladeco2ade20k: dict[str, tuple[int]] = _get_ladeco_labels(self.ade20k_labels)
|
| 65 |
+
|
| 66 |
+
def predict(
|
| 67 |
+
self, image_paths: str | Path | Iterable[str | Path], show_progress: bool = False
|
| 68 |
+
) -> "LadecoOutput":
|
| 69 |
+
if isinstance(image_paths, (str, Path)):
|
| 70 |
+
imgpaths = [image_paths]
|
| 71 |
+
else:
|
| 72 |
+
imgpaths = list(image_paths)
|
| 73 |
+
|
| 74 |
+
images = (
|
| 75 |
+
Image.open(img_path).convert("RGB")
|
| 76 |
+
for img_path in imgpaths
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# batch inference functionality of OneFormer is broken
|
| 80 |
+
masks: list[torch.Tensor] = []
|
| 81 |
+
for img in tqdm(images, total=len(imgpaths), desc="Segmenting", disable=not show_progress):
|
| 82 |
+
samples = self.processor(
|
| 83 |
+
images=img, task_inputs=["semantic"], return_tensors="pt"
|
| 84 |
+
).to(self.device)
|
| 85 |
+
|
| 86 |
+
with torch.no_grad():
|
| 87 |
+
outputs = self.model(**samples)
|
| 88 |
+
|
| 89 |
+
masks.append(
|
| 90 |
+
self.processor.post_process_semantic_segmentation(outputs)[0]
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
return LadecoOutput(imgpaths, masks, self.ladeco2ade20k, self.area_threshold)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class LadecoOutput:
|
| 97 |
+
|
| 98 |
+
def __init__(
|
| 99 |
+
self,
|
| 100 |
+
filenames: list[str | Path],
|
| 101 |
+
masks: torch.Tensor,
|
| 102 |
+
ladeco2ade: dict[str, tuple[int]],
|
| 103 |
+
threshold: float,
|
| 104 |
+
):
|
| 105 |
+
self.filenames = filenames
|
| 106 |
+
self.masks = masks
|
| 107 |
+
self.ladeco2ade: dict[str, tuple[int]] = ladeco2ade
|
| 108 |
+
self.ade2ladeco: dict[int, str] = {
|
| 109 |
+
idx: label
|
| 110 |
+
for label, indices in self.ladeco2ade.items()
|
| 111 |
+
for idx in indices
|
| 112 |
+
}
|
| 113 |
+
self.threshold = threshold
|
| 114 |
+
|
| 115 |
+
def visualize(self, level: int) -> list[LadecoVisualization]:
|
| 116 |
+
return list(self.ivisualize(level))
|
| 117 |
+
|
| 118 |
+
def ivisualize(self, level: int) -> Generator[LadecoVisualization, None, None]:
|
| 119 |
+
colormaps = self.color_map(level)
|
| 120 |
+
labelnames = [name for name in self.ladeco2ade if name.startswith(f"l{level}")]
|
| 121 |
+
|
| 122 |
+
for fname, mask in zip(self.filenames, self.masks):
|
| 123 |
+
size = mask.shape + (3,) # (H, W, RGB)
|
| 124 |
+
vis = torch.zeros(size, dtype=torch.uint8)
|
| 125 |
+
for name in labelnames:
|
| 126 |
+
for idx in self.ladeco2ade[name]:
|
| 127 |
+
color = torch.tensor(colormaps[name] * 255, dtype=torch.uint8)
|
| 128 |
+
vis[mask == idx] = color
|
| 129 |
+
|
| 130 |
+
with Image.open(fname) as img:
|
| 131 |
+
target_size = img.size
|
| 132 |
+
vis = Image.fromarray(vis.numpy(), mode="RGB").resize(target_size)
|
| 133 |
+
|
| 134 |
+
fig, ax = plt.subplots()
|
| 135 |
+
ax.imshow(vis)
|
| 136 |
+
ax.axis('off')
|
| 137 |
+
|
| 138 |
+
yield LadecoVisualization(filename=str(fname), image=fig)
|
| 139 |
+
|
| 140 |
+
def area(self) -> list[dict[str, float | str]]:
|
| 141 |
+
return list(self.iarea())
|
| 142 |
+
|
| 143 |
+
def iarea(self) -> Generator[dict[str, float | str], None, None]:
|
| 144 |
+
n_label_ADE20k = 150
|
| 145 |
+
for filename, mask in zip(self.filenames, self.masks):
|
| 146 |
+
ade_ratios = torch.tensor([(mask == i).count_nonzero() / mask.numel() for i in range(n_label_ADE20k)])
|
| 147 |
+
#breakpoint()
|
| 148 |
+
ldc_ratios: dict[str, float] = {
|
| 149 |
+
label: round(ade_ratios[list(ade_indices)].sum().item(), 4)
|
| 150 |
+
for label, ade_indices in self.ladeco2ade.items()
|
| 151 |
+
}
|
| 152 |
+
ldc_ratios: dict[str, float] = {
|
| 153 |
+
label: 0 if ratio < self.threshold else ratio
|
| 154 |
+
for label, ratio in ldc_ratios.items()
|
| 155 |
+
}
|
| 156 |
+
others = round(1 - ldc_ratios["l1_nature"] - ldc_ratios["l1_man_made"], 4)
|
| 157 |
+
nfi = round(ldc_ratios["l1_nature"]/ (ldc_ratios["l1_nature"] + ldc_ratios.get("l1_man_made", 0) + 1e-6), 4)
|
| 158 |
+
|
| 159 |
+
yield {
|
| 160 |
+
"fid": str(filename), **ldc_ratios, "others": others, "LC_NFI": nfi,
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
def color_map(self, level: int) -> dict[str, npt.NDArray[np.float64]]:
|
| 164 |
+
"returns {'label_name': (R, G, B), ...}, where (R, G, B) in range [0, 1]"
|
| 165 |
+
labels = [
|
| 166 |
+
name for name in self.ladeco2ade.keys() if name.startswith(f"l{level}")
|
| 167 |
+
]
|
| 168 |
+
if len(labels) == 0:
|
| 169 |
+
raise RuntimeError(
|
| 170 |
+
f"LaDeco only has 4 levels in 1, 2, 3, 4. You assigned {level}."
|
| 171 |
+
)
|
| 172 |
+
colormap = mpl.colormaps["viridis"].resampled(len(labels)).colors[:, :-1]
|
| 173 |
+
# [:, :-1]: discard alpha channel
|
| 174 |
+
return {name: color for name, color in zip(labels, colormap)}
|
| 175 |
+
|
| 176 |
+
def color_legend(self, level: int) -> Figure:
|
| 177 |
+
colors = self.color_map(level)
|
| 178 |
+
|
| 179 |
+
match level:
|
| 180 |
+
case 1:
|
| 181 |
+
ncols = 1
|
| 182 |
+
case 2:
|
| 183 |
+
ncols = 1
|
| 184 |
+
case 3:
|
| 185 |
+
ncols = 2
|
| 186 |
+
case 4:
|
| 187 |
+
ncols = 5
|
| 188 |
+
|
| 189 |
+
cell_width = 212
|
| 190 |
+
cell_height = 22
|
| 191 |
+
swatch_width = 48
|
| 192 |
+
margin = 12
|
| 193 |
+
|
| 194 |
+
nrows = math.ceil(len(colors) / ncols)
|
| 195 |
+
|
| 196 |
+
width = cell_width * ncols + 2 * margin
|
| 197 |
+
height = cell_height * nrows + 2 * margin
|
| 198 |
+
dpi = 72
|
| 199 |
+
|
| 200 |
+
fig, ax = plt.subplots(figsize=(width / dpi, height / dpi), dpi=dpi)
|
| 201 |
+
fig.subplots_adjust(margin/width, margin/height,
|
| 202 |
+
(width-margin)/width, (height-margin*2)/height)
|
| 203 |
+
ax.set_xlim(0, cell_width * ncols)
|
| 204 |
+
ax.set_ylim(cell_height * (nrows-0.5), -cell_height/2.)
|
| 205 |
+
ax.yaxis.set_visible(False)
|
| 206 |
+
ax.xaxis.set_visible(False)
|
| 207 |
+
ax.set_axis_off()
|
| 208 |
+
|
| 209 |
+
for i, name in enumerate(colors):
|
| 210 |
+
row = i % nrows
|
| 211 |
+
col = i // nrows
|
| 212 |
+
y = row * cell_height
|
| 213 |
+
|
| 214 |
+
swatch_start_x = cell_width * col
|
| 215 |
+
text_pos_x = cell_width * col + swatch_width + 7
|
| 216 |
+
|
| 217 |
+
ax.text(text_pos_x, y, name, fontsize=14,
|
| 218 |
+
horizontalalignment='left',
|
| 219 |
+
verticalalignment='center')
|
| 220 |
+
|
| 221 |
+
ax.add_patch(
|
| 222 |
+
Rectangle(xy=(swatch_start_x, y-9), width=swatch_width,
|
| 223 |
+
height=18, facecolor=colors[name], edgecolor='0.7')
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
ax.set_title(f"LaDeco Color Legend - Level {level}")
|
| 227 |
+
|
| 228 |
+
return fig
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def _get_ladeco_labels(ade20k: dict[str, int]) -> dict[str, tuple[int]]:
|
| 232 |
+
labels = {
|
| 233 |
+
# level 4 labels
|
| 234 |
+
# under l3_architecture
|
| 235 |
+
"l4_hovel": (ade20k["hovel, hut, hutch, shack, shanty"],),
|
| 236 |
+
"l4_building": (ade20k["building"], ade20k["house"]),
|
| 237 |
+
"l4_skyscraper": (ade20k["skyscraper"],),
|
| 238 |
+
"l4_tower": (ade20k["tower"],),
|
| 239 |
+
# under l3_archi_parts
|
| 240 |
+
"l4_step": (ade20k["step, stair"],),
|
| 241 |
+
"l4_canopy": (ade20k["awning, sunshade, sunblind"], ade20k["canopy"]),
|
| 242 |
+
"l4_arcade": (ade20k["arcade machine"],),
|
| 243 |
+
"l4_door": (ade20k["door"],),
|
| 244 |
+
"l4_window": (ade20k["window"],),
|
| 245 |
+
"l4_wall": (ade20k["wall"],),
|
| 246 |
+
# under l3_roadway
|
| 247 |
+
"l4_stairway": (ade20k["stairway, staircase"],),
|
| 248 |
+
"l4_sidewalk": (ade20k["sidewalk, pavement"],),
|
| 249 |
+
"l4_road": (ade20k["road, route"],),
|
| 250 |
+
# under l3_furniture
|
| 251 |
+
"l4_sculpture": (ade20k["sculpture"],),
|
| 252 |
+
"l4_flag": (ade20k["flag"],),
|
| 253 |
+
"l4_can": (ade20k["trash can"],),
|
| 254 |
+
"l4_chair": (ade20k["chair"],),
|
| 255 |
+
"l4_pot": (ade20k["pot"],),
|
| 256 |
+
"l4_booth": (ade20k["booth"],),
|
| 257 |
+
"l4_streetlight": (ade20k["street lamp"],),
|
| 258 |
+
"l4_bench": (ade20k["bench"],),
|
| 259 |
+
"l4_fence": (ade20k["fence"],),
|
| 260 |
+
"l4_table": (ade20k["table"],),
|
| 261 |
+
# under l3_vehicle
|
| 262 |
+
"l4_bike": (ade20k["bicycle"],),
|
| 263 |
+
"l4_motorbike": (ade20k["minibike, motorbike"],),
|
| 264 |
+
"l4_van": (ade20k["van"],),
|
| 265 |
+
"l4_truck": (ade20k["truck"],),
|
| 266 |
+
"l4_bus": (ade20k["bus"],),
|
| 267 |
+
"l4_car": (ade20k["car"],),
|
| 268 |
+
# under l3_sign
|
| 269 |
+
"l4_traffic_sign": (ade20k["traffic light"],),
|
| 270 |
+
"l4_poster": (ade20k["poster, posting, placard, notice, bill, card"],),
|
| 271 |
+
"l4_signboard": (ade20k["signboard, sign"],),
|
| 272 |
+
# under l3_vert_land
|
| 273 |
+
"l4_rock": (ade20k["rock, stone"],),
|
| 274 |
+
"l4_hill": (ade20k["hill"],),
|
| 275 |
+
"l4_mountain": (ade20k["mountain, mount"],),
|
| 276 |
+
# under l3_hori_land
|
| 277 |
+
"l4_ground": (ade20k["earth, ground"], ade20k["land, ground, soil"]),
|
| 278 |
+
"l4_field": (ade20k["field"],),
|
| 279 |
+
"l4_sand": (ade20k["sand"],),
|
| 280 |
+
"l4_dirt": (ade20k["dirt track"],),
|
| 281 |
+
"l4_path": (ade20k["path"],),
|
| 282 |
+
# under l3_flower
|
| 283 |
+
"l4_flower": (ade20k["flower"],),
|
| 284 |
+
# under l3_grass
|
| 285 |
+
"l4_grass": (ade20k["grass"],),
|
| 286 |
+
# under l3_shrub
|
| 287 |
+
"l4_flora": (ade20k["plant"],),
|
| 288 |
+
# under l3_arbor
|
| 289 |
+
"l4_tree": (ade20k["tree"],),
|
| 290 |
+
"l4_palm": (ade20k["palm, palm tree"],),
|
| 291 |
+
# under l3_hori_water
|
| 292 |
+
"l4_lake": (ade20k["lake"],),
|
| 293 |
+
"l4_pool": (ade20k["pool"],),
|
| 294 |
+
"l4_river": (ade20k["river"],),
|
| 295 |
+
"l4_sea": (ade20k["sea"],),
|
| 296 |
+
"l4_water": (ade20k["water"],),
|
| 297 |
+
# under l3_vert_water
|
| 298 |
+
"l4_fountain": (ade20k["fountain"],),
|
| 299 |
+
"l4_waterfall": (ade20k["falls"],),
|
| 300 |
+
# under l3_human
|
| 301 |
+
"l4_person": (ade20k["person"],),
|
| 302 |
+
# under l3_animal
|
| 303 |
+
"l4_animal": (ade20k["animal"],),
|
| 304 |
+
# under l3_sky
|
| 305 |
+
"l4_sky": (ade20k["sky"],),
|
| 306 |
+
}
|
| 307 |
+
labels = labels | {
|
| 308 |
+
# level 3 labels
|
| 309 |
+
# under l2_landform
|
| 310 |
+
"l3_hori_land": labels["l4_ground"] + labels["l4_field"] + labels["l4_sand"] + labels["l4_dirt"] + labels["l4_path"],
|
| 311 |
+
"l3_vert_land": labels["l4_mountain"] + labels["l4_hill"] + labels["l4_rock"],
|
| 312 |
+
# under l2_vegetation
|
| 313 |
+
"l3_woody_plant": labels["l4_tree"] + labels["l4_palm"] + labels["l4_flora"],
|
| 314 |
+
"l3_herb_plant": labels["l4_grass"],
|
| 315 |
+
"l3_flower": labels["l4_flower"],
|
| 316 |
+
# under l2_water
|
| 317 |
+
"l3_hori_water": labels["l4_water"] + labels["l4_sea"] + labels["l4_river"] + labels["l4_pool"] + labels["l4_lake"],
|
| 318 |
+
"l3_vert_water": labels["l4_fountain"] + labels["l4_waterfall"],
|
| 319 |
+
# under l2_bio
|
| 320 |
+
"l3_human": labels["l4_person"],
|
| 321 |
+
"l3_animal": labels["l4_animal"],
|
| 322 |
+
# under l2_sky
|
| 323 |
+
"l3_sky": labels["l4_sky"],
|
| 324 |
+
# under l2_archi
|
| 325 |
+
"l3_architecture": labels["l4_building"] + labels["l4_hovel"] + labels["l4_tower"] + labels["l4_skyscraper"],
|
| 326 |
+
"l3_archi_parts": labels["l4_wall"] + labels["l4_window"] + labels["l4_door"] + labels["l4_arcade"] + labels["l4_canopy"] + labels["l4_step"],
|
| 327 |
+
# under l2_street
|
| 328 |
+
"l3_roadway": labels["l4_road"] + labels["l4_sidewalk"] + labels["l4_stairway"],
|
| 329 |
+
"l3_furniture": labels["l4_table"] + labels["l4_chair"] + labels["l4_fence"] + labels["l4_bench"] + labels["l4_streetlight"] + labels["l4_booth"] + labels["l4_pot"] + labels["l4_can"] + labels["l4_flag"] + labels["l4_sculpture"],
|
| 330 |
+
"l3_vehicle": labels["l4_car"] + labels["l4_bus"] + labels["l4_truck"] + labels["l4_van"] + labels["l4_motorbike"] + labels["l4_bike"],
|
| 331 |
+
"l3_sign": labels["l4_signboard"] + labels["l4_poster"] + labels["l4_traffic_sign"],
|
| 332 |
+
}
|
| 333 |
+
labels = labels | {
|
| 334 |
+
# level 2 labels
|
| 335 |
+
# under l1_nature
|
| 336 |
+
"l2_landform": labels["l3_hori_land"] + labels["l3_vert_land"],
|
| 337 |
+
"l2_vegetation": labels["l3_woody_plant"] + labels["l3_herb_plant"] + labels["l3_flower"],
|
| 338 |
+
"l2_water": labels["l3_hori_water"] + labels["l3_vert_water"],
|
| 339 |
+
"l2_bio": labels["l3_human"] + labels["l3_animal"],
|
| 340 |
+
"l2_sky": labels["l3_sky"],
|
| 341 |
+
# under l1_man_made
|
| 342 |
+
"l2_archi": labels["l3_architecture"] + labels["l3_archi_parts"],
|
| 343 |
+
"l2_street": labels["l3_roadway"] + labels["l3_furniture"] + labels["l3_vehicle"] + labels["l3_sign"],
|
| 344 |
+
}
|
| 345 |
+
labels = labels | {
|
| 346 |
+
# level 1 labels
|
| 347 |
+
"l1_nature": labels["l2_landform"] + labels["l2_vegetation"] + labels["l2_water"] + labels["l2_bio"] + labels["l2_sky"],
|
| 348 |
+
"l1_man_made": labels["l2_archi"] + labels["l2_street"],
|
| 349 |
+
}
|
| 350 |
+
return labels
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
if __name__ == "__main__":
|
| 354 |
+
ldc = Ladeco()
|
| 355 |
+
image = Path("images") / "canyon_3011_00002354.jpg"
|
| 356 |
+
out = ldc.predict(image)
|
examples/beach.jpg
ADDED
|
examples/field.jpg
ADDED
|
examples/sky.jpg
ADDED
|
face_comparison.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import tempfile
|
| 5 |
+
import os
|
| 6 |
+
import subprocess
|
| 7 |
+
import sys
|
| 8 |
+
import json
|
| 9 |
+
from typing import Dict, List, Tuple, Optional
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
# Set up logging to suppress DeepFace warnings
|
| 13 |
+
logging.getLogger('deepface').setLevel(logging.ERROR)
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from deepface import DeepFace
|
| 17 |
+
DEEPFACE_AVAILABLE = True
|
| 18 |
+
except ImportError:
|
| 19 |
+
DEEPFACE_AVAILABLE = False
|
| 20 |
+
print("Warning: DeepFace not available. Face comparison will be disabled.")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def run_deepface_in_subprocess(img1_path: str, img2_path: str) -> dict:
|
| 24 |
+
"""
|
| 25 |
+
Run DeepFace verification in a separate process to avoid TensorFlow conflicts.
|
| 26 |
+
"""
|
| 27 |
+
script_content = f'''
|
| 28 |
+
import sys
|
| 29 |
+
import json
|
| 30 |
+
from deepface import DeepFace
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
result = DeepFace.verify(img1_path="{img1_path}", img2_path="{img2_path}")
|
| 34 |
+
print(json.dumps(result))
|
| 35 |
+
except Exception as e:
|
| 36 |
+
print(json.dumps({{"error": str(e)}}))
|
| 37 |
+
'''
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
# Write the script to a temporary file
|
| 41 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as script_file:
|
| 42 |
+
script_file.write(script_content)
|
| 43 |
+
script_path = script_file.name
|
| 44 |
+
|
| 45 |
+
# Run the script in a subprocess
|
| 46 |
+
result = subprocess.run([sys.executable, script_path],
|
| 47 |
+
capture_output=True, text=True, timeout=30)
|
| 48 |
+
|
| 49 |
+
# Clean up the script file
|
| 50 |
+
os.unlink(script_path)
|
| 51 |
+
|
| 52 |
+
if result.returncode == 0:
|
| 53 |
+
return json.loads(result.stdout.strip())
|
| 54 |
+
else:
|
| 55 |
+
return {"error": f"Subprocess failed: {result.stderr}"}
|
| 56 |
+
|
| 57 |
+
except Exception as e:
|
| 58 |
+
return {"error": str(e)}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class FaceComparison:
|
| 62 |
+
"""
|
| 63 |
+
Handles face detection and comparison on full images.
|
| 64 |
+
Only responsible for determining if faces match - does not handle segmentation.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(self):
|
| 68 |
+
"""
|
| 69 |
+
Initialize face comparison using DeepFace's default verification threshold.
|
| 70 |
+
"""
|
| 71 |
+
self.available = DEEPFACE_AVAILABLE
|
| 72 |
+
self.face_match_result = None
|
| 73 |
+
self.comparison_log = []
|
| 74 |
+
|
| 75 |
+
def extract_faces(self, image_path: str) -> List[np.ndarray]:
|
| 76 |
+
"""
|
| 77 |
+
Extract faces from the full image using DeepFace (exactly like the working script).
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
image_path: Path to the image
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
List of face arrays
|
| 84 |
+
"""
|
| 85 |
+
if not self.available:
|
| 86 |
+
return []
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
faces = DeepFace.extract_faces(img_path=image_path, detector_backend='opencv')
|
| 90 |
+
if len(faces) == 0:
|
| 91 |
+
return []
|
| 92 |
+
return [f['face'] for f in faces]
|
| 93 |
+
|
| 94 |
+
except Exception as e:
|
| 95 |
+
print(f"Error extracting faces from {image_path}: {str(e)}")
|
| 96 |
+
return []
|
| 97 |
+
|
| 98 |
+
def compare_all_faces(self, image1_path: str, image2_path: str) -> Tuple[bool, List[str]]:
|
| 99 |
+
"""
|
| 100 |
+
Compare all faces between two images (exactly like the working script).
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
image1_path: Path to first image
|
| 104 |
+
image2_path: Path to second image
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
Tuple of (match_found, log_messages)
|
| 108 |
+
"""
|
| 109 |
+
if not self.available:
|
| 110 |
+
return False, ["Face comparison not available - DeepFace not installed"]
|
| 111 |
+
|
| 112 |
+
log_messages = []
|
| 113 |
+
|
| 114 |
+
try:
|
| 115 |
+
faces1 = self.extract_faces(image1_path)
|
| 116 |
+
faces2 = self.extract_faces(image2_path)
|
| 117 |
+
|
| 118 |
+
match_found = False
|
| 119 |
+
|
| 120 |
+
log_messages.append(f"Found {len(faces1)} face(s) in Image 1 and {len(faces2)} face(s) in Image 2")
|
| 121 |
+
|
| 122 |
+
if len(faces1) == 0 or len(faces2) == 0:
|
| 123 |
+
log_messages.append("❌ No faces found in one or both images")
|
| 124 |
+
return False, log_messages
|
| 125 |
+
|
| 126 |
+
for idx1, face1 in enumerate(faces1):
|
| 127 |
+
for idx2, face2 in enumerate(faces2):
|
| 128 |
+
# Create temporary files instead of permanent ones (exactly like original)
|
| 129 |
+
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as temp1, \
|
| 130 |
+
tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as temp2:
|
| 131 |
+
|
| 132 |
+
# Convert faces to uint8 and save temporarily (exactly like original)
|
| 133 |
+
face1_uint8 = (face1 * 255).astype(np.uint8)
|
| 134 |
+
face2_uint8 = (face2 * 255).astype(np.uint8)
|
| 135 |
+
|
| 136 |
+
cv2.imwrite(temp1.name, cv2.cvtColor(face1_uint8, cv2.COLOR_RGB2BGR))
|
| 137 |
+
cv2.imwrite(temp2.name, cv2.cvtColor(face2_uint8, cv2.COLOR_RGB2BGR))
|
| 138 |
+
|
| 139 |
+
try:
|
| 140 |
+
# Try subprocess approach first to avoid TensorFlow conflicts
|
| 141 |
+
result = run_deepface_in_subprocess(temp1.name, temp2.name)
|
| 142 |
+
|
| 143 |
+
if "error" in result:
|
| 144 |
+
# If subprocess fails, try direct approach
|
| 145 |
+
result = DeepFace.verify(img1_path=temp1.name, img2_path=temp2.name)
|
| 146 |
+
|
| 147 |
+
similarity = 1 - result['distance']
|
| 148 |
+
|
| 149 |
+
log_messages.append(f"Comparing Face1-{idx1} to Face2-{idx2} | Similarity: {similarity:.3f}")
|
| 150 |
+
|
| 151 |
+
if result['verified']:
|
| 152 |
+
log_messages.append(f"✅ Match found between Face1-{idx1} and Face2-{idx2}")
|
| 153 |
+
match_found = True
|
| 154 |
+
else:
|
| 155 |
+
log_messages.append(f"❌ No match between Face1-{idx1} and Face2-{idx2}")
|
| 156 |
+
|
| 157 |
+
except Exception as e:
|
| 158 |
+
log_messages.append(f"❌ Error comparing Face1-{idx1} to Face2-{idx2}: {str(e)}")
|
| 159 |
+
|
| 160 |
+
# Clean up temporary files immediately
|
| 161 |
+
try:
|
| 162 |
+
os.unlink(temp1.name)
|
| 163 |
+
os.unlink(temp2.name)
|
| 164 |
+
except:
|
| 165 |
+
pass
|
| 166 |
+
|
| 167 |
+
if not match_found:
|
| 168 |
+
log_messages.append("❌ No matching faces found between the two images.")
|
| 169 |
+
|
| 170 |
+
return match_found, log_messages
|
| 171 |
+
|
| 172 |
+
except Exception as e:
|
| 173 |
+
log_messages.append(f"Error in face comparison: {str(e)}")
|
| 174 |
+
return False, log_messages
|
| 175 |
+
|
| 176 |
+
def run_face_comparison(self, img1_path: str, img2_path: str) -> Tuple[bool, List[str]]:
|
| 177 |
+
"""
|
| 178 |
+
Run face comparison and store results for later use.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
img1_path: Path to first image
|
| 182 |
+
img2_path: Path to second image
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
Tuple of (faces_match, log_messages)
|
| 186 |
+
"""
|
| 187 |
+
faces_match, log_messages = self.compare_all_faces(img1_path, img2_path)
|
| 188 |
+
|
| 189 |
+
# Store results for later filtering
|
| 190 |
+
self.face_match_result = faces_match
|
| 191 |
+
self.comparison_log = log_messages
|
| 192 |
+
|
| 193 |
+
return faces_match, log_messages
|
| 194 |
+
|
| 195 |
+
def filter_human_regions_by_face_match(self, masks: Dict[str, np.ndarray]) -> Tuple[Dict[str, np.ndarray], List[str]]:
|
| 196 |
+
"""
|
| 197 |
+
Filter human regions based on previously computed face comparison results.
|
| 198 |
+
This only includes/excludes human regions - fine-grained segmentation happens elsewhere.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
masks: Dictionary of semantic masks
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
Tuple of (filtered_masks, log_messages)
|
| 205 |
+
"""
|
| 206 |
+
if not self.available:
|
| 207 |
+
return masks, ["Face comparison not available - DeepFace not installed"]
|
| 208 |
+
|
| 209 |
+
if self.face_match_result is None:
|
| 210 |
+
return masks, ["No face comparison results available. Run face comparison first."]
|
| 211 |
+
|
| 212 |
+
filtered_masks = {}
|
| 213 |
+
log_messages = []
|
| 214 |
+
|
| 215 |
+
# Look for human-specific regions (l3_human, not l2_bio which includes animals)
|
| 216 |
+
human_labels = [label for label in masks.keys() if 'l3_human' in label.lower()]
|
| 217 |
+
bio_labels = [label for label in masks.keys() if 'l2_bio' in label.lower()]
|
| 218 |
+
|
| 219 |
+
log_messages.append(f"Found human labels: {human_labels}")
|
| 220 |
+
log_messages.append(f"Found bio labels: {bio_labels}")
|
| 221 |
+
|
| 222 |
+
# Include all non-human regions regardless of face matching
|
| 223 |
+
for label, mask in masks.items():
|
| 224 |
+
if not any(human_term in label.lower() for human_term in ['l3_human', 'l2_bio']):
|
| 225 |
+
filtered_masks[label] = mask
|
| 226 |
+
log_messages.append(f"✅ Including non-human region: {label}")
|
| 227 |
+
else:
|
| 228 |
+
log_messages.append(f"🔍 Found human/bio region: {label}")
|
| 229 |
+
|
| 230 |
+
# Handle human regions based on face matching results
|
| 231 |
+
if self.face_match_result:
|
| 232 |
+
log_messages.append("✅ Faces matched! Including human regions in color matching.")
|
| 233 |
+
# Include human regions since faces matched
|
| 234 |
+
for label in human_labels + bio_labels:
|
| 235 |
+
if label in masks:
|
| 236 |
+
filtered_masks[label] = masks[label]
|
| 237 |
+
log_messages.append(f"✅ Including human region (faces matched): {label}")
|
| 238 |
+
else:
|
| 239 |
+
log_messages.append("❌ No face match found. Excluding human regions from color matching.")
|
| 240 |
+
# Don't include human regions since faces didn't match
|
| 241 |
+
for label in human_labels + bio_labels:
|
| 242 |
+
log_messages.append(f"❌ Excluding human region (no face match): {label}")
|
| 243 |
+
|
| 244 |
+
log_messages.append(f"📊 Final filtered masks: {list(filtered_masks.keys())}")
|
| 245 |
+
|
| 246 |
+
return filtered_masks, log_messages
|
folder_paths.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
# Simple folder_paths module to replace ComfyUI's folder_paths
|
| 4 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 5 |
+
models_dir = os.path.join(current_dir, "models")
|
| 6 |
+
|
| 7 |
+
# Model folder mappings
|
| 8 |
+
model_folder_paths = {}
|
| 9 |
+
|
| 10 |
+
def add_model_folder_path(name, path):
|
| 11 |
+
"""Add a model folder path."""
|
| 12 |
+
model_folder_paths[name] = path
|
| 13 |
+
os.makedirs(path, exist_ok=True)
|
| 14 |
+
|
| 15 |
+
def get_full_path(dirname, filename):
|
| 16 |
+
"""Get the full path for a model file."""
|
| 17 |
+
if dirname in model_folder_paths:
|
| 18 |
+
return os.path.join(model_folder_paths[dirname], filename)
|
| 19 |
+
return os.path.join(models_dir, dirname, filename)
|
| 20 |
+
|
| 21 |
+
# Initialize default paths
|
| 22 |
+
os.makedirs(models_dir, exist_ok=True)
|
human_parts_segmentation.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image, ImageFilter
|
| 5 |
+
import cv2
|
| 6 |
+
import requests
|
| 7 |
+
from typing import Dict, List, Tuple, Optional
|
| 8 |
+
import onnxruntime as ort
|
| 9 |
+
|
| 10 |
+
# Human parts labels based on CCIHP dataset - consistent with latest repo
|
| 11 |
+
HUMAN_PARTS_LABELS = {
|
| 12 |
+
0: ("background", "Background"),
|
| 13 |
+
1: ("hat", "Hat: Hat, helmet, cap, hood, veil, headscarf, part covering the skull and hair of a hood/balaclava, crown…"),
|
| 14 |
+
2: ("hair", "Hair"),
|
| 15 |
+
3: ("glove", "Glove"),
|
| 16 |
+
4: ("glasses", "Sunglasses/Glasses: Sunglasses, eyewear, protective glasses…"),
|
| 17 |
+
5: ("upper_clothes", "UpperClothes: T-shirt, shirt, tank top, sweater under a coat, top of a dress…"),
|
| 18 |
+
6: ("face_mask", "Face Mask: Protective mask, surgical mask, carnival mask, facial part of a balaclava, visor of a helmet…"),
|
| 19 |
+
7: ("coat", "Coat: Coat, jacket worn without anything on it, vest with nothing on it, a sweater with nothing on it…"),
|
| 20 |
+
8: ("socks", "Socks"),
|
| 21 |
+
9: ("pants", "Pants: Pants, shorts, tights, leggings, swimsuit bottoms… (clothing with 2 legs)"),
|
| 22 |
+
10: ("torso-skin", "Torso-skin"),
|
| 23 |
+
11: ("scarf", "Scarf: Scarf, bow tie, tie…"),
|
| 24 |
+
12: ("skirt", "Skirt: Skirt, kilt, bottom of a dress…"),
|
| 25 |
+
13: ("face", "Face"),
|
| 26 |
+
14: ("left-arm", "Left-arm (naked part)"),
|
| 27 |
+
15: ("right-arm", "Right-arm (naked part)"),
|
| 28 |
+
16: ("left-leg", "Left-leg (naked part)"),
|
| 29 |
+
17: ("right-leg", "Right-leg (naked part)"),
|
| 30 |
+
18: ("left-shoe", "Left-shoe"),
|
| 31 |
+
19: ("right-shoe", "Right-shoe"),
|
| 32 |
+
20: ("bag", "Bag: Backpack, shoulder bag, fanny pack… (bag carried on oneself"),
|
| 33 |
+
21: ("", "Others: Jewelry, tags, bibs, belts, ribbons, pins, head decorations, headphones…"),
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
# Model configuration - updated paths consistent with new repos
|
| 37 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 38 |
+
models_dir = os.path.join(current_dir, "models")
|
| 39 |
+
models_dir_path = os.path.join(models_dir, "onnx", "human-parts")
|
| 40 |
+
model_url = "https://huggingface.co/Metal3d/deeplabv3p-resnet50-human/resolve/main/deeplabv3p-resnet50-human.onnx"
|
| 41 |
+
model_name = "deeplabv3p-resnet50-human.onnx"
|
| 42 |
+
model_path = os.path.join(models_dir_path, model_name)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_class_index(class_name: str) -> int:
|
| 46 |
+
"""Return the index of the class name in the model."""
|
| 47 |
+
if class_name == "":
|
| 48 |
+
return -1
|
| 49 |
+
|
| 50 |
+
for key, value in HUMAN_PARTS_LABELS.items():
|
| 51 |
+
if value[0] == class_name:
|
| 52 |
+
return key
|
| 53 |
+
return -1
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def download_model(model_url: str, model_path: str) -> bool:
|
| 57 |
+
"""Download the human parts segmentation model if not present - improved version."""
|
| 58 |
+
if os.path.exists(model_path):
|
| 59 |
+
return True
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
| 63 |
+
print(f"Downloading human parts model to {model_path}...")
|
| 64 |
+
|
| 65 |
+
response = requests.get(model_url, stream=True)
|
| 66 |
+
response.raise_for_status()
|
| 67 |
+
|
| 68 |
+
total_size = int(response.headers.get('content-length', 0))
|
| 69 |
+
downloaded = 0
|
| 70 |
+
|
| 71 |
+
with open(model_path, 'wb') as f:
|
| 72 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 73 |
+
f.write(chunk)
|
| 74 |
+
downloaded += len(chunk)
|
| 75 |
+
if total_size > 0:
|
| 76 |
+
percent = (downloaded / total_size) * 100
|
| 77 |
+
print(f"\rDownload progress: {percent:.1f}%", end='', flush=True)
|
| 78 |
+
|
| 79 |
+
print("\n✅ Model download completed")
|
| 80 |
+
return True
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
print(f"\n❌ Error downloading model: {e}")
|
| 84 |
+
return False
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_human_parts_mask(image: torch.Tensor, model: ort.InferenceSession, rotation: float = 0, **kwargs) -> Tuple[torch.Tensor, int]:
|
| 88 |
+
"""
|
| 89 |
+
Generate human parts mask using the ONNX model - improved version.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
image: Input image tensor
|
| 93 |
+
model: ONNX inference session
|
| 94 |
+
rotation: Rotation angle (not used currently)
|
| 95 |
+
**kwargs: Part-specific enable flags
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Tuple of (mask_tensor, score)
|
| 99 |
+
"""
|
| 100 |
+
image = image.squeeze(0)
|
| 101 |
+
image_np = image.numpy() * 255
|
| 102 |
+
|
| 103 |
+
pil_image = Image.fromarray(image_np.astype(np.uint8))
|
| 104 |
+
original_size = pil_image.size
|
| 105 |
+
|
| 106 |
+
# Resize to 512x512 as the model expects
|
| 107 |
+
pil_image = pil_image.resize((512, 512))
|
| 108 |
+
center = (256, 256)
|
| 109 |
+
|
| 110 |
+
if rotation != 0:
|
| 111 |
+
pil_image = pil_image.rotate(rotation, center=center)
|
| 112 |
+
|
| 113 |
+
# Normalize the image
|
| 114 |
+
image_np = np.array(pil_image).astype(np.float32) / 127.5 - 1
|
| 115 |
+
image_np = np.expand_dims(image_np, axis=0)
|
| 116 |
+
|
| 117 |
+
# Use the ONNX model to get the segmentation
|
| 118 |
+
input_name = model.get_inputs()[0].name
|
| 119 |
+
output_name = model.get_outputs()[0].name
|
| 120 |
+
result = model.run([output_name], {input_name: image_np})
|
| 121 |
+
result = np.array(result[0]).argmax(axis=3).squeeze(0)
|
| 122 |
+
|
| 123 |
+
# Debug: Check what classes the model actually detected
|
| 124 |
+
unique_classes = np.unique(result)
|
| 125 |
+
|
| 126 |
+
score = 0
|
| 127 |
+
mask = np.zeros_like(result)
|
| 128 |
+
|
| 129 |
+
# Combine masks for enabled classes
|
| 130 |
+
for class_name, enabled in kwargs.items():
|
| 131 |
+
class_index = get_class_index(class_name)
|
| 132 |
+
if enabled and class_index != -1:
|
| 133 |
+
detected = result == class_index
|
| 134 |
+
mask[detected] = 255
|
| 135 |
+
score += mask.sum()
|
| 136 |
+
|
| 137 |
+
# Resize back to original size
|
| 138 |
+
mask_image = Image.fromarray(mask.astype(np.uint8), mode="L")
|
| 139 |
+
if rotation != 0:
|
| 140 |
+
mask_image = mask_image.rotate(-rotation, center=center)
|
| 141 |
+
|
| 142 |
+
mask_image = mask_image.resize(original_size)
|
| 143 |
+
|
| 144 |
+
# Convert back to numpy - improved tensor handling
|
| 145 |
+
mask = np.array(mask_image).astype(np.float32) / 255.0 # Normalize to 0-1 range
|
| 146 |
+
|
| 147 |
+
# Add dimensions for torch tensor - consistent format
|
| 148 |
+
mask = np.expand_dims(mask, axis=0)
|
| 149 |
+
mask = np.expand_dims(mask, axis=0)
|
| 150 |
+
|
| 151 |
+
return torch.from_numpy(mask), score
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def numpy_to_torch_tensor(image_np: np.ndarray) -> torch.Tensor:
|
| 155 |
+
"""Convert numpy array to torch tensor in the format expected by the models."""
|
| 156 |
+
if len(image_np.shape) == 3:
|
| 157 |
+
return torch.from_numpy(image_np.astype(np.float32) / 255.0).unsqueeze(0)
|
| 158 |
+
return torch.from_numpy(image_np.astype(np.float32) / 255.0)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def torch_tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
|
| 162 |
+
"""Convert torch tensor back to numpy array - improved version."""
|
| 163 |
+
if len(tensor.shape) == 4:
|
| 164 |
+
tensor = tensor.squeeze(0)
|
| 165 |
+
|
| 166 |
+
# Always handle as float32 tensor in 0-1 range then convert to binary
|
| 167 |
+
tensor_np = tensor.numpy()
|
| 168 |
+
if tensor_np.dtype == np.float32 and tensor_np.max() <= 1.0:
|
| 169 |
+
return (tensor_np > 0.5).astype(np.float32) # Binary threshold
|
| 170 |
+
else:
|
| 171 |
+
return tensor_np
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class HumanPartsSegmentation:
|
| 175 |
+
"""
|
| 176 |
+
Standalone human parts segmentation for face and hair using DeepLabV3+ ResNet50.
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
def __init__(self):
|
| 180 |
+
self.model = None
|
| 181 |
+
|
| 182 |
+
def check_model_cache(self):
|
| 183 |
+
"""Check if model file exists in cache - consistent with updated repos."""
|
| 184 |
+
if not os.path.exists(model_path):
|
| 185 |
+
return False, "Model file not found"
|
| 186 |
+
return True, "Model cache verified"
|
| 187 |
+
|
| 188 |
+
def clear_model(self):
|
| 189 |
+
"""Clear model from memory - improved version."""
|
| 190 |
+
if self.model is not None:
|
| 191 |
+
del self.model
|
| 192 |
+
self.model = None
|
| 193 |
+
|
| 194 |
+
def load_model(self):
|
| 195 |
+
"""Load the human parts segmentation model - improved version."""
|
| 196 |
+
try:
|
| 197 |
+
# Check and download model if needed
|
| 198 |
+
cache_status, message = self.check_model_cache()
|
| 199 |
+
if not cache_status:
|
| 200 |
+
print(f"Cache check: {message}")
|
| 201 |
+
if not download_model(model_url, model_path):
|
| 202 |
+
return False
|
| 203 |
+
|
| 204 |
+
# Load model if needed
|
| 205 |
+
if self.model is None:
|
| 206 |
+
print("Loading human parts segmentation model...")
|
| 207 |
+
self.model = ort.InferenceSession(model_path)
|
| 208 |
+
print("✅ Human parts segmentation model loaded successfully")
|
| 209 |
+
|
| 210 |
+
return True
|
| 211 |
+
|
| 212 |
+
except Exception as e:
|
| 213 |
+
print(f"❌ Error loading human parts model: {e}")
|
| 214 |
+
self.clear_model() # Cleanup on error
|
| 215 |
+
return False
|
| 216 |
+
|
| 217 |
+
def segment_parts(self, image_path: str, parts: List[str], mask_blur: int = 0, mask_offset: int = 0) -> Dict[str, np.ndarray]:
|
| 218 |
+
"""
|
| 219 |
+
Segment specific human parts from an image - improved version with filtering.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
image_path: Path to the image file
|
| 223 |
+
parts: List of part names to segment (e.g., ['face', 'hair'])
|
| 224 |
+
mask_blur: Blur amount for mask edges
|
| 225 |
+
mask_offset: Expand/Shrink mask boundary
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
Dictionary mapping part names to binary masks
|
| 229 |
+
"""
|
| 230 |
+
if not self.load_model():
|
| 231 |
+
print("❌ Cannot load human parts segmentation model")
|
| 232 |
+
return {}
|
| 233 |
+
|
| 234 |
+
try:
|
| 235 |
+
# Load image
|
| 236 |
+
image = cv2.imread(image_path)
|
| 237 |
+
if image is None:
|
| 238 |
+
print(f"❌ Could not load image: {image_path}")
|
| 239 |
+
return {}
|
| 240 |
+
|
| 241 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 242 |
+
|
| 243 |
+
# Convert to tensor format expected by the model
|
| 244 |
+
image_tensor = numpy_to_torch_tensor(image_rgb)
|
| 245 |
+
|
| 246 |
+
# Prepare kwargs for each part
|
| 247 |
+
part_kwargs = {part: True for part in parts}
|
| 248 |
+
|
| 249 |
+
# Get segmentation mask
|
| 250 |
+
mask_tensor, score = get_human_parts_mask(image_tensor, self.model, **part_kwargs)
|
| 251 |
+
|
| 252 |
+
# Convert back to numpy
|
| 253 |
+
if len(mask_tensor.shape) == 4:
|
| 254 |
+
mask_tensor = mask_tensor.squeeze(0).squeeze(0)
|
| 255 |
+
elif len(mask_tensor.shape) == 3:
|
| 256 |
+
mask_tensor = mask_tensor.squeeze(0)
|
| 257 |
+
|
| 258 |
+
# Get the combined mask for all requested parts
|
| 259 |
+
combined_mask = mask_tensor.numpy()
|
| 260 |
+
|
| 261 |
+
# Generate individual masks for each part if multiple parts requested
|
| 262 |
+
result_masks = {}
|
| 263 |
+
if len(parts) == 1:
|
| 264 |
+
# Single part - return the combined mask
|
| 265 |
+
part_name = parts[0]
|
| 266 |
+
final_mask = self._apply_filters(combined_mask, mask_blur, mask_offset)
|
| 267 |
+
if np.sum(final_mask > 0) > 0:
|
| 268 |
+
result_masks[part_name] = final_mask
|
| 269 |
+
else:
|
| 270 |
+
result_masks[part_name] = final_mask # Return empty mask instead of None
|
| 271 |
+
else:
|
| 272 |
+
# Multiple parts - need to segment each individually
|
| 273 |
+
for part in parts:
|
| 274 |
+
single_part_kwargs = {part: True}
|
| 275 |
+
single_mask_tensor, _ = get_human_parts_mask(image_tensor, self.model, **single_part_kwargs)
|
| 276 |
+
|
| 277 |
+
if len(single_mask_tensor.shape) == 4:
|
| 278 |
+
single_mask_tensor = single_mask_tensor.squeeze(0).squeeze(0)
|
| 279 |
+
elif len(single_mask_tensor.shape) == 3:
|
| 280 |
+
single_mask_tensor = single_mask_tensor.squeeze(0)
|
| 281 |
+
|
| 282 |
+
single_mask = single_mask_tensor.numpy()
|
| 283 |
+
final_mask = self._apply_filters(single_mask, mask_blur, mask_offset)
|
| 284 |
+
|
| 285 |
+
result_masks[part] = final_mask # Always add mask, even if empty
|
| 286 |
+
|
| 287 |
+
return result_masks
|
| 288 |
+
|
| 289 |
+
except Exception as e:
|
| 290 |
+
print(f"❌ Error in human parts segmentation: {e}")
|
| 291 |
+
return {}
|
| 292 |
+
finally:
|
| 293 |
+
# Clean up model if not needed
|
| 294 |
+
self.clear_model()
|
| 295 |
+
|
| 296 |
+
def _apply_filters(self, mask: np.ndarray, mask_blur: int = 0, mask_offset: int = 0) -> np.ndarray:
|
| 297 |
+
"""Apply filtering to mask - new method from updated repo."""
|
| 298 |
+
if mask_blur == 0 and mask_offset == 0:
|
| 299 |
+
return mask
|
| 300 |
+
|
| 301 |
+
try:
|
| 302 |
+
# Convert to PIL for filtering
|
| 303 |
+
mask_image = Image.fromarray((mask * 255).astype(np.uint8))
|
| 304 |
+
|
| 305 |
+
# Apply blur if specified
|
| 306 |
+
if mask_blur > 0:
|
| 307 |
+
mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=mask_blur))
|
| 308 |
+
|
| 309 |
+
# Apply offset if specified
|
| 310 |
+
if mask_offset != 0:
|
| 311 |
+
if mask_offset > 0:
|
| 312 |
+
mask_image = mask_image.filter(ImageFilter.MaxFilter(size=mask_offset * 2 + 1))
|
| 313 |
+
else:
|
| 314 |
+
mask_image = mask_image.filter(ImageFilter.MinFilter(size=-mask_offset * 2 + 1))
|
| 315 |
+
|
| 316 |
+
# Convert back to numpy
|
| 317 |
+
filtered_mask = np.array(mask_image).astype(np.float32) / 255.0
|
| 318 |
+
return filtered_mask
|
| 319 |
+
|
| 320 |
+
except Exception as e:
|
| 321 |
+
print(f"❌ Error applying filters: {e}")
|
| 322 |
+
return mask
|
models/RMBG/segformer_clothes/.cache/huggingface/.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*
|
models/RMBG/segformer_clothes/.cache/huggingface/download/config.json.lock
ADDED
|
File without changes
|
models/RMBG/segformer_clothes/.cache/huggingface/download/config.json.metadata
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2634bcc40712620e414ffb0efd5f5e4ea732ec5d
|
| 2 |
+
8352c4562bb0e1f72767dcb170ad6f3f56007836
|
| 3 |
+
1748821507.461211
|
models/RMBG/segformer_clothes/.cache/huggingface/download/model.safetensors.lock
ADDED
|
File without changes
|
models/RMBG/segformer_clothes/.cache/huggingface/download/model.safetensors.metadata
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2634bcc40712620e414ffb0efd5f5e4ea732ec5d
|
| 2 |
+
f70ae566c5773fb335796ebaa8acc924ac25eb97222c2b2967d44d2fc11568e6
|
| 3 |
+
1748821512.848557
|
models/RMBG/segformer_clothes/.cache/huggingface/download/preprocessor_config.json.lock
ADDED
|
File without changes
|
models/RMBG/segformer_clothes/.cache/huggingface/download/preprocessor_config.json.metadata
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2634bcc40712620e414ffb0efd5f5e4ea732ec5d
|
| 2 |
+
b2340cf4e53b37fda4f5b92d28f11c0f33c3d0fd
|
| 3 |
+
1748821513.065366
|
models/RMBG/segformer_clothes/config.json
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "nvidia/mit-b3",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"SegformerForSemanticSegmentation"
|
| 5 |
+
],
|
| 6 |
+
"attention_probs_dropout_prob": 0.0,
|
| 7 |
+
"classifier_dropout_prob": 0.1,
|
| 8 |
+
"decoder_hidden_size": 768,
|
| 9 |
+
"depths": [
|
| 10 |
+
3,
|
| 11 |
+
4,
|
| 12 |
+
18,
|
| 13 |
+
3
|
| 14 |
+
],
|
| 15 |
+
"downsampling_rates": [
|
| 16 |
+
1,
|
| 17 |
+
4,
|
| 18 |
+
8,
|
| 19 |
+
16
|
| 20 |
+
],
|
| 21 |
+
"drop_path_rate": 0.1,
|
| 22 |
+
"hidden_act": "gelu",
|
| 23 |
+
"hidden_dropout_prob": 0.0,
|
| 24 |
+
"hidden_sizes": [
|
| 25 |
+
64,
|
| 26 |
+
128,
|
| 27 |
+
320,
|
| 28 |
+
512
|
| 29 |
+
],
|
| 30 |
+
"id2label": {
|
| 31 |
+
"0": "Background",
|
| 32 |
+
"1": "Hat",
|
| 33 |
+
"10": "Right-shoe",
|
| 34 |
+
"11": "Face",
|
| 35 |
+
"12": "Left-leg",
|
| 36 |
+
"13": "Right-leg",
|
| 37 |
+
"14": "Left-arm",
|
| 38 |
+
"15": "Right-arm",
|
| 39 |
+
"16": "Bag",
|
| 40 |
+
"17": "Scarf",
|
| 41 |
+
"2": "Hair",
|
| 42 |
+
"3": "Sunglasses",
|
| 43 |
+
"4": "Upper-clothes",
|
| 44 |
+
"5": "Skirt",
|
| 45 |
+
"6": "Pants",
|
| 46 |
+
"7": "Dress",
|
| 47 |
+
"8": "Belt",
|
| 48 |
+
"9": "Left-shoe"
|
| 49 |
+
},
|
| 50 |
+
"image_size": 224,
|
| 51 |
+
"initializer_range": 0.02,
|
| 52 |
+
"label2id": {
|
| 53 |
+
"Background": "0",
|
| 54 |
+
"Bag": "16",
|
| 55 |
+
"Belt": "8",
|
| 56 |
+
"Dress": "7",
|
| 57 |
+
"Face": "11",
|
| 58 |
+
"Hair": "2",
|
| 59 |
+
"Hat": "1",
|
| 60 |
+
"Left-arm": "14",
|
| 61 |
+
"Left-leg": "12",
|
| 62 |
+
"Left-shoe": "9",
|
| 63 |
+
"Pants": "6",
|
| 64 |
+
"Right-arm": "15",
|
| 65 |
+
"Right-leg": "13",
|
| 66 |
+
"Right-shoe": "10",
|
| 67 |
+
"Scarf": "17",
|
| 68 |
+
"Skirt": "5",
|
| 69 |
+
"Sunglasses": "3",
|
| 70 |
+
"Upper-clothes": "4"
|
| 71 |
+
},
|
| 72 |
+
"layer_norm_eps": 1e-06,
|
| 73 |
+
"mlp_ratios": [
|
| 74 |
+
4,
|
| 75 |
+
4,
|
| 76 |
+
4,
|
| 77 |
+
4
|
| 78 |
+
],
|
| 79 |
+
"model_type": "segformer",
|
| 80 |
+
"num_attention_heads": [
|
| 81 |
+
1,
|
| 82 |
+
2,
|
| 83 |
+
5,
|
| 84 |
+
8
|
| 85 |
+
],
|
| 86 |
+
"num_channels": 3,
|
| 87 |
+
"num_encoder_blocks": 4,
|
| 88 |
+
"patch_sizes": [
|
| 89 |
+
7,
|
| 90 |
+
3,
|
| 91 |
+
3,
|
| 92 |
+
3
|
| 93 |
+
],
|
| 94 |
+
"reshape_last_stage": true,
|
| 95 |
+
"semantic_loss_ignore_index": 255,
|
| 96 |
+
"sr_ratios": [
|
| 97 |
+
8,
|
| 98 |
+
4,
|
| 99 |
+
2,
|
| 100 |
+
1
|
| 101 |
+
],
|
| 102 |
+
"strides": [
|
| 103 |
+
4,
|
| 104 |
+
2,
|
| 105 |
+
2,
|
| 106 |
+
2
|
| 107 |
+
],
|
| 108 |
+
"torch_dtype": "float32",
|
| 109 |
+
"transformers_version": "4.38.1"
|
| 110 |
+
}
|
models/RMBG/segformer_clothes/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f70ae566c5773fb335796ebaa8acc924ac25eb97222c2b2967d44d2fc11568e6
|
| 3 |
+
size 189029000
|
models/RMBG/segformer_clothes/preprocessor_config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"do_normalize": true,
|
| 3 |
+
"do_reduce_labels": false,
|
| 4 |
+
"do_rescale": true,
|
| 5 |
+
"do_resize": true,
|
| 6 |
+
"image_mean": [
|
| 7 |
+
0.485,
|
| 8 |
+
0.456,
|
| 9 |
+
0.406
|
| 10 |
+
],
|
| 11 |
+
"image_processor_type": "SegformerImageProcessor",
|
| 12 |
+
"image_std": [
|
| 13 |
+
0.229,
|
| 14 |
+
0.224,
|
| 15 |
+
0.225
|
| 16 |
+
],
|
| 17 |
+
"resample": 2,
|
| 18 |
+
"rescale_factor": 0.00392156862745098,
|
| 19 |
+
"size": {
|
| 20 |
+
"height": 512,
|
| 21 |
+
"width": 512
|
| 22 |
+
}
|
| 23 |
+
}
|
models/onnx/human-parts/deeplabv3p-resnet50-human.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a6e823a82da10ba24c29adfb544130684568c46bfac865e215bbace3b4035a71
|
| 3 |
+
size 47210581
|
requirements.txt
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LaDeco requirements
|
| 2 |
+
torch==2.3.1
|
| 3 |
+
torchaudio
|
| 4 |
+
torchvision
|
| 5 |
+
tf-keras
|
| 6 |
+
transformers==4.42.4
|
| 7 |
+
diffusers
|
| 8 |
+
opencv-python
|
| 9 |
+
Pillow
|
| 10 |
+
numpy
|
| 11 |
+
matplotlib
|
| 12 |
+
scipy
|
| 13 |
+
scikit-learn
|
| 14 |
+
|
| 15 |
+
# For Gradio interface
|
| 16 |
+
gradio
|
| 17 |
+
|
| 18 |
+
# Face comparison
|
| 19 |
+
deepface
|
| 20 |
+
|
| 21 |
+
# Human parts segmentation
|
| 22 |
+
onnxruntime
|
| 23 |
+
|
| 24 |
+
# Clothing segmentation
|
| 25 |
+
huggingface-hub>=0.19.0
|
| 26 |
+
segment-anything>=1.0
|
| 27 |
+
|
| 28 |
+
# Color matching dependencies
|
| 29 |
+
color-matcher
|
| 30 |
+
spaces
|
spaces.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
def GPU(func):
|
| 5 |
+
"""
|
| 6 |
+
A decorator to indicate that a function should use GPU acceleration if available.
|
| 7 |
+
This is used specifically for Hugging Face Spaces.
|
| 8 |
+
"""
|
| 9 |
+
@functools.wraps(func)
|
| 10 |
+
def wrapper(*args, **kwargs):
|
| 11 |
+
return func(*args, **kwargs)
|
| 12 |
+
return wrapper
|