| """ |
| Document Forgery Detection - Gradio Interface for Hugging Face Spaces |
| |
| This app provides a web interface for detecting and classifying document forgeries. |
| """ |
|
|
| import gradio as gr |
| import torch |
| import cv2 |
| import numpy as np |
| from PIL import Image |
| import json |
| from pathlib import Path |
| import sys |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parent)) |
|
|
| from src.models import get_model |
| from src.config import get_config |
| from src.data.preprocessing import DocumentPreprocessor |
| from src.data.augmentation import DatasetAwareAugmentation |
| from src.features.region_extraction import get_mask_refiner, get_region_extractor |
| from src.features.feature_extraction import get_feature_extractor |
| from src.training.classifier import ForgeryClassifier |
|
|
| |
| CLASS_NAMES = {0: 'Copy-Move', 1: 'Splicing', 2: 'Generation'} |
| CLASS_COLORS = { |
| 0: (255, 0, 0), |
| 1: (0, 255, 0), |
| 2: (0, 0, 255) |
| } |
|
|
|
|
| class ForgeryDetector: |
| """Main forgery detection pipeline""" |
| |
| def __init__(self): |
| print("Loading models...") |
| |
| |
| self.config = get_config('config.yaml') |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
| |
| self.model = get_model(self.config).to(self.device) |
| checkpoint = torch.load('models/best_doctamper.pth', map_location=self.device) |
| self.model.load_state_dict(checkpoint['model_state_dict']) |
| self.model.eval() |
| |
| |
| self.classifier = ForgeryClassifier(self.config) |
| self.classifier.load('models/classifier') |
| |
| |
| self.preprocessor = DocumentPreprocessor(self.config, 'doctamper') |
| self.augmentation = DatasetAwareAugmentation(self.config, 'doctamper', is_training=False) |
| self.mask_refiner = get_mask_refiner(self.config) |
| self.region_extractor = get_region_extractor(self.config) |
| self.feature_extractor = get_feature_extractor(self.config, is_text_document=True) |
| |
| print("โ Models loaded successfully!") |
| |
| def detect(self, image): |
| """ |
| Detect forgeries in document image |
| |
| Args: |
| image: PIL Image or numpy array |
| |
| Returns: |
| overlay_image: Image with detection overlay |
| results_json: Detection results as JSON |
| """ |
| |
| if isinstance(image, Image.Image): |
| image = np.array(image) |
| |
| |
| if len(image.shape) == 2: |
| image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) |
| elif image.shape[2] == 4: |
| image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) |
| |
| original_image = image.copy() |
| |
| |
| preprocessed, _ = self.preprocessor(image, None) |
| |
| |
| augmented = self.augmentation(preprocessed, None) |
| image_tensor = augmented['image'].unsqueeze(0).to(self.device) |
| |
| |
| with torch.no_grad(): |
| logits, decoder_features = self.model(image_tensor) |
| prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0] |
| |
| |
| binary_mask = (prob_map > 0.5).astype(np.uint8) |
| refined_mask = self.mask_refiner.refine(binary_mask, original_size=original_image.shape[:2]) |
| |
| |
| regions = self.region_extractor.extract(refined_mask, prob_map, original_image) |
| |
| |
| results = [] |
| for region in regions: |
| |
| features = self.feature_extractor.extract( |
| preprocessed, |
| region['region_mask'], |
| [f.cpu() for f in decoder_features] |
| ) |
| |
| |
| predictions, confidences = self.classifier.predict(features) |
| forgery_type = int(predictions[0]) |
| confidence = float(confidences[0]) |
| |
| if confidence > 0.6: |
| results.append({ |
| 'region_id': region['region_id'], |
| 'bounding_box': region['bounding_box'], |
| 'forgery_type': CLASS_NAMES[forgery_type], |
| 'confidence': confidence |
| }) |
| |
| |
| overlay = self._create_overlay(original_image, results) |
| |
| |
| json_results = { |
| 'num_detections': len(results), |
| 'detections': results, |
| 'model_info': { |
| 'segmentation_dice': '75%', |
| 'classifier_accuracy': '92%' |
| } |
| } |
| |
| return overlay, json_results |
| |
| def _create_overlay(self, image, results): |
| """Create overlay visualization""" |
| overlay = image.copy() |
| |
| |
| for result in results: |
| bbox = result['bounding_box'] |
| x, y, w, h = bbox |
| |
| forgery_type = result['forgery_type'] |
| confidence = result['confidence'] |
| |
| |
| forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0] |
| color = CLASS_COLORS[forgery_id] |
| |
| |
| cv2.rectangle(overlay, (x, y), (x+w, y+h), color, 2) |
| |
| |
| label = f"{forgery_type}: {confidence:.1%}" |
| label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2) |
| cv2.rectangle(overlay, (x, y-label_size[1]-10), (x+label_size[0], y), color, -1) |
| cv2.putText(overlay, label, (x, y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2) |
| |
| |
| if len(results) > 0: |
| legend_y = 30 |
| cv2.putText(overlay, f"Detected {len(results)} forgery region(s)", |
| (10, legend_y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2) |
| |
| return overlay |
|
|
|
|
| |
| detector = ForgeryDetector() |
|
|
|
|
| def detect_forgery(image): |
| """Gradio interface function""" |
| try: |
| overlay, results = detector.detect(image) |
| return overlay, json.dumps(results, indent=2) |
| except Exception as e: |
| return None, f"Error: {str(e)}" |
|
|
|
|
| |
| demo = gr.Interface( |
| fn=detect_forgery, |
| inputs=gr.Image(type="pil", label="Upload Document Image"), |
| outputs=[ |
| gr.Image(type="numpy", label="Detection Result"), |
| gr.JSON(label="Detection Details") |
| ], |
| title="๐ Document Forgery Detector", |
| description=""" |
| Upload a document image to detect and classify forgeries. |
| |
| **Supported Forgery Types:** |
| - ๐ด Copy-Move: Duplicated regions within the document |
| - ๐ข Splicing: Content from different sources |
| - ๐ต Generation: AI-generated or synthesized content |
| |
| **Model Performance:** |
| - Localization: 75% Dice Score |
| - Classification: 92% Accuracy |
| """, |
| examples=[ |
| ["examples/sample1.jpg"], |
| ["examples/sample2.jpg"], |
| ], |
| article=""" |
| ### About |
| This model uses a hybrid deep learning approach: |
| 1. **Localization**: MobileNetV3-Small + UNet-Lite (detects WHERE) |
| 2. **Classification**: LightGBM with hybrid features (detects WHAT) |
| |
| Trained on DocTamper dataset (140K samples). |
| """, |
| theme=gr.themes.Soft(), |
| allow_flagging="never" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|