""" Document Forgery Detection - Gradio Interface for Hugging Face Spaces This app provides a web interface for detecting and classifying document forgeries. """ import gradio as gr import torch import cv2 import numpy as np from PIL import Image import json from pathlib import Path import sys from typing import Dict, List, Tuple import plotly.graph_objects as go # Add src to path sys.path.insert(0, str(Path(__file__).parent)) from src.models import get_model from src.config import get_config from src.data.preprocessing import DocumentPreprocessor from src.data.augmentation import DatasetAwareAugmentation from src.features.region_extraction import get_mask_refiner, get_region_extractor from src.features.feature_extraction import get_feature_extractor from src.training.classifier import ForgeryClassifier # Class names CLASS_NAMES = {0: 'Copy-Move', 1: 'Splicing', 2: 'Text Substitution'} CLASS_COLORS = { 0: (217, 83, 79), # #d9534f - Muted red 1: (92, 184, 92), # #5cb85c - Muted green 2: (65, 105, 225) # #4169E1 - Royal blue } # Actual model performance metrics MODEL_METRICS = { 'segmentation': { 'dice': 0.6212, 'iou': 0.4506, 'precision': 0.7077, 'recall': 0.5536 }, 'classification': { 'overall_accuracy': 0.8897, 'per_class': { 'copy_move': 0.92, 'splicing': 0.85, 'generation': 0.90 } } } def create_gauge_chart(value: float, title: str, max_value: float = 1.0) -> go.Figure: """Create a subtle radial gauge chart""" fig = go.Figure(go.Indicator( mode="gauge+number", value=value * 100, domain={'x': [0, 1], 'y': [0, 1]}, title={'text': title, 'font': {'size': 14}}, number={'suffix': '%', 'font': {'size': 24}}, gauge={ 'axis': {'range': [0, 100], 'tickwidth': 1}, 'bar': {'color': '#4169E1', 'thickness': 0.7}, 'bgcolor': 'rgba(0,0,0,0)', 'borderwidth': 0, 'steps': [ {'range': [0, 50], 'color': 'rgba(217, 83, 79, 0.1)'}, {'range': [50, 75], 'color': 'rgba(240, 173, 78, 0.1)'}, {'range': [75, 100], 'color': 'rgba(92, 184, 92, 0.1)'} ] } )) fig.update_layout( paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)', height=200, margin=dict(l=20, r=20, t=40, b=20) ) return fig def create_detection_metrics_gauge(avg_confidence: float, iou: float, precision: float, recall: float, num_detections: int) -> go.Figure: """Create a high-fidelity radial bar chart (concentric rings)""" # Calculate percentages (0-100) metrics = [ {'name': 'Confidence', 'val': avg_confidence * 100 if num_detections > 0 else 0, 'color': '#4169E1', 'base': 80}, {'name': 'Precision', 'val': precision * 100, 'color': '#5cb85c', 'base': 60}, {'name': 'Recall', 'val': recall * 100, 'color': '#f0ad4e', 'base': 40}, {'name': 'IoU', 'val': iou * 100, 'color': '#d9534f', 'base': 20} ] fig = go.Figure() for m in metrics: # 1. Add background track (faint gray ring) fig.add_trace(go.Barpolar( r=[15], theta=[180], width=[360], base=m['base'], marker_color='rgba(128,128,128,0.1)', hoverinfo='none', showlegend=False )) # 2. Add the actual metric bar (the colored arc) # 100% = 360 degrees angle_width = m['val'] * 3.6 fig.add_trace(go.Barpolar( r=[15], theta=[angle_width / 2], width=[angle_width], base=m['base'], name=f"{m['name']}: {m['val']:.1f}%", marker_color=m['color'], marker_line_width=0, hoverinfo='name' )) fig.update_layout( polar=dict( hole=0.1, radialaxis=dict(visible=False, range=[0, 100]), angularaxis=dict( rotation=90, # Start at 12 o'clock direction='clockwise', # Go clockwise gridcolor='rgba(128,128,128,0.2)', tickmode='array', tickvals=[0, 90, 180, 270], ticktext=['0%', '25%', '50%', '75%'], showticklabels=True, tickfont=dict(size=12, color='#888') ), bgcolor='rgba(0,0,0,0)' ), showlegend=True, legend=dict( orientation="v", yanchor="middle", y=0.5, xanchor="left", x=1.1, font=dict(size=14, color='white'), itemwidth=30 ), paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)', height=450, margin=dict(l=60, r=180, t=40, b=40) ) return fig class ForgeryDetector: """Main forgery detection pipeline""" def __init__(self): print("Loading models...") # Load config self.config = get_config('config.yaml') self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load segmentation model self.model = get_model(self.config).to(self.device) checkpoint = torch.load('models/best_doctamper.pth', map_location=self.device) self.model.load_state_dict(checkpoint['model_state_dict']) self.model.eval() # Load classifier self.classifier = ForgeryClassifier(self.config) self.classifier.load('models/classifier') # Initialize components self.preprocessor = DocumentPreprocessor(self.config, 'doctamper') self.augmentation = DatasetAwareAugmentation(self.config, 'doctamper', is_training=False) self.mask_refiner = get_mask_refiner(self.config) self.region_extractor = get_region_extractor(self.config) self.feature_extractor = get_feature_extractor(self.config, is_text_document=True) print("✓ Models loaded successfully!") def detect(self, image): """ Detect forgeries in document image or PDF Returns: original_image: Original uploaded image overlay_image: Image with detection overlay gauge_dice: Dice score gauge gauge_accuracy: Accuracy gauge results_html: Detection results as HTML """ # Handle file path input (from gr.Image with type="filepath") if isinstance(image, str): if image.lower().endswith('.pdf'): # Handle PDF files import fitz # PyMuPDF pdf_document = fitz.open(image) page = pdf_document[0] pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) image = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n) if pix.n == 4: image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) pdf_document.close() else: # Load image file image = Image.open(image) image = np.array(image) # Convert PIL to numpy if isinstance(image, Image.Image): image = np.array(image) # Convert to RGB if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) elif image.shape[2] == 4: image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) original_image = image.copy() # Preprocess preprocessed, _ = self.preprocessor(image, None) # Augment augmented = self.augmentation(preprocessed, None) image_tensor = augmented['image'].unsqueeze(0).to(self.device) # Run localization with torch.no_grad(): logits, decoder_features = self.model(image_tensor) prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0] # Resize probability map to match original image size to avoid index mismatch errors prob_map_resized = cv2.resize( prob_map, (original_image.shape[1], original_image.shape[0]), interpolation=cv2.INTER_LINEAR ) # Refine mask binary_mask = (prob_map_resized > 0.5).astype(np.uint8) refined_mask = self.mask_refiner.refine(prob_map_resized, original_size=original_image.shape[:2]) # Extract regions regions = self.region_extractor.extract(refined_mask, prob_map_resized, original_image) # Classify regions results = [] for region in regions: # Extract features features = self.feature_extractor.extract( preprocessed, region['region_mask'], [f.cpu() for f in decoder_features] ) # Reshape features to 2D array if features.ndim == 1: features = features.reshape(1, -1) # Pad/truncate features to match classifier expected_features = 526 current_features = features.shape[1] if current_features < expected_features: padding = np.zeros((features.shape[0], expected_features - current_features)) features = np.hstack([features, padding]) elif current_features > expected_features: features = features[:, :expected_features] # Classify predictions, confidences = self.classifier.predict(features) forgery_type = int(predictions[0]) confidence = float(confidences[0]) if confidence > 0.6: results.append({ 'region_id': region['region_id'], 'bounding_box': region['bounding_box'], 'forgery_type': CLASS_NAMES[forgery_type], 'confidence': confidence }) # Create visualization overlay = self._create_overlay(original_image, results) # Calculate actual detection metrics from probability map and mask num_detections = len(results) avg_confidence = sum(r['confidence'] for r in results) / num_detections if num_detections > 0 else 0 # Calculate IoU, Precision, Recall from the refined mask and probability map if num_detections > 0: # Use resized prob_map to match refined_mask dimensions high_conf_mask = (prob_map_resized > 0.7).astype(np.uint8) predicted_positive = np.sum(refined_mask > 0) high_conf_positive = np.sum(high_conf_mask > 0) # Calculate intersection and union intersection = np.sum((refined_mask > 0) & (high_conf_mask > 0)) union = np.sum((refined_mask > 0) | (high_conf_mask > 0)) # Calculate metrics iou = intersection / union if union > 0 else 0 precision = intersection / predicted_positive if predicted_positive > 0 else 0 recall = intersection / high_conf_positive if high_conf_positive > 0 else 0 else: # No detections - use zeros iou = 0 precision = 0 recall = 0 # Create detection metrics gauge with actual values metrics_gauge = create_detection_metrics_gauge(avg_confidence, iou, precision, recall, num_detections) # Create HTML response results_html = self._create_html_report(results) return overlay, metrics_gauge, results_html def _create_overlay(self, image, results): """Create overlay visualization""" overlay = image.copy() for result in results: bbox = result['bounding_box'] x, y, w, h = bbox forgery_type = result['forgery_type'] confidence = result['confidence'] # Get color forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0] color = CLASS_COLORS[forgery_id] # Draw rectangle cv2.rectangle(overlay, (x, y), (x+w, y+h), color, 2) # Draw label label = f"{forgery_type}: {confidence:.1%}" font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.5 thickness = 1 (label_w, label_h), baseline = cv2.getTextSize(label, font, font_scale, thickness) cv2.rectangle(overlay, (x, y-label_h-8), (x+label_w+4, y), color, -1) cv2.putText(overlay, label, (x+2, y-4), font, font_scale, (255, 255, 255), thickness) return overlay def _create_html_report(self, results): """Create HTML report with detection results""" num_detections = len(results) if num_detections == 0: return """
Supported formats:
Forgery types:
Localization: MobileNetV3-Small + UNet
Dice: 62.12% | IoU: 45.06% | Precision: 70.77% | Recall: 55.36%
Classification: LightGBM with 526 features
Train Accuracy: 90.53% | Val Accuracy: 88.97%
Training: 140K samples from DocTamper dataset
Trained Model Performance:
Segmentation Dice: {MODEL_METRICS['segmentation']['dice']*100:.2f}%