Elliot Sones commited on
Commit
5fcc2e6
·
1 Parent(s): 76aaddb

Switch to Gradio with custom canvas for HF Spaces

Browse files
Files changed (3) hide show
  1. README.md +7 -7
  2. app.py +208 -139
  3. requirements.txt +1 -2
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Animal Doodle Classifier
3
  emoji: 🎨
4
  colorFrom: blue
5
  colorTo: purple
6
- sdk: streamlit
7
- sdk_version: "1.28.0"
8
  app_file: app.py
9
  pinned: false
10
  ---
@@ -14,8 +14,7 @@ pinned: false
14
  An RNN-based classifier that recognizes hand-drawn animal doodles in real-time!
15
 
16
  ## Supported Animals
17
- - butterfly, cow, elephant, giraffe, monkey
18
- - octopus, scorpion, shark, snake, spider
19
 
20
  ## Model
21
  - **Architecture:** Bidirectional GRU
@@ -24,5 +23,6 @@ An RNN-based classifier that recognizes hand-drawn animal doodles in real-time!
24
 
25
  ## How It Works
26
  1. Draw an animal on the canvas
27
- 2. Your strokes are captured and preprocessed to match Quick Draw format
28
- 3. The RNN model predicts which animal you drew
 
 
1
  ---
2
+ title: Classification Doodle RNN
3
  emoji: 🎨
4
  colorFrom: blue
5
  colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: "4.44.0"
8
  app_file: app.py
9
  pinned: false
10
  ---
 
14
  An RNN-based classifier that recognizes hand-drawn animal doodles in real-time!
15
 
16
  ## Supported Animals
17
+ butterfly, cow, elephant, giraffe, monkey, octopus, scorpion, shark, snake, spider
 
18
 
19
  ## Model
20
  - **Architecture:** Bidirectional GRU
 
23
 
24
  ## How It Works
25
  1. Draw an animal on the canvas
26
+ 2. Click **Predict**
27
+ 3. Your strokes are captured and preprocessed to match Quick Draw format
28
+ 4. The RNN model predicts which animal you drew
app.py CHANGED
@@ -1,18 +1,17 @@
1
  """
2
- RNN Animal Doodle Classifier - Hugging Face Spaces
3
- Self-contained Streamlit app with embedded model class
4
  """
5
  import ast
6
  import json
7
  from pathlib import Path
8
  import numpy as np
9
- import streamlit as st
10
- from streamlit_drawable_canvas import st_canvas
11
  import torch
12
  from torch import nn
13
 
14
  # ============================================================================
15
- # Model Definition (embedded from training-doodle.py)
16
  # ============================================================================
17
 
18
  class GRUClassifier(nn.Module):
@@ -22,11 +21,8 @@ class GRUClassifier(nn.Module):
22
  super().__init__()
23
  self.use_packing = use_packing
24
  self.gru = nn.GRU(
25
- input_size=input_size,
26
- hidden_size=hidden_size,
27
- num_layers=num_layers,
28
- batch_first=True,
29
- bidirectional=bidirectional,
30
  dropout=dropout if num_layers > 1 else 0.0,
31
  )
32
  out_dim = hidden_size * (2 if bidirectional else 1)
@@ -35,27 +31,21 @@ class GRUClassifier(nn.Module):
35
 
36
  def forward(self, x: torch.Tensor, lengths: torch.Tensor):
37
  if self.use_packing:
38
- packed = nn.utils.rnn.pack_padded_sequence(
39
- x, lengths.cpu(), batch_first=True, enforce_sorted=False
40
- )
41
  _, h_n = self.gru(packed)
42
  else:
43
  _, h_n = self.gru(x)
44
- if self.gru.bidirectional:
45
- h = torch.cat([h_n[-2], h_n[-1]], dim=1)
46
- else:
47
- h = h_n[-1]
48
- h = self.norm(h)
49
- return self.fc(h)
50
 
51
  def parse_drawing_to_seq(drawing_str: str) -> np.ndarray:
52
  """Convert drawing JSON to sequence of [dx, dy, pen_lift]."""
53
  try:
54
  strokes = json.loads(drawing_str)
55
- except Exception:
56
  try:
57
  strokes = ast.literal_eval(drawing_str)
58
- except Exception:
59
  return np.zeros((0, 3), dtype=np.float32)
60
 
61
  seq_parts = []
@@ -78,114 +68,45 @@ def parse_drawing_to_seq(drawing_str: str) -> np.ndarray:
78
 
79
  if not seq_parts:
80
  return np.zeros((0, 3), dtype=np.float32)
81
-
82
  seq = np.concatenate(seq_parts, axis=0)
83
  seq[:, :2] = np.clip(seq[:, :2], -1.0, 1.0)
84
  return seq.astype(np.float32)
85
 
86
  # ============================================================================
87
- # Constants
88
  # ============================================================================
89
 
90
- CANVAS_SIZE = 400
91
- STROKE_WIDTH = 3
92
  ANIMALS = ["butterfly", "cow", "elephant", "giraffe", "monkey",
93
  "octopus", "scorpion", "shark", "snake", "spider"]
94
 
95
- CALIB_TARGET_MEAN = 0.04
96
- CALIB_MAX_GAIN = 12.0
97
- CALIB_MIN_GAIN = 0.5
98
-
99
- def _calibrate_seq(seq: np.ndarray) -> np.ndarray:
100
- """Scale (dx, dy) so the mean step magnitude matches training data."""
101
- if seq is None or seq.ndim != 2 or seq.shape[1] < 2 or seq.shape[0] == 0:
102
  return seq
103
  steps = np.sqrt((seq[:, 0] ** 2) + (seq[:, 1] ** 2))
104
  curr = float(steps.mean()) if steps.size else 0.0
105
  if curr <= 1e-6:
106
  return seq
107
- gain = float(np.clip(CALIB_TARGET_MEAN / curr, CALIB_MIN_GAIN, CALIB_MAX_GAIN))
108
  out = seq.astype(np.float32).copy()
109
  out[:, 0:2] = np.clip(out[:, 0:2] * gain, -1.0, 1.0)
110
  return out
111
 
112
- # ============================================================================
113
- # Model Loading
114
- # ============================================================================
115
-
116
- @st.cache_resource
117
- def load_model():
118
- """Load the trained RNN model."""
119
- model_path = Path(__file__).parent / "rnn_animals_best.pt"
120
- if not model_path.exists():
121
- st.error(f"Model file not found: {model_path}")
122
- return None, None
123
-
124
- checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
125
- cfg = checkpoint.get("config", {})
126
-
127
- model = GRUClassifier(
128
- input_size=3,
129
- hidden_size=cfg.get("hidden_size", 512),
130
- num_layers=cfg.get("num_layers", 2),
131
- bidirectional=cfg.get("bidirectional", True),
132
- dropout=cfg.get("dropout", 0.3),
133
- num_classes=len(ANIMALS),
134
- use_packing=True
135
- )
136
- model.load_state_dict(checkpoint["model_state"])
137
- model.eval()
138
-
139
- class_to_idx = checkpoint.get("class_to_idx", {a: i for i, a in enumerate(ANIMALS)})
140
- idx_to_class = {v: k for k, v in class_to_idx.items()}
141
- return model, idx_to_class
142
-
143
- # ============================================================================
144
- # Stroke Processing
145
- # ============================================================================
146
-
147
- def canvas_strokes_to_quickdraw(canvas_json):
148
- """Convert canvas to QuickDraw format with preprocessing."""
149
- if canvas_json is None:
150
- return []
151
-
152
- objects = canvas_json.get("objects", [])
153
- raw_strokes = []
154
-
155
- for obj in objects:
156
- if obj.get("type") != "path":
157
- continue
158
- path = obj.get("path", [])
159
- xs, ys = [], []
160
- for cmd in path:
161
- if len(cmd) < 3:
162
- continue
163
- if cmd[0] == "M":
164
- xs.append(float(cmd[1]))
165
- ys.append(float(cmd[2]))
166
- elif cmd[0] == "Q" and len(cmd) >= 5:
167
- xs.append(float(cmd[3]))
168
- ys.append(float(cmd[4]))
169
- elif cmd[0] == "L":
170
- xs.append(float(cmd[1]))
171
- ys.append(float(cmd[2]))
172
- if len(xs) >= 2:
173
- raw_strokes.append((xs, ys))
174
-
175
  if not raw_strokes:
176
  return []
177
 
178
  # Downsample
179
- downsampled = []
180
  for xs, ys in raw_strokes:
181
  if len(xs) > 25:
182
  step = max(1, len(xs) // 25)
183
  xs, ys = xs[::step], ys[::step]
184
- downsampled.append((xs, ys))
185
 
186
  # Smooth
187
  smoothed = []
188
- for xs, ys in downsampled:
189
  if len(xs) >= 3:
190
  xs_s = [xs[0]] + [(xs[i-1]+xs[i]+xs[i+1])/3 for i in range(1, len(xs)-1)] + [xs[-1]]
191
  ys_s = [ys[0]] + [(ys[i-1]+ys[i]+ys[i+1])/3 for i in range(1, len(ys)-1)] + [ys[-1]]
@@ -196,9 +117,11 @@ def canvas_strokes_to_quickdraw(canvas_json):
196
  # Center and scale
197
  all_x = [x for xs, _ in smoothed for x in xs]
198
  all_y = [y for _, ys in smoothed for y in ys]
 
 
 
199
  min_x, max_x = min(all_x), max(all_x)
200
  min_y, max_y = min(all_y), max(all_y)
201
-
202
  scale = 235 / max(max(1, max_x - min_x), max(1, max_y - min_y))
203
  cx, cy = (min_x + max_x) / 2, (min_y + max_y) / 2
204
  ox, oy = 127.5 - cx * scale, 127.5 - cy * scale
@@ -210,56 +133,202 @@ def canvas_strokes_to_quickdraw(canvas_json):
210
  result.append([xs_n, ys_n])
211
  return result
212
 
213
- def predict(model, idx_to_class, strokes):
214
- """Make prediction from strokes."""
215
- if not strokes or model is None:
216
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  try:
218
- seq = parse_drawing_to_seq(json.dumps(strokes))
219
- if seq is None or len(seq) < 6:
220
- return None
 
 
 
 
 
 
 
 
 
 
 
 
221
  seq = _calibrate_seq(seq)
222
  seq_t = torch.tensor(seq, dtype=torch.float32).unsqueeze(0)
223
  lengths = torch.tensor([seq.shape[0]], dtype=torch.long)
 
224
  with torch.no_grad():
225
- probs = torch.softmax(model(seq_t, lengths), dim=1)
226
- top_p, top_i = torch.topk(probs, k=5, dim=1)
227
- return [(idx_to_class.get(top_i[0,i].item()), top_p[0,i].item()) for i in range(5)]
228
  except Exception as e:
229
- st.error(f"Error: {e}")
230
- return None
231
 
232
  # ============================================================================
233
- # Main App
234
  # ============================================================================
235
 
236
- def main():
237
- st.set_page_config(page_title="Animal Doodle Classifier", page_icon="🎨", layout="wide")
238
- st.title("🎨 Animal Doodle Classifier")
239
- st.caption("Draw: butterfly, cow, elephant, giraffe, monkey, octopus, scorpion, shark, snake, spider")
240
-
241
- model, idx_to_class = load_model()
242
- if model is None:
243
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
- col1, col2 = st.columns([1, 1])
246
- with col1:
247
- canvas = st_canvas(
248
- stroke_width=STROKE_WIDTH, stroke_color="#000000",
249
- background_color="#FFFFFF", height=CANVAS_SIZE, width=CANVAS_SIZE,
250
- drawing_mode="freedraw", key="canvas"
251
- )
 
252
 
253
- with col2:
254
- st.subheader("Predictions")
255
- if canvas.json_data:
256
- strokes = canvas_strokes_to_quickdraw(canvas.json_data)
257
- if strokes:
258
- results = predict(model, idx_to_class, strokes)
259
- if results:
260
- st.success(f"**{results[0][0].upper()}** ({results[0][1]*100:.1f}%)")
261
- for name, prob in results:
262
- st.progress(prob, text=f"{name}: {prob*100:.1f}%")
263
 
264
  if __name__ == "__main__":
265
- main()
 
1
  """
2
+ RNN Animal Doodle Classifier - Gradio App for HF Spaces
3
+ Uses custom HTML canvas to capture stroke coordinates (not rasterized)
4
  """
5
  import ast
6
  import json
7
  from pathlib import Path
8
  import numpy as np
9
+ import gradio as gr
 
10
  import torch
11
  from torch import nn
12
 
13
  # ============================================================================
14
+ # Model Definition
15
  # ============================================================================
16
 
17
  class GRUClassifier(nn.Module):
 
21
  super().__init__()
22
  self.use_packing = use_packing
23
  self.gru = nn.GRU(
24
+ input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
25
+ batch_first=True, bidirectional=bidirectional,
 
 
 
26
  dropout=dropout if num_layers > 1 else 0.0,
27
  )
28
  out_dim = hidden_size * (2 if bidirectional else 1)
 
31
 
32
  def forward(self, x: torch.Tensor, lengths: torch.Tensor):
33
  if self.use_packing:
34
+ packed = nn.utils.rnn.pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
 
 
35
  _, h_n = self.gru(packed)
36
  else:
37
  _, h_n = self.gru(x)
38
+ h = torch.cat([h_n[-2], h_n[-1]], dim=1) if self.gru.bidirectional else h_n[-1]
39
+ return self.fc(self.norm(h))
 
 
 
 
40
 
41
  def parse_drawing_to_seq(drawing_str: str) -> np.ndarray:
42
  """Convert drawing JSON to sequence of [dx, dy, pen_lift]."""
43
  try:
44
  strokes = json.loads(drawing_str)
45
+ except:
46
  try:
47
  strokes = ast.literal_eval(drawing_str)
48
+ except:
49
  return np.zeros((0, 3), dtype=np.float32)
50
 
51
  seq_parts = []
 
68
 
69
  if not seq_parts:
70
  return np.zeros((0, 3), dtype=np.float32)
 
71
  seq = np.concatenate(seq_parts, axis=0)
72
  seq[:, :2] = np.clip(seq[:, :2], -1.0, 1.0)
73
  return seq.astype(np.float32)
74
 
75
  # ============================================================================
76
+ # Constants & Utils
77
  # ============================================================================
78
 
 
 
79
  ANIMALS = ["butterfly", "cow", "elephant", "giraffe", "monkey",
80
  "octopus", "scorpion", "shark", "snake", "spider"]
81
 
82
+ def _calibrate_seq(seq, target=0.04, max_gain=12.0, min_gain=0.5):
83
+ if seq is None or len(seq) == 0:
 
 
 
 
 
84
  return seq
85
  steps = np.sqrt((seq[:, 0] ** 2) + (seq[:, 1] ** 2))
86
  curr = float(steps.mean()) if steps.size else 0.0
87
  if curr <= 1e-6:
88
  return seq
89
+ gain = float(np.clip(target / curr, min_gain, max_gain))
90
  out = seq.astype(np.float32).copy()
91
  out[:, 0:2] = np.clip(out[:, 0:2] * gain, -1.0, 1.0)
92
  return out
93
 
94
+ def preprocess_strokes(raw_strokes):
95
+ """Downsample, smooth, center, and scale strokes."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  if not raw_strokes:
97
  return []
98
 
99
  # Downsample
100
+ processed = []
101
  for xs, ys in raw_strokes:
102
  if len(xs) > 25:
103
  step = max(1, len(xs) // 25)
104
  xs, ys = xs[::step], ys[::step]
105
+ processed.append((list(xs), list(ys)))
106
 
107
  # Smooth
108
  smoothed = []
109
+ for xs, ys in processed:
110
  if len(xs) >= 3:
111
  xs_s = [xs[0]] + [(xs[i-1]+xs[i]+xs[i+1])/3 for i in range(1, len(xs)-1)] + [xs[-1]]
112
  ys_s = [ys[0]] + [(ys[i-1]+ys[i]+ys[i+1])/3 for i in range(1, len(ys)-1)] + [ys[-1]]
 
117
  # Center and scale
118
  all_x = [x for xs, _ in smoothed for x in xs]
119
  all_y = [y for _, ys in smoothed for y in ys]
120
+ if not all_x:
121
+ return []
122
+
123
  min_x, max_x = min(all_x), max(all_x)
124
  min_y, max_y = min(all_y), max(all_y)
 
125
  scale = 235 / max(max(1, max_x - min_x), max(1, max_y - min_y))
126
  cx, cy = (min_x + max_x) / 2, (min_y + max_y) / 2
127
  ox, oy = 127.5 - cx * scale, 127.5 - cy * scale
 
133
  result.append([xs_n, ys_n])
134
  return result
135
 
136
+ # ============================================================================
137
+ # Model Loading
138
+ # ============================================================================
139
+
140
+ def load_model():
141
+ model_path = Path(__file__).parent / "rnn_animals_best.pt"
142
+ if not model_path.exists():
143
+ return None, None
144
+
145
+ ckpt = torch.load(model_path, map_location="cpu", weights_only=False)
146
+ cfg = ckpt.get("config", {})
147
+
148
+ model = GRUClassifier(
149
+ input_size=3, hidden_size=cfg.get("hidden_size", 512),
150
+ num_layers=cfg.get("num_layers", 2), bidirectional=cfg.get("bidirectional", True),
151
+ dropout=cfg.get("dropout", 0.3), num_classes=len(ANIMALS), use_packing=True
152
+ )
153
+ model.load_state_dict(ckpt["model_state"])
154
+ model.eval()
155
+
156
+ class_to_idx = ckpt.get("class_to_idx", {a: i for i, a in enumerate(ANIMALS)})
157
+ idx_to_class = {v: k for k, v in class_to_idx.items()}
158
+ return model, idx_to_class
159
+
160
+ MODEL, IDX_TO_CLASS = load_model()
161
+
162
+ # ============================================================================
163
+ # Prediction
164
+ # ============================================================================
165
+
166
+ def predict(strokes_json):
167
+ """Predict from JSON stroke data."""
168
+ if MODEL is None:
169
+ return {"error": "Model not loaded"}
170
+
171
  try:
172
+ raw_strokes = json.loads(strokes_json) if isinstance(strokes_json, str) else strokes_json
173
+ if not raw_strokes:
174
+ return {a: 0.0 for a in ANIMALS}
175
+
176
+ # Convert to list of (xs, ys) tuples
177
+ stroke_tuples = [(s[0], s[1]) for s in raw_strokes if len(s) == 2]
178
+ processed = preprocess_strokes(stroke_tuples)
179
+
180
+ if not processed:
181
+ return {a: 0.0 for a in ANIMALS}
182
+
183
+ seq = parse_drawing_to_seq(json.dumps(processed))
184
+ if seq is None or len(seq) < 3:
185
+ return {a: 0.0 for a in ANIMALS}
186
+
187
  seq = _calibrate_seq(seq)
188
  seq_t = torch.tensor(seq, dtype=torch.float32).unsqueeze(0)
189
  lengths = torch.tensor([seq.shape[0]], dtype=torch.long)
190
+
191
  with torch.no_grad():
192
+ probs = torch.softmax(MODEL(seq_t, lengths), dim=1)[0]
193
+
194
+ return {IDX_TO_CLASS.get(i, f"class_{i}"): float(probs[i]) for i in range(len(ANIMALS))}
195
  except Exception as e:
196
+ return {"error": str(e)}
 
197
 
198
  # ============================================================================
199
+ # Custom Canvas HTML
200
  # ============================================================================
201
 
202
+ CANVAS_HTML = """
203
+ <div id="canvas-container" style="display: flex; flex-direction: column; align-items: center;">
204
+ <canvas id="drawing-canvas" width="400" height="400"
205
+ style="border: 2px solid #333; border-radius: 8px; background: white; cursor: crosshair;"></canvas>
206
+ <div style="margin-top: 10px;">
207
+ <button onclick="clearCanvas()" style="padding: 8px 16px; margin-right: 10px; cursor: pointer;">Clear</button>
208
+ <button onclick="sendStrokes()" style="padding: 8px 16px; background: #4CAF50; color: white; border: none; border-radius: 4px; cursor: pointer;">Predict</button>
209
+ </div>
210
+ <p style="color: #666; font-size: 12px; margin-top: 5px;">Draw an animal, then click Predict</p>
211
+ </div>
212
+
213
+ <script>
214
+ const canvas = document.getElementById('drawing-canvas');
215
+ const ctx = canvas.getContext('2d');
216
+ let isDrawing = false;
217
+ let strokes = [];
218
+ let currentStroke = {x: [], y: []};
219
+
220
+ ctx.strokeStyle = '#000';
221
+ ctx.lineWidth = 3;
222
+ ctx.lineCap = 'round';
223
+ ctx.lineJoin = 'round';
224
+
225
+ canvas.addEventListener('mousedown', (e) => {
226
+ isDrawing = true;
227
+ const rect = canvas.getBoundingClientRect();
228
+ const x = e.clientX - rect.left;
229
+ const y = e.clientY - rect.top;
230
+ currentStroke = {x: [x], y: [y]};
231
+ ctx.beginPath();
232
+ ctx.moveTo(x, y);
233
+ });
234
+
235
+ canvas.addEventListener('mousemove', (e) => {
236
+ if (!isDrawing) return;
237
+ const rect = canvas.getBoundingClientRect();
238
+ const x = e.clientX - rect.left;
239
+ const y = e.clientY - rect.top;
240
+ currentStroke.x.push(x);
241
+ currentStroke.y.push(y);
242
+ ctx.lineTo(x, y);
243
+ ctx.stroke();
244
+ });
245
+
246
+ canvas.addEventListener('mouseup', () => {
247
+ if (isDrawing && currentStroke.x.length > 1) {
248
+ strokes.push([currentStroke.x, currentStroke.y]);
249
+ }
250
+ isDrawing = false;
251
+ });
252
+
253
+ canvas.addEventListener('mouseleave', () => {
254
+ if (isDrawing && currentStroke.x.length > 1) {
255
+ strokes.push([currentStroke.x, currentStroke.y]);
256
+ }
257
+ isDrawing = false;
258
+ });
259
+
260
+ // Touch support
261
+ canvas.addEventListener('touchstart', (e) => {
262
+ e.preventDefault();
263
+ const touch = e.touches[0];
264
+ const rect = canvas.getBoundingClientRect();
265
+ const x = touch.clientX - rect.left;
266
+ const y = touch.clientY - rect.top;
267
+ isDrawing = true;
268
+ currentStroke = {x: [x], y: [y]};
269
+ ctx.beginPath();
270
+ ctx.moveTo(x, y);
271
+ });
272
+
273
+ canvas.addEventListener('touchmove', (e) => {
274
+ e.preventDefault();
275
+ if (!isDrawing) return;
276
+ const touch = e.touches[0];
277
+ const rect = canvas.getBoundingClientRect();
278
+ const x = touch.clientX - rect.left;
279
+ const y = touch.clientY - rect.top;
280
+ currentStroke.x.push(x);
281
+ currentStroke.y.push(y);
282
+ ctx.lineTo(x, y);
283
+ ctx.stroke();
284
+ });
285
+
286
+ canvas.addEventListener('touchend', () => {
287
+ if (isDrawing && currentStroke.x.length > 1) {
288
+ strokes.push([currentStroke.x, currentStroke.y]);
289
+ }
290
+ isDrawing = false;
291
+ });
292
+
293
+ function clearCanvas() {
294
+ ctx.clearRect(0, 0, canvas.width, canvas.height);
295
+ strokes = [];
296
+ }
297
+
298
+ function sendStrokes() {
299
+ const strokesJson = JSON.stringify(strokes);
300
+ // Update the hidden textbox with strokes data
301
+ const textbox = document.querySelector('#strokes-input textarea');
302
+ if (textbox) {
303
+ textbox.value = strokesJson;
304
+ textbox.dispatchEvent(new Event('input', { bubbles: true }));
305
+ }
306
+ // Also trigger the button
307
+ const btn = document.querySelector('#predict-btn');
308
+ if (btn) btn.click();
309
+ }
310
+ </script>
311
+ """
312
+
313
+ # ============================================================================
314
+ # Gradio App
315
+ # ============================================================================
316
+
317
+ with gr.Blocks(title="Animal Doodle Classifier", theme=gr.themes.Soft()) as app:
318
+ gr.Markdown("# 🎨 Animal Doodle Classifier")
319
+ gr.Markdown("Draw an animal and click **Predict**! Supported: butterfly, cow, elephant, giraffe, monkey, octopus, scorpion, shark, snake, spider")
320
 
321
+ with gr.Row():
322
+ with gr.Column(scale=1):
323
+ canvas = gr.HTML(CANVAS_HTML)
324
+ strokes_input = gr.Textbox(label="Strokes", elem_id="strokes-input", visible=False)
325
+ predict_btn = gr.Button("Predict", elem_id="predict-btn", visible=False)
326
+
327
+ with gr.Column(scale=1):
328
+ output = gr.Label(num_top_classes=5, label="Predictions")
329
 
330
+ predict_btn.click(fn=predict, inputs=strokes_input, outputs=output)
331
+ strokes_input.change(fn=predict, inputs=strokes_input, outputs=output)
 
 
 
 
 
 
 
 
332
 
333
  if __name__ == "__main__":
334
+ app.launch()
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
- streamlit>=1.28.0
2
- streamlit-drawable-canvas>=0.9.3
3
  torch>=2.0.0
4
  numpy>=1.24.0
 
1
+ gradio>=4.0.0
 
2
  torch>=2.0.0
3
  numpy>=1.24.0