Gowthamvemula commited on
Commit
bea8775
·
verified ·
1 Parent(s): 0aa6c34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -43
app.py CHANGED
@@ -4,58 +4,59 @@ from streamlit_drawable_canvas import st_canvas
4
  from keras.models import load_model
5
  import numpy as np
6
 
7
- # Page setup
8
  st.set_page_config(page_title="Digit Recognizer", layout="centered")
9
 
10
- # Load the trained MNIST model
11
  @st.cache_resource
12
  def load_mnist_model():
13
- return load_model("mnist_model.keras") # Ensure this model is accurate (CNN preferred)
14
 
15
  model = load_mnist_model()
16
 
17
- # Styling
18
  st.markdown("""
19
  <style>
20
  .main-title {
21
  text-align: center;
22
- font-size: 36px;
 
23
  color: #2c3e50;
24
- margin-bottom: 10px;
25
  }
26
  .subtitle {
27
  text-align: center;
28
  font-size: 18px;
29
  color: #555;
 
30
  }
31
  .result-box {
32
- background-color: #e8f5e9;
33
- padding: 10px;
34
- border-radius: 8px;
35
- margin-top: 15px;
36
  text-align: center;
37
  }
38
  .digit {
39
- font-size: 28px;
40
- color: #2e7d32;
41
  font-weight: bold;
42
  }
43
  </style>
44
  """, unsafe_allow_html=True)
45
 
46
- st.markdown('<div class="main-title">✏️ Draw a Digit</div>', unsafe_allow_html=True)
47
- st.markdown('<div class="subtitle">Draw a digit (0-9) and get an accurate prediction</div>', unsafe_allow_html=True)
48
 
49
- # Sidebar settings
50
- st.sidebar.header("Canvas Settings")
51
  stroke_width = st.sidebar.slider("Stroke Width", 5, 25, 15)
52
  stroke_color = st.sidebar.color_picker("Stroke Color", "#000000")
53
  bg_color = st.sidebar.color_picker("Background Color", "#FFFFFF")
54
  realtime = st.sidebar.checkbox("Update in Realtime", True)
55
 
56
- # Canvas
57
  canvas_result = st_canvas(
58
- fill_color="rgba(255, 165, 0, 0.3)",
59
  stroke_width=stroke_width,
60
  stroke_color=stroke_color,
61
  background_color=bg_color,
@@ -66,39 +67,35 @@ canvas_result = st_canvas(
66
  key="canvas",
67
  )
68
 
69
- # Preprocessing function
70
  def preprocess_drawn_image(img_data):
71
- img_gray = cv2.cvtColor(img_data.astype("uint8"), cv2.COLOR_RGBA2GRAY)
72
- img_gray = 255 - img_gray # Invert for white digit on black
 
73
 
74
- # Threshold to remove background noise
75
- _, img_thresh = cv2.threshold(img_gray, 50, 255, cv2.THRESH_BINARY)
76
-
77
- # Find contours to crop the digit
78
- contours, _ = cv2.findContours(img_thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
79
- if len(contours) == 0:
80
  return None
81
 
82
  x, y, w, h = cv2.boundingRect(contours[0])
83
- digit_crop = img_thresh[y:y+h, x:x+w]
84
 
85
- # Fit into square and resize to 20x20
86
- max_side = max(w, h)
87
- square_digit = np.zeros((max_side, max_side), dtype=np.uint8)
88
- x_offset = (max_side - w) // 2
89
- y_offset = (max_side - h) // 2
90
- square_digit[y_offset:y_offset+h, x_offset:x_offset+w] = digit_crop
91
- digit_resized = cv2.resize(square_digit, (20, 20))
92
 
93
- # Place in center of 28x28 image
94
- final_img = np.zeros((28, 28), dtype=np.uint8)
95
- final_img[4:24, 4:24] = digit_resized
 
96
 
97
- # Normalize
98
- final_img = final_img / 255.0
99
- return final_img.reshape(1, 28, 28, 1)
100
 
101
- # Prediction
102
  if canvas_result.image_data is not None:
103
  processed_img = preprocess_drawn_image(canvas_result.image_data)
104
 
@@ -117,4 +114,3 @@ if canvas_result.image_data is not None:
117
  """, unsafe_allow_html=True)
118
  else:
119
  st.warning("Couldn't detect a digit. Please try drawing again.")
120
-
 
4
  from keras.models import load_model
5
  import numpy as np
6
 
7
+ # Page configuration
8
  st.set_page_config(page_title="Digit Recognizer", layout="centered")
9
 
10
+ # Load trained model (preferably CNN-based on MNIST)
11
  @st.cache_resource
12
  def load_mnist_model():
13
+ return load_model("mnist_model.keras")
14
 
15
  model = load_mnist_model()
16
 
17
+ # Custom CSS Styling
18
  st.markdown("""
19
  <style>
20
  .main-title {
21
  text-align: center;
22
+ font-size: 40px;
23
+ font-weight: 700;
24
  color: #2c3e50;
 
25
  }
26
  .subtitle {
27
  text-align: center;
28
  font-size: 18px;
29
  color: #555;
30
+ margin-bottom: 20px;
31
  }
32
  .result-box {
33
+ background-color: #f0f9ff;
34
+ border: 2px solid #3498db;
35
+ border-radius: 10px;
36
+ padding: 15px;
37
  text-align: center;
38
  }
39
  .digit {
40
+ font-size: 36px;
41
+ color: #2c3e50;
42
  font-weight: bold;
43
  }
44
  </style>
45
  """, unsafe_allow_html=True)
46
 
47
+ st.markdown('<div class="main-title">✍️ Digit Recognizer</div>', unsafe_allow_html=True)
48
+ st.markdown('<div class="subtitle">Draw any digit (0-9) below and let the model predict it</div>', unsafe_allow_html=True)
49
 
50
+ # Sidebar controls
51
+ st.sidebar.header("🛠️ Canvas Settings")
52
  stroke_width = st.sidebar.slider("Stroke Width", 5, 25, 15)
53
  stroke_color = st.sidebar.color_picker("Stroke Color", "#000000")
54
  bg_color = st.sidebar.color_picker("Background Color", "#FFFFFF")
55
  realtime = st.sidebar.checkbox("Update in Realtime", True)
56
 
57
+ # Drawing canvas
58
  canvas_result = st_canvas(
59
+ fill_color="rgba(255, 165, 0, 0.3)", # Transparent fill
60
  stroke_width=stroke_width,
61
  stroke_color=stroke_color,
62
  background_color=bg_color,
 
67
  key="canvas",
68
  )
69
 
70
+ # Preprocess drawing like MNIST
71
  def preprocess_drawn_image(img_data):
72
+ gray = cv2.cvtColor(img_data.astype("uint8"), cv2.COLOR_RGBA2GRAY)
73
+ gray = 255 - gray # Invert to white digit on black
74
+ _, thresh = cv2.threshold(gray, 50, 255, cv2.THRESH_BINARY)
75
 
76
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
77
+ if not contours:
 
 
 
 
78
  return None
79
 
80
  x, y, w, h = cv2.boundingRect(contours[0])
81
+ digit = thresh[y:y+h, x:x+w]
82
 
83
+ # Center the digit in a square image
84
+ max_dim = max(w, h)
85
+ square = np.zeros((max_dim, max_dim), dtype=np.uint8)
86
+ x_offset = (max_dim - w) // 2
87
+ y_offset = (max_dim - h) // 2
88
+ square[y_offset:y_offset+h, x_offset:x_offset+w] = digit
 
89
 
90
+ # Resize to 20x20, then embed in 28x28
91
+ resized = cv2.resize(square, (20, 20))
92
+ final = np.zeros((28, 28), dtype=np.uint8)
93
+ final[4:24, 4:24] = resized
94
 
95
+ final = final / 255.0
96
+ return final.reshape(1, 28, 28, 1)
 
97
 
98
+ # Predict and display result
99
  if canvas_result.image_data is not None:
100
  processed_img = preprocess_drawn_image(canvas_result.image_data)
101
 
 
114
  """, unsafe_allow_html=True)
115
  else:
116
  st.warning("Couldn't detect a digit. Please try drawing again.")