warresnaet's picture
Revert to Gradio SDK only; remove FastAPI /predict and Dockerfile
9b08420 verified
import os
import pathlib
import numpy as np
import gradio as gr
import tensorflow as tf
from huggingface_hub import hf_hub_download
# Configuration via environment variables (with sensible defaults)
# Ensure HF caches are writable in Spaces by default
os.environ.setdefault("HF_HOME", "/tmp/hfhome")
os.environ.setdefault("HF_HUB_CACHE", "/tmp/hfhome/hub")
HF_REPO_ID = os.environ.get("HF_REPO_ID", "warresnaet/masterclass-2025")
HF_MODEL_FILENAME = os.environ.get("HF_MODEL_FILENAME", "model.keras")
HF_REVISION = os.environ.get("HF_REVISION", "main")
LOCAL_MODEL_DIR = os.environ.get("LOCAL_MODEL_DIR", "/tmp/model")
# Labels used during training/inference
CLASS_NAMES = ["Cat", "Dog", "Panda"]
def _ensure_model() -> str:
"""
Ensure the model file exists locally by downloading it from the Hub if needed.
Returns the absolute path to the model file.
"""
os.makedirs(LOCAL_MODEL_DIR, exist_ok=True)
local_model_path = os.path.join(LOCAL_MODEL_DIR, HF_MODEL_FILENAME)
if os.path.exists(local_model_path):
return os.path.abspath(local_model_path)
downloaded_path = hf_hub_download(
repo_id=HF_REPO_ID,
filename=HF_MODEL_FILENAME,
repo_type="model",
revision=HF_REVISION,
local_dir=LOCAL_MODEL_DIR,
)
return os.path.abspath(downloaded_path)
def _load_model() -> tf.keras.Model:
"""
Load the Keras model (model.keras) from local path.
"""
model_path = _ensure_model()
model = tf.keras.models.load_model(model_path)
return model
# Load the model at startup (Space container build/run)
MODEL: tf.keras.Model = _load_model()
def predict(image: np.ndarray) -> dict:
"""
Gradio prediction function.
- image is a numpy array (H, W, C) in RGB
- resize to (64, 64)
- run inference
- return a {label: probability} dict
"""
# Resize to training input
resized = tf.image.resize(image, (64, 64))
# Training used raw pixel values (no normalization to 0..1)
batch = np.expand_dims(np.array(resized, dtype=np.float32), axis=0)
# Predict
probs = MODEL.predict(batch, verbose=0)[0]
# Ensure it's a flat list of floats
probs = np.asarray(probs, dtype=np.float32).tolist()
return {label: float(p) for label, p in zip(CLASS_NAMES, probs)}
# Build the Gradio UI
demo = gr.Interface(
fn=predict,
inputs=gr.Image(label="Upload image"),
outputs=gr.Label(num_top_classes=3, label="Top-3 predictions"),
title="Animal Classifier",
description=(
"Upload an image of a cat, dog, or panda. The model will predict the class "
"and return class probabilities."
),
examples=None, # You can set sample images here if you want
api_name="predict",
)
if __name__ == "__main__":
# Running locally (Hugging Face Spaces will call this file automatically)
demo.launch()