Mpavan45 commited on
Commit
17868e1
·
verified ·
1 Parent(s): d9afd8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -48
app.py CHANGED
@@ -68,37 +68,32 @@ import numpy as np
68
  from streamlit_drawable_canvas import st_canvas
69
  from keras.models import load_model
70
 
71
- # Sidebar controls
72
- st.sidebar.title("Canvas Settings")
73
- drawing_mode = st.sidebar.selectbox("Drawing tool:", ("freedraw", "line", "rect", "circle", "transform"))
74
- stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 10)
75
- stroke_color = st.sidebar.color_picker("Stroke color hex: ", "#000000") # black
76
- bg_color = st.sidebar.color_picker("Background color hex: ", "#FFFFFF") # white
77
- bg_image = st.sidebar.file_uploader("Background image:", type=["png", "jpg"])
78
- realtime_update = st.sidebar.checkbox("Update in realtime", True)
79
-
80
- # Mode selection
81
- mode = st.sidebar.radio("Select Prediction Mode", ["Single Digit", "Multi Digit"])
82
-
83
- # === Load models ===
84
  @st.cache_resource
85
  def load_single_digit_model():
86
  return load_model("mnist_model.keras")
87
 
88
  @st.cache_resource
89
  def load_multi_digit_model():
90
- return load_model("best_model.keras") # Your multi-digit model
91
 
92
  model_single = load_single_digit_model()
93
  model_multi = load_multi_digit_model()
94
 
95
- # === Streamlit UI ===
 
 
 
 
 
 
 
 
96
  st.title("🧠 Digit Recognition App")
97
- st.subheader(f"✏️ Mode: {mode}")
98
 
99
- # Create drawing canvas
100
  canvas_result = st_canvas(
101
- fill_color="rgba(255, 165, 0, 0.3)",
102
  stroke_width=stroke_width,
103
  stroke_color=stroke_color,
104
  background_color=bg_color,
@@ -109,43 +104,45 @@ canvas_result = st_canvas(
109
  key="canvas",
110
  )
111
 
112
- # Prediction Section
113
  if canvas_result.image_data is not None:
114
  st.markdown("---")
115
- st.subheader("🧪 Prediction Results")
116
-
117
- # Preprocess drawing
118
- img = cv2.cvtColor(canvas_result.image_data.astype("uint8"), cv2.COLOR_RGBA2GRAY)
119
- img = 255 - img # Invert
120
- _, thresh = cv2.threshold(img, 30, 255, cv2.THRESH_BINARY)
121
 
122
- if mode == "Single Digit":
123
- resized = cv2.resize(thresh, (28, 28))
124
- norm = resized.astype("float32") / 255.0
125
- input_img = norm.reshape(1, 28, 28, 1)
 
126
 
127
- prediction = model_single.predict(input_img)
128
- digit = np.argmax(prediction)
 
129
 
130
- col1, col2 = st.columns(2)
131
- with col1:
132
- st.image(resized, width=200, caption="28x28 Preprocessed")
133
- with col2:
134
- st.success(f"🧠 Predicted Digit: **{digit}**")
135
 
136
- elif mode == "Multi Digit":
137
- resized = cv2.resize(thresh, (80, 28)) # Resize to match your model (width=80, height=28)
138
- norm = resized.astype("float32") / 255.0
139
- input_seq = norm.reshape(1, 28, 80, 1)
140
 
141
- preds = model_multi.predict(input_seq)
142
-
143
- # Decode predictions for each digit
144
  predicted_digits = [np.argmax(p[0]) for p in preds]
145
  predicted_str = ''.join(str(d) for d in predicted_digits)
146
 
147
- col1, col2 = st.columns(2)
148
- with col1:
149
- st.image(resized, width=300, caption="80x28 Multi-digit Input")
150
- with col2:
151
- st.success(f"🧠 Predicted Sequence: **{predicted_str}**")
 
 
 
 
 
 
 
 
 
 
68
  from streamlit_drawable_canvas import st_canvas
69
  from keras.models import load_model
70
 
71
+ # === Load models with caching ===
 
 
 
 
 
 
 
 
 
 
 
 
72
  @st.cache_resource
73
  def load_single_digit_model():
74
  return load_model("mnist_model.keras")
75
 
76
  @st.cache_resource
77
  def load_multi_digit_model():
78
+ return load_model("best_model.keras")
79
 
80
  model_single = load_single_digit_model()
81
  model_multi = load_multi_digit_model()
82
 
83
+ # === Sidebar settings ===
84
+ st.sidebar.title("Canvas Settings")
85
+ drawing_mode = st.sidebar.selectbox("Drawing tool:", ("freedraw", "line", "rect", "circle", "transform"))
86
+ stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 10)
87
+ stroke_color = st.sidebar.color_picker("Stroke color:", "#000000")
88
+ bg_color = st.sidebar.color_picker("Background color:", "#FFFFFF")
89
+ realtime_update = st.sidebar.checkbox("Update in realtime", True)
90
+
91
+ # === Streamlit layout ===
92
  st.title("🧠 Digit Recognition App")
93
+ st.subheader("Draw a digit or number below 👇")
94
 
 
95
  canvas_result = st_canvas(
96
+ fill_color="rgba(255,165,0,0.3)",
97
  stroke_width=stroke_width,
98
  stroke_color=stroke_color,
99
  background_color=bg_color,
 
104
  key="canvas",
105
  )
106
 
107
+ # === Prediction logic ===
108
  if canvas_result.image_data is not None:
109
  st.markdown("---")
110
+ st.subheader("Prediction Output")
 
 
 
 
 
111
 
112
+ # Preprocess the drawing
113
+ img = canvas_result.image_data.astype("uint8")
114
+ img_gray = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
115
+ img_gray = 255 - img_gray # Invert
116
+ _, thresh = cv2.threshold(img_gray, 30, 255, cv2.THRESH_BINARY)
117
 
118
+ # Resize to both possible formats
119
+ resized_single = cv2.resize(thresh, (28, 28))
120
+ resized_multi = cv2.resize(thresh, (100, 28)) # width x height
121
 
122
+ # Decide based on content width if it's likely multi-digit
123
+ nonzero_cols = np.count_nonzero(np.sum(thresh, axis=0) > 10)
124
+ is_multi = nonzero_cols > 40 # simple heuristic
 
 
125
 
126
+ if is_multi:
127
+ st.info("🔢 Detected Multi-Digit Input")
128
+ input_img = resized_multi.astype("float32") / 255.0
129
+ input_img = input_img.reshape(1, 28, 100, 1)
130
 
131
+ preds = model_multi.predict(input_img)
 
 
132
  predicted_digits = [np.argmax(p[0]) for p in preds]
133
  predicted_str = ''.join(str(d) for d in predicted_digits)
134
 
135
+ st.image(resized_multi, caption="Preprocessed 100x28 Image", width=250)
136
+ st.success(f"🧠 Predicted Number: **{predicted_str}**")
137
+
138
+ else:
139
+ st.info("✏️ Detected Single-Digit Input")
140
+ input_img = resized_single.astype("float32") / 255.0
141
+ input_img = input_img.reshape(1, 28, 28, 1)
142
+
143
+ prediction = model_single.predict(input_img)
144
+ predicted_digit = np.argmax(prediction)
145
+
146
+ st.image(resized_single, caption="Preprocessed 28x28 Image", width=200)
147
+ st.success(f"🧠 Predicted Digit: **{predicted_digit}**")
148
+