FaceInsight_AI / src /data /dataset.py
vaisagan's picture
Upload src/data/dataset.py with huggingface_hub
659083c verified
"""
UTKFace PyTorch Dataset.
Filename format: [age]_[gender]_[race]_[datetime].jpg
age : 0-116
gender : 0=Male 1=Female
race : 0=White 1=Black 2=Asian 3=Indian 4=Others
"""
from __future__ import annotations
import os
import random
from pathlib import Path
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
# ── augmentation presets ───────────────────────────────────────────────────
def train_transforms(img_size: int = 224) -> transforms.Compose:
return transforms.Compose([
transforms.Resize((img_size + 20, img_size + 20)),
transforms.RandomCrop(img_size),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.05),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
def eval_transforms(img_size: int = 224) -> transforms.Compose:
return transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# ── dataset class ──────────────────────────────────────────────────────────
class UTKFaceDataset(Dataset):
"""
Returns (image_tensor, gender_label, age_normalised)
gender_label : int 0=Male 1=Female
age_normalised : float in [0, 1] (age / MAX_AGE)
"""
MAX_AGE = 90.0
def __init__(
self,
root_dir: "Union[str, Path]",
split: str = "train",
target_races: Optional[List[int]] = None,
min_age: int = 1,
max_age: int = 90,
train_ratio: float = 0.80,
val_ratio: float = 0.10,
img_size: int = 224,
seed: int = 42,
) -> None:
self.root_dir = Path(root_dir)
self.split = split
self.target_races = set(target_races) if target_races else None
self.min_age = min_age
self.max_age = max_age
self.img_size = img_size
self.transform = train_transforms(img_size) if split == "train" else eval_transforms(img_size)
samples = self._scan()
samples = self._filter(samples)
random.seed(seed)
random.shuffle(samples)
n = len(samples)
n_train = int(n * train_ratio)
n_val = int(n * val_ratio)
if split == "train":
self.samples = samples[:n_train]
elif split == "val":
self.samples = samples[n_train: n_train + n_val]
else: # test
self.samples = samples[n_train + n_val:]
# ── private helpers ────────────────────────────────────────────────────
def _scan(self) -> List[Tuple[Path, int, int, int]]:
"""Return list of (path, age, gender, race)."""
records: List[Tuple[Path, int, int, int]] = []
for p in self.root_dir.glob("*.jpg"):
parts = p.stem.split("_")
if len(parts) < 3:
continue
try:
age = int(parts[0])
gender = int(parts[1])
race = int(parts[2])
except ValueError:
continue
records.append((p, age, gender, race))
return records
def _filter(self, records: List[Tuple[Path, int, int, int]]) -> List[Tuple[Path, int, int, int]]:
out = []
for p, age, gender, race in records:
if age < self.min_age or age > self.max_age:
continue
if gender not in (0, 1):
continue
if self.target_races and race not in self.target_races:
continue
out.append((p, age, gender, race))
return out
# ── public API ─────────────────────────────────────────────────────────
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
path, age, gender, _ = self.samples[idx]
img = Image.open(path).convert("RGB")
img = self.transform(img)
gender_t = torch.tensor(gender, dtype=torch.long)
age_t = torch.tensor(age / self.MAX_AGE, dtype=torch.float32)
return img, gender_t, age_t
def class_weights(self) -> torch.Tensor:
"""Return balanced class weights for gender (0=Male, 1=Female)."""
counts = [0, 0]
for _, _, gender, _ in self.samples:
counts[gender] += 1
total = sum(counts)
weights = torch.tensor([total / (2 * c) for c in counts], dtype=torch.float32)
return weights
@staticmethod
def denorm_age(age_norm: float, max_age: float = 90.0) -> int:
return round(float(age_norm) * max_age)
def build_dataloaders(cfg) -> dict:
"""Build train / val / test DataLoaders from config."""
from torch.utils.data import DataLoader
common = dict(
root_dir = cfg.UTKFACE_DIR,
target_races = cfg.TARGET_RACES,
min_age = cfg.MIN_AGE,
max_age = cfg.MAX_AGE,
train_ratio = cfg.TRAIN_RATIO,
val_ratio = cfg.VAL_RATIO,
img_size = cfg.IMG_SIZE,
seed = cfg.SEED,
)
loaders = {}
for split in ("train", "val", "test"):
ds = UTKFaceDataset(split=split, **common)
loaders[split] = DataLoader(
ds,
batch_size = cfg.BATCH_SIZE,
shuffle = (split == "train"),
num_workers = cfg.NUM_WORKERS,
pin_memory = True,
drop_last = (split == "train"),
)
print(f"[dataset] {split:5s}: {len(ds):,} samples")
return loaders