Spaces:
Sleeping
Sleeping
| """ | |
| Inference pipeline for document forgery detection | |
| Complete pipeline: Image → Localization → Regions → Classification → Output | |
| """ | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from typing import Dict, List, Optional, Tuple | |
| from pathlib import Path | |
| import json | |
| from PIL import Image | |
| import fitz # PyMuPDF | |
| from ..config import get_config | |
| from ..models import get_model | |
| from ..features import ( | |
| get_feature_extractor, | |
| get_mask_refiner, | |
| get_region_extractor | |
| ) | |
| from ..training.classifier import get_classifier | |
| class ForgeryDetectionPipeline: | |
| """ | |
| Complete inference pipeline for document forgery detection | |
| Pipeline: | |
| 1. Input handling (PDF/Image) | |
| 2. Preprocessing | |
| 3. Deep localization | |
| 4. Mask refinement | |
| 5. Region extraction | |
| 6. Feature extraction | |
| 7. Classification | |
| 8. Post-processing | |
| 9. Output generation | |
| """ | |
| def __init__(self, | |
| config, | |
| model_path: str, | |
| classifier_path: Optional[str] = None, | |
| is_text_document: bool = True): | |
| """ | |
| Initialize pipeline | |
| Args: | |
| config: Configuration object | |
| model_path: Path to localization model checkpoint | |
| classifier_path: Path to classifier (optional) | |
| is_text_document: Whether input is text document (for OCR features) | |
| """ | |
| self.config = config | |
| self.is_text_document = is_text_document | |
| # Device | |
| self.device = torch.device( | |
| 'cuda' if torch.cuda.is_available() and config.get('system.device') == 'cuda' | |
| else 'cpu' | |
| ) | |
| print(f"Inference device: {self.device}") | |
| # Load localization model | |
| self.model = get_model(config).to(self.device) | |
| self._load_model(model_path) | |
| self.model.eval() | |
| # Initialize mask refiner | |
| self.mask_refiner = get_mask_refiner(config, 'default') | |
| # Initialize region extractor | |
| self.region_extractor = get_region_extractor(config, 'default') | |
| # Initialize feature extractor | |
| self.feature_extractor = get_feature_extractor(config, is_text_document) | |
| # Load classifier if provided | |
| if classifier_path: | |
| self.classifier = get_classifier(config) | |
| self.classifier.load(classifier_path) | |
| else: | |
| self.classifier = None | |
| # Confidence threshold | |
| self.confidence_threshold = config.get('classifier.confidence_threshold', 0.6) | |
| # Image size | |
| self.image_size = config.get('data.image_size', 384) | |
| print("Inference pipeline initialized") | |
| def _load_model(self, model_path: str): | |
| """Load model checkpoint""" | |
| checkpoint = torch.load(model_path, map_location=self.device) | |
| if 'model_state_dict' in checkpoint: | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| else: | |
| self.model.load_state_dict(checkpoint) | |
| print(f"Loaded model from {model_path}") | |
| def _load_image(self, input_path: str) -> np.ndarray: | |
| """ | |
| Load image from file or PDF | |
| Args: | |
| input_path: Path to image or PDF | |
| Returns: | |
| Image as numpy array (H, W, 3) | |
| """ | |
| path = Path(input_path) | |
| if path.suffix.lower() == '.pdf': | |
| # Rasterize PDF at 300 DPI | |
| doc = fitz.open(str(path)) | |
| page = doc[0] | |
| mat = fitz.Matrix(300/72, 300/72) # 300 DPI | |
| pix = page.get_pixmap(matrix=mat) | |
| image = np.frombuffer(pix.samples, dtype=np.uint8) | |
| image = image.reshape(pix.height, pix.width, pix.n) | |
| if pix.n == 4: | |
| image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) | |
| doc.close() | |
| else: | |
| # Load image | |
| image = cv2.imread(str(path)) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| return image | |
| def _preprocess(self, image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: | |
| """ | |
| Preprocess image for inference | |
| Args: | |
| image: Input image (H, W, 3) | |
| Returns: | |
| Preprocessed image and original image | |
| """ | |
| original = image.copy() | |
| # Resize | |
| preprocessed = cv2.resize(image, (self.image_size, self.image_size)) | |
| # Normalize to [0, 1] | |
| preprocessed = preprocessed.astype(np.float32) / 255.0 | |
| return preprocessed, original | |
| def _to_tensor(self, image: np.ndarray) -> torch.Tensor: | |
| """Convert image to tensor""" | |
| # (H, W, C) -> (C, H, W) | |
| tensor = torch.from_numpy(image.transpose(2, 0, 1)) | |
| tensor = tensor.unsqueeze(0) # Add batch dimension | |
| return tensor.to(self.device) | |
| def run(self, | |
| input_path: str, | |
| output_dir: Optional[str] = None) -> Dict: | |
| """ | |
| Run full inference pipeline | |
| Args: | |
| input_path: Path to input image or PDF | |
| output_dir: Optional output directory | |
| Returns: | |
| Dictionary with results | |
| """ | |
| print(f"\n{'='*60}") | |
| print(f"Processing: {input_path}") | |
| print(f"{'='*60}") | |
| # 1. Load image | |
| image = self._load_image(input_path) | |
| original_size = image.shape[:2] | |
| print(f"Input size: {original_size}") | |
| # 2. Preprocess | |
| preprocessed, original = self._preprocess(image) | |
| tensor = self._to_tensor(preprocessed) | |
| # 3. Deep localization | |
| with torch.no_grad(): | |
| logits, decoder_features = self.model(tensor) | |
| probability_map = torch.sigmoid(logits).cpu().numpy()[0, 0] | |
| print(f"Localization complete. Max prob: {probability_map.max():.3f}") | |
| # 4. Mask refinement | |
| binary_mask = self.mask_refiner.refine(probability_map, original_size) | |
| num_positive_pixels = binary_mask.sum() | |
| print(f"Mask refinement: {num_positive_pixels} positive pixels") | |
| # 5. Region extraction | |
| # Resize probability map to original size for confidence aggregation | |
| prob_resized = cv2.resize(probability_map, (original_size[1], original_size[0])) | |
| regions = self.region_extractor.extract(binary_mask, prob_resized, original) | |
| print(f"Regions extracted: {len(regions)}") | |
| # 6. Feature extraction & 7. Classification | |
| results = [] | |
| for region in regions: | |
| # Extract features | |
| features = self.feature_extractor.extract( | |
| preprocessed, | |
| cv2.resize(region['region_mask'], (self.image_size, self.image_size)), | |
| [f.cpu() for f in decoder_features] | |
| ) | |
| # Classify if classifier available | |
| if self.classifier is not None: | |
| predictions, confidences, valid_mask = self.classifier.predict_with_filtering( | |
| features.reshape(1, -1) | |
| ) | |
| if valid_mask[0]: | |
| region['forgery_type'] = self.classifier.get_class_name(predictions[0]) | |
| region['classification_confidence'] = float(confidences[0]) | |
| else: | |
| # Low confidence - discard | |
| continue | |
| else: | |
| region['forgery_type'] = 'unknown' | |
| region['classification_confidence'] = region['confidence'] | |
| # Clean up non-serializable fields | |
| region_result = { | |
| 'region_id': region['region_id'], | |
| 'bounding_box': region['bounding_box'], | |
| 'forgery_type': region['forgery_type'], | |
| 'confidence': region['confidence'], | |
| 'classification_confidence': region['classification_confidence'], | |
| 'mask_probability_mean': region['mask_probability_mean'], | |
| 'area': region['area'] | |
| } | |
| results.append(region_result) | |
| print(f"Valid regions after filtering: {len(results)}") | |
| # 8. Post-processing - False positive removal | |
| results = self._post_process(results) | |
| # 9. Generate output | |
| output = { | |
| 'input_path': str(input_path), | |
| 'original_size': original_size, | |
| 'num_regions': len(results), | |
| 'regions': results, | |
| 'is_tampered': len(results) > 0 | |
| } | |
| # Save outputs if directory provided | |
| if output_dir: | |
| output_path = Path(output_dir) | |
| output_path.mkdir(parents=True, exist_ok=True) | |
| input_name = Path(input_path).stem | |
| # Save final mask | |
| mask_path = output_path / f'{input_name}_mask.png' | |
| cv2.imwrite(str(mask_path), binary_mask * 255) | |
| # Save overlay visualization | |
| overlay = self._create_overlay(original, binary_mask, results) | |
| overlay_path = output_path / f'{input_name}_overlay.png' | |
| cv2.imwrite(str(overlay_path), cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR)) | |
| # Save JSON | |
| json_path = output_path / f'{input_name}_results.json' | |
| with open(json_path, 'w') as f: | |
| json.dump(output, f, indent=2) | |
| print(f"\nOutputs saved to: {output_path}") | |
| output['mask_path'] = str(mask_path) | |
| output['overlay_path'] = str(overlay_path) | |
| output['json_path'] = str(json_path) | |
| return output | |
| def _post_process(self, regions: List[Dict]) -> List[Dict]: | |
| """ | |
| Post-process regions to remove false positives | |
| Args: | |
| regions: List of region dictionaries | |
| Returns: | |
| Filtered regions | |
| """ | |
| filtered = [] | |
| for region in regions: | |
| # Confidence filtering | |
| if region['confidence'] < self.confidence_threshold: | |
| continue | |
| filtered.append(region) | |
| return filtered | |
| def _create_overlay(self, | |
| image: np.ndarray, | |
| mask: np.ndarray, | |
| regions: List[Dict]) -> np.ndarray: | |
| """ | |
| Create visualization overlay | |
| Args: | |
| image: Original image | |
| mask: Binary mask | |
| regions: Detected regions | |
| Returns: | |
| Overlay image | |
| """ | |
| overlay = image.copy() | |
| alpha = self.config.get('outputs.visualization.overlay_alpha', 0.5) | |
| # Create colored mask | |
| mask_colored = np.zeros_like(image) | |
| mask_colored[mask > 0] = [255, 0, 0] # Red for forgery | |
| # Blend | |
| mask_resized = cv2.resize(mask, (image.shape[1], image.shape[0])) | |
| overlay = np.where( | |
| mask_resized[:, :, None] > 0, | |
| (1 - alpha) * image + alpha * mask_colored, | |
| image | |
| ).astype(np.uint8) | |
| # Draw bounding boxes and labels | |
| for region in regions: | |
| x, y, w, h = region['bounding_box'] | |
| # Draw rectangle | |
| cv2.rectangle(overlay, (x, y), (x + w, y + h), (0, 255, 0), 2) | |
| # Draw label | |
| label = f"{region['forgery_type']} ({region['confidence']:.2f})" | |
| cv2.putText(overlay, label, (x, y - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) | |
| return overlay | |
| def get_pipeline(config, | |
| model_path: str, | |
| classifier_path: Optional[str] = None, | |
| is_text_document: bool = True) -> ForgeryDetectionPipeline: | |
| """Factory function for pipeline""" | |
| return ForgeryDetectionPipeline(config, model_path, classifier_path, is_text_document) | |