JKrishnanandhaa commited on
Commit
6bdcef9
·
verified ·
1 Parent(s): 770b89a

Delete deployment

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. deployment/app.py +0 -231
  2. deployment/config.yaml +0 -297
  3. deployment/models/best_doctamper.pth +0 -3
  4. deployment/models/classifier/classifier_metadata.json +0 -821
  5. deployment/models/classifier/lightgbm_model.txt +0 -0
  6. deployment/models/classifier/scaler.joblib +0 -3
  7. deployment/src/__init__.py +0 -32
  8. deployment/src/__pycache__/__init__.cpython-312.pyc +0 -0
  9. deployment/src/config/__init__.py +0 -5
  10. deployment/src/config/__pycache__/__init__.cpython-312.pyc +0 -0
  11. deployment/src/config/__pycache__/config_loader.cpython-312.pyc +0 -0
  12. deployment/src/config/config_loader.py +0 -117
  13. deployment/src/data/__init__.py +0 -23
  14. deployment/src/data/__pycache__/__init__.cpython-312.pyc +0 -0
  15. deployment/src/data/__pycache__/augmentation.cpython-312.pyc +0 -0
  16. deployment/src/data/__pycache__/datasets.cpython-312.pyc +0 -0
  17. deployment/src/data/__pycache__/preprocessing.cpython-312.pyc +0 -0
  18. deployment/src/data/augmentation.py +0 -150
  19. deployment/src/data/datasets.py +0 -541
  20. deployment/src/data/preprocessing.py +0 -226
  21. deployment/src/features/__init__.py +0 -32
  22. deployment/src/features/__pycache__/__init__.cpython-312.pyc +0 -0
  23. deployment/src/features/__pycache__/feature_extraction.cpython-312.pyc +0 -0
  24. deployment/src/features/__pycache__/region_extraction.cpython-312.pyc +0 -0
  25. deployment/src/features/feature_extraction.py +0 -485
  26. deployment/src/features/region_extraction.py +0 -226
  27. deployment/src/inference/__init__.py +0 -5
  28. deployment/src/inference/__pycache__/__init__.cpython-312.pyc +0 -0
  29. deployment/src/inference/__pycache__/pipeline.cpython-312.pyc +0 -0
  30. deployment/src/inference/pipeline.py +0 -359
  31. deployment/src/models/__init__.py +0 -19
  32. deployment/src/models/__pycache__/__init__.cpython-312.pyc +0 -0
  33. deployment/src/models/__pycache__/decoder.cpython-312.pyc +0 -0
  34. deployment/src/models/__pycache__/encoder.cpython-312.pyc +0 -0
  35. deployment/src/models/__pycache__/losses.cpython-312.pyc +0 -0
  36. deployment/src/models/__pycache__/network.cpython-312.pyc +0 -0
  37. deployment/src/models/decoder.py +0 -186
  38. deployment/src/models/encoder.py +0 -75
  39. deployment/src/models/losses.py +0 -168
  40. deployment/src/models/network.py +0 -133
  41. deployment/src/training/__init__.py +0 -24
  42. deployment/src/training/__pycache__/__init__.cpython-312.pyc +0 -0
  43. deployment/src/training/__pycache__/classifier.cpython-312.pyc +0 -0
  44. deployment/src/training/__pycache__/metrics.cpython-312.pyc +0 -0
  45. deployment/src/training/__pycache__/trainer.cpython-312.pyc +0 -0
  46. deployment/src/training/classifier.py +0 -282
  47. deployment/src/training/metrics.py +0 -305
  48. deployment/src/training/trainer.py +0 -450
  49. deployment/src/utils/__init__.py +0 -28
  50. deployment/src/utils/__pycache__/__init__.cpython-312.pyc +0 -0
deployment/app.py DELETED
@@ -1,231 +0,0 @@
1
- """
2
- Document Forgery Detection - Gradio Interface for Hugging Face Spaces
3
-
4
- This app provides a web interface for detecting and classifying document forgeries.
5
- """
6
-
7
- import gradio as gr
8
- import torch
9
- import cv2
10
- import numpy as np
11
- from PIL import Image
12
- import json
13
- from pathlib import Path
14
- import sys
15
-
16
- # Add src to path
17
- sys.path.insert(0, str(Path(__file__).parent))
18
-
19
- from src.models import get_model
20
- from src.config import get_config
21
- from src.data.preprocessing import DocumentPreprocessor
22
- from src.data.augmentation import DatasetAwareAugmentation
23
- from src.features.region_extraction import get_mask_refiner, get_region_extractor
24
- from src.features.feature_extraction import get_feature_extractor
25
- from src.training.classifier import ForgeryClassifier
26
-
27
- # Class names
28
- CLASS_NAMES = {0: 'Copy-Move', 1: 'Splicing', 2: 'Generation'}
29
- CLASS_COLORS = {
30
- 0: (255, 0, 0), # Red for Copy-Move
31
- 1: (0, 255, 0), # Green for Splicing
32
- 2: (0, 0, 255) # Blue for Generation
33
- }
34
-
35
-
36
- class ForgeryDetector:
37
- """Main forgery detection pipeline"""
38
-
39
- def __init__(self):
40
- print("Loading models...")
41
-
42
- # Load config
43
- self.config = get_config('config.yaml')
44
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
45
-
46
- # Load segmentation model
47
- self.model = get_model(self.config).to(self.device)
48
- checkpoint = torch.load('models/segmentation_model.pth', map_location=self.device)
49
- self.model.load_state_dict(checkpoint['model_state_dict'])
50
- self.model.eval()
51
-
52
- # Load classifier
53
- self.classifier = ForgeryClassifier(self.config)
54
- self.classifier.load('models/classifier')
55
-
56
- # Initialize components
57
- self.preprocessor = DocumentPreprocessor(self.config, 'doctamper')
58
- self.augmentation = DatasetAwareAugmentation(self.config, 'doctamper', is_training=False)
59
- self.mask_refiner = get_mask_refiner(self.config)
60
- self.region_extractor = get_region_extractor(self.config)
61
- self.feature_extractor = get_feature_extractor(self.config, is_text_document=True)
62
-
63
- print("✓ Models loaded successfully!")
64
-
65
- def detect(self, image):
66
- """
67
- Detect forgeries in document image
68
-
69
- Args:
70
- image: PIL Image or numpy array
71
-
72
- Returns:
73
- overlay_image: Image with detection overlay
74
- results_json: Detection results as JSON
75
- """
76
- # Convert PIL to numpy
77
- if isinstance(image, Image.Image):
78
- image = np.array(image)
79
-
80
- # Convert to RGB
81
- if len(image.shape) == 2:
82
- image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
83
- elif image.shape[2] == 4:
84
- image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
85
-
86
- original_image = image.copy()
87
-
88
- # Preprocess
89
- preprocessed, _ = self.preprocessor(image, None)
90
-
91
- # Augment
92
- augmented = self.augmentation(preprocessed, None)
93
- image_tensor = augmented['image'].unsqueeze(0).to(self.device)
94
-
95
- # Run localization
96
- with torch.no_grad():
97
- logits, decoder_features = self.model(image_tensor)
98
- prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
99
-
100
- # Refine mask
101
- binary_mask = (prob_map > 0.5).astype(np.uint8)
102
- refined_mask = self.mask_refiner.refine(binary_mask, original_size=original_image.shape[:2])
103
-
104
- # Extract regions
105
- regions = self.region_extractor.extract(refined_mask, prob_map, original_image)
106
-
107
- # Classify regions
108
- results = []
109
- for region in regions:
110
- # Extract features
111
- features = self.feature_extractor.extract(
112
- preprocessed,
113
- region['region_mask'],
114
- [f.cpu() for f in decoder_features]
115
- )
116
-
117
- # Classify
118
- predictions, confidences = self.classifier.predict(features)
119
- forgery_type = int(predictions[0])
120
- confidence = float(confidences[0])
121
-
122
- if confidence > 0.6: # Confidence threshold
123
- results.append({
124
- 'region_id': region['region_id'],
125
- 'bounding_box': region['bounding_box'],
126
- 'forgery_type': CLASS_NAMES[forgery_type],
127
- 'confidence': confidence
128
- })
129
-
130
- # Create visualization
131
- overlay = self._create_overlay(original_image, results)
132
-
133
- # Create JSON response
134
- json_results = {
135
- 'num_detections': len(results),
136
- 'detections': results,
137
- 'model_info': {
138
- 'segmentation_dice': '75%',
139
- 'classifier_accuracy': '92%'
140
- }
141
- }
142
-
143
- return overlay, json_results
144
-
145
- def _create_overlay(self, image, results):
146
- """Create overlay visualization"""
147
- overlay = image.copy()
148
-
149
- # Draw bounding boxes and labels
150
- for result in results:
151
- bbox = result['bounding_box']
152
- x, y, w, h = bbox
153
-
154
- forgery_type = result['forgery_type']
155
- confidence = result['confidence']
156
-
157
- # Get color
158
- forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0]
159
- color = CLASS_COLORS[forgery_id]
160
-
161
- # Draw rectangle
162
- cv2.rectangle(overlay, (x, y), (x+w, y+h), color, 2)
163
-
164
- # Draw label
165
- label = f"{forgery_type}: {confidence:.1%}"
166
- label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
167
- cv2.rectangle(overlay, (x, y-label_size[1]-10), (x+label_size[0], y), color, -1)
168
- cv2.putText(overlay, label, (x, y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
169
-
170
- # Add legend
171
- if len(results) > 0:
172
- legend_y = 30
173
- cv2.putText(overlay, f"Detected {len(results)} forgery region(s)",
174
- (10, legend_y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
175
-
176
- return overlay
177
-
178
-
179
- # Initialize detector
180
- detector = ForgeryDetector()
181
-
182
-
183
- def detect_forgery(image):
184
- """Gradio interface function"""
185
- try:
186
- overlay, results = detector.detect(image)
187
- return overlay, json.dumps(results, indent=2)
188
- except Exception as e:
189
- return None, f"Error: {str(e)}"
190
-
191
-
192
- # Create Gradio interface
193
- demo = gr.Interface(
194
- fn=detect_forgery,
195
- inputs=gr.Image(type="pil", label="Upload Document Image"),
196
- outputs=[
197
- gr.Image(type="numpy", label="Detection Result"),
198
- gr.JSON(label="Detection Details")
199
- ],
200
- title="📄 Document Forgery Detector",
201
- description="""
202
- Upload a document image to detect and classify forgeries.
203
-
204
- **Supported Forgery Types:**
205
- - 🔴 Copy-Move: Duplicated regions within the document
206
- - 🟢 Splicing: Content from different sources
207
- - 🔵 Generation: AI-generated or synthesized content
208
-
209
- **Model Performance:**
210
- - Localization: 75% Dice Score
211
- - Classification: 92% Accuracy
212
- """,
213
- examples=[
214
- ["examples/sample1.jpg"],
215
- ["examples/sample2.jpg"],
216
- ],
217
- article="""
218
- ### About
219
- This model uses a hybrid deep learning approach:
220
- 1. **Localization**: MobileNetV3-Small + UNet-Lite (detects WHERE)
221
- 2. **Classification**: LightGBM with hybrid features (detects WHAT)
222
-
223
- Trained on DocTamper dataset (140K samples).
224
- """,
225
- theme=gr.themes.Soft(),
226
- allow_flagging="never"
227
- )
228
-
229
-
230
- if __name__ == "__main__":
231
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/config.yaml DELETED
@@ -1,297 +0,0 @@
1
- # Hybrid Document Forgery Detection - Configuration
2
-
3
- # System Settings
4
- system:
5
- device: cuda # cuda or cpu
6
- num_workers: 0 # Reduced to avoid multiprocessing errors
7
- pin_memory: true
8
- seed: 42
9
-
10
- # Data Settings
11
- data:
12
- image_size: 384
13
- batch_size: 8 # Reduced for 16GB RAM
14
- num_classes: 3 # copy_move, splicing, text_substitution
15
-
16
- # Dataset paths
17
- datasets:
18
- doctamper:
19
- path: datasets/DocTamper
20
- type: lmdb
21
- has_pixel_mask: true
22
- min_region_area: 0.001 # 0.1%
23
-
24
- rtm:
25
- path: datasets/RealTextManipulation
26
- type: folder
27
- has_pixel_mask: true
28
- min_region_area: 0.0003 # 0.03%
29
-
30
- casia:
31
- path: datasets/CASIA 1.0 dataset
32
- type: folder
33
- has_pixel_mask: false
34
- min_region_area: 0.001 # 0.1%
35
- skip_deskew: true
36
- skip_denoising: true
37
-
38
- receipts:
39
- path: datasets/findit2
40
- type: folder
41
- has_pixel_mask: true
42
- min_region_area: 0.0005 # 0.05%
43
-
44
- fcd:
45
- path: datasets/DocTamper/DocTamperV1-FCD
46
- type: lmdb
47
- has_pixel_mask: true
48
- min_region_area: 0.00035 # 0.035% (larger forgeries, keep 99%)
49
-
50
- scd:
51
- path: datasets/DocTamper/DocTamperV1-SCD
52
- type: lmdb
53
- has_pixel_mask: true
54
- min_region_area: 0.00009 # 0.009% (small forgeries, keep 91.5%)
55
-
56
- # Chunked training for DocTamper (RAM constraint)
57
- chunked_training:
58
- enabled: true
59
- dataset: doctamper
60
- chunks:
61
- - {start: 0.0, end: 0.25, name: "chunk_1"}
62
- - {start: 0.25, end: 0.5, name: "chunk_2"}
63
- - {start: 0.5, end: 0.75, name: "chunk_3"}
64
- - {start: 0.75, end: 1.0, name: "chunk_4"}
65
-
66
- # Mixed dataset training (TrainingSet + FCD + SCD)
67
- mixing_ratios:
68
- doctamper: 0.70 # 70% TrainingSet (maintains baseline)
69
- scd: 0.20 # 20% SCD (handles small forgeries, 0.88% avg)
70
- fcd: 0.10 # 10% FCD (adds diversity, 3.55% avg)
71
-
72
- # Preprocessing
73
- preprocessing:
74
- deskew: true
75
- normalize: true
76
- noise_threshold: 15.0 # Laplacian variance threshold
77
- median_filter_size: 3
78
- gaussian_sigma: 0.8
79
-
80
- # Dataset-aware preprocessing
81
- dataset_specific:
82
- casia:
83
- deskew: false
84
- denoising: false
85
-
86
- # Augmentation (Training only)
87
- augmentation:
88
- enabled: true
89
-
90
- # Common augmentations
91
- common:
92
- - {type: "noise", prob: 0.3}
93
- - {type: "motion_blur", prob: 0.2}
94
- - {type: "jpeg_compression", prob: 0.3, quality: [60, 95]}
95
- - {type: "lighting", prob: 0.3}
96
- - {type: "perspective", prob: 0.2}
97
-
98
- # Dataset-specific augmentations
99
- receipts:
100
- - {type: "stain", prob: 0.2}
101
- - {type: "fold", prob: 0.15}
102
-
103
- # Model Architecture
104
- model:
105
- # Encoder
106
- encoder:
107
- name: mobilenetv3_small_100
108
- pretrained: true
109
- features_only: true
110
-
111
- # Decoder
112
- decoder:
113
- name: unet_lite
114
- channels: [16, 24, 40, 48, 96] # MobileNetV3-Small feature channels
115
- upsampling: bilinear
116
- use_depthwise_separable: true
117
-
118
- # Output
119
- output_channels: 1 # Binary forgery mask
120
-
121
- # Loss Function
122
- loss:
123
- # Dataset-aware loss
124
- use_dice: true # Only for datasets with pixel masks
125
- bce_weight: 1.0
126
- dice_weight: 1.0
127
-
128
- # Training
129
- training:
130
- epochs: 30 # Per chunk (increased for single-pass training)
131
- learning_rate: 0.001 # Higher initial LR for faster convergence
132
- weight_decay: 0.0001 # Slight increase for better regularization
133
-
134
- # Optimizer
135
- optimizer: adamw
136
-
137
- # Scheduler
138
- scheduler:
139
- type: cosine_annealing_warm_restarts
140
- T_0: 10 # Restart every 10 epochs
141
- T_mult: 2 # Double restart period each time
142
- warmup_epochs: 3 # Warmup for first 3 epochs
143
- min_lr: 0.00001 # End at 1/100th of initial LR
144
-
145
- # Early stopping
146
- early_stopping:
147
- enabled: true
148
- patience: 10 # Increased to allow more exploration
149
- min_delta: 0.0005 # Accept smaller improvements (0.05%)
150
- restore_best_weights: true # Restore best model when stopping
151
- monitor: val_dice
152
- mode: max
153
-
154
- # Checkpointing
155
- checkpoint:
156
- save_best: true
157
- save_every: 5 # Save every 5 epochs
158
- save_last: true # Also save last checkpoint
159
- monitor: val_dice
160
-
161
- # Mask Refinement
162
- mask_refinement:
163
- threshold: 0.5
164
- morphology:
165
- closing_kernel: 5
166
- opening_kernel: 3
167
-
168
- # Adaptive thresholds per dataset
169
- min_region_area:
170
- rtm: 0.0003
171
- receipts: 0.0005
172
- default: 0.001
173
-
174
- # Feature Extraction
175
- features:
176
- # Deep features
177
- deep:
178
- enabled: true
179
- pooling: gap # Global Average Pooling
180
-
181
- # Statistical & Shape features
182
- statistical:
183
- enabled: true
184
- features:
185
- - area
186
- - perimeter
187
- - aspect_ratio
188
- - solidity
189
- - eccentricity
190
- - entropy
191
-
192
- # Frequency-domain features
193
- frequency:
194
- enabled: true
195
- features:
196
- - dct_coefficients
197
- - high_frequency_energy
198
- - wavelet_energy
199
-
200
- # Noise & ELA features
201
- noise:
202
- enabled: true
203
- features:
204
- - ela_mean
205
- - ela_variance
206
- - noise_residual
207
-
208
- # OCR-consistency features (text documents only)
209
- ocr:
210
- enabled: true
211
- gated: true # Only for text documents
212
- features:
213
- - confidence_deviation
214
- - spacing_irregularity
215
- - stroke_width_variation
216
-
217
- # Feature normalization
218
- normalization:
219
- method: standard_scaler
220
- handle_missing: true
221
-
222
- # LightGBM Classifier
223
- classifier:
224
- model: lightgbm
225
- params:
226
- objective: multiclass
227
- num_class: 3
228
- boosting_type: gbdt
229
- num_leaves: 31
230
- learning_rate: 0.05
231
- n_estimators: 200
232
- max_depth: 7
233
- min_child_samples: 20
234
- subsample: 0.8
235
- colsample_bytree: 0.8
236
- reg_alpha: 0.1
237
- reg_lambda: 0.1
238
- random_state: 42
239
-
240
- # Confidence threshold
241
- confidence_threshold: 0.6
242
-
243
- # Metrics
244
- metrics:
245
- # Localization metrics (only for datasets with pixel masks)
246
- localization:
247
- - iou
248
- - dice
249
- - precision
250
- - recall
251
-
252
- # Classification metrics
253
- classification:
254
- - accuracy
255
- - f1_score
256
- - precision
257
- - recall
258
- - confusion_matrix
259
-
260
- # Dataset-aware metric computation
261
- compute_localization:
262
- doctamper: true
263
- rtm: true
264
- casia: false
265
- receipts: true
266
-
267
- # Outputs
268
- outputs:
269
- base_dir: outputs
270
-
271
- # Subdirectories
272
- checkpoints: outputs/checkpoints
273
- logs: outputs/logs
274
- plots: outputs/plots
275
- results: outputs/results
276
-
277
- # Visualization
278
- visualization:
279
- save_mask: true
280
- save_overlay: true
281
- save_json: true
282
- overlay_alpha: 0.5
283
- colormap: jet
284
-
285
- # Deployment
286
- deployment:
287
- export_onnx: true
288
- onnx_path: outputs/model.onnx
289
- quantization: false
290
- opset_version: 14
291
-
292
- # Logging
293
- logging:
294
- level: INFO
295
- tensorboard: true
296
- csv: true
297
- console: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/models/best_doctamper.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d049ca9d4dc28c8d01519f8faab1ec131a05de877da9703ee5bb0e9322095ad2
3
- size 14283981
 
 
 
 
deployment/models/classifier/classifier_metadata.json DELETED
@@ -1,821 +0,0 @@
1
- {
2
- "confidence_threshold": 0.6,
3
- "class_names": [
4
- "copy_move",
5
- "splicing",
6
- "text_substitution"
7
- ],
8
- "feature_names": [
9
- "deep_0",
10
- "deep_1",
11
- "deep_2",
12
- "deep_3",
13
- "deep_4",
14
- "deep_5",
15
- "deep_6",
16
- "deep_7",
17
- "deep_8",
18
- "deep_9",
19
- "deep_10",
20
- "deep_11",
21
- "deep_12",
22
- "deep_13",
23
- "deep_14",
24
- "deep_15",
25
- "deep_16",
26
- "deep_17",
27
- "deep_18",
28
- "deep_19",
29
- "deep_20",
30
- "deep_21",
31
- "deep_22",
32
- "deep_23",
33
- "deep_24",
34
- "deep_25",
35
- "deep_26",
36
- "deep_27",
37
- "deep_28",
38
- "deep_29",
39
- "deep_30",
40
- "deep_31",
41
- "deep_32",
42
- "deep_33",
43
- "deep_34",
44
- "deep_35",
45
- "deep_36",
46
- "deep_37",
47
- "deep_38",
48
- "deep_39",
49
- "deep_40",
50
- "deep_41",
51
- "deep_42",
52
- "deep_43",
53
- "deep_44",
54
- "deep_45",
55
- "deep_46",
56
- "deep_47",
57
- "deep_48",
58
- "deep_49",
59
- "deep_50",
60
- "deep_51",
61
- "deep_52",
62
- "deep_53",
63
- "deep_54",
64
- "deep_55",
65
- "deep_56",
66
- "deep_57",
67
- "deep_58",
68
- "deep_59",
69
- "deep_60",
70
- "deep_61",
71
- "deep_62",
72
- "deep_63",
73
- "deep_64",
74
- "deep_65",
75
- "deep_66",
76
- "deep_67",
77
- "deep_68",
78
- "deep_69",
79
- "deep_70",
80
- "deep_71",
81
- "deep_72",
82
- "deep_73",
83
- "deep_74",
84
- "deep_75",
85
- "deep_76",
86
- "deep_77",
87
- "deep_78",
88
- "deep_79",
89
- "deep_80",
90
- "deep_81",
91
- "deep_82",
92
- "deep_83",
93
- "deep_84",
94
- "deep_85",
95
- "deep_86",
96
- "deep_87",
97
- "deep_88",
98
- "deep_89",
99
- "deep_90",
100
- "deep_91",
101
- "deep_92",
102
- "deep_93",
103
- "deep_94",
104
- "deep_95",
105
- "deep_96",
106
- "deep_97",
107
- "deep_98",
108
- "deep_99",
109
- "deep_100",
110
- "deep_101",
111
- "deep_102",
112
- "deep_103",
113
- "deep_104",
114
- "deep_105",
115
- "deep_106",
116
- "deep_107",
117
- "deep_108",
118
- "deep_109",
119
- "deep_110",
120
- "deep_111",
121
- "deep_112",
122
- "deep_113",
123
- "deep_114",
124
- "deep_115",
125
- "deep_116",
126
- "deep_117",
127
- "deep_118",
128
- "deep_119",
129
- "deep_120",
130
- "deep_121",
131
- "deep_122",
132
- "deep_123",
133
- "deep_124",
134
- "deep_125",
135
- "deep_126",
136
- "deep_127",
137
- "deep_128",
138
- "deep_129",
139
- "deep_130",
140
- "deep_131",
141
- "deep_132",
142
- "deep_133",
143
- "deep_134",
144
- "deep_135",
145
- "deep_136",
146
- "deep_137",
147
- "deep_138",
148
- "deep_139",
149
- "deep_140",
150
- "deep_141",
151
- "deep_142",
152
- "deep_143",
153
- "deep_144",
154
- "deep_145",
155
- "deep_146",
156
- "deep_147",
157
- "deep_148",
158
- "deep_149",
159
- "deep_150",
160
- "deep_151",
161
- "deep_152",
162
- "deep_153",
163
- "deep_154",
164
- "deep_155",
165
- "deep_156",
166
- "deep_157",
167
- "deep_158",
168
- "deep_159",
169
- "deep_160",
170
- "deep_161",
171
- "deep_162",
172
- "deep_163",
173
- "deep_164",
174
- "deep_165",
175
- "deep_166",
176
- "deep_167",
177
- "deep_168",
178
- "deep_169",
179
- "deep_170",
180
- "deep_171",
181
- "deep_172",
182
- "deep_173",
183
- "deep_174",
184
- "deep_175",
185
- "deep_176",
186
- "deep_177",
187
- "deep_178",
188
- "deep_179",
189
- "deep_180",
190
- "deep_181",
191
- "deep_182",
192
- "deep_183",
193
- "deep_184",
194
- "deep_185",
195
- "deep_186",
196
- "deep_187",
197
- "deep_188",
198
- "deep_189",
199
- "deep_190",
200
- "deep_191",
201
- "deep_192",
202
- "deep_193",
203
- "deep_194",
204
- "deep_195",
205
- "deep_196",
206
- "deep_197",
207
- "deep_198",
208
- "deep_199",
209
- "deep_200",
210
- "deep_201",
211
- "deep_202",
212
- "deep_203",
213
- "deep_204",
214
- "deep_205",
215
- "deep_206",
216
- "deep_207",
217
- "deep_208",
218
- "deep_209",
219
- "deep_210",
220
- "deep_211",
221
- "deep_212",
222
- "deep_213",
223
- "deep_214",
224
- "deep_215",
225
- "deep_216",
226
- "deep_217",
227
- "deep_218",
228
- "deep_219",
229
- "deep_220",
230
- "deep_221",
231
- "deep_222",
232
- "deep_223",
233
- "deep_224",
234
- "deep_225",
235
- "deep_226",
236
- "deep_227",
237
- "deep_228",
238
- "deep_229",
239
- "deep_230",
240
- "deep_231",
241
- "deep_232",
242
- "deep_233",
243
- "deep_234",
244
- "deep_235",
245
- "deep_236",
246
- "deep_237",
247
- "deep_238",
248
- "deep_239",
249
- "deep_240",
250
- "deep_241",
251
- "deep_242",
252
- "deep_243",
253
- "deep_244",
254
- "deep_245",
255
- "deep_246",
256
- "deep_247",
257
- "deep_248",
258
- "deep_249",
259
- "deep_250",
260
- "deep_251",
261
- "deep_252",
262
- "deep_253",
263
- "deep_254",
264
- "deep_255",
265
- "area",
266
- "perimeter",
267
- "aspect_ratio",
268
- "solidity",
269
- "eccentricity",
270
- "entropy",
271
- "dct_mean",
272
- "dct_std",
273
- "high_freq_energy",
274
- "wavelet_cA",
275
- "wavelet_cH",
276
- "wavelet_cV",
277
- "wavelet_cD",
278
- "wavelet_entropy_H",
279
- "wavelet_entropy_V",
280
- "wavelet_entropy_D",
281
- "ela_mean",
282
- "ela_var",
283
- "ela_max",
284
- "noise_residual_mean",
285
- "noise_residual_var",
286
- "ocr_conf_mean",
287
- "ocr_conf_std",
288
- "spacing_irregularity",
289
- "text_density",
290
- "stroke_mean",
291
- "stroke_std"
292
- ],
293
- "feature_importance": [
294
- 151.5697784423828,
295
- 8.955550193786621,
296
- 32.9064998626709,
297
- 151.0029697418213,
298
- 19.174699783325195,
299
- 157.97871017456055,
300
- 45.12229919433594,
301
- 19.72992992401123,
302
- 105.08611106872559,
303
- 0.0,
304
- 148.97894096374512,
305
- 35.71831035614014,
306
- 50.15155029296875,
307
- 71.74272060394287,
308
- 43.958970069885254,
309
- 129.9348111152649,
310
- 27.99122953414917,
311
- 61.592909812927246,
312
- 295.4245676994324,
313
- 61.00736045837402,
314
- 28.548550128936768,
315
- 0.0,
316
- 54.50248908996582,
317
- 93.74169921875,
318
- 120.9488091468811,
319
- 148.32109832763672,
320
- 30.55735969543457,
321
- 59.058170318603516,
322
- 82.7595911026001,
323
- 49.24997901916504,
324
- 0.0,
325
- 23.502280235290527,
326
- 392.399715423584,
327
- 551.6174192428589,
328
- 0.0,
329
- 50.8812894821167,
330
- 60.7820405960083,
331
- 78.98891925811768,
332
- 0.0,
333
- 9.173580169677734,
334
- 631.6932668685913,
335
- 42.097740173339844,
336
- 305.0536642074585,
337
- 416.94709300994873,
338
- 92.70171976089478,
339
- 66.76712036132812,
340
- 1435.1315097808838,
341
- 0.0,
342
- 126.6096019744873,
343
- 111.61981964111328,
344
- 124.68002033233643,
345
- 46.16030025482178,
346
- 12.660099983215332,
347
- 115.48313999176025,
348
- 86.43069076538086,
349
- 16.674290657043457,
350
- 110.49228954315186,
351
- 0.0,
352
- 98.00746059417725,
353
- 98.95538091659546,
354
- 41.432090759277344,
355
- 11.24590015411377,
356
- 65.1699800491333,
357
- 9.251449584960938,
358
- 100.24416923522949,
359
- 109.5842399597168,
360
- 83.83185005187988,
361
- 196.82151079177856,
362
- 0.0,
363
- 455.4096431732178,
364
- 120.69411087036133,
365
- 23.130990028381348,
366
- 18.21858024597168,
367
- 69.65920066833496,
368
- 82.33455085754395,
369
- 0.0,
370
- 82.21379089355469,
371
- 119.78182220458984,
372
- 65.07565069198608,
373
- 53.62262964248657,
374
- 247.53085803985596,
375
- 144.45191097259521,
376
- 38.63272047042847,
377
- 82.24878883361816,
378
- 60.303489685058594,
379
- 8.717499732971191,
380
- 412.6672077178955,
381
- 54.25755023956299,
382
- 0.0,
383
- 23.141600608825684,
384
- 62.88635063171387,
385
- 144.1060814857483,
386
- 352.47050952911377,
387
- 23.701799392700195,
388
- 180.19217205047607,
389
- 74.43132972717285,
390
- 0.0,
391
- 92.36961936950684,
392
- 418.40467262268066,
393
- 163.96015119552612,
394
- 136.4917197227478,
395
- 8.362039566040039,
396
- 10.378399848937988,
397
- 30.465800285339355,
398
- 47.935009479522705,
399
- 28.957390308380127,
400
- 61.46374034881592,
401
- 11.319199562072754,
402
- 142.72890949249268,
403
- 0.0,
404
- 140.48277807235718,
405
- 59.3709602355957,
406
- 9.517510414123535,
407
- 22.945700645446777,
408
- 85.35987043380737,
409
- 25.964330196380615,
410
- 18.778900146484375,
411
- 79.01968955993652,
412
- 74.93959999084473,
413
- 0.0,
414
- 36.94928026199341,
415
- 47.99788188934326,
416
- 84.99461078643799,
417
- 65.24014949798584,
418
- 128.61994075775146,
419
- 71.96449947357178,
420
- 0.0,
421
- 60.59358024597168,
422
- 0.0,
423
- 144.41107177734375,
424
- 119.25859117507935,
425
- 0.0,
426
- 29.235299110412598,
427
- 75.50409030914307,
428
- 0.0,
429
- 0.0,
430
- 133.30608654022217,
431
- 50.813700675964355,
432
- 7.879730224609375,
433
- 80.23723936080933,
434
- 28.72357988357544,
435
- 85.63543939590454,
436
- 88.70749998092651,
437
- 0.0,
438
- 38.14083003997803,
439
- 10.110199928283691,
440
- 223.45562982559204,
441
- 0.0,
442
- 189.3048586845398,
443
- 11.311699867248535,
444
- 87.91403198242188,
445
- 45.88195037841797,
446
- 57.93142032623291,
447
- 621.7998056411743,
448
- 151.6710205078125,
449
- 55.90662956237793,
450
- 310.18284845352173,
451
- 0.0,
452
- 37.39265060424805,
453
- 142.64961051940918,
454
- 86.32072973251343,
455
- 167.73473930358887,
456
- 135.1251916885376,
457
- 67.87245082855225,
458
- 25.777999877929688,
459
- 82.70090961456299,
460
- 160.77113008499146,
461
- 0.0,
462
- 109.31087112426758,
463
- 36.81955051422119,
464
- 21.341699600219727,
465
- 39.508570194244385,
466
- 0.0,
467
- 12.186599731445312,
468
- 52.13583946228027,
469
- 242.86930990219116,
470
- 0.0,
471
- 27.03380012512207,
472
- 11.51550006866455,
473
- 102.65280055999756,
474
- 8.523859977722168,
475
- 105.87909126281738,
476
- 0.0,
477
- 191.5287847518921,
478
- 16.16029930114746,
479
- 43.0986704826355,
480
- 0.0,
481
- 54.736299991607666,
482
- 145.84991836547852,
483
- 62.068660736083984,
484
- 72.52587032318115,
485
- 81.85652828216553,
486
- 25.7001895904541,
487
- 36.71660041809082,
488
- 78.73716068267822,
489
- 145.95945167541504,
490
- 146.47522068023682,
491
- 23.559300422668457,
492
- 39.53977966308594,
493
- 194.42743015289307,
494
- 66.81133842468262,
495
- 0.0,
496
- 156.6984510421753,
497
- 671.7460441589355,
498
- 38.70531988143921,
499
- 0.0,
500
- 356.6153998374939,
501
- 0.0,
502
- 0.0,
503
- 166.1197419166565,
504
- 0.0,
505
- 73.76784992218018,
506
- 82.50808954238892,
507
- 249.50656414031982,
508
- 21.96009922027588,
509
- 43.69997024536133,
510
- 0.0,
511
- 95.96379089355469,
512
- 80.70125961303711,
513
- 0.0,
514
- 0.0,
515
- 31.88983964920044,
516
- 301.3817310333252,
517
- 0.0,
518
- 15.77073049545288,
519
- 396.3671169281006,
520
- 83.96024990081787,
521
- 265.5281705856323,
522
- 47.332489013671875,
523
- 0.0,
524
- 268.84939098358154,
525
- 58.15328025817871,
526
- 31.172239780426025,
527
- 30.765819549560547,
528
- 10.469799995422363,
529
- 16.379559993743896,
530
- 28.163670539855957,
531
- 199.17678022384644,
532
- 112.94913101196289,
533
- 5.905869960784912,
534
- 719.0067505836487,
535
- 157.29250049591064,
536
- 92.6033205986023,
537
- 73.79398918151855,
538
- 24.25756072998047,
539
- 0.0,
540
- 31.15705966949463,
541
- 50.47894048690796,
542
- 73.0004301071167,
543
- 131.88961124420166,
544
- 0.0,
545
- 44.40921926498413,
546
- 59.08494997024536,
547
- 60.722700119018555,
548
- 108.21477127075195,
549
- 78.56892967224121,
550
- 486.87088108062744,
551
- 235.95975875854492,
552
- 1809.188328742981,
553
- 396.9979257583618,
554
- 441.098051071167,
555
- 218.83313035964966,
556
- 265.3398394584656,
557
- 595.3824620246887,
558
- 6126.337133407593,
559
- 3245.946928501129,
560
- 170.21856021881104,
561
- 262.3172616958618,
562
- 98.2627010345459,
563
- 146.45634078979492,
564
- 135.70992946624756,
565
- 34.09130001068115,
566
- 14156.531812667847,
567
- 227.55861043930054,
568
- 121.6160798072815,
569
- 409.0565061569214,
570
- 282.5465121269226,
571
- 481.5555577278137,
572
- 291.560200214386,
573
- 797.986575126648,
574
- 246.7717628479004,
575
- 6129.707794189453,
576
- 957.9258012771606,
577
- 4484.775461196899,
578
- 5722.659900188446,
579
- 393.6506414413452,
580
- 882.6219139099121,
581
- 264.54289960861206,
582
- 79.82537126541138,
583
- 228.20479917526245,
584
- 155.19043970108032,
585
- 319.6992588043213,
586
- 391.5327887535095,
587
- 2005.5544757843018,
588
- 0.0,
589
- 1028.816568851471,
590
- 577.8704214096069,
591
- 159.98183917999268,
592
- 138.31745052337646,
593
- 115.26242113113403,
594
- 117.50687980651855,
595
- 0.0,
596
- 270.78229904174805,
597
- 300.6347818374634,
598
- 164.85750007629395,
599
- 542.5208883285522,
600
- 10002.710669994354,
601
- 502.5058374404907,
602
- 6619.406281471252,
603
- 194.39686965942383,
604
- 0.0,
605
- 239.30037021636963,
606
- 129.93587112426758,
607
- 149.23295974731445,
608
- 57.12141132354736,
609
- 152.30589962005615,
610
- 590.8979144096375,
611
- 125.51728057861328,
612
- 216.1852297782898,
613
- 4445.603507041931,
614
- 0.0,
615
- 97.60689973831177,
616
- 497.5633420944214,
617
- 699.1335229873657,
618
- 159.68335962295532,
619
- 127.93899154663086,
620
- 148.00423860549927,
621
- 385.3561215400696,
622
- 1255.3204145431519,
623
- 170.33005905151367,
624
- 564.577874660492,
625
- 1513.99400806427,
626
- 254.163161277771,
627
- 782.5869626998901,
628
- 166.38124132156372,
629
- 4800.666547775269,
630
- 271.63431215286255,
631
- 225.10281944274902,
632
- 674.5281610488892,
633
- 198.04610967636108,
634
- 4262.1786432266235,
635
- 0.0,
636
- 0.0,
637
- 749.2932777404785,
638
- 50.16440010070801,
639
- 350.71588039398193,
640
- 169.4644889831543,
641
- 3843.8212938308716,
642
- 0.0,
643
- 0.0,
644
- 1463.2607378959656,
645
- 0.0,
646
- 914.5419778823853,
647
- 213.03434944152832,
648
- 32.90106964111328,
649
- 119.6264705657959,
650
- 137.204270362854,
651
- 359.72862100601196,
652
- 75.62465047836304,
653
- 446.62164974212646,
654
- 105.61136054992676,
655
- 2787.228641986847,
656
- 311.6961917877197,
657
- 156.06305074691772,
658
- 1498.6027584075928,
659
- 185.69973182678223,
660
- 147.8509397506714,
661
- 12.531700134277344,
662
- 0.0,
663
- 192.53613948822021,
664
- 424.5432171821594,
665
- 259.268039226532,
666
- 175.13502979278564,
667
- 281.5383825302124,
668
- 299.1759967803955,
669
- 227.893488407135,
670
- 136.72871112823486,
671
- 416.3120012283325,
672
- 115.03175830841064,
673
- 0.0,
674
- 144.02852058410645,
675
- 208.2749309539795,
676
- 160.34006214141846,
677
- 109.58282947540283,
678
- 1500.150812625885,
679
- 4945.450592041016,
680
- 2852.855231285095,
681
- 881.7318058013916,
682
- 397.0553340911865,
683
- 315.55763959884644,
684
- 2086.7152404785156,
685
- 1611.37087059021,
686
- 2103.3109679222107,
687
- 3135.3377957344055,
688
- 2692.6771001815796,
689
- 4584.85631608963,
690
- 1700.0699429512024,
691
- 883.6995916366577,
692
- 33464.33708667755,
693
- 574.8801603317261,
694
- 2229.160650253296,
695
- 379.5017247200012,
696
- 905.5721397399902,
697
- 493.963942527771,
698
- 4049.96994638443,
699
- 189.95257091522217,
700
- 61.00449848175049,
701
- 450.8264832496643,
702
- 398.1711621284485,
703
- 38847.667073726654,
704
- 1835.184115409851,
705
- 2697.096595287323,
706
- 4710.6771783828735,
707
- 5588.210665225983,
708
- 1004.0054593086243,
709
- 652.6680641174316,
710
- 2031.7795896530151,
711
- 367.2168278694153,
712
- 2698.1613121032715,
713
- 591.61465883255,
714
- 448.26813650131226,
715
- 849.9976563453674,
716
- 8368.735646724701,
717
- 414.3280692100525,
718
- 3544.0216879844666,
719
- 679.3534464836121,
720
- 247.58060026168823,
721
- 402.0281286239624,
722
- 5822.276999950409,
723
- 1743.6888279914856,
724
- 2081.8095812797546,
725
- 1696.2736263275146,
726
- 197.28233861923218,
727
- 3321.6009736061096,
728
- 2298.3414697647095,
729
- 2910.3161034584045,
730
- 296.4575996398926,
731
- 14755.747835159302,
732
- 6977.302089691162,
733
- 3608.7710394859314,
734
- 289.08115005493164,
735
- 2645.5259099006653,
736
- 158.54701232910156,
737
- 490.0809507369995,
738
- 1880.1874709129333,
739
- 1493.8953075408936,
740
- 609.5897555351257,
741
- 462.8165135383606,
742
- 243.31624794006348,
743
- 150.1076784133911,
744
- 6197.5719475746155,
745
- 1036.8616194725037,
746
- 5302.397746086121,
747
- 1388.753752708435,
748
- 2091.038170814514,
749
- 785.7442808151245,
750
- 377.4342908859253,
751
- 3640.3371028900146,
752
- 1029.8467602729797,
753
- 296.86861085891724,
754
- 1221.5854263305664,
755
- 535.2803363800049,
756
- 2508.307864189148,
757
- 3831.0581674575806,
758
- 2263.3348484039307,
759
- 926.5323433876038,
760
- 8959.179275035858,
761
- 309.04264068603516,
762
- 1767.5786666870117,
763
- 2107.6189522743225,
764
- 155.21375036239624,
765
- 378.6039876937866,
766
- 2220.862048149109,
767
- 1505.2828221321106,
768
- 517.8384418487549,
769
- 4313.928272247314,
770
- 342.4098491668701,
771
- 1310.0776271820068,
772
- 434.5597867965698,
773
- 2071.2271361351013,
774
- 0.0,
775
- 8595.476936340332,
776
- 202.46072053909302,
777
- 366.71736097335815,
778
- 7074.809521198273,
779
- 6.880340099334717,
780
- 1959.3085498809814,
781
- 636.0715098381042,
782
- 9.84004020690918,
783
- 386.9805417060852,
784
- 2382.4822087287903,
785
- 2317.9521684646606,
786
- 2793.7392020225525,
787
- 1188.6612939834595,
788
- 933.1099715232849,
789
- 4565.712460041046,
790
- 14641.29742860794,
791
- 15552.311092853546,
792
- 56185.89445209503,
793
- 97331.36661911011,
794
- 87548.01149320602,
795
- 521853.7248663902,
796
- 2643.261353492737,
797
- 20220.717566013336,
798
- 79148.93348503113,
799
- 17449.243332386017,
800
- 13258.27445936203,
801
- 6109.533164024353,
802
- 6781.56981420517,
803
- 3942.6140484809875,
804
- 8469.07410955429,
805
- 40318.94767665863,
806
- 156345.23027658463,
807
- 12197.998657226562,
808
- 22888.345291614532,
809
- 10946.28234910965,
810
- 204263.674387455,
811
- 229631.36437797546,
812
- 1945.9702520370483,
813
- 3069.6773653030396,
814
- 6425.405041217804,
815
- 508.55564069747925,
816
- 8993.14672756195,
817
- 0.0,
818
- 0.0,
819
- 0.0
820
- ]
821
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/models/classifier/lightgbm_model.txt DELETED
The diff for this file is too large to render. See raw diff
 
deployment/models/classifier/scaler.joblib DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:347b85c4f3e4bcbda0599f607a1ad5194c01655baca73b6e2ee72a9ba50dcf84
3
- size 13207
 
 
 
 
deployment/src/__init__.py DELETED
@@ -1,32 +0,0 @@
1
- """
2
- Hybrid Document Forgery Detection & Localization System
3
-
4
- A robust hybrid (Deep Learning + Classical ML) system for multi-type
5
- document forgery detection and localization.
6
-
7
- Architecture:
8
- - Deep Learning: MobileNetV3-Small + UNet-Lite for pixel-level localization
9
- - Classical ML: LightGBM for interpretable forgery classification
10
- """
11
-
12
- __version__ = "1.0.0"
13
-
14
- from .config import get_config
15
- from .models import get_model, get_loss_function
16
- from .data import get_dataset
17
- from .features import get_feature_extractor, get_mask_refiner, get_region_extractor
18
- from .training import get_trainer, get_metrics_tracker
19
- from .inference import get_pipeline
20
-
21
- __all__ = [
22
- 'get_config',
23
- 'get_model',
24
- 'get_loss_function',
25
- 'get_dataset',
26
- 'get_feature_extractor',
27
- 'get_mask_refiner',
28
- 'get_region_extractor',
29
- 'get_trainer',
30
- 'get_metrics_tracker',
31
- 'get_pipeline'
32
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/src/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (993 Bytes)
 
deployment/src/config/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- """Configuration module"""
2
-
3
- from .config_loader import Config, get_config
4
-
5
- __all__ = ['Config', 'get_config']
 
 
 
 
 
 
deployment/src/config/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (291 Bytes)
 
deployment/src/config/__pycache__/config_loader.cpython-312.pyc DELETED
Binary file (5.42 kB)
 
deployment/src/config/config_loader.py DELETED
@@ -1,117 +0,0 @@
1
- """
2
- Configuration loader for Hybrid Document Forgery Detection System
3
- """
4
-
5
- import yaml
6
- from pathlib import Path
7
- from typing import Dict, Any
8
-
9
-
10
- class Config:
11
- """Configuration manager"""
12
-
13
- def __init__(self, config_path: str = "config.yaml"):
14
- """
15
- Load configuration from YAML file
16
-
17
- Args:
18
- config_path: Path to configuration file
19
- """
20
- self.config_path = Path(config_path)
21
- self.config = self._load_config()
22
-
23
- def _load_config(self) -> Dict[str, Any]:
24
- """Load YAML configuration"""
25
- if not self.config_path.exists():
26
- raise FileNotFoundError(f"Config file not found: {self.config_path}")
27
-
28
- with open(self.config_path, 'r') as f:
29
- config = yaml.safe_load(f)
30
-
31
- return config
32
-
33
- def get(self, key: str, default: Any = None) -> Any:
34
- """
35
- Get configuration value using dot notation
36
-
37
- Args:
38
- key: Configuration key (e.g., 'model.encoder.name')
39
- default: Default value if key not found
40
-
41
- Returns:
42
- Configuration value
43
- """
44
- keys = key.split('.')
45
- value = self.config
46
-
47
- for k in keys:
48
- if isinstance(value, dict) and k in value:
49
- value = value[k]
50
- else:
51
- return default
52
-
53
- return value
54
-
55
- def get_dataset_config(self, dataset_name: str) -> Dict[str, Any]:
56
- """
57
- Get dataset-specific configuration
58
-
59
- Args:
60
- dataset_name: Dataset name (doctamper, rtm, casia, receipts)
61
-
62
- Returns:
63
- Dataset configuration dictionary
64
- """
65
- return self.config['data']['datasets'].get(dataset_name, {})
66
-
67
- def has_pixel_mask(self, dataset_name: str) -> bool:
68
- """Check if dataset has pixel-level masks"""
69
- dataset_config = self.get_dataset_config(dataset_name)
70
- return dataset_config.get('has_pixel_mask', False)
71
-
72
- def should_skip_deskew(self, dataset_name: str) -> bool:
73
- """Check if deskewing should be skipped for dataset"""
74
- dataset_config = self.get_dataset_config(dataset_name)
75
- return dataset_config.get('skip_deskew', False)
76
-
77
- def should_skip_denoising(self, dataset_name: str) -> bool:
78
- """Check if denoising should be skipped for dataset"""
79
- dataset_config = self.get_dataset_config(dataset_name)
80
- return dataset_config.get('skip_denoising', False)
81
-
82
- def get_min_region_area(self, dataset_name: str) -> float:
83
- """Get minimum region area threshold for dataset"""
84
- dataset_config = self.get_dataset_config(dataset_name)
85
- return dataset_config.get('min_region_area', 0.001)
86
-
87
- def should_compute_localization_metrics(self, dataset_name: str) -> bool:
88
- """Check if localization metrics should be computed for dataset"""
89
- compute_config = self.config['metrics'].get('compute_localization', {})
90
- return compute_config.get(dataset_name, False)
91
-
92
- def __getitem__(self, key: str) -> Any:
93
- """Allow dictionary-style access"""
94
- return self.get(key)
95
-
96
- def __repr__(self) -> str:
97
- return f"Config(path={self.config_path})"
98
-
99
-
100
- # Global config instance
101
- _config = None
102
-
103
-
104
- def get_config(config_path: str = "config.yaml") -> Config:
105
- """
106
- Get global configuration instance
107
-
108
- Args:
109
- config_path: Path to configuration file
110
-
111
- Returns:
112
- Config instance
113
- """
114
- global _config
115
- if _config is None:
116
- _config = Config(config_path)
117
- return _config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/src/data/__init__.py DELETED
@@ -1,23 +0,0 @@
1
- """Data module"""
2
-
3
- from .preprocessing import DocumentPreprocessor, preprocess_image
4
- from .augmentation import DatasetAwareAugmentation, get_augmentation
5
- from .datasets import (
6
- DocTamperDataset,
7
- RTMDataset,
8
- CASIADataset,
9
- ReceiptsDataset,
10
- get_dataset
11
- )
12
-
13
- __all__ = [
14
- 'DocumentPreprocessor',
15
- 'preprocess_image',
16
- 'DatasetAwareAugmentation',
17
- 'get_augmentation',
18
- 'DocTamperDataset',
19
- 'RTMDataset',
20
- 'CASIADataset',
21
- 'ReceiptsDataset',
22
- 'get_dataset'
23
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/src/data/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (572 Bytes)
 
deployment/src/data/__pycache__/augmentation.cpython-312.pyc DELETED
Binary file (5.94 kB)
 
deployment/src/data/__pycache__/datasets.cpython-312.pyc DELETED
Binary file (21.2 kB)
 
deployment/src/data/__pycache__/preprocessing.cpython-312.pyc DELETED
Binary file (9.38 kB)
 
deployment/src/data/augmentation.py DELETED
@@ -1,150 +0,0 @@
1
- """
2
- Dataset-aware augmentation for training
3
- """
4
-
5
- import cv2
6
- import numpy as np
7
- import albumentations as A
8
- from albumentations.pytorch import ToTensorV2
9
- from typing import Dict, Any, Optional
10
-
11
-
12
- class DatasetAwareAugmentation:
13
- """Dataset-aware augmentation pipeline"""
14
-
15
- def __init__(self, config, dataset_name: str, is_training: bool = True):
16
- """
17
- Initialize augmentation pipeline
18
-
19
- Args:
20
- config: Configuration object
21
- dataset_name: Dataset name
22
- is_training: Whether in training mode
23
- """
24
- self.config = config
25
- self.dataset_name = dataset_name
26
- self.is_training = is_training
27
-
28
- # Build augmentation pipeline
29
- self.transform = self._build_transform()
30
-
31
- def _build_transform(self) -> A.Compose:
32
- """Build albumentations transform pipeline"""
33
-
34
- transforms = []
35
-
36
- if self.is_training and self.config.get('augmentation.enabled', True):
37
- # Common augmentations
38
- common_augs = self.config.get('augmentation.common', [])
39
-
40
- for aug_config in common_augs:
41
- aug_type = aug_config.get('type')
42
- prob = aug_config.get('prob', 0.5)
43
-
44
- if aug_type == 'noise':
45
- transforms.append(
46
- A.GaussNoise(var_limit=(10.0, 50.0), p=prob)
47
- )
48
-
49
- elif aug_type == 'motion_blur':
50
- transforms.append(
51
- A.MotionBlur(blur_limit=7, p=prob)
52
- )
53
-
54
- elif aug_type == 'jpeg_compression':
55
- quality_range = aug_config.get('quality', [60, 95])
56
- transforms.append(
57
- A.ImageCompression(quality_lower=quality_range[0],
58
- quality_upper=quality_range[1],
59
- p=prob)
60
- )
61
-
62
- elif aug_type == 'lighting':
63
- transforms.append(
64
- A.OneOf([
65
- A.RandomBrightnessContrast(p=1.0),
66
- A.RandomGamma(p=1.0),
67
- A.HueSaturationValue(p=1.0),
68
- ], p=prob)
69
- )
70
-
71
- elif aug_type == 'perspective':
72
- transforms.append(
73
- A.Perspective(scale=(0.02, 0.05), p=prob)
74
- )
75
-
76
- # Dataset-specific augmentations
77
- if self.dataset_name == 'receipts':
78
- receipt_augs = self.config.get('augmentation.receipts', [])
79
-
80
- for aug_config in receipt_augs:
81
- aug_type = aug_config.get('type')
82
- prob = aug_config.get('prob', 0.5)
83
-
84
- if aug_type == 'stain':
85
- # Simulate stains using random blobs
86
- transforms.append(
87
- A.RandomShadow(
88
- shadow_roi=(0, 0, 1, 1),
89
- num_shadows_lower=1,
90
- num_shadows_upper=3,
91
- shadow_dimension=5,
92
- p=prob
93
- )
94
- )
95
-
96
- elif aug_type == 'fold':
97
- # Simulate folds using grid distortion
98
- transforms.append(
99
- A.GridDistortion(num_steps=5, distort_limit=0.1, p=prob)
100
- )
101
-
102
- # Always convert to tensor
103
- transforms.append(ToTensorV2())
104
-
105
- return A.Compose(
106
- transforms,
107
- additional_targets={'mask': 'mask'}
108
- )
109
-
110
- def __call__(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Dict[str, Any]:
111
- """
112
- Apply augmentation
113
-
114
- Args:
115
- image: Input image (H, W, 3), float32, [0, 1]
116
- mask: Optional mask (H, W), uint8, {0, 1}
117
-
118
- Returns:
119
- Dictionary with 'image' and optionally 'mask'
120
- """
121
- # Convert to uint8 for albumentations
122
- image_uint8 = (image * 255).astype(np.uint8)
123
-
124
- if mask is not None:
125
- mask_uint8 = (mask * 255).astype(np.uint8)
126
- augmented = self.transform(image=image_uint8, mask=mask_uint8)
127
-
128
- # Convert back to float32
129
- augmented['image'] = augmented['image'].float() / 255.0
130
- augmented['mask'] = (augmented['mask'].float() / 255.0).unsqueeze(0)
131
- else:
132
- augmented = self.transform(image=image_uint8)
133
- augmented['image'] = augmented['image'].float() / 255.0
134
-
135
- return augmented
136
-
137
-
138
- def get_augmentation(config, dataset_name: str, is_training: bool = True) -> DatasetAwareAugmentation:
139
- """
140
- Get augmentation pipeline
141
-
142
- Args:
143
- config: Configuration object
144
- dataset_name: Dataset name
145
- is_training: Whether in training mode
146
-
147
- Returns:
148
- Augmentation pipeline
149
- """
150
- return DatasetAwareAugmentation(config, dataset_name, is_training)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/src/data/datasets.py DELETED
@@ -1,541 +0,0 @@
1
- """
2
- Dataset loaders for document forgery detection
3
- Implements Critical Fix #7: Image-level train/test splits
4
- """
5
-
6
- import os
7
- import lmdb
8
- import cv2
9
- import numpy as np
10
- import torch
11
- from torch.utils.data import Dataset
12
- from pathlib import Path
13
- from typing import Tuple, Optional, List
14
- import json
15
- from PIL import Image
16
-
17
- from .preprocessing import DocumentPreprocessor
18
- from .augmentation import DatasetAwareAugmentation
19
-
20
-
21
- class DocTamperDataset(Dataset):
22
- """
23
- DocTamper dataset loader (LMDB-based)
24
- Implements chunked loading for RAM constraints
25
- Uses lazy LMDB initialization for multiprocessing compatibility
26
- """
27
-
28
- def __init__(self,
29
- config,
30
- split: str = 'train',
31
- chunk_start: float = 0.0,
32
- chunk_end: float = 1.0):
33
- """
34
- Initialize DocTamper dataset
35
-
36
- Args:
37
- config: Configuration object
38
- split: 'train' or 'val'
39
- chunk_start: Start ratio for chunked training (0.0 to 1.0)
40
- chunk_end: End ratio for chunked training (0.0 to 1.0)
41
- """
42
- self.config = config
43
- self.split = split
44
- self.dataset_name = 'doctamper'
45
-
46
- # Get dataset path
47
- dataset_config = config.get_dataset_config(self.dataset_name)
48
- self.data_path = Path(dataset_config['path'])
49
-
50
- # Map split to actual folder names
51
- if split == 'train':
52
- lmdb_folder = 'DocTamperV1-TrainingSet'
53
- elif split == 'val' or split == 'test':
54
- lmdb_folder = 'DocTamperV1-TestingSet'
55
- else:
56
- lmdb_folder = 'DocTamperV1-TrainingSet'
57
-
58
- self.lmdb_path = str(self.data_path / lmdb_folder)
59
-
60
- if not Path(self.lmdb_path).exists():
61
- raise FileNotFoundError(f"LMDB folder not found: {self.lmdb_path}")
62
-
63
- # LAZY INITIALIZATION: Don't open LMDB here (pickle issue with multiprocessing)
64
- # Just get the count by temporarily opening
65
- temp_env = lmdb.open(self.lmdb_path, readonly=True, lock=False)
66
- with temp_env.begin() as txn:
67
- stat = txn.stat()
68
- self.length = stat['entries'] // 2
69
- temp_env.close()
70
-
71
- # LMDB env will be opened lazily in __getitem__
72
- self._env = None
73
-
74
- # Critical Fix #7: Image-level chunking (not region-level)
75
- self.chunk_start = int(self.length * chunk_start)
76
- self.chunk_end = int(self.length * chunk_end)
77
- self.chunk_length = self.chunk_end - self.chunk_start
78
-
79
- print(f"DocTamper {split}: Total={self.length}, "
80
- f"Chunk=[{self.chunk_start}:{self.chunk_end}], "
81
- f"Length={self.chunk_length}")
82
-
83
- # Initialize preprocessor and augmentation
84
- self.preprocessor = DocumentPreprocessor(config, self.dataset_name)
85
- self.augmentation = DatasetAwareAugmentation(
86
- config,
87
- self.dataset_name,
88
- is_training=(split == 'train')
89
- )
90
-
91
- @property
92
- def env(self):
93
- """Lazy LMDB environment initialization for multiprocessing compatibility"""
94
- if self._env is None:
95
- self._env = lmdb.open(self.lmdb_path, readonly=True, lock=False,
96
- max_readers=32, readahead=False)
97
- return self._env
98
-
99
- def __len__(self) -> int:
100
- return self.chunk_length
101
-
102
- def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, dict]:
103
- """
104
- Get item from dataset
105
-
106
- Args:
107
- idx: Index within chunk
108
-
109
- Returns:
110
- image: (3, H, W) tensor
111
- mask: (1, H, W) tensor
112
- metadata: Dictionary with additional info
113
- """
114
- # Try to get the requested sample, skip to next if missing
115
- max_attempts = 10
116
- original_idx = idx
117
-
118
- for attempt in range(max_attempts):
119
- try:
120
- # Map chunk index to global index
121
- global_idx = self.chunk_start + idx
122
-
123
- # Read from LMDB
124
- with self.env.begin() as txn:
125
- # DocTamper format: image-XXXXXXXXX, label-XXXXXXXXX (9 digits, dash separator)
126
- img_key = f'image-{global_idx:09d}'.encode()
127
- label_key = f'label-{global_idx:09d}'.encode()
128
-
129
- img_buf = txn.get(img_key)
130
- label_buf = txn.get(label_key)
131
-
132
- if img_buf is None:
133
- # Sample missing, try next index
134
- idx = (idx + 1) % self.chunk_length
135
- continue
136
-
137
- # Decode image
138
- img_array = np.frombuffer(img_buf, dtype=np.uint8)
139
- image = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
140
-
141
- if image is None:
142
- # Failed to decode, try next index
143
- idx = (idx + 1) % self.chunk_length
144
- continue
145
-
146
- # Decode label/mask
147
- if label_buf is not None:
148
- label_array = np.frombuffer(label_buf, dtype=np.uint8)
149
- mask = cv2.imdecode(label_array, cv2.IMREAD_GRAYSCALE)
150
- if mask is None:
151
- # Label might be raw bytes, create empty mask
152
- mask = np.zeros(image.shape[:2], dtype=np.uint8)
153
- else:
154
- # No mask found - create empty mask
155
- mask = np.zeros(image.shape[:2], dtype=np.uint8)
156
-
157
- # Successfully loaded - break out of retry loop
158
- break
159
-
160
- except Exception as e:
161
- # Something went wrong, try next index
162
- idx = (idx + 1) % self.chunk_length
163
- if attempt == max_attempts - 1:
164
- # Last attempt failed, create a dummy sample
165
- print(f"Warning: Could not load sample at idx {original_idx}, creating dummy sample")
166
- image = np.zeros((384, 384, 3), dtype=np.float32)
167
- mask = np.zeros((384, 384), dtype=np.uint8)
168
- global_idx = original_idx
169
-
170
- # Preprocess
171
- image, mask = self.preprocessor(image, mask)
172
-
173
- # Augment
174
- augmented = self.augmentation(image, mask)
175
- image = augmented['image']
176
- mask = augmented['mask']
177
-
178
- # Metadata
179
- metadata = {
180
- 'dataset': self.dataset_name,
181
- 'index': global_idx,
182
- 'has_pixel_mask': True
183
- }
184
-
185
- return image, mask, metadata
186
-
187
- def __del__(self):
188
- """Close LMDB environment"""
189
- if hasattr(self, '_env') and self._env is not None:
190
- self._env.close()
191
-
192
-
193
-
194
- class RTMDataset(Dataset):
195
- """Real Text Manipulation dataset loader"""
196
-
197
- def __init__(self, config, split: str = 'train'):
198
- """
199
- Initialize RTM dataset
200
-
201
- Args:
202
- config: Configuration object
203
- split: 'train' or 'test'
204
- """
205
- self.config = config
206
- self.split = split
207
- self.dataset_name = 'rtm'
208
-
209
- # Get dataset path
210
- dataset_config = config.get_dataset_config(self.dataset_name)
211
- self.data_path = Path(dataset_config['path'])
212
-
213
- # Load split file
214
- split_file = self.data_path / f'{split}.txt'
215
- with open(split_file, 'r') as f:
216
- self.image_ids = [line.strip() for line in f.readlines()]
217
-
218
- self.images_dir = self.data_path / 'JPEGImages'
219
- self.masks_dir = self.data_path / 'SegmentationClass'
220
-
221
- print(f"RTM {split}: {len(self.image_ids)} images")
222
-
223
- # Initialize preprocessor and augmentation
224
- self.preprocessor = DocumentPreprocessor(config, self.dataset_name)
225
- self.augmentation = DatasetAwareAugmentation(
226
- config,
227
- self.dataset_name,
228
- is_training=(split == 'train')
229
- )
230
-
231
- def __len__(self) -> int:
232
- return len(self.image_ids)
233
-
234
- def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, dict]:
235
- """Get item from dataset"""
236
- image_id = self.image_ids[idx]
237
-
238
- # Load image
239
- img_path = self.images_dir / f'{image_id}.jpg'
240
- image = cv2.imread(str(img_path))
241
-
242
- # Load mask
243
- mask_path = self.masks_dir / f'{image_id}.png'
244
- mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
245
-
246
- # Binarize mask
247
- mask = (mask > 0).astype(np.uint8)
248
-
249
- # Preprocess
250
- image, mask = self.preprocessor(image, mask)
251
-
252
- # Augment
253
- augmented = self.augmentation(image, mask)
254
- image = augmented['image']
255
- mask = augmented['mask']
256
-
257
- # Metadata
258
- metadata = {
259
- 'dataset': self.dataset_name,
260
- 'image_id': image_id,
261
- 'has_pixel_mask': True
262
- }
263
-
264
- return image, mask, metadata
265
-
266
-
267
- class CASIADataset(Dataset):
268
- """
269
- CASIA v1.0 dataset loader
270
- Image-level labels only (no pixel masks)
271
- Implements Critical Fix #6: CASIA image-level handling
272
- """
273
-
274
- def __init__(self, config, split: str = 'train'):
275
- """
276
- Initialize CASIA dataset
277
-
278
- Args:
279
- config: Configuration object
280
- split: 'train' or 'test'
281
- """
282
- self.config = config
283
- self.split = split
284
- self.dataset_name = 'casia'
285
-
286
- # Get dataset path
287
- dataset_config = config.get_dataset_config(self.dataset_name)
288
- self.data_path = Path(dataset_config['path'])
289
-
290
- # Load authentic and tampered images
291
- self.authentic_dir = self.data_path / 'Au'
292
- self.tampered_dir = self.data_path / 'Tp'
293
-
294
- # Get all image paths
295
- authentic_images = list(self.authentic_dir.glob('*.jpg')) + \
296
- list(self.authentic_dir.glob('*.png'))
297
- tampered_images = list(self.tampered_dir.glob('*.jpg')) + \
298
- list(self.tampered_dir.glob('*.png'))
299
-
300
- # Create image list with labels
301
- self.samples = []
302
- for img_path in authentic_images:
303
- self.samples.append((img_path, 0)) # 0 = authentic
304
- for img_path in tampered_images:
305
- self.samples.append((img_path, 1)) # 1 = tampered
306
-
307
- # Critical Fix #7: Image-level split (80/20)
308
- np.random.seed(42)
309
- indices = np.random.permutation(len(self.samples))
310
- split_idx = int(len(self.samples) * 0.8)
311
-
312
- if split == 'train':
313
- indices = indices[:split_idx]
314
- else:
315
- indices = indices[split_idx:]
316
-
317
- self.samples = [self.samples[i] for i in indices]
318
-
319
- print(f"CASIA {split}: {len(self.samples)} images")
320
-
321
- # Initialize preprocessor and augmentation
322
- self.preprocessor = DocumentPreprocessor(config, self.dataset_name)
323
- self.augmentation = DatasetAwareAugmentation(
324
- config,
325
- self.dataset_name,
326
- is_training=(split == 'train')
327
- )
328
-
329
- def __len__(self) -> int:
330
- return len(self.samples)
331
-
332
- def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, dict]:
333
- """Get item from dataset"""
334
- img_path, label = self.samples[idx]
335
-
336
- # Load image
337
- image = cv2.imread(str(img_path))
338
-
339
- # Critical Fix #6: Create image-level mask (entire image)
340
- h, w = image.shape[:2]
341
- mask = np.ones((h, w), dtype=np.uint8) * label
342
-
343
- # Preprocess
344
- image, mask = self.preprocessor(image, mask)
345
-
346
- # Augment
347
- augmented = self.augmentation(image, mask)
348
- image = augmented['image']
349
- mask = augmented['mask']
350
-
351
- # Metadata
352
- metadata = {
353
- 'dataset': self.dataset_name,
354
- 'image_path': str(img_path),
355
- 'has_pixel_mask': False, # Image-level only
356
- 'label': label
357
- }
358
-
359
- return image, mask, metadata
360
-
361
-
362
- class ReceiptsDataset(Dataset):
363
- """Find-It-Again receipts dataset loader"""
364
-
365
- def __init__(self, config, split: str = 'train'):
366
- """
367
- Initialize receipts dataset
368
-
369
- Args:
370
- config: Configuration object
371
- split: 'train', 'val', or 'test'
372
- """
373
- self.config = config
374
- self.split = split
375
- self.dataset_name = 'receipts'
376
-
377
- # Get dataset path
378
- dataset_config = config.get_dataset_config(self.dataset_name)
379
- self.data_path = Path(dataset_config['path'])
380
-
381
- # Load split file
382
- split_file = self.data_path / f'{split}.json'
383
- with open(split_file, 'r') as f:
384
- self.annotations = json.load(f)
385
-
386
- print(f"Receipts {split}: {len(self.annotations)} images")
387
-
388
- # Initialize preprocessor and augmentation
389
- self.preprocessor = DocumentPreprocessor(config, self.dataset_name)
390
- self.augmentation = DatasetAwareAugmentation(
391
- config,
392
- self.dataset_name,
393
- is_training=(split == 'train')
394
- )
395
-
396
- def __len__(self) -> int:
397
- return len(self.annotations)
398
-
399
- def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, dict]:
400
- """Get item from dataset"""
401
- ann = self.annotations[idx]
402
-
403
- # Load image
404
- img_path = self.data_path / ann['image_path']
405
- image = cv2.imread(str(img_path))
406
-
407
- # Create mask from bounding boxes
408
- h, w = image.shape[:2]
409
- mask = np.zeros((h, w), dtype=np.uint8)
410
-
411
- for bbox in ann.get('bboxes', []):
412
- x, y, w_box, h_box = bbox
413
- mask[y:y+h_box, x:x+w_box] = 1
414
-
415
- # Preprocess
416
- image, mask = self.preprocessor(image, mask)
417
-
418
- # Augment
419
- augmented = self.augmentation(image, mask)
420
- image = augmented['image']
421
- mask = augmented['mask']
422
-
423
- # Metadata
424
- metadata = {
425
- 'dataset': self.dataset_name,
426
- 'image_path': str(img_path),
427
- 'has_pixel_mask': True
428
- }
429
-
430
- return image, mask, metadata
431
-
432
-
433
- class FCDDataset(DocTamperDataset):
434
- """FCD (Forgery Classification Dataset) loader - inherits from DocTamperDataset"""
435
-
436
- def __init__(self, config, split: str = 'train'):
437
- self.config = config
438
- self.split = split
439
- self.dataset_name = 'fcd'
440
-
441
- # Get dataset path from config
442
- dataset_config = config.get_dataset_config(self.dataset_name)
443
- self.data_path = Path(dataset_config['path'])
444
- self.lmdb_path = str(self.data_path)
445
-
446
- if not Path(self.lmdb_path).exists():
447
- raise FileNotFoundError(f"LMDB folder not found: {self.lmdb_path}")
448
-
449
- # Get total count
450
- temp_env = lmdb.open(self.lmdb_path, readonly=True, lock=False)
451
- with temp_env.begin() as txn:
452
- stat = txn.stat()
453
- self.length = stat['entries'] // 2 # Half are images, half are labels
454
- temp_env.close()
455
-
456
- self._env = None
457
-
458
- # FCD is small, no chunking needed
459
- self.chunk_start = 0
460
- self.chunk_end = self.length
461
- self.chunk_length = self.length
462
-
463
- print(f"FCD {split}: {self.length} samples")
464
-
465
- # Initialize preprocessor and augmentation
466
- self.preprocessor = DocumentPreprocessor(config, self.dataset_name)
467
- self.augmentation = DatasetAwareAugmentation(
468
- config,
469
- self.dataset_name,
470
- is_training=(split == 'train')
471
- )
472
-
473
-
474
- class SCDDataset(DocTamperDataset):
475
- """SCD (Splicing Classification Dataset) loader - inherits from DocTamperDataset"""
476
-
477
- def __init__(self, config, split: str = 'train'):
478
- self.config = config
479
- self.split = split
480
- self.dataset_name = 'scd'
481
-
482
- # Get dataset path from config
483
- dataset_config = config.get_dataset_config(self.dataset_name)
484
- self.data_path = Path(dataset_config['path'])
485
- self.lmdb_path = str(self.data_path)
486
-
487
- if not Path(self.lmdb_path).exists():
488
- raise FileNotFoundError(f"LMDB folder not found: {self.lmdb_path}")
489
-
490
- # Get total count
491
- temp_env = lmdb.open(self.lmdb_path, readonly=True, lock=False)
492
- with temp_env.begin() as txn:
493
- stat = txn.stat()
494
- self.length = stat['entries'] // 2 # Half are images, half are labels
495
- temp_env.close()
496
-
497
- self._env = None
498
-
499
- # SCD is medium-sized, no chunking needed
500
- self.chunk_start = 0
501
- self.chunk_end = self.length
502
- self.chunk_length = self.length
503
-
504
- print(f"SCD {split}: {self.length} samples")
505
-
506
- # Initialize preprocessor and augmentation
507
- self.preprocessor = DocumentPreprocessor(config, self.dataset_name)
508
- self.augmentation = DatasetAwareAugmentation(
509
- config,
510
- self.dataset_name,
511
- is_training=(split == 'train')
512
- )
513
-
514
-
515
- def get_dataset(config, dataset_name: str, split: str = 'train', **kwargs) -> Dataset:
516
- """
517
- Factory function to get dataset
518
-
519
- Args:
520
- config: Configuration object
521
- dataset_name: Dataset name
522
- split: Data split
523
- **kwargs: Additional arguments (e.g., chunk_start, chunk_end)
524
-
525
- Returns:
526
- Dataset instance
527
- """
528
- if dataset_name == 'doctamper':
529
- return DocTamperDataset(config, split, **kwargs)
530
- elif dataset_name == 'rtm':
531
- return RTMDataset(config, split)
532
- elif dataset_name == 'casia':
533
- return CASIADataset(config, split)
534
- elif dataset_name == 'receipts':
535
- return ReceiptsDataset(config, split)
536
- elif dataset_name == 'fcd':
537
- return FCDDataset(config, split)
538
- elif dataset_name == 'scd':
539
- return SCDDataset(config, split)
540
- else:
541
- raise ValueError(f"Unknown dataset: {dataset_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/src/data/preprocessing.py DELETED
@@ -1,226 +0,0 @@
1
- """
2
- Dataset-aware preprocessing for document forgery detection
3
- Implements Critical Fix #1: Dataset-Aware Preprocessing
4
- """
5
-
6
- import cv2
7
- import numpy as np
8
- from typing import Tuple, Optional
9
- import pywt
10
- from scipy import ndimage
11
-
12
-
13
- class DocumentPreprocessor:
14
- """Dataset-aware document preprocessing"""
15
-
16
- def __init__(self, config, dataset_name: str):
17
- """
18
- Initialize preprocessor
19
-
20
- Args:
21
- config: Configuration object
22
- dataset_name: Name of dataset (for dataset-aware processing)
23
- """
24
- self.config = config
25
- self.dataset_name = dataset_name
26
- self.image_size = config.get('data.image_size', 384)
27
- self.noise_threshold = config.get('preprocessing.noise_threshold', 15.0)
28
-
29
- # Dataset-aware flags (Critical Fix #1)
30
- self.skip_deskew = config.should_skip_deskew(dataset_name)
31
- self.skip_denoising = config.should_skip_denoising(dataset_name)
32
-
33
- def __call__(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
34
- """
35
- Apply preprocessing pipeline
36
-
37
- Args:
38
- image: Input image (H, W, 3)
39
- mask: Optional ground truth mask (H, W)
40
-
41
- Returns:
42
- Preprocessed image and mask
43
- """
44
- # 1. Convert to RGB
45
- if len(image.shape) == 2:
46
- image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
47
- elif image.shape[2] == 4:
48
- image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
49
- elif image.shape[2] == 3:
50
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
51
-
52
- # 2. Deskew (dataset-aware)
53
- if not self.skip_deskew:
54
- image, mask = self._deskew(image, mask)
55
-
56
- # 3. Resize
57
- image, mask = self._resize(image, mask)
58
-
59
- # 4. Normalize
60
- image = self._normalize(image)
61
-
62
- # 5. Conditional denoising (dataset-aware)
63
- if not self.skip_denoising:
64
- noise_level = self._estimate_noise(image)
65
- if noise_level > self.noise_threshold:
66
- image = self._denoise(image)
67
-
68
- return image, mask
69
-
70
- def _deskew(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
71
- """
72
- Deskew document image
73
-
74
- Args:
75
- image: Input image
76
- mask: Optional mask
77
-
78
- Returns:
79
- Deskewed image and mask
80
- """
81
- # Convert to grayscale for angle detection
82
- gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
83
-
84
- # Detect edges
85
- edges = cv2.Canny(gray, 50, 150, apertureSize=3)
86
-
87
- # Detect lines using Hough transform
88
- lines = cv2.HoughLines(edges, 1, np.pi / 180, 200)
89
-
90
- if lines is not None and len(lines) > 0:
91
- # Calculate dominant angle
92
- angles = []
93
- for rho, theta in lines[:, 0]:
94
- angle = (theta * 180 / np.pi) - 90
95
- angles.append(angle)
96
-
97
- # Use median angle
98
- angle = np.median(angles)
99
-
100
- # Only deskew if angle is significant (> 0.5 degrees)
101
- if abs(angle) > 0.5:
102
- # Get rotation matrix
103
- h, w = image.shape[:2]
104
- center = (w // 2, h // 2)
105
- M = cv2.getRotationMatrix2D(center, angle, 1.0)
106
-
107
- # Rotate image
108
- image = cv2.warpAffine(image, M, (w, h),
109
- flags=cv2.INTER_CUBIC,
110
- borderMode=cv2.BORDER_REPLICATE)
111
-
112
- # Rotate mask if provided
113
- if mask is not None:
114
- mask = cv2.warpAffine(mask, M, (w, h),
115
- flags=cv2.INTER_NEAREST,
116
- borderMode=cv2.BORDER_CONSTANT,
117
- borderValue=0)
118
-
119
- return image, mask
120
-
121
- def _resize(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
122
- """
123
- Resize image and mask to target size
124
-
125
- Args:
126
- image: Input image
127
- mask: Optional mask
128
-
129
- Returns:
130
- Resized image and mask
131
- """
132
- target_size = (self.image_size, self.image_size)
133
-
134
- # Resize image
135
- image = cv2.resize(image, target_size, interpolation=cv2.INTER_CUBIC)
136
-
137
- # Resize mask if provided
138
- if mask is not None:
139
- mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
140
-
141
- return image, mask
142
-
143
- def _normalize(self, image: np.ndarray) -> np.ndarray:
144
- """
145
- Normalize pixel values to [0, 1]
146
-
147
- Args:
148
- image: Input image
149
-
150
- Returns:
151
- Normalized image
152
- """
153
- return image.astype(np.float32) / 255.0
154
-
155
- def _estimate_noise(self, image: np.ndarray) -> float:
156
- """
157
- Estimate noise level using Laplacian variance and wavelet-based estimation
158
-
159
- Args:
160
- image: Input image (normalized)
161
-
162
- Returns:
163
- Estimated noise level
164
- """
165
- # Convert to grayscale for noise estimation
166
- if len(image.shape) == 3:
167
- gray = cv2.cvtColor((image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
168
- else:
169
- gray = (image * 255).astype(np.uint8)
170
-
171
- # Method 1: Laplacian variance
172
- laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var()
173
-
174
- # Method 2: Wavelet-based noise estimation
175
- coeffs = pywt.dwt2(gray, 'db1')
176
- _, (cH, cV, cD) = coeffs
177
- sigma = np.median(np.abs(cD)) / 0.6745
178
-
179
- # Combine both estimates
180
- noise_level = (laplacian_var + sigma) / 2.0
181
-
182
- return noise_level
183
-
184
- def _denoise(self, image: np.ndarray) -> np.ndarray:
185
- """
186
- Apply conditional denoising
187
-
188
- Args:
189
- image: Input image (normalized)
190
-
191
- Returns:
192
- Denoised image
193
- """
194
- # Convert to uint8 for filtering
195
- image_uint8 = (image * 255).astype(np.uint8)
196
-
197
- # Apply median filter (3x3)
198
- median_filtered = cv2.medianBlur(image_uint8, 3)
199
-
200
- # Apply Gaussian filter (σ ≤ 0.8)
201
- gaussian_filtered = cv2.GaussianBlur(median_filtered, (3, 3), 0.8)
202
-
203
- # Convert back to float32
204
- denoised = gaussian_filtered.astype(np.float32) / 255.0
205
-
206
- return denoised
207
-
208
-
209
- def preprocess_image(image: np.ndarray,
210
- mask: Optional[np.ndarray] = None,
211
- config = None,
212
- dataset_name: str = 'default') -> Tuple[np.ndarray, Optional[np.ndarray]]:
213
- """
214
- Convenience function for preprocessing
215
-
216
- Args:
217
- image: Input image
218
- mask: Optional mask
219
- config: Configuration object
220
- dataset_name: Dataset name
221
-
222
- Returns:
223
- Preprocessed image and mask
224
- """
225
- preprocessor = DocumentPreprocessor(config, dataset_name)
226
- return preprocessor(image, mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/src/features/__init__.py DELETED
@@ -1,32 +0,0 @@
1
- """Features module"""
2
-
3
- from .feature_extraction import (
4
- DeepFeatureExtractor,
5
- StatisticalFeatureExtractor,
6
- FrequencyFeatureExtractor,
7
- NoiseELAFeatureExtractor,
8
- OCRFeatureExtractor,
9
- HybridFeatureExtractor,
10
- get_feature_extractor
11
- )
12
-
13
- from .region_extraction import (
14
- MaskRefiner,
15
- RegionExtractor,
16
- get_mask_refiner,
17
- get_region_extractor
18
- )
19
-
20
- __all__ = [
21
- 'DeepFeatureExtractor',
22
- 'StatisticalFeatureExtractor',
23
- 'FrequencyFeatureExtractor',
24
- 'NoiseELAFeatureExtractor',
25
- 'OCRFeatureExtractor',
26
- 'HybridFeatureExtractor',
27
- 'get_feature_extractor',
28
- 'MaskRefiner',
29
- 'RegionExtractor',
30
- 'get_mask_refiner',
31
- 'get_region_extractor'
32
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/src/features/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (691 Bytes)
 
deployment/src/features/__pycache__/feature_extraction.cpython-312.pyc DELETED
Binary file (22.6 kB)
 
deployment/src/features/__pycache__/region_extraction.cpython-312.pyc DELETED
Binary file (8.93 kB)
 
deployment/src/features/feature_extraction.py DELETED
@@ -1,485 +0,0 @@
1
- """
2
- Hybrid feature extraction for forgery detection
3
- Implements Critical Fix #5: Feature Group Gating
4
- """
5
-
6
- import cv2
7
- import numpy as np
8
- import torch
9
- import torch.nn.functional as F
10
- from typing import Dict, List, Optional, Tuple
11
- from scipy import ndimage
12
- from scipy.fftpack import dct
13
- import pywt
14
- from skimage.measure import regionprops, label
15
- from skimage.filters import sobel
16
-
17
-
18
- class DeepFeatureExtractor:
19
- """Extract deep features from decoder feature maps"""
20
-
21
- def __init__(self):
22
- """Initialize deep feature extractor"""
23
- pass
24
-
25
- def extract(self,
26
- decoder_features: List[torch.Tensor],
27
- region_mask: np.ndarray) -> np.ndarray:
28
- """
29
- Extract deep features using Global Average Pooling
30
-
31
- Args:
32
- decoder_features: List of decoder feature tensors
33
- region_mask: Binary region mask (H, W)
34
-
35
- Returns:
36
- Deep feature vector
37
- """
38
- features = []
39
-
40
- for feat in decoder_features:
41
- # Ensure on CPU and numpy
42
- if isinstance(feat, torch.Tensor):
43
- feat = feat.detach().cpu().numpy()
44
-
45
- # feat shape: (B, C, H, W) or (C, H, W)
46
- if feat.ndim == 4:
47
- feat = feat[0] # Take first batch
48
-
49
- # Resize mask to feature size
50
- h, w = feat.shape[1:]
51
- mask_resized = cv2.resize(region_mask.astype(np.float32), (w, h))
52
- mask_resized = mask_resized > 0.5
53
-
54
- # Masked Global Average Pooling
55
- if mask_resized.sum() > 0:
56
- for c in range(feat.shape[0]):
57
- channel_feat = feat[c]
58
- masked_mean = channel_feat[mask_resized].mean()
59
- features.append(masked_mean)
60
- else:
61
- # Fallback: use global average
62
- features.extend(feat.mean(axis=(1, 2)).tolist())
63
-
64
- return np.array(features, dtype=np.float32)
65
-
66
-
67
- class StatisticalFeatureExtractor:
68
- """Extract statistical and shape features from regions"""
69
-
70
- def __init__(self):
71
- """Initialize statistical feature extractor"""
72
- pass
73
-
74
- def extract(self,
75
- image: np.ndarray,
76
- region_mask: np.ndarray) -> np.ndarray:
77
- """
78
- Extract statistical and shape features
79
-
80
- Args:
81
- image: Input image (H, W, 3) normalized [0, 1]
82
- region_mask: Binary region mask (H, W)
83
-
84
- Returns:
85
- Statistical feature vector
86
- """
87
- features = []
88
-
89
- # Label the mask
90
- labeled_mask = label(region_mask)
91
- props = regionprops(labeled_mask)
92
-
93
- if len(props) > 0:
94
- prop = props[0]
95
-
96
- # Area and perimeter
97
- features.append(prop.area)
98
- features.append(prop.perimeter)
99
-
100
- # Aspect ratio
101
- if prop.major_axis_length > 0:
102
- aspect_ratio = prop.minor_axis_length / prop.major_axis_length
103
- else:
104
- aspect_ratio = 1.0
105
- features.append(aspect_ratio)
106
-
107
- # Solidity
108
- features.append(prop.solidity)
109
-
110
- # Eccentricity
111
- features.append(prop.eccentricity)
112
-
113
- # Entropy (using intensity)
114
- if len(image.shape) == 3:
115
- gray = cv2.cvtColor((image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
116
- else:
117
- gray = (image * 255).astype(np.uint8)
118
-
119
- region_pixels = gray[region_mask > 0]
120
- if len(region_pixels) > 0:
121
- hist, _ = np.histogram(region_pixels, bins=256, range=(0, 256))
122
- hist = hist / hist.sum() + 1e-8
123
- entropy = -np.sum(hist * np.log2(hist + 1e-8))
124
- else:
125
- entropy = 0.0
126
- features.append(entropy)
127
- else:
128
- # Default values
129
- features.extend([0, 0, 1.0, 0, 0, 0])
130
-
131
- return np.array(features, dtype=np.float32)
132
-
133
-
134
- class FrequencyFeatureExtractor:
135
- """Extract frequency-domain features"""
136
-
137
- def __init__(self):
138
- """Initialize frequency feature extractor"""
139
- pass
140
-
141
- def extract(self,
142
- image: np.ndarray,
143
- region_mask: np.ndarray) -> np.ndarray:
144
- """
145
- Extract frequency-domain features (DCT, wavelet)
146
-
147
- Args:
148
- image: Input image (H, W, 3) normalized [0, 1]
149
- region_mask: Binary region mask (H, W)
150
-
151
- Returns:
152
- Frequency feature vector
153
- """
154
- features = []
155
-
156
- # Convert to grayscale
157
- if len(image.shape) == 3:
158
- gray = cv2.cvtColor((image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
159
- else:
160
- gray = (image * 255).astype(np.uint8)
161
-
162
- # Get region bounding box
163
- coords = np.where(region_mask > 0)
164
- if len(coords[0]) == 0:
165
- return np.zeros(13, dtype=np.float32)
166
-
167
- y_min, y_max = coords[0].min(), coords[0].max()
168
- x_min, x_max = coords[1].min(), coords[1].max()
169
-
170
- # Crop region
171
- region = gray[y_min:y_max+1, x_min:x_max+1].astype(np.float32)
172
-
173
- if region.size == 0:
174
- return np.zeros(13, dtype=np.float32)
175
-
176
- # DCT coefficients
177
- try:
178
- dct_coeffs = dct(dct(region, axis=0, norm='ortho'), axis=1, norm='ortho')
179
-
180
- # Mean and std of DCT coefficients
181
- features.append(np.mean(np.abs(dct_coeffs)))
182
- features.append(np.std(dct_coeffs))
183
-
184
- # High-frequency energy (bottom-right quadrant)
185
- h, w = dct_coeffs.shape
186
- high_freq = dct_coeffs[h//2:, w//2:]
187
- features.append(np.sum(np.abs(high_freq)) / (high_freq.size + 1e-8))
188
- except Exception:
189
- features.extend([0, 0, 0])
190
-
191
- # Wavelet features
192
- try:
193
- coeffs = pywt.dwt2(region, 'db1')
194
- cA, (cH, cV, cD) = coeffs
195
-
196
- # Energy in each sub-band
197
- features.append(np.sum(cA ** 2) / (cA.size + 1e-8))
198
- features.append(np.sum(cH ** 2) / (cH.size + 1e-8))
199
- features.append(np.sum(cV ** 2) / (cV.size + 1e-8))
200
- features.append(np.sum(cD ** 2) / (cD.size + 1e-8))
201
-
202
- # Wavelet entropy
203
- for coeff in [cH, cV, cD]:
204
- coeff_flat = np.abs(coeff.flatten())
205
- if coeff_flat.sum() > 0:
206
- coeff_norm = coeff_flat / coeff_flat.sum()
207
- entropy = -np.sum(coeff_norm * np.log2(coeff_norm + 1e-8))
208
- else:
209
- entropy = 0.0
210
- features.append(entropy)
211
- except Exception:
212
- features.extend([0, 0, 0, 0, 0, 0, 0])
213
-
214
- return np.array(features, dtype=np.float32)
215
-
216
-
217
- class NoiseELAFeatureExtractor:
218
- """Extract noise and Error Level Analysis features"""
219
-
220
- def __init__(self, quality: int = 90):
221
- """
222
- Initialize noise/ELA extractor
223
-
224
- Args:
225
- quality: JPEG quality for ELA
226
- """
227
- self.quality = quality
228
-
229
- def extract(self,
230
- image: np.ndarray,
231
- region_mask: np.ndarray) -> np.ndarray:
232
- """
233
- Extract noise and ELA features
234
-
235
- Args:
236
- image: Input image (H, W, 3) normalized [0, 1]
237
- region_mask: Binary region mask (H, W)
238
-
239
- Returns:
240
- Noise/ELA feature vector
241
- """
242
- features = []
243
-
244
- # Convert to uint8
245
- img_uint8 = (image * 255).astype(np.uint8)
246
-
247
- # Error Level Analysis
248
- # Compress and compute difference
249
- encode_param = [cv2.IMWRITE_JPEG_QUALITY, self.quality]
250
- _, encoded = cv2.imencode('.jpg', img_uint8, encode_param)
251
- recompressed = cv2.imdecode(encoded, cv2.IMREAD_COLOR)
252
-
253
- ela = np.abs(img_uint8.astype(np.float32) - recompressed.astype(np.float32))
254
-
255
- # ELA features within region
256
- ela_region = ela[region_mask > 0]
257
- if len(ela_region) > 0:
258
- features.append(np.mean(ela_region)) # ELA mean
259
- features.append(np.var(ela_region)) # ELA variance
260
- features.append(np.max(ela_region)) # ELA max
261
- else:
262
- features.extend([0, 0, 0])
263
-
264
- # Noise residual (using median filter)
265
- if len(image.shape) == 3:
266
- gray = cv2.cvtColor(img_uint8, cv2.COLOR_RGB2GRAY)
267
- else:
268
- gray = img_uint8
269
-
270
- median_filtered = cv2.medianBlur(gray, 3)
271
- noise_residual = np.abs(gray.astype(np.float32) - median_filtered.astype(np.float32))
272
-
273
- residual_region = noise_residual[region_mask > 0]
274
- if len(residual_region) > 0:
275
- features.append(np.mean(residual_region))
276
- features.append(np.var(residual_region))
277
- else:
278
- features.extend([0, 0])
279
-
280
- return np.array(features, dtype=np.float32)
281
-
282
-
283
- class OCRFeatureExtractor:
284
- """
285
- Extract OCR-based consistency features
286
- Only for text documents (Feature Gating - Critical Fix #5)
287
- """
288
-
289
- def __init__(self):
290
- """Initialize OCR feature extractor"""
291
- self.ocr_available = False
292
-
293
- try:
294
- import easyocr
295
- self.reader = easyocr.Reader(['en'], gpu=True)
296
- self.ocr_available = True
297
- except Exception:
298
- print("Warning: EasyOCR not available, OCR features disabled")
299
-
300
- def extract(self,
301
- image: np.ndarray,
302
- region_mask: np.ndarray) -> np.ndarray:
303
- """
304
- Extract OCR consistency features
305
-
306
- Args:
307
- image: Input image (H, W, 3) normalized [0, 1]
308
- region_mask: Binary region mask (H, W)
309
-
310
- Returns:
311
- OCR feature vector (or zeros if not text document)
312
- """
313
- features = []
314
-
315
- if not self.ocr_available:
316
- return np.zeros(6, dtype=np.float32)
317
-
318
- # Convert to uint8
319
- img_uint8 = (image * 255).astype(np.uint8)
320
-
321
- # Get region bounding box
322
- coords = np.where(region_mask > 0)
323
- if len(coords[0]) == 0:
324
- return np.zeros(6, dtype=np.float32)
325
-
326
- y_min, y_max = coords[0].min(), coords[0].max()
327
- x_min, x_max = coords[1].min(), coords[1].max()
328
-
329
- # Crop region
330
- region = img_uint8[y_min:y_max+1, x_min:x_max+1]
331
-
332
- try:
333
- # OCR on region
334
- results = self.reader.readtext(region)
335
-
336
- if len(results) > 0:
337
- # Confidence deviation
338
- confidences = [r[2] for r in results]
339
- features.append(np.mean(confidences))
340
- features.append(np.std(confidences))
341
-
342
- # Character spacing analysis
343
- bbox_widths = [abs(r[0][1][0] - r[0][0][0]) for r in results]
344
- if len(bbox_widths) > 1:
345
- features.append(np.std(bbox_widths) / (np.mean(bbox_widths) + 1e-8))
346
- else:
347
- features.append(0.0)
348
-
349
- # Text density
350
- features.append(len(results) / (region.shape[0] * region.shape[1] + 1e-8))
351
-
352
- # Stroke width variation (using edge detection)
353
- gray_region = cv2.cvtColor(region, cv2.COLOR_RGB2GRAY)
354
- edges = sobel(gray_region)
355
- features.append(np.mean(edges))
356
- features.append(np.std(edges))
357
- else:
358
- features.extend([0, 0, 0, 0, 0, 0])
359
- except Exception:
360
- features.extend([0, 0, 0, 0, 0, 0])
361
-
362
- return np.array(features, dtype=np.float32)
363
-
364
-
365
- class HybridFeatureExtractor:
366
- """
367
- Complete hybrid feature extraction
368
- Implements Critical Fix #5: Feature Group Gating
369
- """
370
-
371
- def __init__(self, config, is_text_document: bool = True):
372
- """
373
- Initialize hybrid feature extractor
374
-
375
- Args:
376
- config: Configuration object
377
- is_text_document: Whether input is text document (for OCR gating)
378
- """
379
- self.config = config
380
- self.is_text_document = is_text_document
381
-
382
- # Initialize extractors
383
- self.deep_extractor = DeepFeatureExtractor()
384
- self.stat_extractor = StatisticalFeatureExtractor()
385
- self.freq_extractor = FrequencyFeatureExtractor()
386
- self.noise_extractor = NoiseELAFeatureExtractor()
387
-
388
- # Critical Fix #5: OCR only for text documents
389
- if is_text_document and config.get('features.ocr.enabled', True):
390
- self.ocr_extractor = OCRFeatureExtractor()
391
- else:
392
- self.ocr_extractor = None
393
-
394
- def extract(self,
395
- image: np.ndarray,
396
- region_mask: np.ndarray,
397
- decoder_features: Optional[List[torch.Tensor]] = None) -> np.ndarray:
398
- """
399
- Extract all hybrid features for a region
400
-
401
- Args:
402
- image: Input image (H, W, 3) normalized [0, 1]
403
- region_mask: Binary region mask (H, W)
404
- decoder_features: Optional decoder features for deep feature extraction
405
-
406
- Returns:
407
- Concatenated feature vector
408
- """
409
- all_features = []
410
-
411
- # Deep features (if available)
412
- if decoder_features is not None and self.config.get('features.deep.enabled', True):
413
- deep_feats = self.deep_extractor.extract(decoder_features, region_mask)
414
- all_features.append(deep_feats)
415
-
416
- # Statistical & shape features
417
- if self.config.get('features.statistical.enabled', True):
418
- stat_feats = self.stat_extractor.extract(image, region_mask)
419
- all_features.append(stat_feats)
420
-
421
- # Frequency-domain features
422
- if self.config.get('features.frequency.enabled', True):
423
- freq_feats = self.freq_extractor.extract(image, region_mask)
424
- all_features.append(freq_feats)
425
-
426
- # Noise & ELA features
427
- if self.config.get('features.noise.enabled', True):
428
- noise_feats = self.noise_extractor.extract(image, region_mask)
429
- all_features.append(noise_feats)
430
-
431
- # Critical Fix #5: OCR features only for text documents
432
- if self.ocr_extractor is not None:
433
- ocr_feats = self.ocr_extractor.extract(image, region_mask)
434
- all_features.append(ocr_feats)
435
-
436
- # Concatenate all features
437
- if len(all_features) > 0:
438
- features = np.concatenate(all_features)
439
- else:
440
- features = np.array([], dtype=np.float32)
441
-
442
- # Handle NaN/Inf
443
- features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
444
-
445
- return features
446
-
447
- def get_feature_names(self) -> List[str]:
448
- """Get list of feature names for interpretability"""
449
- names = []
450
-
451
- if self.config.get('features.deep.enabled', True):
452
- names.extend([f'deep_{i}' for i in range(256)]) # Approximate
453
-
454
- if self.config.get('features.statistical.enabled', True):
455
- names.extend(['area', 'perimeter', 'aspect_ratio',
456
- 'solidity', 'eccentricity', 'entropy'])
457
-
458
- if self.config.get('features.frequency.enabled', True):
459
- names.extend(['dct_mean', 'dct_std', 'high_freq_energy',
460
- 'wavelet_cA', 'wavelet_cH', 'wavelet_cV', 'wavelet_cD',
461
- 'wavelet_entropy_H', 'wavelet_entropy_V', 'wavelet_entropy_D'])
462
-
463
- if self.config.get('features.noise.enabled', True):
464
- names.extend(['ela_mean', 'ela_var', 'ela_max',
465
- 'noise_residual_mean', 'noise_residual_var'])
466
-
467
- if self.ocr_extractor is not None:
468
- names.extend(['ocr_conf_mean', 'ocr_conf_std', 'spacing_irregularity',
469
- 'text_density', 'stroke_mean', 'stroke_std'])
470
-
471
- return names
472
-
473
-
474
- def get_feature_extractor(config, is_text_document: bool = True) -> HybridFeatureExtractor:
475
- """
476
- Factory function to create feature extractor
477
-
478
- Args:
479
- config: Configuration object
480
- is_text_document: Whether input is text document
481
-
482
- Returns:
483
- HybridFeatureExtractor instance
484
- """
485
- return HybridFeatureExtractor(config, is_text_document)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/src/features/region_extraction.py DELETED
@@ -1,226 +0,0 @@
1
- """
2
- Mask refinement and region extraction
3
- Implements Critical Fix #3: Adaptive Mask Refinement Thresholds
4
- """
5
-
6
- import cv2
7
- import numpy as np
8
- from typing import List, Tuple, Dict, Optional
9
- from scipy import ndimage
10
- from skimage.measure import label, regionprops
11
-
12
-
13
- class MaskRefiner:
14
- """
15
- Mask refinement with adaptive thresholds
16
- Implements Critical Fix #3: Dataset-specific minimum region areas
17
- """
18
-
19
- def __init__(self, config, dataset_name: str = 'default'):
20
- """
21
- Initialize mask refiner
22
-
23
- Args:
24
- config: Configuration object
25
- dataset_name: Dataset name for adaptive thresholds
26
- """
27
- self.config = config
28
- self.dataset_name = dataset_name
29
-
30
- # Get mask refinement parameters
31
- self.threshold = config.get('mask_refinement.threshold', 0.5)
32
- self.closing_kernel = config.get('mask_refinement.morphology.closing_kernel', 5)
33
- self.opening_kernel = config.get('mask_refinement.morphology.opening_kernel', 3)
34
-
35
- # Critical Fix #3: Adaptive thresholds per dataset
36
- self.min_region_area = config.get_min_region_area(dataset_name)
37
-
38
- print(f"MaskRefiner initialized for {dataset_name}")
39
- print(f"Min region area: {self.min_region_area * 100:.2f}%")
40
-
41
- def refine(self,
42
- probability_map: np.ndarray,
43
- original_size: Tuple[int, int] = None) -> np.ndarray:
44
- """
45
- Refine probability map to binary mask
46
-
47
- Args:
48
- probability_map: Forgery probability map (H, W), values [0, 1]
49
- original_size: Optional (H, W) to resize mask back to original
50
-
51
- Returns:
52
- Refined binary mask (H, W)
53
- """
54
- # Threshold to binary
55
- binary_mask = (probability_map > self.threshold).astype(np.uint8)
56
-
57
- # Morphological closing (fill broken strokes)
58
- closing_kernel = cv2.getStructuringElement(
59
- cv2.MORPH_RECT,
60
- (self.closing_kernel, self.closing_kernel)
61
- )
62
- binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, closing_kernel)
63
-
64
- # Morphological opening (remove isolated noise)
65
- opening_kernel = cv2.getStructuringElement(
66
- cv2.MORPH_RECT,
67
- (self.opening_kernel, self.opening_kernel)
68
- )
69
- binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, opening_kernel)
70
-
71
- # Critical Fix #3: Remove small regions with adaptive threshold
72
- binary_mask = self._remove_small_regions(binary_mask)
73
-
74
- # Resize to original size if provided
75
- if original_size is not None:
76
- binary_mask = cv2.resize(
77
- binary_mask,
78
- (original_size[1], original_size[0]), # cv2 uses (W, H)
79
- interpolation=cv2.INTER_NEAREST
80
- )
81
-
82
- return binary_mask
83
-
84
- def _remove_small_regions(self, mask: np.ndarray) -> np.ndarray:
85
- """
86
- Remove regions smaller than minimum area threshold
87
-
88
- Args:
89
- mask: Binary mask (H, W)
90
-
91
- Returns:
92
- Filtered mask
93
- """
94
- # Calculate minimum pixel count
95
- image_area = mask.shape[0] * mask.shape[1]
96
- min_pixels = int(image_area * self.min_region_area)
97
-
98
- # Label connected components
99
- labeled_mask, num_features = ndimage.label(mask)
100
-
101
- # Keep only large enough regions
102
- filtered_mask = np.zeros_like(mask)
103
-
104
- for region_id in range(1, num_features + 1):
105
- region_mask = (labeled_mask == region_id)
106
- region_area = region_mask.sum()
107
-
108
- if region_area >= min_pixels:
109
- filtered_mask[region_mask] = 1
110
-
111
- return filtered_mask
112
-
113
-
114
- class RegionExtractor:
115
- """
116
- Extract individual regions from binary mask
117
- Implements Critical Fix #4: Region Confidence Aggregation
118
- """
119
-
120
- def __init__(self, config, dataset_name: str = 'default'):
121
- """
122
- Initialize region extractor
123
-
124
- Args:
125
- config: Configuration object
126
- dataset_name: Dataset name
127
- """
128
- self.config = config
129
- self.dataset_name = dataset_name
130
- self.min_region_area = config.get_min_region_area(dataset_name)
131
-
132
- def extract(self,
133
- binary_mask: np.ndarray,
134
- probability_map: np.ndarray,
135
- original_image: np.ndarray) -> List[Dict]:
136
- """
137
- Extract regions from binary mask
138
-
139
- Args:
140
- binary_mask: Refined binary mask (H, W)
141
- probability_map: Original probability map (H, W)
142
- original_image: Original image (H, W, 3)
143
-
144
- Returns:
145
- List of region dictionaries with bounding box, mask, image, confidence
146
- """
147
- regions = []
148
-
149
- # Connected component analysis (8-connectivity)
150
- labeled_mask = label(binary_mask, connectivity=2)
151
- props = regionprops(labeled_mask)
152
-
153
- for region_id, prop in enumerate(props, start=1):
154
- # Bounding box
155
- y_min, x_min, y_max, x_max = prop.bbox
156
-
157
- # Region mask
158
- region_mask = (labeled_mask == region_id).astype(np.uint8)
159
-
160
- # Cropped region image
161
- region_image = original_image[y_min:y_max, x_min:x_max].copy()
162
- region_mask_cropped = region_mask[y_min:y_max, x_min:x_max]
163
-
164
- # Critical Fix #4: Region-level confidence aggregation
165
- region_probs = probability_map[region_mask > 0]
166
- region_confidence = float(np.mean(region_probs)) if len(region_probs) > 0 else 0.0
167
-
168
- regions.append({
169
- 'region_id': region_id,
170
- 'bounding_box': [int(x_min), int(y_min),
171
- int(x_max - x_min), int(y_max - y_min)],
172
- 'area': prop.area,
173
- 'centroid': (int(prop.centroid[1]), int(prop.centroid[0])),
174
- 'region_mask': region_mask,
175
- 'region_mask_cropped': region_mask_cropped,
176
- 'region_image': region_image,
177
- 'confidence': region_confidence,
178
- 'mask_probability_mean': region_confidence
179
- })
180
-
181
- return regions
182
-
183
- def extract_for_casia(self,
184
- binary_mask: np.ndarray,
185
- probability_map: np.ndarray,
186
- original_image: np.ndarray) -> List[Dict]:
187
- """
188
- Critical Fix #6: CASIA handling - treat entire image as one region
189
-
190
- Args:
191
- binary_mask: Binary mask (may be empty for authentic images)
192
- probability_map: Probability map
193
- original_image: Original image
194
-
195
- Returns:
196
- Single region representing entire image
197
- """
198
- h, w = original_image.shape[:2]
199
-
200
- # Create single region covering entire image
201
- region_mask = np.ones((h, w), dtype=np.uint8)
202
-
203
- # Overall confidence from probability map
204
- overall_confidence = float(np.mean(probability_map))
205
-
206
- return [{
207
- 'region_id': 1,
208
- 'bounding_box': [0, 0, w, h],
209
- 'area': h * w,
210
- 'centroid': (w // 2, h // 2),
211
- 'region_mask': region_mask,
212
- 'region_mask_cropped': region_mask,
213
- 'region_image': original_image,
214
- 'confidence': overall_confidence,
215
- 'mask_probability_mean': overall_confidence
216
- }]
217
-
218
-
219
- def get_mask_refiner(config, dataset_name: str = 'default') -> MaskRefiner:
220
- """Factory function for mask refiner"""
221
- return MaskRefiner(config, dataset_name)
222
-
223
-
224
- def get_region_extractor(config, dataset_name: str = 'default') -> RegionExtractor:
225
- """Factory function for region extractor"""
226
- return RegionExtractor(config, dataset_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/src/inference/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- """Inference module"""
2
-
3
- from .pipeline import ForgeryDetectionPipeline, get_pipeline
4
-
5
- __all__ = ['ForgeryDetectionPipeline', 'get_pipeline']
 
 
 
 
 
 
deployment/src/inference/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (305 Bytes)
 
deployment/src/inference/__pycache__/pipeline.cpython-312.pyc DELETED
Binary file (14.5 kB)
 
deployment/src/inference/pipeline.py DELETED
@@ -1,359 +0,0 @@
1
- """
2
- Inference pipeline for document forgery detection
3
- Complete pipeline: Image → Localization → Regions → Classification → Output
4
- """
5
-
6
- import cv2
7
- import numpy as np
8
- import torch
9
- from typing import Dict, List, Optional, Tuple
10
- from pathlib import Path
11
- import json
12
- from PIL import Image
13
- import fitz # PyMuPDF
14
-
15
- from ..config import get_config
16
- from ..models import get_model
17
- from ..features import (
18
- get_feature_extractor,
19
- get_mask_refiner,
20
- get_region_extractor
21
- )
22
- from ..training.classifier import get_classifier
23
-
24
-
25
- class ForgeryDetectionPipeline:
26
- """
27
- Complete inference pipeline for document forgery detection
28
-
29
- Pipeline:
30
- 1. Input handling (PDF/Image)
31
- 2. Preprocessing
32
- 3. Deep localization
33
- 4. Mask refinement
34
- 5. Region extraction
35
- 6. Feature extraction
36
- 7. Classification
37
- 8. Post-processing
38
- 9. Output generation
39
- """
40
-
41
- def __init__(self,
42
- config,
43
- model_path: str,
44
- classifier_path: Optional[str] = None,
45
- is_text_document: bool = True):
46
- """
47
- Initialize pipeline
48
-
49
- Args:
50
- config: Configuration object
51
- model_path: Path to localization model checkpoint
52
- classifier_path: Path to classifier (optional)
53
- is_text_document: Whether input is text document (for OCR features)
54
- """
55
- self.config = config
56
- self.is_text_document = is_text_document
57
-
58
- # Device
59
- self.device = torch.device(
60
- 'cuda' if torch.cuda.is_available() and config.get('system.device') == 'cuda'
61
- else 'cpu'
62
- )
63
- print(f"Inference device: {self.device}")
64
-
65
- # Load localization model
66
- self.model = get_model(config).to(self.device)
67
- self._load_model(model_path)
68
- self.model.eval()
69
-
70
- # Initialize mask refiner
71
- self.mask_refiner = get_mask_refiner(config, 'default')
72
-
73
- # Initialize region extractor
74
- self.region_extractor = get_region_extractor(config, 'default')
75
-
76
- # Initialize feature extractor
77
- self.feature_extractor = get_feature_extractor(config, is_text_document)
78
-
79
- # Load classifier if provided
80
- if classifier_path:
81
- self.classifier = get_classifier(config)
82
- self.classifier.load(classifier_path)
83
- else:
84
- self.classifier = None
85
-
86
- # Confidence threshold
87
- self.confidence_threshold = config.get('classifier.confidence_threshold', 0.6)
88
-
89
- # Image size
90
- self.image_size = config.get('data.image_size', 384)
91
-
92
- print("Inference pipeline initialized")
93
-
94
- def _load_model(self, model_path: str):
95
- """Load model checkpoint"""
96
- checkpoint = torch.load(model_path, map_location=self.device)
97
-
98
- if 'model_state_dict' in checkpoint:
99
- self.model.load_state_dict(checkpoint['model_state_dict'])
100
- else:
101
- self.model.load_state_dict(checkpoint)
102
-
103
- print(f"Loaded model from {model_path}")
104
-
105
- def _load_image(self, input_path: str) -> np.ndarray:
106
- """
107
- Load image from file or PDF
108
-
109
- Args:
110
- input_path: Path to image or PDF
111
-
112
- Returns:
113
- Image as numpy array (H, W, 3)
114
- """
115
- path = Path(input_path)
116
-
117
- if path.suffix.lower() == '.pdf':
118
- # Rasterize PDF at 300 DPI
119
- doc = fitz.open(str(path))
120
- page = doc[0]
121
- mat = fitz.Matrix(300/72, 300/72) # 300 DPI
122
- pix = page.get_pixmap(matrix=mat)
123
- image = np.frombuffer(pix.samples, dtype=np.uint8)
124
- image = image.reshape(pix.height, pix.width, pix.n)
125
- if pix.n == 4:
126
- image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
127
- doc.close()
128
- else:
129
- # Load image
130
- image = cv2.imread(str(path))
131
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
132
-
133
- return image
134
-
135
- def _preprocess(self, image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
136
- """
137
- Preprocess image for inference
138
-
139
- Args:
140
- image: Input image (H, W, 3)
141
-
142
- Returns:
143
- Preprocessed image and original image
144
- """
145
- original = image.copy()
146
-
147
- # Resize
148
- preprocessed = cv2.resize(image, (self.image_size, self.image_size))
149
-
150
- # Normalize to [0, 1]
151
- preprocessed = preprocessed.astype(np.float32) / 255.0
152
-
153
- return preprocessed, original
154
-
155
- def _to_tensor(self, image: np.ndarray) -> torch.Tensor:
156
- """Convert image to tensor"""
157
- # (H, W, C) -> (C, H, W)
158
- tensor = torch.from_numpy(image.transpose(2, 0, 1))
159
- tensor = tensor.unsqueeze(0) # Add batch dimension
160
- return tensor.to(self.device)
161
-
162
- def run(self,
163
- input_path: str,
164
- output_dir: Optional[str] = None) -> Dict:
165
- """
166
- Run full inference pipeline
167
-
168
- Args:
169
- input_path: Path to input image or PDF
170
- output_dir: Optional output directory
171
-
172
- Returns:
173
- Dictionary with results
174
- """
175
- print(f"\n{'='*60}")
176
- print(f"Processing: {input_path}")
177
- print(f"{'='*60}")
178
-
179
- # 1. Load image
180
- image = self._load_image(input_path)
181
- original_size = image.shape[:2]
182
- print(f"Input size: {original_size}")
183
-
184
- # 2. Preprocess
185
- preprocessed, original = self._preprocess(image)
186
- tensor = self._to_tensor(preprocessed)
187
-
188
- # 3. Deep localization
189
- with torch.no_grad():
190
- logits, decoder_features = self.model(tensor)
191
- probability_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
192
-
193
- print(f"Localization complete. Max prob: {probability_map.max():.3f}")
194
-
195
- # 4. Mask refinement
196
- binary_mask = self.mask_refiner.refine(probability_map, original_size)
197
- num_positive_pixels = binary_mask.sum()
198
- print(f"Mask refinement: {num_positive_pixels} positive pixels")
199
-
200
- # 5. Region extraction
201
- # Resize probability map to original size for confidence aggregation
202
- prob_resized = cv2.resize(probability_map, (original_size[1], original_size[0]))
203
-
204
- regions = self.region_extractor.extract(binary_mask, prob_resized, original)
205
- print(f"Regions extracted: {len(regions)}")
206
-
207
- # 6. Feature extraction & 7. Classification
208
- results = []
209
-
210
- for region in regions:
211
- # Extract features
212
- features = self.feature_extractor.extract(
213
- preprocessed,
214
- cv2.resize(region['region_mask'], (self.image_size, self.image_size)),
215
- [f.cpu() for f in decoder_features]
216
- )
217
-
218
- # Classify if classifier available
219
- if self.classifier is not None:
220
- predictions, confidences, valid_mask = self.classifier.predict_with_filtering(
221
- features.reshape(1, -1)
222
- )
223
-
224
- if valid_mask[0]:
225
- region['forgery_type'] = self.classifier.get_class_name(predictions[0])
226
- region['classification_confidence'] = float(confidences[0])
227
- else:
228
- # Low confidence - discard
229
- continue
230
- else:
231
- region['forgery_type'] = 'unknown'
232
- region['classification_confidence'] = region['confidence']
233
-
234
- # Clean up non-serializable fields
235
- region_result = {
236
- 'region_id': region['region_id'],
237
- 'bounding_box': region['bounding_box'],
238
- 'forgery_type': region['forgery_type'],
239
- 'confidence': region['confidence'],
240
- 'classification_confidence': region['classification_confidence'],
241
- 'mask_probability_mean': region['mask_probability_mean'],
242
- 'area': region['area']
243
- }
244
- results.append(region_result)
245
-
246
- print(f"Valid regions after filtering: {len(results)}")
247
-
248
- # 8. Post-processing - False positive removal
249
- results = self._post_process(results)
250
-
251
- # 9. Generate output
252
- output = {
253
- 'input_path': str(input_path),
254
- 'original_size': original_size,
255
- 'num_regions': len(results),
256
- 'regions': results,
257
- 'is_tampered': len(results) > 0
258
- }
259
-
260
- # Save outputs if directory provided
261
- if output_dir:
262
- output_path = Path(output_dir)
263
- output_path.mkdir(parents=True, exist_ok=True)
264
-
265
- input_name = Path(input_path).stem
266
-
267
- # Save final mask
268
- mask_path = output_path / f'{input_name}_mask.png'
269
- cv2.imwrite(str(mask_path), binary_mask * 255)
270
-
271
- # Save overlay visualization
272
- overlay = self._create_overlay(original, binary_mask, results)
273
- overlay_path = output_path / f'{input_name}_overlay.png'
274
- cv2.imwrite(str(overlay_path), cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
275
-
276
- # Save JSON
277
- json_path = output_path / f'{input_name}_results.json'
278
- with open(json_path, 'w') as f:
279
- json.dump(output, f, indent=2)
280
-
281
- print(f"\nOutputs saved to: {output_path}")
282
- output['mask_path'] = str(mask_path)
283
- output['overlay_path'] = str(overlay_path)
284
- output['json_path'] = str(json_path)
285
-
286
- return output
287
-
288
- def _post_process(self, regions: List[Dict]) -> List[Dict]:
289
- """
290
- Post-process regions to remove false positives
291
-
292
- Args:
293
- regions: List of region dictionaries
294
-
295
- Returns:
296
- Filtered regions
297
- """
298
- filtered = []
299
-
300
- for region in regions:
301
- # Confidence filtering
302
- if region['confidence'] < self.confidence_threshold:
303
- continue
304
-
305
- filtered.append(region)
306
-
307
- return filtered
308
-
309
- def _create_overlay(self,
310
- image: np.ndarray,
311
- mask: np.ndarray,
312
- regions: List[Dict]) -> np.ndarray:
313
- """
314
- Create visualization overlay
315
-
316
- Args:
317
- image: Original image
318
- mask: Binary mask
319
- regions: Detected regions
320
-
321
- Returns:
322
- Overlay image
323
- """
324
- overlay = image.copy()
325
- alpha = self.config.get('outputs.visualization.overlay_alpha', 0.5)
326
-
327
- # Create colored mask
328
- mask_colored = np.zeros_like(image)
329
- mask_colored[mask > 0] = [255, 0, 0] # Red for forgery
330
-
331
- # Blend
332
- mask_resized = cv2.resize(mask, (image.shape[1], image.shape[0]))
333
- overlay = np.where(
334
- mask_resized[:, :, None] > 0,
335
- (1 - alpha) * image + alpha * mask_colored,
336
- image
337
- ).astype(np.uint8)
338
-
339
- # Draw bounding boxes and labels
340
- for region in regions:
341
- x, y, w, h = region['bounding_box']
342
-
343
- # Draw rectangle
344
- cv2.rectangle(overlay, (x, y), (x + w, y + h), (0, 255, 0), 2)
345
-
346
- # Draw label
347
- label = f"{region['forgery_type']} ({region['confidence']:.2f})"
348
- cv2.putText(overlay, label, (x, y - 10),
349
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
350
-
351
- return overlay
352
-
353
-
354
- def get_pipeline(config,
355
- model_path: str,
356
- classifier_path: Optional[str] = None,
357
- is_text_document: bool = True) -> ForgeryDetectionPipeline:
358
- """Factory function for pipeline"""
359
- return ForgeryDetectionPipeline(config, model_path, classifier_path, is_text_document)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/src/models/__init__.py DELETED
@@ -1,19 +0,0 @@
1
- """Models module"""
2
-
3
- from .encoder import MobileNetV3Encoder, get_encoder
4
- from .decoder import UNetLiteDecoder, get_decoder
5
- from .network import ForgeryLocalizationNetwork, get_model
6
- from .losses import DiceLoss, CombinedLoss, DatasetAwareLoss, get_loss_function
7
-
8
- __all__ = [
9
- 'MobileNetV3Encoder',
10
- 'get_encoder',
11
- 'UNetLiteDecoder',
12
- 'get_decoder',
13
- 'ForgeryLocalizationNetwork',
14
- 'get_model',
15
- 'DiceLoss',
16
- 'CombinedLoss',
17
- 'DatasetAwareLoss',
18
- 'get_loss_function'
19
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/src/models/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (600 Bytes)
 
deployment/src/models/__pycache__/decoder.cpython-312.pyc DELETED
Binary file (7.65 kB)
 
deployment/src/models/__pycache__/encoder.cpython-312.pyc DELETED
Binary file (2.91 kB)
 
deployment/src/models/__pycache__/losses.cpython-312.pyc DELETED
Binary file (6.55 kB)
 
deployment/src/models/__pycache__/network.cpython-312.pyc DELETED
Binary file (5.84 kB)
 
deployment/src/models/decoder.py DELETED
@@ -1,186 +0,0 @@
1
- """
2
- UNet-Lite Decoder for forgery localization
3
- Lightweight decoder with skip connections, depthwise separable convolutions
4
- """
5
-
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- from typing import List
10
-
11
-
12
- class DepthwiseSeparableConv(nn.Module):
13
- """Depthwise separable convolution for efficiency"""
14
-
15
- def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
16
- super().__init__()
17
-
18
- self.depthwise = nn.Conv2d(
19
- in_channels, in_channels,
20
- kernel_size=kernel_size,
21
- padding=kernel_size // 2,
22
- groups=in_channels,
23
- bias=False
24
- )
25
- self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
26
- self.bn = nn.BatchNorm2d(out_channels)
27
- self.relu = nn.ReLU(inplace=True)
28
-
29
- def forward(self, x: torch.Tensor) -> torch.Tensor:
30
- x = self.depthwise(x)
31
- x = self.pointwise(x)
32
- x = self.bn(x)
33
- x = self.relu(x)
34
- return x
35
-
36
-
37
- class DecoderBlock(nn.Module):
38
- """Single decoder block with skip connection"""
39
-
40
- def __init__(self, in_channels: int, skip_channels: int, out_channels: int):
41
- """
42
- Initialize decoder block
43
-
44
- Args:
45
- in_channels: Input channels from previous decoder stage
46
- skip_channels: Channels from encoder skip connection
47
- out_channels: Output channels
48
- """
49
- super().__init__()
50
-
51
- # Combine upsampled features with skip connection
52
- combined_channels = in_channels + skip_channels
53
-
54
- self.conv1 = DepthwiseSeparableConv(combined_channels, out_channels)
55
- self.conv2 = DepthwiseSeparableConv(out_channels, out_channels)
56
-
57
- def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
58
- """
59
- Forward pass
60
-
61
- Args:
62
- x: Input from previous decoder stage
63
- skip: Skip connection from encoder
64
-
65
- Returns:
66
- Decoded features
67
- """
68
- # Bilinear upsampling
69
- x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
70
-
71
- # Concatenate with skip connection
72
- x = torch.cat([x, skip], dim=1)
73
-
74
- # Convolutions
75
- x = self.conv1(x)
76
- x = self.conv2(x)
77
-
78
- return x
79
-
80
-
81
- class UNetLiteDecoder(nn.Module):
82
- """
83
- UNet-Lite decoder for forgery localization
84
-
85
- Features:
86
- - Skip connections from encoder stages
87
- - Bilinear upsampling
88
- - Depthwise separable convolutions for efficiency
89
- """
90
-
91
- def __init__(self,
92
- encoder_channels: List[int],
93
- decoder_channels: List[int] = None,
94
- output_channels: int = 1):
95
- """
96
- Initialize decoder
97
-
98
- Args:
99
- encoder_channels: List of encoder feature channels [stage0, ..., stageN]
100
- decoder_channels: List of decoder output channels
101
- output_channels: Number of output channels (1 for binary mask)
102
- """
103
- super().__init__()
104
-
105
- # Default decoder channels if not provided
106
- if decoder_channels is None:
107
- decoder_channels = [256, 128, 64, 32, 16]
108
-
109
- # Reverse encoder channels for decoder (bottom to top)
110
- encoder_channels = encoder_channels[::-1]
111
-
112
- # Initial convolution from deepest encoder features
113
- self.initial_conv = DepthwiseSeparableConv(encoder_channels[0], decoder_channels[0])
114
-
115
- # Decoder blocks
116
- self.decoder_blocks = nn.ModuleList()
117
-
118
- for i in range(len(encoder_channels) - 1):
119
- in_ch = decoder_channels[i]
120
- skip_ch = encoder_channels[i + 1]
121
- out_ch = decoder_channels[i + 1] if i + 1 < len(decoder_channels) else decoder_channels[-1]
122
-
123
- self.decoder_blocks.append(
124
- DecoderBlock(in_ch, skip_ch, out_ch)
125
- )
126
-
127
- # Final upsampling to original resolution
128
- self.final_upsample = nn.Sequential(
129
- DepthwiseSeparableConv(decoder_channels[-1], decoder_channels[-1]),
130
- nn.Conv2d(decoder_channels[-1], output_channels, kernel_size=1)
131
- )
132
-
133
- # Store decoder feature channels for feature extraction
134
- self.decoder_channels = decoder_channels
135
-
136
- print(f"UNet-Lite decoder initialized")
137
- print(f"Encoder channels: {encoder_channels[::-1]}")
138
- print(f"Decoder channels: {decoder_channels}")
139
-
140
- def forward(self, encoder_features: List[torch.Tensor]) -> tuple:
141
- """
142
- Forward pass
143
-
144
- Args:
145
- encoder_features: List of encoder features [stage0, ..., stageN]
146
-
147
- Returns:
148
- output: Forgery probability map (B, 1, H, W)
149
- decoder_features: List of decoder features for hybrid extraction
150
- """
151
- # Reverse for bottom-up decoding
152
- features = encoder_features[::-1]
153
-
154
- # Initial convolution
155
- x = self.initial_conv(features[0])
156
-
157
- # Store decoder features for hybrid feature extraction
158
- decoder_features = [x]
159
-
160
- # Decoder blocks with skip connections
161
- for i, block in enumerate(self.decoder_blocks):
162
- x = block(x, features[i + 1])
163
- decoder_features.append(x)
164
-
165
- # Final upsampling to original resolution
166
- # Assume input was 384x384, final feature map should match
167
- target_size = encoder_features[0].shape[2] * 2 # First encoder feature is at 1/2 scale
168
- x = F.interpolate(x, size=(target_size, target_size), mode='bilinear', align_corners=False)
169
- output = self.final_upsample[1](self.final_upsample[0](x))
170
-
171
- return output, decoder_features
172
-
173
-
174
- def get_decoder(encoder_channels: List[int], config) -> UNetLiteDecoder:
175
- """
176
- Factory function to create decoder
177
-
178
- Args:
179
- encoder_channels: Encoder feature channels
180
- config: Configuration object
181
-
182
- Returns:
183
- Decoder instance
184
- """
185
- output_channels = config.get('model.output_channels', 1)
186
- return UNetLiteDecoder(encoder_channels, output_channels=output_channels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/src/models/encoder.py DELETED
@@ -1,75 +0,0 @@
1
- """
2
- MobileNetV3-Small Encoder for forgery localization
3
- ImageNet pretrained, feature extraction mode
4
- """
5
-
6
- import torch
7
- import torch.nn as nn
8
- import timm
9
- from typing import List
10
-
11
-
12
- class MobileNetV3Encoder(nn.Module):
13
- """
14
- MobileNetV3-Small encoder for document forgery detection
15
-
16
- Chosen for:
17
- - Stroke-level and texture preservation
18
- - Robustness to compression and blur
19
- - Edge and CPU deployment efficiency
20
- """
21
-
22
- def __init__(self, pretrained: bool = True):
23
- """
24
- Initialize encoder
25
-
26
- Args:
27
- pretrained: Whether to use ImageNet pretrained weights
28
- """
29
- super().__init__()
30
-
31
- # Load MobileNetV3-Small with feature extraction
32
- self.backbone = timm.create_model(
33
- 'mobilenetv3_small_100',
34
- pretrained=pretrained,
35
- features_only=True,
36
- out_indices=(0, 1, 2, 3, 4) # All feature stages
37
- )
38
-
39
- # Get feature channel dimensions
40
- # MobileNetV3-Small: [16, 16, 24, 48, 576]
41
- self.feature_channels = self.backbone.feature_info.channels()
42
-
43
- print(f"MobileNetV3-Small encoder initialized")
44
- print(f"Feature channels: {self.feature_channels}")
45
-
46
- def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
47
- """
48
- Extract multi-scale features
49
-
50
- Args:
51
- x: Input tensor (B, 3, H, W)
52
-
53
- Returns:
54
- List of feature tensors at different scales
55
- """
56
- features = self.backbone(x)
57
- return features
58
-
59
- def get_feature_channels(self) -> List[int]:
60
- """Get feature channel dimensions for each stage"""
61
- return self.feature_channels
62
-
63
-
64
- def get_encoder(config) -> MobileNetV3Encoder:
65
- """
66
- Factory function to create encoder
67
-
68
- Args:
69
- config: Configuration object
70
-
71
- Returns:
72
- Encoder instance
73
- """
74
- pretrained = config.get('model.encoder.pretrained', True)
75
- return MobileNetV3Encoder(pretrained=pretrained)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/src/models/losses.py DELETED
@@ -1,168 +0,0 @@
1
- """
2
- Dataset-aware loss functions
3
- Implements Critical Fix #2: Dataset-Aware Loss Function
4
- """
5
-
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- from typing import Dict, Optional
10
-
11
-
12
- class DiceLoss(nn.Module):
13
- """Dice loss for segmentation"""
14
-
15
- def __init__(self, smooth: float = 1.0):
16
- """
17
- Initialize Dice loss
18
-
19
- Args:
20
- smooth: Smoothing factor to avoid division by zero
21
- """
22
- super().__init__()
23
- self.smooth = smooth
24
-
25
- def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
26
- """
27
- Compute Dice loss
28
-
29
- Args:
30
- pred: Predicted probabilities (B, 1, H, W)
31
- target: Ground truth mask (B, 1, H, W)
32
-
33
- Returns:
34
- Dice loss value
35
- """
36
- pred = torch.sigmoid(pred)
37
-
38
- # Flatten
39
- pred_flat = pred.view(-1)
40
- target_flat = target.view(-1)
41
-
42
- # Dice coefficient
43
- intersection = (pred_flat * target_flat).sum()
44
- dice = (2. * intersection + self.smooth) / (
45
- pred_flat.sum() + target_flat.sum() + self.smooth
46
- )
47
-
48
- return 1 - dice
49
-
50
-
51
- class CombinedLoss(nn.Module):
52
- """
53
- Combined BCE + Dice loss for segmentation
54
- Dataset-aware: Only uses Dice when pixel masks are available
55
- """
56
-
57
- def __init__(self,
58
- bce_weight: float = 1.0,
59
- dice_weight: float = 1.0):
60
- """
61
- Initialize combined loss
62
-
63
- Args:
64
- bce_weight: Weight for BCE loss
65
- dice_weight: Weight for Dice loss
66
- """
67
- super().__init__()
68
-
69
- self.bce_weight = bce_weight
70
- self.dice_weight = dice_weight
71
-
72
- self.bce_loss = nn.BCEWithLogitsLoss()
73
- self.dice_loss = DiceLoss()
74
-
75
- def forward(self,
76
- pred: torch.Tensor,
77
- target: torch.Tensor,
78
- has_pixel_mask: bool = True) -> Dict[str, torch.Tensor]:
79
- """
80
- Compute loss (dataset-aware)
81
-
82
- Critical Fix #2: Only use Dice loss for datasets with pixel masks
83
-
84
- Args:
85
- pred: Predicted logits (B, 1, H, W)
86
- target: Ground truth mask (B, 1, H, W)
87
- has_pixel_mask: Whether dataset has pixel-level masks
88
-
89
- Returns:
90
- Dictionary with 'total', 'bce', and optionally 'dice' losses
91
- """
92
- # BCE loss (always used)
93
- bce = self.bce_loss(pred, target)
94
-
95
- losses = {
96
- 'bce': bce
97
- }
98
-
99
- if has_pixel_mask:
100
- # Use Dice loss only for datasets with pixel masks
101
- dice = self.dice_loss(pred, target)
102
- losses['dice'] = dice
103
- losses['total'] = self.bce_weight * bce + self.dice_weight * dice
104
- else:
105
- # Critical Fix #2: CASIA only uses BCE
106
- losses['total'] = self.bce_weight * bce
107
-
108
- return losses
109
-
110
-
111
- class DatasetAwareLoss(nn.Module):
112
- """
113
- Dataset-aware loss function wrapper
114
- Automatically determines appropriate loss based on dataset metadata
115
- """
116
-
117
- def __init__(self, config):
118
- """
119
- Initialize dataset-aware loss
120
-
121
- Args:
122
- config: Configuration object
123
- """
124
- super().__init__()
125
-
126
- self.config = config
127
-
128
- bce_weight = config.get('loss.bce_weight', 1.0)
129
- dice_weight = config.get('loss.dice_weight', 1.0)
130
-
131
- self.combined_loss = CombinedLoss(
132
- bce_weight=bce_weight,
133
- dice_weight=dice_weight
134
- )
135
-
136
- def forward(self,
137
- pred: torch.Tensor,
138
- target: torch.Tensor,
139
- metadata: Dict) -> Dict[str, torch.Tensor]:
140
- """
141
- Compute loss with dataset awareness
142
-
143
- Args:
144
- pred: Predicted logits (B, 1, H, W)
145
- target: Ground truth mask (B, 1, H, W)
146
- metadata: Batch metadata containing 'has_pixel_mask' flags
147
-
148
- Returns:
149
- Dictionary with loss components
150
- """
151
- # Check if batch has pixel masks
152
- has_pixel_mask = all(m.get('has_pixel_mask', True) for m in metadata) \
153
- if isinstance(metadata, list) else metadata.get('has_pixel_mask', True)
154
-
155
- return self.combined_loss(pred, target, has_pixel_mask)
156
-
157
-
158
- def get_loss_function(config) -> DatasetAwareLoss:
159
- """
160
- Factory function to create loss
161
-
162
- Args:
163
- config: Configuration object
164
-
165
- Returns:
166
- Loss function instance
167
- """
168
- return DatasetAwareLoss(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/src/models/network.py DELETED
@@ -1,133 +0,0 @@
1
- """
2
- Complete Forgery Localization Network
3
- MobileNetV3-Small Encoder + UNet-Lite Decoder
4
- """
5
-
6
- import torch
7
- import torch.nn as nn
8
- from typing import Tuple, List, Optional
9
-
10
- from .encoder import MobileNetV3Encoder
11
- from .decoder import UNetLiteDecoder
12
-
13
-
14
- class ForgeryLocalizationNetwork(nn.Module):
15
- """
16
- Complete network for forgery localization
17
-
18
- Architecture:
19
- - Encoder: MobileNetV3-Small (ImageNet pretrained)
20
- - Decoder: UNet-Lite with skip connections
21
- - Output: Single-channel forgery probability map
22
- """
23
-
24
- def __init__(self, config):
25
- """
26
- Initialize network
27
-
28
- Args:
29
- config: Configuration object
30
- """
31
- super().__init__()
32
-
33
- self.config = config
34
-
35
- # Initialize encoder
36
- pretrained = config.get('model.encoder.pretrained', True)
37
- self.encoder = MobileNetV3Encoder(pretrained=pretrained)
38
-
39
- # Initialize decoder
40
- encoder_channels = self.encoder.get_feature_channels()
41
- output_channels = config.get('model.output_channels', 1)
42
- self.decoder = UNetLiteDecoder(
43
- encoder_channels=encoder_channels,
44
- output_channels=output_channels
45
- )
46
-
47
- print(f"ForgeryLocalizationNetwork initialized")
48
- print(f"Total parameters: {self.count_parameters():,}")
49
-
50
- def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
51
- """
52
- Forward pass
53
-
54
- Args:
55
- x: Input image tensor (B, 3, H, W)
56
-
57
- Returns:
58
- output: Forgery probability map (B, 1, H, W) - logits
59
- decoder_features: Decoder features for hybrid feature extraction
60
- """
61
- # Encode
62
- encoder_features = self.encoder(x)
63
-
64
- # Decode
65
- output, decoder_features = self.decoder(encoder_features)
66
-
67
- return output, decoder_features
68
-
69
- def predict(self, x: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
70
- """
71
- Predict binary mask
72
-
73
- Args:
74
- x: Input image tensor (B, 3, H, W)
75
- threshold: Probability threshold for binarization
76
-
77
- Returns:
78
- Binary mask (B, 1, H, W)
79
- """
80
- with torch.no_grad():
81
- logits, _ = self.forward(x)
82
- probs = torch.sigmoid(logits)
83
- mask = (probs > threshold).float()
84
-
85
- return mask
86
-
87
- def get_probability_map(self, x: torch.Tensor) -> torch.Tensor:
88
- """
89
- Get probability map
90
-
91
- Args:
92
- x: Input image tensor (B, 3, H, W)
93
-
94
- Returns:
95
- Probability map (B, 1, H, W)
96
- """
97
- with torch.no_grad():
98
- logits, _ = self.forward(x)
99
- probs = torch.sigmoid(logits)
100
-
101
- return probs
102
-
103
- def count_parameters(self) -> int:
104
- """Count total trainable parameters"""
105
- return sum(p.numel() for p in self.parameters() if p.requires_grad)
106
-
107
- def get_decoder_features(self, x: torch.Tensor) -> List[torch.Tensor]:
108
- """
109
- Get decoder features for hybrid feature extraction
110
-
111
- Args:
112
- x: Input image tensor (B, 3, H, W)
113
-
114
- Returns:
115
- List of decoder features
116
- """
117
- with torch.no_grad():
118
- _, decoder_features = self.forward(x)
119
-
120
- return decoder_features
121
-
122
-
123
- def get_model(config) -> ForgeryLocalizationNetwork:
124
- """
125
- Factory function to create model
126
-
127
- Args:
128
- config: Configuration object
129
-
130
- Returns:
131
- Model instance
132
- """
133
- return ForgeryLocalizationNetwork(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/src/training/__init__.py DELETED
@@ -1,24 +0,0 @@
1
- """Training module"""
2
-
3
- from .metrics import (
4
- SegmentationMetrics,
5
- ClassificationMetrics,
6
- MetricsTracker,
7
- EarlyStopping,
8
- get_metrics_tracker
9
- )
10
-
11
- from .trainer import Trainer, get_trainer
12
- from .classifier import ForgeryClassifier, get_classifier
13
-
14
- __all__ = [
15
- 'SegmentationMetrics',
16
- 'ClassificationMetrics',
17
- 'MetricsTracker',
18
- 'EarlyStopping',
19
- 'get_metrics_tracker',
20
- 'Trainer',
21
- 'get_trainer',
22
- 'ForgeryClassifier',
23
- 'get_classifier'
24
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/src/training/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (568 Bytes)
 
deployment/src/training/__pycache__/classifier.cpython-312.pyc DELETED
Binary file (11 kB)
 
deployment/src/training/__pycache__/metrics.cpython-312.pyc DELETED
Binary file (12.5 kB)
 
deployment/src/training/__pycache__/trainer.cpython-312.pyc DELETED
Binary file (18.8 kB)
 
deployment/src/training/classifier.py DELETED
@@ -1,282 +0,0 @@
1
- """
2
- LightGBM classifier for forgery type classification
3
- Implements Critical Fix #8: Configurable Confidence Threshold
4
- """
5
-
6
- import numpy as np
7
- import lightgbm as lgb
8
- from sklearn.preprocessing import StandardScaler
9
- from sklearn.model_selection import train_test_split
10
- from typing import Dict, List, Tuple, Optional
11
- import joblib
12
- from pathlib import Path
13
- import json
14
-
15
-
16
- class ForgeryClassifier:
17
- """
18
- LightGBM classifier for region-wise forgery classification
19
-
20
- Target classes:
21
- - 0: copy_move
22
- - 1: splicing
23
- - 2: text_substitution
24
- """
25
-
26
- CLASS_NAMES = ['copy_move', 'splicing', 'text_substitution']
27
-
28
- def __init__(self, config):
29
- """
30
- Initialize classifier
31
-
32
- Args:
33
- config: Configuration object
34
- """
35
- self.config = config
36
-
37
- # LightGBM parameters
38
- self.params = config.get('classifier.params', {
39
- 'objective': 'multiclass',
40
- 'num_class': 3,
41
- 'boosting_type': 'gbdt',
42
- 'num_leaves': 31,
43
- 'learning_rate': 0.05,
44
- 'n_estimators': 200,
45
- 'max_depth': 7,
46
- 'min_child_samples': 20,
47
- 'subsample': 0.8,
48
- 'colsample_bytree': 0.8,
49
- 'reg_alpha': 0.1,
50
- 'reg_lambda': 0.1,
51
- 'random_state': 42,
52
- 'verbose': -1
53
- })
54
-
55
- # Critical Fix #8: Configurable confidence threshold
56
- self.confidence_threshold = config.get('classifier.confidence_threshold', 0.6)
57
-
58
- # Initialize model and scaler
59
- self.model = None
60
- self.scaler = StandardScaler()
61
-
62
- # Feature importance
63
- self.feature_importance = None
64
- self.feature_names = None
65
-
66
- def train(self,
67
- features: np.ndarray,
68
- labels: np.ndarray,
69
- feature_names: Optional[List[str]] = None,
70
- validation_split: float = 0.2) -> Dict:
71
- """
72
- Train classifier
73
-
74
- Args:
75
- features: Feature matrix (N, D)
76
- labels: Class labels (N,)
77
- feature_names: Optional feature names
78
- validation_split: Validation split ratio
79
-
80
- Returns:
81
- Training metrics
82
- """
83
- print(f"Training LightGBM classifier")
84
- print(f"Features shape: {features.shape}")
85
- print(f"Labels distribution: {np.bincount(labels)}")
86
-
87
- # Handle NaN/Inf
88
- features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
89
-
90
- # Normalize features
91
- features_scaled = self.scaler.fit_transform(features)
92
-
93
- # Split data (Critical Fix #7: Image-level splitting should be done upstream)
94
- X_train, X_val, y_train, y_val = train_test_split(
95
- features_scaled, labels,
96
- test_size=validation_split,
97
- random_state=42,
98
- stratify=labels
99
- )
100
-
101
- # Create LightGBM datasets
102
- train_data = lgb.Dataset(X_train, label=y_train)
103
- val_data = lgb.Dataset(X_val, label=y_val, reference=train_data)
104
-
105
- # Train model
106
- self.model = lgb.train(
107
- self.params,
108
- train_data,
109
- valid_sets=[train_data, val_data],
110
- valid_names=['train', 'val'],
111
- num_boost_round=self.params.get('n_estimators', 200),
112
- callbacks=[
113
- lgb.early_stopping(stopping_rounds=20),
114
- lgb.log_evaluation(period=10)
115
- ]
116
- )
117
-
118
- # Store feature importance
119
- self.feature_names = feature_names
120
- self.feature_importance = self.model.feature_importance(importance_type='gain')
121
-
122
- # Evaluate
123
- train_pred = self.model.predict(X_train)
124
- train_acc = (train_pred.argmax(axis=1) == y_train).mean()
125
-
126
- val_pred = self.model.predict(X_val)
127
- val_acc = (val_pred.argmax(axis=1) == y_val).mean()
128
-
129
- metrics = {
130
- 'train_accuracy': train_acc,
131
- 'val_accuracy': val_acc,
132
- 'num_features': features.shape[1],
133
- 'num_samples': len(labels),
134
- 'best_iteration': self.model.best_iteration
135
- }
136
-
137
- print(f"Training complete!")
138
- print(f"Train accuracy: {train_acc:.4f}")
139
- print(f"Val accuracy: {val_acc:.4f}")
140
-
141
- return metrics
142
-
143
- def predict(self, features: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
144
- """
145
- Predict forgery types
146
-
147
- Args:
148
- features: Feature matrix (N, D)
149
-
150
- Returns:
151
- predictions: Predicted class indices (N,)
152
- confidences: Prediction confidences (N,)
153
- """
154
- if self.model is None:
155
- raise ValueError("Model not trained. Call train() first.")
156
-
157
- # Handle NaN/Inf
158
- features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
159
-
160
- # Normalize features
161
- features_scaled = self.scaler.transform(features)
162
-
163
- # Predict probabilities
164
- probabilities = self.model.predict(features_scaled)
165
-
166
- # Get predictions and confidences
167
- predictions = probabilities.argmax(axis=1)
168
- confidences = probabilities.max(axis=1)
169
-
170
- return predictions, confidences
171
-
172
- def predict_with_filtering(self,
173
- features: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
174
- """
175
- Predict with confidence filtering
176
-
177
- Args:
178
- features: Feature matrix (N, D)
179
-
180
- Returns:
181
- predictions: Predicted class indices (N,)
182
- confidences: Prediction confidences (N,)
183
- valid_mask: Boolean mask for valid predictions (N,)
184
- """
185
- predictions, confidences = self.predict(features)
186
-
187
- # Critical Fix #8: Apply confidence threshold
188
- valid_mask = confidences >= self.confidence_threshold
189
-
190
- return predictions, confidences, valid_mask
191
-
192
- def get_class_name(self, class_idx: int) -> str:
193
- """Get class name from index"""
194
- return self.CLASS_NAMES[class_idx]
195
-
196
- def get_feature_importance(self, top_k: int = 20) -> List[Tuple[str, float]]:
197
- """
198
- Get top-k most important features
199
-
200
- Args:
201
- top_k: Number of features to return
202
-
203
- Returns:
204
- List of (feature_name, importance) tuples
205
- """
206
- if self.feature_importance is None:
207
- return []
208
-
209
- # Sort by importance
210
- indices = np.argsort(self.feature_importance)[::-1][:top_k]
211
-
212
- result = []
213
- for idx in indices:
214
- name = self.feature_names[idx] if self.feature_names else f'feature_{idx}'
215
- importance = self.feature_importance[idx]
216
- result.append((name, importance))
217
-
218
- return result
219
-
220
- def save(self, save_dir: str):
221
- """
222
- Save model and scaler
223
-
224
- Args:
225
- save_dir: Directory to save model
226
- """
227
- save_path = Path(save_dir)
228
- save_path.mkdir(parents=True, exist_ok=True)
229
-
230
- # Save LightGBM model
231
- model_path = save_path / 'lightgbm_model.txt'
232
- self.model.save_model(str(model_path))
233
-
234
- # Save scaler
235
- scaler_path = save_path / 'scaler.joblib'
236
- joblib.dump(self.scaler, str(scaler_path))
237
-
238
- # Save metadata
239
- metadata = {
240
- 'confidence_threshold': self.confidence_threshold,
241
- 'class_names': self.CLASS_NAMES,
242
- 'feature_names': self.feature_names,
243
- 'feature_importance': self.feature_importance.tolist() if self.feature_importance is not None else None
244
- }
245
- metadata_path = save_path / 'classifier_metadata.json'
246
- with open(metadata_path, 'w') as f:
247
- json.dump(metadata, f, indent=2)
248
-
249
- print(f"Classifier saved to {save_path}")
250
-
251
- def load(self, load_dir: str):
252
- """
253
- Load model and scaler
254
-
255
- Args:
256
- load_dir: Directory to load from
257
- """
258
- load_path = Path(load_dir)
259
-
260
- # Load LightGBM model
261
- model_path = load_path / 'lightgbm_model.txt'
262
- self.model = lgb.Booster(model_file=str(model_path))
263
-
264
- # Load scaler
265
- scaler_path = load_path / 'scaler.joblib'
266
- self.scaler = joblib.load(str(scaler_path))
267
-
268
- # Load metadata
269
- metadata_path = load_path / 'classifier_metadata.json'
270
- with open(metadata_path, 'r') as f:
271
- metadata = json.load(f)
272
-
273
- self.confidence_threshold = metadata.get('confidence_threshold', 0.6)
274
- self.feature_names = metadata.get('feature_names')
275
- self.feature_importance = np.array(metadata.get('feature_importance', []))
276
-
277
- print(f"Classifier loaded from {load_path}")
278
-
279
-
280
- def get_classifier(config) -> ForgeryClassifier:
281
- """Factory function for classifier"""
282
- return ForgeryClassifier(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/src/training/metrics.py DELETED
@@ -1,305 +0,0 @@
1
- """
2
- Training utilities and metrics
3
- Implements Critical Fix #9: Dataset-Aware Metric Computation
4
- """
5
-
6
- import torch
7
- import numpy as np
8
- from typing import Dict, List, Optional
9
- from sklearn.metrics import (
10
- accuracy_score, f1_score, precision_score, recall_score,
11
- confusion_matrix
12
- )
13
-
14
-
15
- class SegmentationMetrics:
16
- """
17
- Segmentation metrics (IoU, Dice)
18
- Only computed for datasets with pixel masks (Critical Fix #9)
19
- """
20
-
21
- def __init__(self):
22
- """Initialize metrics"""
23
- self.reset()
24
-
25
- def reset(self):
26
- """Reset all metrics"""
27
- self.intersection = 0
28
- self.union = 0
29
- self.pred_sum = 0
30
- self.target_sum = 0
31
- self.total_samples = 0
32
-
33
- def update(self,
34
- pred: torch.Tensor,
35
- target: torch.Tensor,
36
- has_pixel_mask: bool = True):
37
- """
38
- Update metrics with batch
39
-
40
- Args:
41
- pred: Predicted probabilities (B, 1, H, W)
42
- target: Ground truth masks (B, 1, H, W)
43
- has_pixel_mask: Whether to compute metrics (Critical Fix #9)
44
- """
45
- if not has_pixel_mask:
46
- return
47
-
48
- # Binarize predictions
49
- pred_binary = (pred > 0.5).float()
50
-
51
- # Compute intersection and union
52
- intersection = (pred_binary * target).sum().item()
53
- union = pred_binary.sum().item() + target.sum().item() - intersection
54
-
55
- self.intersection += intersection
56
- self.union += union
57
- self.pred_sum += pred_binary.sum().item()
58
- self.target_sum += target.sum().item()
59
- self.total_samples += pred.shape[0]
60
-
61
- def compute(self) -> Dict[str, float]:
62
- """
63
- Compute final metrics
64
-
65
- Returns:
66
- Dictionary with IoU, Dice, Precision, Recall
67
- """
68
- # IoU (Jaccard)
69
- iou = self.intersection / (self.union + 1e-8)
70
-
71
- # Dice (F1)
72
- dice = (2 * self.intersection) / (self.pred_sum + self.target_sum + 1e-8)
73
-
74
- # Precision
75
- precision = self.intersection / (self.pred_sum + 1e-8)
76
-
77
- # Recall
78
- recall = self.intersection / (self.target_sum + 1e-8)
79
-
80
- return {
81
- 'iou': iou,
82
- 'dice': dice,
83
- 'precision': precision,
84
- 'recall': recall
85
- }
86
-
87
-
88
- class ClassificationMetrics:
89
- """Classification metrics for forgery type classification"""
90
-
91
- def __init__(self, num_classes: int = 3):
92
- """
93
- Initialize metrics
94
-
95
- Args:
96
- num_classes: Number of forgery types
97
- """
98
- self.num_classes = num_classes
99
- self.reset()
100
-
101
- def reset(self):
102
- """Reset all metrics"""
103
- self.predictions = []
104
- self.targets = []
105
- self.confidences = []
106
-
107
- def update(self,
108
- pred: np.ndarray,
109
- target: np.ndarray,
110
- confidence: Optional[np.ndarray] = None):
111
- """
112
- Update metrics with predictions
113
-
114
- Args:
115
- pred: Predicted class indices
116
- target: Ground truth class indices
117
- confidence: Optional prediction confidences
118
- """
119
- self.predictions.extend(pred.tolist())
120
- self.targets.extend(target.tolist())
121
- if confidence is not None:
122
- self.confidences.extend(confidence.tolist())
123
-
124
- def compute(self) -> Dict[str, float]:
125
- """
126
- Compute final metrics
127
-
128
- Returns:
129
- Dictionary with Accuracy, F1, Precision, Recall
130
- """
131
- if len(self.predictions) == 0:
132
- return {
133
- 'accuracy': 0.0,
134
- 'f1_macro': 0.0,
135
- 'f1_weighted': 0.0,
136
- 'precision': 0.0,
137
- 'recall': 0.0
138
- }
139
-
140
- preds = np.array(self.predictions)
141
- targets = np.array(self.targets)
142
-
143
- # Accuracy
144
- accuracy = accuracy_score(targets, preds)
145
-
146
- # F1 score (macro and weighted)
147
- f1_macro = f1_score(targets, preds, average='macro', zero_division=0)
148
- f1_weighted = f1_score(targets, preds, average='weighted', zero_division=0)
149
-
150
- # Precision and Recall
151
- precision = precision_score(targets, preds, average='macro', zero_division=0)
152
- recall = recall_score(targets, preds, average='macro', zero_division=0)
153
-
154
- # Confusion matrix
155
- cm = confusion_matrix(targets, preds, labels=range(self.num_classes))
156
-
157
- return {
158
- 'accuracy': accuracy,
159
- 'f1_macro': f1_macro,
160
- 'f1_weighted': f1_weighted,
161
- 'precision': precision,
162
- 'recall': recall,
163
- 'confusion_matrix': cm.tolist()
164
- }
165
-
166
-
167
- class MetricsTracker:
168
- """Track all metrics during training"""
169
-
170
- def __init__(self, config):
171
- """
172
- Initialize metrics tracker
173
-
174
- Args:
175
- config: Configuration object
176
- """
177
- self.config = config
178
- self.num_classes = config.get('data.num_classes', 3)
179
-
180
- self.seg_metrics = SegmentationMetrics()
181
- self.cls_metrics = ClassificationMetrics(self.num_classes)
182
-
183
- self.history = {
184
- 'train_loss': [],
185
- 'val_loss': [],
186
- 'train_iou': [],
187
- 'val_iou': [],
188
- 'train_dice': [],
189
- 'val_dice': [],
190
- 'train_precision': [],
191
- 'val_precision': [],
192
- 'train_recall': [],
193
- 'val_recall': []
194
- }
195
-
196
- def reset(self):
197
- """Reset metrics for new epoch"""
198
- self.seg_metrics.reset()
199
- self.cls_metrics.reset()
200
-
201
- def update_segmentation(self,
202
- pred: torch.Tensor,
203
- target: torch.Tensor,
204
- dataset_name: str):
205
- """Update segmentation metrics (dataset-aware)"""
206
- has_pixel_mask = self.config.should_compute_localization_metrics(dataset_name)
207
- self.seg_metrics.update(pred, target, has_pixel_mask)
208
-
209
- def update_classification(self,
210
- pred: np.ndarray,
211
- target: np.ndarray,
212
- confidence: Optional[np.ndarray] = None):
213
- """Update classification metrics"""
214
- self.cls_metrics.update(pred, target, confidence)
215
-
216
- def compute_all(self) -> Dict[str, float]:
217
- """Compute all metrics"""
218
- seg = self.seg_metrics.compute()
219
-
220
- # Only include classification metrics if they have data
221
- if len(self.cls_metrics.predictions) > 0:
222
- cls = self.cls_metrics.compute()
223
- # Prefix classification metrics to avoid collision
224
- cls_prefixed = {f'cls_{k}': v for k, v in cls.items()}
225
- return {**seg, **cls_prefixed}
226
-
227
- return seg
228
-
229
- def log_epoch(self, epoch: int, phase: str, loss: float, metrics: Dict):
230
- """Log metrics for epoch"""
231
- prefix = f'{phase}_'
232
-
233
- self.history[f'{phase}_loss'].append(loss)
234
-
235
- if 'iou' in metrics:
236
- self.history[f'{phase}_iou'].append(metrics['iou'])
237
- if 'dice' in metrics:
238
- self.history[f'{phase}_dice'].append(metrics['dice'])
239
- if 'precision' in metrics:
240
- self.history[f'{phase}_precision'].append(metrics['precision'])
241
- if 'recall' in metrics:
242
- self.history[f'{phase}_recall'].append(metrics['recall'])
243
-
244
- def get_history(self) -> Dict:
245
- """Get full training history"""
246
- return self.history
247
-
248
-
249
- class EarlyStopping:
250
- """Early stopping to prevent overfitting"""
251
-
252
- def __init__(self,
253
- patience: int = 10,
254
- min_delta: float = 0.001,
255
- mode: str = 'max'):
256
- """
257
- Initialize early stopping
258
-
259
- Args:
260
- patience: Number of epochs to wait
261
- min_delta: Minimum improvement required
262
- mode: 'min' for loss, 'max' for metrics
263
- """
264
- self.patience = patience
265
- self.min_delta = min_delta
266
- self.mode = mode
267
-
268
- self.counter = 0
269
- self.best_value = None
270
- self.should_stop = False
271
-
272
- def __call__(self, value: float) -> bool:
273
- """
274
- Check if training should stop
275
-
276
- Args:
277
- value: Current metric value
278
-
279
- Returns:
280
- True if should stop
281
- """
282
- if self.best_value is None:
283
- self.best_value = value
284
- return False
285
-
286
- if self.mode == 'max':
287
- improved = value > self.best_value + self.min_delta
288
- else:
289
- improved = value < self.best_value - self.min_delta
290
-
291
- if improved:
292
- self.best_value = value
293
- self.counter = 0
294
- else:
295
- self.counter += 1
296
-
297
- if self.counter >= self.patience:
298
- self.should_stop = True
299
-
300
- return self.should_stop
301
-
302
-
303
- def get_metrics_tracker(config) -> MetricsTracker:
304
- """Factory function for metrics tracker"""
305
- return MetricsTracker(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/src/training/trainer.py DELETED
@@ -1,450 +0,0 @@
1
- """
2
- Training loop for forgery localization network
3
- Implements chunked training for RAM constraints
4
- """
5
-
6
- import os
7
- import torch
8
- import torch.nn as nn
9
- import torch.optim as optim
10
- from torch.utils.data import DataLoader
11
- from torch.cuda.amp import autocast, GradScaler
12
- from typing import Dict, Optional, Tuple
13
- from pathlib import Path
14
- from tqdm import tqdm
15
- import json
16
- import csv
17
-
18
- from ..models import get_model, get_loss_function
19
- from ..data import get_dataset
20
- from .metrics import MetricsTracker, EarlyStopping
21
-
22
-
23
- class Trainer:
24
- """
25
- Trainer for forgery localization network
26
- Supports chunked training for large datasets (DocTamper)
27
- """
28
-
29
- def __init__(self, config, dataset_name: str = 'doctamper'):
30
- """
31
- Initialize trainer
32
-
33
- Args:
34
- config: Configuration object
35
- dataset_name: Dataset to train on
36
- """
37
- self.config = config
38
- self.dataset_name = dataset_name
39
-
40
- # Device setup
41
- self.device = torch.device(
42
- 'cuda' if torch.cuda.is_available() and config.get('system.device') == 'cuda'
43
- else 'cpu'
44
- )
45
- print(f"Training on: {self.device}")
46
-
47
- # Initialize model
48
- self.model = get_model(config).to(self.device)
49
-
50
- # Loss function (dataset-aware)
51
- self.criterion = get_loss_function(config)
52
-
53
- # Optimizer
54
- lr = config.get('training.learning_rate', 0.001)
55
- weight_decay = config.get('training.weight_decay', 0.0001)
56
- self.optimizer = optim.AdamW(
57
- self.model.parameters(),
58
- lr=lr,
59
- weight_decay=weight_decay
60
- )
61
-
62
- # Learning rate scheduler
63
- epochs = config.get('training.epochs', 50)
64
- warmup_epochs = config.get('training.scheduler.warmup_epochs', 5)
65
- min_lr = config.get('training.scheduler.min_lr', 1e-5)
66
-
67
- self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
68
- self.optimizer,
69
- T_0=epochs - warmup_epochs,
70
- T_mult=1,
71
- eta_min=min_lr
72
- )
73
-
74
- # Mixed precision training
75
- self.scaler = GradScaler()
76
-
77
- # Metrics
78
- self.metrics_tracker = MetricsTracker(config)
79
-
80
- # Early stopping
81
- patience = config.get('training.early_stopping.patience', 10)
82
- min_delta = config.get('training.early_stopping.min_delta', 0.001)
83
- self.early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
84
-
85
- # Output directories
86
- self.checkpoint_dir = Path(config.get('outputs.checkpoints', 'outputs/checkpoints'))
87
- self.log_dir = Path(config.get('outputs.logs', 'outputs/logs'))
88
- self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
89
- self.log_dir.mkdir(parents=True, exist_ok=True)
90
-
91
- # Training state
92
- self.current_epoch = 0
93
- self.best_metric = 0.0
94
-
95
- def create_dataloaders(self,
96
- chunk_start: float = 0.0,
97
- chunk_end: float = 1.0) -> Tuple[DataLoader, DataLoader]:
98
- """
99
- Create train and validation dataloaders
100
-
101
- Args:
102
- chunk_start: Start ratio for chunked training
103
- chunk_end: End ratio for chunked training
104
-
105
- Returns:
106
- Train and validation dataloaders
107
- """
108
- batch_size = self.config.get('data.batch_size', 8)
109
- num_workers = self.config.get('system.num_workers', 4)
110
-
111
- # Training dataset (with chunking for DocTamper)
112
- if self.dataset_name == 'doctamper':
113
- train_dataset = get_dataset(
114
- self.config,
115
- self.dataset_name,
116
- split='train',
117
- chunk_start=chunk_start,
118
- chunk_end=chunk_end
119
- )
120
- else:
121
- train_dataset = get_dataset(
122
- self.config,
123
- self.dataset_name,
124
- split='train'
125
- )
126
-
127
- # Validation dataset (always full)
128
- # For FCD and SCD, validate on DocTamper TestingSet
129
- if self.dataset_name in ['fcd', 'scd']:
130
- val_dataset = get_dataset(
131
- self.config,
132
- 'doctamper', # Use DocTamper for validation
133
- split='val'
134
- )
135
- else:
136
- val_dataset = get_dataset(
137
- self.config,
138
- self.dataset_name,
139
- split='val' if self.dataset_name in ['doctamper', 'receipts'] else 'test'
140
- )
141
-
142
- train_loader = DataLoader(
143
- train_dataset,
144
- batch_size=batch_size,
145
- shuffle=True,
146
- num_workers=num_workers,
147
- pin_memory=self.config.get('system.pin_memory', True),
148
- drop_last=True
149
- )
150
-
151
- val_loader = DataLoader(
152
- val_dataset,
153
- batch_size=batch_size,
154
- shuffle=False,
155
- num_workers=num_workers,
156
- pin_memory=True
157
- )
158
-
159
- return train_loader, val_loader
160
-
161
- def train_epoch(self, dataloader: DataLoader) -> Tuple[float, Dict]:
162
- """
163
- Train for one epoch
164
-
165
- Args:
166
- dataloader: Training dataloader
167
-
168
- Returns:
169
- Average loss and metrics
170
- """
171
- self.model.train()
172
- self.metrics_tracker.reset()
173
-
174
- total_loss = 0.0
175
- num_batches = 0
176
-
177
- pbar = tqdm(dataloader, desc=f"Epoch {self.current_epoch} [Train]")
178
-
179
- for batch_idx, (images, masks, metadata) in enumerate(pbar):
180
- images = images.to(self.device)
181
- masks = masks.to(self.device)
182
-
183
- # Forward pass with mixed precision
184
- self.optimizer.zero_grad()
185
-
186
- with autocast():
187
- outputs, _ = self.model(images)
188
-
189
- # Dataset-aware loss
190
- has_pixel_mask = self.config.has_pixel_mask(self.dataset_name)
191
- losses = self.criterion.combined_loss(outputs, masks, has_pixel_mask)
192
-
193
- # Backward pass with gradient scaling
194
- self.scaler.scale(losses['total']).backward()
195
- self.scaler.step(self.optimizer)
196
- self.scaler.update()
197
-
198
- # Update metrics
199
- with torch.no_grad():
200
- probs = torch.sigmoid(outputs)
201
- self.metrics_tracker.update_segmentation(
202
- probs, masks, self.dataset_name
203
- )
204
-
205
- total_loss += losses['total'].item()
206
- num_batches += 1
207
-
208
- # Update progress bar
209
- pbar.set_postfix({
210
- 'loss': f"{losses['total'].item():.4f}",
211
- 'bce': f"{losses['bce'].item():.4f}"
212
- })
213
-
214
- avg_loss = total_loss / num_batches
215
- metrics = self.metrics_tracker.compute_all()
216
-
217
- return avg_loss, metrics
218
-
219
- def validate(self, dataloader: DataLoader) -> Tuple[float, Dict]:
220
- """
221
- Validate model
222
-
223
- Args:
224
- dataloader: Validation dataloader
225
-
226
- Returns:
227
- Average loss and metrics
228
- """
229
- self.model.eval()
230
- self.metrics_tracker.reset()
231
-
232
- total_loss = 0.0
233
- num_batches = 0
234
-
235
- pbar = tqdm(dataloader, desc=f"Epoch {self.current_epoch} [Val]")
236
-
237
- with torch.no_grad():
238
- for images, masks, metadata in pbar:
239
- images = images.to(self.device)
240
- masks = masks.to(self.device)
241
-
242
- # Forward pass
243
- outputs, _ = self.model(images)
244
-
245
- # Dataset-aware loss
246
- has_pixel_mask = self.config.has_pixel_mask(self.dataset_name)
247
- losses = self.criterion.combined_loss(outputs, masks, has_pixel_mask)
248
-
249
- # Update metrics
250
- probs = torch.sigmoid(outputs)
251
- self.metrics_tracker.update_segmentation(
252
- probs, masks, self.dataset_name
253
- )
254
-
255
- total_loss += losses['total'].item()
256
- num_batches += 1
257
-
258
- pbar.set_postfix({
259
- 'loss': f"{losses['total'].item():.4f}"
260
- })
261
-
262
- avg_loss = total_loss / num_batches
263
- metrics = self.metrics_tracker.compute_all()
264
-
265
- return avg_loss, metrics
266
-
267
- def save_checkpoint(self,
268
- filename: str,
269
- is_best: bool = False,
270
- chunk_id: Optional[int] = None):
271
- """Save model checkpoint"""
272
- checkpoint = {
273
- 'epoch': self.current_epoch,
274
- 'model_state_dict': self.model.state_dict(),
275
- 'optimizer_state_dict': self.optimizer.state_dict(),
276
- 'scheduler_state_dict': self.scheduler.state_dict(),
277
- 'best_metric': self.best_metric,
278
- 'dataset': self.dataset_name,
279
- 'chunk_id': chunk_id
280
- }
281
-
282
- path = self.checkpoint_dir / filename
283
- torch.save(checkpoint, path)
284
- print(f"Saved checkpoint: {path}")
285
-
286
- if is_best:
287
- best_path = self.checkpoint_dir / f'best_{self.dataset_name}.pth'
288
- torch.save(checkpoint, best_path)
289
- print(f"Saved best model: {best_path}")
290
-
291
- def load_checkpoint(self, filename: str, reset_epoch: bool = False):
292
- """
293
- Load model checkpoint
294
-
295
- Args:
296
- filename: Checkpoint filename
297
- reset_epoch: If True, reset epoch counter to 0 (useful for chunked training)
298
- """
299
- path = self.checkpoint_dir / filename
300
-
301
- if not path.exists():
302
- print(f"Checkpoint not found: {path}")
303
- return False
304
-
305
- checkpoint = torch.load(path, map_location=self.device)
306
-
307
- self.model.load_state_dict(checkpoint['model_state_dict'])
308
- self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
309
- self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
310
-
311
- if reset_epoch:
312
- self.current_epoch = 0
313
- print(f"Loaded checkpoint: {path} (epoch counter reset to 0)")
314
- else:
315
- self.current_epoch = checkpoint['epoch'] + 1 # Continue from next epoch
316
- print(f"Loaded checkpoint: {path} (resuming from epoch {self.current_epoch})")
317
-
318
- self.best_metric = checkpoint.get('best_metric', 0.0)
319
-
320
- return True
321
-
322
- def train(self,
323
- epochs: Optional[int] = None,
324
- chunk_start: float = 0.0,
325
- chunk_end: float = 1.0,
326
- chunk_id: Optional[int] = None,
327
- resume_from: Optional[str] = None):
328
- """
329
- Main training loop
330
-
331
- Args:
332
- epochs: Number of epochs (None uses config)
333
- chunk_start: Start ratio for chunked training
334
- chunk_end: End ratio for chunked training
335
- chunk_id: Chunk identifier for logging
336
- resume_from: Checkpoint to resume from
337
- """
338
- if epochs is None:
339
- epochs = self.config.get('training.epochs', 50)
340
-
341
- # Resume if specified
342
- if resume_from:
343
- self.load_checkpoint(resume_from)
344
-
345
- # Create dataloaders
346
- train_loader, val_loader = self.create_dataloaders(chunk_start, chunk_end)
347
-
348
- print(f"\n{'='*60}")
349
- print(f"Training: {self.dataset_name}")
350
- if chunk_id is not None:
351
- print(f"Chunk: {chunk_id} [{chunk_start*100:.0f}% - {chunk_end*100:.0f}%]")
352
- print(f"Epochs: {epochs}")
353
- print(f"Train samples: {len(train_loader.dataset)}")
354
- print(f"Val samples: {len(val_loader.dataset)}")
355
- print(f"{'='*60}\n")
356
-
357
- # Training log file
358
- log_file = self.log_dir / f'{self.dataset_name}_chunk{chunk_id or 0}_log.csv'
359
- with open(log_file, 'w', newline='') as f:
360
- writer = csv.writer(f)
361
- writer.writerow(['epoch', 'train_loss', 'val_loss',
362
- 'train_iou', 'val_iou', 'train_dice', 'val_dice',
363
- 'train_precision', 'val_precision',
364
- 'train_recall', 'val_recall', 'lr'])
365
-
366
- for epoch in range(self.current_epoch, epochs):
367
- self.current_epoch = epoch
368
-
369
- # Train
370
- train_loss, train_metrics = self.train_epoch(train_loader)
371
-
372
- # Validate
373
- val_loss, val_metrics = self.validate(val_loader)
374
-
375
- # Update scheduler
376
- self.scheduler.step()
377
- current_lr = self.optimizer.param_groups[0]['lr']
378
-
379
- # Log metrics
380
- self.metrics_tracker.log_epoch(epoch, 'train', train_loss, train_metrics)
381
- self.metrics_tracker.log_epoch(epoch, 'val', val_loss, val_metrics)
382
-
383
- # Log to file
384
- with open(log_file, 'a', newline='') as f:
385
- writer = csv.writer(f)
386
- writer.writerow([
387
- epoch,
388
- f"{train_loss:.4f}",
389
- f"{val_loss:.4f}",
390
- f"{train_metrics.get('iou', 0):.4f}",
391
- f"{val_metrics.get('iou', 0):.4f}",
392
- f"{train_metrics.get('dice', 0):.4f}",
393
- f"{val_metrics.get('dice', 0):.4f}",
394
- f"{train_metrics.get('precision', 0):.4f}",
395
- f"{val_metrics.get('precision', 0):.4f}",
396
- f"{train_metrics.get('recall', 0):.4f}",
397
- f"{val_metrics.get('recall', 0):.4f}",
398
- f"{current_lr:.6f}"
399
- ])
400
-
401
- # Print summary
402
- print(f"\nEpoch {epoch}/{epochs-1}")
403
- print(f" Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
404
- print(f" Train IoU: {train_metrics.get('iou', 0):.4f} | Val IoU: {val_metrics.get('iou', 0):.4f}")
405
- print(f" Train Dice: {train_metrics.get('dice', 0):.4f} | Val Dice: {val_metrics.get('dice', 0):.4f}")
406
- print(f" LR: {current_lr:.6f}")
407
-
408
- # Save checkpoints
409
- if self.config.get('training.checkpoint.save_every', 5) > 0:
410
- if (epoch + 1) % self.config.get('training.checkpoint.save_every', 5) == 0:
411
- self.save_checkpoint(
412
- f'{self.dataset_name}_chunk{chunk_id or 0}_epoch{epoch}.pth',
413
- chunk_id=chunk_id
414
- )
415
-
416
- # Check for best model
417
- monitor_metric = val_metrics.get('dice', 0)
418
- if monitor_metric > self.best_metric:
419
- self.best_metric = monitor_metric
420
- self.save_checkpoint(
421
- f'{self.dataset_name}_chunk{chunk_id or 0}_best.pth',
422
- is_best=True,
423
- chunk_id=chunk_id
424
- )
425
-
426
- # Early stopping
427
- if self.early_stopping(monitor_metric):
428
- print(f"\nEarly stopping triggered at epoch {epoch}")
429
- break
430
-
431
- # Save final checkpoint
432
- self.save_checkpoint(
433
- f'{self.dataset_name}_chunk{chunk_id or 0}_final.pth',
434
- chunk_id=chunk_id
435
- )
436
-
437
- # Save training history
438
- history_file = self.log_dir / f'{self.dataset_name}_chunk{chunk_id or 0}_history.json'
439
- with open(history_file, 'w') as f:
440
- json.dump(self.metrics_tracker.get_history(), f, indent=2)
441
-
442
- print(f"\nTraining complete!")
443
- print(f"Best Dice: {self.best_metric:.4f}")
444
-
445
- return self.metrics_tracker.get_history()
446
-
447
-
448
- def get_trainer(config, dataset_name: str = 'doctamper') -> Trainer:
449
- """Factory function for trainer"""
450
- return Trainer(config, dataset_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/src/utils/__init__.py DELETED
@@ -1,28 +0,0 @@
1
- """Utilities module"""
2
-
3
- from .plotting import (
4
- plot_training_curves,
5
- plot_confusion_matrix,
6
- plot_feature_importance,
7
- plot_dataset_comparison,
8
- plot_chunked_training_progress,
9
- generate_training_report
10
- )
11
-
12
- from .export import (
13
- export_to_onnx,
14
- export_to_torchscript,
15
- quantize_model
16
- )
17
-
18
- __all__ = [
19
- 'plot_training_curves',
20
- 'plot_confusion_matrix',
21
- 'plot_feature_importance',
22
- 'plot_dataset_comparison',
23
- 'plot_chunked_training_progress',
24
- 'generate_training_report',
25
- 'export_to_onnx',
26
- 'export_to_torchscript',
27
- 'quantize_model'
28
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deployment/src/utils/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (601 Bytes)