JKrishnanandhaa's picture
Upload 54 files
ff0e79e verified
"""
Inference pipeline for document forgery detection
Complete pipeline: Image → Localization → Regions → Classification → Output
"""
import cv2
import numpy as np
import torch
from typing import Dict, List, Optional, Tuple
from pathlib import Path
import json
from PIL import Image
import fitz # PyMuPDF
from ..config import get_config
from ..models import get_model
from ..features import (
get_feature_extractor,
get_mask_refiner,
get_region_extractor
)
from ..training.classifier import get_classifier
class ForgeryDetectionPipeline:
"""
Complete inference pipeline for document forgery detection
Pipeline:
1. Input handling (PDF/Image)
2. Preprocessing
3. Deep localization
4. Mask refinement
5. Region extraction
6. Feature extraction
7. Classification
8. Post-processing
9. Output generation
"""
def __init__(self,
config,
model_path: str,
classifier_path: Optional[str] = None,
is_text_document: bool = True):
"""
Initialize pipeline
Args:
config: Configuration object
model_path: Path to localization model checkpoint
classifier_path: Path to classifier (optional)
is_text_document: Whether input is text document (for OCR features)
"""
self.config = config
self.is_text_document = is_text_document
# Device
self.device = torch.device(
'cuda' if torch.cuda.is_available() and config.get('system.device') == 'cuda'
else 'cpu'
)
print(f"Inference device: {self.device}")
# Load localization model
self.model = get_model(config).to(self.device)
self._load_model(model_path)
self.model.eval()
# Initialize mask refiner
self.mask_refiner = get_mask_refiner(config, 'default')
# Initialize region extractor
self.region_extractor = get_region_extractor(config, 'default')
# Initialize feature extractor
self.feature_extractor = get_feature_extractor(config, is_text_document)
# Load classifier if provided
if classifier_path:
self.classifier = get_classifier(config)
self.classifier.load(classifier_path)
else:
self.classifier = None
# Confidence threshold
self.confidence_threshold = config.get('classifier.confidence_threshold', 0.6)
# Image size
self.image_size = config.get('data.image_size', 384)
print("Inference pipeline initialized")
def _load_model(self, model_path: str):
"""Load model checkpoint"""
checkpoint = torch.load(model_path, map_location=self.device)
if 'model_state_dict' in checkpoint:
self.model.load_state_dict(checkpoint['model_state_dict'])
else:
self.model.load_state_dict(checkpoint)
print(f"Loaded model from {model_path}")
def _load_image(self, input_path: str) -> np.ndarray:
"""
Load image from file or PDF
Args:
input_path: Path to image or PDF
Returns:
Image as numpy array (H, W, 3)
"""
path = Path(input_path)
if path.suffix.lower() == '.pdf':
# Rasterize PDF at 300 DPI
doc = fitz.open(str(path))
page = doc[0]
mat = fitz.Matrix(300/72, 300/72) # 300 DPI
pix = page.get_pixmap(matrix=mat)
image = np.frombuffer(pix.samples, dtype=np.uint8)
image = image.reshape(pix.height, pix.width, pix.n)
if pix.n == 4:
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
doc.close()
else:
# Load image
image = cv2.imread(str(path))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image
def _preprocess(self, image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Preprocess image for inference
Args:
image: Input image (H, W, 3)
Returns:
Preprocessed image and original image
"""
original = image.copy()
# Resize
preprocessed = cv2.resize(image, (self.image_size, self.image_size))
# Normalize to [0, 1]
preprocessed = preprocessed.astype(np.float32) / 255.0
return preprocessed, original
def _to_tensor(self, image: np.ndarray) -> torch.Tensor:
"""Convert image to tensor"""
# (H, W, C) -> (C, H, W)
tensor = torch.from_numpy(image.transpose(2, 0, 1))
tensor = tensor.unsqueeze(0) # Add batch dimension
return tensor.to(self.device)
def run(self,
input_path: str,
output_dir: Optional[str] = None) -> Dict:
"""
Run full inference pipeline
Args:
input_path: Path to input image or PDF
output_dir: Optional output directory
Returns:
Dictionary with results
"""
print(f"\n{'='*60}")
print(f"Processing: {input_path}")
print(f"{'='*60}")
# 1. Load image
image = self._load_image(input_path)
original_size = image.shape[:2]
print(f"Input size: {original_size}")
# 2. Preprocess
preprocessed, original = self._preprocess(image)
tensor = self._to_tensor(preprocessed)
# 3. Deep localization
with torch.no_grad():
logits, decoder_features = self.model(tensor)
probability_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
print(f"Localization complete. Max prob: {probability_map.max():.3f}")
# 4. Mask refinement
binary_mask = self.mask_refiner.refine(probability_map, original_size)
num_positive_pixels = binary_mask.sum()
print(f"Mask refinement: {num_positive_pixels} positive pixels")
# 5. Region extraction
# Resize probability map to original size for confidence aggregation
prob_resized = cv2.resize(probability_map, (original_size[1], original_size[0]))
regions = self.region_extractor.extract(binary_mask, prob_resized, original)
print(f"Regions extracted: {len(regions)}")
# 6. Feature extraction & 7. Classification
results = []
for region in regions:
# Extract features
features = self.feature_extractor.extract(
preprocessed,
cv2.resize(region['region_mask'], (self.image_size, self.image_size)),
[f.cpu() for f in decoder_features]
)
# Classify if classifier available
if self.classifier is not None:
predictions, confidences, valid_mask = self.classifier.predict_with_filtering(
features.reshape(1, -1)
)
if valid_mask[0]:
region['forgery_type'] = self.classifier.get_class_name(predictions[0])
region['classification_confidence'] = float(confidences[0])
else:
# Low confidence - discard
continue
else:
region['forgery_type'] = 'unknown'
region['classification_confidence'] = region['confidence']
# Clean up non-serializable fields
region_result = {
'region_id': region['region_id'],
'bounding_box': region['bounding_box'],
'forgery_type': region['forgery_type'],
'confidence': region['confidence'],
'classification_confidence': region['classification_confidence'],
'mask_probability_mean': region['mask_probability_mean'],
'area': region['area']
}
results.append(region_result)
print(f"Valid regions after filtering: {len(results)}")
# 8. Post-processing - False positive removal
results = self._post_process(results)
# 9. Generate output
output = {
'input_path': str(input_path),
'original_size': original_size,
'num_regions': len(results),
'regions': results,
'is_tampered': len(results) > 0
}
# Save outputs if directory provided
if output_dir:
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
input_name = Path(input_path).stem
# Save final mask
mask_path = output_path / f'{input_name}_mask.png'
cv2.imwrite(str(mask_path), binary_mask * 255)
# Save overlay visualization
overlay = self._create_overlay(original, binary_mask, results)
overlay_path = output_path / f'{input_name}_overlay.png'
cv2.imwrite(str(overlay_path), cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
# Save JSON
json_path = output_path / f'{input_name}_results.json'
with open(json_path, 'w') as f:
json.dump(output, f, indent=2)
print(f"\nOutputs saved to: {output_path}")
output['mask_path'] = str(mask_path)
output['overlay_path'] = str(overlay_path)
output['json_path'] = str(json_path)
return output
def _post_process(self, regions: List[Dict]) -> List[Dict]:
"""
Post-process regions to remove false positives
Args:
regions: List of region dictionaries
Returns:
Filtered regions
"""
filtered = []
for region in regions:
# Confidence filtering
if region['confidence'] < self.confidence_threshold:
continue
filtered.append(region)
return filtered
def _create_overlay(self,
image: np.ndarray,
mask: np.ndarray,
regions: List[Dict]) -> np.ndarray:
"""
Create visualization overlay
Args:
image: Original image
mask: Binary mask
regions: Detected regions
Returns:
Overlay image
"""
overlay = image.copy()
alpha = self.config.get('outputs.visualization.overlay_alpha', 0.5)
# Create colored mask
mask_colored = np.zeros_like(image)
mask_colored[mask > 0] = [255, 0, 0] # Red for forgery
# Blend
mask_resized = cv2.resize(mask, (image.shape[1], image.shape[0]))
overlay = np.where(
mask_resized[:, :, None] > 0,
(1 - alpha) * image + alpha * mask_colored,
image
).astype(np.uint8)
# Draw bounding boxes and labels
for region in regions:
x, y, w, h = region['bounding_box']
# Draw rectangle
cv2.rectangle(overlay, (x, y), (x + w, y + h), (0, 255, 0), 2)
# Draw label
label = f"{region['forgery_type']} ({region['confidence']:.2f})"
cv2.putText(overlay, label, (x, y - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
return overlay
def get_pipeline(config,
model_path: str,
classifier_path: Optional[str] = None,
is_text_document: bool = True) -> ForgeryDetectionPipeline:
"""Factory function for pipeline"""
return ForgeryDetectionPipeline(config, model_path, classifier_path, is_text_document)