lab1-cvlface-code / cvlface /research /recognition /code /run_v1 /scripts /prepare_agedb_protocol.py
| import argparse | |
| import csv | |
| import os | |
| import random | |
| import shutil | |
| from collections import defaultdict | |
| from pathlib import Path | |
| def parse_agedb_name(path): | |
| left, age, gender = path.stem.rsplit('_', 2) | |
| _, identity = left.split('_', 1) | |
| return identity, int(age), gender | |
| def link_or_copy(src, dst, copy_files): | |
| dst.parent.mkdir(parents=True, exist_ok=True) | |
| if dst.exists() or dst.is_symlink(): | |
| return | |
| if copy_files: | |
| shutil.copy2(src, dst) | |
| else: | |
| os.symlink(os.path.relpath(src, dst.parent), dst) | |
| def build_positive_pairs(test_rows, max_pairs): | |
| by_identity = defaultdict(list) | |
| for row in test_rows: | |
| by_identity[row['identity']].append(row) | |
| candidates = [] | |
| for rows in by_identity.values(): | |
| rows = sorted(rows, key=lambda x: (x['age'], x['path'].name)) | |
| for i, left in enumerate(rows): | |
| for right in rows[i + 1:]: | |
| age_gap = abs(left['age'] - right['age']) | |
| if age_gap >= 30: | |
| candidates.append((age_gap, left, right)) | |
| candidates.sort(key=lambda x: (-x[0], x[1]['path'].name, x[2]['path'].name)) | |
| return [(left, right) for _, left, right in candidates[:max_pairs]] | |
| def build_negative_pairs(test_rows, count, seed): | |
| rng = random.Random(seed) | |
| rows = sorted(test_rows, key=lambda x: x['path'].name) | |
| by_identity = defaultdict(list) | |
| for row in rows: | |
| by_identity[row['identity']].append(row) | |
| identities = sorted(by_identity) | |
| pairs = [] | |
| used = set() | |
| attempts = 0 | |
| while len(pairs) < count and attempts < count * 100: | |
| attempts += 1 | |
| left_id, right_id = rng.sample(identities, 2) | |
| left = rng.choice(by_identity[left_id]) | |
| right = rng.choice(by_identity[right_id]) | |
| key = tuple(sorted([left['path'].name, right['path'].name])) | |
| if key in used: | |
| continue | |
| used.add(key) | |
| pairs.append((left, right)) | |
| if len(pairs) < count: | |
| raise RuntimeError(f'Only built {len(pairs)} negative pairs, expected {count}') | |
| return pairs | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--source', default='AgeDB_aligned_224') | |
| parser.add_argument('--output', default='data/agedb_protocol') | |
| parser.add_argument('--train-ratio', type=float, default=0.8) | |
| parser.add_argument('--seed', type=int, default=2048) | |
| parser.add_argument('--max-pairs', type=int, default=3000) | |
| parser.add_argument('--copy-files', action='store_true') | |
| args = parser.parse_args() | |
| source = Path(args.source).resolve() | |
| output = Path(args.output).resolve() | |
| train_dir = output / 'agedb_train_80' | |
| val_dir = output / 'facerec_val' / 'agedb_30_1to1' | |
| val_dir.mkdir(parents=True, exist_ok=True) | |
| by_identity = defaultdict(list) | |
| for path in sorted(source.glob('*.jpg')): | |
| identity, age, gender = parse_agedb_name(path) | |
| by_identity[identity].append({'path': path, 'identity': identity, 'age': age, 'gender': gender}) | |
| rng = random.Random(args.seed) | |
| train_rows = [] | |
| test_rows = [] | |
| for identity, rows in sorted(by_identity.items()): | |
| rows = sorted(rows, key=lambda x: x['path'].name) | |
| indices = list(range(len(rows))) | |
| rng.shuffle(indices) | |
| if len(rows) > 1: | |
| n_test = max(1, int(round(len(rows) * (1 - args.train_ratio)))) | |
| else: | |
| n_test = 0 | |
| test_indices = set(indices[:n_test]) | |
| for idx, row in enumerate(rows): | |
| if idx in test_indices: | |
| test_rows.append(row) | |
| else: | |
| train_rows.append(row) | |
| for row in train_rows: | |
| dst = train_dir / row['identity'] / row['path'].name | |
| link_or_copy(row['path'], dst, args.copy_files) | |
| positive_pairs = build_positive_pairs(test_rows, args.max_pairs) | |
| if not positive_pairs: | |
| raise RuntimeError('No AgeDB-30 positive pairs found in the test split') | |
| negative_pairs = build_negative_pairs(test_rows, len(positive_pairs), args.seed) | |
| pair_rows = [] | |
| pair_index = 0 | |
| for is_same, pairs in [(True, positive_pairs), (False, negative_pairs)]: | |
| for left, right in pairs: | |
| pair_rows.append({'path': str(left['path']), 'index': pair_index * 2, 'is_same': is_same}) | |
| pair_rows.append({'path': str(right['path']), 'index': pair_index * 2 + 1, 'is_same': is_same}) | |
| pair_index += 1 | |
| with open(val_dir / 'pairs.csv', 'w', newline='') as f: | |
| writer = csv.DictWriter(f, fieldnames=['path', 'index', 'is_same']) | |
| writer.writeheader() | |
| writer.writerows(pair_rows) | |
| print(f'identities: {len(by_identity)}') | |
| print(f'train images: {len(train_rows)}') | |
| print(f'test images: {len(test_rows)}') | |
| print(f'verification pairs: {len(pair_rows) // 2} ({len(positive_pairs)} positive, {len(negative_pairs)} negative)') | |
| print(f'train_dir: {train_dir}') | |
| print(f'val_dir: {val_dir}') | |
| if __name__ == '__main__': | |
| main() | |