Elliot Sones commited on
Commit
76aaddb
·
0 Parent(s):

Initial commit: Animal Doodle Classifier

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. README.md +28 -0
  3. app.py +265 -0
  4. requirements.txt +4 -0
  5. rnn_animals_best.pt +3 -0
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Animal Doodle Classifier
3
+ emoji: 🎨
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: streamlit
7
+ sdk_version: "1.28.0"
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # 🎨 Animal Doodle Classifier
13
+
14
+ An RNN-based classifier that recognizes hand-drawn animal doodles in real-time!
15
+
16
+ ## Supported Animals
17
+ - butterfly, cow, elephant, giraffe, monkey
18
+ - octopus, scorpion, shark, snake, spider
19
+
20
+ ## Model
21
+ - **Architecture:** Bidirectional GRU
22
+ - **Accuracy:** 97.47% Top-1, 99.75% Top-3
23
+ - **Training Data:** Google Quick Draw dataset
24
+
25
+ ## How It Works
26
+ 1. Draw an animal on the canvas
27
+ 2. Your strokes are captured and preprocessed to match Quick Draw format
28
+ 3. The RNN model predicts which animal you drew
app.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RNN Animal Doodle Classifier - Hugging Face Spaces
3
+ Self-contained Streamlit app with embedded model class
4
+ """
5
+ import ast
6
+ import json
7
+ from pathlib import Path
8
+ import numpy as np
9
+ import streamlit as st
10
+ from streamlit_drawable_canvas import st_canvas
11
+ import torch
12
+ from torch import nn
13
+
14
+ # ============================================================================
15
+ # Model Definition (embedded from training-doodle.py)
16
+ # ============================================================================
17
+
18
+ class GRUClassifier(nn.Module):
19
+ """Bidirectional GRU classifier for sequence classification."""
20
+ def __init__(self, input_size: int, hidden_size: int, num_layers: int,
21
+ bidirectional: bool, dropout: float, num_classes: int, use_packing: bool = True):
22
+ super().__init__()
23
+ self.use_packing = use_packing
24
+ self.gru = nn.GRU(
25
+ input_size=input_size,
26
+ hidden_size=hidden_size,
27
+ num_layers=num_layers,
28
+ batch_first=True,
29
+ bidirectional=bidirectional,
30
+ dropout=dropout if num_layers > 1 else 0.0,
31
+ )
32
+ out_dim = hidden_size * (2 if bidirectional else 1)
33
+ self.norm = nn.LayerNorm(out_dim)
34
+ self.fc = nn.Linear(out_dim, num_classes)
35
+
36
+ def forward(self, x: torch.Tensor, lengths: torch.Tensor):
37
+ if self.use_packing:
38
+ packed = nn.utils.rnn.pack_padded_sequence(
39
+ x, lengths.cpu(), batch_first=True, enforce_sorted=False
40
+ )
41
+ _, h_n = self.gru(packed)
42
+ else:
43
+ _, h_n = self.gru(x)
44
+ if self.gru.bidirectional:
45
+ h = torch.cat([h_n[-2], h_n[-1]], dim=1)
46
+ else:
47
+ h = h_n[-1]
48
+ h = self.norm(h)
49
+ return self.fc(h)
50
+
51
+ def parse_drawing_to_seq(drawing_str: str) -> np.ndarray:
52
+ """Convert drawing JSON to sequence of [dx, dy, pen_lift]."""
53
+ try:
54
+ strokes = json.loads(drawing_str)
55
+ except Exception:
56
+ try:
57
+ strokes = ast.literal_eval(drawing_str)
58
+ except Exception:
59
+ return np.zeros((0, 3), dtype=np.float32)
60
+
61
+ seq_parts = []
62
+ for stroke in strokes:
63
+ if not isinstance(stroke, (list, tuple)) or len(stroke) != 2:
64
+ continue
65
+ x, y = stroke
66
+ n = min(len(x), len(y))
67
+ if n < 2:
68
+ continue
69
+ x = np.asarray(x[:n], dtype=np.int16)
70
+ y = np.asarray(y[:n], dtype=np.int16)
71
+ dx = np.diff(x).astype(np.float32) / 255.0
72
+ dy = np.diff(y).astype(np.float32) / 255.0
73
+ if dx.size == 0:
74
+ continue
75
+ pen = np.zeros_like(dx, dtype=np.float32)
76
+ pen[-1] = 1.0
77
+ seq_parts.append(np.stack([dx, dy, pen], axis=1))
78
+
79
+ if not seq_parts:
80
+ return np.zeros((0, 3), dtype=np.float32)
81
+
82
+ seq = np.concatenate(seq_parts, axis=0)
83
+ seq[:, :2] = np.clip(seq[:, :2], -1.0, 1.0)
84
+ return seq.astype(np.float32)
85
+
86
+ # ============================================================================
87
+ # Constants
88
+ # ============================================================================
89
+
90
+ CANVAS_SIZE = 400
91
+ STROKE_WIDTH = 3
92
+ ANIMALS = ["butterfly", "cow", "elephant", "giraffe", "monkey",
93
+ "octopus", "scorpion", "shark", "snake", "spider"]
94
+
95
+ CALIB_TARGET_MEAN = 0.04
96
+ CALIB_MAX_GAIN = 12.0
97
+ CALIB_MIN_GAIN = 0.5
98
+
99
+ def _calibrate_seq(seq: np.ndarray) -> np.ndarray:
100
+ """Scale (dx, dy) so the mean step magnitude matches training data."""
101
+ if seq is None or seq.ndim != 2 or seq.shape[1] < 2 or seq.shape[0] == 0:
102
+ return seq
103
+ steps = np.sqrt((seq[:, 0] ** 2) + (seq[:, 1] ** 2))
104
+ curr = float(steps.mean()) if steps.size else 0.0
105
+ if curr <= 1e-6:
106
+ return seq
107
+ gain = float(np.clip(CALIB_TARGET_MEAN / curr, CALIB_MIN_GAIN, CALIB_MAX_GAIN))
108
+ out = seq.astype(np.float32).copy()
109
+ out[:, 0:2] = np.clip(out[:, 0:2] * gain, -1.0, 1.0)
110
+ return out
111
+
112
+ # ============================================================================
113
+ # Model Loading
114
+ # ============================================================================
115
+
116
+ @st.cache_resource
117
+ def load_model():
118
+ """Load the trained RNN model."""
119
+ model_path = Path(__file__).parent / "rnn_animals_best.pt"
120
+ if not model_path.exists():
121
+ st.error(f"Model file not found: {model_path}")
122
+ return None, None
123
+
124
+ checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
125
+ cfg = checkpoint.get("config", {})
126
+
127
+ model = GRUClassifier(
128
+ input_size=3,
129
+ hidden_size=cfg.get("hidden_size", 512),
130
+ num_layers=cfg.get("num_layers", 2),
131
+ bidirectional=cfg.get("bidirectional", True),
132
+ dropout=cfg.get("dropout", 0.3),
133
+ num_classes=len(ANIMALS),
134
+ use_packing=True
135
+ )
136
+ model.load_state_dict(checkpoint["model_state"])
137
+ model.eval()
138
+
139
+ class_to_idx = checkpoint.get("class_to_idx", {a: i for i, a in enumerate(ANIMALS)})
140
+ idx_to_class = {v: k for k, v in class_to_idx.items()}
141
+ return model, idx_to_class
142
+
143
+ # ============================================================================
144
+ # Stroke Processing
145
+ # ============================================================================
146
+
147
+ def canvas_strokes_to_quickdraw(canvas_json):
148
+ """Convert canvas to QuickDraw format with preprocessing."""
149
+ if canvas_json is None:
150
+ return []
151
+
152
+ objects = canvas_json.get("objects", [])
153
+ raw_strokes = []
154
+
155
+ for obj in objects:
156
+ if obj.get("type") != "path":
157
+ continue
158
+ path = obj.get("path", [])
159
+ xs, ys = [], []
160
+ for cmd in path:
161
+ if len(cmd) < 3:
162
+ continue
163
+ if cmd[0] == "M":
164
+ xs.append(float(cmd[1]))
165
+ ys.append(float(cmd[2]))
166
+ elif cmd[0] == "Q" and len(cmd) >= 5:
167
+ xs.append(float(cmd[3]))
168
+ ys.append(float(cmd[4]))
169
+ elif cmd[0] == "L":
170
+ xs.append(float(cmd[1]))
171
+ ys.append(float(cmd[2]))
172
+ if len(xs) >= 2:
173
+ raw_strokes.append((xs, ys))
174
+
175
+ if not raw_strokes:
176
+ return []
177
+
178
+ # Downsample
179
+ downsampled = []
180
+ for xs, ys in raw_strokes:
181
+ if len(xs) > 25:
182
+ step = max(1, len(xs) // 25)
183
+ xs, ys = xs[::step], ys[::step]
184
+ downsampled.append((xs, ys))
185
+
186
+ # Smooth
187
+ smoothed = []
188
+ for xs, ys in downsampled:
189
+ if len(xs) >= 3:
190
+ xs_s = [xs[0]] + [(xs[i-1]+xs[i]+xs[i+1])/3 for i in range(1, len(xs)-1)] + [xs[-1]]
191
+ ys_s = [ys[0]] + [(ys[i-1]+ys[i]+ys[i+1])/3 for i in range(1, len(ys)-1)] + [ys[-1]]
192
+ smoothed.append((xs_s, ys_s))
193
+ else:
194
+ smoothed.append((xs, ys))
195
+
196
+ # Center and scale
197
+ all_x = [x for xs, _ in smoothed for x in xs]
198
+ all_y = [y for _, ys in smoothed for y in ys]
199
+ min_x, max_x = min(all_x), max(all_x)
200
+ min_y, max_y = min(all_y), max(all_y)
201
+
202
+ scale = 235 / max(max(1, max_x - min_x), max(1, max_y - min_y))
203
+ cx, cy = (min_x + max_x) / 2, (min_y + max_y) / 2
204
+ ox, oy = 127.5 - cx * scale, 127.5 - cy * scale
205
+
206
+ result = []
207
+ for xs, ys in smoothed:
208
+ xs_n = [int(np.clip(x * scale + ox, 0, 255)) for x in xs]
209
+ ys_n = [int(np.clip(y * scale + oy, 0, 255)) for y in ys]
210
+ result.append([xs_n, ys_n])
211
+ return result
212
+
213
+ def predict(model, idx_to_class, strokes):
214
+ """Make prediction from strokes."""
215
+ if not strokes or model is None:
216
+ return None
217
+ try:
218
+ seq = parse_drawing_to_seq(json.dumps(strokes))
219
+ if seq is None or len(seq) < 6:
220
+ return None
221
+ seq = _calibrate_seq(seq)
222
+ seq_t = torch.tensor(seq, dtype=torch.float32).unsqueeze(0)
223
+ lengths = torch.tensor([seq.shape[0]], dtype=torch.long)
224
+ with torch.no_grad():
225
+ probs = torch.softmax(model(seq_t, lengths), dim=1)
226
+ top_p, top_i = torch.topk(probs, k=5, dim=1)
227
+ return [(idx_to_class.get(top_i[0,i].item()), top_p[0,i].item()) for i in range(5)]
228
+ except Exception as e:
229
+ st.error(f"Error: {e}")
230
+ return None
231
+
232
+ # ============================================================================
233
+ # Main App
234
+ # ============================================================================
235
+
236
+ def main():
237
+ st.set_page_config(page_title="Animal Doodle Classifier", page_icon="🎨", layout="wide")
238
+ st.title("🎨 Animal Doodle Classifier")
239
+ st.caption("Draw: butterfly, cow, elephant, giraffe, monkey, octopus, scorpion, shark, snake, spider")
240
+
241
+ model, idx_to_class = load_model()
242
+ if model is None:
243
+ return
244
+
245
+ col1, col2 = st.columns([1, 1])
246
+ with col1:
247
+ canvas = st_canvas(
248
+ stroke_width=STROKE_WIDTH, stroke_color="#000000",
249
+ background_color="#FFFFFF", height=CANVAS_SIZE, width=CANVAS_SIZE,
250
+ drawing_mode="freedraw", key="canvas"
251
+ )
252
+
253
+ with col2:
254
+ st.subheader("Predictions")
255
+ if canvas.json_data:
256
+ strokes = canvas_strokes_to_quickdraw(canvas.json_data)
257
+ if strokes:
258
+ results = predict(model, idx_to_class, strokes)
259
+ if results:
260
+ st.success(f"**{results[0][0].upper()}** ({results[0][1]*100:.1f}%)")
261
+ for name, prob in results:
262
+ st.progress(prob, text=f"{name}: {prob*100:.1f}%")
263
+
264
+ if __name__ == "__main__":
265
+ main()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ streamlit>=1.28.0
2
+ streamlit-drawable-canvas>=0.9.3
3
+ torch>=2.0.0
4
+ numpy>=1.24.0
rnn_animals_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3045301cb82537f7ccdaa2271c0d0944470d6a5079cc633a1bfcafb2198ac895
3
+ size 44206972