|
|
""" |
|
|
Train a TensorFlow regression model to predict age from face images (UTKFace dataset). |
|
|
|
|
|
Usage: |
|
|
- Put UTKFace images into a folder, e.g. data/UTKFace/ |
|
|
- python train.py --dataset_dir data/UTKFace --epochs 30 --batch_size 32 |
|
|
|
|
|
The script extracts the age from the filename (before the first underscore). |
|
|
""" |
|
|
|
|
|
import os |
|
|
import argparse |
|
|
import random |
|
|
import math |
|
|
import zipfile |
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
import requests |
|
|
|
|
|
import tensorflow as tf |
|
|
from tensorflow import keras |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description="Train an age regression model on UTKFace images") |
|
|
parser.add_argument("--dataset_dir", type=str, default="data/UTKFace", help="Path to folder containing UTKFace images") |
|
|
parser.add_argument("--img_size", type=int, default=224, help="Image size (square)") |
|
|
parser.add_argument("--batch_size", type=int, default=32) |
|
|
parser.add_argument("--epochs", type=int, default=30) |
|
|
parser.add_argument("--val_split", type=float, default=0.12, help="Fraction to reserve for validation") |
|
|
parser.add_argument("--learning_rate", type=float, default=1e-4) |
|
|
parser.add_argument("--auto_download", type=lambda x: (str(x).lower() in ("true", "1", "yes")), default=False, |
|
|
help="Whether to attempt to download UTKFace archive automatically if dataset folder is missing") |
|
|
parser.add_argument("--fine_tune", type=lambda x: (str(x).lower() in ("true", "1", "yes")), default=False, |
|
|
help="Whether to unfreeze part of the backbone for fine-tuning") |
|
|
args = parser.parse_args() |
|
|
return args |
|
|
|
|
|
|
|
|
def attempt_download_utkface(dest_dir: Path): |
|
|
"""Attempt to download a ZIP archive of the UTKFace repository and extract it. |
|
|
|
|
|
This may fail if the remote hosting changes. The function attempts a best-effort download |
|
|
from the repository URL commonly used to host UTKFace on GitHub. |
|
|
""" |
|
|
dest_dir.mkdir(parents=True, exist_ok=True) |
|
|
github_zip = "https://github.com/susanqq/UTKFace/archive/refs/heads/master.zip" |
|
|
tmp_zip = dest_dir / "utkface_master.zip" |
|
|
print(f"Attempting to download UTKFace from {github_zip} ...") |
|
|
|
|
|
try: |
|
|
with requests.get(github_zip, stream=True, timeout=30) as r: |
|
|
r.raise_for_status() |
|
|
total = int(r.headers.get('content-length', 0)) |
|
|
with open(tmp_zip, 'wb') as f: |
|
|
for chunk in r.iter_content(chunk_size=8192): |
|
|
if chunk: |
|
|
f.write(chunk) |
|
|
|
|
|
print("Download complete. Extracting archive...") |
|
|
with zipfile.ZipFile(tmp_zip, 'r') as z: |
|
|
z.extractall(dest_dir) |
|
|
|
|
|
|
|
|
extracted_root = None |
|
|
for name in os.listdir(dest_dir): |
|
|
if name.lower().startswith('utkface') and os.path.isdir(dest_dir / name): |
|
|
extracted_root = dest_dir / name |
|
|
break |
|
|
if extracted_root: |
|
|
images = list(extracted_root.rglob('*.jpg')) + list(extracted_root.rglob('*.png')) |
|
|
for p in images: |
|
|
target = dest_dir / p.name |
|
|
try: |
|
|
os.replace(p, target) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
try: |
|
|
os.remove(tmp_zip) |
|
|
except Exception: |
|
|
pass |
|
|
print("UTKFace images should now be in:", dest_dir) |
|
|
except Exception as e: |
|
|
print("Automatic download failed:", e) |
|
|
print("Please download the UTKFace archive manually and place images in the dataset directory.") |
|
|
|
|
|
|
|
|
def collect_image_paths_and_labels(dataset_dir: Path): |
|
|
|
|
|
img_paths = [] |
|
|
labels = [] |
|
|
supported_ext = ('.jpg', '.jpeg', '.png') |
|
|
for p in dataset_dir.iterdir(): |
|
|
if p.is_file() and p.suffix.lower() in supported_ext: |
|
|
|
|
|
parts = p.name.split('_') |
|
|
try: |
|
|
age = int(parts[0]) |
|
|
except Exception: |
|
|
continue |
|
|
img_paths.append(str(p)) |
|
|
labels.append(age) |
|
|
return img_paths, labels |
|
|
|
|
|
|
|
|
def make_dataset(paths, labels, img_size, batch_size, is_training=True): |
|
|
paths = tf.convert_to_tensor(paths) |
|
|
labels = tf.convert_to_tensor(labels, dtype=tf.float32) |
|
|
|
|
|
ds = tf.data.Dataset.from_tensor_slices((paths, labels)) |
|
|
if is_training: |
|
|
ds = ds.shuffle(10000, reshuffle_each_iteration=True) |
|
|
|
|
|
def _load_image(path, label): |
|
|
img = tf.io.read_file(path) |
|
|
img = tf.image.decode_jpeg(img, channels=3) |
|
|
img = tf.image.resize(img, [img_size, img_size]) |
|
|
img = img / 255.0 |
|
|
if is_training: |
|
|
img = data_augmentation(img) |
|
|
return img, label |
|
|
|
|
|
ds = ds.map(_load_image, num_parallel_calls=tf.data.AUTOTUNE) |
|
|
ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE) |
|
|
return ds |
|
|
|
|
|
|
|
|
def data_augmentation(image): |
|
|
|
|
|
image = tf.image.random_flip_left_right(image) |
|
|
image = tf.image.random_brightness(image, max_delta=0.08) |
|
|
image = tf.image.random_contrast(image, 0.9, 1.1) |
|
|
|
|
|
if tf.random.uniform(()) > 0.6: |
|
|
crop_frac = tf.random.uniform((), 0.8, 1.0) |
|
|
shape = tf.shape(image) |
|
|
crop_h = tf.cast(tf.cast(shape[0], tf.float32) * crop_frac, tf.int32) |
|
|
crop_w = tf.cast(tf.cast(shape[1], tf.float32) * crop_frac, tf.int32) |
|
|
image = tf.image.random_crop(image, size=[crop_h, crop_w, 3]) |
|
|
image = tf.image.resize(image, [shape[0], shape[1]]) |
|
|
return image |
|
|
|
|
|
|
|
|
def build_model(img_size, fine_tune=False): |
|
|
inputs = keras.Input(shape=(img_size, img_size, 3)) |
|
|
base = keras.applications.MobileNetV2(include_top=False, input_tensor=inputs, weights='imagenet') |
|
|
base.trainable = False |
|
|
|
|
|
x = base.output |
|
|
x = keras.layers.GlobalAveragePooling2D()(x) |
|
|
x = keras.layers.Dropout(0.2)(x) |
|
|
x = keras.layers.Dense(128, activation='relu')(x) |
|
|
x = keras.layers.Dense(64, activation='relu')(x) |
|
|
outputs = keras.layers.Dense(1, name='age')(x) |
|
|
|
|
|
model = keras.Model(inputs=inputs, outputs=outputs) |
|
|
|
|
|
if fine_tune: |
|
|
|
|
|
base.trainable = True |
|
|
|
|
|
for layer in base.layers[:-30]: |
|
|
layer.trainable = False |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
dataset_dir = Path(args.dataset_dir) |
|
|
|
|
|
if (not dataset_dir.exists() or not any(dataset_dir.iterdir())) and args.auto_download: |
|
|
attempt_download_utkface(dataset_dir) |
|
|
|
|
|
if not dataset_dir.exists() or not any(dataset_dir.iterdir()): |
|
|
raise RuntimeError(f"No images found in {dataset_dir}. Place UTKFace images there or use --auto_download True to attempt download.") |
|
|
|
|
|
paths, labels = collect_image_paths_and_labels(dataset_dir) |
|
|
if len(paths) == 0: |
|
|
raise RuntimeError("No valid UTKFace images found in dataset directory. Ensure the files follow the naming convention '<age>_...'.") |
|
|
|
|
|
|
|
|
paths = np.array(paths) |
|
|
labels = np.array(labels, dtype=np.float32) |
|
|
|
|
|
|
|
|
indices = np.arange(len(paths)) |
|
|
np.random.shuffle(indices) |
|
|
paths = paths[indices] |
|
|
labels = labels[indices] |
|
|
|
|
|
n_val = max(1, int(len(paths) * args.val_split)) |
|
|
val_paths = paths[:n_val].tolist() |
|
|
val_labels = labels[:n_val].tolist() |
|
|
train_paths = paths[n_val:].tolist() |
|
|
train_labels = labels[n_val:].tolist() |
|
|
|
|
|
print(f"Found {len(train_paths)} training images and {len(val_paths)} validation images.") |
|
|
|
|
|
train_ds = make_dataset(train_paths, train_labels, args.img_size, args.batch_size, is_training=True) |
|
|
val_ds = make_dataset(val_paths, val_labels, args.img_size, args.batch_size, is_training=False) |
|
|
|
|
|
model = build_model(args.img_size, fine_tune=args.fine_tune) |
|
|
model.compile(optimizer=keras.optimizers.Adam(learning_rate=args.learning_rate), |
|
|
loss='mse', |
|
|
metrics=[keras.metrics.MeanAbsoluteError(name='mae')]) |
|
|
|
|
|
model.summary() |
|
|
|
|
|
callbacks = [ |
|
|
keras.callbacks.ModelCheckpoint('best_model.h5', save_best_only=True, monitor='val_loss'), |
|
|
keras.callbacks.EarlyStopping(monitor='val_loss', patience=8, restore_best_weights=True), |
|
|
keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=4, min_lr=1e-7) |
|
|
] |
|
|
|
|
|
history = model.fit(train_ds, validation_data=val_ds, epochs=args.epochs, callbacks=callbacks) |
|
|
|
|
|
|
|
|
print("Evaluating on validation set:") |
|
|
eval_res = model.evaluate(val_ds) |
|
|
print(dict(zip(model.metrics_names, eval_res))) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
model.export('saved_model_age_regressor') |
|
|
print('Exported SavedModel to ./saved_model_age_regressor') |
|
|
except Exception as e: |
|
|
print('SavedModel export failed:', e) |
|
|
|
|
|
try: |
|
|
model.save('saved_model_age_regressor.keras') |
|
|
print('Saved Keras model to ./saved_model_age_regressor.keras') |
|
|
except Exception as e2: |
|
|
print('Keras native save failed:', e2) |
|
|
|
|
|
try: |
|
|
model.save('final_model.h5') |
|
|
print('Saved HDF5 model to ./final_model.h5') |
|
|
except Exception as e3: |
|
|
print('HDF5 save failed:', e3) |
|
|
|
|
|
|
|
|
sample_paths = val_paths[:12] |
|
|
sample_labels = val_labels[:12] |
|
|
|
|
|
sample_ds = make_dataset(sample_paths, sample_labels, args.img_size, batch_size=12, is_training=False) |
|
|
imgs, labs = next(iter(sample_ds)) |
|
|
preds = model.predict(imgs).flatten() |
|
|
|
|
|
try: |
|
|
import matplotlib.pyplot as plt |
|
|
n = len(preds) |
|
|
cols = 4 |
|
|
rows = math.ceil(n / cols) |
|
|
plt.figure(figsize=(cols * 3, rows * 3)) |
|
|
for i in range(n): |
|
|
ax = plt.subplot(rows, cols, i + 1) |
|
|
img = imgs[i].numpy() |
|
|
plt.imshow(img) |
|
|
plt.axis('off') |
|
|
plt.title(f"True: {int(labs[i])}\nPred: {preds[i]:.1f}") |
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
except Exception: |
|
|
print("Matplotlib not available or running headless; skipping sample visualization.") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|