File size: 3,172 Bytes
de3c81a
 
 
 
 
 
 
d549347
de3c81a
 
d549347
 
 
 
 
 
 
 
 
 
de3c81a
d549347
de3c81a
d549347
 
 
 
 
 
 
 
 
 
de3c81a
 
 
d549347
 
 
 
 
 
de3c81a
 
 
d549347
de3c81a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import numpy as np
from PIL import Image
import tensorflow as tf
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
import gradio as gr

# Try to load model from various sources
model = None

# Try local files first (for development)
local_model_paths = ["saved_model", "best_model.h5", "final_model.h5"]
for path in local_model_paths:
    if os.path.exists(path):
        try:
            model = tf.keras.models.load_model(path, compile=False)
            print(f"Loaded model from local path: {path}")
            break
        except Exception as e:
            print(f"Failed to load local model from {path}: {e}")

# If no local model, try to download from Hugging Face Hub
if model is None:
    HF_MODEL_ID = os.environ.get("HF_MODEL_ID", "Sharris/age_detection_regression")
    try:
        from huggingface_hub import hf_hub_download
        # Try to download the .h5 model file
        model_path = hf_hub_download(repo_id=HF_MODEL_ID, filename="best_model.h5")
        model = tf.keras.models.load_model(model_path, compile=False)
        print(f"Loaded model from HF Hub: {HF_MODEL_ID}/best_model.h5")
    except Exception as e:
        print(f"Failed to load model from HF Hub ({HF_MODEL_ID}): {e}")
        # Fallback: try to download entire repo and load from there
        try:
            from huggingface_hub import snapshot_download
            repo_dir = snapshot_download(repo_id=HF_MODEL_ID)
            model_file = os.path.join(repo_dir, "best_model.h5")
            if os.path.exists(model_file):
                model = tf.keras.models.load_model(model_file, compile=False)
                print(f"Loaded model from downloaded repo: {model_file}")
        except Exception as e2:
            print(f"Fallback download also failed: {e2}")

if model is None:
    raise RuntimeError(
        "No model found. Ensure 'best_model.h5' exists locally or set HF_MODEL_ID env var to a Hugging Face model repo containing the model."
    )

INPUT_SIZE = (224, 224)


def predict_age(image: Image.Image):
    if image.mode != 'RGB':
        image = image.convert('RGB')
    image = image.resize(INPUT_SIZE)
    arr = np.array(image).astype(np.float32)
    arr = preprocess_input(arr)
    arr = np.expand_dims(arr, 0)

    pred = model.predict(arr)[0]
    # Ensure scalar
    if hasattr(pred, '__len__'):
        pred = float(np.asarray(pred).squeeze())
    else:
        pred = float(pred)

    return {
        "predicted_age": round(pred, 2),
        "raw_output": float(pred)
    }


demo = gr.Interface(
    fn=predict_age,
    inputs=gr.Image(type='pil', label='Face image (crop to face for best results)'),
    outputs=[
        gr.Number(label='Predicted age (years)'),
        gr.Number(label='Raw model output')
    ],
    examples=[],
    title='UTKFace Age Estimator',
    description='Upload a cropped face image and the model will predict age in years. For Spaces, set the HF_MODEL_ID environment variable to your Hugging Face model repo if you want the app to download a SavedModel from the Hub.'
)

if __name__ == '__main__':
    demo.launch(server_name='0.0.0.0', server_port=int(os.environ.get('PORT', 7860)))