MorphGuard / scripts /data_setup.py
juanquy's picture
Initial clean commit of modular MorphGuard
2978bba
Raw
History Blame Contribute Delete
11.3 kB
"""
scripts/data_setup.py
Utilities to fetch real face images from Unsplash and generate morphed images by averaging.
"""
import os
import uuid
import random
try:
import requests
except ImportError:
requests = None
from PIL import Image
import numpy as np
# Base data directory
DATA_DIR = os.path.join(os.getcwd(), 'data')
def fetch_real_faces(count: int, split: str = 'train') -> int:
"""
Fetch random face images from Unsplash into data/<split>/real/.
Requires environment variable UNSPLASH_ACCESS_KEY to be set.
Returns the number of images downloaded.
"""
if requests is None:
raise ImportError('requests library is required to fetch images')
access_key = os.getenv('UNSPLASH_ACCESS_KEY')
if not access_key:
raise EnvironmentError('UNSPLASH_ACCESS_KEY environment variable not set')
save_dir = os.path.join(DATA_DIR, split, 'real')
os.makedirs(save_dir, exist_ok=True)
# Unsplash random photo endpoint
url = 'https://api.unsplash.com/photos/random'
params = {
'client_id': access_key,
'query': 'face portrait',
'count': count
}
resp = requests.get(url, params=params)
if resp.status_code != 200:
raise RuntimeError(f'Unsplash API error {resp.status_code}: {resp.text}')
items = resp.json()
downloaded = 0
for item in items:
img_url = item.get('urls', {}).get('small')
if not img_url:
continue
img_data = requests.get(img_url).content
fname = os.path.join(save_dir, f"{uuid.uuid4()}.jpg")
with open(fname, 'wb') as f:
f.write(img_data)
downloaded += 1
return downloaded
def generate_morphs(count: int, split: str = 'train') -> int:
"""
Generate morphed images by averaging random pairs from data/<split>/real/,
saving to data/<split>/morph/. Returns number generated.
"""
real_dir = os.path.join(DATA_DIR, split, 'real')
morph_dir = os.path.join(DATA_DIR, split, 'morph')
os.makedirs(morph_dir, exist_ok=True)
# Collect real image files
files = [f for f in os.listdir(real_dir)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
if len(files) < 2:
raise RuntimeError('Not enough real images to generate morphs')
generated = 0
for _ in range(count):
a, b = random.sample(files, 2)
path_a = os.path.join(real_dir, a)
path_b = os.path.join(real_dir, b)
img_a = Image.open(path_a).convert('RGB').resize((224, 224))
img_b = Image.open(path_b).convert('RGB').resize((224, 224))
arr_a = np.array(img_a).astype(np.float32)
arr_b = np.array(img_b).astype(np.float32)
arr = ((arr_a + arr_b) / 2.0).astype(np.uint8)
img_m = Image.fromarray(arr)
fname = os.path.join(morph_dir, f"{uuid.uuid4()}.jpg")
img_m.save(fname)
generated += 1
return generated
try:
from sklearn.datasets import fetch_lfw_people
except ImportError:
fetch_lfw_people = None
def fetch_lfw(count: int, split: str = 'train') -> int:
"""
Download images from the LFW (Labeled Faces in the Wild) dataset.
Saves up to `count` RGB images into data/<split>/real/.
Requires scikit-learn to be installed.
Returns the number of images saved.
"""
if fetch_lfw_people is None:
# scikit-learn not installed: report error
raise ImportError('scikit-learn is required to fetch LFW dataset')
# fetch grayscale images
lfw = fetch_lfw_people(min_faces_per_person=1, resize=0.5, color=False)
images = lfw.images
num = min(count, len(images))
save_dir = os.path.join(DATA_DIR, split, 'real')
os.makedirs(save_dir, exist_ok=True)
for idx in range(num):
img_gray = images[idx]
# convert to RGB by stacking
arr3 = np.stack([img_gray] * 3, axis=-1)
pil_img = Image.fromarray((arr3 * 255).astype(np.uint8))
pil_img = pil_img.resize((224, 224))
fname = os.path.join(save_dir, f"lfw_{idx}_{uuid.uuid4().hex}.jpg")
pil_img.save(fname)
return num
def fetch_pexels(count: int, split: str = 'train') -> int:
"""
Fetch random face images from Pexels API into data/<split>/real/.
Requires environment variable PEXELS_API_KEY to be set.
Returns the number of images downloaded.
"""
try:
import requests
except ImportError:
raise ImportError('requests library is required to fetch Pexels images')
api_key = os.getenv('PEXELS_API_KEY')
if not api_key:
raise EnvironmentError('PEXELS_API_KEY environment variable not set')
save_dir = os.path.join(DATA_DIR, split, 'real')
os.makedirs(save_dir, exist_ok=True)
url = 'https://api.pexels.com/v1/search'
headers = {'Authorization': api_key}
params = {'query': 'face', 'per_page': count}
resp = requests.get(url, headers=headers, params=params)
if resp.status_code != 200:
raise RuntimeError(f'Pexels API error {resp.status_code}: {resp.text}')
data = resp.json()
photos = data.get('photos', [])
downloaded = 0
for photo in photos:
img_url = photo.get('src', {}).get('medium')
if not img_url:
continue
img_data = requests.get(img_url).content
fname = os.path.join(save_dir, f"{uuid.uuid4()}.jpg")
with open(fname, 'wb') as f:
f.write(img_data)
downloaded += 1
return downloaded
def fetch_pixabay(count: int, split: str = 'train') -> int:
"""
Fetch random face images from Pixabay API into data/<split>/real/.
Requires environment variable PIXABAY_API_KEY to be set.
Returns the number of images downloaded.
"""
try:
import requests
except ImportError:
raise ImportError('requests library is required to fetch Pixabay images')
api_key = os.getenv('PIXABAY_API_KEY')
if not api_key:
raise EnvironmentError('PIXABAY_API_KEY environment variable not set')
save_dir = os.path.join(DATA_DIR, split, 'real')
os.makedirs(save_dir, exist_ok=True)
url = 'https://pixabay.com/api/'
params = {'key': api_key, 'q': 'face', 'image_type': 'photo', 'per_page': count}
resp = requests.get(url, params=params)
if resp.status_code != 200:
raise RuntimeError(f'Pixabay API error {resp.status_code}: {resp.text}')
data = resp.json()
hits = data.get('hits', [])
downloaded = 0
for hit in hits:
img_url = hit.get('webformatURL')
if not img_url:
continue
img_data = requests.get(img_url).content
fname = os.path.join(save_dir, f"{uuid.uuid4()}.jpg")
with open(fname, 'wb') as f:
f.write(img_data)
downloaded += 1
return downloaded
def fetch_utkface(count: int, split: str = 'train') -> int:
"""
Fetch the UTKFace dataset (faces only) from an S3 archive.
Downloads and extracts up to `count` JPEGs into data/<split>/real/.
No API key required.
Returns the number of images extracted.
"""
import requests, tarfile, tempfile
url = 'https://s3-us-west-1.amazonaws.com/utkface/UTKFace.tar.gz'
save_dir = os.path.join(DATA_DIR, split, 'real')
os.makedirs(save_dir, exist_ok=True)
# Download UTKFace archive (follow redirects automatically)
resp = requests.get(url, stream=True)
# Handle moved-permanently redirect for outdated S3 endpoint
if resp.status_code == 301:
raise RuntimeError(
'UTKFace download URL has moved (HTTP 301). '
'Please download the UTKFace dataset manually from https://susanqq.github.io/UTKFace/ '
f'and extract {count} images into data/{split}/real/'
)
if resp.status_code != 200:
raise RuntimeError(f'UTKFace download error {resp.status_code}')
tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.tar.gz')
for chunk in resp.iter_content(chunk_size=1024*1024):
if chunk:
tmp.write(chunk)
tmp.close()
# Extract JPEGs
extracted = 0
with tarfile.open(tmp.name, 'r:gz') as tar:
for member in tar.getmembers():
if extracted >= count:
break
if member.isfile() and member.name.lower().endswith('.jpg'):
f = tar.extractfile(member)
if f:
outpath = os.path.join(save_dir, os.path.basename(member.name))
with open(outpath, 'wb') as out:
out.write(f.read())
extracted += 1
return extracted
def fetch_tpdne(count: int, split: str = 'train') -> int:
"""
Fetch GAN-generated faces from thispersondoesnotexist.com
into data/<split>/real/. No API key required.
Returns the number of images downloaded.
"""
try:
import requests
except ImportError:
raise ImportError('requests library is required to fetch GAN images')
save_dir = os.path.join(DATA_DIR, split, 'real')
os.makedirs(save_dir, exist_ok=True)
downloaded = 0
for i in range(count):
resp = requests.get('https://thispersondoesnotexist.com/image', timeout=5)
if resp.status_code != 200:
continue
fname = os.path.join(save_dir, f"tpdne_{uuid.uuid4().hex}.jpg")
with open(fname, 'wb') as f:
f.write(resp.content)
downloaded += 1
return downloaded
def fetch_celeba(count: int, split: str = 'train') -> int:
"""
Fetch a sample of the CelebA dataset via Kaggle CLI.
Requires Kaggle CLI installed and KAGGLE_USERNAME/KAGGLE_KEY set.
Downloads and unzips full CelebA into data/raw/celeba/, then copies up to `count` images into data/<split>/real/.
"""
import subprocess, glob, random, shutil
raw_dir = os.path.join(DATA_DIR, 'raw', 'celeba')
save_dir = os.path.join(DATA_DIR, split, 'real')
os.makedirs(save_dir, exist_ok=True)
# Download and unzip if not already present
if not os.path.isdir(raw_dir):
os.makedirs(raw_dir, exist_ok=True)
cmd = [
'kaggle', 'datasets', 'download', '-d', 'jessicali9530/celeba-dataset',
'-p', raw_dir, '--unzip'
]
subprocess.run(cmd, check=True)
# Collect all images
img_paths = glob.glob(os.path.join(raw_dir, '**', '*.jpg'), recursive=True)
if not img_paths:
raise RuntimeError('No CelebA images found in raw directory')
# Randomly sample up to count
chosen = random.sample(img_paths, min(count, len(img_paths)))
copied = 0
for src in chosen:
dst = os.path.join(save_dir, os.path.basename(src))
if not os.path.exists(dst):
shutil.copy2(src, dst)
copied += 1
return copied
def fetch_vggface2(count: int, split: str = 'train') -> int:
"""
Stub for VGGFace2 dataset: manual download required.
Download from https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/ and extract into data/<split>/real/vggface2/.
This function does not automate download.
"""
raise EnvironmentError(
'VGGFace2 requires manual download: fetch from https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/ '
'and extract images into data/{}/real/vggface2/'.format(split)
)