Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| from PIL import Image | |
| import tensorflow as tf | |
| from tensorflow.keras.applications.mobilenet_v2 import preprocess_input | |
| import gradio as gr | |
| # Try to load model from various sources | |
| model = None | |
| # Try local files first (for development) | |
| local_model_paths = ["saved_model", "best_model.h5", "final_model.h5"] | |
| for path in local_model_paths: | |
| if os.path.exists(path): | |
| try: | |
| model = tf.keras.models.load_model(path, compile=False) | |
| print(f"Loaded model from local path: {path}") | |
| break | |
| except Exception as e: | |
| print(f"Failed to load local model from {path}: {e}") | |
| # If no local model, try to download from Hugging Face Hub | |
| if model is None: | |
| HF_MODEL_ID = os.environ.get("HF_MODEL_ID", "Sharris/age_detection_regression") | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| # Try to download the .h5 model file | |
| model_path = hf_hub_download(repo_id=HF_MODEL_ID, filename="best_model.h5") | |
| model = tf.keras.models.load_model(model_path, compile=False) | |
| print(f"Loaded model from HF Hub: {HF_MODEL_ID}/best_model.h5") | |
| except Exception as e: | |
| print(f"Failed to load model from HF Hub ({HF_MODEL_ID}): {e}") | |
| # Fallback: try to download entire repo and load from there | |
| try: | |
| from huggingface_hub import snapshot_download | |
| repo_dir = snapshot_download(repo_id=HF_MODEL_ID) | |
| model_file = os.path.join(repo_dir, "best_model.h5") | |
| if os.path.exists(model_file): | |
| model = tf.keras.models.load_model(model_file, compile=False) | |
| print(f"Loaded model from downloaded repo: {model_file}") | |
| except Exception as e2: | |
| print(f"Fallback download also failed: {e2}") | |
| if model is None: | |
| raise RuntimeError( | |
| "No model found. Ensure 'best_model.h5' exists locally or set HF_MODEL_ID env var to a Hugging Face model repo containing the model." | |
| ) | |
| INPUT_SIZE = (224, 224) | |
| def predict_age(image: Image.Image): | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| image = image.resize(INPUT_SIZE) | |
| arr = np.array(image).astype(np.float32) | |
| arr = preprocess_input(arr) | |
| arr = np.expand_dims(arr, 0) | |
| pred = model.predict(arr)[0] | |
| # Ensure scalar | |
| if hasattr(pred, '__len__'): | |
| pred = float(np.asarray(pred).squeeze()) | |
| else: | |
| pred = float(pred) | |
| return { | |
| "predicted_age": round(pred, 2), | |
| "raw_output": float(pred) | |
| } | |
| demo = gr.Interface( | |
| fn=predict_age, | |
| inputs=gr.Image(type='pil', label='Face image (crop to face for best results)'), | |
| outputs=[ | |
| gr.Number(label='Predicted age (years)'), | |
| gr.Number(label='Raw model output') | |
| ], | |
| examples=[], | |
| title='UTKFace Age Estimator', | |
| description='Upload a cropped face image and the model will predict age in years. For Spaces, set the HF_MODEL_ID environment variable to your Hugging Face model repo if you want the app to download a SavedModel from the Hub.' | |
| ) | |
| if __name__ == '__main__': | |
| demo.launch(server_name='0.0.0.0', server_port=int(os.environ.get('PORT', 7860))) | |