| """
|
| Mask refinement and region extraction
|
| Implements Critical Fix #3: Adaptive Mask Refinement Thresholds
|
| """
|
|
|
| import cv2
|
| import numpy as np
|
| from typing import List, Tuple, Dict, Optional
|
| from scipy import ndimage
|
| from skimage.measure import label, regionprops
|
|
|
|
|
| class MaskRefiner:
|
| """
|
| Mask refinement with adaptive thresholds
|
| Implements Critical Fix #3: Dataset-specific minimum region areas
|
| """
|
|
|
| def __init__(self, config, dataset_name: str = 'default'):
|
| """
|
| Initialize mask refiner
|
|
|
| Args:
|
| config: Configuration object
|
| dataset_name: Dataset name for adaptive thresholds
|
| """
|
| self.config = config
|
| self.dataset_name = dataset_name
|
|
|
|
|
| self.threshold = config.get('mask_refinement.threshold', 0.5)
|
| self.closing_kernel = config.get('mask_refinement.morphology.closing_kernel', 5)
|
| self.opening_kernel = config.get('mask_refinement.morphology.opening_kernel', 3)
|
|
|
|
|
| self.min_region_area = config.get_min_region_area(dataset_name)
|
|
|
| print(f"MaskRefiner initialized for {dataset_name}")
|
| print(f"Min region area: {self.min_region_area * 100:.2f}%")
|
|
|
| def refine(self,
|
| probability_map: np.ndarray,
|
| original_size: Tuple[int, int] = None) -> np.ndarray:
|
| """
|
| Refine probability map to binary mask
|
|
|
| Args:
|
| probability_map: Forgery probability map (H, W), values [0, 1]
|
| original_size: Optional (H, W) to resize mask back to original
|
|
|
| Returns:
|
| Refined binary mask (H, W)
|
| """
|
|
|
| binary_mask = (probability_map > self.threshold).astype(np.uint8)
|
|
|
|
|
| closing_kernel = cv2.getStructuringElement(
|
| cv2.MORPH_RECT,
|
| (self.closing_kernel, self.closing_kernel)
|
| )
|
| binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, closing_kernel)
|
|
|
|
|
| opening_kernel = cv2.getStructuringElement(
|
| cv2.MORPH_RECT,
|
| (self.opening_kernel, self.opening_kernel)
|
| )
|
| binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, opening_kernel)
|
|
|
|
|
| binary_mask = self._remove_small_regions(binary_mask)
|
|
|
|
|
| if original_size is not None:
|
| binary_mask = cv2.resize(
|
| binary_mask,
|
| (original_size[1], original_size[0]),
|
| interpolation=cv2.INTER_NEAREST
|
| )
|
|
|
| return binary_mask
|
|
|
| def _remove_small_regions(self, mask: np.ndarray) -> np.ndarray:
|
| """
|
| Remove regions smaller than minimum area threshold
|
|
|
| Args:
|
| mask: Binary mask (H, W)
|
|
|
| Returns:
|
| Filtered mask
|
| """
|
|
|
| image_area = mask.shape[0] * mask.shape[1]
|
| min_pixels = int(image_area * self.min_region_area)
|
|
|
|
|
| labeled_mask, num_features = ndimage.label(mask)
|
|
|
|
|
| filtered_mask = np.zeros_like(mask)
|
|
|
| for region_id in range(1, num_features + 1):
|
| region_mask = (labeled_mask == region_id)
|
| region_area = region_mask.sum()
|
|
|
| if region_area >= min_pixels:
|
| filtered_mask[region_mask] = 1
|
|
|
| return filtered_mask
|
|
|
|
|
| class RegionExtractor:
|
| """
|
| Extract individual regions from binary mask
|
| Implements Critical Fix #4: Region Confidence Aggregation
|
| """
|
|
|
| def __init__(self, config, dataset_name: str = 'default'):
|
| """
|
| Initialize region extractor
|
|
|
| Args:
|
| config: Configuration object
|
| dataset_name: Dataset name
|
| """
|
| self.config = config
|
| self.dataset_name = dataset_name
|
| self.min_region_area = config.get_min_region_area(dataset_name)
|
|
|
| def extract(self,
|
| binary_mask: np.ndarray,
|
| probability_map: np.ndarray,
|
| original_image: np.ndarray) -> List[Dict]:
|
| """
|
| Extract regions from binary mask
|
|
|
| Args:
|
| binary_mask: Refined binary mask (H, W)
|
| probability_map: Original probability map (H, W)
|
| original_image: Original image (H, W, 3)
|
|
|
| Returns:
|
| List of region dictionaries with bounding box, mask, image, confidence
|
| """
|
| regions = []
|
|
|
|
|
| labeled_mask = label(binary_mask, connectivity=2)
|
| props = regionprops(labeled_mask)
|
|
|
| for region_id, prop in enumerate(props, start=1):
|
|
|
| y_min, x_min, y_max, x_max = prop.bbox
|
|
|
|
|
| region_mask = (labeled_mask == region_id).astype(np.uint8)
|
|
|
|
|
| region_image = original_image[y_min:y_max, x_min:x_max].copy()
|
| region_mask_cropped = region_mask[y_min:y_max, x_min:x_max]
|
|
|
|
|
| region_probs = probability_map[region_mask > 0]
|
| region_confidence = float(np.mean(region_probs)) if len(region_probs) > 0 else 0.0
|
|
|
| regions.append({
|
| 'region_id': region_id,
|
| 'bounding_box': [int(x_min), int(y_min),
|
| int(x_max - x_min), int(y_max - y_min)],
|
| 'area': prop.area,
|
| 'centroid': (int(prop.centroid[1]), int(prop.centroid[0])),
|
| 'region_mask': region_mask,
|
| 'region_mask_cropped': region_mask_cropped,
|
| 'region_image': region_image,
|
| 'confidence': region_confidence,
|
| 'mask_probability_mean': region_confidence
|
| })
|
|
|
| return regions
|
|
|
| def extract_for_casia(self,
|
| binary_mask: np.ndarray,
|
| probability_map: np.ndarray,
|
| original_image: np.ndarray) -> List[Dict]:
|
| """
|
| Critical Fix #6: CASIA handling - treat entire image as one region
|
|
|
| Args:
|
| binary_mask: Binary mask (may be empty for authentic images)
|
| probability_map: Probability map
|
| original_image: Original image
|
|
|
| Returns:
|
| Single region representing entire image
|
| """
|
| h, w = original_image.shape[:2]
|
|
|
|
|
| region_mask = np.ones((h, w), dtype=np.uint8)
|
|
|
|
|
| overall_confidence = float(np.mean(probability_map))
|
|
|
| return [{
|
| 'region_id': 1,
|
| 'bounding_box': [0, 0, w, h],
|
| 'area': h * w,
|
| 'centroid': (w // 2, h // 2),
|
| 'region_mask': region_mask,
|
| 'region_mask_cropped': region_mask,
|
| 'region_image': original_image,
|
| 'confidence': overall_confidence,
|
| 'mask_probability_mean': overall_confidence
|
| }]
|
|
|
|
|
| def get_mask_refiner(config, dataset_name: str = 'default') -> MaskRefiner:
|
| """Factory function for mask refiner"""
|
| return MaskRefiner(config, dataset_name)
|
|
|
|
|
| def get_region_extractor(config, dataset_name: str = 'default') -> RegionExtractor:
|
| """Factory function for region extractor"""
|
| return RegionExtractor(config, dataset_name)
|
|
|