JKrishnanandhaa commited on
Commit
c38d472
·
verified ·
1 Parent(s): 1a69472

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -455
app.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
- Document Forgery Detection - Professional Gradio Interface
3
- Advanced AI-powered document forgery detection and classification system
 
4
  """
5
 
6
  import gradio as gr
@@ -11,9 +12,7 @@ from PIL import Image
11
  import json
12
  from pathlib import Path
13
  import sys
14
- from typing import Dict, List, Tuple, Optional
15
- import plotly.graph_objects as go
16
- from datetime import datetime
17
 
18
  # Add src to path
19
  sys.path.insert(0, str(Path(__file__).parent))
@@ -26,40 +25,24 @@ from src.features.region_extraction import get_mask_refiner, get_region_extracto
26
  from src.features.feature_extraction import get_feature_extractor
27
  from src.training.classifier import ForgeryClassifier
28
 
29
- # ============================================================================
30
- # CONFIGURATION & CONSTANTS
31
- # ============================================================================
32
-
33
- CLASS_NAMES = {
34
- 0: 'Copy-Move',
35
- 1: 'Splicing',
36
- 2: 'Text Substitution'
37
- }
38
-
39
- CLASS_DESCRIPTIONS = {
40
- 0: 'Duplicated regions within the same document',
41
- 1: 'Content from different sources combined',
42
- 2: 'Artificially generated or modified text/content'
43
- }
44
-
45
  CLASS_COLORS = {
46
- 0: '#FF4444', # Red for Copy-Move
47
- 1: '#44FF44', # Green for Splicing
48
- 2: '#4444FF' # Blue for Generation
49
  }
50
 
51
- # Actual model performance metrics from training
52
  MODEL_METRICS = {
53
  'segmentation': {
54
- 'dice': 0.6212, # Best validation Dice from chunk 4, epoch 8
55
  'iou': 0.4506,
56
  'precision': 0.7077,
57
- 'recall': 0.5536,
58
- 'accuracy': 0.9261
59
  },
60
  'classification': {
61
- 'overall_accuracy': 0.8897, # From training_metrics.json
62
- 'train_accuracy': 0.9053,
63
  'per_class': {
64
  'copy_move': 0.92,
65
  'splicing': 0.85,
@@ -68,262 +51,24 @@ MODEL_METRICS = {
68
  }
69
  }
70
 
71
- # ============================================================================
72
- # VISUALIZATION UTILITIES
73
- # ============================================================================
74
-
75
- def create_radial_gauge(value: float, title: str, color: str = '#4A90E2') -> go.Figure:
76
- """Create a beautiful radial gauge chart for metrics"""
77
- fig = go.Figure(go.Indicator(
78
- mode="gauge+number+delta",
79
- value=value * 100,
80
- domain={'x': [0, 1], 'y': [0, 1]},
81
- title={'text': title, 'font': {'size': 16, 'color': '#2C3E50', 'family': 'Inter'}},
82
- number={'suffix': '%', 'font': {'size': 32, 'color': '#2C3E50'}},
83
- gauge={
84
- 'axis': {'range': [0, 100], 'tickwidth': 2, 'tickcolor': color},
85
- 'bar': {'color': color, 'thickness': 0.75},
86
- 'bgcolor': 'white',
87
- 'borderwidth': 2,
88
- 'bordercolor': '#E8E8E8',
89
- 'steps': [
90
- {'range': [0, 50], 'color': '#FFE5E5'},
91
- {'range': [50, 75], 'color': '#FFF4E5'},
92
- {'range': [75, 100], 'color': '#E5F5E5'}
93
- ],
94
- 'threshold': {
95
- 'line': {'color': 'red', 'width': 4},
96
- 'thickness': 0.75,
97
- 'value': 90
98
- }
99
- }
100
- ))
101
-
102
- fig.update_layout(
103
- paper_bgcolor='rgba(0,0,0,0)',
104
- plot_bgcolor='rgba(0,0,0,0)',
105
- font={'family': 'Inter, sans-serif'},
106
- height=250,
107
- margin=dict(l=20, r=20, t=50, b=20)
108
- )
109
-
110
- return fig
111
-
112
-
113
- def create_metrics_dashboard(detection_results: Dict) -> go.Figure:
114
- """Create comprehensive metrics dashboard"""
115
- num_detections = detection_results.get('num_detections', 0)
116
- detections = detection_results.get('detections', [])
117
-
118
- # Calculate average confidence
119
- avg_confidence = 0
120
- if detections:
121
- avg_confidence = sum(d['confidence'] for d in detections) / len(detections)
122
-
123
- # Count by type
124
- type_counts = {'Copy-Move': 0, 'Splicing': 0, 'Text Substitution': 0}
125
- for det in detections:
126
- forgery_type = det.get('forgery_type', 'Unknown')
127
- if forgery_type in type_counts:
128
- type_counts[forgery_type] += 1
129
-
130
- # Create subplots
131
- from plotly.subplots import make_subplots
132
-
133
- fig = make_subplots(
134
- rows=2, cols=2,
135
- subplot_titles=('Detection Confidence', 'Forgery Distribution',
136
- 'Model Performance', 'Region Analysis'),
137
- specs=[[{'type': 'indicator'}, {'type': 'pie'}],
138
- [{'type': 'bar'}, {'type': 'indicator'}]],
139
- vertical_spacing=0.15,
140
- horizontal_spacing=0.12
141
- )
142
-
143
- # 1. Confidence Gauge
144
- fig.add_trace(go.Indicator(
145
- mode="gauge+number",
146
- value=avg_confidence * 100,
147
- title={'text': 'Avg Confidence', 'font': {'size': 14}},
148
- number={'suffix': '%', 'font': {'size': 24}},
149
- gauge={
150
- 'axis': {'range': [0, 100]},
151
- 'bar': {'color': '#4A90E2'},
152
- 'steps': [
153
- {'range': [0, 60], 'color': '#FFE5E5'},
154
- {'range': [60, 80], 'color': '#FFF4E5'},
155
- {'range': [80, 100], 'color': '#E5F5E5'}
156
- ]
157
- }
158
- ), row=1, col=1)
159
-
160
- # 2. Forgery Type Distribution
161
- colors_list = [CLASS_COLORS[0], CLASS_COLORS[1], CLASS_COLORS[2]]
162
- fig.add_trace(go.Pie(
163
- labels=list(type_counts.keys()),
164
- values=list(type_counts.values()),
165
- marker=dict(colors=colors_list),
166
- textinfo='label+percent',
167
- textfont=dict(size=12),
168
- hole=0.4
169
- ), row=1, col=2)
170
-
171
- # 3. Model Performance Bars
172
- metrics_names = ['Dice Score', 'IoU', 'Precision', 'Recall']
173
- metrics_values = [
174
- MODEL_METRICS['segmentation']['dice'] * 100,
175
- MODEL_METRICS['segmentation']['iou'] * 100,
176
- MODEL_METRICS['segmentation']['precision'] * 100,
177
- MODEL_METRICS['segmentation']['recall'] * 100
178
- ]
179
-
180
- fig.add_trace(go.Bar(
181
- x=metrics_names,
182
- y=metrics_values,
183
- marker=dict(
184
- color=metrics_values,
185
- colorscale='RdYlGn',
186
- showscale=False,
187
- line=dict(color='#2C3E50', width=1.5)
188
- ),
189
- text=[f'{v:.1f}%' for v in metrics_values],
190
- textposition='outside',
191
- textfont=dict(size=11, color='#2C3E50')
192
- ), row=2, col=1)
193
-
194
- # 4. Number of Regions Detected
195
- fig.add_trace(go.Indicator(
196
- mode="number",
197
- value=num_detections,
198
- title={'text': 'Regions Detected', 'font': {'size': 14}},
199
- number={'font': {'size': 32, 'color': '#E74C3C' if num_detections > 0 else '#27AE60'}}
200
- ), row=2, col=2)
201
-
202
- fig.update_layout(
203
- showlegend=False,
204
- paper_bgcolor='rgba(255,255,255,0.95)',
205
- plot_bgcolor='rgba(0,0,0,0)',
206
- font={'family': 'Inter, sans-serif', 'color': '#2C3E50'},
207
- height=600,
208
- margin=dict(l=40, r=40, t=80, b=40)
209
- )
210
-
211
- fig.update_yaxes(range=[0, 100], row=2, col=1)
212
-
213
- return fig
214
-
215
-
216
- def create_detailed_report(detection_results: Dict) -> str:
217
- """Create detailed HTML report"""
218
- num_detections = detection_results.get('num_detections', 0)
219
- detections = detection_results.get('detections', [])
220
-
221
- # Calculate statistics
222
- avg_confidence = 0
223
- if detections:
224
- avg_confidence = sum(d['confidence'] for d in detections) / len(detections)
225
-
226
- html = f"""
227
- <div style="font-family: 'Inter', sans-serif; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 12px; color: white;">
228
- <h2 style="margin: 0 0 20px 0; font-size: 28px; font-weight: 600;">
229
- 🔍 Analysis Complete
230
- </h2>
231
- <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 15px; margin-bottom: 20px;">
232
- <div style="background: rgba(255,255,255,0.15); padding: 15px; border-radius: 8px; backdrop-filter: blur(10px);">
233
- <div style="font-size: 14px; opacity: 0.9;">Regions Detected</div>
234
- <div style="font-size: 32px; font-weight: 700; margin-top: 5px;">{num_detections}</div>
235
- </div>
236
- <div style="background: rgba(255,255,255,0.15); padding: 15px; border-radius: 8px; backdrop-filter: blur(10px);">
237
- <div style="font-size: 14px; opacity: 0.9;">Avg Confidence</div>
238
- <div style="font-size: 32px; font-weight: 700; margin-top: 5px;">{avg_confidence*100:.1f}%</div>
239
- </div>
240
- <div style="background: rgba(255,255,255,0.15); padding: 15px; border-radius: 8px; backdrop-filter: blur(10px);">
241
- <div style="font-size: 14px; opacity: 0.9;">Model Accuracy</div>
242
- <div style="font-size: 32px; font-weight: 700; margin-top: 5px;">{MODEL_METRICS['classification']['overall_accuracy']*100:.1f}%</div>
243
- </div>
244
- <div style="background: rgba(255,255,255,0.15); padding: 15px; border-radius: 8px; backdrop-filter: blur(10px);">
245
- <div style="font-size: 14px; opacity: 0.9;">Dice Score</div>
246
- <div style="font-size: 32px; font-weight: 700; margin-top: 5px;">{MODEL_METRICS['segmentation']['dice']*100:.1f}%</div>
247
- </div>
248
- </div>
249
- """
250
-
251
- if num_detections > 0:
252
- html += """
253
- <div style="background: rgba(255,255,255,0.95); padding: 20px; border-radius: 8px; color: #2C3E50; margin-top: 20px;">
254
- <h3 style="margin: 0 0 15px 0; color: #E74C3C; font-size: 20px;">⚠️ Forgery Detected</h3>
255
- <div style="font-size: 14px; line-height: 1.6;">
256
- """
257
-
258
- for i, det in enumerate(detections, 1):
259
- forgery_type = det.get('forgery_type', 'Unknown')
260
- confidence = det.get('confidence', 0)
261
- bbox = det.get('bounding_box', [0, 0, 0, 0])
262
-
263
- color = CLASS_COLORS.get(
264
- [k for k, v in CLASS_NAMES.items() if v == forgery_type][0] if forgery_type in CLASS_NAMES.values() else 0,
265
- '#888888'
266
- )
267
-
268
- html += f"""
269
- <div style="margin-bottom: 12px; padding: 12px; background: #F8F9FA; border-left: 4px solid {color}; border-radius: 4px;">
270
- <div style="font-weight: 600; font-size: 15px; margin-bottom: 5px;">
271
- Region {i}: {forgery_type}
272
- </div>
273
- <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 8px; font-size: 13px; color: #555;">
274
- <div>📊 Confidence: <strong>{confidence*100:.1f}%</strong></div>
275
- <div>📍 Location: ({bbox[0]}, {bbox[1]})</div>
276
- <div>📏 Size: {bbox[2]}×{bbox[3]} px</div>
277
- <div>🎯 Type: {forgery_type}</div>
278
- </div>
279
- </div>
280
- """
281
-
282
- html += """
283
- </div>
284
- </div>
285
- """
286
- else:
287
- html += """
288
- <div style="background: rgba(255,255,255,0.95); padding: 20px; border-radius: 8px; color: #2C3E50; margin-top: 20px; text-align: center;">
289
- <h3 style="margin: 0 0 10px 0; color: #27AE60; font-size: 20px;">✅ No Forgery Detected</h3>
290
- <p style="margin: 0; font-size: 14px; color: #555;">
291
- The document appears to be authentic based on our analysis.
292
- </p>
293
- </div>
294
- """
295
-
296
- html += """
297
- </div>
298
- """
299
-
300
- return html
301
-
302
-
303
- # ============================================================================
304
- # FORGERY DETECTOR CLASS
305
- # ============================================================================
306
 
307
  class ForgeryDetector:
308
- """Advanced forgery detection pipeline with professional output"""
309
 
310
  def __init__(self):
311
- print("🚀 Initializing Document Forgery Detection System...")
312
 
313
  # Load config
314
  self.config = get_config('config.yaml')
315
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
316
- print(f" Device: {self.device}")
317
 
318
  # Load segmentation model
319
- print(" Loading segmentation model...")
320
  self.model = get_model(self.config).to(self.device)
321
  checkpoint = torch.load('models/best_doctamper.pth', map_location=self.device)
322
  self.model.load_state_dict(checkpoint['model_state_dict'])
323
  self.model.eval()
324
 
325
  # Load classifier
326
- print(" Loading classification model...")
327
  self.classifier = ForgeryClassifier(self.config)
328
  self.classifier.load('models/classifier')
329
 
@@ -334,17 +79,19 @@ class ForgeryDetector:
334
  self.region_extractor = get_region_extractor(self.config)
335
  self.feature_extractor = get_feature_extractor(self.config, is_text_document=True)
336
 
337
- print(" System ready!")
338
 
339
- def detect(self, image) -> Tuple[np.ndarray, Dict, go.Figure, str]:
340
  """
341
  Detect forgeries in document image or PDF
342
 
 
 
 
343
  Returns:
 
344
  overlay_image: Image with detection overlay
345
- results_json: Detection results as JSON
346
- metrics_plot: Plotly figure with metrics
347
- report_html: HTML report
348
  """
349
  # Handle PDF files
350
  if isinstance(image, str) and image.lower().endswith('.pdf'):
@@ -398,11 +145,11 @@ class ForgeryDetector:
398
  [f.cpu() for f in decoder_features]
399
  )
400
 
401
- # Reshape features
402
  if features.ndim == 1:
403
  features = features.reshape(1, -1)
404
 
405
- # Pad/truncate features
406
  expected_features = 526
407
  current_features = features.shape[1]
408
  if current_features < expected_features:
@@ -421,51 +168,21 @@ class ForgeryDetector:
421
  'region_id': region['region_id'],
422
  'bounding_box': region['bounding_box'],
423
  'forgery_type': CLASS_NAMES[forgery_type],
424
- 'confidence': confidence,
425
- 'description': CLASS_DESCRIPTIONS[forgery_type]
426
  })
427
 
428
  # Create visualization
429
  overlay = self._create_overlay(original_image, results)
430
 
431
- # Create JSON response with actual metrics
432
- json_results = {
433
- 'timestamp': datetime.now().isoformat(),
434
- 'num_detections': len(results),
435
- 'detections': results,
436
- 'model_performance': {
437
- 'segmentation': {
438
- 'dice_score': f"{MODEL_METRICS['segmentation']['dice']*100:.2f}%",
439
- 'iou': f"{MODEL_METRICS['segmentation']['iou']*100:.2f}%",
440
- 'precision': f"{MODEL_METRICS['segmentation']['precision']*100:.2f}%",
441
- 'recall': f"{MODEL_METRICS['segmentation']['recall']*100:.2f}%"
442
- },
443
- 'classification': {
444
- 'overall_accuracy': f"{MODEL_METRICS['classification']['overall_accuracy']*100:.2f}%",
445
- 'per_class_accuracy': {
446
- 'copy_move': f"{MODEL_METRICS['classification']['per_class']['copy_move']*100:.1f}%",
447
- 'splicing': f"{MODEL_METRICS['classification']['per_class']['splicing']*100:.1f}%",
448
- 'generation': f"{MODEL_METRICS['classification']['per_class']['generation']*100:.1f}%"
449
- }
450
- }
451
- }
452
- }
453
 
454
- # Create metrics dashboard
455
- metrics_plot = create_metrics_dashboard(json_results)
456
-
457
- # Create HTML report
458
- report_html = create_detailed_report(json_results)
459
-
460
- return overlay, json_results, metrics_plot, report_html
461
 
462
- def _create_overlay(self, image: np.ndarray, results: List[Dict]) -> np.ndarray:
463
- """Create professional overlay visualization"""
464
  overlay = image.copy()
465
 
466
- # Create semi-transparent overlay
467
- overlay_layer = overlay.copy()
468
-
469
  for result in results:
470
  bbox = result['bounding_box']
471
  x, y, w, h = bbox
@@ -475,49 +192,81 @@ class ForgeryDetector:
475
 
476
  # Get color
477
  forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0]
478
- color_hex = CLASS_COLORS[forgery_id]
479
- color = tuple(int(color_hex[i:i+2], 16) for i in (1, 3, 5))
480
-
481
- # Draw filled rectangle with transparency
482
- cv2.rectangle(overlay_layer, (x, y), (x+w, y+h), color, -1)
483
 
484
- # Draw border
485
- cv2.rectangle(overlay, (x, y), (x+w, y+h), color, 3)
486
 
487
- # Create label background
488
  label = f"{forgery_type}: {confidence:.1%}"
489
  font = cv2.FONT_HERSHEY_SIMPLEX
490
- font_scale = 0.6
491
- thickness = 2
492
  (label_w, label_h), baseline = cv2.getTextSize(label, font, font_scale, thickness)
493
 
494
- # Draw label background with rounded corners effect
495
- label_bg_y = max(y - label_h - 15, 0)
496
- cv2.rectangle(overlay, (x, label_bg_y), (x + label_w + 10, y), color, -1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
 
498
- # Draw label text
499
- cv2.putText(overlay, label, (x + 5, y - 5), font, font_scale, (255, 255, 255), thickness)
 
 
 
 
500
 
501
- # Blend overlay layer
502
- overlay = cv2.addWeighted(overlay_layer, 0.2, overlay, 0.8, 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
503
 
504
- # Add watermark
505
- if len(results) > 0:
506
- watermark = f"Detected {len(results)} forgery region(s)"
507
- cv2.putText(overlay, watermark, (10, 30), cv2.FONT_HERSHEY_SIMPLEX,
508
- 0.8, (255, 255, 255), 3)
509
- cv2.putText(overlay, watermark, (10, 30), cv2.FONT_HERSHEY_SIMPLEX,
510
- 0.8, (0, 0, 0), 2)
511
 
512
- return overlay
513
 
514
 
515
- # ============================================================================
516
- # GRADIO INTERFACE
517
- # ============================================================================
518
-
519
  # Initialize detector
520
- print("Initializing detector...")
521
  detector = ForgeryDetector()
522
 
523
 
@@ -525,179 +274,118 @@ def detect_forgery(file):
525
  """Gradio interface function"""
526
  try:
527
  if file is None:
528
- return None, {"error": "No file uploaded"}, None, "<p style='color: red;'>No file uploaded</p>"
529
 
530
  # Get file path
531
  file_path = file.name if hasattr(file, 'name') else file
532
 
533
  # Check if PDF
534
  if file_path.lower().endswith('.pdf'):
535
- overlay, results, metrics_plot, report_html = detector.detect(file_path)
536
  else:
537
  image = Image.open(file_path)
538
- overlay, results, metrics_plot, report_html = detector.detect(image)
539
 
540
- return overlay, results, metrics_plot, report_html
541
 
542
  except Exception as e:
543
  import traceback
544
  error_details = traceback.format_exc()
545
  print(f"Error: {error_details}")
546
  error_html = f"""
547
- <div style="padding: 20px; background: #FFF5F5; border-left: 4px solid #E74C3C; border-radius: 8px;">
548
- <h3 style="color: #E74C3C; margin: 0 0 10px 0;">❌ Error</h3>
549
- <p style="margin: 0; color: #555;">{str(e)}</p>
550
  </div>
551
  """
552
- return None, {"error": str(e), "details": error_details}, None, error_html
553
 
554
 
555
- # Custom CSS for premium look
556
  custom_css = """
557
- @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
558
-
559
- * {
560
- font-family: 'Inter', sans-serif !important;
561
- }
562
-
563
- .gradio-container {
564
- background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%) !important;
565
- }
566
-
567
- .gr-button-primary {
568
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
569
- border: none !important;
570
- font-weight: 600 !important;
571
- text-transform: uppercase !important;
572
- letter-spacing: 0.5px !important;
573
- transition: all 0.3s ease !important;
574
- }
575
-
576
- .gr-button-primary:hover {
577
- transform: translateY(-2px) !important;
578
- box-shadow: 0 10px 20px rgba(102, 126, 234, 0.3) !important;
579
- }
580
-
581
- .gr-box {
582
- border-radius: 12px !important;
583
- border: 1px solid #e0e0e0 !important;
584
- background: white !important;
585
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.07) !important;
586
- }
587
-
588
- .gr-form {
589
- background: white !important;
590
- border-radius: 12px !important;
591
- padding: 20px !important;
592
- }
593
-
594
- .gr-input, .gr-dropdown {
595
- border-radius: 8px !important;
596
- border: 2px solid #e0e0e0 !important;
597
- transition: all 0.3s ease !important;
598
- }
599
-
600
- .gr-input:focus, .gr-dropdown:focus {
601
- border-color: #667eea !important;
602
- box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1) !important;
603
- }
604
-
605
- h1 {
606
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
607
- -webkit-background-clip: text;
608
- -webkit-text-fill-color: transparent;
609
- background-clip: text;
610
- font-weight: 700 !important;
611
  }
612
-
613
- .gr-panel {
614
- border: none !important;
615
- background: white !important;
616
  }
617
  """
618
 
619
- # Create interface
620
- with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="Document Forgery Detector") as demo:
 
621
  gr.Markdown(
622
  """
623
- # 📄 Document Forgery Detection System
624
- ### Advanced AI-Powered Forensic Analysis
625
-
626
- Upload a document image or PDF to detect and classify forgeries using state-of-the-art deep learning.
627
- Our hybrid system combines **MobileNetV3-UNet** for localization and **LightGBM** for classification.
628
  """
629
  )
630
 
631
  with gr.Row():
632
  with gr.Column(scale=1):
633
- gr.Markdown("### 📤 Upload Document")
634
  input_file = gr.File(
635
  label="Document (Image or PDF)",
636
  file_types=["image", ".pdf"],
637
  type="filepath"
638
  )
639
 
 
 
 
 
640
  gr.Markdown(
641
  """
642
- **Supported Formats:**
643
- - 📷 Images: JPG, PNG, BMP, TIFF, WebP
644
- - 📄 PDF: First page analyzed
645
 
646
- **Forgery Types Detected:**
647
- - 🔴 **Copy-Move**: Duplicated regions
648
- - 🟢 **Splicing**: Mixed sources
649
- - 🔵 **Generation**: AI-generated content
650
  """
651
  )
652
-
653
- analyze_btn = gr.Button("🔍 Analyze Document", variant="primary", size="lg")
654
 
655
  with gr.Column(scale=1):
656
- gr.Markdown("### 🎯 Detection Result")
657
- output_image = gr.Image(label="Annotated Document", type="numpy")
658
-
659
- with gr.Row():
660
- with gr.Column():
661
- gr.Markdown("### 📊 Performance Metrics")
662
- metrics_plot = gr.Plot(label="Model Performance Dashboard")
663
 
664
  with gr.Row():
665
  with gr.Column(scale=1):
666
- gr.Markdown("### 📋 Detailed Report")
667
- report_html = gr.HTML()
668
 
669
  with gr.Column(scale=1):
670
- gr.Markdown("### 📁 JSON Results")
671
- output_json = gr.JSON(label="Detection Details")
 
 
672
 
673
  gr.Markdown(
674
  """
675
  ---
676
- ### 🔬 Model Architecture
677
-
678
- **Stage 1: Localization** (MobileNetV3-Small + UNet)
679
- - Detects WHERE forgeries exist with pixel-level precision
680
- - Trained on 140K samples from DocTamper, FCD, and SCD datasets
681
-
682
- **Stage 2: Classification** (LightGBM)
683
- - Identifies WHAT TYPE of forgery using 526 hybrid features
684
- - Combines deep features, statistical, frequency, noise, and OCR features
685
-
686
- **Training:** Multi-round chunked training with 4 sequential rounds
687
- **Dataset:** DocTamper (120K) + SCD (18K) + FCD (2K) = 140K samples
688
  """
689
  )
690
 
691
- # Event handler
692
  analyze_btn.click(
693
  fn=detect_forgery,
694
  inputs=[input_file],
695
- outputs=[output_image, output_json, metrics_plot, report_html]
 
 
 
 
 
 
696
  )
697
 
698
- # ============================================================================
699
- # LAUNCH
700
- # ============================================================================
701
 
702
  if __name__ == "__main__":
703
  demo.launch()
 
1
  """
2
+ Document Forgery Detection - Gradio Interface for Hugging Face Spaces
3
+
4
+ This app provides a web interface for detecting and classifying document forgeries.
5
  """
6
 
7
  import gradio as gr
 
12
  import json
13
  from pathlib import Path
14
  import sys
15
+ from typing import Dict, List, Tuple
 
 
16
 
17
  # Add src to path
18
  sys.path.insert(0, str(Path(__file__).parent))
 
25
  from src.features.feature_extraction import get_feature_extractor
26
  from src.training.classifier import ForgeryClassifier
27
 
28
+ # Class names
29
+ CLASS_NAMES = {0: 'Copy-Move', 1: 'Splicing', 2: 'Text Substitution'}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  CLASS_COLORS = {
31
+ 0: (217, 83, 79), # #d9534f - Muted red
32
+ 1: (92, 184, 92), # #5cb85c - Muted green
33
+ 2: (65, 105, 225) # #4169E1 - Royal blue
34
  }
35
 
36
+ # Actual model performance metrics
37
  MODEL_METRICS = {
38
  'segmentation': {
39
+ 'dice': 0.6212,
40
  'iou': 0.4506,
41
  'precision': 0.7077,
42
+ 'recall': 0.5536
 
43
  },
44
  'classification': {
45
+ 'overall_accuracy': 0.8897,
 
46
  'per_class': {
47
  'copy_move': 0.92,
48
  'splicing': 0.85,
 
51
  }
52
  }
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  class ForgeryDetector:
56
+ """Main forgery detection pipeline"""
57
 
58
  def __init__(self):
59
+ print("Loading models...")
60
 
61
  # Load config
62
  self.config = get_config('config.yaml')
63
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
64
 
65
  # Load segmentation model
 
66
  self.model = get_model(self.config).to(self.device)
67
  checkpoint = torch.load('models/best_doctamper.pth', map_location=self.device)
68
  self.model.load_state_dict(checkpoint['model_state_dict'])
69
  self.model.eval()
70
 
71
  # Load classifier
 
72
  self.classifier = ForgeryClassifier(self.config)
73
  self.classifier.load('models/classifier')
74
 
 
79
  self.region_extractor = get_region_extractor(self.config)
80
  self.feature_extractor = get_feature_extractor(self.config, is_text_document=True)
81
 
82
+ print(" Models loaded successfully!")
83
 
84
+ def detect(self, image):
85
  """
86
  Detect forgeries in document image or PDF
87
 
88
+ Args:
89
+ image: PIL Image, numpy array, or path to PDF file
90
+
91
  Returns:
92
+ original_image: Original uploaded image
93
  overlay_image: Image with detection overlay
94
+ results_html: Detection results as HTML
 
 
95
  """
96
  # Handle PDF files
97
  if isinstance(image, str) and image.lower().endswith('.pdf'):
 
145
  [f.cpu() for f in decoder_features]
146
  )
147
 
148
+ # Reshape features to 2D array
149
  if features.ndim == 1:
150
  features = features.reshape(1, -1)
151
 
152
+ # Pad/truncate features to match classifier
153
  expected_features = 526
154
  current_features = features.shape[1]
155
  if current_features < expected_features:
 
168
  'region_id': region['region_id'],
169
  'bounding_box': region['bounding_box'],
170
  'forgery_type': CLASS_NAMES[forgery_type],
171
+ 'confidence': confidence
 
172
  })
173
 
174
  # Create visualization
175
  overlay = self._create_overlay(original_image, results)
176
 
177
+ # Create HTML response
178
+ results_html = self._create_html_report(results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
+ return original_image, overlay, results_html
 
 
 
 
 
 
181
 
182
+ def _create_overlay(self, image, results):
183
+ """Create overlay visualization"""
184
  overlay = image.copy()
185
 
 
 
 
186
  for result in results:
187
  bbox = result['bounding_box']
188
  x, y, w, h = bbox
 
192
 
193
  # Get color
194
  forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0]
195
+ color = CLASS_COLORS[forgery_id]
 
 
 
 
196
 
197
+ # Draw rectangle
198
+ cv2.rectangle(overlay, (x, y), (x+w, y+h), color, 2)
199
 
200
+ # Draw label
201
  label = f"{forgery_type}: {confidence:.1%}"
202
  font = cv2.FONT_HERSHEY_SIMPLEX
203
+ font_scale = 0.5
204
+ thickness = 1
205
  (label_w, label_h), baseline = cv2.getTextSize(label, font, font_scale, thickness)
206
 
207
+ cv2.rectangle(overlay, (x, y-label_h-8), (x+label_w+4, y), color, -1)
208
+ cv2.putText(overlay, label, (x+2, y-4), font, font_scale, (255, 255, 255), thickness)
209
+
210
+ return overlay
211
+
212
+ def _create_html_report(self, results):
213
+ """Create HTML report with detection results"""
214
+ num_detections = len(results)
215
+
216
+ if num_detections == 0:
217
+ return """
218
+ <div style='padding:12px; border:1px solid #5cb85c; border-radius:8px;'>
219
+ ✓ <b>No forgery detected.</b><br>
220
+ The document appears to be authentic.
221
+ </div>
222
+ """
223
+
224
+ # Calculate statistics
225
+ avg_confidence = sum(r['confidence'] for r in results) / num_detections
226
+ type_counts = {}
227
+ for r in results:
228
+ ft = r['forgery_type']
229
+ type_counts[ft] = type_counts.get(ft, 0) + 1
230
+
231
+ html = f"""
232
+ <div style='padding:12px; border:1px solid #d9534f; border-radius:8px;'>
233
+ <b>⚠️ Forgery Detected</b><br><br>
234
+
235
+ <b>Summary:</b><br>
236
+ • Regions detected: {num_detections}<br>
237
+ • Average confidence: {avg_confidence*100:.1f}%<br><br>
238
 
239
+ <b>Model Performance:</b><br>
240
+ Segmentation Dice: {MODEL_METRICS['segmentation']['dice']*100:.1f}%<br>
241
+ • Classification Accuracy: {MODEL_METRICS['classification']['overall_accuracy']*100:.1f}%<br><br>
242
+
243
+ <b>Detections:</b><br>
244
+ """
245
 
246
+ for i, result in enumerate(results, 1):
247
+ forgery_type = result['forgery_type']
248
+ confidence = result['confidence']
249
+ bbox = result['bounding_box']
250
+
251
+ forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0]
252
+ color_rgb = CLASS_COLORS[forgery_id]
253
+ color_hex = f"#{color_rgb[0]:02x}{color_rgb[1]:02x}{color_rgb[2]:02x}"
254
+
255
+ html += f"""
256
+ <div style='margin:8px 0; padding:8px; border-left:3px solid {color_hex}; background:#f9f9f9;'>
257
+ <b>Region {i}:</b> {forgery_type} ({confidence*100:.1f}%)<br>
258
+ <small>Location: ({bbox[0]}, {bbox[1]}) | Size: {bbox[2]}×{bbox[3]}px</small>
259
+ </div>
260
+ """
261
 
262
+ html += """
263
+ </div>
264
+ """
 
 
 
 
265
 
266
+ return html
267
 
268
 
 
 
 
 
269
  # Initialize detector
 
270
  detector = ForgeryDetector()
271
 
272
 
 
274
  """Gradio interface function"""
275
  try:
276
  if file is None:
277
+ return None, None, "<div style='padding:12px; border:1px solid #d9534f; border-radius:8px;'>❌ <b>No file uploaded.</b></div>"
278
 
279
  # Get file path
280
  file_path = file.name if hasattr(file, 'name') else file
281
 
282
  # Check if PDF
283
  if file_path.lower().endswith('.pdf'):
284
+ original, overlay, results_html = detector.detect(file_path)
285
  else:
286
  image = Image.open(file_path)
287
+ original, overlay, results_html = detector.detect(image)
288
 
289
+ return original, overlay, results_html
290
 
291
  except Exception as e:
292
  import traceback
293
  error_details = traceback.format_exc()
294
  print(f"Error: {error_details}")
295
  error_html = f"""
296
+ <div style='padding:12px; border:1px solid #d9534f; border-radius:8px;'>
297
+ <b>Error:</b> {str(e)}
 
298
  </div>
299
  """
300
+ return None, None, error_html
301
 
302
 
303
+ # Custom CSS - subtle styling
304
  custom_css = """
305
+ .predict-btn {
306
+ background-color: #4169E1 !important;
307
+ color: white !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  }
309
+ .clear-btn {
310
+ background-color: #6A89A7 !important;
311
+ color: white !important;
 
312
  }
313
  """
314
 
315
+ # Create Gradio interface
316
+ with gr.Blocks(css=custom_css) as demo:
317
+
318
  gr.Markdown(
319
  """
320
+ # 📄 Document Forgery Detection
321
+ Upload a document image or PDF to detect and classify forgeries.
 
 
 
322
  """
323
  )
324
 
325
  with gr.Row():
326
  with gr.Column(scale=1):
327
+ gr.Markdown("### Upload Document")
328
  input_file = gr.File(
329
  label="Document (Image or PDF)",
330
  file_types=["image", ".pdf"],
331
  type="filepath"
332
  )
333
 
334
+ with gr.Row():
335
+ clear_btn = gr.Button("🧹 Clear", elem_classes="clear-btn")
336
+ analyze_btn = gr.Button("🔍 Analyze", elem_classes="predict-btn")
337
+
338
  gr.Markdown(
339
  """
340
+ **Supported formats:**
341
+ - Images: JPG, PNG, BMP, TIFF, WebP
342
+ - PDF: First page analyzed
343
 
344
+ **Forgery types:**
345
+ - Copy-Move: Duplicated regions
346
+ - Splicing: Mixed sources
347
+ - Text Substitution: Modified text
348
  """
349
  )
 
 
350
 
351
  with gr.Column(scale=1):
352
+ gr.Markdown("### Original Image")
353
+ original_image = gr.Image(label="Uploaded Document", type="numpy")
 
 
 
 
 
354
 
355
  with gr.Row():
356
  with gr.Column(scale=1):
357
+ gr.Markdown("### Detection Result")
358
+ output_image = gr.Image(label="Annotated Document", type="numpy")
359
 
360
  with gr.Column(scale=1):
361
+ gr.Markdown("### Analysis Report")
362
+ output_html = gr.HTML(
363
+ value="<i>No analysis yet. Upload a document and click Analyze.</i>"
364
+ )
365
 
366
  gr.Markdown(
367
  """
368
  ---
369
+ **Model Architecture:**
370
+ - **Localization:** MobileNetV3-Small + UNet (Dice: 62.1%, IoU: 45.1%)
371
+ - **Classification:** LightGBM with 526 features (Accuracy: 88.97%)
372
+ - **Training:** 140K samples (DocTamper + SCD + FCD datasets)
 
 
 
 
 
 
 
 
373
  """
374
  )
375
 
376
+ # Event handlers
377
  analyze_btn.click(
378
  fn=detect_forgery,
379
  inputs=[input_file],
380
+ outputs=[original_image, output_image, output_html]
381
+ )
382
+
383
+ clear_btn.click(
384
+ fn=lambda: (None, None, None, "<i>No analysis yet. Upload a document and click Analyze.</i>"),
385
+ inputs=None,
386
+ outputs=[input_file, original_image, output_image, output_html]
387
  )
388
 
 
 
 
389
 
390
  if __name__ == "__main__":
391
  demo.launch()