Mpavan45 commited on
Commit
f5acbfa
Β·
verified Β·
1 Parent(s): d3e054c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -60
app.py CHANGED
@@ -65,84 +65,96 @@
65
  import streamlit as st
66
  import cv2
67
  import numpy as np
68
- from streamlit_drawable_canvas import st_canvas
69
  from keras.models import load_model
 
70
 
71
- # === Load models with caching ===
72
  @st.cache_resource
73
  def load_single_digit_model():
74
- return load_model("mnist_model.keras")
75
 
76
  @st.cache_resource
77
  def load_multi_digit_model():
78
- return load_model("best_model.keras")
 
 
 
79
 
80
- model_single = load_single_digit_model()
81
- model_multi = load_multi_digit_model()
 
 
 
 
 
 
82
 
83
- # === Sidebar settings ===
84
  st.sidebar.title("Canvas Settings")
85
  drawing_mode = st.sidebar.selectbox("Drawing tool:", ("freedraw", "line", "rect", "circle", "transform"))
86
  stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 10)
87
- stroke_color = st.sidebar.color_picker("Stroke color:", "#000000")
88
- bg_color = st.sidebar.color_picker("Background color:", "#FFFFFF")
 
89
  realtime_update = st.sidebar.checkbox("Update in realtime", True)
90
 
91
- # === Streamlit layout ===
92
- st.title("🧠 Digit Recognition App")
93
- st.subheader("Draw a digit or number below πŸ‘‡")
94
-
95
- canvas_result = st_canvas(
96
- fill_color="rgba(255,165,0,0.3)",
97
- stroke_width=stroke_width,
98
- stroke_color=stroke_color,
99
- background_color=bg_color,
100
- update_streamlit=realtime_update,
101
- height=280,
102
- width=280,
103
- drawing_mode=drawing_mode,
104
- key="canvas",
105
- )
106
-
107
- # === Prediction logic ===
 
 
 
 
 
 
 
 
 
 
 
108
  if canvas_result.image_data is not None:
109
  st.markdown("---")
110
- st.subheader("Prediction Output")
111
-
112
- img = canvas_result.image_data.astype("uint8")
113
- img_gray = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
114
- img_gray = 255 - img_gray # Invert
115
- _, thresh = cv2.threshold(img_gray, 30, 255, cv2.THRESH_BINARY)
116
-
117
- # Detect bounding box of the drawing
118
- contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
119
- if contours:
120
- x, y, w, h = cv2.boundingRect(np.vstack(contours))
 
 
121
  else:
122
- w = 0 # fallback
123
-
124
- # Use width of the drawn content to decide
125
- if w < 50:
126
- st.info("✏️ Detected Single-Digit Input")
127
- resized = cv2.resize(thresh, (28, 28))
128
- input_img = resized.astype("float32") / 255.0
129
- input_img = input_img.reshape(1, 28, 28, 1)
130
-
131
- prediction = model_single.predict(input_img)
132
- predicted_digit = np.argmax(prediction)
133
-
134
- st.image(resized, caption="Preprocessed 28x28 Image", width=200)
135
- st.success(f"🧠 Predicted Digit: **{predicted_digit}**")
136
 
 
 
 
 
 
 
 
137
  else:
138
- st.info("πŸ”’ Detected Multi-Digit Input")
139
- resized = cv2.resize(thresh, (100, 28))
140
- input_img = resized.astype("float32") / 255.0
141
- input_img = input_img.reshape(1, 28, 100, 1)
142
 
143
- preds = model_multi.predict(input_img)
144
- predicted_digits = [np.argmax(p[0]) for p in preds]
145
- predicted_str = ''.join(str(d) for d in predicted_digits)
146
 
147
- st.image(resized, caption="Preprocessed 100x28 Image", width=250)
148
- st.success(f"🧠 Predicted Number: **{predicted_str}**")
 
65
  import streamlit as st
66
  import cv2
67
  import numpy as np
 
68
  from keras.models import load_model
69
+ from streamlit_drawable_canvas import st_canvas
70
 
71
+ # === Load models ===
72
  @st.cache_resource
73
  def load_single_digit_model():
74
+ return load_model("single_digit_model.keras")
75
 
76
  @st.cache_resource
77
  def load_multi_digit_model():
78
+ return load_model("best_model.keras") # multi-digit model
79
+
80
+ single_digit_model = load_single_digit_model()
81
+ multi_digit_model = load_multi_digit_model()
82
 
83
+ # === Helper function to clean prediction ===
84
+ def clean_prediction(predicted_digits):
85
+ """
86
+ Removes junk or padded digits like trailing 0s or 1s and keeps only valid 0–9 digits.
87
+ You can further tune this logic based on training patterns.
88
+ """
89
+ digits = [str(d) for d in predicted_digits if 0 <= d <= 9]
90
+ return ''.join(digits)
91
 
92
+ # === Sidebar controls ===
93
  st.sidebar.title("Canvas Settings")
94
  drawing_mode = st.sidebar.selectbox("Drawing tool:", ("freedraw", "line", "rect", "circle", "transform"))
95
  stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 10)
96
+ stroke_color = st.sidebar.color_picker("Stroke color hex: ", "#000000") # black
97
+ bg_color = st.sidebar.color_picker("Background color hex: ", "#FFFFFF") # white
98
+ bg_image = st.sidebar.file_uploader("Background image:", type=["png", "jpg"])
99
  realtime_update = st.sidebar.checkbox("Update in realtime", True)
100
 
101
+ # === Title ===
102
+ st.title("πŸ–ŒοΈ Multi-Digit and Single-Digit Drawing: Predict Instantly")
103
+
104
+ # === Create a two-column layout ===
105
+ col1, col2 = st.columns([1, 1])
106
+
107
+ # === Canvas for drawing ===
108
+ with col1:
109
+ st.subheader("Draw Here πŸ‘‡")
110
+ canvas_result = st_canvas(
111
+ fill_color="rgba(255, 165, 0, 0.3)",
112
+ stroke_width=stroke_width,
113
+ stroke_color=stroke_color,
114
+ background_color=bg_color,
115
+ update_streamlit=realtime_update,
116
+ height=280,
117
+ width=280,
118
+ drawing_mode=drawing_mode,
119
+ key="canvas",
120
+ )
121
+
122
+ # === Display original drawing ===
123
+ with col2:
124
+ if canvas_result.image_data is not None:
125
+ st.subheader("Original Drawing")
126
+ st.image(canvas_result.image_data, use_column_width=True)
127
+
128
+ # === Image preprocessing and prediction ===
129
  if canvas_result.image_data is not None:
130
  st.markdown("---")
131
+ st.subheader("Preprocessed Image & Prediction")
132
+
133
+ # Preprocess image
134
+ img = cv2.cvtColor(canvas_result.image_data.astype("uint8"), cv2.COLOR_RGBA2GRAY)
135
+ img = 255 - img # Invert colors
136
+ img_resized = cv2.resize(img, (28, 80)) # Resize to match multi-digit model input shape
137
+ img_normalized = img_resized / 255.0
138
+ final_img = img_normalized.reshape(1, 28, 80, 1)
139
+
140
+ # === Choose which model to use based on the image size ===
141
+ # If image is more likely to be a single digit (e.g., smaller width), use the single digit model
142
+ if img_resized.shape[1] < 50: # This is an arbitrary threshold for width
143
+ model_to_use = single_digit_model
144
  else:
145
+ model_to_use = multi_digit_model
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
+ # Predict using the selected model
148
+ preds = model_to_use.predict(final_img)
149
+
150
+ # For multi-digit model, decode and clean prediction
151
+ if model_to_use == multi_digit_model:
152
+ predicted_digits = [np.argmax(p[0]) for p in preds]
153
+ predicted_str = clean_prediction(predicted_digits)
154
  else:
155
+ # For single digit model, directly decode
156
+ predicted_str = str(np.argmax(preds))
 
 
157
 
158
+ # Show prediction result
159
+ st.markdown(f"### 🧠 Predicted Number: **{predicted_str}**")
 
160