Spaces:
Running
Running
| """ | |
| scripts/data_setup.py | |
| Utilities to fetch real face images from Unsplash and generate morphed images by averaging. | |
| """ | |
| import os | |
| import uuid | |
| import random | |
| try: | |
| import requests | |
| except ImportError: | |
| requests = None | |
| from PIL import Image | |
| import numpy as np | |
| # Base data directory | |
| DATA_DIR = os.path.join(os.getcwd(), 'data') | |
| def fetch_real_faces(count: int, split: str = 'train') -> int: | |
| """ | |
| Fetch random face images from Unsplash into data/<split>/real/. | |
| Requires environment variable UNSPLASH_ACCESS_KEY to be set. | |
| Returns the number of images downloaded. | |
| """ | |
| if requests is None: | |
| raise ImportError('requests library is required to fetch images') | |
| access_key = os.getenv('UNSPLASH_ACCESS_KEY') | |
| if not access_key: | |
| raise EnvironmentError('UNSPLASH_ACCESS_KEY environment variable not set') | |
| save_dir = os.path.join(DATA_DIR, split, 'real') | |
| os.makedirs(save_dir, exist_ok=True) | |
| # Unsplash random photo endpoint | |
| url = 'https://api.unsplash.com/photos/random' | |
| params = { | |
| 'client_id': access_key, | |
| 'query': 'face portrait', | |
| 'count': count | |
| } | |
| resp = requests.get(url, params=params) | |
| if resp.status_code != 200: | |
| raise RuntimeError(f'Unsplash API error {resp.status_code}: {resp.text}') | |
| items = resp.json() | |
| downloaded = 0 | |
| for item in items: | |
| img_url = item.get('urls', {}).get('small') | |
| if not img_url: | |
| continue | |
| img_data = requests.get(img_url).content | |
| fname = os.path.join(save_dir, f"{uuid.uuid4()}.jpg") | |
| with open(fname, 'wb') as f: | |
| f.write(img_data) | |
| downloaded += 1 | |
| return downloaded | |
| def generate_morphs(count: int, split: str = 'train') -> int: | |
| """ | |
| Generate morphed images by averaging random pairs from data/<split>/real/, | |
| saving to data/<split>/morph/. Returns number generated. | |
| """ | |
| real_dir = os.path.join(DATA_DIR, split, 'real') | |
| morph_dir = os.path.join(DATA_DIR, split, 'morph') | |
| os.makedirs(morph_dir, exist_ok=True) | |
| # Collect real image files | |
| files = [f for f in os.listdir(real_dir) | |
| if f.lower().endswith(('.png', '.jpg', '.jpeg'))] | |
| if len(files) < 2: | |
| raise RuntimeError('Not enough real images to generate morphs') | |
| generated = 0 | |
| for _ in range(count): | |
| a, b = random.sample(files, 2) | |
| path_a = os.path.join(real_dir, a) | |
| path_b = os.path.join(real_dir, b) | |
| img_a = Image.open(path_a).convert('RGB').resize((224, 224)) | |
| img_b = Image.open(path_b).convert('RGB').resize((224, 224)) | |
| arr_a = np.array(img_a).astype(np.float32) | |
| arr_b = np.array(img_b).astype(np.float32) | |
| arr = ((arr_a + arr_b) / 2.0).astype(np.uint8) | |
| img_m = Image.fromarray(arr) | |
| fname = os.path.join(morph_dir, f"{uuid.uuid4()}.jpg") | |
| img_m.save(fname) | |
| generated += 1 | |
| return generated | |
| try: | |
| from sklearn.datasets import fetch_lfw_people | |
| except ImportError: | |
| fetch_lfw_people = None | |
| def fetch_lfw(count: int, split: str = 'train') -> int: | |
| """ | |
| Download images from the LFW (Labeled Faces in the Wild) dataset. | |
| Saves up to `count` RGB images into data/<split>/real/. | |
| Requires scikit-learn to be installed. | |
| Returns the number of images saved. | |
| """ | |
| if fetch_lfw_people is None: | |
| # scikit-learn not installed: report error | |
| raise ImportError('scikit-learn is required to fetch LFW dataset') | |
| # fetch grayscale images | |
| lfw = fetch_lfw_people(min_faces_per_person=1, resize=0.5, color=False) | |
| images = lfw.images | |
| num = min(count, len(images)) | |
| save_dir = os.path.join(DATA_DIR, split, 'real') | |
| os.makedirs(save_dir, exist_ok=True) | |
| for idx in range(num): | |
| img_gray = images[idx] | |
| # convert to RGB by stacking | |
| arr3 = np.stack([img_gray] * 3, axis=-1) | |
| pil_img = Image.fromarray((arr3 * 255).astype(np.uint8)) | |
| pil_img = pil_img.resize((224, 224)) | |
| fname = os.path.join(save_dir, f"lfw_{idx}_{uuid.uuid4().hex}.jpg") | |
| pil_img.save(fname) | |
| return num | |
| def fetch_pexels(count: int, split: str = 'train') -> int: | |
| """ | |
| Fetch random face images from Pexels API into data/<split>/real/. | |
| Requires environment variable PEXELS_API_KEY to be set. | |
| Returns the number of images downloaded. | |
| """ | |
| try: | |
| import requests | |
| except ImportError: | |
| raise ImportError('requests library is required to fetch Pexels images') | |
| api_key = os.getenv('PEXELS_API_KEY') | |
| if not api_key: | |
| raise EnvironmentError('PEXELS_API_KEY environment variable not set') | |
| save_dir = os.path.join(DATA_DIR, split, 'real') | |
| os.makedirs(save_dir, exist_ok=True) | |
| url = 'https://api.pexels.com/v1/search' | |
| headers = {'Authorization': api_key} | |
| params = {'query': 'face', 'per_page': count} | |
| resp = requests.get(url, headers=headers, params=params) | |
| if resp.status_code != 200: | |
| raise RuntimeError(f'Pexels API error {resp.status_code}: {resp.text}') | |
| data = resp.json() | |
| photos = data.get('photos', []) | |
| downloaded = 0 | |
| for photo in photos: | |
| img_url = photo.get('src', {}).get('medium') | |
| if not img_url: | |
| continue | |
| img_data = requests.get(img_url).content | |
| fname = os.path.join(save_dir, f"{uuid.uuid4()}.jpg") | |
| with open(fname, 'wb') as f: | |
| f.write(img_data) | |
| downloaded += 1 | |
| return downloaded | |
| def fetch_pixabay(count: int, split: str = 'train') -> int: | |
| """ | |
| Fetch random face images from Pixabay API into data/<split>/real/. | |
| Requires environment variable PIXABAY_API_KEY to be set. | |
| Returns the number of images downloaded. | |
| """ | |
| try: | |
| import requests | |
| except ImportError: | |
| raise ImportError('requests library is required to fetch Pixabay images') | |
| api_key = os.getenv('PIXABAY_API_KEY') | |
| if not api_key: | |
| raise EnvironmentError('PIXABAY_API_KEY environment variable not set') | |
| save_dir = os.path.join(DATA_DIR, split, 'real') | |
| os.makedirs(save_dir, exist_ok=True) | |
| url = 'https://pixabay.com/api/' | |
| params = {'key': api_key, 'q': 'face', 'image_type': 'photo', 'per_page': count} | |
| resp = requests.get(url, params=params) | |
| if resp.status_code != 200: | |
| raise RuntimeError(f'Pixabay API error {resp.status_code}: {resp.text}') | |
| data = resp.json() | |
| hits = data.get('hits', []) | |
| downloaded = 0 | |
| for hit in hits: | |
| img_url = hit.get('webformatURL') | |
| if not img_url: | |
| continue | |
| img_data = requests.get(img_url).content | |
| fname = os.path.join(save_dir, f"{uuid.uuid4()}.jpg") | |
| with open(fname, 'wb') as f: | |
| f.write(img_data) | |
| downloaded += 1 | |
| return downloaded | |
| def fetch_utkface(count: int, split: str = 'train') -> int: | |
| """ | |
| Fetch the UTKFace dataset (faces only) from an S3 archive. | |
| Downloads and extracts up to `count` JPEGs into data/<split>/real/. | |
| No API key required. | |
| Returns the number of images extracted. | |
| """ | |
| import requests, tarfile, tempfile | |
| url = 'https://s3-us-west-1.amazonaws.com/utkface/UTKFace.tar.gz' | |
| save_dir = os.path.join(DATA_DIR, split, 'real') | |
| os.makedirs(save_dir, exist_ok=True) | |
| # Download UTKFace archive (follow redirects automatically) | |
| resp = requests.get(url, stream=True) | |
| # Handle moved-permanently redirect for outdated S3 endpoint | |
| if resp.status_code == 301: | |
| raise RuntimeError( | |
| 'UTKFace download URL has moved (HTTP 301). ' | |
| 'Please download the UTKFace dataset manually from https://susanqq.github.io/UTKFace/ ' | |
| f'and extract {count} images into data/{split}/real/' | |
| ) | |
| if resp.status_code != 200: | |
| raise RuntimeError(f'UTKFace download error {resp.status_code}') | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.tar.gz') | |
| for chunk in resp.iter_content(chunk_size=1024*1024): | |
| if chunk: | |
| tmp.write(chunk) | |
| tmp.close() | |
| # Extract JPEGs | |
| extracted = 0 | |
| with tarfile.open(tmp.name, 'r:gz') as tar: | |
| for member in tar.getmembers(): | |
| if extracted >= count: | |
| break | |
| if member.isfile() and member.name.lower().endswith('.jpg'): | |
| f = tar.extractfile(member) | |
| if f: | |
| outpath = os.path.join(save_dir, os.path.basename(member.name)) | |
| with open(outpath, 'wb') as out: | |
| out.write(f.read()) | |
| extracted += 1 | |
| return extracted | |
| def fetch_tpdne(count: int, split: str = 'train') -> int: | |
| """ | |
| Fetch GAN-generated faces from thispersondoesnotexist.com | |
| into data/<split>/real/. No API key required. | |
| Returns the number of images downloaded. | |
| """ | |
| try: | |
| import requests | |
| except ImportError: | |
| raise ImportError('requests library is required to fetch GAN images') | |
| save_dir = os.path.join(DATA_DIR, split, 'real') | |
| os.makedirs(save_dir, exist_ok=True) | |
| downloaded = 0 | |
| for i in range(count): | |
| resp = requests.get('https://thispersondoesnotexist.com/image', timeout=5) | |
| if resp.status_code != 200: | |
| continue | |
| fname = os.path.join(save_dir, f"tpdne_{uuid.uuid4().hex}.jpg") | |
| with open(fname, 'wb') as f: | |
| f.write(resp.content) | |
| downloaded += 1 | |
| return downloaded | |
| def fetch_celeba(count: int, split: str = 'train') -> int: | |
| """ | |
| Fetch a sample of the CelebA dataset via Kaggle CLI. | |
| Requires Kaggle CLI installed and KAGGLE_USERNAME/KAGGLE_KEY set. | |
| Downloads and unzips full CelebA into data/raw/celeba/, then copies up to `count` images into data/<split>/real/. | |
| """ | |
| import subprocess, glob, random, shutil | |
| raw_dir = os.path.join(DATA_DIR, 'raw', 'celeba') | |
| save_dir = os.path.join(DATA_DIR, split, 'real') | |
| os.makedirs(save_dir, exist_ok=True) | |
| # Download and unzip if not already present | |
| if not os.path.isdir(raw_dir): | |
| os.makedirs(raw_dir, exist_ok=True) | |
| cmd = [ | |
| 'kaggle', 'datasets', 'download', '-d', 'jessicali9530/celeba-dataset', | |
| '-p', raw_dir, '--unzip' | |
| ] | |
| subprocess.run(cmd, check=True) | |
| # Collect all images | |
| img_paths = glob.glob(os.path.join(raw_dir, '**', '*.jpg'), recursive=True) | |
| if not img_paths: | |
| raise RuntimeError('No CelebA images found in raw directory') | |
| # Randomly sample up to count | |
| chosen = random.sample(img_paths, min(count, len(img_paths))) | |
| copied = 0 | |
| for src in chosen: | |
| dst = os.path.join(save_dir, os.path.basename(src)) | |
| if not os.path.exists(dst): | |
| shutil.copy2(src, dst) | |
| copied += 1 | |
| return copied | |
| def fetch_vggface2(count: int, split: str = 'train') -> int: | |
| """ | |
| Stub for VGGFace2 dataset: manual download required. | |
| Download from https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/ and extract into data/<split>/real/vggface2/. | |
| This function does not automate download. | |
| """ | |
| raise EnvironmentError( | |
| 'VGGFace2 requires manual download: fetch from https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/ ' | |
| 'and extract images into data/{}/real/vggface2/'.format(split) | |
| ) |