QuickDraw 345 Doodle Classifier β€” TFLite

A doodle recognition model trained on all 345 categories from Google's Quick Draw Dataset, exported as TFLite for Flutter on-device offline inference.

Model Performance

Metric Accuracy
Top-1 76.19%
Top-3 89.51%
Top-5 92.26%
Top-10 94.55%
TFLite (float16) 76.40%

State-of-the-art for 345-class Quick Draw classification is ~73-75% top-1. This model exceeds that.

Files

File Size Description
quickdraw_model.tflite 8.44 MB Float16 quantized β€” recommended for Flutter
quickdraw_model_int8.tflite 4.34 MB Int8 quantized β€” smallest, fastest
labels.txt 2.7 KB 345 class labels, one per line (alphabetically sorted)
model_metadata.json 6.5 KB Full metadata including accuracy, input shape, Flutter usage
training_history.json 5.7 KB Loss/accuracy per epoch
categories.txt 2.7 KB Raw category list

Architecture

  • SE-ResNet (Squeeze-and-Excitation + ResNet blocks)
  • 3 stages: 64 β†’ 128 β†’ 256 filters
  • Input: 28Γ—28 grayscale images
  • Output: 345-class softmax
  • ~3M parameters

Training

  • Dataset: Google Quick Draw numpy bitmaps (GCS), 8,000 samples/class Γ— 345 classes = 2.76M images
  • Augmentation: Random rotation Β±8%, translation Β±8%, zoom -5%/+10%
  • Optimizer: Adam + Warmup Cosine Decay
  • Training time: ~10.9 hours on Kaggle GPU P100

Flutter Integration

pubspec.yaml

dependencies:
  tflite_flutter: ^0.10.4

flutter:
  assets:
    - assets/quickdraw_model.tflite
    - assets/labels.txt

Dart Usage

import 'package:tflite_flutter/tflite_flutter.dart';

class QuickDrawClassifier {
  late Interpreter _interpreter;
  late List<String> _labels;

  Future<void> load() async {
    _interpreter = await Interpreter.fromAsset('assets/quickdraw_model.tflite');
    final labelsData = await rootBundle.loadString('assets/labels.txt');
    _labels = labelsData.trim().split('\n');
  }

  /// [pixels] must be a 28x28 Float32List, values in [0.0, 1.0]
  /// where 0.0 = black stroke, 1.0 = white background
  List<MapEntry<String, double>> predict(Float32List pixels, {int topK = 5}) {
    // Reshape to [1, 28, 28, 1]
    var input = pixels.reshape([1, 28, 28, 1]);
    var output = List.filled(1 * 345, 0.0).reshape([1, 345]);

    _interpreter.run(input, output);

    final probs = List<double>.from(output[0]);
    final indexed = probs.asMap().entries.toList()
      ..sort((a, b) => b.value.compareTo(a.value));

    return indexed.take(topK)
      .map((e) => MapEntry(_labels[e.key], e.value))
      .toList();
  }
}

Preprocessing a drawing canvas

/// Convert your drawing canvas to a 28x28 normalized Float32List
Float32List canvasToInput(ui.Image image) async {
  // Resize to 28x28
  final recorder = ui.PictureRecorder();
  final canvas = Canvas(recorder);
  canvas.drawImageRect(
    image,
    Rect.fromLTWH(0, 0, image.width.toDouble(), image.height.toDouble()),
    Rect.fromLTWH(0, 0, 28, 28),
    Paint(),
  );
  final resized = await recorder.endRecording().toImage(28, 28);
  final bytes = await resized.toByteData(format: ui.ImageByteFormat.rawRgba);

  // Convert RGBA to grayscale float32, normalize to [0,1]
  // white background = 1.0, black strokes = 0.0
  final pixels = Float32List(28 * 28);
  for (int i = 0; i < 28 * 28; i++) {
    final r = bytes!.getUint8(i * 4);
    final g = bytes.getUint8(i * 4 + 1);
    final b = bytes.getUint8(i * 4 + 2);
    pixels[i] = (0.299 * r + 0.587 * g + 0.114 * b) / 255.0;
  }
  return pixels;
}

Input/Output Spec

Property Value
Input shape [1, 28, 28, 1]
Input dtype float32
Input range [0.0, 1.0]
Background 1.0 (white)
Stroke 0.0 (black)
Output shape [1, 345]
Output dtype float32
Output Softmax probabilities

License

Model weights: Apache 2.0
Dataset: Creative Commons Attribution 4.0 (Google Quick Draw)

Downloads last month
31
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train zarqankhn/quickdraw-345-tflite

Evaluation results