File size: 3,799 Bytes
de3c81a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""
Load a trained age regression model and run a prediction on a single image.

Usage: python predict.py --model_path saved_model_age_regressor --image_path some_image.jpg
"""
import argparse
from pathlib import Path

import numpy as np
from PIL import Image
import tensorflow as tf


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, default='saved_model_age_regressor')
    parser.add_argument('--image_path', type=str, required=True)
    parser.add_argument('--img_size', type=int, default=224)
    parser.add_argument('--output_key', type=str, default=None,
                        help='If the model returns a dict, select this key for the numeric prediction. If omitted the first numeric output will be used.')
    return parser.parse_args()


def load_image(path, img_size):
    img = Image.open(path).convert('RGB')
    img = img.resize((img_size, img_size))
    arr = np.array(img, dtype=np.float32) / 255.0
    return arr


def main():
    args = parse_args()
    model_path = Path(args.model_path)
    # Load Keras .h5/.keras files directly, and attempt Keras load for directories first.
    if model_path.is_file() and model_path.suffix.lower() in ('.h5', '.keras'):
        model = tf.keras.models.load_model(str(model_path), compile=False)
        print(f"Loaded Keras model file: {model_path}")
    elif model_path.is_dir():
        # Some SavedModel directories are not loadable with tf.keras.load_model in Keras 3;
        # try load_model first (covers .keras saved dirs), otherwise wrap with TFSMLayer.
        try:
            model = tf.keras.models.load_model(str(model_path), compile=False)
            print(f"Loaded Keras-compatible model from directory: {model_path}")
        except Exception:
            # Wrap the SavedModel with a TFSMLayer for inference compatibility in Keras.
            try:
                tf_layer = tf.keras.layers.TFSMLayer(str(model_path), call_endpoint='serving_default')
                model = tf.keras.Sequential([
                    tf.keras.Input(shape=(args.img_size, args.img_size, 3)),
                    tf_layer,
                ])
                print(f"Wrapped TensorFlow SavedModel at {model_path} with TFSMLayer (serving_default).")
            except Exception as e:
                raise RuntimeError(f"Failed to load or wrap SavedModel directory '{model_path}': {e}")
    else:
        # Unknown path type: try load_model and allow it to raise a helpful exception.
        model = tf.keras.models.load_model(str(model_path), compile=False)
        print(f"Loaded model from path: {model_path}")
    image_path = Path(args.image_path)
    if not image_path.exists():
        raise FileNotFoundError(f"Image not found: {image_path}")
    x = load_image(image_path, args.img_size)
    x = np.expand_dims(x, axis=0)
    pred = model.predict(x)

    # If the model returns a dict (typical for a wrapped SavedModel serving signature),
    # select the requested output key or fall back to the first available numeric output.
    if isinstance(pred, dict):
        if args.output_key:
            if args.output_key not in pred:
                raise KeyError(f"Requested output key '{args.output_key}' not found. Available keys: {list(pred.keys())}")
            chosen = pred[args.output_key]
        else:
            first_key = next(iter(pred.keys()))
            print(f"No --output_key provided; using first output key: '{first_key}'")
            chosen = pred[first_key]
        arr = np.asarray(chosen)
    else:
        arr = np.asarray(pred)

    if arr.size == 0:
        raise ValueError("Model returned an empty prediction.")
    age_pred = float(arr.flatten()[0])
    print(f"Predicted age: {age_pred:.2f} years")


if __name__ == '__main__':
    main()