Spaces:
Running
Running
| """ | |
| UTKFace PyTorch Dataset. | |
| Filename format: [age]_[gender]_[race]_[datetime].jpg | |
| age : 0-116 | |
| gender : 0=Male 1=Female | |
| race : 0=White 1=Black 2=Asian 3=Indian 4=Others | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import random | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| # ββ augmentation presets βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def train_transforms(img_size: int = 224) -> transforms.Compose: | |
| return transforms.Compose([ | |
| transforms.Resize((img_size + 20, img_size + 20)), | |
| transforms.RandomCrop(img_size), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.05), | |
| transforms.RandomRotation(10), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def eval_transforms(img_size: int = 224) -> transforms.Compose: | |
| return transforms.Compose([ | |
| transforms.Resize((img_size, img_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # ββ dataset class ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class UTKFaceDataset(Dataset): | |
| """ | |
| Returns (image_tensor, gender_label, age_normalised) | |
| gender_label : int 0=Male 1=Female | |
| age_normalised : float in [0, 1] (age / MAX_AGE) | |
| """ | |
| MAX_AGE = 90.0 | |
| def __init__( | |
| self, | |
| root_dir: "Union[str, Path]", | |
| split: str = "train", | |
| target_races: Optional[List[int]] = None, | |
| min_age: int = 1, | |
| max_age: int = 90, | |
| train_ratio: float = 0.80, | |
| val_ratio: float = 0.10, | |
| img_size: int = 224, | |
| seed: int = 42, | |
| ) -> None: | |
| self.root_dir = Path(root_dir) | |
| self.split = split | |
| self.target_races = set(target_races) if target_races else None | |
| self.min_age = min_age | |
| self.max_age = max_age | |
| self.img_size = img_size | |
| self.transform = train_transforms(img_size) if split == "train" else eval_transforms(img_size) | |
| samples = self._scan() | |
| samples = self._filter(samples) | |
| random.seed(seed) | |
| random.shuffle(samples) | |
| n = len(samples) | |
| n_train = int(n * train_ratio) | |
| n_val = int(n * val_ratio) | |
| if split == "train": | |
| self.samples = samples[:n_train] | |
| elif split == "val": | |
| self.samples = samples[n_train: n_train + n_val] | |
| else: # test | |
| self.samples = samples[n_train + n_val:] | |
| # ββ private helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _scan(self) -> List[Tuple[Path, int, int, int]]: | |
| """Return list of (path, age, gender, race).""" | |
| records: List[Tuple[Path, int, int, int]] = [] | |
| for p in self.root_dir.glob("*.jpg"): | |
| parts = p.stem.split("_") | |
| if len(parts) < 3: | |
| continue | |
| try: | |
| age = int(parts[0]) | |
| gender = int(parts[1]) | |
| race = int(parts[2]) | |
| except ValueError: | |
| continue | |
| records.append((p, age, gender, race)) | |
| return records | |
| def _filter(self, records: List[Tuple[Path, int, int, int]]) -> List[Tuple[Path, int, int, int]]: | |
| out = [] | |
| for p, age, gender, race in records: | |
| if age < self.min_age or age > self.max_age: | |
| continue | |
| if gender not in (0, 1): | |
| continue | |
| if self.target_races and race not in self.target_races: | |
| continue | |
| out.append((p, age, gender, race)) | |
| return out | |
| # ββ public API βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def __len__(self) -> int: | |
| return len(self.samples) | |
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| path, age, gender, _ = self.samples[idx] | |
| img = Image.open(path).convert("RGB") | |
| img = self.transform(img) | |
| gender_t = torch.tensor(gender, dtype=torch.long) | |
| age_t = torch.tensor(age / self.MAX_AGE, dtype=torch.float32) | |
| return img, gender_t, age_t | |
| def class_weights(self) -> torch.Tensor: | |
| """Return balanced class weights for gender (0=Male, 1=Female).""" | |
| counts = [0, 0] | |
| for _, _, gender, _ in self.samples: | |
| counts[gender] += 1 | |
| total = sum(counts) | |
| weights = torch.tensor([total / (2 * c) for c in counts], dtype=torch.float32) | |
| return weights | |
| def denorm_age(age_norm: float, max_age: float = 90.0) -> int: | |
| return round(float(age_norm) * max_age) | |
| def build_dataloaders(cfg) -> dict: | |
| """Build train / val / test DataLoaders from config.""" | |
| from torch.utils.data import DataLoader | |
| common = dict( | |
| root_dir = cfg.UTKFACE_DIR, | |
| target_races = cfg.TARGET_RACES, | |
| min_age = cfg.MIN_AGE, | |
| max_age = cfg.MAX_AGE, | |
| train_ratio = cfg.TRAIN_RATIO, | |
| val_ratio = cfg.VAL_RATIO, | |
| img_size = cfg.IMG_SIZE, | |
| seed = cfg.SEED, | |
| ) | |
| loaders = {} | |
| for split in ("train", "val", "test"): | |
| ds = UTKFaceDataset(split=split, **common) | |
| loaders[split] = DataLoader( | |
| ds, | |
| batch_size = cfg.BATCH_SIZE, | |
| shuffle = (split == "train"), | |
| num_workers = cfg.NUM_WORKERS, | |
| pin_memory = True, | |
| drop_last = (split == "train"), | |
| ) | |
| print(f"[dataset] {split:5s}: {len(ds):,} samples") | |
| return loaders | |