Spaces:
Sleeping
Sleeping
Upload 54 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- app.py +231 -0
- config.yaml +297 -0
- models/best_doctamper.pth +3 -0
- models/classifier/classifier_metadata.json +821 -0
- models/classifier/lightgbm_model.txt +0 -0
- models/classifier/scaler.joblib +3 -0
- src/__init__.py +32 -0
- src/__pycache__/__init__.cpython-312.pyc +0 -0
- src/config/__init__.py +5 -0
- src/config/__pycache__/__init__.cpython-312.pyc +0 -0
- src/config/__pycache__/config_loader.cpython-312.pyc +0 -0
- src/config/config_loader.py +117 -0
- src/data/__init__.py +23 -0
- src/data/__pycache__/__init__.cpython-312.pyc +0 -0
- src/data/__pycache__/augmentation.cpython-312.pyc +0 -0
- src/data/__pycache__/datasets.cpython-312.pyc +0 -0
- src/data/__pycache__/preprocessing.cpython-312.pyc +0 -0
- src/data/augmentation.py +150 -0
- src/data/datasets.py +541 -0
- src/data/preprocessing.py +226 -0
- src/features/__init__.py +32 -0
- src/features/__pycache__/__init__.cpython-312.pyc +0 -0
- src/features/__pycache__/feature_extraction.cpython-312.pyc +0 -0
- src/features/__pycache__/region_extraction.cpython-312.pyc +0 -0
- src/features/feature_extraction.py +485 -0
- src/features/region_extraction.py +226 -0
- src/inference/__init__.py +5 -0
- src/inference/__pycache__/__init__.cpython-312.pyc +0 -0
- src/inference/__pycache__/pipeline.cpython-312.pyc +0 -0
- src/inference/pipeline.py +359 -0
- src/models/__init__.py +19 -0
- src/models/__pycache__/__init__.cpython-312.pyc +0 -0
- src/models/__pycache__/decoder.cpython-312.pyc +0 -0
- src/models/__pycache__/encoder.cpython-312.pyc +0 -0
- src/models/__pycache__/losses.cpython-312.pyc +0 -0
- src/models/__pycache__/network.cpython-312.pyc +0 -0
- src/models/decoder.py +186 -0
- src/models/encoder.py +75 -0
- src/models/losses.py +168 -0
- src/models/network.py +133 -0
- src/training/__init__.py +24 -0
- src/training/__pycache__/__init__.cpython-312.pyc +0 -0
- src/training/__pycache__/classifier.cpython-312.pyc +0 -0
- src/training/__pycache__/metrics.cpython-312.pyc +0 -0
- src/training/__pycache__/trainer.cpython-312.pyc +0 -0
- src/training/classifier.py +282 -0
- src/training/metrics.py +305 -0
- src/training/trainer.py +450 -0
- src/utils/__init__.py +28 -0
- src/utils/__pycache__/__init__.cpython-312.pyc +0 -0
app.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Document Forgery Detection - Gradio Interface for Hugging Face Spaces
|
| 3 |
+
|
| 4 |
+
This app provides a web interface for detecting and classifying document forgeries.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import torch
|
| 9 |
+
import cv2
|
| 10 |
+
import numpy as np
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import json
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
import sys
|
| 15 |
+
|
| 16 |
+
# Add src to path
|
| 17 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 18 |
+
|
| 19 |
+
from src.models import get_model
|
| 20 |
+
from src.config import get_config
|
| 21 |
+
from src.data.preprocessing import DocumentPreprocessor
|
| 22 |
+
from src.data.augmentation import DatasetAwareAugmentation
|
| 23 |
+
from src.features.region_extraction import get_mask_refiner, get_region_extractor
|
| 24 |
+
from src.features.feature_extraction import get_feature_extractor
|
| 25 |
+
from src.training.classifier import ForgeryClassifier
|
| 26 |
+
|
| 27 |
+
# Class names
|
| 28 |
+
CLASS_NAMES = {0: 'Copy-Move', 1: 'Splicing', 2: 'Generation'}
|
| 29 |
+
CLASS_COLORS = {
|
| 30 |
+
0: (255, 0, 0), # Red for Copy-Move
|
| 31 |
+
1: (0, 255, 0), # Green for Splicing
|
| 32 |
+
2: (0, 0, 255) # Blue for Generation
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ForgeryDetector:
|
| 37 |
+
"""Main forgery detection pipeline"""
|
| 38 |
+
|
| 39 |
+
def __init__(self):
|
| 40 |
+
print("Loading models...")
|
| 41 |
+
|
| 42 |
+
# Load config
|
| 43 |
+
self.config = get_config('config.yaml')
|
| 44 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 45 |
+
|
| 46 |
+
# Load segmentation model
|
| 47 |
+
self.model = get_model(self.config).to(self.device)
|
| 48 |
+
checkpoint = torch.load('models/segmentation_model.pth', map_location=self.device)
|
| 49 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 50 |
+
self.model.eval()
|
| 51 |
+
|
| 52 |
+
# Load classifier
|
| 53 |
+
self.classifier = ForgeryClassifier(self.config)
|
| 54 |
+
self.classifier.load('models/classifier')
|
| 55 |
+
|
| 56 |
+
# Initialize components
|
| 57 |
+
self.preprocessor = DocumentPreprocessor(self.config, 'doctamper')
|
| 58 |
+
self.augmentation = DatasetAwareAugmentation(self.config, 'doctamper', is_training=False)
|
| 59 |
+
self.mask_refiner = get_mask_refiner(self.config)
|
| 60 |
+
self.region_extractor = get_region_extractor(self.config)
|
| 61 |
+
self.feature_extractor = get_feature_extractor(self.config, is_text_document=True)
|
| 62 |
+
|
| 63 |
+
print("✓ Models loaded successfully!")
|
| 64 |
+
|
| 65 |
+
def detect(self, image):
|
| 66 |
+
"""
|
| 67 |
+
Detect forgeries in document image
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
image: PIL Image or numpy array
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
overlay_image: Image with detection overlay
|
| 74 |
+
results_json: Detection results as JSON
|
| 75 |
+
"""
|
| 76 |
+
# Convert PIL to numpy
|
| 77 |
+
if isinstance(image, Image.Image):
|
| 78 |
+
image = np.array(image)
|
| 79 |
+
|
| 80 |
+
# Convert to RGB
|
| 81 |
+
if len(image.shape) == 2:
|
| 82 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
| 83 |
+
elif image.shape[2] == 4:
|
| 84 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
| 85 |
+
|
| 86 |
+
original_image = image.copy()
|
| 87 |
+
|
| 88 |
+
# Preprocess
|
| 89 |
+
preprocessed, _ = self.preprocessor(image, None)
|
| 90 |
+
|
| 91 |
+
# Augment
|
| 92 |
+
augmented = self.augmentation(preprocessed, None)
|
| 93 |
+
image_tensor = augmented['image'].unsqueeze(0).to(self.device)
|
| 94 |
+
|
| 95 |
+
# Run localization
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
logits, decoder_features = self.model(image_tensor)
|
| 98 |
+
prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
|
| 99 |
+
|
| 100 |
+
# Refine mask
|
| 101 |
+
binary_mask = (prob_map > 0.5).astype(np.uint8)
|
| 102 |
+
refined_mask = self.mask_refiner.refine(binary_mask, original_size=original_image.shape[:2])
|
| 103 |
+
|
| 104 |
+
# Extract regions
|
| 105 |
+
regions = self.region_extractor.extract(refined_mask, prob_map, original_image)
|
| 106 |
+
|
| 107 |
+
# Classify regions
|
| 108 |
+
results = []
|
| 109 |
+
for region in regions:
|
| 110 |
+
# Extract features
|
| 111 |
+
features = self.feature_extractor.extract(
|
| 112 |
+
preprocessed,
|
| 113 |
+
region['region_mask'],
|
| 114 |
+
[f.cpu() for f in decoder_features]
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Classify
|
| 118 |
+
predictions, confidences = self.classifier.predict(features)
|
| 119 |
+
forgery_type = int(predictions[0])
|
| 120 |
+
confidence = float(confidences[0])
|
| 121 |
+
|
| 122 |
+
if confidence > 0.6: # Confidence threshold
|
| 123 |
+
results.append({
|
| 124 |
+
'region_id': region['region_id'],
|
| 125 |
+
'bounding_box': region['bounding_box'],
|
| 126 |
+
'forgery_type': CLASS_NAMES[forgery_type],
|
| 127 |
+
'confidence': confidence
|
| 128 |
+
})
|
| 129 |
+
|
| 130 |
+
# Create visualization
|
| 131 |
+
overlay = self._create_overlay(original_image, results)
|
| 132 |
+
|
| 133 |
+
# Create JSON response
|
| 134 |
+
json_results = {
|
| 135 |
+
'num_detections': len(results),
|
| 136 |
+
'detections': results,
|
| 137 |
+
'model_info': {
|
| 138 |
+
'segmentation_dice': '75%',
|
| 139 |
+
'classifier_accuracy': '92%'
|
| 140 |
+
}
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
return overlay, json_results
|
| 144 |
+
|
| 145 |
+
def _create_overlay(self, image, results):
|
| 146 |
+
"""Create overlay visualization"""
|
| 147 |
+
overlay = image.copy()
|
| 148 |
+
|
| 149 |
+
# Draw bounding boxes and labels
|
| 150 |
+
for result in results:
|
| 151 |
+
bbox = result['bounding_box']
|
| 152 |
+
x, y, w, h = bbox
|
| 153 |
+
|
| 154 |
+
forgery_type = result['forgery_type']
|
| 155 |
+
confidence = result['confidence']
|
| 156 |
+
|
| 157 |
+
# Get color
|
| 158 |
+
forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0]
|
| 159 |
+
color = CLASS_COLORS[forgery_id]
|
| 160 |
+
|
| 161 |
+
# Draw rectangle
|
| 162 |
+
cv2.rectangle(overlay, (x, y), (x+w, y+h), color, 2)
|
| 163 |
+
|
| 164 |
+
# Draw label
|
| 165 |
+
label = f"{forgery_type}: {confidence:.1%}"
|
| 166 |
+
label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
|
| 167 |
+
cv2.rectangle(overlay, (x, y-label_size[1]-10), (x+label_size[0], y), color, -1)
|
| 168 |
+
cv2.putText(overlay, label, (x, y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
|
| 169 |
+
|
| 170 |
+
# Add legend
|
| 171 |
+
if len(results) > 0:
|
| 172 |
+
legend_y = 30
|
| 173 |
+
cv2.putText(overlay, f"Detected {len(results)} forgery region(s)",
|
| 174 |
+
(10, legend_y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
|
| 175 |
+
|
| 176 |
+
return overlay
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# Initialize detector
|
| 180 |
+
detector = ForgeryDetector()
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def detect_forgery(image):
|
| 184 |
+
"""Gradio interface function"""
|
| 185 |
+
try:
|
| 186 |
+
overlay, results = detector.detect(image)
|
| 187 |
+
return overlay, json.dumps(results, indent=2)
|
| 188 |
+
except Exception as e:
|
| 189 |
+
return None, f"Error: {str(e)}"
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# Create Gradio interface
|
| 193 |
+
demo = gr.Interface(
|
| 194 |
+
fn=detect_forgery,
|
| 195 |
+
inputs=gr.Image(type="pil", label="Upload Document Image"),
|
| 196 |
+
outputs=[
|
| 197 |
+
gr.Image(type="numpy", label="Detection Result"),
|
| 198 |
+
gr.JSON(label="Detection Details")
|
| 199 |
+
],
|
| 200 |
+
title="📄 Document Forgery Detector",
|
| 201 |
+
description="""
|
| 202 |
+
Upload a document image to detect and classify forgeries.
|
| 203 |
+
|
| 204 |
+
**Supported Forgery Types:**
|
| 205 |
+
- 🔴 Copy-Move: Duplicated regions within the document
|
| 206 |
+
- 🟢 Splicing: Content from different sources
|
| 207 |
+
- 🔵 Generation: AI-generated or synthesized content
|
| 208 |
+
|
| 209 |
+
**Model Performance:**
|
| 210 |
+
- Localization: 75% Dice Score
|
| 211 |
+
- Classification: 92% Accuracy
|
| 212 |
+
""",
|
| 213 |
+
examples=[
|
| 214 |
+
["examples/sample1.jpg"],
|
| 215 |
+
["examples/sample2.jpg"],
|
| 216 |
+
],
|
| 217 |
+
article="""
|
| 218 |
+
### About
|
| 219 |
+
This model uses a hybrid deep learning approach:
|
| 220 |
+
1. **Localization**: MobileNetV3-Small + UNet-Lite (detects WHERE)
|
| 221 |
+
2. **Classification**: LightGBM with hybrid features (detects WHAT)
|
| 222 |
+
|
| 223 |
+
Trained on DocTamper dataset (140K samples).
|
| 224 |
+
""",
|
| 225 |
+
theme=gr.themes.Soft(),
|
| 226 |
+
allow_flagging="never"
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
if __name__ == "__main__":
|
| 231 |
+
demo.launch()
|
config.yaml
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hybrid Document Forgery Detection - Configuration
|
| 2 |
+
|
| 3 |
+
# System Settings
|
| 4 |
+
system:
|
| 5 |
+
device: cuda # cuda or cpu
|
| 6 |
+
num_workers: 0 # Reduced to avoid multiprocessing errors
|
| 7 |
+
pin_memory: true
|
| 8 |
+
seed: 42
|
| 9 |
+
|
| 10 |
+
# Data Settings
|
| 11 |
+
data:
|
| 12 |
+
image_size: 384
|
| 13 |
+
batch_size: 8 # Reduced for 16GB RAM
|
| 14 |
+
num_classes: 3 # copy_move, splicing, text_substitution
|
| 15 |
+
|
| 16 |
+
# Dataset paths
|
| 17 |
+
datasets:
|
| 18 |
+
doctamper:
|
| 19 |
+
path: datasets/DocTamper
|
| 20 |
+
type: lmdb
|
| 21 |
+
has_pixel_mask: true
|
| 22 |
+
min_region_area: 0.001 # 0.1%
|
| 23 |
+
|
| 24 |
+
rtm:
|
| 25 |
+
path: datasets/RealTextManipulation
|
| 26 |
+
type: folder
|
| 27 |
+
has_pixel_mask: true
|
| 28 |
+
min_region_area: 0.0003 # 0.03%
|
| 29 |
+
|
| 30 |
+
casia:
|
| 31 |
+
path: datasets/CASIA 1.0 dataset
|
| 32 |
+
type: folder
|
| 33 |
+
has_pixel_mask: false
|
| 34 |
+
min_region_area: 0.001 # 0.1%
|
| 35 |
+
skip_deskew: true
|
| 36 |
+
skip_denoising: true
|
| 37 |
+
|
| 38 |
+
receipts:
|
| 39 |
+
path: datasets/findit2
|
| 40 |
+
type: folder
|
| 41 |
+
has_pixel_mask: true
|
| 42 |
+
min_region_area: 0.0005 # 0.05%
|
| 43 |
+
|
| 44 |
+
fcd:
|
| 45 |
+
path: datasets/DocTamper/DocTamperV1-FCD
|
| 46 |
+
type: lmdb
|
| 47 |
+
has_pixel_mask: true
|
| 48 |
+
min_region_area: 0.00035 # 0.035% (larger forgeries, keep 99%)
|
| 49 |
+
|
| 50 |
+
scd:
|
| 51 |
+
path: datasets/DocTamper/DocTamperV1-SCD
|
| 52 |
+
type: lmdb
|
| 53 |
+
has_pixel_mask: true
|
| 54 |
+
min_region_area: 0.00009 # 0.009% (small forgeries, keep 91.5%)
|
| 55 |
+
|
| 56 |
+
# Chunked training for DocTamper (RAM constraint)
|
| 57 |
+
chunked_training:
|
| 58 |
+
enabled: true
|
| 59 |
+
dataset: doctamper
|
| 60 |
+
chunks:
|
| 61 |
+
- {start: 0.0, end: 0.25, name: "chunk_1"}
|
| 62 |
+
- {start: 0.25, end: 0.5, name: "chunk_2"}
|
| 63 |
+
- {start: 0.5, end: 0.75, name: "chunk_3"}
|
| 64 |
+
- {start: 0.75, end: 1.0, name: "chunk_4"}
|
| 65 |
+
|
| 66 |
+
# Mixed dataset training (TrainingSet + FCD + SCD)
|
| 67 |
+
mixing_ratios:
|
| 68 |
+
doctamper: 0.70 # 70% TrainingSet (maintains baseline)
|
| 69 |
+
scd: 0.20 # 20% SCD (handles small forgeries, 0.88% avg)
|
| 70 |
+
fcd: 0.10 # 10% FCD (adds diversity, 3.55% avg)
|
| 71 |
+
|
| 72 |
+
# Preprocessing
|
| 73 |
+
preprocessing:
|
| 74 |
+
deskew: true
|
| 75 |
+
normalize: true
|
| 76 |
+
noise_threshold: 15.0 # Laplacian variance threshold
|
| 77 |
+
median_filter_size: 3
|
| 78 |
+
gaussian_sigma: 0.8
|
| 79 |
+
|
| 80 |
+
# Dataset-aware preprocessing
|
| 81 |
+
dataset_specific:
|
| 82 |
+
casia:
|
| 83 |
+
deskew: false
|
| 84 |
+
denoising: false
|
| 85 |
+
|
| 86 |
+
# Augmentation (Training only)
|
| 87 |
+
augmentation:
|
| 88 |
+
enabled: true
|
| 89 |
+
|
| 90 |
+
# Common augmentations
|
| 91 |
+
common:
|
| 92 |
+
- {type: "noise", prob: 0.3}
|
| 93 |
+
- {type: "motion_blur", prob: 0.2}
|
| 94 |
+
- {type: "jpeg_compression", prob: 0.3, quality: [60, 95]}
|
| 95 |
+
- {type: "lighting", prob: 0.3}
|
| 96 |
+
- {type: "perspective", prob: 0.2}
|
| 97 |
+
|
| 98 |
+
# Dataset-specific augmentations
|
| 99 |
+
receipts:
|
| 100 |
+
- {type: "stain", prob: 0.2}
|
| 101 |
+
- {type: "fold", prob: 0.15}
|
| 102 |
+
|
| 103 |
+
# Model Architecture
|
| 104 |
+
model:
|
| 105 |
+
# Encoder
|
| 106 |
+
encoder:
|
| 107 |
+
name: mobilenetv3_small_100
|
| 108 |
+
pretrained: true
|
| 109 |
+
features_only: true
|
| 110 |
+
|
| 111 |
+
# Decoder
|
| 112 |
+
decoder:
|
| 113 |
+
name: unet_lite
|
| 114 |
+
channels: [16, 24, 40, 48, 96] # MobileNetV3-Small feature channels
|
| 115 |
+
upsampling: bilinear
|
| 116 |
+
use_depthwise_separable: true
|
| 117 |
+
|
| 118 |
+
# Output
|
| 119 |
+
output_channels: 1 # Binary forgery mask
|
| 120 |
+
|
| 121 |
+
# Loss Function
|
| 122 |
+
loss:
|
| 123 |
+
# Dataset-aware loss
|
| 124 |
+
use_dice: true # Only for datasets with pixel masks
|
| 125 |
+
bce_weight: 1.0
|
| 126 |
+
dice_weight: 1.0
|
| 127 |
+
|
| 128 |
+
# Training
|
| 129 |
+
training:
|
| 130 |
+
epochs: 30 # Per chunk (increased for single-pass training)
|
| 131 |
+
learning_rate: 0.001 # Higher initial LR for faster convergence
|
| 132 |
+
weight_decay: 0.0001 # Slight increase for better regularization
|
| 133 |
+
|
| 134 |
+
# Optimizer
|
| 135 |
+
optimizer: adamw
|
| 136 |
+
|
| 137 |
+
# Scheduler
|
| 138 |
+
scheduler:
|
| 139 |
+
type: cosine_annealing_warm_restarts
|
| 140 |
+
T_0: 10 # Restart every 10 epochs
|
| 141 |
+
T_mult: 2 # Double restart period each time
|
| 142 |
+
warmup_epochs: 3 # Warmup for first 3 epochs
|
| 143 |
+
min_lr: 0.00001 # End at 1/100th of initial LR
|
| 144 |
+
|
| 145 |
+
# Early stopping
|
| 146 |
+
early_stopping:
|
| 147 |
+
enabled: true
|
| 148 |
+
patience: 10 # Increased to allow more exploration
|
| 149 |
+
min_delta: 0.0005 # Accept smaller improvements (0.05%)
|
| 150 |
+
restore_best_weights: true # Restore best model when stopping
|
| 151 |
+
monitor: val_dice
|
| 152 |
+
mode: max
|
| 153 |
+
|
| 154 |
+
# Checkpointing
|
| 155 |
+
checkpoint:
|
| 156 |
+
save_best: true
|
| 157 |
+
save_every: 5 # Save every 5 epochs
|
| 158 |
+
save_last: true # Also save last checkpoint
|
| 159 |
+
monitor: val_dice
|
| 160 |
+
|
| 161 |
+
# Mask Refinement
|
| 162 |
+
mask_refinement:
|
| 163 |
+
threshold: 0.5
|
| 164 |
+
morphology:
|
| 165 |
+
closing_kernel: 5
|
| 166 |
+
opening_kernel: 3
|
| 167 |
+
|
| 168 |
+
# Adaptive thresholds per dataset
|
| 169 |
+
min_region_area:
|
| 170 |
+
rtm: 0.0003
|
| 171 |
+
receipts: 0.0005
|
| 172 |
+
default: 0.001
|
| 173 |
+
|
| 174 |
+
# Feature Extraction
|
| 175 |
+
features:
|
| 176 |
+
# Deep features
|
| 177 |
+
deep:
|
| 178 |
+
enabled: true
|
| 179 |
+
pooling: gap # Global Average Pooling
|
| 180 |
+
|
| 181 |
+
# Statistical & Shape features
|
| 182 |
+
statistical:
|
| 183 |
+
enabled: true
|
| 184 |
+
features:
|
| 185 |
+
- area
|
| 186 |
+
- perimeter
|
| 187 |
+
- aspect_ratio
|
| 188 |
+
- solidity
|
| 189 |
+
- eccentricity
|
| 190 |
+
- entropy
|
| 191 |
+
|
| 192 |
+
# Frequency-domain features
|
| 193 |
+
frequency:
|
| 194 |
+
enabled: true
|
| 195 |
+
features:
|
| 196 |
+
- dct_coefficients
|
| 197 |
+
- high_frequency_energy
|
| 198 |
+
- wavelet_energy
|
| 199 |
+
|
| 200 |
+
# Noise & ELA features
|
| 201 |
+
noise:
|
| 202 |
+
enabled: true
|
| 203 |
+
features:
|
| 204 |
+
- ela_mean
|
| 205 |
+
- ela_variance
|
| 206 |
+
- noise_residual
|
| 207 |
+
|
| 208 |
+
# OCR-consistency features (text documents only)
|
| 209 |
+
ocr:
|
| 210 |
+
enabled: true
|
| 211 |
+
gated: true # Only for text documents
|
| 212 |
+
features:
|
| 213 |
+
- confidence_deviation
|
| 214 |
+
- spacing_irregularity
|
| 215 |
+
- stroke_width_variation
|
| 216 |
+
|
| 217 |
+
# Feature normalization
|
| 218 |
+
normalization:
|
| 219 |
+
method: standard_scaler
|
| 220 |
+
handle_missing: true
|
| 221 |
+
|
| 222 |
+
# LightGBM Classifier
|
| 223 |
+
classifier:
|
| 224 |
+
model: lightgbm
|
| 225 |
+
params:
|
| 226 |
+
objective: multiclass
|
| 227 |
+
num_class: 3
|
| 228 |
+
boosting_type: gbdt
|
| 229 |
+
num_leaves: 31
|
| 230 |
+
learning_rate: 0.05
|
| 231 |
+
n_estimators: 200
|
| 232 |
+
max_depth: 7
|
| 233 |
+
min_child_samples: 20
|
| 234 |
+
subsample: 0.8
|
| 235 |
+
colsample_bytree: 0.8
|
| 236 |
+
reg_alpha: 0.1
|
| 237 |
+
reg_lambda: 0.1
|
| 238 |
+
random_state: 42
|
| 239 |
+
|
| 240 |
+
# Confidence threshold
|
| 241 |
+
confidence_threshold: 0.6
|
| 242 |
+
|
| 243 |
+
# Metrics
|
| 244 |
+
metrics:
|
| 245 |
+
# Localization metrics (only for datasets with pixel masks)
|
| 246 |
+
localization:
|
| 247 |
+
- iou
|
| 248 |
+
- dice
|
| 249 |
+
- precision
|
| 250 |
+
- recall
|
| 251 |
+
|
| 252 |
+
# Classification metrics
|
| 253 |
+
classification:
|
| 254 |
+
- accuracy
|
| 255 |
+
- f1_score
|
| 256 |
+
- precision
|
| 257 |
+
- recall
|
| 258 |
+
- confusion_matrix
|
| 259 |
+
|
| 260 |
+
# Dataset-aware metric computation
|
| 261 |
+
compute_localization:
|
| 262 |
+
doctamper: true
|
| 263 |
+
rtm: true
|
| 264 |
+
casia: false
|
| 265 |
+
receipts: true
|
| 266 |
+
|
| 267 |
+
# Outputs
|
| 268 |
+
outputs:
|
| 269 |
+
base_dir: outputs
|
| 270 |
+
|
| 271 |
+
# Subdirectories
|
| 272 |
+
checkpoints: outputs/checkpoints
|
| 273 |
+
logs: outputs/logs
|
| 274 |
+
plots: outputs/plots
|
| 275 |
+
results: outputs/results
|
| 276 |
+
|
| 277 |
+
# Visualization
|
| 278 |
+
visualization:
|
| 279 |
+
save_mask: true
|
| 280 |
+
save_overlay: true
|
| 281 |
+
save_json: true
|
| 282 |
+
overlay_alpha: 0.5
|
| 283 |
+
colormap: jet
|
| 284 |
+
|
| 285 |
+
# Deployment
|
| 286 |
+
deployment:
|
| 287 |
+
export_onnx: true
|
| 288 |
+
onnx_path: outputs/model.onnx
|
| 289 |
+
quantization: false
|
| 290 |
+
opset_version: 14
|
| 291 |
+
|
| 292 |
+
# Logging
|
| 293 |
+
logging:
|
| 294 |
+
level: INFO
|
| 295 |
+
tensorboard: true
|
| 296 |
+
csv: true
|
| 297 |
+
console: true
|
models/best_doctamper.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d049ca9d4dc28c8d01519f8faab1ec131a05de877da9703ee5bb0e9322095ad2
|
| 3 |
+
size 14283981
|
models/classifier/classifier_metadata.json
ADDED
|
@@ -0,0 +1,821 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"confidence_threshold": 0.6,
|
| 3 |
+
"class_names": [
|
| 4 |
+
"copy_move",
|
| 5 |
+
"splicing",
|
| 6 |
+
"text_substitution"
|
| 7 |
+
],
|
| 8 |
+
"feature_names": [
|
| 9 |
+
"deep_0",
|
| 10 |
+
"deep_1",
|
| 11 |
+
"deep_2",
|
| 12 |
+
"deep_3",
|
| 13 |
+
"deep_4",
|
| 14 |
+
"deep_5",
|
| 15 |
+
"deep_6",
|
| 16 |
+
"deep_7",
|
| 17 |
+
"deep_8",
|
| 18 |
+
"deep_9",
|
| 19 |
+
"deep_10",
|
| 20 |
+
"deep_11",
|
| 21 |
+
"deep_12",
|
| 22 |
+
"deep_13",
|
| 23 |
+
"deep_14",
|
| 24 |
+
"deep_15",
|
| 25 |
+
"deep_16",
|
| 26 |
+
"deep_17",
|
| 27 |
+
"deep_18",
|
| 28 |
+
"deep_19",
|
| 29 |
+
"deep_20",
|
| 30 |
+
"deep_21",
|
| 31 |
+
"deep_22",
|
| 32 |
+
"deep_23",
|
| 33 |
+
"deep_24",
|
| 34 |
+
"deep_25",
|
| 35 |
+
"deep_26",
|
| 36 |
+
"deep_27",
|
| 37 |
+
"deep_28",
|
| 38 |
+
"deep_29",
|
| 39 |
+
"deep_30",
|
| 40 |
+
"deep_31",
|
| 41 |
+
"deep_32",
|
| 42 |
+
"deep_33",
|
| 43 |
+
"deep_34",
|
| 44 |
+
"deep_35",
|
| 45 |
+
"deep_36",
|
| 46 |
+
"deep_37",
|
| 47 |
+
"deep_38",
|
| 48 |
+
"deep_39",
|
| 49 |
+
"deep_40",
|
| 50 |
+
"deep_41",
|
| 51 |
+
"deep_42",
|
| 52 |
+
"deep_43",
|
| 53 |
+
"deep_44",
|
| 54 |
+
"deep_45",
|
| 55 |
+
"deep_46",
|
| 56 |
+
"deep_47",
|
| 57 |
+
"deep_48",
|
| 58 |
+
"deep_49",
|
| 59 |
+
"deep_50",
|
| 60 |
+
"deep_51",
|
| 61 |
+
"deep_52",
|
| 62 |
+
"deep_53",
|
| 63 |
+
"deep_54",
|
| 64 |
+
"deep_55",
|
| 65 |
+
"deep_56",
|
| 66 |
+
"deep_57",
|
| 67 |
+
"deep_58",
|
| 68 |
+
"deep_59",
|
| 69 |
+
"deep_60",
|
| 70 |
+
"deep_61",
|
| 71 |
+
"deep_62",
|
| 72 |
+
"deep_63",
|
| 73 |
+
"deep_64",
|
| 74 |
+
"deep_65",
|
| 75 |
+
"deep_66",
|
| 76 |
+
"deep_67",
|
| 77 |
+
"deep_68",
|
| 78 |
+
"deep_69",
|
| 79 |
+
"deep_70",
|
| 80 |
+
"deep_71",
|
| 81 |
+
"deep_72",
|
| 82 |
+
"deep_73",
|
| 83 |
+
"deep_74",
|
| 84 |
+
"deep_75",
|
| 85 |
+
"deep_76",
|
| 86 |
+
"deep_77",
|
| 87 |
+
"deep_78",
|
| 88 |
+
"deep_79",
|
| 89 |
+
"deep_80",
|
| 90 |
+
"deep_81",
|
| 91 |
+
"deep_82",
|
| 92 |
+
"deep_83",
|
| 93 |
+
"deep_84",
|
| 94 |
+
"deep_85",
|
| 95 |
+
"deep_86",
|
| 96 |
+
"deep_87",
|
| 97 |
+
"deep_88",
|
| 98 |
+
"deep_89",
|
| 99 |
+
"deep_90",
|
| 100 |
+
"deep_91",
|
| 101 |
+
"deep_92",
|
| 102 |
+
"deep_93",
|
| 103 |
+
"deep_94",
|
| 104 |
+
"deep_95",
|
| 105 |
+
"deep_96",
|
| 106 |
+
"deep_97",
|
| 107 |
+
"deep_98",
|
| 108 |
+
"deep_99",
|
| 109 |
+
"deep_100",
|
| 110 |
+
"deep_101",
|
| 111 |
+
"deep_102",
|
| 112 |
+
"deep_103",
|
| 113 |
+
"deep_104",
|
| 114 |
+
"deep_105",
|
| 115 |
+
"deep_106",
|
| 116 |
+
"deep_107",
|
| 117 |
+
"deep_108",
|
| 118 |
+
"deep_109",
|
| 119 |
+
"deep_110",
|
| 120 |
+
"deep_111",
|
| 121 |
+
"deep_112",
|
| 122 |
+
"deep_113",
|
| 123 |
+
"deep_114",
|
| 124 |
+
"deep_115",
|
| 125 |
+
"deep_116",
|
| 126 |
+
"deep_117",
|
| 127 |
+
"deep_118",
|
| 128 |
+
"deep_119",
|
| 129 |
+
"deep_120",
|
| 130 |
+
"deep_121",
|
| 131 |
+
"deep_122",
|
| 132 |
+
"deep_123",
|
| 133 |
+
"deep_124",
|
| 134 |
+
"deep_125",
|
| 135 |
+
"deep_126",
|
| 136 |
+
"deep_127",
|
| 137 |
+
"deep_128",
|
| 138 |
+
"deep_129",
|
| 139 |
+
"deep_130",
|
| 140 |
+
"deep_131",
|
| 141 |
+
"deep_132",
|
| 142 |
+
"deep_133",
|
| 143 |
+
"deep_134",
|
| 144 |
+
"deep_135",
|
| 145 |
+
"deep_136",
|
| 146 |
+
"deep_137",
|
| 147 |
+
"deep_138",
|
| 148 |
+
"deep_139",
|
| 149 |
+
"deep_140",
|
| 150 |
+
"deep_141",
|
| 151 |
+
"deep_142",
|
| 152 |
+
"deep_143",
|
| 153 |
+
"deep_144",
|
| 154 |
+
"deep_145",
|
| 155 |
+
"deep_146",
|
| 156 |
+
"deep_147",
|
| 157 |
+
"deep_148",
|
| 158 |
+
"deep_149",
|
| 159 |
+
"deep_150",
|
| 160 |
+
"deep_151",
|
| 161 |
+
"deep_152",
|
| 162 |
+
"deep_153",
|
| 163 |
+
"deep_154",
|
| 164 |
+
"deep_155",
|
| 165 |
+
"deep_156",
|
| 166 |
+
"deep_157",
|
| 167 |
+
"deep_158",
|
| 168 |
+
"deep_159",
|
| 169 |
+
"deep_160",
|
| 170 |
+
"deep_161",
|
| 171 |
+
"deep_162",
|
| 172 |
+
"deep_163",
|
| 173 |
+
"deep_164",
|
| 174 |
+
"deep_165",
|
| 175 |
+
"deep_166",
|
| 176 |
+
"deep_167",
|
| 177 |
+
"deep_168",
|
| 178 |
+
"deep_169",
|
| 179 |
+
"deep_170",
|
| 180 |
+
"deep_171",
|
| 181 |
+
"deep_172",
|
| 182 |
+
"deep_173",
|
| 183 |
+
"deep_174",
|
| 184 |
+
"deep_175",
|
| 185 |
+
"deep_176",
|
| 186 |
+
"deep_177",
|
| 187 |
+
"deep_178",
|
| 188 |
+
"deep_179",
|
| 189 |
+
"deep_180",
|
| 190 |
+
"deep_181",
|
| 191 |
+
"deep_182",
|
| 192 |
+
"deep_183",
|
| 193 |
+
"deep_184",
|
| 194 |
+
"deep_185",
|
| 195 |
+
"deep_186",
|
| 196 |
+
"deep_187",
|
| 197 |
+
"deep_188",
|
| 198 |
+
"deep_189",
|
| 199 |
+
"deep_190",
|
| 200 |
+
"deep_191",
|
| 201 |
+
"deep_192",
|
| 202 |
+
"deep_193",
|
| 203 |
+
"deep_194",
|
| 204 |
+
"deep_195",
|
| 205 |
+
"deep_196",
|
| 206 |
+
"deep_197",
|
| 207 |
+
"deep_198",
|
| 208 |
+
"deep_199",
|
| 209 |
+
"deep_200",
|
| 210 |
+
"deep_201",
|
| 211 |
+
"deep_202",
|
| 212 |
+
"deep_203",
|
| 213 |
+
"deep_204",
|
| 214 |
+
"deep_205",
|
| 215 |
+
"deep_206",
|
| 216 |
+
"deep_207",
|
| 217 |
+
"deep_208",
|
| 218 |
+
"deep_209",
|
| 219 |
+
"deep_210",
|
| 220 |
+
"deep_211",
|
| 221 |
+
"deep_212",
|
| 222 |
+
"deep_213",
|
| 223 |
+
"deep_214",
|
| 224 |
+
"deep_215",
|
| 225 |
+
"deep_216",
|
| 226 |
+
"deep_217",
|
| 227 |
+
"deep_218",
|
| 228 |
+
"deep_219",
|
| 229 |
+
"deep_220",
|
| 230 |
+
"deep_221",
|
| 231 |
+
"deep_222",
|
| 232 |
+
"deep_223",
|
| 233 |
+
"deep_224",
|
| 234 |
+
"deep_225",
|
| 235 |
+
"deep_226",
|
| 236 |
+
"deep_227",
|
| 237 |
+
"deep_228",
|
| 238 |
+
"deep_229",
|
| 239 |
+
"deep_230",
|
| 240 |
+
"deep_231",
|
| 241 |
+
"deep_232",
|
| 242 |
+
"deep_233",
|
| 243 |
+
"deep_234",
|
| 244 |
+
"deep_235",
|
| 245 |
+
"deep_236",
|
| 246 |
+
"deep_237",
|
| 247 |
+
"deep_238",
|
| 248 |
+
"deep_239",
|
| 249 |
+
"deep_240",
|
| 250 |
+
"deep_241",
|
| 251 |
+
"deep_242",
|
| 252 |
+
"deep_243",
|
| 253 |
+
"deep_244",
|
| 254 |
+
"deep_245",
|
| 255 |
+
"deep_246",
|
| 256 |
+
"deep_247",
|
| 257 |
+
"deep_248",
|
| 258 |
+
"deep_249",
|
| 259 |
+
"deep_250",
|
| 260 |
+
"deep_251",
|
| 261 |
+
"deep_252",
|
| 262 |
+
"deep_253",
|
| 263 |
+
"deep_254",
|
| 264 |
+
"deep_255",
|
| 265 |
+
"area",
|
| 266 |
+
"perimeter",
|
| 267 |
+
"aspect_ratio",
|
| 268 |
+
"solidity",
|
| 269 |
+
"eccentricity",
|
| 270 |
+
"entropy",
|
| 271 |
+
"dct_mean",
|
| 272 |
+
"dct_std",
|
| 273 |
+
"high_freq_energy",
|
| 274 |
+
"wavelet_cA",
|
| 275 |
+
"wavelet_cH",
|
| 276 |
+
"wavelet_cV",
|
| 277 |
+
"wavelet_cD",
|
| 278 |
+
"wavelet_entropy_H",
|
| 279 |
+
"wavelet_entropy_V",
|
| 280 |
+
"wavelet_entropy_D",
|
| 281 |
+
"ela_mean",
|
| 282 |
+
"ela_var",
|
| 283 |
+
"ela_max",
|
| 284 |
+
"noise_residual_mean",
|
| 285 |
+
"noise_residual_var",
|
| 286 |
+
"ocr_conf_mean",
|
| 287 |
+
"ocr_conf_std",
|
| 288 |
+
"spacing_irregularity",
|
| 289 |
+
"text_density",
|
| 290 |
+
"stroke_mean",
|
| 291 |
+
"stroke_std"
|
| 292 |
+
],
|
| 293 |
+
"feature_importance": [
|
| 294 |
+
151.5697784423828,
|
| 295 |
+
8.955550193786621,
|
| 296 |
+
32.9064998626709,
|
| 297 |
+
151.0029697418213,
|
| 298 |
+
19.174699783325195,
|
| 299 |
+
157.97871017456055,
|
| 300 |
+
45.12229919433594,
|
| 301 |
+
19.72992992401123,
|
| 302 |
+
105.08611106872559,
|
| 303 |
+
0.0,
|
| 304 |
+
148.97894096374512,
|
| 305 |
+
35.71831035614014,
|
| 306 |
+
50.15155029296875,
|
| 307 |
+
71.74272060394287,
|
| 308 |
+
43.958970069885254,
|
| 309 |
+
129.9348111152649,
|
| 310 |
+
27.99122953414917,
|
| 311 |
+
61.592909812927246,
|
| 312 |
+
295.4245676994324,
|
| 313 |
+
61.00736045837402,
|
| 314 |
+
28.548550128936768,
|
| 315 |
+
0.0,
|
| 316 |
+
54.50248908996582,
|
| 317 |
+
93.74169921875,
|
| 318 |
+
120.9488091468811,
|
| 319 |
+
148.32109832763672,
|
| 320 |
+
30.55735969543457,
|
| 321 |
+
59.058170318603516,
|
| 322 |
+
82.7595911026001,
|
| 323 |
+
49.24997901916504,
|
| 324 |
+
0.0,
|
| 325 |
+
23.502280235290527,
|
| 326 |
+
392.399715423584,
|
| 327 |
+
551.6174192428589,
|
| 328 |
+
0.0,
|
| 329 |
+
50.8812894821167,
|
| 330 |
+
60.7820405960083,
|
| 331 |
+
78.98891925811768,
|
| 332 |
+
0.0,
|
| 333 |
+
9.173580169677734,
|
| 334 |
+
631.6932668685913,
|
| 335 |
+
42.097740173339844,
|
| 336 |
+
305.0536642074585,
|
| 337 |
+
416.94709300994873,
|
| 338 |
+
92.70171976089478,
|
| 339 |
+
66.76712036132812,
|
| 340 |
+
1435.1315097808838,
|
| 341 |
+
0.0,
|
| 342 |
+
126.6096019744873,
|
| 343 |
+
111.61981964111328,
|
| 344 |
+
124.68002033233643,
|
| 345 |
+
46.16030025482178,
|
| 346 |
+
12.660099983215332,
|
| 347 |
+
115.48313999176025,
|
| 348 |
+
86.43069076538086,
|
| 349 |
+
16.674290657043457,
|
| 350 |
+
110.49228954315186,
|
| 351 |
+
0.0,
|
| 352 |
+
98.00746059417725,
|
| 353 |
+
98.95538091659546,
|
| 354 |
+
41.432090759277344,
|
| 355 |
+
11.24590015411377,
|
| 356 |
+
65.1699800491333,
|
| 357 |
+
9.251449584960938,
|
| 358 |
+
100.24416923522949,
|
| 359 |
+
109.5842399597168,
|
| 360 |
+
83.83185005187988,
|
| 361 |
+
196.82151079177856,
|
| 362 |
+
0.0,
|
| 363 |
+
455.4096431732178,
|
| 364 |
+
120.69411087036133,
|
| 365 |
+
23.130990028381348,
|
| 366 |
+
18.21858024597168,
|
| 367 |
+
69.65920066833496,
|
| 368 |
+
82.33455085754395,
|
| 369 |
+
0.0,
|
| 370 |
+
82.21379089355469,
|
| 371 |
+
119.78182220458984,
|
| 372 |
+
65.07565069198608,
|
| 373 |
+
53.62262964248657,
|
| 374 |
+
247.53085803985596,
|
| 375 |
+
144.45191097259521,
|
| 376 |
+
38.63272047042847,
|
| 377 |
+
82.24878883361816,
|
| 378 |
+
60.303489685058594,
|
| 379 |
+
8.717499732971191,
|
| 380 |
+
412.6672077178955,
|
| 381 |
+
54.25755023956299,
|
| 382 |
+
0.0,
|
| 383 |
+
23.141600608825684,
|
| 384 |
+
62.88635063171387,
|
| 385 |
+
144.1060814857483,
|
| 386 |
+
352.47050952911377,
|
| 387 |
+
23.701799392700195,
|
| 388 |
+
180.19217205047607,
|
| 389 |
+
74.43132972717285,
|
| 390 |
+
0.0,
|
| 391 |
+
92.36961936950684,
|
| 392 |
+
418.40467262268066,
|
| 393 |
+
163.96015119552612,
|
| 394 |
+
136.4917197227478,
|
| 395 |
+
8.362039566040039,
|
| 396 |
+
10.378399848937988,
|
| 397 |
+
30.465800285339355,
|
| 398 |
+
47.935009479522705,
|
| 399 |
+
28.957390308380127,
|
| 400 |
+
61.46374034881592,
|
| 401 |
+
11.319199562072754,
|
| 402 |
+
142.72890949249268,
|
| 403 |
+
0.0,
|
| 404 |
+
140.48277807235718,
|
| 405 |
+
59.3709602355957,
|
| 406 |
+
9.517510414123535,
|
| 407 |
+
22.945700645446777,
|
| 408 |
+
85.35987043380737,
|
| 409 |
+
25.964330196380615,
|
| 410 |
+
18.778900146484375,
|
| 411 |
+
79.01968955993652,
|
| 412 |
+
74.93959999084473,
|
| 413 |
+
0.0,
|
| 414 |
+
36.94928026199341,
|
| 415 |
+
47.99788188934326,
|
| 416 |
+
84.99461078643799,
|
| 417 |
+
65.24014949798584,
|
| 418 |
+
128.61994075775146,
|
| 419 |
+
71.96449947357178,
|
| 420 |
+
0.0,
|
| 421 |
+
60.59358024597168,
|
| 422 |
+
0.0,
|
| 423 |
+
144.41107177734375,
|
| 424 |
+
119.25859117507935,
|
| 425 |
+
0.0,
|
| 426 |
+
29.235299110412598,
|
| 427 |
+
75.50409030914307,
|
| 428 |
+
0.0,
|
| 429 |
+
0.0,
|
| 430 |
+
133.30608654022217,
|
| 431 |
+
50.813700675964355,
|
| 432 |
+
7.879730224609375,
|
| 433 |
+
80.23723936080933,
|
| 434 |
+
28.72357988357544,
|
| 435 |
+
85.63543939590454,
|
| 436 |
+
88.70749998092651,
|
| 437 |
+
0.0,
|
| 438 |
+
38.14083003997803,
|
| 439 |
+
10.110199928283691,
|
| 440 |
+
223.45562982559204,
|
| 441 |
+
0.0,
|
| 442 |
+
189.3048586845398,
|
| 443 |
+
11.311699867248535,
|
| 444 |
+
87.91403198242188,
|
| 445 |
+
45.88195037841797,
|
| 446 |
+
57.93142032623291,
|
| 447 |
+
621.7998056411743,
|
| 448 |
+
151.6710205078125,
|
| 449 |
+
55.90662956237793,
|
| 450 |
+
310.18284845352173,
|
| 451 |
+
0.0,
|
| 452 |
+
37.39265060424805,
|
| 453 |
+
142.64961051940918,
|
| 454 |
+
86.32072973251343,
|
| 455 |
+
167.73473930358887,
|
| 456 |
+
135.1251916885376,
|
| 457 |
+
67.87245082855225,
|
| 458 |
+
25.777999877929688,
|
| 459 |
+
82.70090961456299,
|
| 460 |
+
160.77113008499146,
|
| 461 |
+
0.0,
|
| 462 |
+
109.31087112426758,
|
| 463 |
+
36.81955051422119,
|
| 464 |
+
21.341699600219727,
|
| 465 |
+
39.508570194244385,
|
| 466 |
+
0.0,
|
| 467 |
+
12.186599731445312,
|
| 468 |
+
52.13583946228027,
|
| 469 |
+
242.86930990219116,
|
| 470 |
+
0.0,
|
| 471 |
+
27.03380012512207,
|
| 472 |
+
11.51550006866455,
|
| 473 |
+
102.65280055999756,
|
| 474 |
+
8.523859977722168,
|
| 475 |
+
105.87909126281738,
|
| 476 |
+
0.0,
|
| 477 |
+
191.5287847518921,
|
| 478 |
+
16.16029930114746,
|
| 479 |
+
43.0986704826355,
|
| 480 |
+
0.0,
|
| 481 |
+
54.736299991607666,
|
| 482 |
+
145.84991836547852,
|
| 483 |
+
62.068660736083984,
|
| 484 |
+
72.52587032318115,
|
| 485 |
+
81.85652828216553,
|
| 486 |
+
25.7001895904541,
|
| 487 |
+
36.71660041809082,
|
| 488 |
+
78.73716068267822,
|
| 489 |
+
145.95945167541504,
|
| 490 |
+
146.47522068023682,
|
| 491 |
+
23.559300422668457,
|
| 492 |
+
39.53977966308594,
|
| 493 |
+
194.42743015289307,
|
| 494 |
+
66.81133842468262,
|
| 495 |
+
0.0,
|
| 496 |
+
156.6984510421753,
|
| 497 |
+
671.7460441589355,
|
| 498 |
+
38.70531988143921,
|
| 499 |
+
0.0,
|
| 500 |
+
356.6153998374939,
|
| 501 |
+
0.0,
|
| 502 |
+
0.0,
|
| 503 |
+
166.1197419166565,
|
| 504 |
+
0.0,
|
| 505 |
+
73.76784992218018,
|
| 506 |
+
82.50808954238892,
|
| 507 |
+
249.50656414031982,
|
| 508 |
+
21.96009922027588,
|
| 509 |
+
43.69997024536133,
|
| 510 |
+
0.0,
|
| 511 |
+
95.96379089355469,
|
| 512 |
+
80.70125961303711,
|
| 513 |
+
0.0,
|
| 514 |
+
0.0,
|
| 515 |
+
31.88983964920044,
|
| 516 |
+
301.3817310333252,
|
| 517 |
+
0.0,
|
| 518 |
+
15.77073049545288,
|
| 519 |
+
396.3671169281006,
|
| 520 |
+
83.96024990081787,
|
| 521 |
+
265.5281705856323,
|
| 522 |
+
47.332489013671875,
|
| 523 |
+
0.0,
|
| 524 |
+
268.84939098358154,
|
| 525 |
+
58.15328025817871,
|
| 526 |
+
31.172239780426025,
|
| 527 |
+
30.765819549560547,
|
| 528 |
+
10.469799995422363,
|
| 529 |
+
16.379559993743896,
|
| 530 |
+
28.163670539855957,
|
| 531 |
+
199.17678022384644,
|
| 532 |
+
112.94913101196289,
|
| 533 |
+
5.905869960784912,
|
| 534 |
+
719.0067505836487,
|
| 535 |
+
157.29250049591064,
|
| 536 |
+
92.6033205986023,
|
| 537 |
+
73.79398918151855,
|
| 538 |
+
24.25756072998047,
|
| 539 |
+
0.0,
|
| 540 |
+
31.15705966949463,
|
| 541 |
+
50.47894048690796,
|
| 542 |
+
73.0004301071167,
|
| 543 |
+
131.88961124420166,
|
| 544 |
+
0.0,
|
| 545 |
+
44.40921926498413,
|
| 546 |
+
59.08494997024536,
|
| 547 |
+
60.722700119018555,
|
| 548 |
+
108.21477127075195,
|
| 549 |
+
78.56892967224121,
|
| 550 |
+
486.87088108062744,
|
| 551 |
+
235.95975875854492,
|
| 552 |
+
1809.188328742981,
|
| 553 |
+
396.9979257583618,
|
| 554 |
+
441.098051071167,
|
| 555 |
+
218.83313035964966,
|
| 556 |
+
265.3398394584656,
|
| 557 |
+
595.3824620246887,
|
| 558 |
+
6126.337133407593,
|
| 559 |
+
3245.946928501129,
|
| 560 |
+
170.21856021881104,
|
| 561 |
+
262.3172616958618,
|
| 562 |
+
98.2627010345459,
|
| 563 |
+
146.45634078979492,
|
| 564 |
+
135.70992946624756,
|
| 565 |
+
34.09130001068115,
|
| 566 |
+
14156.531812667847,
|
| 567 |
+
227.55861043930054,
|
| 568 |
+
121.6160798072815,
|
| 569 |
+
409.0565061569214,
|
| 570 |
+
282.5465121269226,
|
| 571 |
+
481.5555577278137,
|
| 572 |
+
291.560200214386,
|
| 573 |
+
797.986575126648,
|
| 574 |
+
246.7717628479004,
|
| 575 |
+
6129.707794189453,
|
| 576 |
+
957.9258012771606,
|
| 577 |
+
4484.775461196899,
|
| 578 |
+
5722.659900188446,
|
| 579 |
+
393.6506414413452,
|
| 580 |
+
882.6219139099121,
|
| 581 |
+
264.54289960861206,
|
| 582 |
+
79.82537126541138,
|
| 583 |
+
228.20479917526245,
|
| 584 |
+
155.19043970108032,
|
| 585 |
+
319.6992588043213,
|
| 586 |
+
391.5327887535095,
|
| 587 |
+
2005.5544757843018,
|
| 588 |
+
0.0,
|
| 589 |
+
1028.816568851471,
|
| 590 |
+
577.8704214096069,
|
| 591 |
+
159.98183917999268,
|
| 592 |
+
138.31745052337646,
|
| 593 |
+
115.26242113113403,
|
| 594 |
+
117.50687980651855,
|
| 595 |
+
0.0,
|
| 596 |
+
270.78229904174805,
|
| 597 |
+
300.6347818374634,
|
| 598 |
+
164.85750007629395,
|
| 599 |
+
542.5208883285522,
|
| 600 |
+
10002.710669994354,
|
| 601 |
+
502.5058374404907,
|
| 602 |
+
6619.406281471252,
|
| 603 |
+
194.39686965942383,
|
| 604 |
+
0.0,
|
| 605 |
+
239.30037021636963,
|
| 606 |
+
129.93587112426758,
|
| 607 |
+
149.23295974731445,
|
| 608 |
+
57.12141132354736,
|
| 609 |
+
152.30589962005615,
|
| 610 |
+
590.8979144096375,
|
| 611 |
+
125.51728057861328,
|
| 612 |
+
216.1852297782898,
|
| 613 |
+
4445.603507041931,
|
| 614 |
+
0.0,
|
| 615 |
+
97.60689973831177,
|
| 616 |
+
497.5633420944214,
|
| 617 |
+
699.1335229873657,
|
| 618 |
+
159.68335962295532,
|
| 619 |
+
127.93899154663086,
|
| 620 |
+
148.00423860549927,
|
| 621 |
+
385.3561215400696,
|
| 622 |
+
1255.3204145431519,
|
| 623 |
+
170.33005905151367,
|
| 624 |
+
564.577874660492,
|
| 625 |
+
1513.99400806427,
|
| 626 |
+
254.163161277771,
|
| 627 |
+
782.5869626998901,
|
| 628 |
+
166.38124132156372,
|
| 629 |
+
4800.666547775269,
|
| 630 |
+
271.63431215286255,
|
| 631 |
+
225.10281944274902,
|
| 632 |
+
674.5281610488892,
|
| 633 |
+
198.04610967636108,
|
| 634 |
+
4262.1786432266235,
|
| 635 |
+
0.0,
|
| 636 |
+
0.0,
|
| 637 |
+
749.2932777404785,
|
| 638 |
+
50.16440010070801,
|
| 639 |
+
350.71588039398193,
|
| 640 |
+
169.4644889831543,
|
| 641 |
+
3843.8212938308716,
|
| 642 |
+
0.0,
|
| 643 |
+
0.0,
|
| 644 |
+
1463.2607378959656,
|
| 645 |
+
0.0,
|
| 646 |
+
914.5419778823853,
|
| 647 |
+
213.03434944152832,
|
| 648 |
+
32.90106964111328,
|
| 649 |
+
119.6264705657959,
|
| 650 |
+
137.204270362854,
|
| 651 |
+
359.72862100601196,
|
| 652 |
+
75.62465047836304,
|
| 653 |
+
446.62164974212646,
|
| 654 |
+
105.61136054992676,
|
| 655 |
+
2787.228641986847,
|
| 656 |
+
311.6961917877197,
|
| 657 |
+
156.06305074691772,
|
| 658 |
+
1498.6027584075928,
|
| 659 |
+
185.69973182678223,
|
| 660 |
+
147.8509397506714,
|
| 661 |
+
12.531700134277344,
|
| 662 |
+
0.0,
|
| 663 |
+
192.53613948822021,
|
| 664 |
+
424.5432171821594,
|
| 665 |
+
259.268039226532,
|
| 666 |
+
175.13502979278564,
|
| 667 |
+
281.5383825302124,
|
| 668 |
+
299.1759967803955,
|
| 669 |
+
227.893488407135,
|
| 670 |
+
136.72871112823486,
|
| 671 |
+
416.3120012283325,
|
| 672 |
+
115.03175830841064,
|
| 673 |
+
0.0,
|
| 674 |
+
144.02852058410645,
|
| 675 |
+
208.2749309539795,
|
| 676 |
+
160.34006214141846,
|
| 677 |
+
109.58282947540283,
|
| 678 |
+
1500.150812625885,
|
| 679 |
+
4945.450592041016,
|
| 680 |
+
2852.855231285095,
|
| 681 |
+
881.7318058013916,
|
| 682 |
+
397.0553340911865,
|
| 683 |
+
315.55763959884644,
|
| 684 |
+
2086.7152404785156,
|
| 685 |
+
1611.37087059021,
|
| 686 |
+
2103.3109679222107,
|
| 687 |
+
3135.3377957344055,
|
| 688 |
+
2692.6771001815796,
|
| 689 |
+
4584.85631608963,
|
| 690 |
+
1700.0699429512024,
|
| 691 |
+
883.6995916366577,
|
| 692 |
+
33464.33708667755,
|
| 693 |
+
574.8801603317261,
|
| 694 |
+
2229.160650253296,
|
| 695 |
+
379.5017247200012,
|
| 696 |
+
905.5721397399902,
|
| 697 |
+
493.963942527771,
|
| 698 |
+
4049.96994638443,
|
| 699 |
+
189.95257091522217,
|
| 700 |
+
61.00449848175049,
|
| 701 |
+
450.8264832496643,
|
| 702 |
+
398.1711621284485,
|
| 703 |
+
38847.667073726654,
|
| 704 |
+
1835.184115409851,
|
| 705 |
+
2697.096595287323,
|
| 706 |
+
4710.6771783828735,
|
| 707 |
+
5588.210665225983,
|
| 708 |
+
1004.0054593086243,
|
| 709 |
+
652.6680641174316,
|
| 710 |
+
2031.7795896530151,
|
| 711 |
+
367.2168278694153,
|
| 712 |
+
2698.1613121032715,
|
| 713 |
+
591.61465883255,
|
| 714 |
+
448.26813650131226,
|
| 715 |
+
849.9976563453674,
|
| 716 |
+
8368.735646724701,
|
| 717 |
+
414.3280692100525,
|
| 718 |
+
3544.0216879844666,
|
| 719 |
+
679.3534464836121,
|
| 720 |
+
247.58060026168823,
|
| 721 |
+
402.0281286239624,
|
| 722 |
+
5822.276999950409,
|
| 723 |
+
1743.6888279914856,
|
| 724 |
+
2081.8095812797546,
|
| 725 |
+
1696.2736263275146,
|
| 726 |
+
197.28233861923218,
|
| 727 |
+
3321.6009736061096,
|
| 728 |
+
2298.3414697647095,
|
| 729 |
+
2910.3161034584045,
|
| 730 |
+
296.4575996398926,
|
| 731 |
+
14755.747835159302,
|
| 732 |
+
6977.302089691162,
|
| 733 |
+
3608.7710394859314,
|
| 734 |
+
289.08115005493164,
|
| 735 |
+
2645.5259099006653,
|
| 736 |
+
158.54701232910156,
|
| 737 |
+
490.0809507369995,
|
| 738 |
+
1880.1874709129333,
|
| 739 |
+
1493.8953075408936,
|
| 740 |
+
609.5897555351257,
|
| 741 |
+
462.8165135383606,
|
| 742 |
+
243.31624794006348,
|
| 743 |
+
150.1076784133911,
|
| 744 |
+
6197.5719475746155,
|
| 745 |
+
1036.8616194725037,
|
| 746 |
+
5302.397746086121,
|
| 747 |
+
1388.753752708435,
|
| 748 |
+
2091.038170814514,
|
| 749 |
+
785.7442808151245,
|
| 750 |
+
377.4342908859253,
|
| 751 |
+
3640.3371028900146,
|
| 752 |
+
1029.8467602729797,
|
| 753 |
+
296.86861085891724,
|
| 754 |
+
1221.5854263305664,
|
| 755 |
+
535.2803363800049,
|
| 756 |
+
2508.307864189148,
|
| 757 |
+
3831.0581674575806,
|
| 758 |
+
2263.3348484039307,
|
| 759 |
+
926.5323433876038,
|
| 760 |
+
8959.179275035858,
|
| 761 |
+
309.04264068603516,
|
| 762 |
+
1767.5786666870117,
|
| 763 |
+
2107.6189522743225,
|
| 764 |
+
155.21375036239624,
|
| 765 |
+
378.6039876937866,
|
| 766 |
+
2220.862048149109,
|
| 767 |
+
1505.2828221321106,
|
| 768 |
+
517.8384418487549,
|
| 769 |
+
4313.928272247314,
|
| 770 |
+
342.4098491668701,
|
| 771 |
+
1310.0776271820068,
|
| 772 |
+
434.5597867965698,
|
| 773 |
+
2071.2271361351013,
|
| 774 |
+
0.0,
|
| 775 |
+
8595.476936340332,
|
| 776 |
+
202.46072053909302,
|
| 777 |
+
366.71736097335815,
|
| 778 |
+
7074.809521198273,
|
| 779 |
+
6.880340099334717,
|
| 780 |
+
1959.3085498809814,
|
| 781 |
+
636.0715098381042,
|
| 782 |
+
9.84004020690918,
|
| 783 |
+
386.9805417060852,
|
| 784 |
+
2382.4822087287903,
|
| 785 |
+
2317.9521684646606,
|
| 786 |
+
2793.7392020225525,
|
| 787 |
+
1188.6612939834595,
|
| 788 |
+
933.1099715232849,
|
| 789 |
+
4565.712460041046,
|
| 790 |
+
14641.29742860794,
|
| 791 |
+
15552.311092853546,
|
| 792 |
+
56185.89445209503,
|
| 793 |
+
97331.36661911011,
|
| 794 |
+
87548.01149320602,
|
| 795 |
+
521853.7248663902,
|
| 796 |
+
2643.261353492737,
|
| 797 |
+
20220.717566013336,
|
| 798 |
+
79148.93348503113,
|
| 799 |
+
17449.243332386017,
|
| 800 |
+
13258.27445936203,
|
| 801 |
+
6109.533164024353,
|
| 802 |
+
6781.56981420517,
|
| 803 |
+
3942.6140484809875,
|
| 804 |
+
8469.07410955429,
|
| 805 |
+
40318.94767665863,
|
| 806 |
+
156345.23027658463,
|
| 807 |
+
12197.998657226562,
|
| 808 |
+
22888.345291614532,
|
| 809 |
+
10946.28234910965,
|
| 810 |
+
204263.674387455,
|
| 811 |
+
229631.36437797546,
|
| 812 |
+
1945.9702520370483,
|
| 813 |
+
3069.6773653030396,
|
| 814 |
+
6425.405041217804,
|
| 815 |
+
508.55564069747925,
|
| 816 |
+
8993.14672756195,
|
| 817 |
+
0.0,
|
| 818 |
+
0.0,
|
| 819 |
+
0.0
|
| 820 |
+
]
|
| 821 |
+
}
|
models/classifier/lightgbm_model.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/classifier/scaler.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:347b85c4f3e4bcbda0599f607a1ad5194c01655baca73b6e2ee72a9ba50dcf84
|
| 3 |
+
size 13207
|
src/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hybrid Document Forgery Detection & Localization System
|
| 3 |
+
|
| 4 |
+
A robust hybrid (Deep Learning + Classical ML) system for multi-type
|
| 5 |
+
document forgery detection and localization.
|
| 6 |
+
|
| 7 |
+
Architecture:
|
| 8 |
+
- Deep Learning: MobileNetV3-Small + UNet-Lite for pixel-level localization
|
| 9 |
+
- Classical ML: LightGBM for interpretable forgery classification
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
__version__ = "1.0.0"
|
| 13 |
+
|
| 14 |
+
from .config import get_config
|
| 15 |
+
from .models import get_model, get_loss_function
|
| 16 |
+
from .data import get_dataset
|
| 17 |
+
from .features import get_feature_extractor, get_mask_refiner, get_region_extractor
|
| 18 |
+
from .training import get_trainer, get_metrics_tracker
|
| 19 |
+
from .inference import get_pipeline
|
| 20 |
+
|
| 21 |
+
__all__ = [
|
| 22 |
+
'get_config',
|
| 23 |
+
'get_model',
|
| 24 |
+
'get_loss_function',
|
| 25 |
+
'get_dataset',
|
| 26 |
+
'get_feature_extractor',
|
| 27 |
+
'get_mask_refiner',
|
| 28 |
+
'get_region_extractor',
|
| 29 |
+
'get_trainer',
|
| 30 |
+
'get_metrics_tracker',
|
| 31 |
+
'get_pipeline'
|
| 32 |
+
]
|
src/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (993 Bytes). View file
|
|
|
src/config/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration module"""
|
| 2 |
+
|
| 3 |
+
from .config_loader import Config, get_config
|
| 4 |
+
|
| 5 |
+
__all__ = ['Config', 'get_config']
|
src/config/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (291 Bytes). View file
|
|
|
src/config/__pycache__/config_loader.cpython-312.pyc
ADDED
|
Binary file (5.42 kB). View file
|
|
|
src/config/config_loader.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration loader for Hybrid Document Forgery Detection System
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import yaml
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Dict, Any
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Config:
|
| 11 |
+
"""Configuration manager"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, config_path: str = "config.yaml"):
|
| 14 |
+
"""
|
| 15 |
+
Load configuration from YAML file
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
config_path: Path to configuration file
|
| 19 |
+
"""
|
| 20 |
+
self.config_path = Path(config_path)
|
| 21 |
+
self.config = self._load_config()
|
| 22 |
+
|
| 23 |
+
def _load_config(self) -> Dict[str, Any]:
|
| 24 |
+
"""Load YAML configuration"""
|
| 25 |
+
if not self.config_path.exists():
|
| 26 |
+
raise FileNotFoundError(f"Config file not found: {self.config_path}")
|
| 27 |
+
|
| 28 |
+
with open(self.config_path, 'r') as f:
|
| 29 |
+
config = yaml.safe_load(f)
|
| 30 |
+
|
| 31 |
+
return config
|
| 32 |
+
|
| 33 |
+
def get(self, key: str, default: Any = None) -> Any:
|
| 34 |
+
"""
|
| 35 |
+
Get configuration value using dot notation
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
key: Configuration key (e.g., 'model.encoder.name')
|
| 39 |
+
default: Default value if key not found
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Configuration value
|
| 43 |
+
"""
|
| 44 |
+
keys = key.split('.')
|
| 45 |
+
value = self.config
|
| 46 |
+
|
| 47 |
+
for k in keys:
|
| 48 |
+
if isinstance(value, dict) and k in value:
|
| 49 |
+
value = value[k]
|
| 50 |
+
else:
|
| 51 |
+
return default
|
| 52 |
+
|
| 53 |
+
return value
|
| 54 |
+
|
| 55 |
+
def get_dataset_config(self, dataset_name: str) -> Dict[str, Any]:
|
| 56 |
+
"""
|
| 57 |
+
Get dataset-specific configuration
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
dataset_name: Dataset name (doctamper, rtm, casia, receipts)
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Dataset configuration dictionary
|
| 64 |
+
"""
|
| 65 |
+
return self.config['data']['datasets'].get(dataset_name, {})
|
| 66 |
+
|
| 67 |
+
def has_pixel_mask(self, dataset_name: str) -> bool:
|
| 68 |
+
"""Check if dataset has pixel-level masks"""
|
| 69 |
+
dataset_config = self.get_dataset_config(dataset_name)
|
| 70 |
+
return dataset_config.get('has_pixel_mask', False)
|
| 71 |
+
|
| 72 |
+
def should_skip_deskew(self, dataset_name: str) -> bool:
|
| 73 |
+
"""Check if deskewing should be skipped for dataset"""
|
| 74 |
+
dataset_config = self.get_dataset_config(dataset_name)
|
| 75 |
+
return dataset_config.get('skip_deskew', False)
|
| 76 |
+
|
| 77 |
+
def should_skip_denoising(self, dataset_name: str) -> bool:
|
| 78 |
+
"""Check if denoising should be skipped for dataset"""
|
| 79 |
+
dataset_config = self.get_dataset_config(dataset_name)
|
| 80 |
+
return dataset_config.get('skip_denoising', False)
|
| 81 |
+
|
| 82 |
+
def get_min_region_area(self, dataset_name: str) -> float:
|
| 83 |
+
"""Get minimum region area threshold for dataset"""
|
| 84 |
+
dataset_config = self.get_dataset_config(dataset_name)
|
| 85 |
+
return dataset_config.get('min_region_area', 0.001)
|
| 86 |
+
|
| 87 |
+
def should_compute_localization_metrics(self, dataset_name: str) -> bool:
|
| 88 |
+
"""Check if localization metrics should be computed for dataset"""
|
| 89 |
+
compute_config = self.config['metrics'].get('compute_localization', {})
|
| 90 |
+
return compute_config.get(dataset_name, False)
|
| 91 |
+
|
| 92 |
+
def __getitem__(self, key: str) -> Any:
|
| 93 |
+
"""Allow dictionary-style access"""
|
| 94 |
+
return self.get(key)
|
| 95 |
+
|
| 96 |
+
def __repr__(self) -> str:
|
| 97 |
+
return f"Config(path={self.config_path})"
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# Global config instance
|
| 101 |
+
_config = None
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def get_config(config_path: str = "config.yaml") -> Config:
|
| 105 |
+
"""
|
| 106 |
+
Get global configuration instance
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
config_path: Path to configuration file
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
Config instance
|
| 113 |
+
"""
|
| 114 |
+
global _config
|
| 115 |
+
if _config is None:
|
| 116 |
+
_config = Config(config_path)
|
| 117 |
+
return _config
|
src/data/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data module"""
|
| 2 |
+
|
| 3 |
+
from .preprocessing import DocumentPreprocessor, preprocess_image
|
| 4 |
+
from .augmentation import DatasetAwareAugmentation, get_augmentation
|
| 5 |
+
from .datasets import (
|
| 6 |
+
DocTamperDataset,
|
| 7 |
+
RTMDataset,
|
| 8 |
+
CASIADataset,
|
| 9 |
+
ReceiptsDataset,
|
| 10 |
+
get_dataset
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
'DocumentPreprocessor',
|
| 15 |
+
'preprocess_image',
|
| 16 |
+
'DatasetAwareAugmentation',
|
| 17 |
+
'get_augmentation',
|
| 18 |
+
'DocTamperDataset',
|
| 19 |
+
'RTMDataset',
|
| 20 |
+
'CASIADataset',
|
| 21 |
+
'ReceiptsDataset',
|
| 22 |
+
'get_dataset'
|
| 23 |
+
]
|
src/data/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (572 Bytes). View file
|
|
|
src/data/__pycache__/augmentation.cpython-312.pyc
ADDED
|
Binary file (5.94 kB). View file
|
|
|
src/data/__pycache__/datasets.cpython-312.pyc
ADDED
|
Binary file (21.2 kB). View file
|
|
|
src/data/__pycache__/preprocessing.cpython-312.pyc
ADDED
|
Binary file (9.38 kB). View file
|
|
|
src/data/augmentation.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset-aware augmentation for training
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import albumentations as A
|
| 8 |
+
from albumentations.pytorch import ToTensorV2
|
| 9 |
+
from typing import Dict, Any, Optional
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DatasetAwareAugmentation:
|
| 13 |
+
"""Dataset-aware augmentation pipeline"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, config, dataset_name: str, is_training: bool = True):
|
| 16 |
+
"""
|
| 17 |
+
Initialize augmentation pipeline
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
config: Configuration object
|
| 21 |
+
dataset_name: Dataset name
|
| 22 |
+
is_training: Whether in training mode
|
| 23 |
+
"""
|
| 24 |
+
self.config = config
|
| 25 |
+
self.dataset_name = dataset_name
|
| 26 |
+
self.is_training = is_training
|
| 27 |
+
|
| 28 |
+
# Build augmentation pipeline
|
| 29 |
+
self.transform = self._build_transform()
|
| 30 |
+
|
| 31 |
+
def _build_transform(self) -> A.Compose:
|
| 32 |
+
"""Build albumentations transform pipeline"""
|
| 33 |
+
|
| 34 |
+
transforms = []
|
| 35 |
+
|
| 36 |
+
if self.is_training and self.config.get('augmentation.enabled', True):
|
| 37 |
+
# Common augmentations
|
| 38 |
+
common_augs = self.config.get('augmentation.common', [])
|
| 39 |
+
|
| 40 |
+
for aug_config in common_augs:
|
| 41 |
+
aug_type = aug_config.get('type')
|
| 42 |
+
prob = aug_config.get('prob', 0.5)
|
| 43 |
+
|
| 44 |
+
if aug_type == 'noise':
|
| 45 |
+
transforms.append(
|
| 46 |
+
A.GaussNoise(var_limit=(10.0, 50.0), p=prob)
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
elif aug_type == 'motion_blur':
|
| 50 |
+
transforms.append(
|
| 51 |
+
A.MotionBlur(blur_limit=7, p=prob)
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
elif aug_type == 'jpeg_compression':
|
| 55 |
+
quality_range = aug_config.get('quality', [60, 95])
|
| 56 |
+
transforms.append(
|
| 57 |
+
A.ImageCompression(quality_lower=quality_range[0],
|
| 58 |
+
quality_upper=quality_range[1],
|
| 59 |
+
p=prob)
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
elif aug_type == 'lighting':
|
| 63 |
+
transforms.append(
|
| 64 |
+
A.OneOf([
|
| 65 |
+
A.RandomBrightnessContrast(p=1.0),
|
| 66 |
+
A.RandomGamma(p=1.0),
|
| 67 |
+
A.HueSaturationValue(p=1.0),
|
| 68 |
+
], p=prob)
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
elif aug_type == 'perspective':
|
| 72 |
+
transforms.append(
|
| 73 |
+
A.Perspective(scale=(0.02, 0.05), p=prob)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Dataset-specific augmentations
|
| 77 |
+
if self.dataset_name == 'receipts':
|
| 78 |
+
receipt_augs = self.config.get('augmentation.receipts', [])
|
| 79 |
+
|
| 80 |
+
for aug_config in receipt_augs:
|
| 81 |
+
aug_type = aug_config.get('type')
|
| 82 |
+
prob = aug_config.get('prob', 0.5)
|
| 83 |
+
|
| 84 |
+
if aug_type == 'stain':
|
| 85 |
+
# Simulate stains using random blobs
|
| 86 |
+
transforms.append(
|
| 87 |
+
A.RandomShadow(
|
| 88 |
+
shadow_roi=(0, 0, 1, 1),
|
| 89 |
+
num_shadows_lower=1,
|
| 90 |
+
num_shadows_upper=3,
|
| 91 |
+
shadow_dimension=5,
|
| 92 |
+
p=prob
|
| 93 |
+
)
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
elif aug_type == 'fold':
|
| 97 |
+
# Simulate folds using grid distortion
|
| 98 |
+
transforms.append(
|
| 99 |
+
A.GridDistortion(num_steps=5, distort_limit=0.1, p=prob)
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Always convert to tensor
|
| 103 |
+
transforms.append(ToTensorV2())
|
| 104 |
+
|
| 105 |
+
return A.Compose(
|
| 106 |
+
transforms,
|
| 107 |
+
additional_targets={'mask': 'mask'}
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
def __call__(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Dict[str, Any]:
|
| 111 |
+
"""
|
| 112 |
+
Apply augmentation
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
image: Input image (H, W, 3), float32, [0, 1]
|
| 116 |
+
mask: Optional mask (H, W), uint8, {0, 1}
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
Dictionary with 'image' and optionally 'mask'
|
| 120 |
+
"""
|
| 121 |
+
# Convert to uint8 for albumentations
|
| 122 |
+
image_uint8 = (image * 255).astype(np.uint8)
|
| 123 |
+
|
| 124 |
+
if mask is not None:
|
| 125 |
+
mask_uint8 = (mask * 255).astype(np.uint8)
|
| 126 |
+
augmented = self.transform(image=image_uint8, mask=mask_uint8)
|
| 127 |
+
|
| 128 |
+
# Convert back to float32
|
| 129 |
+
augmented['image'] = augmented['image'].float() / 255.0
|
| 130 |
+
augmented['mask'] = (augmented['mask'].float() / 255.0).unsqueeze(0)
|
| 131 |
+
else:
|
| 132 |
+
augmented = self.transform(image=image_uint8)
|
| 133 |
+
augmented['image'] = augmented['image'].float() / 255.0
|
| 134 |
+
|
| 135 |
+
return augmented
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def get_augmentation(config, dataset_name: str, is_training: bool = True) -> DatasetAwareAugmentation:
|
| 139 |
+
"""
|
| 140 |
+
Get augmentation pipeline
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
config: Configuration object
|
| 144 |
+
dataset_name: Dataset name
|
| 145 |
+
is_training: Whether in training mode
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
Augmentation pipeline
|
| 149 |
+
"""
|
| 150 |
+
return DatasetAwareAugmentation(config, dataset_name, is_training)
|
src/data/datasets.py
ADDED
|
@@ -0,0 +1,541 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset loaders for document forgery detection
|
| 3 |
+
Implements Critical Fix #7: Image-level train/test splits
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import lmdb
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from torch.utils.data import Dataset
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Tuple, Optional, List
|
| 14 |
+
import json
|
| 15 |
+
from PIL import Image
|
| 16 |
+
|
| 17 |
+
from .preprocessing import DocumentPreprocessor
|
| 18 |
+
from .augmentation import DatasetAwareAugmentation
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DocTamperDataset(Dataset):
|
| 22 |
+
"""
|
| 23 |
+
DocTamper dataset loader (LMDB-based)
|
| 24 |
+
Implements chunked loading for RAM constraints
|
| 25 |
+
Uses lazy LMDB initialization for multiprocessing compatibility
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self,
|
| 29 |
+
config,
|
| 30 |
+
split: str = 'train',
|
| 31 |
+
chunk_start: float = 0.0,
|
| 32 |
+
chunk_end: float = 1.0):
|
| 33 |
+
"""
|
| 34 |
+
Initialize DocTamper dataset
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
config: Configuration object
|
| 38 |
+
split: 'train' or 'val'
|
| 39 |
+
chunk_start: Start ratio for chunked training (0.0 to 1.0)
|
| 40 |
+
chunk_end: End ratio for chunked training (0.0 to 1.0)
|
| 41 |
+
"""
|
| 42 |
+
self.config = config
|
| 43 |
+
self.split = split
|
| 44 |
+
self.dataset_name = 'doctamper'
|
| 45 |
+
|
| 46 |
+
# Get dataset path
|
| 47 |
+
dataset_config = config.get_dataset_config(self.dataset_name)
|
| 48 |
+
self.data_path = Path(dataset_config['path'])
|
| 49 |
+
|
| 50 |
+
# Map split to actual folder names
|
| 51 |
+
if split == 'train':
|
| 52 |
+
lmdb_folder = 'DocTamperV1-TrainingSet'
|
| 53 |
+
elif split == 'val' or split == 'test':
|
| 54 |
+
lmdb_folder = 'DocTamperV1-TestingSet'
|
| 55 |
+
else:
|
| 56 |
+
lmdb_folder = 'DocTamperV1-TrainingSet'
|
| 57 |
+
|
| 58 |
+
self.lmdb_path = str(self.data_path / lmdb_folder)
|
| 59 |
+
|
| 60 |
+
if not Path(self.lmdb_path).exists():
|
| 61 |
+
raise FileNotFoundError(f"LMDB folder not found: {self.lmdb_path}")
|
| 62 |
+
|
| 63 |
+
# LAZY INITIALIZATION: Don't open LMDB here (pickle issue with multiprocessing)
|
| 64 |
+
# Just get the count by temporarily opening
|
| 65 |
+
temp_env = lmdb.open(self.lmdb_path, readonly=True, lock=False)
|
| 66 |
+
with temp_env.begin() as txn:
|
| 67 |
+
stat = txn.stat()
|
| 68 |
+
self.length = stat['entries'] // 2
|
| 69 |
+
temp_env.close()
|
| 70 |
+
|
| 71 |
+
# LMDB env will be opened lazily in __getitem__
|
| 72 |
+
self._env = None
|
| 73 |
+
|
| 74 |
+
# Critical Fix #7: Image-level chunking (not region-level)
|
| 75 |
+
self.chunk_start = int(self.length * chunk_start)
|
| 76 |
+
self.chunk_end = int(self.length * chunk_end)
|
| 77 |
+
self.chunk_length = self.chunk_end - self.chunk_start
|
| 78 |
+
|
| 79 |
+
print(f"DocTamper {split}: Total={self.length}, "
|
| 80 |
+
f"Chunk=[{self.chunk_start}:{self.chunk_end}], "
|
| 81 |
+
f"Length={self.chunk_length}")
|
| 82 |
+
|
| 83 |
+
# Initialize preprocessor and augmentation
|
| 84 |
+
self.preprocessor = DocumentPreprocessor(config, self.dataset_name)
|
| 85 |
+
self.augmentation = DatasetAwareAugmentation(
|
| 86 |
+
config,
|
| 87 |
+
self.dataset_name,
|
| 88 |
+
is_training=(split == 'train')
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def env(self):
|
| 93 |
+
"""Lazy LMDB environment initialization for multiprocessing compatibility"""
|
| 94 |
+
if self._env is None:
|
| 95 |
+
self._env = lmdb.open(self.lmdb_path, readonly=True, lock=False,
|
| 96 |
+
max_readers=32, readahead=False)
|
| 97 |
+
return self._env
|
| 98 |
+
|
| 99 |
+
def __len__(self) -> int:
|
| 100 |
+
return self.chunk_length
|
| 101 |
+
|
| 102 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
| 103 |
+
"""
|
| 104 |
+
Get item from dataset
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
idx: Index within chunk
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
image: (3, H, W) tensor
|
| 111 |
+
mask: (1, H, W) tensor
|
| 112 |
+
metadata: Dictionary with additional info
|
| 113 |
+
"""
|
| 114 |
+
# Try to get the requested sample, skip to next if missing
|
| 115 |
+
max_attempts = 10
|
| 116 |
+
original_idx = idx
|
| 117 |
+
|
| 118 |
+
for attempt in range(max_attempts):
|
| 119 |
+
try:
|
| 120 |
+
# Map chunk index to global index
|
| 121 |
+
global_idx = self.chunk_start + idx
|
| 122 |
+
|
| 123 |
+
# Read from LMDB
|
| 124 |
+
with self.env.begin() as txn:
|
| 125 |
+
# DocTamper format: image-XXXXXXXXX, label-XXXXXXXXX (9 digits, dash separator)
|
| 126 |
+
img_key = f'image-{global_idx:09d}'.encode()
|
| 127 |
+
label_key = f'label-{global_idx:09d}'.encode()
|
| 128 |
+
|
| 129 |
+
img_buf = txn.get(img_key)
|
| 130 |
+
label_buf = txn.get(label_key)
|
| 131 |
+
|
| 132 |
+
if img_buf is None:
|
| 133 |
+
# Sample missing, try next index
|
| 134 |
+
idx = (idx + 1) % self.chunk_length
|
| 135 |
+
continue
|
| 136 |
+
|
| 137 |
+
# Decode image
|
| 138 |
+
img_array = np.frombuffer(img_buf, dtype=np.uint8)
|
| 139 |
+
image = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
|
| 140 |
+
|
| 141 |
+
if image is None:
|
| 142 |
+
# Failed to decode, try next index
|
| 143 |
+
idx = (idx + 1) % self.chunk_length
|
| 144 |
+
continue
|
| 145 |
+
|
| 146 |
+
# Decode label/mask
|
| 147 |
+
if label_buf is not None:
|
| 148 |
+
label_array = np.frombuffer(label_buf, dtype=np.uint8)
|
| 149 |
+
mask = cv2.imdecode(label_array, cv2.IMREAD_GRAYSCALE)
|
| 150 |
+
if mask is None:
|
| 151 |
+
# Label might be raw bytes, create empty mask
|
| 152 |
+
mask = np.zeros(image.shape[:2], dtype=np.uint8)
|
| 153 |
+
else:
|
| 154 |
+
# No mask found - create empty mask
|
| 155 |
+
mask = np.zeros(image.shape[:2], dtype=np.uint8)
|
| 156 |
+
|
| 157 |
+
# Successfully loaded - break out of retry loop
|
| 158 |
+
break
|
| 159 |
+
|
| 160 |
+
except Exception as e:
|
| 161 |
+
# Something went wrong, try next index
|
| 162 |
+
idx = (idx + 1) % self.chunk_length
|
| 163 |
+
if attempt == max_attempts - 1:
|
| 164 |
+
# Last attempt failed, create a dummy sample
|
| 165 |
+
print(f"Warning: Could not load sample at idx {original_idx}, creating dummy sample")
|
| 166 |
+
image = np.zeros((384, 384, 3), dtype=np.float32)
|
| 167 |
+
mask = np.zeros((384, 384), dtype=np.uint8)
|
| 168 |
+
global_idx = original_idx
|
| 169 |
+
|
| 170 |
+
# Preprocess
|
| 171 |
+
image, mask = self.preprocessor(image, mask)
|
| 172 |
+
|
| 173 |
+
# Augment
|
| 174 |
+
augmented = self.augmentation(image, mask)
|
| 175 |
+
image = augmented['image']
|
| 176 |
+
mask = augmented['mask']
|
| 177 |
+
|
| 178 |
+
# Metadata
|
| 179 |
+
metadata = {
|
| 180 |
+
'dataset': self.dataset_name,
|
| 181 |
+
'index': global_idx,
|
| 182 |
+
'has_pixel_mask': True
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
return image, mask, metadata
|
| 186 |
+
|
| 187 |
+
def __del__(self):
|
| 188 |
+
"""Close LMDB environment"""
|
| 189 |
+
if hasattr(self, '_env') and self._env is not None:
|
| 190 |
+
self._env.close()
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class RTMDataset(Dataset):
|
| 195 |
+
"""Real Text Manipulation dataset loader"""
|
| 196 |
+
|
| 197 |
+
def __init__(self, config, split: str = 'train'):
|
| 198 |
+
"""
|
| 199 |
+
Initialize RTM dataset
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
config: Configuration object
|
| 203 |
+
split: 'train' or 'test'
|
| 204 |
+
"""
|
| 205 |
+
self.config = config
|
| 206 |
+
self.split = split
|
| 207 |
+
self.dataset_name = 'rtm'
|
| 208 |
+
|
| 209 |
+
# Get dataset path
|
| 210 |
+
dataset_config = config.get_dataset_config(self.dataset_name)
|
| 211 |
+
self.data_path = Path(dataset_config['path'])
|
| 212 |
+
|
| 213 |
+
# Load split file
|
| 214 |
+
split_file = self.data_path / f'{split}.txt'
|
| 215 |
+
with open(split_file, 'r') as f:
|
| 216 |
+
self.image_ids = [line.strip() for line in f.readlines()]
|
| 217 |
+
|
| 218 |
+
self.images_dir = self.data_path / 'JPEGImages'
|
| 219 |
+
self.masks_dir = self.data_path / 'SegmentationClass'
|
| 220 |
+
|
| 221 |
+
print(f"RTM {split}: {len(self.image_ids)} images")
|
| 222 |
+
|
| 223 |
+
# Initialize preprocessor and augmentation
|
| 224 |
+
self.preprocessor = DocumentPreprocessor(config, self.dataset_name)
|
| 225 |
+
self.augmentation = DatasetAwareAugmentation(
|
| 226 |
+
config,
|
| 227 |
+
self.dataset_name,
|
| 228 |
+
is_training=(split == 'train')
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
def __len__(self) -> int:
|
| 232 |
+
return len(self.image_ids)
|
| 233 |
+
|
| 234 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
| 235 |
+
"""Get item from dataset"""
|
| 236 |
+
image_id = self.image_ids[idx]
|
| 237 |
+
|
| 238 |
+
# Load image
|
| 239 |
+
img_path = self.images_dir / f'{image_id}.jpg'
|
| 240 |
+
image = cv2.imread(str(img_path))
|
| 241 |
+
|
| 242 |
+
# Load mask
|
| 243 |
+
mask_path = self.masks_dir / f'{image_id}.png'
|
| 244 |
+
mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
|
| 245 |
+
|
| 246 |
+
# Binarize mask
|
| 247 |
+
mask = (mask > 0).astype(np.uint8)
|
| 248 |
+
|
| 249 |
+
# Preprocess
|
| 250 |
+
image, mask = self.preprocessor(image, mask)
|
| 251 |
+
|
| 252 |
+
# Augment
|
| 253 |
+
augmented = self.augmentation(image, mask)
|
| 254 |
+
image = augmented['image']
|
| 255 |
+
mask = augmented['mask']
|
| 256 |
+
|
| 257 |
+
# Metadata
|
| 258 |
+
metadata = {
|
| 259 |
+
'dataset': self.dataset_name,
|
| 260 |
+
'image_id': image_id,
|
| 261 |
+
'has_pixel_mask': True
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
return image, mask, metadata
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class CASIADataset(Dataset):
|
| 268 |
+
"""
|
| 269 |
+
CASIA v1.0 dataset loader
|
| 270 |
+
Image-level labels only (no pixel masks)
|
| 271 |
+
Implements Critical Fix #6: CASIA image-level handling
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
def __init__(self, config, split: str = 'train'):
|
| 275 |
+
"""
|
| 276 |
+
Initialize CASIA dataset
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
config: Configuration object
|
| 280 |
+
split: 'train' or 'test'
|
| 281 |
+
"""
|
| 282 |
+
self.config = config
|
| 283 |
+
self.split = split
|
| 284 |
+
self.dataset_name = 'casia'
|
| 285 |
+
|
| 286 |
+
# Get dataset path
|
| 287 |
+
dataset_config = config.get_dataset_config(self.dataset_name)
|
| 288 |
+
self.data_path = Path(dataset_config['path'])
|
| 289 |
+
|
| 290 |
+
# Load authentic and tampered images
|
| 291 |
+
self.authentic_dir = self.data_path / 'Au'
|
| 292 |
+
self.tampered_dir = self.data_path / 'Tp'
|
| 293 |
+
|
| 294 |
+
# Get all image paths
|
| 295 |
+
authentic_images = list(self.authentic_dir.glob('*.jpg')) + \
|
| 296 |
+
list(self.authentic_dir.glob('*.png'))
|
| 297 |
+
tampered_images = list(self.tampered_dir.glob('*.jpg')) + \
|
| 298 |
+
list(self.tampered_dir.glob('*.png'))
|
| 299 |
+
|
| 300 |
+
# Create image list with labels
|
| 301 |
+
self.samples = []
|
| 302 |
+
for img_path in authentic_images:
|
| 303 |
+
self.samples.append((img_path, 0)) # 0 = authentic
|
| 304 |
+
for img_path in tampered_images:
|
| 305 |
+
self.samples.append((img_path, 1)) # 1 = tampered
|
| 306 |
+
|
| 307 |
+
# Critical Fix #7: Image-level split (80/20)
|
| 308 |
+
np.random.seed(42)
|
| 309 |
+
indices = np.random.permutation(len(self.samples))
|
| 310 |
+
split_idx = int(len(self.samples) * 0.8)
|
| 311 |
+
|
| 312 |
+
if split == 'train':
|
| 313 |
+
indices = indices[:split_idx]
|
| 314 |
+
else:
|
| 315 |
+
indices = indices[split_idx:]
|
| 316 |
+
|
| 317 |
+
self.samples = [self.samples[i] for i in indices]
|
| 318 |
+
|
| 319 |
+
print(f"CASIA {split}: {len(self.samples)} images")
|
| 320 |
+
|
| 321 |
+
# Initialize preprocessor and augmentation
|
| 322 |
+
self.preprocessor = DocumentPreprocessor(config, self.dataset_name)
|
| 323 |
+
self.augmentation = DatasetAwareAugmentation(
|
| 324 |
+
config,
|
| 325 |
+
self.dataset_name,
|
| 326 |
+
is_training=(split == 'train')
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
def __len__(self) -> int:
|
| 330 |
+
return len(self.samples)
|
| 331 |
+
|
| 332 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
| 333 |
+
"""Get item from dataset"""
|
| 334 |
+
img_path, label = self.samples[idx]
|
| 335 |
+
|
| 336 |
+
# Load image
|
| 337 |
+
image = cv2.imread(str(img_path))
|
| 338 |
+
|
| 339 |
+
# Critical Fix #6: Create image-level mask (entire image)
|
| 340 |
+
h, w = image.shape[:2]
|
| 341 |
+
mask = np.ones((h, w), dtype=np.uint8) * label
|
| 342 |
+
|
| 343 |
+
# Preprocess
|
| 344 |
+
image, mask = self.preprocessor(image, mask)
|
| 345 |
+
|
| 346 |
+
# Augment
|
| 347 |
+
augmented = self.augmentation(image, mask)
|
| 348 |
+
image = augmented['image']
|
| 349 |
+
mask = augmented['mask']
|
| 350 |
+
|
| 351 |
+
# Metadata
|
| 352 |
+
metadata = {
|
| 353 |
+
'dataset': self.dataset_name,
|
| 354 |
+
'image_path': str(img_path),
|
| 355 |
+
'has_pixel_mask': False, # Image-level only
|
| 356 |
+
'label': label
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
return image, mask, metadata
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
class ReceiptsDataset(Dataset):
|
| 363 |
+
"""Find-It-Again receipts dataset loader"""
|
| 364 |
+
|
| 365 |
+
def __init__(self, config, split: str = 'train'):
|
| 366 |
+
"""
|
| 367 |
+
Initialize receipts dataset
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
config: Configuration object
|
| 371 |
+
split: 'train', 'val', or 'test'
|
| 372 |
+
"""
|
| 373 |
+
self.config = config
|
| 374 |
+
self.split = split
|
| 375 |
+
self.dataset_name = 'receipts'
|
| 376 |
+
|
| 377 |
+
# Get dataset path
|
| 378 |
+
dataset_config = config.get_dataset_config(self.dataset_name)
|
| 379 |
+
self.data_path = Path(dataset_config['path'])
|
| 380 |
+
|
| 381 |
+
# Load split file
|
| 382 |
+
split_file = self.data_path / f'{split}.json'
|
| 383 |
+
with open(split_file, 'r') as f:
|
| 384 |
+
self.annotations = json.load(f)
|
| 385 |
+
|
| 386 |
+
print(f"Receipts {split}: {len(self.annotations)} images")
|
| 387 |
+
|
| 388 |
+
# Initialize preprocessor and augmentation
|
| 389 |
+
self.preprocessor = DocumentPreprocessor(config, self.dataset_name)
|
| 390 |
+
self.augmentation = DatasetAwareAugmentation(
|
| 391 |
+
config,
|
| 392 |
+
self.dataset_name,
|
| 393 |
+
is_training=(split == 'train')
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
def __len__(self) -> int:
|
| 397 |
+
return len(self.annotations)
|
| 398 |
+
|
| 399 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
| 400 |
+
"""Get item from dataset"""
|
| 401 |
+
ann = self.annotations[idx]
|
| 402 |
+
|
| 403 |
+
# Load image
|
| 404 |
+
img_path = self.data_path / ann['image_path']
|
| 405 |
+
image = cv2.imread(str(img_path))
|
| 406 |
+
|
| 407 |
+
# Create mask from bounding boxes
|
| 408 |
+
h, w = image.shape[:2]
|
| 409 |
+
mask = np.zeros((h, w), dtype=np.uint8)
|
| 410 |
+
|
| 411 |
+
for bbox in ann.get('bboxes', []):
|
| 412 |
+
x, y, w_box, h_box = bbox
|
| 413 |
+
mask[y:y+h_box, x:x+w_box] = 1
|
| 414 |
+
|
| 415 |
+
# Preprocess
|
| 416 |
+
image, mask = self.preprocessor(image, mask)
|
| 417 |
+
|
| 418 |
+
# Augment
|
| 419 |
+
augmented = self.augmentation(image, mask)
|
| 420 |
+
image = augmented['image']
|
| 421 |
+
mask = augmented['mask']
|
| 422 |
+
|
| 423 |
+
# Metadata
|
| 424 |
+
metadata = {
|
| 425 |
+
'dataset': self.dataset_name,
|
| 426 |
+
'image_path': str(img_path),
|
| 427 |
+
'has_pixel_mask': True
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
return image, mask, metadata
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
class FCDDataset(DocTamperDataset):
|
| 434 |
+
"""FCD (Forgery Classification Dataset) loader - inherits from DocTamperDataset"""
|
| 435 |
+
|
| 436 |
+
def __init__(self, config, split: str = 'train'):
|
| 437 |
+
self.config = config
|
| 438 |
+
self.split = split
|
| 439 |
+
self.dataset_name = 'fcd'
|
| 440 |
+
|
| 441 |
+
# Get dataset path from config
|
| 442 |
+
dataset_config = config.get_dataset_config(self.dataset_name)
|
| 443 |
+
self.data_path = Path(dataset_config['path'])
|
| 444 |
+
self.lmdb_path = str(self.data_path)
|
| 445 |
+
|
| 446 |
+
if not Path(self.lmdb_path).exists():
|
| 447 |
+
raise FileNotFoundError(f"LMDB folder not found: {self.lmdb_path}")
|
| 448 |
+
|
| 449 |
+
# Get total count
|
| 450 |
+
temp_env = lmdb.open(self.lmdb_path, readonly=True, lock=False)
|
| 451 |
+
with temp_env.begin() as txn:
|
| 452 |
+
stat = txn.stat()
|
| 453 |
+
self.length = stat['entries'] // 2 # Half are images, half are labels
|
| 454 |
+
temp_env.close()
|
| 455 |
+
|
| 456 |
+
self._env = None
|
| 457 |
+
|
| 458 |
+
# FCD is small, no chunking needed
|
| 459 |
+
self.chunk_start = 0
|
| 460 |
+
self.chunk_end = self.length
|
| 461 |
+
self.chunk_length = self.length
|
| 462 |
+
|
| 463 |
+
print(f"FCD {split}: {self.length} samples")
|
| 464 |
+
|
| 465 |
+
# Initialize preprocessor and augmentation
|
| 466 |
+
self.preprocessor = DocumentPreprocessor(config, self.dataset_name)
|
| 467 |
+
self.augmentation = DatasetAwareAugmentation(
|
| 468 |
+
config,
|
| 469 |
+
self.dataset_name,
|
| 470 |
+
is_training=(split == 'train')
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
class SCDDataset(DocTamperDataset):
|
| 475 |
+
"""SCD (Splicing Classification Dataset) loader - inherits from DocTamperDataset"""
|
| 476 |
+
|
| 477 |
+
def __init__(self, config, split: str = 'train'):
|
| 478 |
+
self.config = config
|
| 479 |
+
self.split = split
|
| 480 |
+
self.dataset_name = 'scd'
|
| 481 |
+
|
| 482 |
+
# Get dataset path from config
|
| 483 |
+
dataset_config = config.get_dataset_config(self.dataset_name)
|
| 484 |
+
self.data_path = Path(dataset_config['path'])
|
| 485 |
+
self.lmdb_path = str(self.data_path)
|
| 486 |
+
|
| 487 |
+
if not Path(self.lmdb_path).exists():
|
| 488 |
+
raise FileNotFoundError(f"LMDB folder not found: {self.lmdb_path}")
|
| 489 |
+
|
| 490 |
+
# Get total count
|
| 491 |
+
temp_env = lmdb.open(self.lmdb_path, readonly=True, lock=False)
|
| 492 |
+
with temp_env.begin() as txn:
|
| 493 |
+
stat = txn.stat()
|
| 494 |
+
self.length = stat['entries'] // 2 # Half are images, half are labels
|
| 495 |
+
temp_env.close()
|
| 496 |
+
|
| 497 |
+
self._env = None
|
| 498 |
+
|
| 499 |
+
# SCD is medium-sized, no chunking needed
|
| 500 |
+
self.chunk_start = 0
|
| 501 |
+
self.chunk_end = self.length
|
| 502 |
+
self.chunk_length = self.length
|
| 503 |
+
|
| 504 |
+
print(f"SCD {split}: {self.length} samples")
|
| 505 |
+
|
| 506 |
+
# Initialize preprocessor and augmentation
|
| 507 |
+
self.preprocessor = DocumentPreprocessor(config, self.dataset_name)
|
| 508 |
+
self.augmentation = DatasetAwareAugmentation(
|
| 509 |
+
config,
|
| 510 |
+
self.dataset_name,
|
| 511 |
+
is_training=(split == 'train')
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def get_dataset(config, dataset_name: str, split: str = 'train', **kwargs) -> Dataset:
|
| 516 |
+
"""
|
| 517 |
+
Factory function to get dataset
|
| 518 |
+
|
| 519 |
+
Args:
|
| 520 |
+
config: Configuration object
|
| 521 |
+
dataset_name: Dataset name
|
| 522 |
+
split: Data split
|
| 523 |
+
**kwargs: Additional arguments (e.g., chunk_start, chunk_end)
|
| 524 |
+
|
| 525 |
+
Returns:
|
| 526 |
+
Dataset instance
|
| 527 |
+
"""
|
| 528 |
+
if dataset_name == 'doctamper':
|
| 529 |
+
return DocTamperDataset(config, split, **kwargs)
|
| 530 |
+
elif dataset_name == 'rtm':
|
| 531 |
+
return RTMDataset(config, split)
|
| 532 |
+
elif dataset_name == 'casia':
|
| 533 |
+
return CASIADataset(config, split)
|
| 534 |
+
elif dataset_name == 'receipts':
|
| 535 |
+
return ReceiptsDataset(config, split)
|
| 536 |
+
elif dataset_name == 'fcd':
|
| 537 |
+
return FCDDataset(config, split)
|
| 538 |
+
elif dataset_name == 'scd':
|
| 539 |
+
return SCDDataset(config, split)
|
| 540 |
+
else:
|
| 541 |
+
raise ValueError(f"Unknown dataset: {dataset_name}")
|
src/data/preprocessing.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset-aware preprocessing for document forgery detection
|
| 3 |
+
Implements Critical Fix #1: Dataset-Aware Preprocessing
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
from typing import Tuple, Optional
|
| 9 |
+
import pywt
|
| 10 |
+
from scipy import ndimage
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DocumentPreprocessor:
|
| 14 |
+
"""Dataset-aware document preprocessing"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, config, dataset_name: str):
|
| 17 |
+
"""
|
| 18 |
+
Initialize preprocessor
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
config: Configuration object
|
| 22 |
+
dataset_name: Name of dataset (for dataset-aware processing)
|
| 23 |
+
"""
|
| 24 |
+
self.config = config
|
| 25 |
+
self.dataset_name = dataset_name
|
| 26 |
+
self.image_size = config.get('data.image_size', 384)
|
| 27 |
+
self.noise_threshold = config.get('preprocessing.noise_threshold', 15.0)
|
| 28 |
+
|
| 29 |
+
# Dataset-aware flags (Critical Fix #1)
|
| 30 |
+
self.skip_deskew = config.should_skip_deskew(dataset_name)
|
| 31 |
+
self.skip_denoising = config.should_skip_denoising(dataset_name)
|
| 32 |
+
|
| 33 |
+
def __call__(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
| 34 |
+
"""
|
| 35 |
+
Apply preprocessing pipeline
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
image: Input image (H, W, 3)
|
| 39 |
+
mask: Optional ground truth mask (H, W)
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Preprocessed image and mask
|
| 43 |
+
"""
|
| 44 |
+
# 1. Convert to RGB
|
| 45 |
+
if len(image.shape) == 2:
|
| 46 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
| 47 |
+
elif image.shape[2] == 4:
|
| 48 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
|
| 49 |
+
elif image.shape[2] == 3:
|
| 50 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 51 |
+
|
| 52 |
+
# 2. Deskew (dataset-aware)
|
| 53 |
+
if not self.skip_deskew:
|
| 54 |
+
image, mask = self._deskew(image, mask)
|
| 55 |
+
|
| 56 |
+
# 3. Resize
|
| 57 |
+
image, mask = self._resize(image, mask)
|
| 58 |
+
|
| 59 |
+
# 4. Normalize
|
| 60 |
+
image = self._normalize(image)
|
| 61 |
+
|
| 62 |
+
# 5. Conditional denoising (dataset-aware)
|
| 63 |
+
if not self.skip_denoising:
|
| 64 |
+
noise_level = self._estimate_noise(image)
|
| 65 |
+
if noise_level > self.noise_threshold:
|
| 66 |
+
image = self._denoise(image)
|
| 67 |
+
|
| 68 |
+
return image, mask
|
| 69 |
+
|
| 70 |
+
def _deskew(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
| 71 |
+
"""
|
| 72 |
+
Deskew document image
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
image: Input image
|
| 76 |
+
mask: Optional mask
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
Deskewed image and mask
|
| 80 |
+
"""
|
| 81 |
+
# Convert to grayscale for angle detection
|
| 82 |
+
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
| 83 |
+
|
| 84 |
+
# Detect edges
|
| 85 |
+
edges = cv2.Canny(gray, 50, 150, apertureSize=3)
|
| 86 |
+
|
| 87 |
+
# Detect lines using Hough transform
|
| 88 |
+
lines = cv2.HoughLines(edges, 1, np.pi / 180, 200)
|
| 89 |
+
|
| 90 |
+
if lines is not None and len(lines) > 0:
|
| 91 |
+
# Calculate dominant angle
|
| 92 |
+
angles = []
|
| 93 |
+
for rho, theta in lines[:, 0]:
|
| 94 |
+
angle = (theta * 180 / np.pi) - 90
|
| 95 |
+
angles.append(angle)
|
| 96 |
+
|
| 97 |
+
# Use median angle
|
| 98 |
+
angle = np.median(angles)
|
| 99 |
+
|
| 100 |
+
# Only deskew if angle is significant (> 0.5 degrees)
|
| 101 |
+
if abs(angle) > 0.5:
|
| 102 |
+
# Get rotation matrix
|
| 103 |
+
h, w = image.shape[:2]
|
| 104 |
+
center = (w // 2, h // 2)
|
| 105 |
+
M = cv2.getRotationMatrix2D(center, angle, 1.0)
|
| 106 |
+
|
| 107 |
+
# Rotate image
|
| 108 |
+
image = cv2.warpAffine(image, M, (w, h),
|
| 109 |
+
flags=cv2.INTER_CUBIC,
|
| 110 |
+
borderMode=cv2.BORDER_REPLICATE)
|
| 111 |
+
|
| 112 |
+
# Rotate mask if provided
|
| 113 |
+
if mask is not None:
|
| 114 |
+
mask = cv2.warpAffine(mask, M, (w, h),
|
| 115 |
+
flags=cv2.INTER_NEAREST,
|
| 116 |
+
borderMode=cv2.BORDER_CONSTANT,
|
| 117 |
+
borderValue=0)
|
| 118 |
+
|
| 119 |
+
return image, mask
|
| 120 |
+
|
| 121 |
+
def _resize(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
| 122 |
+
"""
|
| 123 |
+
Resize image and mask to target size
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
image: Input image
|
| 127 |
+
mask: Optional mask
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
Resized image and mask
|
| 131 |
+
"""
|
| 132 |
+
target_size = (self.image_size, self.image_size)
|
| 133 |
+
|
| 134 |
+
# Resize image
|
| 135 |
+
image = cv2.resize(image, target_size, interpolation=cv2.INTER_CUBIC)
|
| 136 |
+
|
| 137 |
+
# Resize mask if provided
|
| 138 |
+
if mask is not None:
|
| 139 |
+
mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
|
| 140 |
+
|
| 141 |
+
return image, mask
|
| 142 |
+
|
| 143 |
+
def _normalize(self, image: np.ndarray) -> np.ndarray:
|
| 144 |
+
"""
|
| 145 |
+
Normalize pixel values to [0, 1]
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
image: Input image
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
Normalized image
|
| 152 |
+
"""
|
| 153 |
+
return image.astype(np.float32) / 255.0
|
| 154 |
+
|
| 155 |
+
def _estimate_noise(self, image: np.ndarray) -> float:
|
| 156 |
+
"""
|
| 157 |
+
Estimate noise level using Laplacian variance and wavelet-based estimation
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
image: Input image (normalized)
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
Estimated noise level
|
| 164 |
+
"""
|
| 165 |
+
# Convert to grayscale for noise estimation
|
| 166 |
+
if len(image.shape) == 3:
|
| 167 |
+
gray = cv2.cvtColor((image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
|
| 168 |
+
else:
|
| 169 |
+
gray = (image * 255).astype(np.uint8)
|
| 170 |
+
|
| 171 |
+
# Method 1: Laplacian variance
|
| 172 |
+
laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var()
|
| 173 |
+
|
| 174 |
+
# Method 2: Wavelet-based noise estimation
|
| 175 |
+
coeffs = pywt.dwt2(gray, 'db1')
|
| 176 |
+
_, (cH, cV, cD) = coeffs
|
| 177 |
+
sigma = np.median(np.abs(cD)) / 0.6745
|
| 178 |
+
|
| 179 |
+
# Combine both estimates
|
| 180 |
+
noise_level = (laplacian_var + sigma) / 2.0
|
| 181 |
+
|
| 182 |
+
return noise_level
|
| 183 |
+
|
| 184 |
+
def _denoise(self, image: np.ndarray) -> np.ndarray:
|
| 185 |
+
"""
|
| 186 |
+
Apply conditional denoising
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
image: Input image (normalized)
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
Denoised image
|
| 193 |
+
"""
|
| 194 |
+
# Convert to uint8 for filtering
|
| 195 |
+
image_uint8 = (image * 255).astype(np.uint8)
|
| 196 |
+
|
| 197 |
+
# Apply median filter (3x3)
|
| 198 |
+
median_filtered = cv2.medianBlur(image_uint8, 3)
|
| 199 |
+
|
| 200 |
+
# Apply Gaussian filter (σ ≤ 0.8)
|
| 201 |
+
gaussian_filtered = cv2.GaussianBlur(median_filtered, (3, 3), 0.8)
|
| 202 |
+
|
| 203 |
+
# Convert back to float32
|
| 204 |
+
denoised = gaussian_filtered.astype(np.float32) / 255.0
|
| 205 |
+
|
| 206 |
+
return denoised
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def preprocess_image(image: np.ndarray,
|
| 210 |
+
mask: Optional[np.ndarray] = None,
|
| 211 |
+
config = None,
|
| 212 |
+
dataset_name: str = 'default') -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
| 213 |
+
"""
|
| 214 |
+
Convenience function for preprocessing
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
image: Input image
|
| 218 |
+
mask: Optional mask
|
| 219 |
+
config: Configuration object
|
| 220 |
+
dataset_name: Dataset name
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
Preprocessed image and mask
|
| 224 |
+
"""
|
| 225 |
+
preprocessor = DocumentPreprocessor(config, dataset_name)
|
| 226 |
+
return preprocessor(image, mask)
|
src/features/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Features module"""
|
| 2 |
+
|
| 3 |
+
from .feature_extraction import (
|
| 4 |
+
DeepFeatureExtractor,
|
| 5 |
+
StatisticalFeatureExtractor,
|
| 6 |
+
FrequencyFeatureExtractor,
|
| 7 |
+
NoiseELAFeatureExtractor,
|
| 8 |
+
OCRFeatureExtractor,
|
| 9 |
+
HybridFeatureExtractor,
|
| 10 |
+
get_feature_extractor
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
from .region_extraction import (
|
| 14 |
+
MaskRefiner,
|
| 15 |
+
RegionExtractor,
|
| 16 |
+
get_mask_refiner,
|
| 17 |
+
get_region_extractor
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
__all__ = [
|
| 21 |
+
'DeepFeatureExtractor',
|
| 22 |
+
'StatisticalFeatureExtractor',
|
| 23 |
+
'FrequencyFeatureExtractor',
|
| 24 |
+
'NoiseELAFeatureExtractor',
|
| 25 |
+
'OCRFeatureExtractor',
|
| 26 |
+
'HybridFeatureExtractor',
|
| 27 |
+
'get_feature_extractor',
|
| 28 |
+
'MaskRefiner',
|
| 29 |
+
'RegionExtractor',
|
| 30 |
+
'get_mask_refiner',
|
| 31 |
+
'get_region_extractor'
|
| 32 |
+
]
|
src/features/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (691 Bytes). View file
|
|
|
src/features/__pycache__/feature_extraction.cpython-312.pyc
ADDED
|
Binary file (22.6 kB). View file
|
|
|
src/features/__pycache__/region_extraction.cpython-312.pyc
ADDED
|
Binary file (8.93 kB). View file
|
|
|
src/features/feature_extraction.py
ADDED
|
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hybrid feature extraction for forgery detection
|
| 3 |
+
Implements Critical Fix #5: Feature Group Gating
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from typing import Dict, List, Optional, Tuple
|
| 11 |
+
from scipy import ndimage
|
| 12 |
+
from scipy.fftpack import dct
|
| 13 |
+
import pywt
|
| 14 |
+
from skimage.measure import regionprops, label
|
| 15 |
+
from skimage.filters import sobel
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DeepFeatureExtractor:
|
| 19 |
+
"""Extract deep features from decoder feature maps"""
|
| 20 |
+
|
| 21 |
+
def __init__(self):
|
| 22 |
+
"""Initialize deep feature extractor"""
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
def extract(self,
|
| 26 |
+
decoder_features: List[torch.Tensor],
|
| 27 |
+
region_mask: np.ndarray) -> np.ndarray:
|
| 28 |
+
"""
|
| 29 |
+
Extract deep features using Global Average Pooling
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
decoder_features: List of decoder feature tensors
|
| 33 |
+
region_mask: Binary region mask (H, W)
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
Deep feature vector
|
| 37 |
+
"""
|
| 38 |
+
features = []
|
| 39 |
+
|
| 40 |
+
for feat in decoder_features:
|
| 41 |
+
# Ensure on CPU and numpy
|
| 42 |
+
if isinstance(feat, torch.Tensor):
|
| 43 |
+
feat = feat.detach().cpu().numpy()
|
| 44 |
+
|
| 45 |
+
# feat shape: (B, C, H, W) or (C, H, W)
|
| 46 |
+
if feat.ndim == 4:
|
| 47 |
+
feat = feat[0] # Take first batch
|
| 48 |
+
|
| 49 |
+
# Resize mask to feature size
|
| 50 |
+
h, w = feat.shape[1:]
|
| 51 |
+
mask_resized = cv2.resize(region_mask.astype(np.float32), (w, h))
|
| 52 |
+
mask_resized = mask_resized > 0.5
|
| 53 |
+
|
| 54 |
+
# Masked Global Average Pooling
|
| 55 |
+
if mask_resized.sum() > 0:
|
| 56 |
+
for c in range(feat.shape[0]):
|
| 57 |
+
channel_feat = feat[c]
|
| 58 |
+
masked_mean = channel_feat[mask_resized].mean()
|
| 59 |
+
features.append(masked_mean)
|
| 60 |
+
else:
|
| 61 |
+
# Fallback: use global average
|
| 62 |
+
features.extend(feat.mean(axis=(1, 2)).tolist())
|
| 63 |
+
|
| 64 |
+
return np.array(features, dtype=np.float32)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class StatisticalFeatureExtractor:
|
| 68 |
+
"""Extract statistical and shape features from regions"""
|
| 69 |
+
|
| 70 |
+
def __init__(self):
|
| 71 |
+
"""Initialize statistical feature extractor"""
|
| 72 |
+
pass
|
| 73 |
+
|
| 74 |
+
def extract(self,
|
| 75 |
+
image: np.ndarray,
|
| 76 |
+
region_mask: np.ndarray) -> np.ndarray:
|
| 77 |
+
"""
|
| 78 |
+
Extract statistical and shape features
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
image: Input image (H, W, 3) normalized [0, 1]
|
| 82 |
+
region_mask: Binary region mask (H, W)
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Statistical feature vector
|
| 86 |
+
"""
|
| 87 |
+
features = []
|
| 88 |
+
|
| 89 |
+
# Label the mask
|
| 90 |
+
labeled_mask = label(region_mask)
|
| 91 |
+
props = regionprops(labeled_mask)
|
| 92 |
+
|
| 93 |
+
if len(props) > 0:
|
| 94 |
+
prop = props[0]
|
| 95 |
+
|
| 96 |
+
# Area and perimeter
|
| 97 |
+
features.append(prop.area)
|
| 98 |
+
features.append(prop.perimeter)
|
| 99 |
+
|
| 100 |
+
# Aspect ratio
|
| 101 |
+
if prop.major_axis_length > 0:
|
| 102 |
+
aspect_ratio = prop.minor_axis_length / prop.major_axis_length
|
| 103 |
+
else:
|
| 104 |
+
aspect_ratio = 1.0
|
| 105 |
+
features.append(aspect_ratio)
|
| 106 |
+
|
| 107 |
+
# Solidity
|
| 108 |
+
features.append(prop.solidity)
|
| 109 |
+
|
| 110 |
+
# Eccentricity
|
| 111 |
+
features.append(prop.eccentricity)
|
| 112 |
+
|
| 113 |
+
# Entropy (using intensity)
|
| 114 |
+
if len(image.shape) == 3:
|
| 115 |
+
gray = cv2.cvtColor((image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
|
| 116 |
+
else:
|
| 117 |
+
gray = (image * 255).astype(np.uint8)
|
| 118 |
+
|
| 119 |
+
region_pixels = gray[region_mask > 0]
|
| 120 |
+
if len(region_pixels) > 0:
|
| 121 |
+
hist, _ = np.histogram(region_pixels, bins=256, range=(0, 256))
|
| 122 |
+
hist = hist / hist.sum() + 1e-8
|
| 123 |
+
entropy = -np.sum(hist * np.log2(hist + 1e-8))
|
| 124 |
+
else:
|
| 125 |
+
entropy = 0.0
|
| 126 |
+
features.append(entropy)
|
| 127 |
+
else:
|
| 128 |
+
# Default values
|
| 129 |
+
features.extend([0, 0, 1.0, 0, 0, 0])
|
| 130 |
+
|
| 131 |
+
return np.array(features, dtype=np.float32)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class FrequencyFeatureExtractor:
|
| 135 |
+
"""Extract frequency-domain features"""
|
| 136 |
+
|
| 137 |
+
def __init__(self):
|
| 138 |
+
"""Initialize frequency feature extractor"""
|
| 139 |
+
pass
|
| 140 |
+
|
| 141 |
+
def extract(self,
|
| 142 |
+
image: np.ndarray,
|
| 143 |
+
region_mask: np.ndarray) -> np.ndarray:
|
| 144 |
+
"""
|
| 145 |
+
Extract frequency-domain features (DCT, wavelet)
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
image: Input image (H, W, 3) normalized [0, 1]
|
| 149 |
+
region_mask: Binary region mask (H, W)
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
Frequency feature vector
|
| 153 |
+
"""
|
| 154 |
+
features = []
|
| 155 |
+
|
| 156 |
+
# Convert to grayscale
|
| 157 |
+
if len(image.shape) == 3:
|
| 158 |
+
gray = cv2.cvtColor((image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
|
| 159 |
+
else:
|
| 160 |
+
gray = (image * 255).astype(np.uint8)
|
| 161 |
+
|
| 162 |
+
# Get region bounding box
|
| 163 |
+
coords = np.where(region_mask > 0)
|
| 164 |
+
if len(coords[0]) == 0:
|
| 165 |
+
return np.zeros(13, dtype=np.float32)
|
| 166 |
+
|
| 167 |
+
y_min, y_max = coords[0].min(), coords[0].max()
|
| 168 |
+
x_min, x_max = coords[1].min(), coords[1].max()
|
| 169 |
+
|
| 170 |
+
# Crop region
|
| 171 |
+
region = gray[y_min:y_max+1, x_min:x_max+1].astype(np.float32)
|
| 172 |
+
|
| 173 |
+
if region.size == 0:
|
| 174 |
+
return np.zeros(13, dtype=np.float32)
|
| 175 |
+
|
| 176 |
+
# DCT coefficients
|
| 177 |
+
try:
|
| 178 |
+
dct_coeffs = dct(dct(region, axis=0, norm='ortho'), axis=1, norm='ortho')
|
| 179 |
+
|
| 180 |
+
# Mean and std of DCT coefficients
|
| 181 |
+
features.append(np.mean(np.abs(dct_coeffs)))
|
| 182 |
+
features.append(np.std(dct_coeffs))
|
| 183 |
+
|
| 184 |
+
# High-frequency energy (bottom-right quadrant)
|
| 185 |
+
h, w = dct_coeffs.shape
|
| 186 |
+
high_freq = dct_coeffs[h//2:, w//2:]
|
| 187 |
+
features.append(np.sum(np.abs(high_freq)) / (high_freq.size + 1e-8))
|
| 188 |
+
except Exception:
|
| 189 |
+
features.extend([0, 0, 0])
|
| 190 |
+
|
| 191 |
+
# Wavelet features
|
| 192 |
+
try:
|
| 193 |
+
coeffs = pywt.dwt2(region, 'db1')
|
| 194 |
+
cA, (cH, cV, cD) = coeffs
|
| 195 |
+
|
| 196 |
+
# Energy in each sub-band
|
| 197 |
+
features.append(np.sum(cA ** 2) / (cA.size + 1e-8))
|
| 198 |
+
features.append(np.sum(cH ** 2) / (cH.size + 1e-8))
|
| 199 |
+
features.append(np.sum(cV ** 2) / (cV.size + 1e-8))
|
| 200 |
+
features.append(np.sum(cD ** 2) / (cD.size + 1e-8))
|
| 201 |
+
|
| 202 |
+
# Wavelet entropy
|
| 203 |
+
for coeff in [cH, cV, cD]:
|
| 204 |
+
coeff_flat = np.abs(coeff.flatten())
|
| 205 |
+
if coeff_flat.sum() > 0:
|
| 206 |
+
coeff_norm = coeff_flat / coeff_flat.sum()
|
| 207 |
+
entropy = -np.sum(coeff_norm * np.log2(coeff_norm + 1e-8))
|
| 208 |
+
else:
|
| 209 |
+
entropy = 0.0
|
| 210 |
+
features.append(entropy)
|
| 211 |
+
except Exception:
|
| 212 |
+
features.extend([0, 0, 0, 0, 0, 0, 0])
|
| 213 |
+
|
| 214 |
+
return np.array(features, dtype=np.float32)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class NoiseELAFeatureExtractor:
|
| 218 |
+
"""Extract noise and Error Level Analysis features"""
|
| 219 |
+
|
| 220 |
+
def __init__(self, quality: int = 90):
|
| 221 |
+
"""
|
| 222 |
+
Initialize noise/ELA extractor
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
quality: JPEG quality for ELA
|
| 226 |
+
"""
|
| 227 |
+
self.quality = quality
|
| 228 |
+
|
| 229 |
+
def extract(self,
|
| 230 |
+
image: np.ndarray,
|
| 231 |
+
region_mask: np.ndarray) -> np.ndarray:
|
| 232 |
+
"""
|
| 233 |
+
Extract noise and ELA features
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
image: Input image (H, W, 3) normalized [0, 1]
|
| 237 |
+
region_mask: Binary region mask (H, W)
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
Noise/ELA feature vector
|
| 241 |
+
"""
|
| 242 |
+
features = []
|
| 243 |
+
|
| 244 |
+
# Convert to uint8
|
| 245 |
+
img_uint8 = (image * 255).astype(np.uint8)
|
| 246 |
+
|
| 247 |
+
# Error Level Analysis
|
| 248 |
+
# Compress and compute difference
|
| 249 |
+
encode_param = [cv2.IMWRITE_JPEG_QUALITY, self.quality]
|
| 250 |
+
_, encoded = cv2.imencode('.jpg', img_uint8, encode_param)
|
| 251 |
+
recompressed = cv2.imdecode(encoded, cv2.IMREAD_COLOR)
|
| 252 |
+
|
| 253 |
+
ela = np.abs(img_uint8.astype(np.float32) - recompressed.astype(np.float32))
|
| 254 |
+
|
| 255 |
+
# ELA features within region
|
| 256 |
+
ela_region = ela[region_mask > 0]
|
| 257 |
+
if len(ela_region) > 0:
|
| 258 |
+
features.append(np.mean(ela_region)) # ELA mean
|
| 259 |
+
features.append(np.var(ela_region)) # ELA variance
|
| 260 |
+
features.append(np.max(ela_region)) # ELA max
|
| 261 |
+
else:
|
| 262 |
+
features.extend([0, 0, 0])
|
| 263 |
+
|
| 264 |
+
# Noise residual (using median filter)
|
| 265 |
+
if len(image.shape) == 3:
|
| 266 |
+
gray = cv2.cvtColor(img_uint8, cv2.COLOR_RGB2GRAY)
|
| 267 |
+
else:
|
| 268 |
+
gray = img_uint8
|
| 269 |
+
|
| 270 |
+
median_filtered = cv2.medianBlur(gray, 3)
|
| 271 |
+
noise_residual = np.abs(gray.astype(np.float32) - median_filtered.astype(np.float32))
|
| 272 |
+
|
| 273 |
+
residual_region = noise_residual[region_mask > 0]
|
| 274 |
+
if len(residual_region) > 0:
|
| 275 |
+
features.append(np.mean(residual_region))
|
| 276 |
+
features.append(np.var(residual_region))
|
| 277 |
+
else:
|
| 278 |
+
features.extend([0, 0])
|
| 279 |
+
|
| 280 |
+
return np.array(features, dtype=np.float32)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class OCRFeatureExtractor:
|
| 284 |
+
"""
|
| 285 |
+
Extract OCR-based consistency features
|
| 286 |
+
Only for text documents (Feature Gating - Critical Fix #5)
|
| 287 |
+
"""
|
| 288 |
+
|
| 289 |
+
def __init__(self):
|
| 290 |
+
"""Initialize OCR feature extractor"""
|
| 291 |
+
self.ocr_available = False
|
| 292 |
+
|
| 293 |
+
try:
|
| 294 |
+
import easyocr
|
| 295 |
+
self.reader = easyocr.Reader(['en'], gpu=True)
|
| 296 |
+
self.ocr_available = True
|
| 297 |
+
except Exception:
|
| 298 |
+
print("Warning: EasyOCR not available, OCR features disabled")
|
| 299 |
+
|
| 300 |
+
def extract(self,
|
| 301 |
+
image: np.ndarray,
|
| 302 |
+
region_mask: np.ndarray) -> np.ndarray:
|
| 303 |
+
"""
|
| 304 |
+
Extract OCR consistency features
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
image: Input image (H, W, 3) normalized [0, 1]
|
| 308 |
+
region_mask: Binary region mask (H, W)
|
| 309 |
+
|
| 310 |
+
Returns:
|
| 311 |
+
OCR feature vector (or zeros if not text document)
|
| 312 |
+
"""
|
| 313 |
+
features = []
|
| 314 |
+
|
| 315 |
+
if not self.ocr_available:
|
| 316 |
+
return np.zeros(6, dtype=np.float32)
|
| 317 |
+
|
| 318 |
+
# Convert to uint8
|
| 319 |
+
img_uint8 = (image * 255).astype(np.uint8)
|
| 320 |
+
|
| 321 |
+
# Get region bounding box
|
| 322 |
+
coords = np.where(region_mask > 0)
|
| 323 |
+
if len(coords[0]) == 0:
|
| 324 |
+
return np.zeros(6, dtype=np.float32)
|
| 325 |
+
|
| 326 |
+
y_min, y_max = coords[0].min(), coords[0].max()
|
| 327 |
+
x_min, x_max = coords[1].min(), coords[1].max()
|
| 328 |
+
|
| 329 |
+
# Crop region
|
| 330 |
+
region = img_uint8[y_min:y_max+1, x_min:x_max+1]
|
| 331 |
+
|
| 332 |
+
try:
|
| 333 |
+
# OCR on region
|
| 334 |
+
results = self.reader.readtext(region)
|
| 335 |
+
|
| 336 |
+
if len(results) > 0:
|
| 337 |
+
# Confidence deviation
|
| 338 |
+
confidences = [r[2] for r in results]
|
| 339 |
+
features.append(np.mean(confidences))
|
| 340 |
+
features.append(np.std(confidences))
|
| 341 |
+
|
| 342 |
+
# Character spacing analysis
|
| 343 |
+
bbox_widths = [abs(r[0][1][0] - r[0][0][0]) for r in results]
|
| 344 |
+
if len(bbox_widths) > 1:
|
| 345 |
+
features.append(np.std(bbox_widths) / (np.mean(bbox_widths) + 1e-8))
|
| 346 |
+
else:
|
| 347 |
+
features.append(0.0)
|
| 348 |
+
|
| 349 |
+
# Text density
|
| 350 |
+
features.append(len(results) / (region.shape[0] * region.shape[1] + 1e-8))
|
| 351 |
+
|
| 352 |
+
# Stroke width variation (using edge detection)
|
| 353 |
+
gray_region = cv2.cvtColor(region, cv2.COLOR_RGB2GRAY)
|
| 354 |
+
edges = sobel(gray_region)
|
| 355 |
+
features.append(np.mean(edges))
|
| 356 |
+
features.append(np.std(edges))
|
| 357 |
+
else:
|
| 358 |
+
features.extend([0, 0, 0, 0, 0, 0])
|
| 359 |
+
except Exception:
|
| 360 |
+
features.extend([0, 0, 0, 0, 0, 0])
|
| 361 |
+
|
| 362 |
+
return np.array(features, dtype=np.float32)
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
class HybridFeatureExtractor:
|
| 366 |
+
"""
|
| 367 |
+
Complete hybrid feature extraction
|
| 368 |
+
Implements Critical Fix #5: Feature Group Gating
|
| 369 |
+
"""
|
| 370 |
+
|
| 371 |
+
def __init__(self, config, is_text_document: bool = True):
|
| 372 |
+
"""
|
| 373 |
+
Initialize hybrid feature extractor
|
| 374 |
+
|
| 375 |
+
Args:
|
| 376 |
+
config: Configuration object
|
| 377 |
+
is_text_document: Whether input is text document (for OCR gating)
|
| 378 |
+
"""
|
| 379 |
+
self.config = config
|
| 380 |
+
self.is_text_document = is_text_document
|
| 381 |
+
|
| 382 |
+
# Initialize extractors
|
| 383 |
+
self.deep_extractor = DeepFeatureExtractor()
|
| 384 |
+
self.stat_extractor = StatisticalFeatureExtractor()
|
| 385 |
+
self.freq_extractor = FrequencyFeatureExtractor()
|
| 386 |
+
self.noise_extractor = NoiseELAFeatureExtractor()
|
| 387 |
+
|
| 388 |
+
# Critical Fix #5: OCR only for text documents
|
| 389 |
+
if is_text_document and config.get('features.ocr.enabled', True):
|
| 390 |
+
self.ocr_extractor = OCRFeatureExtractor()
|
| 391 |
+
else:
|
| 392 |
+
self.ocr_extractor = None
|
| 393 |
+
|
| 394 |
+
def extract(self,
|
| 395 |
+
image: np.ndarray,
|
| 396 |
+
region_mask: np.ndarray,
|
| 397 |
+
decoder_features: Optional[List[torch.Tensor]] = None) -> np.ndarray:
|
| 398 |
+
"""
|
| 399 |
+
Extract all hybrid features for a region
|
| 400 |
+
|
| 401 |
+
Args:
|
| 402 |
+
image: Input image (H, W, 3) normalized [0, 1]
|
| 403 |
+
region_mask: Binary region mask (H, W)
|
| 404 |
+
decoder_features: Optional decoder features for deep feature extraction
|
| 405 |
+
|
| 406 |
+
Returns:
|
| 407 |
+
Concatenated feature vector
|
| 408 |
+
"""
|
| 409 |
+
all_features = []
|
| 410 |
+
|
| 411 |
+
# Deep features (if available)
|
| 412 |
+
if decoder_features is not None and self.config.get('features.deep.enabled', True):
|
| 413 |
+
deep_feats = self.deep_extractor.extract(decoder_features, region_mask)
|
| 414 |
+
all_features.append(deep_feats)
|
| 415 |
+
|
| 416 |
+
# Statistical & shape features
|
| 417 |
+
if self.config.get('features.statistical.enabled', True):
|
| 418 |
+
stat_feats = self.stat_extractor.extract(image, region_mask)
|
| 419 |
+
all_features.append(stat_feats)
|
| 420 |
+
|
| 421 |
+
# Frequency-domain features
|
| 422 |
+
if self.config.get('features.frequency.enabled', True):
|
| 423 |
+
freq_feats = self.freq_extractor.extract(image, region_mask)
|
| 424 |
+
all_features.append(freq_feats)
|
| 425 |
+
|
| 426 |
+
# Noise & ELA features
|
| 427 |
+
if self.config.get('features.noise.enabled', True):
|
| 428 |
+
noise_feats = self.noise_extractor.extract(image, region_mask)
|
| 429 |
+
all_features.append(noise_feats)
|
| 430 |
+
|
| 431 |
+
# Critical Fix #5: OCR features only for text documents
|
| 432 |
+
if self.ocr_extractor is not None:
|
| 433 |
+
ocr_feats = self.ocr_extractor.extract(image, region_mask)
|
| 434 |
+
all_features.append(ocr_feats)
|
| 435 |
+
|
| 436 |
+
# Concatenate all features
|
| 437 |
+
if len(all_features) > 0:
|
| 438 |
+
features = np.concatenate(all_features)
|
| 439 |
+
else:
|
| 440 |
+
features = np.array([], dtype=np.float32)
|
| 441 |
+
|
| 442 |
+
# Handle NaN/Inf
|
| 443 |
+
features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
|
| 444 |
+
|
| 445 |
+
return features
|
| 446 |
+
|
| 447 |
+
def get_feature_names(self) -> List[str]:
|
| 448 |
+
"""Get list of feature names for interpretability"""
|
| 449 |
+
names = []
|
| 450 |
+
|
| 451 |
+
if self.config.get('features.deep.enabled', True):
|
| 452 |
+
names.extend([f'deep_{i}' for i in range(256)]) # Approximate
|
| 453 |
+
|
| 454 |
+
if self.config.get('features.statistical.enabled', True):
|
| 455 |
+
names.extend(['area', 'perimeter', 'aspect_ratio',
|
| 456 |
+
'solidity', 'eccentricity', 'entropy'])
|
| 457 |
+
|
| 458 |
+
if self.config.get('features.frequency.enabled', True):
|
| 459 |
+
names.extend(['dct_mean', 'dct_std', 'high_freq_energy',
|
| 460 |
+
'wavelet_cA', 'wavelet_cH', 'wavelet_cV', 'wavelet_cD',
|
| 461 |
+
'wavelet_entropy_H', 'wavelet_entropy_V', 'wavelet_entropy_D'])
|
| 462 |
+
|
| 463 |
+
if self.config.get('features.noise.enabled', True):
|
| 464 |
+
names.extend(['ela_mean', 'ela_var', 'ela_max',
|
| 465 |
+
'noise_residual_mean', 'noise_residual_var'])
|
| 466 |
+
|
| 467 |
+
if self.ocr_extractor is not None:
|
| 468 |
+
names.extend(['ocr_conf_mean', 'ocr_conf_std', 'spacing_irregularity',
|
| 469 |
+
'text_density', 'stroke_mean', 'stroke_std'])
|
| 470 |
+
|
| 471 |
+
return names
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def get_feature_extractor(config, is_text_document: bool = True) -> HybridFeatureExtractor:
|
| 475 |
+
"""
|
| 476 |
+
Factory function to create feature extractor
|
| 477 |
+
|
| 478 |
+
Args:
|
| 479 |
+
config: Configuration object
|
| 480 |
+
is_text_document: Whether input is text document
|
| 481 |
+
|
| 482 |
+
Returns:
|
| 483 |
+
HybridFeatureExtractor instance
|
| 484 |
+
"""
|
| 485 |
+
return HybridFeatureExtractor(config, is_text_document)
|
src/features/region_extraction.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Mask refinement and region extraction
|
| 3 |
+
Implements Critical Fix #3: Adaptive Mask Refinement Thresholds
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
from typing import List, Tuple, Dict, Optional
|
| 9 |
+
from scipy import ndimage
|
| 10 |
+
from skimage.measure import label, regionprops
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class MaskRefiner:
|
| 14 |
+
"""
|
| 15 |
+
Mask refinement with adaptive thresholds
|
| 16 |
+
Implements Critical Fix #3: Dataset-specific minimum region areas
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, config, dataset_name: str = 'default'):
|
| 20 |
+
"""
|
| 21 |
+
Initialize mask refiner
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
config: Configuration object
|
| 25 |
+
dataset_name: Dataset name for adaptive thresholds
|
| 26 |
+
"""
|
| 27 |
+
self.config = config
|
| 28 |
+
self.dataset_name = dataset_name
|
| 29 |
+
|
| 30 |
+
# Get mask refinement parameters
|
| 31 |
+
self.threshold = config.get('mask_refinement.threshold', 0.5)
|
| 32 |
+
self.closing_kernel = config.get('mask_refinement.morphology.closing_kernel', 5)
|
| 33 |
+
self.opening_kernel = config.get('mask_refinement.morphology.opening_kernel', 3)
|
| 34 |
+
|
| 35 |
+
# Critical Fix #3: Adaptive thresholds per dataset
|
| 36 |
+
self.min_region_area = config.get_min_region_area(dataset_name)
|
| 37 |
+
|
| 38 |
+
print(f"MaskRefiner initialized for {dataset_name}")
|
| 39 |
+
print(f"Min region area: {self.min_region_area * 100:.2f}%")
|
| 40 |
+
|
| 41 |
+
def refine(self,
|
| 42 |
+
probability_map: np.ndarray,
|
| 43 |
+
original_size: Tuple[int, int] = None) -> np.ndarray:
|
| 44 |
+
"""
|
| 45 |
+
Refine probability map to binary mask
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
probability_map: Forgery probability map (H, W), values [0, 1]
|
| 49 |
+
original_size: Optional (H, W) to resize mask back to original
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
Refined binary mask (H, W)
|
| 53 |
+
"""
|
| 54 |
+
# Threshold to binary
|
| 55 |
+
binary_mask = (probability_map > self.threshold).astype(np.uint8)
|
| 56 |
+
|
| 57 |
+
# Morphological closing (fill broken strokes)
|
| 58 |
+
closing_kernel = cv2.getStructuringElement(
|
| 59 |
+
cv2.MORPH_RECT,
|
| 60 |
+
(self.closing_kernel, self.closing_kernel)
|
| 61 |
+
)
|
| 62 |
+
binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, closing_kernel)
|
| 63 |
+
|
| 64 |
+
# Morphological opening (remove isolated noise)
|
| 65 |
+
opening_kernel = cv2.getStructuringElement(
|
| 66 |
+
cv2.MORPH_RECT,
|
| 67 |
+
(self.opening_kernel, self.opening_kernel)
|
| 68 |
+
)
|
| 69 |
+
binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, opening_kernel)
|
| 70 |
+
|
| 71 |
+
# Critical Fix #3: Remove small regions with adaptive threshold
|
| 72 |
+
binary_mask = self._remove_small_regions(binary_mask)
|
| 73 |
+
|
| 74 |
+
# Resize to original size if provided
|
| 75 |
+
if original_size is not None:
|
| 76 |
+
binary_mask = cv2.resize(
|
| 77 |
+
binary_mask,
|
| 78 |
+
(original_size[1], original_size[0]), # cv2 uses (W, H)
|
| 79 |
+
interpolation=cv2.INTER_NEAREST
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
return binary_mask
|
| 83 |
+
|
| 84 |
+
def _remove_small_regions(self, mask: np.ndarray) -> np.ndarray:
|
| 85 |
+
"""
|
| 86 |
+
Remove regions smaller than minimum area threshold
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
mask: Binary mask (H, W)
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
Filtered mask
|
| 93 |
+
"""
|
| 94 |
+
# Calculate minimum pixel count
|
| 95 |
+
image_area = mask.shape[0] * mask.shape[1]
|
| 96 |
+
min_pixels = int(image_area * self.min_region_area)
|
| 97 |
+
|
| 98 |
+
# Label connected components
|
| 99 |
+
labeled_mask, num_features = ndimage.label(mask)
|
| 100 |
+
|
| 101 |
+
# Keep only large enough regions
|
| 102 |
+
filtered_mask = np.zeros_like(mask)
|
| 103 |
+
|
| 104 |
+
for region_id in range(1, num_features + 1):
|
| 105 |
+
region_mask = (labeled_mask == region_id)
|
| 106 |
+
region_area = region_mask.sum()
|
| 107 |
+
|
| 108 |
+
if region_area >= min_pixels:
|
| 109 |
+
filtered_mask[region_mask] = 1
|
| 110 |
+
|
| 111 |
+
return filtered_mask
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class RegionExtractor:
|
| 115 |
+
"""
|
| 116 |
+
Extract individual regions from binary mask
|
| 117 |
+
Implements Critical Fix #4: Region Confidence Aggregation
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def __init__(self, config, dataset_name: str = 'default'):
|
| 121 |
+
"""
|
| 122 |
+
Initialize region extractor
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
config: Configuration object
|
| 126 |
+
dataset_name: Dataset name
|
| 127 |
+
"""
|
| 128 |
+
self.config = config
|
| 129 |
+
self.dataset_name = dataset_name
|
| 130 |
+
self.min_region_area = config.get_min_region_area(dataset_name)
|
| 131 |
+
|
| 132 |
+
def extract(self,
|
| 133 |
+
binary_mask: np.ndarray,
|
| 134 |
+
probability_map: np.ndarray,
|
| 135 |
+
original_image: np.ndarray) -> List[Dict]:
|
| 136 |
+
"""
|
| 137 |
+
Extract regions from binary mask
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
binary_mask: Refined binary mask (H, W)
|
| 141 |
+
probability_map: Original probability map (H, W)
|
| 142 |
+
original_image: Original image (H, W, 3)
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
List of region dictionaries with bounding box, mask, image, confidence
|
| 146 |
+
"""
|
| 147 |
+
regions = []
|
| 148 |
+
|
| 149 |
+
# Connected component analysis (8-connectivity)
|
| 150 |
+
labeled_mask = label(binary_mask, connectivity=2)
|
| 151 |
+
props = regionprops(labeled_mask)
|
| 152 |
+
|
| 153 |
+
for region_id, prop in enumerate(props, start=1):
|
| 154 |
+
# Bounding box
|
| 155 |
+
y_min, x_min, y_max, x_max = prop.bbox
|
| 156 |
+
|
| 157 |
+
# Region mask
|
| 158 |
+
region_mask = (labeled_mask == region_id).astype(np.uint8)
|
| 159 |
+
|
| 160 |
+
# Cropped region image
|
| 161 |
+
region_image = original_image[y_min:y_max, x_min:x_max].copy()
|
| 162 |
+
region_mask_cropped = region_mask[y_min:y_max, x_min:x_max]
|
| 163 |
+
|
| 164 |
+
# Critical Fix #4: Region-level confidence aggregation
|
| 165 |
+
region_probs = probability_map[region_mask > 0]
|
| 166 |
+
region_confidence = float(np.mean(region_probs)) if len(region_probs) > 0 else 0.0
|
| 167 |
+
|
| 168 |
+
regions.append({
|
| 169 |
+
'region_id': region_id,
|
| 170 |
+
'bounding_box': [int(x_min), int(y_min),
|
| 171 |
+
int(x_max - x_min), int(y_max - y_min)],
|
| 172 |
+
'area': prop.area,
|
| 173 |
+
'centroid': (int(prop.centroid[1]), int(prop.centroid[0])),
|
| 174 |
+
'region_mask': region_mask,
|
| 175 |
+
'region_mask_cropped': region_mask_cropped,
|
| 176 |
+
'region_image': region_image,
|
| 177 |
+
'confidence': region_confidence,
|
| 178 |
+
'mask_probability_mean': region_confidence
|
| 179 |
+
})
|
| 180 |
+
|
| 181 |
+
return regions
|
| 182 |
+
|
| 183 |
+
def extract_for_casia(self,
|
| 184 |
+
binary_mask: np.ndarray,
|
| 185 |
+
probability_map: np.ndarray,
|
| 186 |
+
original_image: np.ndarray) -> List[Dict]:
|
| 187 |
+
"""
|
| 188 |
+
Critical Fix #6: CASIA handling - treat entire image as one region
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
binary_mask: Binary mask (may be empty for authentic images)
|
| 192 |
+
probability_map: Probability map
|
| 193 |
+
original_image: Original image
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
Single region representing entire image
|
| 197 |
+
"""
|
| 198 |
+
h, w = original_image.shape[:2]
|
| 199 |
+
|
| 200 |
+
# Create single region covering entire image
|
| 201 |
+
region_mask = np.ones((h, w), dtype=np.uint8)
|
| 202 |
+
|
| 203 |
+
# Overall confidence from probability map
|
| 204 |
+
overall_confidence = float(np.mean(probability_map))
|
| 205 |
+
|
| 206 |
+
return [{
|
| 207 |
+
'region_id': 1,
|
| 208 |
+
'bounding_box': [0, 0, w, h],
|
| 209 |
+
'area': h * w,
|
| 210 |
+
'centroid': (w // 2, h // 2),
|
| 211 |
+
'region_mask': region_mask,
|
| 212 |
+
'region_mask_cropped': region_mask,
|
| 213 |
+
'region_image': original_image,
|
| 214 |
+
'confidence': overall_confidence,
|
| 215 |
+
'mask_probability_mean': overall_confidence
|
| 216 |
+
}]
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def get_mask_refiner(config, dataset_name: str = 'default') -> MaskRefiner:
|
| 220 |
+
"""Factory function for mask refiner"""
|
| 221 |
+
return MaskRefiner(config, dataset_name)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def get_region_extractor(config, dataset_name: str = 'default') -> RegionExtractor:
|
| 225 |
+
"""Factory function for region extractor"""
|
| 226 |
+
return RegionExtractor(config, dataset_name)
|
src/inference/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Inference module"""
|
| 2 |
+
|
| 3 |
+
from .pipeline import ForgeryDetectionPipeline, get_pipeline
|
| 4 |
+
|
| 5 |
+
__all__ = ['ForgeryDetectionPipeline', 'get_pipeline']
|
src/inference/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (305 Bytes). View file
|
|
|
src/inference/__pycache__/pipeline.cpython-312.pyc
ADDED
|
Binary file (14.5 kB). View file
|
|
|
src/inference/pipeline.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference pipeline for document forgery detection
|
| 3 |
+
Complete pipeline: Image → Localization → Regions → Classification → Output
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from typing import Dict, List, Optional, Tuple
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import json
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import fitz # PyMuPDF
|
| 14 |
+
|
| 15 |
+
from ..config import get_config
|
| 16 |
+
from ..models import get_model
|
| 17 |
+
from ..features import (
|
| 18 |
+
get_feature_extractor,
|
| 19 |
+
get_mask_refiner,
|
| 20 |
+
get_region_extractor
|
| 21 |
+
)
|
| 22 |
+
from ..training.classifier import get_classifier
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ForgeryDetectionPipeline:
|
| 26 |
+
"""
|
| 27 |
+
Complete inference pipeline for document forgery detection
|
| 28 |
+
|
| 29 |
+
Pipeline:
|
| 30 |
+
1. Input handling (PDF/Image)
|
| 31 |
+
2. Preprocessing
|
| 32 |
+
3. Deep localization
|
| 33 |
+
4. Mask refinement
|
| 34 |
+
5. Region extraction
|
| 35 |
+
6. Feature extraction
|
| 36 |
+
7. Classification
|
| 37 |
+
8. Post-processing
|
| 38 |
+
9. Output generation
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self,
|
| 42 |
+
config,
|
| 43 |
+
model_path: str,
|
| 44 |
+
classifier_path: Optional[str] = None,
|
| 45 |
+
is_text_document: bool = True):
|
| 46 |
+
"""
|
| 47 |
+
Initialize pipeline
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
config: Configuration object
|
| 51 |
+
model_path: Path to localization model checkpoint
|
| 52 |
+
classifier_path: Path to classifier (optional)
|
| 53 |
+
is_text_document: Whether input is text document (for OCR features)
|
| 54 |
+
"""
|
| 55 |
+
self.config = config
|
| 56 |
+
self.is_text_document = is_text_document
|
| 57 |
+
|
| 58 |
+
# Device
|
| 59 |
+
self.device = torch.device(
|
| 60 |
+
'cuda' if torch.cuda.is_available() and config.get('system.device') == 'cuda'
|
| 61 |
+
else 'cpu'
|
| 62 |
+
)
|
| 63 |
+
print(f"Inference device: {self.device}")
|
| 64 |
+
|
| 65 |
+
# Load localization model
|
| 66 |
+
self.model = get_model(config).to(self.device)
|
| 67 |
+
self._load_model(model_path)
|
| 68 |
+
self.model.eval()
|
| 69 |
+
|
| 70 |
+
# Initialize mask refiner
|
| 71 |
+
self.mask_refiner = get_mask_refiner(config, 'default')
|
| 72 |
+
|
| 73 |
+
# Initialize region extractor
|
| 74 |
+
self.region_extractor = get_region_extractor(config, 'default')
|
| 75 |
+
|
| 76 |
+
# Initialize feature extractor
|
| 77 |
+
self.feature_extractor = get_feature_extractor(config, is_text_document)
|
| 78 |
+
|
| 79 |
+
# Load classifier if provided
|
| 80 |
+
if classifier_path:
|
| 81 |
+
self.classifier = get_classifier(config)
|
| 82 |
+
self.classifier.load(classifier_path)
|
| 83 |
+
else:
|
| 84 |
+
self.classifier = None
|
| 85 |
+
|
| 86 |
+
# Confidence threshold
|
| 87 |
+
self.confidence_threshold = config.get('classifier.confidence_threshold', 0.6)
|
| 88 |
+
|
| 89 |
+
# Image size
|
| 90 |
+
self.image_size = config.get('data.image_size', 384)
|
| 91 |
+
|
| 92 |
+
print("Inference pipeline initialized")
|
| 93 |
+
|
| 94 |
+
def _load_model(self, model_path: str):
|
| 95 |
+
"""Load model checkpoint"""
|
| 96 |
+
checkpoint = torch.load(model_path, map_location=self.device)
|
| 97 |
+
|
| 98 |
+
if 'model_state_dict' in checkpoint:
|
| 99 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 100 |
+
else:
|
| 101 |
+
self.model.load_state_dict(checkpoint)
|
| 102 |
+
|
| 103 |
+
print(f"Loaded model from {model_path}")
|
| 104 |
+
|
| 105 |
+
def _load_image(self, input_path: str) -> np.ndarray:
|
| 106 |
+
"""
|
| 107 |
+
Load image from file or PDF
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
input_path: Path to image or PDF
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
Image as numpy array (H, W, 3)
|
| 114 |
+
"""
|
| 115 |
+
path = Path(input_path)
|
| 116 |
+
|
| 117 |
+
if path.suffix.lower() == '.pdf':
|
| 118 |
+
# Rasterize PDF at 300 DPI
|
| 119 |
+
doc = fitz.open(str(path))
|
| 120 |
+
page = doc[0]
|
| 121 |
+
mat = fitz.Matrix(300/72, 300/72) # 300 DPI
|
| 122 |
+
pix = page.get_pixmap(matrix=mat)
|
| 123 |
+
image = np.frombuffer(pix.samples, dtype=np.uint8)
|
| 124 |
+
image = image.reshape(pix.height, pix.width, pix.n)
|
| 125 |
+
if pix.n == 4:
|
| 126 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
| 127 |
+
doc.close()
|
| 128 |
+
else:
|
| 129 |
+
# Load image
|
| 130 |
+
image = cv2.imread(str(path))
|
| 131 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 132 |
+
|
| 133 |
+
return image
|
| 134 |
+
|
| 135 |
+
def _preprocess(self, image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 136 |
+
"""
|
| 137 |
+
Preprocess image for inference
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
image: Input image (H, W, 3)
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
Preprocessed image and original image
|
| 144 |
+
"""
|
| 145 |
+
original = image.copy()
|
| 146 |
+
|
| 147 |
+
# Resize
|
| 148 |
+
preprocessed = cv2.resize(image, (self.image_size, self.image_size))
|
| 149 |
+
|
| 150 |
+
# Normalize to [0, 1]
|
| 151 |
+
preprocessed = preprocessed.astype(np.float32) / 255.0
|
| 152 |
+
|
| 153 |
+
return preprocessed, original
|
| 154 |
+
|
| 155 |
+
def _to_tensor(self, image: np.ndarray) -> torch.Tensor:
|
| 156 |
+
"""Convert image to tensor"""
|
| 157 |
+
# (H, W, C) -> (C, H, W)
|
| 158 |
+
tensor = torch.from_numpy(image.transpose(2, 0, 1))
|
| 159 |
+
tensor = tensor.unsqueeze(0) # Add batch dimension
|
| 160 |
+
return tensor.to(self.device)
|
| 161 |
+
|
| 162 |
+
def run(self,
|
| 163 |
+
input_path: str,
|
| 164 |
+
output_dir: Optional[str] = None) -> Dict:
|
| 165 |
+
"""
|
| 166 |
+
Run full inference pipeline
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
input_path: Path to input image or PDF
|
| 170 |
+
output_dir: Optional output directory
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
Dictionary with results
|
| 174 |
+
"""
|
| 175 |
+
print(f"\n{'='*60}")
|
| 176 |
+
print(f"Processing: {input_path}")
|
| 177 |
+
print(f"{'='*60}")
|
| 178 |
+
|
| 179 |
+
# 1. Load image
|
| 180 |
+
image = self._load_image(input_path)
|
| 181 |
+
original_size = image.shape[:2]
|
| 182 |
+
print(f"Input size: {original_size}")
|
| 183 |
+
|
| 184 |
+
# 2. Preprocess
|
| 185 |
+
preprocessed, original = self._preprocess(image)
|
| 186 |
+
tensor = self._to_tensor(preprocessed)
|
| 187 |
+
|
| 188 |
+
# 3. Deep localization
|
| 189 |
+
with torch.no_grad():
|
| 190 |
+
logits, decoder_features = self.model(tensor)
|
| 191 |
+
probability_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
|
| 192 |
+
|
| 193 |
+
print(f"Localization complete. Max prob: {probability_map.max():.3f}")
|
| 194 |
+
|
| 195 |
+
# 4. Mask refinement
|
| 196 |
+
binary_mask = self.mask_refiner.refine(probability_map, original_size)
|
| 197 |
+
num_positive_pixels = binary_mask.sum()
|
| 198 |
+
print(f"Mask refinement: {num_positive_pixels} positive pixels")
|
| 199 |
+
|
| 200 |
+
# 5. Region extraction
|
| 201 |
+
# Resize probability map to original size for confidence aggregation
|
| 202 |
+
prob_resized = cv2.resize(probability_map, (original_size[1], original_size[0]))
|
| 203 |
+
|
| 204 |
+
regions = self.region_extractor.extract(binary_mask, prob_resized, original)
|
| 205 |
+
print(f"Regions extracted: {len(regions)}")
|
| 206 |
+
|
| 207 |
+
# 6. Feature extraction & 7. Classification
|
| 208 |
+
results = []
|
| 209 |
+
|
| 210 |
+
for region in regions:
|
| 211 |
+
# Extract features
|
| 212 |
+
features = self.feature_extractor.extract(
|
| 213 |
+
preprocessed,
|
| 214 |
+
cv2.resize(region['region_mask'], (self.image_size, self.image_size)),
|
| 215 |
+
[f.cpu() for f in decoder_features]
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Classify if classifier available
|
| 219 |
+
if self.classifier is not None:
|
| 220 |
+
predictions, confidences, valid_mask = self.classifier.predict_with_filtering(
|
| 221 |
+
features.reshape(1, -1)
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
if valid_mask[0]:
|
| 225 |
+
region['forgery_type'] = self.classifier.get_class_name(predictions[0])
|
| 226 |
+
region['classification_confidence'] = float(confidences[0])
|
| 227 |
+
else:
|
| 228 |
+
# Low confidence - discard
|
| 229 |
+
continue
|
| 230 |
+
else:
|
| 231 |
+
region['forgery_type'] = 'unknown'
|
| 232 |
+
region['classification_confidence'] = region['confidence']
|
| 233 |
+
|
| 234 |
+
# Clean up non-serializable fields
|
| 235 |
+
region_result = {
|
| 236 |
+
'region_id': region['region_id'],
|
| 237 |
+
'bounding_box': region['bounding_box'],
|
| 238 |
+
'forgery_type': region['forgery_type'],
|
| 239 |
+
'confidence': region['confidence'],
|
| 240 |
+
'classification_confidence': region['classification_confidence'],
|
| 241 |
+
'mask_probability_mean': region['mask_probability_mean'],
|
| 242 |
+
'area': region['area']
|
| 243 |
+
}
|
| 244 |
+
results.append(region_result)
|
| 245 |
+
|
| 246 |
+
print(f"Valid regions after filtering: {len(results)}")
|
| 247 |
+
|
| 248 |
+
# 8. Post-processing - False positive removal
|
| 249 |
+
results = self._post_process(results)
|
| 250 |
+
|
| 251 |
+
# 9. Generate output
|
| 252 |
+
output = {
|
| 253 |
+
'input_path': str(input_path),
|
| 254 |
+
'original_size': original_size,
|
| 255 |
+
'num_regions': len(results),
|
| 256 |
+
'regions': results,
|
| 257 |
+
'is_tampered': len(results) > 0
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
# Save outputs if directory provided
|
| 261 |
+
if output_dir:
|
| 262 |
+
output_path = Path(output_dir)
|
| 263 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 264 |
+
|
| 265 |
+
input_name = Path(input_path).stem
|
| 266 |
+
|
| 267 |
+
# Save final mask
|
| 268 |
+
mask_path = output_path / f'{input_name}_mask.png'
|
| 269 |
+
cv2.imwrite(str(mask_path), binary_mask * 255)
|
| 270 |
+
|
| 271 |
+
# Save overlay visualization
|
| 272 |
+
overlay = self._create_overlay(original, binary_mask, results)
|
| 273 |
+
overlay_path = output_path / f'{input_name}_overlay.png'
|
| 274 |
+
cv2.imwrite(str(overlay_path), cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
|
| 275 |
+
|
| 276 |
+
# Save JSON
|
| 277 |
+
json_path = output_path / f'{input_name}_results.json'
|
| 278 |
+
with open(json_path, 'w') as f:
|
| 279 |
+
json.dump(output, f, indent=2)
|
| 280 |
+
|
| 281 |
+
print(f"\nOutputs saved to: {output_path}")
|
| 282 |
+
output['mask_path'] = str(mask_path)
|
| 283 |
+
output['overlay_path'] = str(overlay_path)
|
| 284 |
+
output['json_path'] = str(json_path)
|
| 285 |
+
|
| 286 |
+
return output
|
| 287 |
+
|
| 288 |
+
def _post_process(self, regions: List[Dict]) -> List[Dict]:
|
| 289 |
+
"""
|
| 290 |
+
Post-process regions to remove false positives
|
| 291 |
+
|
| 292 |
+
Args:
|
| 293 |
+
regions: List of region dictionaries
|
| 294 |
+
|
| 295 |
+
Returns:
|
| 296 |
+
Filtered regions
|
| 297 |
+
"""
|
| 298 |
+
filtered = []
|
| 299 |
+
|
| 300 |
+
for region in regions:
|
| 301 |
+
# Confidence filtering
|
| 302 |
+
if region['confidence'] < self.confidence_threshold:
|
| 303 |
+
continue
|
| 304 |
+
|
| 305 |
+
filtered.append(region)
|
| 306 |
+
|
| 307 |
+
return filtered
|
| 308 |
+
|
| 309 |
+
def _create_overlay(self,
|
| 310 |
+
image: np.ndarray,
|
| 311 |
+
mask: np.ndarray,
|
| 312 |
+
regions: List[Dict]) -> np.ndarray:
|
| 313 |
+
"""
|
| 314 |
+
Create visualization overlay
|
| 315 |
+
|
| 316 |
+
Args:
|
| 317 |
+
image: Original image
|
| 318 |
+
mask: Binary mask
|
| 319 |
+
regions: Detected regions
|
| 320 |
+
|
| 321 |
+
Returns:
|
| 322 |
+
Overlay image
|
| 323 |
+
"""
|
| 324 |
+
overlay = image.copy()
|
| 325 |
+
alpha = self.config.get('outputs.visualization.overlay_alpha', 0.5)
|
| 326 |
+
|
| 327 |
+
# Create colored mask
|
| 328 |
+
mask_colored = np.zeros_like(image)
|
| 329 |
+
mask_colored[mask > 0] = [255, 0, 0] # Red for forgery
|
| 330 |
+
|
| 331 |
+
# Blend
|
| 332 |
+
mask_resized = cv2.resize(mask, (image.shape[1], image.shape[0]))
|
| 333 |
+
overlay = np.where(
|
| 334 |
+
mask_resized[:, :, None] > 0,
|
| 335 |
+
(1 - alpha) * image + alpha * mask_colored,
|
| 336 |
+
image
|
| 337 |
+
).astype(np.uint8)
|
| 338 |
+
|
| 339 |
+
# Draw bounding boxes and labels
|
| 340 |
+
for region in regions:
|
| 341 |
+
x, y, w, h = region['bounding_box']
|
| 342 |
+
|
| 343 |
+
# Draw rectangle
|
| 344 |
+
cv2.rectangle(overlay, (x, y), (x + w, y + h), (0, 255, 0), 2)
|
| 345 |
+
|
| 346 |
+
# Draw label
|
| 347 |
+
label = f"{region['forgery_type']} ({region['confidence']:.2f})"
|
| 348 |
+
cv2.putText(overlay, label, (x, y - 10),
|
| 349 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
|
| 350 |
+
|
| 351 |
+
return overlay
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def get_pipeline(config,
|
| 355 |
+
model_path: str,
|
| 356 |
+
classifier_path: Optional[str] = None,
|
| 357 |
+
is_text_document: bool = True) -> ForgeryDetectionPipeline:
|
| 358 |
+
"""Factory function for pipeline"""
|
| 359 |
+
return ForgeryDetectionPipeline(config, model_path, classifier_path, is_text_document)
|
src/models/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Models module"""
|
| 2 |
+
|
| 3 |
+
from .encoder import MobileNetV3Encoder, get_encoder
|
| 4 |
+
from .decoder import UNetLiteDecoder, get_decoder
|
| 5 |
+
from .network import ForgeryLocalizationNetwork, get_model
|
| 6 |
+
from .losses import DiceLoss, CombinedLoss, DatasetAwareLoss, get_loss_function
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
'MobileNetV3Encoder',
|
| 10 |
+
'get_encoder',
|
| 11 |
+
'UNetLiteDecoder',
|
| 12 |
+
'get_decoder',
|
| 13 |
+
'ForgeryLocalizationNetwork',
|
| 14 |
+
'get_model',
|
| 15 |
+
'DiceLoss',
|
| 16 |
+
'CombinedLoss',
|
| 17 |
+
'DatasetAwareLoss',
|
| 18 |
+
'get_loss_function'
|
| 19 |
+
]
|
src/models/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (600 Bytes). View file
|
|
|
src/models/__pycache__/decoder.cpython-312.pyc
ADDED
|
Binary file (7.65 kB). View file
|
|
|
src/models/__pycache__/encoder.cpython-312.pyc
ADDED
|
Binary file (2.91 kB). View file
|
|
|
src/models/__pycache__/losses.cpython-312.pyc
ADDED
|
Binary file (6.55 kB). View file
|
|
|
src/models/__pycache__/network.cpython-312.pyc
ADDED
|
Binary file (5.84 kB). View file
|
|
|
src/models/decoder.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
UNet-Lite Decoder for forgery localization
|
| 3 |
+
Lightweight decoder with skip connections, depthwise separable convolutions
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from typing import List
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DepthwiseSeparableConv(nn.Module):
|
| 13 |
+
"""Depthwise separable convolution for efficiency"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
|
| 16 |
+
super().__init__()
|
| 17 |
+
|
| 18 |
+
self.depthwise = nn.Conv2d(
|
| 19 |
+
in_channels, in_channels,
|
| 20 |
+
kernel_size=kernel_size,
|
| 21 |
+
padding=kernel_size // 2,
|
| 22 |
+
groups=in_channels,
|
| 23 |
+
bias=False
|
| 24 |
+
)
|
| 25 |
+
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
|
| 26 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
| 27 |
+
self.relu = nn.ReLU(inplace=True)
|
| 28 |
+
|
| 29 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 30 |
+
x = self.depthwise(x)
|
| 31 |
+
x = self.pointwise(x)
|
| 32 |
+
x = self.bn(x)
|
| 33 |
+
x = self.relu(x)
|
| 34 |
+
return x
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class DecoderBlock(nn.Module):
|
| 38 |
+
"""Single decoder block with skip connection"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, in_channels: int, skip_channels: int, out_channels: int):
|
| 41 |
+
"""
|
| 42 |
+
Initialize decoder block
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
in_channels: Input channels from previous decoder stage
|
| 46 |
+
skip_channels: Channels from encoder skip connection
|
| 47 |
+
out_channels: Output channels
|
| 48 |
+
"""
|
| 49 |
+
super().__init__()
|
| 50 |
+
|
| 51 |
+
# Combine upsampled features with skip connection
|
| 52 |
+
combined_channels = in_channels + skip_channels
|
| 53 |
+
|
| 54 |
+
self.conv1 = DepthwiseSeparableConv(combined_channels, out_channels)
|
| 55 |
+
self.conv2 = DepthwiseSeparableConv(out_channels, out_channels)
|
| 56 |
+
|
| 57 |
+
def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
|
| 58 |
+
"""
|
| 59 |
+
Forward pass
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
x: Input from previous decoder stage
|
| 63 |
+
skip: Skip connection from encoder
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Decoded features
|
| 67 |
+
"""
|
| 68 |
+
# Bilinear upsampling
|
| 69 |
+
x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
|
| 70 |
+
|
| 71 |
+
# Concatenate with skip connection
|
| 72 |
+
x = torch.cat([x, skip], dim=1)
|
| 73 |
+
|
| 74 |
+
# Convolutions
|
| 75 |
+
x = self.conv1(x)
|
| 76 |
+
x = self.conv2(x)
|
| 77 |
+
|
| 78 |
+
return x
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class UNetLiteDecoder(nn.Module):
|
| 82 |
+
"""
|
| 83 |
+
UNet-Lite decoder for forgery localization
|
| 84 |
+
|
| 85 |
+
Features:
|
| 86 |
+
- Skip connections from encoder stages
|
| 87 |
+
- Bilinear upsampling
|
| 88 |
+
- Depthwise separable convolutions for efficiency
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(self,
|
| 92 |
+
encoder_channels: List[int],
|
| 93 |
+
decoder_channels: List[int] = None,
|
| 94 |
+
output_channels: int = 1):
|
| 95 |
+
"""
|
| 96 |
+
Initialize decoder
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
encoder_channels: List of encoder feature channels [stage0, ..., stageN]
|
| 100 |
+
decoder_channels: List of decoder output channels
|
| 101 |
+
output_channels: Number of output channels (1 for binary mask)
|
| 102 |
+
"""
|
| 103 |
+
super().__init__()
|
| 104 |
+
|
| 105 |
+
# Default decoder channels if not provided
|
| 106 |
+
if decoder_channels is None:
|
| 107 |
+
decoder_channels = [256, 128, 64, 32, 16]
|
| 108 |
+
|
| 109 |
+
# Reverse encoder channels for decoder (bottom to top)
|
| 110 |
+
encoder_channels = encoder_channels[::-1]
|
| 111 |
+
|
| 112 |
+
# Initial convolution from deepest encoder features
|
| 113 |
+
self.initial_conv = DepthwiseSeparableConv(encoder_channels[0], decoder_channels[0])
|
| 114 |
+
|
| 115 |
+
# Decoder blocks
|
| 116 |
+
self.decoder_blocks = nn.ModuleList()
|
| 117 |
+
|
| 118 |
+
for i in range(len(encoder_channels) - 1):
|
| 119 |
+
in_ch = decoder_channels[i]
|
| 120 |
+
skip_ch = encoder_channels[i + 1]
|
| 121 |
+
out_ch = decoder_channels[i + 1] if i + 1 < len(decoder_channels) else decoder_channels[-1]
|
| 122 |
+
|
| 123 |
+
self.decoder_blocks.append(
|
| 124 |
+
DecoderBlock(in_ch, skip_ch, out_ch)
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Final upsampling to original resolution
|
| 128 |
+
self.final_upsample = nn.Sequential(
|
| 129 |
+
DepthwiseSeparableConv(decoder_channels[-1], decoder_channels[-1]),
|
| 130 |
+
nn.Conv2d(decoder_channels[-1], output_channels, kernel_size=1)
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Store decoder feature channels for feature extraction
|
| 134 |
+
self.decoder_channels = decoder_channels
|
| 135 |
+
|
| 136 |
+
print(f"UNet-Lite decoder initialized")
|
| 137 |
+
print(f"Encoder channels: {encoder_channels[::-1]}")
|
| 138 |
+
print(f"Decoder channels: {decoder_channels}")
|
| 139 |
+
|
| 140 |
+
def forward(self, encoder_features: List[torch.Tensor]) -> tuple:
|
| 141 |
+
"""
|
| 142 |
+
Forward pass
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
encoder_features: List of encoder features [stage0, ..., stageN]
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
output: Forgery probability map (B, 1, H, W)
|
| 149 |
+
decoder_features: List of decoder features for hybrid extraction
|
| 150 |
+
"""
|
| 151 |
+
# Reverse for bottom-up decoding
|
| 152 |
+
features = encoder_features[::-1]
|
| 153 |
+
|
| 154 |
+
# Initial convolution
|
| 155 |
+
x = self.initial_conv(features[0])
|
| 156 |
+
|
| 157 |
+
# Store decoder features for hybrid feature extraction
|
| 158 |
+
decoder_features = [x]
|
| 159 |
+
|
| 160 |
+
# Decoder blocks with skip connections
|
| 161 |
+
for i, block in enumerate(self.decoder_blocks):
|
| 162 |
+
x = block(x, features[i + 1])
|
| 163 |
+
decoder_features.append(x)
|
| 164 |
+
|
| 165 |
+
# Final upsampling to original resolution
|
| 166 |
+
# Assume input was 384x384, final feature map should match
|
| 167 |
+
target_size = encoder_features[0].shape[2] * 2 # First encoder feature is at 1/2 scale
|
| 168 |
+
x = F.interpolate(x, size=(target_size, target_size), mode='bilinear', align_corners=False)
|
| 169 |
+
output = self.final_upsample[1](self.final_upsample[0](x))
|
| 170 |
+
|
| 171 |
+
return output, decoder_features
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def get_decoder(encoder_channels: List[int], config) -> UNetLiteDecoder:
|
| 175 |
+
"""
|
| 176 |
+
Factory function to create decoder
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
encoder_channels: Encoder feature channels
|
| 180 |
+
config: Configuration object
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
Decoder instance
|
| 184 |
+
"""
|
| 185 |
+
output_channels = config.get('model.output_channels', 1)
|
| 186 |
+
return UNetLiteDecoder(encoder_channels, output_channels=output_channels)
|
src/models/encoder.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MobileNetV3-Small Encoder for forgery localization
|
| 3 |
+
ImageNet pretrained, feature extraction mode
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import timm
|
| 9 |
+
from typing import List
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MobileNetV3Encoder(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
MobileNetV3-Small encoder for document forgery detection
|
| 15 |
+
|
| 16 |
+
Chosen for:
|
| 17 |
+
- Stroke-level and texture preservation
|
| 18 |
+
- Robustness to compression and blur
|
| 19 |
+
- Edge and CPU deployment efficiency
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, pretrained: bool = True):
|
| 23 |
+
"""
|
| 24 |
+
Initialize encoder
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
pretrained: Whether to use ImageNet pretrained weights
|
| 28 |
+
"""
|
| 29 |
+
super().__init__()
|
| 30 |
+
|
| 31 |
+
# Load MobileNetV3-Small with feature extraction
|
| 32 |
+
self.backbone = timm.create_model(
|
| 33 |
+
'mobilenetv3_small_100',
|
| 34 |
+
pretrained=pretrained,
|
| 35 |
+
features_only=True,
|
| 36 |
+
out_indices=(0, 1, 2, 3, 4) # All feature stages
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# Get feature channel dimensions
|
| 40 |
+
# MobileNetV3-Small: [16, 16, 24, 48, 576]
|
| 41 |
+
self.feature_channels = self.backbone.feature_info.channels()
|
| 42 |
+
|
| 43 |
+
print(f"MobileNetV3-Small encoder initialized")
|
| 44 |
+
print(f"Feature channels: {self.feature_channels}")
|
| 45 |
+
|
| 46 |
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| 47 |
+
"""
|
| 48 |
+
Extract multi-scale features
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
x: Input tensor (B, 3, H, W)
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
List of feature tensors at different scales
|
| 55 |
+
"""
|
| 56 |
+
features = self.backbone(x)
|
| 57 |
+
return features
|
| 58 |
+
|
| 59 |
+
def get_feature_channels(self) -> List[int]:
|
| 60 |
+
"""Get feature channel dimensions for each stage"""
|
| 61 |
+
return self.feature_channels
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_encoder(config) -> MobileNetV3Encoder:
|
| 65 |
+
"""
|
| 66 |
+
Factory function to create encoder
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
config: Configuration object
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
Encoder instance
|
| 73 |
+
"""
|
| 74 |
+
pretrained = config.get('model.encoder.pretrained', True)
|
| 75 |
+
return MobileNetV3Encoder(pretrained=pretrained)
|
src/models/losses.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset-aware loss functions
|
| 3 |
+
Implements Critical Fix #2: Dataset-Aware Loss Function
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from typing import Dict, Optional
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DiceLoss(nn.Module):
|
| 13 |
+
"""Dice loss for segmentation"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, smooth: float = 1.0):
|
| 16 |
+
"""
|
| 17 |
+
Initialize Dice loss
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
smooth: Smoothing factor to avoid division by zero
|
| 21 |
+
"""
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.smooth = smooth
|
| 24 |
+
|
| 25 |
+
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 26 |
+
"""
|
| 27 |
+
Compute Dice loss
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
pred: Predicted probabilities (B, 1, H, W)
|
| 31 |
+
target: Ground truth mask (B, 1, H, W)
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
Dice loss value
|
| 35 |
+
"""
|
| 36 |
+
pred = torch.sigmoid(pred)
|
| 37 |
+
|
| 38 |
+
# Flatten
|
| 39 |
+
pred_flat = pred.view(-1)
|
| 40 |
+
target_flat = target.view(-1)
|
| 41 |
+
|
| 42 |
+
# Dice coefficient
|
| 43 |
+
intersection = (pred_flat * target_flat).sum()
|
| 44 |
+
dice = (2. * intersection + self.smooth) / (
|
| 45 |
+
pred_flat.sum() + target_flat.sum() + self.smooth
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
return 1 - dice
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class CombinedLoss(nn.Module):
|
| 52 |
+
"""
|
| 53 |
+
Combined BCE + Dice loss for segmentation
|
| 54 |
+
Dataset-aware: Only uses Dice when pixel masks are available
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(self,
|
| 58 |
+
bce_weight: float = 1.0,
|
| 59 |
+
dice_weight: float = 1.0):
|
| 60 |
+
"""
|
| 61 |
+
Initialize combined loss
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
bce_weight: Weight for BCE loss
|
| 65 |
+
dice_weight: Weight for Dice loss
|
| 66 |
+
"""
|
| 67 |
+
super().__init__()
|
| 68 |
+
|
| 69 |
+
self.bce_weight = bce_weight
|
| 70 |
+
self.dice_weight = dice_weight
|
| 71 |
+
|
| 72 |
+
self.bce_loss = nn.BCEWithLogitsLoss()
|
| 73 |
+
self.dice_loss = DiceLoss()
|
| 74 |
+
|
| 75 |
+
def forward(self,
|
| 76 |
+
pred: torch.Tensor,
|
| 77 |
+
target: torch.Tensor,
|
| 78 |
+
has_pixel_mask: bool = True) -> Dict[str, torch.Tensor]:
|
| 79 |
+
"""
|
| 80 |
+
Compute loss (dataset-aware)
|
| 81 |
+
|
| 82 |
+
Critical Fix #2: Only use Dice loss for datasets with pixel masks
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
pred: Predicted logits (B, 1, H, W)
|
| 86 |
+
target: Ground truth mask (B, 1, H, W)
|
| 87 |
+
has_pixel_mask: Whether dataset has pixel-level masks
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
Dictionary with 'total', 'bce', and optionally 'dice' losses
|
| 91 |
+
"""
|
| 92 |
+
# BCE loss (always used)
|
| 93 |
+
bce = self.bce_loss(pred, target)
|
| 94 |
+
|
| 95 |
+
losses = {
|
| 96 |
+
'bce': bce
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
if has_pixel_mask:
|
| 100 |
+
# Use Dice loss only for datasets with pixel masks
|
| 101 |
+
dice = self.dice_loss(pred, target)
|
| 102 |
+
losses['dice'] = dice
|
| 103 |
+
losses['total'] = self.bce_weight * bce + self.dice_weight * dice
|
| 104 |
+
else:
|
| 105 |
+
# Critical Fix #2: CASIA only uses BCE
|
| 106 |
+
losses['total'] = self.bce_weight * bce
|
| 107 |
+
|
| 108 |
+
return losses
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class DatasetAwareLoss(nn.Module):
|
| 112 |
+
"""
|
| 113 |
+
Dataset-aware loss function wrapper
|
| 114 |
+
Automatically determines appropriate loss based on dataset metadata
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
def __init__(self, config):
|
| 118 |
+
"""
|
| 119 |
+
Initialize dataset-aware loss
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
config: Configuration object
|
| 123 |
+
"""
|
| 124 |
+
super().__init__()
|
| 125 |
+
|
| 126 |
+
self.config = config
|
| 127 |
+
|
| 128 |
+
bce_weight = config.get('loss.bce_weight', 1.0)
|
| 129 |
+
dice_weight = config.get('loss.dice_weight', 1.0)
|
| 130 |
+
|
| 131 |
+
self.combined_loss = CombinedLoss(
|
| 132 |
+
bce_weight=bce_weight,
|
| 133 |
+
dice_weight=dice_weight
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
def forward(self,
|
| 137 |
+
pred: torch.Tensor,
|
| 138 |
+
target: torch.Tensor,
|
| 139 |
+
metadata: Dict) -> Dict[str, torch.Tensor]:
|
| 140 |
+
"""
|
| 141 |
+
Compute loss with dataset awareness
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
pred: Predicted logits (B, 1, H, W)
|
| 145 |
+
target: Ground truth mask (B, 1, H, W)
|
| 146 |
+
metadata: Batch metadata containing 'has_pixel_mask' flags
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
Dictionary with loss components
|
| 150 |
+
"""
|
| 151 |
+
# Check if batch has pixel masks
|
| 152 |
+
has_pixel_mask = all(m.get('has_pixel_mask', True) for m in metadata) \
|
| 153 |
+
if isinstance(metadata, list) else metadata.get('has_pixel_mask', True)
|
| 154 |
+
|
| 155 |
+
return self.combined_loss(pred, target, has_pixel_mask)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def get_loss_function(config) -> DatasetAwareLoss:
|
| 159 |
+
"""
|
| 160 |
+
Factory function to create loss
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
config: Configuration object
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
Loss function instance
|
| 167 |
+
"""
|
| 168 |
+
return DatasetAwareLoss(config)
|
src/models/network.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Complete Forgery Localization Network
|
| 3 |
+
MobileNetV3-Small Encoder + UNet-Lite Decoder
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from typing import Tuple, List, Optional
|
| 9 |
+
|
| 10 |
+
from .encoder import MobileNetV3Encoder
|
| 11 |
+
from .decoder import UNetLiteDecoder
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ForgeryLocalizationNetwork(nn.Module):
|
| 15 |
+
"""
|
| 16 |
+
Complete network for forgery localization
|
| 17 |
+
|
| 18 |
+
Architecture:
|
| 19 |
+
- Encoder: MobileNetV3-Small (ImageNet pretrained)
|
| 20 |
+
- Decoder: UNet-Lite with skip connections
|
| 21 |
+
- Output: Single-channel forgery probability map
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, config):
|
| 25 |
+
"""
|
| 26 |
+
Initialize network
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
config: Configuration object
|
| 30 |
+
"""
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
self.config = config
|
| 34 |
+
|
| 35 |
+
# Initialize encoder
|
| 36 |
+
pretrained = config.get('model.encoder.pretrained', True)
|
| 37 |
+
self.encoder = MobileNetV3Encoder(pretrained=pretrained)
|
| 38 |
+
|
| 39 |
+
# Initialize decoder
|
| 40 |
+
encoder_channels = self.encoder.get_feature_channels()
|
| 41 |
+
output_channels = config.get('model.output_channels', 1)
|
| 42 |
+
self.decoder = UNetLiteDecoder(
|
| 43 |
+
encoder_channels=encoder_channels,
|
| 44 |
+
output_channels=output_channels
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
print(f"ForgeryLocalizationNetwork initialized")
|
| 48 |
+
print(f"Total parameters: {self.count_parameters():,}")
|
| 49 |
+
|
| 50 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
| 51 |
+
"""
|
| 52 |
+
Forward pass
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
x: Input image tensor (B, 3, H, W)
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
output: Forgery probability map (B, 1, H, W) - logits
|
| 59 |
+
decoder_features: Decoder features for hybrid feature extraction
|
| 60 |
+
"""
|
| 61 |
+
# Encode
|
| 62 |
+
encoder_features = self.encoder(x)
|
| 63 |
+
|
| 64 |
+
# Decode
|
| 65 |
+
output, decoder_features = self.decoder(encoder_features)
|
| 66 |
+
|
| 67 |
+
return output, decoder_features
|
| 68 |
+
|
| 69 |
+
def predict(self, x: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
|
| 70 |
+
"""
|
| 71 |
+
Predict binary mask
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
x: Input image tensor (B, 3, H, W)
|
| 75 |
+
threshold: Probability threshold for binarization
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
Binary mask (B, 1, H, W)
|
| 79 |
+
"""
|
| 80 |
+
with torch.no_grad():
|
| 81 |
+
logits, _ = self.forward(x)
|
| 82 |
+
probs = torch.sigmoid(logits)
|
| 83 |
+
mask = (probs > threshold).float()
|
| 84 |
+
|
| 85 |
+
return mask
|
| 86 |
+
|
| 87 |
+
def get_probability_map(self, x: torch.Tensor) -> torch.Tensor:
|
| 88 |
+
"""
|
| 89 |
+
Get probability map
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
x: Input image tensor (B, 3, H, W)
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
Probability map (B, 1, H, W)
|
| 96 |
+
"""
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
logits, _ = self.forward(x)
|
| 99 |
+
probs = torch.sigmoid(logits)
|
| 100 |
+
|
| 101 |
+
return probs
|
| 102 |
+
|
| 103 |
+
def count_parameters(self) -> int:
|
| 104 |
+
"""Count total trainable parameters"""
|
| 105 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 106 |
+
|
| 107 |
+
def get_decoder_features(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| 108 |
+
"""
|
| 109 |
+
Get decoder features for hybrid feature extraction
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
x: Input image tensor (B, 3, H, W)
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
List of decoder features
|
| 116 |
+
"""
|
| 117 |
+
with torch.no_grad():
|
| 118 |
+
_, decoder_features = self.forward(x)
|
| 119 |
+
|
| 120 |
+
return decoder_features
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def get_model(config) -> ForgeryLocalizationNetwork:
|
| 124 |
+
"""
|
| 125 |
+
Factory function to create model
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
config: Configuration object
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
Model instance
|
| 132 |
+
"""
|
| 133 |
+
return ForgeryLocalizationNetwork(config)
|
src/training/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training module"""
|
| 2 |
+
|
| 3 |
+
from .metrics import (
|
| 4 |
+
SegmentationMetrics,
|
| 5 |
+
ClassificationMetrics,
|
| 6 |
+
MetricsTracker,
|
| 7 |
+
EarlyStopping,
|
| 8 |
+
get_metrics_tracker
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
from .trainer import Trainer, get_trainer
|
| 12 |
+
from .classifier import ForgeryClassifier, get_classifier
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
'SegmentationMetrics',
|
| 16 |
+
'ClassificationMetrics',
|
| 17 |
+
'MetricsTracker',
|
| 18 |
+
'EarlyStopping',
|
| 19 |
+
'get_metrics_tracker',
|
| 20 |
+
'Trainer',
|
| 21 |
+
'get_trainer',
|
| 22 |
+
'ForgeryClassifier',
|
| 23 |
+
'get_classifier'
|
| 24 |
+
]
|
src/training/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (568 Bytes). View file
|
|
|
src/training/__pycache__/classifier.cpython-312.pyc
ADDED
|
Binary file (11 kB). View file
|
|
|
src/training/__pycache__/metrics.cpython-312.pyc
ADDED
|
Binary file (12.5 kB). View file
|
|
|
src/training/__pycache__/trainer.cpython-312.pyc
ADDED
|
Binary file (18.8 kB). View file
|
|
|
src/training/classifier.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LightGBM classifier for forgery type classification
|
| 3 |
+
Implements Critical Fix #8: Configurable Confidence Threshold
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import lightgbm as lgb
|
| 8 |
+
from sklearn.preprocessing import StandardScaler
|
| 9 |
+
from sklearn.model_selection import train_test_split
|
| 10 |
+
from typing import Dict, List, Tuple, Optional
|
| 11 |
+
import joblib
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
import json
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ForgeryClassifier:
|
| 17 |
+
"""
|
| 18 |
+
LightGBM classifier for region-wise forgery classification
|
| 19 |
+
|
| 20 |
+
Target classes:
|
| 21 |
+
- 0: copy_move
|
| 22 |
+
- 1: splicing
|
| 23 |
+
- 2: text_substitution
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
CLASS_NAMES = ['copy_move', 'splicing', 'text_substitution']
|
| 27 |
+
|
| 28 |
+
def __init__(self, config):
|
| 29 |
+
"""
|
| 30 |
+
Initialize classifier
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
config: Configuration object
|
| 34 |
+
"""
|
| 35 |
+
self.config = config
|
| 36 |
+
|
| 37 |
+
# LightGBM parameters
|
| 38 |
+
self.params = config.get('classifier.params', {
|
| 39 |
+
'objective': 'multiclass',
|
| 40 |
+
'num_class': 3,
|
| 41 |
+
'boosting_type': 'gbdt',
|
| 42 |
+
'num_leaves': 31,
|
| 43 |
+
'learning_rate': 0.05,
|
| 44 |
+
'n_estimators': 200,
|
| 45 |
+
'max_depth': 7,
|
| 46 |
+
'min_child_samples': 20,
|
| 47 |
+
'subsample': 0.8,
|
| 48 |
+
'colsample_bytree': 0.8,
|
| 49 |
+
'reg_alpha': 0.1,
|
| 50 |
+
'reg_lambda': 0.1,
|
| 51 |
+
'random_state': 42,
|
| 52 |
+
'verbose': -1
|
| 53 |
+
})
|
| 54 |
+
|
| 55 |
+
# Critical Fix #8: Configurable confidence threshold
|
| 56 |
+
self.confidence_threshold = config.get('classifier.confidence_threshold', 0.6)
|
| 57 |
+
|
| 58 |
+
# Initialize model and scaler
|
| 59 |
+
self.model = None
|
| 60 |
+
self.scaler = StandardScaler()
|
| 61 |
+
|
| 62 |
+
# Feature importance
|
| 63 |
+
self.feature_importance = None
|
| 64 |
+
self.feature_names = None
|
| 65 |
+
|
| 66 |
+
def train(self,
|
| 67 |
+
features: np.ndarray,
|
| 68 |
+
labels: np.ndarray,
|
| 69 |
+
feature_names: Optional[List[str]] = None,
|
| 70 |
+
validation_split: float = 0.2) -> Dict:
|
| 71 |
+
"""
|
| 72 |
+
Train classifier
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
features: Feature matrix (N, D)
|
| 76 |
+
labels: Class labels (N,)
|
| 77 |
+
feature_names: Optional feature names
|
| 78 |
+
validation_split: Validation split ratio
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Training metrics
|
| 82 |
+
"""
|
| 83 |
+
print(f"Training LightGBM classifier")
|
| 84 |
+
print(f"Features shape: {features.shape}")
|
| 85 |
+
print(f"Labels distribution: {np.bincount(labels)}")
|
| 86 |
+
|
| 87 |
+
# Handle NaN/Inf
|
| 88 |
+
features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
|
| 89 |
+
|
| 90 |
+
# Normalize features
|
| 91 |
+
features_scaled = self.scaler.fit_transform(features)
|
| 92 |
+
|
| 93 |
+
# Split data (Critical Fix #7: Image-level splitting should be done upstream)
|
| 94 |
+
X_train, X_val, y_train, y_val = train_test_split(
|
| 95 |
+
features_scaled, labels,
|
| 96 |
+
test_size=validation_split,
|
| 97 |
+
random_state=42,
|
| 98 |
+
stratify=labels
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Create LightGBM datasets
|
| 102 |
+
train_data = lgb.Dataset(X_train, label=y_train)
|
| 103 |
+
val_data = lgb.Dataset(X_val, label=y_val, reference=train_data)
|
| 104 |
+
|
| 105 |
+
# Train model
|
| 106 |
+
self.model = lgb.train(
|
| 107 |
+
self.params,
|
| 108 |
+
train_data,
|
| 109 |
+
valid_sets=[train_data, val_data],
|
| 110 |
+
valid_names=['train', 'val'],
|
| 111 |
+
num_boost_round=self.params.get('n_estimators', 200),
|
| 112 |
+
callbacks=[
|
| 113 |
+
lgb.early_stopping(stopping_rounds=20),
|
| 114 |
+
lgb.log_evaluation(period=10)
|
| 115 |
+
]
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Store feature importance
|
| 119 |
+
self.feature_names = feature_names
|
| 120 |
+
self.feature_importance = self.model.feature_importance(importance_type='gain')
|
| 121 |
+
|
| 122 |
+
# Evaluate
|
| 123 |
+
train_pred = self.model.predict(X_train)
|
| 124 |
+
train_acc = (train_pred.argmax(axis=1) == y_train).mean()
|
| 125 |
+
|
| 126 |
+
val_pred = self.model.predict(X_val)
|
| 127 |
+
val_acc = (val_pred.argmax(axis=1) == y_val).mean()
|
| 128 |
+
|
| 129 |
+
metrics = {
|
| 130 |
+
'train_accuracy': train_acc,
|
| 131 |
+
'val_accuracy': val_acc,
|
| 132 |
+
'num_features': features.shape[1],
|
| 133 |
+
'num_samples': len(labels),
|
| 134 |
+
'best_iteration': self.model.best_iteration
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
print(f"Training complete!")
|
| 138 |
+
print(f"Train accuracy: {train_acc:.4f}")
|
| 139 |
+
print(f"Val accuracy: {val_acc:.4f}")
|
| 140 |
+
|
| 141 |
+
return metrics
|
| 142 |
+
|
| 143 |
+
def predict(self, features: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 144 |
+
"""
|
| 145 |
+
Predict forgery types
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
features: Feature matrix (N, D)
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
predictions: Predicted class indices (N,)
|
| 152 |
+
confidences: Prediction confidences (N,)
|
| 153 |
+
"""
|
| 154 |
+
if self.model is None:
|
| 155 |
+
raise ValueError("Model not trained. Call train() first.")
|
| 156 |
+
|
| 157 |
+
# Handle NaN/Inf
|
| 158 |
+
features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
|
| 159 |
+
|
| 160 |
+
# Normalize features
|
| 161 |
+
features_scaled = self.scaler.transform(features)
|
| 162 |
+
|
| 163 |
+
# Predict probabilities
|
| 164 |
+
probabilities = self.model.predict(features_scaled)
|
| 165 |
+
|
| 166 |
+
# Get predictions and confidences
|
| 167 |
+
predictions = probabilities.argmax(axis=1)
|
| 168 |
+
confidences = probabilities.max(axis=1)
|
| 169 |
+
|
| 170 |
+
return predictions, confidences
|
| 171 |
+
|
| 172 |
+
def predict_with_filtering(self,
|
| 173 |
+
features: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 174 |
+
"""
|
| 175 |
+
Predict with confidence filtering
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
features: Feature matrix (N, D)
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
predictions: Predicted class indices (N,)
|
| 182 |
+
confidences: Prediction confidences (N,)
|
| 183 |
+
valid_mask: Boolean mask for valid predictions (N,)
|
| 184 |
+
"""
|
| 185 |
+
predictions, confidences = self.predict(features)
|
| 186 |
+
|
| 187 |
+
# Critical Fix #8: Apply confidence threshold
|
| 188 |
+
valid_mask = confidences >= self.confidence_threshold
|
| 189 |
+
|
| 190 |
+
return predictions, confidences, valid_mask
|
| 191 |
+
|
| 192 |
+
def get_class_name(self, class_idx: int) -> str:
|
| 193 |
+
"""Get class name from index"""
|
| 194 |
+
return self.CLASS_NAMES[class_idx]
|
| 195 |
+
|
| 196 |
+
def get_feature_importance(self, top_k: int = 20) -> List[Tuple[str, float]]:
|
| 197 |
+
"""
|
| 198 |
+
Get top-k most important features
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
top_k: Number of features to return
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
List of (feature_name, importance) tuples
|
| 205 |
+
"""
|
| 206 |
+
if self.feature_importance is None:
|
| 207 |
+
return []
|
| 208 |
+
|
| 209 |
+
# Sort by importance
|
| 210 |
+
indices = np.argsort(self.feature_importance)[::-1][:top_k]
|
| 211 |
+
|
| 212 |
+
result = []
|
| 213 |
+
for idx in indices:
|
| 214 |
+
name = self.feature_names[idx] if self.feature_names else f'feature_{idx}'
|
| 215 |
+
importance = self.feature_importance[idx]
|
| 216 |
+
result.append((name, importance))
|
| 217 |
+
|
| 218 |
+
return result
|
| 219 |
+
|
| 220 |
+
def save(self, save_dir: str):
|
| 221 |
+
"""
|
| 222 |
+
Save model and scaler
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
save_dir: Directory to save model
|
| 226 |
+
"""
|
| 227 |
+
save_path = Path(save_dir)
|
| 228 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
| 229 |
+
|
| 230 |
+
# Save LightGBM model
|
| 231 |
+
model_path = save_path / 'lightgbm_model.txt'
|
| 232 |
+
self.model.save_model(str(model_path))
|
| 233 |
+
|
| 234 |
+
# Save scaler
|
| 235 |
+
scaler_path = save_path / 'scaler.joblib'
|
| 236 |
+
joblib.dump(self.scaler, str(scaler_path))
|
| 237 |
+
|
| 238 |
+
# Save metadata
|
| 239 |
+
metadata = {
|
| 240 |
+
'confidence_threshold': self.confidence_threshold,
|
| 241 |
+
'class_names': self.CLASS_NAMES,
|
| 242 |
+
'feature_names': self.feature_names,
|
| 243 |
+
'feature_importance': self.feature_importance.tolist() if self.feature_importance is not None else None
|
| 244 |
+
}
|
| 245 |
+
metadata_path = save_path / 'classifier_metadata.json'
|
| 246 |
+
with open(metadata_path, 'w') as f:
|
| 247 |
+
json.dump(metadata, f, indent=2)
|
| 248 |
+
|
| 249 |
+
print(f"Classifier saved to {save_path}")
|
| 250 |
+
|
| 251 |
+
def load(self, load_dir: str):
|
| 252 |
+
"""
|
| 253 |
+
Load model and scaler
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
load_dir: Directory to load from
|
| 257 |
+
"""
|
| 258 |
+
load_path = Path(load_dir)
|
| 259 |
+
|
| 260 |
+
# Load LightGBM model
|
| 261 |
+
model_path = load_path / 'lightgbm_model.txt'
|
| 262 |
+
self.model = lgb.Booster(model_file=str(model_path))
|
| 263 |
+
|
| 264 |
+
# Load scaler
|
| 265 |
+
scaler_path = load_path / 'scaler.joblib'
|
| 266 |
+
self.scaler = joblib.load(str(scaler_path))
|
| 267 |
+
|
| 268 |
+
# Load metadata
|
| 269 |
+
metadata_path = load_path / 'classifier_metadata.json'
|
| 270 |
+
with open(metadata_path, 'r') as f:
|
| 271 |
+
metadata = json.load(f)
|
| 272 |
+
|
| 273 |
+
self.confidence_threshold = metadata.get('confidence_threshold', 0.6)
|
| 274 |
+
self.feature_names = metadata.get('feature_names')
|
| 275 |
+
self.feature_importance = np.array(metadata.get('feature_importance', []))
|
| 276 |
+
|
| 277 |
+
print(f"Classifier loaded from {load_path}")
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def get_classifier(config) -> ForgeryClassifier:
|
| 281 |
+
"""Factory function for classifier"""
|
| 282 |
+
return ForgeryClassifier(config)
|
src/training/metrics.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training utilities and metrics
|
| 3 |
+
Implements Critical Fix #9: Dataset-Aware Metric Computation
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
from typing import Dict, List, Optional
|
| 9 |
+
from sklearn.metrics import (
|
| 10 |
+
accuracy_score, f1_score, precision_score, recall_score,
|
| 11 |
+
confusion_matrix
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SegmentationMetrics:
|
| 16 |
+
"""
|
| 17 |
+
Segmentation metrics (IoU, Dice)
|
| 18 |
+
Only computed for datasets with pixel masks (Critical Fix #9)
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self):
|
| 22 |
+
"""Initialize metrics"""
|
| 23 |
+
self.reset()
|
| 24 |
+
|
| 25 |
+
def reset(self):
|
| 26 |
+
"""Reset all metrics"""
|
| 27 |
+
self.intersection = 0
|
| 28 |
+
self.union = 0
|
| 29 |
+
self.pred_sum = 0
|
| 30 |
+
self.target_sum = 0
|
| 31 |
+
self.total_samples = 0
|
| 32 |
+
|
| 33 |
+
def update(self,
|
| 34 |
+
pred: torch.Tensor,
|
| 35 |
+
target: torch.Tensor,
|
| 36 |
+
has_pixel_mask: bool = True):
|
| 37 |
+
"""
|
| 38 |
+
Update metrics with batch
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
pred: Predicted probabilities (B, 1, H, W)
|
| 42 |
+
target: Ground truth masks (B, 1, H, W)
|
| 43 |
+
has_pixel_mask: Whether to compute metrics (Critical Fix #9)
|
| 44 |
+
"""
|
| 45 |
+
if not has_pixel_mask:
|
| 46 |
+
return
|
| 47 |
+
|
| 48 |
+
# Binarize predictions
|
| 49 |
+
pred_binary = (pred > 0.5).float()
|
| 50 |
+
|
| 51 |
+
# Compute intersection and union
|
| 52 |
+
intersection = (pred_binary * target).sum().item()
|
| 53 |
+
union = pred_binary.sum().item() + target.sum().item() - intersection
|
| 54 |
+
|
| 55 |
+
self.intersection += intersection
|
| 56 |
+
self.union += union
|
| 57 |
+
self.pred_sum += pred_binary.sum().item()
|
| 58 |
+
self.target_sum += target.sum().item()
|
| 59 |
+
self.total_samples += pred.shape[0]
|
| 60 |
+
|
| 61 |
+
def compute(self) -> Dict[str, float]:
|
| 62 |
+
"""
|
| 63 |
+
Compute final metrics
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Dictionary with IoU, Dice, Precision, Recall
|
| 67 |
+
"""
|
| 68 |
+
# IoU (Jaccard)
|
| 69 |
+
iou = self.intersection / (self.union + 1e-8)
|
| 70 |
+
|
| 71 |
+
# Dice (F1)
|
| 72 |
+
dice = (2 * self.intersection) / (self.pred_sum + self.target_sum + 1e-8)
|
| 73 |
+
|
| 74 |
+
# Precision
|
| 75 |
+
precision = self.intersection / (self.pred_sum + 1e-8)
|
| 76 |
+
|
| 77 |
+
# Recall
|
| 78 |
+
recall = self.intersection / (self.target_sum + 1e-8)
|
| 79 |
+
|
| 80 |
+
return {
|
| 81 |
+
'iou': iou,
|
| 82 |
+
'dice': dice,
|
| 83 |
+
'precision': precision,
|
| 84 |
+
'recall': recall
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class ClassificationMetrics:
|
| 89 |
+
"""Classification metrics for forgery type classification"""
|
| 90 |
+
|
| 91 |
+
def __init__(self, num_classes: int = 3):
|
| 92 |
+
"""
|
| 93 |
+
Initialize metrics
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
num_classes: Number of forgery types
|
| 97 |
+
"""
|
| 98 |
+
self.num_classes = num_classes
|
| 99 |
+
self.reset()
|
| 100 |
+
|
| 101 |
+
def reset(self):
|
| 102 |
+
"""Reset all metrics"""
|
| 103 |
+
self.predictions = []
|
| 104 |
+
self.targets = []
|
| 105 |
+
self.confidences = []
|
| 106 |
+
|
| 107 |
+
def update(self,
|
| 108 |
+
pred: np.ndarray,
|
| 109 |
+
target: np.ndarray,
|
| 110 |
+
confidence: Optional[np.ndarray] = None):
|
| 111 |
+
"""
|
| 112 |
+
Update metrics with predictions
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
pred: Predicted class indices
|
| 116 |
+
target: Ground truth class indices
|
| 117 |
+
confidence: Optional prediction confidences
|
| 118 |
+
"""
|
| 119 |
+
self.predictions.extend(pred.tolist())
|
| 120 |
+
self.targets.extend(target.tolist())
|
| 121 |
+
if confidence is not None:
|
| 122 |
+
self.confidences.extend(confidence.tolist())
|
| 123 |
+
|
| 124 |
+
def compute(self) -> Dict[str, float]:
|
| 125 |
+
"""
|
| 126 |
+
Compute final metrics
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
Dictionary with Accuracy, F1, Precision, Recall
|
| 130 |
+
"""
|
| 131 |
+
if len(self.predictions) == 0:
|
| 132 |
+
return {
|
| 133 |
+
'accuracy': 0.0,
|
| 134 |
+
'f1_macro': 0.0,
|
| 135 |
+
'f1_weighted': 0.0,
|
| 136 |
+
'precision': 0.0,
|
| 137 |
+
'recall': 0.0
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
preds = np.array(self.predictions)
|
| 141 |
+
targets = np.array(self.targets)
|
| 142 |
+
|
| 143 |
+
# Accuracy
|
| 144 |
+
accuracy = accuracy_score(targets, preds)
|
| 145 |
+
|
| 146 |
+
# F1 score (macro and weighted)
|
| 147 |
+
f1_macro = f1_score(targets, preds, average='macro', zero_division=0)
|
| 148 |
+
f1_weighted = f1_score(targets, preds, average='weighted', zero_division=0)
|
| 149 |
+
|
| 150 |
+
# Precision and Recall
|
| 151 |
+
precision = precision_score(targets, preds, average='macro', zero_division=0)
|
| 152 |
+
recall = recall_score(targets, preds, average='macro', zero_division=0)
|
| 153 |
+
|
| 154 |
+
# Confusion matrix
|
| 155 |
+
cm = confusion_matrix(targets, preds, labels=range(self.num_classes))
|
| 156 |
+
|
| 157 |
+
return {
|
| 158 |
+
'accuracy': accuracy,
|
| 159 |
+
'f1_macro': f1_macro,
|
| 160 |
+
'f1_weighted': f1_weighted,
|
| 161 |
+
'precision': precision,
|
| 162 |
+
'recall': recall,
|
| 163 |
+
'confusion_matrix': cm.tolist()
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class MetricsTracker:
|
| 168 |
+
"""Track all metrics during training"""
|
| 169 |
+
|
| 170 |
+
def __init__(self, config):
|
| 171 |
+
"""
|
| 172 |
+
Initialize metrics tracker
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
config: Configuration object
|
| 176 |
+
"""
|
| 177 |
+
self.config = config
|
| 178 |
+
self.num_classes = config.get('data.num_classes', 3)
|
| 179 |
+
|
| 180 |
+
self.seg_metrics = SegmentationMetrics()
|
| 181 |
+
self.cls_metrics = ClassificationMetrics(self.num_classes)
|
| 182 |
+
|
| 183 |
+
self.history = {
|
| 184 |
+
'train_loss': [],
|
| 185 |
+
'val_loss': [],
|
| 186 |
+
'train_iou': [],
|
| 187 |
+
'val_iou': [],
|
| 188 |
+
'train_dice': [],
|
| 189 |
+
'val_dice': [],
|
| 190 |
+
'train_precision': [],
|
| 191 |
+
'val_precision': [],
|
| 192 |
+
'train_recall': [],
|
| 193 |
+
'val_recall': []
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
def reset(self):
|
| 197 |
+
"""Reset metrics for new epoch"""
|
| 198 |
+
self.seg_metrics.reset()
|
| 199 |
+
self.cls_metrics.reset()
|
| 200 |
+
|
| 201 |
+
def update_segmentation(self,
|
| 202 |
+
pred: torch.Tensor,
|
| 203 |
+
target: torch.Tensor,
|
| 204 |
+
dataset_name: str):
|
| 205 |
+
"""Update segmentation metrics (dataset-aware)"""
|
| 206 |
+
has_pixel_mask = self.config.should_compute_localization_metrics(dataset_name)
|
| 207 |
+
self.seg_metrics.update(pred, target, has_pixel_mask)
|
| 208 |
+
|
| 209 |
+
def update_classification(self,
|
| 210 |
+
pred: np.ndarray,
|
| 211 |
+
target: np.ndarray,
|
| 212 |
+
confidence: Optional[np.ndarray] = None):
|
| 213 |
+
"""Update classification metrics"""
|
| 214 |
+
self.cls_metrics.update(pred, target, confidence)
|
| 215 |
+
|
| 216 |
+
def compute_all(self) -> Dict[str, float]:
|
| 217 |
+
"""Compute all metrics"""
|
| 218 |
+
seg = self.seg_metrics.compute()
|
| 219 |
+
|
| 220 |
+
# Only include classification metrics if they have data
|
| 221 |
+
if len(self.cls_metrics.predictions) > 0:
|
| 222 |
+
cls = self.cls_metrics.compute()
|
| 223 |
+
# Prefix classification metrics to avoid collision
|
| 224 |
+
cls_prefixed = {f'cls_{k}': v for k, v in cls.items()}
|
| 225 |
+
return {**seg, **cls_prefixed}
|
| 226 |
+
|
| 227 |
+
return seg
|
| 228 |
+
|
| 229 |
+
def log_epoch(self, epoch: int, phase: str, loss: float, metrics: Dict):
|
| 230 |
+
"""Log metrics for epoch"""
|
| 231 |
+
prefix = f'{phase}_'
|
| 232 |
+
|
| 233 |
+
self.history[f'{phase}_loss'].append(loss)
|
| 234 |
+
|
| 235 |
+
if 'iou' in metrics:
|
| 236 |
+
self.history[f'{phase}_iou'].append(metrics['iou'])
|
| 237 |
+
if 'dice' in metrics:
|
| 238 |
+
self.history[f'{phase}_dice'].append(metrics['dice'])
|
| 239 |
+
if 'precision' in metrics:
|
| 240 |
+
self.history[f'{phase}_precision'].append(metrics['precision'])
|
| 241 |
+
if 'recall' in metrics:
|
| 242 |
+
self.history[f'{phase}_recall'].append(metrics['recall'])
|
| 243 |
+
|
| 244 |
+
def get_history(self) -> Dict:
|
| 245 |
+
"""Get full training history"""
|
| 246 |
+
return self.history
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class EarlyStopping:
|
| 250 |
+
"""Early stopping to prevent overfitting"""
|
| 251 |
+
|
| 252 |
+
def __init__(self,
|
| 253 |
+
patience: int = 10,
|
| 254 |
+
min_delta: float = 0.001,
|
| 255 |
+
mode: str = 'max'):
|
| 256 |
+
"""
|
| 257 |
+
Initialize early stopping
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
patience: Number of epochs to wait
|
| 261 |
+
min_delta: Minimum improvement required
|
| 262 |
+
mode: 'min' for loss, 'max' for metrics
|
| 263 |
+
"""
|
| 264 |
+
self.patience = patience
|
| 265 |
+
self.min_delta = min_delta
|
| 266 |
+
self.mode = mode
|
| 267 |
+
|
| 268 |
+
self.counter = 0
|
| 269 |
+
self.best_value = None
|
| 270 |
+
self.should_stop = False
|
| 271 |
+
|
| 272 |
+
def __call__(self, value: float) -> bool:
|
| 273 |
+
"""
|
| 274 |
+
Check if training should stop
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
value: Current metric value
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
True if should stop
|
| 281 |
+
"""
|
| 282 |
+
if self.best_value is None:
|
| 283 |
+
self.best_value = value
|
| 284 |
+
return False
|
| 285 |
+
|
| 286 |
+
if self.mode == 'max':
|
| 287 |
+
improved = value > self.best_value + self.min_delta
|
| 288 |
+
else:
|
| 289 |
+
improved = value < self.best_value - self.min_delta
|
| 290 |
+
|
| 291 |
+
if improved:
|
| 292 |
+
self.best_value = value
|
| 293 |
+
self.counter = 0
|
| 294 |
+
else:
|
| 295 |
+
self.counter += 1
|
| 296 |
+
|
| 297 |
+
if self.counter >= self.patience:
|
| 298 |
+
self.should_stop = True
|
| 299 |
+
|
| 300 |
+
return self.should_stop
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def get_metrics_tracker(config) -> MetricsTracker:
|
| 304 |
+
"""Factory function for metrics tracker"""
|
| 305 |
+
return MetricsTracker(config)
|
src/training/trainer.py
ADDED
|
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training loop for forgery localization network
|
| 3 |
+
Implements chunked training for RAM constraints
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.optim as optim
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 12 |
+
from typing import Dict, Optional, Tuple
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
import json
|
| 16 |
+
import csv
|
| 17 |
+
|
| 18 |
+
from ..models import get_model, get_loss_function
|
| 19 |
+
from ..data import get_dataset
|
| 20 |
+
from .metrics import MetricsTracker, EarlyStopping
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Trainer:
|
| 24 |
+
"""
|
| 25 |
+
Trainer for forgery localization network
|
| 26 |
+
Supports chunked training for large datasets (DocTamper)
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, config, dataset_name: str = 'doctamper'):
|
| 30 |
+
"""
|
| 31 |
+
Initialize trainer
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
config: Configuration object
|
| 35 |
+
dataset_name: Dataset to train on
|
| 36 |
+
"""
|
| 37 |
+
self.config = config
|
| 38 |
+
self.dataset_name = dataset_name
|
| 39 |
+
|
| 40 |
+
# Device setup
|
| 41 |
+
self.device = torch.device(
|
| 42 |
+
'cuda' if torch.cuda.is_available() and config.get('system.device') == 'cuda'
|
| 43 |
+
else 'cpu'
|
| 44 |
+
)
|
| 45 |
+
print(f"Training on: {self.device}")
|
| 46 |
+
|
| 47 |
+
# Initialize model
|
| 48 |
+
self.model = get_model(config).to(self.device)
|
| 49 |
+
|
| 50 |
+
# Loss function (dataset-aware)
|
| 51 |
+
self.criterion = get_loss_function(config)
|
| 52 |
+
|
| 53 |
+
# Optimizer
|
| 54 |
+
lr = config.get('training.learning_rate', 0.001)
|
| 55 |
+
weight_decay = config.get('training.weight_decay', 0.0001)
|
| 56 |
+
self.optimizer = optim.AdamW(
|
| 57 |
+
self.model.parameters(),
|
| 58 |
+
lr=lr,
|
| 59 |
+
weight_decay=weight_decay
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Learning rate scheduler
|
| 63 |
+
epochs = config.get('training.epochs', 50)
|
| 64 |
+
warmup_epochs = config.get('training.scheduler.warmup_epochs', 5)
|
| 65 |
+
min_lr = config.get('training.scheduler.min_lr', 1e-5)
|
| 66 |
+
|
| 67 |
+
self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
| 68 |
+
self.optimizer,
|
| 69 |
+
T_0=epochs - warmup_epochs,
|
| 70 |
+
T_mult=1,
|
| 71 |
+
eta_min=min_lr
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# Mixed precision training
|
| 75 |
+
self.scaler = GradScaler()
|
| 76 |
+
|
| 77 |
+
# Metrics
|
| 78 |
+
self.metrics_tracker = MetricsTracker(config)
|
| 79 |
+
|
| 80 |
+
# Early stopping
|
| 81 |
+
patience = config.get('training.early_stopping.patience', 10)
|
| 82 |
+
min_delta = config.get('training.early_stopping.min_delta', 0.001)
|
| 83 |
+
self.early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
|
| 84 |
+
|
| 85 |
+
# Output directories
|
| 86 |
+
self.checkpoint_dir = Path(config.get('outputs.checkpoints', 'outputs/checkpoints'))
|
| 87 |
+
self.log_dir = Path(config.get('outputs.logs', 'outputs/logs'))
|
| 88 |
+
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 89 |
+
self.log_dir.mkdir(parents=True, exist_ok=True)
|
| 90 |
+
|
| 91 |
+
# Training state
|
| 92 |
+
self.current_epoch = 0
|
| 93 |
+
self.best_metric = 0.0
|
| 94 |
+
|
| 95 |
+
def create_dataloaders(self,
|
| 96 |
+
chunk_start: float = 0.0,
|
| 97 |
+
chunk_end: float = 1.0) -> Tuple[DataLoader, DataLoader]:
|
| 98 |
+
"""
|
| 99 |
+
Create train and validation dataloaders
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
chunk_start: Start ratio for chunked training
|
| 103 |
+
chunk_end: End ratio for chunked training
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
Train and validation dataloaders
|
| 107 |
+
"""
|
| 108 |
+
batch_size = self.config.get('data.batch_size', 8)
|
| 109 |
+
num_workers = self.config.get('system.num_workers', 4)
|
| 110 |
+
|
| 111 |
+
# Training dataset (with chunking for DocTamper)
|
| 112 |
+
if self.dataset_name == 'doctamper':
|
| 113 |
+
train_dataset = get_dataset(
|
| 114 |
+
self.config,
|
| 115 |
+
self.dataset_name,
|
| 116 |
+
split='train',
|
| 117 |
+
chunk_start=chunk_start,
|
| 118 |
+
chunk_end=chunk_end
|
| 119 |
+
)
|
| 120 |
+
else:
|
| 121 |
+
train_dataset = get_dataset(
|
| 122 |
+
self.config,
|
| 123 |
+
self.dataset_name,
|
| 124 |
+
split='train'
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Validation dataset (always full)
|
| 128 |
+
# For FCD and SCD, validate on DocTamper TestingSet
|
| 129 |
+
if self.dataset_name in ['fcd', 'scd']:
|
| 130 |
+
val_dataset = get_dataset(
|
| 131 |
+
self.config,
|
| 132 |
+
'doctamper', # Use DocTamper for validation
|
| 133 |
+
split='val'
|
| 134 |
+
)
|
| 135 |
+
else:
|
| 136 |
+
val_dataset = get_dataset(
|
| 137 |
+
self.config,
|
| 138 |
+
self.dataset_name,
|
| 139 |
+
split='val' if self.dataset_name in ['doctamper', 'receipts'] else 'test'
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
train_loader = DataLoader(
|
| 143 |
+
train_dataset,
|
| 144 |
+
batch_size=batch_size,
|
| 145 |
+
shuffle=True,
|
| 146 |
+
num_workers=num_workers,
|
| 147 |
+
pin_memory=self.config.get('system.pin_memory', True),
|
| 148 |
+
drop_last=True
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
val_loader = DataLoader(
|
| 152 |
+
val_dataset,
|
| 153 |
+
batch_size=batch_size,
|
| 154 |
+
shuffle=False,
|
| 155 |
+
num_workers=num_workers,
|
| 156 |
+
pin_memory=True
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
return train_loader, val_loader
|
| 160 |
+
|
| 161 |
+
def train_epoch(self, dataloader: DataLoader) -> Tuple[float, Dict]:
|
| 162 |
+
"""
|
| 163 |
+
Train for one epoch
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
dataloader: Training dataloader
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
Average loss and metrics
|
| 170 |
+
"""
|
| 171 |
+
self.model.train()
|
| 172 |
+
self.metrics_tracker.reset()
|
| 173 |
+
|
| 174 |
+
total_loss = 0.0
|
| 175 |
+
num_batches = 0
|
| 176 |
+
|
| 177 |
+
pbar = tqdm(dataloader, desc=f"Epoch {self.current_epoch} [Train]")
|
| 178 |
+
|
| 179 |
+
for batch_idx, (images, masks, metadata) in enumerate(pbar):
|
| 180 |
+
images = images.to(self.device)
|
| 181 |
+
masks = masks.to(self.device)
|
| 182 |
+
|
| 183 |
+
# Forward pass with mixed precision
|
| 184 |
+
self.optimizer.zero_grad()
|
| 185 |
+
|
| 186 |
+
with autocast():
|
| 187 |
+
outputs, _ = self.model(images)
|
| 188 |
+
|
| 189 |
+
# Dataset-aware loss
|
| 190 |
+
has_pixel_mask = self.config.has_pixel_mask(self.dataset_name)
|
| 191 |
+
losses = self.criterion.combined_loss(outputs, masks, has_pixel_mask)
|
| 192 |
+
|
| 193 |
+
# Backward pass with gradient scaling
|
| 194 |
+
self.scaler.scale(losses['total']).backward()
|
| 195 |
+
self.scaler.step(self.optimizer)
|
| 196 |
+
self.scaler.update()
|
| 197 |
+
|
| 198 |
+
# Update metrics
|
| 199 |
+
with torch.no_grad():
|
| 200 |
+
probs = torch.sigmoid(outputs)
|
| 201 |
+
self.metrics_tracker.update_segmentation(
|
| 202 |
+
probs, masks, self.dataset_name
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
total_loss += losses['total'].item()
|
| 206 |
+
num_batches += 1
|
| 207 |
+
|
| 208 |
+
# Update progress bar
|
| 209 |
+
pbar.set_postfix({
|
| 210 |
+
'loss': f"{losses['total'].item():.4f}",
|
| 211 |
+
'bce': f"{losses['bce'].item():.4f}"
|
| 212 |
+
})
|
| 213 |
+
|
| 214 |
+
avg_loss = total_loss / num_batches
|
| 215 |
+
metrics = self.metrics_tracker.compute_all()
|
| 216 |
+
|
| 217 |
+
return avg_loss, metrics
|
| 218 |
+
|
| 219 |
+
def validate(self, dataloader: DataLoader) -> Tuple[float, Dict]:
|
| 220 |
+
"""
|
| 221 |
+
Validate model
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
dataloader: Validation dataloader
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
Average loss and metrics
|
| 228 |
+
"""
|
| 229 |
+
self.model.eval()
|
| 230 |
+
self.metrics_tracker.reset()
|
| 231 |
+
|
| 232 |
+
total_loss = 0.0
|
| 233 |
+
num_batches = 0
|
| 234 |
+
|
| 235 |
+
pbar = tqdm(dataloader, desc=f"Epoch {self.current_epoch} [Val]")
|
| 236 |
+
|
| 237 |
+
with torch.no_grad():
|
| 238 |
+
for images, masks, metadata in pbar:
|
| 239 |
+
images = images.to(self.device)
|
| 240 |
+
masks = masks.to(self.device)
|
| 241 |
+
|
| 242 |
+
# Forward pass
|
| 243 |
+
outputs, _ = self.model(images)
|
| 244 |
+
|
| 245 |
+
# Dataset-aware loss
|
| 246 |
+
has_pixel_mask = self.config.has_pixel_mask(self.dataset_name)
|
| 247 |
+
losses = self.criterion.combined_loss(outputs, masks, has_pixel_mask)
|
| 248 |
+
|
| 249 |
+
# Update metrics
|
| 250 |
+
probs = torch.sigmoid(outputs)
|
| 251 |
+
self.metrics_tracker.update_segmentation(
|
| 252 |
+
probs, masks, self.dataset_name
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
total_loss += losses['total'].item()
|
| 256 |
+
num_batches += 1
|
| 257 |
+
|
| 258 |
+
pbar.set_postfix({
|
| 259 |
+
'loss': f"{losses['total'].item():.4f}"
|
| 260 |
+
})
|
| 261 |
+
|
| 262 |
+
avg_loss = total_loss / num_batches
|
| 263 |
+
metrics = self.metrics_tracker.compute_all()
|
| 264 |
+
|
| 265 |
+
return avg_loss, metrics
|
| 266 |
+
|
| 267 |
+
def save_checkpoint(self,
|
| 268 |
+
filename: str,
|
| 269 |
+
is_best: bool = False,
|
| 270 |
+
chunk_id: Optional[int] = None):
|
| 271 |
+
"""Save model checkpoint"""
|
| 272 |
+
checkpoint = {
|
| 273 |
+
'epoch': self.current_epoch,
|
| 274 |
+
'model_state_dict': self.model.state_dict(),
|
| 275 |
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
| 276 |
+
'scheduler_state_dict': self.scheduler.state_dict(),
|
| 277 |
+
'best_metric': self.best_metric,
|
| 278 |
+
'dataset': self.dataset_name,
|
| 279 |
+
'chunk_id': chunk_id
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
path = self.checkpoint_dir / filename
|
| 283 |
+
torch.save(checkpoint, path)
|
| 284 |
+
print(f"Saved checkpoint: {path}")
|
| 285 |
+
|
| 286 |
+
if is_best:
|
| 287 |
+
best_path = self.checkpoint_dir / f'best_{self.dataset_name}.pth'
|
| 288 |
+
torch.save(checkpoint, best_path)
|
| 289 |
+
print(f"Saved best model: {best_path}")
|
| 290 |
+
|
| 291 |
+
def load_checkpoint(self, filename: str, reset_epoch: bool = False):
|
| 292 |
+
"""
|
| 293 |
+
Load model checkpoint
|
| 294 |
+
|
| 295 |
+
Args:
|
| 296 |
+
filename: Checkpoint filename
|
| 297 |
+
reset_epoch: If True, reset epoch counter to 0 (useful for chunked training)
|
| 298 |
+
"""
|
| 299 |
+
path = self.checkpoint_dir / filename
|
| 300 |
+
|
| 301 |
+
if not path.exists():
|
| 302 |
+
print(f"Checkpoint not found: {path}")
|
| 303 |
+
return False
|
| 304 |
+
|
| 305 |
+
checkpoint = torch.load(path, map_location=self.device)
|
| 306 |
+
|
| 307 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 308 |
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 309 |
+
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
| 310 |
+
|
| 311 |
+
if reset_epoch:
|
| 312 |
+
self.current_epoch = 0
|
| 313 |
+
print(f"Loaded checkpoint: {path} (epoch counter reset to 0)")
|
| 314 |
+
else:
|
| 315 |
+
self.current_epoch = checkpoint['epoch'] + 1 # Continue from next epoch
|
| 316 |
+
print(f"Loaded checkpoint: {path} (resuming from epoch {self.current_epoch})")
|
| 317 |
+
|
| 318 |
+
self.best_metric = checkpoint.get('best_metric', 0.0)
|
| 319 |
+
|
| 320 |
+
return True
|
| 321 |
+
|
| 322 |
+
def train(self,
|
| 323 |
+
epochs: Optional[int] = None,
|
| 324 |
+
chunk_start: float = 0.0,
|
| 325 |
+
chunk_end: float = 1.0,
|
| 326 |
+
chunk_id: Optional[int] = None,
|
| 327 |
+
resume_from: Optional[str] = None):
|
| 328 |
+
"""
|
| 329 |
+
Main training loop
|
| 330 |
+
|
| 331 |
+
Args:
|
| 332 |
+
epochs: Number of epochs (None uses config)
|
| 333 |
+
chunk_start: Start ratio for chunked training
|
| 334 |
+
chunk_end: End ratio for chunked training
|
| 335 |
+
chunk_id: Chunk identifier for logging
|
| 336 |
+
resume_from: Checkpoint to resume from
|
| 337 |
+
"""
|
| 338 |
+
if epochs is None:
|
| 339 |
+
epochs = self.config.get('training.epochs', 50)
|
| 340 |
+
|
| 341 |
+
# Resume if specified
|
| 342 |
+
if resume_from:
|
| 343 |
+
self.load_checkpoint(resume_from)
|
| 344 |
+
|
| 345 |
+
# Create dataloaders
|
| 346 |
+
train_loader, val_loader = self.create_dataloaders(chunk_start, chunk_end)
|
| 347 |
+
|
| 348 |
+
print(f"\n{'='*60}")
|
| 349 |
+
print(f"Training: {self.dataset_name}")
|
| 350 |
+
if chunk_id is not None:
|
| 351 |
+
print(f"Chunk: {chunk_id} [{chunk_start*100:.0f}% - {chunk_end*100:.0f}%]")
|
| 352 |
+
print(f"Epochs: {epochs}")
|
| 353 |
+
print(f"Train samples: {len(train_loader.dataset)}")
|
| 354 |
+
print(f"Val samples: {len(val_loader.dataset)}")
|
| 355 |
+
print(f"{'='*60}\n")
|
| 356 |
+
|
| 357 |
+
# Training log file
|
| 358 |
+
log_file = self.log_dir / f'{self.dataset_name}_chunk{chunk_id or 0}_log.csv'
|
| 359 |
+
with open(log_file, 'w', newline='') as f:
|
| 360 |
+
writer = csv.writer(f)
|
| 361 |
+
writer.writerow(['epoch', 'train_loss', 'val_loss',
|
| 362 |
+
'train_iou', 'val_iou', 'train_dice', 'val_dice',
|
| 363 |
+
'train_precision', 'val_precision',
|
| 364 |
+
'train_recall', 'val_recall', 'lr'])
|
| 365 |
+
|
| 366 |
+
for epoch in range(self.current_epoch, epochs):
|
| 367 |
+
self.current_epoch = epoch
|
| 368 |
+
|
| 369 |
+
# Train
|
| 370 |
+
train_loss, train_metrics = self.train_epoch(train_loader)
|
| 371 |
+
|
| 372 |
+
# Validate
|
| 373 |
+
val_loss, val_metrics = self.validate(val_loader)
|
| 374 |
+
|
| 375 |
+
# Update scheduler
|
| 376 |
+
self.scheduler.step()
|
| 377 |
+
current_lr = self.optimizer.param_groups[0]['lr']
|
| 378 |
+
|
| 379 |
+
# Log metrics
|
| 380 |
+
self.metrics_tracker.log_epoch(epoch, 'train', train_loss, train_metrics)
|
| 381 |
+
self.metrics_tracker.log_epoch(epoch, 'val', val_loss, val_metrics)
|
| 382 |
+
|
| 383 |
+
# Log to file
|
| 384 |
+
with open(log_file, 'a', newline='') as f:
|
| 385 |
+
writer = csv.writer(f)
|
| 386 |
+
writer.writerow([
|
| 387 |
+
epoch,
|
| 388 |
+
f"{train_loss:.4f}",
|
| 389 |
+
f"{val_loss:.4f}",
|
| 390 |
+
f"{train_metrics.get('iou', 0):.4f}",
|
| 391 |
+
f"{val_metrics.get('iou', 0):.4f}",
|
| 392 |
+
f"{train_metrics.get('dice', 0):.4f}",
|
| 393 |
+
f"{val_metrics.get('dice', 0):.4f}",
|
| 394 |
+
f"{train_metrics.get('precision', 0):.4f}",
|
| 395 |
+
f"{val_metrics.get('precision', 0):.4f}",
|
| 396 |
+
f"{train_metrics.get('recall', 0):.4f}",
|
| 397 |
+
f"{val_metrics.get('recall', 0):.4f}",
|
| 398 |
+
f"{current_lr:.6f}"
|
| 399 |
+
])
|
| 400 |
+
|
| 401 |
+
# Print summary
|
| 402 |
+
print(f"\nEpoch {epoch}/{epochs-1}")
|
| 403 |
+
print(f" Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
|
| 404 |
+
print(f" Train IoU: {train_metrics.get('iou', 0):.4f} | Val IoU: {val_metrics.get('iou', 0):.4f}")
|
| 405 |
+
print(f" Train Dice: {train_metrics.get('dice', 0):.4f} | Val Dice: {val_metrics.get('dice', 0):.4f}")
|
| 406 |
+
print(f" LR: {current_lr:.6f}")
|
| 407 |
+
|
| 408 |
+
# Save checkpoints
|
| 409 |
+
if self.config.get('training.checkpoint.save_every', 5) > 0:
|
| 410 |
+
if (epoch + 1) % self.config.get('training.checkpoint.save_every', 5) == 0:
|
| 411 |
+
self.save_checkpoint(
|
| 412 |
+
f'{self.dataset_name}_chunk{chunk_id or 0}_epoch{epoch}.pth',
|
| 413 |
+
chunk_id=chunk_id
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
# Check for best model
|
| 417 |
+
monitor_metric = val_metrics.get('dice', 0)
|
| 418 |
+
if monitor_metric > self.best_metric:
|
| 419 |
+
self.best_metric = monitor_metric
|
| 420 |
+
self.save_checkpoint(
|
| 421 |
+
f'{self.dataset_name}_chunk{chunk_id or 0}_best.pth',
|
| 422 |
+
is_best=True,
|
| 423 |
+
chunk_id=chunk_id
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
# Early stopping
|
| 427 |
+
if self.early_stopping(monitor_metric):
|
| 428 |
+
print(f"\nEarly stopping triggered at epoch {epoch}")
|
| 429 |
+
break
|
| 430 |
+
|
| 431 |
+
# Save final checkpoint
|
| 432 |
+
self.save_checkpoint(
|
| 433 |
+
f'{self.dataset_name}_chunk{chunk_id or 0}_final.pth',
|
| 434 |
+
chunk_id=chunk_id
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# Save training history
|
| 438 |
+
history_file = self.log_dir / f'{self.dataset_name}_chunk{chunk_id or 0}_history.json'
|
| 439 |
+
with open(history_file, 'w') as f:
|
| 440 |
+
json.dump(self.metrics_tracker.get_history(), f, indent=2)
|
| 441 |
+
|
| 442 |
+
print(f"\nTraining complete!")
|
| 443 |
+
print(f"Best Dice: {self.best_metric:.4f}")
|
| 444 |
+
|
| 445 |
+
return self.metrics_tracker.get_history()
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def get_trainer(config, dataset_name: str = 'doctamper') -> Trainer:
|
| 449 |
+
"""Factory function for trainer"""
|
| 450 |
+
return Trainer(config, dataset_name)
|
src/utils/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utilities module"""
|
| 2 |
+
|
| 3 |
+
from .plotting import (
|
| 4 |
+
plot_training_curves,
|
| 5 |
+
plot_confusion_matrix,
|
| 6 |
+
plot_feature_importance,
|
| 7 |
+
plot_dataset_comparison,
|
| 8 |
+
plot_chunked_training_progress,
|
| 9 |
+
generate_training_report
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
from .export import (
|
| 13 |
+
export_to_onnx,
|
| 14 |
+
export_to_torchscript,
|
| 15 |
+
quantize_model
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
'plot_training_curves',
|
| 20 |
+
'plot_confusion_matrix',
|
| 21 |
+
'plot_feature_importance',
|
| 22 |
+
'plot_dataset_comparison',
|
| 23 |
+
'plot_chunked_training_progress',
|
| 24 |
+
'generate_training_report',
|
| 25 |
+
'export_to_onnx',
|
| 26 |
+
'export_to_torchscript',
|
| 27 |
+
'quantize_model'
|
| 28 |
+
]
|
src/utils/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (601 Bytes). View file
|
|
|