mnist-ann-model / inference.py
ramanat1968's picture
Upload folder using huggingface_hub
7a344ab verified
import tensorflow as tf
import numpy as np
# Load model
model = tf.keras.models.load_model("mnist_ann_model.keras")
def predict_digit(image_array):
# Expect shape (28, 28)
image_array = image_array / 255.0
image_array = np.expand_dims(image_array, axis=0)
prediction = model.predict(image_array)
return np.argmax(prediction)
# Example usage
if __name__ == "__main__":
sample = np.random.rand(28, 28)
print("Predicted digit:", predict_digit(sample))