papasega commited on
Commit
2b04118
·
verified ·
1 Parent(s): cb83b67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -30
app.py CHANGED
@@ -3,52 +3,54 @@ import gradio as gr
3
  import numpy as np
4
  import cv2
5
 
6
- # Load model
7
  model = tf.keras.models.load_model('TP_MNIST_CNN_model.h5')
8
 
9
- def preprocess_and_predict(input_data):
 
 
 
 
 
 
 
 
10
  if input_data is None:
11
  return None
12
-
13
- # Gradio 4.x Sketchpad returns a dict usually: {'composite': array, 'layers': [...]}
14
- # We take the composite image
15
- if isinstance(input_data, dict):
16
- image = input_data['composite']
17
- else:
18
- image = input_data
19
-
20
- # 1. Resize to MNIST standard (28x28)
21
- # Interpolation AREA is better for shrinking images without losing thin lines
22
- image = cv2.resize(image, (28, 28), interpolation=cv2.INTER_AREA)
23
 
24
- # 2. Convert to Grayscale (if RGB/RGBA)
 
 
 
 
 
 
 
25
  if len(image.shape) == 3:
26
  image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
27
 
28
- # 3. Normalize & Invert Logic
29
- # MNIST is White digits on Black background.
30
- # If user draws Black on White, we must invert.
31
- # Check mean pixel intensity: if high (>127), background is likely white.
32
  if np.mean(image) > 127:
33
  image = 255 - image
34
-
35
- # 4. Normalize 0-1
36
- image = image / 255.0
37
 
38
- # 5. Reshape for Model (Batch, Height, Width, Channels)
 
 
 
39
  image = image.reshape(1, 28, 28, 1)
40
 
41
- # Prediction
42
- predictions = model.predict(image, verbose=0)
43
- return int(np.argmax(predictions))
44
 
45
- # Modern Gradio 4 Interface
46
  iface = gr.Interface(
47
- fn=preprocess_and_predict,
48
- inputs=gr.Sketchpad(label="Dessinez un chiffre", type="numpy"),
49
  outputs="label",
50
- title="MNIST Digit Recognition",
51
- description="Reconnaissance de chiffres manuscrits via CNN. Dessinez au centre.",
52
  allow_flagging="never"
53
  )
54
 
 
3
  import numpy as np
4
  import cv2
5
 
6
+ # 1. Load Model (Optimized load)
7
  model = tf.keras.models.load_model('TP_MNIST_CNN_model.h5')
8
 
9
+ def predict_digit(input_data):
10
+ """
11
+ Pipeline de prédiction robuste :
12
+ 1. Gestion du format d'entrée (Gradio 4 renvoie parfois un dict).
13
+ 2. Resize vers 28x28 (Interpolation AREA pour préserver les traits).
14
+ 3. Conversion Grayscale.
15
+ 4. Inversion des couleurs (Adaptation domaine Humain -> Machine).
16
+ 5. Normalisation et inférence.
17
+ """
18
  if input_data is None:
19
  return None
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ # Gradio 4 handle : input_data peut être un dictionnaire {'composite': ...}
22
+ image = input_data["composite"] if isinstance(input_data, dict) else input_data
23
+
24
+ # Pipeline OpenCV
25
+ # Resize vers 28x28
26
+ image = cv2.resize(image, (28, 28), interpolation=cv2.INTER_AREA)
27
+
28
+ # Convertir en niveaux de gris si nécessaire
29
  if len(image.shape) == 3:
30
  image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
31
 
32
+ # Inversion intelligente : Si l'image est majoritairement blanche (dessin noir sur fond blanc)
33
+ # on inverse car MNIST a été entraîné sur blanc sur fond noir.
 
 
34
  if np.mean(image) > 127:
35
  image = 255 - image
 
 
 
36
 
37
+ # Normalisation
38
+ image = image / 255.0
39
+
40
+ # Reshape (Batch, H, W, Channels)
41
  image = image.reshape(1, 28, 28, 1)
42
 
43
+ # Inférence
44
+ prediction = model.predict(image, verbose=0)
45
+ return int(np.argmax(prediction))
46
 
47
+ # Interface Gradio 4 Moderne
48
  iface = gr.Interface(
49
+ fn=predict_digit,
50
+ inputs=gr.Sketchpad(label="Dessinez un chiffre", type="numpy", crop_size=(28, 28)),
51
  outputs="label",
52
+ title="Reconnaissance MNIST - Production Grade",
53
+ description="CNN Model. Dessinez un chiffre au centre.",
54
  allow_flagging="never"
55
  )
56