JKrishnanandhaa commited on
Commit
53014ca
·
verified ·
1 Parent(s): e003867

Update src/features/feature_extraction.py

Browse files
Files changed (1) hide show
  1. src/features/feature_extraction.py +515 -485
src/features/feature_extraction.py CHANGED
@@ -1,485 +1,515 @@
1
- """
2
- Hybrid feature extraction for forgery detection
3
- Implements Critical Fix #5: Feature Group Gating
4
- """
5
-
6
- import cv2
7
- import numpy as np
8
- import torch
9
- import torch.nn.functional as F
10
- from typing import Dict, List, Optional, Tuple
11
- from scipy import ndimage
12
- from scipy.fftpack import dct
13
- import pywt
14
- from skimage.measure import regionprops, label
15
- from skimage.filters import sobel
16
-
17
-
18
- class DeepFeatureExtractor:
19
- """Extract deep features from decoder feature maps"""
20
-
21
- def __init__(self):
22
- """Initialize deep feature extractor"""
23
- pass
24
-
25
- def extract(self,
26
- decoder_features: List[torch.Tensor],
27
- region_mask: np.ndarray) -> np.ndarray:
28
- """
29
- Extract deep features using Global Average Pooling
30
-
31
- Args:
32
- decoder_features: List of decoder feature tensors
33
- region_mask: Binary region mask (H, W)
34
-
35
- Returns:
36
- Deep feature vector
37
- """
38
- features = []
39
-
40
- for feat in decoder_features:
41
- # Ensure on CPU and numpy
42
- if isinstance(feat, torch.Tensor):
43
- feat = feat.detach().cpu().numpy()
44
-
45
- # feat shape: (B, C, H, W) or (C, H, W)
46
- if feat.ndim == 4:
47
- feat = feat[0] # Take first batch
48
-
49
- # Resize mask to feature size
50
- h, w = feat.shape[1:]
51
- mask_resized = cv2.resize(region_mask.astype(np.float32), (w, h))
52
- mask_resized = mask_resized > 0.5
53
-
54
- # Masked Global Average Pooling
55
- if mask_resized.sum() > 0:
56
- for c in range(feat.shape[0]):
57
- channel_feat = feat[c]
58
- masked_mean = channel_feat[mask_resized].mean()
59
- features.append(masked_mean)
60
- else:
61
- # Fallback: use global average
62
- features.extend(feat.mean(axis=(1, 2)).tolist())
63
-
64
- return np.array(features, dtype=np.float32)
65
-
66
-
67
- class StatisticalFeatureExtractor:
68
- """Extract statistical and shape features from regions"""
69
-
70
- def __init__(self):
71
- """Initialize statistical feature extractor"""
72
- pass
73
-
74
- def extract(self,
75
- image: np.ndarray,
76
- region_mask: np.ndarray) -> np.ndarray:
77
- """
78
- Extract statistical and shape features
79
-
80
- Args:
81
- image: Input image (H, W, 3) normalized [0, 1]
82
- region_mask: Binary region mask (H, W)
83
-
84
- Returns:
85
- Statistical feature vector
86
- """
87
- features = []
88
-
89
- # Label the mask
90
- labeled_mask = label(region_mask)
91
- props = regionprops(labeled_mask)
92
-
93
- if len(props) > 0:
94
- prop = props[0]
95
-
96
- # Area and perimeter
97
- features.append(prop.area)
98
- features.append(prop.perimeter)
99
-
100
- # Aspect ratio
101
- if prop.major_axis_length > 0:
102
- aspect_ratio = prop.minor_axis_length / prop.major_axis_length
103
- else:
104
- aspect_ratio = 1.0
105
- features.append(aspect_ratio)
106
-
107
- # Solidity
108
- features.append(prop.solidity)
109
-
110
- # Eccentricity
111
- features.append(prop.eccentricity)
112
-
113
- # Entropy (using intensity)
114
- if len(image.shape) == 3:
115
- gray = cv2.cvtColor((image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
116
- else:
117
- gray = (image * 255).astype(np.uint8)
118
-
119
- region_pixels = gray[region_mask > 0]
120
- if len(region_pixels) > 0:
121
- hist, _ = np.histogram(region_pixels, bins=256, range=(0, 256))
122
- hist = hist / hist.sum() + 1e-8
123
- entropy = -np.sum(hist * np.log2(hist + 1e-8))
124
- else:
125
- entropy = 0.0
126
- features.append(entropy)
127
- else:
128
- # Default values
129
- features.extend([0, 0, 1.0, 0, 0, 0])
130
-
131
- return np.array(features, dtype=np.float32)
132
-
133
-
134
- class FrequencyFeatureExtractor:
135
- """Extract frequency-domain features"""
136
-
137
- def __init__(self):
138
- """Initialize frequency feature extractor"""
139
- pass
140
-
141
- def extract(self,
142
- image: np.ndarray,
143
- region_mask: np.ndarray) -> np.ndarray:
144
- """
145
- Extract frequency-domain features (DCT, wavelet)
146
-
147
- Args:
148
- image: Input image (H, W, 3) normalized [0, 1]
149
- region_mask: Binary region mask (H, W)
150
-
151
- Returns:
152
- Frequency feature vector
153
- """
154
- features = []
155
-
156
- # Convert to grayscale
157
- if len(image.shape) == 3:
158
- gray = cv2.cvtColor((image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
159
- else:
160
- gray = (image * 255).astype(np.uint8)
161
-
162
- # Get region bounding box
163
- coords = np.where(region_mask > 0)
164
- if len(coords[0]) == 0:
165
- return np.zeros(13, dtype=np.float32)
166
-
167
- y_min, y_max = coords[0].min(), coords[0].max()
168
- x_min, x_max = coords[1].min(), coords[1].max()
169
-
170
- # Crop region
171
- region = gray[y_min:y_max+1, x_min:x_max+1].astype(np.float32)
172
-
173
- if region.size == 0:
174
- return np.zeros(13, dtype=np.float32)
175
-
176
- # DCT coefficients
177
- try:
178
- dct_coeffs = dct(dct(region, axis=0, norm='ortho'), axis=1, norm='ortho')
179
-
180
- # Mean and std of DCT coefficients
181
- features.append(np.mean(np.abs(dct_coeffs)))
182
- features.append(np.std(dct_coeffs))
183
-
184
- # High-frequency energy (bottom-right quadrant)
185
- h, w = dct_coeffs.shape
186
- high_freq = dct_coeffs[h//2:, w//2:]
187
- features.append(np.sum(np.abs(high_freq)) / (high_freq.size + 1e-8))
188
- except Exception:
189
- features.extend([0, 0, 0])
190
-
191
- # Wavelet features
192
- try:
193
- coeffs = pywt.dwt2(region, 'db1')
194
- cA, (cH, cV, cD) = coeffs
195
-
196
- # Energy in each sub-band
197
- features.append(np.sum(cA ** 2) / (cA.size + 1e-8))
198
- features.append(np.sum(cH ** 2) / (cH.size + 1e-8))
199
- features.append(np.sum(cV ** 2) / (cV.size + 1e-8))
200
- features.append(np.sum(cD ** 2) / (cD.size + 1e-8))
201
-
202
- # Wavelet entropy
203
- for coeff in [cH, cV, cD]:
204
- coeff_flat = np.abs(coeff.flatten())
205
- if coeff_flat.sum() > 0:
206
- coeff_norm = coeff_flat / coeff_flat.sum()
207
- entropy = -np.sum(coeff_norm * np.log2(coeff_norm + 1e-8))
208
- else:
209
- entropy = 0.0
210
- features.append(entropy)
211
- except Exception:
212
- features.extend([0, 0, 0, 0, 0, 0, 0])
213
-
214
- return np.array(features, dtype=np.float32)
215
-
216
-
217
- class NoiseELAFeatureExtractor:
218
- """Extract noise and Error Level Analysis features"""
219
-
220
- def __init__(self, quality: int = 90):
221
- """
222
- Initialize noise/ELA extractor
223
-
224
- Args:
225
- quality: JPEG quality for ELA
226
- """
227
- self.quality = quality
228
-
229
- def extract(self,
230
- image: np.ndarray,
231
- region_mask: np.ndarray) -> np.ndarray:
232
- """
233
- Extract noise and ELA features
234
-
235
- Args:
236
- image: Input image (H, W, 3) normalized [0, 1]
237
- region_mask: Binary region mask (H, W)
238
-
239
- Returns:
240
- Noise/ELA feature vector
241
- """
242
- features = []
243
-
244
- # Convert to uint8
245
- img_uint8 = (image * 255).astype(np.uint8)
246
-
247
- # Error Level Analysis
248
- # Compress and compute difference
249
- encode_param = [cv2.IMWRITE_JPEG_QUALITY, self.quality]
250
- _, encoded = cv2.imencode('.jpg', img_uint8, encode_param)
251
- recompressed = cv2.imdecode(encoded, cv2.IMREAD_COLOR)
252
-
253
- ela = np.abs(img_uint8.astype(np.float32) - recompressed.astype(np.float32))
254
-
255
- # ELA features within region
256
- ela_region = ela[region_mask > 0]
257
- if len(ela_region) > 0:
258
- features.append(np.mean(ela_region)) # ELA mean
259
- features.append(np.var(ela_region)) # ELA variance
260
- features.append(np.max(ela_region)) # ELA max
261
- else:
262
- features.extend([0, 0, 0])
263
-
264
- # Noise residual (using median filter)
265
- if len(image.shape) == 3:
266
- gray = cv2.cvtColor(img_uint8, cv2.COLOR_RGB2GRAY)
267
- else:
268
- gray = img_uint8
269
-
270
- median_filtered = cv2.medianBlur(gray, 3)
271
- noise_residual = np.abs(gray.astype(np.float32) - median_filtered.astype(np.float32))
272
-
273
- residual_region = noise_residual[region_mask > 0]
274
- if len(residual_region) > 0:
275
- features.append(np.mean(residual_region))
276
- features.append(np.var(residual_region))
277
- else:
278
- features.extend([0, 0])
279
-
280
- return np.array(features, dtype=np.float32)
281
-
282
-
283
- class OCRFeatureExtractor:
284
- """
285
- Extract OCR-based consistency features
286
- Only for text documents (Feature Gating - Critical Fix #5)
287
- """
288
-
289
- def __init__(self):
290
- """Initialize OCR feature extractor"""
291
- self.ocr_available = False
292
-
293
- try:
294
- import easyocr
295
- self.reader = easyocr.Reader(['en'], gpu=True)
296
- self.ocr_available = True
297
- except Exception:
298
- print("Warning: EasyOCR not available, OCR features disabled")
299
-
300
- def extract(self,
301
- image: np.ndarray,
302
- region_mask: np.ndarray) -> np.ndarray:
303
- """
304
- Extract OCR consistency features
305
-
306
- Args:
307
- image: Input image (H, W, 3) normalized [0, 1]
308
- region_mask: Binary region mask (H, W)
309
-
310
- Returns:
311
- OCR feature vector (or zeros if not text document)
312
- """
313
- features = []
314
-
315
- if not self.ocr_available:
316
- return np.zeros(6, dtype=np.float32)
317
-
318
- # Convert to uint8
319
- img_uint8 = (image * 255).astype(np.uint8)
320
-
321
- # Get region bounding box
322
- coords = np.where(region_mask > 0)
323
- if len(coords[0]) == 0:
324
- return np.zeros(6, dtype=np.float32)
325
-
326
- y_min, y_max = coords[0].min(), coords[0].max()
327
- x_min, x_max = coords[1].min(), coords[1].max()
328
-
329
- # Crop region
330
- region = img_uint8[y_min:y_max+1, x_min:x_max+1]
331
-
332
- try:
333
- # OCR on region
334
- results = self.reader.readtext(region)
335
-
336
- if len(results) > 0:
337
- # Confidence deviation
338
- confidences = [r[2] for r in results]
339
- features.append(np.mean(confidences))
340
- features.append(np.std(confidences))
341
-
342
- # Character spacing analysis
343
- bbox_widths = [abs(r[0][1][0] - r[0][0][0]) for r in results]
344
- if len(bbox_widths) > 1:
345
- features.append(np.std(bbox_widths) / (np.mean(bbox_widths) + 1e-8))
346
- else:
347
- features.append(0.0)
348
-
349
- # Text density
350
- features.append(len(results) / (region.shape[0] * region.shape[1] + 1e-8))
351
-
352
- # Stroke width variation (using edge detection)
353
- gray_region = cv2.cvtColor(region, cv2.COLOR_RGB2GRAY)
354
- edges = sobel(gray_region)
355
- features.append(np.mean(edges))
356
- features.append(np.std(edges))
357
- else:
358
- features.extend([0, 0, 0, 0, 0, 0])
359
- except Exception:
360
- features.extend([0, 0, 0, 0, 0, 0])
361
-
362
- return np.array(features, dtype=np.float32)
363
-
364
-
365
- class HybridFeatureExtractor:
366
- """
367
- Complete hybrid feature extraction
368
- Implements Critical Fix #5: Feature Group Gating
369
- """
370
-
371
- def __init__(self, config, is_text_document: bool = True):
372
- """
373
- Initialize hybrid feature extractor
374
-
375
- Args:
376
- config: Configuration object
377
- is_text_document: Whether input is text document (for OCR gating)
378
- """
379
- self.config = config
380
- self.is_text_document = is_text_document
381
-
382
- # Initialize extractors
383
- self.deep_extractor = DeepFeatureExtractor()
384
- self.stat_extractor = StatisticalFeatureExtractor()
385
- self.freq_extractor = FrequencyFeatureExtractor()
386
- self.noise_extractor = NoiseELAFeatureExtractor()
387
-
388
- # Critical Fix #5: OCR only for text documents
389
- if is_text_document and config.get('features.ocr.enabled', True):
390
- self.ocr_extractor = OCRFeatureExtractor()
391
- else:
392
- self.ocr_extractor = None
393
-
394
- def extract(self,
395
- image: np.ndarray,
396
- region_mask: np.ndarray,
397
- decoder_features: Optional[List[torch.Tensor]] = None) -> np.ndarray:
398
- """
399
- Extract all hybrid features for a region
400
-
401
- Args:
402
- image: Input image (H, W, 3) normalized [0, 1]
403
- region_mask: Binary region mask (H, W)
404
- decoder_features: Optional decoder features for deep feature extraction
405
-
406
- Returns:
407
- Concatenated feature vector
408
- """
409
- all_features = []
410
-
411
- # Deep features (if available)
412
- if decoder_features is not None and self.config.get('features.deep.enabled', True):
413
- deep_feats = self.deep_extractor.extract(decoder_features, region_mask)
414
- all_features.append(deep_feats)
415
-
416
- # Statistical & shape features
417
- if self.config.get('features.statistical.enabled', True):
418
- stat_feats = self.stat_extractor.extract(image, region_mask)
419
- all_features.append(stat_feats)
420
-
421
- # Frequency-domain features
422
- if self.config.get('features.frequency.enabled', True):
423
- freq_feats = self.freq_extractor.extract(image, region_mask)
424
- all_features.append(freq_feats)
425
-
426
- # Noise & ELA features
427
- if self.config.get('features.noise.enabled', True):
428
- noise_feats = self.noise_extractor.extract(image, region_mask)
429
- all_features.append(noise_feats)
430
-
431
- # Critical Fix #5: OCR features only for text documents
432
- if self.ocr_extractor is not None:
433
- ocr_feats = self.ocr_extractor.extract(image, region_mask)
434
- all_features.append(ocr_feats)
435
-
436
- # Concatenate all features
437
- if len(all_features) > 0:
438
- features = np.concatenate(all_features)
439
- else:
440
- features = np.array([], dtype=np.float32)
441
-
442
- # Handle NaN/Inf
443
- features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
444
-
445
- return features
446
-
447
- def get_feature_names(self) -> List[str]:
448
- """Get list of feature names for interpretability"""
449
- names = []
450
-
451
- if self.config.get('features.deep.enabled', True):
452
- names.extend([f'deep_{i}' for i in range(256)]) # Approximate
453
-
454
- if self.config.get('features.statistical.enabled', True):
455
- names.extend(['area', 'perimeter', 'aspect_ratio',
456
- 'solidity', 'eccentricity', 'entropy'])
457
-
458
- if self.config.get('features.frequency.enabled', True):
459
- names.extend(['dct_mean', 'dct_std', 'high_freq_energy',
460
- 'wavelet_cA', 'wavelet_cH', 'wavelet_cV', 'wavelet_cD',
461
- 'wavelet_entropy_H', 'wavelet_entropy_V', 'wavelet_entropy_D'])
462
-
463
- if self.config.get('features.noise.enabled', True):
464
- names.extend(['ela_mean', 'ela_var', 'ela_max',
465
- 'noise_residual_mean', 'noise_residual_var'])
466
-
467
- if self.ocr_extractor is not None:
468
- names.extend(['ocr_conf_mean', 'ocr_conf_std', 'spacing_irregularity',
469
- 'text_density', 'stroke_mean', 'stroke_std'])
470
-
471
- return names
472
-
473
-
474
- def get_feature_extractor(config, is_text_document: bool = True) -> HybridFeatureExtractor:
475
- """
476
- Factory function to create feature extractor
477
-
478
- Args:
479
- config: Configuration object
480
- is_text_document: Whether input is text document
481
-
482
- Returns:
483
- HybridFeatureExtractor instance
484
- """
485
- return HybridFeatureExtractor(config, is_text_document)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hybrid feature extraction for forgery detection
3
+ Implements Critical Fix #5: Feature Group Gating
4
+ """
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from typing import Dict, List, Optional, Tuple
11
+ from scipy import ndimage
12
+ from scipy.fftpack import dct
13
+ import pywt
14
+ from skimage.measure import regionprops, label
15
+ from skimage.filters import sobel
16
+
17
+
18
+ class DeepFeatureExtractor:
19
+ """Extract deep features from decoder feature maps"""
20
+
21
+ def __init__(self):
22
+ """Initialize deep feature extractor"""
23
+ pass
24
+
25
+ def extract(self,
26
+ decoder_features: List[torch.Tensor],
27
+ region_mask: np.ndarray) -> np.ndarray:
28
+ """
29
+ Extract deep features using Global Average Pooling
30
+
31
+ Args:
32
+ decoder_features: List of decoder feature tensors
33
+ region_mask: Binary region mask (H, W)
34
+
35
+ Returns:
36
+ Deep feature vector
37
+ """
38
+ features = []
39
+
40
+ for feat in decoder_features:
41
+ # Ensure on CPU and numpy
42
+ if isinstance(feat, torch.Tensor):
43
+ feat = feat.detach().cpu().numpy()
44
+
45
+ # feat shape: (B, C, H, W) or (C, H, W)
46
+ if feat.ndim == 4:
47
+ feat = feat[0] # Take first batch
48
+
49
+ # Resize mask to feature size
50
+ h, w = feat.shape[1:]
51
+ mask_resized = cv2.resize(region_mask.astype(np.float32), (w, h))
52
+ mask_resized = mask_resized > 0.5
53
+
54
+ # Masked Global Average Pooling
55
+ if mask_resized.sum() > 0:
56
+ for c in range(feat.shape[0]):
57
+ channel_feat = feat[c]
58
+ masked_mean = channel_feat[mask_resized].mean()
59
+ features.append(masked_mean)
60
+ else:
61
+ # Fallback: use global average
62
+ features.extend(feat.mean(axis=(1, 2)).tolist())
63
+
64
+ return np.array(features, dtype=np.float32)
65
+
66
+
67
+ class StatisticalFeatureExtractor:
68
+ """Extract statistical and shape features from regions"""
69
+
70
+ def __init__(self):
71
+ """Initialize statistical feature extractor"""
72
+ pass
73
+
74
+ def extract(self,
75
+ image: np.ndarray,
76
+ region_mask: np.ndarray) -> np.ndarray:
77
+ """
78
+ Extract statistical and shape features
79
+
80
+ Args:
81
+ image: Input image (H, W, 3) normalized [0, 1]
82
+ region_mask: Binary region mask (H, W)
83
+
84
+ Returns:
85
+ Statistical feature vector
86
+ """
87
+ features = []
88
+
89
+ # Label the mask
90
+ labeled_mask = label(region_mask)
91
+ props = regionprops(labeled_mask)
92
+
93
+ if len(props) > 0:
94
+ prop = props[0]
95
+
96
+ # Area and perimeter
97
+ features.append(prop.area)
98
+ features.append(prop.perimeter)
99
+
100
+ # Aspect ratio
101
+ if prop.major_axis_length > 0:
102
+ aspect_ratio = prop.minor_axis_length / prop.major_axis_length
103
+ else:
104
+ aspect_ratio = 1.0
105
+ features.append(aspect_ratio)
106
+
107
+ # Solidity
108
+ features.append(prop.solidity)
109
+
110
+ # Eccentricity
111
+ features.append(prop.eccentricity)
112
+
113
+ # Entropy (using intensity)
114
+ if len(image.shape) == 3:
115
+ gray = cv2.cvtColor((image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
116
+ else:
117
+ gray = (image * 255).astype(np.uint8)
118
+
119
+ # Resize region_mask to match gray image dimensions
120
+ if region_mask.shape != gray.shape:
121
+ region_mask_resized = cv2.resize(
122
+ region_mask.astype(np.uint8),
123
+ (gray.shape[1], gray.shape[0]),
124
+ interpolation=cv2.INTER_NEAREST
125
+ )
126
+ else:
127
+ region_mask_resized = region_mask
128
+
129
+ region_pixels = gray[region_mask_resized > 0]
130
+ if len(region_pixels) > 0:
131
+ hist, _ = np.histogram(region_pixels, bins=256, range=(0, 256))
132
+ hist = hist / hist.sum() + 1e-8
133
+ entropy = -np.sum(hist * np.log2(hist + 1e-8))
134
+ else:
135
+ entropy = 0.0
136
+ features.append(entropy)
137
+ else:
138
+ # Default values
139
+ features.extend([0, 0, 1.0, 0, 0, 0])
140
+
141
+ return np.array(features, dtype=np.float32)
142
+
143
+
144
+ class FrequencyFeatureExtractor:
145
+ """Extract frequency-domain features"""
146
+
147
+ def __init__(self):
148
+ """Initialize frequency feature extractor"""
149
+ pass
150
+
151
+ def extract(self,
152
+ image: np.ndarray,
153
+ region_mask: np.ndarray) -> np.ndarray:
154
+ """
155
+ Extract frequency-domain features (DCT, wavelet)
156
+
157
+ Args:
158
+ image: Input image (H, W, 3) normalized [0, 1]
159
+ region_mask: Binary region mask (H, W)
160
+
161
+ Returns:
162
+ Frequency feature vector
163
+ """
164
+ features = []
165
+
166
+ # Convert to grayscale
167
+ if len(image.shape) == 3:
168
+ gray = cv2.cvtColor((image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
169
+ else:
170
+ gray = (image * 255).astype(np.uint8)
171
+
172
+ # Get region bounding box
173
+ coords = np.where(region_mask > 0)
174
+ if len(coords[0]) == 0:
175
+ return np.zeros(13, dtype=np.float32)
176
+
177
+ y_min, y_max = coords[0].min(), coords[0].max()
178
+ x_min, x_max = coords[1].min(), coords[1].max()
179
+
180
+ # Crop region
181
+ region = gray[y_min:y_max+1, x_min:x_max+1].astype(np.float32)
182
+
183
+ if region.size == 0:
184
+ return np.zeros(13, dtype=np.float32)
185
+
186
+ # DCT coefficients
187
+ try:
188
+ dct_coeffs = dct(dct(region, axis=0, norm='ortho'), axis=1, norm='ortho')
189
+
190
+ # Mean and std of DCT coefficients
191
+ features.append(np.mean(np.abs(dct_coeffs)))
192
+ features.append(np.std(dct_coeffs))
193
+
194
+ # High-frequency energy (bottom-right quadrant)
195
+ h, w = dct_coeffs.shape
196
+ high_freq = dct_coeffs[h//2:, w//2:]
197
+ features.append(np.sum(np.abs(high_freq)) / (high_freq.size + 1e-8))
198
+ except Exception:
199
+ features.extend([0, 0, 0])
200
+
201
+ # Wavelet features
202
+ try:
203
+ coeffs = pywt.dwt2(region, 'db1')
204
+ cA, (cH, cV, cD) = coeffs
205
+
206
+ # Energy in each sub-band
207
+ features.append(np.sum(cA ** 2) / (cA.size + 1e-8))
208
+ features.append(np.sum(cH ** 2) / (cH.size + 1e-8))
209
+ features.append(np.sum(cV ** 2) / (cV.size + 1e-8))
210
+ features.append(np.sum(cD ** 2) / (cD.size + 1e-8))
211
+
212
+ # Wavelet entropy
213
+ for coeff in [cH, cV, cD]:
214
+ coeff_flat = np.abs(coeff.flatten())
215
+ if coeff_flat.sum() > 0:
216
+ coeff_norm = coeff_flat / coeff_flat.sum()
217
+ entropy = -np.sum(coeff_norm * np.log2(coeff_norm + 1e-8))
218
+ else:
219
+ entropy = 0.0
220
+ features.append(entropy)
221
+ except Exception:
222
+ features.extend([0, 0, 0, 0, 0, 0, 0])
223
+
224
+ return np.array(features, dtype=np.float32)
225
+
226
+
227
+ class NoiseELAFeatureExtractor:
228
+ """Extract noise and Error Level Analysis features"""
229
+
230
+ def __init__(self, quality: int = 90):
231
+ """
232
+ Initialize noise/ELA extractor
233
+
234
+ Args:
235
+ quality: JPEG quality for ELA
236
+ """
237
+ self.quality = quality
238
+
239
+ def extract(self,
240
+ image: np.ndarray,
241
+ region_mask: np.ndarray) -> np.ndarray:
242
+ """
243
+ Extract noise and ELA features
244
+
245
+ Args:
246
+ image: Input image (H, W, 3) normalized [0, 1]
247
+ region_mask: Binary region mask (H, W)
248
+
249
+ Returns:
250
+ Noise/ELA feature vector
251
+ """
252
+ features = []
253
+
254
+ # Convert to uint8
255
+ img_uint8 = (image * 255).astype(np.uint8)
256
+
257
+ # Error Level Analysis
258
+ # Compress and compute difference
259
+ encode_param = [cv2.IMWRITE_JPEG_QUALITY, self.quality]
260
+ _, encoded = cv2.imencode('.jpg', img_uint8, encode_param)
261
+ recompressed = cv2.imdecode(encoded, cv2.IMREAD_COLOR)
262
+
263
+ ela = np.abs(img_uint8.astype(np.float32) - recompressed.astype(np.float32))
264
+
265
+ # ELA features within region
266
+ # Resize region_mask to match ela dimensions
267
+ if region_mask.shape[:2] != ela.shape[:2]:
268
+ mask_resized = cv2.resize(
269
+ region_mask.astype(np.uint8),
270
+ (ela.shape[1], ela.shape[0]),
271
+ interpolation=cv2.INTER_NEAREST
272
+ )
273
+ else:
274
+ mask_resized = region_mask
275
+
276
+ ela_region = ela[mask_resized > 0]
277
+ if len(ela_region) > 0:
278
+ features.append(np.mean(ela_region)) # ELA mean
279
+ features.append(np.var(ela_region)) # ELA variance
280
+ features.append(np.max(ela_region)) # ELA max
281
+ else:
282
+ features.extend([0, 0, 0])
283
+
284
+ # Noise residual (using median filter)
285
+ if len(image.shape) == 3:
286
+ gray = cv2.cvtColor(img_uint8, cv2.COLOR_RGB2GRAY)
287
+ else:
288
+ gray = img_uint8
289
+
290
+ median_filtered = cv2.medianBlur(gray, 3)
291
+ noise_residual = np.abs(gray.astype(np.float32) - median_filtered.astype(np.float32))
292
+
293
+ # Resize region_mask to match noise_residual dimensions
294
+ if region_mask.shape != noise_residual.shape:
295
+ mask_resized = cv2.resize(
296
+ region_mask.astype(np.uint8),
297
+ (noise_residual.shape[1], noise_residual.shape[0]),
298
+ interpolation=cv2.INTER_NEAREST
299
+ )
300
+ else:
301
+ mask_resized = region_mask
302
+
303
+ residual_region = noise_residual[mask_resized > 0]
304
+ if len(residual_region) > 0:
305
+ features.append(np.mean(residual_region))
306
+ features.append(np.var(residual_region))
307
+ else:
308
+ features.extend([0, 0])
309
+
310
+ return np.array(features, dtype=np.float32)
311
+
312
+
313
+ class OCRFeatureExtractor:
314
+ """
315
+ Extract OCR-based consistency features
316
+ Only for text documents (Feature Gating - Critical Fix #5)
317
+ """
318
+
319
+ def __init__(self):
320
+ """Initialize OCR feature extractor"""
321
+ self.ocr_available = False
322
+
323
+ try:
324
+ import easyocr
325
+ self.reader = easyocr.Reader(['en'], gpu=True)
326
+ self.ocr_available = True
327
+ except Exception:
328
+ print("Warning: EasyOCR not available, OCR features disabled")
329
+
330
+ def extract(self,
331
+ image: np.ndarray,
332
+ region_mask: np.ndarray) -> np.ndarray:
333
+ """
334
+ Extract OCR consistency features
335
+
336
+ Args:
337
+ image: Input image (H, W, 3) normalized [0, 1]
338
+ region_mask: Binary region mask (H, W)
339
+
340
+ Returns:
341
+ OCR feature vector (or zeros if not text document)
342
+ """
343
+ features = []
344
+
345
+ if not self.ocr_available:
346
+ return np.zeros(6, dtype=np.float32)
347
+
348
+ # Convert to uint8
349
+ img_uint8 = (image * 255).astype(np.uint8)
350
+
351
+ # Get region bounding box
352
+ coords = np.where(region_mask > 0)
353
+ if len(coords[0]) == 0:
354
+ return np.zeros(6, dtype=np.float32)
355
+
356
+ y_min, y_max = coords[0].min(), coords[0].max()
357
+ x_min, x_max = coords[1].min(), coords[1].max()
358
+
359
+ # Crop region
360
+ region = img_uint8[y_min:y_max+1, x_min:x_max+1]
361
+
362
+ try:
363
+ # OCR on region
364
+ results = self.reader.readtext(region)
365
+
366
+ if len(results) > 0:
367
+ # Confidence deviation
368
+ confidences = [r[2] for r in results]
369
+ features.append(np.mean(confidences))
370
+ features.append(np.std(confidences))
371
+
372
+ # Character spacing analysis
373
+ bbox_widths = [abs(r[0][1][0] - r[0][0][0]) for r in results]
374
+ if len(bbox_widths) > 1:
375
+ features.append(np.std(bbox_widths) / (np.mean(bbox_widths) + 1e-8))
376
+ else:
377
+ features.append(0.0)
378
+
379
+ # Text density
380
+ features.append(len(results) / (region.shape[0] * region.shape[1] + 1e-8))
381
+
382
+ # Stroke width variation (using edge detection)
383
+ gray_region = cv2.cvtColor(region, cv2.COLOR_RGB2GRAY)
384
+ edges = sobel(gray_region)
385
+ features.append(np.mean(edges))
386
+ features.append(np.std(edges))
387
+ else:
388
+ features.extend([0, 0, 0, 0, 0, 0])
389
+ except Exception:
390
+ features.extend([0, 0, 0, 0, 0, 0])
391
+
392
+ return np.array(features, dtype=np.float32)
393
+
394
+
395
+ class HybridFeatureExtractor:
396
+ """
397
+ Complete hybrid feature extraction
398
+ Implements Critical Fix #5: Feature Group Gating
399
+ """
400
+
401
+ def __init__(self, config, is_text_document: bool = True):
402
+ """
403
+ Initialize hybrid feature extractor
404
+
405
+ Args:
406
+ config: Configuration object
407
+ is_text_document: Whether input is text document (for OCR gating)
408
+ """
409
+ self.config = config
410
+ self.is_text_document = is_text_document
411
+
412
+ # Initialize extractors
413
+ self.deep_extractor = DeepFeatureExtractor()
414
+ self.stat_extractor = StatisticalFeatureExtractor()
415
+ self.freq_extractor = FrequencyFeatureExtractor()
416
+ self.noise_extractor = NoiseELAFeatureExtractor()
417
+
418
+ # Critical Fix #5: OCR only for text documents
419
+ if is_text_document and config.get('features.ocr.enabled', True):
420
+ self.ocr_extractor = OCRFeatureExtractor()
421
+ else:
422
+ self.ocr_extractor = None
423
+
424
+ def extract(self,
425
+ image: np.ndarray,
426
+ region_mask: np.ndarray,
427
+ decoder_features: Optional[List[torch.Tensor]] = None) -> np.ndarray:
428
+ """
429
+ Extract all hybrid features for a region
430
+
431
+ Args:
432
+ image: Input image (H, W, 3) normalized [0, 1]
433
+ region_mask: Binary region mask (H, W)
434
+ decoder_features: Optional decoder features for deep feature extraction
435
+
436
+ Returns:
437
+ Concatenated feature vector
438
+ """
439
+ all_features = []
440
+
441
+ # Deep features (if available)
442
+ if decoder_features is not None and self.config.get('features.deep.enabled', True):
443
+ deep_feats = self.deep_extractor.extract(decoder_features, region_mask)
444
+ all_features.append(deep_feats)
445
+
446
+ # Statistical & shape features
447
+ if self.config.get('features.statistical.enabled', True):
448
+ stat_feats = self.stat_extractor.extract(image, region_mask)
449
+ all_features.append(stat_feats)
450
+
451
+ # Frequency-domain features
452
+ if self.config.get('features.frequency.enabled', True):
453
+ freq_feats = self.freq_extractor.extract(image, region_mask)
454
+ all_features.append(freq_feats)
455
+
456
+ # Noise & ELA features
457
+ if self.config.get('features.noise.enabled', True):
458
+ noise_feats = self.noise_extractor.extract(image, region_mask)
459
+ all_features.append(noise_feats)
460
+
461
+ # Critical Fix #5: OCR features only for text documents
462
+ if self.ocr_extractor is not None:
463
+ ocr_feats = self.ocr_extractor.extract(image, region_mask)
464
+ all_features.append(ocr_feats)
465
+
466
+ # Concatenate all features
467
+ if len(all_features) > 0:
468
+ features = np.concatenate(all_features)
469
+ else:
470
+ features = np.array([], dtype=np.float32)
471
+
472
+ # Handle NaN/Inf
473
+ features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
474
+
475
+ return features
476
+
477
+ def get_feature_names(self) -> List[str]:
478
+ """Get list of feature names for interpretability"""
479
+ names = []
480
+
481
+ if self.config.get('features.deep.enabled', True):
482
+ names.extend([f'deep_{i}' for i in range(256)]) # Approximate
483
+
484
+ if self.config.get('features.statistical.enabled', True):
485
+ names.extend(['area', 'perimeter', 'aspect_ratio',
486
+ 'solidity', 'eccentricity', 'entropy'])
487
+
488
+ if self.config.get('features.frequency.enabled', True):
489
+ names.extend(['dct_mean', 'dct_std', 'high_freq_energy',
490
+ 'wavelet_cA', 'wavelet_cH', 'wavelet_cV', 'wavelet_cD',
491
+ 'wavelet_entropy_H', 'wavelet_entropy_V', 'wavelet_entropy_D'])
492
+
493
+ if self.config.get('features.noise.enabled', True):
494
+ names.extend(['ela_mean', 'ela_var', 'ela_max',
495
+ 'noise_residual_mean', 'noise_residual_var'])
496
+
497
+ if self.ocr_extractor is not None:
498
+ names.extend(['ocr_conf_mean', 'ocr_conf_std', 'spacing_irregularity',
499
+ 'text_density', 'stroke_mean', 'stroke_std'])
500
+
501
+ return names
502
+
503
+
504
+ def get_feature_extractor(config, is_text_document: bool = True) -> HybridFeatureExtractor:
505
+ """
506
+ Factory function to create feature extractor
507
+
508
+ Args:
509
+ config: Configuration object
510
+ is_text_document: Whether input is text document
511
+
512
+ Returns:
513
+ HybridFeatureExtractor instance
514
+ """
515
+ return HybridFeatureExtractor(config, is_text_document)