File size: 3,606 Bytes
bb9e8e0
 
 
 
 
 
 
cfce11e
bb9e8e0
cfce11e
bb9e8e0
cfce11e
bb9e8e0
cfce11e
 
 
 
 
 
 
 
 
 
 
 
 
 
bb9e8e0
cfce11e
 
 
 
 
 
 
 
 
bb9e8e0
cfce11e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb9e8e0
 
 
cfce11e
bb9e8e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36ac5cc
bb9e8e0
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import numpy as np
from PIL import Image
import tensorflow as tf
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
import gradio as gr

# Load model from Hugging Face Hub
model = None
HF_MODEL_ID = os.environ.get("HF_MODEL_ID", "Sharris/age_detection_regression")

print(f"Attempting to load model from: {HF_MODEL_ID}")

try:
    from huggingface_hub import hf_hub_download
    print("Downloading best_model.h5...")
    model_path = hf_hub_download(repo_id=HF_MODEL_ID, filename="best_model.h5")
    print(f"Model downloaded to: {model_path}")
    
    print("Loading model with TensorFlow...")
    model = tf.keras.models.load_model(model_path, compile=False)
    print(f"βœ… Successfully loaded model from {HF_MODEL_ID}")
    
except Exception as e:
    print(f"❌ Failed to download best_model.h5: {e}")
    
    # Fallback: try to download entire repo and look for model files
    try:
        print("Trying fallback: downloading entire repository...")
        from huggingface_hub import snapshot_download
        repo_dir = snapshot_download(repo_id=HF_MODEL_ID)
        print(f"Repository downloaded to: {repo_dir}")
        
        # Look for model files in the downloaded repo
        possible_files = ["best_model.h5", "final_model.h5", "model.h5"]
        for filename in possible_files:
            model_file = os.path.join(repo_dir, filename)
            if os.path.exists(model_file):
                print(f"Found model file: {model_file}")
                try:
                    model = tf.keras.models.load_model(model_file, compile=False)
                    print(f"βœ… Successfully loaded model from {model_file}")
                    break
                except Exception as load_error:
                    print(f"Failed to load {model_file}: {load_error}")
                    continue
        
        if model is None:
            # List all files in the repo for debugging
            import os
            print("Files in downloaded repository:")
            for root, dirs, files in os.walk(repo_dir):
                for file in files:
                    print(f"  {os.path.join(root, file)}")
                    
    except Exception as e2:
        print(f"❌ Fallback download also failed: {e2}")

if model is None:
    raise RuntimeError(
        f"❌ Could not load model from {HF_MODEL_ID}. Please ensure the repository contains a valid model file (best_model.h5, final_model.h5, or model.h5)."
    )

INPUT_SIZE = (224, 224)


def predict_age(image: Image.Image):
    if image.mode != 'RGB':
        image = image.convert('RGB')
    image = image.resize(INPUT_SIZE)
    arr = np.array(image).astype(np.float32)
    arr = preprocess_input(arr)
    arr = np.expand_dims(arr, 0)

    pred = model.predict(arr)[0]
    # Ensure scalar
    if hasattr(pred, '__len__'):
        pred = float(np.asarray(pred).squeeze())
    else:
        pred = float(pred)

    return round(pred, 2), float(pred)
demo = gr.Interface(
    fn=predict_age,
    inputs=gr.Image(type='pil', label='Face image (crop to face for best results)'),
    outputs=[
        gr.Number(label='Predicted age (years)'),
        gr.Number(label='Raw model output')
    ],
    examples=[],
    title='UTKFace Age Estimator',
    description='Upload a cropped face image and the model will predict age in years. For Spaces, set the HF_MODEL_ID environment variable to your Hugging Face model repo if you want the app to download a SavedModel from the Hub.'
)

if __name__ == '__main__':
    demo.launch(server_name='0.0.0.0', server_port=int(os.environ.get('PORT', 7860)))