""" Mask refinement and region extraction Implements Critical Fix #3: Adaptive Mask Refinement Thresholds """ import cv2 import numpy as np from typing import List, Tuple, Dict, Optional from scipy import ndimage from skimage.measure import label, regionprops class MaskRefiner: """ Mask refinement with adaptive thresholds Implements Critical Fix #3: Dataset-specific minimum region areas """ def __init__(self, config, dataset_name: str = 'default'): """ Initialize mask refiner Args: config: Configuration object dataset_name: Dataset name for adaptive thresholds """ self.config = config self.dataset_name = dataset_name # Get mask refinement parameters self.threshold = config.get('mask_refinement.threshold', 0.5) self.closing_kernel = config.get('mask_refinement.morphology.closing_kernel', 5) self.opening_kernel = config.get('mask_refinement.morphology.opening_kernel', 3) # Critical Fix #3: Adaptive thresholds per dataset self.min_region_area = config.get_min_region_area(dataset_name) print(f"MaskRefiner initialized for {dataset_name}") print(f"Min region area: {self.min_region_area * 100:.2f}%") def refine(self, probability_map: np.ndarray, original_size: Tuple[int, int] = None) -> np.ndarray: """ Refine probability map to binary mask Args: probability_map: Forgery probability map (H, W), values [0, 1] original_size: Optional (H, W) to resize mask back to original Returns: Refined binary mask (H, W) """ # Threshold to binary binary_mask = (probability_map > self.threshold).astype(np.uint8) # Morphological closing (fill broken strokes) closing_kernel = cv2.getStructuringElement( cv2.MORPH_RECT, (self.closing_kernel, self.closing_kernel) ) binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, closing_kernel) # Morphological opening (remove isolated noise) opening_kernel = cv2.getStructuringElement( cv2.MORPH_RECT, (self.opening_kernel, self.opening_kernel) ) binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, opening_kernel) # Critical Fix #3: Remove small regions with adaptive threshold binary_mask = self._remove_small_regions(binary_mask) # Resize to original size if provided if original_size is not None: binary_mask = cv2.resize( binary_mask, (original_size[1], original_size[0]), # cv2 uses (W, H) interpolation=cv2.INTER_NEAREST ) return binary_mask def _remove_small_regions(self, mask: np.ndarray) -> np.ndarray: """ Remove regions smaller than minimum area threshold Args: mask: Binary mask (H, W) Returns: Filtered mask """ # Calculate minimum pixel count image_area = mask.shape[0] * mask.shape[1] min_pixels = int(image_area * self.min_region_area) # Label connected components labeled_mask, num_features = ndimage.label(mask) # Keep only large enough regions filtered_mask = np.zeros_like(mask) for region_id in range(1, num_features + 1): region_mask = (labeled_mask == region_id) region_area = region_mask.sum() if region_area >= min_pixels: filtered_mask[region_mask] = 1 return filtered_mask class RegionExtractor: """ Extract individual regions from binary mask Implements Critical Fix #4: Region Confidence Aggregation """ def __init__(self, config, dataset_name: str = 'default'): """ Initialize region extractor Args: config: Configuration object dataset_name: Dataset name """ self.config = config self.dataset_name = dataset_name self.min_region_area = config.get_min_region_area(dataset_name) def extract(self, binary_mask: np.ndarray, probability_map: np.ndarray, original_image: np.ndarray) -> List[Dict]: """ Extract regions from binary mask Args: binary_mask: Refined binary mask (H, W) probability_map: Original probability map (H, W) original_image: Original image (H, W, 3) Returns: List of region dictionaries with bounding box, mask, image, confidence """ regions = [] # Safety check: Ensure probability_map and binary_mask have same dimensions if probability_map.shape != binary_mask.shape: import cv2 probability_map = cv2.resize( probability_map, (binary_mask.shape[1], binary_mask.shape[0]), interpolation=cv2.INTER_LINEAR ) # Connected component analysis (8-connectivity) labeled_mask = label(binary_mask, connectivity=2) props = regionprops(labeled_mask) for region_id, prop in enumerate(props, start=1): # Bounding box y_min, x_min, y_max, x_max = prop.bbox # Region mask region_mask = (labeled_mask == region_id).astype(np.uint8) # Cropped region image region_image = original_image[y_min:y_max, x_min:x_max].copy() region_mask_cropped = region_mask[y_min:y_max, x_min:x_max] # Critical Fix #4: Region-level confidence aggregation # Ensure region_mask and probability_map have same shape if region_mask.shape != probability_map.shape: import cv2 # Resize probability_map to match region_mask probability_map = cv2.resize( probability_map, (region_mask.shape[1], region_mask.shape[0]), interpolation=cv2.INTER_LINEAR ) region_probs = probability_map[region_mask > 0] region_confidence = float(np.mean(region_probs)) if len(region_probs) > 0 else 0.0 regions.append({ 'region_id': region_id, 'bounding_box': [int(x_min), int(y_min), int(x_max - x_min), int(y_max - y_min)], 'area': prop.area, 'centroid': (int(prop.centroid[1]), int(prop.centroid[0])), 'region_mask': region_mask, 'region_mask_cropped': region_mask_cropped, 'region_image': region_image, 'confidence': region_confidence, 'mask_probability_mean': region_confidence }) return regions def extract_for_casia(self, binary_mask: np.ndarray, probability_map: np.ndarray, original_image: np.ndarray) -> List[Dict]: """ Critical Fix #6: CASIA handling - treat entire image as one region Args: binary_mask: Binary mask (may be empty for authentic images) probability_map: Probability map original_image: Original image Returns: Single region representing entire image """ h, w = original_image.shape[:2] # Create single region covering entire image region_mask = np.ones((h, w), dtype=np.uint8) # Overall confidence from probability map overall_confidence = float(np.mean(probability_map)) return [{ 'region_id': 1, 'bounding_box': [0, 0, w, h], 'area': h * w, 'centroid': (w // 2, h // 2), 'region_mask': region_mask, 'region_mask_cropped': region_mask, 'region_image': original_image, 'confidence': overall_confidence, 'mask_probability_mean': overall_confidence }] def get_mask_refiner(config, dataset_name: str = 'default') -> MaskRefiner: """Factory function for mask refiner""" return MaskRefiner(config, dataset_name) def get_region_extractor(config, dataset_name: str = 'default') -> RegionExtractor: """Factory function for region extractor""" return RegionExtractor(config, dataset_name)