MogensR commited on
Commit
94a9c1b
·
1 Parent(s): e9f947a

Create core/hair_segmentation.py

Browse files
Files changed (1) hide show
  1. core/hair_segmentation.py +795 -0
core/hair_segmentation.py ADDED
@@ -0,0 +1,795 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Advanced hair segmentation pipeline for BackgroundFX Pro.
3
+ Specialized module for accurate hair detection and segmentation.
4
+ """
5
+
6
+ import numpy as np
7
+ import cv2
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from typing import Dict, List, Optional, Tuple, Any
12
+ from dataclasses import dataclass
13
+ import logging
14
+ from scipy import ndimage
15
+ from skimage import morphology, filters
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ @dataclass
21
+ class HairConfig:
22
+ """Configuration for hair segmentation."""
23
+ min_hair_confidence: float = 0.6
24
+ edge_sensitivity: float = 0.8
25
+ strand_detection: bool = True
26
+ strand_thickness: int = 2
27
+ asymmetry_correction: bool = True
28
+ max_asymmetry_ratio: float = 1.5
29
+ use_deep_features: bool = False
30
+ refinement_iterations: int = 3
31
+ alpha_matting: bool = True
32
+ preserve_details: bool = True
33
+ smoothing_sigma: float = 1.0
34
+
35
+
36
+ class HairSegmentationPipeline:
37
+ """Complete hair segmentation pipeline."""
38
+
39
+ def __init__(self, config: Optional[HairConfig] = None):
40
+ self.config = config or HairConfig()
41
+ self.mask_refiner = HairMaskRefiner(config)
42
+ self.asymmetry_detector = AsymmetryDetector(config)
43
+ self.edge_enhancer = HairEdgeEnhancer(config)
44
+
45
+ # Optional deep learning model
46
+ self.deep_model = None
47
+ if self.config.use_deep_features:
48
+ self.deep_model = HairNet()
49
+
50
+ def segment(self, image: np.ndarray,
51
+ initial_mask: Optional[np.ndarray] = None,
52
+ prompts: Optional[Dict] = None) -> Dict[str, np.ndarray]:
53
+ """
54
+ Perform complete hair segmentation.
55
+
56
+ Returns:
57
+ Dictionary containing:
58
+ - 'mask': Final hair mask
59
+ - 'confidence': Confidence map
60
+ - 'strands': Fine hair strands mask
61
+ - 'edges': Hair edge map
62
+ """
63
+ h, w = image.shape[:2]
64
+
65
+ # 1. Initial hair detection
66
+ hair_regions = self._detect_hair_regions(image, initial_mask)
67
+
68
+ # 2. Deep feature extraction (if enabled)
69
+ if self.deep_model and self.config.use_deep_features:
70
+ deep_features = self._extract_deep_features(image)
71
+ hair_regions = self._enhance_with_deep_features(hair_regions, deep_features)
72
+
73
+ # 3. Detect and correct asymmetry
74
+ if self.config.asymmetry_correction:
75
+ asymmetry_info = self.asymmetry_detector.detect(hair_regions, image)
76
+ if asymmetry_info['is_asymmetric']:
77
+ logger.info(f"Correcting hair asymmetry: {asymmetry_info['score']:.3f}")
78
+ hair_regions = self.asymmetry_detector.correct(
79
+ hair_regions, asymmetry_info
80
+ )
81
+
82
+ # 4. Detect fine hair strands
83
+ strands_mask = None
84
+ if self.config.strand_detection:
85
+ strands_mask = self._detect_hair_strands(image, hair_regions)
86
+ # Integrate strands into main mask
87
+ hair_regions = self._integrate_strands(hair_regions, strands_mask)
88
+
89
+ # 5. Refine mask
90
+ refined_mask = self.mask_refiner.refine(image, hair_regions)
91
+
92
+ # 6. Edge enhancement
93
+ edges = self.edge_enhancer.enhance(refined_mask, image)
94
+ refined_mask = self._apply_edge_enhancement(refined_mask, edges)
95
+
96
+ # 7. Alpha matting (if enabled)
97
+ if self.config.alpha_matting:
98
+ refined_mask = self._apply_alpha_matting(image, refined_mask)
99
+
100
+ # 8. Final smoothing
101
+ final_mask = self._final_smoothing(refined_mask)
102
+
103
+ # 9. Compute confidence
104
+ confidence = self._compute_confidence(final_mask, initial_mask)
105
+
106
+ return {
107
+ 'mask': final_mask,
108
+ 'confidence': confidence,
109
+ 'strands': strands_mask,
110
+ 'edges': edges
111
+ }
112
+
113
+ def _detect_hair_regions(self, image: np.ndarray,
114
+ initial_mask: Optional[np.ndarray]) -> np.ndarray:
115
+ """Detect hair regions using multiple cues."""
116
+ # Color-based detection
117
+ color_mask = self._detect_by_color(image)
118
+
119
+ # Texture-based detection
120
+ texture_mask = self._detect_by_texture(image)
121
+
122
+ # Combine cues
123
+ hair_probability = 0.6 * color_mask + 0.4 * texture_mask
124
+
125
+ # If initial mask provided, constrain to it
126
+ if initial_mask is not None:
127
+ # Dilate initial mask slightly to catch hair edges
128
+ kernel = np.ones((15, 15), np.uint8)
129
+ dilated_initial = cv2.dilate(initial_mask, kernel, iterations=2)
130
+ hair_probability *= dilated_initial
131
+
132
+ # Threshold
133
+ hair_mask = (hair_probability > self.config.min_hair_confidence).astype(np.float32)
134
+
135
+ # Clean up small regions
136
+ hair_mask = self._remove_small_regions(hair_mask)
137
+
138
+ return hair_mask
139
+
140
+ def _detect_by_color(self, image: np.ndarray) -> np.ndarray:
141
+ """Detect hair by color characteristics."""
142
+ # Convert to multiple color spaces
143
+ hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
144
+ lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
145
+ ycrcb = cv2.cvtColor(image, cv2.COLOR_BGR2YCrCb)
146
+
147
+ masks = []
148
+
149
+ # Black hair detection
150
+ black_mask = cv2.inRange(hsv, (0, 0, 0), (180, 255, 30))
151
+ masks.append(black_mask)
152
+
153
+ # Brown hair detection
154
+ brown_mask = cv2.inRange(hsv, (10, 20, 20), (20, 255, 100))
155
+ masks.append(brown_mask)
156
+
157
+ # Blonde hair detection
158
+ blonde_mask = cv2.inRange(hsv, (15, 30, 50), (25, 255, 200))
159
+ masks.append(blonde_mask)
160
+
161
+ # Red/Auburn hair detection
162
+ red_mask = cv2.inRange(hsv, (0, 50, 50), (10, 255, 150))
163
+ auburn_mask = cv2.inRange(hsv, (160, 50, 50), (180, 255, 150))
164
+ masks.append(cv2.bitwise_or(red_mask, auburn_mask))
165
+
166
+ # Gray/White hair detection
167
+ gray_mask = cv2.inRange(hsv, (0, 0, 50), (180, 30, 200))
168
+ masks.append(gray_mask)
169
+
170
+ # Combine all masks
171
+ combined = np.zeros_like(masks[0], dtype=np.float32)
172
+ for mask in masks:
173
+ combined = np.maximum(combined, mask.astype(np.float32) / 255.0)
174
+
175
+ # Smooth the result
176
+ combined = cv2.GaussianBlur(combined, (7, 7), 2.0)
177
+
178
+ return combined
179
+
180
+ def _detect_by_texture(self, image: np.ndarray) -> np.ndarray:
181
+ """Detect hair by texture characteristics."""
182
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image
183
+
184
+ # Multi-scale texture analysis
185
+ texture_responses = []
186
+
187
+ # Gabor filters for different orientations and scales
188
+ for scale in [3, 5, 7]:
189
+ for angle in [0, 30, 60, 90, 120, 150]:
190
+ theta = np.deg2rad(angle)
191
+ kernel = cv2.getGaborKernel(
192
+ (21, 21), scale, theta, 10.0, 0.5, 0, ktype=cv2.CV_32F
193
+ )
194
+ response = cv2.filter2D(gray, cv2.CV_32F, kernel)
195
+ texture_responses.append(np.abs(response))
196
+
197
+ # Combine responses
198
+ texture_map = np.mean(texture_responses, axis=0)
199
+
200
+ # Normalize
201
+ texture_map = (texture_map - np.min(texture_map)) / (np.max(texture_map) - np.min(texture_map) + 1e-6)
202
+
203
+ # Hair tends to have consistent directional texture
204
+ # Compute local coherence
205
+ coherence = self._compute_texture_coherence(texture_responses)
206
+
207
+ # Combine texture magnitude and coherence
208
+ hair_texture = texture_map * coherence
209
+
210
+ return hair_texture
211
+
212
+ def _compute_texture_coherence(self, responses: List[np.ndarray]) -> np.ndarray:
213
+ """Compute texture coherence (consistency of orientation)."""
214
+ if len(responses) < 2:
215
+ return np.ones_like(responses[0])
216
+
217
+ # Compute variance across orientations
218
+ response_stack = np.stack(responses, axis=0)
219
+ variance = np.var(response_stack, axis=0)
220
+ mean = np.mean(response_stack, axis=0) + 1e-6
221
+
222
+ # Low variance relative to mean = high coherence
223
+ coherence = 1.0 - np.minimum(variance / mean, 1.0)
224
+
225
+ return coherence
226
+
227
+ def _detect_hair_strands(self, image: np.ndarray,
228
+ hair_mask: np.ndarray) -> np.ndarray:
229
+ """Detect fine hair strands."""
230
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image
231
+
232
+ # Edge detection with low threshold for fine details
233
+ edges = cv2.Canny(gray, 10, 30)
234
+
235
+ # Line detection using Hough transform
236
+ lines = cv2.HoughLinesP(
237
+ edges, 1, np.pi/180, threshold=20,
238
+ minLineLength=10, maxLineGap=5
239
+ )
240
+
241
+ # Create strand mask
242
+ strand_mask = np.zeros_like(gray, dtype=np.float32)
243
+
244
+ if lines is not None:
245
+ for line in lines:
246
+ x1, y1, x2, y2 = line[0]
247
+
248
+ # Check if line is near hair region
249
+ mid_x, mid_y = (x1 + x2) // 2, (y1 + y2) // 2
250
+
251
+ # Dilated hair mask for proximity check
252
+ kernel = np.ones((15, 15), np.uint8)
253
+ dilated_hair = cv2.dilate(hair_mask, kernel, iterations=1)
254
+
255
+ if dilated_hair[mid_y, mid_x] > 0:
256
+ # Draw line as potential hair strand
257
+ cv2.line(strand_mask, (x1, y1), (x2, y2), 1.0, self.config.strand_thickness)
258
+
259
+ # Ridge detection for curved strands
260
+ ridges = filters.frangi(gray, sigmas=range(1, 4))
261
+ ridges = (ridges - np.min(ridges)) / (np.max(ridges) - np.min(ridges) + 1e-6)
262
+
263
+ # Combine with line detection
264
+ strand_mask = np.maximum(strand_mask, ridges * dilated_hair)
265
+
266
+ # Threshold and clean
267
+ strand_mask = (strand_mask > 0.3).astype(np.float32)
268
+ strand_mask = cv2.morphologyEx(strand_mask, cv2.MORPH_CLOSE, np.ones((3, 3)))
269
+
270
+ return strand_mask
271
+
272
+ def _integrate_strands(self, hair_mask: np.ndarray,
273
+ strands_mask: np.ndarray) -> np.ndarray:
274
+ """Integrate detected strands into main hair mask."""
275
+ if strands_mask is None:
276
+ return hair_mask
277
+
278
+ # Add strands to hair mask
279
+ integrated = np.maximum(hair_mask, strands_mask * 0.8)
280
+
281
+ # Smooth the integration
282
+ integrated = cv2.GaussianBlur(integrated, (5, 5), 1.0)
283
+
284
+ return np.clip(integrated, 0, 1)
285
+
286
+ def _extract_deep_features(self, image: np.ndarray) -> torch.Tensor:
287
+ """Extract deep features using neural network."""
288
+ if not self.deep_model:
289
+ return None
290
+
291
+ # Prepare input
292
+ input_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float() / 255.0
293
+
294
+ # Extract features
295
+ with torch.no_grad():
296
+ features = self.deep_model.extract_features(input_tensor)
297
+
298
+ return features
299
+
300
+ def _enhance_with_deep_features(self, mask: np.ndarray,
301
+ features: torch.Tensor) -> np.ndarray:
302
+ """Enhance mask using deep features."""
303
+ if features is None:
304
+ return mask
305
+
306
+ # Process features to get hair probability
307
+ hair_prob = self.deep_model.process_features(features)
308
+ hair_prob = hair_prob.squeeze().cpu().numpy()
309
+
310
+ # Resize to match mask
311
+ hair_prob = cv2.resize(hair_prob, (mask.shape[1], mask.shape[0]))
312
+
313
+ # Combine with existing mask
314
+ enhanced = 0.7 * mask + 0.3 * hair_prob
315
+
316
+ return np.clip(enhanced, 0, 1)
317
+
318
+ def _apply_alpha_matting(self, image: np.ndarray,
319
+ mask: np.ndarray) -> np.ndarray:
320
+ """Apply alpha matting for refined transparency."""
321
+ # Simple alpha matting using guided filter
322
+ # For production, consider using more advanced methods like Deep Image Matting
323
+
324
+ # Convert image to grayscale for guidance
325
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image
326
+ gray = gray.astype(np.float32) / 255.0
327
+
328
+ # Guided filter for alpha matting
329
+ radius = 20
330
+ epsilon = 0.01
331
+
332
+ alpha = self._guided_filter(mask, gray, radius, epsilon)
333
+
334
+ return np.clip(alpha, 0, 1)
335
+
336
+ def _guided_filter(self, p: np.ndarray, I: np.ndarray,
337
+ radius: int, epsilon: float) -> np.ndarray:
338
+ """Guided filter implementation."""
339
+ mean_I = cv2.boxFilter(I, cv2.CV_32F, (radius, radius))
340
+ mean_p = cv2.boxFilter(p, cv2.CV_32F, (radius, radius))
341
+ mean_Ip = cv2.boxFilter(I * p, cv2.CV_32F, (radius, radius))
342
+ cov_Ip = mean_Ip - mean_I * mean_p
343
+
344
+ mean_II = cv2.boxFilter(I * I, cv2.CV_32F, (radius, radius))
345
+ var_I = mean_II - mean_I * mean_I
346
+
347
+ a = cov_Ip / (var_I + epsilon)
348
+ b = mean_p - a * mean_I
349
+
350
+ mean_a = cv2.boxFilter(a, cv2.CV_32F, (radius, radius))
351
+ mean_b = cv2.boxFilter(b, cv2.CV_32F, (radius, radius))
352
+
353
+ q = mean_a * I + mean_b
354
+
355
+ return q
356
+
357
+ def _apply_edge_enhancement(self, mask: np.ndarray,
358
+ edges: np.ndarray) -> np.ndarray:
359
+ """Apply edge enhancement to mask."""
360
+ # Strengthen mask at detected edges
361
+ edge_weight = 0.3
362
+ enhanced = mask + edge_weight * edges
363
+
364
+ return np.clip(enhanced, 0, 1)
365
+
366
+ def _final_smoothing(self, mask: np.ndarray) -> np.ndarray:
367
+ """Apply final smoothing while preserving details."""
368
+ if self.config.preserve_details:
369
+ # Edge-preserving smoothing
370
+ smoothed = cv2.bilateralFilter(
371
+ (mask * 255).astype(np.uint8), 9, 75, 75
372
+ ) / 255.0
373
+ else:
374
+ # Simple Gaussian smoothing
375
+ smoothed = cv2.GaussianBlur(
376
+ mask, (5, 5), self.config.smoothing_sigma
377
+ )
378
+
379
+ return smoothed
380
+
381
+ def _compute_confidence(self, mask: np.ndarray,
382
+ initial_mask: Optional[np.ndarray]) -> np.ndarray:
383
+ """Compute confidence map for the segmentation."""
384
+ # Base confidence from mask values
385
+ # Values close to 0 or 1 are more confident
386
+ distance_from_middle = np.abs(mask - 0.5) * 2
387
+ confidence = distance_from_middle
388
+
389
+ # If initial mask provided, boost confidence in agreement areas
390
+ if initial_mask is not None:
391
+ agreement = 1 - np.abs(mask - initial_mask)
392
+ confidence = 0.7 * confidence + 0.3 * agreement
393
+
394
+ return np.clip(confidence, 0, 1)
395
+
396
+ def _remove_small_regions(self, mask: np.ndarray,
397
+ min_size: int = 100) -> np.ndarray:
398
+ """Remove small disconnected regions."""
399
+ # Convert to binary
400
+ binary = (mask > 0.5).astype(np.uint8)
401
+
402
+ # Find connected components
403
+ num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary)
404
+
405
+ # Remove small components
406
+ cleaned = np.zeros_like(mask)
407
+ for i in range(1, num_labels):
408
+ if stats[i, cv2.CC_STAT_AREA] >= min_size:
409
+ cleaned[labels == i] = mask[labels == i]
410
+
411
+ return cleaned
412
+
413
+
414
+ class HairMaskRefiner:
415
+ """Refines hair masks for better quality."""
416
+
417
+ def __init__(self, config: HairConfig):
418
+ self.config = config
419
+
420
+ def refine(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
421
+ """Refine hair mask through multiple iterations."""
422
+ refined = mask.copy()
423
+
424
+ for iteration in range(self.config.refinement_iterations):
425
+ # Progressive refinement
426
+ refined = self._refine_iteration(image, refined, iteration)
427
+
428
+ return refined
429
+
430
+ def _refine_iteration(self, image: np.ndarray, mask: np.ndarray,
431
+ iteration: int) -> np.ndarray:
432
+ """Single refinement iteration."""
433
+ # Morphological operations
434
+ kernel_size = 5 - iteration # Decreasing kernel size
435
+ kernel = cv2.getStructuringElement(
436
+ cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)
437
+ )
438
+
439
+ # Close gaps
440
+ refined = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
441
+
442
+ # Remove noise
443
+ refined = cv2.morphologyEx(refined, cv2.MORPH_OPEN, kernel)
444
+
445
+ # Smooth boundaries
446
+ refined = cv2.GaussianBlur(refined, (3, 3), 0.5)
447
+
448
+ return refined
449
+
450
+
451
+ class AsymmetryDetector:
452
+ """Detects and corrects asymmetry in hair masks."""
453
+
454
+ def __init__(self, config: HairConfig):
455
+ self.config = config
456
+
457
+ def detect(self, mask: np.ndarray, image: np.ndarray) -> Dict[str, Any]:
458
+ """Detect asymmetry in hair mask."""
459
+ h, w = mask.shape[:2]
460
+
461
+ # Find vertical center line
462
+ center_x = self._find_center_line(mask)
463
+
464
+ # Split into left and right
465
+ left_mask = mask[:, :center_x]
466
+ right_mask = mask[:, center_x:]
467
+
468
+ # Make same width for comparison
469
+ min_width = min(left_mask.shape[1], right_mask.shape[1])
470
+ left_mask = left_mask[:, -min_width:] if left_mask.shape[1] > min_width else left_mask
471
+ right_mask = right_mask[:, :min_width] if right_mask.shape[1] > min_width else right_mask
472
+
473
+ # Flip right for comparison
474
+ right_flipped = np.fliplr(right_mask)
475
+
476
+ # Compute asymmetry metrics
477
+ pixel_diff = np.mean(np.abs(left_mask - right_flipped))
478
+
479
+ # Area comparison
480
+ left_area = np.sum(left_mask > 0.5)
481
+ right_area = np.sum(right_mask > 0.5)
482
+ area_ratio = max(left_area, right_area) / (min(left_area, right_area) + 1e-6)
483
+
484
+ # Edge comparison
485
+ left_edges = cv2.Canny((left_mask * 255).astype(np.uint8), 50, 150)
486
+ right_edges = cv2.Canny((right_mask * 255).astype(np.uint8), 50, 150)
487
+ right_edges_flipped = np.fliplr(right_edges)
488
+ edge_diff = np.mean(np.abs(left_edges - right_edges_flipped)) / 255.0
489
+
490
+ # Overall asymmetry score
491
+ asymmetry_score = 0.4 * pixel_diff + 0.3 * (area_ratio - 1.0) / 2.0 + 0.3 * edge_diff
492
+
493
+ is_asymmetric = (asymmetry_score > self.config.symmetry_threshold or
494
+ area_ratio > self.config.max_asymmetry_ratio)
495
+
496
+ return {
497
+ 'is_asymmetric': is_asymmetric,
498
+ 'score': asymmetry_score,
499
+ 'center_x': center_x,
500
+ 'area_ratio': area_ratio,
501
+ 'pixel_diff': pixel_diff,
502
+ 'edge_diff': edge_diff
503
+ }
504
+
505
+ def correct(self, mask: np.ndarray, asymmetry_info: Dict[str, Any]) -> np.ndarray:
506
+ """Correct detected asymmetry."""
507
+ center_x = asymmetry_info['center_x']
508
+ h, w = mask.shape[:2]
509
+
510
+ # Split mask
511
+ left_mask = mask[:, :center_x]
512
+ right_mask = mask[:, center_x:]
513
+
514
+ # Determine which side is more reliable
515
+ left_density = np.mean(left_mask > 0.5)
516
+ right_density = np.mean(right_mask > 0.5)
517
+
518
+ # Use denser side as reference (usually more complete)
519
+ if left_density > right_density:
520
+ # Mirror left to right
521
+ reference = left_mask
522
+ mirrored = np.fliplr(reference)
523
+
524
+ # Blend with original right
525
+ corrected_right = 0.7 * mirrored[:, :right_mask.shape[1]] + 0.3 * right_mask
526
+
527
+ # Reconstruct
528
+ corrected = np.zeros_like(mask)
529
+ corrected[:, :center_x] = left_mask
530
+ corrected[:, center_x:center_x + corrected_right.shape[1]] = corrected_right
531
+ else:
532
+ # Mirror right to left
533
+ reference = right_mask
534
+ mirrored = np.fliplr(reference)
535
+
536
+ # Blend with original left
537
+ corrected_left = 0.7 * mirrored[:, -left_mask.shape[1]:] + 0.3 * left_mask
538
+
539
+ # Reconstruct
540
+ corrected = np.zeros_like(mask)
541
+ corrected[:, :center_x] = corrected_left
542
+ corrected[:, center_x:] = right_mask
543
+
544
+ # Smooth the center seam
545
+ seam_width = 10
546
+ seam_start = max(0, center_x - seam_width)
547
+ seam_end = min(w, center_x + seam_width)
548
+ corrected[:, seam_start:seam_end] = cv2.GaussianBlur(
549
+ corrected[:, seam_start:seam_end], (7, 1), 2.0
550
+ )
551
+
552
+ return corrected
553
+
554
+ def _find_center_line(self, mask: np.ndarray) -> int:
555
+ """Find the vertical center line of the object."""
556
+ # Use center of mass
557
+ mask_binary = (mask > 0.5).astype(np.uint8)
558
+ moments = cv2.moments(mask_binary)
559
+
560
+ if moments['m00'] > 0:
561
+ cx = int(moments['m10'] / moments['m00'])
562
+ else:
563
+ # Fallback to image center
564
+ cx = mask.shape[1] // 2
565
+
566
+ return cx
567
+
568
+
569
+ class HairEdgeEnhancer:
570
+ """Enhances edges in hair masks."""
571
+
572
+ def __init__(self, config: HairConfig):
573
+ self.config = config
574
+
575
+ def enhance(self, mask: np.ndarray, image: np.ndarray) -> np.ndarray:
576
+ """Enhance hair edges for better quality."""
577
+ # Detect edges in image
578
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image
579
+
580
+ # Multi-scale edge detection
581
+ edges = self._multi_scale_edges(gray)
582
+
583
+ # Detect edges in mask
584
+ mask_edges = cv2.Canny((mask * 255).astype(np.uint8), 30, 100) / 255.0
585
+
586
+ # Find hair-specific edges
587
+ hair_edges = self._detect_hair_edges(gray, mask)
588
+
589
+ # Combine all edge information
590
+ combined_edges = np.maximum(edges * 0.3, np.maximum(mask_edges * 0.3, hair_edges * 0.4))
591
+
592
+ # Apply non-maximum suppression
593
+ combined_edges = self._non_max_suppression(combined_edges)
594
+
595
+ return combined_edges
596
+
597
+ def _multi_scale_edges(self, gray: np.ndarray) -> np.ndarray:
598
+ """Detect edges at multiple scales."""
599
+ edges_list = []
600
+
601
+ for scale in [1, 2, 3]:
602
+ # Resize image
603
+ if scale > 1:
604
+ scaled = cv2.resize(gray, None, fx=1/scale, fy=1/scale)
605
+ else:
606
+ scaled = gray
607
+
608
+ # Detect edges
609
+ edges = cv2.Canny(scaled, 30 * scale, 80 * scale)
610
+
611
+ # Resize back
612
+ if scale > 1:
613
+ edges = cv2.resize(edges, (gray.shape[1], gray.shape[0]))
614
+
615
+ edges_list.append(edges / 255.0)
616
+
617
+ # Combine scales
618
+ combined = np.mean(edges_list, axis=0)
619
+
620
+ return combined
621
+
622
+ def _detect_hair_edges(self, gray: np.ndarray, mask: np.ndarray) -> np.ndarray:
623
+ """Detect edges specific to hair texture."""
624
+ # Use Gabor filters to detect hair-like textures
625
+ hair_edges = np.zeros_like(gray, dtype=np.float32)
626
+
627
+ # Multiple orientations
628
+ for angle in range(0, 180, 30):
629
+ theta = np.deg2rad(angle)
630
+ kernel = cv2.getGaborKernel(
631
+ (11, 11), 3.0, theta, 8.0, 0.5, 0, ktype=cv2.CV_32F
632
+ )
633
+
634
+ filtered = cv2.filter2D(gray, cv2.CV_32F, kernel)
635
+ hair_edges = np.maximum(hair_edges, np.abs(filtered))
636
+
637
+ # Normalize
638
+ hair_edges = hair_edges / (np.max(hair_edges) + 1e-6)
639
+
640
+ # Mask to hair regions
641
+ hair_edges *= mask
642
+
643
+ # Threshold
644
+ hair_edges = (hair_edges > self.config.edge_sensitivity * 0.5).astype(np.float32)
645
+
646
+ return hair_edges
647
+
648
+ def _non_max_suppression(self, edges: np.ndarray) -> np.ndarray:
649
+ """Apply non-maximum suppression to edges."""
650
+ # Compute gradients
651
+ dx = cv2.Sobel(edges, cv2.CV_32F, 1, 0, ksize=3)
652
+ dy = cv2.Sobel(edges, cv2.CV_32F, 0, 1, ksize=3)
653
+
654
+ # Gradient magnitude and direction
655
+ magnitude = np.sqrt(dx**2 + dy**2)
656
+ direction = np.arctan2(dy, dx)
657
+
658
+ # Quantize directions to 4 main orientations
659
+ direction = np.rad2deg(direction)
660
+ direction[direction < 0] += 180
661
+
662
+ # Non-maximum suppression
663
+ suppressed = np.zeros_like(magnitude)
664
+
665
+ for i in range(1, magnitude.shape[0] - 1):
666
+ for j in range(1, magnitude.shape[1] - 1):
667
+ angle = direction[i, j]
668
+ mag = magnitude[i, j]
669
+
670
+ # Determine neighbors based on gradient direction
671
+ if (0 <= angle < 22.5) or (157.5 <= angle <= 180):
672
+ # Horizontal
673
+ neighbors = [magnitude[i, j-1], magnitude[i, j+1]]
674
+ elif 22.5 <= angle < 67.5:
675
+ # Diagonal /
676
+ neighbors = [magnitude[i-1, j+1], magnitude[i+1, j-1]]
677
+ elif 67.5 <= angle < 112.5:
678
+ # Vertical
679
+ neighbors = [magnitude[i-1, j], magnitude[i+1, j]]
680
+ else:
681
+ # Diagonal \
682
+ neighbors = [magnitude[i-1, j-1], magnitude[i+1, j+1]]
683
+
684
+ # Keep only if local maximum
685
+ if mag >= max(neighbors):
686
+ suppressed[i, j] = mag
687
+
688
+ # Normalize
689
+ suppressed = suppressed / (np.max(suppressed) + 1e-6)
690
+
691
+ return suppressed
692
+
693
+
694
+ class HairNet(nn.Module):
695
+ """Simple neural network for hair feature extraction (placeholder)."""
696
+
697
+ def __init__(self):
698
+ super().__init__()
699
+ # Simplified architecture - replace with actual model if needed
700
+ self.encoder = nn.Sequential(
701
+ nn.Conv2d(3, 32, 3, padding=1),
702
+ nn.ReLU(),
703
+ nn.MaxPool2d(2),
704
+ nn.Conv2d(32, 64, 3, padding=1),
705
+ nn.ReLU(),
706
+ nn.MaxPool2d(2),
707
+ nn.Conv2d(64, 128, 3, padding=1),
708
+ nn.ReLU(),
709
+ )
710
+
711
+ self.decoder = nn.Sequential(
712
+ nn.Conv2d(128, 64, 3, padding=1),
713
+ nn.ReLU(),
714
+ nn.Upsample(scale_factor=2),
715
+ nn.Conv2d(64, 32, 3, padding=1),
716
+ nn.ReLU(),
717
+ nn.Upsample(scale_factor=2),
718
+ nn.Conv2d(32, 1, 3, padding=1),
719
+ nn.Sigmoid()
720
+ )
721
+
722
+ def extract_features(self, x: torch.Tensor) -> torch.Tensor:
723
+ """Extract features from input image."""
724
+ return self.encoder(x)
725
+
726
+ def process_features(self, features: torch.Tensor) -> torch.Tensor:
727
+ """Process features to get hair probability."""
728
+ return self.decoder(features)
729
+
730
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
731
+ """Forward pass."""
732
+ features = self.extract_features(x)
733
+ output = self.process_features(features)
734
+ return output
735
+
736
+
737
+ # Utility functions
738
+ def visualize_hair_segmentation(image: np.ndarray,
739
+ results: Dict[str, np.ndarray],
740
+ save_path: Optional[str] = None) -> np.ndarray:
741
+ """Visualize hair segmentation results."""
742
+ h, w = image.shape[:2]
743
+
744
+ # Create visualization grid
745
+ viz = np.zeros((h * 2, w * 2, 3), dtype=np.uint8)
746
+
747
+ # Original image
748
+ viz[:h, :w] = image
749
+
750
+ # Hair mask overlay
751
+ mask_colored = np.zeros_like(image)
752
+ mask_colored[:, :, 1] = (results['mask'] * 255).astype(np.uint8) # Green channel
753
+ overlay = cv2.addWeighted(image, 0.7, mask_colored, 0.3, 0)
754
+ viz[:h, w:] = overlay
755
+
756
+ # Confidence map
757
+ if 'confidence' in results:
758
+ confidence_colored = cv2.applyColorMap(
759
+ (results['confidence'] * 255).astype(np.uint8),
760
+ cv2.COLORMAP_JET
761
+ )
762
+ viz[h:, :w] = confidence_colored
763
+
764
+ # Edges and strands
765
+ if 'edges' in results and 'strands' in results:
766
+ edges_viz = np.zeros_like(image)
767
+ edges_viz[:, :, 2] = (results['edges'] * 255).astype(np.uint8) # Red channel
768
+
769
+ if results['strands'] is not None:
770
+ edges_viz[:, :, 0] = (results['strands'] * 255).astype(np.uint8) # Blue channel
771
+
772
+ viz[h:, w:] = edges_viz
773
+
774
+ # Add labels
775
+ cv2.putText(viz, "Original", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
776
+ cv2.putText(viz, "Hair Mask", (w + 10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
777
+ cv2.putText(viz, "Confidence", (10, h + 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
778
+ cv2.putText(viz, "Edges/Strands", (w + 10, h + 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
779
+
780
+ if save_path:
781
+ cv2.imwrite(save_path, viz)
782
+
783
+ return viz
784
+
785
+
786
+ # Export classes and functions
787
+ __all__ = [
788
+ 'HairSegmentationPipeline',
789
+ 'HairConfig',
790
+ 'HairMaskRefiner',
791
+ 'AsymmetryDetector',
792
+ 'HairEdgeEnhancer',
793
+ 'HairNet',
794
+ 'visualize_hair_segmentation'
795
+ ]