import os import pathlib import numpy as np import gradio as gr import tensorflow as tf from huggingface_hub import hf_hub_download # Configuration via environment variables (with sensible defaults) # Ensure HF caches are writable in Spaces by default os.environ.setdefault("HF_HOME", "/tmp/hfhome") os.environ.setdefault("HF_HUB_CACHE", "/tmp/hfhome/hub") HF_REPO_ID = os.environ.get("HF_REPO_ID", "warresnaet/masterclass-2025") HF_MODEL_FILENAME = os.environ.get("HF_MODEL_FILENAME", "model.keras") HF_REVISION = os.environ.get("HF_REVISION", "main") LOCAL_MODEL_DIR = os.environ.get("LOCAL_MODEL_DIR", "/tmp/model") # Labels used during training/inference CLASS_NAMES = ["Cat", "Dog", "Panda"] def _ensure_model() -> str: """ Ensure the model file exists locally by downloading it from the Hub if needed. Returns the absolute path to the model file. """ os.makedirs(LOCAL_MODEL_DIR, exist_ok=True) local_model_path = os.path.join(LOCAL_MODEL_DIR, HF_MODEL_FILENAME) if os.path.exists(local_model_path): return os.path.abspath(local_model_path) downloaded_path = hf_hub_download( repo_id=HF_REPO_ID, filename=HF_MODEL_FILENAME, repo_type="model", revision=HF_REVISION, local_dir=LOCAL_MODEL_DIR, ) return os.path.abspath(downloaded_path) def _load_model() -> tf.keras.Model: """ Load the Keras model (model.keras) from local path. """ model_path = _ensure_model() model = tf.keras.models.load_model(model_path) return model # Load the model at startup (Space container build/run) MODEL: tf.keras.Model = _load_model() def predict(image: np.ndarray) -> dict: """ Gradio prediction function. - image is a numpy array (H, W, C) in RGB - resize to (64, 64) - run inference - return a {label: probability} dict """ # Resize to training input resized = tf.image.resize(image, (64, 64)) # Training used raw pixel values (no normalization to 0..1) batch = np.expand_dims(np.array(resized, dtype=np.float32), axis=0) # Predict probs = MODEL.predict(batch, verbose=0)[0] # Ensure it's a flat list of floats probs = np.asarray(probs, dtype=np.float32).tolist() return {label: float(p) for label, p in zip(CLASS_NAMES, probs)} # Build the Gradio UI demo = gr.Interface( fn=predict, inputs=gr.Image(label="Upload image"), outputs=gr.Label(num_top_classes=3, label="Top-3 predictions"), title="Animal Classifier", description=( "Upload an image of a cat, dog, or panda. The model will predict the class " "and return class probabilities." ), examples=None, # You can set sample images here if you want api_name="predict", ) if __name__ == "__main__": # Running locally (Hugging Face Spaces will call this file automatically) demo.launch()