Spaces:
Sleeping
Sleeping
| import os, io, json | |
| import numpy as np | |
| from PIL import Image | |
| import tensorflow as tf | |
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.responses import JSONResponse | |
| # --- set HF env BEFORE importing huggingface_hub --- | |
| os.environ.setdefault("HF_HUB_ENABLE_XET", "0") | |
| os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1") | |
| os.environ.setdefault("HF_HOME", "/tmp/hf") | |
| os.environ.setdefault("HF_HUB_CACHE", "/tmp/hf/cache") | |
| os.environ.setdefault("XDG_CACHE_HOME", "/tmp") | |
| from huggingface_hub import hf_hub_download | |
| import gradio as gr | |
| # ----- Settings ----- | |
| REPO_ID = "woyaya114/mlop-deployment" | |
| INPUT_SIZE = (64, 64) | |
| CLASS_NAMES = ["cat", "dog", "panda"] | |
| HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| LOCAL_DIR = "/tmp/model" | |
| os.makedirs(LOCAL_DIR, exist_ok=True) | |
| # ----- Download model files (no symlinks) ----- | |
| config_path = hf_hub_download( | |
| repo_id=REPO_ID, filename="config.json", repo_type="model", | |
| local_dir=LOCAL_DIR, local_dir_use_symlinks=False, token=HF_TOKEN | |
| ) | |
| weights_path = hf_hub_download( | |
| repo_id=REPO_ID, filename="model.weights.h5", repo_type="model", | |
| local_dir=LOCAL_DIR, local_dir_use_symlinks=False, token=HF_TOKEN | |
| ) | |
| # ----- Rebuild Keras model ----- | |
| with open(config_path, "r") as f: | |
| cfg = json.load(f) | |
| def build_model_from_config(cfg_dict): | |
| try: | |
| from keras.saving import serialization_lib # Keras 3 | |
| return serialization_lib.deserialize_keras_object(cfg_dict) | |
| except Exception: | |
| inner = cfg_dict.get("config", {}) | |
| if cfg_dict.get("class_name") == "Sequential": | |
| return tf.keras.Sequential.from_config(inner) | |
| return tf.keras.Model.from_config(inner) | |
| model = build_model_from_config(cfg) | |
| model.load_weights(weights_path) | |
| # -------- FastAPI app -------- | |
| app = FastAPI(title="Animal Classifier (Keras) - FastAPI + Gradio") | |
| def healthz(): | |
| return {"status": "ok"} | |
| async def predict(file: UploadFile = File(...)): | |
| """REST endpoint (multipart form-data: file=@image.jpg)""" | |
| img_bytes = await file.read() | |
| img = Image.open(io.BytesIO(img_bytes)).convert("RGB").resize(INPUT_SIZE) | |
| x = np.array(img, dtype=np.float32) / 255.0 | |
| x = np.expand_dims(x, 0) | |
| probs = model.predict(x, verbose=0)[0].tolist() | |
| result = {name: float(p) for name, p in zip(CLASS_NAMES, probs)} | |
| top3 = sorted(result.items(), key=lambda kv: kv[1], reverse=True)[:3] | |
| return JSONResponse({"probs": result, "top3": top3}) | |
| # -------- Gradio UI (mounted into FastAPI) -------- | |
| def predict_gradio(image): | |
| x = tf.convert_to_tensor(image, dtype=tf.float32) / 255.0 | |
| x = tf.image.resize(x, INPUT_SIZE) | |
| x = tf.expand_dims(x, 0) | |
| probs = model.predict(x, verbose=0)[0] | |
| return {name: float(p) for name, p in zip(CLASS_NAMES, probs)} | |
| demo = gr.Interface( | |
| fn=predict_gradio, | |
| inputs=gr.Image(type="numpy", label="Upload image"), | |
| outputs=gr.Label(num_top_classes=min(3, len(CLASS_NAMES)), label="Prediction"), | |
| title="Animal Classifier (Keras)" | |
| ) | |
| # Mount Gradio at "/" (root). Swagger stays at /docs, REST at /predict. | |
| app = gr.mount_gradio_app(app, demo, path="/") | |