""" training/dataset_loader.py --------------------------- Dataset loading, preprocessing, augmentation and splitting. STATUS: COMPLETE Expected dataset structure: data/raw/ ├── real/ ← real camera images (.jpg, .png, .jpeg, .webp) └── fake/ ← AI-generated images Splits: 70% train / 15% val / 15% test (stratified) Supports: - TensorFlow tf.data pipeline (for CNN branch) - PyTorch DataLoader (for ViT branch) - Plain numpy arrays (for handcrafted branches) """ import os import numpy as np import cv2 from pathlib import Path from typing import Tuple, List, Dict from sklearn.model_selection import train_test_split # ───────────────────────────────────────────────────────────────── # Config # ───────────────────────────────────────────────────────────────── DATA_DIR = Path(__file__).parent.parent / "data" / "raw" IMAGE_SIZE = (224, 224) VALID_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp"} RANDOM_SEED = 42 # ───────────────────────────────────────────────────────────────── # File Discovery # ───────────────────────────────────────────────────────────────── def discover_dataset(data_dir: Path = DATA_DIR) -> Tuple[List[str], List[int]]: """ Scan data/raw/real/ and data/raw/fake/ and return (paths, labels). Labels: 0 = real, 1 = fake. """ paths, labels = [], [] for label_name, label_val in [("real", 0), ("fake", 1)]: class_dir = data_dir / label_name if not class_dir.exists(): print(f"⚠ Directory not found: {class_dir}") continue for fp in sorted(class_dir.iterdir()): if fp.suffix.lower() in VALID_EXTS: paths.append(str(fp)) labels.append(label_val) print(f"Dataset discovered: {labels.count(0)} real + {labels.count(1)} fake images") return paths, labels def split_dataset( paths: List[str], labels: List[int], val_size: float = 0.15, test_size: float = 0.15 ) -> Dict[str, Tuple[List[str], List[int]]]: """ Stratified train/val/test split. Returns dict with keys 'train', 'val', 'test'. """ train_paths, temp_paths, train_labels, temp_labels = train_test_split( paths, labels, test_size=val_size + test_size, random_state=RANDOM_SEED, stratify=labels, ) relative_test = test_size / (val_size + test_size) val_paths, test_paths, val_labels, test_labels = train_test_split( temp_paths, temp_labels, test_size=relative_test, random_state=RANDOM_SEED, stratify=temp_labels, ) return { "train": (train_paths, train_labels), "val": (val_paths, val_labels), "test": (test_paths, test_labels), } # ───────────────────────────────────────────────────────────────── # TensorFlow Dataset # ───────────────────────────────────────────────────────────────── def make_tf_dataset( paths: List[str], labels: List[int], batch_size: int = 32, augment: bool = False, shuffle: bool = True, ): """ Build a tf.data.Dataset for CNN branch training. Preprocessing: - Decode image → resize 224×224 → normalize [0, 1] Augmentation (train only): - Random horizontal flip - Random brightness/contrast jitter - Random rotation ±15° """ import tensorflow as tf paths_t = tf.constant(paths) labels_t = tf.constant(labels, dtype=tf.float32) ds = tf.data.Dataset.from_tensor_slices((paths_t, labels_t)) if shuffle: ds = ds.shuffle(buffer_size=len(paths), seed=RANDOM_SEED) def preprocess(path, label): raw = tf.io.read_file(path) # Try JPEG first, fallback to PNG img = tf.io.decode_image(raw, channels=3, expand_animations=False) img = tf.cast(img, tf.float32) / 255.0 img = tf.image.resize(img, IMAGE_SIZE, method="area") return img, label def augment_fn(img, label): img = tf.image.random_flip_left_right(img) img = tf.image.random_brightness(img, max_delta=0.15) img = tf.image.random_contrast(img, lower=0.85, upper=1.15) img = tf.image.random_saturation(img, lower=0.8, upper=1.2) img = tf.clip_by_value(img, 0.0, 1.0) return img, label ds = ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE) if augment: ds = ds.map(augment_fn, num_parallel_calls=tf.data.AUTOTUNE) ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE) return ds # ───────────────────────────────────────────────────────────────── # PyTorch DataLoader # ───────────────────────────────────────────────────────────────── def make_torch_dataloader( paths: List[str], labels: List[int], batch_size: int = 32, augment: bool = False, num_workers: int = 2, ): """ Build a PyTorch DataLoader for ViT branch training. Uses torchvision transforms with ImageNet normalization. """ import torch from torch.utils.data import Dataset, DataLoader import torchvision.transforms as T IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] if augment: transform = T.Compose([ T.Resize((224, 224)), T.RandomHorizontalFlip(), T.RandomRotation(15), T.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.15), T.ToTensor(), T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ]) else: transform = T.Compose([ T.Resize((224, 224)), T.ToTensor(), T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ]) class ImageDataset(Dataset): def __init__(self, paths, labels, transform): self.paths = paths self.labels = labels self.transform = transform def __len__(self): return len(self.paths) def __getitem__(self, idx): from PIL import Image as PILImage img = PILImage.open(self.paths[idx]).convert("RGB") img = self.transform(img) label = torch.tensor(self.labels[idx], dtype=torch.long) return img, label dataset = ImageDataset(paths, labels, transform) loader = DataLoader( dataset, batch_size=batch_size, shuffle=(augment), # shuffle only for train num_workers=num_workers, pin_memory=True, ) return loader