Sharris's picture
Upload folder using huggingface_hub
de3c81a verified
"""
Train a TensorFlow regression model to predict age from face images (UTKFace dataset).
Usage:
- Put UTKFace images into a folder, e.g. data/UTKFace/
- python train.py --dataset_dir data/UTKFace --epochs 30 --batch_size 32
The script extracts the age from the filename (before the first underscore).
"""
import os
import argparse
import random
import math
import zipfile
from pathlib import Path
import numpy as np
from tqdm import tqdm
import requests
import tensorflow as tf
from tensorflow import keras
def parse_args():
parser = argparse.ArgumentParser(description="Train an age regression model on UTKFace images")
parser.add_argument("--dataset_dir", type=str, default="data/UTKFace", help="Path to folder containing UTKFace images")
parser.add_argument("--img_size", type=int, default=224, help="Image size (square)")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--epochs", type=int, default=30)
parser.add_argument("--val_split", type=float, default=0.12, help="Fraction to reserve for validation")
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--auto_download", type=lambda x: (str(x).lower() in ("true", "1", "yes")), default=False,
help="Whether to attempt to download UTKFace archive automatically if dataset folder is missing")
parser.add_argument("--fine_tune", type=lambda x: (str(x).lower() in ("true", "1", "yes")), default=False,
help="Whether to unfreeze part of the backbone for fine-tuning")
args = parser.parse_args()
return args
def attempt_download_utkface(dest_dir: Path):
"""Attempt to download a ZIP archive of the UTKFace repository and extract it.
This may fail if the remote hosting changes. The function attempts a best-effort download
from the repository URL commonly used to host UTKFace on GitHub.
"""
dest_dir.mkdir(parents=True, exist_ok=True)
github_zip = "https://github.com/susanqq/UTKFace/archive/refs/heads/master.zip"
tmp_zip = dest_dir / "utkface_master.zip"
print(f"Attempting to download UTKFace from {github_zip} ...")
try:
with requests.get(github_zip, stream=True, timeout=30) as r:
r.raise_for_status()
total = int(r.headers.get('content-length', 0))
with open(tmp_zip, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
print("Download complete. Extracting archive...")
with zipfile.ZipFile(tmp_zip, 'r') as z:
z.extractall(dest_dir)
# Move images into dest_dir if they're inside a top-level folder
extracted_root = None
for name in os.listdir(dest_dir):
if name.lower().startswith('utkface') and os.path.isdir(dest_dir / name):
extracted_root = dest_dir / name
break
if extracted_root:
images = list(extracted_root.rglob('*.jpg')) + list(extracted_root.rglob('*.png'))
for p in images:
target = dest_dir / p.name
try:
os.replace(p, target)
except Exception:
pass
# clean up
try:
os.remove(tmp_zip)
except Exception:
pass
print("UTKFace images should now be in:", dest_dir)
except Exception as e:
print("Automatic download failed:", e)
print("Please download the UTKFace archive manually and place images in the dataset directory.")
def collect_image_paths_and_labels(dataset_dir: Path):
# UTKFace filenames: <age>_<gender>_<race>_<date&time>.jpg
img_paths = []
labels = []
supported_ext = ('.jpg', '.jpeg', '.png')
for p in dataset_dir.iterdir():
if p.is_file() and p.suffix.lower() in supported_ext:
# parse age
parts = p.name.split('_')
try:
age = int(parts[0])
except Exception:
continue
img_paths.append(str(p))
labels.append(age)
return img_paths, labels
def make_dataset(paths, labels, img_size, batch_size, is_training=True):
paths = tf.convert_to_tensor(paths)
labels = tf.convert_to_tensor(labels, dtype=tf.float32)
ds = tf.data.Dataset.from_tensor_slices((paths, labels))
if is_training:
ds = ds.shuffle(10000, reshuffle_each_iteration=True)
def _load_image(path, label):
img = tf.io.read_file(path)
img = tf.image.decode_jpeg(img, channels=3)
img = tf.image.resize(img, [img_size, img_size])
img = img / 255.0 # normalize to [0,1]
if is_training:
img = data_augmentation(img)
return img, label
ds = ds.map(_load_image, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
return ds
def data_augmentation(image):
# Simple augmentation pipeline
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, max_delta=0.08)
image = tf.image.random_contrast(image, 0.9, 1.1)
# random zoom by central crop/resizing
if tf.random.uniform(()) > 0.6:
crop_frac = tf.random.uniform((), 0.8, 1.0)
shape = tf.shape(image)
crop_h = tf.cast(tf.cast(shape[0], tf.float32) * crop_frac, tf.int32)
crop_w = tf.cast(tf.cast(shape[1], tf.float32) * crop_frac, tf.int32)
image = tf.image.random_crop(image, size=[crop_h, crop_w, 3])
image = tf.image.resize(image, [shape[0], shape[1]])
return image
def build_model(img_size, fine_tune=False):
inputs = keras.Input(shape=(img_size, img_size, 3))
base = keras.applications.MobileNetV2(include_top=False, input_tensor=inputs, weights='imagenet')
base.trainable = False
x = base.output
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)
x = keras.layers.Dense(128, activation='relu')(x)
x = keras.layers.Dense(64, activation='relu')(x)
outputs = keras.layers.Dense(1, name='age')(x) # regression output
model = keras.Model(inputs=inputs, outputs=outputs)
if fine_tune:
# Unfreeze last blocks for fine-tuning
base.trainable = True
# Freeze earlier layers
for layer in base.layers[:-30]:
layer.trainable = False
return model
def main():
args = parse_args()
dataset_dir = Path(args.dataset_dir)
if (not dataset_dir.exists() or not any(dataset_dir.iterdir())) and args.auto_download:
attempt_download_utkface(dataset_dir)
if not dataset_dir.exists() or not any(dataset_dir.iterdir()):
raise RuntimeError(f"No images found in {dataset_dir}. Place UTKFace images there or use --auto_download True to attempt download.")
paths, labels = collect_image_paths_and_labels(dataset_dir)
if len(paths) == 0:
raise RuntimeError("No valid UTKFace images found in dataset directory. Ensure the files follow the naming convention '<age>_...'.")
# Convert to numpy lists
paths = np.array(paths)
labels = np.array(labels, dtype=np.float32)
# Shuffle and split
indices = np.arange(len(paths))
np.random.shuffle(indices)
paths = paths[indices]
labels = labels[indices]
n_val = max(1, int(len(paths) * args.val_split))
val_paths = paths[:n_val].tolist()
val_labels = labels[:n_val].tolist()
train_paths = paths[n_val:].tolist()
train_labels = labels[n_val:].tolist()
print(f"Found {len(train_paths)} training images and {len(val_paths)} validation images.")
train_ds = make_dataset(train_paths, train_labels, args.img_size, args.batch_size, is_training=True)
val_ds = make_dataset(val_paths, val_labels, args.img_size, args.batch_size, is_training=False)
model = build_model(args.img_size, fine_tune=args.fine_tune)
model.compile(optimizer=keras.optimizers.Adam(learning_rate=args.learning_rate),
loss='mse',
metrics=[keras.metrics.MeanAbsoluteError(name='mae')])
model.summary()
callbacks = [
keras.callbacks.ModelCheckpoint('best_model.h5', save_best_only=True, monitor='val_loss'),
keras.callbacks.EarlyStopping(monitor='val_loss', patience=8, restore_best_weights=True),
keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=4, min_lr=1e-7)
]
history = model.fit(train_ds, validation_data=val_ds, epochs=args.epochs, callbacks=callbacks)
# Evaluate
print("Evaluating on validation set:")
eval_res = model.evaluate(val_ds)
print(dict(zip(model.metrics_names, eval_res)))
# Save in both SavedModel (preferred) and Keras formats for compatibility
try:
# Preferred: export to SavedModel directory for TFServing/TFLite
model.export('saved_model_age_regressor')
print('Exported SavedModel to ./saved_model_age_regressor')
except Exception as e:
print('SavedModel export failed:', e)
# Fallback: save as Keras native single-file (.keras)
try:
model.save('saved_model_age_regressor.keras')
print('Saved Keras model to ./saved_model_age_regressor.keras')
except Exception as e2:
print('Keras native save failed:', e2)
# Also save an HDF5 copy for backward compatibility with tools that require .h5
try:
model.save('final_model.h5')
print('Saved HDF5 model to ./final_model.h5')
except Exception as e3:
print('HDF5 save failed:', e3)
# Show a few sample predictions
sample_paths = val_paths[:12]
sample_labels = val_labels[:12]
sample_ds = make_dataset(sample_paths, sample_labels, args.img_size, batch_size=12, is_training=False)
imgs, labs = next(iter(sample_ds))
preds = model.predict(imgs).flatten()
try:
import matplotlib.pyplot as plt
n = len(preds)
cols = 4
rows = math.ceil(n / cols)
plt.figure(figsize=(cols * 3, rows * 3))
for i in range(n):
ax = plt.subplot(rows, cols, i + 1)
img = imgs[i].numpy()
plt.imshow(img)
plt.axis('off')
plt.title(f"True: {int(labs[i])}\nPred: {preds[i]:.1f}")
plt.tight_layout()
plt.show()
except Exception:
print("Matplotlib not available or running headless; skipping sample visualization.")
if __name__ == '__main__':
main()