""" scripts/data_setup.py Utilities to fetch real face images from Unsplash and generate morphed images by averaging. """ import os import uuid import random try: import requests except ImportError: requests = None from PIL import Image import numpy as np # Base data directory DATA_DIR = os.path.join(os.getcwd(), 'data') def fetch_real_faces(count: int, split: str = 'train') -> int: """ Fetch random face images from Unsplash into data//real/. Requires environment variable UNSPLASH_ACCESS_KEY to be set. Returns the number of images downloaded. """ if requests is None: raise ImportError('requests library is required to fetch images') access_key = os.getenv('UNSPLASH_ACCESS_KEY') if not access_key: raise EnvironmentError('UNSPLASH_ACCESS_KEY environment variable not set') save_dir = os.path.join(DATA_DIR, split, 'real') os.makedirs(save_dir, exist_ok=True) # Unsplash random photo endpoint url = 'https://api.unsplash.com/photos/random' params = { 'client_id': access_key, 'query': 'face portrait', 'count': count } resp = requests.get(url, params=params) if resp.status_code != 200: raise RuntimeError(f'Unsplash API error {resp.status_code}: {resp.text}') items = resp.json() downloaded = 0 for item in items: img_url = item.get('urls', {}).get('small') if not img_url: continue img_data = requests.get(img_url).content fname = os.path.join(save_dir, f"{uuid.uuid4()}.jpg") with open(fname, 'wb') as f: f.write(img_data) downloaded += 1 return downloaded def generate_morphs(count: int, split: str = 'train') -> int: """ Generate morphed images by averaging random pairs from data//real/, saving to data//morph/. Returns number generated. """ real_dir = os.path.join(DATA_DIR, split, 'real') morph_dir = os.path.join(DATA_DIR, split, 'morph') os.makedirs(morph_dir, exist_ok=True) # Collect real image files files = [f for f in os.listdir(real_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] if len(files) < 2: raise RuntimeError('Not enough real images to generate morphs') generated = 0 for _ in range(count): a, b = random.sample(files, 2) path_a = os.path.join(real_dir, a) path_b = os.path.join(real_dir, b) img_a = Image.open(path_a).convert('RGB').resize((224, 224)) img_b = Image.open(path_b).convert('RGB').resize((224, 224)) arr_a = np.array(img_a).astype(np.float32) arr_b = np.array(img_b).astype(np.float32) arr = ((arr_a + arr_b) / 2.0).astype(np.uint8) img_m = Image.fromarray(arr) fname = os.path.join(morph_dir, f"{uuid.uuid4()}.jpg") img_m.save(fname) generated += 1 return generated try: from sklearn.datasets import fetch_lfw_people except ImportError: fetch_lfw_people = None def fetch_lfw(count: int, split: str = 'train') -> int: """ Download images from the LFW (Labeled Faces in the Wild) dataset. Saves up to `count` RGB images into data//real/. Requires scikit-learn to be installed. Returns the number of images saved. """ if fetch_lfw_people is None: # scikit-learn not installed: report error raise ImportError('scikit-learn is required to fetch LFW dataset') # fetch grayscale images lfw = fetch_lfw_people(min_faces_per_person=1, resize=0.5, color=False) images = lfw.images num = min(count, len(images)) save_dir = os.path.join(DATA_DIR, split, 'real') os.makedirs(save_dir, exist_ok=True) for idx in range(num): img_gray = images[idx] # convert to RGB by stacking arr3 = np.stack([img_gray] * 3, axis=-1) pil_img = Image.fromarray((arr3 * 255).astype(np.uint8)) pil_img = pil_img.resize((224, 224)) fname = os.path.join(save_dir, f"lfw_{idx}_{uuid.uuid4().hex}.jpg") pil_img.save(fname) return num def fetch_pexels(count: int, split: str = 'train') -> int: """ Fetch random face images from Pexels API into data//real/. Requires environment variable PEXELS_API_KEY to be set. Returns the number of images downloaded. """ try: import requests except ImportError: raise ImportError('requests library is required to fetch Pexels images') api_key = os.getenv('PEXELS_API_KEY') if not api_key: raise EnvironmentError('PEXELS_API_KEY environment variable not set') save_dir = os.path.join(DATA_DIR, split, 'real') os.makedirs(save_dir, exist_ok=True) url = 'https://api.pexels.com/v1/search' headers = {'Authorization': api_key} params = {'query': 'face', 'per_page': count} resp = requests.get(url, headers=headers, params=params) if resp.status_code != 200: raise RuntimeError(f'Pexels API error {resp.status_code}: {resp.text}') data = resp.json() photos = data.get('photos', []) downloaded = 0 for photo in photos: img_url = photo.get('src', {}).get('medium') if not img_url: continue img_data = requests.get(img_url).content fname = os.path.join(save_dir, f"{uuid.uuid4()}.jpg") with open(fname, 'wb') as f: f.write(img_data) downloaded += 1 return downloaded def fetch_pixabay(count: int, split: str = 'train') -> int: """ Fetch random face images from Pixabay API into data//real/. Requires environment variable PIXABAY_API_KEY to be set. Returns the number of images downloaded. """ try: import requests except ImportError: raise ImportError('requests library is required to fetch Pixabay images') api_key = os.getenv('PIXABAY_API_KEY') if not api_key: raise EnvironmentError('PIXABAY_API_KEY environment variable not set') save_dir = os.path.join(DATA_DIR, split, 'real') os.makedirs(save_dir, exist_ok=True) url = 'https://pixabay.com/api/' params = {'key': api_key, 'q': 'face', 'image_type': 'photo', 'per_page': count} resp = requests.get(url, params=params) if resp.status_code != 200: raise RuntimeError(f'Pixabay API error {resp.status_code}: {resp.text}') data = resp.json() hits = data.get('hits', []) downloaded = 0 for hit in hits: img_url = hit.get('webformatURL') if not img_url: continue img_data = requests.get(img_url).content fname = os.path.join(save_dir, f"{uuid.uuid4()}.jpg") with open(fname, 'wb') as f: f.write(img_data) downloaded += 1 return downloaded def fetch_utkface(count: int, split: str = 'train') -> int: """ Fetch the UTKFace dataset (faces only) from an S3 archive. Downloads and extracts up to `count` JPEGs into data//real/. No API key required. Returns the number of images extracted. """ import requests, tarfile, tempfile url = 'https://s3-us-west-1.amazonaws.com/utkface/UTKFace.tar.gz' save_dir = os.path.join(DATA_DIR, split, 'real') os.makedirs(save_dir, exist_ok=True) # Download UTKFace archive (follow redirects automatically) resp = requests.get(url, stream=True) # Handle moved-permanently redirect for outdated S3 endpoint if resp.status_code == 301: raise RuntimeError( 'UTKFace download URL has moved (HTTP 301). ' 'Please download the UTKFace dataset manually from https://susanqq.github.io/UTKFace/ ' f'and extract {count} images into data/{split}/real/' ) if resp.status_code != 200: raise RuntimeError(f'UTKFace download error {resp.status_code}') tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.tar.gz') for chunk in resp.iter_content(chunk_size=1024*1024): if chunk: tmp.write(chunk) tmp.close() # Extract JPEGs extracted = 0 with tarfile.open(tmp.name, 'r:gz') as tar: for member in tar.getmembers(): if extracted >= count: break if member.isfile() and member.name.lower().endswith('.jpg'): f = tar.extractfile(member) if f: outpath = os.path.join(save_dir, os.path.basename(member.name)) with open(outpath, 'wb') as out: out.write(f.read()) extracted += 1 return extracted def fetch_tpdne(count: int, split: str = 'train') -> int: """ Fetch GAN-generated faces from thispersondoesnotexist.com into data//real/. No API key required. Returns the number of images downloaded. """ try: import requests except ImportError: raise ImportError('requests library is required to fetch GAN images') save_dir = os.path.join(DATA_DIR, split, 'real') os.makedirs(save_dir, exist_ok=True) downloaded = 0 for i in range(count): resp = requests.get('https://thispersondoesnotexist.com/image', timeout=5) if resp.status_code != 200: continue fname = os.path.join(save_dir, f"tpdne_{uuid.uuid4().hex}.jpg") with open(fname, 'wb') as f: f.write(resp.content) downloaded += 1 return downloaded def fetch_celeba(count: int, split: str = 'train') -> int: """ Fetch a sample of the CelebA dataset via Kaggle CLI. Requires Kaggle CLI installed and KAGGLE_USERNAME/KAGGLE_KEY set. Downloads and unzips full CelebA into data/raw/celeba/, then copies up to `count` images into data//real/. """ import subprocess, glob, random, shutil raw_dir = os.path.join(DATA_DIR, 'raw', 'celeba') save_dir = os.path.join(DATA_DIR, split, 'real') os.makedirs(save_dir, exist_ok=True) # Download and unzip if not already present if not os.path.isdir(raw_dir): os.makedirs(raw_dir, exist_ok=True) cmd = [ 'kaggle', 'datasets', 'download', '-d', 'jessicali9530/celeba-dataset', '-p', raw_dir, '--unzip' ] subprocess.run(cmd, check=True) # Collect all images img_paths = glob.glob(os.path.join(raw_dir, '**', '*.jpg'), recursive=True) if not img_paths: raise RuntimeError('No CelebA images found in raw directory') # Randomly sample up to count chosen = random.sample(img_paths, min(count, len(img_paths))) copied = 0 for src in chosen: dst = os.path.join(save_dir, os.path.basename(src)) if not os.path.exists(dst): shutil.copy2(src, dst) copied += 1 return copied def fetch_vggface2(count: int, split: str = 'train') -> int: """ Stub for VGGFace2 dataset: manual download required. Download from https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/ and extract into data//real/vggface2/. This function does not automate download. """ raise EnvironmentError( 'VGGFace2 requires manual download: fetch from https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/ ' 'and extract images into data/{}/real/vggface2/'.format(split) )