Sharris's picture
Upload app.py with huggingface_hub
bb9e8e0 verified
raw
history blame
3.17 kB
import os
import numpy as np
from PIL import Image
import tensorflow as tf
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
import gradio as gr
# Try to load model from various sources
model = None
# Try local files first (for development)
local_model_paths = ["saved_model", "best_model.h5", "final_model.h5"]
for path in local_model_paths:
if os.path.exists(path):
try:
model = tf.keras.models.load_model(path, compile=False)
print(f"Loaded model from local path: {path}")
break
except Exception as e:
print(f"Failed to load local model from {path}: {e}")
# If no local model, try to download from Hugging Face Hub
if model is None:
HF_MODEL_ID = os.environ.get("HF_MODEL_ID", "Sharris/age_detection_regression")
try:
from huggingface_hub import hf_hub_download
# Try to download the .h5 model file
model_path = hf_hub_download(repo_id=HF_MODEL_ID, filename="best_model.h5")
model = tf.keras.models.load_model(model_path, compile=False)
print(f"Loaded model from HF Hub: {HF_MODEL_ID}/best_model.h5")
except Exception as e:
print(f"Failed to load model from HF Hub ({HF_MODEL_ID}): {e}")
# Fallback: try to download entire repo and load from there
try:
from huggingface_hub import snapshot_download
repo_dir = snapshot_download(repo_id=HF_MODEL_ID)
model_file = os.path.join(repo_dir, "best_model.h5")
if os.path.exists(model_file):
model = tf.keras.models.load_model(model_file, compile=False)
print(f"Loaded model from downloaded repo: {model_file}")
except Exception as e2:
print(f"Fallback download also failed: {e2}")
if model is None:
raise RuntimeError(
"No model found. Ensure 'best_model.h5' exists locally or set HF_MODEL_ID env var to a Hugging Face model repo containing the model."
)
INPUT_SIZE = (224, 224)
def predict_age(image: Image.Image):
if image.mode != 'RGB':
image = image.convert('RGB')
image = image.resize(INPUT_SIZE)
arr = np.array(image).astype(np.float32)
arr = preprocess_input(arr)
arr = np.expand_dims(arr, 0)
pred = model.predict(arr)[0]
# Ensure scalar
if hasattr(pred, '__len__'):
pred = float(np.asarray(pred).squeeze())
else:
pred = float(pred)
return {
"predicted_age": round(pred, 2),
"raw_output": float(pred)
}
demo = gr.Interface(
fn=predict_age,
inputs=gr.Image(type='pil', label='Face image (crop to face for best results)'),
outputs=[
gr.Number(label='Predicted age (years)'),
gr.Number(label='Raw model output')
],
examples=[],
title='UTKFace Age Estimator',
description='Upload a cropped face image and the model will predict age in years. For Spaces, set the HF_MODEL_ID environment variable to your Hugging Face model repo if you want the app to download a SavedModel from the Hub.'
)
if __name__ == '__main__':
demo.launch(server_name='0.0.0.0', server_port=int(os.environ.get('PORT', 7860)))