Spaces:
Sleeping
Sleeping
initial deployment
Browse files- .gitignore +12 -0
- Makefile +27 -0
- README.md +126 -10
- app.py +284 -0
- config.py +60 -0
- models/.gitkeep +0 -0
- notebooks/colab_train.ipynb +564 -0
- notebooks/exploration.ipynb +43 -0
- requirements.txt +15 -0
- scripts/__init__.py +1 -0
- scripts/build_features.py +145 -0
- scripts/make_dataset.py +65 -0
- scripts/model.py +616 -0
- setup.py +56 -0
.gitignore
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
.venv/
|
| 4 |
+
venv/
|
| 5 |
+
.env
|
| 6 |
+
data/raw/*.npy
|
| 7 |
+
data/processed/
|
| 8 |
+
data/outputs/
|
| 9 |
+
models/*.pkl
|
| 10 |
+
notebooks/.ipynb_checkpoints/
|
| 11 |
+
.DS_Store
|
| 12 |
+
*.egg-info/
|
Makefile
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.PHONY: setup download features train app clean
|
| 2 |
+
|
| 3 |
+
# Run full pipeline (download → features → train)
|
| 4 |
+
setup:
|
| 5 |
+
python setup.py
|
| 6 |
+
|
| 7 |
+
# Individual steps
|
| 8 |
+
download:
|
| 9 |
+
python scripts/make_dataset.py
|
| 10 |
+
|
| 11 |
+
features:
|
| 12 |
+
python scripts/build_features.py
|
| 13 |
+
|
| 14 |
+
train:
|
| 15 |
+
python scripts/model.py
|
| 16 |
+
|
| 17 |
+
# Run the app locally
|
| 18 |
+
app:
|
| 19 |
+
python app.py
|
| 20 |
+
|
| 21 |
+
# Remove all generated data and model files
|
| 22 |
+
clean:
|
| 23 |
+
rm -rf data/raw/*.npy data/processed/*.npy data/outputs/* models/*.pkl models/*.pth
|
| 24 |
+
|
| 25 |
+
# Install dependencies
|
| 26 |
+
install:
|
| 27 |
+
pip install -r requirements.txt
|
README.md
CHANGED
|
@@ -1,13 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ScribblBot
|
| 2 |
+
|
| 3 |
+
Real time sketch recognition powered by a lightweight CNN trained on Google's Quick Draw dataset. Draw anything in the browser and ScribblBot identifies it instantly.
|
| 4 |
+
|
| 5 |
+
**Live app:** [your HuggingFace Spaces URL here]
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## Results
|
| 10 |
+
|
| 11 |
+
| Model | Architecture | Test Accuracy |
|
| 12 |
+
|---|---|---|
|
| 13 |
+
| Majority Class Baseline | Always predicts most frequent class | 6.67% |
|
| 14 |
+
| Random Forest | HOG features, 200 trees | 85.10% |
|
| 15 |
+
| ScribblNet | 3-layer CNN | **94.42%** |
|
| 16 |
+
|
| 17 |
+
ScribblNet was trained on 30,000 samples across 15 classes (2,000 per class) for 15 epochs using Adam with cosine annealing. Training took under 5 minutes on Apple M-series hardware with MPS acceleration. The Random Forest operates on 1,296-dimensional HOG feature vectors extracted from 28×28 grayscale bitmaps.
|
| 18 |
+
|
| 19 |
+
**Classes:** cat · dog · pizza · bicycle · house · sun · tree · car · fish · butterfly · guitar · hamburger · airplane · banana · star
|
| 20 |
+
|
| 21 |
+
### Per-class performance (ScribblNet)
|
| 22 |
+
|
| 23 |
+
| Class | Precision | Recall | F1 |
|
| 24 |
+
|---|---|---|---|
|
| 25 |
+
| cat | 0.88 | 0.83 | 0.85 |
|
| 26 |
+
| dog | 0.81 | 0.82 | 0.82 |
|
| 27 |
+
| pizza | 0.94 | 0.97 | 0.96 |
|
| 28 |
+
| bicycle | 0.96 | 0.98 | 0.97 |
|
| 29 |
+
| house | 0.99 | 0.98 | 0.98 |
|
| 30 |
+
| sun | 0.95 | 0.96 | 0.96 |
|
| 31 |
+
| tree | 0.96 | 0.97 | 0.97 |
|
| 32 |
+
| car | 0.97 | 0.96 | 0.96 |
|
| 33 |
+
| fish | 0.97 | 0.95 | 0.96 |
|
| 34 |
+
| butterfly | 0.97 | 0.97 | 0.97 |
|
| 35 |
+
| guitar | 0.95 | 0.98 | 0.96 |
|
| 36 |
+
| hamburger | 0.99 | 0.97 | 0.98 |
|
| 37 |
+
| airplane | 0.91 | 0.89 | 0.90 |
|
| 38 |
+
| banana | 0.97 | 0.98 | 0.97 |
|
| 39 |
+
| star | 0.94 | 0.94 | 0.94 |
|
| 40 |
+
|
| 41 |
+
Cat and dog are the hardest classes, which is expected given their visual similarity in quick sketches. Airplane also underperforms, likely due to style variation in how people draw wings and fuselage.
|
| 42 |
+
|
| 43 |
+
---
|
| 44 |
+
|
| 45 |
+
## Experiment: Training Size Sensitivity
|
| 46 |
+
|
| 47 |
+
Both ScribblNet and Random Forest were trained at 10%, 25%, 50%, 75%, and 100% of available training data.
|
| 48 |
+
|
| 49 |
+
| Fraction | Samples | ScribblNet | Random Forest |
|
| 50 |
+
|---|---|---|---|
|
| 51 |
+
| 10% | 3,000 | 86.12% | 77.92% |
|
| 52 |
+
| 25% | 7,500 | 90.37% | 80.35% |
|
| 53 |
+
| 50% | 15,000 | 92.70% | 81.57% |
|
| 54 |
+
| 75% | 22,500 | 93.00% | 82.88% |
|
| 55 |
+
| 100% | 30,000 | 94.02% | 83.03% |
|
| 56 |
+
|
| 57 |
+
The CNN scales more steeply with data volume than the Random Forest. At 10% of training data the gap is about 8 points; at 100% it grows to 11 points. The Random Forest plateaus around 83% while ScribblNet continues improving, suggesting the CNN would benefit further from additional data.
|
| 58 |
+
|
| 59 |
+
---
|
| 60 |
+
|
| 61 |
+
## Dataset
|
| 62 |
+
|
| 63 |
+
[Quick Draw](https://quickdraw.withgoogle.com/data) by Google — 50 million drawings across 345 categories, collected from players of the Quick Draw game. Each drawing is a 28×28 grayscale bitmap stored as a flat 784-element uint8 vector. The dataset is publicly available via Google Cloud Storage.
|
| 64 |
+
|
| 65 |
---
|
| 66 |
+
|
| 67 |
+
## Setup
|
| 68 |
+
|
| 69 |
+
```bash
|
| 70 |
+
pip install -r requirements.txt
|
| 71 |
+
python setup.py
|
| 72 |
+
python app.py
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
`setup.py` runs the full pipeline: downloads the raw `.npy` files, extracts HOG features, trains all three models, and runs the experiment.
|
| 76 |
+
|
| 77 |
+
Individual steps:
|
| 78 |
+
|
| 79 |
+
```bash
|
| 80 |
+
python scripts/make_dataset.py
|
| 81 |
+
python scripts/build_features.py
|
| 82 |
+
python scripts/model.py
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
---
|
| 86 |
|
| 87 |
+
## Repository Structure
|
| 88 |
+
|
| 89 |
+
```
|
| 90 |
+
scribblbot/
|
| 91 |
+
├── README.md
|
| 92 |
+
├── requirements.txt
|
| 93 |
+
├── Makefile
|
| 94 |
+
├── setup.py
|
| 95 |
+
├── app.py
|
| 96 |
+
├── config.py
|
| 97 |
+
├── scripts/
|
| 98 |
+
│ ├── make_dataset.py
|
| 99 |
+
│ ├── build_features.py
|
| 100 |
+
│ └── model.py
|
| 101 |
+
├── models/
|
| 102 |
+
├── data/
|
| 103 |
+
│ ├── raw/
|
| 104 |
+
│ ├── processed/
|
| 105 |
+
│ └── outputs/
|
| 106 |
+
└── notebooks/
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
| Component | Location |
|
| 110 |
+
|---|---|
|
| 111 |
+
| Naive baseline | `scripts/model.py` — `MajorityClassifier`, saved to `models/naive_model.pkl` |
|
| 112 |
+
| Random Forest | `scripts/model.py` — `train_classical()`, saved to `models/classical_model.pkl` |
|
| 113 |
+
| ScribblNet CNN | `scripts/model.py` — `ScribblNet`, `train_deep()`, saved to `models/deep_model.pth` |
|
| 114 |
+
| Inference app | `app.py` |
|
| 115 |
+
| Config | `config.py` |
|
| 116 |
+
|
| 117 |
+
---
|
| 118 |
+
|
| 119 |
+
## Deployment
|
| 120 |
+
|
| 121 |
+
1. `python setup.py` to train and generate `models/deep_model.pth`
|
| 122 |
+
2. Create a new Space on HuggingFace (SDK: Gradio)
|
| 123 |
+
3. Push the full repo including `models/deep_model.pth`
|
| 124 |
+
|
| 125 |
+
---
|
| 126 |
+
|
| 127 |
+
## Git Workflow
|
| 128 |
+
|
| 129 |
+
Working branches: `develop` for integration, `feature/*` for individual changes. All work branches into `develop` via pull requests. `develop` merges into `main` for releases. No direct commits to `main`.
|
app.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
app.py - ScribblBot inference application.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import gradio as gr
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 16 |
+
|
| 17 |
+
from config import CLASSES, CLASS_EMOJIS, MODELS_DIR, NUM_CLASSES
|
| 18 |
+
from scripts.model import ScribblNet
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _load_model() -> tuple[ScribblNet, torch.device]:
|
| 22 |
+
"""Load trained ScribblNet weights from disk."""
|
| 23 |
+
device = torch.device("cpu")
|
| 24 |
+
model_path = MODELS_DIR / "deep_model.pth"
|
| 25 |
+
if not model_path.exists():
|
| 26 |
+
raise FileNotFoundError(f"Weights not found at {model_path}. Run python setup.py first.")
|
| 27 |
+
model = ScribblNet(num_classes=NUM_CLASSES)
|
| 28 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 29 |
+
model.eval()
|
| 30 |
+
return model, device
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
MODEL, DEVICE = _load_model()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def predict(sketch: Optional[dict], _counter: int) -> tuple[str, int]:
|
| 37 |
+
"""Run inference on an ImageEditor drawing.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
sketch: Dict from gr.ImageEditor with 'composite' key.
|
| 41 |
+
_counter: Click counter used to bust Gradio output caching.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
Tuple of (HTML results string, incremented counter).
|
| 45 |
+
"""
|
| 46 |
+
_counter += 1
|
| 47 |
+
if sketch is None:
|
| 48 |
+
return _empty_state_html(), _counter
|
| 49 |
+
|
| 50 |
+
img_array = sketch.get("composite") if isinstance(sketch, dict) else sketch
|
| 51 |
+
if img_array is None:
|
| 52 |
+
return _empty_state_html(), _counter
|
| 53 |
+
|
| 54 |
+
try:
|
| 55 |
+
img_pil = Image.fromarray(img_array.astype(np.uint8))
|
| 56 |
+
if img_pil.mode == "RGBA":
|
| 57 |
+
white = Image.new("RGBA", img_pil.size, (248, 247, 242, 255))
|
| 58 |
+
img_pil = Image.alpha_composite(white, img_pil).convert("L")
|
| 59 |
+
else:
|
| 60 |
+
img_pil = img_pil.convert("L")
|
| 61 |
+
img_pil = img_pil.resize((28, 28), Image.LANCZOS)
|
| 62 |
+
arr = np.array(img_pil, dtype=np.float32)
|
| 63 |
+
arr = (255.0 - arr) / 255.0
|
| 64 |
+
tensor = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0).to(DEVICE)
|
| 65 |
+
except Exception as exc:
|
| 66 |
+
return _error_html(str(exc)), _counter
|
| 67 |
+
|
| 68 |
+
with torch.no_grad():
|
| 69 |
+
probs = F.softmax(MODEL(tensor), dim=1)[0].cpu().numpy()
|
| 70 |
+
|
| 71 |
+
top = [(CLASSES[i], float(probs[i])) for i in np.argsort(probs)[::-1][:5]]
|
| 72 |
+
return _results_html(top), _counter
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _results_html(top: list[tuple[str, float]]) -> str:
|
| 76 |
+
best_cls, best_prob = top[0]
|
| 77 |
+
conf_pct = best_prob * 100
|
| 78 |
+
label = "CONFIDENT" if best_prob > 0.7 else "LIKELY" if best_prob > 0.4 else "UNSURE"
|
| 79 |
+
bars = ""
|
| 80 |
+
for i, (cls, prob) in enumerate(top):
|
| 81 |
+
pct = prob * 100
|
| 82 |
+
bars += f"""
|
| 83 |
+
<div class="bar-row" style="animation-delay:{i*0.08}s">
|
| 84 |
+
<span class="bar-emoji">{CLASS_EMOJIS.get(cls,'')}</span>
|
| 85 |
+
<span class="bar-label">{cls.upper()}</span>
|
| 86 |
+
<div class="bar-track"><div class="bar-fill" style="width:{pct:.1f}%;animation-delay:{i*0.08+0.1}s"></div></div>
|
| 87 |
+
<span class="bar-pct">{pct:.1f}%</span>
|
| 88 |
+
</div>"""
|
| 89 |
+
return f"""
|
| 90 |
+
<div class="results-panel fade-in">
|
| 91 |
+
<div class="result-tag">[ PREDICTION ]</div>
|
| 92 |
+
<div class="top-result">
|
| 93 |
+
<span class="top-emoji">{CLASS_EMOJIS.get(best_cls,'?')}</span>
|
| 94 |
+
<div class="top-text">
|
| 95 |
+
<div class="top-label">{best_cls.upper()}</div>
|
| 96 |
+
<div class="top-conf">{conf_pct:.1f}% · {label}</div>
|
| 97 |
+
</div>
|
| 98 |
+
</div>
|
| 99 |
+
<div class="divider"></div>
|
| 100 |
+
<div class="section-label">TOP 5 PROBABILITIES</div>
|
| 101 |
+
{bars}
|
| 102 |
+
</div>"""
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _empty_state_html() -> str:
|
| 106 |
+
return """
|
| 107 |
+
<div class="results-panel empty-state">
|
| 108 |
+
<div class="empty-icon">✏️</div>
|
| 109 |
+
<div class="empty-title">DRAW SOMETHING</div>
|
| 110 |
+
<div class="empty-sub">then hit ANALYZE</div>
|
| 111 |
+
<div class="class-pills">
|
| 112 |
+
<span class="pill">🐱 cat</span><span class="pill">🐶 dog</span>
|
| 113 |
+
<span class="pill">🍕 pizza</span><span class="pill">🚲 bicycle</span>
|
| 114 |
+
<span class="pill">🏠 house</span><span class="pill">☀️ sun</span>
|
| 115 |
+
<span class="pill">🌳 tree</span><span class="pill">🚗 car</span>
|
| 116 |
+
<span class="pill">🐟 fish</span><span class="pill">🦋 butterfly</span>
|
| 117 |
+
<span class="pill">🎸 guitar</span><span class="pill">🍔 hamburger</span>
|
| 118 |
+
<span class="pill">✈️ airplane</span><span class="pill">🍌 banana</span>
|
| 119 |
+
<span class="pill">⭐ star</span>
|
| 120 |
+
</div>
|
| 121 |
+
</div>"""
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _error_html(msg: str) -> str:
|
| 125 |
+
return f'<div class="results-panel error-state"><p class="err-msg">⚠ {msg}</p></div>'
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
CUSTOM_CSS = """
|
| 129 |
+
@import url('https://fonts.googleapis.com/css2?family=VT323&family=IBM+Plex+Mono:wght@400;500&display=swap');
|
| 130 |
+
|
| 131 |
+
:root {
|
| 132 |
+
--bg: #080808;
|
| 133 |
+
--surface: #111111;
|
| 134 |
+
--surface2: #1a1a1a;
|
| 135 |
+
--border: #2a2a2a;
|
| 136 |
+
--accent: #b8ff57;
|
| 137 |
+
--text: #e8e8e0;
|
| 138 |
+
--text-muted:#888880;
|
| 139 |
+
--red: #ff5f57;
|
| 140 |
+
--mono: 'IBM Plex Mono', monospace;
|
| 141 |
+
--display: 'VT323', monospace;
|
| 142 |
+
}
|
| 143 |
+
body, .gradio-container, #root {
|
| 144 |
+
background: var(--bg) !important;
|
| 145 |
+
font-family: var(--mono) !important;
|
| 146 |
+
color: var(--text) !important;
|
| 147 |
+
}
|
| 148 |
+
.gradio-container { max-width: 1100px !important; margin: 0 auto !important; }
|
| 149 |
+
footer { display: none !important; }
|
| 150 |
+
.block, .gr-box { background: transparent !important; border: none !important; box-shadow: none !important; }
|
| 151 |
+
|
| 152 |
+
.app-header { text-align: center; padding: 36px 20px 20px; border-bottom: 1px solid var(--border); margin-bottom: 28px; }
|
| 153 |
+
.app-title { font-family: var(--display); font-size: 72px; line-height: 1; color: var(--accent); letter-spacing: 6px; text-shadow: 0 0 30px rgba(184,255,87,0.3); margin: 0; }
|
| 154 |
+
.app-subtitle { font-size: 12px; color: var(--text-muted); letter-spacing: 4px; margin-top: 6px; }
|
| 155 |
+
|
| 156 |
+
/* ImageEditor styling */
|
| 157 |
+
/* Override Gradio's orange accent with our green */
|
| 158 |
+
.sketch-col { --color-accent: #b8ff57 !important; --color-accent-soft: rgba(184,255,87,0.15) !important; }
|
| 159 |
+
.sketch-col .image-editor { border: 1px solid var(--border) !important; border-radius: 4px !important; background: var(--surface) !important; }
|
| 160 |
+
/* Hide color picker and swatch - we only need pen and eraser */
|
| 161 |
+
.sketch-col [aria-label="Color"],
|
| 162 |
+
.sketch-col [title="Color"],
|
| 163 |
+
.sketch-col .image-editor .toolbar > button:nth-child(3),
|
| 164 |
+
.sketch-col .image-editor .toolbar > button:nth-child(4) { display: none !important; }
|
| 165 |
+
/* Toolbar background */
|
| 166 |
+
.sketch-col .image-editor > div { background: var(--surface2) !important; }
|
| 167 |
+
/* All buttons */
|
| 168 |
+
.sketch-col .image-editor button {
|
| 169 |
+
background: var(--surface2) !important;
|
| 170 |
+
border: 1px solid var(--border) !important;
|
| 171 |
+
border-radius: 3px !important;
|
| 172 |
+
margin: 2px !important;
|
| 173 |
+
color: var(--text) !important;
|
| 174 |
+
}
|
| 175 |
+
.sketch-col .image-editor button:hover {
|
| 176 |
+
background: var(--accent) !important;
|
| 177 |
+
border-color: var(--accent) !important;
|
| 178 |
+
color: #000 !important;
|
| 179 |
+
}
|
| 180 |
+
/* Active tool */
|
| 181 |
+
.sketch-col .image-editor button[aria-pressed="true"] {
|
| 182 |
+
border: 2px solid var(--accent) !important;
|
| 183 |
+
background: rgba(184,255,87,0.15) !important;
|
| 184 |
+
color: var(--accent) !important;
|
| 185 |
+
}
|
| 186 |
+
/* Force all SVG icons to white/text color */
|
| 187 |
+
.sketch-col .image-editor svg * { color: inherit !important; stroke: currentColor !important; }
|
| 188 |
+
.sketch-col [data-testid="layer-wrap"] { display: none !important; }
|
| 189 |
+
.sketch-col .layers-panel { display: none !important; }
|
| 190 |
+
/* White canvas */
|
| 191 |
+
.sketch-col .konvajs-content,
|
| 192 |
+
.sketch-col .konvajs-content canvas,
|
| 193 |
+
.sketch-col canvas { background: #f8f7f2 !important; background-color: #f8f7f2 !important; }
|
| 194 |
+
.sketch-col canvas { cursor: crosshair !important; }
|
| 195 |
+
.sketch-col * { cursor: auto; }
|
| 196 |
+
.sketch-col canvas { cursor: crosshair !important; }
|
| 197 |
+
|
| 198 |
+
.results-panel { background: var(--surface); border: 1px solid var(--border); border-radius: 4px; padding: 20px; min-height: 420px; font-family: var(--mono); }
|
| 199 |
+
.result-tag { font-size: 11px; color: var(--accent); letter-spacing: 3px; margin-bottom: 16px; }
|
| 200 |
+
.top-result { display: flex; align-items: center; gap: 18px; margin-bottom: 18px; }
|
| 201 |
+
.top-emoji { font-size: 56px; line-height: 1; }
|
| 202 |
+
.top-label { font-family: var(--display); font-size: 52px; color: var(--text); line-height: 1; letter-spacing: 3px; }
|
| 203 |
+
.top-conf { font-size: 13px; color: var(--accent); margin-top: 4px; }
|
| 204 |
+
.divider { height: 1px; background: var(--border); margin: 16px 0; }
|
| 205 |
+
.section-label { font-size: 10px; color: var(--text-muted); letter-spacing: 3px; margin-bottom: 12px; }
|
| 206 |
+
.bar-row { display: grid; grid-template-columns: 28px 90px 1fr 50px; align-items: center; gap: 8px; margin-bottom: 10px; opacity: 0; animation: slideIn 0.3s ease forwards; }
|
| 207 |
+
.bar-emoji { font-size: 16px; text-align: center; }
|
| 208 |
+
.bar-label { font-size: 11px; color: var(--text-muted); letter-spacing: 1px; }
|
| 209 |
+
.bar-track { height: 6px; background: var(--surface2); border-radius: 3px; overflow: hidden; }
|
| 210 |
+
.bar-fill { height: 100%; background: var(--accent); border-radius: 3px; width: 0; animation: barGrow 0.4s ease forwards; }
|
| 211 |
+
.bar-pct { font-size: 11px; color: var(--text); text-align: right; }
|
| 212 |
+
|
| 213 |
+
.empty-state { display: flex; flex-direction: column; align-items: center; justify-content: center; min-height: 360px; }
|
| 214 |
+
.empty-icon { font-size: 48px; margin-bottom: 12px; }
|
| 215 |
+
.empty-title { font-family: var(--display); font-size: 36px; color: var(--accent); letter-spacing: 3px; }
|
| 216 |
+
.empty-sub { font-size: 12px; color: var(--text-muted); margin: 4px 0 24px; letter-spacing: 2px; }
|
| 217 |
+
.class-pills { display: flex; flex-wrap: wrap; gap: 6px; justify-content: center; max-width: 340px; }
|
| 218 |
+
.pill { background: var(--surface2); border: 1px solid var(--border); padding: 3px 10px; border-radius: 20px; font-size: 11px; color: var(--text-muted); }
|
| 219 |
+
.error-state { display: flex; align-items: center; justify-content: center; min-height: 200px; }
|
| 220 |
+
.err-msg { font-size: 13px; color: var(--red); }
|
| 221 |
+
|
| 222 |
+
.analyze-row { padding: 12px 0 0 !important; }
|
| 223 |
+
.analyze-row button { width: 100% !important; background: rgba(184,255,87,0.06) !important; border: 2px solid var(--accent) !important; color: var(--accent) !important; font-family: var(--mono) !important; font-size: 15px !important; letter-spacing: 4px !important; padding: 14px !important; border-radius: 2px !important; cursor: pointer !important; transition: background 0.15s !important; }
|
| 224 |
+
.analyze-row button:hover { background: var(--accent) !important; color: #000 !important; }
|
| 225 |
+
|
| 226 |
+
.app-footer { text-align: center; padding: 18px; font-size: 11px; color: var(--text-muted); letter-spacing: 1px; border-top: 1px solid var(--border); margin-top: 16px; }
|
| 227 |
+
|
| 228 |
+
@keyframes slideIn { from { opacity: 0; transform: translateX(-8px); } to { opacity: 1; transform: translateX(0); } }
|
| 229 |
+
@keyframes barGrow { from { width: 0; } }
|
| 230 |
+
.fade-in { animation: fadeIn 0.25s ease; }
|
| 231 |
+
@keyframes fadeIn { from { opacity: 0; } to { opacity: 1; } }
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def build_app() -> gr.Blocks:
|
| 236 |
+
"""Construct and return the Gradio Blocks application."""
|
| 237 |
+
with gr.Blocks(css=CUSTOM_CSS, title="ScribblBot") as app:
|
| 238 |
+
|
| 239 |
+
gr.HTML("""
|
| 240 |
+
<div class="app-header">
|
| 241 |
+
<h1 class="app-title">SCRIBBLBOT</h1>
|
| 242 |
+
<p class="app-subtitle">NEURAL SKETCH CLASSIFIER · 15 CATEGORIES · QUICK DRAW DATASET</p>
|
| 243 |
+
</div>
|
| 244 |
+
""")
|
| 245 |
+
|
| 246 |
+
click_counter = gr.State(0)
|
| 247 |
+
|
| 248 |
+
with gr.Row():
|
| 249 |
+
with gr.Column(elem_classes=["sketch-col"]):
|
| 250 |
+
sketch_input = gr.ImageEditor(
|
| 251 |
+
type="numpy",
|
| 252 |
+
image_mode="RGBA",
|
| 253 |
+
canvas_size=(480, 480),
|
| 254 |
+
layers=False,
|
| 255 |
+
sources=[],
|
| 256 |
+
brush=gr.Brush(
|
| 257 |
+
colors=["#111111"],
|
| 258 |
+
default_size=14,
|
| 259 |
+
color_mode="fixed",
|
| 260 |
+
),
|
| 261 |
+
eraser=gr.Eraser(default_size=20),
|
| 262 |
+
show_label=False,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
with gr.Column():
|
| 266 |
+
result_html = gr.HTML(_empty_state_html())
|
| 267 |
+
|
| 268 |
+
with gr.Row(elem_classes=["analyze-row"]):
|
| 269 |
+
analyze_btn = gr.Button("ANALYZE")
|
| 270 |
+
|
| 271 |
+
gr.HTML('<div class="app-footer">ScribblBot · built with Quick Draw · PyTorch · Gradio</div>')
|
| 272 |
+
|
| 273 |
+
analyze_btn.click(
|
| 274 |
+
fn=predict,
|
| 275 |
+
inputs=[sketch_input, click_counter],
|
| 276 |
+
outputs=[result_html, click_counter],
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
return app
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
if __name__ == "__main__":
|
| 283 |
+
demo = build_app()
|
| 284 |
+
demo.launch()
|
config.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Central configuration for ScribblBot.
|
| 3 |
+
All hyperparameters, paths, and constants live here.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
# Paths
|
| 9 |
+
PROJECT_ROOT = Path(__file__).parent
|
| 10 |
+
DATA_DIR = PROJECT_ROOT / "data"
|
| 11 |
+
RAW_DIR = DATA_DIR / "raw"
|
| 12 |
+
PROCESSED_DIR = DATA_DIR / "processed"
|
| 13 |
+
OUTPUTS_DIR = DATA_DIR / "outputs"
|
| 14 |
+
MODELS_DIR = PROJECT_ROOT / "models"
|
| 15 |
+
|
| 16 |
+
for _d in [RAW_DIR, PROCESSED_DIR, OUTPUTS_DIR, MODELS_DIR]:
|
| 17 |
+
_d.mkdir(parents=True, exist_ok=True)
|
| 18 |
+
|
| 19 |
+
# Classes
|
| 20 |
+
# 15 visually distinct Quick Draw categories
|
| 21 |
+
CLASSES = [
|
| 22 |
+
"cat", "dog", "pizza", "bicycle", "house",
|
| 23 |
+
"sun", "tree", "car", "fish", "butterfly",
|
| 24 |
+
"guitar", "hamburger", "airplane", "banana", "star",
|
| 25 |
+
]
|
| 26 |
+
NUM_CLASSES = len(CLASSES)
|
| 27 |
+
|
| 28 |
+
CLASS_EMOJIS = {
|
| 29 |
+
"cat": "🐱", "dog": "🐶", "pizza": "🍕", "bicycle": "🚲",
|
| 30 |
+
"house": "🏠", "sun": "☀️", "tree": "🌳", "car": "🚗",
|
| 31 |
+
"fish": "🐟", "butterfly": "🦋", "guitar": "🎸", "hamburger": "🍔",
|
| 32 |
+
"airplane": "✈️", "banana": "🍌", "star": "⭐",
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
# Dataset
|
| 36 |
+
TRAIN_SAMPLES_PER_CLASS = 2000 # keeps training fast (~30k total)
|
| 37 |
+
TEST_SAMPLES_PER_CLASS = 400 # solid eval set (~6k total)
|
| 38 |
+
IMG_SIZE = 28 # Quick Draw native resolution
|
| 39 |
+
|
| 40 |
+
# Quick Draw public GCS bucket
|
| 41 |
+
QUICKDRAW_URL = (
|
| 42 |
+
"https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/{cls}.npy"
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# Deep Model
|
| 46 |
+
DEEP_BATCH_SIZE = 128
|
| 47 |
+
DEEP_EPOCHS = 15
|
| 48 |
+
DEEP_LR = 1e-3
|
| 49 |
+
DEEP_WEIGHT_DECAY = 1e-4
|
| 50 |
+
|
| 51 |
+
# Classical Model
|
| 52 |
+
RF_N_ESTIMATORS = 200
|
| 53 |
+
RF_MAX_DEPTH = None
|
| 54 |
+
HOG_ORIENTATIONS = 9
|
| 55 |
+
HOG_PIXELS_PER_CELL = (4, 4)
|
| 56 |
+
HOG_CELLS_PER_BLOCK = (2, 2)
|
| 57 |
+
|
| 58 |
+
# Experiment: training set size sensitivity
|
| 59 |
+
EXPERIMENT_FRACTIONS = [0.1, 0.25, 0.5, 0.75, 1.0]
|
| 60 |
+
EXPERIMENT_EPOCHS = 10 # shorter runs for the sweep
|
models/.gitkeep
ADDED
|
File without changes
|
notebooks/colab_train.ipynb
ADDED
|
@@ -0,0 +1,564 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# ScribblBot — Colab Training Notebook\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"Run every cell top to bottom. At the end your trained model files will be saved to Google Drive.\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"**Before running:** go to Runtime > Change runtime type > T4 GPU"
|
| 12 |
+
]
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"cell_type": "code",
|
| 16 |
+
"execution_count": null,
|
| 17 |
+
"metadata": {},
|
| 18 |
+
"outputs": [],
|
| 19 |
+
"source": [
|
| 20 |
+
"# 1. Install dependencies\n",
|
| 21 |
+
"!pip install scikit-image seaborn joblib --quiet"
|
| 22 |
+
]
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"cell_type": "code",
|
| 26 |
+
"execution_count": null,
|
| 27 |
+
"metadata": {},
|
| 28 |
+
"outputs": [],
|
| 29 |
+
"source": [
|
| 30 |
+
"# 2. Mount Google Drive so models persist after the session ends\n",
|
| 31 |
+
"from google.colab import drive\n",
|
| 32 |
+
"drive.mount('/content/drive')\n",
|
| 33 |
+
"\n",
|
| 34 |
+
"import os\n",
|
| 35 |
+
"SAVE_DIR = '/content/drive/MyDrive/scribblbot_models'\n",
|
| 36 |
+
"os.makedirs(SAVE_DIR, exist_ok=True)\n",
|
| 37 |
+
"print(f'Models will be saved to: {SAVE_DIR}')"
|
| 38 |
+
]
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"cell_type": "code",
|
| 42 |
+
"execution_count": null,
|
| 43 |
+
"metadata": {},
|
| 44 |
+
"outputs": [],
|
| 45 |
+
"source": [
|
| 46 |
+
"# 3. Config\n",
|
| 47 |
+
"from pathlib import Path\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"CLASSES = [\n",
|
| 50 |
+
" 'cat', 'dog', 'pizza', 'bicycle', 'house',\n",
|
| 51 |
+
" 'sun', 'tree', 'car', 'fish', 'butterfly',\n",
|
| 52 |
+
" 'guitar', 'hamburger', 'airplane', 'banana', 'star',\n",
|
| 53 |
+
"]\n",
|
| 54 |
+
"NUM_CLASSES = len(CLASSES)\n",
|
| 55 |
+
"\n",
|
| 56 |
+
"TRAIN_SAMPLES_PER_CLASS = 2000\n",
|
| 57 |
+
"TEST_SAMPLES_PER_CLASS = 400\n",
|
| 58 |
+
"IMG_SIZE = 28\n",
|
| 59 |
+
"\n",
|
| 60 |
+
"DEEP_BATCH_SIZE = 256 # bigger batch since we have GPU\n",
|
| 61 |
+
"DEEP_EPOCHS = 15\n",
|
| 62 |
+
"DEEP_LR = 1e-3\n",
|
| 63 |
+
"DEEP_WEIGHT_DECAY = 1e-4\n",
|
| 64 |
+
"\n",
|
| 65 |
+
"RF_N_ESTIMATORS = 200\n",
|
| 66 |
+
"HOG_ORIENTATIONS = 9\n",
|
| 67 |
+
"HOG_PIXELS_PER_CELL = (4, 4)\n",
|
| 68 |
+
"HOG_CELLS_PER_BLOCK = (2, 2)\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"EXPERIMENT_FRACTIONS = [0.1, 0.25, 0.5, 0.75, 1.0]\n",
|
| 71 |
+
"EXPERIMENT_EPOCHS = 10\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"RAW_DIR = Path('/content/data/raw')\n",
|
| 74 |
+
"PROCESSED_DIR = Path('/content/data/processed')\n",
|
| 75 |
+
"OUTPUTS_DIR = Path('/content/data/outputs')\n",
|
| 76 |
+
"MODELS_DIR = Path('/content/models')\n",
|
| 77 |
+
"\n",
|
| 78 |
+
"for d in [RAW_DIR, PROCESSED_DIR, OUTPUTS_DIR, MODELS_DIR]:\n",
|
| 79 |
+
" d.mkdir(parents=True, exist_ok=True)\n",
|
| 80 |
+
"\n",
|
| 81 |
+
"QUICKDRAW_URL = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/{cls}.npy'\n",
|
| 82 |
+
"\n",
|
| 83 |
+
"print(f'Config ready. {NUM_CLASSES} classes, {TRAIN_SAMPLES_PER_CLASS} train samples each.')"
|
| 84 |
+
]
|
| 85 |
+
},
|
| 86 |
+
{
|
| 87 |
+
"cell_type": "code",
|
| 88 |
+
"execution_count": null,
|
| 89 |
+
"metadata": {},
|
| 90 |
+
"outputs": [],
|
| 91 |
+
"source": [
|
| 92 |
+
"# 4. Download Quick Draw data\n",
|
| 93 |
+
"import urllib.request\n",
|
| 94 |
+
"\n",
|
| 95 |
+
"def download_class(cls):\n",
|
| 96 |
+
" url = QUICKDRAW_URL.format(cls=cls.replace(' ', '%20'))\n",
|
| 97 |
+
" dest = RAW_DIR / f'{cls}.npy'\n",
|
| 98 |
+
" if dest.exists():\n",
|
| 99 |
+
" print(f' already have {cls}.npy')\n",
|
| 100 |
+
" return\n",
|
| 101 |
+
" urllib.request.urlretrieve(url, dest)\n",
|
| 102 |
+
" print(f' downloaded {cls}.npy')\n",
|
| 103 |
+
"\n",
|
| 104 |
+
"print('Downloading dataset...')\n",
|
| 105 |
+
"for cls in CLASSES:\n",
|
| 106 |
+
" download_class(cls)\n",
|
| 107 |
+
"print('Done.')"
|
| 108 |
+
]
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"cell_type": "code",
|
| 112 |
+
"execution_count": null,
|
| 113 |
+
"metadata": {},
|
| 114 |
+
"outputs": [],
|
| 115 |
+
"source": [
|
| 116 |
+
"# 5. Build train/test splits and HOG features\n",
|
| 117 |
+
"import numpy as np\n",
|
| 118 |
+
"from skimage.feature import hog\n",
|
| 119 |
+
"\n",
|
| 120 |
+
"def load_class_data(cls, n_train, n_test):\n",
|
| 121 |
+
" data = np.load(RAW_DIR / f'{cls}.npy', mmap_mode='r')\n",
|
| 122 |
+
" rng = np.random.default_rng(seed=42)\n",
|
| 123 |
+
" indices = rng.permutation(len(data))[:n_train + n_test]\n",
|
| 124 |
+
" data = data[indices]\n",
|
| 125 |
+
" return data[:n_train], data[n_train:n_train + n_test]\n",
|
| 126 |
+
"\n",
|
| 127 |
+
"def extract_hog(pixel_matrix):\n",
|
| 128 |
+
" features = []\n",
|
| 129 |
+
" for row in pixel_matrix:\n",
|
| 130 |
+
" img = row.reshape(IMG_SIZE, IMG_SIZE)\n",
|
| 131 |
+
" desc = hog(img, orientations=HOG_ORIENTATIONS,\n",
|
| 132 |
+
" pixels_per_cell=HOG_PIXELS_PER_CELL,\n",
|
| 133 |
+
" cells_per_block=HOG_CELLS_PER_BLOCK,\n",
|
| 134 |
+
" visualize=False, channel_axis=None)\n",
|
| 135 |
+
" features.append(desc)\n",
|
| 136 |
+
" return np.array(features, dtype=np.float32)\n",
|
| 137 |
+
"\n",
|
| 138 |
+
"print('Loading raw data...')\n",
|
| 139 |
+
"train_raws, test_raws, train_labels, test_labels = [], [], [], []\n",
|
| 140 |
+
"for idx, cls in enumerate(CLASSES):\n",
|
| 141 |
+
" tr, te = load_class_data(cls, TRAIN_SAMPLES_PER_CLASS, TEST_SAMPLES_PER_CLASS)\n",
|
| 142 |
+
" train_raws.append(tr)\n",
|
| 143 |
+
" test_raws.append(te)\n",
|
| 144 |
+
" train_labels.append(np.full(len(tr), idx, dtype=np.int64))\n",
|
| 145 |
+
" test_labels.append(np.full(len(te), idx, dtype=np.int64))\n",
|
| 146 |
+
" print(f' {cls}')\n",
|
| 147 |
+
"\n",
|
| 148 |
+
"X_train_raw = np.concatenate(train_raws)\n",
|
| 149 |
+
"X_test_raw = np.concatenate(test_raws)\n",
|
| 150 |
+
"y_train = np.concatenate(train_labels)\n",
|
| 151 |
+
"y_test = np.concatenate(test_labels)\n",
|
| 152 |
+
"\n",
|
| 153 |
+
"rng = np.random.default_rng(seed=0)\n",
|
| 154 |
+
"perm = rng.permutation(len(X_train_raw))\n",
|
| 155 |
+
"X_train_raw = X_train_raw[perm]\n",
|
| 156 |
+
"y_train = y_train[perm]\n",
|
| 157 |
+
"\n",
|
| 158 |
+
"print('\\nExtracting HOG features (train)...')\n",
|
| 159 |
+
"X_train_hog = extract_hog(X_train_raw)\n",
|
| 160 |
+
"print('Extracting HOG features (test)...')\n",
|
| 161 |
+
"X_test_hog = extract_hog(X_test_raw)\n",
|
| 162 |
+
"\n",
|
| 163 |
+
"np.save(PROCESSED_DIR / 'X_train_raw.npy', X_train_raw)\n",
|
| 164 |
+
"np.save(PROCESSED_DIR / 'X_test_raw.npy', X_test_raw)\n",
|
| 165 |
+
"np.save(PROCESSED_DIR / 'y_train.npy', y_train)\n",
|
| 166 |
+
"np.save(PROCESSED_DIR / 'y_test.npy', y_test)\n",
|
| 167 |
+
"np.save(PROCESSED_DIR / 'X_train_hog.npy', X_train_hog)\n",
|
| 168 |
+
"np.save(PROCESSED_DIR / 'X_test_hog.npy', X_test_hog)\n",
|
| 169 |
+
"\n",
|
| 170 |
+
"print(f'\\nTrain: {X_train_raw.shape}, Test: {X_test_raw.shape}, HOG features: {X_train_hog.shape[1]}')"
|
| 171 |
+
]
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"cell_type": "code",
|
| 175 |
+
"execution_count": null,
|
| 176 |
+
"metadata": {},
|
| 177 |
+
"outputs": [],
|
| 178 |
+
"source": [
|
| 179 |
+
"# 6. Naive baseline\n",
|
| 180 |
+
"from sklearn.metrics import accuracy_score\n",
|
| 181 |
+
"import joblib\n",
|
| 182 |
+
"\n",
|
| 183 |
+
"majority_class = int(np.bincount(y_train).argmax())\n",
|
| 184 |
+
"naive_preds = np.full(len(y_test), majority_class, dtype=np.int64)\n",
|
| 185 |
+
"naive_acc = accuracy_score(y_test, naive_preds)\n",
|
| 186 |
+
"\n",
|
| 187 |
+
"joblib.dump({'majority_class': majority_class, 'accuracy': naive_acc}, MODELS_DIR / 'naive_model.pkl')\n",
|
| 188 |
+
"print(f'Naive baseline accuracy: {naive_acc:.4f}')\n",
|
| 189 |
+
"print(f'Majority class: {CLASSES[majority_class]}')"
|
| 190 |
+
]
|
| 191 |
+
},
|
| 192 |
+
{
|
| 193 |
+
"cell_type": "code",
|
| 194 |
+
"execution_count": null,
|
| 195 |
+
"metadata": {},
|
| 196 |
+
"outputs": [],
|
| 197 |
+
"source": [
|
| 198 |
+
"# 7. Classical ML: Random Forest on HOG features\n",
|
| 199 |
+
"from sklearn.ensemble import RandomForestClassifier\n",
|
| 200 |
+
"from sklearn.preprocessing import StandardScaler\n",
|
| 201 |
+
"from sklearn.metrics import classification_report, confusion_matrix\n",
|
| 202 |
+
"import time\n",
|
| 203 |
+
"\n",
|
| 204 |
+
"scaler = StandardScaler()\n",
|
| 205 |
+
"X_tr_scaled = scaler.fit_transform(X_train_hog)\n",
|
| 206 |
+
"X_te_scaled = scaler.transform(X_test_hog)\n",
|
| 207 |
+
"\n",
|
| 208 |
+
"rf = RandomForestClassifier(n_estimators=RF_N_ESTIMATORS, n_jobs=-1, random_state=42)\n",
|
| 209 |
+
"t0 = time.time()\n",
|
| 210 |
+
"rf.fit(X_tr_scaled, y_train)\n",
|
| 211 |
+
"elapsed = time.time() - t0\n",
|
| 212 |
+
"\n",
|
| 213 |
+
"rf_preds = rf.predict(X_te_scaled)\n",
|
| 214 |
+
"rf_acc = accuracy_score(y_test, rf_preds)\n",
|
| 215 |
+
"\n",
|
| 216 |
+
"joblib.dump({'clf': rf, 'scaler': scaler}, MODELS_DIR / 'classical_model.pkl')\n",
|
| 217 |
+
"\n",
|
| 218 |
+
"print(f'Random Forest trained in {elapsed:.1f}s')\n",
|
| 219 |
+
"print(f'Test accuracy: {rf_acc:.4f}')\n",
|
| 220 |
+
"print()\n",
|
| 221 |
+
"print(classification_report(y_test, rf_preds, target_names=CLASSES))"
|
| 222 |
+
]
|
| 223 |
+
},
|
| 224 |
+
{
|
| 225 |
+
"cell_type": "code",
|
| 226 |
+
"execution_count": null,
|
| 227 |
+
"metadata": {},
|
| 228 |
+
"outputs": [],
|
| 229 |
+
"source": [
|
| 230 |
+
"# 8. Define ScribblNet\n",
|
| 231 |
+
"import torch\n",
|
| 232 |
+
"import torch.nn as nn\n",
|
| 233 |
+
"import torch.nn.functional as F\n",
|
| 234 |
+
"from torch.utils.data import DataLoader, TensorDataset\n",
|
| 235 |
+
"\n",
|
| 236 |
+
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
| 237 |
+
"print(f'Using device: {device}')\n",
|
| 238 |
+
"\n",
|
| 239 |
+
"class ScribblNet(nn.Module):\n",
|
| 240 |
+
" def __init__(self, num_classes=NUM_CLASSES):\n",
|
| 241 |
+
" super().__init__()\n",
|
| 242 |
+
" self.features = nn.Sequential(\n",
|
| 243 |
+
" nn.Conv2d(1, 32, kernel_size=3, padding=1),\n",
|
| 244 |
+
" nn.BatchNorm2d(32),\n",
|
| 245 |
+
" nn.ReLU(inplace=True),\n",
|
| 246 |
+
" nn.MaxPool2d(2),\n",
|
| 247 |
+
"\n",
|
| 248 |
+
" nn.Conv2d(32, 64, kernel_size=3, padding=1),\n",
|
| 249 |
+
" nn.BatchNorm2d(64),\n",
|
| 250 |
+
" nn.ReLU(inplace=True),\n",
|
| 251 |
+
" nn.MaxPool2d(2),\n",
|
| 252 |
+
"\n",
|
| 253 |
+
" nn.Conv2d(64, 128, kernel_size=3, padding=1),\n",
|
| 254 |
+
" nn.BatchNorm2d(128),\n",
|
| 255 |
+
" nn.ReLU(inplace=True),\n",
|
| 256 |
+
" nn.MaxPool2d(2),\n",
|
| 257 |
+
" )\n",
|
| 258 |
+
" self.classifier = nn.Sequential(\n",
|
| 259 |
+
" nn.Dropout(0.5),\n",
|
| 260 |
+
" nn.Linear(128 * 3 * 3, 256),\n",
|
| 261 |
+
" nn.ReLU(inplace=True),\n",
|
| 262 |
+
" nn.Dropout(0.3),\n",
|
| 263 |
+
" nn.Linear(256, num_classes),\n",
|
| 264 |
+
" )\n",
|
| 265 |
+
"\n",
|
| 266 |
+
" def forward(self, x):\n",
|
| 267 |
+
" x = self.features(x)\n",
|
| 268 |
+
" x = x.view(x.size(0), -1)\n",
|
| 269 |
+
" return self.classifier(x)\n",
|
| 270 |
+
"\n",
|
| 271 |
+
"print('ScribblNet defined.')"
|
| 272 |
+
]
|
| 273 |
+
},
|
| 274 |
+
{
|
| 275 |
+
"cell_type": "code",
|
| 276 |
+
"execution_count": null,
|
| 277 |
+
"metadata": {},
|
| 278 |
+
"outputs": [],
|
| 279 |
+
"source": [
|
| 280 |
+
"# 9. Train ScribblNet\n",
|
| 281 |
+
"def make_loaders(X_raw, y, X_test, y_test, batch_size, fraction=1.0):\n",
|
| 282 |
+
" if fraction < 1.0:\n",
|
| 283 |
+
" n = max(1, int(len(X_raw) * fraction))\n",
|
| 284 |
+
" idx = np.random.default_rng(seed=7).permutation(len(X_raw))[:n]\n",
|
| 285 |
+
" X_raw = X_raw[idx]\n",
|
| 286 |
+
" y = y[idx]\n",
|
| 287 |
+
" def to_ds(X, labels):\n",
|
| 288 |
+
" imgs = torch.from_numpy(X.astype(np.float32) / 255.0).view(-1, 1, IMG_SIZE, IMG_SIZE)\n",
|
| 289 |
+
" return TensorDataset(imgs, torch.from_numpy(labels))\n",
|
| 290 |
+
" train_loader = DataLoader(to_ds(X_raw, y), batch_size=batch_size, shuffle=True, num_workers=2)\n",
|
| 291 |
+
" test_loader = DataLoader(to_ds(X_test, y_test), batch_size=batch_size, shuffle=False, num_workers=2)\n",
|
| 292 |
+
" return train_loader, test_loader\n",
|
| 293 |
+
"\n",
|
| 294 |
+
"def evaluate_model(model, loader):\n",
|
| 295 |
+
" model.eval()\n",
|
| 296 |
+
" all_preds, all_labels = [], []\n",
|
| 297 |
+
" with torch.no_grad():\n",
|
| 298 |
+
" for imgs, labels in loader:\n",
|
| 299 |
+
" preds = model(imgs.to(device)).argmax(dim=1).cpu().numpy()\n",
|
| 300 |
+
" all_preds.append(preds)\n",
|
| 301 |
+
" all_labels.append(labels.numpy())\n",
|
| 302 |
+
" preds = np.concatenate(all_preds)\n",
|
| 303 |
+
" labels = np.concatenate(all_labels)\n",
|
| 304 |
+
" return accuracy_score(labels, preds), preds\n",
|
| 305 |
+
"\n",
|
| 306 |
+
"train_loader, test_loader = make_loaders(X_train_raw, y_train, X_test_raw, y_test, DEEP_BATCH_SIZE)\n",
|
| 307 |
+
"\n",
|
| 308 |
+
"model = ScribblNet().to(device)\n",
|
| 309 |
+
"optimizer = torch.optim.Adam(model.parameters(), lr=DEEP_LR, weight_decay=DEEP_WEIGHT_DECAY)\n",
|
| 310 |
+
"scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=DEEP_EPOCHS)\n",
|
| 311 |
+
"criterion = nn.CrossEntropyLoss()\n",
|
| 312 |
+
"\n",
|
| 313 |
+
"best_acc = 0.0\n",
|
| 314 |
+
"history = {'loss': [], 'val_acc': []}\n",
|
| 315 |
+
"\n",
|
| 316 |
+
"print(f'Training for {DEEP_EPOCHS} epochs...')\n",
|
| 317 |
+
"for epoch in range(1, DEEP_EPOCHS + 1):\n",
|
| 318 |
+
" model.train()\n",
|
| 319 |
+
" total_loss = 0.0\n",
|
| 320 |
+
" for imgs, labels in train_loader:\n",
|
| 321 |
+
" imgs, labels = imgs.to(device), labels.to(device)\n",
|
| 322 |
+
" optimizer.zero_grad()\n",
|
| 323 |
+
" loss = criterion(model(imgs), labels)\n",
|
| 324 |
+
" loss.backward()\n",
|
| 325 |
+
" optimizer.step()\n",
|
| 326 |
+
" total_loss += loss.item()\n",
|
| 327 |
+
" avg_loss = total_loss / len(train_loader)\n",
|
| 328 |
+
" val_acc, _ = evaluate_model(model, test_loader)\n",
|
| 329 |
+
" scheduler.step()\n",
|
| 330 |
+
" history['loss'].append(avg_loss)\n",
|
| 331 |
+
" history['val_acc'].append(val_acc)\n",
|
| 332 |
+
" print(f' epoch {epoch:02d}/{DEEP_EPOCHS} loss={avg_loss:.4f} val_acc={val_acc:.4f}')\n",
|
| 333 |
+
" if val_acc > best_acc:\n",
|
| 334 |
+
" best_acc = val_acc\n",
|
| 335 |
+
" torch.save(model.state_dict(), MODELS_DIR / 'deep_model.pth')\n",
|
| 336 |
+
"\n",
|
| 337 |
+
"print(f'\\nBest test accuracy: {best_acc:.4f}')"
|
| 338 |
+
]
|
| 339 |
+
},
|
| 340 |
+
{
|
| 341 |
+
"cell_type": "code",
|
| 342 |
+
"execution_count": null,
|
| 343 |
+
"metadata": {},
|
| 344 |
+
"outputs": [],
|
| 345 |
+
"source": [
|
| 346 |
+
"# 10. Training curves\n",
|
| 347 |
+
"import matplotlib.pyplot as plt\n",
|
| 348 |
+
"\n",
|
| 349 |
+
"fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4))\n",
|
| 350 |
+
"\n",
|
| 351 |
+
"ax1.plot(range(1, len(history['loss']) + 1), history['loss'],\n",
|
| 352 |
+
" color='steelblue', marker='o', linestyle='solid', markersize=5)\n",
|
| 353 |
+
"ax1.set_xlabel('Epoch')\n",
|
| 354 |
+
"ax1.set_ylabel('Training Loss')\n",
|
| 355 |
+
"ax1.set_title('ScribblNet Training Loss')\n",
|
| 356 |
+
"ax1.grid(True, alpha=0.3)\n",
|
| 357 |
+
"\n",
|
| 358 |
+
"ax2.plot(range(1, len(history['val_acc']) + 1), history['val_acc'],\n",
|
| 359 |
+
" color='seagreen', marker='o', linestyle='solid', markersize=5)\n",
|
| 360 |
+
"ax2.set_xlabel('Epoch')\n",
|
| 361 |
+
"ax2.set_ylabel('Validation Accuracy')\n",
|
| 362 |
+
"ax2.set_title('ScribblNet Validation Accuracy')\n",
|
| 363 |
+
"ax2.grid(True, alpha=0.3)\n",
|
| 364 |
+
"\n",
|
| 365 |
+
"plt.tight_layout()\n",
|
| 366 |
+
"plt.savefig(OUTPUTS_DIR / 'deep_training_curves.png', dpi=150)\n",
|
| 367 |
+
"plt.show()"
|
| 368 |
+
]
|
| 369 |
+
},
|
| 370 |
+
{
|
| 371 |
+
"cell_type": "code",
|
| 372 |
+
"execution_count": null,
|
| 373 |
+
"metadata": {},
|
| 374 |
+
"outputs": [],
|
| 375 |
+
"source": [
|
| 376 |
+
"# 11. Confusion matrices\n",
|
| 377 |
+
"import seaborn as sns\n",
|
| 378 |
+
"\n",
|
| 379 |
+
"model.load_state_dict(torch.load(MODELS_DIR / 'deep_model.pth', map_location=device))\n",
|
| 380 |
+
"_, deep_preds = evaluate_model(model, test_loader)\n",
|
| 381 |
+
"\n",
|
| 382 |
+
"def plot_cm(y_true, y_pred, title, filename):\n",
|
| 383 |
+
" cm = confusion_matrix(y_true, y_pred, normalize='true')\n",
|
| 384 |
+
" fig, ax = plt.subplots(figsize=(10, 8))\n",
|
| 385 |
+
" sns.heatmap(cm, annot=True, fmt='.2f', xticklabels=CLASSES,\n",
|
| 386 |
+
" yticklabels=CLASSES, cmap='Blues', ax=ax, linewidths=0.5)\n",
|
| 387 |
+
" ax.set_xlabel('Predicted')\n",
|
| 388 |
+
" ax.set_ylabel('True')\n",
|
| 389 |
+
" ax.set_title(title)\n",
|
| 390 |
+
" plt.xticks(rotation=45, ha='right')\n",
|
| 391 |
+
" plt.tight_layout()\n",
|
| 392 |
+
" plt.savefig(OUTPUTS_DIR / filename, dpi=150)\n",
|
| 393 |
+
" plt.show()\n",
|
| 394 |
+
"\n",
|
| 395 |
+
"plot_cm(y_test, deep_preds, 'ScribblNet Confusion Matrix', 'deep_confusion_matrix.png')\n",
|
| 396 |
+
"plot_cm(y_test, rf_preds, 'Random Forest Confusion Matrix', 'classical_confusion_matrix.png')"
|
| 397 |
+
]
|
| 398 |
+
},
|
| 399 |
+
{
|
| 400 |
+
"cell_type": "code",
|
| 401 |
+
"execution_count": null,
|
| 402 |
+
"metadata": {},
|
| 403 |
+
"outputs": [],
|
| 404 |
+
"source": [
|
| 405 |
+
"# 12. Model comparison bar chart\n",
|
| 406 |
+
"names = ['Naive Baseline', 'Random Forest', 'ScribblNet']\n",
|
| 407 |
+
"accs = [naive_acc, rf_acc, best_acc]\n",
|
| 408 |
+
"\n",
|
| 409 |
+
"fig, ax = plt.subplots(figsize=(7, 4))\n",
|
| 410 |
+
"bars = ax.bar(names, accs, color=['#94a3b8', '#60a5fa', '#34d399'], width=0.5)\n",
|
| 411 |
+
"ax.set_ylim(0, 1)\n",
|
| 412 |
+
"ax.set_ylabel('Test Accuracy')\n",
|
| 413 |
+
"ax.set_title('Model Comparison')\n",
|
| 414 |
+
"for bar, acc in zip(bars, accs):\n",
|
| 415 |
+
" ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,\n",
|
| 416 |
+
" f'{acc:.3f}', ha='center', fontsize=12)\n",
|
| 417 |
+
"ax.grid(True, axis='y', alpha=0.3)\n",
|
| 418 |
+
"plt.tight_layout()\n",
|
| 419 |
+
"plt.savefig(OUTPUTS_DIR / 'model_comparison.png', dpi=150)\n",
|
| 420 |
+
"plt.show()"
|
| 421 |
+
]
|
| 422 |
+
},
|
| 423 |
+
{
|
| 424 |
+
"cell_type": "code",
|
| 425 |
+
"execution_count": null,
|
| 426 |
+
"metadata": {},
|
| 427 |
+
"outputs": [],
|
| 428 |
+
"source": [
|
| 429 |
+
"# 13. Experiment: training set size sensitivity\n",
|
| 430 |
+
"# Sweeps over fractions of training data for both models.\n",
|
| 431 |
+
"# Shows how data volume affects each approach.\n",
|
| 432 |
+
"print('Running experiment: training size sensitivity sweep...')\n",
|
| 433 |
+
"\n",
|
| 434 |
+
"deep_accs_exp = []\n",
|
| 435 |
+
"rf_accs_exp = []\n",
|
| 436 |
+
"n_samples_list = []\n",
|
| 437 |
+
"\n",
|
| 438 |
+
"for frac in EXPERIMENT_FRACTIONS:\n",
|
| 439 |
+
" n = int(len(X_train_raw) * frac)\n",
|
| 440 |
+
" n_samples_list.append(n)\n",
|
| 441 |
+
" print(f'\\nFraction {frac:.0%} (n={n})')\n",
|
| 442 |
+
"\n",
|
| 443 |
+
" # Deep model\n",
|
| 444 |
+
" tr_loader, te_loader = make_loaders(X_train_raw, y_train, X_test_raw, y_test,\n",
|
| 445 |
+
" DEEP_BATCH_SIZE, fraction=frac)\n",
|
| 446 |
+
" exp_model = ScribblNet().to(device)\n",
|
| 447 |
+
" exp_opt = torch.optim.Adam(exp_model.parameters(), lr=DEEP_LR, weight_decay=DEEP_WEIGHT_DECAY)\n",
|
| 448 |
+
" exp_sched = torch.optim.lr_scheduler.CosineAnnealingLR(exp_opt, T_max=EXPERIMENT_EPOCHS)\n",
|
| 449 |
+
" exp_model.train()\n",
|
| 450 |
+
" for ep in range(EXPERIMENT_EPOCHS):\n",
|
| 451 |
+
" for imgs, labels in tr_loader:\n",
|
| 452 |
+
" imgs, labels = imgs.to(device), labels.to(device)\n",
|
| 453 |
+
" exp_opt.zero_grad()\n",
|
| 454 |
+
" criterion(exp_model(imgs), labels).backward()\n",
|
| 455 |
+
" exp_opt.step()\n",
|
| 456 |
+
" exp_sched.step()\n",
|
| 457 |
+
" acc_deep, _ = evaluate_model(exp_model, te_loader)\n",
|
| 458 |
+
" deep_accs_exp.append(acc_deep)\n",
|
| 459 |
+
" print(f' ScribblNet acc: {acc_deep:.4f}')\n",
|
| 460 |
+
"\n",
|
| 461 |
+
" # Random Forest\n",
|
| 462 |
+
" idx = np.random.default_rng(seed=42).permutation(len(X_train_hog))[:n]\n",
|
| 463 |
+
" sc = StandardScaler()\n",
|
| 464 |
+
" X_tr_exp = sc.fit_transform(X_train_hog[idx])\n",
|
| 465 |
+
" X_te_exp = sc.transform(X_test_hog)\n",
|
| 466 |
+
" rf_exp = RandomForestClassifier(n_estimators=100, n_jobs=-1, random_state=42)\n",
|
| 467 |
+
" rf_exp.fit(X_tr_exp, y_train[idx])\n",
|
| 468 |
+
" acc_rf = accuracy_score(y_test, rf_exp.predict(X_te_exp))\n",
|
| 469 |
+
" rf_accs_exp.append(acc_rf)\n",
|
| 470 |
+
" print(f' Random Forest acc: {acc_rf:.4f}')\n",
|
| 471 |
+
"\n",
|
| 472 |
+
"# Plot\n",
|
| 473 |
+
"fig, ax = plt.subplots(figsize=(8, 5))\n",
|
| 474 |
+
"ax.plot(n_samples_list, deep_accs_exp, marker='o', linestyle='solid',\n",
|
| 475 |
+
" label='ScribblNet (CNN)', linewidth=2, markersize=7)\n",
|
| 476 |
+
"ax.plot(n_samples_list, rf_accs_exp, marker='s', linestyle='dashed',\n",
|
| 477 |
+
" label='Random Forest (HOG)', linewidth=2, markersize=7)\n",
|
| 478 |
+
"ax.set_xlabel('Training samples')\n",
|
| 479 |
+
"ax.set_ylabel('Test accuracy')\n",
|
| 480 |
+
"ax.set_title('Training Set Size Sensitivity')\n",
|
| 481 |
+
"ax.legend()\n",
|
| 482 |
+
"ax.grid(True, alpha=0.3)\n",
|
| 483 |
+
"ax.set_ylim(0, 1)\n",
|
| 484 |
+
"plt.tight_layout()\n",
|
| 485 |
+
"plt.savefig(OUTPUTS_DIR / 'experiment_sensitivity.png', dpi=150)\n",
|
| 486 |
+
"plt.show()\n",
|
| 487 |
+
"\n",
|
| 488 |
+
"import json\n",
|
| 489 |
+
"with open(OUTPUTS_DIR / 'experiment_results.json', 'w') as f:\n",
|
| 490 |
+
" json.dump({'fractions': EXPERIMENT_FRACTIONS, 'n_samples': n_samples_list,\n",
|
| 491 |
+
" 'deep_accs': deep_accs_exp, 'rf_accs': rf_accs_exp}, f, indent=2)\n",
|
| 492 |
+
"\n",
|
| 493 |
+
"print('Experiment complete.')"
|
| 494 |
+
]
|
| 495 |
+
},
|
| 496 |
+
{
|
| 497 |
+
"cell_type": "code",
|
| 498 |
+
"execution_count": null,
|
| 499 |
+
"metadata": {},
|
| 500 |
+
"outputs": [],
|
| 501 |
+
"source": [
|
| 502 |
+
"# 14. Save everything to Google Drive\n",
|
| 503 |
+
"import shutil\n",
|
| 504 |
+
"\n",
|
| 505 |
+
"files_to_save = [\n",
|
| 506 |
+
" (MODELS_DIR / 'deep_model.pth', 'deep_model.pth'),\n",
|
| 507 |
+
" (MODELS_DIR / 'classical_model.pkl', 'classical_model.pkl'),\n",
|
| 508 |
+
" (MODELS_DIR / 'naive_model.pkl', 'naive_model.pkl'),\n",
|
| 509 |
+
" (OUTPUTS_DIR / 'deep_training_curves.png', 'deep_training_curves.png'),\n",
|
| 510 |
+
" (OUTPUTS_DIR / 'deep_confusion_matrix.png', 'deep_confusion_matrix.png'),\n",
|
| 511 |
+
" (OUTPUTS_DIR / 'classical_confusion_matrix.png', 'classical_confusion_matrix.png'),\n",
|
| 512 |
+
" (OUTPUTS_DIR / 'model_comparison.png', 'model_comparison.png'),\n",
|
| 513 |
+
" (OUTPUTS_DIR / 'experiment_sensitivity.png', 'experiment_sensitivity.png'),\n",
|
| 514 |
+
" (OUTPUTS_DIR / 'experiment_results.json', 'experiment_results.json'),\n",
|
| 515 |
+
"]\n",
|
| 516 |
+
"\n",
|
| 517 |
+
"for src, name in files_to_save:\n",
|
| 518 |
+
" dest = Path(SAVE_DIR) / name\n",
|
| 519 |
+
" shutil.copy(src, dest)\n",
|
| 520 |
+
" print(f'Saved {name}')\n",
|
| 521 |
+
"\n",
|
| 522 |
+
"print(f'\\nAll files saved to Google Drive at: {SAVE_DIR}')\n",
|
| 523 |
+
"print('Download deep_model.pth and put it in your local models/ folder.')"
|
| 524 |
+
]
|
| 525 |
+
},
|
| 526 |
+
{
|
| 527 |
+
"cell_type": "code",
|
| 528 |
+
"execution_count": null,
|
| 529 |
+
"metadata": {},
|
| 530 |
+
"outputs": [],
|
| 531 |
+
"source": [
|
| 532 |
+
"# 15. Results summary\n",
|
| 533 |
+
"print('Results Summary')\n",
|
| 534 |
+
"print(f' Naive baseline: {naive_acc:.4f}')\n",
|
| 535 |
+
"print(f' Random Forest: {rf_acc:.4f}')\n",
|
| 536 |
+
"print(f' ScribblNet: {best_acc:.4f}')\n",
|
| 537 |
+
"\n",
|
| 538 |
+
"with open(OUTPUTS_DIR / 'results_summary.json', 'w') as f:\n",
|
| 539 |
+
" json.dump({'naive_accuracy': naive_acc, 'classical_accuracy': rf_acc,\n",
|
| 540 |
+
" 'deep_accuracy': best_acc}, f, indent=2)\n",
|
| 541 |
+
"shutil.copy(OUTPUTS_DIR / 'results_summary.json', Path(SAVE_DIR) / 'results_summary.json')\n",
|
| 542 |
+
"print('results_summary.json saved to Drive.')"
|
| 543 |
+
]
|
| 544 |
+
}
|
| 545 |
+
],
|
| 546 |
+
"metadata": {
|
| 547 |
+
"accelerator": "GPU",
|
| 548 |
+
"colab": {
|
| 549 |
+
"gpuType": "T4",
|
| 550 |
+
"provenance": []
|
| 551 |
+
},
|
| 552 |
+
"kernelspec": {
|
| 553 |
+
"display_name": "Python 3",
|
| 554 |
+
"language": "python",
|
| 555 |
+
"name": "python3"
|
| 556 |
+
},
|
| 557 |
+
"language_info": {
|
| 558 |
+
"name": "python",
|
| 559 |
+
"version": "3.10.0"
|
| 560 |
+
}
|
| 561 |
+
},
|
| 562 |
+
"nbformat": 4,
|
| 563 |
+
"nbformat_minor": 5
|
| 564 |
+
}
|
notebooks/exploration.ipynb
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# ScribblBot — Data Exploration"
|
| 8 |
+
]
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"cell_type": "code",
|
| 12 |
+
"execution_count": null,
|
| 13 |
+
"metadata": {},
|
| 14 |
+
"outputs": [],
|
| 15 |
+
"source": [
|
| 16 |
+
"import sys\n",
|
| 17 |
+
"sys.path.insert(0, '..')\n",
|
| 18 |
+
"\n",
|
| 19 |
+
"import numpy as np\n",
|
| 20 |
+
"import matplotlib.pyplot as plt\n",
|
| 21 |
+
"from config import CLASSES, RAW_DIR\n",
|
| 22 |
+
"\n",
|
| 23 |
+
"cls = 'cat'\n",
|
| 24 |
+
"data = np.load(RAW_DIR / f'{cls}.npy')\n",
|
| 25 |
+
"print(f'{cls}: {data.shape}')\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"fig, axes = plt.subplots(2, 8, figsize=(16, 4))\n",
|
| 28 |
+
"for i, ax in enumerate(axes.flat):\n",
|
| 29 |
+
" ax.imshow(data[i].reshape(28, 28), cmap='gray_r')\n",
|
| 30 |
+
" ax.axis('off')\n",
|
| 31 |
+
"plt.suptitle(cls)\n",
|
| 32 |
+
"plt.tight_layout()\n",
|
| 33 |
+
"plt.show()"
|
| 34 |
+
]
|
| 35 |
+
}
|
| 36 |
+
],
|
| 37 |
+
"metadata": {
|
| 38 |
+
"kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" },
|
| 39 |
+
"language_info": { "name": "python", "version": "3.10.0" }
|
| 40 |
+
},
|
| 41 |
+
"nbformat": 4,
|
| 42 |
+
"nbformat_minor": 5
|
| 43 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.1.0
|
| 2 |
+
torchvision>=0.16.0
|
| 3 |
+
scikit-learn>=1.3.0
|
| 4 |
+
scikit-image>=0.21.0
|
| 5 |
+
numpy>=1.24.0
|
| 6 |
+
pandas>=2.0.0
|
| 7 |
+
matplotlib>=3.7.0
|
| 8 |
+
seaborn>=0.12.0
|
| 9 |
+
Pillow>=9.5.0
|
| 10 |
+
tqdm>=4.65.0
|
| 11 |
+
joblib>=1.3.0
|
| 12 |
+
gradio>=5.0.0
|
| 13 |
+
huggingface_hub>=0.20
|
| 14 |
+
jinja2>=3.1.4
|
| 15 |
+
requests>=2.31.0
|
scripts/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# scripts package
|
scripts/build_features.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
build_features.py – Load raw Quick Draw .npy files, split into train/test,
|
| 3 |
+
and extract HOG features for the classical ML pipeline.
|
| 4 |
+
|
| 5 |
+
Saved artefacts (under data/processed/):
|
| 6 |
+
X_train_raw.npy, y_train.npy -> pixel arrays for deep model
|
| 7 |
+
X_test_raw.npy, y_test.npy -> pixel arrays for evaluation
|
| 8 |
+
X_train_hog.npy -> HOG feature matrix for Random Forest
|
| 9 |
+
X_test_hog.npy
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
python scripts/build_features.py
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import sys
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
from skimage.feature import hog
|
| 20 |
+
|
| 21 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 22 |
+
|
| 23 |
+
from config import (
|
| 24 |
+
CLASSES,
|
| 25 |
+
RAW_DIR,
|
| 26 |
+
PROCESSED_DIR,
|
| 27 |
+
TRAIN_SAMPLES_PER_CLASS,
|
| 28 |
+
TEST_SAMPLES_PER_CLASS,
|
| 29 |
+
IMG_SIZE,
|
| 30 |
+
HOG_ORIENTATIONS,
|
| 31 |
+
HOG_PIXELS_PER_CELL,
|
| 32 |
+
HOG_CELLS_PER_BLOCK,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def load_class_data(cls: str, n_train: int, n_test: int) -> tuple[np.ndarray, np.ndarray]:
|
| 37 |
+
"""Load and slice pixel data for a single class.
|
| 38 |
+
|
| 39 |
+
The Quick Draw .npy files contain rows of 784-element uint8 vectors
|
| 40 |
+
(28×28 flattened, pixel values 0–255, white stroke on black background).
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
cls: Class name.
|
| 44 |
+
n_train: Number of training samples to keep.
|
| 45 |
+
n_test: Number of test samples to keep.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Tuple of (train_pixels, test_pixels) each shaped (n, 784).
|
| 49 |
+
"""
|
| 50 |
+
path = RAW_DIR / f"{cls}.npy"
|
| 51 |
+
if not path.exists():
|
| 52 |
+
raise FileNotFoundError(
|
| 53 |
+
f"Missing {path}. Run scripts/make_dataset.py first."
|
| 54 |
+
)
|
| 55 |
+
data = np.load(path, mmap_mode="r") # memory mapped for large files
|
| 56 |
+
|
| 57 |
+
# Shuffle deterministically so splits are reproducible
|
| 58 |
+
rng = np.random.default_rng(seed=42)
|
| 59 |
+
indices = rng.permutation(len(data))[: n_train + n_test]
|
| 60 |
+
data = data[indices]
|
| 61 |
+
|
| 62 |
+
return data[:n_train], data[n_train : n_train + n_test]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def extract_hog_features(pixel_matrix: np.ndarray) -> np.ndarray:
|
| 66 |
+
"""Compute HOG descriptors for a batch of flat pixel vectors.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
pixel_matrix: Array of shape (N, 784), dtype uint8.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
Feature matrix of shape (N, D) where D is the HOG descriptor length.
|
| 73 |
+
"""
|
| 74 |
+
features = []
|
| 75 |
+
for row in pixel_matrix:
|
| 76 |
+
img = row.reshape(IMG_SIZE, IMG_SIZE)
|
| 77 |
+
desc = hog(
|
| 78 |
+
img,
|
| 79 |
+
orientations=HOG_ORIENTATIONS,
|
| 80 |
+
pixels_per_cell=HOG_PIXELS_PER_CELL,
|
| 81 |
+
cells_per_block=HOG_CELLS_PER_BLOCK,
|
| 82 |
+
visualize=False,
|
| 83 |
+
channel_axis=None,
|
| 84 |
+
)
|
| 85 |
+
features.append(desc)
|
| 86 |
+
return np.array(features, dtype=np.float32)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def build_splits() -> None:
|
| 90 |
+
"""Assemble train/test raw arrays and labels from all classes."""
|
| 91 |
+
train_raws, test_raws = [], []
|
| 92 |
+
train_labels, test_labels = [], []
|
| 93 |
+
|
| 94 |
+
print("Loading raw data …")
|
| 95 |
+
for label_idx, cls in enumerate(CLASSES):
|
| 96 |
+
print(f" {cls} ({label_idx + 1}/{len(CLASSES)})")
|
| 97 |
+
tr, te = load_class_data(cls, TRAIN_SAMPLES_PER_CLASS, TEST_SAMPLES_PER_CLASS)
|
| 98 |
+
train_raws.append(tr)
|
| 99 |
+
test_raws.append(te)
|
| 100 |
+
train_labels.append(np.full(len(tr), label_idx, dtype=np.int64))
|
| 101 |
+
test_labels.append(np.full(len(te), label_idx, dtype=np.int64))
|
| 102 |
+
|
| 103 |
+
X_train_raw = np.concatenate(train_raws)
|
| 104 |
+
X_test_raw = np.concatenate(test_raws)
|
| 105 |
+
y_train = np.concatenate(train_labels)
|
| 106 |
+
y_test = np.concatenate(test_labels)
|
| 107 |
+
|
| 108 |
+
# Shuffle training set
|
| 109 |
+
rng = np.random.default_rng(seed=0)
|
| 110 |
+
perm = rng.permutation(len(X_train_raw))
|
| 111 |
+
X_train_raw = X_train_raw[perm]
|
| 112 |
+
y_train = y_train[perm]
|
| 113 |
+
|
| 114 |
+
PROCESSED_DIR.mkdir(parents=True, exist_ok=True)
|
| 115 |
+
np.save(PROCESSED_DIR / "X_train_raw.npy", X_train_raw)
|
| 116 |
+
np.save(PROCESSED_DIR / "X_test_raw.npy", X_test_raw)
|
| 117 |
+
np.save(PROCESSED_DIR / "y_train.npy", y_train)
|
| 118 |
+
np.save(PROCESSED_DIR / "y_test.npy", y_test)
|
| 119 |
+
print(f"\nSaved raw splits → train {X_train_raw.shape}, test {X_test_raw.shape}")
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def build_hog_features() -> None:
|
| 123 |
+
"""Extract HOG features from saved raw arrays."""
|
| 124 |
+
X_train_raw = np.load(PROCESSED_DIR / "X_train_raw.npy")
|
| 125 |
+
X_test_raw = np.load(PROCESSED_DIR / "X_test_raw.npy")
|
| 126 |
+
|
| 127 |
+
print("Extracting HOG features (train) …")
|
| 128 |
+
X_train_hog = extract_hog_features(X_train_raw)
|
| 129 |
+
print("Extracting HOG features (test) …")
|
| 130 |
+
X_test_hog = extract_hog_features(X_test_raw)
|
| 131 |
+
|
| 132 |
+
np.save(PROCESSED_DIR / "X_train_hog.npy", X_train_hog)
|
| 133 |
+
np.save(PROCESSED_DIR / "X_test_hog.npy", X_test_hog)
|
| 134 |
+
print(f"Saved HOG features → train {X_train_hog.shape}, test {X_test_hog.shape}")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def build_all() -> None:
|
| 138 |
+
"""Run the complete feature building pipeline."""
|
| 139 |
+
build_splits()
|
| 140 |
+
build_hog_features()
|
| 141 |
+
print("\nFeature pipeline complete.")
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
if __name__ == "__main__":
|
| 145 |
+
build_all()
|
scripts/make_dataset.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
make_dataset.py – Download Quick Draw .npy files for all configured classes.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python scripts/make_dataset.py
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import sys
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import urllib.request
|
| 11 |
+
|
| 12 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 13 |
+
|
| 14 |
+
from config import CLASSES, RAW_DIR, QUICKDRAW_URL
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def download_class(cls: str, dest_dir: Path, force: bool = False) -> Path:
|
| 18 |
+
"""Download the numpy bitmap file for a single Quick Draw class.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
cls: Class name matching a Quick Draw category (e.g. 'cat').
|
| 22 |
+
dest_dir: Directory to write the .npy file into.
|
| 23 |
+
force: Redownload even if the file already exists.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
Path to the downloaded file.
|
| 27 |
+
"""
|
| 28 |
+
url = QUICKDRAW_URL.format(cls=cls.replace(" ", "%20"))
|
| 29 |
+
dest = dest_dir / f"{cls}.npy"
|
| 30 |
+
if dest.exists() and not force:
|
| 31 |
+
print(f" [skip] {cls}.npy already exists")
|
| 32 |
+
return dest
|
| 33 |
+
|
| 34 |
+
print(f" [down] {cls}.npy -> {url}")
|
| 35 |
+
|
| 36 |
+
def _reporthook(block_num: int, block_size: int, total_size: int) -> None:
|
| 37 |
+
downloaded = block_num * block_size
|
| 38 |
+
pct = min(100, downloaded * 100 // total_size) if total_size > 0 else 0
|
| 39 |
+
print(f"\r {pct:3d}%", end="", flush=True)
|
| 40 |
+
|
| 41 |
+
urllib.request.urlretrieve(url, dest, reporthook=_reporthook)
|
| 42 |
+
print()
|
| 43 |
+
return dest
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def download_all(force: bool = False) -> None:
|
| 47 |
+
"""Download .npy files for every class listed in config.CLASSES.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
force: Redownload files that already exist on disk.
|
| 51 |
+
"""
|
| 52 |
+
RAW_DIR.mkdir(parents=True, exist_ok=True)
|
| 53 |
+
print(f"Downloading {len(CLASSES)} classes to {RAW_DIR} …\n")
|
| 54 |
+
for cls in CLASSES:
|
| 55 |
+
download_class(cls, RAW_DIR, force=force)
|
| 56 |
+
print("\nAll downloads complete.")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
if __name__ == "__main__":
|
| 60 |
+
import argparse
|
| 61 |
+
|
| 62 |
+
parser = argparse.ArgumentParser(description="Download Quick Draw dataset")
|
| 63 |
+
parser.add_argument("--force", action="store_true", help="Redownload existing files")
|
| 64 |
+
args = parser.parse_args()
|
| 65 |
+
download_all(force=args.force)
|
scripts/model.py
ADDED
|
@@ -0,0 +1,616 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
model.py – Define, train, and evaluate all three models:
|
| 3 |
+
1. Naive baseline (majority class classifier)
|
| 4 |
+
2. Classical ML (Random Forest on HOG features)
|
| 5 |
+
3. Deep learning (ScribblNet CNN)
|
| 6 |
+
|
| 7 |
+
Also runs the training size sensitivity experiment and saves results/plots.
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
python scripts/model.py
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import json
|
| 14 |
+
import sys
|
| 15 |
+
import time
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import Any
|
| 18 |
+
|
| 19 |
+
import joblib
|
| 20 |
+
import matplotlib
|
| 21 |
+
matplotlib.use("Agg") # headless backend
|
| 22 |
+
import matplotlib.pyplot as plt
|
| 23 |
+
import numpy as np
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 27 |
+
from sklearn.metrics import (
|
| 28 |
+
accuracy_score,
|
| 29 |
+
classification_report,
|
| 30 |
+
confusion_matrix,
|
| 31 |
+
)
|
| 32 |
+
from sklearn.preprocessing import StandardScaler
|
| 33 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 34 |
+
import seaborn as sns
|
| 35 |
+
|
| 36 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 37 |
+
|
| 38 |
+
from config import (
|
| 39 |
+
CLASSES,
|
| 40 |
+
MODELS_DIR,
|
| 41 |
+
OUTPUTS_DIR,
|
| 42 |
+
PROCESSED_DIR,
|
| 43 |
+
NUM_CLASSES,
|
| 44 |
+
RF_MAX_DEPTH,
|
| 45 |
+
RF_N_ESTIMATORS,
|
| 46 |
+
DEEP_BATCH_SIZE,
|
| 47 |
+
DEEP_EPOCHS,
|
| 48 |
+
DEEP_LR,
|
| 49 |
+
DEEP_WEIGHT_DECAY,
|
| 50 |
+
IMG_SIZE,
|
| 51 |
+
EXPERIMENT_FRACTIONS,
|
| 52 |
+
EXPERIMENT_EPOCHS,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# Utility
|
| 57 |
+
|
| 58 |
+
def get_device() -> torch.device:
|
| 59 |
+
"""Return the best available torch device (MPS > CUDA > CPU)."""
|
| 60 |
+
if torch.backends.mps.is_available():
|
| 61 |
+
return torch.device("mps")
|
| 62 |
+
if torch.cuda.is_available():
|
| 63 |
+
return torch.device("cuda")
|
| 64 |
+
return torch.device("cpu")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def load_processed_data() -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
| 68 |
+
"""Load all processed arrays from disk.
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
X_train_raw, X_test_raw, y_train, y_test, X_train_hog, X_test_hog
|
| 72 |
+
"""
|
| 73 |
+
X_train_raw = np.load(PROCESSED_DIR / "X_train_raw.npy")
|
| 74 |
+
X_test_raw = np.load(PROCESSED_DIR / "X_test_raw.npy")
|
| 75 |
+
y_train = np.load(PROCESSED_DIR / "y_train.npy")
|
| 76 |
+
y_test = np.load(PROCESSED_DIR / "y_test.npy")
|
| 77 |
+
X_train_hog = np.load(PROCESSED_DIR / "X_train_hog.npy")
|
| 78 |
+
X_test_hog = np.load(PROCESSED_DIR / "X_test_hog.npy")
|
| 79 |
+
return X_train_raw, X_test_raw, y_train, y_test, X_train_hog, X_test_hog
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# 1. Naive Baseline
|
| 83 |
+
|
| 84 |
+
class MajorityClassifier:
|
| 85 |
+
"""Naive baseline: always predicts the most frequent class in training."""
|
| 86 |
+
|
| 87 |
+
def __init__(self) -> None:
|
| 88 |
+
self.majority_class: int = 0
|
| 89 |
+
|
| 90 |
+
def fit(self, y: np.ndarray) -> "MajorityClassifier":
|
| 91 |
+
"""Fit by finding the majority class label.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
y: 1-D array of integer class labels.
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
self
|
| 98 |
+
"""
|
| 99 |
+
counts = np.bincount(y)
|
| 100 |
+
self.majority_class = int(np.argmax(counts))
|
| 101 |
+
return self
|
| 102 |
+
|
| 103 |
+
def predict(self, n_samples: int) -> np.ndarray:
|
| 104 |
+
"""Return the majority class repeated n_samples times.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
n_samples: Number of predictions to generate.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
Array of length n_samples, all equal to majority_class.
|
| 111 |
+
"""
|
| 112 |
+
return np.full(n_samples, self.majority_class, dtype=np.int64)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def train_naive(y_train: np.ndarray, y_test: np.ndarray) -> dict[str, Any]:
|
| 116 |
+
"""Train and evaluate the majority class baseline.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
y_train: Training labels.
|
| 120 |
+
y_test: Test labels.
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
Dictionary of evaluation metrics.
|
| 124 |
+
"""
|
| 125 |
+
print(f"\nNaive Baseline")
|
| 126 |
+
clf = MajorityClassifier().fit(y_train)
|
| 127 |
+
preds = clf.predict(len(y_test))
|
| 128 |
+
acc = accuracy_score(y_test, preds)
|
| 129 |
+
print(f" Majority class: {CLASSES[clf.majority_class]}")
|
| 130 |
+
print(f" Test accuracy: {acc:.4f}")
|
| 131 |
+
|
| 132 |
+
model_data = {"majority_class": clf.majority_class, "accuracy": acc}
|
| 133 |
+
joblib.dump(model_data, MODELS_DIR / "naive_model.pkl")
|
| 134 |
+
|
| 135 |
+
return {"model": "naive", "accuracy": acc}
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# 2. Classical ML
|
| 139 |
+
|
| 140 |
+
def train_classical(
|
| 141 |
+
X_train_hog: np.ndarray,
|
| 142 |
+
X_test_hog: np.ndarray,
|
| 143 |
+
y_train: np.ndarray,
|
| 144 |
+
y_test: np.ndarray,
|
| 145 |
+
) -> dict[str, Any]:
|
| 146 |
+
"""Train Random Forest on HOG features and evaluate.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
X_train_hog: Training HOG feature matrix.
|
| 150 |
+
X_test_hog: Test HOG feature matrix.
|
| 151 |
+
y_train: Training labels.
|
| 152 |
+
y_test: Test labels.
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
Dictionary of evaluation metrics.
|
| 156 |
+
"""
|
| 157 |
+
print(f"\nClassical ML (Random Forest on HOG)")
|
| 158 |
+
|
| 159 |
+
# Standardise features
|
| 160 |
+
scaler = StandardScaler()
|
| 161 |
+
X_tr = scaler.fit_transform(X_train_hog)
|
| 162 |
+
X_te = scaler.transform(X_test_hog)
|
| 163 |
+
|
| 164 |
+
clf = RandomForestClassifier(
|
| 165 |
+
n_estimators=RF_N_ESTIMATORS,
|
| 166 |
+
max_depth=RF_MAX_DEPTH,
|
| 167 |
+
n_jobs=-1,
|
| 168 |
+
random_state=42,
|
| 169 |
+
)
|
| 170 |
+
t0 = time.time()
|
| 171 |
+
clf.fit(X_tr, y_train)
|
| 172 |
+
elapsed = time.time() - t0
|
| 173 |
+
|
| 174 |
+
preds = clf.predict(X_te)
|
| 175 |
+
acc = accuracy_score(y_test, preds)
|
| 176 |
+
report = classification_report(y_test, preds, target_names=CLASSES)
|
| 177 |
+
|
| 178 |
+
print(f" Training time: {elapsed:.1f}s")
|
| 179 |
+
print(f" Test accuracy: {acc:.4f}")
|
| 180 |
+
print(f"\n{report}")
|
| 181 |
+
|
| 182 |
+
joblib.dump({"clf": clf, "scaler": scaler}, MODELS_DIR / "classical_model.pkl")
|
| 183 |
+
_save_confusion_matrix(y_test, preds, "classical_confusion_matrix.png")
|
| 184 |
+
|
| 185 |
+
return {"model": "classical", "accuracy": acc, "training_time_s": elapsed}
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# 3. Deep Model
|
| 189 |
+
|
| 190 |
+
class ScribblNet(nn.Module):
|
| 191 |
+
"""Lightweight CNN for 28×28 grayscale sketch classification.
|
| 192 |
+
|
| 193 |
+
Architecture:
|
| 194 |
+
3 × (Conv2d → BatchNorm → ReLU → MaxPool)
|
| 195 |
+
Dropout → FC(1152→256) → ReLU → Dropout → FC(256→num_classes)
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
def __init__(self, num_classes: int = NUM_CLASSES) -> None:
|
| 199 |
+
super().__init__()
|
| 200 |
+
self.features = nn.Sequential(
|
| 201 |
+
nn.Conv2d(1, 32, kernel_size=3, padding=1),
|
| 202 |
+
nn.BatchNorm2d(32),
|
| 203 |
+
nn.ReLU(inplace=True),
|
| 204 |
+
nn.MaxPool2d(2),
|
| 205 |
+
|
| 206 |
+
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
| 207 |
+
nn.BatchNorm2d(64),
|
| 208 |
+
nn.ReLU(inplace=True),
|
| 209 |
+
nn.MaxPool2d(2),
|
| 210 |
+
|
| 211 |
+
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
| 212 |
+
nn.BatchNorm2d(128),
|
| 213 |
+
nn.ReLU(inplace=True),
|
| 214 |
+
nn.MaxPool2d(2),
|
| 215 |
+
)
|
| 216 |
+
# 28→14→7→3 ∴ feature map is 128×3×3 = 1152
|
| 217 |
+
self.classifier = nn.Sequential(
|
| 218 |
+
nn.Dropout(0.5),
|
| 219 |
+
nn.Linear(128 * 3 * 3, 256),
|
| 220 |
+
nn.ReLU(inplace=True),
|
| 221 |
+
nn.Dropout(0.3),
|
| 222 |
+
nn.Linear(256, num_classes),
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 226 |
+
"""Forward pass.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
x: Tensor of shape (B, 1, 28, 28), values in [0, 1].
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
Logits tensor of shape (B, num_classes).
|
| 233 |
+
"""
|
| 234 |
+
x = self.features(x)
|
| 235 |
+
x = x.view(x.size(0), -1)
|
| 236 |
+
return self.classifier(x)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def make_dataloaders(
|
| 240 |
+
X_raw: np.ndarray,
|
| 241 |
+
y: np.ndarray,
|
| 242 |
+
X_test_raw: np.ndarray,
|
| 243 |
+
y_test: np.ndarray,
|
| 244 |
+
batch_size: int = DEEP_BATCH_SIZE,
|
| 245 |
+
train_fraction: float = 1.0,
|
| 246 |
+
) -> tuple[DataLoader, DataLoader]:
|
| 247 |
+
"""Build PyTorch DataLoaders from raw pixel arrays.
|
| 248 |
+
|
| 249 |
+
Pixel values are normalised to [0, 1]. Training set can be subsampled
|
| 250 |
+
via train_fraction for the sensitivity experiment.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
X_raw: Training pixel array (N, 784), uint8.
|
| 254 |
+
y: Training labels.
|
| 255 |
+
X_test_raw: Test pixel array.
|
| 256 |
+
y_test: Test labels.
|
| 257 |
+
batch_size: Minibatch size.
|
| 258 |
+
train_fraction: Fraction of training samples to use (0 < f ≤ 1).
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
(train_loader, test_loader)
|
| 262 |
+
"""
|
| 263 |
+
if train_fraction < 1.0:
|
| 264 |
+
n = max(1, int(len(X_raw) * train_fraction))
|
| 265 |
+
idx = np.random.default_rng(seed=7).permutation(len(X_raw))[:n]
|
| 266 |
+
X_raw = X_raw[idx]
|
| 267 |
+
y = y[idx]
|
| 268 |
+
|
| 269 |
+
def _to_tensor(X: np.ndarray, labels: np.ndarray) -> TensorDataset:
|
| 270 |
+
imgs = torch.from_numpy(X.astype(np.float32) / 255.0)
|
| 271 |
+
imgs = imgs.view(-1, 1, IMG_SIZE, IMG_SIZE)
|
| 272 |
+
return TensorDataset(imgs, torch.from_numpy(labels))
|
| 273 |
+
|
| 274 |
+
train_ds = _to_tensor(X_raw, y)
|
| 275 |
+
test_ds = _to_tensor(X_test_raw, y_test)
|
| 276 |
+
|
| 277 |
+
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
|
| 278 |
+
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0)
|
| 279 |
+
return train_loader, test_loader
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def train_one_epoch(
|
| 283 |
+
model: nn.Module,
|
| 284 |
+
loader: DataLoader,
|
| 285 |
+
optimizer: torch.optim.Optimizer,
|
| 286 |
+
criterion: nn.Module,
|
| 287 |
+
device: torch.device,
|
| 288 |
+
) -> float:
|
| 289 |
+
"""Run one training epoch and return average loss.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
model: ScribblNet instance.
|
| 293 |
+
loader: Training DataLoader.
|
| 294 |
+
optimizer: Optimiser (Adam).
|
| 295 |
+
criterion: Loss function (CrossEntropyLoss).
|
| 296 |
+
device: Torch device.
|
| 297 |
+
|
| 298 |
+
Returns:
|
| 299 |
+
Mean loss over all minibatches.
|
| 300 |
+
"""
|
| 301 |
+
model.train()
|
| 302 |
+
total_loss = 0.0
|
| 303 |
+
for imgs, labels in loader:
|
| 304 |
+
imgs, labels = imgs.to(device), labels.to(device)
|
| 305 |
+
optimizer.zero_grad()
|
| 306 |
+
loss = criterion(model(imgs), labels)
|
| 307 |
+
loss.backward()
|
| 308 |
+
optimizer.step()
|
| 309 |
+
total_loss += loss.item()
|
| 310 |
+
return total_loss / len(loader)
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def evaluate(
|
| 314 |
+
model: nn.Module,
|
| 315 |
+
loader: DataLoader,
|
| 316 |
+
device: torch.device,
|
| 317 |
+
) -> tuple[float, np.ndarray]:
|
| 318 |
+
"""Evaluate model on a DataLoader.
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
model: ScribblNet instance.
|
| 322 |
+
loader: Evaluation DataLoader.
|
| 323 |
+
device: Torch device.
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
(accuracy, predictions_array)
|
| 327 |
+
"""
|
| 328 |
+
model.eval()
|
| 329 |
+
all_preds, all_labels = [], []
|
| 330 |
+
with torch.no_grad():
|
| 331 |
+
for imgs, labels in loader:
|
| 332 |
+
imgs = imgs.to(device)
|
| 333 |
+
preds = model(imgs).argmax(dim=1).cpu().numpy()
|
| 334 |
+
all_preds.append(preds)
|
| 335 |
+
all_labels.append(labels.numpy())
|
| 336 |
+
preds = np.concatenate(all_preds)
|
| 337 |
+
labels = np.concatenate(all_labels)
|
| 338 |
+
return accuracy_score(labels, preds), preds
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def train_deep(
|
| 342 |
+
X_train_raw: np.ndarray,
|
| 343 |
+
X_test_raw: np.ndarray,
|
| 344 |
+
y_train: np.ndarray,
|
| 345 |
+
y_test: np.ndarray,
|
| 346 |
+
epochs: int = DEEP_EPOCHS,
|
| 347 |
+
train_fraction: float = 1.0,
|
| 348 |
+
save_model: bool = True,
|
| 349 |
+
) -> dict[str, Any]:
|
| 350 |
+
"""Train ScribblNet and evaluate on test set.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
X_train_raw: Raw training pixel array.
|
| 354 |
+
X_test_raw: Raw test pixel array.
|
| 355 |
+
y_train: Training labels.
|
| 356 |
+
y_test: Test labels.
|
| 357 |
+
epochs: Number of training epochs.
|
| 358 |
+
train_fraction: Fraction of training data to use.
|
| 359 |
+
save_model: Whether to save weights to disk.
|
| 360 |
+
|
| 361 |
+
Returns:
|
| 362 |
+
Dictionary of evaluation metrics and training history.
|
| 363 |
+
"""
|
| 364 |
+
print(f"\nDeep Model (ScribblNet, fraction={train_fraction:.0%})")
|
| 365 |
+
device = get_device()
|
| 366 |
+
print(f" Device: {device}")
|
| 367 |
+
|
| 368 |
+
train_loader, test_loader = make_dataloaders(
|
| 369 |
+
X_train_raw, y_train, X_test_raw, y_test, train_fraction=train_fraction
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
model = ScribblNet(num_classes=NUM_CLASSES).to(device)
|
| 373 |
+
optimizer = torch.optim.Adam(
|
| 374 |
+
model.parameters(), lr=DEEP_LR, weight_decay=DEEP_WEIGHT_DECAY
|
| 375 |
+
)
|
| 376 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
| 377 |
+
criterion = nn.CrossEntropyLoss()
|
| 378 |
+
|
| 379 |
+
history = {"loss": [], "val_acc": []}
|
| 380 |
+
best_acc = 0.0
|
| 381 |
+
|
| 382 |
+
for epoch in range(1, epochs + 1):
|
| 383 |
+
loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
|
| 384 |
+
acc, _ = evaluate(model, test_loader, device)
|
| 385 |
+
scheduler.step()
|
| 386 |
+
history["loss"].append(loss)
|
| 387 |
+
history["val_acc"].append(acc)
|
| 388 |
+
print(f" epoch {epoch:02d}/{epochs} loss={loss:.4f} val_acc={acc:.4f}")
|
| 389 |
+
|
| 390 |
+
if acc > best_acc:
|
| 391 |
+
best_acc = acc
|
| 392 |
+
if save_model:
|
| 393 |
+
torch.save(model.state_dict(), MODELS_DIR / "deep_model.pth")
|
| 394 |
+
|
| 395 |
+
# Final evaluation with best weights
|
| 396 |
+
if save_model:
|
| 397 |
+
model.load_state_dict(torch.load(MODELS_DIR / "deep_model.pth", map_location=device))
|
| 398 |
+
|
| 399 |
+
final_acc, final_preds = evaluate(model, test_loader, device)
|
| 400 |
+
print(f"\n Best test accuracy: {best_acc:.4f}")
|
| 401 |
+
|
| 402 |
+
if save_model:
|
| 403 |
+
report = classification_report(y_test, final_preds, target_names=CLASSES)
|
| 404 |
+
print(f"\n{report}")
|
| 405 |
+
_save_confusion_matrix(y_test, final_preds, "deep_confusion_matrix.png")
|
| 406 |
+
_save_training_curves(history)
|
| 407 |
+
|
| 408 |
+
return {"model": "deep", "accuracy": best_acc, "history": history}
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
# Experiment: Training Size Sensitivity
|
| 412 |
+
|
| 413 |
+
def run_experiment(
|
| 414 |
+
X_train_raw: np.ndarray,
|
| 415 |
+
X_test_raw: np.ndarray,
|
| 416 |
+
y_train: np.ndarray,
|
| 417 |
+
y_test: np.ndarray,
|
| 418 |
+
X_train_hog: np.ndarray,
|
| 419 |
+
X_test_hog: np.ndarray,
|
| 420 |
+
) -> None:
|
| 421 |
+
"""Training set size sensitivity analysis.
|
| 422 |
+
|
| 423 |
+
Sweeps over EXPERIMENT_FRACTIONS, training both the deep model and Random
|
| 424 |
+
Forest at each fraction, then plots accuracy vs number of training samples.
|
| 425 |
+
|
| 426 |
+
Motivation: Understanding how each model scales with data volume helps
|
| 427 |
+
justify architectural choices and highlights when more data is beneficial.
|
| 428 |
+
|
| 429 |
+
Args:
|
| 430 |
+
X_train_raw: Raw training pixels.
|
| 431 |
+
X_test_raw: Raw test pixels.
|
| 432 |
+
y_train: Training labels.
|
| 433 |
+
y_test: Test labels.
|
| 434 |
+
X_train_hog: HOG training features.
|
| 435 |
+
X_test_hog: HOG test features.
|
| 436 |
+
"""
|
| 437 |
+
print(f"\nExperiment: Training Size Sensitivity")
|
| 438 |
+
deep_accs, rf_accs, n_samples = [], [], []
|
| 439 |
+
scaler = StandardScaler()
|
| 440 |
+
X_test_scaled = scaler.fit_transform(X_test_hog)
|
| 441 |
+
|
| 442 |
+
for frac in EXPERIMENT_FRACTIONS:
|
| 443 |
+
n = int(len(X_train_raw) * frac)
|
| 444 |
+
n_samples.append(n)
|
| 445 |
+
print(f"\n Fraction={frac:.0%} (n={n})")
|
| 446 |
+
|
| 447 |
+
# Deep model
|
| 448 |
+
result = train_deep(
|
| 449 |
+
X_train_raw, X_test_raw, y_train, y_test,
|
| 450 |
+
epochs=EXPERIMENT_EPOCHS, train_fraction=frac, save_model=False,
|
| 451 |
+
)
|
| 452 |
+
deep_accs.append(result["accuracy"])
|
| 453 |
+
|
| 454 |
+
# Random Forest
|
| 455 |
+
idx = np.random.default_rng(seed=42).permutation(len(X_train_hog))[:n]
|
| 456 |
+
X_tr = scaler.fit_transform(X_train_hog[idx])
|
| 457 |
+
rf = RandomForestClassifier(
|
| 458 |
+
n_estimators=100, n_jobs=-1, random_state=42
|
| 459 |
+
)
|
| 460 |
+
rf.fit(X_tr, y_train[idx])
|
| 461 |
+
rf_pred = rf.predict(X_test_scaled)
|
| 462 |
+
rf_accs.append(accuracy_score(y_test, rf_pred))
|
| 463 |
+
print(f" RF acc={rf_accs[-1]:.4f}")
|
| 464 |
+
|
| 465 |
+
# Plot
|
| 466 |
+
fig, ax = plt.subplots(figsize=(8, 5))
|
| 467 |
+
ax.plot(n_samples, deep_accs, marker="o", linestyle="solid", label="ScribblNet (CNN)", linewidth=2, markersize=7)
|
| 468 |
+
ax.plot(n_samples, rf_accs, marker="s", linestyle="dashed", label="Random Forest (HOG)", linewidth=2, markersize=7)
|
| 469 |
+
ax.set_xlabel("Training samples", fontsize=12)
|
| 470 |
+
ax.set_ylabel("Test accuracy", fontsize=12)
|
| 471 |
+
ax.set_title("Training Set Size Sensitivity", fontsize=14)
|
| 472 |
+
ax.legend(fontsize=11)
|
| 473 |
+
ax.grid(True, alpha=0.3)
|
| 474 |
+
ax.set_ylim(0, 1)
|
| 475 |
+
plt.tight_layout()
|
| 476 |
+
out_path = OUTPUTS_DIR / "experiment_sensitivity.png"
|
| 477 |
+
fig.savefig(out_path, dpi=150)
|
| 478 |
+
plt.close(fig)
|
| 479 |
+
print(f"\n Saved experiment plot → {out_path}")
|
| 480 |
+
|
| 481 |
+
results = {
|
| 482 |
+
"fractions": EXPERIMENT_FRACTIONS,
|
| 483 |
+
"n_samples": n_samples,
|
| 484 |
+
"deep_accs": deep_accs,
|
| 485 |
+
"rf_accs": rf_accs,
|
| 486 |
+
}
|
| 487 |
+
with open(OUTPUTS_DIR / "experiment_results.json", "w") as f:
|
| 488 |
+
json.dump(results, f, indent=2)
|
| 489 |
+
print(" Saved experiment_results.json")
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
# Plotting Helpers
|
| 493 |
+
|
| 494 |
+
def _save_confusion_matrix(
|
| 495 |
+
y_true: np.ndarray,
|
| 496 |
+
y_pred: np.ndarray,
|
| 497 |
+
filename: str,
|
| 498 |
+
) -> None:
|
| 499 |
+
"""Save a normalised confusion matrix heatmap.
|
| 500 |
+
|
| 501 |
+
Args:
|
| 502 |
+
y_true: Ground truth labels.
|
| 503 |
+
y_pred: Predicted labels.
|
| 504 |
+
filename: Output filename (saved under OUTPUTS_DIR).
|
| 505 |
+
"""
|
| 506 |
+
cm = confusion_matrix(y_true, y_pred, normalize="true")
|
| 507 |
+
fig, ax = plt.subplots(figsize=(10, 8))
|
| 508 |
+
sns.heatmap(
|
| 509 |
+
cm,
|
| 510 |
+
annot=True,
|
| 511 |
+
fmt=".2f",
|
| 512 |
+
xticklabels=CLASSES,
|
| 513 |
+
yticklabels=CLASSES,
|
| 514 |
+
cmap="Blues",
|
| 515 |
+
ax=ax,
|
| 516 |
+
linewidths=0.5,
|
| 517 |
+
)
|
| 518 |
+
ax.set_xlabel("Predicted", fontsize=11)
|
| 519 |
+
ax.set_ylabel("True", fontsize=11)
|
| 520 |
+
ax.set_title(filename.replace("_", " ").replace(".png", "").title(), fontsize=13)
|
| 521 |
+
plt.xticks(rotation=45, ha="right")
|
| 522 |
+
plt.tight_layout()
|
| 523 |
+
fig.savefig(OUTPUTS_DIR / filename, dpi=150)
|
| 524 |
+
plt.close(fig)
|
| 525 |
+
print(f" Saved {filename}")
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def _save_training_curves(history: dict[str, list[float]]) -> None:
|
| 529 |
+
"""Save loss and validation accuracy curves for the deep model.
|
| 530 |
+
|
| 531 |
+
Args:
|
| 532 |
+
history: Dict with keys 'loss' and 'val_acc', each a list of per epoch values.
|
| 533 |
+
"""
|
| 534 |
+
epochs = range(1, len(history["loss"]) + 1)
|
| 535 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4))
|
| 536 |
+
|
| 537 |
+
ax1.plot(epochs, history["loss"], color="steelblue", marker="o", linestyle="solid", markersize=5)
|
| 538 |
+
ax1.set_xlabel("Epoch")
|
| 539 |
+
ax1.set_ylabel("Training Loss")
|
| 540 |
+
ax1.set_title("ScribblNet Training Loss")
|
| 541 |
+
ax1.grid(True, alpha=0.3)
|
| 542 |
+
|
| 543 |
+
ax2.plot(epochs, history["val_acc"], color="seagreen", marker="o", linestyle="solid", markersize=5)
|
| 544 |
+
ax2.set_xlabel("Epoch")
|
| 545 |
+
ax2.set_ylabel("Validation Accuracy")
|
| 546 |
+
ax2.set_title("ScribblNet Validation Accuracy")
|
| 547 |
+
ax2.grid(True, alpha=0.3)
|
| 548 |
+
|
| 549 |
+
plt.tight_layout()
|
| 550 |
+
fig.savefig(OUTPUTS_DIR / "deep_training_curves.png", dpi=150)
|
| 551 |
+
plt.close(fig)
|
| 552 |
+
print(" Saved deep_training_curves.png")
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def _save_model_comparison(results: list[dict[str, Any]]) -> None:
|
| 556 |
+
"""Bar chart comparing test accuracy across all three models.
|
| 557 |
+
|
| 558 |
+
Args:
|
| 559 |
+
results: List of result dicts each containing 'model' and 'accuracy'.
|
| 560 |
+
"""
|
| 561 |
+
names = [r["model"].capitalize() for r in results]
|
| 562 |
+
accs = [r["accuracy"] for r in results]
|
| 563 |
+
|
| 564 |
+
fig, ax = plt.subplots(figsize=(7, 4))
|
| 565 |
+
bars = ax.bar(names, accs, color=["#94a3b8", "#60a5fa", "#34d399"], width=0.5)
|
| 566 |
+
ax.set_ylim(0, 1)
|
| 567 |
+
ax.set_ylabel("Test Accuracy")
|
| 568 |
+
ax.set_title("Model Comparison")
|
| 569 |
+
for bar, acc in zip(bars, accs):
|
| 570 |
+
ax.text(
|
| 571 |
+
bar.get_x() + bar.get_width() / 2,
|
| 572 |
+
bar.get_height() + 0.01,
|
| 573 |
+
f"{acc:.3f}",
|
| 574 |
+
ha="center",
|
| 575 |
+
fontsize=12,
|
| 576 |
+
)
|
| 577 |
+
ax.grid(True, axis="y", alpha=0.3)
|
| 578 |
+
plt.tight_layout()
|
| 579 |
+
fig.savefig(OUTPUTS_DIR / "model_comparison.png", dpi=150)
|
| 580 |
+
plt.close(fig)
|
| 581 |
+
print(" Saved model_comparison.png")
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
# Orchestrator
|
| 585 |
+
|
| 586 |
+
def train_all() -> None:
|
| 587 |
+
"""Train all three models, run the experiment, and save all artefacts."""
|
| 588 |
+
X_train_raw, X_test_raw, y_train, y_test, X_train_hog, X_test_hog = (
|
| 589 |
+
load_processed_data()
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
r_naive = train_naive(y_train, y_test)
|
| 593 |
+
r_classical = train_classical(X_train_hog, X_test_hog, y_train, y_test)
|
| 594 |
+
r_deep = train_deep(X_train_raw, X_test_raw, y_train, y_test)
|
| 595 |
+
|
| 596 |
+
_save_model_comparison([r_naive, r_classical, r_deep])
|
| 597 |
+
|
| 598 |
+
run_experiment(
|
| 599 |
+
X_train_raw, X_test_raw, y_train, y_test, X_train_hog, X_test_hog
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
summary = {
|
| 603 |
+
"naive_accuracy": r_naive["accuracy"],
|
| 604 |
+
"classical_accuracy": r_classical["accuracy"],
|
| 605 |
+
"deep_accuracy": r_deep["accuracy"],
|
| 606 |
+
}
|
| 607 |
+
with open(OUTPUTS_DIR / "results_summary.json", "w") as f:
|
| 608 |
+
json.dump(summary, f, indent=2)
|
| 609 |
+
|
| 610 |
+
print("\nTraining complete. Summary:")
|
| 611 |
+
for k, v in summary.items():
|
| 612 |
+
print(f" {k}: {v:.4f}")
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
if __name__ == "__main__":
|
| 616 |
+
train_all()
|
setup.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
setup.py – Orchestrates the full ScribblBot pipeline:
|
| 3 |
+
1. Download Quick Draw data (make_dataset.py)
|
| 4 |
+
2. Build features (build_features.py)
|
| 5 |
+
3. Train all models (model.py)
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python setup.py # run full pipeline
|
| 9 |
+
python setup.py --skip_download # skip if data already downloaded
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import sys
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 17 |
+
|
| 18 |
+
from scripts.make_dataset import download_all
|
| 19 |
+
from scripts.build_features import build_all
|
| 20 |
+
from scripts.model import train_all
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def run(skip_download: bool = False) -> None:
|
| 24 |
+
"""Execute the complete data and training pipeline.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
skip_download: If True, skip the dataset download step.
|
| 28 |
+
Useful when raw .npy files are already present.
|
| 29 |
+
"""
|
| 30 |
+
print("ScribblBot setup pipeline starting")
|
| 31 |
+
|
| 32 |
+
if not skip_download:
|
| 33 |
+
print("\n[1/3] Downloading dataset ...")
|
| 34 |
+
download_all()
|
| 35 |
+
else:
|
| 36 |
+
print("\n[1/3] Skipping download (--skip_download)")
|
| 37 |
+
|
| 38 |
+
print("\n[2/3] Building features ...")
|
| 39 |
+
build_all()
|
| 40 |
+
|
| 41 |
+
print("\n[3/3] Training models ...")
|
| 42 |
+
train_all()
|
| 43 |
+
|
| 44 |
+
print("\nSetup complete. Run the app with:")
|
| 45 |
+
print(" python app.py")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if __name__ == "__main__":
|
| 49 |
+
parser = argparse.ArgumentParser(description="ScribblBot full pipeline setup")
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--skip_download",
|
| 52 |
+
action="store_true",
|
| 53 |
+
help="Skip dataset download (use if .npy files already exist in data/raw/)",
|
| 54 |
+
)
|
| 55 |
+
args = parser.parse_args()
|
| 56 |
+
run(skip_download=args.skip_download)
|