MBG0903 commited on
Commit
c668f9a
·
verified ·
1 Parent(s): bf80002

Update src/app.py

Browse files
Files changed (1) hide show
  1. src/app.py +67 -103
src/app.py CHANGED
@@ -1,6 +1,4 @@
1
  import os
2
- from pathlib import Path
3
-
4
  import streamlit as st
5
  from streamlit_drawable_canvas import st_canvas
6
  import numpy as np
@@ -8,125 +6,91 @@ from tensorflow.keras.models import load_model
8
  from tensorflow.keras.preprocessing.image import img_to_array
9
  from PIL import Image
10
 
11
- # -------------------
12
- # Branding & Settings
13
- # -------------------
 
 
14
  PRIMARY_COLOR = "#0F2C59" # Procelevate navy
15
- ACCENT_COLOR = "#FFD700" # Gold highlight
16
- BG_COLOR = "#F5F5F5" # Light background
 
 
 
 
 
 
 
 
17
  CLASSES = ["cat", "dog", "house", "car", "tree", "airplane", "sun", "fish", "flower", "bird"]
18
 
 
 
 
 
 
 
 
 
19
  st.set_page_config(page_title="Guess the Doodle - Procelevate", page_icon="✏️", layout="wide")
20
 
21
- # Sidebar
22
- st.sidebar.image("src/procelevate_logo.png") # also works if copied to /app/src
23
  st.sidebar.markdown(
24
  f"<h3 style='color:{PRIMARY_COLOR};'>🚀 Procelevate Academy</h3>"
25
  "<p>Where AI meets fun & learning!</p>",
26
  unsafe_allow_html=True
27
  )
28
 
29
- # -------------------
30
- # Robust model loader
31
- # -------------------
32
- @st.cache_resource(show_spinner=False)
33
- def load_doodle_model():
34
- candidates = [
35
- "quickdraw_model.keras",
36
- "src/quickdraw_model.keras",
37
- "quickdraw_model.h5",
38
- "src/quickdraw_model.h5",
39
- ]
40
- for p in candidates:
41
- if Path(p).exists() and Path(p).is_file() and Path(p).stat().st_size > 0:
42
- m = load_model(p)
43
- return m, p
44
- return None, None
45
-
46
- model, model_path = load_doodle_model()
47
- if model_path:
48
- st.caption(f"Model loaded: `{model_path}`")
49
- else:
50
- st.warning("Model file not found. Running in demo mode with mock predictions.")
51
-
52
- # -------------------
53
- # Title + helper text
54
- # -------------------
55
  st.markdown(
56
  f"""
57
  <div style='text-align:center;'>
58
- <h1 style='color:{PRIMARY_COLOR}; margin-bottom:0;'>🎨 Guess the Doodle – AI Game</h1>
59
- <p style='font-size:18px; margin-top:6px;'>Draw something below, and let AI guess it in real-time!</p>
60
  </div>
61
  """,
62
  unsafe_allow_html=True
63
  )
64
 
65
- # -------------
66
- # UI: 2 columns
67
- # -------------
68
- left, right = st.columns([1,1.1])
69
-
70
- with left:
71
- st.subheader("🖌️ Draw here")
72
- canvas_result = st_canvas(
73
- fill_color=BG_COLOR,
74
- stroke_width=10,
75
- stroke_color="black",
76
- background_color="white",
77
- width=280,
78
- height=280,
79
- drawing_mode="freedraw",
80
- key="canvas",
81
- )
82
- col_a, col_b = st.columns(2)
83
- with col_a:
84
- if st.button("Clear", type="secondary"):
85
- st.experimental_rerun()
86
- with col_b:
87
- st.caption("Tip: Try a 🐱 cat, 🚗 car, 🌳 tree, ☀️ sun")
88
-
89
- with right:
90
  st.markdown(f"<h2 style='color:{ACCENT_COLOR};'>🤖 AI’s Best Guesses:</h2>", unsafe_allow_html=True)
91
 
92
- if canvas_result.image_data is not None:
93
- # Preprocess drawing to 28x28 grayscale
94
- img = Image.fromarray((canvas_result.image_data).astype("uint8")).convert("L")
95
- img = img.resize((28, 28))
96
- img_array = img_to_array(img) / 255.0
97
- img_array = np.expand_dims(img_array, axis=0)
98
-
99
- # Predict
100
- if model is not None:
101
- preds = model.predict(img_array, verbose=0)[0]
102
- else:
103
- # Mock predictions so the demo UI still works if no model
104
- preds = np.random.dirichlet(np.ones(len(CLASSES)), size=1)[0]
105
-
106
- top_indices = preds.argsort()[-3:][::-1]
107
-
108
- # Emoji map
109
- emoji_map = {
110
- "cat":"🐱","dog":"🐶","car":"���","tree":"🌳",
111
- "house":"🏠","sun":"☀️","airplane":"✈️",
112
- "fish":"🐟","flower":"🌸","bird":"🐦"
113
- }
114
-
115
- # Confetti / balloons if high confidence on top-1
116
- if preds[top_indices[0]] >= 0.80:
117
- st.balloons()
118
-
119
- # Show guesses
120
- for idx in top_indices:
121
- label = CLASSES[idx]
122
- score = float(preds[idx])
123
- emoji = emoji_map.get(label, "🎯")
124
-
125
- st.markdown(
126
- f"<h3 style='color:{PRIMARY_COLOR}; margin:6px 0 4px 0;'>{emoji} {label} – {score*100:.1f}%</h3>",
127
- unsafe_allow_html=True
128
- )
129
- st.progress(score)
130
-
131
- # Footer hint
132
- st.success("💡 Try: cat 🐱, car 🚗, tree 🌳, sun ☀️ — short, simple strokes work best!")
 
1
  import os
 
 
2
  import streamlit as st
3
  from streamlit_drawable_canvas import st_canvas
4
  import numpy as np
 
6
  from tensorflow.keras.preprocessing.image import img_to_array
7
  from PIL import Image
8
 
9
+ # Reduce TensorFlow warnings
10
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
11
+ os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
12
+
13
+ # Branding Colors
14
  PRIMARY_COLOR = "#0F2C59" # Procelevate navy
15
+ ACCENT_COLOR = "#FFD700" # Gold highlight
16
+ BG_COLOR = "#F5F5F5" # Light background
17
+
18
+ # Load Model
19
+ try:
20
+ model = load_model("src/quickdraw_model.keras")
21
+ except Exception as e:
22
+ st.error(f"❌ Could not load model: {e}")
23
+ st.stop()
24
+
25
  CLASSES = ["cat", "dog", "house", "car", "tree", "airplane", "sun", "fish", "flower", "bird"]
26
 
27
+ # Emoji map for results
28
+ emoji_map = {
29
+ "cat": "🐱", "dog": "🐶", "car": "🚗", "tree": "🌳",
30
+ "house": "🏠", "sun": "☀️", "airplane": "✈️",
31
+ "fish": "🐟", "flower": "🌸", "bird": "🐦"
32
+ }
33
+
34
+ # Page Config
35
  st.set_page_config(page_title="Guess the Doodle - Procelevate", page_icon="✏️", layout="wide")
36
 
37
+ # Sidebar with logo
38
+ st.sidebar.image("src/procelevate_logo.png")
39
  st.sidebar.markdown(
40
  f"<h3 style='color:{PRIMARY_COLOR};'>🚀 Procelevate Academy</h3>"
41
  "<p>Where AI meets fun & learning!</p>",
42
  unsafe_allow_html=True
43
  )
44
 
45
+ # Title Section
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  st.markdown(
47
  f"""
48
  <div style='text-align:center;'>
49
+ <h1 style='color:{PRIMARY_COLOR};'>🎨 Guess the Doodle – AI Game</h1>
50
+ <p style='font-size:18px;'>Draw something below, then hit <b>Predict</b> to see AI’s guesses!</p>
51
  </div>
52
  """,
53
  unsafe_allow_html=True
54
  )
55
 
56
+ # Drawing Canvas
57
+ canvas_result = st_canvas(
58
+ fill_color=BG_COLOR, stroke_width=10, stroke_color="black",
59
+ background_color="white", width=280, height=280,
60
+ drawing_mode="freedraw", key="canvas",
61
+ )
62
+
63
+ # Predict Button
64
+ do_pred = st.button("🔮 Predict")
65
+
66
+ # Prediction Block
67
+ if do_pred and canvas_result.image_data is not None:
68
+ img = Image.fromarray((canvas_result.image_data).astype("uint8")).convert("L")
69
+ img = img.resize((28, 28))
70
+ img_array = img_to_array(img) / 255.0
71
+ img_array = np.expand_dims(img_array, axis=0)
72
+
73
+ preds = model.predict(img_array)[0]
74
+ top_indices = preds.argsort()[-3:][::-1]
75
+
 
 
 
 
 
76
  st.markdown(f"<h2 style='color:{ACCENT_COLOR};'>🤖 AI’s Best Guesses:</h2>", unsafe_allow_html=True)
77
 
78
+ for idx in top_indices:
79
+ emoji = emoji_map.get(CLASSES[idx], "🎯")
80
+ st.markdown(
81
+ f"<h3 style='color:{PRIMARY_COLOR};'>{emoji} {CLASSES[idx]} – {preds[idx]*100:.1f}%</h3>",
82
+ unsafe_allow_html=True
83
+ )
84
+ st.progress(float(preds[idx]))
85
+
86
+ st.balloons()
87
+
88
+ # Class Legend
89
+ st.markdown(
90
+ "<hr><h4>📖 Supported Classes:</h4>" +
91
+ " ".join([f"{emoji_map[c]} {c}" for c in CLASSES]),
92
+ unsafe_allow_html=True
93
+ )
94
+
95
+ # Tip
96
+ st.success("💡 Tip: Try drawing a cat 🐱, car 🚗, tree 🌳, or sun ☀️!")