JKrishnanandhaa commited on
Commit
936d73b
·
verified ·
1 Parent(s): 1fd370b

Update src/features/region_extraction.py

Browse files
Files changed (1) hide show
  1. src/features/region_extraction.py +235 -226
src/features/region_extraction.py CHANGED
@@ -1,226 +1,235 @@
1
- """
2
- Mask refinement and region extraction
3
- Implements Critical Fix #3: Adaptive Mask Refinement Thresholds
4
- """
5
-
6
- import cv2
7
- import numpy as np
8
- from typing import List, Tuple, Dict, Optional
9
- from scipy import ndimage
10
- from skimage.measure import label, regionprops
11
-
12
-
13
- class MaskRefiner:
14
- """
15
- Mask refinement with adaptive thresholds
16
- Implements Critical Fix #3: Dataset-specific minimum region areas
17
- """
18
-
19
- def __init__(self, config, dataset_name: str = 'default'):
20
- """
21
- Initialize mask refiner
22
-
23
- Args:
24
- config: Configuration object
25
- dataset_name: Dataset name for adaptive thresholds
26
- """
27
- self.config = config
28
- self.dataset_name = dataset_name
29
-
30
- # Get mask refinement parameters
31
- self.threshold = config.get('mask_refinement.threshold', 0.5)
32
- self.closing_kernel = config.get('mask_refinement.morphology.closing_kernel', 5)
33
- self.opening_kernel = config.get('mask_refinement.morphology.opening_kernel', 3)
34
-
35
- # Critical Fix #3: Adaptive thresholds per dataset
36
- self.min_region_area = config.get_min_region_area(dataset_name)
37
-
38
- print(f"MaskRefiner initialized for {dataset_name}")
39
- print(f"Min region area: {self.min_region_area * 100:.2f}%")
40
-
41
- def refine(self,
42
- probability_map: np.ndarray,
43
- original_size: Tuple[int, int] = None) -> np.ndarray:
44
- """
45
- Refine probability map to binary mask
46
-
47
- Args:
48
- probability_map: Forgery probability map (H, W), values [0, 1]
49
- original_size: Optional (H, W) to resize mask back to original
50
-
51
- Returns:
52
- Refined binary mask (H, W)
53
- """
54
- # Threshold to binary
55
- binary_mask = (probability_map > self.threshold).astype(np.uint8)
56
-
57
- # Morphological closing (fill broken strokes)
58
- closing_kernel = cv2.getStructuringElement(
59
- cv2.MORPH_RECT,
60
- (self.closing_kernel, self.closing_kernel)
61
- )
62
- binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, closing_kernel)
63
-
64
- # Morphological opening (remove isolated noise)
65
- opening_kernel = cv2.getStructuringElement(
66
- cv2.MORPH_RECT,
67
- (self.opening_kernel, self.opening_kernel)
68
- )
69
- binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, opening_kernel)
70
-
71
- # Critical Fix #3: Remove small regions with adaptive threshold
72
- binary_mask = self._remove_small_regions(binary_mask)
73
-
74
- # Resize to original size if provided
75
- if original_size is not None:
76
- binary_mask = cv2.resize(
77
- binary_mask,
78
- (original_size[1], original_size[0]), # cv2 uses (W, H)
79
- interpolation=cv2.INTER_NEAREST
80
- )
81
-
82
- return binary_mask
83
-
84
- def _remove_small_regions(self, mask: np.ndarray) -> np.ndarray:
85
- """
86
- Remove regions smaller than minimum area threshold
87
-
88
- Args:
89
- mask: Binary mask (H, W)
90
-
91
- Returns:
92
- Filtered mask
93
- """
94
- # Calculate minimum pixel count
95
- image_area = mask.shape[0] * mask.shape[1]
96
- min_pixels = int(image_area * self.min_region_area)
97
-
98
- # Label connected components
99
- labeled_mask, num_features = ndimage.label(mask)
100
-
101
- # Keep only large enough regions
102
- filtered_mask = np.zeros_like(mask)
103
-
104
- for region_id in range(1, num_features + 1):
105
- region_mask = (labeled_mask == region_id)
106
- region_area = region_mask.sum()
107
-
108
- if region_area >= min_pixels:
109
- filtered_mask[region_mask] = 1
110
-
111
- return filtered_mask
112
-
113
-
114
- class RegionExtractor:
115
- """
116
- Extract individual regions from binary mask
117
- Implements Critical Fix #4: Region Confidence Aggregation
118
- """
119
-
120
- def __init__(self, config, dataset_name: str = 'default'):
121
- """
122
- Initialize region extractor
123
-
124
- Args:
125
- config: Configuration object
126
- dataset_name: Dataset name
127
- """
128
- self.config = config
129
- self.dataset_name = dataset_name
130
- self.min_region_area = config.get_min_region_area(dataset_name)
131
-
132
- def extract(self,
133
- binary_mask: np.ndarray,
134
- probability_map: np.ndarray,
135
- original_image: np.ndarray) -> List[Dict]:
136
- """
137
- Extract regions from binary mask
138
-
139
- Args:
140
- binary_mask: Refined binary mask (H, W)
141
- probability_map: Original probability map (H, W)
142
- original_image: Original image (H, W, 3)
143
-
144
- Returns:
145
- List of region dictionaries with bounding box, mask, image, confidence
146
- """
147
- regions = []
148
-
149
- # Connected component analysis (8-connectivity)
150
- labeled_mask = label(binary_mask, connectivity=2)
151
- props = regionprops(labeled_mask)
152
-
153
- for region_id, prop in enumerate(props, start=1):
154
- # Bounding box
155
- y_min, x_min, y_max, x_max = prop.bbox
156
-
157
- # Region mask
158
- region_mask = (labeled_mask == region_id).astype(np.uint8)
159
-
160
- # Cropped region image
161
- region_image = original_image[y_min:y_max, x_min:x_max].copy()
162
- region_mask_cropped = region_mask[y_min:y_max, x_min:x_max]
163
-
164
- # Critical Fix #4: Region-level confidence aggregation
165
- region_probs = probability_map[region_mask > 0]
166
- region_confidence = float(np.mean(region_probs)) if len(region_probs) > 0 else 0.0
167
-
168
- regions.append({
169
- 'region_id': region_id,
170
- 'bounding_box': [int(x_min), int(y_min),
171
- int(x_max - x_min), int(y_max - y_min)],
172
- 'area': prop.area,
173
- 'centroid': (int(prop.centroid[1]), int(prop.centroid[0])),
174
- 'region_mask': region_mask,
175
- 'region_mask_cropped': region_mask_cropped,
176
- 'region_image': region_image,
177
- 'confidence': region_confidence,
178
- 'mask_probability_mean': region_confidence
179
- })
180
-
181
- return regions
182
-
183
- def extract_for_casia(self,
184
- binary_mask: np.ndarray,
185
- probability_map: np.ndarray,
186
- original_image: np.ndarray) -> List[Dict]:
187
- """
188
- Critical Fix #6: CASIA handling - treat entire image as one region
189
-
190
- Args:
191
- binary_mask: Binary mask (may be empty for authentic images)
192
- probability_map: Probability map
193
- original_image: Original image
194
-
195
- Returns:
196
- Single region representing entire image
197
- """
198
- h, w = original_image.shape[:2]
199
-
200
- # Create single region covering entire image
201
- region_mask = np.ones((h, w), dtype=np.uint8)
202
-
203
- # Overall confidence from probability map
204
- overall_confidence = float(np.mean(probability_map))
205
-
206
- return [{
207
- 'region_id': 1,
208
- 'bounding_box': [0, 0, w, h],
209
- 'area': h * w,
210
- 'centroid': (w // 2, h // 2),
211
- 'region_mask': region_mask,
212
- 'region_mask_cropped': region_mask,
213
- 'region_image': original_image,
214
- 'confidence': overall_confidence,
215
- 'mask_probability_mean': overall_confidence
216
- }]
217
-
218
-
219
- def get_mask_refiner(config, dataset_name: str = 'default') -> MaskRefiner:
220
- """Factory function for mask refiner"""
221
- return MaskRefiner(config, dataset_name)
222
-
223
-
224
- def get_region_extractor(config, dataset_name: str = 'default') -> RegionExtractor:
225
- """Factory function for region extractor"""
226
- return RegionExtractor(config, dataset_name)
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mask refinement and region extraction
3
+ Implements Critical Fix #3: Adaptive Mask Refinement Thresholds
4
+ """
5
+
6
+ import cv2
7
+ import numpy as np
8
+ from typing import List, Tuple, Dict, Optional
9
+ from scipy import ndimage
10
+ from skimage.measure import label, regionprops
11
+
12
+
13
+ class MaskRefiner:
14
+ """
15
+ Mask refinement with adaptive thresholds
16
+ Implements Critical Fix #3: Dataset-specific minimum region areas
17
+ """
18
+
19
+ def __init__(self, config, dataset_name: str = 'default'):
20
+ """
21
+ Initialize mask refiner
22
+
23
+ Args:
24
+ config: Configuration object
25
+ dataset_name: Dataset name for adaptive thresholds
26
+ """
27
+ self.config = config
28
+ self.dataset_name = dataset_name
29
+
30
+ # Get mask refinement parameters
31
+ self.threshold = config.get('mask_refinement.threshold', 0.5)
32
+ self.closing_kernel = config.get('mask_refinement.morphology.closing_kernel', 5)
33
+ self.opening_kernel = config.get('mask_refinement.morphology.opening_kernel', 3)
34
+
35
+ # Critical Fix #3: Adaptive thresholds per dataset
36
+ self.min_region_area = config.get_min_region_area(dataset_name)
37
+
38
+ print(f"MaskRefiner initialized for {dataset_name}")
39
+ print(f"Min region area: {self.min_region_area * 100:.2f}%")
40
+
41
+ def refine(self,
42
+ probability_map: np.ndarray,
43
+ original_size: Tuple[int, int] = None) -> np.ndarray:
44
+ """
45
+ Refine probability map to binary mask
46
+
47
+ Args:
48
+ probability_map: Forgery probability map (H, W), values [0, 1]
49
+ original_size: Optional (H, W) to resize mask back to original
50
+
51
+ Returns:
52
+ Refined binary mask (H, W)
53
+ """
54
+ # Threshold to binary
55
+ binary_mask = (probability_map > self.threshold).astype(np.uint8)
56
+
57
+ # Morphological closing (fill broken strokes)
58
+ closing_kernel = cv2.getStructuringElement(
59
+ cv2.MORPH_RECT,
60
+ (self.closing_kernel, self.closing_kernel)
61
+ )
62
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, closing_kernel)
63
+
64
+ # Morphological opening (remove isolated noise)
65
+ opening_kernel = cv2.getStructuringElement(
66
+ cv2.MORPH_RECT,
67
+ (self.opening_kernel, self.opening_kernel)
68
+ )
69
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, opening_kernel)
70
+
71
+ # Critical Fix #3: Remove small regions with adaptive threshold
72
+ binary_mask = self._remove_small_regions(binary_mask)
73
+
74
+ # Resize to original size if provided
75
+ if original_size is not None:
76
+ binary_mask = cv2.resize(
77
+ binary_mask,
78
+ (original_size[1], original_size[0]), # cv2 uses (W, H)
79
+ interpolation=cv2.INTER_NEAREST
80
+ )
81
+
82
+ return binary_mask
83
+
84
+ def _remove_small_regions(self, mask: np.ndarray) -> np.ndarray:
85
+ """
86
+ Remove regions smaller than minimum area threshold
87
+
88
+ Args:
89
+ mask: Binary mask (H, W)
90
+
91
+ Returns:
92
+ Filtered mask
93
+ """
94
+ # Calculate minimum pixel count
95
+ image_area = mask.shape[0] * mask.shape[1]
96
+ min_pixels = int(image_area * self.min_region_area)
97
+
98
+ # Label connected components
99
+ labeled_mask, num_features = ndimage.label(mask)
100
+
101
+ # Keep only large enough regions
102
+ filtered_mask = np.zeros_like(mask)
103
+
104
+ for region_id in range(1, num_features + 1):
105
+ region_mask = (labeled_mask == region_id)
106
+ region_area = region_mask.sum()
107
+
108
+ if region_area >= min_pixels:
109
+ filtered_mask[region_mask] = 1
110
+
111
+ return filtered_mask
112
+
113
+
114
+ class RegionExtractor:
115
+ """
116
+ Extract individual regions from binary mask
117
+ Implements Critical Fix #4: Region Confidence Aggregation
118
+ """
119
+
120
+ def __init__(self, config, dataset_name: str = 'default'):
121
+ """
122
+ Initialize region extractor
123
+
124
+ Args:
125
+ config: Configuration object
126
+ dataset_name: Dataset name
127
+ """
128
+ self.config = config
129
+ self.dataset_name = dataset_name
130
+ self.min_region_area = config.get_min_region_area(dataset_name)
131
+
132
+ def extract(self,
133
+ binary_mask: np.ndarray,
134
+ probability_map: np.ndarray,
135
+ original_image: np.ndarray) -> List[Dict]:
136
+ """
137
+ Extract regions from binary mask
138
+
139
+ Args:
140
+ binary_mask: Refined binary mask (H, W)
141
+ probability_map: Original probability map (H, W)
142
+ original_image: Original image (H, W, 3)
143
+
144
+ Returns:
145
+ List of region dictionaries with bounding box, mask, image, confidence
146
+ """
147
+ regions = []
148
+
149
+ # Safety check: Ensure probability_map and binary_mask have same dimensions
150
+ if probability_map.shape != binary_mask.shape:
151
+ import cv2
152
+ probability_map = cv2.resize(
153
+ probability_map,
154
+ (binary_mask.shape[1], binary_mask.shape[0]),
155
+ interpolation=cv2.INTER_LINEAR
156
+ )
157
+
158
+ # Connected component analysis (8-connectivity)
159
+ labeled_mask = label(binary_mask, connectivity=2)
160
+ props = regionprops(labeled_mask)
161
+
162
+ for region_id, prop in enumerate(props, start=1):
163
+ # Bounding box
164
+ y_min, x_min, y_max, x_max = prop.bbox
165
+
166
+ # Region mask
167
+ region_mask = (labeled_mask == region_id).astype(np.uint8)
168
+
169
+ # Cropped region image
170
+ region_image = original_image[y_min:y_max, x_min:x_max].copy()
171
+ region_mask_cropped = region_mask[y_min:y_max, x_min:x_max]
172
+
173
+ # Critical Fix #4: Region-level confidence aggregation
174
+ region_probs = probability_map[region_mask > 0]
175
+ region_confidence = float(np.mean(region_probs)) if len(region_probs) > 0 else 0.0
176
+
177
+ regions.append({
178
+ 'region_id': region_id,
179
+ 'bounding_box': [int(x_min), int(y_min),
180
+ int(x_max - x_min), int(y_max - y_min)],
181
+ 'area': prop.area,
182
+ 'centroid': (int(prop.centroid[1]), int(prop.centroid[0])),
183
+ 'region_mask': region_mask,
184
+ 'region_mask_cropped': region_mask_cropped,
185
+ 'region_image': region_image,
186
+ 'confidence': region_confidence,
187
+ 'mask_probability_mean': region_confidence
188
+ })
189
+
190
+ return regions
191
+
192
+ def extract_for_casia(self,
193
+ binary_mask: np.ndarray,
194
+ probability_map: np.ndarray,
195
+ original_image: np.ndarray) -> List[Dict]:
196
+ """
197
+ Critical Fix #6: CASIA handling - treat entire image as one region
198
+
199
+ Args:
200
+ binary_mask: Binary mask (may be empty for authentic images)
201
+ probability_map: Probability map
202
+ original_image: Original image
203
+
204
+ Returns:
205
+ Single region representing entire image
206
+ """
207
+ h, w = original_image.shape[:2]
208
+
209
+ # Create single region covering entire image
210
+ region_mask = np.ones((h, w), dtype=np.uint8)
211
+
212
+ # Overall confidence from probability map
213
+ overall_confidence = float(np.mean(probability_map))
214
+
215
+ return [{
216
+ 'region_id': 1,
217
+ 'bounding_box': [0, 0, w, h],
218
+ 'area': h * w,
219
+ 'centroid': (w // 2, h // 2),
220
+ 'region_mask': region_mask,
221
+ 'region_mask_cropped': region_mask,
222
+ 'region_image': original_image,
223
+ 'confidence': overall_confidence,
224
+ 'mask_probability_mean': overall_confidence
225
+ }]
226
+
227
+
228
+ def get_mask_refiner(config, dataset_name: str = 'default') -> MaskRefiner:
229
+ """Factory function for mask refiner"""
230
+ return MaskRefiner(config, dataset_name)
231
+
232
+
233
+ def get_region_extractor(config, dataset_name: str = 'default') -> RegionExtractor:
234
+ """Factory function for region extractor"""
235
+ return RegionExtractor(config, dataset_name)