DOMMETI commited on
Commit
07499c5
·
verified ·
1 Parent(s): 001401b

Update Home.py

Browse files
Files changed (1) hide show
  1. Home.py +89 -16
Home.py CHANGED
@@ -4,21 +4,78 @@ from keras.models import load_model
4
  import numpy as np
5
  import cv2
6
 
7
- drawing_mode = st.sidebar.selectbox("Drawing tool:", ("freedraw", "line", "rect", "circle", "transform"))
8
- stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 10)
9
- stroke_color = st.sidebar.color_picker("Stroke color hex: ", "#FFFFFF")
10
- bg_color = st.sidebar.color_picker("Background color hex: ", "#000000")
11
- bg_image = st.sidebar.file_uploader("Background image:", type=["png", "jpg"])
12
- realtime_update = st.sidebar.checkbox("Update in realtime", True)
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  @st.cache_resource
15
  def load_mnist_model():
16
  return load_model("digit_recognization.keras")
17
 
18
  model = load_mnist_model()
19
 
 
20
  canvas_result = st_canvas(
21
- fill_color="rgba(255, 165, 0, 0.3)",
22
  stroke_width=stroke_width,
23
  stroke_color=stroke_color,
24
  background_color=bg_color,
@@ -26,15 +83,31 @@ canvas_result = st_canvas(
26
  height=280,
27
  width=280,
28
  drawing_mode=drawing_mode,
29
- key="canvas",
30
  )
31
 
32
-
33
  if canvas_result.image_data is not None:
34
- st.image(canvas_result.image_data, caption="Original Drawing")
35
- img = cv2.cvtColor(canvas_result.image_data.astype("uint8"), cv2.COLOR_RGBA2GRAY)
36
- img_resized = cv2.resize(img, (28, 28))
37
- img_normalized = img_resized / 255.0
38
- img_resh=img_normalized.reshape((1,28,28))
39
- prediction = model.predict(img_resh)
40
- st.write("Prediction:", np.argmax(prediction))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import numpy as np
5
  import cv2
6
 
7
+ # --------- Page Setup ---------
8
+ st.set_page_config(page_title="🎨 Digit Recognizer", layout="centered")
 
 
 
 
9
 
10
+ # --------- Custom CSS Styling ---------
11
+ st.markdown("""
12
+ <style>
13
+ @import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@500&family=Poppins:wght@300;400;600&display=swap');
14
+
15
+ html, body, [class*="css"] {
16
+ font-family: 'Poppins', sans-serif;
17
+ background: linear-gradient(to right, #0f2027, #203a43, #2c5364);
18
+ color: white;
19
+ }
20
+
21
+ h1 {
22
+ font-family: 'Orbitron', sans-serif;
23
+ font-size: 3em;
24
+ color: #00FFFF;
25
+ text-align: center;
26
+ text-shadow: 0px 0px 12px rgba(0, 255, 255, 0.8);
27
+ }
28
+
29
+ .section {
30
+ background: rgba(255, 255, 255, 0.07);
31
+ padding: 20px;
32
+ border-radius: 15px;
33
+ box-shadow: 0px 4px 20px rgba(0, 255, 255, 0.2);
34
+ margin-bottom: 30px;
35
+ }
36
+
37
+ .prediction-box {
38
+ font-size: 2em;
39
+ font-weight: bold;
40
+ text-align: center;
41
+ color: #00ffff;
42
+ background-color: rgba(255, 255, 255, 0.1);
43
+ padding: 15px;
44
+ border-radius: 10px;
45
+ margin-top: 20px;
46
+ box-shadow: 0px 0px 12px rgba(0, 255, 255, 0.4);
47
+ }
48
+
49
+ .emoji {
50
+ font-size: 3em;
51
+ text-align: center;
52
+ margin-top: 10px;
53
+ }
54
+ </style>
55
+ """, unsafe_allow_html=True)
56
+
57
+ # --------- Title ---------
58
+ st.markdown("<h1>Digit Recognizer 🔢✨</h1>", unsafe_allow_html=True)
59
+ st.markdown("<div style='text-align:center; font-size:18px;'>Draw a digit from 0–9 and let the AI predict it!</div>", unsafe_allow_html=True)
60
+
61
+ # --------- Sidebar: Drawing Settings ---------
62
+ st.sidebar.title("🛠️ Tool Settings")
63
+ drawing_mode = st.sidebar.selectbox("Drawing Tool", ("freedraw", "line", "rect", "circle", "transform"))
64
+ stroke_width = st.sidebar.slider("Stroke Width", 1, 25, 10)
65
+ stroke_color = st.sidebar.color_picker("Stroke Color", "#FFFFFF")
66
+ bg_color = st.sidebar.color_picker("Background Color", "#000000")
67
+ realtime_update = st.sidebar.checkbox("Update in Realtime", True)
68
+
69
+ # --------- Load Model ---------
70
  @st.cache_resource
71
  def load_mnist_model():
72
  return load_model("digit_recognization.keras")
73
 
74
  model = load_mnist_model()
75
 
76
+ # --------- Canvas ---------
77
  canvas_result = st_canvas(
78
+ fill_color="rgba(255, 255, 255, 0.3)",
79
  stroke_width=stroke_width,
80
  stroke_color=stroke_color,
81
  background_color=bg_color,
 
83
  height=280,
84
  width=280,
85
  drawing_mode=drawing_mode,
86
+ key="canvas"
87
  )
88
 
89
+ # --------- Prediction ---------
90
  if canvas_result.image_data is not None:
91
+ st.markdown("<div class='section'>", unsafe_allow_html=True)
92
+
93
+ col1, col2 = st.columns(2)
94
+
95
+ with col1:
96
+ st.image(canvas_result.image_data, caption="Your Drawing", use_column_width=True)
97
+
98
+ with col2:
99
+ # Preprocess
100
+ img = cv2.cvtColor(canvas_result.image_data.astype("uint8"), cv2.COLOR_RGBA2GRAY)
101
+ img_resized = cv2.resize(img, (28, 28))
102
+ img_normalized = img_resized / 255.0
103
+ img_reshaped = img_normalized.reshape((1, 28, 28))
104
+
105
+ # Predict
106
+ prediction = model.predict(img_reshaped)
107
+ predicted_digit = np.argmax(prediction)
108
+
109
+ st.markdown(f"<div class='prediction-box'>Prediction: {predicted_digit}</div>", unsafe_allow_html=True)
110
+ st.markdown(f"<div class='emoji'>{['0️⃣','1️⃣','2️⃣','3️⃣','4️⃣','5️⃣','6️⃣','7️⃣','8️⃣','9️⃣'][predicted_digit]}</div>", unsafe_allow_html=True)
111
+
112
+ st.markdown("</div>", unsafe_allow_html=True)
113
+