MBG0903 commited on
Commit
4054926
·
verified ·
1 Parent(s): c668f9a

Update src/app.py

Browse files
Files changed (1) hide show
  1. src/app.py +152 -63
src/app.py CHANGED
@@ -1,96 +1,185 @@
1
  import os
 
 
 
2
  import streamlit as st
3
  from streamlit_drawable_canvas import st_canvas
4
  import numpy as np
5
  from tensorflow.keras.models import load_model
6
  from tensorflow.keras.preprocessing.image import img_to_array
7
- from PIL import Image
8
 
9
- # Reduce TensorFlow warnings
10
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
11
  os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
12
 
13
- # Branding Colors
14
- PRIMARY_COLOR = "#0F2C59" # Procelevate navy
15
- ACCENT_COLOR = "#FFD700" # Gold highlight
16
- BG_COLOR = "#F5F5F5" # Light background
17
-
18
- # Load Model
19
- try:
20
- model = load_model("src/quickdraw_model.keras")
21
- except Exception as e:
22
- st.error(f"❌ Could not load model: {e}")
23
- st.stop()
24
-
25
- CLASSES = ["cat", "dog", "house", "car", "tree", "airplane", "sun", "fish", "flower", "bird"]
26
 
27
- # Emoji map for results
28
- emoji_map = {
29
- "cat": "🐱", "dog": "🐶", "car": "🚗", "tree": "🌳",
30
- "house": "🏠", "sun": "☀️", "airplane": "✈️",
31
- "fish": "🐟", "flower": "🌸", "bird": "🐦"
32
- }
33
-
34
- # Page Config
35
  st.set_page_config(page_title="Guess the Doodle - Procelevate", page_icon="✏️", layout="wide")
36
 
37
- # Sidebar with logo
38
- st.sidebar.image("src/procelevate_logo.png")
 
 
 
 
 
39
  st.sidebar.markdown(
40
  f"<h3 style='color:{PRIMARY_COLOR};'>🚀 Procelevate Academy</h3>"
41
  "<p>Where AI meets fun & learning!</p>",
42
  unsafe_allow_html=True
43
  )
44
 
45
- # Title Section
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  st.markdown(
47
  f"""
48
  <div style='text-align:center;'>
49
- <h1 style='color:{PRIMARY_COLOR};'>🎨 Guess the Doodle – AI Game</h1>
50
- <p style='font-size:18px;'>Draw something below, then hit <b>Predict</b> to see AI’s guesses!</p>
51
  </div>
52
  """,
53
  unsafe_allow_html=True
54
  )
55
 
56
- # Drawing Canvas
57
- canvas_result = st_canvas(
58
- fill_color=BG_COLOR, stroke_width=10, stroke_color="black",
59
- background_color="white", width=280, height=280,
60
- drawing_mode="freedraw", key="canvas",
61
- )
62
-
63
- # Predict Button
64
- do_pred = st.button("🔮 Predict")
65
-
66
- # Prediction Block
67
- if do_pred and canvas_result.image_data is not None:
68
- img = Image.fromarray((canvas_result.image_data).astype("uint8")).convert("L")
69
- img = img.resize((28, 28))
70
- img_array = img_to_array(img) / 255.0
71
- img_array = np.expand_dims(img_array, axis=0)
72
-
73
- preds = model.predict(img_array)[0]
74
- top_indices = preds.argsort()[-3:][::-1]
75
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  st.markdown(f"<h2 style='color:{ACCENT_COLOR};'>🤖 AI’s Best Guesses:</h2>", unsafe_allow_html=True)
77
 
78
- for idx in top_indices:
79
- emoji = emoji_map.get(CLASSES[idx], "🎯")
80
- st.markdown(
81
- f"<h3 style='color:{PRIMARY_COLOR};'>{emoji} {CLASSES[idx]} – {preds[idx]*100:.1f}%</h3>",
82
- unsafe_allow_html=True
83
- )
84
- st.progress(float(preds[idx]))
85
-
86
- st.balloons()
87
-
88
- # Class Legend
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  st.markdown(
90
  "<hr><h4>📖 Supported Classes:</h4>" +
91
- " ".join([f"{emoji_map[c]} {c}" for c in CLASSES]),
92
  unsafe_allow_html=True
93
  )
94
-
95
- # Tip
96
- st.success("💡 Tip: Try drawing a cat 🐱, car 🚗, tree 🌳, or sun ☀️!")
 
1
  import os
2
+ from pathlib import Path
3
+ import io
4
+
5
  import streamlit as st
6
  from streamlit_drawable_canvas import st_canvas
7
  import numpy as np
8
  from tensorflow.keras.models import load_model
9
  from tensorflow.keras.preprocessing.image import img_to_array
10
+ from PIL import Image, ImageOps, ImageFilter
11
 
12
+ # ----------------- Quiet TF logs (optional) -----------------
13
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
14
  os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
15
 
16
+ # ----------------- Branding -----------------
17
+ PRIMARY_COLOR = "#0F2C59"
18
+ ACCENT_COLOR = "#FFD700"
19
+ BG_COLOR = "#F5F5F5"
 
 
 
 
 
 
 
 
 
20
 
 
 
 
 
 
 
 
 
21
  st.set_page_config(page_title="Guess the Doodle - Procelevate", page_icon="✏️", layout="wide")
22
 
23
+ # ----------------- Sidebar / Branding -----------------
24
+ # logo path works whether you're copying root or /src
25
+ if Path("src/procelevate_logo.png").exists():
26
+ st.sidebar.image("src/procelevate_logo.png")
27
+ elif Path("procelevate_logo.png").exists():
28
+ st.sidebar.image("procelevate_logo.png")
29
+
30
  st.sidebar.markdown(
31
  f"<h3 style='color:{PRIMARY_COLOR};'>🚀 Procelevate Academy</h3>"
32
  "<p>Where AI meets fun & learning!</p>",
33
  unsafe_allow_html=True
34
  )
35
 
36
+ # ----------------- Load classes (guarantee order) -----------------
37
+ # If you saved a classes.txt from Colab with the exact CATEGORIES order, place it next to the model.
38
+ def load_classes():
39
+ for p in ["src/classes.txt", "classes.txt"]:
40
+ if Path(p).exists():
41
+ text = Path(p).read_text().strip()
42
+ arr = [x.strip() for x in text.splitlines() if x.strip()]
43
+ if arr:
44
+ return arr
45
+ # fallback to the training order we used in Colab
46
+ return ["cat","dog","car","tree","house","sun","airplane","fish","flower","bird"]
47
+
48
+ CLASSES = load_classes()
49
+
50
+ # ----------------- Load model robustly -----------------
51
+ @st.cache_resource(show_spinner=False)
52
+ def load_doodle_model():
53
+ for p in ["src/quickdraw_model.keras", "quickdraw_model.keras", "src/quickdraw_model.h5", "quickdraw_model.h5"]:
54
+ if Path(p).exists() and Path(p).stat().st_size > 0:
55
+ return load_model(p), p
56
+ return None, None
57
+
58
+ model, model_path = load_doodle_model()
59
+ if model_path:
60
+ st.caption(f"Model loaded: `{model_path}`")
61
+ else:
62
+ st.error("Model file not found. Upload `quickdraw_model.keras` (or .h5).")
63
+ st.stop()
64
+
65
+ # ----------------- Preprocessing controls -----------------
66
+ st.sidebar.markdown("### ✨ Preprocessing")
67
+ invert_colors = st.sidebar.checkbox("Invert colors", value=True,
68
+ help="QuickDraw bitmaps are often drawn as black-on-white but some pipelines invert; try this ON.")
69
+ binary_thresh = st.sidebar.slider("Binarize threshold", 0, 255, 180,
70
+ help="Higher = thinner lines preserved; lower = thicker.")
71
+ thicken_strokes = st.sidebar.checkbox("Thicken strokes", value=True,
72
+ help="Applies a slight dilation to strengthen faint lines.")
73
+ show_processed = st.sidebar.checkbox("Show processed 28×28 preview", value=True)
74
+
75
+ # ----------------- Title -----------------
76
  st.markdown(
77
  f"""
78
  <div style='text-align:center;'>
79
+ <h1 style='color:{PRIMARY_COLOR}; margin-bottom:0;'>🎨 Guess the Doodle – AI Game</h1>
80
+ <p style='font-size:18px; margin-top:6px;'>Draw something below, then hit <b>Predict</b>.</p>
81
  </div>
82
  """,
83
  unsafe_allow_html=True
84
  )
85
 
86
+ # ----------------- Canvas -----------------
87
+ left, right = st.columns([1,1.1])
88
+
89
+ with left:
90
+ st.subheader("🖌️ Draw here")
91
+ canvas_result = st_canvas(
92
+ fill_color=BG_COLOR, stroke_width=10, stroke_color="black",
93
+ background_color="white", width=280, height=280,
94
+ drawing_mode="freedraw", key="canvas",
95
+ )
96
+ c1, c2 = st.columns(2)
97
+ with c1:
98
+ do_pred = st.button("🔮 Predict")
99
+ with c2:
100
+ if st.button("♻️ Reset"):
101
+ st.experimental_rerun()
102
+
103
+ # ----------------- Preprocess helper -----------------
104
+ def preprocess_rgba_to_28x28(img_rgba: Image.Image) -> Image.Image:
105
+ # Convert RGBA->L
106
+ img = img_rgba.convert("L")
107
+
108
+ # Optional invert (try True first; many QuickDraw pipelines use inverted inputs)
109
+ if invert_colors:
110
+ img = ImageOps.invert(img)
111
+
112
+ # Binarize
113
+ img = img.point(lambda x: 255 if x >= binary_thresh else 0, mode="1").convert("L")
114
+
115
+ # Find bounding box of the drawing
116
+ np_img = np.array(img)
117
+ ys, xs = np.where(np_img < 255) # non-white pixels
118
+ if len(xs) == 0 or len(ys) == 0:
119
+ # empty canvas fallback
120
+ return img.resize((28, 28))
121
+
122
+ xmin, xmax = xs.min(), xs.max()
123
+ ymin, ymax = ys.min(), ys.max()
124
+
125
+ # Crop to content
126
+ img = img.crop((xmin, ymin, xmax+1, ymax+1))
127
+
128
+ # Make square (pad)
129
+ w, h = img.size
130
+ pad = abs(w - h) // 2
131
+ if w > h:
132
+ img = ImageOps.expand(img, border=(0, pad, 0, w - h - pad), fill=255)
133
+ elif h > w:
134
+ img = ImageOps.expand(img, border=(pad, 0, w - h - pad, 0), fill=255)
135
+
136
+ # Optional thicken
137
+ if thicken_strokes:
138
+ img = img.filter(ImageFilter.MaxFilter(3)) # light dilation
139
+
140
+ # Resize to 28x28
141
+ img = img.resize((28, 28), resample=Image.NEAREST)
142
+ return img
143
+
144
+ # ----------------- Prediction -----------------
145
+ with right:
146
  st.markdown(f"<h2 style='color:{ACCENT_COLOR};'>🤖 AI’s Best Guesses:</h2>", unsafe_allow_html=True)
147
 
148
+ if do_pred and canvas_result.image_data is not None:
149
+ # 1) Preprocess to 28x28
150
+ raw = Image.fromarray((canvas_result.image_data).astype("uint8"))
151
+ proc = preprocess_rgba_to_28x28(raw)
152
+
153
+ if show_processed:
154
+ st.caption("Model input preview (scaled up):")
155
+ st.image(proc.resize((140, 140), resample=Image.NEAREST))
156
+
157
+ # 2) Model predict
158
+ arr = img_to_array(proc) / 255.0 # (28,28,1)
159
+ arr = np.expand_dims(arr, axis=0) # (1,28,28,1)
160
+ preds = model.predict(arr, verbose=0)[0] # (10,)
161
+ top = preds.argsort()[-3:][::-1]
162
+
163
+ # 3) Show results
164
+ emoji_map = {"cat":"🐱","dog":"🐶","car":"🚗","tree":"🌳","house":"🏠",
165
+ "sun":"☀️","airplane":"✈️","fish":"🐟","flower":"🌸","bird":"🐦"}
166
+
167
+ if preds[top[0]] >= 0.80:
168
+ st.balloons()
169
+
170
+ for idx in top:
171
+ label = CLASSES[idx]
172
+ score = float(preds[idx])
173
+ st.markdown(
174
+ f"<h3 style='color:{PRIMARY_COLOR}; margin:6px 0 4px 0;'>{emoji_map.get(label,'🎯')} {label} – {score*100:.1f}%</h3>",
175
+ unsafe_allow_html=True
176
+ )
177
+ st.progress(score)
178
+
179
+ # ----------------- Legend -----------------
180
  st.markdown(
181
  "<hr><h4>📖 Supported Classes:</h4>" +
182
+ " ".join([f"🐱 cat", "🐶 dog", "🚗 car", "🌳 tree", "🏠 house", "☀️ sun", "✈️ airplane", "🐟 fish", "🌸 flower", "🐦 bird"]),
183
  unsafe_allow_html=True
184
  )
185
+ st.success("💡 Tips: draw big, single clear object; avoid tiny sketches; use continuous strokes.")