Gowthamvemula commited on
Commit
0aa6c34
·
verified ·
1 Parent(s): 38caf49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -102
app.py CHANGED
@@ -2,118 +2,119 @@ import streamlit as st
2
  import cv2
3
  from streamlit_drawable_canvas import st_canvas
4
  from keras.models import load_model
5
- from keras.datasets import mnist
6
  import numpy as np
7
- import random
8
 
9
- # Apply custom CSS
 
 
 
 
 
 
 
 
 
 
10
  st.markdown("""
11
  <style>
12
- body {
13
- background-color: #f5f7fa;
 
 
 
14
  }
15
- .main {
16
- background: linear-gradient(to right, #f8f9fa, #e0f7fa);
17
- border-radius: 10px;
18
- padding: 20px;
19
  }
20
- .title {
 
 
 
 
21
  text-align: center;
22
- font-size: 36px;
23
- color: #0077b6;
24
- font-family: 'Segoe UI', sans-serif;
25
- margin-bottom: 20px;
26
  }
27
- .st-cf {
28
- background-color: white;
29
- border-radius: 15px;
30
- padding: 15px;
31
- box-shadow: 2px 2px 10px rgba(0, 0, 0, 0.1);
32
  }
33
  </style>
34
  """, unsafe_allow_html=True)
35
 
36
- # Load MNIST model
37
- @st.cache_resource
38
- def load_mnist_model():
39
- return load_model("mnist_model.keras")
40
-
41
- model = load_mnist_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # Load MNIST test set for generator
44
- @st.cache_data
45
- def load_test_images():
46
- (_, _), (x_test, y_test) = mnist.load_data()
47
- return x_test, y_test
48
-
49
- x_test, y_test = load_test_images()
50
-
51
- # App title
52
- st.markdown('<div class="title">🧠 Mindist: Draw or Generate a Digit</div>', unsafe_allow_html=True)
53
-
54
- # Sidebar controls
55
- st.sidebar.title("🖌️ Canvas Settings")
56
- drawing_mode = st.sidebar.selectbox("✏️ Drawing Tool:", ("freedraw", "line", "rect", "circle", "transform"))
57
- stroke_width = st.sidebar.slider("🖍️ Stroke Width", 1, 25, 10)
58
- stroke_color = st.sidebar.color_picker("🎨 Stroke Color", "#000000")
59
- bg_color = st.sidebar.color_picker("🌄 Background Color", "#FFFFFF")
60
- bg_image = st.sidebar.file_uploader("🖼️ Background Image", type=["png", "jpg"])
61
- realtime_update = st.sidebar.checkbox("🔄 Update in Realtime", True)
62
-
63
- # Tabs for drawing or generating
64
- tab1, tab2 = st.tabs(["🎨 Draw Your Own", "🤖 Generate Random Digit"])
65
-
66
- with tab1:
67
- col1, col2 = st.columns([1, 1])
68
-
69
- with col1:
70
- st.subheader("🎨 Draw Here")
71
- canvas_result = st_canvas(
72
- fill_color="rgba(255, 165, 0, 0.3)",
73
- stroke_width=stroke_width,
74
- stroke_color=stroke_color,
75
- background_color=bg_color,
76
- update_streamlit=realtime_update,
77
- height=280,
78
- width=280,
79
- drawing_mode=drawing_mode,
80
- key="canvas",
81
- )
82
-
83
- with col2:
84
- if canvas_result.image_data is not None:
85
- st.subheader("🖼️ Your Drawing")
86
- st.image(canvas_result.image_data, use_column_width=True)
87
-
88
- if canvas_result.image_data is not None:
89
- st.markdown("---")
90
- st.subheader("📊 Prediction from Drawing")
91
-
92
- img = cv2.cvtColor(canvas_result.image_data.astype("uint8"), cv2.COLOR_RGBA2GRAY)
93
- img = 255 - img
94
- img_resized = cv2.resize(img, (28, 28))
95
- img_normalized = img_resized / 255.0
96
- final_img = img_normalized.reshape(1, 28, 28, 1)
97
-
98
- col3, col4 = st.columns([1, 1])
99
- with col3:
100
- st.image(img_resized, caption="🧼 Preprocessed (28x28)", clamp=True, channels="GRAY")
101
- with col4:
102
- prediction = model.predict(final_img)
103
- predicted_digit = int(np.argmax(prediction))
104
- st.markdown(f"<h3 style='color:#00796b;'>✅ Predicted Digit: <strong>{predicted_digit}</strong></h3>", unsafe_allow_html=True)
105
-
106
- with tab2:
107
- st.subheader("🎲 Random Digit Generator (MNIST)")
108
- if st.button("🔁 Generate Random Digit"):
109
- idx = random.randint(0, len(x_test) - 1)
110
- random_img = x_test[idx]
111
- true_label = y_test[idx]
112
-
113
- st.image(random_img, width=150, caption=f"🧾 True Label: {true_label}", clamp=True, channels="GRAY")
114
-
115
- input_img = random_img.reshape(1, 28, 28, 1) / 255.0
116
- pred = model.predict(input_img)
117
- pred_label = int(np.argmax(pred))
118
-
119
- st.markdown(f"<h3 style='color:#1e88e5;'>🤖 Predicted Digit: <strong>{pred_label}</strong></h3>", unsafe_allow_html=True)
 
2
  import cv2
3
  from streamlit_drawable_canvas import st_canvas
4
  from keras.models import load_model
 
5
  import numpy as np
 
6
 
7
+ # Page setup
8
+ st.set_page_config(page_title="Digit Recognizer", layout="centered")
9
+
10
+ # Load the trained MNIST model
11
+ @st.cache_resource
12
+ def load_mnist_model():
13
+ return load_model("mnist_model.keras") # Ensure this model is accurate (CNN preferred)
14
+
15
+ model = load_mnist_model()
16
+
17
+ # Styling
18
  st.markdown("""
19
  <style>
20
+ .main-title {
21
+ text-align: center;
22
+ font-size: 36px;
23
+ color: #2c3e50;
24
+ margin-bottom: 10px;
25
  }
26
+ .subtitle {
27
+ text-align: center;
28
+ font-size: 18px;
29
+ color: #555;
30
  }
31
+ .result-box {
32
+ background-color: #e8f5e9;
33
+ padding: 10px;
34
+ border-radius: 8px;
35
+ margin-top: 15px;
36
  text-align: center;
 
 
 
 
37
  }
38
+ .digit {
39
+ font-size: 28px;
40
+ color: #2e7d32;
41
+ font-weight: bold;
 
42
  }
43
  </style>
44
  """, unsafe_allow_html=True)
45
 
46
+ st.markdown('<div class="main-title">✏️ Draw a Digit</div>', unsafe_allow_html=True)
47
+ st.markdown('<div class="subtitle">Draw a digit (0-9) and get an accurate prediction</div>', unsafe_allow_html=True)
48
+
49
+ # Sidebar settings
50
+ st.sidebar.header("Canvas Settings")
51
+ stroke_width = st.sidebar.slider("Stroke Width", 5, 25, 15)
52
+ stroke_color = st.sidebar.color_picker("Stroke Color", "#000000")
53
+ bg_color = st.sidebar.color_picker("Background Color", "#FFFFFF")
54
+ realtime = st.sidebar.checkbox("Update in Realtime", True)
55
+
56
+ # Canvas
57
+ canvas_result = st_canvas(
58
+ fill_color="rgba(255, 165, 0, 0.3)",
59
+ stroke_width=stroke_width,
60
+ stroke_color=stroke_color,
61
+ background_color=bg_color,
62
+ update_streamlit=realtime,
63
+ height=280,
64
+ width=280,
65
+ drawing_mode="freedraw",
66
+ key="canvas",
67
+ )
68
+
69
+ # Preprocessing function
70
+ def preprocess_drawn_image(img_data):
71
+ img_gray = cv2.cvtColor(img_data.astype("uint8"), cv2.COLOR_RGBA2GRAY)
72
+ img_gray = 255 - img_gray # Invert for white digit on black
73
+
74
+ # Threshold to remove background noise
75
+ _, img_thresh = cv2.threshold(img_gray, 50, 255, cv2.THRESH_BINARY)
76
+
77
+ # Find contours to crop the digit
78
+ contours, _ = cv2.findContours(img_thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
79
+ if len(contours) == 0:
80
+ return None
81
+
82
+ x, y, w, h = cv2.boundingRect(contours[0])
83
+ digit_crop = img_thresh[y:y+h, x:x+w]
84
+
85
+ # Fit into square and resize to 20x20
86
+ max_side = max(w, h)
87
+ square_digit = np.zeros((max_side, max_side), dtype=np.uint8)
88
+ x_offset = (max_side - w) // 2
89
+ y_offset = (max_side - h) // 2
90
+ square_digit[y_offset:y_offset+h, x_offset:x_offset+w] = digit_crop
91
+ digit_resized = cv2.resize(square_digit, (20, 20))
92
+
93
+ # Place in center of 28x28 image
94
+ final_img = np.zeros((28, 28), dtype=np.uint8)
95
+ final_img[4:24, 4:24] = digit_resized
96
+
97
+ # Normalize
98
+ final_img = final_img / 255.0
99
+ return final_img.reshape(1, 28, 28, 1)
100
+
101
+ # Prediction
102
+ if canvas_result.image_data is not None:
103
+ processed_img = preprocess_drawn_image(canvas_result.image_data)
104
+
105
+ if processed_img is not None:
106
+ st.image(processed_img.reshape(28, 28), caption="🧼 Preprocessed Image", clamp=True, channels="GRAY")
107
+
108
+ prediction = model.predict(processed_img)
109
+ pred_digit = int(np.argmax(prediction))
110
+ confidence = float(np.max(prediction)) * 100
111
+
112
+ st.markdown(f"""
113
+ <div class='result-box'>
114
+ 🧠 Predicted Digit: <span class='digit'>{pred_digit}</span><br>
115
+ 📊 Confidence: <strong>{confidence:.2f}%</strong>
116
+ </div>
117
+ """, unsafe_allow_html=True)
118
+ else:
119
+ st.warning("Couldn't detect a digit. Please try drawing again.")
120