JKrishnanandhaa commited on
Commit
51fdac5
·
verified ·
1 Parent(s): 8378f18

Upload 8 files

Browse files
scripts/evaluate.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Evaluation Script
3
+
4
+ Evaluate trained model on validation/test sets with comprehensive metrics.
5
+
6
+ Usage:
7
+ python scripts/evaluate.py --model outputs/checkpoints/best_doctamper.pth --dataset doctamper
8
+ """
9
+
10
+ import argparse
11
+ import sys
12
+ from pathlib import Path
13
+ import json
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+ import torch
17
+
18
+ # Add src to path
19
+ sys.path.insert(0, str(Path(__file__).parent.parent))
20
+
21
+ from src.config import get_config
22
+ from src.models import get_model
23
+ from src.data import get_dataset
24
+ from src.training.metrics import SegmentationMetrics
25
+ from src.utils import plot_training_curves
26
+
27
+
28
+ def parse_args():
29
+ parser = argparse.ArgumentParser(description="Evaluate forgery detection model")
30
+
31
+ parser.add_argument('--model', type=str, required=True,
32
+ help='Path to model checkpoint')
33
+
34
+ parser.add_argument('--dataset', type=str, required=True,
35
+ choices=['doctamper', 'rtm', 'casia', 'receipts'],
36
+ help='Dataset to evaluate on')
37
+
38
+ parser.add_argument('--split', type=str, default='val',
39
+ help='Data split (val/test)')
40
+
41
+ parser.add_argument('--output', type=str, default='outputs/evaluation',
42
+ help='Output directory')
43
+
44
+ parser.add_argument('--config', type=str, default='config.yaml',
45
+ help='Path to config file')
46
+
47
+ return parser.parse_args()
48
+
49
+
50
+ def main():
51
+ args = parse_args()
52
+
53
+ # Load config
54
+ config = get_config(args.config)
55
+
56
+ # Device
57
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
58
+
59
+ print("\n" + "="*60)
60
+ print("Model Evaluation")
61
+ print("="*60)
62
+ print(f"Model: {args.model}")
63
+ print(f"Dataset: {args.dataset}")
64
+ print(f"Split: {args.split}")
65
+ print(f"Device: {device}")
66
+ print("="*60)
67
+
68
+ # Load model
69
+ model = get_model(config).to(device)
70
+ checkpoint = torch.load(args.model, map_location=device)
71
+
72
+ if 'model_state_dict' in checkpoint:
73
+ model.load_state_dict(checkpoint['model_state_dict'])
74
+ else:
75
+ model.load_state_dict(checkpoint)
76
+
77
+ model.eval()
78
+ print("Model loaded")
79
+
80
+ # Load dataset
81
+ dataset = get_dataset(config, args.dataset, split=args.split)
82
+ print(f"Dataset loaded: {len(dataset)} samples")
83
+
84
+ # Create output directory
85
+ output_dir = Path(args.output)
86
+ output_dir.mkdir(parents=True, exist_ok=True)
87
+
88
+ # Evaluate
89
+ metrics = SegmentationMetrics()
90
+ has_pixel_mask = config.has_pixel_mask(args.dataset)
91
+
92
+ print(f"\nEvaluating...")
93
+
94
+ all_ious = []
95
+ all_dices = []
96
+
97
+ with torch.no_grad():
98
+ for i in tqdm(range(len(dataset)), desc="Evaluating"):
99
+ try:
100
+ image, mask, metadata = dataset[i]
101
+
102
+ # Move to device
103
+ image = image.unsqueeze(0).to(device)
104
+ mask = mask.unsqueeze(0).to(device)
105
+
106
+ # Forward pass
107
+ logits, _ = model(image)
108
+ probs = torch.sigmoid(logits)
109
+
110
+ # Update metrics
111
+ if has_pixel_mask:
112
+ metrics.update(probs, mask, has_pixel_mask=True)
113
+
114
+ # Per-sample metrics
115
+ pred_binary = (probs > 0.5).float()
116
+ intersection = (pred_binary * mask).sum().item()
117
+ union = pred_binary.sum().item() + mask.sum().item() - intersection
118
+
119
+ iou = intersection / (union + 1e-8)
120
+ dice = (2 * intersection) / (pred_binary.sum().item() + mask.sum().item() + 1e-8)
121
+
122
+ all_ious.append(iou)
123
+ all_dices.append(dice)
124
+
125
+ except Exception as e:
126
+ print(f"Error on sample {i}: {e}")
127
+ continue
128
+
129
+ # Compute final metrics
130
+ results = metrics.compute()
131
+
132
+ # Add per-sample statistics
133
+ if has_pixel_mask and all_ious:
134
+ results['iou_mean'] = np.mean(all_ious)
135
+ results['iou_std'] = np.std(all_ious)
136
+ results['dice_mean'] = np.mean(all_dices)
137
+ results['dice_std'] = np.std(all_dices)
138
+
139
+ # Print results
140
+ print("\n" + "="*60)
141
+ print("Evaluation Results")
142
+ print("="*60)
143
+
144
+ for key, value in results.items():
145
+ if isinstance(value, float):
146
+ print(f" {key}: {value:.4f}")
147
+
148
+ # Save results
149
+ results_path = output_dir / f'{args.dataset}_{args.split}_results.json'
150
+ with open(results_path, 'w') as f:
151
+ json.dump(results, f, indent=2)
152
+
153
+ print(f"\nResults saved to: {results_path}")
154
+ print("="*60)
155
+
156
+
157
+ if __name__ == '__main__':
158
+ main()
scripts/evaluate_full_testingset.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Comprehensive Evaluation on Entire TestingSet
3
+
4
+ Evaluates the complete pipeline on all 30,000 samples from TestingSet
5
+ Calculates all metrics: Accuracy, Precision, Recall, F1, IoU, Dice, etc.
6
+ No visualizations - metrics only for speed
7
+
8
+ Usage:
9
+ python scripts/evaluate_full_testingset.py
10
+ """
11
+
12
+ import sys
13
+ from pathlib import Path
14
+ import numpy as np
15
+ import torch
16
+ from tqdm import tqdm
17
+ import json
18
+ from datetime import datetime
19
+
20
+ sys.path.insert(0, str(Path(__file__).parent.parent))
21
+
22
+ from src.config import get_config
23
+ from src.models import get_model
24
+ from src.data import get_dataset
25
+ from src.features import get_mask_refiner, get_region_extractor
26
+ from src.training.classifier import ForgeryClassifier
27
+ from src.data.preprocessing import DocumentPreprocessor
28
+ from src.data.augmentation import DatasetAwareAugmentation
29
+
30
+ # Class mapping
31
+ CLASS_NAMES = {0: 'Copy-Move', 1: 'Splicing', 2: 'Generation'}
32
+
33
+
34
+ def calculate_metrics(pred_mask, gt_mask):
35
+ """Calculate all segmentation metrics"""
36
+ pred = pred_mask.astype(bool)
37
+ gt = gt_mask.astype(bool)
38
+
39
+ intersection = (pred & gt).sum()
40
+ union = (pred | gt).sum()
41
+
42
+ tp = intersection
43
+ fp = (pred & ~gt).sum()
44
+ fn = (~pred & gt).sum()
45
+ tn = (~pred & ~gt).sum()
46
+
47
+ # Segmentation metrics
48
+ iou = intersection / (union + 1e-8)
49
+ dice = (2 * intersection) / (pred.sum() + gt.sum() + 1e-8)
50
+ precision = tp / (tp + fp + 1e-8)
51
+ recall = tp / (tp + fn + 1e-8)
52
+ f1 = 2 * precision * recall / (precision + recall + 1e-8)
53
+ accuracy = (tp + tn) / (tp + tn + fp + fn + 1e-8)
54
+
55
+ return {
56
+ 'iou': float(iou),
57
+ 'dice': float(dice),
58
+ 'precision': float(precision),
59
+ 'recall': float(recall),
60
+ 'f1': float(f1),
61
+ 'accuracy': float(accuracy),
62
+ 'tp': int(tp),
63
+ 'fp': int(fp),
64
+ 'fn': int(fn),
65
+ 'tn': int(tn)
66
+ }
67
+
68
+
69
+ def main():
70
+ print("="*80)
71
+ print("COMPREHENSIVE EVALUATION ON ENTIRE TESTINGSET")
72
+ print("="*80)
73
+ print("Dataset: DocTamper TestingSet (30,000 samples)")
74
+ print("Mode: Metrics only (no visualizations)")
75
+ print("="*80)
76
+
77
+ config = get_config('config.yaml')
78
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
79
+
80
+ # Load models
81
+ print("\n1. Loading models...")
82
+
83
+ # Localization model
84
+ model = get_model(config).to(device)
85
+ checkpoint = torch.load('outputs/checkpoints/best_doctamper.pth', map_location=device)
86
+ model.load_state_dict(checkpoint['model_state_dict'])
87
+ model.eval()
88
+ print(f" ✓ Localization model loaded (Dice: {checkpoint.get('best_metric', 0):.2%})")
89
+
90
+ # Classifier
91
+ classifier = ForgeryClassifier(config)
92
+ classifier.load('outputs/classifier')
93
+ print(f" ✓ Classifier loaded")
94
+
95
+ # Components
96
+ preprocessor = DocumentPreprocessor(config, 'doctamper')
97
+ augmentation = DatasetAwareAugmentation(config, 'doctamper', is_training=False)
98
+ mask_refiner = get_mask_refiner(config)
99
+ region_extractor = get_region_extractor(config)
100
+
101
+ # Load dataset
102
+ print("\n2. Loading TestingSet...")
103
+ dataset = get_dataset(config, 'doctamper', split='val') # val = TestingSet
104
+ total_samples = len(dataset)
105
+ print(f" ✓ Loaded {total_samples} samples")
106
+
107
+ # Initialize metrics storage
108
+ all_metrics = []
109
+ detection_stats = {
110
+ 'total': 0,
111
+ 'has_forgery': 0,
112
+ 'detected': 0,
113
+ 'missed': 0,
114
+ 'false_positives': 0,
115
+ 'true_negatives': 0
116
+ }
117
+
118
+ print("\n3. Running evaluation...")
119
+ print("="*80)
120
+
121
+ # Process all samples
122
+ for idx in tqdm(range(total_samples), desc="Evaluating"):
123
+ try:
124
+ # Get sample from dataset
125
+ image_tensor, mask_tensor, metadata = dataset[idx]
126
+
127
+ # Ground truth
128
+ gt_mask = mask_tensor.numpy()[0]
129
+ gt_mask_binary = (gt_mask > 0.5).astype(np.uint8)
130
+ has_forgery = gt_mask_binary.sum() > 0
131
+
132
+ # Run localization
133
+ with torch.no_grad():
134
+ image_batch = image_tensor.unsqueeze(0).to(device)
135
+ logits, _ = model(image_batch)
136
+ prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
137
+
138
+ # Generate mask
139
+ binary_mask = (prob_map > 0.5).astype(np.uint8)
140
+ refined_mask = mask_refiner.refine(binary_mask)
141
+
142
+ # Calculate metrics
143
+ metrics = calculate_metrics(refined_mask, gt_mask_binary)
144
+ metrics['sample_idx'] = idx
145
+ metrics['has_forgery'] = has_forgery
146
+ metrics['prob_max'] = float(prob_map.max())
147
+
148
+ # Detection statistics
149
+ detected = refined_mask.sum() > 0
150
+
151
+ detection_stats['total'] += 1
152
+ if has_forgery:
153
+ detection_stats['has_forgery'] += 1
154
+ if detected:
155
+ detection_stats['detected'] += 1
156
+ else:
157
+ detection_stats['missed'] += 1
158
+ else:
159
+ if detected:
160
+ detection_stats['false_positives'] += 1
161
+ else:
162
+ detection_stats['true_negatives'] += 1
163
+
164
+ all_metrics.append(metrics)
165
+
166
+ except Exception as e:
167
+ print(f"\nError at sample {idx}: {str(e)[:100]}")
168
+ continue
169
+
170
+ # Calculate overall statistics
171
+ print("\n" + "="*80)
172
+ print("RESULTS")
173
+ print("="*80)
174
+
175
+ # Detection statistics
176
+ print("\n📊 DETECTION STATISTICS:")
177
+ print("-"*80)
178
+ print(f"Total samples: {detection_stats['total']}")
179
+ print(f"Samples with forgery: {detection_stats['has_forgery']}")
180
+ print(f"Samples without forgery: {detection_stats['total'] - detection_stats['has_forgery']}")
181
+ print()
182
+ print(f"✅ Correctly detected: {detection_stats['detected']}")
183
+ print(f"❌ Missed detections: {detection_stats['missed']}")
184
+ print(f"⚠️ False positives: {detection_stats['false_positives']}")
185
+ print(f"✓ True negatives: {detection_stats['true_negatives']}")
186
+ print()
187
+
188
+ # Detection rates
189
+ if detection_stats['has_forgery'] > 0:
190
+ detection_rate = detection_stats['detected'] / detection_stats['has_forgery']
191
+ miss_rate = detection_stats['missed'] / detection_stats['has_forgery']
192
+ print(f"Detection Rate (Recall): {detection_rate:.2%} ⬆️ Higher is better")
193
+ print(f"Miss Rate: {miss_rate:.2%} ⬇️ Lower is better")
194
+
195
+ if detection_stats['detected'] + detection_stats['false_positives'] > 0:
196
+ precision = detection_stats['detected'] / (detection_stats['detected'] + detection_stats['false_positives'])
197
+ print(f"Precision: {precision:.2%} ⬆️ Higher is better")
198
+
199
+ overall_accuracy = (detection_stats['detected'] + detection_stats['true_negatives']) / detection_stats['total']
200
+ print(f"Overall Accuracy: {overall_accuracy:.2%} ⬆️ Higher is better")
201
+
202
+ # Segmentation metrics (only for samples with forgery)
203
+ forgery_metrics = [m for m in all_metrics if m['has_forgery']]
204
+
205
+ if forgery_metrics:
206
+ print("\n📈 SEGMENTATION METRICS (on samples with forgery):")
207
+ print("-"*80)
208
+
209
+ avg_iou = np.mean([m['iou'] for m in forgery_metrics])
210
+ avg_dice = np.mean([m['dice'] for m in forgery_metrics])
211
+ avg_precision = np.mean([m['precision'] for m in forgery_metrics])
212
+ avg_recall = np.mean([m['recall'] for m in forgery_metrics])
213
+ avg_f1 = np.mean([m['f1'] for m in forgery_metrics])
214
+ avg_accuracy = np.mean([m['accuracy'] for m in forgery_metrics])
215
+
216
+ print(f"IoU (Intersection over Union): {avg_iou:.4f} ⬆️ Higher is better (0-1)")
217
+ print(f"Dice Coefficient: {avg_dice:.4f} ⬆️ Higher is better (0-1)")
218
+ print(f"Pixel Precision: {avg_precision:.4f} ⬆️ Higher is better (0-1)")
219
+ print(f"Pixel Recall: {avg_recall:.4f} ⬆️ Higher is better (0-1)")
220
+ print(f"Pixel F1-Score: {avg_f1:.4f} ⬆️ Higher is better (0-1)")
221
+ print(f"Pixel Accuracy: {avg_accuracy:.4f} ⬆️ Higher is better (0-1)")
222
+
223
+ # Probability statistics
224
+ avg_prob = np.mean([m['prob_max'] for m in forgery_metrics])
225
+ print(f"\nAverage Max Probability: {avg_prob:.4f}")
226
+
227
+ # Save results
228
+ print("\n" + "="*80)
229
+ print("SAVING RESULTS")
230
+ print("="*80)
231
+
232
+ output_dir = Path('outputs/evaluation')
233
+ output_dir.mkdir(parents=True, exist_ok=True)
234
+
235
+ # Summary
236
+ summary = {
237
+ 'timestamp': datetime.now().isoformat(),
238
+ 'total_samples': detection_stats['total'],
239
+ 'detection_statistics': detection_stats,
240
+ 'detection_rate': detection_stats['detected'] / detection_stats['has_forgery'] if detection_stats['has_forgery'] > 0 else 0,
241
+ 'precision': detection_stats['detected'] / (detection_stats['detected'] + detection_stats['false_positives']) if (detection_stats['detected'] + detection_stats['false_positives']) > 0 else 0,
242
+ 'overall_accuracy': overall_accuracy,
243
+ 'segmentation_metrics': {
244
+ 'iou': float(avg_iou) if forgery_metrics else 0,
245
+ 'dice': float(avg_dice) if forgery_metrics else 0,
246
+ 'precision': float(avg_precision) if forgery_metrics else 0,
247
+ 'recall': float(avg_recall) if forgery_metrics else 0,
248
+ 'f1': float(avg_f1) if forgery_metrics else 0,
249
+ 'accuracy': float(avg_accuracy) if forgery_metrics else 0
250
+ }
251
+ }
252
+
253
+ summary_path = output_dir / 'evaluation_summary.json'
254
+ with open(summary_path, 'w') as f:
255
+ json.dump(summary, f, indent=2)
256
+
257
+ print(f"✓ Summary saved to: {summary_path}")
258
+
259
+ # Detailed metrics (optional - can be large)
260
+ # detailed_path = output_dir / 'detailed_metrics.json'
261
+ # with open(detailed_path, 'w') as f:
262
+ # json.dump(all_metrics, f, indent=2)
263
+ # print(f"✓ Detailed metrics saved to: {detailed_path}")
264
+
265
+ print("\n" + "="*80)
266
+ print("✅ EVALUATION COMPLETE!")
267
+ print("="*80)
268
+ print(f"\nKey Metrics Summary:")
269
+ print(f" Detection Rate: {detection_stats['detected'] / detection_stats['has_forgery']:.2%}")
270
+ print(f" Overall Accuracy: {overall_accuracy:.2%}")
271
+ print(f" Dice Score: {avg_dice:.4f}" if forgery_metrics else " Dice Score: N/A")
272
+ print(f" IoU: {avg_iou:.4f}" if forgery_metrics else " IoU: N/A")
273
+ print("="*80 + "\n")
274
+
275
+
276
+ if __name__ == '__main__':
277
+ main()
scripts/export_model.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Export Script
3
+
4
+ Export trained model to ONNX format for deployment.
5
+
6
+ Usage:
7
+ python scripts/export_model.py --model outputs/checkpoints/best_doctamper.pth --format onnx
8
+ """
9
+
10
+ import argparse
11
+ import sys
12
+ from pathlib import Path
13
+
14
+ # Add src to path
15
+ sys.path.insert(0, str(Path(__file__).parent.parent))
16
+
17
+ import torch
18
+
19
+ from src.config import get_config
20
+ from src.models import get_model
21
+ from src.utils import export_to_onnx, export_to_torchscript
22
+
23
+
24
+ def parse_args():
25
+ parser = argparse.ArgumentParser(description="Export model for deployment")
26
+
27
+ parser.add_argument('--model', type=str, required=True,
28
+ help='Path to model checkpoint')
29
+
30
+ parser.add_argument('--format', type=str, default='onnx',
31
+ choices=['onnx', 'torchscript', 'both'],
32
+ help='Export format')
33
+
34
+ parser.add_argument('--output', type=str, default='outputs/exported',
35
+ help='Output directory')
36
+
37
+ parser.add_argument('--config', type=str, default='config.yaml',
38
+ help='Path to config file')
39
+
40
+ return parser.parse_args()
41
+
42
+
43
+ def main():
44
+ args = parse_args()
45
+
46
+ # Load config
47
+ config = get_config(args.config)
48
+
49
+ print("\n" + "="*60)
50
+ print("Model Export")
51
+ print("="*60)
52
+ print(f"Model: {args.model}")
53
+ print(f"Format: {args.format}")
54
+ print("="*60)
55
+
56
+ # Create output directory
57
+ output_dir = Path(args.output)
58
+ output_dir.mkdir(parents=True, exist_ok=True)
59
+
60
+ # Load model
61
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
62
+ model = get_model(config).to(device)
63
+
64
+ checkpoint = torch.load(args.model, map_location=device)
65
+ if 'model_state_dict' in checkpoint:
66
+ model.load_state_dict(checkpoint['model_state_dict'])
67
+ else:
68
+ model.load_state_dict(checkpoint)
69
+
70
+ model.eval()
71
+ print("Model loaded")
72
+
73
+ # Get image size
74
+ image_size = config.get('data.image_size', 384)
75
+
76
+ # Export
77
+ if args.format in ['onnx', 'both']:
78
+ onnx_path = output_dir / 'model.onnx'
79
+ export_to_onnx(model, str(onnx_path), input_size=(image_size, image_size))
80
+
81
+ if args.format in ['torchscript', 'both']:
82
+ ts_path = output_dir / 'model.pt'
83
+ export_to_torchscript(model, str(ts_path), input_size=(image_size, image_size))
84
+
85
+ print("\n" + "="*60)
86
+ print("Export Complete!")
87
+ print(f"Output: {output_dir}")
88
+ print("="*60)
89
+
90
+
91
+ if __name__ == '__main__':
92
+ main()
scripts/inference_pipeline.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Complete Document Forgery Detection Pipeline
3
+ Implements Full Algorithm Steps 1-11
4
+
5
+ Features:
6
+ - ✅ Localization (WHERE is forgery?)
7
+ - ✅ Classification (WHAT type of forgery?)
8
+ - ✅ Confidence filtering
9
+ - ✅ Visualizations (heatmaps, overlays, bounding boxes)
10
+ - ✅ JSON output with detailed results
11
+ - ✅ Actual vs Predicted comparison (if ground truth available)
12
+
13
+ Usage:
14
+ python scripts/inference_pipeline.py --image path/to/document.jpg
15
+ python scripts/inference_pipeline.py --image path/to/document.jpg --ground_truth path/to/mask.png
16
+ """
17
+
18
+ import sys
19
+ from pathlib import Path
20
+ import argparse
21
+ import numpy as np
22
+ import cv2
23
+ import torch
24
+ import json
25
+ from datetime import datetime
26
+ import matplotlib.pyplot as plt
27
+ import matplotlib.patches as patches
28
+
29
+ sys.path.insert(0, str(Path(__file__).parent.parent))
30
+
31
+ from src.config import get_config
32
+ from src.models import get_model
33
+ from src.features import get_feature_extractor, get_mask_refiner, get_region_extractor
34
+ from src.training.classifier import ForgeryClassifier
35
+ from src.data.preprocessing import DocumentPreprocessor
36
+
37
+ # Class mapping
38
+ CLASS_NAMES = {
39
+ 0: 'Copy-Move',
40
+ 1: 'Splicing',
41
+ 2: 'Generation'
42
+ }
43
+
44
+ CLASS_COLORS = {
45
+ 0: (255, 0, 0), # Red for Copy-Move
46
+ 1: (0, 255, 0), # Green for Splicing
47
+ 2: (0, 0, 255) # Blue for Generation
48
+ }
49
+
50
+
51
+ class ForgeryDetectionPipeline:
52
+ """
53
+ Complete forgery detection pipeline
54
+ Implements Algorithm Steps 1-11
55
+ """
56
+
57
+ def __init__(self, config_path='config.yaml'):
58
+ """Initialize pipeline with models"""
59
+ print("="*70)
60
+ print("Initializing Forgery Detection Pipeline")
61
+ print("="*70)
62
+
63
+ self.config = get_config(config_path)
64
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
65
+
66
+ # Load localization model (Steps 1-6)
67
+ print("\n1. Loading localization model...")
68
+ self.localization_model = get_model(self.config).to(self.device)
69
+ checkpoint = torch.load('outputs/checkpoints/best_doctamper.pth',
70
+ map_location=self.device)
71
+ self.localization_model.load_state_dict(checkpoint['model_state_dict'])
72
+ self.localization_model.eval()
73
+ print(f" ✓ Loaded (Val Dice: {checkpoint.get('best_metric', 0):.2%})")
74
+
75
+ # Load classifier (Step 8)
76
+ print("\n2. Loading forgery type classifier...")
77
+ self.classifier = ForgeryClassifier(self.config)
78
+ self.classifier.load('outputs/classifier')
79
+ print(" ✓ Loaded")
80
+
81
+ # Initialize components
82
+ print("\n3. Initializing components...")
83
+ self.preprocessor = DocumentPreprocessor(self.config, 'doctamper')
84
+
85
+ # Initialize augmentation for inference
86
+ from src.data.augmentation import DatasetAwareAugmentation
87
+ self.augmentation = DatasetAwareAugmentation(self.config, 'doctamper', is_training=False)
88
+
89
+ self.feature_extractor = get_feature_extractor(self.config, is_text_document=True)
90
+ self.mask_refiner = get_mask_refiner(self.config)
91
+ self.region_extractor = get_region_extractor(self.config)
92
+ print(" ✓ Ready")
93
+
94
+ print("\n" + "="*70)
95
+ print("Pipeline Initialized Successfully!")
96
+ print("="*70 + "\n")
97
+
98
+ def detect(self, image_path, ground_truth_path=None, output_dir='outputs/inference'):
99
+ """
100
+ Run complete detection pipeline
101
+
102
+ Args:
103
+ image_path: Path to input document image
104
+ ground_truth_path: Optional path to ground truth mask
105
+ output_dir: Directory to save outputs
106
+
107
+ Returns:
108
+ results: Dictionary with detection results
109
+ """
110
+ print(f"\n{'='*70}")
111
+ print(f"Processing: {image_path}")
112
+ print(f"{'='*70}\n")
113
+
114
+ # Create output directory
115
+ output_path = Path(output_dir)
116
+ output_path.mkdir(parents=True, exist_ok=True)
117
+
118
+ # Get base filename
119
+ base_name = Path(image_path).stem
120
+
121
+ # Step 1-2: Load and preprocess image (EXACTLY like dataset)
122
+ print("Step 1-2: Loading and preprocessing...")
123
+ image = cv2.imread(str(image_path))
124
+ if image is None:
125
+ raise ValueError(f"Could not load image: {image_path}")
126
+
127
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
128
+
129
+ # Create dummy mask for preprocessing
130
+ dummy_mask = np.zeros(image_rgb.shape[:2], dtype=np.uint8)
131
+
132
+ # Step 1: Preprocess (like dataset line: image, mask = self.preprocessor(image, mask))
133
+ preprocessed_img, preprocessed_mask = self.preprocessor(image_rgb, dummy_mask)
134
+
135
+ # Step 2: Augment (like dataset line: augmented = self.augmentation(image, mask))
136
+ augmented = self.augmentation(preprocessed_img, preprocessed_mask)
137
+
138
+ # Step 3: Extract tensor (like dataset line: image = augmented['image'])
139
+ image_tensor = augmented['image']
140
+
141
+ print(f" ✓ Image shape: {image_rgb.shape}")
142
+ print(f" ✓ Preprocessed tensor shape: {image_tensor.shape}")
143
+ print(f" ✓ Tensor range: [{image_tensor.min():.4f}, {image_tensor.max():.4f}]")
144
+
145
+ # Load ground truth if provided
146
+ ground_truth = None
147
+ if ground_truth_path:
148
+ ground_truth = cv2.imread(str(ground_truth_path), cv2.IMREAD_GRAYSCALE)
149
+ if ground_truth is not None:
150
+ # Resize to match preprocessed size
151
+ target_size = (image_tensor.shape[2], image_tensor.shape[1]) # (W, H)
152
+ ground_truth = cv2.resize(ground_truth, target_size)
153
+ print(f" ✓ Ground truth loaded")
154
+
155
+ # Step 3-4: Localization (WHERE is forgery?)
156
+ print("\nStep 3-4: Forgery localization...")
157
+ image_batch = image_tensor.unsqueeze(0).to(self.device)
158
+
159
+ with torch.no_grad():
160
+ logits, decoder_features = self.localization_model(image_batch)
161
+ prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
162
+
163
+ print(f" ✓ Probability map generated")
164
+ print(f" ✓ Prob map range: [{prob_map.min():.4f}, {prob_map.max():.4f}]")
165
+
166
+ # Step 5: Binary mask generation
167
+ print("\nStep 5: Generating binary mask...")
168
+ binary_mask = (prob_map > 0.5).astype(np.uint8)
169
+ refined_mask = self.mask_refiner.refine(binary_mask)
170
+ print(f" ✓ Mask refined")
171
+
172
+ # Step 6: Region extraction
173
+ print("\nStep 6: Extracting forgery regions...")
174
+ # Convert tensor to numpy for region extraction and feature extraction
175
+ preprocessed_numpy = image_tensor.permute(1, 2, 0).cpu().numpy()
176
+ regions = self.region_extractor.extract(refined_mask, prob_map, preprocessed_numpy)
177
+ print(f" ✓ Found {len(regions)} regions")
178
+
179
+ if len(regions) == 0:
180
+ print("\n⚠ No forgery regions detected!")
181
+ # Still create visualizations if ground truth exists
182
+ if ground_truth is not None:
183
+ print("\nCreating comparison with ground truth...")
184
+ self._create_comparison_visualization(
185
+ image_rgb, prob_map, refined_mask, ground_truth,
186
+ base_name, output_path
187
+ )
188
+ return self._create_clean_result(image_rgb, base_name, output_path, ground_truth)
189
+
190
+ # Step 7-8: Feature extraction and classification
191
+ print("\nStep 7-8: Classifying forgery types...")
192
+ region_results = []
193
+
194
+ for i, region in enumerate(regions):
195
+ # Extract features (Step 7)
196
+ features = self.feature_extractor.extract(
197
+ preprocessed_numpy,
198
+ region['region_mask'],
199
+ [f.cpu() for f in decoder_features]
200
+ )
201
+
202
+ # Ensure correct dimension (526)
203
+ expected_dim = 526
204
+ if len(features) < expected_dim:
205
+ features = np.pad(features, (0, expected_dim - len(features)))
206
+ elif len(features) > expected_dim:
207
+ features = features[:expected_dim]
208
+
209
+ features = features.reshape(1, -1)
210
+
211
+ # Classify (Step 8)
212
+ predictions, confidences = self.classifier.predict(features)
213
+ forgery_type = int(predictions[0])
214
+ confidence = float(confidences[0])
215
+
216
+ region_results.append({
217
+ 'region_id': i + 1,
218
+ 'bounding_box': region['bounding_box'],
219
+ 'area': int(region['area']),
220
+ 'forgery_type': CLASS_NAMES[forgery_type],
221
+ 'forgery_type_id': forgery_type,
222
+ 'confidence': confidence,
223
+ 'mask_probability_mean': float(prob_map[region['region_mask'] > 0].mean())
224
+ })
225
+
226
+ print(f" Region {i+1}: {CLASS_NAMES[forgery_type]} "
227
+ f"(confidence: {confidence:.2%})")
228
+
229
+ # Step 9: False positive removal
230
+ print("\nStep 9: Filtering low-confidence regions...")
231
+ confidence_threshold = self.config.get('classification.confidence_threshold', 0.6)
232
+ filtered_results = [r for r in region_results if r['confidence'] >= confidence_threshold]
233
+ print(f" ✓ Kept {len(filtered_results)}/{len(region_results)} regions "
234
+ f"(threshold: {confidence_threshold:.0%})")
235
+
236
+ # Step 10-11: Generate outputs
237
+ print("\nStep 10-11: Generating outputs...")
238
+
239
+ # Calculate scale factors for coordinate conversion
240
+ # Bounding boxes are in preprocessed coordinates (384x384)
241
+ # Need to scale to original image coordinates
242
+ orig_h, orig_w = image_rgb.shape[:2]
243
+ prep_h, prep_w = prob_map.shape
244
+ scale_x = orig_w / prep_w
245
+ scale_y = orig_h / prep_h
246
+
247
+ # Create visualizations
248
+ self._create_visualizations(
249
+ image_rgb, prob_map, refined_mask, filtered_results,
250
+ ground_truth, base_name, output_path, scale_x, scale_y
251
+ )
252
+
253
+ # Create JSON output
254
+ results = self._create_json_output(
255
+ image_path, filtered_results, ground_truth, base_name, output_path
256
+ )
257
+
258
+ print(f"\n{'='*70}")
259
+ print("✅ Detection Complete!")
260
+ print(f"{'='*70}")
261
+ print(f"Output directory: {output_path}")
262
+ print(f"Detected {len(filtered_results)} forgery regions")
263
+ print(f"{'='*70}\n")
264
+
265
+ return results
266
+
267
+ def _create_visualizations(self, image, prob_map, mask, results,
268
+ ground_truth, base_name, output_path, scale_x, scale_y):
269
+ """Create all visualizations"""
270
+
271
+ # 1. Probability heatmap
272
+ plt.figure(figsize=(15, 5))
273
+
274
+ plt.subplot(1, 3, 1)
275
+ plt.imshow(image)
276
+ plt.title('Original Document')
277
+ plt.axis('off')
278
+
279
+ plt.subplot(1, 3, 2)
280
+ plt.imshow(prob_map, cmap='hot', vmin=0, vmax=1)
281
+ plt.colorbar(label='Forgery Probability')
282
+ plt.title('Probability Heatmap')
283
+ plt.axis('off')
284
+
285
+ plt.subplot(1, 3, 3)
286
+ plt.imshow(mask, cmap='gray')
287
+ plt.title('Binary Mask')
288
+ plt.axis('off')
289
+
290
+ plt.tight_layout()
291
+ plt.savefig(output_path / f'{base_name}_heatmap.png', dpi=150, bbox_inches='tight')
292
+ plt.close()
293
+ print(f" ✓ Saved heatmap")
294
+
295
+ # 2. Overlay with bounding boxes and labels
296
+ overlay = image.copy()
297
+ alpha = 0.4
298
+
299
+ # Create colored mask overlay (scale mask to original size)
300
+ mask_scaled = cv2.resize(mask, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
301
+ colored_mask = np.zeros_like(image)
302
+
303
+ for result in results:
304
+ bbox = result['bounding_box']
305
+ forgery_type = result['forgery_type_id']
306
+ color = CLASS_COLORS[forgery_type]
307
+
308
+ # Scale bounding box to original image coordinates
309
+ x, y, w, h = bbox
310
+ x_scaled = int(x * scale_x)
311
+ y_scaled = int(y * scale_y)
312
+ w_scaled = int(w * scale_x)
313
+ h_scaled = int(h * scale_y)
314
+
315
+ # Color the region
316
+ colored_mask[y_scaled:y_scaled+h_scaled, x_scaled:x_scaled+w_scaled] = color
317
+
318
+ # Blend with original
319
+ overlay = cv2.addWeighted(overlay, 1-alpha, colored_mask, alpha, 0)
320
+
321
+ # Draw bounding boxes and labels
322
+ fig, ax = plt.subplots(1, figsize=(12, 8))
323
+ ax.imshow(overlay)
324
+
325
+ for result in results:
326
+ bbox = result['bounding_box']
327
+ x, y, w, h = bbox # bbox is [x, y, w, h] in preprocessed coordinates
328
+
329
+ # Scale to original image coordinates
330
+ x_scaled = x * scale_x
331
+ y_scaled = y * scale_y
332
+ w_scaled = w * scale_x
333
+ h_scaled = h * scale_y
334
+
335
+ forgery_type = result['forgery_type']
336
+ confidence = result['confidence']
337
+ color_rgb = tuple(c/255 for c in CLASS_COLORS[result['forgery_type_id']])
338
+
339
+ # Draw rectangle
340
+ rect = patches.Rectangle((x_scaled, y_scaled), w_scaled, h_scaled,
341
+ linewidth=2, edgecolor=color_rgb,
342
+ facecolor='none')
343
+ ax.add_patch(rect)
344
+
345
+ # Add label
346
+ label = f"{forgery_type}\n{confidence:.1%}"
347
+ ax.text(x_scaled, y_scaled-10, label, color='white', fontsize=10,
348
+ bbox=dict(boxstyle='round', facecolor=color_rgb, alpha=0.8))
349
+
350
+ ax.axis('off')
351
+ ax.set_title('Forgery Detection Results', fontsize=14, fontweight='bold')
352
+ plt.tight_layout()
353
+ plt.savefig(output_path / f'{base_name}_overlay.png', dpi=150, bbox_inches='tight')
354
+ plt.close()
355
+ print(f" ✓ Saved overlay")
356
+
357
+ # 3. Comparison with ground truth (if available)
358
+ if ground_truth is not None:
359
+ fig, axes = plt.subplots(1, 3, figsize=(18, 6))
360
+
361
+ axes[0].imshow(image)
362
+ axes[0].set_title('Original Document', fontsize=12)
363
+ axes[0].axis('off')
364
+
365
+ axes[1].imshow(ground_truth, cmap='gray')
366
+ axes[1].set_title('Ground Truth', fontsize=12)
367
+ axes[1].axis('off')
368
+
369
+ axes[2].imshow(mask, cmap='gray')
370
+ axes[2].set_title('Predicted Mask', fontsize=12)
371
+ axes[2].axis('off')
372
+
373
+ # Calculate metrics
374
+ intersection = np.logical_and(ground_truth > 127, mask > 0).sum()
375
+ union = np.logical_or(ground_truth > 127, mask > 0).sum()
376
+ iou = intersection / (union + 1e-8)
377
+ dice = 2 * intersection / (ground_truth.sum() + mask.sum() + 1e-8)
378
+
379
+ fig.suptitle(f'Actual vs Predicted (IoU: {iou:.2%}, Dice: {dice:.2%})',
380
+ fontsize=14, fontweight='bold')
381
+
382
+ plt.tight_layout()
383
+ plt.savefig(output_path / f'{base_name}_comparison.png', dpi=150, bbox_inches='tight')
384
+ plt.close()
385
+ print(f" ✓ Saved comparison (IoU: {iou:.2%}, Dice: {dice:.2%})")
386
+
387
+ # 4. Per-region visualization
388
+ if len(results) > 0:
389
+ n_regions = len(results)
390
+ cols = min(4, n_regions)
391
+ rows = (n_regions + cols - 1) // cols
392
+
393
+ fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
394
+ if n_regions == 1:
395
+ axes = [axes]
396
+ else:
397
+ axes = axes.flatten()
398
+
399
+ for i, result in enumerate(results):
400
+ bbox = result['bounding_box']
401
+ x, y, w, h = bbox # bbox is [x, y, w, h] in preprocessed coordinates
402
+
403
+ # Scale to original image coordinates
404
+ x_scaled = int(x * scale_x)
405
+ y_scaled = int(y * scale_y)
406
+ w_scaled = int(w * scale_x)
407
+ h_scaled = int(h * scale_y)
408
+
409
+ region_img = image[y_scaled:y_scaled+h_scaled, x_scaled:x_scaled+w_scaled]
410
+
411
+ axes[i].imshow(region_img)
412
+ axes[i].set_title(f"Region {i+1}: {result['forgery_type']}\n"
413
+ f"Confidence: {result['confidence']:.1%}",
414
+ fontsize=10)
415
+ axes[i].axis('off')
416
+
417
+ # Hide unused subplots
418
+ for i in range(n_regions, len(axes)):
419
+ axes[i].axis('off')
420
+
421
+ plt.tight_layout()
422
+ plt.savefig(output_path / f'{base_name}_regions.png', dpi=150, bbox_inches='tight')
423
+ plt.close()
424
+ print(f" ✓ Saved region details")
425
+
426
+ def _create_json_output(self, image_path, results, ground_truth, base_name, output_path):
427
+ """Create JSON output with results"""
428
+
429
+ output = {
430
+ 'image_path': str(image_path),
431
+ 'timestamp': datetime.now().isoformat(),
432
+ 'num_regions_detected': len(results),
433
+ 'regions': results
434
+ }
435
+
436
+ # Add ground truth comparison if available
437
+ if ground_truth is not None:
438
+ output['has_ground_truth'] = True
439
+
440
+ # Save JSON
441
+ json_path = output_path / f'{base_name}_results.json'
442
+ with open(json_path, 'w') as f:
443
+ json.dump(output, f, indent=2)
444
+
445
+ print(f" ✓ Saved JSON results")
446
+
447
+ return output
448
+
449
+ def _create_comparison_visualization(self, image, prob_map, mask, ground_truth,
450
+ base_name, output_path):
451
+ """Create comparison visualization between actual and predicted"""
452
+
453
+ fig, axes = plt.subplots(2, 2, figsize=(16, 12))
454
+
455
+ # Original image
456
+ axes[0, 0].imshow(image)
457
+ axes[0, 0].set_title('Original Document', fontsize=14, fontweight='bold')
458
+ axes[0, 0].axis('off')
459
+
460
+ # Ground truth
461
+ axes[0, 1].imshow(ground_truth, cmap='gray')
462
+ axes[0, 1].set_title('Ground Truth (Actual)', fontsize=14, fontweight='bold')
463
+ axes[0, 1].axis('off')
464
+
465
+ # Predicted mask
466
+ axes[1, 0].imshow(mask, cmap='gray')
467
+ axes[1, 0].set_title('Predicted Mask', fontsize=14, fontweight='bold')
468
+ axes[1, 0].axis('off')
469
+
470
+ # Probability heatmap
471
+ im = axes[1, 1].imshow(prob_map, cmap='hot', vmin=0, vmax=1)
472
+ axes[1, 1].set_title('Probability Heatmap', fontsize=14, fontweight='bold')
473
+ axes[1, 1].axis('off')
474
+ plt.colorbar(im, ax=axes[1, 1], fraction=0.046, pad=0.04)
475
+
476
+ # Calculate metrics
477
+ intersection = np.logical_and(ground_truth > 127, mask > 0).sum()
478
+ union = np.logical_or(ground_truth > 127, mask > 0).sum()
479
+ gt_sum = (ground_truth > 127).sum()
480
+ pred_sum = (mask > 0).sum()
481
+
482
+ iou = intersection / (union + 1e-8)
483
+ dice = 2 * intersection / (gt_sum + pred_sum + 1e-8)
484
+ precision = intersection / (pred_sum + 1e-8) if pred_sum > 0 else 0
485
+ recall = intersection / (gt_sum + 1e-8) if gt_sum > 0 else 0
486
+
487
+ fig.suptitle(f'Actual vs Predicted Comparison\n'
488
+ f'IoU: {iou:.2%} | Dice: {dice:.2%} | '
489
+ f'Precision: {precision:.2%} | Recall: {recall:.2%}',
490
+ fontsize=16, fontweight='bold')
491
+
492
+ plt.tight_layout()
493
+ plt.savefig(output_path / f'{base_name}_comparison.png', dpi=150, bbox_inches='tight')
494
+ plt.close()
495
+ print(f" ✓ Saved comparison (IoU: {iou:.2%}, Dice: {dice:.2%})")
496
+
497
+ def _create_clean_result(self, image, base_name, output_path, ground_truth=None):
498
+ """Create result for clean (no forgery) document"""
499
+
500
+ # Save original image
501
+ plt.figure(figsize=(10, 8))
502
+ plt.imshow(image)
503
+ plt.title('No Forgery Detected', fontsize=14, fontweight='bold', color='green')
504
+ plt.axis('off')
505
+ plt.tight_layout()
506
+ plt.savefig(output_path / f'{base_name}_clean.png', dpi=150, bbox_inches='tight')
507
+ plt.close()
508
+
509
+ # Create JSON
510
+ output = {
511
+ 'timestamp': datetime.now().isoformat(),
512
+ 'num_regions_detected': 0,
513
+ 'regions': [],
514
+ 'status': 'clean'
515
+ }
516
+
517
+ json_path = output_path / f'{base_name}_results.json'
518
+ with open(json_path, 'w') as f:
519
+ json.dump(output, f, indent=2)
520
+
521
+ return output
522
+
523
+
524
+ def main():
525
+ parser = argparse.ArgumentParser(description='Document Forgery Detection Pipeline')
526
+ parser.add_argument('--image', type=str, required=True,
527
+ help='Path to input document image')
528
+ parser.add_argument('--ground_truth', type=str, default=None,
529
+ help='Path to ground truth mask (optional)')
530
+ parser.add_argument('--output_dir', type=str, default='outputs/inference',
531
+ help='Output directory for results')
532
+ parser.add_argument('--config', type=str, default='config.yaml',
533
+ help='Path to config file')
534
+
535
+ args = parser.parse_args()
536
+
537
+ # Initialize pipeline
538
+ pipeline = ForgeryDetectionPipeline(args.config)
539
+
540
+ # Run detection
541
+ results = pipeline.detect(
542
+ args.image,
543
+ ground_truth_path=args.ground_truth,
544
+ output_dir=args.output_dir
545
+ )
546
+
547
+ # Print summary
548
+ print("\nDetection Summary:")
549
+ print(f" Regions detected: {results['num_regions_detected']}")
550
+ if results['num_regions_detected'] > 0:
551
+ for region in results['regions']:
552
+ print(f" - {region['forgery_type']}: {region['confidence']:.1%} confidence")
553
+
554
+
555
+ if __name__ == '__main__':
556
+ main()
scripts/run_inference.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference Script for Document Forgery Detection
3
+
4
+ Run inference on single images or entire directories.
5
+
6
+ Usage:
7
+ python scripts/run_inference.py --input path/to/image.jpg --model outputs/checkpoints/best_doctamper.pth
8
+ python scripts/run_inference.py --input path/to/folder/ --model outputs/checkpoints/best_doctamper.pth
9
+ """
10
+
11
+ import argparse
12
+ import sys
13
+ from pathlib import Path
14
+ import json
15
+
16
+ # Add src to path
17
+ sys.path.insert(0, str(Path(__file__).parent.parent))
18
+
19
+ from src.config import get_config
20
+ from src.inference import get_pipeline
21
+
22
+
23
+ def parse_args():
24
+ parser = argparse.ArgumentParser(description="Run forgery detection inference")
25
+
26
+ parser.add_argument('--input', type=str, required=True,
27
+ help='Input image or directory path')
28
+
29
+ parser.add_argument('--model', type=str, required=True,
30
+ help='Path to localization model checkpoint')
31
+
32
+ parser.add_argument('--classifier', type=str, default=None,
33
+ help='Path to classifier directory (optional)')
34
+
35
+ parser.add_argument('--output', type=str, default='outputs/results',
36
+ help='Output directory')
37
+
38
+ parser.add_argument('--is_text', action='store_true',
39
+ help='Enable OCR features for text documents')
40
+
41
+ parser.add_argument('--config', type=str, default='config.yaml',
42
+ help='Path to config file')
43
+
44
+ return parser.parse_args()
45
+
46
+
47
+ def process_file(pipeline, input_path: str, output_dir: str):
48
+ """Process a single file"""
49
+ try:
50
+ result = pipeline.run(input_path, output_dir)
51
+ return result
52
+ except Exception as e:
53
+ print(f"Error processing {input_path}: {e}")
54
+ return None
55
+
56
+
57
+ def main():
58
+ args = parse_args()
59
+
60
+ # Load config
61
+ config = get_config(args.config)
62
+
63
+ print("\n" + "="*60)
64
+ print("Hybrid Document Forgery Detection - Inference")
65
+ print("="*60)
66
+ print(f"Input: {args.input}")
67
+ print(f"Model: {args.model}")
68
+ print(f"Classifier: {args.classifier or 'None'}")
69
+ print(f"Output: {args.output}")
70
+ print("="*60)
71
+
72
+ # Create pipeline
73
+ pipeline = get_pipeline(
74
+ config,
75
+ model_path=args.model,
76
+ classifier_path=args.classifier,
77
+ is_text_document=args.is_text
78
+ )
79
+
80
+ # Create output directory
81
+ output_dir = Path(args.output)
82
+ output_dir.mkdir(parents=True, exist_ok=True)
83
+
84
+ # Get input files
85
+ input_path = Path(args.input)
86
+
87
+ if input_path.is_file():
88
+ files = [input_path]
89
+ elif input_path.is_dir():
90
+ extensions = ['.jpg', '.jpeg', '.png', '.pdf', '.bmp', '.tiff']
91
+ files = [f for f in input_path.iterdir()
92
+ if f.suffix.lower() in extensions]
93
+ else:
94
+ print(f"Invalid input path: {input_path}")
95
+ return
96
+
97
+ print(f"\nProcessing {len(files)} file(s)...")
98
+
99
+ # Process files
100
+ all_results = []
101
+
102
+ for file_path in files:
103
+ result = process_file(pipeline, str(file_path), str(output_dir))
104
+ if result:
105
+ all_results.append(result)
106
+
107
+ # Print summary
108
+ status = "TAMPERED" if result['is_tampered'] else "AUTHENTIC"
109
+ print(f"\n {file_path.name}: {status}")
110
+ if result['is_tampered']:
111
+ print(f" Regions detected: {result['num_regions']}")
112
+ for region in result['regions'][:3]: # Show first 3
113
+ print(f" - {region['forgery_type']} (conf: {region['confidence']:.2f})")
114
+
115
+ # Save summary
116
+ summary_path = output_dir / 'inference_summary.json'
117
+ summary = {
118
+ 'total_files': len(files),
119
+ 'processed': len(all_results),
120
+ 'tampered': sum(1 for r in all_results if r['is_tampered']),
121
+ 'authentic': sum(1 for r in all_results if not r['is_tampered']),
122
+ 'results': all_results
123
+ }
124
+
125
+ with open(summary_path, 'w') as f:
126
+ json.dump(summary, f, indent=2, default=str)
127
+
128
+ print("\n" + "="*60)
129
+ print("Inference Complete!")
130
+ print(f"Total: {summary['total_files']}, "
131
+ f"Tampered: {summary['tampered']}, "
132
+ f"Authentic: {summary['authentic']}")
133
+ print(f"Results saved to: {output_dir}")
134
+ print("="*60)
135
+
136
+
137
+ if __name__ == '__main__':
138
+ main()
scripts/setup.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Setup Script
3
+
4
+ Creates output directories and verifies installation.
5
+
6
+ Usage:
7
+ python scripts/setup.py
8
+ """
9
+
10
+ import sys
11
+ from pathlib import Path
12
+
13
+ # Add src to path
14
+ sys.path.insert(0, str(Path(__file__).parent.parent))
15
+
16
+
17
+ def create_directories():
18
+ """Create required output directories"""
19
+
20
+ base_dir = Path(__file__).parent.parent
21
+
22
+ directories = [
23
+ base_dir / 'outputs',
24
+ base_dir / 'outputs' / 'checkpoints',
25
+ base_dir / 'outputs' / 'logs',
26
+ base_dir / 'outputs' / 'plots',
27
+ base_dir / 'outputs' / 'results',
28
+ base_dir / 'outputs' / 'classifier',
29
+ base_dir / 'outputs' / 'exported',
30
+ ]
31
+
32
+ for directory in directories:
33
+ directory.mkdir(parents=True, exist_ok=True)
34
+ print(f"Created: {directory}")
35
+
36
+
37
+ def verify_installation():
38
+ """Verify all required packages are installed"""
39
+
40
+ required_packages = [
41
+ ('torch', 'PyTorch'),
42
+ ('torchvision', 'TorchVision'),
43
+ ('timm', 'TIMM'),
44
+ ('lightgbm', 'LightGBM'),
45
+ ('sklearn', 'Scikit-learn'),
46
+ ('cv2', 'OpenCV'),
47
+ ('PIL', 'Pillow'),
48
+ ('numpy', 'NumPy'),
49
+ ('pandas', 'Pandas'),
50
+ ('matplotlib', 'Matplotlib'),
51
+ ('seaborn', 'Seaborn'),
52
+ ('albumentations', 'Albumentations'),
53
+ ('tqdm', 'TQDM'),
54
+ ('yaml', 'PyYAML'),
55
+ ('pywt', 'PyWavelets'),
56
+ ]
57
+
58
+ print("\nVerifying installation...")
59
+ print("-" * 40)
60
+
61
+ missing = []
62
+
63
+ for package, name in required_packages:
64
+ try:
65
+ __import__(package)
66
+ print(f" ✓ {name}")
67
+ except ImportError:
68
+ print(f" ✗ {name} (MISSING)")
69
+ missing.append(name)
70
+
71
+ # Check CUDA
72
+ print("-" * 40)
73
+ try:
74
+ import torch
75
+ if torch.cuda.is_available():
76
+ print(f" ✓ CUDA Available: {torch.cuda.get_device_name(0)}")
77
+ else:
78
+ print(" ⚠ CUDA Not Available (CPU mode)")
79
+ except Exception as e:
80
+ print(f" ✗ CUDA Check Failed: {e}")
81
+
82
+ return missing
83
+
84
+
85
+ def verify_datasets():
86
+ """Verify dataset paths exist"""
87
+
88
+ base_dir = Path(__file__).parent.parent
89
+
90
+ datasets = {
91
+ 'DocTamper': base_dir / 'datasets' / 'DocTamper',
92
+ 'RTM': base_dir / 'datasets' / 'RealTextManipulation',
93
+ 'CASIA': base_dir / 'datasets' / 'CASIA 1.0 dataset',
94
+ 'Receipts': base_dir / 'datasets' / 'findit2',
95
+ }
96
+
97
+ print("\nVerifying datasets...")
98
+ print("-" * 40)
99
+
100
+ for name, path in datasets.items():
101
+ if path.exists():
102
+ print(f" ✓ {name}: {path}")
103
+ else:
104
+ print(f" ✗ {name}: NOT FOUND ({path})")
105
+
106
+
107
+ def main():
108
+ print("\n" + "="*60)
109
+ print("Hybrid Document Forgery Detection - Setup")
110
+ print("="*60)
111
+
112
+ # Create directories
113
+ print("\nCreating directories...")
114
+ print("-" * 40)
115
+ create_directories()
116
+
117
+ # Verify installation
118
+ missing = verify_installation()
119
+
120
+ # Verify datasets
121
+ verify_datasets()
122
+
123
+ # Summary
124
+ print("\n" + "="*60)
125
+ if missing:
126
+ print("Setup complete with WARNINGS")
127
+ print(f"Missing packages: {', '.join(missing)}")
128
+ print("Run: pip install -r requirements.txt")
129
+ else:
130
+ print("Setup Complete! All checks passed.")
131
+ print("="*60)
132
+
133
+
134
+ if __name__ == '__main__':
135
+ main()
scripts/train_chunked.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chunked Training Script for Document Forgery Detection
3
+
4
+ Supports training on large datasets (DocTamper) in chunks to manage RAM constraints.
5
+ Usage:
6
+ python scripts/train_chunked.py --dataset doctamper --chunk 1
7
+ python scripts/train_chunked.py --dataset rtm
8
+ python scripts/train_chunked.py --dataset casia
9
+ python scripts/train_chunked.py --dataset receipts
10
+ """
11
+
12
+ import argparse
13
+ import os
14
+ import sys
15
+ from pathlib import Path
16
+
17
+ # Add src to path
18
+ sys.path.insert(0, str(Path(__file__).parent.parent))
19
+
20
+ import torch
21
+ import gc
22
+
23
+ from src.config import get_config
24
+ from src.training import get_trainer
25
+ from src.utils import plot_training_curves, plot_chunked_training_progress, generate_training_report
26
+
27
+
28
+ def parse_args():
29
+ parser = argparse.ArgumentParser(description="Train forgery detection model")
30
+
31
+ parser.add_argument('--dataset', type=str, default='doctamper',
32
+ choices=['doctamper', 'rtm', 'casia', 'receipts', 'fcd', 'scd'],
33
+ help='Dataset to train on')
34
+
35
+ parser.add_argument('--chunk', type=int, default=None,
36
+ help='Chunk number (1-4) for DocTamper chunked training')
37
+
38
+ parser.add_argument('--epochs', type=int, default=None,
39
+ help='Number of epochs (overrides config)')
40
+
41
+ parser.add_argument('--resume', type=str, default=None,
42
+ help='Checkpoint to resume from')
43
+
44
+ parser.add_argument('--config', type=str, default='config.yaml',
45
+ help='Path to config file')
46
+
47
+ return parser.parse_args()
48
+
49
+
50
+ def train_chunk(config, dataset_name: str, chunk_id: int, epochs: int = None, resume: str = None):
51
+ """Train a single chunk"""
52
+
53
+ # Calculate chunk boundaries
54
+ chunks = config.get('data.chunked_training.chunks', [])
55
+
56
+ if chunk_id > len(chunks):
57
+ raise ValueError(f"Invalid chunk ID: {chunk_id}. Max: {len(chunks)}")
58
+
59
+ chunk_config = chunks[chunk_id - 1]
60
+ chunk_start = chunk_config['start']
61
+ chunk_end = chunk_config['end']
62
+ chunk_name = chunk_config['name']
63
+
64
+ print(f"\n{'='*60}")
65
+ print(f"Training Chunk {chunk_id}: {chunk_name}")
66
+ print(f"Range: {chunk_start*100:.0f}% - {chunk_end*100:.0f}%")
67
+ print(f"{'='*60}")
68
+
69
+ # Create trainer
70
+ trainer = get_trainer(config, dataset_name)
71
+
72
+ # Resume from previous chunk if applicable
73
+ if resume:
74
+ # For chunked training, reset epoch counter to train full epochs on new data
75
+ trainer.load_checkpoint(resume, reset_epoch=True)
76
+ elif chunk_id > 1:
77
+ # Auto-resume from previous chunk
78
+ prev_checkpoint = f'{dataset_name}_chunk{chunk_id-1}_final.pth'
79
+ if (Path(config.get('outputs.checkpoints')) / prev_checkpoint).exists():
80
+ print(f"Auto-resuming from previous chunk: {prev_checkpoint}")
81
+ trainer.load_checkpoint(prev_checkpoint, reset_epoch=True)
82
+
83
+ # Train
84
+ history = trainer.train(
85
+ epochs=epochs,
86
+ chunk_start=chunk_start,
87
+ chunk_end=chunk_end,
88
+ chunk_id=chunk_id,
89
+ resume_from=None # Already loaded above
90
+ )
91
+
92
+ # Plot training curves
93
+ plot_dir = Path(config.get('outputs.plots', 'outputs/plots'))
94
+ plot_dir.mkdir(parents=True, exist_ok=True)
95
+
96
+ plot_path = plot_dir / f'{dataset_name}_chunk{chunk_id}_curves.png'
97
+ plot_training_curves(
98
+ history,
99
+ str(plot_path),
100
+ title=f"{dataset_name.upper()} Chunk {chunk_id} Training"
101
+ )
102
+
103
+ # Generate report
104
+ report_path = plot_dir / f'{dataset_name}_chunk{chunk_id}_report.txt'
105
+ generate_training_report(history, str(report_path), f"{dataset_name} Chunk {chunk_id}")
106
+
107
+ # Clear memory
108
+ del trainer
109
+ gc.collect()
110
+ torch.cuda.empty_cache()
111
+
112
+ return history
113
+
114
+
115
+ def train_full_dataset(config, dataset_name: str, epochs: int = None, resume: str = None):
116
+ """Train on full dataset (for smaller datasets)"""
117
+
118
+ print(f"\n{'='*60}")
119
+ print(f"Training on: {dataset_name.upper()}")
120
+ print(f"{'='*60}")
121
+
122
+ # Create trainer
123
+ trainer = get_trainer(config, dataset_name)
124
+
125
+ # Load checkpoint if resuming (reset epoch counter for new dataset)
126
+ if resume:
127
+ print(f"Loading weights from: {resume}")
128
+ trainer.load_checkpoint(resume, reset_epoch=True)
129
+ print("Epoch counter reset to 0 for new dataset training")
130
+
131
+ # Train
132
+ history = trainer.train(
133
+ epochs=epochs,
134
+ chunk_id=0,
135
+ resume_from=None # Already loaded above
136
+ )
137
+
138
+ # Plot training curves
139
+ plot_dir = Path(config.get('outputs.plots', 'outputs/plots'))
140
+ plot_dir.mkdir(parents=True, exist_ok=True)
141
+
142
+ plot_path = plot_dir / f'{dataset_name}_training_curves.png'
143
+ plot_training_curves(
144
+ history,
145
+ str(plot_path),
146
+ title=f"{dataset_name.upper()} Training"
147
+ )
148
+
149
+ # Generate report
150
+ report_path = plot_dir / f'{dataset_name}_report.txt'
151
+ generate_training_report(history, str(report_path), dataset_name)
152
+
153
+ return history
154
+
155
+
156
+ def main():
157
+ args = parse_args()
158
+
159
+ # Load config
160
+ config = get_config(args.config)
161
+
162
+ print("\n" + "="*60)
163
+ print("Hybrid Document Forgery Detection - Training")
164
+ print("="*60)
165
+ print(f"Dataset: {args.dataset}")
166
+ print(f"Device: {config.get('system.device')}")
167
+ print(f"CUDA Available: {torch.cuda.is_available()}")
168
+ if torch.cuda.is_available():
169
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
170
+ print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
171
+ print("="*60)
172
+
173
+ # DocTamper: chunked training
174
+ if args.dataset == 'doctamper' and args.chunk is not None:
175
+ history = train_chunk(
176
+ config,
177
+ args.dataset,
178
+ args.chunk,
179
+ epochs=args.epochs,
180
+ resume=args.resume
181
+ )
182
+
183
+ # DocTamper: all chunks sequentially
184
+ elif args.dataset == 'doctamper' and args.chunk is None:
185
+ print("Training DocTamper in 4 chunks...")
186
+
187
+ all_histories = []
188
+ for chunk_id in range(1, 5):
189
+ history = train_chunk(
190
+ config,
191
+ args.dataset,
192
+ chunk_id,
193
+ epochs=args.epochs,
194
+ resume=None if chunk_id == 1 else None # Auto-resume from prev chunk
195
+ )
196
+ all_histories.append(history)
197
+
198
+ # Plot combined progress
199
+ plot_dir = Path(config.get('outputs.plots', 'outputs/plots'))
200
+ combined_path = plot_dir / 'doctamper_all_chunks_progress.png'
201
+ plot_chunked_training_progress(
202
+ all_histories,
203
+ str(combined_path),
204
+ title="DocTamper Chunked Training Progress"
205
+ )
206
+
207
+ # Other datasets: full training
208
+ else:
209
+ history = train_full_dataset(
210
+ config,
211
+ args.dataset,
212
+ epochs=args.epochs,
213
+ resume=args.resume
214
+ )
215
+
216
+ print("\n" + "="*60)
217
+ print("Training Complete!")
218
+ print("="*60)
219
+
220
+
221
+ if __name__ == '__main__':
222
+ main()
scripts/train_classifier_doctamper_fixed.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LightGBM Classifier Training - DocTamper with Tampering Labels
3
+ FIXED VERSION with proper checkpointing and feature dimension handling
4
+
5
+ Implements Algorithm Steps 7-8:
6
+ 7. Hybrid Feature Extraction
7
+ 8. Region-wise Forgery Classification
8
+
9
+ Uses:
10
+ - Localization: best_doctamper.pth (Steps 1-6 complete)
11
+ - Training: DocTamper TrainingSet + tampering/DocTamperV1-TrainingSet.pk
12
+ - Testing: DocTamper TestingSet + tampering/DocTamperV1-TestingSet.pk
13
+ - Classes: Copy-Move (CM), Splicing (SP), Generation (GE)
14
+
15
+ Features:
16
+ - ✅ Checkpoint saving every 1000 samples
17
+ - ✅ Resume from checkpoint if interrupted
18
+ - ✅ Fixed feature dimension mismatch
19
+ - ✅ Robust error handling
20
+
21
+ Usage:
22
+ python scripts/train_classifier_doctamper_fixed.py
23
+ """
24
+
25
+ import sys
26
+ from pathlib import Path
27
+ import numpy as np
28
+ import pickle
29
+ import lmdb
30
+ import cv2
31
+ import torch
32
+ from tqdm import tqdm
33
+ import json
34
+
35
+ sys.path.insert(0, str(Path(__file__).parent.parent))
36
+
37
+ from src.config import get_config
38
+ from src.models import get_model
39
+ from src.features import get_feature_extractor
40
+ from src.training.classifier import get_classifier
41
+
42
+ # Configuration
43
+ MODEL_PATH = 'outputs/checkpoints/best_doctamper.pth'
44
+ OUTPUT_DIR = 'outputs/classifier'
45
+ MAX_SAMPLES = 999999 # Use all available samples
46
+
47
+ # Label mapping (Algorithm Step 8.2) - 3 classes
48
+ LABEL_MAP = {
49
+ 'CM': 0, # Copy-Move
50
+ 'SP': 1, # Splicing
51
+ 'GE': 2, # Generation (AI-generated, separate from Splicing)
52
+ }
53
+
54
+
55
+ def load_tampering_labels(label_file):
56
+ """Load forgery type labels from tampering folder"""
57
+ with open(label_file, 'rb') as f:
58
+ labels = pickle.load(f)
59
+
60
+ print(f"Loaded {len(labels)} labels from {label_file}")
61
+ return labels
62
+
63
+
64
+ def load_sample_from_lmdb(lmdb_env, index):
65
+ """Load image and mask from LMDB"""
66
+ txn = lmdb_env.begin()
67
+
68
+ # Get image
69
+ img_key = f'image-{index:09d}'.encode('utf-8')
70
+ img_data = txn.get(img_key)
71
+ if not img_data:
72
+ return None, None
73
+
74
+ # Get mask (DocTamper uses 'label-' not 'mask-')
75
+ mask_key = f'label-{index:09d}'.encode('utf-8')
76
+ mask_data = txn.get(mask_key)
77
+ if not mask_data:
78
+ return None, None
79
+
80
+ # Decode
81
+ img_array = np.frombuffer(img_data, dtype=np.uint8)
82
+ image = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
83
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
84
+
85
+ mask_array = np.frombuffer(mask_data, dtype=np.uint8)
86
+ mask = cv2.imdecode(mask_array, cv2.IMREAD_GRAYSCALE)
87
+
88
+ return image, mask
89
+
90
+
91
+ def extract_features(config, model, lmdb_path, tampering_labels,
92
+ max_samples, device, split_name):
93
+ """
94
+ Extract hybrid features with checkpointing and resume capability
95
+ """
96
+
97
+ print(f"\n{'='*60}")
98
+ print(f"Extracting features from {split_name}")
99
+ print(f"{'='*60}")
100
+
101
+ # Setup checkpoint directory
102
+ checkpoint_dir = Path(OUTPUT_DIR)
103
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
104
+
105
+ # Check for existing checkpoint to resume
106
+ checkpoints = list(checkpoint_dir.glob(f'checkpoint_{split_name}_*.npz'))
107
+ if checkpoints:
108
+ latest_checkpoint = max(checkpoints, key=lambda p: int(p.stem.split('_')[-1]))
109
+ print(f"✓ Found checkpoint: {latest_checkpoint.name}")
110
+
111
+ data = np.load(latest_checkpoint, allow_pickle=True)
112
+ all_features = data['features'].tolist()
113
+ all_labels = data['labels'].tolist()
114
+ expected_dim = int(data['feature_dim'])
115
+ start_idx = len(all_features)
116
+
117
+ print(f"✓ Resuming from sample {start_idx}, feature_dim={expected_dim}")
118
+ else:
119
+ all_features = []
120
+ all_labels = []
121
+ expected_dim = None
122
+ start_idx = 0
123
+
124
+ # Open LMDB
125
+ env = lmdb.open(lmdb_path, readonly=True, lock=False)
126
+
127
+ # Initialize feature extractor
128
+ feature_extractor = get_feature_extractor(config, is_text_document=True)
129
+
130
+ # Process samples
131
+ num_processed = start_idx
132
+ dim_mismatch_count = 0
133
+
134
+ for i in tqdm(range(start_idx, min(len(tampering_labels), max_samples)),
135
+ desc=f"Processing {split_name}", initial=start_idx,
136
+ total=min(len(tampering_labels), max_samples)):
137
+ try:
138
+ # Skip if no label
139
+ if i not in tampering_labels:
140
+ continue
141
+
142
+ # Get forgery type label
143
+ forgery_type = tampering_labels[i]
144
+ if forgery_type not in LABEL_MAP:
145
+ continue
146
+
147
+ label = LABEL_MAP[forgery_type]
148
+
149
+ # Load image and mask
150
+ image, mask = load_sample_from_lmdb(env, i)
151
+ if image is None or mask is None:
152
+ continue
153
+
154
+ # Skip if no forgery
155
+ if mask.max() == 0:
156
+ continue
157
+
158
+ # Prepare for model
159
+ image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
160
+ image_tensor = image_tensor.unsqueeze(0).to(device)
161
+
162
+ # Get deep features from localization model
163
+ with torch.no_grad():
164
+ logits, decoder_features = model(image_tensor)
165
+
166
+ # Use ground truth mask for feature extraction
167
+ mask_binary = (mask > 127).astype(np.uint8)
168
+
169
+ # Extract hybrid features
170
+ features = feature_extractor.extract(
171
+ image / 255.0,
172
+ mask_binary,
173
+ [f.cpu() for f in decoder_features]
174
+ )
175
+
176
+ # Set expected dimension from first valid sample
177
+ if expected_dim is None:
178
+ expected_dim = len(features)
179
+ print(f"\n✓ Feature dimension set to: {expected_dim}")
180
+
181
+ # Ensure consistent feature dimension
182
+ if len(features) != expected_dim:
183
+ if len(features) < expected_dim:
184
+ features = np.pad(features, (0, expected_dim - len(features)), mode='constant')
185
+ else:
186
+ features = features[:expected_dim]
187
+ dim_mismatch_count += 1
188
+
189
+ all_features.append(features)
190
+ all_labels.append(label)
191
+ num_processed += 1
192
+
193
+ # Save checkpoint every 10,000 samples (only 12 checkpoints total)
194
+ if num_processed % 10000 == 0:
195
+ checkpoint_path = checkpoint_dir / f'checkpoint_{split_name}_{num_processed}.npz'
196
+ features_array = np.array(all_features, dtype=np.float32)
197
+ labels_array = np.array(all_labels, dtype=np.int32)
198
+
199
+ np.savez_compressed(checkpoint_path,
200
+ features=features_array,
201
+ labels=labels_array,
202
+ feature_dim=expected_dim)
203
+ print(f"\n✓ Checkpoint: {num_processed} samples (dim={expected_dim}, mismatches={dim_mismatch_count})")
204
+
205
+ # Delete old checkpoints to save space (keep only last 2)
206
+ old_checkpoints = sorted(checkpoint_dir.glob(f'checkpoint_{split_name}_*.npz'))
207
+ if len(old_checkpoints) > 2:
208
+ for old_cp in old_checkpoints[:-2]:
209
+ old_cp.unlink()
210
+ print(f" Cleaned up: {old_cp.name}")
211
+
212
+ except Exception as e:
213
+ print(f"\n⚠ Error at sample {i}: {str(e)[:80]}")
214
+ continue
215
+
216
+ env.close()
217
+
218
+ print(f"\n✓ Extracted {num_processed} samples")
219
+ if dim_mismatch_count > 0:
220
+ print(f"⚠ Fixed {dim_mismatch_count} dimension mismatches")
221
+
222
+ # Save final features
223
+ final_path = checkpoint_dir / f'features_{split_name}_final.npz'
224
+ if len(all_features) > 0:
225
+ features_array = np.array(all_features, dtype=np.float32)
226
+ labels_array = np.array(all_labels, dtype=np.int32)
227
+
228
+ np.savez_compressed(final_path,
229
+ features=features_array,
230
+ labels=labels_array,
231
+ feature_dim=expected_dim)
232
+ print(f"✓ Final features saved: {final_path}")
233
+ print(f" Shape: features={features_array.shape}, labels={labels_array.shape}")
234
+
235
+ return features_array, labels_array
236
+
237
+ return None, None
238
+
239
+
240
+ def main():
241
+ config = get_config('config.yaml')
242
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
243
+
244
+ print("\n" + "="*60)
245
+ print("LightGBM Classifier Training - DocTamper (FIXED)")
246
+ print("Implements Algorithm Steps 7-8")
247
+ print("="*60)
248
+ print(f"Model: {MODEL_PATH}")
249
+ print(f"Device: {device}")
250
+ print(f"Max samples: {MAX_SAMPLES}")
251
+ print("="*60)
252
+ print("\nForgery Type Classes (Step 8.2):")
253
+ print(" 0: Copy-Move (CM)")
254
+ print(" 1: Splicing (SP)")
255
+ print(" 2: Generation (GE)")
256
+ print("="*60)
257
+
258
+ # Load localization model
259
+ print("\nLoading localization model...")
260
+ model = get_model(config).to(device)
261
+ checkpoint = torch.load(MODEL_PATH, map_location=device)
262
+ model.load_state_dict(checkpoint['model_state_dict'])
263
+ model.eval()
264
+ print(f"✓ Model loaded (Val Dice: {checkpoint.get('best_metric', 0):.4f})")
265
+
266
+ # Load tampering labels
267
+ train_labels = load_tampering_labels(
268
+ 'datasets/DocTamper/tampering/DocTamperV1-TrainingSet.pk'
269
+ )
270
+ test_labels = load_tampering_labels(
271
+ 'datasets/DocTamper/tampering/DocTamperV1-TestingSet.pk'
272
+ )
273
+
274
+ # Extract features from TrainingSet
275
+ X_train, y_train = extract_features(
276
+ config, model,
277
+ 'datasets/DocTamper/DocTamperV1-TrainingSet',
278
+ train_labels,
279
+ MAX_SAMPLES,
280
+ device,
281
+ 'TrainingSet'
282
+ )
283
+
284
+ # Extract features from TestingSet
285
+ X_test, y_test = extract_features(
286
+ config, model,
287
+ 'datasets/DocTamper/DocTamperV1-TestingSet',
288
+ test_labels,
289
+ MAX_SAMPLES // 4,
290
+ device,
291
+ 'TestingSet'
292
+ )
293
+
294
+ if X_train is None or X_test is None:
295
+ print("\n❌ No features extracted!")
296
+ return
297
+
298
+ # Summary
299
+ print("\n" + "="*60)
300
+ print("Dataset Summary")
301
+ print("="*60)
302
+ print(f"Training samples: {len(X_train):,}")
303
+ print(f"Testing samples: {len(X_test):,}")
304
+ print(f"Feature dimension: {X_train.shape[1]}")
305
+
306
+ print(f"\nTraining class distribution:")
307
+ train_counts = np.bincount(y_train)
308
+ class_names = ['Copy-Move', 'Splicing', 'Generation']
309
+ for i, count in enumerate(train_counts):
310
+ if i < len(class_names):
311
+ print(f" {class_names[i]}: {count:,} ({count/len(y_train)*100:.1f}%)")
312
+
313
+ print(f"\nTesting class distribution:")
314
+ test_counts = np.bincount(y_test)
315
+ for i, count in enumerate(test_counts):
316
+ if i < len(class_names):
317
+ print(f" {class_names[i]}: {count:,} ({count/len(y_test)*100:.1f}%)")
318
+
319
+ # Train classifier
320
+ print("\n" + "="*60)
321
+ print("Training LightGBM Classifier (Step 8.1)")
322
+ print("="*60)
323
+
324
+ output_dir = Path(OUTPUT_DIR)
325
+ output_dir.mkdir(parents=True, exist_ok=True)
326
+
327
+ classifier = get_classifier(config)
328
+ feature_names = get_feature_extractor(config, is_text_document=True).get_feature_names()
329
+
330
+ # Combine train and test for sklearn train_test_split
331
+ X_combined = np.vstack([X_train, X_test])
332
+ y_combined = np.concatenate([y_train, y_test])
333
+
334
+ metrics = classifier.train(X_combined, y_combined, feature_names=feature_names)
335
+
336
+ # Save results
337
+ classifier.save(str(output_dir))
338
+ print(f"\n✓ Classifier saved to: {output_dir}")
339
+
340
+ # Save metrics
341
+ metrics_path = output_dir / 'training_metrics.json'
342
+ with open(metrics_path, 'w') as f:
343
+ json.dump(metrics, f, indent=2)
344
+
345
+ # Save class mapping
346
+ class_mapping = {
347
+ 0: 'Copy-Move',
348
+ 1: 'Splicing',
349
+ 2: 'Generation'
350
+ }
351
+ mapping_path = output_dir / 'class_mapping.json'
352
+ with open(mapping_path, 'w') as f:
353
+ json.dump(class_mapping, f, indent=2)
354
+
355
+ print("\n" + "="*60)
356
+ print("✅ Classifier Training Complete!")
357
+ print("Algorithm Steps 7-8: DONE")
358
+ print("="*60)
359
+ print(f"\nResults:")
360
+ print(f" Test Accuracy: {metrics.get('test_accuracy', 'N/A')}")
361
+ print(f" Test F1 Score: {metrics.get('test_f1', 'N/A')}")
362
+ print(f"\nOutput: {output_dir}")
363
+ print("\nNext: Implement Steps 9-11 in inference pipeline")
364
+ print("="*60 + "\n")
365
+
366
+
367
+ if __name__ == '__main__':
368
+ main()