| """
|
| Document Forgery Detection - Gradio Interface for Hugging Face Spaces
|
|
|
| This app provides a web interface for detecting and classifying document forgeries.
|
| """
|
|
|
| import gradio as gr
|
| import torch
|
| import cv2
|
| import numpy as np
|
| from PIL import Image
|
| import json
|
| from pathlib import Path
|
| import sys
|
|
|
|
|
| sys.path.insert(0, str(Path(__file__).parent))
|
|
|
| from src.models import get_model
|
| from src.config import get_config
|
| from src.data.preprocessing import DocumentPreprocessor
|
| from src.data.augmentation import DatasetAwareAugmentation
|
| from src.features.region_extraction import get_mask_refiner, get_region_extractor
|
| from src.features.feature_extraction import get_feature_extractor
|
| from src.training.classifier import ForgeryClassifier
|
|
|
|
|
| CLASS_NAMES = {0: 'Copy-Move', 1: 'Splicing', 2: 'Generation'}
|
| CLASS_COLORS = {
|
| 0: (255, 0, 0),
|
| 1: (0, 255, 0),
|
| 2: (0, 0, 255)
|
| }
|
|
|
|
|
| class ForgeryDetector:
|
| """Main forgery detection pipeline"""
|
|
|
| def __init__(self):
|
| print("Loading models...")
|
|
|
|
|
| self.config = get_config('config.yaml')
|
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
| self.model = get_model(self.config).to(self.device)
|
| checkpoint = torch.load('models/segmentation_model.pth', map_location=self.device)
|
| self.model.load_state_dict(checkpoint['model_state_dict'])
|
| self.model.eval()
|
|
|
|
|
| self.classifier = ForgeryClassifier(self.config)
|
| self.classifier.load('models/classifier')
|
|
|
|
|
| self.preprocessor = DocumentPreprocessor(self.config, 'doctamper')
|
| self.augmentation = DatasetAwareAugmentation(self.config, 'doctamper', is_training=False)
|
| self.mask_refiner = get_mask_refiner(self.config)
|
| self.region_extractor = get_region_extractor(self.config)
|
| self.feature_extractor = get_feature_extractor(self.config, is_text_document=True)
|
|
|
| print("โ Models loaded successfully!")
|
|
|
| def detect(self, image):
|
| """
|
| Detect forgeries in document image
|
|
|
| Args:
|
| image: PIL Image or numpy array
|
|
|
| Returns:
|
| overlay_image: Image with detection overlay
|
| results_json: Detection results as JSON
|
| """
|
|
|
| if isinstance(image, Image.Image):
|
| image = np.array(image)
|
|
|
|
|
| if len(image.shape) == 2:
|
| image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
| elif image.shape[2] == 4:
|
| image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
|
|
| original_image = image.copy()
|
|
|
|
|
| preprocessed, _ = self.preprocessor(image, None)
|
|
|
|
|
| augmented = self.augmentation(preprocessed, None)
|
| image_tensor = augmented['image'].unsqueeze(0).to(self.device)
|
|
|
|
|
| with torch.no_grad():
|
| logits, decoder_features = self.model(image_tensor)
|
| prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
|
|
|
|
|
| binary_mask = (prob_map > 0.5).astype(np.uint8)
|
| refined_mask = self.mask_refiner.refine(binary_mask, original_size=original_image.shape[:2])
|
|
|
|
|
| regions = self.region_extractor.extract(refined_mask, prob_map, original_image)
|
|
|
|
|
| results = []
|
| for region in regions:
|
|
|
| features = self.feature_extractor.extract(
|
| preprocessed,
|
| region['region_mask'],
|
| [f.cpu() for f in decoder_features]
|
| )
|
|
|
|
|
| predictions, confidences = self.classifier.predict(features)
|
| forgery_type = int(predictions[0])
|
| confidence = float(confidences[0])
|
|
|
| if confidence > 0.6:
|
| results.append({
|
| 'region_id': region['region_id'],
|
| 'bounding_box': region['bounding_box'],
|
| 'forgery_type': CLASS_NAMES[forgery_type],
|
| 'confidence': confidence
|
| })
|
|
|
|
|
| overlay = self._create_overlay(original_image, results)
|
|
|
|
|
| json_results = {
|
| 'num_detections': len(results),
|
| 'detections': results,
|
| 'model_info': {
|
| 'segmentation_dice': '75%',
|
| 'classifier_accuracy': '92%'
|
| }
|
| }
|
|
|
| return overlay, json_results
|
|
|
| def _create_overlay(self, image, results):
|
| """Create overlay visualization"""
|
| overlay = image.copy()
|
|
|
|
|
| for result in results:
|
| bbox = result['bounding_box']
|
| x, y, w, h = bbox
|
|
|
| forgery_type = result['forgery_type']
|
| confidence = result['confidence']
|
|
|
|
|
| forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0]
|
| color = CLASS_COLORS[forgery_id]
|
|
|
|
|
| cv2.rectangle(overlay, (x, y), (x+w, y+h), color, 2)
|
|
|
|
|
| label = f"{forgery_type}: {confidence:.1%}"
|
| label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
|
| cv2.rectangle(overlay, (x, y-label_size[1]-10), (x+label_size[0], y), color, -1)
|
| cv2.putText(overlay, label, (x, y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
|
|
|
|
|
| if len(results) > 0:
|
| legend_y = 30
|
| cv2.putText(overlay, f"Detected {len(results)} forgery region(s)",
|
| (10, legend_y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
|
|
|
| return overlay
|
|
|
|
|
|
|
| detector = ForgeryDetector()
|
|
|
|
|
| def detect_forgery(image):
|
| """Gradio interface function"""
|
| try:
|
| overlay, results = detector.detect(image)
|
| return overlay, json.dumps(results, indent=2)
|
| except Exception as e:
|
| return None, f"Error: {str(e)}"
|
|
|
|
|
|
|
| demo = gr.Interface(
|
| fn=detect_forgery,
|
| inputs=gr.Image(type="pil", label="Upload Document Image"),
|
| outputs=[
|
| gr.Image(type="numpy", label="Detection Result"),
|
| gr.JSON(label="Detection Details")
|
| ],
|
| title="๐ Document Forgery Detector",
|
| description="""
|
| Upload a document image to detect and classify forgeries.
|
|
|
| **Supported Forgery Types:**
|
| - ๐ด Copy-Move: Duplicated regions within the document
|
| - ๐ข Splicing: Content from different sources
|
| - ๐ต Generation: AI-generated or synthesized content
|
|
|
| **Model Performance:**
|
| - Localization: 75% Dice Score
|
| - Classification: 92% Accuracy
|
| """,
|
| examples=[
|
| ["examples/sample1.jpg"],
|
| ["examples/sample2.jpg"],
|
| ],
|
| article="""
|
| ### About
|
| This model uses a hybrid deep learning approach:
|
| 1. **Localization**: MobileNetV3-Small + UNet-Lite (detects WHERE)
|
| 2. **Classification**: LightGBM with hybrid features (detects WHAT)
|
|
|
| Trained on DocTamper dataset (140K samples).
|
| """,
|
| theme=gr.themes.Soft(),
|
| allow_flagging="never"
|
| )
|
|
|
|
|
| if __name__ == "__main__":
|
| demo.launch()
|
|
|