""" RNN Animal Doodle Classifier - Gradio App for HF Spaces Uses custom HTML canvas to capture stroke coordinates (not rasterized) """ import ast import json from pathlib import Path import numpy as np import gradio as gr import torch from torch import nn import os # ============================================================================ # DIAGNOSTICS (Log to console for HF Spaces) # ============================================================================ print("--- STARTING APP DIAGNOSTICS ---") print(f"CWD: {os.getcwd()}") print(f"Files in CWD: {os.listdir('.')}") model_file = Path("rnn_animals_best.pt") if model_file.exists(): size = model_file.stat().st_size print(f"Model file found. Size: {size} bytes ({size/1024/1024:.2f} MB)") if size < 2000: print("WARNING: Model file is suspiciously small! Likely an LFS pointer file.") try: with open(model_file, 'r') as f: print(f"Content preview: {f.read()}") except: pass else: print("ERROR: Model file 'rnn_animals_best.pt' NOT FOUND in CWD!") print("--- END DIAGNOSTICS ---") # ============================================================================ # Model Definition # ============================================================================ class GRUClassifier(nn.Module): """Bidirectional GRU classifier for sequence classification.""" def __init__(self, input_size: int, hidden_size: int, num_layers: int, bidirectional: bool, dropout: float, num_classes: int, use_packing: bool = True): super().__init__() self.use_packing = use_packing self.gru = nn.GRU( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, bidirectional=bidirectional, dropout=dropout if num_layers > 1 else 0.0, ) out_dim = hidden_size * (2 if bidirectional else 1) self.norm = nn.LayerNorm(out_dim) self.fc = nn.Linear(out_dim, num_classes) def forward(self, x: torch.Tensor, lengths: torch.Tensor): if self.use_packing: packed = nn.utils.rnn.pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False) _, h_n = self.gru(packed) else: _, h_n = self.gru(x) h = torch.cat([h_n[-2], h_n[-1]], dim=1) if self.gru.bidirectional else h_n[-1] return self.fc(self.norm(h)) def parse_drawing_to_seq(drawing_str: str) -> np.ndarray: """Convert drawing JSON to sequence of [dx, dy, pen_lift].""" try: strokes = json.loads(drawing_str) except: try: strokes = ast.literal_eval(drawing_str) except: return np.zeros((0, 3), dtype=np.float32) seq_parts = [] for stroke in strokes: if not isinstance(stroke, (list, tuple)) or len(stroke) != 2: continue x, y = stroke n = min(len(x), len(y)) if n < 2: continue x = np.asarray(x[:n], dtype=np.int16) y = np.asarray(y[:n], dtype=np.int16) dx = np.diff(x).astype(np.float32) / 255.0 dy = np.diff(y).astype(np.float32) / 255.0 if dx.size == 0: continue pen = np.zeros_like(dx, dtype=np.float32) pen[-1] = 1.0 seq_parts.append(np.stack([dx, dy, pen], axis=1)) if not seq_parts: return np.zeros((0, 3), dtype=np.float32) seq = np.concatenate(seq_parts, axis=0) seq[:, :2] = np.clip(seq[:, :2], -1.0, 1.0) return seq.astype(np.float32) # ============================================================================ # Constants & Utils # ============================================================================ ANIMALS = ["butterfly", "cow", "elephant", "giraffe", "monkey", "octopus", "scorpion", "shark", "snake", "spider"] def _calibrate_seq(seq, target=0.04, max_gain=12.0, min_gain=0.5): if seq is None or len(seq) == 0: return seq steps = np.sqrt((seq[:, 0] ** 2) + (seq[:, 1] ** 2)) curr = float(steps.mean()) if steps.size else 0.0 if curr <= 1e-6: return seq gain = float(np.clip(target / curr, min_gain, max_gain)) out = seq.astype(np.float32).copy() out[:, 0:2] = np.clip(out[:, 0:2] * gain, -1.0, 1.0) return out def preprocess_strokes(raw_strokes): """Downsample, smooth, center, and scale strokes.""" if not raw_strokes: return [] # Downsample processed = [] for xs, ys in raw_strokes: if len(xs) > 25: step = max(1, len(xs) // 25) xs, ys = xs[::step], ys[::step] processed.append((list(xs), list(ys))) # Smooth smoothed = [] for xs, ys in processed: if len(xs) >= 3: xs_s = [xs[0]] + [(xs[i-1]+xs[i]+xs[i+1])/3 for i in range(1, len(xs)-1)] + [xs[-1]] ys_s = [ys[0]] + [(ys[i-1]+ys[i]+ys[i+1])/3 for i in range(1, len(ys)-1)] + [ys[-1]] smoothed.append((xs_s, ys_s)) else: smoothed.append((xs, ys)) # Center and scale all_x = [x for xs, _ in smoothed for x in xs] all_y = [y for _, ys in smoothed for y in ys] if not all_x: return [] min_x, max_x = min(all_x), max(all_x) min_y, max_y = min(all_y), max(all_y) scale = 235 / max(max(1, max_x - min_x), max(1, max_y - min_y)) cx, cy = (min_x + max_x) / 2, (min_y + max_y) / 2 ox, oy = 127.5 - cx * scale, 127.5 - cy * scale result = [] for xs, ys in smoothed: xs_n = [int(np.clip(x * scale + ox, 0, 255)) for x in xs] ys_n = [int(np.clip(y * scale + oy, 0, 255)) for y in ys] result.append([xs_n, ys_n]) return result # ============================================================================ # Model Loading # ============================================================================ def load_model(): model_path = Path(__file__).parent / "rnn_animals_best.pt" if not model_path.exists(): return None, None ckpt = torch.load(model_path, map_location="cpu", weights_only=False) cfg = ckpt.get("config", {}) model = GRUClassifier( input_size=3, hidden_size=cfg.get("hidden_size", 512), num_layers=cfg.get("num_layers", 2), bidirectional=cfg.get("bidirectional", True), dropout=cfg.get("dropout", 0.3), num_classes=len(ANIMALS), use_packing=True ) model.load_state_dict(ckpt["model_state"]) model.eval() class_to_idx = ckpt.get("class_to_idx", {a: i for i, a in enumerate(ANIMALS)}) idx_to_class = {v: k for k, v in class_to_idx.items()} return model, idx_to_class MODEL = None IDX_TO_CLASS = {} LOAD_ERROR = None try: MODEL, IDX_TO_CLASS = load_model() except Exception as e: LOAD_ERROR = str(e) print(f"Failed to load model: {e}") # ============================================================================ # Prediction # ============================================================================ def predict(strokes_json): """Predict from JSON stroke data.""" try: if LOAD_ERROR or MODEL is None: return {a: 0.0 for a in ANIMALS} if strokes_json is None: return {a: 0.0 for a in ANIMALS} if isinstance(strokes_json, str): s = strokes_json.strip() if not s: return {a: 0.0 for a in ANIMALS} try: raw_strokes = json.loads(s) except Exception: return {a: 0.0 for a in ANIMALS} else: raw_strokes = strokes_json if not raw_strokes: return {a: 0.0 for a in ANIMALS} # Convert to list of (xs, ys) tuples stroke_tuples = [(s[0], s[1]) for s in raw_strokes if len(s) == 2] processed = preprocess_strokes(stroke_tuples) if not processed: return {a: 0.0 for a in ANIMALS} seq = parse_drawing_to_seq(json.dumps(processed)) if seq is None or len(seq) < 3: return {a: 0.0 for a in ANIMALS} seq = _calibrate_seq(seq) seq_t = torch.tensor(seq, dtype=torch.float32).unsqueeze(0) lengths = torch.tensor([seq.shape[0]], dtype=torch.long) with torch.no_grad(): probs = torch.softmax(MODEL(seq_t, lengths), dim=1)[0] return {IDX_TO_CLASS.get(i, f"class_{i}"): float(probs[i]) for i in range(len(ANIMALS))} except Exception as e: print(f"Prediction failed: {e}") return {a: 0.0 for a in ANIMALS} # ============================================================================ # Custom Canvas HTML # ============================================================================ CANVAS_HTML = """
Draw an animal, then click Predict