""" Document Forgery Detection - Gradio Interface for Hugging Face Spaces This app provides a web interface for detecting and classifying document forgeries. """ import gradio as gr import torch import cv2 import numpy as np from PIL import Image import json from pathlib import Path import sys from typing import Dict, List, Tuple import plotly.graph_objects as go # Add src to path sys.path.insert(0, str(Path(__file__).parent)) from src.models import get_model from src.config import get_config from src.data.preprocessing import DocumentPreprocessor from src.data.augmentation import DatasetAwareAugmentation from src.features.region_extraction import get_mask_refiner, get_region_extractor from src.features.feature_extraction import get_feature_extractor from src.training.classifier import ForgeryClassifier # Class names CLASS_NAMES = {0: 'Copy-Move', 1: 'Splicing', 2: 'Text Substitution'} CLASS_COLORS = { 0: (217, 83, 79), # #d9534f - Muted red 1: (92, 184, 92), # #5cb85c - Muted green 2: (65, 105, 225) # #4169E1 - Royal blue } # Actual model performance metrics MODEL_METRICS = { 'segmentation': { 'dice': 0.6212, 'iou': 0.4506, 'precision': 0.7077, 'recall': 0.5536 }, 'classification': { 'overall_accuracy': 0.8897, 'per_class': { 'copy_move': 0.92, 'splicing': 0.85, 'generation': 0.90 } } } def create_gauge_chart(value: float, title: str, max_value: float = 1.0) -> go.Figure: """Create a subtle radial gauge chart""" fig = go.Figure(go.Indicator( mode="gauge+number", value=value * 100, domain={'x': [0, 1], 'y': [0, 1]}, title={'text': title, 'font': {'size': 14}}, number={'suffix': '%', 'font': {'size': 24}}, gauge={ 'axis': {'range': [0, 100], 'tickwidth': 1}, 'bar': {'color': '#4169E1', 'thickness': 0.7}, 'bgcolor': 'rgba(0,0,0,0)', 'borderwidth': 0, 'steps': [ {'range': [0, 50], 'color': 'rgba(217, 83, 79, 0.1)'}, {'range': [50, 75], 'color': 'rgba(240, 173, 78, 0.1)'}, {'range': [75, 100], 'color': 'rgba(92, 184, 92, 0.1)'} ] } )) fig.update_layout( paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)', height=200, margin=dict(l=20, r=20, t=40, b=20) ) return fig class ForgeryDetector: """Main forgery detection pipeline""" def __init__(self): print("Loading models...") # Load config self.config = get_config('config.yaml') self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load segmentation model self.model = get_model(self.config).to(self.device) checkpoint = torch.load('models/best_doctamper.pth', map_location=self.device) self.model.load_state_dict(checkpoint['model_state_dict']) self.model.eval() # Load classifier self.classifier = ForgeryClassifier(self.config) self.classifier.load('models/classifier') # Initialize components self.preprocessor = DocumentPreprocessor(self.config, 'doctamper') self.augmentation = DatasetAwareAugmentation(self.config, 'doctamper', is_training=False) self.mask_refiner = get_mask_refiner(self.config) self.region_extractor = get_region_extractor(self.config) self.feature_extractor = get_feature_extractor(self.config, is_text_document=True) print("✓ Models loaded successfully!") def detect(self, image): """ Detect forgeries in document image or PDF Returns: original_image: Original uploaded image overlay_image: Image with detection overlay gauge_dice: Dice score gauge gauge_accuracy: Accuracy gauge results_html: Detection results as HTML """ # Handle PDF files if isinstance(image, str) and image.lower().endswith('.pdf'): import fitz # PyMuPDF pdf_document = fitz.open(image) page = pdf_document[0] pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) image = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n) if pix.n == 4: image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) pdf_document.close() # Convert PIL to numpy if isinstance(image, Image.Image): image = np.array(image) # Convert to RGB if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) elif image.shape[2] == 4: image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) original_image = image.copy() # Preprocess preprocessed, _ = self.preprocessor(image, None) # Augment augmented = self.augmentation(preprocessed, None) image_tensor = augmented['image'].unsqueeze(0).to(self.device) # Run localization with torch.no_grad(): logits, decoder_features = self.model(image_tensor) prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0] # Refine mask binary_mask = (prob_map > 0.5).astype(np.uint8) refined_mask = self.mask_refiner.refine(binary_mask, original_size=original_image.shape[:2]) # Extract regions regions = self.region_extractor.extract(refined_mask, prob_map, original_image) # Classify regions results = [] for region in regions: # Extract features features = self.feature_extractor.extract( preprocessed, region['region_mask'], [f.cpu() for f in decoder_features] ) # Reshape features to 2D array if features.ndim == 1: features = features.reshape(1, -1) # Pad/truncate features to match classifier expected_features = 526 current_features = features.shape[1] if current_features < expected_features: padding = np.zeros((features.shape[0], expected_features - current_features)) features = np.hstack([features, padding]) elif current_features > expected_features: features = features[:, :expected_features] # Classify predictions, confidences = self.classifier.predict(features) forgery_type = int(predictions[0]) confidence = float(confidences[0]) if confidence > 0.6: results.append({ 'region_id': region['region_id'], 'bounding_box': region['bounding_box'], 'forgery_type': CLASS_NAMES[forgery_type], 'confidence': confidence }) # Create visualization overlay = self._create_overlay(original_image, results) # Create gauge charts gauge_dice = create_gauge_chart(MODEL_METRICS['segmentation']['dice'], 'Segmentation Dice') gauge_accuracy = create_gauge_chart(MODEL_METRICS['classification']['overall_accuracy'], 'Classification Accuracy') # Create HTML response results_html = self._create_html_report(results) return original_image, overlay, gauge_dice, gauge_accuracy, results_html def _create_overlay(self, image, results): """Create overlay visualization""" overlay = image.copy() for result in results: bbox = result['bounding_box'] x, y, w, h = bbox forgery_type = result['forgery_type'] confidence = result['confidence'] # Get color forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0] color = CLASS_COLORS[forgery_id] # Draw rectangle cv2.rectangle(overlay, (x, y), (x+w, y+h), color, 2) # Draw label label = f"{forgery_type}: {confidence:.1%}" font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.5 thickness = 1 (label_w, label_h), baseline = cv2.getTextSize(label, font, font_scale, thickness) cv2.rectangle(overlay, (x, y-label_h-8), (x+label_w+4, y), color, -1) cv2.putText(overlay, label, (x+2, y-4), font, font_scale, (255, 255, 255), thickness) return overlay def _create_html_report(self, results): """Create HTML report with detection results""" num_detections = len(results) if num_detections == 0: return """