Mpavan45 commited on
Commit
738dd45
·
verified ·
1 Parent(s): 0fdd3c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -57
app.py CHANGED
@@ -66,28 +66,37 @@ import streamlit as st
66
  import cv2
67
  import numpy as np
68
  from streamlit_drawable_canvas import st_canvas
69
- from tensorflow.keras.models import load_model
70
 
71
- # === Load model (trained on single digits) ===
 
 
 
 
 
 
 
 
 
 
 
 
72
  @st.cache_resource
73
- def load_digit_model():
74
  return load_model("mnist_model.keras")
75
 
76
- model = load_digit_model()
 
 
77
 
78
- # === Sidebar Controls ===
79
- st.sidebar.title("🛠️ Settings")
80
- mode = st.sidebar.radio("Choose Prediction Mode", ("Single Digit", "Multi Digit"))
81
- stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 10)
82
- stroke_color = st.sidebar.color_picker("Stroke color: ", "#000000")
83
- bg_color = st.sidebar.color_picker("Background color: ", "#FFFFFF")
84
- realtime_update = st.sidebar.checkbox("Update in realtime", True)
85
 
86
- # === Title ===
87
- st.title("🧠 MNIST Digit Recognizer")
88
- st.caption("Draw digits and let the model predict them! Choose **Single** or **Multi** mode from the sidebar.")
89
 
90
- # === Drawing Canvas ===
91
  canvas_result = st_canvas(
92
  fill_color="rgba(255, 165, 0, 0.3)",
93
  stroke_width=stroke_width,
@@ -96,55 +105,47 @@ canvas_result = st_canvas(
96
  update_streamlit=realtime_update,
97
  height=280,
98
  width=280,
99
- drawing_mode="freedraw",
100
  key="canvas",
101
  )
102
 
103
- # === Prediction Logic ===
104
  if canvas_result.image_data is not None:
105
- st.subheader("🖼️ Original Drawing")
106
- st.image(canvas_result.image_data, use_column_width=False, width=280)
107
 
108
- img = canvas_result.image_data.astype('uint8')
109
- gray = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
110
- gray = 255 - gray # Invert for white digits on black
111
- _, thresh = cv2.threshold(gray, 30, 255, cv2.THRESH_BINARY)
112
 
113
  if mode == "Single Digit":
114
- st.subheader("🔢 Single Digit Mode")
115
  resized = cv2.resize(thresh, (28, 28))
116
- normalized = resized.astype("float32") / 255.0
117
- input_tensor = normalized.reshape(1, 28, 28, 1)
118
- pred = model.predict(input_tensor, verbose=0)
119
- digit = np.argmax(pred)
120
- st.image(resized, width=100, caption="🧼 Cleaned & Resized")
121
- st.success(f"✅ Predicted Digit: **{digit}**")
 
 
 
 
 
122
 
123
  elif mode == "Multi Digit":
124
- st.subheader("🔢 Multi Digit Mode")
125
-
126
- # Detect contours
127
- contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
128
- boxes = [cv2.boundingRect(c) for c in contours if cv2.contourArea(c) > 50]
129
- boxes = sorted(boxes, key=lambda b: b[0]) # sort by x
130
-
131
- predictions = []
132
- for x, y, w, h in boxes:
133
- digit_img = thresh[y:y+h, x:x+w]
134
- digit_img = cv2.resize(digit_img, (28, 28))
135
- normalized = digit_img.astype("float32") / 255.0
136
- input_tensor = normalized.reshape(1, 28, 28, 1)
137
- pred = model.predict(input_tensor, verbose=0)
138
- digit = np.argmax(pred)
139
- predictions.append((digit_img, digit))
140
-
141
- if predictions:
142
- st.markdown("### ✂️ Segmented Digits & Predictions")
143
- cols = st.columns(len(predictions))
144
- for i, (img, digit) in enumerate(predictions):
145
- with cols[i]:
146
- st.image(img, width=64, caption=f"➡️ {digit}")
147
- full_number = ''.join(str(d) for (_, d) in predictions)
148
- st.success(f"📌 Final Multi-Digit Prediction: **{full_number}**")
149
- else:
150
- st.warning("⚠️ Couldn't detect any digits. Try writing more clearly.")
 
66
  import cv2
67
  import numpy as np
68
  from streamlit_drawable_canvas import st_canvas
69
+ from keras.models import load_model
70
 
71
+ # Sidebar controls
72
+ st.sidebar.title("Canvas Settings")
73
+ drawing_mode = st.sidebar.selectbox("Drawing tool:", ("freedraw", "line", "rect", "circle", "transform"))
74
+ stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 10)
75
+ stroke_color = st.sidebar.color_picker("Stroke color hex: ", "#000000") # black
76
+ bg_color = st.sidebar.color_picker("Background color hex: ", "#FFFFFF") # white
77
+ bg_image = st.sidebar.file_uploader("Background image:", type=["png", "jpg"])
78
+ realtime_update = st.sidebar.checkbox("Update in realtime", True)
79
+
80
+ # Mode selection
81
+ mode = st.sidebar.radio("Select Prediction Mode", ["Single Digit", "Multi Digit"])
82
+
83
+ # === Load models ===
84
  @st.cache_resource
85
+ def load_single_digit_model():
86
  return load_model("mnist_model.keras")
87
 
88
+ @st.cache_resource
89
+ def load_multi_digit_model():
90
+ return load_model("best_model.keras") # Your multi-digit model
91
 
92
+ model_single = load_single_digit_model()
93
+ model_multi = load_multi_digit_model()
 
 
 
 
 
94
 
95
+ # === Streamlit UI ===
96
+ st.title("🧠 Digit Recognition App")
97
+ st.subheader(f"✏️ Mode: {mode}")
98
 
99
+ # Create drawing canvas
100
  canvas_result = st_canvas(
101
  fill_color="rgba(255, 165, 0, 0.3)",
102
  stroke_width=stroke_width,
 
105
  update_streamlit=realtime_update,
106
  height=280,
107
  width=280,
108
+ drawing_mode=drawing_mode,
109
  key="canvas",
110
  )
111
 
112
+ # Prediction Section
113
  if canvas_result.image_data is not None:
114
+ st.markdown("---")
115
+ st.subheader("🧪 Prediction Results")
116
 
117
+ # Preprocess drawing
118
+ img = cv2.cvtColor(canvas_result.image_data.astype("uint8"), cv2.COLOR_RGBA2GRAY)
119
+ img = 255 - img # Invert
120
+ _, thresh = cv2.threshold(img, 30, 255, cv2.THRESH_BINARY)
121
 
122
  if mode == "Single Digit":
 
123
  resized = cv2.resize(thresh, (28, 28))
124
+ norm = resized.astype("float32") / 255.0
125
+ input_img = norm.reshape(1, 28, 28, 1)
126
+
127
+ prediction = model_single.predict(input_img)
128
+ digit = np.argmax(prediction)
129
+
130
+ col1, col2 = st.columns(2)
131
+ with col1:
132
+ st.image(resized, width=200, caption="28x28 Preprocessed")
133
+ with col2:
134
+ st.success(f"🧠 Predicted Digit: **{digit}**")
135
 
136
  elif mode == "Multi Digit":
137
+ resized = cv2.resize(thresh, (80, 28)) # Resize to match your model (width=80, height=28)
138
+ norm = resized.astype("float32") / 255.0
139
+ input_seq = norm.reshape(1, 28, 80, 1)
140
+
141
+ preds = model_multi.predict(input_seq)
142
+
143
+ # Decode predictions for each digit
144
+ predicted_digits = [np.argmax(p[0]) for p in preds]
145
+ predicted_str = ''.join(str(d) for d in predicted_digits)
146
+
147
+ col1, col2 = st.columns(2)
148
+ with col1:
149
+ st.image(resized, width=300, caption="80x28 Multi-digit Input")
150
+ with col2:
151
+ st.success(f"🧠 Predicted Sequence: **{predicted_str}**")