MBG0903 commited on
Commit
48a6c2d
·
verified ·
1 Parent(s): bacb675

Update src/app.py

Browse files
Files changed (1) hide show
  1. src/app.py +102 -43
src/app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import streamlit as st
2
  from streamlit_drawable_canvas import st_canvas
3
  import numpy as np
@@ -5,69 +8,125 @@ from tensorflow.keras.models import load_model
5
  from tensorflow.keras.preprocessing.image import img_to_array
6
  from PIL import Image
7
 
8
- # Branding Colors
 
 
9
  PRIMARY_COLOR = "#0F2C59" # Procelevate navy
10
- ACCENT_COLOR = "#FFD700" # Gold highlight
11
- BG_COLOR = "#F5F5F5" # Light background
12
-
13
- # Load Model (direct from same folder as app.py)
14
- model = load_model("quickdraw_model.keras")
15
  CLASSES = ["cat", "dog", "house", "car", "tree", "airplane", "sun", "fish", "flower", "bird"]
16
 
17
- # Page Config
18
  st.set_page_config(page_title="Guess the Doodle - Procelevate", page_icon="✏️", layout="wide")
19
 
20
- # Sidebar with logo
21
- st.sidebar.image("procelevate_logo.png", use_container_width=True)
22
  st.sidebar.markdown(
23
  f"<h3 style='color:{PRIMARY_COLOR};'>🚀 Procelevate Academy</h3>"
24
  "<p>Where AI meets fun & learning!</p>",
25
  unsafe_allow_html=True
26
  )
27
 
28
- # Title Section
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  st.markdown(
30
  f"""
31
  <div style='text-align:center;'>
32
- <h1 style='color:{PRIMARY_COLOR};'>🎨 Guess the Doodle – AI Game</h1>
33
- <p style='font-size:18px;'>Draw something below, and let AI guess it in real-time!</p>
34
  </div>
35
  """,
36
  unsafe_allow_html=True
37
  )
38
 
39
- # Drawing Canvas
40
- canvas_result = st_canvas(
41
- fill_color=BG_COLOR, stroke_width=10, stroke_color="black",
42
- background_color="white", width=280, height=280,
43
- drawing_mode="freedraw", key="canvas",
44
- )
45
-
46
- # Prediction
47
- if canvas_result.image_data is not None:
48
- img = Image.fromarray((canvas_result.image_data).astype("uint8")).convert("L")
49
- img = img.resize((28,28))
50
- img_array = img_to_array(img) / 255.0
51
- img_array = np.expand_dims(img_array, axis=0)
52
 
53
- preds = model.predict(img_array)[0]
54
- top_indices = preds.argsort()[-3:][::-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
 
56
  st.markdown(f"<h2 style='color:{ACCENT_COLOR};'>🤖 AI’s Best Guesses:</h2>", unsafe_allow_html=True)
57
 
58
- # Emoji mapping
59
- emoji_map = {
60
- "cat":"🐱","dog":"🐶","car":"🚗","tree":"🌳",
61
- "house":"🏠","sun":"☀️","airplane":"✈️",
62
- "fish":"🐟","flower":"🌸","bird":"🐦"
63
- }
64
-
65
- for idx in top_indices:
66
- emoji = emoji_map.get(CLASSES[idx], "🎯")
67
- st.markdown(
68
- f"<h3 style='color:{PRIMARY_COLOR};'>{emoji} {CLASSES[idx]} – {preds[idx]*100:.1f}%</h3>",
69
- unsafe_allow_html=True
70
- )
71
- st.progress(float(preds[idx]))
72
-
73
- st.success("💡 Tip: Try drawing a cat 🐱, car 🚗, tree 🌳, or sun ☀️!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
  import streamlit as st
5
  from streamlit_drawable_canvas import st_canvas
6
  import numpy as np
 
8
  from tensorflow.keras.preprocessing.image import img_to_array
9
  from PIL import Image
10
 
11
+ # -------------------
12
+ # Branding & Settings
13
+ # -------------------
14
  PRIMARY_COLOR = "#0F2C59" # Procelevate navy
15
+ ACCENT_COLOR = "#FFD700" # Gold highlight
16
+ BG_COLOR = "#F5F5F5" # Light background
 
 
 
17
  CLASSES = ["cat", "dog", "house", "car", "tree", "airplane", "sun", "fish", "flower", "bird"]
18
 
 
19
  st.set_page_config(page_title="Guess the Doodle - Procelevate", page_icon="✏️", layout="wide")
20
 
21
+ # Sidebar
22
+ st.sidebar.image("src/procelevate_logo.png", use_container_width=True) # also works if copied to /app/src
23
  st.sidebar.markdown(
24
  f"<h3 style='color:{PRIMARY_COLOR};'>🚀 Procelevate Academy</h3>"
25
  "<p>Where AI meets fun & learning!</p>",
26
  unsafe_allow_html=True
27
  )
28
 
29
+ # -------------------
30
+ # Robust model loader
31
+ # -------------------
32
+ @st.cache_resource(show_spinner=False)
33
+ def load_doodle_model():
34
+ candidates = [
35
+ "quickdraw_model.keras",
36
+ "src/quickdraw_model.keras",
37
+ "quickdraw_model.h5",
38
+ "src/quickdraw_model.h5",
39
+ ]
40
+ for p in candidates:
41
+ if Path(p).exists() and Path(p).is_file() and Path(p).stat().st_size > 0:
42
+ m = load_model(p)
43
+ return m, p
44
+ return None, None
45
+
46
+ model, model_path = load_doodle_model()
47
+ if model_path:
48
+ st.caption(f"Model loaded: `{model_path}`")
49
+ else:
50
+ st.warning("Model file not found. Running in demo mode with mock predictions.")
51
+
52
+ # -------------------
53
+ # Title + helper text
54
+ # -------------------
55
  st.markdown(
56
  f"""
57
  <div style='text-align:center;'>
58
+ <h1 style='color:{PRIMARY_COLOR}; margin-bottom:0;'>🎨 Guess the Doodle – AI Game</h1>
59
+ <p style='font-size:18px; margin-top:6px;'>Draw something below, and let AI guess it in real-time!</p>
60
  </div>
61
  """,
62
  unsafe_allow_html=True
63
  )
64
 
65
+ # -------------
66
+ # UI: 2 columns
67
+ # -------------
68
+ left, right = st.columns([1,1.1])
 
 
 
 
 
 
 
 
 
69
 
70
+ with left:
71
+ st.subheader("🖌️ Draw here")
72
+ canvas_result = st_canvas(
73
+ fill_color=BG_COLOR,
74
+ stroke_width=10,
75
+ stroke_color="black",
76
+ background_color="white",
77
+ width=280,
78
+ height=280,
79
+ drawing_mode="freedraw",
80
+ key="canvas",
81
+ )
82
+ col_a, col_b = st.columns(2)
83
+ with col_a:
84
+ if st.button("Clear", type="secondary"):
85
+ st.experimental_rerun()
86
+ with col_b:
87
+ st.caption("Tip: Try a 🐱 cat, 🚗 car, 🌳 tree, ☀️ sun")
88
 
89
+ with right:
90
  st.markdown(f"<h2 style='color:{ACCENT_COLOR};'>🤖 AI’s Best Guesses:</h2>", unsafe_allow_html=True)
91
 
92
+ if canvas_result.image_data is not None:
93
+ # Preprocess drawing to 28x28 grayscale
94
+ img = Image.fromarray((canvas_result.image_data).astype("uint8")).convert("L")
95
+ img = img.resize((28, 28))
96
+ img_array = img_to_array(img) / 255.0
97
+ img_array = np.expand_dims(img_array, axis=0)
98
+
99
+ # Predict
100
+ if model is not None:
101
+ preds = model.predict(img_array, verbose=0)[0]
102
+ else:
103
+ # Mock predictions so the demo UI still works if no model
104
+ preds = np.random.dirichlet(np.ones(len(CLASSES)), size=1)[0]
105
+
106
+ top_indices = preds.argsort()[-3:][::-1]
107
+
108
+ # Emoji map
109
+ emoji_map = {
110
+ "cat":"🐱","dog":"🐶","car":"🚗","tree":"🌳",
111
+ "house":"🏠","sun":"☀️","airplane":"✈️",
112
+ "fish":"🐟","flower":"🌸","bird":"🐦"
113
+ }
114
+
115
+ # Confetti / balloons if high confidence on top-1
116
+ if preds[top_indices[0]] >= 0.80:
117
+ st.balloons()
118
+
119
+ # Show guesses
120
+ for idx in top_indices:
121
+ label = CLASSES[idx]
122
+ score = float(preds[idx])
123
+ emoji = emoji_map.get(label, "🎯")
124
+
125
+ st.markdown(
126
+ f"<h3 style='color:{PRIMARY_COLOR}; margin:6px 0 4px 0;'>{emoji} {label} – {score*100:.1f}%</h3>",
127
+ unsafe_allow_html=True
128
+ )
129
+ st.progress(score)
130
+
131
+ # Footer hint
132
+ st.success("💡 Try: cat 🐱, car 🚗, tree 🌳, sun ☀️ — short, simple strokes work best!")