""" Document Forgery Detection - Gradio Interface for Hugging Face Spaces This app provides a web interface for detecting and classifying document forgeries. """ import gradio as gr import torch import cv2 import numpy as np from PIL import Image import json from pathlib import Path import sys from typing import Dict, List, Tuple import plotly.graph_objects as go # Add src to path sys.path.insert(0, str(Path(__file__).parent)) from src.models import get_model from src.config import get_config from src.data.preprocessing import DocumentPreprocessor from src.data.augmentation import DatasetAwareAugmentation from src.features.region_extraction import get_mask_refiner, get_region_extractor from src.features.feature_extraction import get_feature_extractor from src.training.classifier import ForgeryClassifier # Class names CLASS_NAMES = {0: 'Copy-Move', 1: 'Splicing', 2: 'Text Substitution'} CLASS_COLORS = { 0: (217, 83, 79), # #d9534f - Muted red 1: (92, 184, 92), # #5cb85c - Muted green 2: (65, 105, 225) # #4169E1 - Royal blue } # Actual model performance metrics MODEL_METRICS = { 'segmentation': { 'dice': 0.6212, 'iou': 0.4506, 'precision': 0.7077, 'recall': 0.5536 }, 'classification': { 'overall_accuracy': 0.8897, 'per_class': { 'copy_move': 0.92, 'splicing': 0.85, 'generation': 0.90 } } } def create_gauge_chart(value: float, title: str, max_value: float = 1.0) -> go.Figure: """Create a subtle radial gauge chart""" fig = go.Figure(go.Indicator( mode="gauge+number", value=value * 100, domain={'x': [0, 1], 'y': [0, 1]}, title={'text': title, 'font': {'size': 14}}, number={'suffix': '%', 'font': {'size': 24}}, gauge={ 'axis': {'range': [0, 100], 'tickwidth': 1}, 'bar': {'color': '#4169E1', 'thickness': 0.7}, 'bgcolor': 'rgba(0,0,0,0)', 'borderwidth': 0, 'steps': [ {'range': [0, 50], 'color': 'rgba(217, 83, 79, 0.1)'}, {'range': [50, 75], 'color': 'rgba(240, 173, 78, 0.1)'}, {'range': [75, 100], 'color': 'rgba(92, 184, 92, 0.1)'} ] } )) fig.update_layout( paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)', height=200, margin=dict(l=20, r=20, t=40, b=20) ) return fig def create_detection_metrics_gauge(avg_confidence: float, iou: float, precision: float, recall: float, num_detections: int) -> go.Figure: """Create a high-fidelity radial bar chart (concentric rings)""" # Calculate percentages (0-100) metrics = [ {'name': 'Confidence', 'val': avg_confidence * 100 if num_detections > 0 else 0, 'color': '#4169E1', 'base': 80}, {'name': 'Precision', 'val': precision * 100, 'color': '#5cb85c', 'base': 60}, {'name': 'Recall', 'val': recall * 100, 'color': '#f0ad4e', 'base': 40}, {'name': 'IoU', 'val': iou * 100, 'color': '#d9534f', 'base': 20} ] fig = go.Figure() for m in metrics: # 1. Add background track (faint gray ring) fig.add_trace(go.Barpolar( r=[15], theta=[180], width=[360], base=m['base'], marker_color='rgba(128,128,128,0.1)', hoverinfo='none', showlegend=False )) # 2. Add the actual metric bar (the colored arc) # 100% = 360 degrees angle_width = m['val'] * 3.6 fig.add_trace(go.Barpolar( r=[15], theta=[angle_width / 2], width=[angle_width], base=m['base'], name=f"{m['name']}: {m['val']:.1f}%", marker_color=m['color'], marker_line_width=0, hoverinfo='name' )) fig.update_layout( polar=dict( hole=0.1, radialaxis=dict(visible=False, range=[0, 100]), angularaxis=dict( rotation=90, # Start at 12 o'clock direction='clockwise', # Go clockwise gridcolor='rgba(128,128,128,0.2)', tickmode='array', tickvals=[0, 90, 180, 270], ticktext=['0%', '25%', '50%', '75%'], showticklabels=True, tickfont=dict(size=12, color='#888') ), bgcolor='rgba(0,0,0,0)' ), showlegend=True, legend=dict( orientation="v", yanchor="middle", y=0.5, xanchor="left", x=1.1, font=dict(size=14, color='white'), itemwidth=30 ), paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)', height=450, margin=dict(l=60, r=180, t=40, b=40) ) return fig class ForgeryDetector: """Main forgery detection pipeline""" def __init__(self): print("Loading models...") # Load config self.config = get_config('config.yaml') self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load segmentation model self.model = get_model(self.config).to(self.device) checkpoint = torch.load('models/best_doctamper.pth', map_location=self.device) self.model.load_state_dict(checkpoint['model_state_dict']) self.model.eval() # Load classifier self.classifier = ForgeryClassifier(self.config) self.classifier.load('models/classifier') # Initialize components self.preprocessor = DocumentPreprocessor(self.config, 'doctamper') self.augmentation = DatasetAwareAugmentation(self.config, 'doctamper', is_training=False) self.mask_refiner = get_mask_refiner(self.config) self.region_extractor = get_region_extractor(self.config) self.feature_extractor = get_feature_extractor(self.config, is_text_document=True) print("✓ Models loaded successfully!") def detect(self, image): """ Detect forgeries in document image or PDF Returns: original_image: Original uploaded image overlay_image: Image with detection overlay gauge_dice: Dice score gauge gauge_accuracy: Accuracy gauge results_html: Detection results as HTML """ # Handle file path input (from gr.Image with type="filepath") if isinstance(image, str): if image.lower().endswith('.pdf'): # Handle PDF files import fitz # PyMuPDF pdf_document = fitz.open(image) page = pdf_document[0] pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) image = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n) if pix.n == 4: image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) pdf_document.close() else: # Load image file image = Image.open(image) image = np.array(image) # Convert PIL to numpy if isinstance(image, Image.Image): image = np.array(image) # Convert to RGB if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) elif image.shape[2] == 4: image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) original_image = image.copy() # Preprocess preprocessed, _ = self.preprocessor(image, None) # Augment augmented = self.augmentation(preprocessed, None) image_tensor = augmented['image'].unsqueeze(0).to(self.device) # Run localization with torch.no_grad(): logits, decoder_features = self.model(image_tensor) prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0] # Resize probability map to match original image size to avoid index mismatch errors prob_map_resized = cv2.resize( prob_map, (original_image.shape[1], original_image.shape[0]), interpolation=cv2.INTER_LINEAR ) # Refine mask binary_mask = (prob_map_resized > 0.5).astype(np.uint8) refined_mask = self.mask_refiner.refine(prob_map_resized, original_size=original_image.shape[:2]) # Extract regions regions = self.region_extractor.extract(refined_mask, prob_map_resized, original_image) # Classify regions results = [] for region in regions: # Extract features features = self.feature_extractor.extract( preprocessed, region['region_mask'], [f.cpu() for f in decoder_features] ) # Reshape features to 2D array if features.ndim == 1: features = features.reshape(1, -1) # Pad/truncate features to match classifier expected_features = 526 current_features = features.shape[1] if current_features < expected_features: padding = np.zeros((features.shape[0], expected_features - current_features)) features = np.hstack([features, padding]) elif current_features > expected_features: features = features[:, :expected_features] # Classify predictions, confidences = self.classifier.predict(features) forgery_type = int(predictions[0]) confidence = float(confidences[0]) if confidence > 0.6: results.append({ 'region_id': region['region_id'], 'bounding_box': region['bounding_box'], 'forgery_type': CLASS_NAMES[forgery_type], 'confidence': confidence }) # Create visualization overlay = self._create_overlay(original_image, results) # Calculate actual detection metrics from probability map and mask num_detections = len(results) avg_confidence = sum(r['confidence'] for r in results) / num_detections if num_detections > 0 else 0 # Calculate IoU, Precision, Recall from the refined mask and probability map if num_detections > 0: # Use resized prob_map to match refined_mask dimensions high_conf_mask = (prob_map_resized > 0.7).astype(np.uint8) predicted_positive = np.sum(refined_mask > 0) high_conf_positive = np.sum(high_conf_mask > 0) # Calculate intersection and union intersection = np.sum((refined_mask > 0) & (high_conf_mask > 0)) union = np.sum((refined_mask > 0) | (high_conf_mask > 0)) # Calculate metrics iou = intersection / union if union > 0 else 0 precision = intersection / predicted_positive if predicted_positive > 0 else 0 recall = intersection / high_conf_positive if high_conf_positive > 0 else 0 else: # No detections - use zeros iou = 0 precision = 0 recall = 0 # Create detection metrics gauge with actual values metrics_gauge = create_detection_metrics_gauge(avg_confidence, iou, precision, recall, num_detections) # Create HTML response results_html = self._create_html_report(results) return overlay, metrics_gauge, results_html def _create_overlay(self, image, results): """Create overlay visualization""" overlay = image.copy() for result in results: bbox = result['bounding_box'] x, y, w, h = bbox forgery_type = result['forgery_type'] confidence = result['confidence'] # Get color forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0] color = CLASS_COLORS[forgery_id] # Draw rectangle cv2.rectangle(overlay, (x, y), (x+w, y+h), color, 2) # Draw label label = f"{forgery_type}: {confidence:.1%}" font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.5 thickness = 1 (label_w, label_h), baseline = cv2.getTextSize(label, font, font_scale, thickness) cv2.rectangle(overlay, (x, y-label_h-8), (x+label_w+4, y), color, -1) cv2.putText(overlay, label, (x+2, y-4), font, font_scale, (255, 255, 255), thickness) return overlay def _create_html_report(self, results): """Create HTML report with detection results""" num_detections = len(results) if num_detections == 0: return """
No forgery detected.
The document appears to be authentic.
""" # Calculate statistics avg_confidence = sum(r['confidence'] for r in results) / num_detections type_counts = {} for r in results: ft = r['forgery_type'] type_counts[ft] = type_counts.get(ft, 0) + 1 html = f"""
⚠️ Forgery Detected

Summary:
• Regions detected: {num_detections}
• Average confidence: {avg_confidence*100:.1f}%

Detections:
""" for i, result in enumerate(results, 1): forgery_type = result['forgery_type'] confidence = result['confidence'] bbox = result['bounding_box'] forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0] color_rgb = CLASS_COLORS[forgery_id] color_hex = f"#{color_rgb[0]:02x}{color_rgb[1]:02x}{color_rgb[2]:02x}" html += f"""
Region {i}: {forgery_type} ({confidence*100:.1f}%)
Location: ({bbox[0]}, {bbox[1]}) | Size: {bbox[2]}×{bbox[3]}px
""" html += """
""" return html # Initialize detector detector = ForgeryDetector() def detect_forgery(file): """Gradio interface function - handles image and PDF uploads""" try: if file is None: empty_html = "
No file uploaded.
" return None, None, empty_html # Detect forgeries overlay, metrics_gauge, results_html = detector.detect(file) return overlay, metrics_gauge, results_html except Exception as e: import traceback error_details = traceback.format_exc() print(f"Error: {error_details}") error_html = f"""
Error: {str(e)}
""" return None, None, error_html # Custom CSS - subtle styling custom_css = """ .predict-btn { background-color: #4169E1 !important; color: white !important; } .clear-btn { background-color: #6A89A7 !important; color: white !important; } """ # Create Gradio interface with gr.Blocks(css=custom_css) as demo: gr.Markdown( """ # 📄 Document Forgery Detection Upload a document image or PDF to detect and classify forgeries using deep learning. The system combines MobileNetV3-UNet for precise localization and LightGBM for classification, identifying Copy-Move, Splicing, and Text Substitution manipulations with detailed confidence scores and bounding boxes. Trained on 140K samples for robust performance. """ ) gr.Markdown("---") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Upload Document") input_file = gr.File( label="📤 Upload Image or PDF", file_types=["image", ".pdf"], type="filepath" ) with gr.Row(): clear_btn = gr.Button("🧹 Clear", elem_classes="clear-btn") analyze_btn = gr.Button("🔍 Analyze", elem_classes="predict-btn") with gr.Column(scale=1): gr.Markdown("### Information") gr.HTML( """

Supported formats:

Forgery types:

""" ) with gr.Column(scale=2): gr.Markdown("### Detection Results") output_image = gr.Image(label="Detected Forgeries", type="numpy") gr.Markdown("---") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Analysis Report") output_html = gr.HTML( value="No analysis yet. Upload a document and click Analyze." ) with gr.Column(scale=1): gr.Markdown("### Detection Metrics") metrics_gauge = gr.Plot(label="Concentric Metrics Gauge") gr.Markdown("---") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Model Architecture") gr.HTML( """

Localization: MobileNetV3-Small + UNet

Dice: 62.12% | IoU: 45.06% | Precision: 70.77% | Recall: 55.36%

Classification: LightGBM with 526 features

Train Accuracy: 90.53% | Val Accuracy: 88.97%

Training: 140K samples from DocTamper dataset

""" ) with gr.Column(scale=1): gr.Markdown("### Model Performance") gr.HTML( f"""

Trained Model Performance:

Segmentation Dice: {MODEL_METRICS['segmentation']['dice']*100:.2f}%
Classification Accuracy: {MODEL_METRICS['classification']['overall_accuracy']*100:.2f}%
""" ) # Event handlers analyze_btn.click( fn=detect_forgery, inputs=[input_file], outputs=[output_image, metrics_gauge, output_html] ) clear_btn.click( fn=lambda: (None, None, None, "No analysis yet. Upload a document and click Analyze."), inputs=None, outputs=[input_file, output_image, metrics_gauge, output_html] ) if __name__ == "__main__": demo.launch()