Spaces:
Sleeping
Sleeping
Upload 8 files
Browse files- scripts/evaluate.py +158 -0
- scripts/evaluate_full_testingset.py +277 -0
- scripts/export_model.py +92 -0
- scripts/inference_pipeline.py +556 -0
- scripts/run_inference.py +138 -0
- scripts/setup.py +135 -0
- scripts/train_chunked.py +222 -0
- scripts/train_classifier_doctamper_fixed.py +368 -0
scripts/evaluate.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model Evaluation Script
|
| 3 |
+
|
| 4 |
+
Evaluate trained model on validation/test sets with comprehensive metrics.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python scripts/evaluate.py --model outputs/checkpoints/best_doctamper.pth --dataset doctamper
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import sys
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
import json
|
| 14 |
+
import numpy as np
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
# Add src to path
|
| 19 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 20 |
+
|
| 21 |
+
from src.config import get_config
|
| 22 |
+
from src.models import get_model
|
| 23 |
+
from src.data import get_dataset
|
| 24 |
+
from src.training.metrics import SegmentationMetrics
|
| 25 |
+
from src.utils import plot_training_curves
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def parse_args():
|
| 29 |
+
parser = argparse.ArgumentParser(description="Evaluate forgery detection model")
|
| 30 |
+
|
| 31 |
+
parser.add_argument('--model', type=str, required=True,
|
| 32 |
+
help='Path to model checkpoint')
|
| 33 |
+
|
| 34 |
+
parser.add_argument('--dataset', type=str, required=True,
|
| 35 |
+
choices=['doctamper', 'rtm', 'casia', 'receipts'],
|
| 36 |
+
help='Dataset to evaluate on')
|
| 37 |
+
|
| 38 |
+
parser.add_argument('--split', type=str, default='val',
|
| 39 |
+
help='Data split (val/test)')
|
| 40 |
+
|
| 41 |
+
parser.add_argument('--output', type=str, default='outputs/evaluation',
|
| 42 |
+
help='Output directory')
|
| 43 |
+
|
| 44 |
+
parser.add_argument('--config', type=str, default='config.yaml',
|
| 45 |
+
help='Path to config file')
|
| 46 |
+
|
| 47 |
+
return parser.parse_args()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def main():
|
| 51 |
+
args = parse_args()
|
| 52 |
+
|
| 53 |
+
# Load config
|
| 54 |
+
config = get_config(args.config)
|
| 55 |
+
|
| 56 |
+
# Device
|
| 57 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 58 |
+
|
| 59 |
+
print("\n" + "="*60)
|
| 60 |
+
print("Model Evaluation")
|
| 61 |
+
print("="*60)
|
| 62 |
+
print(f"Model: {args.model}")
|
| 63 |
+
print(f"Dataset: {args.dataset}")
|
| 64 |
+
print(f"Split: {args.split}")
|
| 65 |
+
print(f"Device: {device}")
|
| 66 |
+
print("="*60)
|
| 67 |
+
|
| 68 |
+
# Load model
|
| 69 |
+
model = get_model(config).to(device)
|
| 70 |
+
checkpoint = torch.load(args.model, map_location=device)
|
| 71 |
+
|
| 72 |
+
if 'model_state_dict' in checkpoint:
|
| 73 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 74 |
+
else:
|
| 75 |
+
model.load_state_dict(checkpoint)
|
| 76 |
+
|
| 77 |
+
model.eval()
|
| 78 |
+
print("Model loaded")
|
| 79 |
+
|
| 80 |
+
# Load dataset
|
| 81 |
+
dataset = get_dataset(config, args.dataset, split=args.split)
|
| 82 |
+
print(f"Dataset loaded: {len(dataset)} samples")
|
| 83 |
+
|
| 84 |
+
# Create output directory
|
| 85 |
+
output_dir = Path(args.output)
|
| 86 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 87 |
+
|
| 88 |
+
# Evaluate
|
| 89 |
+
metrics = SegmentationMetrics()
|
| 90 |
+
has_pixel_mask = config.has_pixel_mask(args.dataset)
|
| 91 |
+
|
| 92 |
+
print(f"\nEvaluating...")
|
| 93 |
+
|
| 94 |
+
all_ious = []
|
| 95 |
+
all_dices = []
|
| 96 |
+
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
for i in tqdm(range(len(dataset)), desc="Evaluating"):
|
| 99 |
+
try:
|
| 100 |
+
image, mask, metadata = dataset[i]
|
| 101 |
+
|
| 102 |
+
# Move to device
|
| 103 |
+
image = image.unsqueeze(0).to(device)
|
| 104 |
+
mask = mask.unsqueeze(0).to(device)
|
| 105 |
+
|
| 106 |
+
# Forward pass
|
| 107 |
+
logits, _ = model(image)
|
| 108 |
+
probs = torch.sigmoid(logits)
|
| 109 |
+
|
| 110 |
+
# Update metrics
|
| 111 |
+
if has_pixel_mask:
|
| 112 |
+
metrics.update(probs, mask, has_pixel_mask=True)
|
| 113 |
+
|
| 114 |
+
# Per-sample metrics
|
| 115 |
+
pred_binary = (probs > 0.5).float()
|
| 116 |
+
intersection = (pred_binary * mask).sum().item()
|
| 117 |
+
union = pred_binary.sum().item() + mask.sum().item() - intersection
|
| 118 |
+
|
| 119 |
+
iou = intersection / (union + 1e-8)
|
| 120 |
+
dice = (2 * intersection) / (pred_binary.sum().item() + mask.sum().item() + 1e-8)
|
| 121 |
+
|
| 122 |
+
all_ious.append(iou)
|
| 123 |
+
all_dices.append(dice)
|
| 124 |
+
|
| 125 |
+
except Exception as e:
|
| 126 |
+
print(f"Error on sample {i}: {e}")
|
| 127 |
+
continue
|
| 128 |
+
|
| 129 |
+
# Compute final metrics
|
| 130 |
+
results = metrics.compute()
|
| 131 |
+
|
| 132 |
+
# Add per-sample statistics
|
| 133 |
+
if has_pixel_mask and all_ious:
|
| 134 |
+
results['iou_mean'] = np.mean(all_ious)
|
| 135 |
+
results['iou_std'] = np.std(all_ious)
|
| 136 |
+
results['dice_mean'] = np.mean(all_dices)
|
| 137 |
+
results['dice_std'] = np.std(all_dices)
|
| 138 |
+
|
| 139 |
+
# Print results
|
| 140 |
+
print("\n" + "="*60)
|
| 141 |
+
print("Evaluation Results")
|
| 142 |
+
print("="*60)
|
| 143 |
+
|
| 144 |
+
for key, value in results.items():
|
| 145 |
+
if isinstance(value, float):
|
| 146 |
+
print(f" {key}: {value:.4f}")
|
| 147 |
+
|
| 148 |
+
# Save results
|
| 149 |
+
results_path = output_dir / f'{args.dataset}_{args.split}_results.json'
|
| 150 |
+
with open(results_path, 'w') as f:
|
| 151 |
+
json.dump(results, f, indent=2)
|
| 152 |
+
|
| 153 |
+
print(f"\nResults saved to: {results_path}")
|
| 154 |
+
print("="*60)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
if __name__ == '__main__':
|
| 158 |
+
main()
|
scripts/evaluate_full_testingset.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Comprehensive Evaluation on Entire TestingSet
|
| 3 |
+
|
| 4 |
+
Evaluates the complete pipeline on all 30,000 samples from TestingSet
|
| 5 |
+
Calculates all metrics: Accuracy, Precision, Recall, F1, IoU, Dice, etc.
|
| 6 |
+
No visualizations - metrics only for speed
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python scripts/evaluate_full_testingset.py
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
import json
|
| 18 |
+
from datetime import datetime
|
| 19 |
+
|
| 20 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 21 |
+
|
| 22 |
+
from src.config import get_config
|
| 23 |
+
from src.models import get_model
|
| 24 |
+
from src.data import get_dataset
|
| 25 |
+
from src.features import get_mask_refiner, get_region_extractor
|
| 26 |
+
from src.training.classifier import ForgeryClassifier
|
| 27 |
+
from src.data.preprocessing import DocumentPreprocessor
|
| 28 |
+
from src.data.augmentation import DatasetAwareAugmentation
|
| 29 |
+
|
| 30 |
+
# Class mapping
|
| 31 |
+
CLASS_NAMES = {0: 'Copy-Move', 1: 'Splicing', 2: 'Generation'}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def calculate_metrics(pred_mask, gt_mask):
|
| 35 |
+
"""Calculate all segmentation metrics"""
|
| 36 |
+
pred = pred_mask.astype(bool)
|
| 37 |
+
gt = gt_mask.astype(bool)
|
| 38 |
+
|
| 39 |
+
intersection = (pred & gt).sum()
|
| 40 |
+
union = (pred | gt).sum()
|
| 41 |
+
|
| 42 |
+
tp = intersection
|
| 43 |
+
fp = (pred & ~gt).sum()
|
| 44 |
+
fn = (~pred & gt).sum()
|
| 45 |
+
tn = (~pred & ~gt).sum()
|
| 46 |
+
|
| 47 |
+
# Segmentation metrics
|
| 48 |
+
iou = intersection / (union + 1e-8)
|
| 49 |
+
dice = (2 * intersection) / (pred.sum() + gt.sum() + 1e-8)
|
| 50 |
+
precision = tp / (tp + fp + 1e-8)
|
| 51 |
+
recall = tp / (tp + fn + 1e-8)
|
| 52 |
+
f1 = 2 * precision * recall / (precision + recall + 1e-8)
|
| 53 |
+
accuracy = (tp + tn) / (tp + tn + fp + fn + 1e-8)
|
| 54 |
+
|
| 55 |
+
return {
|
| 56 |
+
'iou': float(iou),
|
| 57 |
+
'dice': float(dice),
|
| 58 |
+
'precision': float(precision),
|
| 59 |
+
'recall': float(recall),
|
| 60 |
+
'f1': float(f1),
|
| 61 |
+
'accuracy': float(accuracy),
|
| 62 |
+
'tp': int(tp),
|
| 63 |
+
'fp': int(fp),
|
| 64 |
+
'fn': int(fn),
|
| 65 |
+
'tn': int(tn)
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def main():
|
| 70 |
+
print("="*80)
|
| 71 |
+
print("COMPREHENSIVE EVALUATION ON ENTIRE TESTINGSET")
|
| 72 |
+
print("="*80)
|
| 73 |
+
print("Dataset: DocTamper TestingSet (30,000 samples)")
|
| 74 |
+
print("Mode: Metrics only (no visualizations)")
|
| 75 |
+
print("="*80)
|
| 76 |
+
|
| 77 |
+
config = get_config('config.yaml')
|
| 78 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 79 |
+
|
| 80 |
+
# Load models
|
| 81 |
+
print("\n1. Loading models...")
|
| 82 |
+
|
| 83 |
+
# Localization model
|
| 84 |
+
model = get_model(config).to(device)
|
| 85 |
+
checkpoint = torch.load('outputs/checkpoints/best_doctamper.pth', map_location=device)
|
| 86 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 87 |
+
model.eval()
|
| 88 |
+
print(f" ✓ Localization model loaded (Dice: {checkpoint.get('best_metric', 0):.2%})")
|
| 89 |
+
|
| 90 |
+
# Classifier
|
| 91 |
+
classifier = ForgeryClassifier(config)
|
| 92 |
+
classifier.load('outputs/classifier')
|
| 93 |
+
print(f" ✓ Classifier loaded")
|
| 94 |
+
|
| 95 |
+
# Components
|
| 96 |
+
preprocessor = DocumentPreprocessor(config, 'doctamper')
|
| 97 |
+
augmentation = DatasetAwareAugmentation(config, 'doctamper', is_training=False)
|
| 98 |
+
mask_refiner = get_mask_refiner(config)
|
| 99 |
+
region_extractor = get_region_extractor(config)
|
| 100 |
+
|
| 101 |
+
# Load dataset
|
| 102 |
+
print("\n2. Loading TestingSet...")
|
| 103 |
+
dataset = get_dataset(config, 'doctamper', split='val') # val = TestingSet
|
| 104 |
+
total_samples = len(dataset)
|
| 105 |
+
print(f" ✓ Loaded {total_samples} samples")
|
| 106 |
+
|
| 107 |
+
# Initialize metrics storage
|
| 108 |
+
all_metrics = []
|
| 109 |
+
detection_stats = {
|
| 110 |
+
'total': 0,
|
| 111 |
+
'has_forgery': 0,
|
| 112 |
+
'detected': 0,
|
| 113 |
+
'missed': 0,
|
| 114 |
+
'false_positives': 0,
|
| 115 |
+
'true_negatives': 0
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
print("\n3. Running evaluation...")
|
| 119 |
+
print("="*80)
|
| 120 |
+
|
| 121 |
+
# Process all samples
|
| 122 |
+
for idx in tqdm(range(total_samples), desc="Evaluating"):
|
| 123 |
+
try:
|
| 124 |
+
# Get sample from dataset
|
| 125 |
+
image_tensor, mask_tensor, metadata = dataset[idx]
|
| 126 |
+
|
| 127 |
+
# Ground truth
|
| 128 |
+
gt_mask = mask_tensor.numpy()[0]
|
| 129 |
+
gt_mask_binary = (gt_mask > 0.5).astype(np.uint8)
|
| 130 |
+
has_forgery = gt_mask_binary.sum() > 0
|
| 131 |
+
|
| 132 |
+
# Run localization
|
| 133 |
+
with torch.no_grad():
|
| 134 |
+
image_batch = image_tensor.unsqueeze(0).to(device)
|
| 135 |
+
logits, _ = model(image_batch)
|
| 136 |
+
prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
|
| 137 |
+
|
| 138 |
+
# Generate mask
|
| 139 |
+
binary_mask = (prob_map > 0.5).astype(np.uint8)
|
| 140 |
+
refined_mask = mask_refiner.refine(binary_mask)
|
| 141 |
+
|
| 142 |
+
# Calculate metrics
|
| 143 |
+
metrics = calculate_metrics(refined_mask, gt_mask_binary)
|
| 144 |
+
metrics['sample_idx'] = idx
|
| 145 |
+
metrics['has_forgery'] = has_forgery
|
| 146 |
+
metrics['prob_max'] = float(prob_map.max())
|
| 147 |
+
|
| 148 |
+
# Detection statistics
|
| 149 |
+
detected = refined_mask.sum() > 0
|
| 150 |
+
|
| 151 |
+
detection_stats['total'] += 1
|
| 152 |
+
if has_forgery:
|
| 153 |
+
detection_stats['has_forgery'] += 1
|
| 154 |
+
if detected:
|
| 155 |
+
detection_stats['detected'] += 1
|
| 156 |
+
else:
|
| 157 |
+
detection_stats['missed'] += 1
|
| 158 |
+
else:
|
| 159 |
+
if detected:
|
| 160 |
+
detection_stats['false_positives'] += 1
|
| 161 |
+
else:
|
| 162 |
+
detection_stats['true_negatives'] += 1
|
| 163 |
+
|
| 164 |
+
all_metrics.append(metrics)
|
| 165 |
+
|
| 166 |
+
except Exception as e:
|
| 167 |
+
print(f"\nError at sample {idx}: {str(e)[:100]}")
|
| 168 |
+
continue
|
| 169 |
+
|
| 170 |
+
# Calculate overall statistics
|
| 171 |
+
print("\n" + "="*80)
|
| 172 |
+
print("RESULTS")
|
| 173 |
+
print("="*80)
|
| 174 |
+
|
| 175 |
+
# Detection statistics
|
| 176 |
+
print("\n📊 DETECTION STATISTICS:")
|
| 177 |
+
print("-"*80)
|
| 178 |
+
print(f"Total samples: {detection_stats['total']}")
|
| 179 |
+
print(f"Samples with forgery: {detection_stats['has_forgery']}")
|
| 180 |
+
print(f"Samples without forgery: {detection_stats['total'] - detection_stats['has_forgery']}")
|
| 181 |
+
print()
|
| 182 |
+
print(f"✅ Correctly detected: {detection_stats['detected']}")
|
| 183 |
+
print(f"❌ Missed detections: {detection_stats['missed']}")
|
| 184 |
+
print(f"⚠️ False positives: {detection_stats['false_positives']}")
|
| 185 |
+
print(f"✓ True negatives: {detection_stats['true_negatives']}")
|
| 186 |
+
print()
|
| 187 |
+
|
| 188 |
+
# Detection rates
|
| 189 |
+
if detection_stats['has_forgery'] > 0:
|
| 190 |
+
detection_rate = detection_stats['detected'] / detection_stats['has_forgery']
|
| 191 |
+
miss_rate = detection_stats['missed'] / detection_stats['has_forgery']
|
| 192 |
+
print(f"Detection Rate (Recall): {detection_rate:.2%} ⬆️ Higher is better")
|
| 193 |
+
print(f"Miss Rate: {miss_rate:.2%} ⬇️ Lower is better")
|
| 194 |
+
|
| 195 |
+
if detection_stats['detected'] + detection_stats['false_positives'] > 0:
|
| 196 |
+
precision = detection_stats['detected'] / (detection_stats['detected'] + detection_stats['false_positives'])
|
| 197 |
+
print(f"Precision: {precision:.2%} ⬆️ Higher is better")
|
| 198 |
+
|
| 199 |
+
overall_accuracy = (detection_stats['detected'] + detection_stats['true_negatives']) / detection_stats['total']
|
| 200 |
+
print(f"Overall Accuracy: {overall_accuracy:.2%} ⬆️ Higher is better")
|
| 201 |
+
|
| 202 |
+
# Segmentation metrics (only for samples with forgery)
|
| 203 |
+
forgery_metrics = [m for m in all_metrics if m['has_forgery']]
|
| 204 |
+
|
| 205 |
+
if forgery_metrics:
|
| 206 |
+
print("\n📈 SEGMENTATION METRICS (on samples with forgery):")
|
| 207 |
+
print("-"*80)
|
| 208 |
+
|
| 209 |
+
avg_iou = np.mean([m['iou'] for m in forgery_metrics])
|
| 210 |
+
avg_dice = np.mean([m['dice'] for m in forgery_metrics])
|
| 211 |
+
avg_precision = np.mean([m['precision'] for m in forgery_metrics])
|
| 212 |
+
avg_recall = np.mean([m['recall'] for m in forgery_metrics])
|
| 213 |
+
avg_f1 = np.mean([m['f1'] for m in forgery_metrics])
|
| 214 |
+
avg_accuracy = np.mean([m['accuracy'] for m in forgery_metrics])
|
| 215 |
+
|
| 216 |
+
print(f"IoU (Intersection over Union): {avg_iou:.4f} ⬆️ Higher is better (0-1)")
|
| 217 |
+
print(f"Dice Coefficient: {avg_dice:.4f} ⬆️ Higher is better (0-1)")
|
| 218 |
+
print(f"Pixel Precision: {avg_precision:.4f} ⬆️ Higher is better (0-1)")
|
| 219 |
+
print(f"Pixel Recall: {avg_recall:.4f} ⬆️ Higher is better (0-1)")
|
| 220 |
+
print(f"Pixel F1-Score: {avg_f1:.4f} ⬆️ Higher is better (0-1)")
|
| 221 |
+
print(f"Pixel Accuracy: {avg_accuracy:.4f} ⬆️ Higher is better (0-1)")
|
| 222 |
+
|
| 223 |
+
# Probability statistics
|
| 224 |
+
avg_prob = np.mean([m['prob_max'] for m in forgery_metrics])
|
| 225 |
+
print(f"\nAverage Max Probability: {avg_prob:.4f}")
|
| 226 |
+
|
| 227 |
+
# Save results
|
| 228 |
+
print("\n" + "="*80)
|
| 229 |
+
print("SAVING RESULTS")
|
| 230 |
+
print("="*80)
|
| 231 |
+
|
| 232 |
+
output_dir = Path('outputs/evaluation')
|
| 233 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 234 |
+
|
| 235 |
+
# Summary
|
| 236 |
+
summary = {
|
| 237 |
+
'timestamp': datetime.now().isoformat(),
|
| 238 |
+
'total_samples': detection_stats['total'],
|
| 239 |
+
'detection_statistics': detection_stats,
|
| 240 |
+
'detection_rate': detection_stats['detected'] / detection_stats['has_forgery'] if detection_stats['has_forgery'] > 0 else 0,
|
| 241 |
+
'precision': detection_stats['detected'] / (detection_stats['detected'] + detection_stats['false_positives']) if (detection_stats['detected'] + detection_stats['false_positives']) > 0 else 0,
|
| 242 |
+
'overall_accuracy': overall_accuracy,
|
| 243 |
+
'segmentation_metrics': {
|
| 244 |
+
'iou': float(avg_iou) if forgery_metrics else 0,
|
| 245 |
+
'dice': float(avg_dice) if forgery_metrics else 0,
|
| 246 |
+
'precision': float(avg_precision) if forgery_metrics else 0,
|
| 247 |
+
'recall': float(avg_recall) if forgery_metrics else 0,
|
| 248 |
+
'f1': float(avg_f1) if forgery_metrics else 0,
|
| 249 |
+
'accuracy': float(avg_accuracy) if forgery_metrics else 0
|
| 250 |
+
}
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
summary_path = output_dir / 'evaluation_summary.json'
|
| 254 |
+
with open(summary_path, 'w') as f:
|
| 255 |
+
json.dump(summary, f, indent=2)
|
| 256 |
+
|
| 257 |
+
print(f"✓ Summary saved to: {summary_path}")
|
| 258 |
+
|
| 259 |
+
# Detailed metrics (optional - can be large)
|
| 260 |
+
# detailed_path = output_dir / 'detailed_metrics.json'
|
| 261 |
+
# with open(detailed_path, 'w') as f:
|
| 262 |
+
# json.dump(all_metrics, f, indent=2)
|
| 263 |
+
# print(f"✓ Detailed metrics saved to: {detailed_path}")
|
| 264 |
+
|
| 265 |
+
print("\n" + "="*80)
|
| 266 |
+
print("✅ EVALUATION COMPLETE!")
|
| 267 |
+
print("="*80)
|
| 268 |
+
print(f"\nKey Metrics Summary:")
|
| 269 |
+
print(f" Detection Rate: {detection_stats['detected'] / detection_stats['has_forgery']:.2%}")
|
| 270 |
+
print(f" Overall Accuracy: {overall_accuracy:.2%}")
|
| 271 |
+
print(f" Dice Score: {avg_dice:.4f}" if forgery_metrics else " Dice Score: N/A")
|
| 272 |
+
print(f" IoU: {avg_iou:.4f}" if forgery_metrics else " IoU: N/A")
|
| 273 |
+
print("="*80 + "\n")
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
if __name__ == '__main__':
|
| 277 |
+
main()
|
scripts/export_model.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model Export Script
|
| 3 |
+
|
| 4 |
+
Export trained model to ONNX format for deployment.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python scripts/export_model.py --model outputs/checkpoints/best_doctamper.pth --format onnx
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import sys
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
# Add src to path
|
| 15 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
from src.config import get_config
|
| 20 |
+
from src.models import get_model
|
| 21 |
+
from src.utils import export_to_onnx, export_to_torchscript
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def parse_args():
|
| 25 |
+
parser = argparse.ArgumentParser(description="Export model for deployment")
|
| 26 |
+
|
| 27 |
+
parser.add_argument('--model', type=str, required=True,
|
| 28 |
+
help='Path to model checkpoint')
|
| 29 |
+
|
| 30 |
+
parser.add_argument('--format', type=str, default='onnx',
|
| 31 |
+
choices=['onnx', 'torchscript', 'both'],
|
| 32 |
+
help='Export format')
|
| 33 |
+
|
| 34 |
+
parser.add_argument('--output', type=str, default='outputs/exported',
|
| 35 |
+
help='Output directory')
|
| 36 |
+
|
| 37 |
+
parser.add_argument('--config', type=str, default='config.yaml',
|
| 38 |
+
help='Path to config file')
|
| 39 |
+
|
| 40 |
+
return parser.parse_args()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def main():
|
| 44 |
+
args = parse_args()
|
| 45 |
+
|
| 46 |
+
# Load config
|
| 47 |
+
config = get_config(args.config)
|
| 48 |
+
|
| 49 |
+
print("\n" + "="*60)
|
| 50 |
+
print("Model Export")
|
| 51 |
+
print("="*60)
|
| 52 |
+
print(f"Model: {args.model}")
|
| 53 |
+
print(f"Format: {args.format}")
|
| 54 |
+
print("="*60)
|
| 55 |
+
|
| 56 |
+
# Create output directory
|
| 57 |
+
output_dir = Path(args.output)
|
| 58 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 59 |
+
|
| 60 |
+
# Load model
|
| 61 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 62 |
+
model = get_model(config).to(device)
|
| 63 |
+
|
| 64 |
+
checkpoint = torch.load(args.model, map_location=device)
|
| 65 |
+
if 'model_state_dict' in checkpoint:
|
| 66 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 67 |
+
else:
|
| 68 |
+
model.load_state_dict(checkpoint)
|
| 69 |
+
|
| 70 |
+
model.eval()
|
| 71 |
+
print("Model loaded")
|
| 72 |
+
|
| 73 |
+
# Get image size
|
| 74 |
+
image_size = config.get('data.image_size', 384)
|
| 75 |
+
|
| 76 |
+
# Export
|
| 77 |
+
if args.format in ['onnx', 'both']:
|
| 78 |
+
onnx_path = output_dir / 'model.onnx'
|
| 79 |
+
export_to_onnx(model, str(onnx_path), input_size=(image_size, image_size))
|
| 80 |
+
|
| 81 |
+
if args.format in ['torchscript', 'both']:
|
| 82 |
+
ts_path = output_dir / 'model.pt'
|
| 83 |
+
export_to_torchscript(model, str(ts_path), input_size=(image_size, image_size))
|
| 84 |
+
|
| 85 |
+
print("\n" + "="*60)
|
| 86 |
+
print("Export Complete!")
|
| 87 |
+
print(f"Output: {output_dir}")
|
| 88 |
+
print("="*60)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if __name__ == '__main__':
|
| 92 |
+
main()
|
scripts/inference_pipeline.py
ADDED
|
@@ -0,0 +1,556 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Complete Document Forgery Detection Pipeline
|
| 3 |
+
Implements Full Algorithm Steps 1-11
|
| 4 |
+
|
| 5 |
+
Features:
|
| 6 |
+
- ✅ Localization (WHERE is forgery?)
|
| 7 |
+
- ✅ Classification (WHAT type of forgery?)
|
| 8 |
+
- ✅ Confidence filtering
|
| 9 |
+
- ✅ Visualizations (heatmaps, overlays, bounding boxes)
|
| 10 |
+
- ✅ JSON output with detailed results
|
| 11 |
+
- ✅ Actual vs Predicted comparison (if ground truth available)
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
python scripts/inference_pipeline.py --image path/to/document.jpg
|
| 15 |
+
python scripts/inference_pipeline.py --image path/to/document.jpg --ground_truth path/to/mask.png
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import sys
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
import argparse
|
| 21 |
+
import numpy as np
|
| 22 |
+
import cv2
|
| 23 |
+
import torch
|
| 24 |
+
import json
|
| 25 |
+
from datetime import datetime
|
| 26 |
+
import matplotlib.pyplot as plt
|
| 27 |
+
import matplotlib.patches as patches
|
| 28 |
+
|
| 29 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 30 |
+
|
| 31 |
+
from src.config import get_config
|
| 32 |
+
from src.models import get_model
|
| 33 |
+
from src.features import get_feature_extractor, get_mask_refiner, get_region_extractor
|
| 34 |
+
from src.training.classifier import ForgeryClassifier
|
| 35 |
+
from src.data.preprocessing import DocumentPreprocessor
|
| 36 |
+
|
| 37 |
+
# Class mapping
|
| 38 |
+
CLASS_NAMES = {
|
| 39 |
+
0: 'Copy-Move',
|
| 40 |
+
1: 'Splicing',
|
| 41 |
+
2: 'Generation'
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
CLASS_COLORS = {
|
| 45 |
+
0: (255, 0, 0), # Red for Copy-Move
|
| 46 |
+
1: (0, 255, 0), # Green for Splicing
|
| 47 |
+
2: (0, 0, 255) # Blue for Generation
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class ForgeryDetectionPipeline:
|
| 52 |
+
"""
|
| 53 |
+
Complete forgery detection pipeline
|
| 54 |
+
Implements Algorithm Steps 1-11
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(self, config_path='config.yaml'):
|
| 58 |
+
"""Initialize pipeline with models"""
|
| 59 |
+
print("="*70)
|
| 60 |
+
print("Initializing Forgery Detection Pipeline")
|
| 61 |
+
print("="*70)
|
| 62 |
+
|
| 63 |
+
self.config = get_config(config_path)
|
| 64 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 65 |
+
|
| 66 |
+
# Load localization model (Steps 1-6)
|
| 67 |
+
print("\n1. Loading localization model...")
|
| 68 |
+
self.localization_model = get_model(self.config).to(self.device)
|
| 69 |
+
checkpoint = torch.load('outputs/checkpoints/best_doctamper.pth',
|
| 70 |
+
map_location=self.device)
|
| 71 |
+
self.localization_model.load_state_dict(checkpoint['model_state_dict'])
|
| 72 |
+
self.localization_model.eval()
|
| 73 |
+
print(f" ✓ Loaded (Val Dice: {checkpoint.get('best_metric', 0):.2%})")
|
| 74 |
+
|
| 75 |
+
# Load classifier (Step 8)
|
| 76 |
+
print("\n2. Loading forgery type classifier...")
|
| 77 |
+
self.classifier = ForgeryClassifier(self.config)
|
| 78 |
+
self.classifier.load('outputs/classifier')
|
| 79 |
+
print(" ✓ Loaded")
|
| 80 |
+
|
| 81 |
+
# Initialize components
|
| 82 |
+
print("\n3. Initializing components...")
|
| 83 |
+
self.preprocessor = DocumentPreprocessor(self.config, 'doctamper')
|
| 84 |
+
|
| 85 |
+
# Initialize augmentation for inference
|
| 86 |
+
from src.data.augmentation import DatasetAwareAugmentation
|
| 87 |
+
self.augmentation = DatasetAwareAugmentation(self.config, 'doctamper', is_training=False)
|
| 88 |
+
|
| 89 |
+
self.feature_extractor = get_feature_extractor(self.config, is_text_document=True)
|
| 90 |
+
self.mask_refiner = get_mask_refiner(self.config)
|
| 91 |
+
self.region_extractor = get_region_extractor(self.config)
|
| 92 |
+
print(" ✓ Ready")
|
| 93 |
+
|
| 94 |
+
print("\n" + "="*70)
|
| 95 |
+
print("Pipeline Initialized Successfully!")
|
| 96 |
+
print("="*70 + "\n")
|
| 97 |
+
|
| 98 |
+
def detect(self, image_path, ground_truth_path=None, output_dir='outputs/inference'):
|
| 99 |
+
"""
|
| 100 |
+
Run complete detection pipeline
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
image_path: Path to input document image
|
| 104 |
+
ground_truth_path: Optional path to ground truth mask
|
| 105 |
+
output_dir: Directory to save outputs
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
results: Dictionary with detection results
|
| 109 |
+
"""
|
| 110 |
+
print(f"\n{'='*70}")
|
| 111 |
+
print(f"Processing: {image_path}")
|
| 112 |
+
print(f"{'='*70}\n")
|
| 113 |
+
|
| 114 |
+
# Create output directory
|
| 115 |
+
output_path = Path(output_dir)
|
| 116 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 117 |
+
|
| 118 |
+
# Get base filename
|
| 119 |
+
base_name = Path(image_path).stem
|
| 120 |
+
|
| 121 |
+
# Step 1-2: Load and preprocess image (EXACTLY like dataset)
|
| 122 |
+
print("Step 1-2: Loading and preprocessing...")
|
| 123 |
+
image = cv2.imread(str(image_path))
|
| 124 |
+
if image is None:
|
| 125 |
+
raise ValueError(f"Could not load image: {image_path}")
|
| 126 |
+
|
| 127 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 128 |
+
|
| 129 |
+
# Create dummy mask for preprocessing
|
| 130 |
+
dummy_mask = np.zeros(image_rgb.shape[:2], dtype=np.uint8)
|
| 131 |
+
|
| 132 |
+
# Step 1: Preprocess (like dataset line: image, mask = self.preprocessor(image, mask))
|
| 133 |
+
preprocessed_img, preprocessed_mask = self.preprocessor(image_rgb, dummy_mask)
|
| 134 |
+
|
| 135 |
+
# Step 2: Augment (like dataset line: augmented = self.augmentation(image, mask))
|
| 136 |
+
augmented = self.augmentation(preprocessed_img, preprocessed_mask)
|
| 137 |
+
|
| 138 |
+
# Step 3: Extract tensor (like dataset line: image = augmented['image'])
|
| 139 |
+
image_tensor = augmented['image']
|
| 140 |
+
|
| 141 |
+
print(f" ✓ Image shape: {image_rgb.shape}")
|
| 142 |
+
print(f" ✓ Preprocessed tensor shape: {image_tensor.shape}")
|
| 143 |
+
print(f" ✓ Tensor range: [{image_tensor.min():.4f}, {image_tensor.max():.4f}]")
|
| 144 |
+
|
| 145 |
+
# Load ground truth if provided
|
| 146 |
+
ground_truth = None
|
| 147 |
+
if ground_truth_path:
|
| 148 |
+
ground_truth = cv2.imread(str(ground_truth_path), cv2.IMREAD_GRAYSCALE)
|
| 149 |
+
if ground_truth is not None:
|
| 150 |
+
# Resize to match preprocessed size
|
| 151 |
+
target_size = (image_tensor.shape[2], image_tensor.shape[1]) # (W, H)
|
| 152 |
+
ground_truth = cv2.resize(ground_truth, target_size)
|
| 153 |
+
print(f" ✓ Ground truth loaded")
|
| 154 |
+
|
| 155 |
+
# Step 3-4: Localization (WHERE is forgery?)
|
| 156 |
+
print("\nStep 3-4: Forgery localization...")
|
| 157 |
+
image_batch = image_tensor.unsqueeze(0).to(self.device)
|
| 158 |
+
|
| 159 |
+
with torch.no_grad():
|
| 160 |
+
logits, decoder_features = self.localization_model(image_batch)
|
| 161 |
+
prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
|
| 162 |
+
|
| 163 |
+
print(f" ✓ Probability map generated")
|
| 164 |
+
print(f" ✓ Prob map range: [{prob_map.min():.4f}, {prob_map.max():.4f}]")
|
| 165 |
+
|
| 166 |
+
# Step 5: Binary mask generation
|
| 167 |
+
print("\nStep 5: Generating binary mask...")
|
| 168 |
+
binary_mask = (prob_map > 0.5).astype(np.uint8)
|
| 169 |
+
refined_mask = self.mask_refiner.refine(binary_mask)
|
| 170 |
+
print(f" ✓ Mask refined")
|
| 171 |
+
|
| 172 |
+
# Step 6: Region extraction
|
| 173 |
+
print("\nStep 6: Extracting forgery regions...")
|
| 174 |
+
# Convert tensor to numpy for region extraction and feature extraction
|
| 175 |
+
preprocessed_numpy = image_tensor.permute(1, 2, 0).cpu().numpy()
|
| 176 |
+
regions = self.region_extractor.extract(refined_mask, prob_map, preprocessed_numpy)
|
| 177 |
+
print(f" ✓ Found {len(regions)} regions")
|
| 178 |
+
|
| 179 |
+
if len(regions) == 0:
|
| 180 |
+
print("\n⚠ No forgery regions detected!")
|
| 181 |
+
# Still create visualizations if ground truth exists
|
| 182 |
+
if ground_truth is not None:
|
| 183 |
+
print("\nCreating comparison with ground truth...")
|
| 184 |
+
self._create_comparison_visualization(
|
| 185 |
+
image_rgb, prob_map, refined_mask, ground_truth,
|
| 186 |
+
base_name, output_path
|
| 187 |
+
)
|
| 188 |
+
return self._create_clean_result(image_rgb, base_name, output_path, ground_truth)
|
| 189 |
+
|
| 190 |
+
# Step 7-8: Feature extraction and classification
|
| 191 |
+
print("\nStep 7-8: Classifying forgery types...")
|
| 192 |
+
region_results = []
|
| 193 |
+
|
| 194 |
+
for i, region in enumerate(regions):
|
| 195 |
+
# Extract features (Step 7)
|
| 196 |
+
features = self.feature_extractor.extract(
|
| 197 |
+
preprocessed_numpy,
|
| 198 |
+
region['region_mask'],
|
| 199 |
+
[f.cpu() for f in decoder_features]
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# Ensure correct dimension (526)
|
| 203 |
+
expected_dim = 526
|
| 204 |
+
if len(features) < expected_dim:
|
| 205 |
+
features = np.pad(features, (0, expected_dim - len(features)))
|
| 206 |
+
elif len(features) > expected_dim:
|
| 207 |
+
features = features[:expected_dim]
|
| 208 |
+
|
| 209 |
+
features = features.reshape(1, -1)
|
| 210 |
+
|
| 211 |
+
# Classify (Step 8)
|
| 212 |
+
predictions, confidences = self.classifier.predict(features)
|
| 213 |
+
forgery_type = int(predictions[0])
|
| 214 |
+
confidence = float(confidences[0])
|
| 215 |
+
|
| 216 |
+
region_results.append({
|
| 217 |
+
'region_id': i + 1,
|
| 218 |
+
'bounding_box': region['bounding_box'],
|
| 219 |
+
'area': int(region['area']),
|
| 220 |
+
'forgery_type': CLASS_NAMES[forgery_type],
|
| 221 |
+
'forgery_type_id': forgery_type,
|
| 222 |
+
'confidence': confidence,
|
| 223 |
+
'mask_probability_mean': float(prob_map[region['region_mask'] > 0].mean())
|
| 224 |
+
})
|
| 225 |
+
|
| 226 |
+
print(f" Region {i+1}: {CLASS_NAMES[forgery_type]} "
|
| 227 |
+
f"(confidence: {confidence:.2%})")
|
| 228 |
+
|
| 229 |
+
# Step 9: False positive removal
|
| 230 |
+
print("\nStep 9: Filtering low-confidence regions...")
|
| 231 |
+
confidence_threshold = self.config.get('classification.confidence_threshold', 0.6)
|
| 232 |
+
filtered_results = [r for r in region_results if r['confidence'] >= confidence_threshold]
|
| 233 |
+
print(f" ✓ Kept {len(filtered_results)}/{len(region_results)} regions "
|
| 234 |
+
f"(threshold: {confidence_threshold:.0%})")
|
| 235 |
+
|
| 236 |
+
# Step 10-11: Generate outputs
|
| 237 |
+
print("\nStep 10-11: Generating outputs...")
|
| 238 |
+
|
| 239 |
+
# Calculate scale factors for coordinate conversion
|
| 240 |
+
# Bounding boxes are in preprocessed coordinates (384x384)
|
| 241 |
+
# Need to scale to original image coordinates
|
| 242 |
+
orig_h, orig_w = image_rgb.shape[:2]
|
| 243 |
+
prep_h, prep_w = prob_map.shape
|
| 244 |
+
scale_x = orig_w / prep_w
|
| 245 |
+
scale_y = orig_h / prep_h
|
| 246 |
+
|
| 247 |
+
# Create visualizations
|
| 248 |
+
self._create_visualizations(
|
| 249 |
+
image_rgb, prob_map, refined_mask, filtered_results,
|
| 250 |
+
ground_truth, base_name, output_path, scale_x, scale_y
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# Create JSON output
|
| 254 |
+
results = self._create_json_output(
|
| 255 |
+
image_path, filtered_results, ground_truth, base_name, output_path
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
print(f"\n{'='*70}")
|
| 259 |
+
print("✅ Detection Complete!")
|
| 260 |
+
print(f"{'='*70}")
|
| 261 |
+
print(f"Output directory: {output_path}")
|
| 262 |
+
print(f"Detected {len(filtered_results)} forgery regions")
|
| 263 |
+
print(f"{'='*70}\n")
|
| 264 |
+
|
| 265 |
+
return results
|
| 266 |
+
|
| 267 |
+
def _create_visualizations(self, image, prob_map, mask, results,
|
| 268 |
+
ground_truth, base_name, output_path, scale_x, scale_y):
|
| 269 |
+
"""Create all visualizations"""
|
| 270 |
+
|
| 271 |
+
# 1. Probability heatmap
|
| 272 |
+
plt.figure(figsize=(15, 5))
|
| 273 |
+
|
| 274 |
+
plt.subplot(1, 3, 1)
|
| 275 |
+
plt.imshow(image)
|
| 276 |
+
plt.title('Original Document')
|
| 277 |
+
plt.axis('off')
|
| 278 |
+
|
| 279 |
+
plt.subplot(1, 3, 2)
|
| 280 |
+
plt.imshow(prob_map, cmap='hot', vmin=0, vmax=1)
|
| 281 |
+
plt.colorbar(label='Forgery Probability')
|
| 282 |
+
plt.title('Probability Heatmap')
|
| 283 |
+
plt.axis('off')
|
| 284 |
+
|
| 285 |
+
plt.subplot(1, 3, 3)
|
| 286 |
+
plt.imshow(mask, cmap='gray')
|
| 287 |
+
plt.title('Binary Mask')
|
| 288 |
+
plt.axis('off')
|
| 289 |
+
|
| 290 |
+
plt.tight_layout()
|
| 291 |
+
plt.savefig(output_path / f'{base_name}_heatmap.png', dpi=150, bbox_inches='tight')
|
| 292 |
+
plt.close()
|
| 293 |
+
print(f" ✓ Saved heatmap")
|
| 294 |
+
|
| 295 |
+
# 2. Overlay with bounding boxes and labels
|
| 296 |
+
overlay = image.copy()
|
| 297 |
+
alpha = 0.4
|
| 298 |
+
|
| 299 |
+
# Create colored mask overlay (scale mask to original size)
|
| 300 |
+
mask_scaled = cv2.resize(mask, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
|
| 301 |
+
colored_mask = np.zeros_like(image)
|
| 302 |
+
|
| 303 |
+
for result in results:
|
| 304 |
+
bbox = result['bounding_box']
|
| 305 |
+
forgery_type = result['forgery_type_id']
|
| 306 |
+
color = CLASS_COLORS[forgery_type]
|
| 307 |
+
|
| 308 |
+
# Scale bounding box to original image coordinates
|
| 309 |
+
x, y, w, h = bbox
|
| 310 |
+
x_scaled = int(x * scale_x)
|
| 311 |
+
y_scaled = int(y * scale_y)
|
| 312 |
+
w_scaled = int(w * scale_x)
|
| 313 |
+
h_scaled = int(h * scale_y)
|
| 314 |
+
|
| 315 |
+
# Color the region
|
| 316 |
+
colored_mask[y_scaled:y_scaled+h_scaled, x_scaled:x_scaled+w_scaled] = color
|
| 317 |
+
|
| 318 |
+
# Blend with original
|
| 319 |
+
overlay = cv2.addWeighted(overlay, 1-alpha, colored_mask, alpha, 0)
|
| 320 |
+
|
| 321 |
+
# Draw bounding boxes and labels
|
| 322 |
+
fig, ax = plt.subplots(1, figsize=(12, 8))
|
| 323 |
+
ax.imshow(overlay)
|
| 324 |
+
|
| 325 |
+
for result in results:
|
| 326 |
+
bbox = result['bounding_box']
|
| 327 |
+
x, y, w, h = bbox # bbox is [x, y, w, h] in preprocessed coordinates
|
| 328 |
+
|
| 329 |
+
# Scale to original image coordinates
|
| 330 |
+
x_scaled = x * scale_x
|
| 331 |
+
y_scaled = y * scale_y
|
| 332 |
+
w_scaled = w * scale_x
|
| 333 |
+
h_scaled = h * scale_y
|
| 334 |
+
|
| 335 |
+
forgery_type = result['forgery_type']
|
| 336 |
+
confidence = result['confidence']
|
| 337 |
+
color_rgb = tuple(c/255 for c in CLASS_COLORS[result['forgery_type_id']])
|
| 338 |
+
|
| 339 |
+
# Draw rectangle
|
| 340 |
+
rect = patches.Rectangle((x_scaled, y_scaled), w_scaled, h_scaled,
|
| 341 |
+
linewidth=2, edgecolor=color_rgb,
|
| 342 |
+
facecolor='none')
|
| 343 |
+
ax.add_patch(rect)
|
| 344 |
+
|
| 345 |
+
# Add label
|
| 346 |
+
label = f"{forgery_type}\n{confidence:.1%}"
|
| 347 |
+
ax.text(x_scaled, y_scaled-10, label, color='white', fontsize=10,
|
| 348 |
+
bbox=dict(boxstyle='round', facecolor=color_rgb, alpha=0.8))
|
| 349 |
+
|
| 350 |
+
ax.axis('off')
|
| 351 |
+
ax.set_title('Forgery Detection Results', fontsize=14, fontweight='bold')
|
| 352 |
+
plt.tight_layout()
|
| 353 |
+
plt.savefig(output_path / f'{base_name}_overlay.png', dpi=150, bbox_inches='tight')
|
| 354 |
+
plt.close()
|
| 355 |
+
print(f" ✓ Saved overlay")
|
| 356 |
+
|
| 357 |
+
# 3. Comparison with ground truth (if available)
|
| 358 |
+
if ground_truth is not None:
|
| 359 |
+
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
|
| 360 |
+
|
| 361 |
+
axes[0].imshow(image)
|
| 362 |
+
axes[0].set_title('Original Document', fontsize=12)
|
| 363 |
+
axes[0].axis('off')
|
| 364 |
+
|
| 365 |
+
axes[1].imshow(ground_truth, cmap='gray')
|
| 366 |
+
axes[1].set_title('Ground Truth', fontsize=12)
|
| 367 |
+
axes[1].axis('off')
|
| 368 |
+
|
| 369 |
+
axes[2].imshow(mask, cmap='gray')
|
| 370 |
+
axes[2].set_title('Predicted Mask', fontsize=12)
|
| 371 |
+
axes[2].axis('off')
|
| 372 |
+
|
| 373 |
+
# Calculate metrics
|
| 374 |
+
intersection = np.logical_and(ground_truth > 127, mask > 0).sum()
|
| 375 |
+
union = np.logical_or(ground_truth > 127, mask > 0).sum()
|
| 376 |
+
iou = intersection / (union + 1e-8)
|
| 377 |
+
dice = 2 * intersection / (ground_truth.sum() + mask.sum() + 1e-8)
|
| 378 |
+
|
| 379 |
+
fig.suptitle(f'Actual vs Predicted (IoU: {iou:.2%}, Dice: {dice:.2%})',
|
| 380 |
+
fontsize=14, fontweight='bold')
|
| 381 |
+
|
| 382 |
+
plt.tight_layout()
|
| 383 |
+
plt.savefig(output_path / f'{base_name}_comparison.png', dpi=150, bbox_inches='tight')
|
| 384 |
+
plt.close()
|
| 385 |
+
print(f" ✓ Saved comparison (IoU: {iou:.2%}, Dice: {dice:.2%})")
|
| 386 |
+
|
| 387 |
+
# 4. Per-region visualization
|
| 388 |
+
if len(results) > 0:
|
| 389 |
+
n_regions = len(results)
|
| 390 |
+
cols = min(4, n_regions)
|
| 391 |
+
rows = (n_regions + cols - 1) // cols
|
| 392 |
+
|
| 393 |
+
fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
|
| 394 |
+
if n_regions == 1:
|
| 395 |
+
axes = [axes]
|
| 396 |
+
else:
|
| 397 |
+
axes = axes.flatten()
|
| 398 |
+
|
| 399 |
+
for i, result in enumerate(results):
|
| 400 |
+
bbox = result['bounding_box']
|
| 401 |
+
x, y, w, h = bbox # bbox is [x, y, w, h] in preprocessed coordinates
|
| 402 |
+
|
| 403 |
+
# Scale to original image coordinates
|
| 404 |
+
x_scaled = int(x * scale_x)
|
| 405 |
+
y_scaled = int(y * scale_y)
|
| 406 |
+
w_scaled = int(w * scale_x)
|
| 407 |
+
h_scaled = int(h * scale_y)
|
| 408 |
+
|
| 409 |
+
region_img = image[y_scaled:y_scaled+h_scaled, x_scaled:x_scaled+w_scaled]
|
| 410 |
+
|
| 411 |
+
axes[i].imshow(region_img)
|
| 412 |
+
axes[i].set_title(f"Region {i+1}: {result['forgery_type']}\n"
|
| 413 |
+
f"Confidence: {result['confidence']:.1%}",
|
| 414 |
+
fontsize=10)
|
| 415 |
+
axes[i].axis('off')
|
| 416 |
+
|
| 417 |
+
# Hide unused subplots
|
| 418 |
+
for i in range(n_regions, len(axes)):
|
| 419 |
+
axes[i].axis('off')
|
| 420 |
+
|
| 421 |
+
plt.tight_layout()
|
| 422 |
+
plt.savefig(output_path / f'{base_name}_regions.png', dpi=150, bbox_inches='tight')
|
| 423 |
+
plt.close()
|
| 424 |
+
print(f" ✓ Saved region details")
|
| 425 |
+
|
| 426 |
+
def _create_json_output(self, image_path, results, ground_truth, base_name, output_path):
|
| 427 |
+
"""Create JSON output with results"""
|
| 428 |
+
|
| 429 |
+
output = {
|
| 430 |
+
'image_path': str(image_path),
|
| 431 |
+
'timestamp': datetime.now().isoformat(),
|
| 432 |
+
'num_regions_detected': len(results),
|
| 433 |
+
'regions': results
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
# Add ground truth comparison if available
|
| 437 |
+
if ground_truth is not None:
|
| 438 |
+
output['has_ground_truth'] = True
|
| 439 |
+
|
| 440 |
+
# Save JSON
|
| 441 |
+
json_path = output_path / f'{base_name}_results.json'
|
| 442 |
+
with open(json_path, 'w') as f:
|
| 443 |
+
json.dump(output, f, indent=2)
|
| 444 |
+
|
| 445 |
+
print(f" ✓ Saved JSON results")
|
| 446 |
+
|
| 447 |
+
return output
|
| 448 |
+
|
| 449 |
+
def _create_comparison_visualization(self, image, prob_map, mask, ground_truth,
|
| 450 |
+
base_name, output_path):
|
| 451 |
+
"""Create comparison visualization between actual and predicted"""
|
| 452 |
+
|
| 453 |
+
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
|
| 454 |
+
|
| 455 |
+
# Original image
|
| 456 |
+
axes[0, 0].imshow(image)
|
| 457 |
+
axes[0, 0].set_title('Original Document', fontsize=14, fontweight='bold')
|
| 458 |
+
axes[0, 0].axis('off')
|
| 459 |
+
|
| 460 |
+
# Ground truth
|
| 461 |
+
axes[0, 1].imshow(ground_truth, cmap='gray')
|
| 462 |
+
axes[0, 1].set_title('Ground Truth (Actual)', fontsize=14, fontweight='bold')
|
| 463 |
+
axes[0, 1].axis('off')
|
| 464 |
+
|
| 465 |
+
# Predicted mask
|
| 466 |
+
axes[1, 0].imshow(mask, cmap='gray')
|
| 467 |
+
axes[1, 0].set_title('Predicted Mask', fontsize=14, fontweight='bold')
|
| 468 |
+
axes[1, 0].axis('off')
|
| 469 |
+
|
| 470 |
+
# Probability heatmap
|
| 471 |
+
im = axes[1, 1].imshow(prob_map, cmap='hot', vmin=0, vmax=1)
|
| 472 |
+
axes[1, 1].set_title('Probability Heatmap', fontsize=14, fontweight='bold')
|
| 473 |
+
axes[1, 1].axis('off')
|
| 474 |
+
plt.colorbar(im, ax=axes[1, 1], fraction=0.046, pad=0.04)
|
| 475 |
+
|
| 476 |
+
# Calculate metrics
|
| 477 |
+
intersection = np.logical_and(ground_truth > 127, mask > 0).sum()
|
| 478 |
+
union = np.logical_or(ground_truth > 127, mask > 0).sum()
|
| 479 |
+
gt_sum = (ground_truth > 127).sum()
|
| 480 |
+
pred_sum = (mask > 0).sum()
|
| 481 |
+
|
| 482 |
+
iou = intersection / (union + 1e-8)
|
| 483 |
+
dice = 2 * intersection / (gt_sum + pred_sum + 1e-8)
|
| 484 |
+
precision = intersection / (pred_sum + 1e-8) if pred_sum > 0 else 0
|
| 485 |
+
recall = intersection / (gt_sum + 1e-8) if gt_sum > 0 else 0
|
| 486 |
+
|
| 487 |
+
fig.suptitle(f'Actual vs Predicted Comparison\n'
|
| 488 |
+
f'IoU: {iou:.2%} | Dice: {dice:.2%} | '
|
| 489 |
+
f'Precision: {precision:.2%} | Recall: {recall:.2%}',
|
| 490 |
+
fontsize=16, fontweight='bold')
|
| 491 |
+
|
| 492 |
+
plt.tight_layout()
|
| 493 |
+
plt.savefig(output_path / f'{base_name}_comparison.png', dpi=150, bbox_inches='tight')
|
| 494 |
+
plt.close()
|
| 495 |
+
print(f" ✓ Saved comparison (IoU: {iou:.2%}, Dice: {dice:.2%})")
|
| 496 |
+
|
| 497 |
+
def _create_clean_result(self, image, base_name, output_path, ground_truth=None):
|
| 498 |
+
"""Create result for clean (no forgery) document"""
|
| 499 |
+
|
| 500 |
+
# Save original image
|
| 501 |
+
plt.figure(figsize=(10, 8))
|
| 502 |
+
plt.imshow(image)
|
| 503 |
+
plt.title('No Forgery Detected', fontsize=14, fontweight='bold', color='green')
|
| 504 |
+
plt.axis('off')
|
| 505 |
+
plt.tight_layout()
|
| 506 |
+
plt.savefig(output_path / f'{base_name}_clean.png', dpi=150, bbox_inches='tight')
|
| 507 |
+
plt.close()
|
| 508 |
+
|
| 509 |
+
# Create JSON
|
| 510 |
+
output = {
|
| 511 |
+
'timestamp': datetime.now().isoformat(),
|
| 512 |
+
'num_regions_detected': 0,
|
| 513 |
+
'regions': [],
|
| 514 |
+
'status': 'clean'
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
json_path = output_path / f'{base_name}_results.json'
|
| 518 |
+
with open(json_path, 'w') as f:
|
| 519 |
+
json.dump(output, f, indent=2)
|
| 520 |
+
|
| 521 |
+
return output
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def main():
|
| 525 |
+
parser = argparse.ArgumentParser(description='Document Forgery Detection Pipeline')
|
| 526 |
+
parser.add_argument('--image', type=str, required=True,
|
| 527 |
+
help='Path to input document image')
|
| 528 |
+
parser.add_argument('--ground_truth', type=str, default=None,
|
| 529 |
+
help='Path to ground truth mask (optional)')
|
| 530 |
+
parser.add_argument('--output_dir', type=str, default='outputs/inference',
|
| 531 |
+
help='Output directory for results')
|
| 532 |
+
parser.add_argument('--config', type=str, default='config.yaml',
|
| 533 |
+
help='Path to config file')
|
| 534 |
+
|
| 535 |
+
args = parser.parse_args()
|
| 536 |
+
|
| 537 |
+
# Initialize pipeline
|
| 538 |
+
pipeline = ForgeryDetectionPipeline(args.config)
|
| 539 |
+
|
| 540 |
+
# Run detection
|
| 541 |
+
results = pipeline.detect(
|
| 542 |
+
args.image,
|
| 543 |
+
ground_truth_path=args.ground_truth,
|
| 544 |
+
output_dir=args.output_dir
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
# Print summary
|
| 548 |
+
print("\nDetection Summary:")
|
| 549 |
+
print(f" Regions detected: {results['num_regions_detected']}")
|
| 550 |
+
if results['num_regions_detected'] > 0:
|
| 551 |
+
for region in results['regions']:
|
| 552 |
+
print(f" - {region['forgery_type']}: {region['confidence']:.1%} confidence")
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
if __name__ == '__main__':
|
| 556 |
+
main()
|
scripts/run_inference.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference Script for Document Forgery Detection
|
| 3 |
+
|
| 4 |
+
Run inference on single images or entire directories.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python scripts/run_inference.py --input path/to/image.jpg --model outputs/checkpoints/best_doctamper.pth
|
| 8 |
+
python scripts/run_inference.py --input path/to/folder/ --model outputs/checkpoints/best_doctamper.pth
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
import json
|
| 15 |
+
|
| 16 |
+
# Add src to path
|
| 17 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 18 |
+
|
| 19 |
+
from src.config import get_config
|
| 20 |
+
from src.inference import get_pipeline
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def parse_args():
|
| 24 |
+
parser = argparse.ArgumentParser(description="Run forgery detection inference")
|
| 25 |
+
|
| 26 |
+
parser.add_argument('--input', type=str, required=True,
|
| 27 |
+
help='Input image or directory path')
|
| 28 |
+
|
| 29 |
+
parser.add_argument('--model', type=str, required=True,
|
| 30 |
+
help='Path to localization model checkpoint')
|
| 31 |
+
|
| 32 |
+
parser.add_argument('--classifier', type=str, default=None,
|
| 33 |
+
help='Path to classifier directory (optional)')
|
| 34 |
+
|
| 35 |
+
parser.add_argument('--output', type=str, default='outputs/results',
|
| 36 |
+
help='Output directory')
|
| 37 |
+
|
| 38 |
+
parser.add_argument('--is_text', action='store_true',
|
| 39 |
+
help='Enable OCR features for text documents')
|
| 40 |
+
|
| 41 |
+
parser.add_argument('--config', type=str, default='config.yaml',
|
| 42 |
+
help='Path to config file')
|
| 43 |
+
|
| 44 |
+
return parser.parse_args()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def process_file(pipeline, input_path: str, output_dir: str):
|
| 48 |
+
"""Process a single file"""
|
| 49 |
+
try:
|
| 50 |
+
result = pipeline.run(input_path, output_dir)
|
| 51 |
+
return result
|
| 52 |
+
except Exception as e:
|
| 53 |
+
print(f"Error processing {input_path}: {e}")
|
| 54 |
+
return None
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def main():
|
| 58 |
+
args = parse_args()
|
| 59 |
+
|
| 60 |
+
# Load config
|
| 61 |
+
config = get_config(args.config)
|
| 62 |
+
|
| 63 |
+
print("\n" + "="*60)
|
| 64 |
+
print("Hybrid Document Forgery Detection - Inference")
|
| 65 |
+
print("="*60)
|
| 66 |
+
print(f"Input: {args.input}")
|
| 67 |
+
print(f"Model: {args.model}")
|
| 68 |
+
print(f"Classifier: {args.classifier or 'None'}")
|
| 69 |
+
print(f"Output: {args.output}")
|
| 70 |
+
print("="*60)
|
| 71 |
+
|
| 72 |
+
# Create pipeline
|
| 73 |
+
pipeline = get_pipeline(
|
| 74 |
+
config,
|
| 75 |
+
model_path=args.model,
|
| 76 |
+
classifier_path=args.classifier,
|
| 77 |
+
is_text_document=args.is_text
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Create output directory
|
| 81 |
+
output_dir = Path(args.output)
|
| 82 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 83 |
+
|
| 84 |
+
# Get input files
|
| 85 |
+
input_path = Path(args.input)
|
| 86 |
+
|
| 87 |
+
if input_path.is_file():
|
| 88 |
+
files = [input_path]
|
| 89 |
+
elif input_path.is_dir():
|
| 90 |
+
extensions = ['.jpg', '.jpeg', '.png', '.pdf', '.bmp', '.tiff']
|
| 91 |
+
files = [f for f in input_path.iterdir()
|
| 92 |
+
if f.suffix.lower() in extensions]
|
| 93 |
+
else:
|
| 94 |
+
print(f"Invalid input path: {input_path}")
|
| 95 |
+
return
|
| 96 |
+
|
| 97 |
+
print(f"\nProcessing {len(files)} file(s)...")
|
| 98 |
+
|
| 99 |
+
# Process files
|
| 100 |
+
all_results = []
|
| 101 |
+
|
| 102 |
+
for file_path in files:
|
| 103 |
+
result = process_file(pipeline, str(file_path), str(output_dir))
|
| 104 |
+
if result:
|
| 105 |
+
all_results.append(result)
|
| 106 |
+
|
| 107 |
+
# Print summary
|
| 108 |
+
status = "TAMPERED" if result['is_tampered'] else "AUTHENTIC"
|
| 109 |
+
print(f"\n {file_path.name}: {status}")
|
| 110 |
+
if result['is_tampered']:
|
| 111 |
+
print(f" Regions detected: {result['num_regions']}")
|
| 112 |
+
for region in result['regions'][:3]: # Show first 3
|
| 113 |
+
print(f" - {region['forgery_type']} (conf: {region['confidence']:.2f})")
|
| 114 |
+
|
| 115 |
+
# Save summary
|
| 116 |
+
summary_path = output_dir / 'inference_summary.json'
|
| 117 |
+
summary = {
|
| 118 |
+
'total_files': len(files),
|
| 119 |
+
'processed': len(all_results),
|
| 120 |
+
'tampered': sum(1 for r in all_results if r['is_tampered']),
|
| 121 |
+
'authentic': sum(1 for r in all_results if not r['is_tampered']),
|
| 122 |
+
'results': all_results
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
with open(summary_path, 'w') as f:
|
| 126 |
+
json.dump(summary, f, indent=2, default=str)
|
| 127 |
+
|
| 128 |
+
print("\n" + "="*60)
|
| 129 |
+
print("Inference Complete!")
|
| 130 |
+
print(f"Total: {summary['total_files']}, "
|
| 131 |
+
f"Tampered: {summary['tampered']}, "
|
| 132 |
+
f"Authentic: {summary['authentic']}")
|
| 133 |
+
print(f"Results saved to: {output_dir}")
|
| 134 |
+
print("="*60)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if __name__ == '__main__':
|
| 138 |
+
main()
|
scripts/setup.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Setup Script
|
| 3 |
+
|
| 4 |
+
Creates output directories and verifies installation.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python scripts/setup.py
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import sys
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
# Add src to path
|
| 14 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def create_directories():
|
| 18 |
+
"""Create required output directories"""
|
| 19 |
+
|
| 20 |
+
base_dir = Path(__file__).parent.parent
|
| 21 |
+
|
| 22 |
+
directories = [
|
| 23 |
+
base_dir / 'outputs',
|
| 24 |
+
base_dir / 'outputs' / 'checkpoints',
|
| 25 |
+
base_dir / 'outputs' / 'logs',
|
| 26 |
+
base_dir / 'outputs' / 'plots',
|
| 27 |
+
base_dir / 'outputs' / 'results',
|
| 28 |
+
base_dir / 'outputs' / 'classifier',
|
| 29 |
+
base_dir / 'outputs' / 'exported',
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
for directory in directories:
|
| 33 |
+
directory.mkdir(parents=True, exist_ok=True)
|
| 34 |
+
print(f"Created: {directory}")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def verify_installation():
|
| 38 |
+
"""Verify all required packages are installed"""
|
| 39 |
+
|
| 40 |
+
required_packages = [
|
| 41 |
+
('torch', 'PyTorch'),
|
| 42 |
+
('torchvision', 'TorchVision'),
|
| 43 |
+
('timm', 'TIMM'),
|
| 44 |
+
('lightgbm', 'LightGBM'),
|
| 45 |
+
('sklearn', 'Scikit-learn'),
|
| 46 |
+
('cv2', 'OpenCV'),
|
| 47 |
+
('PIL', 'Pillow'),
|
| 48 |
+
('numpy', 'NumPy'),
|
| 49 |
+
('pandas', 'Pandas'),
|
| 50 |
+
('matplotlib', 'Matplotlib'),
|
| 51 |
+
('seaborn', 'Seaborn'),
|
| 52 |
+
('albumentations', 'Albumentations'),
|
| 53 |
+
('tqdm', 'TQDM'),
|
| 54 |
+
('yaml', 'PyYAML'),
|
| 55 |
+
('pywt', 'PyWavelets'),
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
print("\nVerifying installation...")
|
| 59 |
+
print("-" * 40)
|
| 60 |
+
|
| 61 |
+
missing = []
|
| 62 |
+
|
| 63 |
+
for package, name in required_packages:
|
| 64 |
+
try:
|
| 65 |
+
__import__(package)
|
| 66 |
+
print(f" ✓ {name}")
|
| 67 |
+
except ImportError:
|
| 68 |
+
print(f" ✗ {name} (MISSING)")
|
| 69 |
+
missing.append(name)
|
| 70 |
+
|
| 71 |
+
# Check CUDA
|
| 72 |
+
print("-" * 40)
|
| 73 |
+
try:
|
| 74 |
+
import torch
|
| 75 |
+
if torch.cuda.is_available():
|
| 76 |
+
print(f" ✓ CUDA Available: {torch.cuda.get_device_name(0)}")
|
| 77 |
+
else:
|
| 78 |
+
print(" ⚠ CUDA Not Available (CPU mode)")
|
| 79 |
+
except Exception as e:
|
| 80 |
+
print(f" ✗ CUDA Check Failed: {e}")
|
| 81 |
+
|
| 82 |
+
return missing
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def verify_datasets():
|
| 86 |
+
"""Verify dataset paths exist"""
|
| 87 |
+
|
| 88 |
+
base_dir = Path(__file__).parent.parent
|
| 89 |
+
|
| 90 |
+
datasets = {
|
| 91 |
+
'DocTamper': base_dir / 'datasets' / 'DocTamper',
|
| 92 |
+
'RTM': base_dir / 'datasets' / 'RealTextManipulation',
|
| 93 |
+
'CASIA': base_dir / 'datasets' / 'CASIA 1.0 dataset',
|
| 94 |
+
'Receipts': base_dir / 'datasets' / 'findit2',
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
print("\nVerifying datasets...")
|
| 98 |
+
print("-" * 40)
|
| 99 |
+
|
| 100 |
+
for name, path in datasets.items():
|
| 101 |
+
if path.exists():
|
| 102 |
+
print(f" ✓ {name}: {path}")
|
| 103 |
+
else:
|
| 104 |
+
print(f" ✗ {name}: NOT FOUND ({path})")
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def main():
|
| 108 |
+
print("\n" + "="*60)
|
| 109 |
+
print("Hybrid Document Forgery Detection - Setup")
|
| 110 |
+
print("="*60)
|
| 111 |
+
|
| 112 |
+
# Create directories
|
| 113 |
+
print("\nCreating directories...")
|
| 114 |
+
print("-" * 40)
|
| 115 |
+
create_directories()
|
| 116 |
+
|
| 117 |
+
# Verify installation
|
| 118 |
+
missing = verify_installation()
|
| 119 |
+
|
| 120 |
+
# Verify datasets
|
| 121 |
+
verify_datasets()
|
| 122 |
+
|
| 123 |
+
# Summary
|
| 124 |
+
print("\n" + "="*60)
|
| 125 |
+
if missing:
|
| 126 |
+
print("Setup complete with WARNINGS")
|
| 127 |
+
print(f"Missing packages: {', '.join(missing)}")
|
| 128 |
+
print("Run: pip install -r requirements.txt")
|
| 129 |
+
else:
|
| 130 |
+
print("Setup Complete! All checks passed.")
|
| 131 |
+
print("="*60)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
if __name__ == '__main__':
|
| 135 |
+
main()
|
scripts/train_chunked.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chunked Training Script for Document Forgery Detection
|
| 3 |
+
|
| 4 |
+
Supports training on large datasets (DocTamper) in chunks to manage RAM constraints.
|
| 5 |
+
Usage:
|
| 6 |
+
python scripts/train_chunked.py --dataset doctamper --chunk 1
|
| 7 |
+
python scripts/train_chunked.py --dataset rtm
|
| 8 |
+
python scripts/train_chunked.py --dataset casia
|
| 9 |
+
python scripts/train_chunked.py --dataset receipts
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
# Add src to path
|
| 18 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import gc
|
| 22 |
+
|
| 23 |
+
from src.config import get_config
|
| 24 |
+
from src.training import get_trainer
|
| 25 |
+
from src.utils import plot_training_curves, plot_chunked_training_progress, generate_training_report
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def parse_args():
|
| 29 |
+
parser = argparse.ArgumentParser(description="Train forgery detection model")
|
| 30 |
+
|
| 31 |
+
parser.add_argument('--dataset', type=str, default='doctamper',
|
| 32 |
+
choices=['doctamper', 'rtm', 'casia', 'receipts', 'fcd', 'scd'],
|
| 33 |
+
help='Dataset to train on')
|
| 34 |
+
|
| 35 |
+
parser.add_argument('--chunk', type=int, default=None,
|
| 36 |
+
help='Chunk number (1-4) for DocTamper chunked training')
|
| 37 |
+
|
| 38 |
+
parser.add_argument('--epochs', type=int, default=None,
|
| 39 |
+
help='Number of epochs (overrides config)')
|
| 40 |
+
|
| 41 |
+
parser.add_argument('--resume', type=str, default=None,
|
| 42 |
+
help='Checkpoint to resume from')
|
| 43 |
+
|
| 44 |
+
parser.add_argument('--config', type=str, default='config.yaml',
|
| 45 |
+
help='Path to config file')
|
| 46 |
+
|
| 47 |
+
return parser.parse_args()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def train_chunk(config, dataset_name: str, chunk_id: int, epochs: int = None, resume: str = None):
|
| 51 |
+
"""Train a single chunk"""
|
| 52 |
+
|
| 53 |
+
# Calculate chunk boundaries
|
| 54 |
+
chunks = config.get('data.chunked_training.chunks', [])
|
| 55 |
+
|
| 56 |
+
if chunk_id > len(chunks):
|
| 57 |
+
raise ValueError(f"Invalid chunk ID: {chunk_id}. Max: {len(chunks)}")
|
| 58 |
+
|
| 59 |
+
chunk_config = chunks[chunk_id - 1]
|
| 60 |
+
chunk_start = chunk_config['start']
|
| 61 |
+
chunk_end = chunk_config['end']
|
| 62 |
+
chunk_name = chunk_config['name']
|
| 63 |
+
|
| 64 |
+
print(f"\n{'='*60}")
|
| 65 |
+
print(f"Training Chunk {chunk_id}: {chunk_name}")
|
| 66 |
+
print(f"Range: {chunk_start*100:.0f}% - {chunk_end*100:.0f}%")
|
| 67 |
+
print(f"{'='*60}")
|
| 68 |
+
|
| 69 |
+
# Create trainer
|
| 70 |
+
trainer = get_trainer(config, dataset_name)
|
| 71 |
+
|
| 72 |
+
# Resume from previous chunk if applicable
|
| 73 |
+
if resume:
|
| 74 |
+
# For chunked training, reset epoch counter to train full epochs on new data
|
| 75 |
+
trainer.load_checkpoint(resume, reset_epoch=True)
|
| 76 |
+
elif chunk_id > 1:
|
| 77 |
+
# Auto-resume from previous chunk
|
| 78 |
+
prev_checkpoint = f'{dataset_name}_chunk{chunk_id-1}_final.pth'
|
| 79 |
+
if (Path(config.get('outputs.checkpoints')) / prev_checkpoint).exists():
|
| 80 |
+
print(f"Auto-resuming from previous chunk: {prev_checkpoint}")
|
| 81 |
+
trainer.load_checkpoint(prev_checkpoint, reset_epoch=True)
|
| 82 |
+
|
| 83 |
+
# Train
|
| 84 |
+
history = trainer.train(
|
| 85 |
+
epochs=epochs,
|
| 86 |
+
chunk_start=chunk_start,
|
| 87 |
+
chunk_end=chunk_end,
|
| 88 |
+
chunk_id=chunk_id,
|
| 89 |
+
resume_from=None # Already loaded above
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Plot training curves
|
| 93 |
+
plot_dir = Path(config.get('outputs.plots', 'outputs/plots'))
|
| 94 |
+
plot_dir.mkdir(parents=True, exist_ok=True)
|
| 95 |
+
|
| 96 |
+
plot_path = plot_dir / f'{dataset_name}_chunk{chunk_id}_curves.png'
|
| 97 |
+
plot_training_curves(
|
| 98 |
+
history,
|
| 99 |
+
str(plot_path),
|
| 100 |
+
title=f"{dataset_name.upper()} Chunk {chunk_id} Training"
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Generate report
|
| 104 |
+
report_path = plot_dir / f'{dataset_name}_chunk{chunk_id}_report.txt'
|
| 105 |
+
generate_training_report(history, str(report_path), f"{dataset_name} Chunk {chunk_id}")
|
| 106 |
+
|
| 107 |
+
# Clear memory
|
| 108 |
+
del trainer
|
| 109 |
+
gc.collect()
|
| 110 |
+
torch.cuda.empty_cache()
|
| 111 |
+
|
| 112 |
+
return history
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def train_full_dataset(config, dataset_name: str, epochs: int = None, resume: str = None):
|
| 116 |
+
"""Train on full dataset (for smaller datasets)"""
|
| 117 |
+
|
| 118 |
+
print(f"\n{'='*60}")
|
| 119 |
+
print(f"Training on: {dataset_name.upper()}")
|
| 120 |
+
print(f"{'='*60}")
|
| 121 |
+
|
| 122 |
+
# Create trainer
|
| 123 |
+
trainer = get_trainer(config, dataset_name)
|
| 124 |
+
|
| 125 |
+
# Load checkpoint if resuming (reset epoch counter for new dataset)
|
| 126 |
+
if resume:
|
| 127 |
+
print(f"Loading weights from: {resume}")
|
| 128 |
+
trainer.load_checkpoint(resume, reset_epoch=True)
|
| 129 |
+
print("Epoch counter reset to 0 for new dataset training")
|
| 130 |
+
|
| 131 |
+
# Train
|
| 132 |
+
history = trainer.train(
|
| 133 |
+
epochs=epochs,
|
| 134 |
+
chunk_id=0,
|
| 135 |
+
resume_from=None # Already loaded above
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Plot training curves
|
| 139 |
+
plot_dir = Path(config.get('outputs.plots', 'outputs/plots'))
|
| 140 |
+
plot_dir.mkdir(parents=True, exist_ok=True)
|
| 141 |
+
|
| 142 |
+
plot_path = plot_dir / f'{dataset_name}_training_curves.png'
|
| 143 |
+
plot_training_curves(
|
| 144 |
+
history,
|
| 145 |
+
str(plot_path),
|
| 146 |
+
title=f"{dataset_name.upper()} Training"
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# Generate report
|
| 150 |
+
report_path = plot_dir / f'{dataset_name}_report.txt'
|
| 151 |
+
generate_training_report(history, str(report_path), dataset_name)
|
| 152 |
+
|
| 153 |
+
return history
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def main():
|
| 157 |
+
args = parse_args()
|
| 158 |
+
|
| 159 |
+
# Load config
|
| 160 |
+
config = get_config(args.config)
|
| 161 |
+
|
| 162 |
+
print("\n" + "="*60)
|
| 163 |
+
print("Hybrid Document Forgery Detection - Training")
|
| 164 |
+
print("="*60)
|
| 165 |
+
print(f"Dataset: {args.dataset}")
|
| 166 |
+
print(f"Device: {config.get('system.device')}")
|
| 167 |
+
print(f"CUDA Available: {torch.cuda.is_available()}")
|
| 168 |
+
if torch.cuda.is_available():
|
| 169 |
+
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
| 170 |
+
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
| 171 |
+
print("="*60)
|
| 172 |
+
|
| 173 |
+
# DocTamper: chunked training
|
| 174 |
+
if args.dataset == 'doctamper' and args.chunk is not None:
|
| 175 |
+
history = train_chunk(
|
| 176 |
+
config,
|
| 177 |
+
args.dataset,
|
| 178 |
+
args.chunk,
|
| 179 |
+
epochs=args.epochs,
|
| 180 |
+
resume=args.resume
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# DocTamper: all chunks sequentially
|
| 184 |
+
elif args.dataset == 'doctamper' and args.chunk is None:
|
| 185 |
+
print("Training DocTamper in 4 chunks...")
|
| 186 |
+
|
| 187 |
+
all_histories = []
|
| 188 |
+
for chunk_id in range(1, 5):
|
| 189 |
+
history = train_chunk(
|
| 190 |
+
config,
|
| 191 |
+
args.dataset,
|
| 192 |
+
chunk_id,
|
| 193 |
+
epochs=args.epochs,
|
| 194 |
+
resume=None if chunk_id == 1 else None # Auto-resume from prev chunk
|
| 195 |
+
)
|
| 196 |
+
all_histories.append(history)
|
| 197 |
+
|
| 198 |
+
# Plot combined progress
|
| 199 |
+
plot_dir = Path(config.get('outputs.plots', 'outputs/plots'))
|
| 200 |
+
combined_path = plot_dir / 'doctamper_all_chunks_progress.png'
|
| 201 |
+
plot_chunked_training_progress(
|
| 202 |
+
all_histories,
|
| 203 |
+
str(combined_path),
|
| 204 |
+
title="DocTamper Chunked Training Progress"
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# Other datasets: full training
|
| 208 |
+
else:
|
| 209 |
+
history = train_full_dataset(
|
| 210 |
+
config,
|
| 211 |
+
args.dataset,
|
| 212 |
+
epochs=args.epochs,
|
| 213 |
+
resume=args.resume
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
print("\n" + "="*60)
|
| 217 |
+
print("Training Complete!")
|
| 218 |
+
print("="*60)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
if __name__ == '__main__':
|
| 222 |
+
main()
|
scripts/train_classifier_doctamper_fixed.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LightGBM Classifier Training - DocTamper with Tampering Labels
|
| 3 |
+
FIXED VERSION with proper checkpointing and feature dimension handling
|
| 4 |
+
|
| 5 |
+
Implements Algorithm Steps 7-8:
|
| 6 |
+
7. Hybrid Feature Extraction
|
| 7 |
+
8. Region-wise Forgery Classification
|
| 8 |
+
|
| 9 |
+
Uses:
|
| 10 |
+
- Localization: best_doctamper.pth (Steps 1-6 complete)
|
| 11 |
+
- Training: DocTamper TrainingSet + tampering/DocTamperV1-TrainingSet.pk
|
| 12 |
+
- Testing: DocTamper TestingSet + tampering/DocTamperV1-TestingSet.pk
|
| 13 |
+
- Classes: Copy-Move (CM), Splicing (SP), Generation (GE)
|
| 14 |
+
|
| 15 |
+
Features:
|
| 16 |
+
- ✅ Checkpoint saving every 1000 samples
|
| 17 |
+
- ✅ Resume from checkpoint if interrupted
|
| 18 |
+
- ✅ Fixed feature dimension mismatch
|
| 19 |
+
- ✅ Robust error handling
|
| 20 |
+
|
| 21 |
+
Usage:
|
| 22 |
+
python scripts/train_classifier_doctamper_fixed.py
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import sys
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
import numpy as np
|
| 28 |
+
import pickle
|
| 29 |
+
import lmdb
|
| 30 |
+
import cv2
|
| 31 |
+
import torch
|
| 32 |
+
from tqdm import tqdm
|
| 33 |
+
import json
|
| 34 |
+
|
| 35 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 36 |
+
|
| 37 |
+
from src.config import get_config
|
| 38 |
+
from src.models import get_model
|
| 39 |
+
from src.features import get_feature_extractor
|
| 40 |
+
from src.training.classifier import get_classifier
|
| 41 |
+
|
| 42 |
+
# Configuration
|
| 43 |
+
MODEL_PATH = 'outputs/checkpoints/best_doctamper.pth'
|
| 44 |
+
OUTPUT_DIR = 'outputs/classifier'
|
| 45 |
+
MAX_SAMPLES = 999999 # Use all available samples
|
| 46 |
+
|
| 47 |
+
# Label mapping (Algorithm Step 8.2) - 3 classes
|
| 48 |
+
LABEL_MAP = {
|
| 49 |
+
'CM': 0, # Copy-Move
|
| 50 |
+
'SP': 1, # Splicing
|
| 51 |
+
'GE': 2, # Generation (AI-generated, separate from Splicing)
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def load_tampering_labels(label_file):
|
| 56 |
+
"""Load forgery type labels from tampering folder"""
|
| 57 |
+
with open(label_file, 'rb') as f:
|
| 58 |
+
labels = pickle.load(f)
|
| 59 |
+
|
| 60 |
+
print(f"Loaded {len(labels)} labels from {label_file}")
|
| 61 |
+
return labels
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def load_sample_from_lmdb(lmdb_env, index):
|
| 65 |
+
"""Load image and mask from LMDB"""
|
| 66 |
+
txn = lmdb_env.begin()
|
| 67 |
+
|
| 68 |
+
# Get image
|
| 69 |
+
img_key = f'image-{index:09d}'.encode('utf-8')
|
| 70 |
+
img_data = txn.get(img_key)
|
| 71 |
+
if not img_data:
|
| 72 |
+
return None, None
|
| 73 |
+
|
| 74 |
+
# Get mask (DocTamper uses 'label-' not 'mask-')
|
| 75 |
+
mask_key = f'label-{index:09d}'.encode('utf-8')
|
| 76 |
+
mask_data = txn.get(mask_key)
|
| 77 |
+
if not mask_data:
|
| 78 |
+
return None, None
|
| 79 |
+
|
| 80 |
+
# Decode
|
| 81 |
+
img_array = np.frombuffer(img_data, dtype=np.uint8)
|
| 82 |
+
image = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
|
| 83 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 84 |
+
|
| 85 |
+
mask_array = np.frombuffer(mask_data, dtype=np.uint8)
|
| 86 |
+
mask = cv2.imdecode(mask_array, cv2.IMREAD_GRAYSCALE)
|
| 87 |
+
|
| 88 |
+
return image, mask
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def extract_features(config, model, lmdb_path, tampering_labels,
|
| 92 |
+
max_samples, device, split_name):
|
| 93 |
+
"""
|
| 94 |
+
Extract hybrid features with checkpointing and resume capability
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
print(f"\n{'='*60}")
|
| 98 |
+
print(f"Extracting features from {split_name}")
|
| 99 |
+
print(f"{'='*60}")
|
| 100 |
+
|
| 101 |
+
# Setup checkpoint directory
|
| 102 |
+
checkpoint_dir = Path(OUTPUT_DIR)
|
| 103 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 104 |
+
|
| 105 |
+
# Check for existing checkpoint to resume
|
| 106 |
+
checkpoints = list(checkpoint_dir.glob(f'checkpoint_{split_name}_*.npz'))
|
| 107 |
+
if checkpoints:
|
| 108 |
+
latest_checkpoint = max(checkpoints, key=lambda p: int(p.stem.split('_')[-1]))
|
| 109 |
+
print(f"✓ Found checkpoint: {latest_checkpoint.name}")
|
| 110 |
+
|
| 111 |
+
data = np.load(latest_checkpoint, allow_pickle=True)
|
| 112 |
+
all_features = data['features'].tolist()
|
| 113 |
+
all_labels = data['labels'].tolist()
|
| 114 |
+
expected_dim = int(data['feature_dim'])
|
| 115 |
+
start_idx = len(all_features)
|
| 116 |
+
|
| 117 |
+
print(f"✓ Resuming from sample {start_idx}, feature_dim={expected_dim}")
|
| 118 |
+
else:
|
| 119 |
+
all_features = []
|
| 120 |
+
all_labels = []
|
| 121 |
+
expected_dim = None
|
| 122 |
+
start_idx = 0
|
| 123 |
+
|
| 124 |
+
# Open LMDB
|
| 125 |
+
env = lmdb.open(lmdb_path, readonly=True, lock=False)
|
| 126 |
+
|
| 127 |
+
# Initialize feature extractor
|
| 128 |
+
feature_extractor = get_feature_extractor(config, is_text_document=True)
|
| 129 |
+
|
| 130 |
+
# Process samples
|
| 131 |
+
num_processed = start_idx
|
| 132 |
+
dim_mismatch_count = 0
|
| 133 |
+
|
| 134 |
+
for i in tqdm(range(start_idx, min(len(tampering_labels), max_samples)),
|
| 135 |
+
desc=f"Processing {split_name}", initial=start_idx,
|
| 136 |
+
total=min(len(tampering_labels), max_samples)):
|
| 137 |
+
try:
|
| 138 |
+
# Skip if no label
|
| 139 |
+
if i not in tampering_labels:
|
| 140 |
+
continue
|
| 141 |
+
|
| 142 |
+
# Get forgery type label
|
| 143 |
+
forgery_type = tampering_labels[i]
|
| 144 |
+
if forgery_type not in LABEL_MAP:
|
| 145 |
+
continue
|
| 146 |
+
|
| 147 |
+
label = LABEL_MAP[forgery_type]
|
| 148 |
+
|
| 149 |
+
# Load image and mask
|
| 150 |
+
image, mask = load_sample_from_lmdb(env, i)
|
| 151 |
+
if image is None or mask is None:
|
| 152 |
+
continue
|
| 153 |
+
|
| 154 |
+
# Skip if no forgery
|
| 155 |
+
if mask.max() == 0:
|
| 156 |
+
continue
|
| 157 |
+
|
| 158 |
+
# Prepare for model
|
| 159 |
+
image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
|
| 160 |
+
image_tensor = image_tensor.unsqueeze(0).to(device)
|
| 161 |
+
|
| 162 |
+
# Get deep features from localization model
|
| 163 |
+
with torch.no_grad():
|
| 164 |
+
logits, decoder_features = model(image_tensor)
|
| 165 |
+
|
| 166 |
+
# Use ground truth mask for feature extraction
|
| 167 |
+
mask_binary = (mask > 127).astype(np.uint8)
|
| 168 |
+
|
| 169 |
+
# Extract hybrid features
|
| 170 |
+
features = feature_extractor.extract(
|
| 171 |
+
image / 255.0,
|
| 172 |
+
mask_binary,
|
| 173 |
+
[f.cpu() for f in decoder_features]
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# Set expected dimension from first valid sample
|
| 177 |
+
if expected_dim is None:
|
| 178 |
+
expected_dim = len(features)
|
| 179 |
+
print(f"\n✓ Feature dimension set to: {expected_dim}")
|
| 180 |
+
|
| 181 |
+
# Ensure consistent feature dimension
|
| 182 |
+
if len(features) != expected_dim:
|
| 183 |
+
if len(features) < expected_dim:
|
| 184 |
+
features = np.pad(features, (0, expected_dim - len(features)), mode='constant')
|
| 185 |
+
else:
|
| 186 |
+
features = features[:expected_dim]
|
| 187 |
+
dim_mismatch_count += 1
|
| 188 |
+
|
| 189 |
+
all_features.append(features)
|
| 190 |
+
all_labels.append(label)
|
| 191 |
+
num_processed += 1
|
| 192 |
+
|
| 193 |
+
# Save checkpoint every 10,000 samples (only 12 checkpoints total)
|
| 194 |
+
if num_processed % 10000 == 0:
|
| 195 |
+
checkpoint_path = checkpoint_dir / f'checkpoint_{split_name}_{num_processed}.npz'
|
| 196 |
+
features_array = np.array(all_features, dtype=np.float32)
|
| 197 |
+
labels_array = np.array(all_labels, dtype=np.int32)
|
| 198 |
+
|
| 199 |
+
np.savez_compressed(checkpoint_path,
|
| 200 |
+
features=features_array,
|
| 201 |
+
labels=labels_array,
|
| 202 |
+
feature_dim=expected_dim)
|
| 203 |
+
print(f"\n✓ Checkpoint: {num_processed} samples (dim={expected_dim}, mismatches={dim_mismatch_count})")
|
| 204 |
+
|
| 205 |
+
# Delete old checkpoints to save space (keep only last 2)
|
| 206 |
+
old_checkpoints = sorted(checkpoint_dir.glob(f'checkpoint_{split_name}_*.npz'))
|
| 207 |
+
if len(old_checkpoints) > 2:
|
| 208 |
+
for old_cp in old_checkpoints[:-2]:
|
| 209 |
+
old_cp.unlink()
|
| 210 |
+
print(f" Cleaned up: {old_cp.name}")
|
| 211 |
+
|
| 212 |
+
except Exception as e:
|
| 213 |
+
print(f"\n⚠ Error at sample {i}: {str(e)[:80]}")
|
| 214 |
+
continue
|
| 215 |
+
|
| 216 |
+
env.close()
|
| 217 |
+
|
| 218 |
+
print(f"\n✓ Extracted {num_processed} samples")
|
| 219 |
+
if dim_mismatch_count > 0:
|
| 220 |
+
print(f"⚠ Fixed {dim_mismatch_count} dimension mismatches")
|
| 221 |
+
|
| 222 |
+
# Save final features
|
| 223 |
+
final_path = checkpoint_dir / f'features_{split_name}_final.npz'
|
| 224 |
+
if len(all_features) > 0:
|
| 225 |
+
features_array = np.array(all_features, dtype=np.float32)
|
| 226 |
+
labels_array = np.array(all_labels, dtype=np.int32)
|
| 227 |
+
|
| 228 |
+
np.savez_compressed(final_path,
|
| 229 |
+
features=features_array,
|
| 230 |
+
labels=labels_array,
|
| 231 |
+
feature_dim=expected_dim)
|
| 232 |
+
print(f"✓ Final features saved: {final_path}")
|
| 233 |
+
print(f" Shape: features={features_array.shape}, labels={labels_array.shape}")
|
| 234 |
+
|
| 235 |
+
return features_array, labels_array
|
| 236 |
+
|
| 237 |
+
return None, None
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def main():
|
| 241 |
+
config = get_config('config.yaml')
|
| 242 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 243 |
+
|
| 244 |
+
print("\n" + "="*60)
|
| 245 |
+
print("LightGBM Classifier Training - DocTamper (FIXED)")
|
| 246 |
+
print("Implements Algorithm Steps 7-8")
|
| 247 |
+
print("="*60)
|
| 248 |
+
print(f"Model: {MODEL_PATH}")
|
| 249 |
+
print(f"Device: {device}")
|
| 250 |
+
print(f"Max samples: {MAX_SAMPLES}")
|
| 251 |
+
print("="*60)
|
| 252 |
+
print("\nForgery Type Classes (Step 8.2):")
|
| 253 |
+
print(" 0: Copy-Move (CM)")
|
| 254 |
+
print(" 1: Splicing (SP)")
|
| 255 |
+
print(" 2: Generation (GE)")
|
| 256 |
+
print("="*60)
|
| 257 |
+
|
| 258 |
+
# Load localization model
|
| 259 |
+
print("\nLoading localization model...")
|
| 260 |
+
model = get_model(config).to(device)
|
| 261 |
+
checkpoint = torch.load(MODEL_PATH, map_location=device)
|
| 262 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 263 |
+
model.eval()
|
| 264 |
+
print(f"✓ Model loaded (Val Dice: {checkpoint.get('best_metric', 0):.4f})")
|
| 265 |
+
|
| 266 |
+
# Load tampering labels
|
| 267 |
+
train_labels = load_tampering_labels(
|
| 268 |
+
'datasets/DocTamper/tampering/DocTamperV1-TrainingSet.pk'
|
| 269 |
+
)
|
| 270 |
+
test_labels = load_tampering_labels(
|
| 271 |
+
'datasets/DocTamper/tampering/DocTamperV1-TestingSet.pk'
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Extract features from TrainingSet
|
| 275 |
+
X_train, y_train = extract_features(
|
| 276 |
+
config, model,
|
| 277 |
+
'datasets/DocTamper/DocTamperV1-TrainingSet',
|
| 278 |
+
train_labels,
|
| 279 |
+
MAX_SAMPLES,
|
| 280 |
+
device,
|
| 281 |
+
'TrainingSet'
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# Extract features from TestingSet
|
| 285 |
+
X_test, y_test = extract_features(
|
| 286 |
+
config, model,
|
| 287 |
+
'datasets/DocTamper/DocTamperV1-TestingSet',
|
| 288 |
+
test_labels,
|
| 289 |
+
MAX_SAMPLES // 4,
|
| 290 |
+
device,
|
| 291 |
+
'TestingSet'
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
if X_train is None or X_test is None:
|
| 295 |
+
print("\n❌ No features extracted!")
|
| 296 |
+
return
|
| 297 |
+
|
| 298 |
+
# Summary
|
| 299 |
+
print("\n" + "="*60)
|
| 300 |
+
print("Dataset Summary")
|
| 301 |
+
print("="*60)
|
| 302 |
+
print(f"Training samples: {len(X_train):,}")
|
| 303 |
+
print(f"Testing samples: {len(X_test):,}")
|
| 304 |
+
print(f"Feature dimension: {X_train.shape[1]}")
|
| 305 |
+
|
| 306 |
+
print(f"\nTraining class distribution:")
|
| 307 |
+
train_counts = np.bincount(y_train)
|
| 308 |
+
class_names = ['Copy-Move', 'Splicing', 'Generation']
|
| 309 |
+
for i, count in enumerate(train_counts):
|
| 310 |
+
if i < len(class_names):
|
| 311 |
+
print(f" {class_names[i]}: {count:,} ({count/len(y_train)*100:.1f}%)")
|
| 312 |
+
|
| 313 |
+
print(f"\nTesting class distribution:")
|
| 314 |
+
test_counts = np.bincount(y_test)
|
| 315 |
+
for i, count in enumerate(test_counts):
|
| 316 |
+
if i < len(class_names):
|
| 317 |
+
print(f" {class_names[i]}: {count:,} ({count/len(y_test)*100:.1f}%)")
|
| 318 |
+
|
| 319 |
+
# Train classifier
|
| 320 |
+
print("\n" + "="*60)
|
| 321 |
+
print("Training LightGBM Classifier (Step 8.1)")
|
| 322 |
+
print("="*60)
|
| 323 |
+
|
| 324 |
+
output_dir = Path(OUTPUT_DIR)
|
| 325 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 326 |
+
|
| 327 |
+
classifier = get_classifier(config)
|
| 328 |
+
feature_names = get_feature_extractor(config, is_text_document=True).get_feature_names()
|
| 329 |
+
|
| 330 |
+
# Combine train and test for sklearn train_test_split
|
| 331 |
+
X_combined = np.vstack([X_train, X_test])
|
| 332 |
+
y_combined = np.concatenate([y_train, y_test])
|
| 333 |
+
|
| 334 |
+
metrics = classifier.train(X_combined, y_combined, feature_names=feature_names)
|
| 335 |
+
|
| 336 |
+
# Save results
|
| 337 |
+
classifier.save(str(output_dir))
|
| 338 |
+
print(f"\n✓ Classifier saved to: {output_dir}")
|
| 339 |
+
|
| 340 |
+
# Save metrics
|
| 341 |
+
metrics_path = output_dir / 'training_metrics.json'
|
| 342 |
+
with open(metrics_path, 'w') as f:
|
| 343 |
+
json.dump(metrics, f, indent=2)
|
| 344 |
+
|
| 345 |
+
# Save class mapping
|
| 346 |
+
class_mapping = {
|
| 347 |
+
0: 'Copy-Move',
|
| 348 |
+
1: 'Splicing',
|
| 349 |
+
2: 'Generation'
|
| 350 |
+
}
|
| 351 |
+
mapping_path = output_dir / 'class_mapping.json'
|
| 352 |
+
with open(mapping_path, 'w') as f:
|
| 353 |
+
json.dump(class_mapping, f, indent=2)
|
| 354 |
+
|
| 355 |
+
print("\n" + "="*60)
|
| 356 |
+
print("✅ Classifier Training Complete!")
|
| 357 |
+
print("Algorithm Steps 7-8: DONE")
|
| 358 |
+
print("="*60)
|
| 359 |
+
print(f"\nResults:")
|
| 360 |
+
print(f" Test Accuracy: {metrics.get('test_accuracy', 'N/A')}")
|
| 361 |
+
print(f" Test F1 Score: {metrics.get('test_f1', 'N/A')}")
|
| 362 |
+
print(f"\nOutput: {output_dir}")
|
| 363 |
+
print("\nNext: Implement Steps 9-11 in inference pipeline")
|
| 364 |
+
print("="*60 + "\n")
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
if __name__ == '__main__':
|
| 368 |
+
main()
|