Spaces:
Sleeping
Sleeping
File size: 8,843 Bytes
936d73b 86140ca 936d73b 86140ca 936d73b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 | """
Mask refinement and region extraction
Implements Critical Fix #3: Adaptive Mask Refinement Thresholds
"""
import cv2
import numpy as np
from typing import List, Tuple, Dict, Optional
from scipy import ndimage
from skimage.measure import label, regionprops
class MaskRefiner:
"""
Mask refinement with adaptive thresholds
Implements Critical Fix #3: Dataset-specific minimum region areas
"""
def __init__(self, config, dataset_name: str = 'default'):
"""
Initialize mask refiner
Args:
config: Configuration object
dataset_name: Dataset name for adaptive thresholds
"""
self.config = config
self.dataset_name = dataset_name
# Get mask refinement parameters
self.threshold = config.get('mask_refinement.threshold', 0.5)
self.closing_kernel = config.get('mask_refinement.morphology.closing_kernel', 5)
self.opening_kernel = config.get('mask_refinement.morphology.opening_kernel', 3)
# Critical Fix #3: Adaptive thresholds per dataset
self.min_region_area = config.get_min_region_area(dataset_name)
print(f"MaskRefiner initialized for {dataset_name}")
print(f"Min region area: {self.min_region_area * 100:.2f}%")
def refine(self,
probability_map: np.ndarray,
original_size: Tuple[int, int] = None) -> np.ndarray:
"""
Refine probability map to binary mask
Args:
probability_map: Forgery probability map (H, W), values [0, 1]
original_size: Optional (H, W) to resize mask back to original
Returns:
Refined binary mask (H, W)
"""
# Threshold to binary
binary_mask = (probability_map > self.threshold).astype(np.uint8)
# Morphological closing (fill broken strokes)
closing_kernel = cv2.getStructuringElement(
cv2.MORPH_RECT,
(self.closing_kernel, self.closing_kernel)
)
binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, closing_kernel)
# Morphological opening (remove isolated noise)
opening_kernel = cv2.getStructuringElement(
cv2.MORPH_RECT,
(self.opening_kernel, self.opening_kernel)
)
binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, opening_kernel)
# Critical Fix #3: Remove small regions with adaptive threshold
binary_mask = self._remove_small_regions(binary_mask)
# Resize to original size if provided
if original_size is not None:
binary_mask = cv2.resize(
binary_mask,
(original_size[1], original_size[0]), # cv2 uses (W, H)
interpolation=cv2.INTER_NEAREST
)
return binary_mask
def _remove_small_regions(self, mask: np.ndarray) -> np.ndarray:
"""
Remove regions smaller than minimum area threshold
Args:
mask: Binary mask (H, W)
Returns:
Filtered mask
"""
# Calculate minimum pixel count
image_area = mask.shape[0] * mask.shape[1]
min_pixels = int(image_area * self.min_region_area)
# Label connected components
labeled_mask, num_features = ndimage.label(mask)
# Keep only large enough regions
filtered_mask = np.zeros_like(mask)
for region_id in range(1, num_features + 1):
region_mask = (labeled_mask == region_id)
region_area = region_mask.sum()
if region_area >= min_pixels:
filtered_mask[region_mask] = 1
return filtered_mask
class RegionExtractor:
"""
Extract individual regions from binary mask
Implements Critical Fix #4: Region Confidence Aggregation
"""
def __init__(self, config, dataset_name: str = 'default'):
"""
Initialize region extractor
Args:
config: Configuration object
dataset_name: Dataset name
"""
self.config = config
self.dataset_name = dataset_name
self.min_region_area = config.get_min_region_area(dataset_name)
def extract(self,
binary_mask: np.ndarray,
probability_map: np.ndarray,
original_image: np.ndarray) -> List[Dict]:
"""
Extract regions from binary mask
Args:
binary_mask: Refined binary mask (H, W)
probability_map: Original probability map (H, W)
original_image: Original image (H, W, 3)
Returns:
List of region dictionaries with bounding box, mask, image, confidence
"""
regions = []
# Safety check: Ensure probability_map and binary_mask have same dimensions
if probability_map.shape != binary_mask.shape:
import cv2
probability_map = cv2.resize(
probability_map,
(binary_mask.shape[1], binary_mask.shape[0]),
interpolation=cv2.INTER_LINEAR
)
# Connected component analysis (8-connectivity)
labeled_mask = label(binary_mask, connectivity=2)
props = regionprops(labeled_mask)
for region_id, prop in enumerate(props, start=1):
# Bounding box
y_min, x_min, y_max, x_max = prop.bbox
# Region mask
region_mask = (labeled_mask == region_id).astype(np.uint8)
# Cropped region image
region_image = original_image[y_min:y_max, x_min:x_max].copy()
region_mask_cropped = region_mask[y_min:y_max, x_min:x_max]
# Critical Fix #4: Region-level confidence aggregation
# Ensure region_mask and probability_map have same shape
if region_mask.shape != probability_map.shape:
import cv2
# Resize probability_map to match region_mask
probability_map = cv2.resize(
probability_map,
(region_mask.shape[1], region_mask.shape[0]),
interpolation=cv2.INTER_LINEAR
)
region_probs = probability_map[region_mask > 0]
region_confidence = float(np.mean(region_probs)) if len(region_probs) > 0 else 0.0
regions.append({
'region_id': region_id,
'bounding_box': [int(x_min), int(y_min),
int(x_max - x_min), int(y_max - y_min)],
'area': prop.area,
'centroid': (int(prop.centroid[1]), int(prop.centroid[0])),
'region_mask': region_mask,
'region_mask_cropped': region_mask_cropped,
'region_image': region_image,
'confidence': region_confidence,
'mask_probability_mean': region_confidence
})
return regions
def extract_for_casia(self,
binary_mask: np.ndarray,
probability_map: np.ndarray,
original_image: np.ndarray) -> List[Dict]:
"""
Critical Fix #6: CASIA handling - treat entire image as one region
Args:
binary_mask: Binary mask (may be empty for authentic images)
probability_map: Probability map
original_image: Original image
Returns:
Single region representing entire image
"""
h, w = original_image.shape[:2]
# Create single region covering entire image
region_mask = np.ones((h, w), dtype=np.uint8)
# Overall confidence from probability map
overall_confidence = float(np.mean(probability_map))
return [{
'region_id': 1,
'bounding_box': [0, 0, w, h],
'area': h * w,
'centroid': (w // 2, h // 2),
'region_mask': region_mask,
'region_mask_cropped': region_mask,
'region_image': original_image,
'confidence': overall_confidence,
'mask_probability_mean': overall_confidence
}]
def get_mask_refiner(config, dataset_name: str = 'default') -> MaskRefiner:
"""Factory function for mask refiner"""
return MaskRefiner(config, dataset_name)
def get_region_extractor(config, dataset_name: str = 'default') -> RegionExtractor:
"""Factory function for region extractor"""
return RegionExtractor(config, dataset_name)
|