papasega's picture
Update app.py
38b45ca verified
raw
history blame
1.68 kB
import tensorflow as tf
import gradio as gr
import numpy as np
import cv2
# Load model
model = tf.keras.models.load_model('TP_MNIST_CNN_model.h5')
def preprocess_and_predict(input_data):
if input_data is None:
return None
# Gradio 4.x Sketchpad returns a dict usually: {'composite': array, 'layers': [...]}
# We take the composite image
if isinstance(input_data, dict):
image = input_data['composite']
else:
image = input_data
# 1. Resize to MNIST standard (28x28)
# Interpolation AREA is better for shrinking images without losing thin lines
image = cv2.resize(image, (28, 28), interpolation=cv2.INTER_AREA)
# 2. Convert to Grayscale (if RGB/RGBA)
if len(image.shape) == 3:
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# 3. Normalize & Invert Logic
# MNIST is White digits on Black background.
# If user draws Black on White, we must invert.
# Check mean pixel intensity: if high (>127), background is likely white.
if np.mean(image) > 127:
image = 255 - image
# 4. Normalize 0-1
image = image / 255.0
# 5. Reshape for Model (Batch, Height, Width, Channels)
image = image.reshape(1, 28, 28, 1)
# Prediction
predictions = model.predict(image, verbose=0)
return int(np.argmax(predictions))
# Modern Gradio 4 Interface
iface = gr.Interface(
fn=preprocess_and_predict,
inputs=gr.Sketchpad(label="Dessinez un chiffre", type="numpy"),
outputs="label",
title="MNIST Digit Recognition",
description="Reconnaissance de chiffres manuscrits via CNN. Dessinez au centre.",
allow_flagging="never"
)
iface.launch()