anamjafar6 commited on
Commit
ac3c8f5
Β·
verified Β·
1 Parent(s): b856d6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -59
app.py CHANGED
@@ -1,8 +1,3 @@
1
- # -------------------------
2
- # Handwritten Digit Recognition App (robust preprocessing)
3
- # Built by Anam Jafar
4
- # -------------------------
5
-
6
  import streamlit as st
7
  import numpy as np
8
  import cv2
@@ -10,10 +5,14 @@ from PIL import Image
10
  from tensorflow.keras.models import load_model
11
  from streamlit_drawable_canvas import st_canvas
12
 
13
- # Page config
14
- st.set_page_config(page_title="Digit Recognition App", page_icon="πŸ”’", layout="wide")
 
 
 
 
15
 
16
- # Background (professional)
17
  st.markdown(
18
  """
19
  <style>
@@ -25,55 +24,36 @@ st.markdown(
25
  unsafe_allow_html=True
26
  )
27
 
28
- # Load model (cached)
29
  @st.cache_resource
30
  def load_cnn_model():
31
  return load_model("mnist_cnn.h5")
32
 
33
  model = load_cnn_model()
34
 
35
- # ---------------------
36
- # Helper: preprocess PIL file uploads
37
- # ---------------------
38
  def preprocess_pil_file(file_or_pil_image):
39
- """
40
- Accept either a file-like object from file_uploader or a PIL.Image.
41
- Returns: preprocessed array shape (1,28,28,1), and a display PIL image (28x28)
42
- """
43
  if not isinstance(file_or_pil_image, Image.Image):
44
  img = Image.open(file_or_pil_image)
45
  else:
46
  img = file_or_pil_image
47
 
48
- # convert to grayscale and resize
49
  img = img.convert('L').resize((28, 28))
50
- arr = np.array(img).astype('float32') / 255.0 # 0..1
51
 
52
- # auto-invert if background is white and strokes are dark (we expect digit bright on dark background)
53
- if arr.mean() > 0.5:
54
  arr = 1.0 - arr
55
 
56
- # ensure shape (1,28,28,1)
57
  arr = arr.reshape(1, 28, 28, 1).astype('float32')
58
- return arr, Image.fromarray((arr[0,:,:,0]*255).astype('uint8'))
59
 
60
- # ---------------------
61
- # Helper: preprocess canvas image (RGBA or RGB)
62
- # ---------------------
63
  def preprocess_canvas_image(image_data):
64
- """
65
- image_data: HxWx4 (RGBA) or HxWx3 (RGB) numpy array from st_canvas.
66
- Returns preprocessed array shape (1,28,28,1) and display PIL image.
67
- """
68
  if image_data is None:
69
  return None, None
70
 
71
- # If values are float [0..255] -> convert to uint8
72
  img_uint8 = image_data.astype('uint8')
73
 
74
- # If has alpha channel (4), drop or composite with white background
75
- if img_uint8.shape[2] == 4:
76
- # composite alpha over white background
77
  alpha = img_uint8[..., 3] / 255.0
78
  rgb = img_uint8[..., :3].astype('float32')
79
  white = np.ones_like(rgb) * 255.0
@@ -82,10 +62,8 @@ def preprocess_canvas_image(image_data):
82
  else:
83
  gray = cv2.cvtColor(img_uint8, cv2.COLOR_RGB2GRAY)
84
 
85
- # Resize to 28x28, normalize
86
  small = cv2.resize(gray, (28, 28), interpolation=cv2.INTER_AREA).astype('float32') / 255.0
87
 
88
- # auto-invert heuristic
89
  if small.mean() > 0.5:
90
  small = 1.0 - small
91
 
@@ -93,13 +71,12 @@ def preprocess_canvas_image(image_data):
93
  display_img = Image.fromarray((small * 255).astype('uint8'))
94
  return arr, display_img
95
 
96
- # ---------------------
97
- # UI: header & sidebar
98
- # ---------------------
99
  st.markdown("<h1 style='text-align:center;color:#0D47A1;'>πŸ”’ Handwritten Digit Recognizer</h1>", unsafe_allow_html=True)
100
  st.write("Upload or draw a digit (0–9). The app will preprocess the image and predict the digit.")
101
  st.markdown("---")
102
 
 
103
  st.sidebar.header("πŸ“Œ Instructions")
104
  st.sidebar.info(
105
  "β€’ Upload PNG/JPG or draw a digit. \n"
@@ -110,9 +87,7 @@ st.sidebar.markdown("---")
110
  st.sidebar.write("πŸ‘©β€πŸ’» **About**: Built with ❀️ by **Anam Jafar**")
111
  st.sidebar.write("[πŸ”— LinkedIn](https://www.linkedin.com/in/anam-jafar)")
112
 
113
- # ---------------------
114
- # FILE UPLOAD (multiple)
115
- # ---------------------
116
  uploaded_files = st.file_uploader(
117
  "πŸ“‚ Upload digit images (single or multiple):",
118
  type=["png", "jpg", "jpeg"],
@@ -122,29 +97,22 @@ uploaded_files = st.file_uploader(
122
  if uploaded_files:
123
  st.subheader("πŸ“· Uploaded Images & Predictions")
124
 
125
- # display in rows of up to 4 columns
126
  max_cols = 4
127
  for i in range(0, len(uploaded_files), max_cols):
128
  row_files = uploaded_files[i:i+max_cols]
129
  cols = st.columns(len(row_files))
130
  for j, file in enumerate(row_files):
131
  arr, display_img = preprocess_pil_file(file)
132
- # Debug info (remove in production)
133
- st.experimental_show({"shape": arr.shape, "min": float(arr.min()), "max": float(arr.max())}) # optional
134
- # Predict
135
- with st.spinner("Predicting..."):
136
- pred = model.predict(arr)
137
  probs = pred[0]
138
  label = int(np.argmax(probs))
139
  conf = float(np.max(probs))
140
 
141
  with cols[j]:
142
  st.image(display_img, caption=f"Pred: {label} ({conf*100:.1f}%)", use_column_width=True)
143
- st.bar_chart(probs) # show probability distribution
144
 
145
- # ---------------------
146
- # DRAWING PAD
147
- # ---------------------
148
  st.subheader("πŸ–ŒοΈ Draw your digit here:")
149
  canvas_result = st_canvas(
150
  stroke_width=12,
@@ -159,10 +127,7 @@ canvas_result = st_canvas(
159
  if canvas_result is not None and canvas_result.image_data is not None:
160
  arr, display_img = preprocess_canvas_image(canvas_result.image_data)
161
  if arr is not None:
162
- # Debug info (remove in production)
163
- st.experimental_show({"canvas_shape": arr.shape, "min": float(arr.min()), "max": float(arr.max())})
164
- with st.spinner("Predicting..."):
165
- pred = model.predict(arr)
166
  probs = pred[0]
167
  label = int(np.argmax(probs))
168
  conf = float(np.max(probs))
@@ -179,8 +144,9 @@ if canvas_result is not None and canvas_result.image_data is not None:
179
  st.image(display_img, caption="Preprocessed (28Γ—28) view", width=120)
180
  st.bar_chart(probs)
181
 
182
- # ---------------------
183
- # Footer
184
- # ---------------------
185
  st.markdown("---")
186
- st.markdown("<p style='text-align:center;'>Built with ❀️ using Streamlit & TensorFlow | By <b>Anam Jafar</b></p>", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import numpy as np
3
  import cv2
 
5
  from tensorflow.keras.models import load_model
6
  from streamlit_drawable_canvas import st_canvas
7
 
8
+ # ---- Page Config ----
9
+ st.set_page_config(
10
+ page_title="Digit Recognition App",
11
+ page_icon="πŸ”’",
12
+ layout="wide"
13
+ )
14
 
15
+ # ---- Custom Background ----
16
  st.markdown(
17
  """
18
  <style>
 
24
  unsafe_allow_html=True
25
  )
26
 
27
+ # ---- Load Model ----
28
  @st.cache_resource
29
  def load_cnn_model():
30
  return load_model("mnist_cnn.h5")
31
 
32
  model = load_cnn_model()
33
 
34
+ # ---- Preprocessing Helpers ----
 
 
35
  def preprocess_pil_file(file_or_pil_image):
 
 
 
 
36
  if not isinstance(file_or_pil_image, Image.Image):
37
  img = Image.open(file_or_pil_image)
38
  else:
39
  img = file_or_pil_image
40
 
 
41
  img = img.convert('L').resize((28, 28))
42
+ arr = np.array(img).astype('float32') / 255.0
43
 
44
+ if arr.mean() > 0.5: # invert if background is white
 
45
  arr = 1.0 - arr
46
 
 
47
  arr = arr.reshape(1, 28, 28, 1).astype('float32')
48
+ return arr, Image.fromarray((arr[0, :, :, 0] * 255).astype('uint8'))
49
 
 
 
 
50
  def preprocess_canvas_image(image_data):
 
 
 
 
51
  if image_data is None:
52
  return None, None
53
 
 
54
  img_uint8 = image_data.astype('uint8')
55
 
56
+ if img_uint8.shape[2] == 4: # RGBA
 
 
57
  alpha = img_uint8[..., 3] / 255.0
58
  rgb = img_uint8[..., :3].astype('float32')
59
  white = np.ones_like(rgb) * 255.0
 
62
  else:
63
  gray = cv2.cvtColor(img_uint8, cv2.COLOR_RGB2GRAY)
64
 
 
65
  small = cv2.resize(gray, (28, 28), interpolation=cv2.INTER_AREA).astype('float32') / 255.0
66
 
 
67
  if small.mean() > 0.5:
68
  small = 1.0 - small
69
 
 
71
  display_img = Image.fromarray((small * 255).astype('uint8'))
72
  return arr, display_img
73
 
74
+ # ---- Header ----
 
 
75
  st.markdown("<h1 style='text-align:center;color:#0D47A1;'>πŸ”’ Handwritten Digit Recognizer</h1>", unsafe_allow_html=True)
76
  st.write("Upload or draw a digit (0–9). The app will preprocess the image and predict the digit.")
77
  st.markdown("---")
78
 
79
+ # ---- Sidebar ----
80
  st.sidebar.header("πŸ“Œ Instructions")
81
  st.sidebar.info(
82
  "β€’ Upload PNG/JPG or draw a digit. \n"
 
87
  st.sidebar.write("πŸ‘©β€πŸ’» **About**: Built with ❀️ by **Anam Jafar**")
88
  st.sidebar.write("[πŸ”— LinkedIn](https://www.linkedin.com/in/anam-jafar)")
89
 
90
+ # ---- File Upload ----
 
 
91
  uploaded_files = st.file_uploader(
92
  "πŸ“‚ Upload digit images (single or multiple):",
93
  type=["png", "jpg", "jpeg"],
 
97
  if uploaded_files:
98
  st.subheader("πŸ“· Uploaded Images & Predictions")
99
 
 
100
  max_cols = 4
101
  for i in range(0, len(uploaded_files), max_cols):
102
  row_files = uploaded_files[i:i+max_cols]
103
  cols = st.columns(len(row_files))
104
  for j, file in enumerate(row_files):
105
  arr, display_img = preprocess_pil_file(file)
106
+ pred = model.predict(arr)
 
 
 
 
107
  probs = pred[0]
108
  label = int(np.argmax(probs))
109
  conf = float(np.max(probs))
110
 
111
  with cols[j]:
112
  st.image(display_img, caption=f"Pred: {label} ({conf*100:.1f}%)", use_column_width=True)
113
+ st.bar_chart(probs)
114
 
115
+ # ---- Drawing Pad ----
 
 
116
  st.subheader("πŸ–ŒοΈ Draw your digit here:")
117
  canvas_result = st_canvas(
118
  stroke_width=12,
 
127
  if canvas_result is not None and canvas_result.image_data is not None:
128
  arr, display_img = preprocess_canvas_image(canvas_result.image_data)
129
  if arr is not None:
130
+ pred = model.predict(arr)
 
 
 
131
  probs = pred[0]
132
  label = int(np.argmax(probs))
133
  conf = float(np.max(probs))
 
144
  st.image(display_img, caption="Preprocessed (28Γ—28) view", width=120)
145
  st.bar_chart(probs)
146
 
147
+ # ---- Footer ----
 
 
148
  st.markdown("---")
149
+ st.markdown(
150
+ "<p style='text-align:center;'>Built with ❀️ using Streamlit & TensorFlow | By <b>Anam Jafar</b></p>",
151
+ unsafe_allow_html=True
152
+ )