Sharris commited on
Commit
bb9e8e0
·
verified ·
1 Parent(s): ad9496b

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +86 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from PIL import Image
4
+ import tensorflow as tf
5
+ from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
6
+ import gradio as gr
7
+
8
+ # Try to load model from various sources
9
+ model = None
10
+
11
+ # Try local files first (for development)
12
+ local_model_paths = ["saved_model", "best_model.h5", "final_model.h5"]
13
+ for path in local_model_paths:
14
+ if os.path.exists(path):
15
+ try:
16
+ model = tf.keras.models.load_model(path, compile=False)
17
+ print(f"Loaded model from local path: {path}")
18
+ break
19
+ except Exception as e:
20
+ print(f"Failed to load local model from {path}: {e}")
21
+
22
+ # If no local model, try to download from Hugging Face Hub
23
+ if model is None:
24
+ HF_MODEL_ID = os.environ.get("HF_MODEL_ID", "Sharris/age_detection_regression")
25
+ try:
26
+ from huggingface_hub import hf_hub_download
27
+ # Try to download the .h5 model file
28
+ model_path = hf_hub_download(repo_id=HF_MODEL_ID, filename="best_model.h5")
29
+ model = tf.keras.models.load_model(model_path, compile=False)
30
+ print(f"Loaded model from HF Hub: {HF_MODEL_ID}/best_model.h5")
31
+ except Exception as e:
32
+ print(f"Failed to load model from HF Hub ({HF_MODEL_ID}): {e}")
33
+ # Fallback: try to download entire repo and load from there
34
+ try:
35
+ from huggingface_hub import snapshot_download
36
+ repo_dir = snapshot_download(repo_id=HF_MODEL_ID)
37
+ model_file = os.path.join(repo_dir, "best_model.h5")
38
+ if os.path.exists(model_file):
39
+ model = tf.keras.models.load_model(model_file, compile=False)
40
+ print(f"Loaded model from downloaded repo: {model_file}")
41
+ except Exception as e2:
42
+ print(f"Fallback download also failed: {e2}")
43
+
44
+ if model is None:
45
+ raise RuntimeError(
46
+ "No model found. Ensure 'best_model.h5' exists locally or set HF_MODEL_ID env var to a Hugging Face model repo containing the model."
47
+ )
48
+
49
+ INPUT_SIZE = (224, 224)
50
+
51
+
52
+ def predict_age(image: Image.Image):
53
+ if image.mode != 'RGB':
54
+ image = image.convert('RGB')
55
+ image = image.resize(INPUT_SIZE)
56
+ arr = np.array(image).astype(np.float32)
57
+ arr = preprocess_input(arr)
58
+ arr = np.expand_dims(arr, 0)
59
+
60
+ pred = model.predict(arr)[0]
61
+ # Ensure scalar
62
+ if hasattr(pred, '__len__'):
63
+ pred = float(np.asarray(pred).squeeze())
64
+ else:
65
+ pred = float(pred)
66
+
67
+ return {
68
+ "predicted_age": round(pred, 2),
69
+ "raw_output": float(pred)
70
+ }
71
+
72
+
73
+ demo = gr.Interface(
74
+ fn=predict_age,
75
+ inputs=gr.Image(type='pil', label='Face image (crop to face for best results)'),
76
+ outputs=[
77
+ gr.Number(label='Predicted age (years)'),
78
+ gr.Number(label='Raw model output')
79
+ ],
80
+ examples=[],
81
+ title='UTKFace Age Estimator',
82
+ description='Upload a cropped face image and the model will predict age in years. For Spaces, set the HF_MODEL_ID environment variable to your Hugging Face model repo if you want the app to download a SavedModel from the Hub.'
83
+ )
84
+
85
+ if __name__ == '__main__':
86
+ demo.launch(server_name='0.0.0.0', server_port=int(os.environ.get('PORT', 7860)))