Spaces:
Sleeping
Sleeping
| import os | |
| from pathlib import Path | |
| import io | |
| import streamlit as st | |
| from streamlit_drawable_canvas import st_canvas | |
| import numpy as np | |
| from tensorflow.keras.models import load_model | |
| from tensorflow.keras.preprocessing.image import img_to_array | |
| from PIL import Image, ImageOps, ImageFilter | |
| # ----------------- Quiet TF logs (optional) ----------------- | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" | |
| os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" | |
| # ----------------- Branding ----------------- | |
| PRIMARY_COLOR = "#0F2C59" | |
| ACCENT_COLOR = "#FFD700" | |
| BG_COLOR = "#F5F5F5" | |
| st.set_page_config(page_title="Guess the Doodle - Procelevate", page_icon="✏️", layout="wide") | |
| # ----------------- Sidebar / Branding ----------------- | |
| # logo path works whether you're copying root or /src | |
| if Path("src/procelevate_logo.png").exists(): | |
| st.sidebar.image("src/procelevate_logo.png") | |
| elif Path("procelevate_logo.png").exists(): | |
| st.sidebar.image("procelevate_logo.png") | |
| st.sidebar.markdown( | |
| f"<h3 style='color:{PRIMARY_COLOR};'>🚀 Procelevate Academy</h3>" | |
| "<p>Where AI meets fun & learning!</p>", | |
| unsafe_allow_html=True | |
| ) | |
| # ----------------- Load classes (guarantee order) ----------------- | |
| # If you saved a classes.txt from Colab with the exact CATEGORIES order, place it next to the model. | |
| def load_classes(): | |
| for p in ["src/classes.txt", "classes.txt"]: | |
| if Path(p).exists(): | |
| text = Path(p).read_text().strip() | |
| arr = [x.strip() for x in text.splitlines() if x.strip()] | |
| if arr: | |
| return arr | |
| # fallback to the training order we used in Colab | |
| return ["cat","dog","car","tree","house","sun","airplane","fish","flower","bird"] | |
| CLASSES = load_classes() | |
| # ----------------- Load model robustly ----------------- | |
| def load_doodle_model(): | |
| for p in ["src/quickdraw_model.keras", "quickdraw_model.keras", "src/quickdraw_model.h5", "quickdraw_model.h5"]: | |
| if Path(p).exists() and Path(p).stat().st_size > 0: | |
| return load_model(p), p | |
| return None, None | |
| model, model_path = load_doodle_model() | |
| if model_path: | |
| st.caption(f"Model loaded: `{model_path}`") | |
| else: | |
| st.error("Model file not found. Upload `quickdraw_model.keras` (or .h5).") | |
| st.stop() | |
| # ----------------- Preprocessing controls ----------------- | |
| st.sidebar.markdown("### ✨ Preprocessing") | |
| invert_colors = st.sidebar.checkbox("Invert colors", value=True, | |
| help="QuickDraw bitmaps are often drawn as black-on-white but some pipelines invert; try this ON.") | |
| binary_thresh = st.sidebar.slider("Binarize threshold", 0, 255, 180, | |
| help="Higher = thinner lines preserved; lower = thicker.") | |
| thicken_strokes = st.sidebar.checkbox("Thicken strokes", value=True, | |
| help="Applies a slight dilation to strengthen faint lines.") | |
| show_processed = st.sidebar.checkbox("Show processed 28×28 preview", value=True) | |
| # ----------------- Title ----------------- | |
| st.markdown( | |
| f""" | |
| <div style='text-align:center;'> | |
| <h1 style='color:{PRIMARY_COLOR}; margin-bottom:0;'>🎨 Guess the Doodle – AI Game</h1> | |
| <p style='font-size:18px; margin-top:6px;'>Draw something below, then hit <b>Predict</b>.</p> | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| # ----------------- Canvas ----------------- | |
| left, right = st.columns([1,1.1]) | |
| with left: | |
| st.subheader("🖌️ Draw here") | |
| canvas_result = st_canvas( | |
| fill_color=BG_COLOR, stroke_width=10, stroke_color="black", | |
| background_color="white", width=280, height=280, | |
| drawing_mode="freedraw", key="canvas", | |
| ) | |
| c1, c2 = st.columns(2) | |
| with c1: | |
| do_pred = st.button("🔮 Predict") | |
| with c2: | |
| if st.button("♻️ Reset"): | |
| st.experimental_rerun() | |
| # ----------------- Preprocess helper ----------------- | |
| def preprocess_rgba_to_28x28(img_rgba: Image.Image) -> Image.Image: | |
| # Convert RGBA->L | |
| img = img_rgba.convert("L") | |
| # Optional invert (try True first; many QuickDraw pipelines use inverted inputs) | |
| if invert_colors: | |
| img = ImageOps.invert(img) | |
| # Binarize | |
| img = img.point(lambda x: 255 if x >= binary_thresh else 0, mode="1").convert("L") | |
| # Find bounding box of the drawing | |
| np_img = np.array(img) | |
| ys, xs = np.where(np_img < 255) # non-white pixels | |
| if len(xs) == 0 or len(ys) == 0: | |
| # empty canvas fallback | |
| return img.resize((28, 28)) | |
| xmin, xmax = xs.min(), xs.max() | |
| ymin, ymax = ys.min(), ys.max() | |
| # Crop to content | |
| img = img.crop((xmin, ymin, xmax+1, ymax+1)) | |
| # Make square (pad) | |
| w, h = img.size | |
| pad = abs(w - h) // 2 | |
| if w > h: | |
| img = ImageOps.expand(img, border=(0, pad, 0, w - h - pad), fill=255) | |
| elif h > w: | |
| img = ImageOps.expand(img, border=(pad, 0, w - h - pad, 0), fill=255) | |
| # Optional thicken | |
| if thicken_strokes: | |
| img = img.filter(ImageFilter.MaxFilter(3)) # light dilation | |
| # Resize to 28x28 | |
| img = img.resize((28, 28), resample=Image.NEAREST) | |
| return img | |
| # ----------------- Prediction ----------------- | |
| with right: | |
| st.markdown(f"<h2 style='color:{ACCENT_COLOR};'>🤖 AI’s Best Guesses:</h2>", unsafe_allow_html=True) | |
| if do_pred and canvas_result.image_data is not None: | |
| # 1) Preprocess to 28x28 | |
| raw = Image.fromarray((canvas_result.image_data).astype("uint8")) | |
| proc = preprocess_rgba_to_28x28(raw) | |
| if show_processed: | |
| st.caption("Model input preview (scaled up):") | |
| st.image(proc.resize((140, 140), resample=Image.NEAREST)) | |
| # 2) Model predict | |
| arr = img_to_array(proc) / 255.0 # (28,28,1) | |
| arr = np.expand_dims(arr, axis=0) # (1,28,28,1) | |
| preds = model.predict(arr, verbose=0)[0] # (10,) | |
| top = preds.argsort()[-3:][::-1] | |
| # 3) Show results | |
| emoji_map = {"cat":"🐱","dog":"🐶","car":"🚗","tree":"🌳","house":"🏠", | |
| "sun":"☀️","airplane":"✈️","fish":"🐟","flower":"🌸","bird":"🐦"} | |
| if preds[top[0]] >= 0.80: | |
| st.balloons() | |
| for idx in top: | |
| label = CLASSES[idx] | |
| score = float(preds[idx]) | |
| st.markdown( | |
| f"<h3 style='color:{PRIMARY_COLOR}; margin:6px 0 4px 0;'>{emoji_map.get(label,'🎯')} {label} – {score*100:.1f}%</h3>", | |
| unsafe_allow_html=True | |
| ) | |
| st.progress(score) | |
| # ----------------- Legend ----------------- | |
| st.markdown( | |
| "<hr><h4>📖 Supported Classes:</h4>" + | |
| " ".join([f"🐱 cat", "🐶 dog", "🚗 car", "🌳 tree", "🏠 house", "☀️ sun", "✈️ airplane", "🐟 fish", "🌸 flower", "🐦 bird"]), | |
| unsafe_allow_html=True | |
| ) | |
| st.success("💡 Tips: draw big, single clear object; avoid tiny sketches; use continuous strokes.") | |