saramneena commited on
Commit
41994e0
Β·
verified Β·
1 Parent(s): c7e669b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -50
app.py CHANGED
@@ -8,86 +8,104 @@ import numpy as np
8
  import cv2
9
  from PIL import Image, ImageOps
10
 
11
- # Page config
12
- st.set_page_config(page_title="Digit Recognizer", layout="centered")
13
- st.title("✍️ Handwritten Digit Recognizer")
14
-
15
- # Sidebar - Drawing options
16
- st.sidebar.title("πŸ›  Drawing Settings")
17
- mode = st.sidebar.selectbox("Drawing Tool", ("freedraw", "line"))
18
- stroke_width = st.sidebar.slider("Stroke width", 5, 25, 15)
19
- stroke_color = st.sidebar.color_picker("Stroke color", "#000000")
20
- bg_color = st.sidebar.color_picker("Background color", "#FFFFFF")
21
-
22
- # Load trained model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  @st.cache_resource
24
  def load_mnist_model():
25
  return load_model("final_model.keras")
26
 
27
  model = load_mnist_model()
28
 
29
- # Preprocessing Function
30
  def preprocess(img):
31
  img = ImageOps.grayscale(img)
32
  img = img.resize((200, 200))
33
  img = np.array(img)
34
-
35
- # Invert if background is light
36
  if np.mean(img) > 127:
37
  img = 255 - img
38
-
39
- # Thresholding
40
  _, img = cv2.threshold(img, 100, 255, cv2.THRESH_BINARY)
41
-
42
- # Find bounding box of the digit
43
  coords = cv2.findNonZero(img)
44
  if coords is not None:
45
  x, y, w, h = cv2.boundingRect(coords)
46
  digit = img[y:y+h, x:x+w]
47
  else:
48
  return np.zeros((1, 28, 28), dtype="float32")
49
-
50
- # Resize and pad
51
  digit = cv2.resize(digit, (20, 20), interpolation=cv2.INTER_AREA)
52
  digit = np.pad(digit, ((4, 4), (4, 4)), mode="constant", constant_values=0)
53
-
54
- # Normalize
55
  digit = digit.astype("float32") / 255.0
56
  digit = digit.reshape(1, 28, 28)
57
  return digit
58
 
59
- # Upload image
60
- st.subheader("πŸ“€ Upload an Image")
61
- uploaded_file = st.file_uploader("Choose a digit image...", type=["jpg", "png"])
 
 
 
62
 
63
- # Canvas drawing
64
- st.subheader("πŸ–ŒοΈ Draw a Digit")
65
- canvas_result = st_canvas(
66
- stroke_width=stroke_width,
67
- stroke_color=stroke_color,
68
- background_color=bg_color,
69
- height=200,
70
- width=200,
71
- drawing_mode=mode,
72
- key="canvas",
73
- )
74
 
75
- # Get input image from upload or canvas
76
  input_img = None
77
- if uploaded_file:
78
- input_img = Image.open(uploaded_file).convert("RGB")
79
- elif canvas_result.image_data is not None:
80
- input_img = Image.fromarray(canvas_result.image_data.astype("uint8"))
81
 
82
- # Display & Predict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  if input_img:
84
- st.image(input_img, caption="πŸ–Ό Input Image", width=150)
85
 
86
- if st.button("πŸ” Predict"):
87
  processed = preprocess(input_img)
88
  prediction = model.predict(processed, verbose=0)
89
- digit = np.argmax(prediction)
90
- confidence = np.max(prediction) * 100
91
-
92
- st.image(processed.reshape(28, 28), width=150, caption="πŸ§ͺ Processed Image")
93
- st.success(f"🧠 Predicted Digit: **{digit}** with **{confidence:.2f}%** confidence")
 
 
 
 
 
 
8
  import cv2
9
  from PIL import Image, ImageOps
10
 
11
+ # Custom styles
12
+ st.markdown("""
13
+ <style>
14
+ .big-font {
15
+ font-size:40px !important;
16
+ font-weight: bold;
17
+ color: #5A189A;
18
+ text-align: center;
19
+ }
20
+ .result-box {
21
+ background-color: #F0EBF8;
22
+ border-radius: 10px;
23
+ padding: 20px;
24
+ text-align: center;
25
+ font-size: 24px;
26
+ color: #3C096C;
27
+ font-weight: bold;
28
+ }
29
+ </style>
30
+ """, unsafe_allow_html=True)
31
+
32
+ # App title
33
+ st.markdown('<p class="big-font">✍️ Handwritten Digit Recognizer</p>', unsafe_allow_html=True)
34
+ st.markdown("### πŸ”’ Draw or Upload a digit and get it recognized by our ML model!")
35
+
36
+ # Load model
37
  @st.cache_resource
38
  def load_mnist_model():
39
  return load_model("final_model.keras")
40
 
41
  model = load_mnist_model()
42
 
43
+ # Preprocessing
44
  def preprocess(img):
45
  img = ImageOps.grayscale(img)
46
  img = img.resize((200, 200))
47
  img = np.array(img)
 
 
48
  if np.mean(img) > 127:
49
  img = 255 - img
 
 
50
  _, img = cv2.threshold(img, 100, 255, cv2.THRESH_BINARY)
 
 
51
  coords = cv2.findNonZero(img)
52
  if coords is not None:
53
  x, y, w, h = cv2.boundingRect(coords)
54
  digit = img[y:y+h, x:x+w]
55
  else:
56
  return np.zeros((1, 28, 28), dtype="float32")
 
 
57
  digit = cv2.resize(digit, (20, 20), interpolation=cv2.INTER_AREA)
58
  digit = np.pad(digit, ((4, 4), (4, 4)), mode="constant", constant_values=0)
 
 
59
  digit = digit.astype("float32") / 255.0
60
  digit = digit.reshape(1, 28, 28)
61
  return digit
62
 
63
+ # Sidebar settings
64
+ st.sidebar.title("🎨 Drawing Settings")
65
+ mode = st.sidebar.selectbox("Drawing Tool", ("freedraw", "line"))
66
+ stroke_width = st.sidebar.slider("Stroke Width", 5, 25, 15)
67
+ stroke_color = st.sidebar.color_picker("Stroke Color", "#000000")
68
+ bg_color = st.sidebar.color_picker("Background Color", "#FFFFFF")
69
 
70
+ # Tabs for draw vs upload
71
+ tab1, tab2 = st.tabs(["πŸ–ŒοΈ Draw Digit", "πŸ“€ Upload Image"])
 
 
 
 
 
 
 
 
 
72
 
 
73
  input_img = None
 
 
 
 
74
 
75
+ # Tab 1: Draw
76
+ with tab1:
77
+ canvas_result = st_canvas(
78
+ stroke_width=stroke_width,
79
+ stroke_color=stroke_color,
80
+ background_color=bg_color,
81
+ height=200,
82
+ width=200,
83
+ drawing_mode=mode,
84
+ key="canvas",
85
+ )
86
+ if canvas_result.image_data is not None:
87
+ input_img = Image.fromarray(canvas_result.image_data.astype("uint8"))
88
+
89
+ # Tab 2: Upload
90
+ with tab2:
91
+ uploaded_file = st.file_uploader("Upload a digit image...", type=["jpg", "png"])
92
+ if uploaded_file:
93
+ input_img = Image.open(uploaded_file).convert("RGB")
94
+
95
+ # Prediction
96
  if input_img:
97
+ st.image(input_img, caption="πŸ” Input Image", width=150)
98
 
99
+ if st.button("🎯 Predict"):
100
  processed = preprocess(input_img)
101
  prediction = model.predict(processed, verbose=0)
102
+ digit = int(np.argmax(prediction))
103
+ confidence = float(np.max(prediction)) * 100
104
+
105
+ st.image(processed.reshape(28, 28), width=150, caption="πŸ§ͺ Preprocessed Image")
106
+ st.markdown(f"""
107
+ <div class="result-box">
108
+ 🧠 Predicted Digit: <strong>{digit}</strong><br/>
109
+ πŸ”Ž Confidence: <strong>{confidence:.2f}%</strong>
110
+ </div>
111
+ """, unsafe_allow_html=True)