Elliot Sones commited on
Commit
f5801b6
·
1 Parent(s): 8a44149

Fix canvas: run JS via Blocks

Browse files
Files changed (1) hide show
  1. app.py +138 -105
app.py CHANGED
@@ -237,113 +237,146 @@ CANVAS_HTML = """
237
  <canvas id="drawing-canvas" width="400" height="400"
238
  style="border: 2px solid #333; border-radius: 8px; background: white; cursor: crosshair; touch-action: none;"></canvas>
239
  <div style="margin-top: 10px;">
240
- <button onclick="clearCanvas()" style="padding: 8px 16px; margin-right: 10px; cursor: pointer; border: 1px solid #ccc; border-radius: 4px; background: #fff;">Clear</button>
241
- <button onclick="sendStrokes()" style="padding: 8px 16px; background: #4CAF50; color: white; border: none; border-radius: 4px; cursor: pointer;">Predict</button>
242
  </div>
243
  <p style="color: #666; font-size: 12px; margin-top: 5px;">Draw an animal, then click Predict</p>
244
  </div>
 
245
 
246
- <script>
247
- (function() {
248
- // Wait for DOM
249
- setTimeout(() => {
250
- const canvas = document.getElementById('drawing-canvas');
251
- if (!canvas) {
252
- console.error("Canvas not found!");
253
- return;
254
- }
255
-
256
- const ctx = canvas.getContext('2d', { willReadFrequently: true });
257
- let isDrawing = false;
258
- let strokes = [];
259
- let currentStroke = {x: [], y: []};
260
-
261
- ctx.strokeStyle = '#000';
262
- ctx.lineWidth = 3;
263
- ctx.lineCap = 'round';
264
- ctx.lineJoin = 'round';
265
-
266
- // Mouse Events
267
- canvas.addEventListener('mousedown', (e) => {
268
- isDrawing = true;
269
- const rect = canvas.getBoundingClientRect();
270
- const x = e.clientX - rect.left;
271
- const y = e.clientY - rect.top;
272
- currentStroke = {x: [x], y: [y]};
273
- ctx.beginPath();
274
- ctx.moveTo(x, y);
275
- });
276
-
277
- canvas.addEventListener('mousemove', (e) => {
278
- if (!isDrawing) return;
279
- const rect = canvas.getBoundingClientRect();
280
- const x = e.clientX - rect.left;
281
- const y = e.clientY - rect.top;
282
- currentStroke.x.push(x);
283
- currentStroke.y.push(y);
284
- ctx.lineTo(x, y);
285
- ctx.stroke();
286
- });
287
-
288
- const endStroke = () => {
289
- if (isDrawing && currentStroke.x.length > 0) {
290
- strokes.push([currentStroke.x, currentStroke.y]);
291
- }
292
- isDrawing = false;
293
- };
294
-
295
- canvas.addEventListener('mouseup', endStroke);
296
- canvas.addEventListener('mouseleave', endStroke);
297
-
298
- // Touch Events
299
- canvas.addEventListener('touchstart', (e) => {
300
- e.preventDefault();
301
- const touch = e.touches[0];
302
- const rect = canvas.getBoundingClientRect();
303
- const x = touch.clientX - rect.left;
304
- const y = touch.clientY - rect.top;
305
- isDrawing = true;
306
- currentStroke = {x: [x], y: [y]};
307
- ctx.beginPath();
308
- ctx.moveTo(x, y);
309
- }, { passive: false });
310
-
311
- canvas.addEventListener('touchmove', (e) => {
312
- e.preventDefault();
313
- if (!isDrawing) return;
314
- const touch = e.touches[0];
315
- const rect = canvas.getBoundingClientRect();
316
- const x = touch.clientX - rect.left;
317
- const y = touch.clientY - rect.top;
318
- currentStroke.x.push(x);
319
- currentStroke.y.push(y);
320
- ctx.lineTo(x, y);
321
- ctx.stroke();
322
- }, { passive: false });
323
-
324
- canvas.addEventListener('touchend', endStroke);
325
-
326
- // Global functions for buttons
327
- window.clearCanvas = function() {
328
- ctx.clearRect(0, 0, canvas.width, canvas.height);
329
- strokes = [];
330
- };
331
-
332
- window.sendStrokes = function() {
333
- const strokesJson = JSON.stringify(strokes);
334
- const textbox = document.querySelector('#strokes-input textarea');
335
- if (textbox) {
336
- textbox.value = strokesJson;
337
- textbox.dispatchEvent(new Event('input', { bubbles: true }));
338
- }
339
- const btn = document.querySelector('#predict-btn');
340
- if (btn) btn.click();
341
- };
342
-
343
- console.log("Canvas initialized!");
344
- }, 500); // Small delay to ensure render
345
- })();
346
- </script>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  """
348
 
349
  # ============================================================================
@@ -356,7 +389,7 @@ CSS = """
356
  }
357
  """
358
 
359
- with gr.Blocks(title="Animal Doodle Classifier", theme=gr.themes.Soft(), css=CSS) as app:
360
  gr.Markdown("# 🎨 Animal Doodle Classifier")
361
  gr.Markdown("Draw an animal and click **Predict**! Supported: butterfly, cow, elephant, giraffe, monkey, octopus, scorpion, shark, snake, spider")
362
 
@@ -364,7 +397,7 @@ with gr.Blocks(title="Animal Doodle Classifier", theme=gr.themes.Soft(), css=CSS
364
  with gr.Column(scale=1):
365
  canvas = gr.HTML(CANVAS_HTML)
366
  # visible=True so they are in DOM, hidden by CSS
367
- strokes_input = gr.Textbox(label="Strokes", elem_id="strokes-input", visible=True)
368
  predict_btn = gr.Button("Predict", elem_id="predict-btn", visible=True)
369
 
370
  with gr.Column(scale=1):
 
237
  <canvas id="drawing-canvas" width="400" height="400"
238
  style="border: 2px solid #333; border-radius: 8px; background: white; cursor: crosshair; touch-action: none;"></canvas>
239
  <div style="margin-top: 10px;">
240
+ <button id="clear-canvas-btn" style="padding: 8px 16px; margin-right: 10px; cursor: pointer; border: 1px solid #ccc; border-radius: 4px; background: #fff;">Clear</button>
241
+ <button id="predict-canvas-btn" style="padding: 8px 16px; background: #4CAF50; color: white; border: none; border-radius: 4px; cursor: pointer;">Predict</button>
242
  </div>
243
  <p style="color: #666; font-size: 12px; margin-top: 5px;">Draw an animal, then click Predict</p>
244
  </div>
245
+ """
246
 
247
+ CANVAS_JS = r"""() => {
248
+ const CANVAS_ID = "drawing-canvas";
249
+ const CLEAR_ID = "clear-canvas-btn";
250
+ const PREDICT_ID = "predict-canvas-btn";
251
+
252
+ const getTextInput = () =>
253
+ document.querySelector("#strokes-input textarea, #strokes-input input");
254
+
255
+ const getGradioPredictButton = () =>
256
+ document.querySelector("#predict-btn button") ||
257
+ document.querySelector("button#predict-btn") ||
258
+ document.querySelector("#predict-btn");
259
+
260
+ const initCanvas = () => {
261
+ const canvas = document.getElementById(CANVAS_ID);
262
+ const clearBtn = document.getElementById(CLEAR_ID);
263
+ const predictBtn = document.getElementById(PREDICT_ID);
264
+ if (!canvas || !clearBtn || !predictBtn) return false;
265
+ if (canvas.dataset.bound === "1") return true;
266
+
267
+ const ctx = canvas.getContext("2d", { willReadFrequently: true });
268
+ if (!ctx) return false;
269
+
270
+ canvas.dataset.bound = "1";
271
+
272
+ let isDrawing = false;
273
+ let strokes = [];
274
+ let currentStroke = { x: [], y: [] };
275
+
276
+ ctx.strokeStyle = "#000";
277
+ ctx.lineWidth = 3;
278
+ ctx.lineCap = "round";
279
+ ctx.lineJoin = "round";
280
+
281
+ const getPos = (clientX, clientY) => {
282
+ const rect = canvas.getBoundingClientRect();
283
+ return [clientX - rect.left, clientY - rect.top];
284
+ };
285
+
286
+ const startStroke = (x, y) => {
287
+ isDrawing = true;
288
+ currentStroke = { x: [x], y: [y] };
289
+ ctx.beginPath();
290
+ ctx.moveTo(x, y);
291
+ };
292
+
293
+ const moveStroke = (x, y) => {
294
+ if (!isDrawing) return;
295
+ currentStroke.x.push(x);
296
+ currentStroke.y.push(y);
297
+ ctx.lineTo(x, y);
298
+ ctx.stroke();
299
+ };
300
+
301
+ const endStroke = () => {
302
+ if (isDrawing && currentStroke.x.length > 0) {
303
+ strokes.push([currentStroke.x, currentStroke.y]);
304
+ }
305
+ isDrawing = false;
306
+ };
307
+
308
+ canvas.addEventListener("mousedown", (e) => {
309
+ const [x, y] = getPos(e.clientX, e.clientY);
310
+ startStroke(x, y);
311
+ });
312
+
313
+ canvas.addEventListener("mousemove", (e) => {
314
+ const [x, y] = getPos(e.clientX, e.clientY);
315
+ moveStroke(x, y);
316
+ });
317
+
318
+ canvas.addEventListener("mouseup", endStroke);
319
+ canvas.addEventListener("mouseleave", endStroke);
320
+
321
+ canvas.addEventListener(
322
+ "touchstart",
323
+ (e) => {
324
+ e.preventDefault();
325
+ const touch = e.touches[0];
326
+ const [x, y] = getPos(touch.clientX, touch.clientY);
327
+ startStroke(x, y);
328
+ },
329
+ { passive: false }
330
+ );
331
+
332
+ canvas.addEventListener(
333
+ "touchmove",
334
+ (e) => {
335
+ e.preventDefault();
336
+ if (!isDrawing) return;
337
+ const touch = e.touches[0];
338
+ const [x, y] = getPos(touch.clientX, touch.clientY);
339
+ moveStroke(x, y);
340
+ },
341
+ { passive: false }
342
+ );
343
+
344
+ canvas.addEventListener("touchend", endStroke);
345
+ canvas.addEventListener("touchcancel", endStroke);
346
+
347
+ clearBtn.addEventListener("click", () => {
348
+ ctx.clearRect(0, 0, canvas.width, canvas.height);
349
+ strokes = [];
350
+ const textbox = getTextInput();
351
+ if (textbox) {
352
+ textbox.value = "";
353
+ textbox.dispatchEvent(new Event("input", { bubbles: true }));
354
+ }
355
+ });
356
+
357
+ predictBtn.addEventListener("click", () => {
358
+ const strokesJson = JSON.stringify(strokes);
359
+ const textbox = getTextInput();
360
+ if (textbox) {
361
+ textbox.value = strokesJson;
362
+ textbox.dispatchEvent(new Event("input", { bubbles: true }));
363
+ }
364
+ const btn = getGradioPredictButton();
365
+ if (btn) btn.click();
366
+ });
367
+
368
+ return true;
369
+ };
370
+
371
+ const startedAt = Date.now();
372
+ const maxWaitMs = 10000;
373
+ const tick = () => {
374
+ if (initCanvas()) return;
375
+ if (Date.now() - startedAt > maxWaitMs) return;
376
+ requestAnimationFrame(tick);
377
+ };
378
+ tick();
379
+ }
380
  """
381
 
382
  # ============================================================================
 
389
  }
390
  """
391
 
392
+ with gr.Blocks(title="Animal Doodle Classifier", theme=gr.themes.Soft(), css=CSS, js=CANVAS_JS) as app:
393
  gr.Markdown("# 🎨 Animal Doodle Classifier")
394
  gr.Markdown("Draw an animal and click **Predict**! Supported: butterfly, cow, elephant, giraffe, monkey, octopus, scorpion, shark, snake, spider")
395
 
 
397
  with gr.Column(scale=1):
398
  canvas = gr.HTML(CANVAS_HTML)
399
  # visible=True so they are in DOM, hidden by CSS
400
+ strokes_input = gr.Textbox(label="Strokes", elem_id="strokes-input", visible=True, lines=3)
401
  predict_btn = gr.Button("Predict", elem_id="predict-btn", visible=True)
402
 
403
  with gr.Column(scale=1):