"""Test script for evaluating a trained DetectiveSAM forgery localizer model on a single image.""" from __future__ import annotations import argparse import json import logging import os import sys import tempfile import shutil from typing import Dict, Any, Tuple from pathlib import Path import cv2 import matplotlib.pyplot as plt import numpy as np import torch import torch.nn.functional as F from PIL import Image from model.forgerylocalizer import ForgeryLocalizer from utils.localforgerydataset import LocalForgeryDataset from utils.sam_utils import initialize_sam_hydra, get_sam_config_from_json # Initialize Hydra configuration for SAM2 initialize_sam_hydra() logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) def normalize_for_display(tensor): """Normalize tensor for display (0-1 range)""" tensor = tensor.clone() if tensor.dim() == 3: # [C, H, W] for c in range(tensor.shape[0]): channel = tensor[c] channel = (channel - channel.min()) / (channel.max() - channel.min() + 1e-8) tensor[c] = channel else: # [H, W] tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min() + 1e-8) return tensor.clamp(0, 1) def load_model_config(json_path: str) -> Dict[str, Any]: """Load model configuration from JSON file. Args: json_path: Path to the JSON configuration file Returns: Dictionary containing model configuration """ with open(json_path, 'r') as f: config = json.load(f) return config def get_max_streams(contrastive_blur: bool, perturbation_type: str) -> int: """Determine the maximum number of streams based on contrastive mode and perturbation type. Args: contrastive_blur: Whether contrastive mode is enabled perturbation_type: Type of perturbation being applied Returns: Maximum number of streams the model should expect """ if contrastive_blur: # Contrastive mode: always has streams (sharp/clean + perturbations) if perturbation_type == "none": return 0 # [clean, clean] (identical for consistency) elif perturbation_type in ["gaussian_blur", "jpeg_compression", "gaussian_noise"]: return 2 # [sharp/clean, perturbation] elif perturbation_type == "gaussian_blur/gaussian_noise": return 3 # [sharp, blur, noise] else: # Non-contrastive mode: only perturbations in streams if perturbation_type == "none": return 0 # Empty streams list - model will use orig as fallback elif perturbation_type in ["gaussian_blur", "jpeg_compression", "gaussian_noise"]: return 1 # Single perturbation elif perturbation_type == "gaussian_blur/gaussian_noise": return 2 # Two perturbations [blur, noise] return 0 def load_and_initialize_model( config: Dict[str, Any], checkpoint_path: str, device: torch.device ) -> ForgeryLocalizer: """Load and initialize the forgery localizer model. Args: config: Model configuration dictionary checkpoint_path: Path to the model checkpoint (.pth file) device: Device to load the model on Returns: Initialized ForgeryLocalizer model """ model_config = config['model_config'] sam_config_dict = config['sam_config'] data_config = config['data_config'] # Get max_streams from configuration contrastive_blur = model_config.get('contrastive_blur', False) perturbation_type = data_config.get('perturbation_type', 'none') max_streams = get_max_streams(contrastive_blur, perturbation_type) # Determine use_detection_probe based on authentic_ratio authentic_ratio = data_config.get('authentic_ratio', 0.0) use_detection_probe = authentic_ratio > 0.0 logger.info(f"Initializing model with max_streams={max_streams}, use_detection_probe={use_detection_probe}") # Get the directory of the test script to resolve relative paths script_dir = os.path.dirname(os.path.abspath(__file__)) # Resolve SAM config and checkpoint paths using utility function sam_config_file, sam_checkpoint = get_sam_config_from_json(config, script_dir) logger.info(f"SAM config: {sam_config_file}") logger.info(f"SAM checkpoint: {sam_checkpoint}") # Initialize model model = ForgeryLocalizer( sam_config=sam_config_file, sam_checkpoint=sam_checkpoint, prompt_dim=model_config['prompt_dim'], downscale=model_config['downscale'], train_sam_iou=model_config.get('train_sam_iou', True), dropout_rate=model_config['dropout_rate'], max_streams=max_streams, use_detection_probe=use_detection_probe, ).to(device) # Load checkpoint logger.info(f"Loading checkpoint from {checkpoint_path}") checkpoint_data = torch.load(checkpoint_path, map_location=device, weights_only=False) if "model" in checkpoint_data: missing_keys, unexpected_keys = model.load_state_dict(checkpoint_data["model"], strict=False) if missing_keys: logger.warning(f"Missing keys when loading checkpoint: {missing_keys}") if unexpected_keys: logger.warning(f"Unexpected keys when loading checkpoint: {unexpected_keys}") else: # Checkpoint might be just the state dict missing_keys, unexpected_keys = model.load_state_dict(checkpoint_data, strict=False) if missing_keys: logger.warning(f"Missing keys when loading checkpoint: {missing_keys}") if unexpected_keys: logger.warning(f"Unexpected keys when loading checkpoint: {unexpected_keys}") logger.info(f"Loaded checkpoint - epoch: {checkpoint_data.get('epoch', 'N/A')}, score: {checkpoint_data.get('score', 'N/A')}") # Set model to eval mode model.eval() model.encoder.eval() model.decoder.eval() model.sam_prompt_encoder.eval() return model def create_temp_dataset_structure( image_path: str, mask_path: str = None, source_path: str = None ) -> str: """Create a temporary dataset structure for LocalForgeryDataset. Args: image_path: Path to the input image (target) mask_path: Optional path to ground truth mask source_path: Optional path to source image (original unedited) Returns: Path to temporary dataset directory """ # Create temporary directory temp_dir = tempfile.mkdtemp(prefix="detective_sam_test_") # Create dataset structure target_dir = os.path.join(temp_dir, "target") mask_dir = os.path.join(temp_dir, "mask") source_dir = os.path.join(temp_dir, "source") os.makedirs(target_dir, exist_ok=True) os.makedirs(mask_dir, exist_ok=True) os.makedirs(source_dir, exist_ok=True) # Copy image to target directory with a standard name img_name = "test_image.png" shutil.copy(image_path, os.path.join(target_dir, img_name)) # Copy or create dummy mask if mask_path and os.path.exists(mask_path): shutil.copy(mask_path, os.path.join(mask_dir, img_name)) else: # Create a dummy mask (all zeros) if no mask provided img = Image.open(image_path) dummy_mask = Image.new('L', img.size, 0) dummy_mask.save(os.path.join(mask_dir, img_name)) # Copy source image or use target as source if source_path and os.path.exists(source_path): shutil.copy(source_path, os.path.join(source_dir, img_name)) else: # Use target image as source if no source provided shutil.copy(image_path, os.path.join(source_dir, img_name)) return temp_dir def load_sample_from_dataset( image_path: str, mask_path: str, source_path: str, img_size: int, perturbation_type: str, perturbation_intensity: float, contrastive_blur: bool ) -> Tuple[torch.Tensor, list, torch.Tensor, torch.Tensor, np.ndarray, np.ndarray, np.ndarray]: """Load and preprocess a sample using LocalForgeryDataset. Args: image_path: Path to input image (target) mask_path: Path to ground truth mask (optional) source_path: Path to source image (original unedited) img_size: Target image size perturbation_type: Type of perturbation to apply perturbation_intensity: Intensity of perturbation contrastive_blur: Whether to use contrastive blur Returns: Tuple of (orig_tensor, streams_list, mask_tensor, source_tensor, original_image_np, source_image_np, mask_np) """ # Create temporary dataset structure temp_dir = create_temp_dataset_structure(image_path, mask_path, source_path) try: # Create dataset dataset = LocalForgeryDataset( root_dir=temp_dir, img_size=img_size, allow_multiple_targets=False, contrastive_blur=contrastive_blur, is_training=False, # Use validation mode (center crop) perturbation_type=perturbation_type, perturbation_intensity=perturbation_intensity, authentic_ratio=0.0, authentic_source_dir=None, ) # Get the single sample if len(dataset) == 0: raise ValueError("Dataset is empty - check image paths") sample = dataset[0] # Extract data from sample orig_tensor = sample['orig'] # [3, H, W] - target streams = sample['streams'] # List of [3, H, W] tensors mask_tensor = sample['mask'] # [1, H, W] source_tensor = sample['source'] # [3, H, W] - source # Load original image for visualization orig_img = np.array(Image.open(image_path).convert('RGB')) # Convert source tensor to numpy for visualization using proper normalization source_normalized = normalize_for_display(source_tensor) source_img = source_normalized.permute(1, 2, 0).cpu().numpy() # Convert to 0-255 range for display source_img = (source_img * 255).clip(0, 255).astype(np.uint8) # Load ground truth mask for visualization if mask_path and os.path.exists(mask_path): mask_img = np.array(Image.open(mask_path).convert('L')) else: mask_img = np.zeros(orig_img.shape[:2], dtype=np.uint8) return orig_tensor, streams, mask_tensor, source_tensor, orig_img, source_img, mask_img finally: # Clean up temporary directory shutil.rmtree(temp_dir, ignore_errors=True) def visualize_results( source_img: np.ndarray, target_img: np.ndarray, prediction: np.ndarray, ground_truth: np.ndarray = None, save_path: str = None, detection_prob: float = None ): """Visualize the prediction results. Args: source_img: Source image target_img: Target image prediction: Binary prediction mask ground_truth: Optional ground truth mask save_path: Path to save the visualization detection_prob: Optional detection probability """ fig, axes = plt.subplots(1, 4, figsize=(20, 5)) # Source image axes[0].imshow(source_img) axes[0].set_title('Source') axes[0].axis('off') # Target image axes[1].imshow(target_img) axes[1].set_title('Target') axes[1].axis('off') # Ground truth mask over target axes[2].imshow(target_img) if ground_truth is not None and ground_truth.max() > 0: gt_overlay = np.zeros((*ground_truth.shape, 4)) gt_overlay[ground_truth > 0] = [0, 1, 0, 0.5] # Green with 50% transparency axes[2].imshow(gt_overlay) axes[2].set_title('GT Mask over Target') axes[2].axis('off') # Prediction mask over target axes[3].imshow(target_img) if prediction.max() > 0: pred_overlay = np.zeros((*prediction.shape, 4)) pred_overlay[prediction > 0] = [1, 0, 0, 0.5] # Red with 50% transparency axes[3].imshow(pred_overlay) title = 'Prediction over Target' if detection_prob is not None: title += f'\n(Detection Prob: {detection_prob:.3f})' axes[3].set_title(title) axes[3].axis('off') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') logger.info(f"Saved visualization to {save_path}") plt.show() def compute_metrics(prediction: np.ndarray, ground_truth: np.ndarray) -> Dict[str, float]: """Compute evaluation metrics. Args: prediction: Binary prediction mask ground_truth: Binary ground truth mask Returns: Dictionary of metrics """ pred_flat = prediction.flatten() gt_flat = ground_truth.flatten() # Compute IoU intersection = np.logical_and(pred_flat, gt_flat).sum() union = np.logical_or(pred_flat, gt_flat).sum() iou = intersection / (union + 1e-8) # Compute precision, recall, F1 tp = intersection fp = np.logical_and(pred_flat, ~gt_flat).sum() fn = np.logical_and(~pred_flat, gt_flat).sum() precision = tp / (tp + fp + 1e-8) recall = tp / (tp + fn + 1e-8) f1 = 2 * (precision * recall) / (precision + recall + 1e-8) return { 'iou': iou, 'precision': precision, 'recall': recall, 'f1': f1 } def main(): parser = argparse.ArgumentParser(description='Test DetectiveSAM forgery localizer on a single image') parser.add_argument('--image', type=str, required=True, help='Path to input image (target/edited image)') parser.add_argument('--source', type=str, default=None, help='Path to source image (original unedited image). If not provided, will use target image as source.') parser.add_argument('--model', type=str, required=True, help='Path to model checkpoint (.pth file)') parser.add_argument('--config', type=str, required=True, help='Path to model configuration JSON file') parser.add_argument('--mask', type=str, default=None, help='Optional path to ground truth mask for evaluation') parser.add_argument('--output', type=str, default='result.png', help='Path to save output visualization') parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to run inference on (cuda/cpu)') parser.add_argument('--threshold', type=float, default=0.5, help='Threshold for binary prediction (default: 0.5)') args = parser.parse_args() # Setup device device = torch.device(args.device) logger.info(f"Using device: {device}") # Load configuration logger.info(f"Loading configuration from {args.config}") config = load_model_config(args.config) # Load model model = load_and_initialize_model(config, args.model, device) # Get data configuration data_config = config['data_config'] training_config = config['training_config'] model_config = config['model_config'] img_size = training_config.get('img_size', 512) perturbation_type = data_config.get('perturbation_type', 'none') perturbation_intensity = data_config.get('perturbation_intensity', 0.5) contrastive_blur = model_config.get('contrastive_blur', False) logger.info(f"Configuration: img_size={img_size}, perturbation_type={perturbation_type}, " f"perturbation_intensity={perturbation_intensity}, contrastive_blur={contrastive_blur}") # Load and preprocess using LocalForgeryDataset for consistency with training logger.info(f"Loading image from {args.image} using LocalForgeryDataset") if args.source: logger.info(f"Using source image from {args.source}") orig_tensor, streams, mask_tensor, source_tensor, orig_img, source_img, mask_img = load_sample_from_dataset( image_path=args.image, mask_path=args.mask, source_path=args.source, img_size=img_size, perturbation_type=perturbation_type, perturbation_intensity=perturbation_intensity, contrastive_blur=contrastive_blur ) # Prepare batch orig_batch = orig_tensor.unsqueeze(0).to(device) # [1, 3, H, W] streams_batch = [s.unsqueeze(0).to(device) for s in streams] # List of [1, 3, H, W] logger.info(f"Running inference with {len(streams_batch)} stream(s)...") # Run inference with torch.no_grad(): with torch.amp.autocast(device_type=device.type): outputs = model(orig_batch, streams_batch, output_extras=True) if isinstance(outputs, tuple): logits, extras = outputs else: logits = outputs extras = {} # Get probability map probs = torch.sigmoid(logits) # Get detection probability if available detection_prob = None if 'detection_logit' in extras and extras['detection_logit'] is not None: detection_logit = extras['detection_logit'] detection_prob = torch.sigmoid(detection_logit).item() logger.info(f"Detection probability: {detection_prob:.4f}") # Convert to numpy probs_np = probs[0, 0].cpu().numpy() pred_binary = (probs_np > args.threshold).astype(np.uint8) logger.info(f"Prediction shape: {pred_binary.shape}") logger.info(f"Forgery coverage: {pred_binary.sum() / pred_binary.size * 100:.2f}%") # Process ground truth mask from dataset (already at correct size) ground_truth = None if args.mask: logger.info(f"Using ground truth mask from dataset") ground_truth = mask_tensor[0].cpu().numpy().astype(np.uint8) # [H, W] # Compute metrics on the 512x512 versions metrics = compute_metrics(pred_binary, ground_truth) logger.info(f"Metrics: IoU={metrics['iou']:.4f}, Precision={metrics['precision']:.4f}, " f"Recall={metrics['recall']:.4f}, F1={metrics['f1']:.4f}") # Resize prediction, GT, and source to original image size for visualization if orig_img.shape[:2] != pred_binary.shape: pred_resized = cv2.resize(pred_binary, (orig_img.shape[1], orig_img.shape[0]), interpolation=cv2.INTER_NEAREST) source_resized = cv2.resize(source_img, (orig_img.shape[1], orig_img.shape[0]), interpolation=cv2.INTER_LINEAR) if ground_truth is not None: gt_resized = cv2.resize(ground_truth, (orig_img.shape[1], orig_img.shape[0]), interpolation=cv2.INTER_NEAREST) else: gt_resized = None else: pred_resized = pred_binary source_resized = source_img gt_resized = ground_truth # Visualize results logger.info("Generating visualization...") visualize_results(source_resized, orig_img, pred_resized, gt_resized, args.output, detection_prob) logger.info("Done!") if __name__ == '__main__': main()