File size: 10,549 Bytes
de3c81a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
"""
Train a TensorFlow regression model to predict age from face images (UTKFace dataset).

Usage:
  - Put UTKFace images into a folder, e.g. data/UTKFace/
  - python train.py --dataset_dir data/UTKFace --epochs 30 --batch_size 32

The script extracts the age from the filename (before the first underscore).
"""

import os
import argparse
import random
import math
import zipfile
from pathlib import Path

import numpy as np
from tqdm import tqdm
import requests

import tensorflow as tf
from tensorflow import keras


def parse_args():
    parser = argparse.ArgumentParser(description="Train an age regression model on UTKFace images")
    parser.add_argument("--dataset_dir", type=str, default="data/UTKFace", help="Path to folder containing UTKFace images")
    parser.add_argument("--img_size", type=int, default=224, help="Image size (square)")
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--epochs", type=int, default=30)
    parser.add_argument("--val_split", type=float, default=0.12, help="Fraction to reserve for validation")
    parser.add_argument("--learning_rate", type=float, default=1e-4)
    parser.add_argument("--auto_download", type=lambda x: (str(x).lower() in ("true", "1", "yes")), default=False,
                        help="Whether to attempt to download UTKFace archive automatically if dataset folder is missing")
    parser.add_argument("--fine_tune", type=lambda x: (str(x).lower() in ("true", "1", "yes")), default=False,
                        help="Whether to unfreeze part of the backbone for fine-tuning")
    args = parser.parse_args()
    return args


def attempt_download_utkface(dest_dir: Path):
    """Attempt to download a ZIP archive of the UTKFace repository and extract it.

    This may fail if the remote hosting changes. The function attempts a best-effort download
    from the repository URL commonly used to host UTKFace on GitHub.
    """
    dest_dir.mkdir(parents=True, exist_ok=True)
    github_zip = "https://github.com/susanqq/UTKFace/archive/refs/heads/master.zip"
    tmp_zip = dest_dir / "utkface_master.zip"
    print(f"Attempting to download UTKFace from {github_zip} ...")

    try:
        with requests.get(github_zip, stream=True, timeout=30) as r:
            r.raise_for_status()
            total = int(r.headers.get('content-length', 0))
            with open(tmp_zip, 'wb') as f:
                for chunk in r.iter_content(chunk_size=8192):
                    if chunk:
                        f.write(chunk)

        print("Download complete. Extracting archive...")
        with zipfile.ZipFile(tmp_zip, 'r') as z:
            z.extractall(dest_dir)

        # Move images into dest_dir if they're inside a top-level folder
        extracted_root = None
        for name in os.listdir(dest_dir):
            if name.lower().startswith('utkface') and os.path.isdir(dest_dir / name):
                extracted_root = dest_dir / name
                break
        if extracted_root:
            images = list(extracted_root.rglob('*.jpg')) + list(extracted_root.rglob('*.png'))
            for p in images:
                target = dest_dir / p.name
                try:
                    os.replace(p, target)
                except Exception:
                    pass
        # clean up
        try:
            os.remove(tmp_zip)
        except Exception:
            pass
        print("UTKFace images should now be in:", dest_dir)
    except Exception as e:
        print("Automatic download failed:", e)
        print("Please download the UTKFace archive manually and place images in the dataset directory.")


def collect_image_paths_and_labels(dataset_dir: Path):
    # UTKFace filenames: <age>_<gender>_<race>_<date&time>.jpg
    img_paths = []
    labels = []
    supported_ext = ('.jpg', '.jpeg', '.png')
    for p in dataset_dir.iterdir():
        if p.is_file() and p.suffix.lower() in supported_ext:
            # parse age
            parts = p.name.split('_')
            try:
                age = int(parts[0])
            except Exception:
                continue
            img_paths.append(str(p))
            labels.append(age)
    return img_paths, labels


def make_dataset(paths, labels, img_size, batch_size, is_training=True):
    paths = tf.convert_to_tensor(paths)
    labels = tf.convert_to_tensor(labels, dtype=tf.float32)

    ds = tf.data.Dataset.from_tensor_slices((paths, labels))
    if is_training:
        ds = ds.shuffle(10000, reshuffle_each_iteration=True)

    def _load_image(path, label):
        img = tf.io.read_file(path)
        img = tf.image.decode_jpeg(img, channels=3)
        img = tf.image.resize(img, [img_size, img_size])
        img = img / 255.0  # normalize to [0,1]
        if is_training:
            img = data_augmentation(img)
        return img, label

    ds = ds.map(_load_image, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds


def data_augmentation(image):
    # Simple augmentation pipeline
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image, max_delta=0.08)
    image = tf.image.random_contrast(image, 0.9, 1.1)
    # random zoom by central crop/resizing
    if tf.random.uniform(()) > 0.6:
        crop_frac = tf.random.uniform((), 0.8, 1.0)
        shape = tf.shape(image)
        crop_h = tf.cast(tf.cast(shape[0], tf.float32) * crop_frac, tf.int32)
        crop_w = tf.cast(tf.cast(shape[1], tf.float32) * crop_frac, tf.int32)
        image = tf.image.random_crop(image, size=[crop_h, crop_w, 3])
        image = tf.image.resize(image, [shape[0], shape[1]])
    return image


def build_model(img_size, fine_tune=False):
    inputs = keras.Input(shape=(img_size, img_size, 3))
    base = keras.applications.MobileNetV2(include_top=False, input_tensor=inputs, weights='imagenet')
    base.trainable = False

    x = base.output
    x = keras.layers.GlobalAveragePooling2D()(x)
    x = keras.layers.Dropout(0.2)(x)
    x = keras.layers.Dense(128, activation='relu')(x)
    x = keras.layers.Dense(64, activation='relu')(x)
    outputs = keras.layers.Dense(1, name='age')(x)  # regression output

    model = keras.Model(inputs=inputs, outputs=outputs)

    if fine_tune:
        # Unfreeze last blocks for fine-tuning
        base.trainable = True
        # Freeze earlier layers
        for layer in base.layers[:-30]:
            layer.trainable = False

    return model


def main():
    args = parse_args()
    dataset_dir = Path(args.dataset_dir)

    if (not dataset_dir.exists() or not any(dataset_dir.iterdir())) and args.auto_download:
        attempt_download_utkface(dataset_dir)

    if not dataset_dir.exists() or not any(dataset_dir.iterdir()):
        raise RuntimeError(f"No images found in {dataset_dir}. Place UTKFace images there or use --auto_download True to attempt download.")

    paths, labels = collect_image_paths_and_labels(dataset_dir)
    if len(paths) == 0:
        raise RuntimeError("No valid UTKFace images found in dataset directory. Ensure the files follow the naming convention '<age>_...'.")

    # Convert to numpy lists
    paths = np.array(paths)
    labels = np.array(labels, dtype=np.float32)

    # Shuffle and split
    indices = np.arange(len(paths))
    np.random.shuffle(indices)
    paths = paths[indices]
    labels = labels[indices]

    n_val = max(1, int(len(paths) * args.val_split))
    val_paths = paths[:n_val].tolist()
    val_labels = labels[:n_val].tolist()
    train_paths = paths[n_val:].tolist()
    train_labels = labels[n_val:].tolist()

    print(f"Found {len(train_paths)} training images and {len(val_paths)} validation images.")

    train_ds = make_dataset(train_paths, train_labels, args.img_size, args.batch_size, is_training=True)
    val_ds = make_dataset(val_paths, val_labels, args.img_size, args.batch_size, is_training=False)

    model = build_model(args.img_size, fine_tune=args.fine_tune)
    model.compile(optimizer=keras.optimizers.Adam(learning_rate=args.learning_rate),
                  loss='mse',
                  metrics=[keras.metrics.MeanAbsoluteError(name='mae')])

    model.summary()

    callbacks = [
        keras.callbacks.ModelCheckpoint('best_model.h5', save_best_only=True, monitor='val_loss'),
        keras.callbacks.EarlyStopping(monitor='val_loss', patience=8, restore_best_weights=True),
        keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=4, min_lr=1e-7)
    ]

    history = model.fit(train_ds, validation_data=val_ds, epochs=args.epochs, callbacks=callbacks)

    # Evaluate
    print("Evaluating on validation set:")
    eval_res = model.evaluate(val_ds)
    print(dict(zip(model.metrics_names, eval_res)))

    # Save in both SavedModel (preferred) and Keras formats for compatibility
    try:
        # Preferred: export to SavedModel directory for TFServing/TFLite
        model.export('saved_model_age_regressor')
        print('Exported SavedModel to ./saved_model_age_regressor')
    except Exception as e:
        print('SavedModel export failed:', e)
        # Fallback: save as Keras native single-file (.keras)
        try:
            model.save('saved_model_age_regressor.keras')
            print('Saved Keras model to ./saved_model_age_regressor.keras')
        except Exception as e2:
            print('Keras native save failed:', e2)
    # Also save an HDF5 copy for backward compatibility with tools that require .h5
    try:
        model.save('final_model.h5')
        print('Saved HDF5 model to ./final_model.h5')
    except Exception as e3:
        print('HDF5 save failed:', e3)

    # Show a few sample predictions
    sample_paths = val_paths[:12]
    sample_labels = val_labels[:12]

    sample_ds = make_dataset(sample_paths, sample_labels, args.img_size, batch_size=12, is_training=False)
    imgs, labs = next(iter(sample_ds))
    preds = model.predict(imgs).flatten()

    try:
        import matplotlib.pyplot as plt
        n = len(preds)
        cols = 4
        rows = math.ceil(n / cols)
        plt.figure(figsize=(cols * 3, rows * 3))
        for i in range(n):
            ax = plt.subplot(rows, cols, i + 1)
            img = imgs[i].numpy()
            plt.imshow(img)
            plt.axis('off')
            plt.title(f"True: {int(labs[i])}\nPred: {preds[i]:.1f}")
        plt.tight_layout()
        plt.show()
    except Exception:
        print("Matplotlib not available or running headless; skipping sample visualization.")


if __name__ == '__main__':
    main()