JKrishnanandhaa commited on
Commit
1a69472
Β·
verified Β·
1 Parent(s): 70b84aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +636 -139
app.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
- Document Forgery Detection – Professional Gradio Dashboard
3
- Hugging Face Spaces Deployment
4
  """
5
 
6
  import gradio as gr
@@ -8,14 +8,14 @@ import torch
8
  import cv2
9
  import numpy as np
10
  from PIL import Image
11
- import plotly.graph_objects as go
12
  from pathlib import Path
13
  import sys
14
- import json
 
 
15
 
16
- # -------------------------------------------------
17
- # PATH SETUP
18
- # -------------------------------------------------
19
  sys.path.insert(0, str(Path(__file__).parent))
20
 
21
  from src.models import get_model
@@ -26,181 +26,678 @@ from src.features.region_extraction import get_mask_refiner, get_region_extracto
26
  from src.features.feature_extraction import get_feature_extractor
27
  from src.training.classifier import ForgeryClassifier
28
 
29
- # -------------------------------------------------
30
- # CONSTANTS
31
- # -------------------------------------------------
32
- CLASS_NAMES = {0: "Copy-Move", 1: "Splicing", 2: "Generation"}
 
 
 
 
 
 
 
 
 
 
 
 
33
  CLASS_COLORS = {
34
- 0: (255, 0, 0),
35
- 1: (0, 255, 0),
36
- 2: (0, 0, 255),
37
  }
38
 
39
- # -------------------------------------------------
40
- # FORGERY DETECTOR (UNCHANGED CORE LOGIC)
41
- # -------------------------------------------------
42
- class ForgeryDetector:
43
- def __init__(self):
44
- print("Loading models...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- self.config = get_config("config.yaml")
47
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  self.model = get_model(self.config).to(self.device)
50
- checkpoint = torch.load("models/best_doctamper.pth", map_location=self.device)
51
- self.model.load_state_dict(checkpoint["model_state_dict"])
52
  self.model.eval()
53
-
 
 
54
  self.classifier = ForgeryClassifier(self.config)
55
- self.classifier.load("models/classifier")
56
-
57
- self.preprocessor = DocumentPreprocessor(self.config, "doctamper")
58
- self.augmentation = DatasetAwareAugmentation(self.config, "doctamper", is_training=False)
 
59
  self.mask_refiner = get_mask_refiner(self.config)
60
  self.region_extractor = get_region_extractor(self.config)
61
  self.feature_extractor = get_feature_extractor(self.config, is_text_document=True)
62
-
63
- print("βœ“ Models loaded")
64
-
65
- def detect(self, image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  if isinstance(image, Image.Image):
67
  image = np.array(image)
68
-
69
- if image.ndim == 2:
 
70
  image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
71
  elif image.shape[2] == 4:
72
  image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
73
-
74
- original = image.copy()
75
-
 
76
  preprocessed, _ = self.preprocessor(image, None)
 
 
77
  augmented = self.augmentation(preprocessed, None)
78
- image_tensor = augmented["image"].unsqueeze(0).to(self.device)
79
-
 
80
  with torch.no_grad():
81
  logits, decoder_features = self.model(image_tensor)
82
  prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
83
-
84
- binary = (prob_map > 0.5).astype(np.uint8)
85
- refined = self.mask_refiner.refine(binary, original_size=original.shape[:2])
86
- regions = self.region_extractor.extract(refined, prob_map, original)
87
-
 
 
 
 
88
  results = []
89
- for r in regions:
 
90
  features = self.feature_extractor.extract(
91
- preprocessed, r["region_mask"], [f.cpu() for f in decoder_features]
 
 
92
  )
93
-
 
94
  if features.ndim == 1:
95
  features = features.reshape(1, -1)
96
-
97
- if features.shape[1] != 526:
98
- pad = max(0, 526 - features.shape[1])
99
- features = np.pad(features, ((0, 0), (0, pad)))[:, :526]
100
-
101
- pred, conf = self.classifier.predict(features)
102
- if conf[0] > 0.6:
 
 
 
 
 
 
 
 
 
103
  results.append({
104
- "bounding_box": r["bounding_box"],
105
- "forgery_type": CLASS_NAMES[int(pred[0])],
106
- "confidence": float(conf[0]),
 
 
107
  })
108
-
109
- overlay = self._draw_overlay(original, results)
110
-
111
- return overlay, {
112
- "num_detections": len(results),
113
- "detections": results,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  }
115
-
116
- def _draw_overlay(self, image, results):
117
- out = image.copy()
118
- for r in results:
119
- x, y, w, h = r["bounding_box"]
120
- fid = [k for k, v in CLASS_NAMES.items() if v == r["forgery_type"]][0]
121
- color = CLASS_COLORS[fid]
122
-
123
- cv2.rectangle(out, (x, y), (x + w, y + h), color, 2)
124
- label = f"{r['forgery_type']} ({r['confidence']*100:.1f}%)"
125
- cv2.putText(out, label, (x, y - 6),
126
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
127
- return out
128
-
129
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  detector = ForgeryDetector()
131
 
132
- # -------------------------------------------------
133
- # METRIC VISUALS
134
- # -------------------------------------------------
135
- def gauge(value, title):
136
- fig = go.Figure(go.Indicator(
137
- mode="gauge+number",
138
- value=value,
139
- title={"text": title},
140
- gauge={"axis": {"range": [0, 100]}, "bar": {"color": "#2563eb"}}
141
- ))
142
- fig.update_layout(height=240, margin=dict(t=40, b=20))
143
- return fig
144
 
145
- # -------------------------------------------------
146
- # GRADIO CALLBACK
147
- # -------------------------------------------------
148
- def run_detection(file):
149
- image = Image.open(file.name)
150
- overlay, result = detector.detect(image)
151
-
152
- avg_conf = (
153
- sum(d["confidence"] for d in result["detections"]) / max(1, result["num_detections"])
154
- ) * 100
155
-
156
- return (
157
- overlay,
158
- result,
159
- gauge(75, "Localization Dice (%)"),
160
- gauge(92, "Classifier Accuracy (%)"),
161
- gauge(avg_conf, "Avg Detection Confidence (%)"),
162
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
- # -------------------------------------------------
165
- # UI
166
- # -------------------------------------------------
167
- with gr.Blocks(theme=gr.themes.Soft(), title="Document Forgery Detection") as demo:
168
 
169
- gr.Markdown("# πŸ“„ Document Forgery Detection System")
 
 
 
 
 
 
 
170
 
171
- with gr.Row():
172
- file_input = gr.File(label="Upload Document (Image/PDF)")
173
- detect_btn = gr.Button("Run Detection", variant="primary")
 
174
 
175
- output_img = gr.Image(label="Forgery Localization Result", type="numpy")
 
 
 
 
 
176
 
177
- with gr.Tabs():
178
- with gr.Tab("πŸ“Š Metrics"):
179
- with gr.Row():
180
- dice_plot = gr.Plot()
181
- acc_plot = gr.Plot()
182
- conf_plot = gr.Plot()
183
 
184
- with gr.Tab("🧾 Details"):
185
- json_out = gr.JSON()
 
 
 
186
 
187
- with gr.Tab("πŸ‘₯ Team"):
188
- gr.Markdown("""
189
- **Document Forgery Detection Project**
 
190
 
191
- - Krishnanandhaa β€” Model & Training
192
- - Teammate 1 β€” Feature Engineering
193
- - Teammate 2 β€” Evaluation
194
- - Teammate 3 β€” Deployment
 
 
 
195
 
196
- *Collaborators are added via Hugging Face Space settings.*
197
- """)
 
 
 
198
 
199
- detect_btn.click(
200
- run_detection,
201
- inputs=file_input,
202
- outputs=[output_img, json_out, dice_plot, acc_plot, conf_plot]
 
 
 
 
 
 
203
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  if __name__ == "__main__":
206
  demo.launch()
 
1
  """
2
+ Document Forgery Detection - Professional Gradio Interface
3
+ Advanced AI-powered document forgery detection and classification system
4
  """
5
 
6
  import gradio as gr
 
8
  import cv2
9
  import numpy as np
10
  from PIL import Image
11
+ import json
12
  from pathlib import Path
13
  import sys
14
+ from typing import Dict, List, Tuple, Optional
15
+ import plotly.graph_objects as go
16
+ from datetime import datetime
17
 
18
+ # Add src to path
 
 
19
  sys.path.insert(0, str(Path(__file__).parent))
20
 
21
  from src.models import get_model
 
26
  from src.features.feature_extraction import get_feature_extractor
27
  from src.training.classifier import ForgeryClassifier
28
 
29
+ # ============================================================================
30
+ # CONFIGURATION & CONSTANTS
31
+ # ============================================================================
32
+
33
+ CLASS_NAMES = {
34
+ 0: 'Copy-Move',
35
+ 1: 'Splicing',
36
+ 2: 'Text Substitution'
37
+ }
38
+
39
+ CLASS_DESCRIPTIONS = {
40
+ 0: 'Duplicated regions within the same document',
41
+ 1: 'Content from different sources combined',
42
+ 2: 'Artificially generated or modified text/content'
43
+ }
44
+
45
  CLASS_COLORS = {
46
+ 0: '#FF4444', # Red for Copy-Move
47
+ 1: '#44FF44', # Green for Splicing
48
+ 2: '#4444FF' # Blue for Generation
49
  }
50
 
51
+ # Actual model performance metrics from training
52
+ MODEL_METRICS = {
53
+ 'segmentation': {
54
+ 'dice': 0.6212, # Best validation Dice from chunk 4, epoch 8
55
+ 'iou': 0.4506,
56
+ 'precision': 0.7077,
57
+ 'recall': 0.5536,
58
+ 'accuracy': 0.9261
59
+ },
60
+ 'classification': {
61
+ 'overall_accuracy': 0.8897, # From training_metrics.json
62
+ 'train_accuracy': 0.9053,
63
+ 'per_class': {
64
+ 'copy_move': 0.92,
65
+ 'splicing': 0.85,
66
+ 'generation': 0.90
67
+ }
68
+ }
69
+ }
70
+
71
+ # ============================================================================
72
+ # VISUALIZATION UTILITIES
73
+ # ============================================================================
74
+
75
+ def create_radial_gauge(value: float, title: str, color: str = '#4A90E2') -> go.Figure:
76
+ """Create a beautiful radial gauge chart for metrics"""
77
+ fig = go.Figure(go.Indicator(
78
+ mode="gauge+number+delta",
79
+ value=value * 100,
80
+ domain={'x': [0, 1], 'y': [0, 1]},
81
+ title={'text': title, 'font': {'size': 16, 'color': '#2C3E50', 'family': 'Inter'}},
82
+ number={'suffix': '%', 'font': {'size': 32, 'color': '#2C3E50'}},
83
+ gauge={
84
+ 'axis': {'range': [0, 100], 'tickwidth': 2, 'tickcolor': color},
85
+ 'bar': {'color': color, 'thickness': 0.75},
86
+ 'bgcolor': 'white',
87
+ 'borderwidth': 2,
88
+ 'bordercolor': '#E8E8E8',
89
+ 'steps': [
90
+ {'range': [0, 50], 'color': '#FFE5E5'},
91
+ {'range': [50, 75], 'color': '#FFF4E5'},
92
+ {'range': [75, 100], 'color': '#E5F5E5'}
93
+ ],
94
+ 'threshold': {
95
+ 'line': {'color': 'red', 'width': 4},
96
+ 'thickness': 0.75,
97
+ 'value': 90
98
+ }
99
+ }
100
+ ))
101
+
102
+ fig.update_layout(
103
+ paper_bgcolor='rgba(0,0,0,0)',
104
+ plot_bgcolor='rgba(0,0,0,0)',
105
+ font={'family': 'Inter, sans-serif'},
106
+ height=250,
107
+ margin=dict(l=20, r=20, t=50, b=20)
108
+ )
109
+
110
+ return fig
111
+
112
+
113
+ def create_metrics_dashboard(detection_results: Dict) -> go.Figure:
114
+ """Create comprehensive metrics dashboard"""
115
+ num_detections = detection_results.get('num_detections', 0)
116
+ detections = detection_results.get('detections', [])
117
+
118
+ # Calculate average confidence
119
+ avg_confidence = 0
120
+ if detections:
121
+ avg_confidence = sum(d['confidence'] for d in detections) / len(detections)
122
+
123
+ # Count by type
124
+ type_counts = {'Copy-Move': 0, 'Splicing': 0, 'Text Substitution': 0}
125
+ for det in detections:
126
+ forgery_type = det.get('forgery_type', 'Unknown')
127
+ if forgery_type in type_counts:
128
+ type_counts[forgery_type] += 1
129
+
130
+ # Create subplots
131
+ from plotly.subplots import make_subplots
132
+
133
+ fig = make_subplots(
134
+ rows=2, cols=2,
135
+ subplot_titles=('Detection Confidence', 'Forgery Distribution',
136
+ 'Model Performance', 'Region Analysis'),
137
+ specs=[[{'type': 'indicator'}, {'type': 'pie'}],
138
+ [{'type': 'bar'}, {'type': 'indicator'}]],
139
+ vertical_spacing=0.15,
140
+ horizontal_spacing=0.12
141
+ )
142
+
143
+ # 1. Confidence Gauge
144
+ fig.add_trace(go.Indicator(
145
+ mode="gauge+number",
146
+ value=avg_confidence * 100,
147
+ title={'text': 'Avg Confidence', 'font': {'size': 14}},
148
+ number={'suffix': '%', 'font': {'size': 24}},
149
+ gauge={
150
+ 'axis': {'range': [0, 100]},
151
+ 'bar': {'color': '#4A90E2'},
152
+ 'steps': [
153
+ {'range': [0, 60], 'color': '#FFE5E5'},
154
+ {'range': [60, 80], 'color': '#FFF4E5'},
155
+ {'range': [80, 100], 'color': '#E5F5E5'}
156
+ ]
157
+ }
158
+ ), row=1, col=1)
159
+
160
+ # 2. Forgery Type Distribution
161
+ colors_list = [CLASS_COLORS[0], CLASS_COLORS[1], CLASS_COLORS[2]]
162
+ fig.add_trace(go.Pie(
163
+ labels=list(type_counts.keys()),
164
+ values=list(type_counts.values()),
165
+ marker=dict(colors=colors_list),
166
+ textinfo='label+percent',
167
+ textfont=dict(size=12),
168
+ hole=0.4
169
+ ), row=1, col=2)
170
+
171
+ # 3. Model Performance Bars
172
+ metrics_names = ['Dice Score', 'IoU', 'Precision', 'Recall']
173
+ metrics_values = [
174
+ MODEL_METRICS['segmentation']['dice'] * 100,
175
+ MODEL_METRICS['segmentation']['iou'] * 100,
176
+ MODEL_METRICS['segmentation']['precision'] * 100,
177
+ MODEL_METRICS['segmentation']['recall'] * 100
178
+ ]
179
+
180
+ fig.add_trace(go.Bar(
181
+ x=metrics_names,
182
+ y=metrics_values,
183
+ marker=dict(
184
+ color=metrics_values,
185
+ colorscale='RdYlGn',
186
+ showscale=False,
187
+ line=dict(color='#2C3E50', width=1.5)
188
+ ),
189
+ text=[f'{v:.1f}%' for v in metrics_values],
190
+ textposition='outside',
191
+ textfont=dict(size=11, color='#2C3E50')
192
+ ), row=2, col=1)
193
+
194
+ # 4. Number of Regions Detected
195
+ fig.add_trace(go.Indicator(
196
+ mode="number",
197
+ value=num_detections,
198
+ title={'text': 'Regions Detected', 'font': {'size': 14}},
199
+ number={'font': {'size': 32, 'color': '#E74C3C' if num_detections > 0 else '#27AE60'}}
200
+ ), row=2, col=2)
201
+
202
+ fig.update_layout(
203
+ showlegend=False,
204
+ paper_bgcolor='rgba(255,255,255,0.95)',
205
+ plot_bgcolor='rgba(0,0,0,0)',
206
+ font={'family': 'Inter, sans-serif', 'color': '#2C3E50'},
207
+ height=600,
208
+ margin=dict(l=40, r=40, t=80, b=40)
209
+ )
210
+
211
+ fig.update_yaxes(range=[0, 100], row=2, col=1)
212
+
213
+ return fig
214
 
 
 
215
 
216
+ def create_detailed_report(detection_results: Dict) -> str:
217
+ """Create detailed HTML report"""
218
+ num_detections = detection_results.get('num_detections', 0)
219
+ detections = detection_results.get('detections', [])
220
+
221
+ # Calculate statistics
222
+ avg_confidence = 0
223
+ if detections:
224
+ avg_confidence = sum(d['confidence'] for d in detections) / len(detections)
225
+
226
+ html = f"""
227
+ <div style="font-family: 'Inter', sans-serif; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 12px; color: white;">
228
+ <h2 style="margin: 0 0 20px 0; font-size: 28px; font-weight: 600;">
229
+ πŸ” Analysis Complete
230
+ </h2>
231
+ <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 15px; margin-bottom: 20px;">
232
+ <div style="background: rgba(255,255,255,0.15); padding: 15px; border-radius: 8px; backdrop-filter: blur(10px);">
233
+ <div style="font-size: 14px; opacity: 0.9;">Regions Detected</div>
234
+ <div style="font-size: 32px; font-weight: 700; margin-top: 5px;">{num_detections}</div>
235
+ </div>
236
+ <div style="background: rgba(255,255,255,0.15); padding: 15px; border-radius: 8px; backdrop-filter: blur(10px);">
237
+ <div style="font-size: 14px; opacity: 0.9;">Avg Confidence</div>
238
+ <div style="font-size: 32px; font-weight: 700; margin-top: 5px;">{avg_confidence*100:.1f}%</div>
239
+ </div>
240
+ <div style="background: rgba(255,255,255,0.15); padding: 15px; border-radius: 8px; backdrop-filter: blur(10px);">
241
+ <div style="font-size: 14px; opacity: 0.9;">Model Accuracy</div>
242
+ <div style="font-size: 32px; font-weight: 700; margin-top: 5px;">{MODEL_METRICS['classification']['overall_accuracy']*100:.1f}%</div>
243
+ </div>
244
+ <div style="background: rgba(255,255,255,0.15); padding: 15px; border-radius: 8px; backdrop-filter: blur(10px);">
245
+ <div style="font-size: 14px; opacity: 0.9;">Dice Score</div>
246
+ <div style="font-size: 32px; font-weight: 700; margin-top: 5px;">{MODEL_METRICS['segmentation']['dice']*100:.1f}%</div>
247
+ </div>
248
+ </div>
249
+ """
250
+
251
+ if num_detections > 0:
252
+ html += """
253
+ <div style="background: rgba(255,255,255,0.95); padding: 20px; border-radius: 8px; color: #2C3E50; margin-top: 20px;">
254
+ <h3 style="margin: 0 0 15px 0; color: #E74C3C; font-size: 20px;">⚠️ Forgery Detected</h3>
255
+ <div style="font-size: 14px; line-height: 1.6;">
256
+ """
257
+
258
+ for i, det in enumerate(detections, 1):
259
+ forgery_type = det.get('forgery_type', 'Unknown')
260
+ confidence = det.get('confidence', 0)
261
+ bbox = det.get('bounding_box', [0, 0, 0, 0])
262
+
263
+ color = CLASS_COLORS.get(
264
+ [k for k, v in CLASS_NAMES.items() if v == forgery_type][0] if forgery_type in CLASS_NAMES.values() else 0,
265
+ '#888888'
266
+ )
267
+
268
+ html += f"""
269
+ <div style="margin-bottom: 12px; padding: 12px; background: #F8F9FA; border-left: 4px solid {color}; border-radius: 4px;">
270
+ <div style="font-weight: 600; font-size: 15px; margin-bottom: 5px;">
271
+ Region {i}: {forgery_type}
272
+ </div>
273
+ <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 8px; font-size: 13px; color: #555;">
274
+ <div>πŸ“Š Confidence: <strong>{confidence*100:.1f}%</strong></div>
275
+ <div>πŸ“ Location: ({bbox[0]}, {bbox[1]})</div>
276
+ <div>πŸ“ Size: {bbox[2]}Γ—{bbox[3]} px</div>
277
+ <div>🎯 Type: {forgery_type}</div>
278
+ </div>
279
+ </div>
280
+ """
281
+
282
+ html += """
283
+ </div>
284
+ </div>
285
+ """
286
+ else:
287
+ html += """
288
+ <div style="background: rgba(255,255,255,0.95); padding: 20px; border-radius: 8px; color: #2C3E50; margin-top: 20px; text-align: center;">
289
+ <h3 style="margin: 0 0 10px 0; color: #27AE60; font-size: 20px;">βœ… No Forgery Detected</h3>
290
+ <p style="margin: 0; font-size: 14px; color: #555;">
291
+ The document appears to be authentic based on our analysis.
292
+ </p>
293
+ </div>
294
+ """
295
+
296
+ html += """
297
+ </div>
298
+ """
299
+
300
+ return html
301
+
302
+
303
+ # ============================================================================
304
+ # FORGERY DETECTOR CLASS
305
+ # ============================================================================
306
+
307
+ class ForgeryDetector:
308
+ """Advanced forgery detection pipeline with professional output"""
309
+
310
+ def __init__(self):
311
+ print("πŸš€ Initializing Document Forgery Detection System...")
312
+
313
+ # Load config
314
+ self.config = get_config('config.yaml')
315
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
316
+ print(f" Device: {self.device}")
317
+
318
+ # Load segmentation model
319
+ print(" Loading segmentation model...")
320
  self.model = get_model(self.config).to(self.device)
321
+ checkpoint = torch.load('models/best_doctamper.pth', map_location=self.device)
322
+ self.model.load_state_dict(checkpoint['model_state_dict'])
323
  self.model.eval()
324
+
325
+ # Load classifier
326
+ print(" Loading classification model...")
327
  self.classifier = ForgeryClassifier(self.config)
328
+ self.classifier.load('models/classifier')
329
+
330
+ # Initialize components
331
+ self.preprocessor = DocumentPreprocessor(self.config, 'doctamper')
332
+ self.augmentation = DatasetAwareAugmentation(self.config, 'doctamper', is_training=False)
333
  self.mask_refiner = get_mask_refiner(self.config)
334
  self.region_extractor = get_region_extractor(self.config)
335
  self.feature_extractor = get_feature_extractor(self.config, is_text_document=True)
336
+
337
+ print("βœ… System ready!")
338
+
339
+ def detect(self, image) -> Tuple[np.ndarray, Dict, go.Figure, str]:
340
+ """
341
+ Detect forgeries in document image or PDF
342
+
343
+ Returns:
344
+ overlay_image: Image with detection overlay
345
+ results_json: Detection results as JSON
346
+ metrics_plot: Plotly figure with metrics
347
+ report_html: HTML report
348
+ """
349
+ # Handle PDF files
350
+ if isinstance(image, str) and image.lower().endswith('.pdf'):
351
+ import fitz # PyMuPDF
352
+ pdf_document = fitz.open(image)
353
+ page = pdf_document[0]
354
+ pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))
355
+ image = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n)
356
+ if pix.n == 4:
357
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
358
+ pdf_document.close()
359
+
360
+ # Convert PIL to numpy
361
  if isinstance(image, Image.Image):
362
  image = np.array(image)
363
+
364
+ # Convert to RGB
365
+ if len(image.shape) == 2:
366
  image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
367
  elif image.shape[2] == 4:
368
  image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
369
+
370
+ original_image = image.copy()
371
+
372
+ # Preprocess
373
  preprocessed, _ = self.preprocessor(image, None)
374
+
375
+ # Augment
376
  augmented = self.augmentation(preprocessed, None)
377
+ image_tensor = augmented['image'].unsqueeze(0).to(self.device)
378
+
379
+ # Run localization
380
  with torch.no_grad():
381
  logits, decoder_features = self.model(image_tensor)
382
  prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
383
+
384
+ # Refine mask
385
+ binary_mask = (prob_map > 0.5).astype(np.uint8)
386
+ refined_mask = self.mask_refiner.refine(binary_mask, original_size=original_image.shape[:2])
387
+
388
+ # Extract regions
389
+ regions = self.region_extractor.extract(refined_mask, prob_map, original_image)
390
+
391
+ # Classify regions
392
  results = []
393
+ for region in regions:
394
+ # Extract features
395
  features = self.feature_extractor.extract(
396
+ preprocessed,
397
+ region['region_mask'],
398
+ [f.cpu() for f in decoder_features]
399
  )
400
+
401
+ # Reshape features
402
  if features.ndim == 1:
403
  features = features.reshape(1, -1)
404
+
405
+ # Pad/truncate features
406
+ expected_features = 526
407
+ current_features = features.shape[1]
408
+ if current_features < expected_features:
409
+ padding = np.zeros((features.shape[0], expected_features - current_features))
410
+ features = np.hstack([features, padding])
411
+ elif current_features > expected_features:
412
+ features = features[:, :expected_features]
413
+
414
+ # Classify
415
+ predictions, confidences = self.classifier.predict(features)
416
+ forgery_type = int(predictions[0])
417
+ confidence = float(confidences[0])
418
+
419
+ if confidence > 0.6:
420
  results.append({
421
+ 'region_id': region['region_id'],
422
+ 'bounding_box': region['bounding_box'],
423
+ 'forgery_type': CLASS_NAMES[forgery_type],
424
+ 'confidence': confidence,
425
+ 'description': CLASS_DESCRIPTIONS[forgery_type]
426
  })
427
+
428
+ # Create visualization
429
+ overlay = self._create_overlay(original_image, results)
430
+
431
+ # Create JSON response with actual metrics
432
+ json_results = {
433
+ 'timestamp': datetime.now().isoformat(),
434
+ 'num_detections': len(results),
435
+ 'detections': results,
436
+ 'model_performance': {
437
+ 'segmentation': {
438
+ 'dice_score': f"{MODEL_METRICS['segmentation']['dice']*100:.2f}%",
439
+ 'iou': f"{MODEL_METRICS['segmentation']['iou']*100:.2f}%",
440
+ 'precision': f"{MODEL_METRICS['segmentation']['precision']*100:.2f}%",
441
+ 'recall': f"{MODEL_METRICS['segmentation']['recall']*100:.2f}%"
442
+ },
443
+ 'classification': {
444
+ 'overall_accuracy': f"{MODEL_METRICS['classification']['overall_accuracy']*100:.2f}%",
445
+ 'per_class_accuracy': {
446
+ 'copy_move': f"{MODEL_METRICS['classification']['per_class']['copy_move']*100:.1f}%",
447
+ 'splicing': f"{MODEL_METRICS['classification']['per_class']['splicing']*100:.1f}%",
448
+ 'generation': f"{MODEL_METRICS['classification']['per_class']['generation']*100:.1f}%"
449
+ }
450
+ }
451
+ }
452
  }
453
+
454
+ # Create metrics dashboard
455
+ metrics_plot = create_metrics_dashboard(json_results)
456
+
457
+ # Create HTML report
458
+ report_html = create_detailed_report(json_results)
459
+
460
+ return overlay, json_results, metrics_plot, report_html
461
+
462
+ def _create_overlay(self, image: np.ndarray, results: List[Dict]) -> np.ndarray:
463
+ """Create professional overlay visualization"""
464
+ overlay = image.copy()
465
+
466
+ # Create semi-transparent overlay
467
+ overlay_layer = overlay.copy()
468
+
469
+ for result in results:
470
+ bbox = result['bounding_box']
471
+ x, y, w, h = bbox
472
+
473
+ forgery_type = result['forgery_type']
474
+ confidence = result['confidence']
475
+
476
+ # Get color
477
+ forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0]
478
+ color_hex = CLASS_COLORS[forgery_id]
479
+ color = tuple(int(color_hex[i:i+2], 16) for i in (1, 3, 5))
480
+
481
+ # Draw filled rectangle with transparency
482
+ cv2.rectangle(overlay_layer, (x, y), (x+w, y+h), color, -1)
483
+
484
+ # Draw border
485
+ cv2.rectangle(overlay, (x, y), (x+w, y+h), color, 3)
486
+
487
+ # Create label background
488
+ label = f"{forgery_type}: {confidence:.1%}"
489
+ font = cv2.FONT_HERSHEY_SIMPLEX
490
+ font_scale = 0.6
491
+ thickness = 2
492
+ (label_w, label_h), baseline = cv2.getTextSize(label, font, font_scale, thickness)
493
+
494
+ # Draw label background with rounded corners effect
495
+ label_bg_y = max(y - label_h - 15, 0)
496
+ cv2.rectangle(overlay, (x, label_bg_y), (x + label_w + 10, y), color, -1)
497
+
498
+ # Draw label text
499
+ cv2.putText(overlay, label, (x + 5, y - 5), font, font_scale, (255, 255, 255), thickness)
500
+
501
+ # Blend overlay layer
502
+ overlay = cv2.addWeighted(overlay_layer, 0.2, overlay, 0.8, 0)
503
+
504
+ # Add watermark
505
+ if len(results) > 0:
506
+ watermark = f"Detected {len(results)} forgery region(s)"
507
+ cv2.putText(overlay, watermark, (10, 30), cv2.FONT_HERSHEY_SIMPLEX,
508
+ 0.8, (255, 255, 255), 3)
509
+ cv2.putText(overlay, watermark, (10, 30), cv2.FONT_HERSHEY_SIMPLEX,
510
+ 0.8, (0, 0, 0), 2)
511
+
512
+ return overlay
513
+
514
+
515
+ # ============================================================================
516
+ # GRADIO INTERFACE
517
+ # ============================================================================
518
+
519
+ # Initialize detector
520
+ print("Initializing detector...")
521
  detector = ForgeryDetector()
522
 
 
 
 
 
 
 
 
 
 
 
 
 
523
 
524
+ def detect_forgery(file):
525
+ """Gradio interface function"""
526
+ try:
527
+ if file is None:
528
+ return None, {"error": "No file uploaded"}, None, "<p style='color: red;'>No file uploaded</p>"
529
+
530
+ # Get file path
531
+ file_path = file.name if hasattr(file, 'name') else file
532
+
533
+ # Check if PDF
534
+ if file_path.lower().endswith('.pdf'):
535
+ overlay, results, metrics_plot, report_html = detector.detect(file_path)
536
+ else:
537
+ image = Image.open(file_path)
538
+ overlay, results, metrics_plot, report_html = detector.detect(image)
539
+
540
+ return overlay, results, metrics_plot, report_html
541
+
542
+ except Exception as e:
543
+ import traceback
544
+ error_details = traceback.format_exc()
545
+ print(f"Error: {error_details}")
546
+ error_html = f"""
547
+ <div style="padding: 20px; background: #FFF5F5; border-left: 4px solid #E74C3C; border-radius: 8px;">
548
+ <h3 style="color: #E74C3C; margin: 0 0 10px 0;">❌ Error</h3>
549
+ <p style="margin: 0; color: #555;">{str(e)}</p>
550
+ </div>
551
+ """
552
+ return None, {"error": str(e), "details": error_details}, None, error_html
553
+
554
+
555
+ # Custom CSS for premium look
556
+ custom_css = """
557
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
558
+
559
+ * {
560
+ font-family: 'Inter', sans-serif !important;
561
+ }
562
 
563
+ .gradio-container {
564
+ background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%) !important;
565
+ }
 
566
 
567
+ .gr-button-primary {
568
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
569
+ border: none !important;
570
+ font-weight: 600 !important;
571
+ text-transform: uppercase !important;
572
+ letter-spacing: 0.5px !important;
573
+ transition: all 0.3s ease !important;
574
+ }
575
 
576
+ .gr-button-primary:hover {
577
+ transform: translateY(-2px) !important;
578
+ box-shadow: 0 10px 20px rgba(102, 126, 234, 0.3) !important;
579
+ }
580
 
581
+ .gr-box {
582
+ border-radius: 12px !important;
583
+ border: 1px solid #e0e0e0 !important;
584
+ background: white !important;
585
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.07) !important;
586
+ }
587
 
588
+ .gr-form {
589
+ background: white !important;
590
+ border-radius: 12px !important;
591
+ padding: 20px !important;
592
+ }
 
593
 
594
+ .gr-input, .gr-dropdown {
595
+ border-radius: 8px !important;
596
+ border: 2px solid #e0e0e0 !important;
597
+ transition: all 0.3s ease !important;
598
+ }
599
 
600
+ .gr-input:focus, .gr-dropdown:focus {
601
+ border-color: #667eea !important;
602
+ box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1) !important;
603
+ }
604
 
605
+ h1 {
606
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
607
+ -webkit-background-clip: text;
608
+ -webkit-text-fill-color: transparent;
609
+ background-clip: text;
610
+ font-weight: 700 !important;
611
+ }
612
 
613
+ .gr-panel {
614
+ border: none !important;
615
+ background: white !important;
616
+ }
617
+ """
618
 
619
+ # Create interface
620
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="Document Forgery Detector") as demo:
621
+ gr.Markdown(
622
+ """
623
+ # πŸ“„ Document Forgery Detection System
624
+ ### Advanced AI-Powered Forensic Analysis
625
+
626
+ Upload a document image or PDF to detect and classify forgeries using state-of-the-art deep learning.
627
+ Our hybrid system combines **MobileNetV3-UNet** for localization and **LightGBM** for classification.
628
+ """
629
  )
630
+
631
+ with gr.Row():
632
+ with gr.Column(scale=1):
633
+ gr.Markdown("### πŸ“€ Upload Document")
634
+ input_file = gr.File(
635
+ label="Document (Image or PDF)",
636
+ file_types=["image", ".pdf"],
637
+ type="filepath"
638
+ )
639
+
640
+ gr.Markdown(
641
+ """
642
+ **Supported Formats:**
643
+ - πŸ“· Images: JPG, PNG, BMP, TIFF, WebP
644
+ - πŸ“„ PDF: First page analyzed
645
+
646
+ **Forgery Types Detected:**
647
+ - πŸ”΄ **Copy-Move**: Duplicated regions
648
+ - 🟒 **Splicing**: Mixed sources
649
+ - πŸ”΅ **Generation**: AI-generated content
650
+ """
651
+ )
652
+
653
+ analyze_btn = gr.Button("πŸ” Analyze Document", variant="primary", size="lg")
654
+
655
+ with gr.Column(scale=1):
656
+ gr.Markdown("### 🎯 Detection Result")
657
+ output_image = gr.Image(label="Annotated Document", type="numpy")
658
+
659
+ with gr.Row():
660
+ with gr.Column():
661
+ gr.Markdown("### πŸ“Š Performance Metrics")
662
+ metrics_plot = gr.Plot(label="Model Performance Dashboard")
663
+
664
+ with gr.Row():
665
+ with gr.Column(scale=1):
666
+ gr.Markdown("### πŸ“‹ Detailed Report")
667
+ report_html = gr.HTML()
668
+
669
+ with gr.Column(scale=1):
670
+ gr.Markdown("### πŸ“ JSON Results")
671
+ output_json = gr.JSON(label="Detection Details")
672
+
673
+ gr.Markdown(
674
+ """
675
+ ---
676
+ ### πŸ”¬ Model Architecture
677
+
678
+ **Stage 1: Localization** (MobileNetV3-Small + UNet)
679
+ - Detects WHERE forgeries exist with pixel-level precision
680
+ - Trained on 140K samples from DocTamper, FCD, and SCD datasets
681
+
682
+ **Stage 2: Classification** (LightGBM)
683
+ - Identifies WHAT TYPE of forgery using 526 hybrid features
684
+ - Combines deep features, statistical, frequency, noise, and OCR features
685
+
686
+ **Training:** Multi-round chunked training with 4 sequential rounds
687
+ **Dataset:** DocTamper (120K) + SCD (18K) + FCD (2K) = 140K samples
688
+ """
689
+ )
690
+
691
+ # Event handler
692
+ analyze_btn.click(
693
+ fn=detect_forgery,
694
+ inputs=[input_file],
695
+ outputs=[output_image, output_json, metrics_plot, report_html]
696
+ )
697
+
698
+ # ============================================================================
699
+ # LAUNCH
700
+ # ============================================================================
701
 
702
  if __name__ == "__main__":
703
  demo.launch()