IDS75912's picture
fix
36dc52d
import uvicorn
import fastapi
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi import File, UploadFile
import numpy as np
from PIL import Image
from typing import Any, Dict
import os
import pkgutil
from huggingface_hub import hf_hub_download
from huggingface_hub import hf_hub_url
import requests
import tempfile
import shutil
from typing import Any, Dict
import tensorflow as tf
import traceback
import logging
from tensorflow import keras
app = FastAPI(title="1.3 - AI Model Deployment - HF Hub + FastAPI",)
''' browser: http://localhost:8000/docs'''
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
ANIMALS = ['Cat', 'Dog', 'Panda'] # Animal names here, these represent the labels of the images that we trained our model on.
# 1) download your SavedModel from the Hub into a writable directory (Spaces often
# take HF_MODEL_DIR or default model_dir).
repo_id = "IDS75912/masterclass-2025"
local_model_dir = os.environ.get('HF_MODEL_DIR', './model_dir')
# Ensure the directory exists and is writable. If creating fails, raise a clear error.
try:
os.makedirs(local_model_dir, exist_ok=True)
except Exception as e:
raise RuntimeError(f"Cannot create model directory '{local_model_dir}'. Ensure the process has write access or set HF_MODEL_DIR to a writable path., Error: {e}")
# download files into local_model_dir and load model with resilient error handling
model = None
model_load_error = None
#try:
# First try using a cache dir so downloads happen in a shared cache and final move
# into local_model_dir is less likely to require risky renames inside the repo.
# cache_dir = '/tmp/.cache/huggingface'
# os.makedirs(cache_dir, exist_ok=True)
hf_hub_download(repo_id, filename="config.json", repo_type="model", local_dir=local_model_dir )
hf_hub_download(repo_id, filename="metadata.json", repo_type="model", local_dir=local_model_dir)
hf_hub_download(repo_id, filename="model.weights.h5", repo_type="model", local_dir=local_model_dir)
# 2) load it
model = tf.keras.models.load_model(local_model_dir)
logging.info(f"Model loaded successfully from {local_model_dir}")
@app.post('/upload/image')
async def uploadImage(img: UploadFile = File(...)):
if model is None:
# Model isn't available — return a helpful error to the caller instead of crashing.
return fastapi.Response(status_code=503, content=f"Model not loaded: {model_load_error}")
original_image = Image.open(img.file) # Read the bytes and process as an image
if original_image.mode == 'RGBA':
original_image = original_image.convert('RGB')
resized_image = original_image.resize((64, 64)) # Resize
images_to_predict = np.expand_dims(np.array(resized_image), axis=0) # Our AI Model wanted a list of images, but we only have one, so we expand it's dimension
predictions = model.predict(images_to_predict) # The result will be a list with predictions in the one-hot encoded format: [ [0 1 0] ]
prediction_probabilities = predictions
classifications = prediction_probabilities.argmax(axis=1) # We try to fetch the index of the highest value in this list [ [1] ]
return ANIMALS[classifications.tolist()[0]] # Fetch the first item in our classifications array, format it as a list first, result will be e.g.: "Dog"
@app.get("/")
def read_root() -> Dict[str, Any]:
"""Root endpoint."""
return {"message": "Hello from FastAPI in 1.3 - AI Model Deployment - HF Hub + FastAPI"}
@app.get("/version")
def versions() -> Dict[str, Any]:
"""Return key package versions and whether TensorFlow is available."""
return {
"fastapi": fastapi.__version__,
}
@app.get("/predict")
def predict_stub() -> Dict[str, Any]:
# This is a stub, so we're not doing a real prediction
if model is None:
return {"prediction": "model not loaded", "error": model_load_error}
return {"prediction": "stub, we're not doing a real prediction"}
if __name__ == "__main__":
# Run with: conda run -n gradio uvicorn main:app --reload
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) # ? 7860 instead of 8000