yaya zhang
fix predict issue
e79a28e
import os, io, json
import numpy as np
from PIL import Image
import tensorflow as tf
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
# --- set HF env BEFORE importing huggingface_hub ---
os.environ.setdefault("HF_HUB_ENABLE_XET", "0")
os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
os.environ.setdefault("HF_HOME", "/tmp/hf")
os.environ.setdefault("HF_HUB_CACHE", "/tmp/hf/cache")
os.environ.setdefault("XDG_CACHE_HOME", "/tmp")
from huggingface_hub import hf_hub_download
import gradio as gr
# ----- Settings -----
REPO_ID = "woyaya114/mlop-deployment"
INPUT_SIZE = (64, 64)
CLASS_NAMES = ["cat", "dog", "panda"]
HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
LOCAL_DIR = "/tmp/model"
os.makedirs(LOCAL_DIR, exist_ok=True)
# ----- Download model files (no symlinks) -----
config_path = hf_hub_download(
repo_id=REPO_ID, filename="config.json", repo_type="model",
local_dir=LOCAL_DIR, local_dir_use_symlinks=False, token=HF_TOKEN
)
weights_path = hf_hub_download(
repo_id=REPO_ID, filename="model.weights.h5", repo_type="model",
local_dir=LOCAL_DIR, local_dir_use_symlinks=False, token=HF_TOKEN
)
# ----- Rebuild Keras model -----
with open(config_path, "r") as f:
cfg = json.load(f)
def build_model_from_config(cfg_dict):
try:
from keras.saving import serialization_lib # Keras 3
return serialization_lib.deserialize_keras_object(cfg_dict)
except Exception:
inner = cfg_dict.get("config", {})
if cfg_dict.get("class_name") == "Sequential":
return tf.keras.Sequential.from_config(inner)
return tf.keras.Model.from_config(inner)
model = build_model_from_config(cfg)
model.load_weights(weights_path)
# -------- FastAPI app --------
app = FastAPI(title="Animal Classifier (Keras) - FastAPI + Gradio")
@app.get("/healthz")
def healthz():
return {"status": "ok"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
"""REST endpoint (multipart form-data: file=@image.jpg)"""
img_bytes = await file.read()
img = Image.open(io.BytesIO(img_bytes)).convert("RGB").resize(INPUT_SIZE)
x = np.array(img, dtype=np.float32) / 255.0
x = np.expand_dims(x, 0)
probs = model.predict(x, verbose=0)[0].tolist()
result = {name: float(p) for name, p in zip(CLASS_NAMES, probs)}
top3 = sorted(result.items(), key=lambda kv: kv[1], reverse=True)[:3]
return JSONResponse({"probs": result, "top3": top3})
# -------- Gradio UI (mounted into FastAPI) --------
def predict_gradio(image):
x = tf.convert_to_tensor(image, dtype=tf.float32) / 255.0
x = tf.image.resize(x, INPUT_SIZE)
x = tf.expand_dims(x, 0)
probs = model.predict(x, verbose=0)[0]
return {name: float(p) for name, p in zip(CLASS_NAMES, probs)}
demo = gr.Interface(
fn=predict_gradio,
inputs=gr.Image(type="numpy", label="Upload image"),
outputs=gr.Label(num_top_classes=min(3, len(CLASS_NAMES)), label="Prediction"),
title="Animal Classifier (Keras)"
)
# Mount Gradio at "/" (root). Swagger stays at /docs, REST at /predict.
app = gr.mount_gradio_app(app, demo, path="/")