deepdetection / src /training /datasets.py
akagtag's picture
Initial commit
4e75170
"""
src/training/datasets.py
Shared dataset utilities used by scripts/ entrypoints.
"""
from __future__ import annotations
import csv
from pathlib import Path
from typing import Optional
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp"}
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
def get_train_transform(size: int = 224):
return transforms.Compose([
transforms.RandomResizedCrop(size, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])
def get_val_transform(size: int = 224):
return transforms.Compose([
transforms.Resize(int(size * 256 / 224)),
transforms.CenterCrop(size),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])
class ImageManifestDataset(Dataset):
"""
Generic image dataset driven by a manifest CSV.
Manifest format: filepath, label (0=real, 1=fake), [generator (int)]
"""
def __init__(
self,
manifest_path: Path,
transform=None,
root_dir: Optional[Path] = None,
):
self.transform = transform
self.root_dir = Path(root_dir) if root_dir else None
self.samples = []
with open(manifest_path) as f:
reader = csv.DictReader(f)
for row in reader:
filepath = Path(row["filepath"])
if self.root_dir and not filepath.is_absolute():
filepath = self.root_dir / filepath
label = int(row["label"])
generator = int(row.get("generator", 0))
self.samples.append((filepath, label, generator))
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int) -> dict:
path, label, generator = self.samples[idx]
img = Image.open(path).convert("RGB")
if self.transform:
img = self.transform(img)
return {
"image": img,
"label": label,
"generator": generator,
"filepath": str(path),
}
def get_class_weights(self) -> torch.Tensor:
labels = [s[1] for s in self.samples]
n_real = labels.count(0)
n_fake = labels.count(1)
w_real = 1.0 / max(n_real, 1)
w_fake = 1.0 / max(n_fake, 1)
return torch.tensor([w_real, w_fake], dtype=torch.float32)