JKrishnanandhaa commited on
Commit
547247c
·
verified ·
1 Parent(s): 4316937

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +231 -231
app.py CHANGED
@@ -1,231 +1,231 @@
1
- """
2
- Document Forgery Detection - Gradio Interface for Hugging Face Spaces
3
-
4
- This app provides a web interface for detecting and classifying document forgeries.
5
- """
6
-
7
- import gradio as gr
8
- import torch
9
- import cv2
10
- import numpy as np
11
- from PIL import Image
12
- import json
13
- from pathlib import Path
14
- import sys
15
-
16
- # Add src to path
17
- sys.path.insert(0, str(Path(__file__).parent))
18
-
19
- from src.models import get_model
20
- from src.config import get_config
21
- from src.data.preprocessing import DocumentPreprocessor
22
- from src.data.augmentation import DatasetAwareAugmentation
23
- from src.features.region_extraction import get_mask_refiner, get_region_extractor
24
- from src.features.feature_extraction import get_feature_extractor
25
- from src.training.classifier import ForgeryClassifier
26
-
27
- # Class names
28
- CLASS_NAMES = {0: 'Copy-Move', 1: 'Splicing', 2: 'Generation'}
29
- CLASS_COLORS = {
30
- 0: (255, 0, 0), # Red for Copy-Move
31
- 1: (0, 255, 0), # Green for Splicing
32
- 2: (0, 0, 255) # Blue for Generation
33
- }
34
-
35
-
36
- class ForgeryDetector:
37
- """Main forgery detection pipeline"""
38
-
39
- def __init__(self):
40
- print("Loading models...")
41
-
42
- # Load config
43
- self.config = get_config('config.yaml')
44
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
45
-
46
- # Load segmentation model
47
- self.model = get_model(self.config).to(self.device)
48
- checkpoint = torch.load('models/segmentation_model.pth', map_location=self.device)
49
- self.model.load_state_dict(checkpoint['model_state_dict'])
50
- self.model.eval()
51
-
52
- # Load classifier
53
- self.classifier = ForgeryClassifier(self.config)
54
- self.classifier.load('models/classifier')
55
-
56
- # Initialize components
57
- self.preprocessor = DocumentPreprocessor(self.config, 'doctamper')
58
- self.augmentation = DatasetAwareAugmentation(self.config, 'doctamper', is_training=False)
59
- self.mask_refiner = get_mask_refiner(self.config)
60
- self.region_extractor = get_region_extractor(self.config)
61
- self.feature_extractor = get_feature_extractor(self.config, is_text_document=True)
62
-
63
- print("✓ Models loaded successfully!")
64
-
65
- def detect(self, image):
66
- """
67
- Detect forgeries in document image
68
-
69
- Args:
70
- image: PIL Image or numpy array
71
-
72
- Returns:
73
- overlay_image: Image with detection overlay
74
- results_json: Detection results as JSON
75
- """
76
- # Convert PIL to numpy
77
- if isinstance(image, Image.Image):
78
- image = np.array(image)
79
-
80
- # Convert to RGB
81
- if len(image.shape) == 2:
82
- image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
83
- elif image.shape[2] == 4:
84
- image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
85
-
86
- original_image = image.copy()
87
-
88
- # Preprocess
89
- preprocessed, _ = self.preprocessor(image, None)
90
-
91
- # Augment
92
- augmented = self.augmentation(preprocessed, None)
93
- image_tensor = augmented['image'].unsqueeze(0).to(self.device)
94
-
95
- # Run localization
96
- with torch.no_grad():
97
- logits, decoder_features = self.model(image_tensor)
98
- prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
99
-
100
- # Refine mask
101
- binary_mask = (prob_map > 0.5).astype(np.uint8)
102
- refined_mask = self.mask_refiner.refine(binary_mask, original_size=original_image.shape[:2])
103
-
104
- # Extract regions
105
- regions = self.region_extractor.extract(refined_mask, prob_map, original_image)
106
-
107
- # Classify regions
108
- results = []
109
- for region in regions:
110
- # Extract features
111
- features = self.feature_extractor.extract(
112
- preprocessed,
113
- region['region_mask'],
114
- [f.cpu() for f in decoder_features]
115
- )
116
-
117
- # Classify
118
- predictions, confidences = self.classifier.predict(features)
119
- forgery_type = int(predictions[0])
120
- confidence = float(confidences[0])
121
-
122
- if confidence > 0.6: # Confidence threshold
123
- results.append({
124
- 'region_id': region['region_id'],
125
- 'bounding_box': region['bounding_box'],
126
- 'forgery_type': CLASS_NAMES[forgery_type],
127
- 'confidence': confidence
128
- })
129
-
130
- # Create visualization
131
- overlay = self._create_overlay(original_image, results)
132
-
133
- # Create JSON response
134
- json_results = {
135
- 'num_detections': len(results),
136
- 'detections': results,
137
- 'model_info': {
138
- 'segmentation_dice': '75%',
139
- 'classifier_accuracy': '92%'
140
- }
141
- }
142
-
143
- return overlay, json_results
144
-
145
- def _create_overlay(self, image, results):
146
- """Create overlay visualization"""
147
- overlay = image.copy()
148
-
149
- # Draw bounding boxes and labels
150
- for result in results:
151
- bbox = result['bounding_box']
152
- x, y, w, h = bbox
153
-
154
- forgery_type = result['forgery_type']
155
- confidence = result['confidence']
156
-
157
- # Get color
158
- forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0]
159
- color = CLASS_COLORS[forgery_id]
160
-
161
- # Draw rectangle
162
- cv2.rectangle(overlay, (x, y), (x+w, y+h), color, 2)
163
-
164
- # Draw label
165
- label = f"{forgery_type}: {confidence:.1%}"
166
- label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
167
- cv2.rectangle(overlay, (x, y-label_size[1]-10), (x+label_size[0], y), color, -1)
168
- cv2.putText(overlay, label, (x, y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
169
-
170
- # Add legend
171
- if len(results) > 0:
172
- legend_y = 30
173
- cv2.putText(overlay, f"Detected {len(results)} forgery region(s)",
174
- (10, legend_y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
175
-
176
- return overlay
177
-
178
-
179
- # Initialize detector
180
- detector = ForgeryDetector()
181
-
182
-
183
- def detect_forgery(image):
184
- """Gradio interface function"""
185
- try:
186
- overlay, results = detector.detect(image)
187
- return overlay, json.dumps(results, indent=2)
188
- except Exception as e:
189
- return None, f"Error: {str(e)}"
190
-
191
-
192
- # Create Gradio interface
193
- demo = gr.Interface(
194
- fn=detect_forgery,
195
- inputs=gr.Image(type="pil", label="Upload Document Image"),
196
- outputs=[
197
- gr.Image(type="numpy", label="Detection Result"),
198
- gr.JSON(label="Detection Details")
199
- ],
200
- title="📄 Document Forgery Detector",
201
- description="""
202
- Upload a document image to detect and classify forgeries.
203
-
204
- **Supported Forgery Types:**
205
- - 🔴 Copy-Move: Duplicated regions within the document
206
- - 🟢 Splicing: Content from different sources
207
- - 🔵 Generation: AI-generated or synthesized content
208
-
209
- **Model Performance:**
210
- - Localization: 75% Dice Score
211
- - Classification: 92% Accuracy
212
- """,
213
- examples=[
214
- ["examples/sample1.jpg"],
215
- ["examples/sample2.jpg"],
216
- ],
217
- article="""
218
- ### About
219
- This model uses a hybrid deep learning approach:
220
- 1. **Localization**: MobileNetV3-Small + UNet-Lite (detects WHERE)
221
- 2. **Classification**: LightGBM with hybrid features (detects WHAT)
222
-
223
- Trained on DocTamper dataset (140K samples).
224
- """,
225
- theme=gr.themes.Soft(),
226
- allow_flagging="never"
227
- )
228
-
229
-
230
- if __name__ == "__main__":
231
- demo.launch()
 
1
+ """
2
+ Document Forgery Detection - Gradio Interface for Hugging Face Spaces
3
+
4
+ This app provides a web interface for detecting and classifying document forgeries.
5
+ """
6
+
7
+ import gradio as gr
8
+ import torch
9
+ import cv2
10
+ import numpy as np
11
+ from PIL import Image
12
+ import json
13
+ from pathlib import Path
14
+ import sys
15
+
16
+ # Add src to path
17
+ sys.path.insert(0, str(Path(__file__).parent))
18
+
19
+ from src.models import get_model
20
+ from src.config import get_config
21
+ from src.data.preprocessing import DocumentPreprocessor
22
+ from src.data.augmentation import DatasetAwareAugmentation
23
+ from src.features.region_extraction import get_mask_refiner, get_region_extractor
24
+ from src.features.feature_extraction import get_feature_extractor
25
+ from src.training.classifier import ForgeryClassifier
26
+
27
+ # Class names
28
+ CLASS_NAMES = {0: 'Copy-Move', 1: 'Splicing', 2: 'Generation'}
29
+ CLASS_COLORS = {
30
+ 0: (255, 0, 0), # Red for Copy-Move
31
+ 1: (0, 255, 0), # Green for Splicing
32
+ 2: (0, 0, 255) # Blue for Generation
33
+ }
34
+
35
+
36
+ class ForgeryDetector:
37
+ """Main forgery detection pipeline"""
38
+
39
+ def __init__(self):
40
+ print("Loading models...")
41
+
42
+ # Load config
43
+ self.config = get_config('config.yaml')
44
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
45
+
46
+ # Load segmentation model
47
+ self.model = get_model(self.config).to(self.device)
48
+ checkpoint = torch.load('models/best_doctamper.pth', map_location=self.device)
49
+ self.model.load_state_dict(checkpoint['model_state_dict'])
50
+ self.model.eval()
51
+
52
+ # Load classifier
53
+ self.classifier = ForgeryClassifier(self.config)
54
+ self.classifier.load('models/classifier')
55
+
56
+ # Initialize components
57
+ self.preprocessor = DocumentPreprocessor(self.config, 'doctamper')
58
+ self.augmentation = DatasetAwareAugmentation(self.config, 'doctamper', is_training=False)
59
+ self.mask_refiner = get_mask_refiner(self.config)
60
+ self.region_extractor = get_region_extractor(self.config)
61
+ self.feature_extractor = get_feature_extractor(self.config, is_text_document=True)
62
+
63
+ print("✓ Models loaded successfully!")
64
+
65
+ def detect(self, image):
66
+ """
67
+ Detect forgeries in document image
68
+
69
+ Args:
70
+ image: PIL Image or numpy array
71
+
72
+ Returns:
73
+ overlay_image: Image with detection overlay
74
+ results_json: Detection results as JSON
75
+ """
76
+ # Convert PIL to numpy
77
+ if isinstance(image, Image.Image):
78
+ image = np.array(image)
79
+
80
+ # Convert to RGB
81
+ if len(image.shape) == 2:
82
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
83
+ elif image.shape[2] == 4:
84
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
85
+
86
+ original_image = image.copy()
87
+
88
+ # Preprocess
89
+ preprocessed, _ = self.preprocessor(image, None)
90
+
91
+ # Augment
92
+ augmented = self.augmentation(preprocessed, None)
93
+ image_tensor = augmented['image'].unsqueeze(0).to(self.device)
94
+
95
+ # Run localization
96
+ with torch.no_grad():
97
+ logits, decoder_features = self.model(image_tensor)
98
+ prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
99
+
100
+ # Refine mask
101
+ binary_mask = (prob_map > 0.5).astype(np.uint8)
102
+ refined_mask = self.mask_refiner.refine(binary_mask, original_size=original_image.shape[:2])
103
+
104
+ # Extract regions
105
+ regions = self.region_extractor.extract(refined_mask, prob_map, original_image)
106
+
107
+ # Classify regions
108
+ results = []
109
+ for region in regions:
110
+ # Extract features
111
+ features = self.feature_extractor.extract(
112
+ preprocessed,
113
+ region['region_mask'],
114
+ [f.cpu() for f in decoder_features]
115
+ )
116
+
117
+ # Classify
118
+ predictions, confidences = self.classifier.predict(features)
119
+ forgery_type = int(predictions[0])
120
+ confidence = float(confidences[0])
121
+
122
+ if confidence > 0.6: # Confidence threshold
123
+ results.append({
124
+ 'region_id': region['region_id'],
125
+ 'bounding_box': region['bounding_box'],
126
+ 'forgery_type': CLASS_NAMES[forgery_type],
127
+ 'confidence': confidence
128
+ })
129
+
130
+ # Create visualization
131
+ overlay = self._create_overlay(original_image, results)
132
+
133
+ # Create JSON response
134
+ json_results = {
135
+ 'num_detections': len(results),
136
+ 'detections': results,
137
+ 'model_info': {
138
+ 'segmentation_dice': '75%',
139
+ 'classifier_accuracy': '92%'
140
+ }
141
+ }
142
+
143
+ return overlay, json_results
144
+
145
+ def _create_overlay(self, image, results):
146
+ """Create overlay visualization"""
147
+ overlay = image.copy()
148
+
149
+ # Draw bounding boxes and labels
150
+ for result in results:
151
+ bbox = result['bounding_box']
152
+ x, y, w, h = bbox
153
+
154
+ forgery_type = result['forgery_type']
155
+ confidence = result['confidence']
156
+
157
+ # Get color
158
+ forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0]
159
+ color = CLASS_COLORS[forgery_id]
160
+
161
+ # Draw rectangle
162
+ cv2.rectangle(overlay, (x, y), (x+w, y+h), color, 2)
163
+
164
+ # Draw label
165
+ label = f"{forgery_type}: {confidence:.1%}"
166
+ label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
167
+ cv2.rectangle(overlay, (x, y-label_size[1]-10), (x+label_size[0], y), color, -1)
168
+ cv2.putText(overlay, label, (x, y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
169
+
170
+ # Add legend
171
+ if len(results) > 0:
172
+ legend_y = 30
173
+ cv2.putText(overlay, f"Detected {len(results)} forgery region(s)",
174
+ (10, legend_y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
175
+
176
+ return overlay
177
+
178
+
179
+ # Initialize detector
180
+ detector = ForgeryDetector()
181
+
182
+
183
+ def detect_forgery(image):
184
+ """Gradio interface function"""
185
+ try:
186
+ overlay, results = detector.detect(image)
187
+ return overlay, json.dumps(results, indent=2)
188
+ except Exception as e:
189
+ return None, f"Error: {str(e)}"
190
+
191
+
192
+ # Create Gradio interface
193
+ demo = gr.Interface(
194
+ fn=detect_forgery,
195
+ inputs=gr.Image(type="pil", label="Upload Document Image"),
196
+ outputs=[
197
+ gr.Image(type="numpy", label="Detection Result"),
198
+ gr.JSON(label="Detection Details")
199
+ ],
200
+ title="📄 Document Forgery Detector",
201
+ description="""
202
+ Upload a document image to detect and classify forgeries.
203
+
204
+ **Supported Forgery Types:**
205
+ - 🔴 Copy-Move: Duplicated regions within the document
206
+ - 🟢 Splicing: Content from different sources
207
+ - 🔵 Generation: AI-generated or synthesized content
208
+
209
+ **Model Performance:**
210
+ - Localization: 75% Dice Score
211
+ - Classification: 92% Accuracy
212
+ """,
213
+ examples=[
214
+ ["examples/sample1.jpg"],
215
+ ["examples/sample2.jpg"],
216
+ ],
217
+ article="""
218
+ ### About
219
+ This model uses a hybrid deep learning approach:
220
+ 1. **Localization**: MobileNetV3-Small + UNet-Lite (detects WHERE)
221
+ 2. **Classification**: LightGBM with hybrid features (detects WHAT)
222
+
223
+ Trained on DocTamper dataset (140K samples).
224
+ """,
225
+ theme=gr.themes.Soft(),
226
+ allow_flagging="never"
227
+ )
228
+
229
+
230
+ if __name__ == "__main__":
231
+ demo.launch()