File size: 5,005 Bytes
fb24bef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | import argparse
import csv
import os
import random
import shutil
from collections import defaultdict
from pathlib import Path
def parse_agedb_name(path):
left, age, gender = path.stem.rsplit('_', 2)
_, identity = left.split('_', 1)
return identity, int(age), gender
def link_or_copy(src, dst, copy_files):
dst.parent.mkdir(parents=True, exist_ok=True)
if dst.exists() or dst.is_symlink():
return
if copy_files:
shutil.copy2(src, dst)
else:
os.symlink(os.path.relpath(src, dst.parent), dst)
def build_positive_pairs(test_rows, max_pairs):
by_identity = defaultdict(list)
for row in test_rows:
by_identity[row['identity']].append(row)
candidates = []
for rows in by_identity.values():
rows = sorted(rows, key=lambda x: (x['age'], x['path'].name))
for i, left in enumerate(rows):
for right in rows[i + 1:]:
age_gap = abs(left['age'] - right['age'])
if age_gap >= 30:
candidates.append((age_gap, left, right))
candidates.sort(key=lambda x: (-x[0], x[1]['path'].name, x[2]['path'].name))
return [(left, right) for _, left, right in candidates[:max_pairs]]
def build_negative_pairs(test_rows, count, seed):
rng = random.Random(seed)
rows = sorted(test_rows, key=lambda x: x['path'].name)
by_identity = defaultdict(list)
for row in rows:
by_identity[row['identity']].append(row)
identities = sorted(by_identity)
pairs = []
used = set()
attempts = 0
while len(pairs) < count and attempts < count * 100:
attempts += 1
left_id, right_id = rng.sample(identities, 2)
left = rng.choice(by_identity[left_id])
right = rng.choice(by_identity[right_id])
key = tuple(sorted([left['path'].name, right['path'].name]))
if key in used:
continue
used.add(key)
pairs.append((left, right))
if len(pairs) < count:
raise RuntimeError(f'Only built {len(pairs)} negative pairs, expected {count}')
return pairs
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--source', default='AgeDB_aligned_224')
parser.add_argument('--output', default='data/agedb_protocol')
parser.add_argument('--train-ratio', type=float, default=0.8)
parser.add_argument('--seed', type=int, default=2048)
parser.add_argument('--max-pairs', type=int, default=3000)
parser.add_argument('--copy-files', action='store_true')
args = parser.parse_args()
source = Path(args.source).resolve()
output = Path(args.output).resolve()
train_dir = output / 'agedb_train_80'
val_dir = output / 'facerec_val' / 'agedb_30_1to1'
val_dir.mkdir(parents=True, exist_ok=True)
by_identity = defaultdict(list)
for path in sorted(source.glob('*.jpg')):
identity, age, gender = parse_agedb_name(path)
by_identity[identity].append({'path': path, 'identity': identity, 'age': age, 'gender': gender})
rng = random.Random(args.seed)
train_rows = []
test_rows = []
for identity, rows in sorted(by_identity.items()):
rows = sorted(rows, key=lambda x: x['path'].name)
indices = list(range(len(rows)))
rng.shuffle(indices)
if len(rows) > 1:
n_test = max(1, int(round(len(rows) * (1 - args.train_ratio))))
else:
n_test = 0
test_indices = set(indices[:n_test])
for idx, row in enumerate(rows):
if idx in test_indices:
test_rows.append(row)
else:
train_rows.append(row)
for row in train_rows:
dst = train_dir / row['identity'] / row['path'].name
link_or_copy(row['path'], dst, args.copy_files)
positive_pairs = build_positive_pairs(test_rows, args.max_pairs)
if not positive_pairs:
raise RuntimeError('No AgeDB-30 positive pairs found in the test split')
negative_pairs = build_negative_pairs(test_rows, len(positive_pairs), args.seed)
pair_rows = []
pair_index = 0
for is_same, pairs in [(True, positive_pairs), (False, negative_pairs)]:
for left, right in pairs:
pair_rows.append({'path': str(left['path']), 'index': pair_index * 2, 'is_same': is_same})
pair_rows.append({'path': str(right['path']), 'index': pair_index * 2 + 1, 'is_same': is_same})
pair_index += 1
with open(val_dir / 'pairs.csv', 'w', newline='') as f:
writer = csv.DictWriter(f, fieldnames=['path', 'index', 'is_same'])
writer.writeheader()
writer.writerows(pair_rows)
print(f'identities: {len(by_identity)}')
print(f'train images: {len(train_rows)}')
print(f'test images: {len(test_rows)}')
print(f'verification pairs: {len(pair_rows) // 2} ({len(positive_pairs)} positive, {len(negative_pairs)} negative)')
print(f'train_dir: {train_dir}')
print(f'val_dir: {val_dir}')
if __name__ == '__main__':
main()
|