| """ |
| Mask refinement and region extraction |
| Implements Critical Fix #3: Adaptive Mask Refinement Thresholds |
| """ |
|
|
| import cv2 |
| import numpy as np |
| from typing import List, Tuple, Dict, Optional |
| from scipy import ndimage |
| from skimage.measure import label, regionprops |
|
|
|
|
| class MaskRefiner: |
| """ |
| Mask refinement with adaptive thresholds |
| Implements Critical Fix #3: Dataset-specific minimum region areas |
| """ |
| |
| def __init__(self, config, dataset_name: str = 'default'): |
| """ |
| Initialize mask refiner |
| |
| Args: |
| config: Configuration object |
| dataset_name: Dataset name for adaptive thresholds |
| """ |
| self.config = config |
| self.dataset_name = dataset_name |
| |
| |
| self.threshold = config.get('mask_refinement.threshold', 0.5) |
| self.closing_kernel = config.get('mask_refinement.morphology.closing_kernel', 5) |
| self.opening_kernel = config.get('mask_refinement.morphology.opening_kernel', 3) |
| |
| |
| self.min_region_area = config.get_min_region_area(dataset_name) |
| |
| print(f"MaskRefiner initialized for {dataset_name}") |
| print(f"Min region area: {self.min_region_area * 100:.2f}%") |
| |
| def refine(self, |
| probability_map: np.ndarray, |
| original_size: Tuple[int, int] = None) -> np.ndarray: |
| """ |
| Refine probability map to binary mask |
| |
| Args: |
| probability_map: Forgery probability map (H, W), values [0, 1] |
| original_size: Optional (H, W) to resize mask back to original |
| |
| Returns: |
| Refined binary mask (H, W) |
| """ |
| |
| binary_mask = (probability_map > self.threshold).astype(np.uint8) |
| |
| |
| closing_kernel = cv2.getStructuringElement( |
| cv2.MORPH_RECT, |
| (self.closing_kernel, self.closing_kernel) |
| ) |
| binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, closing_kernel) |
| |
| |
| opening_kernel = cv2.getStructuringElement( |
| cv2.MORPH_RECT, |
| (self.opening_kernel, self.opening_kernel) |
| ) |
| binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, opening_kernel) |
| |
| |
| binary_mask = self._remove_small_regions(binary_mask) |
| |
| |
| if original_size is not None: |
| binary_mask = cv2.resize( |
| binary_mask, |
| (original_size[1], original_size[0]), |
| interpolation=cv2.INTER_NEAREST |
| ) |
| |
| return binary_mask |
| |
| def _remove_small_regions(self, mask: np.ndarray) -> np.ndarray: |
| """ |
| Remove regions smaller than minimum area threshold |
| |
| Args: |
| mask: Binary mask (H, W) |
| |
| Returns: |
| Filtered mask |
| """ |
| |
| image_area = mask.shape[0] * mask.shape[1] |
| min_pixels = int(image_area * self.min_region_area) |
| |
| |
| labeled_mask, num_features = ndimage.label(mask) |
| |
| |
| filtered_mask = np.zeros_like(mask) |
| |
| for region_id in range(1, num_features + 1): |
| region_mask = (labeled_mask == region_id) |
| region_area = region_mask.sum() |
| |
| if region_area >= min_pixels: |
| filtered_mask[region_mask] = 1 |
| |
| return filtered_mask |
|
|
|
|
| class RegionExtractor: |
| """ |
| Extract individual regions from binary mask |
| Implements Critical Fix #4: Region Confidence Aggregation |
| """ |
| |
| def __init__(self, config, dataset_name: str = 'default'): |
| """ |
| Initialize region extractor |
| |
| Args: |
| config: Configuration object |
| dataset_name: Dataset name |
| """ |
| self.config = config |
| self.dataset_name = dataset_name |
| self.min_region_area = config.get_min_region_area(dataset_name) |
| |
| def extract(self, |
| binary_mask: np.ndarray, |
| probability_map: np.ndarray, |
| original_image: np.ndarray) -> List[Dict]: |
| """ |
| Extract regions from binary mask |
| |
| Args: |
| binary_mask: Refined binary mask (H, W) |
| probability_map: Original probability map (H, W) |
| original_image: Original image (H, W, 3) |
| |
| Returns: |
| List of region dictionaries with bounding box, mask, image, confidence |
| """ |
| regions = [] |
| |
| |
| if probability_map.shape != binary_mask.shape: |
| import cv2 |
| probability_map = cv2.resize( |
| probability_map, |
| (binary_mask.shape[1], binary_mask.shape[0]), |
| interpolation=cv2.INTER_LINEAR |
| ) |
| |
| |
| labeled_mask = label(binary_mask, connectivity=2) |
| props = regionprops(labeled_mask) |
| |
| for region_id, prop in enumerate(props, start=1): |
| |
| y_min, x_min, y_max, x_max = prop.bbox |
| |
| |
| region_mask = (labeled_mask == region_id).astype(np.uint8) |
| |
| |
| region_image = original_image[y_min:y_max, x_min:x_max].copy() |
| region_mask_cropped = region_mask[y_min:y_max, x_min:x_max] |
| |
| |
| |
| |
| if region_mask.shape != probability_map.shape: |
| import cv2 |
| |
| probability_map = cv2.resize( |
| probability_map, |
| (region_mask.shape[1], region_mask.shape[0]), |
| interpolation=cv2.INTER_LINEAR |
| ) |
| |
| region_probs = probability_map[region_mask > 0] |
| region_confidence = float(np.mean(region_probs)) if len(region_probs) > 0 else 0.0 |
| |
| regions.append({ |
| 'region_id': region_id, |
| 'bounding_box': [int(x_min), int(y_min), |
| int(x_max - x_min), int(y_max - y_min)], |
| 'area': prop.area, |
| 'centroid': (int(prop.centroid[1]), int(prop.centroid[0])), |
| 'region_mask': region_mask, |
| 'region_mask_cropped': region_mask_cropped, |
| 'region_image': region_image, |
| 'confidence': region_confidence, |
| 'mask_probability_mean': region_confidence |
| }) |
| |
| return regions |
| |
| def extract_for_casia(self, |
| binary_mask: np.ndarray, |
| probability_map: np.ndarray, |
| original_image: np.ndarray) -> List[Dict]: |
| """ |
| Critical Fix #6: CASIA handling - treat entire image as one region |
| |
| Args: |
| binary_mask: Binary mask (may be empty for authentic images) |
| probability_map: Probability map |
| original_image: Original image |
| |
| Returns: |
| Single region representing entire image |
| """ |
| h, w = original_image.shape[:2] |
| |
| |
| region_mask = np.ones((h, w), dtype=np.uint8) |
| |
| |
| overall_confidence = float(np.mean(probability_map)) |
| |
| return [{ |
| 'region_id': 1, |
| 'bounding_box': [0, 0, w, h], |
| 'area': h * w, |
| 'centroid': (w // 2, h // 2), |
| 'region_mask': region_mask, |
| 'region_mask_cropped': region_mask, |
| 'region_image': original_image, |
| 'confidence': overall_confidence, |
| 'mask_probability_mean': overall_confidence |
| }] |
|
|
|
|
| def get_mask_refiner(config, dataset_name: str = 'default') -> MaskRefiner: |
| """Factory function for mask refiner""" |
| return MaskRefiner(config, dataset_name) |
|
|
|
|
| def get_region_extractor(config, dataset_name: str = 'default') -> RegionExtractor: |
| """Factory function for region extractor""" |
| return RegionExtractor(config, dataset_name) |
|
|