| | import os |
| | import pathlib |
| | import numpy as np |
| | import gradio as gr |
| | import tensorflow as tf |
| | from huggingface_hub import hf_hub_download |
| |
|
| | |
| | |
| | os.environ.setdefault("HF_HOME", "/tmp/hfhome") |
| | os.environ.setdefault("HF_HUB_CACHE", "/tmp/hfhome/hub") |
| |
|
| | HF_REPO_ID = os.environ.get("HF_REPO_ID", "warresnaet/masterclass-2025") |
| | HF_MODEL_FILENAME = os.environ.get("HF_MODEL_FILENAME", "model.keras") |
| | HF_REVISION = os.environ.get("HF_REVISION", "main") |
| | LOCAL_MODEL_DIR = os.environ.get("LOCAL_MODEL_DIR", "/tmp/model") |
| |
|
| | |
| | CLASS_NAMES = ["Cat", "Dog", "Panda"] |
| |
|
| |
|
| | def _ensure_model() -> str: |
| | """ |
| | Ensure the model file exists locally by downloading it from the Hub if needed. |
| | Returns the absolute path to the model file. |
| | """ |
| | os.makedirs(LOCAL_MODEL_DIR, exist_ok=True) |
| | local_model_path = os.path.join(LOCAL_MODEL_DIR, HF_MODEL_FILENAME) |
| |
|
| | if os.path.exists(local_model_path): |
| | return os.path.abspath(local_model_path) |
| |
|
| | downloaded_path = hf_hub_download( |
| | repo_id=HF_REPO_ID, |
| | filename=HF_MODEL_FILENAME, |
| | repo_type="model", |
| | revision=HF_REVISION, |
| | local_dir=LOCAL_MODEL_DIR, |
| | ) |
| | return os.path.abspath(downloaded_path) |
| |
|
| |
|
| | def _load_model() -> tf.keras.Model: |
| | """ |
| | Load the Keras model (model.keras) from local path. |
| | """ |
| | model_path = _ensure_model() |
| | model = tf.keras.models.load_model(model_path) |
| | return model |
| |
|
| |
|
| | |
| | MODEL: tf.keras.Model = _load_model() |
| |
|
| |
|
| | def predict(image: np.ndarray) -> dict: |
| | """ |
| | Gradio prediction function. |
| | - image is a numpy array (H, W, C) in RGB |
| | - resize to (64, 64) |
| | - run inference |
| | - return a {label: probability} dict |
| | """ |
| | |
| | resized = tf.image.resize(image, (64, 64)) |
| | |
| | batch = np.expand_dims(np.array(resized, dtype=np.float32), axis=0) |
| |
|
| | |
| | probs = MODEL.predict(batch, verbose=0)[0] |
| | |
| | probs = np.asarray(probs, dtype=np.float32).tolist() |
| |
|
| | return {label: float(p) for label, p in zip(CLASS_NAMES, probs)} |
| |
|
| |
|
| | |
| | demo = gr.Interface( |
| | fn=predict, |
| | inputs=gr.Image(label="Upload image"), |
| | outputs=gr.Label(num_top_classes=3, label="Top-3 predictions"), |
| | title="Animal Classifier", |
| | description=( |
| | "Upload an image of a cat, dog, or panda. The model will predict the class " |
| | "and return class probabilities." |
| | ), |
| | examples=None, |
| | api_name="predict", |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | |
| | demo.launch() |
| |
|