jamesywu5's picture
updated and imported model from hf hub
fb595e0
raw
history blame contribute delete
757 Bytes
import os
import tensorflow as tf
from tensorflow import keras
import gradio as gr
import numpy as np
import cv2
from huggingface_hub import hf_hub_download
# Load the model
model_path = hf_hub_download(repo_id="jamesywu5/cifar10_model", filename="image_classifier.keras")
model = tf.keras.models.load_model(model_path)
class_names = ['Airplane', 'Car', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']
def predict_fn(img):
img = cv2.resize(img, (32, 32)) / 255.0
img = np.expand_dims(img, axis=0)
preds = model.predict(img)[0]
return {class_names[i]: float(preds[i]) for i in range(len(class_names))}
interface = gr.Interface(fn=predict_fn, inputs=gr.Image(), outputs=gr.Label())
interface.launch(inline=False,share=True)