""" src/training/datasets.py Shared dataset utilities used by scripts/ entrypoints. """ from __future__ import annotations import csv from pathlib import Path from typing import Optional import numpy as np import torch from PIL import Image from torch.utils.data import Dataset from torchvision import transforms IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp"} IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] def get_train_transform(size: int = 224): return transforms.Compose([ transforms.RandomResizedCrop(size, scale=(0.8, 1.0)), transforms.RandomHorizontalFlip(), transforms.ColorJitter(0.2, 0.2, 0.2, 0.1), transforms.ToTensor(), transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD), ]) def get_val_transform(size: int = 224): return transforms.Compose([ transforms.Resize(int(size * 256 / 224)), transforms.CenterCrop(size), transforms.ToTensor(), transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD), ]) class ImageManifestDataset(Dataset): """ Generic image dataset driven by a manifest CSV. Manifest format: filepath, label (0=real, 1=fake), [generator (int)] """ def __init__( self, manifest_path: Path, transform=None, root_dir: Optional[Path] = None, ): self.transform = transform self.root_dir = Path(root_dir) if root_dir else None self.samples = [] with open(manifest_path) as f: reader = csv.DictReader(f) for row in reader: filepath = Path(row["filepath"]) if self.root_dir and not filepath.is_absolute(): filepath = self.root_dir / filepath label = int(row["label"]) generator = int(row.get("generator", 0)) self.samples.append((filepath, label, generator)) def __len__(self) -> int: return len(self.samples) def __getitem__(self, idx: int) -> dict: path, label, generator = self.samples[idx] img = Image.open(path).convert("RGB") if self.transform: img = self.transform(img) return { "image": img, "label": label, "generator": generator, "filepath": str(path), } def get_class_weights(self) -> torch.Tensor: labels = [s[1] for s in self.samples] n_real = labels.count(0) n_fake = labels.count(1) w_real = 1.0 / max(n_real, 1) w_fake = 1.0 / max(n_fake, 1) return torch.tensor([w_real, w_fake], dtype=torch.float32)