File size: 6,953 Bytes
48a6c2d
4054926
 
 
4460dc8
4cafddd
 
 
 
4054926
4cafddd
4054926
c668f9a
 
 
4054926
 
 
 
4cafddd
 
 
4054926
 
 
 
 
 
 
4cafddd
 
 
 
 
 
4054926
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cafddd
 
 
4054926
 
4cafddd
 
 
 
 
4054926
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cafddd
 
4054926
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c668f9a
 
4054926
c668f9a
 
4054926
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import os
from pathlib import Path
import io

import streamlit as st
from streamlit_drawable_canvas import st_canvas
import numpy as np
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import img_to_array
from PIL import Image, ImageOps, ImageFilter

# ----------------- Quiet TF logs (optional) -----------------
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"

# ----------------- Branding -----------------
PRIMARY_COLOR = "#0F2C59"
ACCENT_COLOR  = "#FFD700"
BG_COLOR      = "#F5F5F5"

st.set_page_config(page_title="Guess the Doodle - Procelevate", page_icon="✏️", layout="wide")

# ----------------- Sidebar / Branding -----------------
# logo path works whether you're copying root or /src
if Path("src/procelevate_logo.png").exists():
    st.sidebar.image("src/procelevate_logo.png")
elif Path("procelevate_logo.png").exists():
    st.sidebar.image("procelevate_logo.png")

st.sidebar.markdown(
    f"<h3 style='color:{PRIMARY_COLOR};'>🚀 Procelevate Academy</h3>"
    "<p>Where AI meets fun & learning!</p>",
    unsafe_allow_html=True
)

# ----------------- Load classes (guarantee order) -----------------
# If you saved a classes.txt from Colab with the exact CATEGORIES order, place it next to the model.
def load_classes():
    for p in ["src/classes.txt", "classes.txt"]:
        if Path(p).exists():
            text = Path(p).read_text().strip()
            arr = [x.strip() for x in text.splitlines() if x.strip()]
            if arr:
                return arr
    # fallback to the training order we used in Colab
    return ["cat","dog","car","tree","house","sun","airplane","fish","flower","bird"]

CLASSES = load_classes()

# ----------------- Load model robustly -----------------
@st.cache_resource(show_spinner=False)
def load_doodle_model():
    for p in ["src/quickdraw_model.keras", "quickdraw_model.keras", "src/quickdraw_model.h5", "quickdraw_model.h5"]:
        if Path(p).exists() and Path(p).stat().st_size > 0:
            return load_model(p), p
    return None, None

model, model_path = load_doodle_model()
if model_path:
    st.caption(f"Model loaded: `{model_path}`")
else:
    st.error("Model file not found. Upload `quickdraw_model.keras` (or .h5).")
    st.stop()

# ----------------- Preprocessing controls -----------------
st.sidebar.markdown("### ✨ Preprocessing")
invert_colors   = st.sidebar.checkbox("Invert colors", value=True,
                                      help="QuickDraw bitmaps are often drawn as black-on-white but some pipelines invert; try this ON.")
binary_thresh   = st.sidebar.slider("Binarize threshold", 0, 255, 180,
                                    help="Higher = thinner lines preserved; lower = thicker.")
thicken_strokes = st.sidebar.checkbox("Thicken strokes", value=True,
                                      help="Applies a slight dilation to strengthen faint lines.")
show_processed  = st.sidebar.checkbox("Show processed 28×28 preview", value=True)

# ----------------- Title -----------------
st.markdown(
    f"""
    <div style='text-align:center;'>
        <h1 style='color:{PRIMARY_COLOR}; margin-bottom:0;'>🎨 Guess the Doodle – AI Game</h1>
        <p style='font-size:18px; margin-top:6px;'>Draw something below, then hit <b>Predict</b>.</p>
    </div>
    """,
    unsafe_allow_html=True
)

# ----------------- Canvas -----------------
left, right = st.columns([1,1.1])

with left:
    st.subheader("🖌️ Draw here")
    canvas_result = st_canvas(
        fill_color=BG_COLOR, stroke_width=10, stroke_color="black",
        background_color="white", width=280, height=280,
        drawing_mode="freedraw", key="canvas",
    )
    c1, c2 = st.columns(2)
    with c1:
        do_pred = st.button("🔮 Predict")
    with c2:
        if st.button("♻️ Reset"):
            st.experimental_rerun()

# ----------------- Preprocess helper -----------------
def preprocess_rgba_to_28x28(img_rgba: Image.Image) -> Image.Image:
    # Convert RGBA->L
    img = img_rgba.convert("L")

    # Optional invert (try True first; many QuickDraw pipelines use inverted inputs)
    if invert_colors:
        img = ImageOps.invert(img)

    # Binarize
    img = img.point(lambda x: 255 if x >= binary_thresh else 0, mode="1").convert("L")

    # Find bounding box of the drawing
    np_img = np.array(img)
    ys, xs = np.where(np_img < 255)  # non-white pixels
    if len(xs) == 0 or len(ys) == 0:
        # empty canvas fallback
        return img.resize((28, 28))

    xmin, xmax = xs.min(), xs.max()
    ymin, ymax = ys.min(), ys.max()

    # Crop to content
    img = img.crop((xmin, ymin, xmax+1, ymax+1))

    # Make square (pad)
    w, h = img.size
    pad = abs(w - h) // 2
    if w > h:
        img = ImageOps.expand(img, border=(0, pad, 0, w - h - pad), fill=255)
    elif h > w:
        img = ImageOps.expand(img, border=(pad, 0, w - h - pad, 0), fill=255)

    # Optional thicken
    if thicken_strokes:
        img = img.filter(ImageFilter.MaxFilter(3))  # light dilation

    # Resize to 28x28
    img = img.resize((28, 28), resample=Image.NEAREST)
    return img

# ----------------- Prediction -----------------
with right:
    st.markdown(f"<h2 style='color:{ACCENT_COLOR};'>🤖 AI’s Best Guesses:</h2>", unsafe_allow_html=True)

    if do_pred and canvas_result.image_data is not None:
        # 1) Preprocess to 28x28
        raw = Image.fromarray((canvas_result.image_data).astype("uint8"))
        proc = preprocess_rgba_to_28x28(raw)

        if show_processed:
            st.caption("Model input preview (scaled up):")
            st.image(proc.resize((140, 140), resample=Image.NEAREST))

        # 2) Model predict
        arr = img_to_array(proc) / 255.0            # (28,28,1)
        arr = np.expand_dims(arr, axis=0)           # (1,28,28,1)
        preds = model.predict(arr, verbose=0)[0]    # (10,)
        top = preds.argsort()[-3:][::-1]

        # 3) Show results
        emoji_map = {"cat":"🐱","dog":"🐶","car":"🚗","tree":"🌳","house":"🏠",
                     "sun":"☀️","airplane":"✈️","fish":"🐟","flower":"🌸","bird":"🐦"}

        if preds[top[0]] >= 0.80:
            st.balloons()

        for idx in top:
            label = CLASSES[idx]
            score = float(preds[idx])
            st.markdown(
                f"<h3 style='color:{PRIMARY_COLOR}; margin:6px 0 4px 0;'>{emoji_map.get(label,'🎯')} {label}{score*100:.1f}%</h3>",
                unsafe_allow_html=True
            )
            st.progress(score)

# ----------------- Legend -----------------
st.markdown(
    "<hr><h4>📖 Supported Classes:</h4>" +
    " ".join([f"🐱 cat", "🐶 dog", "🚗 car", "🌳 tree", "🏠 house", "☀️ sun", "✈️ airplane", "🐟 fish", "🌸 flower", "🐦 bird"]),
    unsafe_allow_html=True
)
st.success("💡 Tips: draw big, single clear object; avoid tiny sketches; use continuous strokes.")