lab1-cvlface-code / cvlface /research /recognition /code /run_v1 /scripts /prepare_agedb_protocol.py
lhx05's picture
Upload CVLFace experiment code
fb24bef verified
import argparse
import csv
import os
import random
import shutil
from collections import defaultdict
from pathlib import Path
def parse_agedb_name(path):
left, age, gender = path.stem.rsplit('_', 2)
_, identity = left.split('_', 1)
return identity, int(age), gender
def link_or_copy(src, dst, copy_files):
dst.parent.mkdir(parents=True, exist_ok=True)
if dst.exists() or dst.is_symlink():
return
if copy_files:
shutil.copy2(src, dst)
else:
os.symlink(os.path.relpath(src, dst.parent), dst)
def build_positive_pairs(test_rows, max_pairs):
by_identity = defaultdict(list)
for row in test_rows:
by_identity[row['identity']].append(row)
candidates = []
for rows in by_identity.values():
rows = sorted(rows, key=lambda x: (x['age'], x['path'].name))
for i, left in enumerate(rows):
for right in rows[i + 1:]:
age_gap = abs(left['age'] - right['age'])
if age_gap >= 30:
candidates.append((age_gap, left, right))
candidates.sort(key=lambda x: (-x[0], x[1]['path'].name, x[2]['path'].name))
return [(left, right) for _, left, right in candidates[:max_pairs]]
def build_negative_pairs(test_rows, count, seed):
rng = random.Random(seed)
rows = sorted(test_rows, key=lambda x: x['path'].name)
by_identity = defaultdict(list)
for row in rows:
by_identity[row['identity']].append(row)
identities = sorted(by_identity)
pairs = []
used = set()
attempts = 0
while len(pairs) < count and attempts < count * 100:
attempts += 1
left_id, right_id = rng.sample(identities, 2)
left = rng.choice(by_identity[left_id])
right = rng.choice(by_identity[right_id])
key = tuple(sorted([left['path'].name, right['path'].name]))
if key in used:
continue
used.add(key)
pairs.append((left, right))
if len(pairs) < count:
raise RuntimeError(f'Only built {len(pairs)} negative pairs, expected {count}')
return pairs
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--source', default='AgeDB_aligned_224')
parser.add_argument('--output', default='data/agedb_protocol')
parser.add_argument('--train-ratio', type=float, default=0.8)
parser.add_argument('--seed', type=int, default=2048)
parser.add_argument('--max-pairs', type=int, default=3000)
parser.add_argument('--copy-files', action='store_true')
args = parser.parse_args()
source = Path(args.source).resolve()
output = Path(args.output).resolve()
train_dir = output / 'agedb_train_80'
val_dir = output / 'facerec_val' / 'agedb_30_1to1'
val_dir.mkdir(parents=True, exist_ok=True)
by_identity = defaultdict(list)
for path in sorted(source.glob('*.jpg')):
identity, age, gender = parse_agedb_name(path)
by_identity[identity].append({'path': path, 'identity': identity, 'age': age, 'gender': gender})
rng = random.Random(args.seed)
train_rows = []
test_rows = []
for identity, rows in sorted(by_identity.items()):
rows = sorted(rows, key=lambda x: x['path'].name)
indices = list(range(len(rows)))
rng.shuffle(indices)
if len(rows) > 1:
n_test = max(1, int(round(len(rows) * (1 - args.train_ratio))))
else:
n_test = 0
test_indices = set(indices[:n_test])
for idx, row in enumerate(rows):
if idx in test_indices:
test_rows.append(row)
else:
train_rows.append(row)
for row in train_rows:
dst = train_dir / row['identity'] / row['path'].name
link_or_copy(row['path'], dst, args.copy_files)
positive_pairs = build_positive_pairs(test_rows, args.max_pairs)
if not positive_pairs:
raise RuntimeError('No AgeDB-30 positive pairs found in the test split')
negative_pairs = build_negative_pairs(test_rows, len(positive_pairs), args.seed)
pair_rows = []
pair_index = 0
for is_same, pairs in [(True, positive_pairs), (False, negative_pairs)]:
for left, right in pairs:
pair_rows.append({'path': str(left['path']), 'index': pair_index * 2, 'is_same': is_same})
pair_rows.append({'path': str(right['path']), 'index': pair_index * 2 + 1, 'is_same': is_same})
pair_index += 1
with open(val_dir / 'pairs.csv', 'w', newline='') as f:
writer = csv.DictWriter(f, fieldnames=['path', 'index', 'is_same'])
writer.writeheader()
writer.writerows(pair_rows)
print(f'identities: {len(by_identity)}')
print(f'train images: {len(train_rows)}')
print(f'test images: {len(test_rows)}')
print(f'verification pairs: {len(pair_rows) // 2} ({len(positive_pairs)} positive, {len(negative_pairs)} negative)')
print(f'train_dir: {train_dir}')
print(f'val_dir: {val_dir}')
if __name__ == '__main__':
main()