Spaces:
Sleeping
Sleeping
Elliot Sones
commited on
Commit
·
76aaddb
0
Parent(s):
Initial commit: Animal Doodle Classifier
Browse files- .gitattributes +1 -0
- README.md +28 -0
- app.py +265 -0
- requirements.txt +4 -0
- rnn_animals_best.pt +3 -0
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Animal Doodle Classifier
|
| 3 |
+
emoji: 🎨
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: streamlit
|
| 7 |
+
sdk_version: "1.28.0"
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# 🎨 Animal Doodle Classifier
|
| 13 |
+
|
| 14 |
+
An RNN-based classifier that recognizes hand-drawn animal doodles in real-time!
|
| 15 |
+
|
| 16 |
+
## Supported Animals
|
| 17 |
+
- butterfly, cow, elephant, giraffe, monkey
|
| 18 |
+
- octopus, scorpion, shark, snake, spider
|
| 19 |
+
|
| 20 |
+
## Model
|
| 21 |
+
- **Architecture:** Bidirectional GRU
|
| 22 |
+
- **Accuracy:** 97.47% Top-1, 99.75% Top-3
|
| 23 |
+
- **Training Data:** Google Quick Draw dataset
|
| 24 |
+
|
| 25 |
+
## How It Works
|
| 26 |
+
1. Draw an animal on the canvas
|
| 27 |
+
2. Your strokes are captured and preprocessed to match Quick Draw format
|
| 28 |
+
3. The RNN model predicts which animal you drew
|
app.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RNN Animal Doodle Classifier - Hugging Face Spaces
|
| 3 |
+
Self-contained Streamlit app with embedded model class
|
| 4 |
+
"""
|
| 5 |
+
import ast
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import numpy as np
|
| 9 |
+
import streamlit as st
|
| 10 |
+
from streamlit_drawable_canvas import st_canvas
|
| 11 |
+
import torch
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
# ============================================================================
|
| 15 |
+
# Model Definition (embedded from training-doodle.py)
|
| 16 |
+
# ============================================================================
|
| 17 |
+
|
| 18 |
+
class GRUClassifier(nn.Module):
|
| 19 |
+
"""Bidirectional GRU classifier for sequence classification."""
|
| 20 |
+
def __init__(self, input_size: int, hidden_size: int, num_layers: int,
|
| 21 |
+
bidirectional: bool, dropout: float, num_classes: int, use_packing: bool = True):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.use_packing = use_packing
|
| 24 |
+
self.gru = nn.GRU(
|
| 25 |
+
input_size=input_size,
|
| 26 |
+
hidden_size=hidden_size,
|
| 27 |
+
num_layers=num_layers,
|
| 28 |
+
batch_first=True,
|
| 29 |
+
bidirectional=bidirectional,
|
| 30 |
+
dropout=dropout if num_layers > 1 else 0.0,
|
| 31 |
+
)
|
| 32 |
+
out_dim = hidden_size * (2 if bidirectional else 1)
|
| 33 |
+
self.norm = nn.LayerNorm(out_dim)
|
| 34 |
+
self.fc = nn.Linear(out_dim, num_classes)
|
| 35 |
+
|
| 36 |
+
def forward(self, x: torch.Tensor, lengths: torch.Tensor):
|
| 37 |
+
if self.use_packing:
|
| 38 |
+
packed = nn.utils.rnn.pack_padded_sequence(
|
| 39 |
+
x, lengths.cpu(), batch_first=True, enforce_sorted=False
|
| 40 |
+
)
|
| 41 |
+
_, h_n = self.gru(packed)
|
| 42 |
+
else:
|
| 43 |
+
_, h_n = self.gru(x)
|
| 44 |
+
if self.gru.bidirectional:
|
| 45 |
+
h = torch.cat([h_n[-2], h_n[-1]], dim=1)
|
| 46 |
+
else:
|
| 47 |
+
h = h_n[-1]
|
| 48 |
+
h = self.norm(h)
|
| 49 |
+
return self.fc(h)
|
| 50 |
+
|
| 51 |
+
def parse_drawing_to_seq(drawing_str: str) -> np.ndarray:
|
| 52 |
+
"""Convert drawing JSON to sequence of [dx, dy, pen_lift]."""
|
| 53 |
+
try:
|
| 54 |
+
strokes = json.loads(drawing_str)
|
| 55 |
+
except Exception:
|
| 56 |
+
try:
|
| 57 |
+
strokes = ast.literal_eval(drawing_str)
|
| 58 |
+
except Exception:
|
| 59 |
+
return np.zeros((0, 3), dtype=np.float32)
|
| 60 |
+
|
| 61 |
+
seq_parts = []
|
| 62 |
+
for stroke in strokes:
|
| 63 |
+
if not isinstance(stroke, (list, tuple)) or len(stroke) != 2:
|
| 64 |
+
continue
|
| 65 |
+
x, y = stroke
|
| 66 |
+
n = min(len(x), len(y))
|
| 67 |
+
if n < 2:
|
| 68 |
+
continue
|
| 69 |
+
x = np.asarray(x[:n], dtype=np.int16)
|
| 70 |
+
y = np.asarray(y[:n], dtype=np.int16)
|
| 71 |
+
dx = np.diff(x).astype(np.float32) / 255.0
|
| 72 |
+
dy = np.diff(y).astype(np.float32) / 255.0
|
| 73 |
+
if dx.size == 0:
|
| 74 |
+
continue
|
| 75 |
+
pen = np.zeros_like(dx, dtype=np.float32)
|
| 76 |
+
pen[-1] = 1.0
|
| 77 |
+
seq_parts.append(np.stack([dx, dy, pen], axis=1))
|
| 78 |
+
|
| 79 |
+
if not seq_parts:
|
| 80 |
+
return np.zeros((0, 3), dtype=np.float32)
|
| 81 |
+
|
| 82 |
+
seq = np.concatenate(seq_parts, axis=0)
|
| 83 |
+
seq[:, :2] = np.clip(seq[:, :2], -1.0, 1.0)
|
| 84 |
+
return seq.astype(np.float32)
|
| 85 |
+
|
| 86 |
+
# ============================================================================
|
| 87 |
+
# Constants
|
| 88 |
+
# ============================================================================
|
| 89 |
+
|
| 90 |
+
CANVAS_SIZE = 400
|
| 91 |
+
STROKE_WIDTH = 3
|
| 92 |
+
ANIMALS = ["butterfly", "cow", "elephant", "giraffe", "monkey",
|
| 93 |
+
"octopus", "scorpion", "shark", "snake", "spider"]
|
| 94 |
+
|
| 95 |
+
CALIB_TARGET_MEAN = 0.04
|
| 96 |
+
CALIB_MAX_GAIN = 12.0
|
| 97 |
+
CALIB_MIN_GAIN = 0.5
|
| 98 |
+
|
| 99 |
+
def _calibrate_seq(seq: np.ndarray) -> np.ndarray:
|
| 100 |
+
"""Scale (dx, dy) so the mean step magnitude matches training data."""
|
| 101 |
+
if seq is None or seq.ndim != 2 or seq.shape[1] < 2 or seq.shape[0] == 0:
|
| 102 |
+
return seq
|
| 103 |
+
steps = np.sqrt((seq[:, 0] ** 2) + (seq[:, 1] ** 2))
|
| 104 |
+
curr = float(steps.mean()) if steps.size else 0.0
|
| 105 |
+
if curr <= 1e-6:
|
| 106 |
+
return seq
|
| 107 |
+
gain = float(np.clip(CALIB_TARGET_MEAN / curr, CALIB_MIN_GAIN, CALIB_MAX_GAIN))
|
| 108 |
+
out = seq.astype(np.float32).copy()
|
| 109 |
+
out[:, 0:2] = np.clip(out[:, 0:2] * gain, -1.0, 1.0)
|
| 110 |
+
return out
|
| 111 |
+
|
| 112 |
+
# ============================================================================
|
| 113 |
+
# Model Loading
|
| 114 |
+
# ============================================================================
|
| 115 |
+
|
| 116 |
+
@st.cache_resource
|
| 117 |
+
def load_model():
|
| 118 |
+
"""Load the trained RNN model."""
|
| 119 |
+
model_path = Path(__file__).parent / "rnn_animals_best.pt"
|
| 120 |
+
if not model_path.exists():
|
| 121 |
+
st.error(f"Model file not found: {model_path}")
|
| 122 |
+
return None, None
|
| 123 |
+
|
| 124 |
+
checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
|
| 125 |
+
cfg = checkpoint.get("config", {})
|
| 126 |
+
|
| 127 |
+
model = GRUClassifier(
|
| 128 |
+
input_size=3,
|
| 129 |
+
hidden_size=cfg.get("hidden_size", 512),
|
| 130 |
+
num_layers=cfg.get("num_layers", 2),
|
| 131 |
+
bidirectional=cfg.get("bidirectional", True),
|
| 132 |
+
dropout=cfg.get("dropout", 0.3),
|
| 133 |
+
num_classes=len(ANIMALS),
|
| 134 |
+
use_packing=True
|
| 135 |
+
)
|
| 136 |
+
model.load_state_dict(checkpoint["model_state"])
|
| 137 |
+
model.eval()
|
| 138 |
+
|
| 139 |
+
class_to_idx = checkpoint.get("class_to_idx", {a: i for i, a in enumerate(ANIMALS)})
|
| 140 |
+
idx_to_class = {v: k for k, v in class_to_idx.items()}
|
| 141 |
+
return model, idx_to_class
|
| 142 |
+
|
| 143 |
+
# ============================================================================
|
| 144 |
+
# Stroke Processing
|
| 145 |
+
# ============================================================================
|
| 146 |
+
|
| 147 |
+
def canvas_strokes_to_quickdraw(canvas_json):
|
| 148 |
+
"""Convert canvas to QuickDraw format with preprocessing."""
|
| 149 |
+
if canvas_json is None:
|
| 150 |
+
return []
|
| 151 |
+
|
| 152 |
+
objects = canvas_json.get("objects", [])
|
| 153 |
+
raw_strokes = []
|
| 154 |
+
|
| 155 |
+
for obj in objects:
|
| 156 |
+
if obj.get("type") != "path":
|
| 157 |
+
continue
|
| 158 |
+
path = obj.get("path", [])
|
| 159 |
+
xs, ys = [], []
|
| 160 |
+
for cmd in path:
|
| 161 |
+
if len(cmd) < 3:
|
| 162 |
+
continue
|
| 163 |
+
if cmd[0] == "M":
|
| 164 |
+
xs.append(float(cmd[1]))
|
| 165 |
+
ys.append(float(cmd[2]))
|
| 166 |
+
elif cmd[0] == "Q" and len(cmd) >= 5:
|
| 167 |
+
xs.append(float(cmd[3]))
|
| 168 |
+
ys.append(float(cmd[4]))
|
| 169 |
+
elif cmd[0] == "L":
|
| 170 |
+
xs.append(float(cmd[1]))
|
| 171 |
+
ys.append(float(cmd[2]))
|
| 172 |
+
if len(xs) >= 2:
|
| 173 |
+
raw_strokes.append((xs, ys))
|
| 174 |
+
|
| 175 |
+
if not raw_strokes:
|
| 176 |
+
return []
|
| 177 |
+
|
| 178 |
+
# Downsample
|
| 179 |
+
downsampled = []
|
| 180 |
+
for xs, ys in raw_strokes:
|
| 181 |
+
if len(xs) > 25:
|
| 182 |
+
step = max(1, len(xs) // 25)
|
| 183 |
+
xs, ys = xs[::step], ys[::step]
|
| 184 |
+
downsampled.append((xs, ys))
|
| 185 |
+
|
| 186 |
+
# Smooth
|
| 187 |
+
smoothed = []
|
| 188 |
+
for xs, ys in downsampled:
|
| 189 |
+
if len(xs) >= 3:
|
| 190 |
+
xs_s = [xs[0]] + [(xs[i-1]+xs[i]+xs[i+1])/3 for i in range(1, len(xs)-1)] + [xs[-1]]
|
| 191 |
+
ys_s = [ys[0]] + [(ys[i-1]+ys[i]+ys[i+1])/3 for i in range(1, len(ys)-1)] + [ys[-1]]
|
| 192 |
+
smoothed.append((xs_s, ys_s))
|
| 193 |
+
else:
|
| 194 |
+
smoothed.append((xs, ys))
|
| 195 |
+
|
| 196 |
+
# Center and scale
|
| 197 |
+
all_x = [x for xs, _ in smoothed for x in xs]
|
| 198 |
+
all_y = [y for _, ys in smoothed for y in ys]
|
| 199 |
+
min_x, max_x = min(all_x), max(all_x)
|
| 200 |
+
min_y, max_y = min(all_y), max(all_y)
|
| 201 |
+
|
| 202 |
+
scale = 235 / max(max(1, max_x - min_x), max(1, max_y - min_y))
|
| 203 |
+
cx, cy = (min_x + max_x) / 2, (min_y + max_y) / 2
|
| 204 |
+
ox, oy = 127.5 - cx * scale, 127.5 - cy * scale
|
| 205 |
+
|
| 206 |
+
result = []
|
| 207 |
+
for xs, ys in smoothed:
|
| 208 |
+
xs_n = [int(np.clip(x * scale + ox, 0, 255)) for x in xs]
|
| 209 |
+
ys_n = [int(np.clip(y * scale + oy, 0, 255)) for y in ys]
|
| 210 |
+
result.append([xs_n, ys_n])
|
| 211 |
+
return result
|
| 212 |
+
|
| 213 |
+
def predict(model, idx_to_class, strokes):
|
| 214 |
+
"""Make prediction from strokes."""
|
| 215 |
+
if not strokes or model is None:
|
| 216 |
+
return None
|
| 217 |
+
try:
|
| 218 |
+
seq = parse_drawing_to_seq(json.dumps(strokes))
|
| 219 |
+
if seq is None or len(seq) < 6:
|
| 220 |
+
return None
|
| 221 |
+
seq = _calibrate_seq(seq)
|
| 222 |
+
seq_t = torch.tensor(seq, dtype=torch.float32).unsqueeze(0)
|
| 223 |
+
lengths = torch.tensor([seq.shape[0]], dtype=torch.long)
|
| 224 |
+
with torch.no_grad():
|
| 225 |
+
probs = torch.softmax(model(seq_t, lengths), dim=1)
|
| 226 |
+
top_p, top_i = torch.topk(probs, k=5, dim=1)
|
| 227 |
+
return [(idx_to_class.get(top_i[0,i].item()), top_p[0,i].item()) for i in range(5)]
|
| 228 |
+
except Exception as e:
|
| 229 |
+
st.error(f"Error: {e}")
|
| 230 |
+
return None
|
| 231 |
+
|
| 232 |
+
# ============================================================================
|
| 233 |
+
# Main App
|
| 234 |
+
# ============================================================================
|
| 235 |
+
|
| 236 |
+
def main():
|
| 237 |
+
st.set_page_config(page_title="Animal Doodle Classifier", page_icon="🎨", layout="wide")
|
| 238 |
+
st.title("🎨 Animal Doodle Classifier")
|
| 239 |
+
st.caption("Draw: butterfly, cow, elephant, giraffe, monkey, octopus, scorpion, shark, snake, spider")
|
| 240 |
+
|
| 241 |
+
model, idx_to_class = load_model()
|
| 242 |
+
if model is None:
|
| 243 |
+
return
|
| 244 |
+
|
| 245 |
+
col1, col2 = st.columns([1, 1])
|
| 246 |
+
with col1:
|
| 247 |
+
canvas = st_canvas(
|
| 248 |
+
stroke_width=STROKE_WIDTH, stroke_color="#000000",
|
| 249 |
+
background_color="#FFFFFF", height=CANVAS_SIZE, width=CANVAS_SIZE,
|
| 250 |
+
drawing_mode="freedraw", key="canvas"
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
with col2:
|
| 254 |
+
st.subheader("Predictions")
|
| 255 |
+
if canvas.json_data:
|
| 256 |
+
strokes = canvas_strokes_to_quickdraw(canvas.json_data)
|
| 257 |
+
if strokes:
|
| 258 |
+
results = predict(model, idx_to_class, strokes)
|
| 259 |
+
if results:
|
| 260 |
+
st.success(f"**{results[0][0].upper()}** ({results[0][1]*100:.1f}%)")
|
| 261 |
+
for name, prob in results:
|
| 262 |
+
st.progress(prob, text=f"{name}: {prob*100:.1f}%")
|
| 263 |
+
|
| 264 |
+
if __name__ == "__main__":
|
| 265 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit>=1.28.0
|
| 2 |
+
streamlit-drawable-canvas>=0.9.3
|
| 3 |
+
torch>=2.0.0
|
| 4 |
+
numpy>=1.24.0
|
rnn_animals_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3045301cb82537f7ccdaa2271c0d0944470d6a5079cc633a1bfcafb2198ac895
|
| 3 |
+
size 44206972
|