""" Document Forgery Detection - Gradio Interface for Hugging Face Spaces This app provides a web interface for detecting and classifying document forgeries. """ import gradio as gr import torch import cv2 import numpy as np from PIL import Image import json from pathlib import Path import sys from typing import Dict, List, Tuple import plotly.graph_objects as go # Add src to path sys.path.insert(0, str(Path(__file__).parent)) from src.models import get_model from src.config import get_config from src.data.preprocessing import DocumentPreprocessor from src.data.augmentation import DatasetAwareAugmentation from src.features.region_extraction import get_mask_refiner, get_region_extractor from src.features.feature_extraction import get_feature_extractor from src.training.classifier import ForgeryClassifier # Class names CLASS_NAMES = {0: 'Copy-Move', 1: 'Splicing', 2: 'Text Substitution'} CLASS_COLORS = { 0: (217, 83, 79), # #d9534f - Muted red 1: (92, 184, 92), # #5cb85c - Muted green 2: (65, 105, 225) # #4169E1 - Royal blue } # Actual model performance metrics MODEL_METRICS = { 'segmentation': { 'dice': 0.6212, 'iou': 0.4506, 'precision': 0.7077, 'recall': 0.5536 }, 'classification': { 'overall_accuracy': 0.8897, 'per_class': { 'copy_move': 0.92, 'splicing': 0.85, 'generation': 0.90 } } } def create_gauge_chart(value: float, title: str, max_value: float = 1.0) -> go.Figure: """Create a subtle radial gauge chart""" fig = go.Figure(go.Indicator( mode="gauge+number", value=value * 100, domain={'x': [0, 1], 'y': [0, 1]}, title={'text': title, 'font': {'size': 14}}, number={'suffix': '%', 'font': {'size': 24}}, gauge={ 'axis': {'range': [0, 100], 'tickwidth': 1}, 'bar': {'color': '#4169E1', 'thickness': 0.7}, 'bgcolor': 'rgba(0,0,0,0)', 'borderwidth': 0, 'steps': [ {'range': [0, 50], 'color': 'rgba(217, 83, 79, 0.1)'}, {'range': [50, 75], 'color': 'rgba(240, 173, 78, 0.1)'}, {'range': [75, 100], 'color': 'rgba(92, 184, 92, 0.1)'} ] } )) fig.update_layout( paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)', height=200, margin=dict(l=20, r=20, t=40, b=20) ) return fig class ForgeryDetector: """Main forgery detection pipeline""" def __init__(self): print("Loading models...") # Load config self.config = get_config('config.yaml') self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load segmentation model self.model = get_model(self.config).to(self.device) checkpoint = torch.load('models/best_doctamper.pth', map_location=self.device) self.model.load_state_dict(checkpoint['model_state_dict']) self.model.eval() # Load classifier self.classifier = ForgeryClassifier(self.config) self.classifier.load('models/classifier') # Initialize components self.preprocessor = DocumentPreprocessor(self.config, 'doctamper') self.augmentation = DatasetAwareAugmentation(self.config, 'doctamper', is_training=False) self.mask_refiner = get_mask_refiner(self.config) self.region_extractor = get_region_extractor(self.config) self.feature_extractor = get_feature_extractor(self.config, is_text_document=True) print("✓ Models loaded successfully!") def detect(self, image): """ Detect forgeries in document image or PDF Returns: original_image: Original uploaded image overlay_image: Image with detection overlay gauge_dice: Dice score gauge gauge_accuracy: Accuracy gauge results_html: Detection results as HTML """ # Handle PDF files if isinstance(image, str) and image.lower().endswith('.pdf'): import fitz # PyMuPDF pdf_document = fitz.open(image) page = pdf_document[0] pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) image = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n) if pix.n == 4: image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) pdf_document.close() # Convert PIL to numpy if isinstance(image, Image.Image): image = np.array(image) # Convert to RGB if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) elif image.shape[2] == 4: image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) original_image = image.copy() # Preprocess preprocessed, _ = self.preprocessor(image, None) # Augment augmented = self.augmentation(preprocessed, None) image_tensor = augmented['image'].unsqueeze(0).to(self.device) # Run localization with torch.no_grad(): logits, decoder_features = self.model(image_tensor) prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0] # Refine mask binary_mask = (prob_map > 0.5).astype(np.uint8) refined_mask = self.mask_refiner.refine(binary_mask, original_size=original_image.shape[:2]) # Extract regions regions = self.region_extractor.extract(refined_mask, prob_map, original_image) # Classify regions results = [] for region in regions: # Extract features features = self.feature_extractor.extract( preprocessed, region['region_mask'], [f.cpu() for f in decoder_features] ) # Reshape features to 2D array if features.ndim == 1: features = features.reshape(1, -1) # Pad/truncate features to match classifier expected_features = 526 current_features = features.shape[1] if current_features < expected_features: padding = np.zeros((features.shape[0], expected_features - current_features)) features = np.hstack([features, padding]) elif current_features > expected_features: features = features[:, :expected_features] # Classify predictions, confidences = self.classifier.predict(features) forgery_type = int(predictions[0]) confidence = float(confidences[0]) if confidence > 0.6: results.append({ 'region_id': region['region_id'], 'bounding_box': region['bounding_box'], 'forgery_type': CLASS_NAMES[forgery_type], 'confidence': confidence }) # Create visualization overlay = self._create_overlay(original_image, results) # Create gauge charts gauge_dice = create_gauge_chart(MODEL_METRICS['segmentation']['dice'], 'Segmentation Dice') gauge_accuracy = create_gauge_chart(MODEL_METRICS['classification']['overall_accuracy'], 'Classification Accuracy') # Create HTML response results_html = self._create_html_report(results) return original_image, overlay, gauge_dice, gauge_accuracy, results_html def _create_overlay(self, image, results): """Create overlay visualization""" overlay = image.copy() for result in results: bbox = result['bounding_box'] x, y, w, h = bbox forgery_type = result['forgery_type'] confidence = result['confidence'] # Get color forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0] color = CLASS_COLORS[forgery_id] # Draw rectangle cv2.rectangle(overlay, (x, y), (x+w, y+h), color, 2) # Draw label label = f"{forgery_type}: {confidence:.1%}" font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.5 thickness = 1 (label_w, label_h), baseline = cv2.getTextSize(label, font, font_scale, thickness) cv2.rectangle(overlay, (x, y-label_h-8), (x+label_w+4, y), color, -1) cv2.putText(overlay, label, (x+2, y-4), font, font_scale, (255, 255, 255), thickness) return overlay def _create_html_report(self, results): """Create HTML report with detection results""" num_detections = len(results) if num_detections == 0: return """
No forgery detected.
The document appears to be authentic.
""" # Calculate statistics avg_confidence = sum(r['confidence'] for r in results) / num_detections type_counts = {} for r in results: ft = r['forgery_type'] type_counts[ft] = type_counts.get(ft, 0) + 1 html = f"""
⚠️ Forgery Detected

Summary:
• Regions detected: {num_detections}
• Average confidence: {avg_confidence*100:.1f}%

Detections:
""" for i, result in enumerate(results, 1): forgery_type = result['forgery_type'] confidence = result['confidence'] bbox = result['bounding_box'] forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0] color_rgb = CLASS_COLORS[forgery_id] color_hex = f"#{color_rgb[0]:02x}{color_rgb[1]:02x}{color_rgb[2]:02x}" html += f"""
Region {i}: {forgery_type} ({confidence*100:.1f}%)
Location: ({bbox[0]}, {bbox[1]}) | Size: {bbox[2]}×{bbox[3]}px
""" html += """
""" return html # Initialize detector detector = ForgeryDetector() def detect_forgery(file): """Gradio interface function""" try: if file is None: empty_html = "
No file uploaded.
" return None, None, None, None, empty_html # Get file path file_path = file if isinstance(file, str) else file # Detect forgeries original, overlay, gauge_dice, gauge_acc, results_html = detector.detect(file_path) return original, overlay, gauge_dice, gauge_acc, results_html except Exception as e: import traceback error_details = traceback.format_exc() print(f"Error: {error_details}") error_html = f"""
Error: {str(e)}
""" return None, None, None, None, error_html # Custom CSS - subtle styling custom_css = """ .predict-btn { background-color: #4169E1 !important; color: white !important; } .clear-btn { background-color: #6A89A7 !important; color: white !important; } """ # Create Gradio interface with gr.Blocks(css=custom_css) as demo: gr.Markdown( """ # 📄 Document Forgery Detection Upload a document image or PDF to detect and classify forgeries. """ ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Upload Document") input_file = gr.Image( label="Document (Image or PDF)", type="filepath", sources=["upload"] ) with gr.Row(): clear_btn = gr.Button("🧹 Clear", elem_classes="clear-btn") analyze_btn = gr.Button("🔍 Analyze", elem_classes="predict-btn") gr.Markdown( """ **Supported formats:** - Images: JPG, PNG, BMP, TIFF, WebP - PDF: First page analyzed **Forgery types:** - Copy-Move: Duplicated regions - Splicing: Mixed sources - Text Substitution: Modified text """ ) with gr.Column(scale=2): gr.Markdown("### Detection Results") with gr.Row(): original_image = gr.Image(label="Original Document", type="numpy") output_image = gr.Image(label="Detected Forgeries", type="numpy") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Model Performance") gauge_dice = gr.Plot(label="Segmentation Dice Score") gauge_accuracy = gr.Plot(label="Classification Accuracy") with gr.Column(scale=1): gr.Markdown("### Analysis Report") output_html = gr.HTML( value="No analysis yet. Upload a document and click Analyze." ) gr.Markdown( """ --- **Model Architecture:** - **Localization:** MobileNetV3-Small + UNet (Dice: 62.1%, IoU: 45.1%) - **Classification:** LightGBM with 526 features (Accuracy: 88.97%) - **Training:** 140K samples (DocTamper + SCD + FCD datasets) """ ) # Event handlers analyze_btn.click( fn=detect_forgery, inputs=[input_file], outputs=[original_image, output_image, gauge_dice, gauge_accuracy, output_html] ) clear_btn.click( fn=lambda: (None, None, None, None, None, "No analysis yet. Upload a document and click Analyze."), inputs=None, outputs=[input_file, original_image, output_image, gauge_dice, gauge_accuracy, output_html] ) if __name__ == "__main__": demo.launch()