import os
from pathlib import Path
import io
import streamlit as st
from streamlit_drawable_canvas import st_canvas
import numpy as np
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import img_to_array
from PIL import Image, ImageOps, ImageFilter
# ----------------- Quiet TF logs (optional) -----------------
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
# ----------------- Branding -----------------
PRIMARY_COLOR = "#0F2C59"
ACCENT_COLOR = "#FFD700"
BG_COLOR = "#F5F5F5"
st.set_page_config(page_title="Guess the Doodle - Procelevate", page_icon="✏️", layout="wide")
# ----------------- Sidebar / Branding -----------------
# logo path works whether you're copying root or /src
if Path("src/procelevate_logo.png").exists():
st.sidebar.image("src/procelevate_logo.png")
elif Path("procelevate_logo.png").exists():
st.sidebar.image("procelevate_logo.png")
st.sidebar.markdown(
f"
🚀 Procelevate Academy
"
"Where AI meets fun & learning!
",
unsafe_allow_html=True
)
# ----------------- Load classes (guarantee order) -----------------
# If you saved a classes.txt from Colab with the exact CATEGORIES order, place it next to the model.
def load_classes():
for p in ["src/classes.txt", "classes.txt"]:
if Path(p).exists():
text = Path(p).read_text().strip()
arr = [x.strip() for x in text.splitlines() if x.strip()]
if arr:
return arr
# fallback to the training order we used in Colab
return ["cat","dog","car","tree","house","sun","airplane","fish","flower","bird"]
CLASSES = load_classes()
# ----------------- Load model robustly -----------------
@st.cache_resource(show_spinner=False)
def load_doodle_model():
for p in ["src/quickdraw_model.keras", "quickdraw_model.keras", "src/quickdraw_model.h5", "quickdraw_model.h5"]:
if Path(p).exists() and Path(p).stat().st_size > 0:
return load_model(p), p
return None, None
model, model_path = load_doodle_model()
if model_path:
st.caption(f"Model loaded: `{model_path}`")
else:
st.error("Model file not found. Upload `quickdraw_model.keras` (or .h5).")
st.stop()
# ----------------- Preprocessing controls -----------------
st.sidebar.markdown("### ✨ Preprocessing")
invert_colors = st.sidebar.checkbox("Invert colors", value=True,
help="QuickDraw bitmaps are often drawn as black-on-white but some pipelines invert; try this ON.")
binary_thresh = st.sidebar.slider("Binarize threshold", 0, 255, 180,
help="Higher = thinner lines preserved; lower = thicker.")
thicken_strokes = st.sidebar.checkbox("Thicken strokes", value=True,
help="Applies a slight dilation to strengthen faint lines.")
show_processed = st.sidebar.checkbox("Show processed 28×28 preview", value=True)
# ----------------- Title -----------------
st.markdown(
f"""
🎨 Guess the Doodle – AI Game
Draw something below, then hit Predict.
""",
unsafe_allow_html=True
)
# ----------------- Canvas -----------------
left, right = st.columns([1,1.1])
with left:
st.subheader("🖌️ Draw here")
canvas_result = st_canvas(
fill_color=BG_COLOR, stroke_width=10, stroke_color="black",
background_color="white", width=280, height=280,
drawing_mode="freedraw", key="canvas",
)
c1, c2 = st.columns(2)
with c1:
do_pred = st.button("🔮 Predict")
with c2:
if st.button("♻️ Reset"):
st.experimental_rerun()
# ----------------- Preprocess helper -----------------
def preprocess_rgba_to_28x28(img_rgba: Image.Image) -> Image.Image:
# Convert RGBA->L
img = img_rgba.convert("L")
# Optional invert (try True first; many QuickDraw pipelines use inverted inputs)
if invert_colors:
img = ImageOps.invert(img)
# Binarize
img = img.point(lambda x: 255 if x >= binary_thresh else 0, mode="1").convert("L")
# Find bounding box of the drawing
np_img = np.array(img)
ys, xs = np.where(np_img < 255) # non-white pixels
if len(xs) == 0 or len(ys) == 0:
# empty canvas fallback
return img.resize((28, 28))
xmin, xmax = xs.min(), xs.max()
ymin, ymax = ys.min(), ys.max()
# Crop to content
img = img.crop((xmin, ymin, xmax+1, ymax+1))
# Make square (pad)
w, h = img.size
pad = abs(w - h) // 2
if w > h:
img = ImageOps.expand(img, border=(0, pad, 0, w - h - pad), fill=255)
elif h > w:
img = ImageOps.expand(img, border=(pad, 0, w - h - pad, 0), fill=255)
# Optional thicken
if thicken_strokes:
img = img.filter(ImageFilter.MaxFilter(3)) # light dilation
# Resize to 28x28
img = img.resize((28, 28), resample=Image.NEAREST)
return img
# ----------------- Prediction -----------------
with right:
st.markdown(f"🤖 AI’s Best Guesses:
", unsafe_allow_html=True)
if do_pred and canvas_result.image_data is not None:
# 1) Preprocess to 28x28
raw = Image.fromarray((canvas_result.image_data).astype("uint8"))
proc = preprocess_rgba_to_28x28(raw)
if show_processed:
st.caption("Model input preview (scaled up):")
st.image(proc.resize((140, 140), resample=Image.NEAREST))
# 2) Model predict
arr = img_to_array(proc) / 255.0 # (28,28,1)
arr = np.expand_dims(arr, axis=0) # (1,28,28,1)
preds = model.predict(arr, verbose=0)[0] # (10,)
top = preds.argsort()[-3:][::-1]
# 3) Show results
emoji_map = {"cat":"🐱","dog":"🐶","car":"🚗","tree":"🌳","house":"🏠",
"sun":"☀️","airplane":"✈️","fish":"🐟","flower":"🌸","bird":"🐦"}
if preds[top[0]] >= 0.80:
st.balloons()
for idx in top:
label = CLASSES[idx]
score = float(preds[idx])
st.markdown(
f"{emoji_map.get(label,'🎯')} {label} – {score*100:.1f}%
",
unsafe_allow_html=True
)
st.progress(score)
# ----------------- Legend -----------------
st.markdown(
"
📖 Supported Classes:
" +
" ".join([f"🐱 cat", "🐶 dog", "🚗 car", "🌳 tree", "🏠 house", "☀️ sun", "✈️ airplane", "🐟 fish", "🌸 flower", "🐦 bird"]),
unsafe_allow_html=True
)
st.success("💡 Tips: draw big, single clear object; avoid tiny sketches; use continuous strokes.")