""" DTD DocTamper Inference Module Simplified for Gradio deployment on Hugging Face Uses scipy DCT as alternative to jpegio (which has compilation issues) """ import os import sys import cv2 import torch import numpy as np import pickle import tempfile from PIL import Image from scipy.fftpack import dct import torchvision.transforms as transforms # Add models directory to path to avoid circular imports sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'models')) # Import compatibility fixes and model import fix_imports import patch_gelu import patch_droppath from dtd import seg_dtd class DTDPredictor: def __init__(self, checkpoint_path='checkpoints/dtd_doctamper.pth', device='cpu'): """ Initialize DTD model for inference Args: checkpoint_path: Path to model checkpoint device: Device to use ('cpu', 'cuda', or 'auto') """ # Auto-detect device if device == 'auto': if torch.cuda.is_available(): self.device = 'cuda' else: self.device = 'cpu' else: self.device = device print(f'Using device: {self.device}') # Load QT table with open('checkpoints/qt_table.pk', 'rb') as fpk: pks = pickle.load(fpk) self.pks = {} for k, v in pks.items(): self.pks[k] = torch.LongTensor(v) # Image transforms self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.455, 0.406), std=(0.229, 0.224, 0.225)) ]) # Load model self.model = seg_dtd('', 2, device=self.device) if self.device == 'cuda': self.model = self.model.cuda() self.model = self.model.to(self.device) # Load checkpoint ckpt = torch.load(checkpoint_path, map_location='cpu') state_dict = ckpt['state_dict'] # Remove 'module.' prefix if present new_state_dict = {} for k, v in state_dict.items(): if k.startswith('module.'): new_state_dict[k[7:]] = v else: new_state_dict[k] = v self.model.load_state_dict(new_state_dict) self.model.eval() print('Model loaded successfully!') def extract_dct(self, image_array, quality=90): """ Extract DCT coefficients using scipy (alternative to jpegio) Uses Y channel from YCbCr like actual JPEG compression NOTE: This extracts UNQUANTIZED DCT from raw image data, which differs from jpegio that reads quantized DCT from JPEG files. The model was trained on quantized coefficients, so results may vary. Args: image_array: RGB image as numpy array (must be 8x8 aligned) quality: JPEG quality (used for QT approximation) Returns: DCT coefficients and quantization table """ # Convert RGB to YCbCr and extract Y channel (luminance) # This matches what JPEG compression does im_ycbcr = cv2.cvtColor(image_array, cv2.COLOR_RGB2YCrCb) y_channel = im_ycbcr[:, :, 0].astype(np.float32) - 128 # Center around 0 # Image should already be 8x8 aligned h, w = y_channel.shape assert h % 8 == 0 and w % 8 == 0, f"Image must be 8x8 aligned, got {h}x{w}" # Compute DCT for each 8x8 block dct_coeffs = np.zeros((h, w), dtype=np.float32) for i in range(0, h, 8): for j in range(0, w, 8): block = y_channel[i:i+8, j:j+8] dct_block = dct(dct(block.T, norm='ortho').T, norm='ortho') dct_coeffs[i:i+8, j:j+8] = dct_block # Generate standard JPEG quantization table for given quality qt = self.jpeg_quantization_table(quality) return dct_coeffs, qt def jpeg_quantization_table(self, quality): """ Generate standard JPEG quantization table for given quality Args: quality: JPEG quality (1-100) Returns: Flattened quantization table (64 values) """ # Standard JPEG luminance quantization table (quality 50) base_table = np.array([ 16, 11, 10, 16, 24, 40, 51, 61, 12, 12, 14, 19, 26, 58, 60, 55, 14, 13, 16, 24, 40, 57, 69, 56, 14, 17, 22, 29, 51, 87, 80, 62, 18, 22, 37, 56, 68, 109, 103, 77, 24, 35, 55, 64, 81, 104, 113, 92, 49, 64, 78, 87, 103, 121, 120, 101, 72, 92, 95, 98, 112, 100, 103, 99 ]) # Scale based on quality if quality < 50: scale = 5000 / quality else: scale = 200 - 2 * quality qt = np.floor((base_table * scale + 50) / 100) qt = np.clip(qt, 1, 255) return qt.astype(int) def crop_image(self, img, dct, crop_size=512): """ Crop image and DCT into 512x512 patches WITHOUT resize Preserves Block Artifact Grids (BAGs) """ h, w = img.shape[:2] h_grids = h // crop_size w_grids = w // crop_size crops = [] positions = [] # Regular grid for h_idx in range(h_grids): for w_idx in range(w_grids): y1 = h_idx * crop_size x1 = w_idx * crop_size y2 = y1 + crop_size x2 = x1 + crop_size crops.append({ 'img': img[y1:y2, x1:x2], 'dct': dct[y1:y2, x1:x2] }) positions.append((y1, x1, y2, x2)) # Right edge (overlapping) if w % crop_size != 0: for h_idx in range(h_grids): y1 = h_idx * crop_size y2 = y1 + crop_size x1 = w - crop_size x2 = w crops.append({ 'img': img[y1:y2, x1:x2], 'dct': dct[y1:y2, x1:x2] }) positions.append((y1, x1, y2, x2)) # Bottom edge (overlapping) if h % crop_size != 0: for w_idx in range(w_grids): x1 = w_idx * crop_size x2 = x1 + crop_size y1 = h - crop_size y2 = h crops.append({ 'img': img[y1:y2, x1:x2], 'dct': dct[y1:y2, x1:x2] }) positions.append((y1, x1, y2, x2)) # Bottom-right corner (overlapping) if w % crop_size != 0 and h % crop_size != 0: crops.append({ 'img': img[h-crop_size:h, w-crop_size:w], 'dct': dct[h-crop_size:h, w-crop_size:w] }) positions.append((h-crop_size, w-crop_size, h, w)) return crops, positions, h_grids, w_grids @torch.no_grad() def predict(self, image_path, quality=90): """ Predict tampering mask for input image Uses patch-based processing WITHOUT resize to preserve BAGs Args: image_path: Path to input JPEG image quality: JPEG quality for DCT extraction Returns: Dictionary containing: - original: Original image (numpy array) - mask: Binary tampering mask (numpy array) - heatmap: Colorized heatmap overlay """ # Load original image and save original dimensions im_orig = Image.open(image_path).convert('RGB') im_orig_np = np.array(im_orig) true_orig_h, true_orig_w = im_orig_np.shape[:2] # Align to 8x8 grid to preserve JPEG block structure h_aligned = (true_orig_h // 8) * 8 w_aligned = (true_orig_w // 8) * 8 # Pad if too small (for processing only) img_to_process = im_orig_np.copy() if h_aligned < 512 or w_aligned < 512: pad_h = max(512 - h_aligned, 0) pad_w = max(512 - w_aligned, 0) img_to_process = np.pad(img_to_process, ((0, pad_h), (0, pad_w), (0, 0)), 'constant', constant_values=255) h_aligned = (img_to_process.shape[0] // 8) * 8 w_aligned = (img_to_process.shape[1] // 8) * 8 img_aligned = img_to_process[:h_aligned, :w_aligned] # Extract DCT from raw aligned image (NO JPEG save/load!) dct, qt = self.extract_dct(img_aligned, quality) # Ensure dimensions match dct_aligned = dct[:h_aligned, :w_aligned] # Crop into 512x512 patches crops, positions, h_grids, w_grids = self.crop_image(img_aligned, dct_aligned) # Process each patch full_mask = np.zeros((h_aligned, w_aligned), dtype=np.float32) count_map = np.zeros((h_aligned, w_aligned), dtype=np.int32) for crop, (y1, x1, y2, x2) in zip(crops, positions): # Prepare inputs img_pil = Image.fromarray(crop['img']) img_tensor = self.transform(img_pil).unsqueeze(0).to(self.device) dct_tensor = torch.from_numpy( np.clip(np.abs(crop['dct']), 0, 20).astype(np.int64) ).unsqueeze(0).long().to(self.device) qt_flat = qt.flatten()[:64] qt_tensor = torch.LongTensor(qt_flat).unsqueeze(0).to(self.device) qt_tensor = qt_tensor.reshape(1, 1, 8, 8) # Forward pass output = self.model(img_tensor, dct_tensor, qt_tensor) pred_mask = output.argmax(1).squeeze().cpu().numpy() # Accumulate predictions (averaging overlapping regions) full_mask[y1:y2, x1:x2] += pred_mask count_map[y1:y2, x1:x2] += 1 # Average overlapping regions full_mask = full_mask / np.maximum(count_map, 1) final_mask = (full_mask > 0.5).astype(np.uint8) # Pad mask back to true original size if it was 8x8 aligned smaller if final_mask.shape[0] < true_orig_h or final_mask.shape[1] < true_orig_w: padded_mask = np.zeros((true_orig_h, true_orig_w), dtype=np.uint8) padded_mask[:final_mask.shape[0], :final_mask.shape[1]] = final_mask final_mask = padded_mask else: # Crop if somehow larger (shouldn't happen) final_mask = final_mask[:true_orig_h, :true_orig_w] # Create heatmap overlay with original image (no padding) heatmap = self.create_heatmap(im_orig_np, final_mask) return { 'original': im_orig_np, 'mask': (final_mask * 255).astype(np.uint8), 'heatmap': heatmap } def create_heatmap(self, image, mask): """ Create colorized heatmap overlay Args: image: Original image (numpy array) mask: Binary mask (numpy array) Returns: Heatmap overlay (numpy array) """ # Create colored mask colored_mask = np.zeros_like(image) colored_mask[mask == 1] = [255, 0, 0] # Red for tampered regions # Blend with original image alpha = 0.5 heatmap = cv2.addWeighted(image, 1 - alpha, colored_mask, alpha, 0) return heatmap