Image-Forensics-Detect / training /dataset_loader.py
dk2430098's picture
Upload folder using huggingface_hub
928b74f verified
"""
training/dataset_loader.py
---------------------------
Dataset loading, preprocessing, augmentation and splitting.
STATUS: COMPLETE
Expected dataset structure:
data/raw/
β”œβ”€β”€ real/ ← real camera images (.jpg, .png, .jpeg, .webp)
└── fake/ ← AI-generated images
Splits: 70% train / 15% val / 15% test (stratified)
Supports:
- TensorFlow tf.data pipeline (for CNN branch)
- PyTorch DataLoader (for ViT branch)
- Plain numpy arrays (for handcrafted branches)
"""
import os
import numpy as np
import cv2
from pathlib import Path
from typing import Tuple, List, Dict
from sklearn.model_selection import train_test_split
# ─────────────────────────────────────────────────────────────────
# Config
# ─────────────────────────────────────────────────────────────────
DATA_DIR = Path(__file__).parent.parent / "data" / "raw"
IMAGE_SIZE = (224, 224)
VALID_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
RANDOM_SEED = 42
# ─────────────────────────────────────────────────────────────────
# File Discovery
# ─────────────────────────────────────────────────────────────────
def discover_dataset(data_dir: Path = DATA_DIR) -> Tuple[List[str], List[int]]:
"""
Scan data/raw/real/ and data/raw/fake/ and return (paths, labels).
Labels: 0 = real, 1 = fake.
"""
paths, labels = [], []
for label_name, label_val in [("real", 0), ("fake", 1)]:
class_dir = data_dir / label_name
if not class_dir.exists():
print(f"⚠ Directory not found: {class_dir}")
continue
for fp in sorted(class_dir.iterdir()):
if fp.suffix.lower() in VALID_EXTS:
paths.append(str(fp))
labels.append(label_val)
print(f"Dataset discovered: {labels.count(0)} real + {labels.count(1)} fake images")
return paths, labels
def split_dataset(
paths: List[str], labels: List[int],
val_size: float = 0.15, test_size: float = 0.15
) -> Dict[str, Tuple[List[str], List[int]]]:
"""
Stratified train/val/test split.
Returns dict with keys 'train', 'val', 'test'.
"""
train_paths, temp_paths, train_labels, temp_labels = train_test_split(
paths, labels,
test_size=val_size + test_size,
random_state=RANDOM_SEED,
stratify=labels,
)
relative_test = test_size / (val_size + test_size)
val_paths, test_paths, val_labels, test_labels = train_test_split(
temp_paths, temp_labels,
test_size=relative_test,
random_state=RANDOM_SEED,
stratify=temp_labels,
)
return {
"train": (train_paths, train_labels),
"val": (val_paths, val_labels),
"test": (test_paths, test_labels),
}
# ─────────────────────────────────────────────────────────────────
# TensorFlow Dataset
# ─────────────────────────────────────────────────────────────────
def make_tf_dataset(
paths: List[str],
labels: List[int],
batch_size: int = 32,
augment: bool = False,
shuffle: bool = True,
):
"""
Build a tf.data.Dataset for CNN branch training.
Preprocessing:
- Decode image β†’ resize 224Γ—224 β†’ normalize [0, 1]
Augmentation (train only):
- Random horizontal flip
- Random brightness/contrast jitter
- Random rotation Β±15Β°
"""
import tensorflow as tf
paths_t = tf.constant(paths)
labels_t = tf.constant(labels, dtype=tf.float32)
ds = tf.data.Dataset.from_tensor_slices((paths_t, labels_t))
if shuffle:
ds = ds.shuffle(buffer_size=len(paths), seed=RANDOM_SEED)
def preprocess(path, label):
raw = tf.io.read_file(path)
# Try JPEG first, fallback to PNG
img = tf.io.decode_image(raw, channels=3, expand_animations=False)
img = tf.cast(img, tf.float32) / 255.0
img = tf.image.resize(img, IMAGE_SIZE, method="area")
return img, label
def augment_fn(img, label):
img = tf.image.random_flip_left_right(img)
img = tf.image.random_brightness(img, max_delta=0.15)
img = tf.image.random_contrast(img, lower=0.85, upper=1.15)
img = tf.image.random_saturation(img, lower=0.8, upper=1.2)
img = tf.clip_by_value(img, 0.0, 1.0)
return img, label
ds = ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
if augment:
ds = ds.map(augment_fn, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
return ds
# ─────────────────────────────────────────────────────────────────
# PyTorch DataLoader
# ─────────────────────────────────────────────────────────────────
def make_torch_dataloader(
paths: List[str],
labels: List[int],
batch_size: int = 32,
augment: bool = False,
num_workers: int = 2,
):
"""
Build a PyTorch DataLoader for ViT branch training.
Uses torchvision transforms with ImageNet normalization.
"""
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
if augment:
transform = T.Compose([
T.Resize((224, 224)),
T.RandomHorizontalFlip(),
T.RandomRotation(15),
T.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.15),
T.ToTensor(),
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
else:
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
class ImageDataset(Dataset):
def __init__(self, paths, labels, transform):
self.paths = paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
from PIL import Image as PILImage
img = PILImage.open(self.paths[idx]).convert("RGB")
img = self.transform(img)
label = torch.tensor(self.labels[idx], dtype=torch.long)
return img, label
dataset = ImageDataset(paths, labels, transform)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=(augment), # shuffle only for train
num_workers=num_workers,
pin_memory=True,
)
return loader