import argparse import csv import os import random import shutil from collections import defaultdict from pathlib import Path def parse_agedb_name(path): left, age, gender = path.stem.rsplit('_', 2) _, identity = left.split('_', 1) return identity, int(age), gender def link_or_copy(src, dst, copy_files): dst.parent.mkdir(parents=True, exist_ok=True) if dst.exists() or dst.is_symlink(): return if copy_files: shutil.copy2(src, dst) else: os.symlink(os.path.relpath(src, dst.parent), dst) def build_positive_pairs(test_rows, max_pairs): by_identity = defaultdict(list) for row in test_rows: by_identity[row['identity']].append(row) candidates = [] for rows in by_identity.values(): rows = sorted(rows, key=lambda x: (x['age'], x['path'].name)) for i, left in enumerate(rows): for right in rows[i + 1:]: age_gap = abs(left['age'] - right['age']) if age_gap >= 30: candidates.append((age_gap, left, right)) candidates.sort(key=lambda x: (-x[0], x[1]['path'].name, x[2]['path'].name)) return [(left, right) for _, left, right in candidates[:max_pairs]] def build_negative_pairs(test_rows, count, seed): rng = random.Random(seed) rows = sorted(test_rows, key=lambda x: x['path'].name) by_identity = defaultdict(list) for row in rows: by_identity[row['identity']].append(row) identities = sorted(by_identity) pairs = [] used = set() attempts = 0 while len(pairs) < count and attempts < count * 100: attempts += 1 left_id, right_id = rng.sample(identities, 2) left = rng.choice(by_identity[left_id]) right = rng.choice(by_identity[right_id]) key = tuple(sorted([left['path'].name, right['path'].name])) if key in used: continue used.add(key) pairs.append((left, right)) if len(pairs) < count: raise RuntimeError(f'Only built {len(pairs)} negative pairs, expected {count}') return pairs def main(): parser = argparse.ArgumentParser() parser.add_argument('--source', default='AgeDB_aligned_224') parser.add_argument('--output', default='data/agedb_protocol') parser.add_argument('--train-ratio', type=float, default=0.8) parser.add_argument('--seed', type=int, default=2048) parser.add_argument('--max-pairs', type=int, default=3000) parser.add_argument('--copy-files', action='store_true') args = parser.parse_args() source = Path(args.source).resolve() output = Path(args.output).resolve() train_dir = output / 'agedb_train_80' val_dir = output / 'facerec_val' / 'agedb_30_1to1' val_dir.mkdir(parents=True, exist_ok=True) by_identity = defaultdict(list) for path in sorted(source.glob('*.jpg')): identity, age, gender = parse_agedb_name(path) by_identity[identity].append({'path': path, 'identity': identity, 'age': age, 'gender': gender}) rng = random.Random(args.seed) train_rows = [] test_rows = [] for identity, rows in sorted(by_identity.items()): rows = sorted(rows, key=lambda x: x['path'].name) indices = list(range(len(rows))) rng.shuffle(indices) if len(rows) > 1: n_test = max(1, int(round(len(rows) * (1 - args.train_ratio)))) else: n_test = 0 test_indices = set(indices[:n_test]) for idx, row in enumerate(rows): if idx in test_indices: test_rows.append(row) else: train_rows.append(row) for row in train_rows: dst = train_dir / row['identity'] / row['path'].name link_or_copy(row['path'], dst, args.copy_files) positive_pairs = build_positive_pairs(test_rows, args.max_pairs) if not positive_pairs: raise RuntimeError('No AgeDB-30 positive pairs found in the test split') negative_pairs = build_negative_pairs(test_rows, len(positive_pairs), args.seed) pair_rows = [] pair_index = 0 for is_same, pairs in [(True, positive_pairs), (False, negative_pairs)]: for left, right in pairs: pair_rows.append({'path': str(left['path']), 'index': pair_index * 2, 'is_same': is_same}) pair_rows.append({'path': str(right['path']), 'index': pair_index * 2 + 1, 'is_same': is_same}) pair_index += 1 with open(val_dir / 'pairs.csv', 'w', newline='') as f: writer = csv.DictWriter(f, fieldnames=['path', 'index', 'is_same']) writer.writeheader() writer.writerows(pair_rows) print(f'identities: {len(by_identity)}') print(f'train images: {len(train_rows)}') print(f'test images: {len(test_rows)}') print(f'verification pairs: {len(pair_rows) // 2} ({len(positive_pairs)} positive, {len(negative_pairs)} negative)') print(f'train_dir: {train_dir}') print(f'val_dir: {val_dir}') if __name__ == '__main__': main()