""" Train a TensorFlow regression model to predict age from face images (UTKFace dataset). Usage: - Put UTKFace images into a folder, e.g. data/UTKFace/ - python train.py --dataset_dir data/UTKFace --epochs 30 --batch_size 32 The script extracts the age from the filename (before the first underscore). """ import os import argparse import random import math import zipfile from pathlib import Path import numpy as np from tqdm import tqdm import requests import tensorflow as tf from tensorflow import keras def parse_args(): parser = argparse.ArgumentParser(description="Train an age regression model on UTKFace images") parser.add_argument("--dataset_dir", type=str, default="data/UTKFace", help="Path to folder containing UTKFace images") parser.add_argument("--img_size", type=int, default=224, help="Image size (square)") parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--epochs", type=int, default=30) parser.add_argument("--val_split", type=float, default=0.12, help="Fraction to reserve for validation") parser.add_argument("--learning_rate", type=float, default=1e-4) parser.add_argument("--auto_download", type=lambda x: (str(x).lower() in ("true", "1", "yes")), default=False, help="Whether to attempt to download UTKFace archive automatically if dataset folder is missing") parser.add_argument("--fine_tune", type=lambda x: (str(x).lower() in ("true", "1", "yes")), default=False, help="Whether to unfreeze part of the backbone for fine-tuning") args = parser.parse_args() return args def attempt_download_utkface(dest_dir: Path): """Attempt to download a ZIP archive of the UTKFace repository and extract it. This may fail if the remote hosting changes. The function attempts a best-effort download from the repository URL commonly used to host UTKFace on GitHub. """ dest_dir.mkdir(parents=True, exist_ok=True) github_zip = "https://github.com/susanqq/UTKFace/archive/refs/heads/master.zip" tmp_zip = dest_dir / "utkface_master.zip" print(f"Attempting to download UTKFace from {github_zip} ...") try: with requests.get(github_zip, stream=True, timeout=30) as r: r.raise_for_status() total = int(r.headers.get('content-length', 0)) with open(tmp_zip, 'wb') as f: for chunk in r.iter_content(chunk_size=8192): if chunk: f.write(chunk) print("Download complete. Extracting archive...") with zipfile.ZipFile(tmp_zip, 'r') as z: z.extractall(dest_dir) # Move images into dest_dir if they're inside a top-level folder extracted_root = None for name in os.listdir(dest_dir): if name.lower().startswith('utkface') and os.path.isdir(dest_dir / name): extracted_root = dest_dir / name break if extracted_root: images = list(extracted_root.rglob('*.jpg')) + list(extracted_root.rglob('*.png')) for p in images: target = dest_dir / p.name try: os.replace(p, target) except Exception: pass # clean up try: os.remove(tmp_zip) except Exception: pass print("UTKFace images should now be in:", dest_dir) except Exception as e: print("Automatic download failed:", e) print("Please download the UTKFace archive manually and place images in the dataset directory.") def collect_image_paths_and_labels(dataset_dir: Path): # UTKFace filenames: ___.jpg img_paths = [] labels = [] supported_ext = ('.jpg', '.jpeg', '.png') for p in dataset_dir.iterdir(): if p.is_file() and p.suffix.lower() in supported_ext: # parse age parts = p.name.split('_') try: age = int(parts[0]) except Exception: continue img_paths.append(str(p)) labels.append(age) return img_paths, labels def make_dataset(paths, labels, img_size, batch_size, is_training=True): paths = tf.convert_to_tensor(paths) labels = tf.convert_to_tensor(labels, dtype=tf.float32) ds = tf.data.Dataset.from_tensor_slices((paths, labels)) if is_training: ds = ds.shuffle(10000, reshuffle_each_iteration=True) def _load_image(path, label): img = tf.io.read_file(path) img = tf.image.decode_jpeg(img, channels=3) img = tf.image.resize(img, [img_size, img_size]) img = img / 255.0 # normalize to [0,1] if is_training: img = data_augmentation(img) return img, label ds = ds.map(_load_image, num_parallel_calls=tf.data.AUTOTUNE) ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE) return ds def data_augmentation(image): # Simple augmentation pipeline image = tf.image.random_flip_left_right(image) image = tf.image.random_brightness(image, max_delta=0.08) image = tf.image.random_contrast(image, 0.9, 1.1) # random zoom by central crop/resizing if tf.random.uniform(()) > 0.6: crop_frac = tf.random.uniform((), 0.8, 1.0) shape = tf.shape(image) crop_h = tf.cast(tf.cast(shape[0], tf.float32) * crop_frac, tf.int32) crop_w = tf.cast(tf.cast(shape[1], tf.float32) * crop_frac, tf.int32) image = tf.image.random_crop(image, size=[crop_h, crop_w, 3]) image = tf.image.resize(image, [shape[0], shape[1]]) return image def build_model(img_size, fine_tune=False): inputs = keras.Input(shape=(img_size, img_size, 3)) base = keras.applications.MobileNetV2(include_top=False, input_tensor=inputs, weights='imagenet') base.trainable = False x = base.output x = keras.layers.GlobalAveragePooling2D()(x) x = keras.layers.Dropout(0.2)(x) x = keras.layers.Dense(128, activation='relu')(x) x = keras.layers.Dense(64, activation='relu')(x) outputs = keras.layers.Dense(1, name='age')(x) # regression output model = keras.Model(inputs=inputs, outputs=outputs) if fine_tune: # Unfreeze last blocks for fine-tuning base.trainable = True # Freeze earlier layers for layer in base.layers[:-30]: layer.trainable = False return model def main(): args = parse_args() dataset_dir = Path(args.dataset_dir) if (not dataset_dir.exists() or not any(dataset_dir.iterdir())) and args.auto_download: attempt_download_utkface(dataset_dir) if not dataset_dir.exists() or not any(dataset_dir.iterdir()): raise RuntimeError(f"No images found in {dataset_dir}. Place UTKFace images there or use --auto_download True to attempt download.") paths, labels = collect_image_paths_and_labels(dataset_dir) if len(paths) == 0: raise RuntimeError("No valid UTKFace images found in dataset directory. Ensure the files follow the naming convention '_...'.") # Convert to numpy lists paths = np.array(paths) labels = np.array(labels, dtype=np.float32) # Shuffle and split indices = np.arange(len(paths)) np.random.shuffle(indices) paths = paths[indices] labels = labels[indices] n_val = max(1, int(len(paths) * args.val_split)) val_paths = paths[:n_val].tolist() val_labels = labels[:n_val].tolist() train_paths = paths[n_val:].tolist() train_labels = labels[n_val:].tolist() print(f"Found {len(train_paths)} training images and {len(val_paths)} validation images.") train_ds = make_dataset(train_paths, train_labels, args.img_size, args.batch_size, is_training=True) val_ds = make_dataset(val_paths, val_labels, args.img_size, args.batch_size, is_training=False) model = build_model(args.img_size, fine_tune=args.fine_tune) model.compile(optimizer=keras.optimizers.Adam(learning_rate=args.learning_rate), loss='mse', metrics=[keras.metrics.MeanAbsoluteError(name='mae')]) model.summary() callbacks = [ keras.callbacks.ModelCheckpoint('best_model.h5', save_best_only=True, monitor='val_loss'), keras.callbacks.EarlyStopping(monitor='val_loss', patience=8, restore_best_weights=True), keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=4, min_lr=1e-7) ] history = model.fit(train_ds, validation_data=val_ds, epochs=args.epochs, callbacks=callbacks) # Evaluate print("Evaluating on validation set:") eval_res = model.evaluate(val_ds) print(dict(zip(model.metrics_names, eval_res))) # Save in both SavedModel (preferred) and Keras formats for compatibility try: # Preferred: export to SavedModel directory for TFServing/TFLite model.export('saved_model_age_regressor') print('Exported SavedModel to ./saved_model_age_regressor') except Exception as e: print('SavedModel export failed:', e) # Fallback: save as Keras native single-file (.keras) try: model.save('saved_model_age_regressor.keras') print('Saved Keras model to ./saved_model_age_regressor.keras') except Exception as e2: print('Keras native save failed:', e2) # Also save an HDF5 copy for backward compatibility with tools that require .h5 try: model.save('final_model.h5') print('Saved HDF5 model to ./final_model.h5') except Exception as e3: print('HDF5 save failed:', e3) # Show a few sample predictions sample_paths = val_paths[:12] sample_labels = val_labels[:12] sample_ds = make_dataset(sample_paths, sample_labels, args.img_size, batch_size=12, is_training=False) imgs, labs = next(iter(sample_ds)) preds = model.predict(imgs).flatten() try: import matplotlib.pyplot as plt n = len(preds) cols = 4 rows = math.ceil(n / cols) plt.figure(figsize=(cols * 3, rows * 3)) for i in range(n): ax = plt.subplot(rows, cols, i + 1) img = imgs[i].numpy() plt.imshow(img) plt.axis('off') plt.title(f"True: {int(labs[i])}\nPred: {preds[i]:.1f}") plt.tight_layout() plt.show() except Exception: print("Matplotlib not available or running headless; skipping sample visualization.") if __name__ == '__main__': main()