Spaces:
Running
Running
| """ | |
| src/training/datasets.py | |
| Shared dataset utilities used by scripts/ entrypoints. | |
| """ | |
| from __future__ import annotations | |
| import csv | |
| from pathlib import Path | |
| from typing import Optional | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp"} | |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
| IMAGENET_STD = [0.229, 0.224, 0.225] | |
| def get_train_transform(size: int = 224): | |
| return transforms.Compose([ | |
| transforms.RandomResizedCrop(size, scale=(0.8, 1.0)), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ColorJitter(0.2, 0.2, 0.2, 0.1), | |
| transforms.ToTensor(), | |
| transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD), | |
| ]) | |
| def get_val_transform(size: int = 224): | |
| return transforms.Compose([ | |
| transforms.Resize(int(size * 256 / 224)), | |
| transforms.CenterCrop(size), | |
| transforms.ToTensor(), | |
| transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD), | |
| ]) | |
| class ImageManifestDataset(Dataset): | |
| """ | |
| Generic image dataset driven by a manifest CSV. | |
| Manifest format: filepath, label (0=real, 1=fake), [generator (int)] | |
| """ | |
| def __init__( | |
| self, | |
| manifest_path: Path, | |
| transform=None, | |
| root_dir: Optional[Path] = None, | |
| ): | |
| self.transform = transform | |
| self.root_dir = Path(root_dir) if root_dir else None | |
| self.samples = [] | |
| with open(manifest_path) as f: | |
| reader = csv.DictReader(f) | |
| for row in reader: | |
| filepath = Path(row["filepath"]) | |
| if self.root_dir and not filepath.is_absolute(): | |
| filepath = self.root_dir / filepath | |
| label = int(row["label"]) | |
| generator = int(row.get("generator", 0)) | |
| self.samples.append((filepath, label, generator)) | |
| def __len__(self) -> int: | |
| return len(self.samples) | |
| def __getitem__(self, idx: int) -> dict: | |
| path, label, generator = self.samples[idx] | |
| img = Image.open(path).convert("RGB") | |
| if self.transform: | |
| img = self.transform(img) | |
| return { | |
| "image": img, | |
| "label": label, | |
| "generator": generator, | |
| "filepath": str(path), | |
| } | |
| def get_class_weights(self) -> torch.Tensor: | |
| labels = [s[1] for s in self.samples] | |
| n_real = labels.count(0) | |
| n_fake = labels.count(1) | |
| w_real = 1.0 / max(n_real, 1) | |
| w_fake = 1.0 / max(n_fake, 1) | |
| return torch.tensor([w_real, w_fake], dtype=torch.float32) | |