JKrishnanandhaa commited on
Commit
70b84aa
·
verified ·
1 Parent(s): 5b33d5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -218
app.py CHANGED
@@ -1,7 +1,6 @@
1
  """
2
- Document Forgery Detection - Gradio Interface for Hugging Face Spaces
3
-
4
- This app provides a web interface for detecting and classifying document forgeries.
5
  """
6
 
7
  import gradio as gr
@@ -9,11 +8,14 @@ import torch
9
  import cv2
10
  import numpy as np
11
  from PIL import Image
12
- import json
13
  from pathlib import Path
14
  import sys
 
15
 
16
- # Add src to path
 
 
17
  sys.path.insert(0, str(Path(__file__).parent))
18
 
19
  from src.models import get_model
@@ -24,253 +26,181 @@ from src.features.region_extraction import get_mask_refiner, get_region_extracto
24
  from src.features.feature_extraction import get_feature_extractor
25
  from src.training.classifier import ForgeryClassifier
26
 
27
- # Class names
28
- CLASS_NAMES = {0: 'Copy-Move', 1: 'Splicing', 2: 'Generation'}
 
 
29
  CLASS_COLORS = {
30
- 0: (255, 0, 0), # Red for Copy-Move
31
- 1: (0, 255, 0), # Green for Splicing
32
- 2: (0, 0, 255) # Blue for Generation
33
  }
34
 
35
-
 
 
36
  class ForgeryDetector:
37
- """Main forgery detection pipeline"""
38
-
39
  def __init__(self):
40
  print("Loading models...")
41
-
42
- # Load config
43
- self.config = get_config('config.yaml')
44
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
45
-
46
- # Load segmentation model
47
  self.model = get_model(self.config).to(self.device)
48
- checkpoint = torch.load('models/best_doctamper.pth', map_location=self.device)
49
- self.model.load_state_dict(checkpoint['model_state_dict'])
50
  self.model.eval()
51
-
52
- # Load classifier
53
  self.classifier = ForgeryClassifier(self.config)
54
- self.classifier.load('models/classifier')
55
-
56
- # Initialize components
57
- self.preprocessor = DocumentPreprocessor(self.config, 'doctamper')
58
- self.augmentation = DatasetAwareAugmentation(self.config, 'doctamper', is_training=False)
59
  self.mask_refiner = get_mask_refiner(self.config)
60
  self.region_extractor = get_region_extractor(self.config)
61
  self.feature_extractor = get_feature_extractor(self.config, is_text_document=True)
62
-
63
- print("✓ Models loaded successfully!")
64
-
65
  def detect(self, image):
66
- """
67
- Detect forgeries in document image or PDF
68
-
69
- Args:
70
- image: PIL Image, numpy array, or path to PDF file
71
-
72
- Returns:
73
- overlay_image: Image with detection overlay
74
- results_json: Detection results as JSON
75
- """
76
- # Handle PDF files
77
- if isinstance(image, str) and image.lower().endswith('.pdf'):
78
- import fitz # PyMuPDF
79
- # Open PDF and convert first page to image
80
- pdf_document = fitz.open(image)
81
- page = pdf_document[0] # First page
82
- pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) # 2x scale for better quality
83
- image = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n)
84
- if pix.n == 4: # RGBA
85
- image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
86
- pdf_document.close()
87
-
88
- # Convert PIL to numpy
89
  if isinstance(image, Image.Image):
90
  image = np.array(image)
91
-
92
- # Convert to RGB
93
- if len(image.shape) == 2:
94
  image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
95
  elif image.shape[2] == 4:
96
  image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
97
-
98
- original_image = image.copy()
99
-
100
- # Preprocess
101
  preprocessed, _ = self.preprocessor(image, None)
102
-
103
- # Augment
104
  augmented = self.augmentation(preprocessed, None)
105
- image_tensor = augmented['image'].unsqueeze(0).to(self.device)
106
-
107
- # Run localization
108
  with torch.no_grad():
109
  logits, decoder_features = self.model(image_tensor)
110
  prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
111
-
112
- # Refine mask
113
- binary_mask = (prob_map > 0.5).astype(np.uint8)
114
- refined_mask = self.mask_refiner.refine(binary_mask, original_size=original_image.shape[:2])
115
-
116
- # Extract regions
117
- regions = self.region_extractor.extract(refined_mask, prob_map, original_image)
118
-
119
- # Classify regions
120
  results = []
121
- for region in regions:
122
- # Extract features
123
  features = self.feature_extractor.extract(
124
- preprocessed,
125
- region['region_mask'],
126
- [f.cpu() for f in decoder_features]
127
  )
128
-
129
- # Reshape features to 2D array (1, n_features) for classifier
130
  if features.ndim == 1:
131
  features = features.reshape(1, -1)
132
-
133
- # TEMPORARY FIX: Pad features to match classifier's expected count
134
- expected_features = 526
135
- current_features = features.shape[1]
136
- if current_features < expected_features:
137
- # Pad with zeros
138
- padding = np.zeros((features.shape[0], expected_features - current_features))
139
- features = np.hstack([features, padding])
140
- print(f"Warning: Padded features from {current_features} to {expected_features}")
141
- elif current_features > expected_features:
142
- # Truncate
143
- features = features[:, :expected_features]
144
- print(f"Warning: Truncated features from {current_features} to {expected_features}")
145
-
146
- # Classify
147
- predictions, confidences = self.classifier.predict(features)
148
- forgery_type = int(predictions[0])
149
- confidence = float(confidences[0])
150
-
151
- if confidence > 0.6: # Confidence threshold
152
  results.append({
153
- 'region_id': region['region_id'],
154
- 'bounding_box': region['bounding_box'],
155
- 'forgery_type': CLASS_NAMES[forgery_type],
156
- 'confidence': confidence
157
  })
158
-
159
- # Create visualization
160
- overlay = self._create_overlay(original_image, results)
161
-
162
- # Create JSON response
163
- json_results = {
164
- 'num_detections': len(results),
165
- 'detections': results,
166
- 'model_info': {
167
- 'segmentation_dice': '75%',
168
- 'classifier_accuracy': '92%'
169
- }
170
  }
171
-
172
- return overlay, json_results
173
-
174
- def _create_overlay(self, image, results):
175
- """Create overlay visualization"""
176
- overlay = image.copy()
177
-
178
- # Draw bounding boxes and labels
179
- for result in results:
180
- bbox = result['bounding_box']
181
- x, y, w, h = bbox
182
-
183
- forgery_type = result['forgery_type']
184
- confidence = result['confidence']
185
-
186
- # Get color
187
- forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0]
188
- color = CLASS_COLORS[forgery_id]
189
-
190
- # Draw rectangle
191
- cv2.rectangle(overlay, (x, y), (x+w, y+h), color, 2)
192
-
193
- # Draw label
194
- label = f"{forgery_type}: {confidence:.1%}"
195
- label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
196
- cv2.rectangle(overlay, (x, y-label_size[1]-10), (x+label_size[0], y), color, -1)
197
- cv2.putText(overlay, label, (x, y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
198
-
199
- # Add legend
200
- if len(results) > 0:
201
- legend_y = 30
202
- cv2.putText(overlay, f"Detected {len(results)} forgery region(s)",
203
- (10, legend_y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
204
-
205
- return overlay
206
-
207
-
208
- # Initialize detector
209
- detector = ForgeryDetector()
210
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
- def detect_forgery(file):
213
- """Gradio interface function"""
214
- try:
215
- if file is None:
216
- return None, {"error": "No file uploaded"}
217
-
218
- # Get file path
219
- file_path = file.name if hasattr(file, 'name') else file
220
-
221
- # Check if PDF
222
- if file_path.lower().endswith('.pdf'):
223
- # Pass PDF path directly to detector
224
- overlay, results = detector.detect(file_path)
225
- else:
226
- # Load image and pass to detector
227
- image = Image.open(file_path)
228
- overlay, results = detector.detect(image)
229
-
230
- return overlay, results # Return dict directly, not json.dumps
231
- except Exception as e:
232
- import traceback
233
- error_details = traceback.format_exc()
234
- print(f"Error: {error_details}")
235
- return None, {"error": str(e), "details": error_details}
236
-
237
-
238
- # Create Gradio interface
239
- demo = gr.Interface(
240
- fn=detect_forgery,
241
- inputs=gr.File(label="Upload Document (Image or PDF)", file_types=["image", ".pdf"]),
242
- outputs=[
243
- gr.Image(type="numpy", label="Detection Result"),
244
- gr.JSON(label="Detection Details")
245
- ],
246
- title="📄 Document Forgery Detector",
247
- description="""
248
- Upload a document image or PDF to detect and classify forgeries.
249
-
250
- **Supported Formats:**
251
- - 📷 Images: JPG, PNG, BMP, TIFF, WebP
252
- - 📄 PDF: First page will be analyzed
253
-
254
- **Supported Forgery Types:**
255
- - 🔴 Copy-Move: Duplicated regions within the document
256
- - 🟢 Splicing: Content from different sources
257
- - 🔵 Generation: AI-generated or synthesized content
258
-
259
- **Model Performance:**
260
- - Localization: 75% Dice Score
261
- - Classification: 92% Accuracy
262
- """,
263
- article="""
264
- ### About
265
- This model uses a hybrid deep learning approach:
266
- 1. **Localization**: MobileNetV3-Small + UNet-Lite (detects WHERE)
267
- 2. **Classification**: LightGBM with hybrid features (detects WHAT)
268
-
269
- Trained on DocTamper dataset (140K samples).
270
- """
271
- )
272
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
  if __name__ == "__main__":
275
  demo.launch()
276
-
 
1
  """
2
+ Document Forgery Detection Professional Gradio Dashboard
3
+ Hugging Face Spaces Deployment
 
4
  """
5
 
6
  import gradio as gr
 
8
  import cv2
9
  import numpy as np
10
  from PIL import Image
11
+ import plotly.graph_objects as go
12
  from pathlib import Path
13
  import sys
14
+ import json
15
 
16
+ # -------------------------------------------------
17
+ # PATH SETUP
18
+ # -------------------------------------------------
19
  sys.path.insert(0, str(Path(__file__).parent))
20
 
21
  from src.models import get_model
 
26
  from src.features.feature_extraction import get_feature_extractor
27
  from src.training.classifier import ForgeryClassifier
28
 
29
+ # -------------------------------------------------
30
+ # CONSTANTS
31
+ # -------------------------------------------------
32
+ CLASS_NAMES = {0: "Copy-Move", 1: "Splicing", 2: "Generation"}
33
  CLASS_COLORS = {
34
+ 0: (255, 0, 0),
35
+ 1: (0, 255, 0),
36
+ 2: (0, 0, 255),
37
  }
38
 
39
+ # -------------------------------------------------
40
+ # FORGERY DETECTOR (UNCHANGED CORE LOGIC)
41
+ # -------------------------------------------------
42
  class ForgeryDetector:
 
 
43
  def __init__(self):
44
  print("Loading models...")
45
+
46
+ self.config = get_config("config.yaml")
47
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
+
 
 
49
  self.model = get_model(self.config).to(self.device)
50
+ checkpoint = torch.load("models/best_doctamper.pth", map_location=self.device)
51
+ self.model.load_state_dict(checkpoint["model_state_dict"])
52
  self.model.eval()
53
+
 
54
  self.classifier = ForgeryClassifier(self.config)
55
+ self.classifier.load("models/classifier")
56
+
57
+ self.preprocessor = DocumentPreprocessor(self.config, "doctamper")
58
+ self.augmentation = DatasetAwareAugmentation(self.config, "doctamper", is_training=False)
 
59
  self.mask_refiner = get_mask_refiner(self.config)
60
  self.region_extractor = get_region_extractor(self.config)
61
  self.feature_extractor = get_feature_extractor(self.config, is_text_document=True)
62
+
63
+ print("✓ Models loaded")
64
+
65
  def detect(self, image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  if isinstance(image, Image.Image):
67
  image = np.array(image)
68
+
69
+ if image.ndim == 2:
 
70
  image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
71
  elif image.shape[2] == 4:
72
  image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
73
+
74
+ original = image.copy()
75
+
 
76
  preprocessed, _ = self.preprocessor(image, None)
 
 
77
  augmented = self.augmentation(preprocessed, None)
78
+ image_tensor = augmented["image"].unsqueeze(0).to(self.device)
79
+
 
80
  with torch.no_grad():
81
  logits, decoder_features = self.model(image_tensor)
82
  prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
83
+
84
+ binary = (prob_map > 0.5).astype(np.uint8)
85
+ refined = self.mask_refiner.refine(binary, original_size=original.shape[:2])
86
+ regions = self.region_extractor.extract(refined, prob_map, original)
87
+
 
 
 
 
88
  results = []
89
+ for r in regions:
 
90
  features = self.feature_extractor.extract(
91
+ preprocessed, r["region_mask"], [f.cpu() for f in decoder_features]
 
 
92
  )
93
+
 
94
  if features.ndim == 1:
95
  features = features.reshape(1, -1)
96
+
97
+ if features.shape[1] != 526:
98
+ pad = max(0, 526 - features.shape[1])
99
+ features = np.pad(features, ((0, 0), (0, pad)))[:, :526]
100
+
101
+ pred, conf = self.classifier.predict(features)
102
+ if conf[0] > 0.6:
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  results.append({
104
+ "bounding_box": r["bounding_box"],
105
+ "forgery_type": CLASS_NAMES[int(pred[0])],
106
+ "confidence": float(conf[0]),
 
107
  })
108
+
109
+ overlay = self._draw_overlay(original, results)
110
+
111
+ return overlay, {
112
+ "num_detections": len(results),
113
+ "detections": results,
 
 
 
 
 
 
114
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ def _draw_overlay(self, image, results):
117
+ out = image.copy()
118
+ for r in results:
119
+ x, y, w, h = r["bounding_box"]
120
+ fid = [k for k, v in CLASS_NAMES.items() if v == r["forgery_type"]][0]
121
+ color = CLASS_COLORS[fid]
122
+
123
+ cv2.rectangle(out, (x, y), (x + w, y + h), color, 2)
124
+ label = f"{r['forgery_type']} ({r['confidence']*100:.1f}%)"
125
+ cv2.putText(out, label, (x, y - 6),
126
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
127
+ return out
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ detector = ForgeryDetector()
131
+
132
+ # -------------------------------------------------
133
+ # METRIC VISUALS
134
+ # -------------------------------------------------
135
+ def gauge(value, title):
136
+ fig = go.Figure(go.Indicator(
137
+ mode="gauge+number",
138
+ value=value,
139
+ title={"text": title},
140
+ gauge={"axis": {"range": [0, 100]}, "bar": {"color": "#2563eb"}}
141
+ ))
142
+ fig.update_layout(height=240, margin=dict(t=40, b=20))
143
+ return fig
144
+
145
+ # -------------------------------------------------
146
+ # GRADIO CALLBACK
147
+ # -------------------------------------------------
148
+ def run_detection(file):
149
+ image = Image.open(file.name)
150
+ overlay, result = detector.detect(image)
151
+
152
+ avg_conf = (
153
+ sum(d["confidence"] for d in result["detections"]) / max(1, result["num_detections"])
154
+ ) * 100
155
+
156
+ return (
157
+ overlay,
158
+ result,
159
+ gauge(75, "Localization Dice (%)"),
160
+ gauge(92, "Classifier Accuracy (%)"),
161
+ gauge(avg_conf, "Avg Detection Confidence (%)"),
162
+ )
163
+
164
+ # -------------------------------------------------
165
+ # UI
166
+ # -------------------------------------------------
167
+ with gr.Blocks(theme=gr.themes.Soft(), title="Document Forgery Detection") as demo:
168
+
169
+ gr.Markdown("# 📄 Document Forgery Detection System")
170
+
171
+ with gr.Row():
172
+ file_input = gr.File(label="Upload Document (Image/PDF)")
173
+ detect_btn = gr.Button("Run Detection", variant="primary")
174
+
175
+ output_img = gr.Image(label="Forgery Localization Result", type="numpy")
176
+
177
+ with gr.Tabs():
178
+ with gr.Tab("📊 Metrics"):
179
+ with gr.Row():
180
+ dice_plot = gr.Plot()
181
+ acc_plot = gr.Plot()
182
+ conf_plot = gr.Plot()
183
+
184
+ with gr.Tab("🧾 Details"):
185
+ json_out = gr.JSON()
186
+
187
+ with gr.Tab("👥 Team"):
188
+ gr.Markdown("""
189
+ **Document Forgery Detection Project**
190
+
191
+ - Krishnanandhaa — Model & Training
192
+ - Teammate 1 — Feature Engineering
193
+ - Teammate 2 — Evaluation
194
+ - Teammate 3 — Deployment
195
+
196
+ *Collaborators are added via Hugging Face Space settings.*
197
+ """)
198
+
199
+ detect_btn.click(
200
+ run_detection,
201
+ inputs=file_input,
202
+ outputs=[output_img, json_out, dice_plot, acc_plot, conf_plot]
203
+ )
204
 
205
  if __name__ == "__main__":
206
  demo.launch()