Spaces:
Sleeping
Sleeping
| """ | |
| RNN Animal Doodle Classifier - Gradio App for HF Spaces | |
| Uses custom HTML canvas to capture stroke coordinates (not rasterized) | |
| """ | |
| import ast | |
| import json | |
| from pathlib import Path | |
| import numpy as np | |
| import gradio as gr | |
| import torch | |
| from torch import nn | |
| import os | |
| # ============================================================================ | |
| # DIAGNOSTICS (Log to console for HF Spaces) | |
| # ============================================================================ | |
| print("--- STARTING APP DIAGNOSTICS ---") | |
| print(f"CWD: {os.getcwd()}") | |
| print(f"Files in CWD: {os.listdir('.')}") | |
| model_file = Path("rnn_animals_best.pt") | |
| if model_file.exists(): | |
| size = model_file.stat().st_size | |
| print(f"Model file found. Size: {size} bytes ({size/1024/1024:.2f} MB)") | |
| if size < 2000: | |
| print("WARNING: Model file is suspiciously small! Likely an LFS pointer file.") | |
| try: | |
| with open(model_file, 'r') as f: | |
| print(f"Content preview: {f.read()}") | |
| except: | |
| pass | |
| else: | |
| print("ERROR: Model file 'rnn_animals_best.pt' NOT FOUND in CWD!") | |
| print("--- END DIAGNOSTICS ---") | |
| # ============================================================================ | |
| # Model Definition | |
| # ============================================================================ | |
| class GRUClassifier(nn.Module): | |
| """Bidirectional GRU classifier for sequence classification.""" | |
| def __init__(self, input_size: int, hidden_size: int, num_layers: int, | |
| bidirectional: bool, dropout: float, num_classes: int, use_packing: bool = True): | |
| super().__init__() | |
| self.use_packing = use_packing | |
| self.gru = nn.GRU( | |
| input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, | |
| batch_first=True, bidirectional=bidirectional, | |
| dropout=dropout if num_layers > 1 else 0.0, | |
| ) | |
| out_dim = hidden_size * (2 if bidirectional else 1) | |
| self.norm = nn.LayerNorm(out_dim) | |
| self.fc = nn.Linear(out_dim, num_classes) | |
| def forward(self, x: torch.Tensor, lengths: torch.Tensor): | |
| if self.use_packing: | |
| packed = nn.utils.rnn.pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False) | |
| _, h_n = self.gru(packed) | |
| else: | |
| _, h_n = self.gru(x) | |
| h = torch.cat([h_n[-2], h_n[-1]], dim=1) if self.gru.bidirectional else h_n[-1] | |
| return self.fc(self.norm(h)) | |
| def parse_drawing_to_seq(drawing_str: str) -> np.ndarray: | |
| """Convert drawing JSON to sequence of [dx, dy, pen_lift].""" | |
| try: | |
| strokes = json.loads(drawing_str) | |
| except: | |
| try: | |
| strokes = ast.literal_eval(drawing_str) | |
| except: | |
| return np.zeros((0, 3), dtype=np.float32) | |
| seq_parts = [] | |
| for stroke in strokes: | |
| if not isinstance(stroke, (list, tuple)) or len(stroke) != 2: | |
| continue | |
| x, y = stroke | |
| n = min(len(x), len(y)) | |
| if n < 2: | |
| continue | |
| x = np.asarray(x[:n], dtype=np.int16) | |
| y = np.asarray(y[:n], dtype=np.int16) | |
| dx = np.diff(x).astype(np.float32) / 255.0 | |
| dy = np.diff(y).astype(np.float32) / 255.0 | |
| if dx.size == 0: | |
| continue | |
| pen = np.zeros_like(dx, dtype=np.float32) | |
| pen[-1] = 1.0 | |
| seq_parts.append(np.stack([dx, dy, pen], axis=1)) | |
| if not seq_parts: | |
| return np.zeros((0, 3), dtype=np.float32) | |
| seq = np.concatenate(seq_parts, axis=0) | |
| seq[:, :2] = np.clip(seq[:, :2], -1.0, 1.0) | |
| return seq.astype(np.float32) | |
| # ============================================================================ | |
| # Constants & Utils | |
| # ============================================================================ | |
| ANIMALS = ["butterfly", "cow", "elephant", "giraffe", "monkey", | |
| "octopus", "scorpion", "shark", "snake", "spider"] | |
| def _calibrate_seq(seq, target=0.04, max_gain=12.0, min_gain=0.5): | |
| if seq is None or len(seq) == 0: | |
| return seq | |
| steps = np.sqrt((seq[:, 0] ** 2) + (seq[:, 1] ** 2)) | |
| curr = float(steps.mean()) if steps.size else 0.0 | |
| if curr <= 1e-6: | |
| return seq | |
| gain = float(np.clip(target / curr, min_gain, max_gain)) | |
| out = seq.astype(np.float32).copy() | |
| out[:, 0:2] = np.clip(out[:, 0:2] * gain, -1.0, 1.0) | |
| return out | |
| def preprocess_strokes(raw_strokes): | |
| """Downsample, smooth, center, and scale strokes.""" | |
| if not raw_strokes: | |
| return [] | |
| # Downsample | |
| processed = [] | |
| for xs, ys in raw_strokes: | |
| if len(xs) > 25: | |
| step = max(1, len(xs) // 25) | |
| xs, ys = xs[::step], ys[::step] | |
| processed.append((list(xs), list(ys))) | |
| # Smooth | |
| smoothed = [] | |
| for xs, ys in processed: | |
| if len(xs) >= 3: | |
| xs_s = [xs[0]] + [(xs[i-1]+xs[i]+xs[i+1])/3 for i in range(1, len(xs)-1)] + [xs[-1]] | |
| ys_s = [ys[0]] + [(ys[i-1]+ys[i]+ys[i+1])/3 for i in range(1, len(ys)-1)] + [ys[-1]] | |
| smoothed.append((xs_s, ys_s)) | |
| else: | |
| smoothed.append((xs, ys)) | |
| # Center and scale | |
| all_x = [x for xs, _ in smoothed for x in xs] | |
| all_y = [y for _, ys in smoothed for y in ys] | |
| if not all_x: | |
| return [] | |
| min_x, max_x = min(all_x), max(all_x) | |
| min_y, max_y = min(all_y), max(all_y) | |
| scale = 235 / max(max(1, max_x - min_x), max(1, max_y - min_y)) | |
| cx, cy = (min_x + max_x) / 2, (min_y + max_y) / 2 | |
| ox, oy = 127.5 - cx * scale, 127.5 - cy * scale | |
| result = [] | |
| for xs, ys in smoothed: | |
| xs_n = [int(np.clip(x * scale + ox, 0, 255)) for x in xs] | |
| ys_n = [int(np.clip(y * scale + oy, 0, 255)) for y in ys] | |
| result.append([xs_n, ys_n]) | |
| return result | |
| # ============================================================================ | |
| # Model Loading | |
| # ============================================================================ | |
| def load_model(): | |
| model_path = Path(__file__).parent / "rnn_animals_best.pt" | |
| if not model_path.exists(): | |
| return None, None | |
| ckpt = torch.load(model_path, map_location="cpu", weights_only=False) | |
| cfg = ckpt.get("config", {}) | |
| model = GRUClassifier( | |
| input_size=3, hidden_size=cfg.get("hidden_size", 512), | |
| num_layers=cfg.get("num_layers", 2), bidirectional=cfg.get("bidirectional", True), | |
| dropout=cfg.get("dropout", 0.3), num_classes=len(ANIMALS), use_packing=True | |
| ) | |
| model.load_state_dict(ckpt["model_state"]) | |
| model.eval() | |
| class_to_idx = ckpt.get("class_to_idx", {a: i for i, a in enumerate(ANIMALS)}) | |
| idx_to_class = {v: k for k, v in class_to_idx.items()} | |
| return model, idx_to_class | |
| MODEL = None | |
| IDX_TO_CLASS = {} | |
| LOAD_ERROR = None | |
| try: | |
| MODEL, IDX_TO_CLASS = load_model() | |
| except Exception as e: | |
| LOAD_ERROR = str(e) | |
| print(f"Failed to load model: {e}") | |
| # ============================================================================ | |
| # Prediction | |
| # ============================================================================ | |
| def predict(strokes_json): | |
| """Predict from JSON stroke data.""" | |
| try: | |
| if LOAD_ERROR or MODEL is None: | |
| return {a: 0.0 for a in ANIMALS} | |
| if strokes_json is None: | |
| return {a: 0.0 for a in ANIMALS} | |
| if isinstance(strokes_json, str): | |
| s = strokes_json.strip() | |
| if not s: | |
| return {a: 0.0 for a in ANIMALS} | |
| try: | |
| raw_strokes = json.loads(s) | |
| except Exception: | |
| return {a: 0.0 for a in ANIMALS} | |
| else: | |
| raw_strokes = strokes_json | |
| if not raw_strokes: | |
| return {a: 0.0 for a in ANIMALS} | |
| # Convert to list of (xs, ys) tuples | |
| stroke_tuples = [(s[0], s[1]) for s in raw_strokes if len(s) == 2] | |
| processed = preprocess_strokes(stroke_tuples) | |
| if not processed: | |
| return {a: 0.0 for a in ANIMALS} | |
| seq = parse_drawing_to_seq(json.dumps(processed)) | |
| if seq is None or len(seq) < 3: | |
| return {a: 0.0 for a in ANIMALS} | |
| seq = _calibrate_seq(seq) | |
| seq_t = torch.tensor(seq, dtype=torch.float32).unsqueeze(0) | |
| lengths = torch.tensor([seq.shape[0]], dtype=torch.long) | |
| with torch.no_grad(): | |
| probs = torch.softmax(MODEL(seq_t, lengths), dim=1)[0] | |
| return {IDX_TO_CLASS.get(i, f"class_{i}"): float(probs[i]) for i in range(len(ANIMALS))} | |
| except Exception as e: | |
| print(f"Prediction failed: {e}") | |
| return {a: 0.0 for a in ANIMALS} | |
| # ============================================================================ | |
| # Custom Canvas HTML | |
| # ============================================================================ | |
| CANVAS_HTML = """ | |
| <div id="canvas-container" style="display: flex; flex-direction: column; align-items: center; position: relative; z-index: 10;"> | |
| <canvas id="drawing-canvas" width="400" height="400" | |
| style="border: 2px solid #333; border-radius: 8px; background: white; cursor: crosshair; touch-action: none;"></canvas> | |
| <div style="margin-top: 10px;"> | |
| <button id="clear-canvas-btn" style="padding: 8px 16px; margin-right: 10px; cursor: pointer; border: 1px solid #ccc; border-radius: 4px; background: #fff;">Clear</button> | |
| <button id="predict-canvas-btn" style="padding: 8px 16px; background: #4CAF50; color: white; border: none; border-radius: 4px; cursor: pointer;">Predict</button> | |
| </div> | |
| <p style="color: #666; font-size: 12px; margin-top: 5px;">Draw an animal, then click Predict</p> | |
| </div> | |
| """ | |
| CANVAS_JS = r"""() => { | |
| const CANVAS_ID = "drawing-canvas"; | |
| const CLEAR_ID = "clear-canvas-btn"; | |
| const PREDICT_ID = "predict-canvas-btn"; | |
| const getTextInput = () => | |
| document.querySelector("#strokes-input textarea, #strokes-input input"); | |
| const getGradioPredictButton = () => | |
| document.querySelector("#predict-btn button") || | |
| document.querySelector("button#predict-btn") || | |
| document.querySelector("#predict-btn"); | |
| const initCanvas = () => { | |
| const canvas = document.getElementById(CANVAS_ID); | |
| const clearBtn = document.getElementById(CLEAR_ID); | |
| const predictBtn = document.getElementById(PREDICT_ID); | |
| if (!canvas || !clearBtn || !predictBtn) return false; | |
| if (canvas.dataset.bound === "1") return true; | |
| const ctx = canvas.getContext("2d", { willReadFrequently: true }); | |
| if (!ctx) return false; | |
| canvas.dataset.bound = "1"; | |
| let isDrawing = false; | |
| let strokes = []; | |
| let currentStroke = { x: [], y: [] }; | |
| ctx.strokeStyle = "#000"; | |
| ctx.lineWidth = 3; | |
| ctx.lineCap = "round"; | |
| ctx.lineJoin = "round"; | |
| const getPos = (clientX, clientY) => { | |
| const rect = canvas.getBoundingClientRect(); | |
| return [clientX - rect.left, clientY - rect.top]; | |
| }; | |
| const startStroke = (x, y) => { | |
| isDrawing = true; | |
| currentStroke = { x: [x], y: [y] }; | |
| ctx.beginPath(); | |
| ctx.moveTo(x, y); | |
| }; | |
| const moveStroke = (x, y) => { | |
| if (!isDrawing) return; | |
| currentStroke.x.push(x); | |
| currentStroke.y.push(y); | |
| ctx.lineTo(x, y); | |
| ctx.stroke(); | |
| }; | |
| const endStroke = () => { | |
| if (isDrawing && currentStroke.x.length > 0) { | |
| strokes.push([currentStroke.x, currentStroke.y]); | |
| } | |
| isDrawing = false; | |
| syncToTextbox(); | |
| }; | |
| const syncToTextbox = () => { | |
| const textbox = getTextInput(); | |
| if (!textbox) return; | |
| textbox.value = JSON.stringify(strokes); | |
| textbox.dispatchEvent(new Event("input", { bubbles: true })); | |
| }; | |
| canvas.addEventListener("mousedown", (e) => { | |
| const [x, y] = getPos(e.clientX, e.clientY); | |
| startStroke(x, y); | |
| }); | |
| canvas.addEventListener("mousemove", (e) => { | |
| const [x, y] = getPos(e.clientX, e.clientY); | |
| moveStroke(x, y); | |
| }); | |
| canvas.addEventListener("mouseup", endStroke); | |
| canvas.addEventListener("mouseleave", endStroke); | |
| canvas.addEventListener( | |
| "touchstart", | |
| (e) => { | |
| e.preventDefault(); | |
| const touch = e.touches[0]; | |
| const [x, y] = getPos(touch.clientX, touch.clientY); | |
| startStroke(x, y); | |
| }, | |
| { passive: false } | |
| ); | |
| canvas.addEventListener( | |
| "touchmove", | |
| (e) => { | |
| e.preventDefault(); | |
| if (!isDrawing) return; | |
| const touch = e.touches[0]; | |
| const [x, y] = getPos(touch.clientX, touch.clientY); | |
| moveStroke(x, y); | |
| }, | |
| { passive: false } | |
| ); | |
| canvas.addEventListener("touchend", endStroke); | |
| canvas.addEventListener("touchcancel", endStroke); | |
| clearBtn.addEventListener("click", () => { | |
| ctx.clearRect(0, 0, canvas.width, canvas.height); | |
| strokes = []; | |
| syncToTextbox(); | |
| }); | |
| predictBtn.addEventListener("click", () => { | |
| syncToTextbox(); | |
| const btn = getGradioPredictButton(); | |
| if (btn) btn.click(); | |
| }); | |
| return true; | |
| }; | |
| const startedAt = Date.now(); | |
| const maxWaitMs = 10000; | |
| const tick = () => { | |
| if (initCanvas()) return; | |
| if (Date.now() - startedAt > maxWaitMs) return; | |
| requestAnimationFrame(tick); | |
| }; | |
| tick(); | |
| } | |
| """ | |
| # ============================================================================ | |
| # Gradio App | |
| # ============================================================================ | |
| CSS = """ | |
| #strokes-input, #predict-btn { | |
| display: none !important; | |
| } | |
| """ | |
| with gr.Blocks(title="Animal Doodle Classifier", theme=gr.themes.Soft(), css=CSS, js=CANVAS_JS) as app: | |
| gr.Markdown("# 🎨 Animal Doodle Classifier") | |
| gr.Markdown("Draw an animal and click **Predict**! Supported: butterfly, cow, elephant, giraffe, monkey, octopus, scorpion, shark, snake, spider") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| canvas = gr.HTML(CANVAS_HTML) | |
| # visible=True so they are in DOM, hidden by CSS | |
| strokes_input = gr.Textbox(label="Strokes", elem_id="strokes-input", visible=True, lines=3) | |
| predict_btn = gr.Button("Predict", elem_id="predict-btn", visible=True) | |
| with gr.Column(scale=1): | |
| output = gr.Label(num_top_classes=5, label="Predictions") | |
| predict_btn.click(fn=predict, inputs=strokes_input, outputs=output) | |
| strokes_input.change(fn=predict, inputs=strokes_input, outputs=output) | |
| if __name__ == "__main__": | |
| app.launch() | |