""" UTKFace PyTorch Dataset. Filename format: [age]_[gender]_[race]_[datetime].jpg age : 0-116 gender : 0=Male 1=Female race : 0=White 1=Black 2=Asian 3=Indian 4=Others """ from __future__ import annotations import os import random from pathlib import Path from typing import List, Optional, Tuple, Union import numpy as np import torch from PIL import Image from torch.utils.data import Dataset from torchvision import transforms # ── augmentation presets ─────────────────────────────────────────────────── def train_transforms(img_size: int = 224) -> transforms.Compose: return transforms.Compose([ transforms.Resize((img_size + 20, img_size + 20)), transforms.RandomCrop(img_size), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.05), transforms.RandomRotation(10), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def eval_transforms(img_size: int = 224) -> transforms.Compose: return transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # ── dataset class ────────────────────────────────────────────────────────── class UTKFaceDataset(Dataset): """ Returns (image_tensor, gender_label, age_normalised) gender_label : int 0=Male 1=Female age_normalised : float in [0, 1] (age / MAX_AGE) """ MAX_AGE = 90.0 def __init__( self, root_dir: "Union[str, Path]", split: str = "train", target_races: Optional[List[int]] = None, min_age: int = 1, max_age: int = 90, train_ratio: float = 0.80, val_ratio: float = 0.10, img_size: int = 224, seed: int = 42, ) -> None: self.root_dir = Path(root_dir) self.split = split self.target_races = set(target_races) if target_races else None self.min_age = min_age self.max_age = max_age self.img_size = img_size self.transform = train_transforms(img_size) if split == "train" else eval_transforms(img_size) samples = self._scan() samples = self._filter(samples) random.seed(seed) random.shuffle(samples) n = len(samples) n_train = int(n * train_ratio) n_val = int(n * val_ratio) if split == "train": self.samples = samples[:n_train] elif split == "val": self.samples = samples[n_train: n_train + n_val] else: # test self.samples = samples[n_train + n_val:] # ── private helpers ──────────────────────────────────────────────────── def _scan(self) -> List[Tuple[Path, int, int, int]]: """Return list of (path, age, gender, race).""" records: List[Tuple[Path, int, int, int]] = [] for p in self.root_dir.glob("*.jpg"): parts = p.stem.split("_") if len(parts) < 3: continue try: age = int(parts[0]) gender = int(parts[1]) race = int(parts[2]) except ValueError: continue records.append((p, age, gender, race)) return records def _filter(self, records: List[Tuple[Path, int, int, int]]) -> List[Tuple[Path, int, int, int]]: out = [] for p, age, gender, race in records: if age < self.min_age or age > self.max_age: continue if gender not in (0, 1): continue if self.target_races and race not in self.target_races: continue out.append((p, age, gender, race)) return out # ── public API ───────────────────────────────────────────────────────── def __len__(self) -> int: return len(self.samples) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: path, age, gender, _ = self.samples[idx] img = Image.open(path).convert("RGB") img = self.transform(img) gender_t = torch.tensor(gender, dtype=torch.long) age_t = torch.tensor(age / self.MAX_AGE, dtype=torch.float32) return img, gender_t, age_t def class_weights(self) -> torch.Tensor: """Return balanced class weights for gender (0=Male, 1=Female).""" counts = [0, 0] for _, _, gender, _ in self.samples: counts[gender] += 1 total = sum(counts) weights = torch.tensor([total / (2 * c) for c in counts], dtype=torch.float32) return weights @staticmethod def denorm_age(age_norm: float, max_age: float = 90.0) -> int: return round(float(age_norm) * max_age) def build_dataloaders(cfg) -> dict: """Build train / val / test DataLoaders from config.""" from torch.utils.data import DataLoader common = dict( root_dir = cfg.UTKFACE_DIR, target_races = cfg.TARGET_RACES, min_age = cfg.MIN_AGE, max_age = cfg.MAX_AGE, train_ratio = cfg.TRAIN_RATIO, val_ratio = cfg.VAL_RATIO, img_size = cfg.IMG_SIZE, seed = cfg.SEED, ) loaders = {} for split in ("train", "val", "test"): ds = UTKFaceDataset(split=split, **common) loaders[split] = DataLoader( ds, batch_size = cfg.BATCH_SIZE, shuffle = (split == "train"), num_workers = cfg.NUM_WORKERS, pin_memory = True, drop_last = (split == "train"), ) print(f"[dataset] {split:5s}: {len(ds):,} samples") return loaders