JKrishnanandhaa commited on
Commit
ff0e79e
·
verified ·
1 Parent(s): 6bdcef9

Upload 54 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +231 -0
  2. config.yaml +297 -0
  3. models/best_doctamper.pth +3 -0
  4. models/classifier/classifier_metadata.json +821 -0
  5. models/classifier/lightgbm_model.txt +0 -0
  6. models/classifier/scaler.joblib +3 -0
  7. src/__init__.py +32 -0
  8. src/__pycache__/__init__.cpython-312.pyc +0 -0
  9. src/config/__init__.py +5 -0
  10. src/config/__pycache__/__init__.cpython-312.pyc +0 -0
  11. src/config/__pycache__/config_loader.cpython-312.pyc +0 -0
  12. src/config/config_loader.py +117 -0
  13. src/data/__init__.py +23 -0
  14. src/data/__pycache__/__init__.cpython-312.pyc +0 -0
  15. src/data/__pycache__/augmentation.cpython-312.pyc +0 -0
  16. src/data/__pycache__/datasets.cpython-312.pyc +0 -0
  17. src/data/__pycache__/preprocessing.cpython-312.pyc +0 -0
  18. src/data/augmentation.py +150 -0
  19. src/data/datasets.py +541 -0
  20. src/data/preprocessing.py +226 -0
  21. src/features/__init__.py +32 -0
  22. src/features/__pycache__/__init__.cpython-312.pyc +0 -0
  23. src/features/__pycache__/feature_extraction.cpython-312.pyc +0 -0
  24. src/features/__pycache__/region_extraction.cpython-312.pyc +0 -0
  25. src/features/feature_extraction.py +485 -0
  26. src/features/region_extraction.py +226 -0
  27. src/inference/__init__.py +5 -0
  28. src/inference/__pycache__/__init__.cpython-312.pyc +0 -0
  29. src/inference/__pycache__/pipeline.cpython-312.pyc +0 -0
  30. src/inference/pipeline.py +359 -0
  31. src/models/__init__.py +19 -0
  32. src/models/__pycache__/__init__.cpython-312.pyc +0 -0
  33. src/models/__pycache__/decoder.cpython-312.pyc +0 -0
  34. src/models/__pycache__/encoder.cpython-312.pyc +0 -0
  35. src/models/__pycache__/losses.cpython-312.pyc +0 -0
  36. src/models/__pycache__/network.cpython-312.pyc +0 -0
  37. src/models/decoder.py +186 -0
  38. src/models/encoder.py +75 -0
  39. src/models/losses.py +168 -0
  40. src/models/network.py +133 -0
  41. src/training/__init__.py +24 -0
  42. src/training/__pycache__/__init__.cpython-312.pyc +0 -0
  43. src/training/__pycache__/classifier.cpython-312.pyc +0 -0
  44. src/training/__pycache__/metrics.cpython-312.pyc +0 -0
  45. src/training/__pycache__/trainer.cpython-312.pyc +0 -0
  46. src/training/classifier.py +282 -0
  47. src/training/metrics.py +305 -0
  48. src/training/trainer.py +450 -0
  49. src/utils/__init__.py +28 -0
  50. src/utils/__pycache__/__init__.cpython-312.pyc +0 -0
app.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Document Forgery Detection - Gradio Interface for Hugging Face Spaces
3
+
4
+ This app provides a web interface for detecting and classifying document forgeries.
5
+ """
6
+
7
+ import gradio as gr
8
+ import torch
9
+ import cv2
10
+ import numpy as np
11
+ from PIL import Image
12
+ import json
13
+ from pathlib import Path
14
+ import sys
15
+
16
+ # Add src to path
17
+ sys.path.insert(0, str(Path(__file__).parent))
18
+
19
+ from src.models import get_model
20
+ from src.config import get_config
21
+ from src.data.preprocessing import DocumentPreprocessor
22
+ from src.data.augmentation import DatasetAwareAugmentation
23
+ from src.features.region_extraction import get_mask_refiner, get_region_extractor
24
+ from src.features.feature_extraction import get_feature_extractor
25
+ from src.training.classifier import ForgeryClassifier
26
+
27
+ # Class names
28
+ CLASS_NAMES = {0: 'Copy-Move', 1: 'Splicing', 2: 'Generation'}
29
+ CLASS_COLORS = {
30
+ 0: (255, 0, 0), # Red for Copy-Move
31
+ 1: (0, 255, 0), # Green for Splicing
32
+ 2: (0, 0, 255) # Blue for Generation
33
+ }
34
+
35
+
36
+ class ForgeryDetector:
37
+ """Main forgery detection pipeline"""
38
+
39
+ def __init__(self):
40
+ print("Loading models...")
41
+
42
+ # Load config
43
+ self.config = get_config('config.yaml')
44
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
45
+
46
+ # Load segmentation model
47
+ self.model = get_model(self.config).to(self.device)
48
+ checkpoint = torch.load('models/segmentation_model.pth', map_location=self.device)
49
+ self.model.load_state_dict(checkpoint['model_state_dict'])
50
+ self.model.eval()
51
+
52
+ # Load classifier
53
+ self.classifier = ForgeryClassifier(self.config)
54
+ self.classifier.load('models/classifier')
55
+
56
+ # Initialize components
57
+ self.preprocessor = DocumentPreprocessor(self.config, 'doctamper')
58
+ self.augmentation = DatasetAwareAugmentation(self.config, 'doctamper', is_training=False)
59
+ self.mask_refiner = get_mask_refiner(self.config)
60
+ self.region_extractor = get_region_extractor(self.config)
61
+ self.feature_extractor = get_feature_extractor(self.config, is_text_document=True)
62
+
63
+ print("✓ Models loaded successfully!")
64
+
65
+ def detect(self, image):
66
+ """
67
+ Detect forgeries in document image
68
+
69
+ Args:
70
+ image: PIL Image or numpy array
71
+
72
+ Returns:
73
+ overlay_image: Image with detection overlay
74
+ results_json: Detection results as JSON
75
+ """
76
+ # Convert PIL to numpy
77
+ if isinstance(image, Image.Image):
78
+ image = np.array(image)
79
+
80
+ # Convert to RGB
81
+ if len(image.shape) == 2:
82
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
83
+ elif image.shape[2] == 4:
84
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
85
+
86
+ original_image = image.copy()
87
+
88
+ # Preprocess
89
+ preprocessed, _ = self.preprocessor(image, None)
90
+
91
+ # Augment
92
+ augmented = self.augmentation(preprocessed, None)
93
+ image_tensor = augmented['image'].unsqueeze(0).to(self.device)
94
+
95
+ # Run localization
96
+ with torch.no_grad():
97
+ logits, decoder_features = self.model(image_tensor)
98
+ prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
99
+
100
+ # Refine mask
101
+ binary_mask = (prob_map > 0.5).astype(np.uint8)
102
+ refined_mask = self.mask_refiner.refine(binary_mask, original_size=original_image.shape[:2])
103
+
104
+ # Extract regions
105
+ regions = self.region_extractor.extract(refined_mask, prob_map, original_image)
106
+
107
+ # Classify regions
108
+ results = []
109
+ for region in regions:
110
+ # Extract features
111
+ features = self.feature_extractor.extract(
112
+ preprocessed,
113
+ region['region_mask'],
114
+ [f.cpu() for f in decoder_features]
115
+ )
116
+
117
+ # Classify
118
+ predictions, confidences = self.classifier.predict(features)
119
+ forgery_type = int(predictions[0])
120
+ confidence = float(confidences[0])
121
+
122
+ if confidence > 0.6: # Confidence threshold
123
+ results.append({
124
+ 'region_id': region['region_id'],
125
+ 'bounding_box': region['bounding_box'],
126
+ 'forgery_type': CLASS_NAMES[forgery_type],
127
+ 'confidence': confidence
128
+ })
129
+
130
+ # Create visualization
131
+ overlay = self._create_overlay(original_image, results)
132
+
133
+ # Create JSON response
134
+ json_results = {
135
+ 'num_detections': len(results),
136
+ 'detections': results,
137
+ 'model_info': {
138
+ 'segmentation_dice': '75%',
139
+ 'classifier_accuracy': '92%'
140
+ }
141
+ }
142
+
143
+ return overlay, json_results
144
+
145
+ def _create_overlay(self, image, results):
146
+ """Create overlay visualization"""
147
+ overlay = image.copy()
148
+
149
+ # Draw bounding boxes and labels
150
+ for result in results:
151
+ bbox = result['bounding_box']
152
+ x, y, w, h = bbox
153
+
154
+ forgery_type = result['forgery_type']
155
+ confidence = result['confidence']
156
+
157
+ # Get color
158
+ forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0]
159
+ color = CLASS_COLORS[forgery_id]
160
+
161
+ # Draw rectangle
162
+ cv2.rectangle(overlay, (x, y), (x+w, y+h), color, 2)
163
+
164
+ # Draw label
165
+ label = f"{forgery_type}: {confidence:.1%}"
166
+ label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
167
+ cv2.rectangle(overlay, (x, y-label_size[1]-10), (x+label_size[0], y), color, -1)
168
+ cv2.putText(overlay, label, (x, y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
169
+
170
+ # Add legend
171
+ if len(results) > 0:
172
+ legend_y = 30
173
+ cv2.putText(overlay, f"Detected {len(results)} forgery region(s)",
174
+ (10, legend_y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
175
+
176
+ return overlay
177
+
178
+
179
+ # Initialize detector
180
+ detector = ForgeryDetector()
181
+
182
+
183
+ def detect_forgery(image):
184
+ """Gradio interface function"""
185
+ try:
186
+ overlay, results = detector.detect(image)
187
+ return overlay, json.dumps(results, indent=2)
188
+ except Exception as e:
189
+ return None, f"Error: {str(e)}"
190
+
191
+
192
+ # Create Gradio interface
193
+ demo = gr.Interface(
194
+ fn=detect_forgery,
195
+ inputs=gr.Image(type="pil", label="Upload Document Image"),
196
+ outputs=[
197
+ gr.Image(type="numpy", label="Detection Result"),
198
+ gr.JSON(label="Detection Details")
199
+ ],
200
+ title="📄 Document Forgery Detector",
201
+ description="""
202
+ Upload a document image to detect and classify forgeries.
203
+
204
+ **Supported Forgery Types:**
205
+ - 🔴 Copy-Move: Duplicated regions within the document
206
+ - 🟢 Splicing: Content from different sources
207
+ - 🔵 Generation: AI-generated or synthesized content
208
+
209
+ **Model Performance:**
210
+ - Localization: 75% Dice Score
211
+ - Classification: 92% Accuracy
212
+ """,
213
+ examples=[
214
+ ["examples/sample1.jpg"],
215
+ ["examples/sample2.jpg"],
216
+ ],
217
+ article="""
218
+ ### About
219
+ This model uses a hybrid deep learning approach:
220
+ 1. **Localization**: MobileNetV3-Small + UNet-Lite (detects WHERE)
221
+ 2. **Classification**: LightGBM with hybrid features (detects WHAT)
222
+
223
+ Trained on DocTamper dataset (140K samples).
224
+ """,
225
+ theme=gr.themes.Soft(),
226
+ allow_flagging="never"
227
+ )
228
+
229
+
230
+ if __name__ == "__main__":
231
+ demo.launch()
config.yaml ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hybrid Document Forgery Detection - Configuration
2
+
3
+ # System Settings
4
+ system:
5
+ device: cuda # cuda or cpu
6
+ num_workers: 0 # Reduced to avoid multiprocessing errors
7
+ pin_memory: true
8
+ seed: 42
9
+
10
+ # Data Settings
11
+ data:
12
+ image_size: 384
13
+ batch_size: 8 # Reduced for 16GB RAM
14
+ num_classes: 3 # copy_move, splicing, text_substitution
15
+
16
+ # Dataset paths
17
+ datasets:
18
+ doctamper:
19
+ path: datasets/DocTamper
20
+ type: lmdb
21
+ has_pixel_mask: true
22
+ min_region_area: 0.001 # 0.1%
23
+
24
+ rtm:
25
+ path: datasets/RealTextManipulation
26
+ type: folder
27
+ has_pixel_mask: true
28
+ min_region_area: 0.0003 # 0.03%
29
+
30
+ casia:
31
+ path: datasets/CASIA 1.0 dataset
32
+ type: folder
33
+ has_pixel_mask: false
34
+ min_region_area: 0.001 # 0.1%
35
+ skip_deskew: true
36
+ skip_denoising: true
37
+
38
+ receipts:
39
+ path: datasets/findit2
40
+ type: folder
41
+ has_pixel_mask: true
42
+ min_region_area: 0.0005 # 0.05%
43
+
44
+ fcd:
45
+ path: datasets/DocTamper/DocTamperV1-FCD
46
+ type: lmdb
47
+ has_pixel_mask: true
48
+ min_region_area: 0.00035 # 0.035% (larger forgeries, keep 99%)
49
+
50
+ scd:
51
+ path: datasets/DocTamper/DocTamperV1-SCD
52
+ type: lmdb
53
+ has_pixel_mask: true
54
+ min_region_area: 0.00009 # 0.009% (small forgeries, keep 91.5%)
55
+
56
+ # Chunked training for DocTamper (RAM constraint)
57
+ chunked_training:
58
+ enabled: true
59
+ dataset: doctamper
60
+ chunks:
61
+ - {start: 0.0, end: 0.25, name: "chunk_1"}
62
+ - {start: 0.25, end: 0.5, name: "chunk_2"}
63
+ - {start: 0.5, end: 0.75, name: "chunk_3"}
64
+ - {start: 0.75, end: 1.0, name: "chunk_4"}
65
+
66
+ # Mixed dataset training (TrainingSet + FCD + SCD)
67
+ mixing_ratios:
68
+ doctamper: 0.70 # 70% TrainingSet (maintains baseline)
69
+ scd: 0.20 # 20% SCD (handles small forgeries, 0.88% avg)
70
+ fcd: 0.10 # 10% FCD (adds diversity, 3.55% avg)
71
+
72
+ # Preprocessing
73
+ preprocessing:
74
+ deskew: true
75
+ normalize: true
76
+ noise_threshold: 15.0 # Laplacian variance threshold
77
+ median_filter_size: 3
78
+ gaussian_sigma: 0.8
79
+
80
+ # Dataset-aware preprocessing
81
+ dataset_specific:
82
+ casia:
83
+ deskew: false
84
+ denoising: false
85
+
86
+ # Augmentation (Training only)
87
+ augmentation:
88
+ enabled: true
89
+
90
+ # Common augmentations
91
+ common:
92
+ - {type: "noise", prob: 0.3}
93
+ - {type: "motion_blur", prob: 0.2}
94
+ - {type: "jpeg_compression", prob: 0.3, quality: [60, 95]}
95
+ - {type: "lighting", prob: 0.3}
96
+ - {type: "perspective", prob: 0.2}
97
+
98
+ # Dataset-specific augmentations
99
+ receipts:
100
+ - {type: "stain", prob: 0.2}
101
+ - {type: "fold", prob: 0.15}
102
+
103
+ # Model Architecture
104
+ model:
105
+ # Encoder
106
+ encoder:
107
+ name: mobilenetv3_small_100
108
+ pretrained: true
109
+ features_only: true
110
+
111
+ # Decoder
112
+ decoder:
113
+ name: unet_lite
114
+ channels: [16, 24, 40, 48, 96] # MobileNetV3-Small feature channels
115
+ upsampling: bilinear
116
+ use_depthwise_separable: true
117
+
118
+ # Output
119
+ output_channels: 1 # Binary forgery mask
120
+
121
+ # Loss Function
122
+ loss:
123
+ # Dataset-aware loss
124
+ use_dice: true # Only for datasets with pixel masks
125
+ bce_weight: 1.0
126
+ dice_weight: 1.0
127
+
128
+ # Training
129
+ training:
130
+ epochs: 30 # Per chunk (increased for single-pass training)
131
+ learning_rate: 0.001 # Higher initial LR for faster convergence
132
+ weight_decay: 0.0001 # Slight increase for better regularization
133
+
134
+ # Optimizer
135
+ optimizer: adamw
136
+
137
+ # Scheduler
138
+ scheduler:
139
+ type: cosine_annealing_warm_restarts
140
+ T_0: 10 # Restart every 10 epochs
141
+ T_mult: 2 # Double restart period each time
142
+ warmup_epochs: 3 # Warmup for first 3 epochs
143
+ min_lr: 0.00001 # End at 1/100th of initial LR
144
+
145
+ # Early stopping
146
+ early_stopping:
147
+ enabled: true
148
+ patience: 10 # Increased to allow more exploration
149
+ min_delta: 0.0005 # Accept smaller improvements (0.05%)
150
+ restore_best_weights: true # Restore best model when stopping
151
+ monitor: val_dice
152
+ mode: max
153
+
154
+ # Checkpointing
155
+ checkpoint:
156
+ save_best: true
157
+ save_every: 5 # Save every 5 epochs
158
+ save_last: true # Also save last checkpoint
159
+ monitor: val_dice
160
+
161
+ # Mask Refinement
162
+ mask_refinement:
163
+ threshold: 0.5
164
+ morphology:
165
+ closing_kernel: 5
166
+ opening_kernel: 3
167
+
168
+ # Adaptive thresholds per dataset
169
+ min_region_area:
170
+ rtm: 0.0003
171
+ receipts: 0.0005
172
+ default: 0.001
173
+
174
+ # Feature Extraction
175
+ features:
176
+ # Deep features
177
+ deep:
178
+ enabled: true
179
+ pooling: gap # Global Average Pooling
180
+
181
+ # Statistical & Shape features
182
+ statistical:
183
+ enabled: true
184
+ features:
185
+ - area
186
+ - perimeter
187
+ - aspect_ratio
188
+ - solidity
189
+ - eccentricity
190
+ - entropy
191
+
192
+ # Frequency-domain features
193
+ frequency:
194
+ enabled: true
195
+ features:
196
+ - dct_coefficients
197
+ - high_frequency_energy
198
+ - wavelet_energy
199
+
200
+ # Noise & ELA features
201
+ noise:
202
+ enabled: true
203
+ features:
204
+ - ela_mean
205
+ - ela_variance
206
+ - noise_residual
207
+
208
+ # OCR-consistency features (text documents only)
209
+ ocr:
210
+ enabled: true
211
+ gated: true # Only for text documents
212
+ features:
213
+ - confidence_deviation
214
+ - spacing_irregularity
215
+ - stroke_width_variation
216
+
217
+ # Feature normalization
218
+ normalization:
219
+ method: standard_scaler
220
+ handle_missing: true
221
+
222
+ # LightGBM Classifier
223
+ classifier:
224
+ model: lightgbm
225
+ params:
226
+ objective: multiclass
227
+ num_class: 3
228
+ boosting_type: gbdt
229
+ num_leaves: 31
230
+ learning_rate: 0.05
231
+ n_estimators: 200
232
+ max_depth: 7
233
+ min_child_samples: 20
234
+ subsample: 0.8
235
+ colsample_bytree: 0.8
236
+ reg_alpha: 0.1
237
+ reg_lambda: 0.1
238
+ random_state: 42
239
+
240
+ # Confidence threshold
241
+ confidence_threshold: 0.6
242
+
243
+ # Metrics
244
+ metrics:
245
+ # Localization metrics (only for datasets with pixel masks)
246
+ localization:
247
+ - iou
248
+ - dice
249
+ - precision
250
+ - recall
251
+
252
+ # Classification metrics
253
+ classification:
254
+ - accuracy
255
+ - f1_score
256
+ - precision
257
+ - recall
258
+ - confusion_matrix
259
+
260
+ # Dataset-aware metric computation
261
+ compute_localization:
262
+ doctamper: true
263
+ rtm: true
264
+ casia: false
265
+ receipts: true
266
+
267
+ # Outputs
268
+ outputs:
269
+ base_dir: outputs
270
+
271
+ # Subdirectories
272
+ checkpoints: outputs/checkpoints
273
+ logs: outputs/logs
274
+ plots: outputs/plots
275
+ results: outputs/results
276
+
277
+ # Visualization
278
+ visualization:
279
+ save_mask: true
280
+ save_overlay: true
281
+ save_json: true
282
+ overlay_alpha: 0.5
283
+ colormap: jet
284
+
285
+ # Deployment
286
+ deployment:
287
+ export_onnx: true
288
+ onnx_path: outputs/model.onnx
289
+ quantization: false
290
+ opset_version: 14
291
+
292
+ # Logging
293
+ logging:
294
+ level: INFO
295
+ tensorboard: true
296
+ csv: true
297
+ console: true
models/best_doctamper.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d049ca9d4dc28c8d01519f8faab1ec131a05de877da9703ee5bb0e9322095ad2
3
+ size 14283981
models/classifier/classifier_metadata.json ADDED
@@ -0,0 +1,821 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "confidence_threshold": 0.6,
3
+ "class_names": [
4
+ "copy_move",
5
+ "splicing",
6
+ "text_substitution"
7
+ ],
8
+ "feature_names": [
9
+ "deep_0",
10
+ "deep_1",
11
+ "deep_2",
12
+ "deep_3",
13
+ "deep_4",
14
+ "deep_5",
15
+ "deep_6",
16
+ "deep_7",
17
+ "deep_8",
18
+ "deep_9",
19
+ "deep_10",
20
+ "deep_11",
21
+ "deep_12",
22
+ "deep_13",
23
+ "deep_14",
24
+ "deep_15",
25
+ "deep_16",
26
+ "deep_17",
27
+ "deep_18",
28
+ "deep_19",
29
+ "deep_20",
30
+ "deep_21",
31
+ "deep_22",
32
+ "deep_23",
33
+ "deep_24",
34
+ "deep_25",
35
+ "deep_26",
36
+ "deep_27",
37
+ "deep_28",
38
+ "deep_29",
39
+ "deep_30",
40
+ "deep_31",
41
+ "deep_32",
42
+ "deep_33",
43
+ "deep_34",
44
+ "deep_35",
45
+ "deep_36",
46
+ "deep_37",
47
+ "deep_38",
48
+ "deep_39",
49
+ "deep_40",
50
+ "deep_41",
51
+ "deep_42",
52
+ "deep_43",
53
+ "deep_44",
54
+ "deep_45",
55
+ "deep_46",
56
+ "deep_47",
57
+ "deep_48",
58
+ "deep_49",
59
+ "deep_50",
60
+ "deep_51",
61
+ "deep_52",
62
+ "deep_53",
63
+ "deep_54",
64
+ "deep_55",
65
+ "deep_56",
66
+ "deep_57",
67
+ "deep_58",
68
+ "deep_59",
69
+ "deep_60",
70
+ "deep_61",
71
+ "deep_62",
72
+ "deep_63",
73
+ "deep_64",
74
+ "deep_65",
75
+ "deep_66",
76
+ "deep_67",
77
+ "deep_68",
78
+ "deep_69",
79
+ "deep_70",
80
+ "deep_71",
81
+ "deep_72",
82
+ "deep_73",
83
+ "deep_74",
84
+ "deep_75",
85
+ "deep_76",
86
+ "deep_77",
87
+ "deep_78",
88
+ "deep_79",
89
+ "deep_80",
90
+ "deep_81",
91
+ "deep_82",
92
+ "deep_83",
93
+ "deep_84",
94
+ "deep_85",
95
+ "deep_86",
96
+ "deep_87",
97
+ "deep_88",
98
+ "deep_89",
99
+ "deep_90",
100
+ "deep_91",
101
+ "deep_92",
102
+ "deep_93",
103
+ "deep_94",
104
+ "deep_95",
105
+ "deep_96",
106
+ "deep_97",
107
+ "deep_98",
108
+ "deep_99",
109
+ "deep_100",
110
+ "deep_101",
111
+ "deep_102",
112
+ "deep_103",
113
+ "deep_104",
114
+ "deep_105",
115
+ "deep_106",
116
+ "deep_107",
117
+ "deep_108",
118
+ "deep_109",
119
+ "deep_110",
120
+ "deep_111",
121
+ "deep_112",
122
+ "deep_113",
123
+ "deep_114",
124
+ "deep_115",
125
+ "deep_116",
126
+ "deep_117",
127
+ "deep_118",
128
+ "deep_119",
129
+ "deep_120",
130
+ "deep_121",
131
+ "deep_122",
132
+ "deep_123",
133
+ "deep_124",
134
+ "deep_125",
135
+ "deep_126",
136
+ "deep_127",
137
+ "deep_128",
138
+ "deep_129",
139
+ "deep_130",
140
+ "deep_131",
141
+ "deep_132",
142
+ "deep_133",
143
+ "deep_134",
144
+ "deep_135",
145
+ "deep_136",
146
+ "deep_137",
147
+ "deep_138",
148
+ "deep_139",
149
+ "deep_140",
150
+ "deep_141",
151
+ "deep_142",
152
+ "deep_143",
153
+ "deep_144",
154
+ "deep_145",
155
+ "deep_146",
156
+ "deep_147",
157
+ "deep_148",
158
+ "deep_149",
159
+ "deep_150",
160
+ "deep_151",
161
+ "deep_152",
162
+ "deep_153",
163
+ "deep_154",
164
+ "deep_155",
165
+ "deep_156",
166
+ "deep_157",
167
+ "deep_158",
168
+ "deep_159",
169
+ "deep_160",
170
+ "deep_161",
171
+ "deep_162",
172
+ "deep_163",
173
+ "deep_164",
174
+ "deep_165",
175
+ "deep_166",
176
+ "deep_167",
177
+ "deep_168",
178
+ "deep_169",
179
+ "deep_170",
180
+ "deep_171",
181
+ "deep_172",
182
+ "deep_173",
183
+ "deep_174",
184
+ "deep_175",
185
+ "deep_176",
186
+ "deep_177",
187
+ "deep_178",
188
+ "deep_179",
189
+ "deep_180",
190
+ "deep_181",
191
+ "deep_182",
192
+ "deep_183",
193
+ "deep_184",
194
+ "deep_185",
195
+ "deep_186",
196
+ "deep_187",
197
+ "deep_188",
198
+ "deep_189",
199
+ "deep_190",
200
+ "deep_191",
201
+ "deep_192",
202
+ "deep_193",
203
+ "deep_194",
204
+ "deep_195",
205
+ "deep_196",
206
+ "deep_197",
207
+ "deep_198",
208
+ "deep_199",
209
+ "deep_200",
210
+ "deep_201",
211
+ "deep_202",
212
+ "deep_203",
213
+ "deep_204",
214
+ "deep_205",
215
+ "deep_206",
216
+ "deep_207",
217
+ "deep_208",
218
+ "deep_209",
219
+ "deep_210",
220
+ "deep_211",
221
+ "deep_212",
222
+ "deep_213",
223
+ "deep_214",
224
+ "deep_215",
225
+ "deep_216",
226
+ "deep_217",
227
+ "deep_218",
228
+ "deep_219",
229
+ "deep_220",
230
+ "deep_221",
231
+ "deep_222",
232
+ "deep_223",
233
+ "deep_224",
234
+ "deep_225",
235
+ "deep_226",
236
+ "deep_227",
237
+ "deep_228",
238
+ "deep_229",
239
+ "deep_230",
240
+ "deep_231",
241
+ "deep_232",
242
+ "deep_233",
243
+ "deep_234",
244
+ "deep_235",
245
+ "deep_236",
246
+ "deep_237",
247
+ "deep_238",
248
+ "deep_239",
249
+ "deep_240",
250
+ "deep_241",
251
+ "deep_242",
252
+ "deep_243",
253
+ "deep_244",
254
+ "deep_245",
255
+ "deep_246",
256
+ "deep_247",
257
+ "deep_248",
258
+ "deep_249",
259
+ "deep_250",
260
+ "deep_251",
261
+ "deep_252",
262
+ "deep_253",
263
+ "deep_254",
264
+ "deep_255",
265
+ "area",
266
+ "perimeter",
267
+ "aspect_ratio",
268
+ "solidity",
269
+ "eccentricity",
270
+ "entropy",
271
+ "dct_mean",
272
+ "dct_std",
273
+ "high_freq_energy",
274
+ "wavelet_cA",
275
+ "wavelet_cH",
276
+ "wavelet_cV",
277
+ "wavelet_cD",
278
+ "wavelet_entropy_H",
279
+ "wavelet_entropy_V",
280
+ "wavelet_entropy_D",
281
+ "ela_mean",
282
+ "ela_var",
283
+ "ela_max",
284
+ "noise_residual_mean",
285
+ "noise_residual_var",
286
+ "ocr_conf_mean",
287
+ "ocr_conf_std",
288
+ "spacing_irregularity",
289
+ "text_density",
290
+ "stroke_mean",
291
+ "stroke_std"
292
+ ],
293
+ "feature_importance": [
294
+ 151.5697784423828,
295
+ 8.955550193786621,
296
+ 32.9064998626709,
297
+ 151.0029697418213,
298
+ 19.174699783325195,
299
+ 157.97871017456055,
300
+ 45.12229919433594,
301
+ 19.72992992401123,
302
+ 105.08611106872559,
303
+ 0.0,
304
+ 148.97894096374512,
305
+ 35.71831035614014,
306
+ 50.15155029296875,
307
+ 71.74272060394287,
308
+ 43.958970069885254,
309
+ 129.9348111152649,
310
+ 27.99122953414917,
311
+ 61.592909812927246,
312
+ 295.4245676994324,
313
+ 61.00736045837402,
314
+ 28.548550128936768,
315
+ 0.0,
316
+ 54.50248908996582,
317
+ 93.74169921875,
318
+ 120.9488091468811,
319
+ 148.32109832763672,
320
+ 30.55735969543457,
321
+ 59.058170318603516,
322
+ 82.7595911026001,
323
+ 49.24997901916504,
324
+ 0.0,
325
+ 23.502280235290527,
326
+ 392.399715423584,
327
+ 551.6174192428589,
328
+ 0.0,
329
+ 50.8812894821167,
330
+ 60.7820405960083,
331
+ 78.98891925811768,
332
+ 0.0,
333
+ 9.173580169677734,
334
+ 631.6932668685913,
335
+ 42.097740173339844,
336
+ 305.0536642074585,
337
+ 416.94709300994873,
338
+ 92.70171976089478,
339
+ 66.76712036132812,
340
+ 1435.1315097808838,
341
+ 0.0,
342
+ 126.6096019744873,
343
+ 111.61981964111328,
344
+ 124.68002033233643,
345
+ 46.16030025482178,
346
+ 12.660099983215332,
347
+ 115.48313999176025,
348
+ 86.43069076538086,
349
+ 16.674290657043457,
350
+ 110.49228954315186,
351
+ 0.0,
352
+ 98.00746059417725,
353
+ 98.95538091659546,
354
+ 41.432090759277344,
355
+ 11.24590015411377,
356
+ 65.1699800491333,
357
+ 9.251449584960938,
358
+ 100.24416923522949,
359
+ 109.5842399597168,
360
+ 83.83185005187988,
361
+ 196.82151079177856,
362
+ 0.0,
363
+ 455.4096431732178,
364
+ 120.69411087036133,
365
+ 23.130990028381348,
366
+ 18.21858024597168,
367
+ 69.65920066833496,
368
+ 82.33455085754395,
369
+ 0.0,
370
+ 82.21379089355469,
371
+ 119.78182220458984,
372
+ 65.07565069198608,
373
+ 53.62262964248657,
374
+ 247.53085803985596,
375
+ 144.45191097259521,
376
+ 38.63272047042847,
377
+ 82.24878883361816,
378
+ 60.303489685058594,
379
+ 8.717499732971191,
380
+ 412.6672077178955,
381
+ 54.25755023956299,
382
+ 0.0,
383
+ 23.141600608825684,
384
+ 62.88635063171387,
385
+ 144.1060814857483,
386
+ 352.47050952911377,
387
+ 23.701799392700195,
388
+ 180.19217205047607,
389
+ 74.43132972717285,
390
+ 0.0,
391
+ 92.36961936950684,
392
+ 418.40467262268066,
393
+ 163.96015119552612,
394
+ 136.4917197227478,
395
+ 8.362039566040039,
396
+ 10.378399848937988,
397
+ 30.465800285339355,
398
+ 47.935009479522705,
399
+ 28.957390308380127,
400
+ 61.46374034881592,
401
+ 11.319199562072754,
402
+ 142.72890949249268,
403
+ 0.0,
404
+ 140.48277807235718,
405
+ 59.3709602355957,
406
+ 9.517510414123535,
407
+ 22.945700645446777,
408
+ 85.35987043380737,
409
+ 25.964330196380615,
410
+ 18.778900146484375,
411
+ 79.01968955993652,
412
+ 74.93959999084473,
413
+ 0.0,
414
+ 36.94928026199341,
415
+ 47.99788188934326,
416
+ 84.99461078643799,
417
+ 65.24014949798584,
418
+ 128.61994075775146,
419
+ 71.96449947357178,
420
+ 0.0,
421
+ 60.59358024597168,
422
+ 0.0,
423
+ 144.41107177734375,
424
+ 119.25859117507935,
425
+ 0.0,
426
+ 29.235299110412598,
427
+ 75.50409030914307,
428
+ 0.0,
429
+ 0.0,
430
+ 133.30608654022217,
431
+ 50.813700675964355,
432
+ 7.879730224609375,
433
+ 80.23723936080933,
434
+ 28.72357988357544,
435
+ 85.63543939590454,
436
+ 88.70749998092651,
437
+ 0.0,
438
+ 38.14083003997803,
439
+ 10.110199928283691,
440
+ 223.45562982559204,
441
+ 0.0,
442
+ 189.3048586845398,
443
+ 11.311699867248535,
444
+ 87.91403198242188,
445
+ 45.88195037841797,
446
+ 57.93142032623291,
447
+ 621.7998056411743,
448
+ 151.6710205078125,
449
+ 55.90662956237793,
450
+ 310.18284845352173,
451
+ 0.0,
452
+ 37.39265060424805,
453
+ 142.64961051940918,
454
+ 86.32072973251343,
455
+ 167.73473930358887,
456
+ 135.1251916885376,
457
+ 67.87245082855225,
458
+ 25.777999877929688,
459
+ 82.70090961456299,
460
+ 160.77113008499146,
461
+ 0.0,
462
+ 109.31087112426758,
463
+ 36.81955051422119,
464
+ 21.341699600219727,
465
+ 39.508570194244385,
466
+ 0.0,
467
+ 12.186599731445312,
468
+ 52.13583946228027,
469
+ 242.86930990219116,
470
+ 0.0,
471
+ 27.03380012512207,
472
+ 11.51550006866455,
473
+ 102.65280055999756,
474
+ 8.523859977722168,
475
+ 105.87909126281738,
476
+ 0.0,
477
+ 191.5287847518921,
478
+ 16.16029930114746,
479
+ 43.0986704826355,
480
+ 0.0,
481
+ 54.736299991607666,
482
+ 145.84991836547852,
483
+ 62.068660736083984,
484
+ 72.52587032318115,
485
+ 81.85652828216553,
486
+ 25.7001895904541,
487
+ 36.71660041809082,
488
+ 78.73716068267822,
489
+ 145.95945167541504,
490
+ 146.47522068023682,
491
+ 23.559300422668457,
492
+ 39.53977966308594,
493
+ 194.42743015289307,
494
+ 66.81133842468262,
495
+ 0.0,
496
+ 156.6984510421753,
497
+ 671.7460441589355,
498
+ 38.70531988143921,
499
+ 0.0,
500
+ 356.6153998374939,
501
+ 0.0,
502
+ 0.0,
503
+ 166.1197419166565,
504
+ 0.0,
505
+ 73.76784992218018,
506
+ 82.50808954238892,
507
+ 249.50656414031982,
508
+ 21.96009922027588,
509
+ 43.69997024536133,
510
+ 0.0,
511
+ 95.96379089355469,
512
+ 80.70125961303711,
513
+ 0.0,
514
+ 0.0,
515
+ 31.88983964920044,
516
+ 301.3817310333252,
517
+ 0.0,
518
+ 15.77073049545288,
519
+ 396.3671169281006,
520
+ 83.96024990081787,
521
+ 265.5281705856323,
522
+ 47.332489013671875,
523
+ 0.0,
524
+ 268.84939098358154,
525
+ 58.15328025817871,
526
+ 31.172239780426025,
527
+ 30.765819549560547,
528
+ 10.469799995422363,
529
+ 16.379559993743896,
530
+ 28.163670539855957,
531
+ 199.17678022384644,
532
+ 112.94913101196289,
533
+ 5.905869960784912,
534
+ 719.0067505836487,
535
+ 157.29250049591064,
536
+ 92.6033205986023,
537
+ 73.79398918151855,
538
+ 24.25756072998047,
539
+ 0.0,
540
+ 31.15705966949463,
541
+ 50.47894048690796,
542
+ 73.0004301071167,
543
+ 131.88961124420166,
544
+ 0.0,
545
+ 44.40921926498413,
546
+ 59.08494997024536,
547
+ 60.722700119018555,
548
+ 108.21477127075195,
549
+ 78.56892967224121,
550
+ 486.87088108062744,
551
+ 235.95975875854492,
552
+ 1809.188328742981,
553
+ 396.9979257583618,
554
+ 441.098051071167,
555
+ 218.83313035964966,
556
+ 265.3398394584656,
557
+ 595.3824620246887,
558
+ 6126.337133407593,
559
+ 3245.946928501129,
560
+ 170.21856021881104,
561
+ 262.3172616958618,
562
+ 98.2627010345459,
563
+ 146.45634078979492,
564
+ 135.70992946624756,
565
+ 34.09130001068115,
566
+ 14156.531812667847,
567
+ 227.55861043930054,
568
+ 121.6160798072815,
569
+ 409.0565061569214,
570
+ 282.5465121269226,
571
+ 481.5555577278137,
572
+ 291.560200214386,
573
+ 797.986575126648,
574
+ 246.7717628479004,
575
+ 6129.707794189453,
576
+ 957.9258012771606,
577
+ 4484.775461196899,
578
+ 5722.659900188446,
579
+ 393.6506414413452,
580
+ 882.6219139099121,
581
+ 264.54289960861206,
582
+ 79.82537126541138,
583
+ 228.20479917526245,
584
+ 155.19043970108032,
585
+ 319.6992588043213,
586
+ 391.5327887535095,
587
+ 2005.5544757843018,
588
+ 0.0,
589
+ 1028.816568851471,
590
+ 577.8704214096069,
591
+ 159.98183917999268,
592
+ 138.31745052337646,
593
+ 115.26242113113403,
594
+ 117.50687980651855,
595
+ 0.0,
596
+ 270.78229904174805,
597
+ 300.6347818374634,
598
+ 164.85750007629395,
599
+ 542.5208883285522,
600
+ 10002.710669994354,
601
+ 502.5058374404907,
602
+ 6619.406281471252,
603
+ 194.39686965942383,
604
+ 0.0,
605
+ 239.30037021636963,
606
+ 129.93587112426758,
607
+ 149.23295974731445,
608
+ 57.12141132354736,
609
+ 152.30589962005615,
610
+ 590.8979144096375,
611
+ 125.51728057861328,
612
+ 216.1852297782898,
613
+ 4445.603507041931,
614
+ 0.0,
615
+ 97.60689973831177,
616
+ 497.5633420944214,
617
+ 699.1335229873657,
618
+ 159.68335962295532,
619
+ 127.93899154663086,
620
+ 148.00423860549927,
621
+ 385.3561215400696,
622
+ 1255.3204145431519,
623
+ 170.33005905151367,
624
+ 564.577874660492,
625
+ 1513.99400806427,
626
+ 254.163161277771,
627
+ 782.5869626998901,
628
+ 166.38124132156372,
629
+ 4800.666547775269,
630
+ 271.63431215286255,
631
+ 225.10281944274902,
632
+ 674.5281610488892,
633
+ 198.04610967636108,
634
+ 4262.1786432266235,
635
+ 0.0,
636
+ 0.0,
637
+ 749.2932777404785,
638
+ 50.16440010070801,
639
+ 350.71588039398193,
640
+ 169.4644889831543,
641
+ 3843.8212938308716,
642
+ 0.0,
643
+ 0.0,
644
+ 1463.2607378959656,
645
+ 0.0,
646
+ 914.5419778823853,
647
+ 213.03434944152832,
648
+ 32.90106964111328,
649
+ 119.6264705657959,
650
+ 137.204270362854,
651
+ 359.72862100601196,
652
+ 75.62465047836304,
653
+ 446.62164974212646,
654
+ 105.61136054992676,
655
+ 2787.228641986847,
656
+ 311.6961917877197,
657
+ 156.06305074691772,
658
+ 1498.6027584075928,
659
+ 185.69973182678223,
660
+ 147.8509397506714,
661
+ 12.531700134277344,
662
+ 0.0,
663
+ 192.53613948822021,
664
+ 424.5432171821594,
665
+ 259.268039226532,
666
+ 175.13502979278564,
667
+ 281.5383825302124,
668
+ 299.1759967803955,
669
+ 227.893488407135,
670
+ 136.72871112823486,
671
+ 416.3120012283325,
672
+ 115.03175830841064,
673
+ 0.0,
674
+ 144.02852058410645,
675
+ 208.2749309539795,
676
+ 160.34006214141846,
677
+ 109.58282947540283,
678
+ 1500.150812625885,
679
+ 4945.450592041016,
680
+ 2852.855231285095,
681
+ 881.7318058013916,
682
+ 397.0553340911865,
683
+ 315.55763959884644,
684
+ 2086.7152404785156,
685
+ 1611.37087059021,
686
+ 2103.3109679222107,
687
+ 3135.3377957344055,
688
+ 2692.6771001815796,
689
+ 4584.85631608963,
690
+ 1700.0699429512024,
691
+ 883.6995916366577,
692
+ 33464.33708667755,
693
+ 574.8801603317261,
694
+ 2229.160650253296,
695
+ 379.5017247200012,
696
+ 905.5721397399902,
697
+ 493.963942527771,
698
+ 4049.96994638443,
699
+ 189.95257091522217,
700
+ 61.00449848175049,
701
+ 450.8264832496643,
702
+ 398.1711621284485,
703
+ 38847.667073726654,
704
+ 1835.184115409851,
705
+ 2697.096595287323,
706
+ 4710.6771783828735,
707
+ 5588.210665225983,
708
+ 1004.0054593086243,
709
+ 652.6680641174316,
710
+ 2031.7795896530151,
711
+ 367.2168278694153,
712
+ 2698.1613121032715,
713
+ 591.61465883255,
714
+ 448.26813650131226,
715
+ 849.9976563453674,
716
+ 8368.735646724701,
717
+ 414.3280692100525,
718
+ 3544.0216879844666,
719
+ 679.3534464836121,
720
+ 247.58060026168823,
721
+ 402.0281286239624,
722
+ 5822.276999950409,
723
+ 1743.6888279914856,
724
+ 2081.8095812797546,
725
+ 1696.2736263275146,
726
+ 197.28233861923218,
727
+ 3321.6009736061096,
728
+ 2298.3414697647095,
729
+ 2910.3161034584045,
730
+ 296.4575996398926,
731
+ 14755.747835159302,
732
+ 6977.302089691162,
733
+ 3608.7710394859314,
734
+ 289.08115005493164,
735
+ 2645.5259099006653,
736
+ 158.54701232910156,
737
+ 490.0809507369995,
738
+ 1880.1874709129333,
739
+ 1493.8953075408936,
740
+ 609.5897555351257,
741
+ 462.8165135383606,
742
+ 243.31624794006348,
743
+ 150.1076784133911,
744
+ 6197.5719475746155,
745
+ 1036.8616194725037,
746
+ 5302.397746086121,
747
+ 1388.753752708435,
748
+ 2091.038170814514,
749
+ 785.7442808151245,
750
+ 377.4342908859253,
751
+ 3640.3371028900146,
752
+ 1029.8467602729797,
753
+ 296.86861085891724,
754
+ 1221.5854263305664,
755
+ 535.2803363800049,
756
+ 2508.307864189148,
757
+ 3831.0581674575806,
758
+ 2263.3348484039307,
759
+ 926.5323433876038,
760
+ 8959.179275035858,
761
+ 309.04264068603516,
762
+ 1767.5786666870117,
763
+ 2107.6189522743225,
764
+ 155.21375036239624,
765
+ 378.6039876937866,
766
+ 2220.862048149109,
767
+ 1505.2828221321106,
768
+ 517.8384418487549,
769
+ 4313.928272247314,
770
+ 342.4098491668701,
771
+ 1310.0776271820068,
772
+ 434.5597867965698,
773
+ 2071.2271361351013,
774
+ 0.0,
775
+ 8595.476936340332,
776
+ 202.46072053909302,
777
+ 366.71736097335815,
778
+ 7074.809521198273,
779
+ 6.880340099334717,
780
+ 1959.3085498809814,
781
+ 636.0715098381042,
782
+ 9.84004020690918,
783
+ 386.9805417060852,
784
+ 2382.4822087287903,
785
+ 2317.9521684646606,
786
+ 2793.7392020225525,
787
+ 1188.6612939834595,
788
+ 933.1099715232849,
789
+ 4565.712460041046,
790
+ 14641.29742860794,
791
+ 15552.311092853546,
792
+ 56185.89445209503,
793
+ 97331.36661911011,
794
+ 87548.01149320602,
795
+ 521853.7248663902,
796
+ 2643.261353492737,
797
+ 20220.717566013336,
798
+ 79148.93348503113,
799
+ 17449.243332386017,
800
+ 13258.27445936203,
801
+ 6109.533164024353,
802
+ 6781.56981420517,
803
+ 3942.6140484809875,
804
+ 8469.07410955429,
805
+ 40318.94767665863,
806
+ 156345.23027658463,
807
+ 12197.998657226562,
808
+ 22888.345291614532,
809
+ 10946.28234910965,
810
+ 204263.674387455,
811
+ 229631.36437797546,
812
+ 1945.9702520370483,
813
+ 3069.6773653030396,
814
+ 6425.405041217804,
815
+ 508.55564069747925,
816
+ 8993.14672756195,
817
+ 0.0,
818
+ 0.0,
819
+ 0.0
820
+ ]
821
+ }
models/classifier/lightgbm_model.txt ADDED
The diff for this file is too large to render. See raw diff
 
models/classifier/scaler.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:347b85c4f3e4bcbda0599f607a1ad5194c01655baca73b6e2ee72a9ba50dcf84
3
+ size 13207
src/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hybrid Document Forgery Detection & Localization System
3
+
4
+ A robust hybrid (Deep Learning + Classical ML) system for multi-type
5
+ document forgery detection and localization.
6
+
7
+ Architecture:
8
+ - Deep Learning: MobileNetV3-Small + UNet-Lite for pixel-level localization
9
+ - Classical ML: LightGBM for interpretable forgery classification
10
+ """
11
+
12
+ __version__ = "1.0.0"
13
+
14
+ from .config import get_config
15
+ from .models import get_model, get_loss_function
16
+ from .data import get_dataset
17
+ from .features import get_feature_extractor, get_mask_refiner, get_region_extractor
18
+ from .training import get_trainer, get_metrics_tracker
19
+ from .inference import get_pipeline
20
+
21
+ __all__ = [
22
+ 'get_config',
23
+ 'get_model',
24
+ 'get_loss_function',
25
+ 'get_dataset',
26
+ 'get_feature_extractor',
27
+ 'get_mask_refiner',
28
+ 'get_region_extractor',
29
+ 'get_trainer',
30
+ 'get_metrics_tracker',
31
+ 'get_pipeline'
32
+ ]
src/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (993 Bytes). View file
 
src/config/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Configuration module"""
2
+
3
+ from .config_loader import Config, get_config
4
+
5
+ __all__ = ['Config', 'get_config']
src/config/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (291 Bytes). View file
 
src/config/__pycache__/config_loader.cpython-312.pyc ADDED
Binary file (5.42 kB). View file
 
src/config/config_loader.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration loader for Hybrid Document Forgery Detection System
3
+ """
4
+
5
+ import yaml
6
+ from pathlib import Path
7
+ from typing import Dict, Any
8
+
9
+
10
+ class Config:
11
+ """Configuration manager"""
12
+
13
+ def __init__(self, config_path: str = "config.yaml"):
14
+ """
15
+ Load configuration from YAML file
16
+
17
+ Args:
18
+ config_path: Path to configuration file
19
+ """
20
+ self.config_path = Path(config_path)
21
+ self.config = self._load_config()
22
+
23
+ def _load_config(self) -> Dict[str, Any]:
24
+ """Load YAML configuration"""
25
+ if not self.config_path.exists():
26
+ raise FileNotFoundError(f"Config file not found: {self.config_path}")
27
+
28
+ with open(self.config_path, 'r') as f:
29
+ config = yaml.safe_load(f)
30
+
31
+ return config
32
+
33
+ def get(self, key: str, default: Any = None) -> Any:
34
+ """
35
+ Get configuration value using dot notation
36
+
37
+ Args:
38
+ key: Configuration key (e.g., 'model.encoder.name')
39
+ default: Default value if key not found
40
+
41
+ Returns:
42
+ Configuration value
43
+ """
44
+ keys = key.split('.')
45
+ value = self.config
46
+
47
+ for k in keys:
48
+ if isinstance(value, dict) and k in value:
49
+ value = value[k]
50
+ else:
51
+ return default
52
+
53
+ return value
54
+
55
+ def get_dataset_config(self, dataset_name: str) -> Dict[str, Any]:
56
+ """
57
+ Get dataset-specific configuration
58
+
59
+ Args:
60
+ dataset_name: Dataset name (doctamper, rtm, casia, receipts)
61
+
62
+ Returns:
63
+ Dataset configuration dictionary
64
+ """
65
+ return self.config['data']['datasets'].get(dataset_name, {})
66
+
67
+ def has_pixel_mask(self, dataset_name: str) -> bool:
68
+ """Check if dataset has pixel-level masks"""
69
+ dataset_config = self.get_dataset_config(dataset_name)
70
+ return dataset_config.get('has_pixel_mask', False)
71
+
72
+ def should_skip_deskew(self, dataset_name: str) -> bool:
73
+ """Check if deskewing should be skipped for dataset"""
74
+ dataset_config = self.get_dataset_config(dataset_name)
75
+ return dataset_config.get('skip_deskew', False)
76
+
77
+ def should_skip_denoising(self, dataset_name: str) -> bool:
78
+ """Check if denoising should be skipped for dataset"""
79
+ dataset_config = self.get_dataset_config(dataset_name)
80
+ return dataset_config.get('skip_denoising', False)
81
+
82
+ def get_min_region_area(self, dataset_name: str) -> float:
83
+ """Get minimum region area threshold for dataset"""
84
+ dataset_config = self.get_dataset_config(dataset_name)
85
+ return dataset_config.get('min_region_area', 0.001)
86
+
87
+ def should_compute_localization_metrics(self, dataset_name: str) -> bool:
88
+ """Check if localization metrics should be computed for dataset"""
89
+ compute_config = self.config['metrics'].get('compute_localization', {})
90
+ return compute_config.get(dataset_name, False)
91
+
92
+ def __getitem__(self, key: str) -> Any:
93
+ """Allow dictionary-style access"""
94
+ return self.get(key)
95
+
96
+ def __repr__(self) -> str:
97
+ return f"Config(path={self.config_path})"
98
+
99
+
100
+ # Global config instance
101
+ _config = None
102
+
103
+
104
+ def get_config(config_path: str = "config.yaml") -> Config:
105
+ """
106
+ Get global configuration instance
107
+
108
+ Args:
109
+ config_path: Path to configuration file
110
+
111
+ Returns:
112
+ Config instance
113
+ """
114
+ global _config
115
+ if _config is None:
116
+ _config = Config(config_path)
117
+ return _config
src/data/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data module"""
2
+
3
+ from .preprocessing import DocumentPreprocessor, preprocess_image
4
+ from .augmentation import DatasetAwareAugmentation, get_augmentation
5
+ from .datasets import (
6
+ DocTamperDataset,
7
+ RTMDataset,
8
+ CASIADataset,
9
+ ReceiptsDataset,
10
+ get_dataset
11
+ )
12
+
13
+ __all__ = [
14
+ 'DocumentPreprocessor',
15
+ 'preprocess_image',
16
+ 'DatasetAwareAugmentation',
17
+ 'get_augmentation',
18
+ 'DocTamperDataset',
19
+ 'RTMDataset',
20
+ 'CASIADataset',
21
+ 'ReceiptsDataset',
22
+ 'get_dataset'
23
+ ]
src/data/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (572 Bytes). View file
 
src/data/__pycache__/augmentation.cpython-312.pyc ADDED
Binary file (5.94 kB). View file
 
src/data/__pycache__/datasets.cpython-312.pyc ADDED
Binary file (21.2 kB). View file
 
src/data/__pycache__/preprocessing.cpython-312.pyc ADDED
Binary file (9.38 kB). View file
 
src/data/augmentation.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset-aware augmentation for training
3
+ """
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import albumentations as A
8
+ from albumentations.pytorch import ToTensorV2
9
+ from typing import Dict, Any, Optional
10
+
11
+
12
+ class DatasetAwareAugmentation:
13
+ """Dataset-aware augmentation pipeline"""
14
+
15
+ def __init__(self, config, dataset_name: str, is_training: bool = True):
16
+ """
17
+ Initialize augmentation pipeline
18
+
19
+ Args:
20
+ config: Configuration object
21
+ dataset_name: Dataset name
22
+ is_training: Whether in training mode
23
+ """
24
+ self.config = config
25
+ self.dataset_name = dataset_name
26
+ self.is_training = is_training
27
+
28
+ # Build augmentation pipeline
29
+ self.transform = self._build_transform()
30
+
31
+ def _build_transform(self) -> A.Compose:
32
+ """Build albumentations transform pipeline"""
33
+
34
+ transforms = []
35
+
36
+ if self.is_training and self.config.get('augmentation.enabled', True):
37
+ # Common augmentations
38
+ common_augs = self.config.get('augmentation.common', [])
39
+
40
+ for aug_config in common_augs:
41
+ aug_type = aug_config.get('type')
42
+ prob = aug_config.get('prob', 0.5)
43
+
44
+ if aug_type == 'noise':
45
+ transforms.append(
46
+ A.GaussNoise(var_limit=(10.0, 50.0), p=prob)
47
+ )
48
+
49
+ elif aug_type == 'motion_blur':
50
+ transforms.append(
51
+ A.MotionBlur(blur_limit=7, p=prob)
52
+ )
53
+
54
+ elif aug_type == 'jpeg_compression':
55
+ quality_range = aug_config.get('quality', [60, 95])
56
+ transforms.append(
57
+ A.ImageCompression(quality_lower=quality_range[0],
58
+ quality_upper=quality_range[1],
59
+ p=prob)
60
+ )
61
+
62
+ elif aug_type == 'lighting':
63
+ transforms.append(
64
+ A.OneOf([
65
+ A.RandomBrightnessContrast(p=1.0),
66
+ A.RandomGamma(p=1.0),
67
+ A.HueSaturationValue(p=1.0),
68
+ ], p=prob)
69
+ )
70
+
71
+ elif aug_type == 'perspective':
72
+ transforms.append(
73
+ A.Perspective(scale=(0.02, 0.05), p=prob)
74
+ )
75
+
76
+ # Dataset-specific augmentations
77
+ if self.dataset_name == 'receipts':
78
+ receipt_augs = self.config.get('augmentation.receipts', [])
79
+
80
+ for aug_config in receipt_augs:
81
+ aug_type = aug_config.get('type')
82
+ prob = aug_config.get('prob', 0.5)
83
+
84
+ if aug_type == 'stain':
85
+ # Simulate stains using random blobs
86
+ transforms.append(
87
+ A.RandomShadow(
88
+ shadow_roi=(0, 0, 1, 1),
89
+ num_shadows_lower=1,
90
+ num_shadows_upper=3,
91
+ shadow_dimension=5,
92
+ p=prob
93
+ )
94
+ )
95
+
96
+ elif aug_type == 'fold':
97
+ # Simulate folds using grid distortion
98
+ transforms.append(
99
+ A.GridDistortion(num_steps=5, distort_limit=0.1, p=prob)
100
+ )
101
+
102
+ # Always convert to tensor
103
+ transforms.append(ToTensorV2())
104
+
105
+ return A.Compose(
106
+ transforms,
107
+ additional_targets={'mask': 'mask'}
108
+ )
109
+
110
+ def __call__(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Dict[str, Any]:
111
+ """
112
+ Apply augmentation
113
+
114
+ Args:
115
+ image: Input image (H, W, 3), float32, [0, 1]
116
+ mask: Optional mask (H, W), uint8, {0, 1}
117
+
118
+ Returns:
119
+ Dictionary with 'image' and optionally 'mask'
120
+ """
121
+ # Convert to uint8 for albumentations
122
+ image_uint8 = (image * 255).astype(np.uint8)
123
+
124
+ if mask is not None:
125
+ mask_uint8 = (mask * 255).astype(np.uint8)
126
+ augmented = self.transform(image=image_uint8, mask=mask_uint8)
127
+
128
+ # Convert back to float32
129
+ augmented['image'] = augmented['image'].float() / 255.0
130
+ augmented['mask'] = (augmented['mask'].float() / 255.0).unsqueeze(0)
131
+ else:
132
+ augmented = self.transform(image=image_uint8)
133
+ augmented['image'] = augmented['image'].float() / 255.0
134
+
135
+ return augmented
136
+
137
+
138
+ def get_augmentation(config, dataset_name: str, is_training: bool = True) -> DatasetAwareAugmentation:
139
+ """
140
+ Get augmentation pipeline
141
+
142
+ Args:
143
+ config: Configuration object
144
+ dataset_name: Dataset name
145
+ is_training: Whether in training mode
146
+
147
+ Returns:
148
+ Augmentation pipeline
149
+ """
150
+ return DatasetAwareAugmentation(config, dataset_name, is_training)
src/data/datasets.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset loaders for document forgery detection
3
+ Implements Critical Fix #7: Image-level train/test splits
4
+ """
5
+
6
+ import os
7
+ import lmdb
8
+ import cv2
9
+ import numpy as np
10
+ import torch
11
+ from torch.utils.data import Dataset
12
+ from pathlib import Path
13
+ from typing import Tuple, Optional, List
14
+ import json
15
+ from PIL import Image
16
+
17
+ from .preprocessing import DocumentPreprocessor
18
+ from .augmentation import DatasetAwareAugmentation
19
+
20
+
21
+ class DocTamperDataset(Dataset):
22
+ """
23
+ DocTamper dataset loader (LMDB-based)
24
+ Implements chunked loading for RAM constraints
25
+ Uses lazy LMDB initialization for multiprocessing compatibility
26
+ """
27
+
28
+ def __init__(self,
29
+ config,
30
+ split: str = 'train',
31
+ chunk_start: float = 0.0,
32
+ chunk_end: float = 1.0):
33
+ """
34
+ Initialize DocTamper dataset
35
+
36
+ Args:
37
+ config: Configuration object
38
+ split: 'train' or 'val'
39
+ chunk_start: Start ratio for chunked training (0.0 to 1.0)
40
+ chunk_end: End ratio for chunked training (0.0 to 1.0)
41
+ """
42
+ self.config = config
43
+ self.split = split
44
+ self.dataset_name = 'doctamper'
45
+
46
+ # Get dataset path
47
+ dataset_config = config.get_dataset_config(self.dataset_name)
48
+ self.data_path = Path(dataset_config['path'])
49
+
50
+ # Map split to actual folder names
51
+ if split == 'train':
52
+ lmdb_folder = 'DocTamperV1-TrainingSet'
53
+ elif split == 'val' or split == 'test':
54
+ lmdb_folder = 'DocTamperV1-TestingSet'
55
+ else:
56
+ lmdb_folder = 'DocTamperV1-TrainingSet'
57
+
58
+ self.lmdb_path = str(self.data_path / lmdb_folder)
59
+
60
+ if not Path(self.lmdb_path).exists():
61
+ raise FileNotFoundError(f"LMDB folder not found: {self.lmdb_path}")
62
+
63
+ # LAZY INITIALIZATION: Don't open LMDB here (pickle issue with multiprocessing)
64
+ # Just get the count by temporarily opening
65
+ temp_env = lmdb.open(self.lmdb_path, readonly=True, lock=False)
66
+ with temp_env.begin() as txn:
67
+ stat = txn.stat()
68
+ self.length = stat['entries'] // 2
69
+ temp_env.close()
70
+
71
+ # LMDB env will be opened lazily in __getitem__
72
+ self._env = None
73
+
74
+ # Critical Fix #7: Image-level chunking (not region-level)
75
+ self.chunk_start = int(self.length * chunk_start)
76
+ self.chunk_end = int(self.length * chunk_end)
77
+ self.chunk_length = self.chunk_end - self.chunk_start
78
+
79
+ print(f"DocTamper {split}: Total={self.length}, "
80
+ f"Chunk=[{self.chunk_start}:{self.chunk_end}], "
81
+ f"Length={self.chunk_length}")
82
+
83
+ # Initialize preprocessor and augmentation
84
+ self.preprocessor = DocumentPreprocessor(config, self.dataset_name)
85
+ self.augmentation = DatasetAwareAugmentation(
86
+ config,
87
+ self.dataset_name,
88
+ is_training=(split == 'train')
89
+ )
90
+
91
+ @property
92
+ def env(self):
93
+ """Lazy LMDB environment initialization for multiprocessing compatibility"""
94
+ if self._env is None:
95
+ self._env = lmdb.open(self.lmdb_path, readonly=True, lock=False,
96
+ max_readers=32, readahead=False)
97
+ return self._env
98
+
99
+ def __len__(self) -> int:
100
+ return self.chunk_length
101
+
102
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, dict]:
103
+ """
104
+ Get item from dataset
105
+
106
+ Args:
107
+ idx: Index within chunk
108
+
109
+ Returns:
110
+ image: (3, H, W) tensor
111
+ mask: (1, H, W) tensor
112
+ metadata: Dictionary with additional info
113
+ """
114
+ # Try to get the requested sample, skip to next if missing
115
+ max_attempts = 10
116
+ original_idx = idx
117
+
118
+ for attempt in range(max_attempts):
119
+ try:
120
+ # Map chunk index to global index
121
+ global_idx = self.chunk_start + idx
122
+
123
+ # Read from LMDB
124
+ with self.env.begin() as txn:
125
+ # DocTamper format: image-XXXXXXXXX, label-XXXXXXXXX (9 digits, dash separator)
126
+ img_key = f'image-{global_idx:09d}'.encode()
127
+ label_key = f'label-{global_idx:09d}'.encode()
128
+
129
+ img_buf = txn.get(img_key)
130
+ label_buf = txn.get(label_key)
131
+
132
+ if img_buf is None:
133
+ # Sample missing, try next index
134
+ idx = (idx + 1) % self.chunk_length
135
+ continue
136
+
137
+ # Decode image
138
+ img_array = np.frombuffer(img_buf, dtype=np.uint8)
139
+ image = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
140
+
141
+ if image is None:
142
+ # Failed to decode, try next index
143
+ idx = (idx + 1) % self.chunk_length
144
+ continue
145
+
146
+ # Decode label/mask
147
+ if label_buf is not None:
148
+ label_array = np.frombuffer(label_buf, dtype=np.uint8)
149
+ mask = cv2.imdecode(label_array, cv2.IMREAD_GRAYSCALE)
150
+ if mask is None:
151
+ # Label might be raw bytes, create empty mask
152
+ mask = np.zeros(image.shape[:2], dtype=np.uint8)
153
+ else:
154
+ # No mask found - create empty mask
155
+ mask = np.zeros(image.shape[:2], dtype=np.uint8)
156
+
157
+ # Successfully loaded - break out of retry loop
158
+ break
159
+
160
+ except Exception as e:
161
+ # Something went wrong, try next index
162
+ idx = (idx + 1) % self.chunk_length
163
+ if attempt == max_attempts - 1:
164
+ # Last attempt failed, create a dummy sample
165
+ print(f"Warning: Could not load sample at idx {original_idx}, creating dummy sample")
166
+ image = np.zeros((384, 384, 3), dtype=np.float32)
167
+ mask = np.zeros((384, 384), dtype=np.uint8)
168
+ global_idx = original_idx
169
+
170
+ # Preprocess
171
+ image, mask = self.preprocessor(image, mask)
172
+
173
+ # Augment
174
+ augmented = self.augmentation(image, mask)
175
+ image = augmented['image']
176
+ mask = augmented['mask']
177
+
178
+ # Metadata
179
+ metadata = {
180
+ 'dataset': self.dataset_name,
181
+ 'index': global_idx,
182
+ 'has_pixel_mask': True
183
+ }
184
+
185
+ return image, mask, metadata
186
+
187
+ def __del__(self):
188
+ """Close LMDB environment"""
189
+ if hasattr(self, '_env') and self._env is not None:
190
+ self._env.close()
191
+
192
+
193
+
194
+ class RTMDataset(Dataset):
195
+ """Real Text Manipulation dataset loader"""
196
+
197
+ def __init__(self, config, split: str = 'train'):
198
+ """
199
+ Initialize RTM dataset
200
+
201
+ Args:
202
+ config: Configuration object
203
+ split: 'train' or 'test'
204
+ """
205
+ self.config = config
206
+ self.split = split
207
+ self.dataset_name = 'rtm'
208
+
209
+ # Get dataset path
210
+ dataset_config = config.get_dataset_config(self.dataset_name)
211
+ self.data_path = Path(dataset_config['path'])
212
+
213
+ # Load split file
214
+ split_file = self.data_path / f'{split}.txt'
215
+ with open(split_file, 'r') as f:
216
+ self.image_ids = [line.strip() for line in f.readlines()]
217
+
218
+ self.images_dir = self.data_path / 'JPEGImages'
219
+ self.masks_dir = self.data_path / 'SegmentationClass'
220
+
221
+ print(f"RTM {split}: {len(self.image_ids)} images")
222
+
223
+ # Initialize preprocessor and augmentation
224
+ self.preprocessor = DocumentPreprocessor(config, self.dataset_name)
225
+ self.augmentation = DatasetAwareAugmentation(
226
+ config,
227
+ self.dataset_name,
228
+ is_training=(split == 'train')
229
+ )
230
+
231
+ def __len__(self) -> int:
232
+ return len(self.image_ids)
233
+
234
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, dict]:
235
+ """Get item from dataset"""
236
+ image_id = self.image_ids[idx]
237
+
238
+ # Load image
239
+ img_path = self.images_dir / f'{image_id}.jpg'
240
+ image = cv2.imread(str(img_path))
241
+
242
+ # Load mask
243
+ mask_path = self.masks_dir / f'{image_id}.png'
244
+ mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
245
+
246
+ # Binarize mask
247
+ mask = (mask > 0).astype(np.uint8)
248
+
249
+ # Preprocess
250
+ image, mask = self.preprocessor(image, mask)
251
+
252
+ # Augment
253
+ augmented = self.augmentation(image, mask)
254
+ image = augmented['image']
255
+ mask = augmented['mask']
256
+
257
+ # Metadata
258
+ metadata = {
259
+ 'dataset': self.dataset_name,
260
+ 'image_id': image_id,
261
+ 'has_pixel_mask': True
262
+ }
263
+
264
+ return image, mask, metadata
265
+
266
+
267
+ class CASIADataset(Dataset):
268
+ """
269
+ CASIA v1.0 dataset loader
270
+ Image-level labels only (no pixel masks)
271
+ Implements Critical Fix #6: CASIA image-level handling
272
+ """
273
+
274
+ def __init__(self, config, split: str = 'train'):
275
+ """
276
+ Initialize CASIA dataset
277
+
278
+ Args:
279
+ config: Configuration object
280
+ split: 'train' or 'test'
281
+ """
282
+ self.config = config
283
+ self.split = split
284
+ self.dataset_name = 'casia'
285
+
286
+ # Get dataset path
287
+ dataset_config = config.get_dataset_config(self.dataset_name)
288
+ self.data_path = Path(dataset_config['path'])
289
+
290
+ # Load authentic and tampered images
291
+ self.authentic_dir = self.data_path / 'Au'
292
+ self.tampered_dir = self.data_path / 'Tp'
293
+
294
+ # Get all image paths
295
+ authentic_images = list(self.authentic_dir.glob('*.jpg')) + \
296
+ list(self.authentic_dir.glob('*.png'))
297
+ tampered_images = list(self.tampered_dir.glob('*.jpg')) + \
298
+ list(self.tampered_dir.glob('*.png'))
299
+
300
+ # Create image list with labels
301
+ self.samples = []
302
+ for img_path in authentic_images:
303
+ self.samples.append((img_path, 0)) # 0 = authentic
304
+ for img_path in tampered_images:
305
+ self.samples.append((img_path, 1)) # 1 = tampered
306
+
307
+ # Critical Fix #7: Image-level split (80/20)
308
+ np.random.seed(42)
309
+ indices = np.random.permutation(len(self.samples))
310
+ split_idx = int(len(self.samples) * 0.8)
311
+
312
+ if split == 'train':
313
+ indices = indices[:split_idx]
314
+ else:
315
+ indices = indices[split_idx:]
316
+
317
+ self.samples = [self.samples[i] for i in indices]
318
+
319
+ print(f"CASIA {split}: {len(self.samples)} images")
320
+
321
+ # Initialize preprocessor and augmentation
322
+ self.preprocessor = DocumentPreprocessor(config, self.dataset_name)
323
+ self.augmentation = DatasetAwareAugmentation(
324
+ config,
325
+ self.dataset_name,
326
+ is_training=(split == 'train')
327
+ )
328
+
329
+ def __len__(self) -> int:
330
+ return len(self.samples)
331
+
332
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, dict]:
333
+ """Get item from dataset"""
334
+ img_path, label = self.samples[idx]
335
+
336
+ # Load image
337
+ image = cv2.imread(str(img_path))
338
+
339
+ # Critical Fix #6: Create image-level mask (entire image)
340
+ h, w = image.shape[:2]
341
+ mask = np.ones((h, w), dtype=np.uint8) * label
342
+
343
+ # Preprocess
344
+ image, mask = self.preprocessor(image, mask)
345
+
346
+ # Augment
347
+ augmented = self.augmentation(image, mask)
348
+ image = augmented['image']
349
+ mask = augmented['mask']
350
+
351
+ # Metadata
352
+ metadata = {
353
+ 'dataset': self.dataset_name,
354
+ 'image_path': str(img_path),
355
+ 'has_pixel_mask': False, # Image-level only
356
+ 'label': label
357
+ }
358
+
359
+ return image, mask, metadata
360
+
361
+
362
+ class ReceiptsDataset(Dataset):
363
+ """Find-It-Again receipts dataset loader"""
364
+
365
+ def __init__(self, config, split: str = 'train'):
366
+ """
367
+ Initialize receipts dataset
368
+
369
+ Args:
370
+ config: Configuration object
371
+ split: 'train', 'val', or 'test'
372
+ """
373
+ self.config = config
374
+ self.split = split
375
+ self.dataset_name = 'receipts'
376
+
377
+ # Get dataset path
378
+ dataset_config = config.get_dataset_config(self.dataset_name)
379
+ self.data_path = Path(dataset_config['path'])
380
+
381
+ # Load split file
382
+ split_file = self.data_path / f'{split}.json'
383
+ with open(split_file, 'r') as f:
384
+ self.annotations = json.load(f)
385
+
386
+ print(f"Receipts {split}: {len(self.annotations)} images")
387
+
388
+ # Initialize preprocessor and augmentation
389
+ self.preprocessor = DocumentPreprocessor(config, self.dataset_name)
390
+ self.augmentation = DatasetAwareAugmentation(
391
+ config,
392
+ self.dataset_name,
393
+ is_training=(split == 'train')
394
+ )
395
+
396
+ def __len__(self) -> int:
397
+ return len(self.annotations)
398
+
399
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, dict]:
400
+ """Get item from dataset"""
401
+ ann = self.annotations[idx]
402
+
403
+ # Load image
404
+ img_path = self.data_path / ann['image_path']
405
+ image = cv2.imread(str(img_path))
406
+
407
+ # Create mask from bounding boxes
408
+ h, w = image.shape[:2]
409
+ mask = np.zeros((h, w), dtype=np.uint8)
410
+
411
+ for bbox in ann.get('bboxes', []):
412
+ x, y, w_box, h_box = bbox
413
+ mask[y:y+h_box, x:x+w_box] = 1
414
+
415
+ # Preprocess
416
+ image, mask = self.preprocessor(image, mask)
417
+
418
+ # Augment
419
+ augmented = self.augmentation(image, mask)
420
+ image = augmented['image']
421
+ mask = augmented['mask']
422
+
423
+ # Metadata
424
+ metadata = {
425
+ 'dataset': self.dataset_name,
426
+ 'image_path': str(img_path),
427
+ 'has_pixel_mask': True
428
+ }
429
+
430
+ return image, mask, metadata
431
+
432
+
433
+ class FCDDataset(DocTamperDataset):
434
+ """FCD (Forgery Classification Dataset) loader - inherits from DocTamperDataset"""
435
+
436
+ def __init__(self, config, split: str = 'train'):
437
+ self.config = config
438
+ self.split = split
439
+ self.dataset_name = 'fcd'
440
+
441
+ # Get dataset path from config
442
+ dataset_config = config.get_dataset_config(self.dataset_name)
443
+ self.data_path = Path(dataset_config['path'])
444
+ self.lmdb_path = str(self.data_path)
445
+
446
+ if not Path(self.lmdb_path).exists():
447
+ raise FileNotFoundError(f"LMDB folder not found: {self.lmdb_path}")
448
+
449
+ # Get total count
450
+ temp_env = lmdb.open(self.lmdb_path, readonly=True, lock=False)
451
+ with temp_env.begin() as txn:
452
+ stat = txn.stat()
453
+ self.length = stat['entries'] // 2 # Half are images, half are labels
454
+ temp_env.close()
455
+
456
+ self._env = None
457
+
458
+ # FCD is small, no chunking needed
459
+ self.chunk_start = 0
460
+ self.chunk_end = self.length
461
+ self.chunk_length = self.length
462
+
463
+ print(f"FCD {split}: {self.length} samples")
464
+
465
+ # Initialize preprocessor and augmentation
466
+ self.preprocessor = DocumentPreprocessor(config, self.dataset_name)
467
+ self.augmentation = DatasetAwareAugmentation(
468
+ config,
469
+ self.dataset_name,
470
+ is_training=(split == 'train')
471
+ )
472
+
473
+
474
+ class SCDDataset(DocTamperDataset):
475
+ """SCD (Splicing Classification Dataset) loader - inherits from DocTamperDataset"""
476
+
477
+ def __init__(self, config, split: str = 'train'):
478
+ self.config = config
479
+ self.split = split
480
+ self.dataset_name = 'scd'
481
+
482
+ # Get dataset path from config
483
+ dataset_config = config.get_dataset_config(self.dataset_name)
484
+ self.data_path = Path(dataset_config['path'])
485
+ self.lmdb_path = str(self.data_path)
486
+
487
+ if not Path(self.lmdb_path).exists():
488
+ raise FileNotFoundError(f"LMDB folder not found: {self.lmdb_path}")
489
+
490
+ # Get total count
491
+ temp_env = lmdb.open(self.lmdb_path, readonly=True, lock=False)
492
+ with temp_env.begin() as txn:
493
+ stat = txn.stat()
494
+ self.length = stat['entries'] // 2 # Half are images, half are labels
495
+ temp_env.close()
496
+
497
+ self._env = None
498
+
499
+ # SCD is medium-sized, no chunking needed
500
+ self.chunk_start = 0
501
+ self.chunk_end = self.length
502
+ self.chunk_length = self.length
503
+
504
+ print(f"SCD {split}: {self.length} samples")
505
+
506
+ # Initialize preprocessor and augmentation
507
+ self.preprocessor = DocumentPreprocessor(config, self.dataset_name)
508
+ self.augmentation = DatasetAwareAugmentation(
509
+ config,
510
+ self.dataset_name,
511
+ is_training=(split == 'train')
512
+ )
513
+
514
+
515
+ def get_dataset(config, dataset_name: str, split: str = 'train', **kwargs) -> Dataset:
516
+ """
517
+ Factory function to get dataset
518
+
519
+ Args:
520
+ config: Configuration object
521
+ dataset_name: Dataset name
522
+ split: Data split
523
+ **kwargs: Additional arguments (e.g., chunk_start, chunk_end)
524
+
525
+ Returns:
526
+ Dataset instance
527
+ """
528
+ if dataset_name == 'doctamper':
529
+ return DocTamperDataset(config, split, **kwargs)
530
+ elif dataset_name == 'rtm':
531
+ return RTMDataset(config, split)
532
+ elif dataset_name == 'casia':
533
+ return CASIADataset(config, split)
534
+ elif dataset_name == 'receipts':
535
+ return ReceiptsDataset(config, split)
536
+ elif dataset_name == 'fcd':
537
+ return FCDDataset(config, split)
538
+ elif dataset_name == 'scd':
539
+ return SCDDataset(config, split)
540
+ else:
541
+ raise ValueError(f"Unknown dataset: {dataset_name}")
src/data/preprocessing.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset-aware preprocessing for document forgery detection
3
+ Implements Critical Fix #1: Dataset-Aware Preprocessing
4
+ """
5
+
6
+ import cv2
7
+ import numpy as np
8
+ from typing import Tuple, Optional
9
+ import pywt
10
+ from scipy import ndimage
11
+
12
+
13
+ class DocumentPreprocessor:
14
+ """Dataset-aware document preprocessing"""
15
+
16
+ def __init__(self, config, dataset_name: str):
17
+ """
18
+ Initialize preprocessor
19
+
20
+ Args:
21
+ config: Configuration object
22
+ dataset_name: Name of dataset (for dataset-aware processing)
23
+ """
24
+ self.config = config
25
+ self.dataset_name = dataset_name
26
+ self.image_size = config.get('data.image_size', 384)
27
+ self.noise_threshold = config.get('preprocessing.noise_threshold', 15.0)
28
+
29
+ # Dataset-aware flags (Critical Fix #1)
30
+ self.skip_deskew = config.should_skip_deskew(dataset_name)
31
+ self.skip_denoising = config.should_skip_denoising(dataset_name)
32
+
33
+ def __call__(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
34
+ """
35
+ Apply preprocessing pipeline
36
+
37
+ Args:
38
+ image: Input image (H, W, 3)
39
+ mask: Optional ground truth mask (H, W)
40
+
41
+ Returns:
42
+ Preprocessed image and mask
43
+ """
44
+ # 1. Convert to RGB
45
+ if len(image.shape) == 2:
46
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
47
+ elif image.shape[2] == 4:
48
+ image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
49
+ elif image.shape[2] == 3:
50
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
51
+
52
+ # 2. Deskew (dataset-aware)
53
+ if not self.skip_deskew:
54
+ image, mask = self._deskew(image, mask)
55
+
56
+ # 3. Resize
57
+ image, mask = self._resize(image, mask)
58
+
59
+ # 4. Normalize
60
+ image = self._normalize(image)
61
+
62
+ # 5. Conditional denoising (dataset-aware)
63
+ if not self.skip_denoising:
64
+ noise_level = self._estimate_noise(image)
65
+ if noise_level > self.noise_threshold:
66
+ image = self._denoise(image)
67
+
68
+ return image, mask
69
+
70
+ def _deskew(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
71
+ """
72
+ Deskew document image
73
+
74
+ Args:
75
+ image: Input image
76
+ mask: Optional mask
77
+
78
+ Returns:
79
+ Deskewed image and mask
80
+ """
81
+ # Convert to grayscale for angle detection
82
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
83
+
84
+ # Detect edges
85
+ edges = cv2.Canny(gray, 50, 150, apertureSize=3)
86
+
87
+ # Detect lines using Hough transform
88
+ lines = cv2.HoughLines(edges, 1, np.pi / 180, 200)
89
+
90
+ if lines is not None and len(lines) > 0:
91
+ # Calculate dominant angle
92
+ angles = []
93
+ for rho, theta in lines[:, 0]:
94
+ angle = (theta * 180 / np.pi) - 90
95
+ angles.append(angle)
96
+
97
+ # Use median angle
98
+ angle = np.median(angles)
99
+
100
+ # Only deskew if angle is significant (> 0.5 degrees)
101
+ if abs(angle) > 0.5:
102
+ # Get rotation matrix
103
+ h, w = image.shape[:2]
104
+ center = (w // 2, h // 2)
105
+ M = cv2.getRotationMatrix2D(center, angle, 1.0)
106
+
107
+ # Rotate image
108
+ image = cv2.warpAffine(image, M, (w, h),
109
+ flags=cv2.INTER_CUBIC,
110
+ borderMode=cv2.BORDER_REPLICATE)
111
+
112
+ # Rotate mask if provided
113
+ if mask is not None:
114
+ mask = cv2.warpAffine(mask, M, (w, h),
115
+ flags=cv2.INTER_NEAREST,
116
+ borderMode=cv2.BORDER_CONSTANT,
117
+ borderValue=0)
118
+
119
+ return image, mask
120
+
121
+ def _resize(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
122
+ """
123
+ Resize image and mask to target size
124
+
125
+ Args:
126
+ image: Input image
127
+ mask: Optional mask
128
+
129
+ Returns:
130
+ Resized image and mask
131
+ """
132
+ target_size = (self.image_size, self.image_size)
133
+
134
+ # Resize image
135
+ image = cv2.resize(image, target_size, interpolation=cv2.INTER_CUBIC)
136
+
137
+ # Resize mask if provided
138
+ if mask is not None:
139
+ mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
140
+
141
+ return image, mask
142
+
143
+ def _normalize(self, image: np.ndarray) -> np.ndarray:
144
+ """
145
+ Normalize pixel values to [0, 1]
146
+
147
+ Args:
148
+ image: Input image
149
+
150
+ Returns:
151
+ Normalized image
152
+ """
153
+ return image.astype(np.float32) / 255.0
154
+
155
+ def _estimate_noise(self, image: np.ndarray) -> float:
156
+ """
157
+ Estimate noise level using Laplacian variance and wavelet-based estimation
158
+
159
+ Args:
160
+ image: Input image (normalized)
161
+
162
+ Returns:
163
+ Estimated noise level
164
+ """
165
+ # Convert to grayscale for noise estimation
166
+ if len(image.shape) == 3:
167
+ gray = cv2.cvtColor((image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
168
+ else:
169
+ gray = (image * 255).astype(np.uint8)
170
+
171
+ # Method 1: Laplacian variance
172
+ laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var()
173
+
174
+ # Method 2: Wavelet-based noise estimation
175
+ coeffs = pywt.dwt2(gray, 'db1')
176
+ _, (cH, cV, cD) = coeffs
177
+ sigma = np.median(np.abs(cD)) / 0.6745
178
+
179
+ # Combine both estimates
180
+ noise_level = (laplacian_var + sigma) / 2.0
181
+
182
+ return noise_level
183
+
184
+ def _denoise(self, image: np.ndarray) -> np.ndarray:
185
+ """
186
+ Apply conditional denoising
187
+
188
+ Args:
189
+ image: Input image (normalized)
190
+
191
+ Returns:
192
+ Denoised image
193
+ """
194
+ # Convert to uint8 for filtering
195
+ image_uint8 = (image * 255).astype(np.uint8)
196
+
197
+ # Apply median filter (3x3)
198
+ median_filtered = cv2.medianBlur(image_uint8, 3)
199
+
200
+ # Apply Gaussian filter (σ ≤ 0.8)
201
+ gaussian_filtered = cv2.GaussianBlur(median_filtered, (3, 3), 0.8)
202
+
203
+ # Convert back to float32
204
+ denoised = gaussian_filtered.astype(np.float32) / 255.0
205
+
206
+ return denoised
207
+
208
+
209
+ def preprocess_image(image: np.ndarray,
210
+ mask: Optional[np.ndarray] = None,
211
+ config = None,
212
+ dataset_name: str = 'default') -> Tuple[np.ndarray, Optional[np.ndarray]]:
213
+ """
214
+ Convenience function for preprocessing
215
+
216
+ Args:
217
+ image: Input image
218
+ mask: Optional mask
219
+ config: Configuration object
220
+ dataset_name: Dataset name
221
+
222
+ Returns:
223
+ Preprocessed image and mask
224
+ """
225
+ preprocessor = DocumentPreprocessor(config, dataset_name)
226
+ return preprocessor(image, mask)
src/features/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Features module"""
2
+
3
+ from .feature_extraction import (
4
+ DeepFeatureExtractor,
5
+ StatisticalFeatureExtractor,
6
+ FrequencyFeatureExtractor,
7
+ NoiseELAFeatureExtractor,
8
+ OCRFeatureExtractor,
9
+ HybridFeatureExtractor,
10
+ get_feature_extractor
11
+ )
12
+
13
+ from .region_extraction import (
14
+ MaskRefiner,
15
+ RegionExtractor,
16
+ get_mask_refiner,
17
+ get_region_extractor
18
+ )
19
+
20
+ __all__ = [
21
+ 'DeepFeatureExtractor',
22
+ 'StatisticalFeatureExtractor',
23
+ 'FrequencyFeatureExtractor',
24
+ 'NoiseELAFeatureExtractor',
25
+ 'OCRFeatureExtractor',
26
+ 'HybridFeatureExtractor',
27
+ 'get_feature_extractor',
28
+ 'MaskRefiner',
29
+ 'RegionExtractor',
30
+ 'get_mask_refiner',
31
+ 'get_region_extractor'
32
+ ]
src/features/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (691 Bytes). View file
 
src/features/__pycache__/feature_extraction.cpython-312.pyc ADDED
Binary file (22.6 kB). View file
 
src/features/__pycache__/region_extraction.cpython-312.pyc ADDED
Binary file (8.93 kB). View file
 
src/features/feature_extraction.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hybrid feature extraction for forgery detection
3
+ Implements Critical Fix #5: Feature Group Gating
4
+ """
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from typing import Dict, List, Optional, Tuple
11
+ from scipy import ndimage
12
+ from scipy.fftpack import dct
13
+ import pywt
14
+ from skimage.measure import regionprops, label
15
+ from skimage.filters import sobel
16
+
17
+
18
+ class DeepFeatureExtractor:
19
+ """Extract deep features from decoder feature maps"""
20
+
21
+ def __init__(self):
22
+ """Initialize deep feature extractor"""
23
+ pass
24
+
25
+ def extract(self,
26
+ decoder_features: List[torch.Tensor],
27
+ region_mask: np.ndarray) -> np.ndarray:
28
+ """
29
+ Extract deep features using Global Average Pooling
30
+
31
+ Args:
32
+ decoder_features: List of decoder feature tensors
33
+ region_mask: Binary region mask (H, W)
34
+
35
+ Returns:
36
+ Deep feature vector
37
+ """
38
+ features = []
39
+
40
+ for feat in decoder_features:
41
+ # Ensure on CPU and numpy
42
+ if isinstance(feat, torch.Tensor):
43
+ feat = feat.detach().cpu().numpy()
44
+
45
+ # feat shape: (B, C, H, W) or (C, H, W)
46
+ if feat.ndim == 4:
47
+ feat = feat[0] # Take first batch
48
+
49
+ # Resize mask to feature size
50
+ h, w = feat.shape[1:]
51
+ mask_resized = cv2.resize(region_mask.astype(np.float32), (w, h))
52
+ mask_resized = mask_resized > 0.5
53
+
54
+ # Masked Global Average Pooling
55
+ if mask_resized.sum() > 0:
56
+ for c in range(feat.shape[0]):
57
+ channel_feat = feat[c]
58
+ masked_mean = channel_feat[mask_resized].mean()
59
+ features.append(masked_mean)
60
+ else:
61
+ # Fallback: use global average
62
+ features.extend(feat.mean(axis=(1, 2)).tolist())
63
+
64
+ return np.array(features, dtype=np.float32)
65
+
66
+
67
+ class StatisticalFeatureExtractor:
68
+ """Extract statistical and shape features from regions"""
69
+
70
+ def __init__(self):
71
+ """Initialize statistical feature extractor"""
72
+ pass
73
+
74
+ def extract(self,
75
+ image: np.ndarray,
76
+ region_mask: np.ndarray) -> np.ndarray:
77
+ """
78
+ Extract statistical and shape features
79
+
80
+ Args:
81
+ image: Input image (H, W, 3) normalized [0, 1]
82
+ region_mask: Binary region mask (H, W)
83
+
84
+ Returns:
85
+ Statistical feature vector
86
+ """
87
+ features = []
88
+
89
+ # Label the mask
90
+ labeled_mask = label(region_mask)
91
+ props = regionprops(labeled_mask)
92
+
93
+ if len(props) > 0:
94
+ prop = props[0]
95
+
96
+ # Area and perimeter
97
+ features.append(prop.area)
98
+ features.append(prop.perimeter)
99
+
100
+ # Aspect ratio
101
+ if prop.major_axis_length > 0:
102
+ aspect_ratio = prop.minor_axis_length / prop.major_axis_length
103
+ else:
104
+ aspect_ratio = 1.0
105
+ features.append(aspect_ratio)
106
+
107
+ # Solidity
108
+ features.append(prop.solidity)
109
+
110
+ # Eccentricity
111
+ features.append(prop.eccentricity)
112
+
113
+ # Entropy (using intensity)
114
+ if len(image.shape) == 3:
115
+ gray = cv2.cvtColor((image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
116
+ else:
117
+ gray = (image * 255).astype(np.uint8)
118
+
119
+ region_pixels = gray[region_mask > 0]
120
+ if len(region_pixels) > 0:
121
+ hist, _ = np.histogram(region_pixels, bins=256, range=(0, 256))
122
+ hist = hist / hist.sum() + 1e-8
123
+ entropy = -np.sum(hist * np.log2(hist + 1e-8))
124
+ else:
125
+ entropy = 0.0
126
+ features.append(entropy)
127
+ else:
128
+ # Default values
129
+ features.extend([0, 0, 1.0, 0, 0, 0])
130
+
131
+ return np.array(features, dtype=np.float32)
132
+
133
+
134
+ class FrequencyFeatureExtractor:
135
+ """Extract frequency-domain features"""
136
+
137
+ def __init__(self):
138
+ """Initialize frequency feature extractor"""
139
+ pass
140
+
141
+ def extract(self,
142
+ image: np.ndarray,
143
+ region_mask: np.ndarray) -> np.ndarray:
144
+ """
145
+ Extract frequency-domain features (DCT, wavelet)
146
+
147
+ Args:
148
+ image: Input image (H, W, 3) normalized [0, 1]
149
+ region_mask: Binary region mask (H, W)
150
+
151
+ Returns:
152
+ Frequency feature vector
153
+ """
154
+ features = []
155
+
156
+ # Convert to grayscale
157
+ if len(image.shape) == 3:
158
+ gray = cv2.cvtColor((image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
159
+ else:
160
+ gray = (image * 255).astype(np.uint8)
161
+
162
+ # Get region bounding box
163
+ coords = np.where(region_mask > 0)
164
+ if len(coords[0]) == 0:
165
+ return np.zeros(13, dtype=np.float32)
166
+
167
+ y_min, y_max = coords[0].min(), coords[0].max()
168
+ x_min, x_max = coords[1].min(), coords[1].max()
169
+
170
+ # Crop region
171
+ region = gray[y_min:y_max+1, x_min:x_max+1].astype(np.float32)
172
+
173
+ if region.size == 0:
174
+ return np.zeros(13, dtype=np.float32)
175
+
176
+ # DCT coefficients
177
+ try:
178
+ dct_coeffs = dct(dct(region, axis=0, norm='ortho'), axis=1, norm='ortho')
179
+
180
+ # Mean and std of DCT coefficients
181
+ features.append(np.mean(np.abs(dct_coeffs)))
182
+ features.append(np.std(dct_coeffs))
183
+
184
+ # High-frequency energy (bottom-right quadrant)
185
+ h, w = dct_coeffs.shape
186
+ high_freq = dct_coeffs[h//2:, w//2:]
187
+ features.append(np.sum(np.abs(high_freq)) / (high_freq.size + 1e-8))
188
+ except Exception:
189
+ features.extend([0, 0, 0])
190
+
191
+ # Wavelet features
192
+ try:
193
+ coeffs = pywt.dwt2(region, 'db1')
194
+ cA, (cH, cV, cD) = coeffs
195
+
196
+ # Energy in each sub-band
197
+ features.append(np.sum(cA ** 2) / (cA.size + 1e-8))
198
+ features.append(np.sum(cH ** 2) / (cH.size + 1e-8))
199
+ features.append(np.sum(cV ** 2) / (cV.size + 1e-8))
200
+ features.append(np.sum(cD ** 2) / (cD.size + 1e-8))
201
+
202
+ # Wavelet entropy
203
+ for coeff in [cH, cV, cD]:
204
+ coeff_flat = np.abs(coeff.flatten())
205
+ if coeff_flat.sum() > 0:
206
+ coeff_norm = coeff_flat / coeff_flat.sum()
207
+ entropy = -np.sum(coeff_norm * np.log2(coeff_norm + 1e-8))
208
+ else:
209
+ entropy = 0.0
210
+ features.append(entropy)
211
+ except Exception:
212
+ features.extend([0, 0, 0, 0, 0, 0, 0])
213
+
214
+ return np.array(features, dtype=np.float32)
215
+
216
+
217
+ class NoiseELAFeatureExtractor:
218
+ """Extract noise and Error Level Analysis features"""
219
+
220
+ def __init__(self, quality: int = 90):
221
+ """
222
+ Initialize noise/ELA extractor
223
+
224
+ Args:
225
+ quality: JPEG quality for ELA
226
+ """
227
+ self.quality = quality
228
+
229
+ def extract(self,
230
+ image: np.ndarray,
231
+ region_mask: np.ndarray) -> np.ndarray:
232
+ """
233
+ Extract noise and ELA features
234
+
235
+ Args:
236
+ image: Input image (H, W, 3) normalized [0, 1]
237
+ region_mask: Binary region mask (H, W)
238
+
239
+ Returns:
240
+ Noise/ELA feature vector
241
+ """
242
+ features = []
243
+
244
+ # Convert to uint8
245
+ img_uint8 = (image * 255).astype(np.uint8)
246
+
247
+ # Error Level Analysis
248
+ # Compress and compute difference
249
+ encode_param = [cv2.IMWRITE_JPEG_QUALITY, self.quality]
250
+ _, encoded = cv2.imencode('.jpg', img_uint8, encode_param)
251
+ recompressed = cv2.imdecode(encoded, cv2.IMREAD_COLOR)
252
+
253
+ ela = np.abs(img_uint8.astype(np.float32) - recompressed.astype(np.float32))
254
+
255
+ # ELA features within region
256
+ ela_region = ela[region_mask > 0]
257
+ if len(ela_region) > 0:
258
+ features.append(np.mean(ela_region)) # ELA mean
259
+ features.append(np.var(ela_region)) # ELA variance
260
+ features.append(np.max(ela_region)) # ELA max
261
+ else:
262
+ features.extend([0, 0, 0])
263
+
264
+ # Noise residual (using median filter)
265
+ if len(image.shape) == 3:
266
+ gray = cv2.cvtColor(img_uint8, cv2.COLOR_RGB2GRAY)
267
+ else:
268
+ gray = img_uint8
269
+
270
+ median_filtered = cv2.medianBlur(gray, 3)
271
+ noise_residual = np.abs(gray.astype(np.float32) - median_filtered.astype(np.float32))
272
+
273
+ residual_region = noise_residual[region_mask > 0]
274
+ if len(residual_region) > 0:
275
+ features.append(np.mean(residual_region))
276
+ features.append(np.var(residual_region))
277
+ else:
278
+ features.extend([0, 0])
279
+
280
+ return np.array(features, dtype=np.float32)
281
+
282
+
283
+ class OCRFeatureExtractor:
284
+ """
285
+ Extract OCR-based consistency features
286
+ Only for text documents (Feature Gating - Critical Fix #5)
287
+ """
288
+
289
+ def __init__(self):
290
+ """Initialize OCR feature extractor"""
291
+ self.ocr_available = False
292
+
293
+ try:
294
+ import easyocr
295
+ self.reader = easyocr.Reader(['en'], gpu=True)
296
+ self.ocr_available = True
297
+ except Exception:
298
+ print("Warning: EasyOCR not available, OCR features disabled")
299
+
300
+ def extract(self,
301
+ image: np.ndarray,
302
+ region_mask: np.ndarray) -> np.ndarray:
303
+ """
304
+ Extract OCR consistency features
305
+
306
+ Args:
307
+ image: Input image (H, W, 3) normalized [0, 1]
308
+ region_mask: Binary region mask (H, W)
309
+
310
+ Returns:
311
+ OCR feature vector (or zeros if not text document)
312
+ """
313
+ features = []
314
+
315
+ if not self.ocr_available:
316
+ return np.zeros(6, dtype=np.float32)
317
+
318
+ # Convert to uint8
319
+ img_uint8 = (image * 255).astype(np.uint8)
320
+
321
+ # Get region bounding box
322
+ coords = np.where(region_mask > 0)
323
+ if len(coords[0]) == 0:
324
+ return np.zeros(6, dtype=np.float32)
325
+
326
+ y_min, y_max = coords[0].min(), coords[0].max()
327
+ x_min, x_max = coords[1].min(), coords[1].max()
328
+
329
+ # Crop region
330
+ region = img_uint8[y_min:y_max+1, x_min:x_max+1]
331
+
332
+ try:
333
+ # OCR on region
334
+ results = self.reader.readtext(region)
335
+
336
+ if len(results) > 0:
337
+ # Confidence deviation
338
+ confidences = [r[2] for r in results]
339
+ features.append(np.mean(confidences))
340
+ features.append(np.std(confidences))
341
+
342
+ # Character spacing analysis
343
+ bbox_widths = [abs(r[0][1][0] - r[0][0][0]) for r in results]
344
+ if len(bbox_widths) > 1:
345
+ features.append(np.std(bbox_widths) / (np.mean(bbox_widths) + 1e-8))
346
+ else:
347
+ features.append(0.0)
348
+
349
+ # Text density
350
+ features.append(len(results) / (region.shape[0] * region.shape[1] + 1e-8))
351
+
352
+ # Stroke width variation (using edge detection)
353
+ gray_region = cv2.cvtColor(region, cv2.COLOR_RGB2GRAY)
354
+ edges = sobel(gray_region)
355
+ features.append(np.mean(edges))
356
+ features.append(np.std(edges))
357
+ else:
358
+ features.extend([0, 0, 0, 0, 0, 0])
359
+ except Exception:
360
+ features.extend([0, 0, 0, 0, 0, 0])
361
+
362
+ return np.array(features, dtype=np.float32)
363
+
364
+
365
+ class HybridFeatureExtractor:
366
+ """
367
+ Complete hybrid feature extraction
368
+ Implements Critical Fix #5: Feature Group Gating
369
+ """
370
+
371
+ def __init__(self, config, is_text_document: bool = True):
372
+ """
373
+ Initialize hybrid feature extractor
374
+
375
+ Args:
376
+ config: Configuration object
377
+ is_text_document: Whether input is text document (for OCR gating)
378
+ """
379
+ self.config = config
380
+ self.is_text_document = is_text_document
381
+
382
+ # Initialize extractors
383
+ self.deep_extractor = DeepFeatureExtractor()
384
+ self.stat_extractor = StatisticalFeatureExtractor()
385
+ self.freq_extractor = FrequencyFeatureExtractor()
386
+ self.noise_extractor = NoiseELAFeatureExtractor()
387
+
388
+ # Critical Fix #5: OCR only for text documents
389
+ if is_text_document and config.get('features.ocr.enabled', True):
390
+ self.ocr_extractor = OCRFeatureExtractor()
391
+ else:
392
+ self.ocr_extractor = None
393
+
394
+ def extract(self,
395
+ image: np.ndarray,
396
+ region_mask: np.ndarray,
397
+ decoder_features: Optional[List[torch.Tensor]] = None) -> np.ndarray:
398
+ """
399
+ Extract all hybrid features for a region
400
+
401
+ Args:
402
+ image: Input image (H, W, 3) normalized [0, 1]
403
+ region_mask: Binary region mask (H, W)
404
+ decoder_features: Optional decoder features for deep feature extraction
405
+
406
+ Returns:
407
+ Concatenated feature vector
408
+ """
409
+ all_features = []
410
+
411
+ # Deep features (if available)
412
+ if decoder_features is not None and self.config.get('features.deep.enabled', True):
413
+ deep_feats = self.deep_extractor.extract(decoder_features, region_mask)
414
+ all_features.append(deep_feats)
415
+
416
+ # Statistical & shape features
417
+ if self.config.get('features.statistical.enabled', True):
418
+ stat_feats = self.stat_extractor.extract(image, region_mask)
419
+ all_features.append(stat_feats)
420
+
421
+ # Frequency-domain features
422
+ if self.config.get('features.frequency.enabled', True):
423
+ freq_feats = self.freq_extractor.extract(image, region_mask)
424
+ all_features.append(freq_feats)
425
+
426
+ # Noise & ELA features
427
+ if self.config.get('features.noise.enabled', True):
428
+ noise_feats = self.noise_extractor.extract(image, region_mask)
429
+ all_features.append(noise_feats)
430
+
431
+ # Critical Fix #5: OCR features only for text documents
432
+ if self.ocr_extractor is not None:
433
+ ocr_feats = self.ocr_extractor.extract(image, region_mask)
434
+ all_features.append(ocr_feats)
435
+
436
+ # Concatenate all features
437
+ if len(all_features) > 0:
438
+ features = np.concatenate(all_features)
439
+ else:
440
+ features = np.array([], dtype=np.float32)
441
+
442
+ # Handle NaN/Inf
443
+ features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
444
+
445
+ return features
446
+
447
+ def get_feature_names(self) -> List[str]:
448
+ """Get list of feature names for interpretability"""
449
+ names = []
450
+
451
+ if self.config.get('features.deep.enabled', True):
452
+ names.extend([f'deep_{i}' for i in range(256)]) # Approximate
453
+
454
+ if self.config.get('features.statistical.enabled', True):
455
+ names.extend(['area', 'perimeter', 'aspect_ratio',
456
+ 'solidity', 'eccentricity', 'entropy'])
457
+
458
+ if self.config.get('features.frequency.enabled', True):
459
+ names.extend(['dct_mean', 'dct_std', 'high_freq_energy',
460
+ 'wavelet_cA', 'wavelet_cH', 'wavelet_cV', 'wavelet_cD',
461
+ 'wavelet_entropy_H', 'wavelet_entropy_V', 'wavelet_entropy_D'])
462
+
463
+ if self.config.get('features.noise.enabled', True):
464
+ names.extend(['ela_mean', 'ela_var', 'ela_max',
465
+ 'noise_residual_mean', 'noise_residual_var'])
466
+
467
+ if self.ocr_extractor is not None:
468
+ names.extend(['ocr_conf_mean', 'ocr_conf_std', 'spacing_irregularity',
469
+ 'text_density', 'stroke_mean', 'stroke_std'])
470
+
471
+ return names
472
+
473
+
474
+ def get_feature_extractor(config, is_text_document: bool = True) -> HybridFeatureExtractor:
475
+ """
476
+ Factory function to create feature extractor
477
+
478
+ Args:
479
+ config: Configuration object
480
+ is_text_document: Whether input is text document
481
+
482
+ Returns:
483
+ HybridFeatureExtractor instance
484
+ """
485
+ return HybridFeatureExtractor(config, is_text_document)
src/features/region_extraction.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mask refinement and region extraction
3
+ Implements Critical Fix #3: Adaptive Mask Refinement Thresholds
4
+ """
5
+
6
+ import cv2
7
+ import numpy as np
8
+ from typing import List, Tuple, Dict, Optional
9
+ from scipy import ndimage
10
+ from skimage.measure import label, regionprops
11
+
12
+
13
+ class MaskRefiner:
14
+ """
15
+ Mask refinement with adaptive thresholds
16
+ Implements Critical Fix #3: Dataset-specific minimum region areas
17
+ """
18
+
19
+ def __init__(self, config, dataset_name: str = 'default'):
20
+ """
21
+ Initialize mask refiner
22
+
23
+ Args:
24
+ config: Configuration object
25
+ dataset_name: Dataset name for adaptive thresholds
26
+ """
27
+ self.config = config
28
+ self.dataset_name = dataset_name
29
+
30
+ # Get mask refinement parameters
31
+ self.threshold = config.get('mask_refinement.threshold', 0.5)
32
+ self.closing_kernel = config.get('mask_refinement.morphology.closing_kernel', 5)
33
+ self.opening_kernel = config.get('mask_refinement.morphology.opening_kernel', 3)
34
+
35
+ # Critical Fix #3: Adaptive thresholds per dataset
36
+ self.min_region_area = config.get_min_region_area(dataset_name)
37
+
38
+ print(f"MaskRefiner initialized for {dataset_name}")
39
+ print(f"Min region area: {self.min_region_area * 100:.2f}%")
40
+
41
+ def refine(self,
42
+ probability_map: np.ndarray,
43
+ original_size: Tuple[int, int] = None) -> np.ndarray:
44
+ """
45
+ Refine probability map to binary mask
46
+
47
+ Args:
48
+ probability_map: Forgery probability map (H, W), values [0, 1]
49
+ original_size: Optional (H, W) to resize mask back to original
50
+
51
+ Returns:
52
+ Refined binary mask (H, W)
53
+ """
54
+ # Threshold to binary
55
+ binary_mask = (probability_map > self.threshold).astype(np.uint8)
56
+
57
+ # Morphological closing (fill broken strokes)
58
+ closing_kernel = cv2.getStructuringElement(
59
+ cv2.MORPH_RECT,
60
+ (self.closing_kernel, self.closing_kernel)
61
+ )
62
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, closing_kernel)
63
+
64
+ # Morphological opening (remove isolated noise)
65
+ opening_kernel = cv2.getStructuringElement(
66
+ cv2.MORPH_RECT,
67
+ (self.opening_kernel, self.opening_kernel)
68
+ )
69
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, opening_kernel)
70
+
71
+ # Critical Fix #3: Remove small regions with adaptive threshold
72
+ binary_mask = self._remove_small_regions(binary_mask)
73
+
74
+ # Resize to original size if provided
75
+ if original_size is not None:
76
+ binary_mask = cv2.resize(
77
+ binary_mask,
78
+ (original_size[1], original_size[0]), # cv2 uses (W, H)
79
+ interpolation=cv2.INTER_NEAREST
80
+ )
81
+
82
+ return binary_mask
83
+
84
+ def _remove_small_regions(self, mask: np.ndarray) -> np.ndarray:
85
+ """
86
+ Remove regions smaller than minimum area threshold
87
+
88
+ Args:
89
+ mask: Binary mask (H, W)
90
+
91
+ Returns:
92
+ Filtered mask
93
+ """
94
+ # Calculate minimum pixel count
95
+ image_area = mask.shape[0] * mask.shape[1]
96
+ min_pixels = int(image_area * self.min_region_area)
97
+
98
+ # Label connected components
99
+ labeled_mask, num_features = ndimage.label(mask)
100
+
101
+ # Keep only large enough regions
102
+ filtered_mask = np.zeros_like(mask)
103
+
104
+ for region_id in range(1, num_features + 1):
105
+ region_mask = (labeled_mask == region_id)
106
+ region_area = region_mask.sum()
107
+
108
+ if region_area >= min_pixels:
109
+ filtered_mask[region_mask] = 1
110
+
111
+ return filtered_mask
112
+
113
+
114
+ class RegionExtractor:
115
+ """
116
+ Extract individual regions from binary mask
117
+ Implements Critical Fix #4: Region Confidence Aggregation
118
+ """
119
+
120
+ def __init__(self, config, dataset_name: str = 'default'):
121
+ """
122
+ Initialize region extractor
123
+
124
+ Args:
125
+ config: Configuration object
126
+ dataset_name: Dataset name
127
+ """
128
+ self.config = config
129
+ self.dataset_name = dataset_name
130
+ self.min_region_area = config.get_min_region_area(dataset_name)
131
+
132
+ def extract(self,
133
+ binary_mask: np.ndarray,
134
+ probability_map: np.ndarray,
135
+ original_image: np.ndarray) -> List[Dict]:
136
+ """
137
+ Extract regions from binary mask
138
+
139
+ Args:
140
+ binary_mask: Refined binary mask (H, W)
141
+ probability_map: Original probability map (H, W)
142
+ original_image: Original image (H, W, 3)
143
+
144
+ Returns:
145
+ List of region dictionaries with bounding box, mask, image, confidence
146
+ """
147
+ regions = []
148
+
149
+ # Connected component analysis (8-connectivity)
150
+ labeled_mask = label(binary_mask, connectivity=2)
151
+ props = regionprops(labeled_mask)
152
+
153
+ for region_id, prop in enumerate(props, start=1):
154
+ # Bounding box
155
+ y_min, x_min, y_max, x_max = prop.bbox
156
+
157
+ # Region mask
158
+ region_mask = (labeled_mask == region_id).astype(np.uint8)
159
+
160
+ # Cropped region image
161
+ region_image = original_image[y_min:y_max, x_min:x_max].copy()
162
+ region_mask_cropped = region_mask[y_min:y_max, x_min:x_max]
163
+
164
+ # Critical Fix #4: Region-level confidence aggregation
165
+ region_probs = probability_map[region_mask > 0]
166
+ region_confidence = float(np.mean(region_probs)) if len(region_probs) > 0 else 0.0
167
+
168
+ regions.append({
169
+ 'region_id': region_id,
170
+ 'bounding_box': [int(x_min), int(y_min),
171
+ int(x_max - x_min), int(y_max - y_min)],
172
+ 'area': prop.area,
173
+ 'centroid': (int(prop.centroid[1]), int(prop.centroid[0])),
174
+ 'region_mask': region_mask,
175
+ 'region_mask_cropped': region_mask_cropped,
176
+ 'region_image': region_image,
177
+ 'confidence': region_confidence,
178
+ 'mask_probability_mean': region_confidence
179
+ })
180
+
181
+ return regions
182
+
183
+ def extract_for_casia(self,
184
+ binary_mask: np.ndarray,
185
+ probability_map: np.ndarray,
186
+ original_image: np.ndarray) -> List[Dict]:
187
+ """
188
+ Critical Fix #6: CASIA handling - treat entire image as one region
189
+
190
+ Args:
191
+ binary_mask: Binary mask (may be empty for authentic images)
192
+ probability_map: Probability map
193
+ original_image: Original image
194
+
195
+ Returns:
196
+ Single region representing entire image
197
+ """
198
+ h, w = original_image.shape[:2]
199
+
200
+ # Create single region covering entire image
201
+ region_mask = np.ones((h, w), dtype=np.uint8)
202
+
203
+ # Overall confidence from probability map
204
+ overall_confidence = float(np.mean(probability_map))
205
+
206
+ return [{
207
+ 'region_id': 1,
208
+ 'bounding_box': [0, 0, w, h],
209
+ 'area': h * w,
210
+ 'centroid': (w // 2, h // 2),
211
+ 'region_mask': region_mask,
212
+ 'region_mask_cropped': region_mask,
213
+ 'region_image': original_image,
214
+ 'confidence': overall_confidence,
215
+ 'mask_probability_mean': overall_confidence
216
+ }]
217
+
218
+
219
+ def get_mask_refiner(config, dataset_name: str = 'default') -> MaskRefiner:
220
+ """Factory function for mask refiner"""
221
+ return MaskRefiner(config, dataset_name)
222
+
223
+
224
+ def get_region_extractor(config, dataset_name: str = 'default') -> RegionExtractor:
225
+ """Factory function for region extractor"""
226
+ return RegionExtractor(config, dataset_name)
src/inference/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Inference module"""
2
+
3
+ from .pipeline import ForgeryDetectionPipeline, get_pipeline
4
+
5
+ __all__ = ['ForgeryDetectionPipeline', 'get_pipeline']
src/inference/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (305 Bytes). View file
 
src/inference/__pycache__/pipeline.cpython-312.pyc ADDED
Binary file (14.5 kB). View file
 
src/inference/pipeline.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference pipeline for document forgery detection
3
+ Complete pipeline: Image → Localization → Regions → Classification → Output
4
+ """
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ from typing import Dict, List, Optional, Tuple
10
+ from pathlib import Path
11
+ import json
12
+ from PIL import Image
13
+ import fitz # PyMuPDF
14
+
15
+ from ..config import get_config
16
+ from ..models import get_model
17
+ from ..features import (
18
+ get_feature_extractor,
19
+ get_mask_refiner,
20
+ get_region_extractor
21
+ )
22
+ from ..training.classifier import get_classifier
23
+
24
+
25
+ class ForgeryDetectionPipeline:
26
+ """
27
+ Complete inference pipeline for document forgery detection
28
+
29
+ Pipeline:
30
+ 1. Input handling (PDF/Image)
31
+ 2. Preprocessing
32
+ 3. Deep localization
33
+ 4. Mask refinement
34
+ 5. Region extraction
35
+ 6. Feature extraction
36
+ 7. Classification
37
+ 8. Post-processing
38
+ 9. Output generation
39
+ """
40
+
41
+ def __init__(self,
42
+ config,
43
+ model_path: str,
44
+ classifier_path: Optional[str] = None,
45
+ is_text_document: bool = True):
46
+ """
47
+ Initialize pipeline
48
+
49
+ Args:
50
+ config: Configuration object
51
+ model_path: Path to localization model checkpoint
52
+ classifier_path: Path to classifier (optional)
53
+ is_text_document: Whether input is text document (for OCR features)
54
+ """
55
+ self.config = config
56
+ self.is_text_document = is_text_document
57
+
58
+ # Device
59
+ self.device = torch.device(
60
+ 'cuda' if torch.cuda.is_available() and config.get('system.device') == 'cuda'
61
+ else 'cpu'
62
+ )
63
+ print(f"Inference device: {self.device}")
64
+
65
+ # Load localization model
66
+ self.model = get_model(config).to(self.device)
67
+ self._load_model(model_path)
68
+ self.model.eval()
69
+
70
+ # Initialize mask refiner
71
+ self.mask_refiner = get_mask_refiner(config, 'default')
72
+
73
+ # Initialize region extractor
74
+ self.region_extractor = get_region_extractor(config, 'default')
75
+
76
+ # Initialize feature extractor
77
+ self.feature_extractor = get_feature_extractor(config, is_text_document)
78
+
79
+ # Load classifier if provided
80
+ if classifier_path:
81
+ self.classifier = get_classifier(config)
82
+ self.classifier.load(classifier_path)
83
+ else:
84
+ self.classifier = None
85
+
86
+ # Confidence threshold
87
+ self.confidence_threshold = config.get('classifier.confidence_threshold', 0.6)
88
+
89
+ # Image size
90
+ self.image_size = config.get('data.image_size', 384)
91
+
92
+ print("Inference pipeline initialized")
93
+
94
+ def _load_model(self, model_path: str):
95
+ """Load model checkpoint"""
96
+ checkpoint = torch.load(model_path, map_location=self.device)
97
+
98
+ if 'model_state_dict' in checkpoint:
99
+ self.model.load_state_dict(checkpoint['model_state_dict'])
100
+ else:
101
+ self.model.load_state_dict(checkpoint)
102
+
103
+ print(f"Loaded model from {model_path}")
104
+
105
+ def _load_image(self, input_path: str) -> np.ndarray:
106
+ """
107
+ Load image from file or PDF
108
+
109
+ Args:
110
+ input_path: Path to image or PDF
111
+
112
+ Returns:
113
+ Image as numpy array (H, W, 3)
114
+ """
115
+ path = Path(input_path)
116
+
117
+ if path.suffix.lower() == '.pdf':
118
+ # Rasterize PDF at 300 DPI
119
+ doc = fitz.open(str(path))
120
+ page = doc[0]
121
+ mat = fitz.Matrix(300/72, 300/72) # 300 DPI
122
+ pix = page.get_pixmap(matrix=mat)
123
+ image = np.frombuffer(pix.samples, dtype=np.uint8)
124
+ image = image.reshape(pix.height, pix.width, pix.n)
125
+ if pix.n == 4:
126
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
127
+ doc.close()
128
+ else:
129
+ # Load image
130
+ image = cv2.imread(str(path))
131
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
132
+
133
+ return image
134
+
135
+ def _preprocess(self, image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
136
+ """
137
+ Preprocess image for inference
138
+
139
+ Args:
140
+ image: Input image (H, W, 3)
141
+
142
+ Returns:
143
+ Preprocessed image and original image
144
+ """
145
+ original = image.copy()
146
+
147
+ # Resize
148
+ preprocessed = cv2.resize(image, (self.image_size, self.image_size))
149
+
150
+ # Normalize to [0, 1]
151
+ preprocessed = preprocessed.astype(np.float32) / 255.0
152
+
153
+ return preprocessed, original
154
+
155
+ def _to_tensor(self, image: np.ndarray) -> torch.Tensor:
156
+ """Convert image to tensor"""
157
+ # (H, W, C) -> (C, H, W)
158
+ tensor = torch.from_numpy(image.transpose(2, 0, 1))
159
+ tensor = tensor.unsqueeze(0) # Add batch dimension
160
+ return tensor.to(self.device)
161
+
162
+ def run(self,
163
+ input_path: str,
164
+ output_dir: Optional[str] = None) -> Dict:
165
+ """
166
+ Run full inference pipeline
167
+
168
+ Args:
169
+ input_path: Path to input image or PDF
170
+ output_dir: Optional output directory
171
+
172
+ Returns:
173
+ Dictionary with results
174
+ """
175
+ print(f"\n{'='*60}")
176
+ print(f"Processing: {input_path}")
177
+ print(f"{'='*60}")
178
+
179
+ # 1. Load image
180
+ image = self._load_image(input_path)
181
+ original_size = image.shape[:2]
182
+ print(f"Input size: {original_size}")
183
+
184
+ # 2. Preprocess
185
+ preprocessed, original = self._preprocess(image)
186
+ tensor = self._to_tensor(preprocessed)
187
+
188
+ # 3. Deep localization
189
+ with torch.no_grad():
190
+ logits, decoder_features = self.model(tensor)
191
+ probability_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
192
+
193
+ print(f"Localization complete. Max prob: {probability_map.max():.3f}")
194
+
195
+ # 4. Mask refinement
196
+ binary_mask = self.mask_refiner.refine(probability_map, original_size)
197
+ num_positive_pixels = binary_mask.sum()
198
+ print(f"Mask refinement: {num_positive_pixels} positive pixels")
199
+
200
+ # 5. Region extraction
201
+ # Resize probability map to original size for confidence aggregation
202
+ prob_resized = cv2.resize(probability_map, (original_size[1], original_size[0]))
203
+
204
+ regions = self.region_extractor.extract(binary_mask, prob_resized, original)
205
+ print(f"Regions extracted: {len(regions)}")
206
+
207
+ # 6. Feature extraction & 7. Classification
208
+ results = []
209
+
210
+ for region in regions:
211
+ # Extract features
212
+ features = self.feature_extractor.extract(
213
+ preprocessed,
214
+ cv2.resize(region['region_mask'], (self.image_size, self.image_size)),
215
+ [f.cpu() for f in decoder_features]
216
+ )
217
+
218
+ # Classify if classifier available
219
+ if self.classifier is not None:
220
+ predictions, confidences, valid_mask = self.classifier.predict_with_filtering(
221
+ features.reshape(1, -1)
222
+ )
223
+
224
+ if valid_mask[0]:
225
+ region['forgery_type'] = self.classifier.get_class_name(predictions[0])
226
+ region['classification_confidence'] = float(confidences[0])
227
+ else:
228
+ # Low confidence - discard
229
+ continue
230
+ else:
231
+ region['forgery_type'] = 'unknown'
232
+ region['classification_confidence'] = region['confidence']
233
+
234
+ # Clean up non-serializable fields
235
+ region_result = {
236
+ 'region_id': region['region_id'],
237
+ 'bounding_box': region['bounding_box'],
238
+ 'forgery_type': region['forgery_type'],
239
+ 'confidence': region['confidence'],
240
+ 'classification_confidence': region['classification_confidence'],
241
+ 'mask_probability_mean': region['mask_probability_mean'],
242
+ 'area': region['area']
243
+ }
244
+ results.append(region_result)
245
+
246
+ print(f"Valid regions after filtering: {len(results)}")
247
+
248
+ # 8. Post-processing - False positive removal
249
+ results = self._post_process(results)
250
+
251
+ # 9. Generate output
252
+ output = {
253
+ 'input_path': str(input_path),
254
+ 'original_size': original_size,
255
+ 'num_regions': len(results),
256
+ 'regions': results,
257
+ 'is_tampered': len(results) > 0
258
+ }
259
+
260
+ # Save outputs if directory provided
261
+ if output_dir:
262
+ output_path = Path(output_dir)
263
+ output_path.mkdir(parents=True, exist_ok=True)
264
+
265
+ input_name = Path(input_path).stem
266
+
267
+ # Save final mask
268
+ mask_path = output_path / f'{input_name}_mask.png'
269
+ cv2.imwrite(str(mask_path), binary_mask * 255)
270
+
271
+ # Save overlay visualization
272
+ overlay = self._create_overlay(original, binary_mask, results)
273
+ overlay_path = output_path / f'{input_name}_overlay.png'
274
+ cv2.imwrite(str(overlay_path), cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
275
+
276
+ # Save JSON
277
+ json_path = output_path / f'{input_name}_results.json'
278
+ with open(json_path, 'w') as f:
279
+ json.dump(output, f, indent=2)
280
+
281
+ print(f"\nOutputs saved to: {output_path}")
282
+ output['mask_path'] = str(mask_path)
283
+ output['overlay_path'] = str(overlay_path)
284
+ output['json_path'] = str(json_path)
285
+
286
+ return output
287
+
288
+ def _post_process(self, regions: List[Dict]) -> List[Dict]:
289
+ """
290
+ Post-process regions to remove false positives
291
+
292
+ Args:
293
+ regions: List of region dictionaries
294
+
295
+ Returns:
296
+ Filtered regions
297
+ """
298
+ filtered = []
299
+
300
+ for region in regions:
301
+ # Confidence filtering
302
+ if region['confidence'] < self.confidence_threshold:
303
+ continue
304
+
305
+ filtered.append(region)
306
+
307
+ return filtered
308
+
309
+ def _create_overlay(self,
310
+ image: np.ndarray,
311
+ mask: np.ndarray,
312
+ regions: List[Dict]) -> np.ndarray:
313
+ """
314
+ Create visualization overlay
315
+
316
+ Args:
317
+ image: Original image
318
+ mask: Binary mask
319
+ regions: Detected regions
320
+
321
+ Returns:
322
+ Overlay image
323
+ """
324
+ overlay = image.copy()
325
+ alpha = self.config.get('outputs.visualization.overlay_alpha', 0.5)
326
+
327
+ # Create colored mask
328
+ mask_colored = np.zeros_like(image)
329
+ mask_colored[mask > 0] = [255, 0, 0] # Red for forgery
330
+
331
+ # Blend
332
+ mask_resized = cv2.resize(mask, (image.shape[1], image.shape[0]))
333
+ overlay = np.where(
334
+ mask_resized[:, :, None] > 0,
335
+ (1 - alpha) * image + alpha * mask_colored,
336
+ image
337
+ ).astype(np.uint8)
338
+
339
+ # Draw bounding boxes and labels
340
+ for region in regions:
341
+ x, y, w, h = region['bounding_box']
342
+
343
+ # Draw rectangle
344
+ cv2.rectangle(overlay, (x, y), (x + w, y + h), (0, 255, 0), 2)
345
+
346
+ # Draw label
347
+ label = f"{region['forgery_type']} ({region['confidence']:.2f})"
348
+ cv2.putText(overlay, label, (x, y - 10),
349
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
350
+
351
+ return overlay
352
+
353
+
354
+ def get_pipeline(config,
355
+ model_path: str,
356
+ classifier_path: Optional[str] = None,
357
+ is_text_document: bool = True) -> ForgeryDetectionPipeline:
358
+ """Factory function for pipeline"""
359
+ return ForgeryDetectionPipeline(config, model_path, classifier_path, is_text_document)
src/models/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Models module"""
2
+
3
+ from .encoder import MobileNetV3Encoder, get_encoder
4
+ from .decoder import UNetLiteDecoder, get_decoder
5
+ from .network import ForgeryLocalizationNetwork, get_model
6
+ from .losses import DiceLoss, CombinedLoss, DatasetAwareLoss, get_loss_function
7
+
8
+ __all__ = [
9
+ 'MobileNetV3Encoder',
10
+ 'get_encoder',
11
+ 'UNetLiteDecoder',
12
+ 'get_decoder',
13
+ 'ForgeryLocalizationNetwork',
14
+ 'get_model',
15
+ 'DiceLoss',
16
+ 'CombinedLoss',
17
+ 'DatasetAwareLoss',
18
+ 'get_loss_function'
19
+ ]
src/models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (600 Bytes). View file
 
src/models/__pycache__/decoder.cpython-312.pyc ADDED
Binary file (7.65 kB). View file
 
src/models/__pycache__/encoder.cpython-312.pyc ADDED
Binary file (2.91 kB). View file
 
src/models/__pycache__/losses.cpython-312.pyc ADDED
Binary file (6.55 kB). View file
 
src/models/__pycache__/network.cpython-312.pyc ADDED
Binary file (5.84 kB). View file
 
src/models/decoder.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UNet-Lite Decoder for forgery localization
3
+ Lightweight decoder with skip connections, depthwise separable convolutions
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from typing import List
10
+
11
+
12
+ class DepthwiseSeparableConv(nn.Module):
13
+ """Depthwise separable convolution for efficiency"""
14
+
15
+ def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
16
+ super().__init__()
17
+
18
+ self.depthwise = nn.Conv2d(
19
+ in_channels, in_channels,
20
+ kernel_size=kernel_size,
21
+ padding=kernel_size // 2,
22
+ groups=in_channels,
23
+ bias=False
24
+ )
25
+ self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
26
+ self.bn = nn.BatchNorm2d(out_channels)
27
+ self.relu = nn.ReLU(inplace=True)
28
+
29
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
30
+ x = self.depthwise(x)
31
+ x = self.pointwise(x)
32
+ x = self.bn(x)
33
+ x = self.relu(x)
34
+ return x
35
+
36
+
37
+ class DecoderBlock(nn.Module):
38
+ """Single decoder block with skip connection"""
39
+
40
+ def __init__(self, in_channels: int, skip_channels: int, out_channels: int):
41
+ """
42
+ Initialize decoder block
43
+
44
+ Args:
45
+ in_channels: Input channels from previous decoder stage
46
+ skip_channels: Channels from encoder skip connection
47
+ out_channels: Output channels
48
+ """
49
+ super().__init__()
50
+
51
+ # Combine upsampled features with skip connection
52
+ combined_channels = in_channels + skip_channels
53
+
54
+ self.conv1 = DepthwiseSeparableConv(combined_channels, out_channels)
55
+ self.conv2 = DepthwiseSeparableConv(out_channels, out_channels)
56
+
57
+ def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
58
+ """
59
+ Forward pass
60
+
61
+ Args:
62
+ x: Input from previous decoder stage
63
+ skip: Skip connection from encoder
64
+
65
+ Returns:
66
+ Decoded features
67
+ """
68
+ # Bilinear upsampling
69
+ x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
70
+
71
+ # Concatenate with skip connection
72
+ x = torch.cat([x, skip], dim=1)
73
+
74
+ # Convolutions
75
+ x = self.conv1(x)
76
+ x = self.conv2(x)
77
+
78
+ return x
79
+
80
+
81
+ class UNetLiteDecoder(nn.Module):
82
+ """
83
+ UNet-Lite decoder for forgery localization
84
+
85
+ Features:
86
+ - Skip connections from encoder stages
87
+ - Bilinear upsampling
88
+ - Depthwise separable convolutions for efficiency
89
+ """
90
+
91
+ def __init__(self,
92
+ encoder_channels: List[int],
93
+ decoder_channels: List[int] = None,
94
+ output_channels: int = 1):
95
+ """
96
+ Initialize decoder
97
+
98
+ Args:
99
+ encoder_channels: List of encoder feature channels [stage0, ..., stageN]
100
+ decoder_channels: List of decoder output channels
101
+ output_channels: Number of output channels (1 for binary mask)
102
+ """
103
+ super().__init__()
104
+
105
+ # Default decoder channels if not provided
106
+ if decoder_channels is None:
107
+ decoder_channels = [256, 128, 64, 32, 16]
108
+
109
+ # Reverse encoder channels for decoder (bottom to top)
110
+ encoder_channels = encoder_channels[::-1]
111
+
112
+ # Initial convolution from deepest encoder features
113
+ self.initial_conv = DepthwiseSeparableConv(encoder_channels[0], decoder_channels[0])
114
+
115
+ # Decoder blocks
116
+ self.decoder_blocks = nn.ModuleList()
117
+
118
+ for i in range(len(encoder_channels) - 1):
119
+ in_ch = decoder_channels[i]
120
+ skip_ch = encoder_channels[i + 1]
121
+ out_ch = decoder_channels[i + 1] if i + 1 < len(decoder_channels) else decoder_channels[-1]
122
+
123
+ self.decoder_blocks.append(
124
+ DecoderBlock(in_ch, skip_ch, out_ch)
125
+ )
126
+
127
+ # Final upsampling to original resolution
128
+ self.final_upsample = nn.Sequential(
129
+ DepthwiseSeparableConv(decoder_channels[-1], decoder_channels[-1]),
130
+ nn.Conv2d(decoder_channels[-1], output_channels, kernel_size=1)
131
+ )
132
+
133
+ # Store decoder feature channels for feature extraction
134
+ self.decoder_channels = decoder_channels
135
+
136
+ print(f"UNet-Lite decoder initialized")
137
+ print(f"Encoder channels: {encoder_channels[::-1]}")
138
+ print(f"Decoder channels: {decoder_channels}")
139
+
140
+ def forward(self, encoder_features: List[torch.Tensor]) -> tuple:
141
+ """
142
+ Forward pass
143
+
144
+ Args:
145
+ encoder_features: List of encoder features [stage0, ..., stageN]
146
+
147
+ Returns:
148
+ output: Forgery probability map (B, 1, H, W)
149
+ decoder_features: List of decoder features for hybrid extraction
150
+ """
151
+ # Reverse for bottom-up decoding
152
+ features = encoder_features[::-1]
153
+
154
+ # Initial convolution
155
+ x = self.initial_conv(features[0])
156
+
157
+ # Store decoder features for hybrid feature extraction
158
+ decoder_features = [x]
159
+
160
+ # Decoder blocks with skip connections
161
+ for i, block in enumerate(self.decoder_blocks):
162
+ x = block(x, features[i + 1])
163
+ decoder_features.append(x)
164
+
165
+ # Final upsampling to original resolution
166
+ # Assume input was 384x384, final feature map should match
167
+ target_size = encoder_features[0].shape[2] * 2 # First encoder feature is at 1/2 scale
168
+ x = F.interpolate(x, size=(target_size, target_size), mode='bilinear', align_corners=False)
169
+ output = self.final_upsample[1](self.final_upsample[0](x))
170
+
171
+ return output, decoder_features
172
+
173
+
174
+ def get_decoder(encoder_channels: List[int], config) -> UNetLiteDecoder:
175
+ """
176
+ Factory function to create decoder
177
+
178
+ Args:
179
+ encoder_channels: Encoder feature channels
180
+ config: Configuration object
181
+
182
+ Returns:
183
+ Decoder instance
184
+ """
185
+ output_channels = config.get('model.output_channels', 1)
186
+ return UNetLiteDecoder(encoder_channels, output_channels=output_channels)
src/models/encoder.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MobileNetV3-Small Encoder for forgery localization
3
+ ImageNet pretrained, feature extraction mode
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import timm
9
+ from typing import List
10
+
11
+
12
+ class MobileNetV3Encoder(nn.Module):
13
+ """
14
+ MobileNetV3-Small encoder for document forgery detection
15
+
16
+ Chosen for:
17
+ - Stroke-level and texture preservation
18
+ - Robustness to compression and blur
19
+ - Edge and CPU deployment efficiency
20
+ """
21
+
22
+ def __init__(self, pretrained: bool = True):
23
+ """
24
+ Initialize encoder
25
+
26
+ Args:
27
+ pretrained: Whether to use ImageNet pretrained weights
28
+ """
29
+ super().__init__()
30
+
31
+ # Load MobileNetV3-Small with feature extraction
32
+ self.backbone = timm.create_model(
33
+ 'mobilenetv3_small_100',
34
+ pretrained=pretrained,
35
+ features_only=True,
36
+ out_indices=(0, 1, 2, 3, 4) # All feature stages
37
+ )
38
+
39
+ # Get feature channel dimensions
40
+ # MobileNetV3-Small: [16, 16, 24, 48, 576]
41
+ self.feature_channels = self.backbone.feature_info.channels()
42
+
43
+ print(f"MobileNetV3-Small encoder initialized")
44
+ print(f"Feature channels: {self.feature_channels}")
45
+
46
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
47
+ """
48
+ Extract multi-scale features
49
+
50
+ Args:
51
+ x: Input tensor (B, 3, H, W)
52
+
53
+ Returns:
54
+ List of feature tensors at different scales
55
+ """
56
+ features = self.backbone(x)
57
+ return features
58
+
59
+ def get_feature_channels(self) -> List[int]:
60
+ """Get feature channel dimensions for each stage"""
61
+ return self.feature_channels
62
+
63
+
64
+ def get_encoder(config) -> MobileNetV3Encoder:
65
+ """
66
+ Factory function to create encoder
67
+
68
+ Args:
69
+ config: Configuration object
70
+
71
+ Returns:
72
+ Encoder instance
73
+ """
74
+ pretrained = config.get('model.encoder.pretrained', True)
75
+ return MobileNetV3Encoder(pretrained=pretrained)
src/models/losses.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset-aware loss functions
3
+ Implements Critical Fix #2: Dataset-Aware Loss Function
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from typing import Dict, Optional
10
+
11
+
12
+ class DiceLoss(nn.Module):
13
+ """Dice loss for segmentation"""
14
+
15
+ def __init__(self, smooth: float = 1.0):
16
+ """
17
+ Initialize Dice loss
18
+
19
+ Args:
20
+ smooth: Smoothing factor to avoid division by zero
21
+ """
22
+ super().__init__()
23
+ self.smooth = smooth
24
+
25
+ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
26
+ """
27
+ Compute Dice loss
28
+
29
+ Args:
30
+ pred: Predicted probabilities (B, 1, H, W)
31
+ target: Ground truth mask (B, 1, H, W)
32
+
33
+ Returns:
34
+ Dice loss value
35
+ """
36
+ pred = torch.sigmoid(pred)
37
+
38
+ # Flatten
39
+ pred_flat = pred.view(-1)
40
+ target_flat = target.view(-1)
41
+
42
+ # Dice coefficient
43
+ intersection = (pred_flat * target_flat).sum()
44
+ dice = (2. * intersection + self.smooth) / (
45
+ pred_flat.sum() + target_flat.sum() + self.smooth
46
+ )
47
+
48
+ return 1 - dice
49
+
50
+
51
+ class CombinedLoss(nn.Module):
52
+ """
53
+ Combined BCE + Dice loss for segmentation
54
+ Dataset-aware: Only uses Dice when pixel masks are available
55
+ """
56
+
57
+ def __init__(self,
58
+ bce_weight: float = 1.0,
59
+ dice_weight: float = 1.0):
60
+ """
61
+ Initialize combined loss
62
+
63
+ Args:
64
+ bce_weight: Weight for BCE loss
65
+ dice_weight: Weight for Dice loss
66
+ """
67
+ super().__init__()
68
+
69
+ self.bce_weight = bce_weight
70
+ self.dice_weight = dice_weight
71
+
72
+ self.bce_loss = nn.BCEWithLogitsLoss()
73
+ self.dice_loss = DiceLoss()
74
+
75
+ def forward(self,
76
+ pred: torch.Tensor,
77
+ target: torch.Tensor,
78
+ has_pixel_mask: bool = True) -> Dict[str, torch.Tensor]:
79
+ """
80
+ Compute loss (dataset-aware)
81
+
82
+ Critical Fix #2: Only use Dice loss for datasets with pixel masks
83
+
84
+ Args:
85
+ pred: Predicted logits (B, 1, H, W)
86
+ target: Ground truth mask (B, 1, H, W)
87
+ has_pixel_mask: Whether dataset has pixel-level masks
88
+
89
+ Returns:
90
+ Dictionary with 'total', 'bce', and optionally 'dice' losses
91
+ """
92
+ # BCE loss (always used)
93
+ bce = self.bce_loss(pred, target)
94
+
95
+ losses = {
96
+ 'bce': bce
97
+ }
98
+
99
+ if has_pixel_mask:
100
+ # Use Dice loss only for datasets with pixel masks
101
+ dice = self.dice_loss(pred, target)
102
+ losses['dice'] = dice
103
+ losses['total'] = self.bce_weight * bce + self.dice_weight * dice
104
+ else:
105
+ # Critical Fix #2: CASIA only uses BCE
106
+ losses['total'] = self.bce_weight * bce
107
+
108
+ return losses
109
+
110
+
111
+ class DatasetAwareLoss(nn.Module):
112
+ """
113
+ Dataset-aware loss function wrapper
114
+ Automatically determines appropriate loss based on dataset metadata
115
+ """
116
+
117
+ def __init__(self, config):
118
+ """
119
+ Initialize dataset-aware loss
120
+
121
+ Args:
122
+ config: Configuration object
123
+ """
124
+ super().__init__()
125
+
126
+ self.config = config
127
+
128
+ bce_weight = config.get('loss.bce_weight', 1.0)
129
+ dice_weight = config.get('loss.dice_weight', 1.0)
130
+
131
+ self.combined_loss = CombinedLoss(
132
+ bce_weight=bce_weight,
133
+ dice_weight=dice_weight
134
+ )
135
+
136
+ def forward(self,
137
+ pred: torch.Tensor,
138
+ target: torch.Tensor,
139
+ metadata: Dict) -> Dict[str, torch.Tensor]:
140
+ """
141
+ Compute loss with dataset awareness
142
+
143
+ Args:
144
+ pred: Predicted logits (B, 1, H, W)
145
+ target: Ground truth mask (B, 1, H, W)
146
+ metadata: Batch metadata containing 'has_pixel_mask' flags
147
+
148
+ Returns:
149
+ Dictionary with loss components
150
+ """
151
+ # Check if batch has pixel masks
152
+ has_pixel_mask = all(m.get('has_pixel_mask', True) for m in metadata) \
153
+ if isinstance(metadata, list) else metadata.get('has_pixel_mask', True)
154
+
155
+ return self.combined_loss(pred, target, has_pixel_mask)
156
+
157
+
158
+ def get_loss_function(config) -> DatasetAwareLoss:
159
+ """
160
+ Factory function to create loss
161
+
162
+ Args:
163
+ config: Configuration object
164
+
165
+ Returns:
166
+ Loss function instance
167
+ """
168
+ return DatasetAwareLoss(config)
src/models/network.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Complete Forgery Localization Network
3
+ MobileNetV3-Small Encoder + UNet-Lite Decoder
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from typing import Tuple, List, Optional
9
+
10
+ from .encoder import MobileNetV3Encoder
11
+ from .decoder import UNetLiteDecoder
12
+
13
+
14
+ class ForgeryLocalizationNetwork(nn.Module):
15
+ """
16
+ Complete network for forgery localization
17
+
18
+ Architecture:
19
+ - Encoder: MobileNetV3-Small (ImageNet pretrained)
20
+ - Decoder: UNet-Lite with skip connections
21
+ - Output: Single-channel forgery probability map
22
+ """
23
+
24
+ def __init__(self, config):
25
+ """
26
+ Initialize network
27
+
28
+ Args:
29
+ config: Configuration object
30
+ """
31
+ super().__init__()
32
+
33
+ self.config = config
34
+
35
+ # Initialize encoder
36
+ pretrained = config.get('model.encoder.pretrained', True)
37
+ self.encoder = MobileNetV3Encoder(pretrained=pretrained)
38
+
39
+ # Initialize decoder
40
+ encoder_channels = self.encoder.get_feature_channels()
41
+ output_channels = config.get('model.output_channels', 1)
42
+ self.decoder = UNetLiteDecoder(
43
+ encoder_channels=encoder_channels,
44
+ output_channels=output_channels
45
+ )
46
+
47
+ print(f"ForgeryLocalizationNetwork initialized")
48
+ print(f"Total parameters: {self.count_parameters():,}")
49
+
50
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
51
+ """
52
+ Forward pass
53
+
54
+ Args:
55
+ x: Input image tensor (B, 3, H, W)
56
+
57
+ Returns:
58
+ output: Forgery probability map (B, 1, H, W) - logits
59
+ decoder_features: Decoder features for hybrid feature extraction
60
+ """
61
+ # Encode
62
+ encoder_features = self.encoder(x)
63
+
64
+ # Decode
65
+ output, decoder_features = self.decoder(encoder_features)
66
+
67
+ return output, decoder_features
68
+
69
+ def predict(self, x: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
70
+ """
71
+ Predict binary mask
72
+
73
+ Args:
74
+ x: Input image tensor (B, 3, H, W)
75
+ threshold: Probability threshold for binarization
76
+
77
+ Returns:
78
+ Binary mask (B, 1, H, W)
79
+ """
80
+ with torch.no_grad():
81
+ logits, _ = self.forward(x)
82
+ probs = torch.sigmoid(logits)
83
+ mask = (probs > threshold).float()
84
+
85
+ return mask
86
+
87
+ def get_probability_map(self, x: torch.Tensor) -> torch.Tensor:
88
+ """
89
+ Get probability map
90
+
91
+ Args:
92
+ x: Input image tensor (B, 3, H, W)
93
+
94
+ Returns:
95
+ Probability map (B, 1, H, W)
96
+ """
97
+ with torch.no_grad():
98
+ logits, _ = self.forward(x)
99
+ probs = torch.sigmoid(logits)
100
+
101
+ return probs
102
+
103
+ def count_parameters(self) -> int:
104
+ """Count total trainable parameters"""
105
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
106
+
107
+ def get_decoder_features(self, x: torch.Tensor) -> List[torch.Tensor]:
108
+ """
109
+ Get decoder features for hybrid feature extraction
110
+
111
+ Args:
112
+ x: Input image tensor (B, 3, H, W)
113
+
114
+ Returns:
115
+ List of decoder features
116
+ """
117
+ with torch.no_grad():
118
+ _, decoder_features = self.forward(x)
119
+
120
+ return decoder_features
121
+
122
+
123
+ def get_model(config) -> ForgeryLocalizationNetwork:
124
+ """
125
+ Factory function to create model
126
+
127
+ Args:
128
+ config: Configuration object
129
+
130
+ Returns:
131
+ Model instance
132
+ """
133
+ return ForgeryLocalizationNetwork(config)
src/training/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training module"""
2
+
3
+ from .metrics import (
4
+ SegmentationMetrics,
5
+ ClassificationMetrics,
6
+ MetricsTracker,
7
+ EarlyStopping,
8
+ get_metrics_tracker
9
+ )
10
+
11
+ from .trainer import Trainer, get_trainer
12
+ from .classifier import ForgeryClassifier, get_classifier
13
+
14
+ __all__ = [
15
+ 'SegmentationMetrics',
16
+ 'ClassificationMetrics',
17
+ 'MetricsTracker',
18
+ 'EarlyStopping',
19
+ 'get_metrics_tracker',
20
+ 'Trainer',
21
+ 'get_trainer',
22
+ 'ForgeryClassifier',
23
+ 'get_classifier'
24
+ ]
src/training/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (568 Bytes). View file
 
src/training/__pycache__/classifier.cpython-312.pyc ADDED
Binary file (11 kB). View file
 
src/training/__pycache__/metrics.cpython-312.pyc ADDED
Binary file (12.5 kB). View file
 
src/training/__pycache__/trainer.cpython-312.pyc ADDED
Binary file (18.8 kB). View file
 
src/training/classifier.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LightGBM classifier for forgery type classification
3
+ Implements Critical Fix #8: Configurable Confidence Threshold
4
+ """
5
+
6
+ import numpy as np
7
+ import lightgbm as lgb
8
+ from sklearn.preprocessing import StandardScaler
9
+ from sklearn.model_selection import train_test_split
10
+ from typing import Dict, List, Tuple, Optional
11
+ import joblib
12
+ from pathlib import Path
13
+ import json
14
+
15
+
16
+ class ForgeryClassifier:
17
+ """
18
+ LightGBM classifier for region-wise forgery classification
19
+
20
+ Target classes:
21
+ - 0: copy_move
22
+ - 1: splicing
23
+ - 2: text_substitution
24
+ """
25
+
26
+ CLASS_NAMES = ['copy_move', 'splicing', 'text_substitution']
27
+
28
+ def __init__(self, config):
29
+ """
30
+ Initialize classifier
31
+
32
+ Args:
33
+ config: Configuration object
34
+ """
35
+ self.config = config
36
+
37
+ # LightGBM parameters
38
+ self.params = config.get('classifier.params', {
39
+ 'objective': 'multiclass',
40
+ 'num_class': 3,
41
+ 'boosting_type': 'gbdt',
42
+ 'num_leaves': 31,
43
+ 'learning_rate': 0.05,
44
+ 'n_estimators': 200,
45
+ 'max_depth': 7,
46
+ 'min_child_samples': 20,
47
+ 'subsample': 0.8,
48
+ 'colsample_bytree': 0.8,
49
+ 'reg_alpha': 0.1,
50
+ 'reg_lambda': 0.1,
51
+ 'random_state': 42,
52
+ 'verbose': -1
53
+ })
54
+
55
+ # Critical Fix #8: Configurable confidence threshold
56
+ self.confidence_threshold = config.get('classifier.confidence_threshold', 0.6)
57
+
58
+ # Initialize model and scaler
59
+ self.model = None
60
+ self.scaler = StandardScaler()
61
+
62
+ # Feature importance
63
+ self.feature_importance = None
64
+ self.feature_names = None
65
+
66
+ def train(self,
67
+ features: np.ndarray,
68
+ labels: np.ndarray,
69
+ feature_names: Optional[List[str]] = None,
70
+ validation_split: float = 0.2) -> Dict:
71
+ """
72
+ Train classifier
73
+
74
+ Args:
75
+ features: Feature matrix (N, D)
76
+ labels: Class labels (N,)
77
+ feature_names: Optional feature names
78
+ validation_split: Validation split ratio
79
+
80
+ Returns:
81
+ Training metrics
82
+ """
83
+ print(f"Training LightGBM classifier")
84
+ print(f"Features shape: {features.shape}")
85
+ print(f"Labels distribution: {np.bincount(labels)}")
86
+
87
+ # Handle NaN/Inf
88
+ features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
89
+
90
+ # Normalize features
91
+ features_scaled = self.scaler.fit_transform(features)
92
+
93
+ # Split data (Critical Fix #7: Image-level splitting should be done upstream)
94
+ X_train, X_val, y_train, y_val = train_test_split(
95
+ features_scaled, labels,
96
+ test_size=validation_split,
97
+ random_state=42,
98
+ stratify=labels
99
+ )
100
+
101
+ # Create LightGBM datasets
102
+ train_data = lgb.Dataset(X_train, label=y_train)
103
+ val_data = lgb.Dataset(X_val, label=y_val, reference=train_data)
104
+
105
+ # Train model
106
+ self.model = lgb.train(
107
+ self.params,
108
+ train_data,
109
+ valid_sets=[train_data, val_data],
110
+ valid_names=['train', 'val'],
111
+ num_boost_round=self.params.get('n_estimators', 200),
112
+ callbacks=[
113
+ lgb.early_stopping(stopping_rounds=20),
114
+ lgb.log_evaluation(period=10)
115
+ ]
116
+ )
117
+
118
+ # Store feature importance
119
+ self.feature_names = feature_names
120
+ self.feature_importance = self.model.feature_importance(importance_type='gain')
121
+
122
+ # Evaluate
123
+ train_pred = self.model.predict(X_train)
124
+ train_acc = (train_pred.argmax(axis=1) == y_train).mean()
125
+
126
+ val_pred = self.model.predict(X_val)
127
+ val_acc = (val_pred.argmax(axis=1) == y_val).mean()
128
+
129
+ metrics = {
130
+ 'train_accuracy': train_acc,
131
+ 'val_accuracy': val_acc,
132
+ 'num_features': features.shape[1],
133
+ 'num_samples': len(labels),
134
+ 'best_iteration': self.model.best_iteration
135
+ }
136
+
137
+ print(f"Training complete!")
138
+ print(f"Train accuracy: {train_acc:.4f}")
139
+ print(f"Val accuracy: {val_acc:.4f}")
140
+
141
+ return metrics
142
+
143
+ def predict(self, features: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
144
+ """
145
+ Predict forgery types
146
+
147
+ Args:
148
+ features: Feature matrix (N, D)
149
+
150
+ Returns:
151
+ predictions: Predicted class indices (N,)
152
+ confidences: Prediction confidences (N,)
153
+ """
154
+ if self.model is None:
155
+ raise ValueError("Model not trained. Call train() first.")
156
+
157
+ # Handle NaN/Inf
158
+ features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
159
+
160
+ # Normalize features
161
+ features_scaled = self.scaler.transform(features)
162
+
163
+ # Predict probabilities
164
+ probabilities = self.model.predict(features_scaled)
165
+
166
+ # Get predictions and confidences
167
+ predictions = probabilities.argmax(axis=1)
168
+ confidences = probabilities.max(axis=1)
169
+
170
+ return predictions, confidences
171
+
172
+ def predict_with_filtering(self,
173
+ features: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
174
+ """
175
+ Predict with confidence filtering
176
+
177
+ Args:
178
+ features: Feature matrix (N, D)
179
+
180
+ Returns:
181
+ predictions: Predicted class indices (N,)
182
+ confidences: Prediction confidences (N,)
183
+ valid_mask: Boolean mask for valid predictions (N,)
184
+ """
185
+ predictions, confidences = self.predict(features)
186
+
187
+ # Critical Fix #8: Apply confidence threshold
188
+ valid_mask = confidences >= self.confidence_threshold
189
+
190
+ return predictions, confidences, valid_mask
191
+
192
+ def get_class_name(self, class_idx: int) -> str:
193
+ """Get class name from index"""
194
+ return self.CLASS_NAMES[class_idx]
195
+
196
+ def get_feature_importance(self, top_k: int = 20) -> List[Tuple[str, float]]:
197
+ """
198
+ Get top-k most important features
199
+
200
+ Args:
201
+ top_k: Number of features to return
202
+
203
+ Returns:
204
+ List of (feature_name, importance) tuples
205
+ """
206
+ if self.feature_importance is None:
207
+ return []
208
+
209
+ # Sort by importance
210
+ indices = np.argsort(self.feature_importance)[::-1][:top_k]
211
+
212
+ result = []
213
+ for idx in indices:
214
+ name = self.feature_names[idx] if self.feature_names else f'feature_{idx}'
215
+ importance = self.feature_importance[idx]
216
+ result.append((name, importance))
217
+
218
+ return result
219
+
220
+ def save(self, save_dir: str):
221
+ """
222
+ Save model and scaler
223
+
224
+ Args:
225
+ save_dir: Directory to save model
226
+ """
227
+ save_path = Path(save_dir)
228
+ save_path.mkdir(parents=True, exist_ok=True)
229
+
230
+ # Save LightGBM model
231
+ model_path = save_path / 'lightgbm_model.txt'
232
+ self.model.save_model(str(model_path))
233
+
234
+ # Save scaler
235
+ scaler_path = save_path / 'scaler.joblib'
236
+ joblib.dump(self.scaler, str(scaler_path))
237
+
238
+ # Save metadata
239
+ metadata = {
240
+ 'confidence_threshold': self.confidence_threshold,
241
+ 'class_names': self.CLASS_NAMES,
242
+ 'feature_names': self.feature_names,
243
+ 'feature_importance': self.feature_importance.tolist() if self.feature_importance is not None else None
244
+ }
245
+ metadata_path = save_path / 'classifier_metadata.json'
246
+ with open(metadata_path, 'w') as f:
247
+ json.dump(metadata, f, indent=2)
248
+
249
+ print(f"Classifier saved to {save_path}")
250
+
251
+ def load(self, load_dir: str):
252
+ """
253
+ Load model and scaler
254
+
255
+ Args:
256
+ load_dir: Directory to load from
257
+ """
258
+ load_path = Path(load_dir)
259
+
260
+ # Load LightGBM model
261
+ model_path = load_path / 'lightgbm_model.txt'
262
+ self.model = lgb.Booster(model_file=str(model_path))
263
+
264
+ # Load scaler
265
+ scaler_path = load_path / 'scaler.joblib'
266
+ self.scaler = joblib.load(str(scaler_path))
267
+
268
+ # Load metadata
269
+ metadata_path = load_path / 'classifier_metadata.json'
270
+ with open(metadata_path, 'r') as f:
271
+ metadata = json.load(f)
272
+
273
+ self.confidence_threshold = metadata.get('confidence_threshold', 0.6)
274
+ self.feature_names = metadata.get('feature_names')
275
+ self.feature_importance = np.array(metadata.get('feature_importance', []))
276
+
277
+ print(f"Classifier loaded from {load_path}")
278
+
279
+
280
+ def get_classifier(config) -> ForgeryClassifier:
281
+ """Factory function for classifier"""
282
+ return ForgeryClassifier(config)
src/training/metrics.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training utilities and metrics
3
+ Implements Critical Fix #9: Dataset-Aware Metric Computation
4
+ """
5
+
6
+ import torch
7
+ import numpy as np
8
+ from typing import Dict, List, Optional
9
+ from sklearn.metrics import (
10
+ accuracy_score, f1_score, precision_score, recall_score,
11
+ confusion_matrix
12
+ )
13
+
14
+
15
+ class SegmentationMetrics:
16
+ """
17
+ Segmentation metrics (IoU, Dice)
18
+ Only computed for datasets with pixel masks (Critical Fix #9)
19
+ """
20
+
21
+ def __init__(self):
22
+ """Initialize metrics"""
23
+ self.reset()
24
+
25
+ def reset(self):
26
+ """Reset all metrics"""
27
+ self.intersection = 0
28
+ self.union = 0
29
+ self.pred_sum = 0
30
+ self.target_sum = 0
31
+ self.total_samples = 0
32
+
33
+ def update(self,
34
+ pred: torch.Tensor,
35
+ target: torch.Tensor,
36
+ has_pixel_mask: bool = True):
37
+ """
38
+ Update metrics with batch
39
+
40
+ Args:
41
+ pred: Predicted probabilities (B, 1, H, W)
42
+ target: Ground truth masks (B, 1, H, W)
43
+ has_pixel_mask: Whether to compute metrics (Critical Fix #9)
44
+ """
45
+ if not has_pixel_mask:
46
+ return
47
+
48
+ # Binarize predictions
49
+ pred_binary = (pred > 0.5).float()
50
+
51
+ # Compute intersection and union
52
+ intersection = (pred_binary * target).sum().item()
53
+ union = pred_binary.sum().item() + target.sum().item() - intersection
54
+
55
+ self.intersection += intersection
56
+ self.union += union
57
+ self.pred_sum += pred_binary.sum().item()
58
+ self.target_sum += target.sum().item()
59
+ self.total_samples += pred.shape[0]
60
+
61
+ def compute(self) -> Dict[str, float]:
62
+ """
63
+ Compute final metrics
64
+
65
+ Returns:
66
+ Dictionary with IoU, Dice, Precision, Recall
67
+ """
68
+ # IoU (Jaccard)
69
+ iou = self.intersection / (self.union + 1e-8)
70
+
71
+ # Dice (F1)
72
+ dice = (2 * self.intersection) / (self.pred_sum + self.target_sum + 1e-8)
73
+
74
+ # Precision
75
+ precision = self.intersection / (self.pred_sum + 1e-8)
76
+
77
+ # Recall
78
+ recall = self.intersection / (self.target_sum + 1e-8)
79
+
80
+ return {
81
+ 'iou': iou,
82
+ 'dice': dice,
83
+ 'precision': precision,
84
+ 'recall': recall
85
+ }
86
+
87
+
88
+ class ClassificationMetrics:
89
+ """Classification metrics for forgery type classification"""
90
+
91
+ def __init__(self, num_classes: int = 3):
92
+ """
93
+ Initialize metrics
94
+
95
+ Args:
96
+ num_classes: Number of forgery types
97
+ """
98
+ self.num_classes = num_classes
99
+ self.reset()
100
+
101
+ def reset(self):
102
+ """Reset all metrics"""
103
+ self.predictions = []
104
+ self.targets = []
105
+ self.confidences = []
106
+
107
+ def update(self,
108
+ pred: np.ndarray,
109
+ target: np.ndarray,
110
+ confidence: Optional[np.ndarray] = None):
111
+ """
112
+ Update metrics with predictions
113
+
114
+ Args:
115
+ pred: Predicted class indices
116
+ target: Ground truth class indices
117
+ confidence: Optional prediction confidences
118
+ """
119
+ self.predictions.extend(pred.tolist())
120
+ self.targets.extend(target.tolist())
121
+ if confidence is not None:
122
+ self.confidences.extend(confidence.tolist())
123
+
124
+ def compute(self) -> Dict[str, float]:
125
+ """
126
+ Compute final metrics
127
+
128
+ Returns:
129
+ Dictionary with Accuracy, F1, Precision, Recall
130
+ """
131
+ if len(self.predictions) == 0:
132
+ return {
133
+ 'accuracy': 0.0,
134
+ 'f1_macro': 0.0,
135
+ 'f1_weighted': 0.0,
136
+ 'precision': 0.0,
137
+ 'recall': 0.0
138
+ }
139
+
140
+ preds = np.array(self.predictions)
141
+ targets = np.array(self.targets)
142
+
143
+ # Accuracy
144
+ accuracy = accuracy_score(targets, preds)
145
+
146
+ # F1 score (macro and weighted)
147
+ f1_macro = f1_score(targets, preds, average='macro', zero_division=0)
148
+ f1_weighted = f1_score(targets, preds, average='weighted', zero_division=0)
149
+
150
+ # Precision and Recall
151
+ precision = precision_score(targets, preds, average='macro', zero_division=0)
152
+ recall = recall_score(targets, preds, average='macro', zero_division=0)
153
+
154
+ # Confusion matrix
155
+ cm = confusion_matrix(targets, preds, labels=range(self.num_classes))
156
+
157
+ return {
158
+ 'accuracy': accuracy,
159
+ 'f1_macro': f1_macro,
160
+ 'f1_weighted': f1_weighted,
161
+ 'precision': precision,
162
+ 'recall': recall,
163
+ 'confusion_matrix': cm.tolist()
164
+ }
165
+
166
+
167
+ class MetricsTracker:
168
+ """Track all metrics during training"""
169
+
170
+ def __init__(self, config):
171
+ """
172
+ Initialize metrics tracker
173
+
174
+ Args:
175
+ config: Configuration object
176
+ """
177
+ self.config = config
178
+ self.num_classes = config.get('data.num_classes', 3)
179
+
180
+ self.seg_metrics = SegmentationMetrics()
181
+ self.cls_metrics = ClassificationMetrics(self.num_classes)
182
+
183
+ self.history = {
184
+ 'train_loss': [],
185
+ 'val_loss': [],
186
+ 'train_iou': [],
187
+ 'val_iou': [],
188
+ 'train_dice': [],
189
+ 'val_dice': [],
190
+ 'train_precision': [],
191
+ 'val_precision': [],
192
+ 'train_recall': [],
193
+ 'val_recall': []
194
+ }
195
+
196
+ def reset(self):
197
+ """Reset metrics for new epoch"""
198
+ self.seg_metrics.reset()
199
+ self.cls_metrics.reset()
200
+
201
+ def update_segmentation(self,
202
+ pred: torch.Tensor,
203
+ target: torch.Tensor,
204
+ dataset_name: str):
205
+ """Update segmentation metrics (dataset-aware)"""
206
+ has_pixel_mask = self.config.should_compute_localization_metrics(dataset_name)
207
+ self.seg_metrics.update(pred, target, has_pixel_mask)
208
+
209
+ def update_classification(self,
210
+ pred: np.ndarray,
211
+ target: np.ndarray,
212
+ confidence: Optional[np.ndarray] = None):
213
+ """Update classification metrics"""
214
+ self.cls_metrics.update(pred, target, confidence)
215
+
216
+ def compute_all(self) -> Dict[str, float]:
217
+ """Compute all metrics"""
218
+ seg = self.seg_metrics.compute()
219
+
220
+ # Only include classification metrics if they have data
221
+ if len(self.cls_metrics.predictions) > 0:
222
+ cls = self.cls_metrics.compute()
223
+ # Prefix classification metrics to avoid collision
224
+ cls_prefixed = {f'cls_{k}': v for k, v in cls.items()}
225
+ return {**seg, **cls_prefixed}
226
+
227
+ return seg
228
+
229
+ def log_epoch(self, epoch: int, phase: str, loss: float, metrics: Dict):
230
+ """Log metrics for epoch"""
231
+ prefix = f'{phase}_'
232
+
233
+ self.history[f'{phase}_loss'].append(loss)
234
+
235
+ if 'iou' in metrics:
236
+ self.history[f'{phase}_iou'].append(metrics['iou'])
237
+ if 'dice' in metrics:
238
+ self.history[f'{phase}_dice'].append(metrics['dice'])
239
+ if 'precision' in metrics:
240
+ self.history[f'{phase}_precision'].append(metrics['precision'])
241
+ if 'recall' in metrics:
242
+ self.history[f'{phase}_recall'].append(metrics['recall'])
243
+
244
+ def get_history(self) -> Dict:
245
+ """Get full training history"""
246
+ return self.history
247
+
248
+
249
+ class EarlyStopping:
250
+ """Early stopping to prevent overfitting"""
251
+
252
+ def __init__(self,
253
+ patience: int = 10,
254
+ min_delta: float = 0.001,
255
+ mode: str = 'max'):
256
+ """
257
+ Initialize early stopping
258
+
259
+ Args:
260
+ patience: Number of epochs to wait
261
+ min_delta: Minimum improvement required
262
+ mode: 'min' for loss, 'max' for metrics
263
+ """
264
+ self.patience = patience
265
+ self.min_delta = min_delta
266
+ self.mode = mode
267
+
268
+ self.counter = 0
269
+ self.best_value = None
270
+ self.should_stop = False
271
+
272
+ def __call__(self, value: float) -> bool:
273
+ """
274
+ Check if training should stop
275
+
276
+ Args:
277
+ value: Current metric value
278
+
279
+ Returns:
280
+ True if should stop
281
+ """
282
+ if self.best_value is None:
283
+ self.best_value = value
284
+ return False
285
+
286
+ if self.mode == 'max':
287
+ improved = value > self.best_value + self.min_delta
288
+ else:
289
+ improved = value < self.best_value - self.min_delta
290
+
291
+ if improved:
292
+ self.best_value = value
293
+ self.counter = 0
294
+ else:
295
+ self.counter += 1
296
+
297
+ if self.counter >= self.patience:
298
+ self.should_stop = True
299
+
300
+ return self.should_stop
301
+
302
+
303
+ def get_metrics_tracker(config) -> MetricsTracker:
304
+ """Factory function for metrics tracker"""
305
+ return MetricsTracker(config)
src/training/trainer.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training loop for forgery localization network
3
+ Implements chunked training for RAM constraints
4
+ """
5
+
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.optim as optim
10
+ from torch.utils.data import DataLoader
11
+ from torch.cuda.amp import autocast, GradScaler
12
+ from typing import Dict, Optional, Tuple
13
+ from pathlib import Path
14
+ from tqdm import tqdm
15
+ import json
16
+ import csv
17
+
18
+ from ..models import get_model, get_loss_function
19
+ from ..data import get_dataset
20
+ from .metrics import MetricsTracker, EarlyStopping
21
+
22
+
23
+ class Trainer:
24
+ """
25
+ Trainer for forgery localization network
26
+ Supports chunked training for large datasets (DocTamper)
27
+ """
28
+
29
+ def __init__(self, config, dataset_name: str = 'doctamper'):
30
+ """
31
+ Initialize trainer
32
+
33
+ Args:
34
+ config: Configuration object
35
+ dataset_name: Dataset to train on
36
+ """
37
+ self.config = config
38
+ self.dataset_name = dataset_name
39
+
40
+ # Device setup
41
+ self.device = torch.device(
42
+ 'cuda' if torch.cuda.is_available() and config.get('system.device') == 'cuda'
43
+ else 'cpu'
44
+ )
45
+ print(f"Training on: {self.device}")
46
+
47
+ # Initialize model
48
+ self.model = get_model(config).to(self.device)
49
+
50
+ # Loss function (dataset-aware)
51
+ self.criterion = get_loss_function(config)
52
+
53
+ # Optimizer
54
+ lr = config.get('training.learning_rate', 0.001)
55
+ weight_decay = config.get('training.weight_decay', 0.0001)
56
+ self.optimizer = optim.AdamW(
57
+ self.model.parameters(),
58
+ lr=lr,
59
+ weight_decay=weight_decay
60
+ )
61
+
62
+ # Learning rate scheduler
63
+ epochs = config.get('training.epochs', 50)
64
+ warmup_epochs = config.get('training.scheduler.warmup_epochs', 5)
65
+ min_lr = config.get('training.scheduler.min_lr', 1e-5)
66
+
67
+ self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
68
+ self.optimizer,
69
+ T_0=epochs - warmup_epochs,
70
+ T_mult=1,
71
+ eta_min=min_lr
72
+ )
73
+
74
+ # Mixed precision training
75
+ self.scaler = GradScaler()
76
+
77
+ # Metrics
78
+ self.metrics_tracker = MetricsTracker(config)
79
+
80
+ # Early stopping
81
+ patience = config.get('training.early_stopping.patience', 10)
82
+ min_delta = config.get('training.early_stopping.min_delta', 0.001)
83
+ self.early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
84
+
85
+ # Output directories
86
+ self.checkpoint_dir = Path(config.get('outputs.checkpoints', 'outputs/checkpoints'))
87
+ self.log_dir = Path(config.get('outputs.logs', 'outputs/logs'))
88
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
89
+ self.log_dir.mkdir(parents=True, exist_ok=True)
90
+
91
+ # Training state
92
+ self.current_epoch = 0
93
+ self.best_metric = 0.0
94
+
95
+ def create_dataloaders(self,
96
+ chunk_start: float = 0.0,
97
+ chunk_end: float = 1.0) -> Tuple[DataLoader, DataLoader]:
98
+ """
99
+ Create train and validation dataloaders
100
+
101
+ Args:
102
+ chunk_start: Start ratio for chunked training
103
+ chunk_end: End ratio for chunked training
104
+
105
+ Returns:
106
+ Train and validation dataloaders
107
+ """
108
+ batch_size = self.config.get('data.batch_size', 8)
109
+ num_workers = self.config.get('system.num_workers', 4)
110
+
111
+ # Training dataset (with chunking for DocTamper)
112
+ if self.dataset_name == 'doctamper':
113
+ train_dataset = get_dataset(
114
+ self.config,
115
+ self.dataset_name,
116
+ split='train',
117
+ chunk_start=chunk_start,
118
+ chunk_end=chunk_end
119
+ )
120
+ else:
121
+ train_dataset = get_dataset(
122
+ self.config,
123
+ self.dataset_name,
124
+ split='train'
125
+ )
126
+
127
+ # Validation dataset (always full)
128
+ # For FCD and SCD, validate on DocTamper TestingSet
129
+ if self.dataset_name in ['fcd', 'scd']:
130
+ val_dataset = get_dataset(
131
+ self.config,
132
+ 'doctamper', # Use DocTamper for validation
133
+ split='val'
134
+ )
135
+ else:
136
+ val_dataset = get_dataset(
137
+ self.config,
138
+ self.dataset_name,
139
+ split='val' if self.dataset_name in ['doctamper', 'receipts'] else 'test'
140
+ )
141
+
142
+ train_loader = DataLoader(
143
+ train_dataset,
144
+ batch_size=batch_size,
145
+ shuffle=True,
146
+ num_workers=num_workers,
147
+ pin_memory=self.config.get('system.pin_memory', True),
148
+ drop_last=True
149
+ )
150
+
151
+ val_loader = DataLoader(
152
+ val_dataset,
153
+ batch_size=batch_size,
154
+ shuffle=False,
155
+ num_workers=num_workers,
156
+ pin_memory=True
157
+ )
158
+
159
+ return train_loader, val_loader
160
+
161
+ def train_epoch(self, dataloader: DataLoader) -> Tuple[float, Dict]:
162
+ """
163
+ Train for one epoch
164
+
165
+ Args:
166
+ dataloader: Training dataloader
167
+
168
+ Returns:
169
+ Average loss and metrics
170
+ """
171
+ self.model.train()
172
+ self.metrics_tracker.reset()
173
+
174
+ total_loss = 0.0
175
+ num_batches = 0
176
+
177
+ pbar = tqdm(dataloader, desc=f"Epoch {self.current_epoch} [Train]")
178
+
179
+ for batch_idx, (images, masks, metadata) in enumerate(pbar):
180
+ images = images.to(self.device)
181
+ masks = masks.to(self.device)
182
+
183
+ # Forward pass with mixed precision
184
+ self.optimizer.zero_grad()
185
+
186
+ with autocast():
187
+ outputs, _ = self.model(images)
188
+
189
+ # Dataset-aware loss
190
+ has_pixel_mask = self.config.has_pixel_mask(self.dataset_name)
191
+ losses = self.criterion.combined_loss(outputs, masks, has_pixel_mask)
192
+
193
+ # Backward pass with gradient scaling
194
+ self.scaler.scale(losses['total']).backward()
195
+ self.scaler.step(self.optimizer)
196
+ self.scaler.update()
197
+
198
+ # Update metrics
199
+ with torch.no_grad():
200
+ probs = torch.sigmoid(outputs)
201
+ self.metrics_tracker.update_segmentation(
202
+ probs, masks, self.dataset_name
203
+ )
204
+
205
+ total_loss += losses['total'].item()
206
+ num_batches += 1
207
+
208
+ # Update progress bar
209
+ pbar.set_postfix({
210
+ 'loss': f"{losses['total'].item():.4f}",
211
+ 'bce': f"{losses['bce'].item():.4f}"
212
+ })
213
+
214
+ avg_loss = total_loss / num_batches
215
+ metrics = self.metrics_tracker.compute_all()
216
+
217
+ return avg_loss, metrics
218
+
219
+ def validate(self, dataloader: DataLoader) -> Tuple[float, Dict]:
220
+ """
221
+ Validate model
222
+
223
+ Args:
224
+ dataloader: Validation dataloader
225
+
226
+ Returns:
227
+ Average loss and metrics
228
+ """
229
+ self.model.eval()
230
+ self.metrics_tracker.reset()
231
+
232
+ total_loss = 0.0
233
+ num_batches = 0
234
+
235
+ pbar = tqdm(dataloader, desc=f"Epoch {self.current_epoch} [Val]")
236
+
237
+ with torch.no_grad():
238
+ for images, masks, metadata in pbar:
239
+ images = images.to(self.device)
240
+ masks = masks.to(self.device)
241
+
242
+ # Forward pass
243
+ outputs, _ = self.model(images)
244
+
245
+ # Dataset-aware loss
246
+ has_pixel_mask = self.config.has_pixel_mask(self.dataset_name)
247
+ losses = self.criterion.combined_loss(outputs, masks, has_pixel_mask)
248
+
249
+ # Update metrics
250
+ probs = torch.sigmoid(outputs)
251
+ self.metrics_tracker.update_segmentation(
252
+ probs, masks, self.dataset_name
253
+ )
254
+
255
+ total_loss += losses['total'].item()
256
+ num_batches += 1
257
+
258
+ pbar.set_postfix({
259
+ 'loss': f"{losses['total'].item():.4f}"
260
+ })
261
+
262
+ avg_loss = total_loss / num_batches
263
+ metrics = self.metrics_tracker.compute_all()
264
+
265
+ return avg_loss, metrics
266
+
267
+ def save_checkpoint(self,
268
+ filename: str,
269
+ is_best: bool = False,
270
+ chunk_id: Optional[int] = None):
271
+ """Save model checkpoint"""
272
+ checkpoint = {
273
+ 'epoch': self.current_epoch,
274
+ 'model_state_dict': self.model.state_dict(),
275
+ 'optimizer_state_dict': self.optimizer.state_dict(),
276
+ 'scheduler_state_dict': self.scheduler.state_dict(),
277
+ 'best_metric': self.best_metric,
278
+ 'dataset': self.dataset_name,
279
+ 'chunk_id': chunk_id
280
+ }
281
+
282
+ path = self.checkpoint_dir / filename
283
+ torch.save(checkpoint, path)
284
+ print(f"Saved checkpoint: {path}")
285
+
286
+ if is_best:
287
+ best_path = self.checkpoint_dir / f'best_{self.dataset_name}.pth'
288
+ torch.save(checkpoint, best_path)
289
+ print(f"Saved best model: {best_path}")
290
+
291
+ def load_checkpoint(self, filename: str, reset_epoch: bool = False):
292
+ """
293
+ Load model checkpoint
294
+
295
+ Args:
296
+ filename: Checkpoint filename
297
+ reset_epoch: If True, reset epoch counter to 0 (useful for chunked training)
298
+ """
299
+ path = self.checkpoint_dir / filename
300
+
301
+ if not path.exists():
302
+ print(f"Checkpoint not found: {path}")
303
+ return False
304
+
305
+ checkpoint = torch.load(path, map_location=self.device)
306
+
307
+ self.model.load_state_dict(checkpoint['model_state_dict'])
308
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
309
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
310
+
311
+ if reset_epoch:
312
+ self.current_epoch = 0
313
+ print(f"Loaded checkpoint: {path} (epoch counter reset to 0)")
314
+ else:
315
+ self.current_epoch = checkpoint['epoch'] + 1 # Continue from next epoch
316
+ print(f"Loaded checkpoint: {path} (resuming from epoch {self.current_epoch})")
317
+
318
+ self.best_metric = checkpoint.get('best_metric', 0.0)
319
+
320
+ return True
321
+
322
+ def train(self,
323
+ epochs: Optional[int] = None,
324
+ chunk_start: float = 0.0,
325
+ chunk_end: float = 1.0,
326
+ chunk_id: Optional[int] = None,
327
+ resume_from: Optional[str] = None):
328
+ """
329
+ Main training loop
330
+
331
+ Args:
332
+ epochs: Number of epochs (None uses config)
333
+ chunk_start: Start ratio for chunked training
334
+ chunk_end: End ratio for chunked training
335
+ chunk_id: Chunk identifier for logging
336
+ resume_from: Checkpoint to resume from
337
+ """
338
+ if epochs is None:
339
+ epochs = self.config.get('training.epochs', 50)
340
+
341
+ # Resume if specified
342
+ if resume_from:
343
+ self.load_checkpoint(resume_from)
344
+
345
+ # Create dataloaders
346
+ train_loader, val_loader = self.create_dataloaders(chunk_start, chunk_end)
347
+
348
+ print(f"\n{'='*60}")
349
+ print(f"Training: {self.dataset_name}")
350
+ if chunk_id is not None:
351
+ print(f"Chunk: {chunk_id} [{chunk_start*100:.0f}% - {chunk_end*100:.0f}%]")
352
+ print(f"Epochs: {epochs}")
353
+ print(f"Train samples: {len(train_loader.dataset)}")
354
+ print(f"Val samples: {len(val_loader.dataset)}")
355
+ print(f"{'='*60}\n")
356
+
357
+ # Training log file
358
+ log_file = self.log_dir / f'{self.dataset_name}_chunk{chunk_id or 0}_log.csv'
359
+ with open(log_file, 'w', newline='') as f:
360
+ writer = csv.writer(f)
361
+ writer.writerow(['epoch', 'train_loss', 'val_loss',
362
+ 'train_iou', 'val_iou', 'train_dice', 'val_dice',
363
+ 'train_precision', 'val_precision',
364
+ 'train_recall', 'val_recall', 'lr'])
365
+
366
+ for epoch in range(self.current_epoch, epochs):
367
+ self.current_epoch = epoch
368
+
369
+ # Train
370
+ train_loss, train_metrics = self.train_epoch(train_loader)
371
+
372
+ # Validate
373
+ val_loss, val_metrics = self.validate(val_loader)
374
+
375
+ # Update scheduler
376
+ self.scheduler.step()
377
+ current_lr = self.optimizer.param_groups[0]['lr']
378
+
379
+ # Log metrics
380
+ self.metrics_tracker.log_epoch(epoch, 'train', train_loss, train_metrics)
381
+ self.metrics_tracker.log_epoch(epoch, 'val', val_loss, val_metrics)
382
+
383
+ # Log to file
384
+ with open(log_file, 'a', newline='') as f:
385
+ writer = csv.writer(f)
386
+ writer.writerow([
387
+ epoch,
388
+ f"{train_loss:.4f}",
389
+ f"{val_loss:.4f}",
390
+ f"{train_metrics.get('iou', 0):.4f}",
391
+ f"{val_metrics.get('iou', 0):.4f}",
392
+ f"{train_metrics.get('dice', 0):.4f}",
393
+ f"{val_metrics.get('dice', 0):.4f}",
394
+ f"{train_metrics.get('precision', 0):.4f}",
395
+ f"{val_metrics.get('precision', 0):.4f}",
396
+ f"{train_metrics.get('recall', 0):.4f}",
397
+ f"{val_metrics.get('recall', 0):.4f}",
398
+ f"{current_lr:.6f}"
399
+ ])
400
+
401
+ # Print summary
402
+ print(f"\nEpoch {epoch}/{epochs-1}")
403
+ print(f" Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
404
+ print(f" Train IoU: {train_metrics.get('iou', 0):.4f} | Val IoU: {val_metrics.get('iou', 0):.4f}")
405
+ print(f" Train Dice: {train_metrics.get('dice', 0):.4f} | Val Dice: {val_metrics.get('dice', 0):.4f}")
406
+ print(f" LR: {current_lr:.6f}")
407
+
408
+ # Save checkpoints
409
+ if self.config.get('training.checkpoint.save_every', 5) > 0:
410
+ if (epoch + 1) % self.config.get('training.checkpoint.save_every', 5) == 0:
411
+ self.save_checkpoint(
412
+ f'{self.dataset_name}_chunk{chunk_id or 0}_epoch{epoch}.pth',
413
+ chunk_id=chunk_id
414
+ )
415
+
416
+ # Check for best model
417
+ monitor_metric = val_metrics.get('dice', 0)
418
+ if monitor_metric > self.best_metric:
419
+ self.best_metric = monitor_metric
420
+ self.save_checkpoint(
421
+ f'{self.dataset_name}_chunk{chunk_id or 0}_best.pth',
422
+ is_best=True,
423
+ chunk_id=chunk_id
424
+ )
425
+
426
+ # Early stopping
427
+ if self.early_stopping(monitor_metric):
428
+ print(f"\nEarly stopping triggered at epoch {epoch}")
429
+ break
430
+
431
+ # Save final checkpoint
432
+ self.save_checkpoint(
433
+ f'{self.dataset_name}_chunk{chunk_id or 0}_final.pth',
434
+ chunk_id=chunk_id
435
+ )
436
+
437
+ # Save training history
438
+ history_file = self.log_dir / f'{self.dataset_name}_chunk{chunk_id or 0}_history.json'
439
+ with open(history_file, 'w') as f:
440
+ json.dump(self.metrics_tracker.get_history(), f, indent=2)
441
+
442
+ print(f"\nTraining complete!")
443
+ print(f"Best Dice: {self.best_metric:.4f}")
444
+
445
+ return self.metrics_tracker.get_history()
446
+
447
+
448
+ def get_trainer(config, dataset_name: str = 'doctamper') -> Trainer:
449
+ """Factory function for trainer"""
450
+ return Trainer(config, dataset_name)
src/utils/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utilities module"""
2
+
3
+ from .plotting import (
4
+ plot_training_curves,
5
+ plot_confusion_matrix,
6
+ plot_feature_importance,
7
+ plot_dataset_comparison,
8
+ plot_chunked_training_progress,
9
+ generate_training_report
10
+ )
11
+
12
+ from .export import (
13
+ export_to_onnx,
14
+ export_to_torchscript,
15
+ quantize_model
16
+ )
17
+
18
+ __all__ = [
19
+ 'plot_training_curves',
20
+ 'plot_confusion_matrix',
21
+ 'plot_feature_importance',
22
+ 'plot_dataset_comparison',
23
+ 'plot_chunked_training_progress',
24
+ 'generate_training_report',
25
+ 'export_to_onnx',
26
+ 'export_to_torchscript',
27
+ 'quantize_model'
28
+ ]
src/utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (601 Bytes). View file