"""
app.py - ScribblBot inference application.
"""
import sys
from pathlib import Path
from typing import Optional
import gradio as gr
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
sys.path.insert(0, str(Path(__file__).parent))
from config import CLASSES, CLASS_EMOJIS, MODELS_DIR, NUM_CLASSES
from scripts.model import ScribblNet
def _load_model() -> tuple[ScribblNet, torch.device]:
"""Load trained ScribblNet weights from disk."""
device = torch.device("cpu")
model_path = MODELS_DIR / "deep_model.pth"
if not model_path.exists():
raise FileNotFoundError(f"Weights not found at {model_path}. Run python setup.py first.")
model = ScribblNet(num_classes=NUM_CLASSES)
model.load_state_dict(torch.load(model_path, map_location=device, weights_only=False))
model.eval()
return model, device
MODEL, DEVICE = _load_model()
def predict(sketch: Optional[dict], _counter: int) -> tuple[str, int]:
"""Run inference on an ImageEditor drawing.
Args:
sketch: Dict from gr.ImageEditor with 'composite' key.
_counter: Click counter used to bust Gradio output caching.
Returns:
Tuple of (HTML results string, incremented counter).
"""
_counter += 1
if sketch is None:
return _empty_state_html(), _counter
img_array = sketch.get("composite") if isinstance(sketch, dict) else sketch
if img_array is None:
return _empty_state_html(), _counter
try:
img_pil = Image.fromarray(img_array.astype(np.uint8))
if img_pil.mode == "RGBA":
white = Image.new("RGBA", img_pil.size, (248, 247, 242, 255))
img_pil = Image.alpha_composite(white, img_pil).convert("L")
else:
img_pil = img_pil.convert("L")
img_pil = img_pil.resize((28, 28), Image.LANCZOS)
arr = np.array(img_pil, dtype=np.float32)
arr = (255.0 - arr) / 255.0
tensor = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0).to(DEVICE)
except Exception as exc:
return _error_html(str(exc)), _counter
with torch.no_grad():
probs = F.softmax(MODEL(tensor), dim=1)[0].cpu().numpy()
top = [(CLASSES[i], float(probs[i])) for i in np.argsort(probs)[::-1][:5]]
return _results_html(top), _counter
def _results_html(top: list[tuple[str, float]]) -> str:
best_cls, best_prob = top[0]
conf_pct = best_prob * 100
label = "CONFIDENT" if best_prob > 0.7 else "LIKELY" if best_prob > 0.4 else "UNSURE"
bars = ""
for i, (cls, prob) in enumerate(top):
pct = prob * 100
bars += f"""
{CLASS_EMOJIS.get(cls,'')}
{cls.upper()}
{pct:.1f}%
"""
return f"""
[ PREDICTION ]
{CLASS_EMOJIS.get(best_cls,'?')}
{best_cls.upper()}
{conf_pct:.1f}% ยท {label}
TOP 5 PROBABILITIES
{bars}
"""
def _empty_state_html() -> str:
return """
โ๏ธ
DRAW SOMETHING
then hit ANALYZE
๐ฑ cat๐ถ dog
๐ pizza๐ฒ bicycle
๐ houseโ๏ธ sun
๐ณ tree๐ car
๐ fish๐ฆ butterfly
๐ธ guitar๐ hamburger
โ๏ธ airplane๐ banana
โญ star
"""
def _error_html(msg: str) -> str:
return f''
CUSTOM_CSS = """
@import url('https://fonts.googleapis.com/css2?family=VT323&family=IBM+Plex+Mono:wght@400;500&display=swap');
:root {
--bg: #080808;
--surface: #111111;
--surface2: #1a1a1a;
--border: #2a2a2a;
--accent: #b8ff57;
--text: #e8e8e0;
--text-muted:#888880;
--red: #ff5f57;
--mono: 'IBM Plex Mono', monospace;
--display: 'VT323', monospace;
}
body, .gradio-container, #root {
background: var(--bg) !important;
font-family: var(--mono) !important;
color: var(--text) !important;
}
.gradio-container { max-width: 1100px !important; margin: 0 auto !important; }
footer { display: none !important; }
.block, .gr-box { background: transparent !important; border: none !important; box-shadow: none !important; }
.app-header { text-align: center; padding: 36px 20px 20px; border-bottom: 1px solid var(--border); margin-bottom: 28px; }
.app-title { font-family: var(--display); font-size: 72px; line-height: 1; color: var(--accent); letter-spacing: 6px; text-shadow: 0 0 30px rgba(184,255,87,0.3); margin: 0; }
.app-subtitle { font-size: 12px; color: var(--text-muted); letter-spacing: 4px; margin-top: 6px; }
/* ImageEditor styling */
/* Override Gradio's orange accent with our green */
.sketch-col { --color-accent: #b8ff57 !important; --color-accent-soft: rgba(184,255,87,0.15) !important; }
.sketch-col .image-editor { border: 1px solid var(--border) !important; border-radius: 4px !important; background: var(--surface) !important; }
/* Hide color picker and swatch - we only need pen and eraser */
.sketch-col [aria-label="Color"],
.sketch-col [title="Color"],
.sketch-col .image-editor .toolbar > button:nth-child(3),
.sketch-col .image-editor .toolbar > button:nth-child(4) { display: none !important; }
/* Toolbar background */
.sketch-col .image-editor > div { background: var(--surface2) !important; }
/* All buttons */
.sketch-col .image-editor button {
background: var(--surface2) !important;
border: 1px solid var(--border) !important;
border-radius: 3px !important;
margin: 2px !important;
color: var(--text) !important;
}
.sketch-col .image-editor button:hover {
background: var(--accent) !important;
border-color: var(--accent) !important;
color: #000 !important;
}
/* Active tool */
.sketch-col .image-editor button[aria-pressed="true"] {
border: 2px solid var(--accent) !important;
background: rgba(184,255,87,0.15) !important;
color: var(--accent) !important;
}
/* Force all SVG icons to white/text color */
.sketch-col .image-editor svg * { color: inherit !important; stroke: currentColor !important; }
.sketch-col [data-testid="layer-wrap"] { display: none !important; }
.sketch-col .layers-panel { display: none !important; }
/* White canvas */
.sketch-col .konvajs-content,
.sketch-col .konvajs-content canvas,
.sketch-col canvas { background: #f8f7f2 !important; background-color: #f8f7f2 !important; }
.sketch-col canvas { cursor: crosshair !important; }
.sketch-col * { cursor: auto; }
.sketch-col canvas { cursor: crosshair !important; }
.results-panel { background: var(--surface); border: 1px solid var(--border); border-radius: 4px; padding: 20px; min-height: 420px; font-family: var(--mono); }
.result-tag { font-size: 11px; color: var(--accent); letter-spacing: 3px; margin-bottom: 16px; }
.top-result { display: flex; align-items: center; gap: 18px; margin-bottom: 18px; }
.top-emoji { font-size: 56px; line-height: 1; }
.top-label { font-family: var(--display); font-size: 52px; color: var(--text); line-height: 1; letter-spacing: 3px; }
.top-conf { font-size: 13px; color: var(--accent); margin-top: 4px; }
.divider { height: 1px; background: var(--border); margin: 16px 0; }
.section-label { font-size: 10px; color: var(--text-muted); letter-spacing: 3px; margin-bottom: 12px; }
.bar-row { display: grid; grid-template-columns: 28px 90px 1fr 50px; align-items: center; gap: 8px; margin-bottom: 10px; opacity: 0; animation: slideIn 0.3s ease forwards; }
.bar-emoji { font-size: 16px; text-align: center; }
.bar-label { font-size: 11px; color: var(--text-muted); letter-spacing: 1px; }
.bar-track { height: 6px; background: var(--surface2); border-radius: 3px; overflow: hidden; }
.bar-fill { height: 100%; background: var(--accent); border-radius: 3px; width: 0; animation: barGrow 0.4s ease forwards; }
.bar-pct { font-size: 11px; color: var(--text); text-align: right; }
.empty-state { display: flex; flex-direction: column; align-items: center; justify-content: center; min-height: 360px; }
.empty-icon { font-size: 48px; margin-bottom: 12px; }
.empty-title { font-family: var(--display); font-size: 36px; color: var(--accent); letter-spacing: 3px; }
.empty-sub { font-size: 12px; color: var(--text-muted); margin: 4px 0 24px; letter-spacing: 2px; }
.class-pills { display: flex; flex-wrap: wrap; gap: 6px; justify-content: center; max-width: 340px; }
.pill { background: var(--surface2); border: 1px solid var(--border); padding: 3px 10px; border-radius: 20px; font-size: 11px; color: var(--text-muted); }
.error-state { display: flex; align-items: center; justify-content: center; min-height: 200px; }
.err-msg { font-size: 13px; color: var(--red); }
.analyze-row { padding: 12px 0 0 !important; }
.analyze-row button { width: 100% !important; background: rgba(184,255,87,0.06) !important; border: 2px solid var(--accent) !important; color: var(--accent) !important; font-family: var(--mono) !important; font-size: 15px !important; letter-spacing: 4px !important; padding: 14px !important; border-radius: 2px !important; cursor: pointer !important; transition: background 0.15s !important; }
.analyze-row button:hover { background: var(--accent) !important; color: #000 !important; }
.app-footer { text-align: center; padding: 18px; font-size: 11px; color: var(--text-muted); letter-spacing: 1px; border-top: 1px solid var(--border); margin-top: 16px; }
@keyframes slideIn { from { opacity: 0; transform: translateX(-8px); } to { opacity: 1; transform: translateX(0); } }
@keyframes barGrow { from { width: 0; } }
.fade-in { animation: fadeIn 0.25s ease; }
@keyframes fadeIn { from { opacity: 0; } to { opacity: 1; } }
"""
def build_app() -> gr.Blocks:
"""Construct and return the Gradio Blocks application."""
with gr.Blocks(css=CUSTOM_CSS, title="ScribblBot") as app:
gr.HTML("""
""")
click_counter = gr.State(0)
with gr.Row():
with gr.Column(elem_classes=["sketch-col"]):
sketch_input = gr.ImageEditor(
type="numpy",
image_mode="RGBA",
canvas_size=(480, 480),
layers=False,
sources=[],
brush=gr.Brush(
colors=["#111111"],
default_size=14,
color_mode="fixed",
),
eraser=gr.Eraser(default_size=20),
show_label=False,
)
with gr.Column():
result_html = gr.HTML(_empty_state_html())
with gr.Row(elem_classes=["analyze-row"]):
analyze_btn = gr.Button("ANALYZE")
gr.HTML('')
analyze_btn.click(
fn=predict,
inputs=[sketch_input, click_counter],
outputs=[result_html, click_counter],
)
return app
if __name__ == "__main__":
demo = build_app()
demo.launch()