| """ |
| Hybrid feature extraction for forgery detection |
| Implements Critical Fix #5: Feature Group Gating |
| """ |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from typing import Dict, List, Optional, Tuple |
| from scipy import ndimage |
| from scipy.fftpack import dct |
| import pywt |
| from skimage.measure import regionprops, label |
| from skimage.filters import sobel |
|
|
|
|
| class DeepFeatureExtractor: |
| """Extract deep features from decoder feature maps""" |
| |
| def __init__(self): |
| """Initialize deep feature extractor""" |
| pass |
| |
| def extract(self, |
| decoder_features: List[torch.Tensor], |
| region_mask: np.ndarray) -> np.ndarray: |
| """ |
| Extract deep features using Global Average Pooling |
| |
| Args: |
| decoder_features: List of decoder feature tensors |
| region_mask: Binary region mask (H, W) |
| |
| Returns: |
| Deep feature vector |
| """ |
| features = [] |
| |
| for feat in decoder_features: |
| |
| if isinstance(feat, torch.Tensor): |
| feat = feat.detach().cpu().numpy() |
| |
| |
| if feat.ndim == 4: |
| feat = feat[0] |
| |
| |
| h, w = feat.shape[1:] |
| mask_resized = cv2.resize(region_mask.astype(np.float32), (w, h)) |
| mask_resized = mask_resized > 0.5 |
| |
| |
| if mask_resized.sum() > 0: |
| for c in range(feat.shape[0]): |
| channel_feat = feat[c] |
| masked_mean = channel_feat[mask_resized].mean() |
| features.append(masked_mean) |
| else: |
| |
| features.extend(feat.mean(axis=(1, 2)).tolist()) |
| |
| return np.array(features, dtype=np.float32) |
|
|
|
|
| class StatisticalFeatureExtractor: |
| """Extract statistical and shape features from regions""" |
| |
| def __init__(self): |
| """Initialize statistical feature extractor""" |
| pass |
| |
| def extract(self, |
| image: np.ndarray, |
| region_mask: np.ndarray) -> np.ndarray: |
| """ |
| Extract statistical and shape features |
| |
| Args: |
| image: Input image (H, W, 3) normalized [0, 1] |
| region_mask: Binary region mask (H, W) |
| |
| Returns: |
| Statistical feature vector |
| """ |
| features = [] |
| |
| |
| labeled_mask = label(region_mask) |
| props = regionprops(labeled_mask) |
| |
| if len(props) > 0: |
| prop = props[0] |
| |
| |
| features.append(prop.area) |
| features.append(prop.perimeter) |
| |
| |
| if prop.major_axis_length > 0: |
| aspect_ratio = prop.minor_axis_length / prop.major_axis_length |
| else: |
| aspect_ratio = 1.0 |
| features.append(aspect_ratio) |
| |
| |
| features.append(prop.solidity) |
| |
| |
| features.append(prop.eccentricity) |
| |
| |
| if len(image.shape) == 3: |
| gray = cv2.cvtColor((image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY) |
| else: |
| gray = (image * 255).astype(np.uint8) |
| |
| |
| if region_mask.shape != gray.shape: |
| region_mask_resized = cv2.resize( |
| region_mask.astype(np.uint8), |
| (gray.shape[1], gray.shape[0]), |
| interpolation=cv2.INTER_NEAREST |
| ) |
| else: |
| region_mask_resized = region_mask |
| |
| region_pixels = gray[region_mask_resized > 0] |
| if len(region_pixels) > 0: |
| hist, _ = np.histogram(region_pixels, bins=256, range=(0, 256)) |
| hist = hist / hist.sum() + 1e-8 |
| entropy = -np.sum(hist * np.log2(hist + 1e-8)) |
| else: |
| entropy = 0.0 |
| features.append(entropy) |
| else: |
| |
| features.extend([0, 0, 1.0, 0, 0, 0]) |
| |
| return np.array(features, dtype=np.float32) |
|
|
|
|
| class FrequencyFeatureExtractor: |
| """Extract frequency-domain features""" |
| |
| def __init__(self): |
| """Initialize frequency feature extractor""" |
| pass |
| |
| def extract(self, |
| image: np.ndarray, |
| region_mask: np.ndarray) -> np.ndarray: |
| """ |
| Extract frequency-domain features (DCT, wavelet) |
| |
| Args: |
| image: Input image (H, W, 3) normalized [0, 1] |
| region_mask: Binary region mask (H, W) |
| |
| Returns: |
| Frequency feature vector |
| """ |
| features = [] |
| |
| |
| if len(image.shape) == 3: |
| gray = cv2.cvtColor((image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY) |
| else: |
| gray = (image * 255).astype(np.uint8) |
| |
| |
| coords = np.where(region_mask > 0) |
| if len(coords[0]) == 0: |
| return np.zeros(13, dtype=np.float32) |
| |
| y_min, y_max = coords[0].min(), coords[0].max() |
| x_min, x_max = coords[1].min(), coords[1].max() |
| |
| |
| region = gray[y_min:y_max+1, x_min:x_max+1].astype(np.float32) |
| |
| if region.size == 0: |
| return np.zeros(13, dtype=np.float32) |
| |
| |
| try: |
| dct_coeffs = dct(dct(region, axis=0, norm='ortho'), axis=1, norm='ortho') |
| |
| |
| features.append(np.mean(np.abs(dct_coeffs))) |
| features.append(np.std(dct_coeffs)) |
| |
| |
| h, w = dct_coeffs.shape |
| high_freq = dct_coeffs[h//2:, w//2:] |
| features.append(np.sum(np.abs(high_freq)) / (high_freq.size + 1e-8)) |
| except Exception: |
| features.extend([0, 0, 0]) |
| |
| |
| try: |
| coeffs = pywt.dwt2(region, 'db1') |
| cA, (cH, cV, cD) = coeffs |
| |
| |
| features.append(np.sum(cA ** 2) / (cA.size + 1e-8)) |
| features.append(np.sum(cH ** 2) / (cH.size + 1e-8)) |
| features.append(np.sum(cV ** 2) / (cV.size + 1e-8)) |
| features.append(np.sum(cD ** 2) / (cD.size + 1e-8)) |
| |
| |
| for coeff in [cH, cV, cD]: |
| coeff_flat = np.abs(coeff.flatten()) |
| if coeff_flat.sum() > 0: |
| coeff_norm = coeff_flat / coeff_flat.sum() |
| entropy = -np.sum(coeff_norm * np.log2(coeff_norm + 1e-8)) |
| else: |
| entropy = 0.0 |
| features.append(entropy) |
| except Exception: |
| features.extend([0, 0, 0, 0, 0, 0, 0]) |
| |
| return np.array(features, dtype=np.float32) |
|
|
|
|
| class NoiseELAFeatureExtractor: |
| """Extract noise and Error Level Analysis features""" |
| |
| def __init__(self, quality: int = 90): |
| """ |
| Initialize noise/ELA extractor |
| |
| Args: |
| quality: JPEG quality for ELA |
| """ |
| self.quality = quality |
| |
| def extract(self, |
| image: np.ndarray, |
| region_mask: np.ndarray) -> np.ndarray: |
| """ |
| Extract noise and ELA features |
| |
| Args: |
| image: Input image (H, W, 3) normalized [0, 1] |
| region_mask: Binary region mask (H, W) |
| |
| Returns: |
| Noise/ELA feature vector |
| """ |
| features = [] |
| |
| |
| img_uint8 = (image * 255).astype(np.uint8) |
| |
| |
| |
| encode_param = [cv2.IMWRITE_JPEG_QUALITY, self.quality] |
| _, encoded = cv2.imencode('.jpg', img_uint8, encode_param) |
| recompressed = cv2.imdecode(encoded, cv2.IMREAD_COLOR) |
| |
| ela = np.abs(img_uint8.astype(np.float32) - recompressed.astype(np.float32)) |
| |
| |
| |
| if region_mask.shape[:2] != ela.shape[:2]: |
| mask_resized = cv2.resize( |
| region_mask.astype(np.uint8), |
| (ela.shape[1], ela.shape[0]), |
| interpolation=cv2.INTER_NEAREST |
| ) |
| else: |
| mask_resized = region_mask |
| |
| ela_region = ela[mask_resized > 0] |
| if len(ela_region) > 0: |
| features.append(np.mean(ela_region)) |
| features.append(np.var(ela_region)) |
| features.append(np.max(ela_region)) |
| else: |
| features.extend([0, 0, 0]) |
| |
| |
| if len(image.shape) == 3: |
| gray = cv2.cvtColor(img_uint8, cv2.COLOR_RGB2GRAY) |
| else: |
| gray = img_uint8 |
| |
| median_filtered = cv2.medianBlur(gray, 3) |
| noise_residual = np.abs(gray.astype(np.float32) - median_filtered.astype(np.float32)) |
| |
| |
| if region_mask.shape != noise_residual.shape: |
| mask_resized = cv2.resize( |
| region_mask.astype(np.uint8), |
| (noise_residual.shape[1], noise_residual.shape[0]), |
| interpolation=cv2.INTER_NEAREST |
| ) |
| else: |
| mask_resized = region_mask |
| |
| residual_region = noise_residual[mask_resized > 0] |
| if len(residual_region) > 0: |
| features.append(np.mean(residual_region)) |
| features.append(np.var(residual_region)) |
| else: |
| features.extend([0, 0]) |
| |
| return np.array(features, dtype=np.float32) |
|
|
|
|
| class OCRFeatureExtractor: |
| """ |
| Extract OCR-based consistency features |
| Only for text documents (Feature Gating - Critical Fix #5) |
| """ |
| |
| def __init__(self): |
| """Initialize OCR feature extractor""" |
| self.ocr_available = False |
| |
| try: |
| import easyocr |
| self.reader = easyocr.Reader(['en'], gpu=True) |
| self.ocr_available = True |
| except Exception: |
| print("Warning: EasyOCR not available, OCR features disabled") |
| |
| def extract(self, |
| image: np.ndarray, |
| region_mask: np.ndarray) -> np.ndarray: |
| """ |
| Extract OCR consistency features |
| |
| Args: |
| image: Input image (H, W, 3) normalized [0, 1] |
| region_mask: Binary region mask (H, W) |
| |
| Returns: |
| OCR feature vector (or zeros if not text document) |
| """ |
| features = [] |
| |
| if not self.ocr_available: |
| return np.zeros(6, dtype=np.float32) |
| |
| |
| img_uint8 = (image * 255).astype(np.uint8) |
| |
| |
| coords = np.where(region_mask > 0) |
| if len(coords[0]) == 0: |
| return np.zeros(6, dtype=np.float32) |
| |
| y_min, y_max = coords[0].min(), coords[0].max() |
| x_min, x_max = coords[1].min(), coords[1].max() |
| |
| |
| region = img_uint8[y_min:y_max+1, x_min:x_max+1] |
| |
| try: |
| |
| results = self.reader.readtext(region) |
| |
| if len(results) > 0: |
| |
| confidences = [r[2] for r in results] |
| features.append(np.mean(confidences)) |
| features.append(np.std(confidences)) |
| |
| |
| bbox_widths = [abs(r[0][1][0] - r[0][0][0]) for r in results] |
| if len(bbox_widths) > 1: |
| features.append(np.std(bbox_widths) / (np.mean(bbox_widths) + 1e-8)) |
| else: |
| features.append(0.0) |
| |
| |
| features.append(len(results) / (region.shape[0] * region.shape[1] + 1e-8)) |
| |
| |
| gray_region = cv2.cvtColor(region, cv2.COLOR_RGB2GRAY) |
| edges = sobel(gray_region) |
| features.append(np.mean(edges)) |
| features.append(np.std(edges)) |
| else: |
| features.extend([0, 0, 0, 0, 0, 0]) |
| except Exception: |
| features.extend([0, 0, 0, 0, 0, 0]) |
| |
| return np.array(features, dtype=np.float32) |
|
|
|
|
| class HybridFeatureExtractor: |
| """ |
| Complete hybrid feature extraction |
| Implements Critical Fix #5: Feature Group Gating |
| """ |
| |
| def __init__(self, config, is_text_document: bool = True): |
| """ |
| Initialize hybrid feature extractor |
| |
| Args: |
| config: Configuration object |
| is_text_document: Whether input is text document (for OCR gating) |
| """ |
| self.config = config |
| self.is_text_document = is_text_document |
| |
| |
| self.deep_extractor = DeepFeatureExtractor() |
| self.stat_extractor = StatisticalFeatureExtractor() |
| self.freq_extractor = FrequencyFeatureExtractor() |
| self.noise_extractor = NoiseELAFeatureExtractor() |
| |
| |
| if is_text_document and config.get('features.ocr.enabled', True): |
| self.ocr_extractor = OCRFeatureExtractor() |
| else: |
| self.ocr_extractor = None |
| |
| def extract(self, |
| image: np.ndarray, |
| region_mask: np.ndarray, |
| decoder_features: Optional[List[torch.Tensor]] = None) -> np.ndarray: |
| """ |
| Extract all hybrid features for a region |
| |
| Args: |
| image: Input image (H, W, 3) normalized [0, 1] |
| region_mask: Binary region mask (H, W) |
| decoder_features: Optional decoder features for deep feature extraction |
| |
| Returns: |
| Concatenated feature vector |
| """ |
| all_features = [] |
| |
| |
| if decoder_features is not None and self.config.get('features.deep.enabled', True): |
| deep_feats = self.deep_extractor.extract(decoder_features, region_mask) |
| all_features.append(deep_feats) |
| |
| |
| if self.config.get('features.statistical.enabled', True): |
| stat_feats = self.stat_extractor.extract(image, region_mask) |
| all_features.append(stat_feats) |
| |
| |
| if self.config.get('features.frequency.enabled', True): |
| freq_feats = self.freq_extractor.extract(image, region_mask) |
| all_features.append(freq_feats) |
| |
| |
| if self.config.get('features.noise.enabled', True): |
| noise_feats = self.noise_extractor.extract(image, region_mask) |
| all_features.append(noise_feats) |
| |
| |
| if self.ocr_extractor is not None: |
| ocr_feats = self.ocr_extractor.extract(image, region_mask) |
| all_features.append(ocr_feats) |
| |
| |
| if len(all_features) > 0: |
| features = np.concatenate(all_features) |
| else: |
| features = np.array([], dtype=np.float32) |
| |
| |
| features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0) |
| |
| return features |
| |
| def get_feature_names(self) -> List[str]: |
| """Get list of feature names for interpretability""" |
| names = [] |
| |
| if self.config.get('features.deep.enabled', True): |
| names.extend([f'deep_{i}' for i in range(256)]) |
| |
| if self.config.get('features.statistical.enabled', True): |
| names.extend(['area', 'perimeter', 'aspect_ratio', |
| 'solidity', 'eccentricity', 'entropy']) |
| |
| if self.config.get('features.frequency.enabled', True): |
| names.extend(['dct_mean', 'dct_std', 'high_freq_energy', |
| 'wavelet_cA', 'wavelet_cH', 'wavelet_cV', 'wavelet_cD', |
| 'wavelet_entropy_H', 'wavelet_entropy_V', 'wavelet_entropy_D']) |
| |
| if self.config.get('features.noise.enabled', True): |
| names.extend(['ela_mean', 'ela_var', 'ela_max', |
| 'noise_residual_mean', 'noise_residual_var']) |
| |
| if self.ocr_extractor is not None: |
| names.extend(['ocr_conf_mean', 'ocr_conf_std', 'spacing_irregularity', |
| 'text_density', 'stroke_mean', 'stroke_std']) |
| |
| return names |
|
|
|
|
| def get_feature_extractor(config, is_text_document: bool = True) -> HybridFeatureExtractor: |
| """ |
| Factory function to create feature extractor |
| |
| Args: |
| config: Configuration object |
| is_text_document: Whether input is text document |
| |
| Returns: |
| HybridFeatureExtractor instance |
| """ |
| return HybridFeatureExtractor(config, is_text_document) |
|
|