MogensR commited on
Commit
30d0dd7
·
1 Parent(s): b1b89ac

Update processing/matting.py

Browse files
Files changed (1) hide show
  1. processing/matting.py +256 -304
processing/matting.py CHANGED
@@ -1,27 +1,26 @@
 
1
  """
2
  Advanced matting algorithms for BackgroundFX Pro.
3
  Implements multiple matting techniques with automatic fallback.
4
  """
5
 
 
 
 
 
 
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
- import numpy as np
10
- import cv2
11
- from typing import Dict, Tuple, Optional, List
12
- from dataclasses import dataclass
13
- import logging
14
 
15
  from utils.logger import get_logger
16
- logger = get_logger(__name__)
17
  from utils.hardware.device_manager import DeviceManager
18
- from utils.config import ConfigManager
19
- from core.models import ModelFactory, ModelType
20
  from core.quality import QualityAnalyzer
21
  from core.edge import EdgeRefinement
22
 
23
-
24
- logger = setup_logger(__name__)
25
 
26
 
27
  @dataclass
@@ -42,411 +41,364 @@ class MattingConfig:
42
 
43
  class AlphaMatting:
44
  """Advanced alpha matting using multiple techniques."""
45
-
46
  def __init__(self, config: Optional[MattingConfig] = None):
47
  self.config = config or MattingConfig()
48
  self.device_manager = DeviceManager()
49
  self.quality_analyzer = QualityAnalyzer()
50
  self.edge_refinement = EdgeRefinement()
51
-
52
- def create_trimap(self, mask: np.ndarray,
53
- dilation_size: int = None) -> np.ndarray:
54
  """
55
- Create trimap from binary mask.
56
-
57
  Args:
58
- mask: Binary mask (H, W)
59
- dilation_size: Size of uncertain region
60
-
61
  Returns:
62
- Trimap with 0 (background), 128 (unknown), 255 (foreground)
63
  """
64
  dilation_size = dilation_size or self.config.trimap_size
65
-
66
- # Ensure binary mask
67
  if mask.dtype != np.uint8:
68
  mask = (mask * 255).astype(np.uint8)
69
-
70
- # Create trimap
71
  trimap = np.copy(mask)
72
- kernel = cv2.getStructuringElement(
73
- cv2.MORPH_ELLIPSE,
74
- (dilation_size, dilation_size)
75
- )
76
-
77
- # Dilate and erode to create unknown region
78
  dilated = cv2.dilate(mask, kernel, iterations=1)
79
  eroded = cv2.erode(mask, kernel, iterations=1)
80
-
81
- # Set unknown region
82
- trimap[dilated == 255] = 128
83
  trimap[eroded == 255] = 255
84
-
 
 
85
  return trimap
86
-
87
- def guided_filter(self, image: np.ndarray,
88
- guide: np.ndarray,
89
- radius: int = None,
90
- eps: float = None) -> np.ndarray:
 
 
 
91
  """
92
  Apply guided filter for edge-preserving smoothing.
93
-
94
  Args:
95
- image: Input image to filter
96
- guide: Guide image (usually RGB image)
97
  radius: Filter radius
98
  eps: Regularization parameter
99
-
100
  Returns:
101
- Filtered image
102
  """
103
  radius = radius or self.config.guided_filter_radius
104
  eps = eps or self.config.guided_filter_eps
105
-
106
- if len(guide.shape) == 3:
107
- guide = cv2.cvtColor(guide, cv2.COLOR_BGR2GRAY)
108
-
109
- # Convert to float32
110
- guide = guide.astype(np.float32) / 255.0
111
- image = image.astype(np.float32) / 255.0
112
-
 
 
113
  # Box filter helper
114
  def box_filter(img, r):
115
  return cv2.boxFilter(img, -1, (r, r))
116
-
117
- # Guided filter implementation
118
- mean_I = box_filter(guide, radius)
119
- mean_p = box_filter(image, radius)
120
- mean_Ip = box_filter(guide * image, radius)
121
  cov_Ip = mean_Ip - mean_I * mean_p
122
-
123
- mean_II = box_filter(guide * guide, radius)
124
  var_I = mean_II - mean_I * mean_I
125
-
126
  a = cov_Ip / (var_I + eps)
127
  b = mean_p - a * mean_I
128
-
129
  mean_a = box_filter(a, radius)
130
  mean_b = box_filter(b, radius)
131
-
132
- output = mean_a * guide + mean_b
133
- return np.clip(output * 255, 0, 255).astype(np.uint8)
134
-
135
- def closed_form_matting(self, image: np.ndarray,
136
- trimap: np.ndarray) -> np.ndarray:
137
  """
138
- Closed-form matting using Laplacian matrix.
139
- Simplified version for real-time processing.
140
-
141
  Args:
142
- image: RGB image
143
- trimap: Trimap with known regions
144
-
145
  Returns:
146
- Alpha matte
147
  """
148
- h, w = trimap.shape
149
-
150
- # Initialize alpha with trimap
151
- alpha = np.copy(trimap).astype(np.float32) / 255.0
152
-
153
- # Known regions
154
  is_fg = trimap == 255
155
  is_bg = trimap == 0
156
  is_unknown = trimap == 128
157
-
158
  if not np.any(is_unknown):
159
- return alpha
160
-
161
- # Simple propagation from known to unknown regions
162
- # Using distance transform for efficiency
163
- dist_fg = cv2.distanceTransform(
164
- is_fg.astype(np.uint8),
165
- cv2.DIST_L2, 5
166
- )
167
- dist_bg = cv2.distanceTransform(
168
- is_bg.astype(np.uint8),
169
- cv2.DIST_L2, 5
170
- )
171
-
172
- # Normalize distances
173
- total_dist = dist_fg + dist_bg + 1e-10
174
- alpha_unknown = dist_fg / total_dist
175
-
176
- # Apply only to unknown regions
177
  alpha[is_unknown] = alpha_unknown[is_unknown]
178
-
179
- # Apply guided filter for smoothing
180
  if self.config.use_guided_filter:
181
- alpha = self.guided_filter(
182
- (alpha * 255).astype(np.uint8),
183
- image
184
- ) / 255.0
185
-
186
- return np.clip(alpha, 0, 1)
187
-
188
- def deep_matting(self, image: np.ndarray,
189
- mask: np.ndarray,
190
- model: Optional[nn.Module] = None) -> np.ndarray:
 
 
191
  """
192
  Apply deep learning-based matting refinement.
193
-
194
  Args:
195
- image: RGB image
196
- mask: Initial mask
197
- model: Optional pre-trained model
198
-
199
  Returns:
200
- Refined alpha matte
201
  """
202
  device = self.device_manager.get_device()
203
-
204
- # Prepare input
205
  h, w = image.shape[:2]
206
-
207
- # Resize for model input
208
  input_size = (512, 512)
209
- image_resized = cv2.resize(image, input_size)
210
- mask_resized = cv2.resize(mask, input_size)
211
-
212
- # Convert to tensor
213
- image_tensor = torch.from_numpy(
214
- image_resized.transpose(2, 0, 1)
215
- ).float().unsqueeze(0) / 255.0
216
-
217
- mask_tensor = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0) / 255.0
218
-
219
- # Move to device
220
- image_tensor = image_tensor.to(device)
221
- mask_tensor = mask_tensor.to(device)
222
-
223
- # If no model provided, use simple refinement
224
- if model is None:
225
- # Simple CNN-based refinement
226
- with torch.no_grad():
227
- # Concatenate image and mask
228
- x = torch.cat([image_tensor, mask_tensor], dim=1)
229
-
230
- # Simple refinement network simulation
231
  refined = self._simple_refine_network(x)
232
-
233
- # Convert back to numpy
234
- alpha = refined.squeeze().cpu().numpy()
235
- else:
236
- with torch.no_grad():
237
- alpha = model(image_tensor, mask_tensor)
238
- alpha = alpha.squeeze().cpu().numpy()
239
-
240
- # Resize back to original size
241
  alpha = cv2.resize(alpha, (w, h))
242
-
243
- return np.clip(alpha, 0, 1)
244
-
245
  def _simple_refine_network(self, x: torch.Tensor) -> torch.Tensor:
246
- """Simple refinement network for demonstration."""
247
- # Extract mask channel
248
  mask = x[:, 3:4, :, :]
249
-
250
- # Apply series of filters
251
  refined = mask
252
-
253
- # Edge-aware smoothing
254
  for _ in range(3):
255
  refined = F.avg_pool2d(refined, 3, stride=1, padding=1)
256
- refined = torch.sigmoid((refined - 0.5) * 10)
257
-
258
  return refined
259
-
260
  def morphological_refinement(self, alpha: np.ndarray) -> np.ndarray:
261
  """
262
- Apply morphological operations for refinement.
263
-
264
  Args:
265
- alpha: Alpha matte
266
-
267
  Returns:
268
- Refined alpha matte
269
  """
270
- # Convert to uint8
271
- alpha_uint8 = (alpha * 255).astype(np.uint8)
272
-
273
- # Morphological operations
274
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
275
-
276
- # Remove small holes
277
- alpha_uint8 = cv2.morphologyEx(
278
- alpha_uint8, cv2.MORPH_CLOSE, kernel,
279
- iterations=self.config.erode_iterations
280
  )
281
-
282
- # Remove small components
283
- alpha_uint8 = cv2.morphologyEx(
284
- alpha_uint8, cv2.MORPH_OPEN, kernel,
285
- iterations=self.config.dilate_iterations
286
  )
287
-
288
- # Smooth boundaries
289
  if self.config.blur_radius > 0:
290
- alpha_uint8 = cv2.GaussianBlur(
291
- alpha_uint8,
292
- (self.config.blur_radius * 2 + 1, self.config.blur_radius * 2 + 1),
293
- 0
294
- )
295
-
296
- return alpha_uint8.astype(np.float32) / 255.0
297
-
298
- def process(self, image: np.ndarray,
299
- mask: np.ndarray,
300
- method: str = 'auto') -> Dict[str, np.ndarray]:
301
  """
302
  Process image with selected matting method.
303
-
304
  Args:
305
- image: RGB image
306
- mask: Initial segmentation mask
307
- method: Matting method ('auto', 'trimap', 'deep', 'guided')
308
-
309
  Returns:
310
- Dictionary with alpha matte and confidence
311
  """
312
  try:
313
- # Analyze quality
314
  quality_metrics = self.quality_analyzer.analyze_frame(image)
315
-
316
- # Select method based on quality
317
- if method == 'auto':
318
- if quality_metrics['blur_score'] > 50:
319
- method = 'guided'
320
- elif quality_metrics['edge_clarity'] > 0.7:
321
- method = 'trimap'
 
 
 
322
  else:
323
- method = 'deep'
324
-
325
- logger.info(f"Using matting method: {method}")
326
-
327
- # Apply selected method
328
- if method == 'trimap':
329
  trimap = self.create_trimap(mask)
330
  alpha = self.closed_form_matting(image, trimap)
331
-
332
- elif method == 'deep':
333
  alpha = self.deep_matting(image, mask)
334
-
335
- elif method == 'guided':
336
- alpha = mask.astype(np.float32) / 255.0
 
337
  if self.config.use_guided_filter:
338
- alpha = self.guided_filter(
339
- (alpha * 255).astype(np.uint8),
340
- image
341
- ) / 255.0
342
  else:
343
- # Default to simple refinement
344
- alpha = mask.astype(np.float32) / 255.0
345
-
346
- # Apply morphological refinement
 
347
  alpha = self.morphological_refinement(alpha)
348
-
349
- # Edge refinement
350
  alpha = self.edge_refinement.refine_edges(
351
- image,
352
- (alpha * 255).astype(np.uint8)
353
- ) / 255.0
354
-
355
- # Calculate confidence
356
  confidence = self._calculate_confidence(alpha, quality_metrics)
357
-
358
  return {
359
- 'alpha': alpha,
360
- 'confidence': confidence,
361
- 'method_used': method,
362
- 'quality_metrics': quality_metrics
363
  }
364
-
365
  except Exception as e:
366
  logger.error(f"Matting processing failed: {e}")
367
- # Return original mask as fallback
 
 
368
  return {
369
- 'alpha': mask.astype(np.float32) / 255.0,
370
- 'confidence': 0.0,
371
- 'method_used': 'fallback',
372
- 'error': str(e)
373
  }
374
-
375
- def _calculate_confidence(self, alpha: np.ndarray,
376
- quality_metrics: Dict) -> float:
377
  """Calculate confidence score for the matting result."""
378
- # Base confidence from quality metrics
379
- confidence = quality_metrics.get('overall_quality', 0.5)
380
-
381
- # Adjust based on alpha distribution
382
- alpha_mean = np.mean(alpha)
383
- alpha_std = np.std(alpha)
384
-
385
- # Good matting should have clear separation
386
  if 0.3 < alpha_mean < 0.7 and alpha_std > 0.3:
387
  confidence *= 1.2
388
-
389
- # Check for edge clarity
390
- edges = cv2.Canny((alpha * 255).astype(np.uint8), 50, 150)
391
- edge_ratio = np.sum(edges > 0) / edges.size
392
-
393
- if edge_ratio < 0.1: # Clear boundaries
394
  confidence *= 1.1
395
-
396
- return np.clip(confidence, 0.0, 1.0)
397
 
398
 
399
  class CompositingEngine:
400
  """Handle alpha compositing and blending."""
401
-
402
  def __init__(self):
403
- self.logger = setup_logger(f"{__name__}.CompositingEngine")
404
-
405
- def composite(self, foreground: np.ndarray,
406
- background: np.ndarray,
407
- alpha: np.ndarray) -> np.ndarray:
408
  """
409
  Composite foreground over background using alpha.
410
-
411
  Args:
412
- foreground: Foreground image (H, W, 3)
413
- background: Background image (H, W, 3)
414
- alpha: Alpha matte (H, W) or (H, W, 1)
415
-
416
  Returns:
417
- Composited image
418
- self.logger = get_logger(f"{__name__}.CompositingEngine")
419
  # Ensure alpha is 3-channel
420
- if len(alpha.shape) == 2:
421
  alpha = np.expand_dims(alpha, axis=2)
422
  if alpha.shape[2] == 1:
423
  alpha = np.repeat(alpha, 3, axis=2)
424
-
425
- # Ensure float32
 
 
 
 
426
  fg = foreground.astype(np.float32) / 255.0
427
  bg = background.astype(np.float32) / 255.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  a = alpha.astype(np.float32)
429
-
430
  if a.max() > 1.0:
431
  a = a / 255.0
432
-
433
- # Alpha blending
434
- result = fg * a + bg * (1 - a)
435
-
436
- # Convert back to uint8
437
- result = np.clip(result * 255, 0, 255).astype(np.uint8)
438
-
439
- return result
440
-
441
- def premultiply_alpha(self, image: np.ndarray,
442
- alpha: np.ndarray) -> np.ndarray:
443
- """Premultiply image by alpha channel."""
444
- if len(alpha.shape) == 2:
445
- alpha = np.expand_dims(alpha, axis=2)
446
-
447
- result = image.astype(np.float32) * alpha.astype(np.float32)
448
-
449
- if alpha.max() > 1.0:
450
- result = result / 255.0
451
-
452
- return np.clip(result, 0, 255).astype(np.uint8)
 
1
+ #!/usr/bin/env python3
2
  """
3
  Advanced matting algorithms for BackgroundFX Pro.
4
  Implements multiple matting techniques with automatic fallback.
5
  """
6
 
7
+ from dataclasses import dataclass
8
+ from typing import Dict, Optional
9
+
10
+ import cv2
11
+ import numpy as np
12
  import torch
13
  import torch.nn as nn
14
  import torch.nn.functional as F
 
 
 
 
 
15
 
16
  from utils.logger import get_logger
 
17
  from utils.hardware.device_manager import DeviceManager
18
+ from utils.config import ConfigManager # kept for forward compatibility / config hook
19
+ from core.models import ModelFactory, ModelType # not used directly here but kept for API consistency
20
  from core.quality import QualityAnalyzer
21
  from core.edge import EdgeRefinement
22
 
23
+ logger = get_logger(__name__)
 
24
 
25
 
26
  @dataclass
 
41
 
42
  class AlphaMatting:
43
  """Advanced alpha matting using multiple techniques."""
44
+
45
  def __init__(self, config: Optional[MattingConfig] = None):
46
  self.config = config or MattingConfig()
47
  self.device_manager = DeviceManager()
48
  self.quality_analyzer = QualityAnalyzer()
49
  self.edge_refinement = EdgeRefinement()
50
+
51
+ def create_trimap(self, mask: np.ndarray, dilation_size: Optional[int] = None) -> np.ndarray:
 
52
  """
53
+ Create trimap from a binary mask.
54
+
55
  Args:
56
+ mask: Binary mask (H, W) in {0, 255} or [0,1]
57
+ dilation_size: Size of uncertain region (pixels)
58
+
59
  Returns:
60
+ Trimap with values 0 (background), 128 (unknown), 255 (foreground)
61
  """
62
  dilation_size = dilation_size or self.config.trimap_size
63
+
64
+ # Ensure uint8 binary
65
  if mask.dtype != np.uint8:
66
  mask = (mask * 255).astype(np.uint8)
67
+ mask = (mask > 127).astype(np.uint8) * 255
68
+
69
  trimap = np.copy(mask)
70
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilation_size, dilation_size))
71
+
72
+ # Dilate/erode once to form unknown band
 
 
 
73
  dilated = cv2.dilate(mask, kernel, iterations=1)
74
  eroded = cv2.erode(mask, kernel, iterations=1)
75
+
76
+ # Unknown where dilation has expanded FG beyond eroded FG band
77
+ trimap[:] = 0
78
  trimap[eroded == 255] = 255
79
+ unknown = (dilated == 255) & (eroded == 0)
80
+ trimap[unknown] = 128
81
+
82
  return trimap
83
+
84
+ def guided_filter(
85
+ self,
86
+ image: np.ndarray,
87
+ guide: np.ndarray,
88
+ radius: Optional[int] = None,
89
+ eps: Optional[float] = None,
90
+ ) -> np.ndarray:
91
  """
92
  Apply guided filter for edge-preserving smoothing.
93
+
94
  Args:
95
+ image: Input image to filter (H, W) uint8
96
+ guide: Guide image (H, W, 3) or (H, W)
97
  radius: Filter radius
98
  eps: Regularization parameter
99
+
100
  Returns:
101
+ Filtered image (H, W) uint8
102
  """
103
  radius = radius or self.config.guided_filter_radius
104
  eps = eps or self.config.guided_filter_eps
105
+
106
+ if guide.ndim == 3:
107
+ guide_gray = cv2.cvtColor(guide, cv2.COLOR_BGR2GRAY)
108
+ else:
109
+ guide_gray = guide
110
+
111
+ # Convert to float32 in [0,1]
112
+ I = guide_gray.astype(np.float32) / 255.0
113
+ p = image.astype(np.float32) / 255.0
114
+
115
  # Box filter helper
116
  def box_filter(img, r):
117
  return cv2.boxFilter(img, -1, (r, r))
118
+
119
+ mean_I = box_filter(I, radius)
120
+ mean_p = box_filter(p, radius)
121
+ mean_Ip = box_filter(I * p, radius)
 
122
  cov_Ip = mean_Ip - mean_I * mean_p
123
+
124
+ mean_II = box_filter(I * I, radius)
125
  var_I = mean_II - mean_I * mean_I
126
+
127
  a = cov_Ip / (var_I + eps)
128
  b = mean_p - a * mean_I
129
+
130
  mean_a = box_filter(a, radius)
131
  mean_b = box_filter(b, radius)
132
+
133
+ q = mean_a * I + mean_b
134
+ return np.clip(q * 255.0, 0, 255).astype(np.uint8)
135
+
136
+ def closed_form_matting(self, image: np.ndarray, trimap: np.ndarray) -> np.ndarray:
 
137
  """
138
+ Closed-form-inspired fast matting using distance transforms + optional guided filtering.
139
+
 
140
  Args:
141
+ image: RGB image (H, W, 3) uint8
142
+ trimap: Trimap with values {0, 128, 255}
143
+
144
  Returns:
145
+ Alpha matte in [0,1] float32
146
  """
147
+ h, w = trimap.shape[:2]
148
+ alpha = (trimap.astype(np.float32) / 255.0)
149
+
 
 
 
150
  is_fg = trimap == 255
151
  is_bg = trimap == 0
152
  is_unknown = trimap == 128
153
+
154
  if not np.any(is_unknown):
155
+ return np.clip(alpha, 0.0, 1.0)
156
+
157
+ dist_fg = cv2.distanceTransform(is_fg.astype(np.uint8), cv2.DIST_L2, 5)
158
+ dist_bg = cv2.distanceTransform(is_bg.astype(np.uint8), cv2.DIST_L2, 5)
159
+
160
+ total = dist_fg + dist_bg + 1e-10
161
+ alpha_unknown = dist_fg / total
 
 
 
 
 
 
 
 
 
 
 
162
  alpha[is_unknown] = alpha_unknown[is_unknown]
163
+
 
164
  if self.config.use_guided_filter:
165
+ alpha_u8 = np.clip(alpha * 255.0, 0, 255).astype(np.uint8)
166
+ alpha_u8 = self.guided_filter(alpha_u8, image)
167
+ alpha = alpha_u8.astype(np.float32) / 255.0
168
+
169
+ return np.clip(alpha, 0.0, 1.0)
170
+
171
+ def deep_matting(
172
+ self,
173
+ image: np.ndarray,
174
+ mask: np.ndarray,
175
+ model: Optional[nn.Module] = None,
176
+ ) -> np.ndarray:
177
  """
178
  Apply deep learning-based matting refinement.
179
+
180
  Args:
181
+ image: RGB image (H, W, 3) uint8
182
+ mask: Initial mask (H, W) {0..255} or [0,1]
183
+ model: Optional pre-trained model taking (img, mask) → alpha
184
+
185
  Returns:
186
+ Refined alpha matte in [0,1] float32
187
  """
188
  device = self.device_manager.get_device()
189
+
 
190
  h, w = image.shape[:2]
 
 
191
  input_size = (512, 512)
192
+
193
+ img_rs = cv2.resize(image, input_size)
194
+ msk_rs = cv2.resize(mask, input_size)
195
+
196
+ img_t = torch.from_numpy(img_rs.transpose(2, 0, 1)).float().unsqueeze(0) / 255.0
197
+ msk_t = torch.from_numpy(msk_rs).float().unsqueeze(0).unsqueeze(0)
198
+ if msk_t.max() > 1.0:
199
+ msk_t = msk_t / 255.0
200
+
201
+ img_t = img_t.to(device)
202
+ msk_t = msk_t.to(device)
203
+
204
+ with torch.no_grad():
205
+ if model is None:
206
+ x = torch.cat([img_t, msk_t], dim=1)
 
 
 
 
 
 
 
207
  refined = self._simple_refine_network(x)
208
+ else:
209
+ refined = model(img_t, msk_t)
210
+ alpha = refined.squeeze().float().cpu().numpy()
211
+
 
 
 
 
 
212
  alpha = cv2.resize(alpha, (w, h))
213
+ return np.clip(alpha, 0.0, 1.0)
214
+
 
215
  def _simple_refine_network(self, x: torch.Tensor) -> torch.Tensor:
216
+ """Tiny non-learned refinement block (demo-quality)."""
217
+ # x: [B, 4, H, W] (RGB + mask)
218
  mask = x[:, 3:4, :, :]
219
+
 
220
  refined = mask
 
 
221
  for _ in range(3):
222
  refined = F.avg_pool2d(refined, 3, stride=1, padding=1)
223
+ refined = torch.sigmoid((refined - 0.5) * 10.0)
224
+
225
  return refined
226
+
227
  def morphological_refinement(self, alpha: np.ndarray) -> np.ndarray:
228
  """
229
+ Apply morphological operations and boundary smoothing.
230
+
231
  Args:
232
+ alpha: Alpha matte in [0,1] float32
233
+
234
  Returns:
235
+ Refined alpha in [0,1] float32
236
  """
237
+ alpha_u8 = np.clip(alpha * 255.0, 0, 255).astype(np.uint8)
 
 
 
238
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
239
+
240
+ # Close small holes in FG
241
+ alpha_u8 = cv2.morphologyEx(
242
+ alpha_u8, cv2.MORPH_CLOSE, kernel, iterations=self.config.erode_iterations
 
243
  )
244
+ # Remove small specks
245
+ alpha_u8 = cv2.morphologyEx(
246
+ alpha_u8, cv2.MORPH_OPEN, kernel, iterations=self.config.dilate_iterations
 
 
247
  )
248
+
 
249
  if self.config.blur_radius > 0:
250
+ r = self.config.blur_radius * 2 + 1
251
+ alpha_u8 = cv2.GaussianBlur(alpha_u8, (r, r), 0)
252
+
253
+ return alpha_u8.astype(np.float32) / 255.0
254
+
255
+ def process(self, image: np.ndarray, mask: np.ndarray, method: str = "auto") -> Dict[str, np.ndarray]:
 
 
 
 
 
256
  """
257
  Process image with selected matting method.
258
+
259
  Args:
260
+ image: RGB image (H, W, 3) uint8
261
+ mask: Initial segmentation mask (H, W)
262
+ method: 'auto' | 'trimap' | 'deep' | 'guided'
263
+
264
  Returns:
265
+ dict(alpha, confidence, method_used, quality_metrics[, error])
266
  """
267
  try:
 
268
  quality_metrics = self.quality_analyzer.analyze_frame(image)
269
+
270
+ chosen = method
271
+ if method == "auto":
272
+ # Heuristic selection
273
+ blur_score = quality_metrics.get("blur_score", 0.0)
274
+ edge_clarity = quality_metrics.get("edge_clarity", 0.0)
275
+ if blur_score > 50:
276
+ chosen = "guided"
277
+ elif edge_clarity > 0.7:
278
+ chosen = "trimap"
279
  else:
280
+ chosen = "deep"
281
+
282
+ logger.info(f"Using matting method: {chosen}")
283
+
284
+ if chosen == "trimap":
 
285
  trimap = self.create_trimap(mask)
286
  alpha = self.closed_form_matting(image, trimap)
287
+ elif chosen == "deep":
 
288
  alpha = self.deep_matting(image, mask)
289
+ elif chosen == "guided":
290
+ alpha = mask.astype(np.float32)
291
+ if alpha.max() > 1.0:
292
+ alpha = alpha / 255.0
293
  if self.config.use_guided_filter:
294
+ alpha_u8 = np.clip(alpha * 255.0, 0, 255).astype(np.uint8)
295
+ alpha = self.guided_filter(alpha_u8, image).astype(np.float32) / 255.0
 
 
296
  else:
297
+ alpha = mask.astype(np.float32)
298
+ if alpha.max() > 1.0:
299
+ alpha = alpha / 255.0
300
+
301
+ # Morphological + edge refinement
302
  alpha = self.morphological_refinement(alpha)
 
 
303
  alpha = self.edge_refinement.refine_edges(
304
+ image, np.clip(alpha * 255.0, 0, 255).astype(np.uint8)
305
+ ).astype(np.float32) / 255.0
306
+
 
 
307
  confidence = self._calculate_confidence(alpha, quality_metrics)
308
+
309
  return {
310
+ "alpha": np.clip(alpha, 0.0, 1.0),
311
+ "confidence": float(np.clip(confidence, 0.0, 1.0)),
312
+ "method_used": chosen,
313
+ "quality_metrics": quality_metrics,
314
  }
315
+
316
  except Exception as e:
317
  logger.error(f"Matting processing failed: {e}")
318
+ fallback = mask.astype(np.float32)
319
+ if fallback.max() > 1.0:
320
+ fallback = fallback / 255.0
321
  return {
322
+ "alpha": np.clip(fallback, 0.0, 1.0),
323
+ "confidence": 0.0,
324
+ "method_used": "fallback",
325
+ "error": str(e),
326
  }
327
+
328
+ def _calculate_confidence(self, alpha: np.ndarray, quality_metrics: Dict) -> float:
 
329
  """Calculate confidence score for the matting result."""
330
+ confidence = float(quality_metrics.get("overall_quality", 0.5))
331
+
332
+ alpha_mean = float(np.mean(alpha))
333
+ alpha_std = float(np.std(alpha))
334
+
335
+ # Prefer clear separation
 
 
336
  if 0.3 < alpha_mean < 0.7 and alpha_std > 0.3:
337
  confidence *= 1.2
338
+
339
+ edges = cv2.Canny(np.clip(alpha * 255.0, 0, 255).astype(np.uint8), 50, 150)
340
+ edge_ratio = float(np.sum(edges > 0) / edges.size)
341
+ if edge_ratio < 0.1:
 
 
342
  confidence *= 1.1
343
+
344
+ return float(np.clip(confidence, 0.0, 1.0))
345
 
346
 
347
  class CompositingEngine:
348
  """Handle alpha compositing and blending."""
349
+
350
  def __init__(self):
351
+ self.logger = get_logger(f"{__name__}.CompositingEngine")
352
+
353
+ def composite(self, foreground: np.ndarray, background: np.ndarray, alpha: np.ndarray) -> np.ndarray:
 
 
354
  """
355
  Composite foreground over background using alpha.
356
+
357
  Args:
358
+ foreground: Foreground image (H, W, 3) uint8
359
+ background: Background image (H, W, 3) uint8
360
+ alpha: Alpha matte (H, W) or (H, W, 1) in [0..255] or [0..1]
361
+
362
  Returns:
363
+ Composited image (H, W, 3) uint8
364
+ """
365
  # Ensure alpha is 3-channel
366
+ if alpha.ndim == 2:
367
  alpha = np.expand_dims(alpha, axis=2)
368
  if alpha.shape[2] == 1:
369
  alpha = np.repeat(alpha, 3, axis=2)
370
+
371
+ # Normalize alpha to [0,1]
372
+ a = alpha.astype(np.float32)
373
+ if a.max() > 1.0:
374
+ a = a / 255.0
375
+
376
  fg = foreground.astype(np.float32) / 255.0
377
  bg = background.astype(np.float32) / 255.0
378
+
379
+ result = fg * a + bg * (1.0 - a)
380
+ return np.clip(result * 255.0, 0, 255).astype(np.uint8)
381
+
382
+ def premultiply_alpha(self, image: np.ndarray, alpha: np.ndarray) -> np.ndarray:
383
+ """
384
+ Premultiply RGB image by alpha channel.
385
+
386
+ Args:
387
+ image: (H, W, 3) uint8
388
+ alpha: (H, W) or (H, W, 1) in [0..255] or [0..1]
389
+
390
+ Returns:
391
+ Premultiplied (H, W, 3) uint8
392
+ """
393
+ if alpha.ndim == 2:
394
+ alpha = np.expand_dims(alpha, axis=2)
395
+ if alpha.shape[2] == 1:
396
+ alpha = np.repeat(alpha, 3, axis=2)
397
+
398
  a = alpha.astype(np.float32)
 
399
  if a.max() > 1.0:
400
  a = a / 255.0
401
+
402
+ img_f = image.astype(np.float32)
403
+ premul = img_f * a
404
+ return np.clip(premul, 0.0, 255.0).astype(np.uint8)