Document_Forgery_Detection / src /features /region_extraction.py
JKrishnanandhaa's picture
Update src/features/region_extraction.py
86140ca verified
"""
Mask refinement and region extraction
Implements Critical Fix #3: Adaptive Mask Refinement Thresholds
"""
import cv2
import numpy as np
from typing import List, Tuple, Dict, Optional
from scipy import ndimage
from skimage.measure import label, regionprops
class MaskRefiner:
"""
Mask refinement with adaptive thresholds
Implements Critical Fix #3: Dataset-specific minimum region areas
"""
def __init__(self, config, dataset_name: str = 'default'):
"""
Initialize mask refiner
Args:
config: Configuration object
dataset_name: Dataset name for adaptive thresholds
"""
self.config = config
self.dataset_name = dataset_name
# Get mask refinement parameters
self.threshold = config.get('mask_refinement.threshold', 0.5)
self.closing_kernel = config.get('mask_refinement.morphology.closing_kernel', 5)
self.opening_kernel = config.get('mask_refinement.morphology.opening_kernel', 3)
# Critical Fix #3: Adaptive thresholds per dataset
self.min_region_area = config.get_min_region_area(dataset_name)
print(f"MaskRefiner initialized for {dataset_name}")
print(f"Min region area: {self.min_region_area * 100:.2f}%")
def refine(self,
probability_map: np.ndarray,
original_size: Tuple[int, int] = None) -> np.ndarray:
"""
Refine probability map to binary mask
Args:
probability_map: Forgery probability map (H, W), values [0, 1]
original_size: Optional (H, W) to resize mask back to original
Returns:
Refined binary mask (H, W)
"""
# Threshold to binary
binary_mask = (probability_map > self.threshold).astype(np.uint8)
# Morphological closing (fill broken strokes)
closing_kernel = cv2.getStructuringElement(
cv2.MORPH_RECT,
(self.closing_kernel, self.closing_kernel)
)
binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, closing_kernel)
# Morphological opening (remove isolated noise)
opening_kernel = cv2.getStructuringElement(
cv2.MORPH_RECT,
(self.opening_kernel, self.opening_kernel)
)
binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, opening_kernel)
# Critical Fix #3: Remove small regions with adaptive threshold
binary_mask = self._remove_small_regions(binary_mask)
# Resize to original size if provided
if original_size is not None:
binary_mask = cv2.resize(
binary_mask,
(original_size[1], original_size[0]), # cv2 uses (W, H)
interpolation=cv2.INTER_NEAREST
)
return binary_mask
def _remove_small_regions(self, mask: np.ndarray) -> np.ndarray:
"""
Remove regions smaller than minimum area threshold
Args:
mask: Binary mask (H, W)
Returns:
Filtered mask
"""
# Calculate minimum pixel count
image_area = mask.shape[0] * mask.shape[1]
min_pixels = int(image_area * self.min_region_area)
# Label connected components
labeled_mask, num_features = ndimage.label(mask)
# Keep only large enough regions
filtered_mask = np.zeros_like(mask)
for region_id in range(1, num_features + 1):
region_mask = (labeled_mask == region_id)
region_area = region_mask.sum()
if region_area >= min_pixels:
filtered_mask[region_mask] = 1
return filtered_mask
class RegionExtractor:
"""
Extract individual regions from binary mask
Implements Critical Fix #4: Region Confidence Aggregation
"""
def __init__(self, config, dataset_name: str = 'default'):
"""
Initialize region extractor
Args:
config: Configuration object
dataset_name: Dataset name
"""
self.config = config
self.dataset_name = dataset_name
self.min_region_area = config.get_min_region_area(dataset_name)
def extract(self,
binary_mask: np.ndarray,
probability_map: np.ndarray,
original_image: np.ndarray) -> List[Dict]:
"""
Extract regions from binary mask
Args:
binary_mask: Refined binary mask (H, W)
probability_map: Original probability map (H, W)
original_image: Original image (H, W, 3)
Returns:
List of region dictionaries with bounding box, mask, image, confidence
"""
regions = []
# Safety check: Ensure probability_map and binary_mask have same dimensions
if probability_map.shape != binary_mask.shape:
import cv2
probability_map = cv2.resize(
probability_map,
(binary_mask.shape[1], binary_mask.shape[0]),
interpolation=cv2.INTER_LINEAR
)
# Connected component analysis (8-connectivity)
labeled_mask = label(binary_mask, connectivity=2)
props = regionprops(labeled_mask)
for region_id, prop in enumerate(props, start=1):
# Bounding box
y_min, x_min, y_max, x_max = prop.bbox
# Region mask
region_mask = (labeled_mask == region_id).astype(np.uint8)
# Cropped region image
region_image = original_image[y_min:y_max, x_min:x_max].copy()
region_mask_cropped = region_mask[y_min:y_max, x_min:x_max]
# Critical Fix #4: Region-level confidence aggregation
# Ensure region_mask and probability_map have same shape
if region_mask.shape != probability_map.shape:
import cv2
# Resize probability_map to match region_mask
probability_map = cv2.resize(
probability_map,
(region_mask.shape[1], region_mask.shape[0]),
interpolation=cv2.INTER_LINEAR
)
region_probs = probability_map[region_mask > 0]
region_confidence = float(np.mean(region_probs)) if len(region_probs) > 0 else 0.0
regions.append({
'region_id': region_id,
'bounding_box': [int(x_min), int(y_min),
int(x_max - x_min), int(y_max - y_min)],
'area': prop.area,
'centroid': (int(prop.centroid[1]), int(prop.centroid[0])),
'region_mask': region_mask,
'region_mask_cropped': region_mask_cropped,
'region_image': region_image,
'confidence': region_confidence,
'mask_probability_mean': region_confidence
})
return regions
def extract_for_casia(self,
binary_mask: np.ndarray,
probability_map: np.ndarray,
original_image: np.ndarray) -> List[Dict]:
"""
Critical Fix #6: CASIA handling - treat entire image as one region
Args:
binary_mask: Binary mask (may be empty for authentic images)
probability_map: Probability map
original_image: Original image
Returns:
Single region representing entire image
"""
h, w = original_image.shape[:2]
# Create single region covering entire image
region_mask = np.ones((h, w), dtype=np.uint8)
# Overall confidence from probability map
overall_confidence = float(np.mean(probability_map))
return [{
'region_id': 1,
'bounding_box': [0, 0, w, h],
'area': h * w,
'centroid': (w // 2, h // 2),
'region_mask': region_mask,
'region_mask_cropped': region_mask,
'region_image': original_image,
'confidence': overall_confidence,
'mask_probability_mean': overall_confidence
}]
def get_mask_refiner(config, dataset_name: str = 'default') -> MaskRefiner:
"""Factory function for mask refiner"""
return MaskRefiner(config, dataset_name)
def get_region_extractor(config, dataset_name: str = 'default') -> RegionExtractor:
"""Factory function for region extractor"""
return RegionExtractor(config, dataset_name)