| """ |
| Document Forgery Detection - Gradio Interface for Hugging Face Spaces |
| |
| This app provides a web interface for detecting and classifying document forgeries. |
| """ |
|
|
| import gradio as gr |
| import torch |
| import cv2 |
| import numpy as np |
| from PIL import Image |
| import json |
| from pathlib import Path |
| import sys |
| from typing import Dict, List, Tuple |
| import plotly.graph_objects as go |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parent)) |
|
|
| from src.models import get_model |
| from src.config import get_config |
| from src.data.preprocessing import DocumentPreprocessor |
| from src.data.augmentation import DatasetAwareAugmentation |
| from src.features.region_extraction import get_mask_refiner, get_region_extractor |
| from src.features.feature_extraction import get_feature_extractor |
| from src.training.classifier import ForgeryClassifier |
|
|
| |
| CLASS_NAMES = {0: 'Copy-Move', 1: 'Splicing', 2: 'Text Substitution'} |
| CLASS_COLORS = { |
| 0: (217, 83, 79), |
| 1: (92, 184, 92), |
| 2: (65, 105, 225) |
| } |
|
|
| |
| MODEL_METRICS = { |
| 'segmentation': { |
| 'dice': 0.6212, |
| 'iou': 0.4506, |
| 'precision': 0.7077, |
| 'recall': 0.5536 |
| }, |
| 'classification': { |
| 'overall_accuracy': 0.8897, |
| 'per_class': { |
| 'copy_move': 0.92, |
| 'splicing': 0.85, |
| 'generation': 0.90 |
| } |
| } |
| } |
|
|
|
|
| def create_gauge_chart(value: float, title: str, max_value: float = 1.0) -> go.Figure: |
| """Create a subtle radial gauge chart""" |
| fig = go.Figure(go.Indicator( |
| mode="gauge+number", |
| value=value * 100, |
| domain={'x': [0, 1], 'y': [0, 1]}, |
| title={'text': title, 'font': {'size': 14}}, |
| number={'suffix': '%', 'font': {'size': 24}}, |
| gauge={ |
| 'axis': {'range': [0, 100], 'tickwidth': 1}, |
| 'bar': {'color': '#4169E1', 'thickness': 0.7}, |
| 'bgcolor': 'rgba(0,0,0,0)', |
| 'borderwidth': 0, |
| 'steps': [ |
| {'range': [0, 50], 'color': 'rgba(217, 83, 79, 0.1)'}, |
| {'range': [50, 75], 'color': 'rgba(240, 173, 78, 0.1)'}, |
| {'range': [75, 100], 'color': 'rgba(92, 184, 92, 0.1)'} |
| ] |
| } |
| )) |
| |
| fig.update_layout( |
| paper_bgcolor='rgba(0,0,0,0)', |
| plot_bgcolor='rgba(0,0,0,0)', |
| height=200, |
| margin=dict(l=20, r=20, t=40, b=20) |
| ) |
| |
| return fig |
|
|
|
|
| def create_detection_metrics_gauge(avg_confidence: float, iou: float, precision: float, recall: float, num_detections: int) -> go.Figure: |
| """Create a high-fidelity radial bar chart (concentric rings)""" |
| |
| |
| metrics = [ |
| {'name': 'Confidence', 'val': avg_confidence * 100 if num_detections > 0 else 0, 'color': '#4169E1', 'base': 80}, |
| {'name': 'Precision', 'val': precision * 100, 'color': '#5cb85c', 'base': 60}, |
| {'name': 'Recall', 'val': recall * 100, 'color': '#f0ad4e', 'base': 40}, |
| {'name': 'IoU', 'val': iou * 100, 'color': '#d9534f', 'base': 20} |
| ] |
| |
| fig = go.Figure() |
|
|
| for m in metrics: |
| |
| fig.add_trace(go.Barpolar( |
| r=[15], |
| theta=[180], |
| width=[360], |
| base=m['base'], |
| marker_color='rgba(128,128,128,0.1)', |
| hoverinfo='none', |
| showlegend=False |
| )) |
| |
| |
| |
| angle_width = m['val'] * 3.6 |
| fig.add_trace(go.Barpolar( |
| r=[15], |
| theta=[angle_width / 2], |
| width=[angle_width], |
| base=m['base'], |
| name=f"{m['name']}: {m['val']:.1f}%", |
| marker_color=m['color'], |
| marker_line_width=0, |
| hoverinfo='name' |
| )) |
|
|
| fig.update_layout( |
| polar=dict( |
| hole=0.1, |
| radialaxis=dict(visible=False, range=[0, 100]), |
| angularaxis=dict( |
| rotation=90, |
| direction='clockwise', |
| gridcolor='rgba(128,128,128,0.2)', |
| tickmode='array', |
| tickvals=[0, 90, 180, 270], |
| ticktext=['0%', '25%', '50%', '75%'], |
| showticklabels=True, |
| tickfont=dict(size=12, color='#888') |
| ), |
| bgcolor='rgba(0,0,0,0)' |
| ), |
| showlegend=True, |
| legend=dict( |
| orientation="v", |
| yanchor="middle", |
| y=0.5, |
| xanchor="left", |
| x=1.1, |
| font=dict(size=14, color='white'), |
| itemwidth=30 |
| ), |
| paper_bgcolor='rgba(0,0,0,0)', |
| plot_bgcolor='rgba(0,0,0,0)', |
| height=450, |
| margin=dict(l=60, r=180, t=40, b=40) |
| ) |
| |
| return fig |
|
|
|
|
| class ForgeryDetector: |
| """Main forgery detection pipeline""" |
| |
| def __init__(self): |
| print("Loading models...") |
| |
| |
| self.config = get_config('config.yaml') |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
| |
| self.model = get_model(self.config).to(self.device) |
| checkpoint = torch.load('models/best_doctamper.pth', map_location=self.device) |
| self.model.load_state_dict(checkpoint['model_state_dict']) |
| self.model.eval() |
| |
| |
| self.classifier = ForgeryClassifier(self.config) |
| self.classifier.load('models/classifier') |
| |
| |
| self.preprocessor = DocumentPreprocessor(self.config, 'doctamper') |
| self.augmentation = DatasetAwareAugmentation(self.config, 'doctamper', is_training=False) |
| self.mask_refiner = get_mask_refiner(self.config) |
| self.region_extractor = get_region_extractor(self.config) |
| self.feature_extractor = get_feature_extractor(self.config, is_text_document=True) |
| |
| print("✓ Models loaded successfully!") |
| |
| def detect(self, image): |
| """ |
| Detect forgeries in document image or PDF |
| |
| Returns: |
| original_image: Original uploaded image |
| overlay_image: Image with detection overlay |
| gauge_dice: Dice score gauge |
| gauge_accuracy: Accuracy gauge |
| results_html: Detection results as HTML |
| """ |
| |
| if isinstance(image, str): |
| if image.lower().endswith('.pdf'): |
| |
| import fitz |
| pdf_document = fitz.open(image) |
| page = pdf_document[0] |
| pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) |
| image = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n) |
| if pix.n == 4: |
| image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) |
| pdf_document.close() |
| else: |
| |
| image = Image.open(image) |
| image = np.array(image) |
| |
| |
| if isinstance(image, Image.Image): |
| image = np.array(image) |
| |
| |
| if len(image.shape) == 2: |
| image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) |
| elif image.shape[2] == 4: |
| image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) |
| |
| original_image = image.copy() |
| |
| |
| preprocessed, _ = self.preprocessor(image, None) |
| |
| |
| augmented = self.augmentation(preprocessed, None) |
| image_tensor = augmented['image'].unsqueeze(0).to(self.device) |
| |
| |
| with torch.no_grad(): |
| logits, decoder_features = self.model(image_tensor) |
| prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0] |
| |
| |
| binary_mask = (prob_map > 0.5).astype(np.uint8) |
| refined_mask = self.mask_refiner.refine(binary_mask, original_size=original_image.shape[:2]) |
| |
| |
| regions = self.region_extractor.extract(refined_mask, prob_map, original_image) |
| |
| |
| results = [] |
| for region in regions: |
| |
| features = self.feature_extractor.extract( |
| preprocessed, |
| region['region_mask'], |
| [f.cpu() for f in decoder_features] |
| ) |
| |
| |
| if features.ndim == 1: |
| features = features.reshape(1, -1) |
| |
| |
| expected_features = 526 |
| current_features = features.shape[1] |
| if current_features < expected_features: |
| padding = np.zeros((features.shape[0], expected_features - current_features)) |
| features = np.hstack([features, padding]) |
| elif current_features > expected_features: |
| features = features[:, :expected_features] |
| |
| |
| predictions, confidences = self.classifier.predict(features) |
| forgery_type = int(predictions[0]) |
| confidence = float(confidences[0]) |
| |
| if confidence > 0.6: |
| results.append({ |
| 'region_id': region['region_id'], |
| 'bounding_box': region['bounding_box'], |
| 'forgery_type': CLASS_NAMES[forgery_type], |
| 'confidence': confidence |
| }) |
| |
| |
| overlay = self._create_overlay(original_image, results) |
| |
| |
| num_detections = len(results) |
| avg_confidence = sum(r['confidence'] for r in results) / num_detections if num_detections > 0 else 0 |
| |
| |
| if num_detections > 0: |
| |
| high_conf_mask = (prob_map > 0.7).astype(np.uint8) |
| predicted_positive = np.sum(refined_mask > 0) |
| high_conf_positive = np.sum(high_conf_mask > 0) |
| |
| |
| intersection = np.sum((refined_mask > 0) & (high_conf_mask > 0)) |
| union = np.sum((refined_mask > 0) | (high_conf_mask > 0)) |
| |
| |
| iou = intersection / union if union > 0 else 0 |
| precision = intersection / predicted_positive if predicted_positive > 0 else 0 |
| recall = intersection / high_conf_positive if high_conf_positive > 0 else 0 |
| else: |
| |
| iou = 0 |
| precision = 0 |
| recall = 0 |
| |
| |
| metrics_gauge = create_detection_metrics_gauge(avg_confidence, iou, precision, recall, num_detections) |
| |
| |
| results_html = self._create_html_report(results) |
| |
| return overlay, metrics_gauge, results_html |
| |
| def _create_overlay(self, image, results): |
| """Create overlay visualization""" |
| overlay = image.copy() |
| |
| for result in results: |
| bbox = result['bounding_box'] |
| x, y, w, h = bbox |
| |
| forgery_type = result['forgery_type'] |
| confidence = result['confidence'] |
| |
| |
| forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0] |
| color = CLASS_COLORS[forgery_id] |
| |
| |
| cv2.rectangle(overlay, (x, y), (x+w, y+h), color, 2) |
| |
| |
| label = f"{forgery_type}: {confidence:.1%}" |
| font = cv2.FONT_HERSHEY_SIMPLEX |
| font_scale = 0.5 |
| thickness = 1 |
| (label_w, label_h), baseline = cv2.getTextSize(label, font, font_scale, thickness) |
| |
| cv2.rectangle(overlay, (x, y-label_h-8), (x+label_w+4, y), color, -1) |
| cv2.putText(overlay, label, (x+2, y-4), font, font_scale, (255, 255, 255), thickness) |
| |
| return overlay |
| |
| def _create_html_report(self, results): |
| """Create HTML report with detection results""" |
| num_detections = len(results) |
| |
| if num_detections == 0: |
| return """ |
| <div style='padding:12px; border:1px solid #5cb85c; border-radius:8px;'> |
| ✓ <b>No forgery detected.</b><br> |
| The document appears to be authentic. |
| </div> |
| """ |
| |
| |
| avg_confidence = sum(r['confidence'] for r in results) / num_detections |
| type_counts = {} |
| for r in results: |
| ft = r['forgery_type'] |
| type_counts[ft] = type_counts.get(ft, 0) + 1 |
| |
| html = f""" |
| <div style='padding:12px; border:1px solid #d9534f; border-radius:8px;'> |
| <b>⚠️ Forgery Detected</b><br><br> |
| |
| <b>Summary:</b><br> |
| • Regions detected: {num_detections}<br> |
| • Average confidence: {avg_confidence*100:.1f}%<br><br> |
| |
| <b>Detections:</b><br> |
| """ |
| |
| for i, result in enumerate(results, 1): |
| forgery_type = result['forgery_type'] |
| confidence = result['confidence'] |
| bbox = result['bounding_box'] |
| |
| forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0] |
| color_rgb = CLASS_COLORS[forgery_id] |
| color_hex = f"#{color_rgb[0]:02x}{color_rgb[1]:02x}{color_rgb[2]:02x}" |
| |
| html += f""" |
| <div style='margin:8px 0; padding:8px; border-left:3px solid {color_hex}; background:rgba(0,0,0,0.02);'> |
| <b>Region {i}:</b> {forgery_type} ({confidence*100:.1f}%)<br> |
| <small>Location: ({bbox[0]}, {bbox[1]}) | Size: {bbox[2]}×{bbox[3]}px</small> |
| </div> |
| """ |
| |
| html += """ |
| </div> |
| """ |
| |
| return html |
|
|
|
|
| |
| detector = ForgeryDetector() |
|
|
|
|
| def detect_forgery(file): |
| """Gradio interface function""" |
| try: |
| if file is None: |
| empty_html = "<div style='padding:12px; border:1px solid #d9534f; border-radius:8px;'>❌ <b>No file uploaded.</b></div>" |
| return None, None, empty_html |
| |
| |
| file_path = file if isinstance(file, str) else file |
| |
| |
| overlay, metrics_gauge, results_html = detector.detect(file_path) |
| |
| return overlay, metrics_gauge, results_html |
| |
| except Exception as e: |
| import traceback |
| error_details = traceback.format_exc() |
| print(f"Error: {error_details}") |
| error_html = f""" |
| <div style='padding:12px; border:1px solid #d9534f; border-radius:8px;'> |
| ❌ <b>Error:</b> {str(e)} |
| </div> |
| """ |
| return None, None, error_html |
|
|
|
|
| |
| custom_css = """ |
| .predict-btn { |
| background-color: #4169E1 !important; |
| color: white !important; |
| } |
| .clear-btn { |
| background-color: #6A89A7 !important; |
| color: white !important; |
| } |
| """ |
|
|
| |
| with gr.Blocks(css=custom_css) as demo: |
| |
| gr.Markdown( |
| """ |
| # 📄 Document Forgery Detection |
| Upload a document image or PDF to detect and classify forgeries using deep learning. The system combines MobileNetV3-UNet for precise localization and LightGBM for classification, identifying Copy-Move, Splicing, and Text Substitution manipulations with detailed confidence scores and bounding boxes. Trained on 140K samples for robust performance. |
| """ |
| ) |
| gr.Markdown("---") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### Upload Document") |
| input_file = gr.Image( |
| label="Document (Image or PDF)", |
| type="filepath", |
| sources=["upload"] |
| ) |
| |
| with gr.Row(): |
| clear_btn = gr.Button("🧹 Clear", elem_classes="clear-btn") |
| analyze_btn = gr.Button("🔍 Analyze", elem_classes="predict-btn") |
| |
| with gr.Column(scale=1): |
| gr.Markdown("### Information") |
| gr.HTML( |
| """ |
| <div style='padding:16px; border:1px solid #ccc; border-radius:8px; background:var(--background-fill-primary);'> |
| <p style='margin-top:0;'><b>Supported formats:</b></p> |
| <ul style='margin:8px 0; padding-left:20px;'> |
| <li>Images: JPG, PNG, BMP, TIFF, WebP</li> |
| <li>PDF: First page analyzed</li> |
| </ul> |
| |
| <p style='margin-bottom:4px;'><b>Forgery types:</b></p> |
| <ul style='margin:8px 0; padding-left:20px;'> |
| <li style='color:#d9534f;'><b>Copy-Move:</b> <span style='color:inherit;'>Duplicated regions</span></li> |
| <li style='color:#4169E1;'><b>Splicing:</b> <span style='color:inherit;'>Mixed sources</span></li> |
| <li style='color:#5cb85c;'><b>Text Substitution:</b> <span style='color:inherit;'>Modified text</span></li> |
| </ul> |
| </div> |
| """ |
| ) |
| |
| with gr.Column(scale=2): |
| gr.Markdown("### Detection Results") |
| output_image = gr.Image(label="Detected Forgeries", type="numpy") |
| |
| gr.Markdown("---") |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### Analysis Report") |
| output_html = gr.HTML( |
| value="<i>No analysis yet. Upload a document and click Analyze.</i>" |
| ) |
| |
| with gr.Column(scale=1): |
| gr.Markdown("### Detection Metrics") |
| metrics_gauge = gr.Plot(label="Concentric Metrics Gauge") |
| |
| gr.Markdown("---") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### Model Architecture") |
| gr.HTML( |
| """ |
| <div style='padding:12px; border:1px solid #444; border-radius:10px; background:var(--background-fill-primary);'> |
| <p style="margin:0 0 0px 0; font-size:1.05em;"><b>Localization:</b> MobileNetV3-Small + UNet</p> |
| <p style='margin:0 20px 5px 0; margin-left:0.5cm; font-size:0.9em; opacity:0.85;'>Dice: 62.12% | IoU: 45.06% | Precision: 70.77% | Recall: 55.36%</p> |
| |
| <p style="margin:0 0 0 0; font-size:1.05em;"><b>Classification:</b> LightGBM with 526 features</p> |
| <p style="margin:0 20px 0 0; margin-left:0.5cm; font-size:0.9em; opacity:0.85;">Train Accuracy: 90.53% | Val Accuracy: 88.97%</p> |
| |
| <p style='margin-top:5px; margin-bottom:0; font-size:1.05em;'><b>Training:</b> 140K samples from DocTamper dataset</p> |
| </div> |
| """ |
| ) |
| |
| with gr.Column(scale=1): |
| gr.Markdown("### Model Performance") |
| gr.HTML( |
| f""" |
| <div style='padding:12px; border:1px solid #444; border-radius:10px; background:var(--background-fill-primary);'> |
| <p style='margin-top:0; margin-bottom:12px;'><b>Trained Model Performance:</b></p> |
| |
| <b>Segmentation Dice: {MODEL_METRICS['segmentation']['dice']*100:.2f}%</b> |
| <div style='width:100%; background:#333; height:12px; border-radius:6px; margin-bottom:12px;'> |
| <div style='width:{MODEL_METRICS['segmentation']['dice']*100:.1f}%; background:#4169E1; height:12px; border-radius:6px;'></div> |
| </div> |
| |
| <b>Classification Accuracy: {MODEL_METRICS['classification']['overall_accuracy']*100:.2f}%</b> |
| <div style='width:100%; background:#333; height:12px; border-radius:6px;'> |
| <div style='width:{MODEL_METRICS['classification']['overall_accuracy']*100:.1f}%; background:#5cb85c; height:12px; border-radius:6px;'></div> |
| </div> |
| </div> |
| """ |
| ) |
| |
| |
| analyze_btn.click( |
| fn=detect_forgery, |
| inputs=[input_file], |
| outputs=[output_image, metrics_gauge, output_html] |
| ) |
| |
| clear_btn.click( |
| fn=lambda: (None, None, None, "<i>No analysis yet. Upload a document and click Analyze.</i>"), |
| inputs=None, |
| outputs=[input_file, output_image, metrics_gauge, output_html] |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|