| """
|
| Dataset loaders for document forgery detection
|
| Implements Critical Fix #7: Image-level train/test splits
|
| """
|
|
|
| import os
|
| import lmdb
|
| import cv2
|
| import numpy as np
|
| import torch
|
| from torch.utils.data import Dataset
|
| from pathlib import Path
|
| from typing import Tuple, Optional, List
|
| import json
|
| from PIL import Image
|
|
|
| from .preprocessing import DocumentPreprocessor
|
| from .augmentation import DatasetAwareAugmentation
|
|
|
|
|
| class DocTamperDataset(Dataset):
|
| """
|
| DocTamper dataset loader (LMDB-based)
|
| Implements chunked loading for RAM constraints
|
| Uses lazy LMDB initialization for multiprocessing compatibility
|
| """
|
|
|
| def __init__(self,
|
| config,
|
| split: str = 'train',
|
| chunk_start: float = 0.0,
|
| chunk_end: float = 1.0):
|
| """
|
| Initialize DocTamper dataset
|
|
|
| Args:
|
| config: Configuration object
|
| split: 'train' or 'val'
|
| chunk_start: Start ratio for chunked training (0.0 to 1.0)
|
| chunk_end: End ratio for chunked training (0.0 to 1.0)
|
| """
|
| self.config = config
|
| self.split = split
|
| self.dataset_name = 'doctamper'
|
|
|
|
|
| dataset_config = config.get_dataset_config(self.dataset_name)
|
| self.data_path = Path(dataset_config['path'])
|
|
|
|
|
| if split == 'train':
|
| lmdb_folder = 'DocTamperV1-TrainingSet'
|
| elif split == 'val' or split == 'test':
|
| lmdb_folder = 'DocTamperV1-TestingSet'
|
| else:
|
| lmdb_folder = 'DocTamperV1-TrainingSet'
|
|
|
| self.lmdb_path = str(self.data_path / lmdb_folder)
|
|
|
| if not Path(self.lmdb_path).exists():
|
| raise FileNotFoundError(f"LMDB folder not found: {self.lmdb_path}")
|
|
|
|
|
|
|
| temp_env = lmdb.open(self.lmdb_path, readonly=True, lock=False)
|
| with temp_env.begin() as txn:
|
| stat = txn.stat()
|
| self.length = stat['entries'] // 2
|
| temp_env.close()
|
|
|
|
|
| self._env = None
|
|
|
|
|
| self.chunk_start = int(self.length * chunk_start)
|
| self.chunk_end = int(self.length * chunk_end)
|
| self.chunk_length = self.chunk_end - self.chunk_start
|
|
|
| print(f"DocTamper {split}: Total={self.length}, "
|
| f"Chunk=[{self.chunk_start}:{self.chunk_end}], "
|
| f"Length={self.chunk_length}")
|
|
|
|
|
| self.preprocessor = DocumentPreprocessor(config, self.dataset_name)
|
| self.augmentation = DatasetAwareAugmentation(
|
| config,
|
| self.dataset_name,
|
| is_training=(split == 'train')
|
| )
|
|
|
| @property
|
| def env(self):
|
| """Lazy LMDB environment initialization for multiprocessing compatibility"""
|
| if self._env is None:
|
| self._env = lmdb.open(self.lmdb_path, readonly=True, lock=False,
|
| max_readers=32, readahead=False)
|
| return self._env
|
|
|
| def __len__(self) -> int:
|
| return self.chunk_length
|
|
|
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
| """
|
| Get item from dataset
|
|
|
| Args:
|
| idx: Index within chunk
|
|
|
| Returns:
|
| image: (3, H, W) tensor
|
| mask: (1, H, W) tensor
|
| metadata: Dictionary with additional info
|
| """
|
|
|
| max_attempts = 10
|
| original_idx = idx
|
|
|
| for attempt in range(max_attempts):
|
| try:
|
|
|
| global_idx = self.chunk_start + idx
|
|
|
|
|
| with self.env.begin() as txn:
|
|
|
| img_key = f'image-{global_idx:09d}'.encode()
|
| label_key = f'label-{global_idx:09d}'.encode()
|
|
|
| img_buf = txn.get(img_key)
|
| label_buf = txn.get(label_key)
|
|
|
| if img_buf is None:
|
|
|
| idx = (idx + 1) % self.chunk_length
|
| continue
|
|
|
|
|
| img_array = np.frombuffer(img_buf, dtype=np.uint8)
|
| image = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
|
|
|
| if image is None:
|
|
|
| idx = (idx + 1) % self.chunk_length
|
| continue
|
|
|
|
|
| if label_buf is not None:
|
| label_array = np.frombuffer(label_buf, dtype=np.uint8)
|
| mask = cv2.imdecode(label_array, cv2.IMREAD_GRAYSCALE)
|
| if mask is None:
|
|
|
| mask = np.zeros(image.shape[:2], dtype=np.uint8)
|
| else:
|
|
|
| mask = np.zeros(image.shape[:2], dtype=np.uint8)
|
|
|
|
|
| break
|
|
|
| except Exception as e:
|
|
|
| idx = (idx + 1) % self.chunk_length
|
| if attempt == max_attempts - 1:
|
|
|
| print(f"Warning: Could not load sample at idx {original_idx}, creating dummy sample")
|
| image = np.zeros((384, 384, 3), dtype=np.float32)
|
| mask = np.zeros((384, 384), dtype=np.uint8)
|
| global_idx = original_idx
|
|
|
|
|
| image, mask = self.preprocessor(image, mask)
|
|
|
|
|
| augmented = self.augmentation(image, mask)
|
| image = augmented['image']
|
| mask = augmented['mask']
|
|
|
|
|
| metadata = {
|
| 'dataset': self.dataset_name,
|
| 'index': global_idx,
|
| 'has_pixel_mask': True
|
| }
|
|
|
| return image, mask, metadata
|
|
|
| def __del__(self):
|
| """Close LMDB environment"""
|
| if hasattr(self, '_env') and self._env is not None:
|
| self._env.close()
|
|
|
|
|
|
|
| class RTMDataset(Dataset):
|
| """Real Text Manipulation dataset loader"""
|
|
|
| def __init__(self, config, split: str = 'train'):
|
| """
|
| Initialize RTM dataset
|
|
|
| Args:
|
| config: Configuration object
|
| split: 'train' or 'test'
|
| """
|
| self.config = config
|
| self.split = split
|
| self.dataset_name = 'rtm'
|
|
|
|
|
| dataset_config = config.get_dataset_config(self.dataset_name)
|
| self.data_path = Path(dataset_config['path'])
|
|
|
|
|
| split_file = self.data_path / f'{split}.txt'
|
| with open(split_file, 'r') as f:
|
| self.image_ids = [line.strip() for line in f.readlines()]
|
|
|
| self.images_dir = self.data_path / 'JPEGImages'
|
| self.masks_dir = self.data_path / 'SegmentationClass'
|
|
|
| print(f"RTM {split}: {len(self.image_ids)} images")
|
|
|
|
|
| self.preprocessor = DocumentPreprocessor(config, self.dataset_name)
|
| self.augmentation = DatasetAwareAugmentation(
|
| config,
|
| self.dataset_name,
|
| is_training=(split == 'train')
|
| )
|
|
|
| def __len__(self) -> int:
|
| return len(self.image_ids)
|
|
|
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
| """Get item from dataset"""
|
| image_id = self.image_ids[idx]
|
|
|
|
|
| img_path = self.images_dir / f'{image_id}.jpg'
|
| image = cv2.imread(str(img_path))
|
|
|
|
|
| mask_path = self.masks_dir / f'{image_id}.png'
|
| mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
|
|
|
|
|
| mask = (mask > 0).astype(np.uint8)
|
|
|
|
|
| image, mask = self.preprocessor(image, mask)
|
|
|
|
|
| augmented = self.augmentation(image, mask)
|
| image = augmented['image']
|
| mask = augmented['mask']
|
|
|
|
|
| metadata = {
|
| 'dataset': self.dataset_name,
|
| 'image_id': image_id,
|
| 'has_pixel_mask': True
|
| }
|
|
|
| return image, mask, metadata
|
|
|
|
|
| class CASIADataset(Dataset):
|
| """
|
| CASIA v1.0 dataset loader
|
| Image-level labels only (no pixel masks)
|
| Implements Critical Fix #6: CASIA image-level handling
|
| """
|
|
|
| def __init__(self, config, split: str = 'train'):
|
| """
|
| Initialize CASIA dataset
|
|
|
| Args:
|
| config: Configuration object
|
| split: 'train' or 'test'
|
| """
|
| self.config = config
|
| self.split = split
|
| self.dataset_name = 'casia'
|
|
|
|
|
| dataset_config = config.get_dataset_config(self.dataset_name)
|
| self.data_path = Path(dataset_config['path'])
|
|
|
|
|
| self.authentic_dir = self.data_path / 'Au'
|
| self.tampered_dir = self.data_path / 'Tp'
|
|
|
|
|
| authentic_images = list(self.authentic_dir.glob('*.jpg')) + \
|
| list(self.authentic_dir.glob('*.png'))
|
| tampered_images = list(self.tampered_dir.glob('*.jpg')) + \
|
| list(self.tampered_dir.glob('*.png'))
|
|
|
|
|
| self.samples = []
|
| for img_path in authentic_images:
|
| self.samples.append((img_path, 0))
|
| for img_path in tampered_images:
|
| self.samples.append((img_path, 1))
|
|
|
|
|
| np.random.seed(42)
|
| indices = np.random.permutation(len(self.samples))
|
| split_idx = int(len(self.samples) * 0.8)
|
|
|
| if split == 'train':
|
| indices = indices[:split_idx]
|
| else:
|
| indices = indices[split_idx:]
|
|
|
| self.samples = [self.samples[i] for i in indices]
|
|
|
| print(f"CASIA {split}: {len(self.samples)} images")
|
|
|
|
|
| self.preprocessor = DocumentPreprocessor(config, self.dataset_name)
|
| self.augmentation = DatasetAwareAugmentation(
|
| config,
|
| self.dataset_name,
|
| is_training=(split == 'train')
|
| )
|
|
|
| def __len__(self) -> int:
|
| return len(self.samples)
|
|
|
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
| """Get item from dataset"""
|
| img_path, label = self.samples[idx]
|
|
|
|
|
| image = cv2.imread(str(img_path))
|
|
|
|
|
| h, w = image.shape[:2]
|
| mask = np.ones((h, w), dtype=np.uint8) * label
|
|
|
|
|
| image, mask = self.preprocessor(image, mask)
|
|
|
|
|
| augmented = self.augmentation(image, mask)
|
| image = augmented['image']
|
| mask = augmented['mask']
|
|
|
|
|
| metadata = {
|
| 'dataset': self.dataset_name,
|
| 'image_path': str(img_path),
|
| 'has_pixel_mask': False,
|
| 'label': label
|
| }
|
|
|
| return image, mask, metadata
|
|
|
|
|
| class ReceiptsDataset(Dataset):
|
| """Find-It-Again receipts dataset loader"""
|
|
|
| def __init__(self, config, split: str = 'train'):
|
| """
|
| Initialize receipts dataset
|
|
|
| Args:
|
| config: Configuration object
|
| split: 'train', 'val', or 'test'
|
| """
|
| self.config = config
|
| self.split = split
|
| self.dataset_name = 'receipts'
|
|
|
|
|
| dataset_config = config.get_dataset_config(self.dataset_name)
|
| self.data_path = Path(dataset_config['path'])
|
|
|
|
|
| split_file = self.data_path / f'{split}.json'
|
| with open(split_file, 'r') as f:
|
| self.annotations = json.load(f)
|
|
|
| print(f"Receipts {split}: {len(self.annotations)} images")
|
|
|
|
|
| self.preprocessor = DocumentPreprocessor(config, self.dataset_name)
|
| self.augmentation = DatasetAwareAugmentation(
|
| config,
|
| self.dataset_name,
|
| is_training=(split == 'train')
|
| )
|
|
|
| def __len__(self) -> int:
|
| return len(self.annotations)
|
|
|
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
| """Get item from dataset"""
|
| ann = self.annotations[idx]
|
|
|
|
|
| img_path = self.data_path / ann['image_path']
|
| image = cv2.imread(str(img_path))
|
|
|
|
|
| h, w = image.shape[:2]
|
| mask = np.zeros((h, w), dtype=np.uint8)
|
|
|
| for bbox in ann.get('bboxes', []):
|
| x, y, w_box, h_box = bbox
|
| mask[y:y+h_box, x:x+w_box] = 1
|
|
|
|
|
| image, mask = self.preprocessor(image, mask)
|
|
|
|
|
| augmented = self.augmentation(image, mask)
|
| image = augmented['image']
|
| mask = augmented['mask']
|
|
|
|
|
| metadata = {
|
| 'dataset': self.dataset_name,
|
| 'image_path': str(img_path),
|
| 'has_pixel_mask': True
|
| }
|
|
|
| return image, mask, metadata
|
|
|
|
|
| class FCDDataset(DocTamperDataset):
|
| """FCD (Forgery Classification Dataset) loader - inherits from DocTamperDataset"""
|
|
|
| def __init__(self, config, split: str = 'train'):
|
| self.config = config
|
| self.split = split
|
| self.dataset_name = 'fcd'
|
|
|
|
|
| dataset_config = config.get_dataset_config(self.dataset_name)
|
| self.data_path = Path(dataset_config['path'])
|
| self.lmdb_path = str(self.data_path)
|
|
|
| if not Path(self.lmdb_path).exists():
|
| raise FileNotFoundError(f"LMDB folder not found: {self.lmdb_path}")
|
|
|
|
|
| temp_env = lmdb.open(self.lmdb_path, readonly=True, lock=False)
|
| with temp_env.begin() as txn:
|
| stat = txn.stat()
|
| self.length = stat['entries'] // 2
|
| temp_env.close()
|
|
|
| self._env = None
|
|
|
|
|
| self.chunk_start = 0
|
| self.chunk_end = self.length
|
| self.chunk_length = self.length
|
|
|
| print(f"FCD {split}: {self.length} samples")
|
|
|
|
|
| self.preprocessor = DocumentPreprocessor(config, self.dataset_name)
|
| self.augmentation = DatasetAwareAugmentation(
|
| config,
|
| self.dataset_name,
|
| is_training=(split == 'train')
|
| )
|
|
|
|
|
| class SCDDataset(DocTamperDataset):
|
| """SCD (Splicing Classification Dataset) loader - inherits from DocTamperDataset"""
|
|
|
| def __init__(self, config, split: str = 'train'):
|
| self.config = config
|
| self.split = split
|
| self.dataset_name = 'scd'
|
|
|
|
|
| dataset_config = config.get_dataset_config(self.dataset_name)
|
| self.data_path = Path(dataset_config['path'])
|
| self.lmdb_path = str(self.data_path)
|
|
|
| if not Path(self.lmdb_path).exists():
|
| raise FileNotFoundError(f"LMDB folder not found: {self.lmdb_path}")
|
|
|
|
|
| temp_env = lmdb.open(self.lmdb_path, readonly=True, lock=False)
|
| with temp_env.begin() as txn:
|
| stat = txn.stat()
|
| self.length = stat['entries'] // 2
|
| temp_env.close()
|
|
|
| self._env = None
|
|
|
|
|
| self.chunk_start = 0
|
| self.chunk_end = self.length
|
| self.chunk_length = self.length
|
|
|
| print(f"SCD {split}: {self.length} samples")
|
|
|
|
|
| self.preprocessor = DocumentPreprocessor(config, self.dataset_name)
|
| self.augmentation = DatasetAwareAugmentation(
|
| config,
|
| self.dataset_name,
|
| is_training=(split == 'train')
|
| )
|
|
|
|
|
| def get_dataset(config, dataset_name: str, split: str = 'train', **kwargs) -> Dataset:
|
| """
|
| Factory function to get dataset
|
|
|
| Args:
|
| config: Configuration object
|
| dataset_name: Dataset name
|
| split: Data split
|
| **kwargs: Additional arguments (e.g., chunk_start, chunk_end)
|
|
|
| Returns:
|
| Dataset instance
|
| """
|
| if dataset_name == 'doctamper':
|
| return DocTamperDataset(config, split, **kwargs)
|
| elif dataset_name == 'rtm':
|
| return RTMDataset(config, split)
|
| elif dataset_name == 'casia':
|
| return CASIADataset(config, split)
|
| elif dataset_name == 'receipts':
|
| return ReceiptsDataset(config, split)
|
| elif dataset_name == 'fcd':
|
| return FCDDataset(config, split)
|
| elif dataset_name == 'scd':
|
| return SCDDataset(config, split)
|
| else:
|
| raise ValueError(f"Unknown dataset: {dataset_name}")
|
|
|