Jog-sama commited on
Commit
af61511
·
1 Parent(s): 5883215

initial deployment

Browse files
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ .venv/
4
+ venv/
5
+ .env
6
+ data/raw/*.npy
7
+ data/processed/
8
+ data/outputs/
9
+ models/*.pkl
10
+ notebooks/.ipynb_checkpoints/
11
+ .DS_Store
12
+ *.egg-info/
Makefile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: setup download features train app clean
2
+
3
+ # Run full pipeline (download → features → train)
4
+ setup:
5
+ python setup.py
6
+
7
+ # Individual steps
8
+ download:
9
+ python scripts/make_dataset.py
10
+
11
+ features:
12
+ python scripts/build_features.py
13
+
14
+ train:
15
+ python scripts/model.py
16
+
17
+ # Run the app locally
18
+ app:
19
+ python app.py
20
+
21
+ # Remove all generated data and model files
22
+ clean:
23
+ rm -rf data/raw/*.npy data/processed/*.npy data/outputs/* models/*.pkl models/*.pth
24
+
25
+ # Install dependencies
26
+ install:
27
+ pip install -r requirements.txt
README.md CHANGED
@@ -1,13 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: ScribbleBot Huggingface
3
- emoji: 🚀
4
- colorFrom: pink
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 6.13.0
8
- app_file: app.py
9
- pinned: false
10
- short_description: Real-time sketch recognition using CNN
 
 
 
 
 
 
 
 
 
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ScribblBot
2
+
3
+ Real time sketch recognition powered by a lightweight CNN trained on Google's Quick Draw dataset. Draw anything in the browser and ScribblBot identifies it instantly.
4
+
5
+ **Live app:** [your HuggingFace Spaces URL here]
6
+
7
+ ---
8
+
9
+ ## Results
10
+
11
+ | Model | Architecture | Test Accuracy |
12
+ |---|---|---|
13
+ | Majority Class Baseline | Always predicts most frequent class | 6.67% |
14
+ | Random Forest | HOG features, 200 trees | 85.10% |
15
+ | ScribblNet | 3-layer CNN | **94.42%** |
16
+
17
+ ScribblNet was trained on 30,000 samples across 15 classes (2,000 per class) for 15 epochs using Adam with cosine annealing. Training took under 5 minutes on Apple M-series hardware with MPS acceleration. The Random Forest operates on 1,296-dimensional HOG feature vectors extracted from 28×28 grayscale bitmaps.
18
+
19
+ **Classes:** cat · dog · pizza · bicycle · house · sun · tree · car · fish · butterfly · guitar · hamburger · airplane · banana · star
20
+
21
+ ### Per-class performance (ScribblNet)
22
+
23
+ | Class | Precision | Recall | F1 |
24
+ |---|---|---|---|
25
+ | cat | 0.88 | 0.83 | 0.85 |
26
+ | dog | 0.81 | 0.82 | 0.82 |
27
+ | pizza | 0.94 | 0.97 | 0.96 |
28
+ | bicycle | 0.96 | 0.98 | 0.97 |
29
+ | house | 0.99 | 0.98 | 0.98 |
30
+ | sun | 0.95 | 0.96 | 0.96 |
31
+ | tree | 0.96 | 0.97 | 0.97 |
32
+ | car | 0.97 | 0.96 | 0.96 |
33
+ | fish | 0.97 | 0.95 | 0.96 |
34
+ | butterfly | 0.97 | 0.97 | 0.97 |
35
+ | guitar | 0.95 | 0.98 | 0.96 |
36
+ | hamburger | 0.99 | 0.97 | 0.98 |
37
+ | airplane | 0.91 | 0.89 | 0.90 |
38
+ | banana | 0.97 | 0.98 | 0.97 |
39
+ | star | 0.94 | 0.94 | 0.94 |
40
+
41
+ Cat and dog are the hardest classes, which is expected given their visual similarity in quick sketches. Airplane also underperforms, likely due to style variation in how people draw wings and fuselage.
42
+
43
+ ---
44
+
45
+ ## Experiment: Training Size Sensitivity
46
+
47
+ Both ScribblNet and Random Forest were trained at 10%, 25%, 50%, 75%, and 100% of available training data.
48
+
49
+ | Fraction | Samples | ScribblNet | Random Forest |
50
+ |---|---|---|---|
51
+ | 10% | 3,000 | 86.12% | 77.92% |
52
+ | 25% | 7,500 | 90.37% | 80.35% |
53
+ | 50% | 15,000 | 92.70% | 81.57% |
54
+ | 75% | 22,500 | 93.00% | 82.88% |
55
+ | 100% | 30,000 | 94.02% | 83.03% |
56
+
57
+ The CNN scales more steeply with data volume than the Random Forest. At 10% of training data the gap is about 8 points; at 100% it grows to 11 points. The Random Forest plateaus around 83% while ScribblNet continues improving, suggesting the CNN would benefit further from additional data.
58
+
59
+ ---
60
+
61
+ ## Dataset
62
+
63
+ [Quick Draw](https://quickdraw.withgoogle.com/data) by Google — 50 million drawings across 345 categories, collected from players of the Quick Draw game. Each drawing is a 28×28 grayscale bitmap stored as a flat 784-element uint8 vector. The dataset is publicly available via Google Cloud Storage.
64
+
65
  ---
66
+
67
+ ## Setup
68
+
69
+ ```bash
70
+ pip install -r requirements.txt
71
+ python setup.py
72
+ python app.py
73
+ ```
74
+
75
+ `setup.py` runs the full pipeline: downloads the raw `.npy` files, extracts HOG features, trains all three models, and runs the experiment.
76
+
77
+ Individual steps:
78
+
79
+ ```bash
80
+ python scripts/make_dataset.py
81
+ python scripts/build_features.py
82
+ python scripts/model.py
83
+ ```
84
+
85
  ---
86
 
87
+ ## Repository Structure
88
+
89
+ ```
90
+ scribblbot/
91
+ ├── README.md
92
+ ├── requirements.txt
93
+ ├── Makefile
94
+ ├── setup.py
95
+ ├── app.py
96
+ ├── config.py
97
+ ├── scripts/
98
+ │ ├── make_dataset.py
99
+ │ ├── build_features.py
100
+ │ └── model.py
101
+ ├── models/
102
+ ├── data/
103
+ │ ├── raw/
104
+ │ ├── processed/
105
+ │ └── outputs/
106
+ └── notebooks/
107
+ ```
108
+
109
+ | Component | Location |
110
+ |---|---|
111
+ | Naive baseline | `scripts/model.py` — `MajorityClassifier`, saved to `models/naive_model.pkl` |
112
+ | Random Forest | `scripts/model.py` — `train_classical()`, saved to `models/classical_model.pkl` |
113
+ | ScribblNet CNN | `scripts/model.py` — `ScribblNet`, `train_deep()`, saved to `models/deep_model.pth` |
114
+ | Inference app | `app.py` |
115
+ | Config | `config.py` |
116
+
117
+ ---
118
+
119
+ ## Deployment
120
+
121
+ 1. `python setup.py` to train and generate `models/deep_model.pth`
122
+ 2. Create a new Space on HuggingFace (SDK: Gradio)
123
+ 3. Push the full repo including `models/deep_model.pth`
124
+
125
+ ---
126
+
127
+ ## Git Workflow
128
+
129
+ Working branches: `develop` for integration, `feature/*` for individual changes. All work branches into `develop` via pull requests. `develop` merges into `main` for releases. No direct commits to `main`.
app.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ app.py - ScribblBot inference application.
3
+ """
4
+
5
+ import sys
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from PIL import Image
14
+
15
+ sys.path.insert(0, str(Path(__file__).parent))
16
+
17
+ from config import CLASSES, CLASS_EMOJIS, MODELS_DIR, NUM_CLASSES
18
+ from scripts.model import ScribblNet
19
+
20
+
21
+ def _load_model() -> tuple[ScribblNet, torch.device]:
22
+ """Load trained ScribblNet weights from disk."""
23
+ device = torch.device("cpu")
24
+ model_path = MODELS_DIR / "deep_model.pth"
25
+ if not model_path.exists():
26
+ raise FileNotFoundError(f"Weights not found at {model_path}. Run python setup.py first.")
27
+ model = ScribblNet(num_classes=NUM_CLASSES)
28
+ model.load_state_dict(torch.load(model_path, map_location=device))
29
+ model.eval()
30
+ return model, device
31
+
32
+
33
+ MODEL, DEVICE = _load_model()
34
+
35
+
36
+ def predict(sketch: Optional[dict], _counter: int) -> tuple[str, int]:
37
+ """Run inference on an ImageEditor drawing.
38
+
39
+ Args:
40
+ sketch: Dict from gr.ImageEditor with 'composite' key.
41
+ _counter: Click counter used to bust Gradio output caching.
42
+
43
+ Returns:
44
+ Tuple of (HTML results string, incremented counter).
45
+ """
46
+ _counter += 1
47
+ if sketch is None:
48
+ return _empty_state_html(), _counter
49
+
50
+ img_array = sketch.get("composite") if isinstance(sketch, dict) else sketch
51
+ if img_array is None:
52
+ return _empty_state_html(), _counter
53
+
54
+ try:
55
+ img_pil = Image.fromarray(img_array.astype(np.uint8))
56
+ if img_pil.mode == "RGBA":
57
+ white = Image.new("RGBA", img_pil.size, (248, 247, 242, 255))
58
+ img_pil = Image.alpha_composite(white, img_pil).convert("L")
59
+ else:
60
+ img_pil = img_pil.convert("L")
61
+ img_pil = img_pil.resize((28, 28), Image.LANCZOS)
62
+ arr = np.array(img_pil, dtype=np.float32)
63
+ arr = (255.0 - arr) / 255.0
64
+ tensor = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0).to(DEVICE)
65
+ except Exception as exc:
66
+ return _error_html(str(exc)), _counter
67
+
68
+ with torch.no_grad():
69
+ probs = F.softmax(MODEL(tensor), dim=1)[0].cpu().numpy()
70
+
71
+ top = [(CLASSES[i], float(probs[i])) for i in np.argsort(probs)[::-1][:5]]
72
+ return _results_html(top), _counter
73
+
74
+
75
+ def _results_html(top: list[tuple[str, float]]) -> str:
76
+ best_cls, best_prob = top[0]
77
+ conf_pct = best_prob * 100
78
+ label = "CONFIDENT" if best_prob > 0.7 else "LIKELY" if best_prob > 0.4 else "UNSURE"
79
+ bars = ""
80
+ for i, (cls, prob) in enumerate(top):
81
+ pct = prob * 100
82
+ bars += f"""
83
+ <div class="bar-row" style="animation-delay:{i*0.08}s">
84
+ <span class="bar-emoji">{CLASS_EMOJIS.get(cls,'')}</span>
85
+ <span class="bar-label">{cls.upper()}</span>
86
+ <div class="bar-track"><div class="bar-fill" style="width:{pct:.1f}%;animation-delay:{i*0.08+0.1}s"></div></div>
87
+ <span class="bar-pct">{pct:.1f}%</span>
88
+ </div>"""
89
+ return f"""
90
+ <div class="results-panel fade-in">
91
+ <div class="result-tag">[ PREDICTION ]</div>
92
+ <div class="top-result">
93
+ <span class="top-emoji">{CLASS_EMOJIS.get(best_cls,'?')}</span>
94
+ <div class="top-text">
95
+ <div class="top-label">{best_cls.upper()}</div>
96
+ <div class="top-conf">{conf_pct:.1f}% &nbsp;·&nbsp; {label}</div>
97
+ </div>
98
+ </div>
99
+ <div class="divider"></div>
100
+ <div class="section-label">TOP 5 PROBABILITIES</div>
101
+ {bars}
102
+ </div>"""
103
+
104
+
105
+ def _empty_state_html() -> str:
106
+ return """
107
+ <div class="results-panel empty-state">
108
+ <div class="empty-icon">✏️</div>
109
+ <div class="empty-title">DRAW SOMETHING</div>
110
+ <div class="empty-sub">then hit ANALYZE</div>
111
+ <div class="class-pills">
112
+ <span class="pill">🐱 cat</span><span class="pill">🐶 dog</span>
113
+ <span class="pill">🍕 pizza</span><span class="pill">🚲 bicycle</span>
114
+ <span class="pill">🏠 house</span><span class="pill">☀️ sun</span>
115
+ <span class="pill">🌳 tree</span><span class="pill">🚗 car</span>
116
+ <span class="pill">🐟 fish</span><span class="pill">🦋 butterfly</span>
117
+ <span class="pill">🎸 guitar</span><span class="pill">🍔 hamburger</span>
118
+ <span class="pill">✈️ airplane</span><span class="pill">🍌 banana</span>
119
+ <span class="pill">⭐ star</span>
120
+ </div>
121
+ </div>"""
122
+
123
+
124
+ def _error_html(msg: str) -> str:
125
+ return f'<div class="results-panel error-state"><p class="err-msg">⚠ {msg}</p></div>'
126
+
127
+
128
+ CUSTOM_CSS = """
129
+ @import url('https://fonts.googleapis.com/css2?family=VT323&family=IBM+Plex+Mono:wght@400;500&display=swap');
130
+
131
+ :root {
132
+ --bg: #080808;
133
+ --surface: #111111;
134
+ --surface2: #1a1a1a;
135
+ --border: #2a2a2a;
136
+ --accent: #b8ff57;
137
+ --text: #e8e8e0;
138
+ --text-muted:#888880;
139
+ --red: #ff5f57;
140
+ --mono: 'IBM Plex Mono', monospace;
141
+ --display: 'VT323', monospace;
142
+ }
143
+ body, .gradio-container, #root {
144
+ background: var(--bg) !important;
145
+ font-family: var(--mono) !important;
146
+ color: var(--text) !important;
147
+ }
148
+ .gradio-container { max-width: 1100px !important; margin: 0 auto !important; }
149
+ footer { display: none !important; }
150
+ .block, .gr-box { background: transparent !important; border: none !important; box-shadow: none !important; }
151
+
152
+ .app-header { text-align: center; padding: 36px 20px 20px; border-bottom: 1px solid var(--border); margin-bottom: 28px; }
153
+ .app-title { font-family: var(--display); font-size: 72px; line-height: 1; color: var(--accent); letter-spacing: 6px; text-shadow: 0 0 30px rgba(184,255,87,0.3); margin: 0; }
154
+ .app-subtitle { font-size: 12px; color: var(--text-muted); letter-spacing: 4px; margin-top: 6px; }
155
+
156
+ /* ImageEditor styling */
157
+ /* Override Gradio's orange accent with our green */
158
+ .sketch-col { --color-accent: #b8ff57 !important; --color-accent-soft: rgba(184,255,87,0.15) !important; }
159
+ .sketch-col .image-editor { border: 1px solid var(--border) !important; border-radius: 4px !important; background: var(--surface) !important; }
160
+ /* Hide color picker and swatch - we only need pen and eraser */
161
+ .sketch-col [aria-label="Color"],
162
+ .sketch-col [title="Color"],
163
+ .sketch-col .image-editor .toolbar > button:nth-child(3),
164
+ .sketch-col .image-editor .toolbar > button:nth-child(4) { display: none !important; }
165
+ /* Toolbar background */
166
+ .sketch-col .image-editor > div { background: var(--surface2) !important; }
167
+ /* All buttons */
168
+ .sketch-col .image-editor button {
169
+ background: var(--surface2) !important;
170
+ border: 1px solid var(--border) !important;
171
+ border-radius: 3px !important;
172
+ margin: 2px !important;
173
+ color: var(--text) !important;
174
+ }
175
+ .sketch-col .image-editor button:hover {
176
+ background: var(--accent) !important;
177
+ border-color: var(--accent) !important;
178
+ color: #000 !important;
179
+ }
180
+ /* Active tool */
181
+ .sketch-col .image-editor button[aria-pressed="true"] {
182
+ border: 2px solid var(--accent) !important;
183
+ background: rgba(184,255,87,0.15) !important;
184
+ color: var(--accent) !important;
185
+ }
186
+ /* Force all SVG icons to white/text color */
187
+ .sketch-col .image-editor svg * { color: inherit !important; stroke: currentColor !important; }
188
+ .sketch-col [data-testid="layer-wrap"] { display: none !important; }
189
+ .sketch-col .layers-panel { display: none !important; }
190
+ /* White canvas */
191
+ .sketch-col .konvajs-content,
192
+ .sketch-col .konvajs-content canvas,
193
+ .sketch-col canvas { background: #f8f7f2 !important; background-color: #f8f7f2 !important; }
194
+ .sketch-col canvas { cursor: crosshair !important; }
195
+ .sketch-col * { cursor: auto; }
196
+ .sketch-col canvas { cursor: crosshair !important; }
197
+
198
+ .results-panel { background: var(--surface); border: 1px solid var(--border); border-radius: 4px; padding: 20px; min-height: 420px; font-family: var(--mono); }
199
+ .result-tag { font-size: 11px; color: var(--accent); letter-spacing: 3px; margin-bottom: 16px; }
200
+ .top-result { display: flex; align-items: center; gap: 18px; margin-bottom: 18px; }
201
+ .top-emoji { font-size: 56px; line-height: 1; }
202
+ .top-label { font-family: var(--display); font-size: 52px; color: var(--text); line-height: 1; letter-spacing: 3px; }
203
+ .top-conf { font-size: 13px; color: var(--accent); margin-top: 4px; }
204
+ .divider { height: 1px; background: var(--border); margin: 16px 0; }
205
+ .section-label { font-size: 10px; color: var(--text-muted); letter-spacing: 3px; margin-bottom: 12px; }
206
+ .bar-row { display: grid; grid-template-columns: 28px 90px 1fr 50px; align-items: center; gap: 8px; margin-bottom: 10px; opacity: 0; animation: slideIn 0.3s ease forwards; }
207
+ .bar-emoji { font-size: 16px; text-align: center; }
208
+ .bar-label { font-size: 11px; color: var(--text-muted); letter-spacing: 1px; }
209
+ .bar-track { height: 6px; background: var(--surface2); border-radius: 3px; overflow: hidden; }
210
+ .bar-fill { height: 100%; background: var(--accent); border-radius: 3px; width: 0; animation: barGrow 0.4s ease forwards; }
211
+ .bar-pct { font-size: 11px; color: var(--text); text-align: right; }
212
+
213
+ .empty-state { display: flex; flex-direction: column; align-items: center; justify-content: center; min-height: 360px; }
214
+ .empty-icon { font-size: 48px; margin-bottom: 12px; }
215
+ .empty-title { font-family: var(--display); font-size: 36px; color: var(--accent); letter-spacing: 3px; }
216
+ .empty-sub { font-size: 12px; color: var(--text-muted); margin: 4px 0 24px; letter-spacing: 2px; }
217
+ .class-pills { display: flex; flex-wrap: wrap; gap: 6px; justify-content: center; max-width: 340px; }
218
+ .pill { background: var(--surface2); border: 1px solid var(--border); padding: 3px 10px; border-radius: 20px; font-size: 11px; color: var(--text-muted); }
219
+ .error-state { display: flex; align-items: center; justify-content: center; min-height: 200px; }
220
+ .err-msg { font-size: 13px; color: var(--red); }
221
+
222
+ .analyze-row { padding: 12px 0 0 !important; }
223
+ .analyze-row button { width: 100% !important; background: rgba(184,255,87,0.06) !important; border: 2px solid var(--accent) !important; color: var(--accent) !important; font-family: var(--mono) !important; font-size: 15px !important; letter-spacing: 4px !important; padding: 14px !important; border-radius: 2px !important; cursor: pointer !important; transition: background 0.15s !important; }
224
+ .analyze-row button:hover { background: var(--accent) !important; color: #000 !important; }
225
+
226
+ .app-footer { text-align: center; padding: 18px; font-size: 11px; color: var(--text-muted); letter-spacing: 1px; border-top: 1px solid var(--border); margin-top: 16px; }
227
+
228
+ @keyframes slideIn { from { opacity: 0; transform: translateX(-8px); } to { opacity: 1; transform: translateX(0); } }
229
+ @keyframes barGrow { from { width: 0; } }
230
+ .fade-in { animation: fadeIn 0.25s ease; }
231
+ @keyframes fadeIn { from { opacity: 0; } to { opacity: 1; } }
232
+ """
233
+
234
+
235
+ def build_app() -> gr.Blocks:
236
+ """Construct and return the Gradio Blocks application."""
237
+ with gr.Blocks(css=CUSTOM_CSS, title="ScribblBot") as app:
238
+
239
+ gr.HTML("""
240
+ <div class="app-header">
241
+ <h1 class="app-title">SCRIBBLBOT</h1>
242
+ <p class="app-subtitle">NEURAL SKETCH CLASSIFIER · 15 CATEGORIES · QUICK DRAW DATASET</p>
243
+ </div>
244
+ """)
245
+
246
+ click_counter = gr.State(0)
247
+
248
+ with gr.Row():
249
+ with gr.Column(elem_classes=["sketch-col"]):
250
+ sketch_input = gr.ImageEditor(
251
+ type="numpy",
252
+ image_mode="RGBA",
253
+ canvas_size=(480, 480),
254
+ layers=False,
255
+ sources=[],
256
+ brush=gr.Brush(
257
+ colors=["#111111"],
258
+ default_size=14,
259
+ color_mode="fixed",
260
+ ),
261
+ eraser=gr.Eraser(default_size=20),
262
+ show_label=False,
263
+ )
264
+
265
+ with gr.Column():
266
+ result_html = gr.HTML(_empty_state_html())
267
+
268
+ with gr.Row(elem_classes=["analyze-row"]):
269
+ analyze_btn = gr.Button("ANALYZE")
270
+
271
+ gr.HTML('<div class="app-footer">ScribblBot · built with Quick Draw · PyTorch · Gradio</div>')
272
+
273
+ analyze_btn.click(
274
+ fn=predict,
275
+ inputs=[sketch_input, click_counter],
276
+ outputs=[result_html, click_counter],
277
+ )
278
+
279
+ return app
280
+
281
+
282
+ if __name__ == "__main__":
283
+ demo = build_app()
284
+ demo.launch()
config.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Central configuration for ScribblBot.
3
+ All hyperparameters, paths, and constants live here.
4
+ """
5
+
6
+ from pathlib import Path
7
+
8
+ # Paths
9
+ PROJECT_ROOT = Path(__file__).parent
10
+ DATA_DIR = PROJECT_ROOT / "data"
11
+ RAW_DIR = DATA_DIR / "raw"
12
+ PROCESSED_DIR = DATA_DIR / "processed"
13
+ OUTPUTS_DIR = DATA_DIR / "outputs"
14
+ MODELS_DIR = PROJECT_ROOT / "models"
15
+
16
+ for _d in [RAW_DIR, PROCESSED_DIR, OUTPUTS_DIR, MODELS_DIR]:
17
+ _d.mkdir(parents=True, exist_ok=True)
18
+
19
+ # Classes
20
+ # 15 visually distinct Quick Draw categories
21
+ CLASSES = [
22
+ "cat", "dog", "pizza", "bicycle", "house",
23
+ "sun", "tree", "car", "fish", "butterfly",
24
+ "guitar", "hamburger", "airplane", "banana", "star",
25
+ ]
26
+ NUM_CLASSES = len(CLASSES)
27
+
28
+ CLASS_EMOJIS = {
29
+ "cat": "🐱", "dog": "🐶", "pizza": "🍕", "bicycle": "🚲",
30
+ "house": "🏠", "sun": "☀️", "tree": "🌳", "car": "🚗",
31
+ "fish": "🐟", "butterfly": "🦋", "guitar": "🎸", "hamburger": "🍔",
32
+ "airplane": "✈️", "banana": "🍌", "star": "⭐",
33
+ }
34
+
35
+ # Dataset
36
+ TRAIN_SAMPLES_PER_CLASS = 2000 # keeps training fast (~30k total)
37
+ TEST_SAMPLES_PER_CLASS = 400 # solid eval set (~6k total)
38
+ IMG_SIZE = 28 # Quick Draw native resolution
39
+
40
+ # Quick Draw public GCS bucket
41
+ QUICKDRAW_URL = (
42
+ "https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/{cls}.npy"
43
+ )
44
+
45
+ # Deep Model
46
+ DEEP_BATCH_SIZE = 128
47
+ DEEP_EPOCHS = 15
48
+ DEEP_LR = 1e-3
49
+ DEEP_WEIGHT_DECAY = 1e-4
50
+
51
+ # Classical Model
52
+ RF_N_ESTIMATORS = 200
53
+ RF_MAX_DEPTH = None
54
+ HOG_ORIENTATIONS = 9
55
+ HOG_PIXELS_PER_CELL = (4, 4)
56
+ HOG_CELLS_PER_BLOCK = (2, 2)
57
+
58
+ # Experiment: training set size sensitivity
59
+ EXPERIMENT_FRACTIONS = [0.1, 0.25, 0.5, 0.75, 1.0]
60
+ EXPERIMENT_EPOCHS = 10 # shorter runs for the sweep
models/.gitkeep ADDED
File without changes
notebooks/colab_train.ipynb ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# ScribblBot — Colab Training Notebook\n",
8
+ "\n",
9
+ "Run every cell top to bottom. At the end your trained model files will be saved to Google Drive.\n",
10
+ "\n",
11
+ "**Before running:** go to Runtime > Change runtime type > T4 GPU"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "# 1. Install dependencies\n",
21
+ "!pip install scikit-image seaborn joblib --quiet"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "# 2. Mount Google Drive so models persist after the session ends\n",
31
+ "from google.colab import drive\n",
32
+ "drive.mount('/content/drive')\n",
33
+ "\n",
34
+ "import os\n",
35
+ "SAVE_DIR = '/content/drive/MyDrive/scribblbot_models'\n",
36
+ "os.makedirs(SAVE_DIR, exist_ok=True)\n",
37
+ "print(f'Models will be saved to: {SAVE_DIR}')"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "# 3. Config\n",
47
+ "from pathlib import Path\n",
48
+ "\n",
49
+ "CLASSES = [\n",
50
+ " 'cat', 'dog', 'pizza', 'bicycle', 'house',\n",
51
+ " 'sun', 'tree', 'car', 'fish', 'butterfly',\n",
52
+ " 'guitar', 'hamburger', 'airplane', 'banana', 'star',\n",
53
+ "]\n",
54
+ "NUM_CLASSES = len(CLASSES)\n",
55
+ "\n",
56
+ "TRAIN_SAMPLES_PER_CLASS = 2000\n",
57
+ "TEST_SAMPLES_PER_CLASS = 400\n",
58
+ "IMG_SIZE = 28\n",
59
+ "\n",
60
+ "DEEP_BATCH_SIZE = 256 # bigger batch since we have GPU\n",
61
+ "DEEP_EPOCHS = 15\n",
62
+ "DEEP_LR = 1e-3\n",
63
+ "DEEP_WEIGHT_DECAY = 1e-4\n",
64
+ "\n",
65
+ "RF_N_ESTIMATORS = 200\n",
66
+ "HOG_ORIENTATIONS = 9\n",
67
+ "HOG_PIXELS_PER_CELL = (4, 4)\n",
68
+ "HOG_CELLS_PER_BLOCK = (2, 2)\n",
69
+ "\n",
70
+ "EXPERIMENT_FRACTIONS = [0.1, 0.25, 0.5, 0.75, 1.0]\n",
71
+ "EXPERIMENT_EPOCHS = 10\n",
72
+ "\n",
73
+ "RAW_DIR = Path('/content/data/raw')\n",
74
+ "PROCESSED_DIR = Path('/content/data/processed')\n",
75
+ "OUTPUTS_DIR = Path('/content/data/outputs')\n",
76
+ "MODELS_DIR = Path('/content/models')\n",
77
+ "\n",
78
+ "for d in [RAW_DIR, PROCESSED_DIR, OUTPUTS_DIR, MODELS_DIR]:\n",
79
+ " d.mkdir(parents=True, exist_ok=True)\n",
80
+ "\n",
81
+ "QUICKDRAW_URL = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/{cls}.npy'\n",
82
+ "\n",
83
+ "print(f'Config ready. {NUM_CLASSES} classes, {TRAIN_SAMPLES_PER_CLASS} train samples each.')"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "metadata": {},
90
+ "outputs": [],
91
+ "source": [
92
+ "# 4. Download Quick Draw data\n",
93
+ "import urllib.request\n",
94
+ "\n",
95
+ "def download_class(cls):\n",
96
+ " url = QUICKDRAW_URL.format(cls=cls.replace(' ', '%20'))\n",
97
+ " dest = RAW_DIR / f'{cls}.npy'\n",
98
+ " if dest.exists():\n",
99
+ " print(f' already have {cls}.npy')\n",
100
+ " return\n",
101
+ " urllib.request.urlretrieve(url, dest)\n",
102
+ " print(f' downloaded {cls}.npy')\n",
103
+ "\n",
104
+ "print('Downloading dataset...')\n",
105
+ "for cls in CLASSES:\n",
106
+ " download_class(cls)\n",
107
+ "print('Done.')"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": null,
113
+ "metadata": {},
114
+ "outputs": [],
115
+ "source": [
116
+ "# 5. Build train/test splits and HOG features\n",
117
+ "import numpy as np\n",
118
+ "from skimage.feature import hog\n",
119
+ "\n",
120
+ "def load_class_data(cls, n_train, n_test):\n",
121
+ " data = np.load(RAW_DIR / f'{cls}.npy', mmap_mode='r')\n",
122
+ " rng = np.random.default_rng(seed=42)\n",
123
+ " indices = rng.permutation(len(data))[:n_train + n_test]\n",
124
+ " data = data[indices]\n",
125
+ " return data[:n_train], data[n_train:n_train + n_test]\n",
126
+ "\n",
127
+ "def extract_hog(pixel_matrix):\n",
128
+ " features = []\n",
129
+ " for row in pixel_matrix:\n",
130
+ " img = row.reshape(IMG_SIZE, IMG_SIZE)\n",
131
+ " desc = hog(img, orientations=HOG_ORIENTATIONS,\n",
132
+ " pixels_per_cell=HOG_PIXELS_PER_CELL,\n",
133
+ " cells_per_block=HOG_CELLS_PER_BLOCK,\n",
134
+ " visualize=False, channel_axis=None)\n",
135
+ " features.append(desc)\n",
136
+ " return np.array(features, dtype=np.float32)\n",
137
+ "\n",
138
+ "print('Loading raw data...')\n",
139
+ "train_raws, test_raws, train_labels, test_labels = [], [], [], []\n",
140
+ "for idx, cls in enumerate(CLASSES):\n",
141
+ " tr, te = load_class_data(cls, TRAIN_SAMPLES_PER_CLASS, TEST_SAMPLES_PER_CLASS)\n",
142
+ " train_raws.append(tr)\n",
143
+ " test_raws.append(te)\n",
144
+ " train_labels.append(np.full(len(tr), idx, dtype=np.int64))\n",
145
+ " test_labels.append(np.full(len(te), idx, dtype=np.int64))\n",
146
+ " print(f' {cls}')\n",
147
+ "\n",
148
+ "X_train_raw = np.concatenate(train_raws)\n",
149
+ "X_test_raw = np.concatenate(test_raws)\n",
150
+ "y_train = np.concatenate(train_labels)\n",
151
+ "y_test = np.concatenate(test_labels)\n",
152
+ "\n",
153
+ "rng = np.random.default_rng(seed=0)\n",
154
+ "perm = rng.permutation(len(X_train_raw))\n",
155
+ "X_train_raw = X_train_raw[perm]\n",
156
+ "y_train = y_train[perm]\n",
157
+ "\n",
158
+ "print('\\nExtracting HOG features (train)...')\n",
159
+ "X_train_hog = extract_hog(X_train_raw)\n",
160
+ "print('Extracting HOG features (test)...')\n",
161
+ "X_test_hog = extract_hog(X_test_raw)\n",
162
+ "\n",
163
+ "np.save(PROCESSED_DIR / 'X_train_raw.npy', X_train_raw)\n",
164
+ "np.save(PROCESSED_DIR / 'X_test_raw.npy', X_test_raw)\n",
165
+ "np.save(PROCESSED_DIR / 'y_train.npy', y_train)\n",
166
+ "np.save(PROCESSED_DIR / 'y_test.npy', y_test)\n",
167
+ "np.save(PROCESSED_DIR / 'X_train_hog.npy', X_train_hog)\n",
168
+ "np.save(PROCESSED_DIR / 'X_test_hog.npy', X_test_hog)\n",
169
+ "\n",
170
+ "print(f'\\nTrain: {X_train_raw.shape}, Test: {X_test_raw.shape}, HOG features: {X_train_hog.shape[1]}')"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": null,
176
+ "metadata": {},
177
+ "outputs": [],
178
+ "source": [
179
+ "# 6. Naive baseline\n",
180
+ "from sklearn.metrics import accuracy_score\n",
181
+ "import joblib\n",
182
+ "\n",
183
+ "majority_class = int(np.bincount(y_train).argmax())\n",
184
+ "naive_preds = np.full(len(y_test), majority_class, dtype=np.int64)\n",
185
+ "naive_acc = accuracy_score(y_test, naive_preds)\n",
186
+ "\n",
187
+ "joblib.dump({'majority_class': majority_class, 'accuracy': naive_acc}, MODELS_DIR / 'naive_model.pkl')\n",
188
+ "print(f'Naive baseline accuracy: {naive_acc:.4f}')\n",
189
+ "print(f'Majority class: {CLASSES[majority_class]}')"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": null,
195
+ "metadata": {},
196
+ "outputs": [],
197
+ "source": [
198
+ "# 7. Classical ML: Random Forest on HOG features\n",
199
+ "from sklearn.ensemble import RandomForestClassifier\n",
200
+ "from sklearn.preprocessing import StandardScaler\n",
201
+ "from sklearn.metrics import classification_report, confusion_matrix\n",
202
+ "import time\n",
203
+ "\n",
204
+ "scaler = StandardScaler()\n",
205
+ "X_tr_scaled = scaler.fit_transform(X_train_hog)\n",
206
+ "X_te_scaled = scaler.transform(X_test_hog)\n",
207
+ "\n",
208
+ "rf = RandomForestClassifier(n_estimators=RF_N_ESTIMATORS, n_jobs=-1, random_state=42)\n",
209
+ "t0 = time.time()\n",
210
+ "rf.fit(X_tr_scaled, y_train)\n",
211
+ "elapsed = time.time() - t0\n",
212
+ "\n",
213
+ "rf_preds = rf.predict(X_te_scaled)\n",
214
+ "rf_acc = accuracy_score(y_test, rf_preds)\n",
215
+ "\n",
216
+ "joblib.dump({'clf': rf, 'scaler': scaler}, MODELS_DIR / 'classical_model.pkl')\n",
217
+ "\n",
218
+ "print(f'Random Forest trained in {elapsed:.1f}s')\n",
219
+ "print(f'Test accuracy: {rf_acc:.4f}')\n",
220
+ "print()\n",
221
+ "print(classification_report(y_test, rf_preds, target_names=CLASSES))"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": null,
227
+ "metadata": {},
228
+ "outputs": [],
229
+ "source": [
230
+ "# 8. Define ScribblNet\n",
231
+ "import torch\n",
232
+ "import torch.nn as nn\n",
233
+ "import torch.nn.functional as F\n",
234
+ "from torch.utils.data import DataLoader, TensorDataset\n",
235
+ "\n",
236
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
237
+ "print(f'Using device: {device}')\n",
238
+ "\n",
239
+ "class ScribblNet(nn.Module):\n",
240
+ " def __init__(self, num_classes=NUM_CLASSES):\n",
241
+ " super().__init__()\n",
242
+ " self.features = nn.Sequential(\n",
243
+ " nn.Conv2d(1, 32, kernel_size=3, padding=1),\n",
244
+ " nn.BatchNorm2d(32),\n",
245
+ " nn.ReLU(inplace=True),\n",
246
+ " nn.MaxPool2d(2),\n",
247
+ "\n",
248
+ " nn.Conv2d(32, 64, kernel_size=3, padding=1),\n",
249
+ " nn.BatchNorm2d(64),\n",
250
+ " nn.ReLU(inplace=True),\n",
251
+ " nn.MaxPool2d(2),\n",
252
+ "\n",
253
+ " nn.Conv2d(64, 128, kernel_size=3, padding=1),\n",
254
+ " nn.BatchNorm2d(128),\n",
255
+ " nn.ReLU(inplace=True),\n",
256
+ " nn.MaxPool2d(2),\n",
257
+ " )\n",
258
+ " self.classifier = nn.Sequential(\n",
259
+ " nn.Dropout(0.5),\n",
260
+ " nn.Linear(128 * 3 * 3, 256),\n",
261
+ " nn.ReLU(inplace=True),\n",
262
+ " nn.Dropout(0.3),\n",
263
+ " nn.Linear(256, num_classes),\n",
264
+ " )\n",
265
+ "\n",
266
+ " def forward(self, x):\n",
267
+ " x = self.features(x)\n",
268
+ " x = x.view(x.size(0), -1)\n",
269
+ " return self.classifier(x)\n",
270
+ "\n",
271
+ "print('ScribblNet defined.')"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "code",
276
+ "execution_count": null,
277
+ "metadata": {},
278
+ "outputs": [],
279
+ "source": [
280
+ "# 9. Train ScribblNet\n",
281
+ "def make_loaders(X_raw, y, X_test, y_test, batch_size, fraction=1.0):\n",
282
+ " if fraction < 1.0:\n",
283
+ " n = max(1, int(len(X_raw) * fraction))\n",
284
+ " idx = np.random.default_rng(seed=7).permutation(len(X_raw))[:n]\n",
285
+ " X_raw = X_raw[idx]\n",
286
+ " y = y[idx]\n",
287
+ " def to_ds(X, labels):\n",
288
+ " imgs = torch.from_numpy(X.astype(np.float32) / 255.0).view(-1, 1, IMG_SIZE, IMG_SIZE)\n",
289
+ " return TensorDataset(imgs, torch.from_numpy(labels))\n",
290
+ " train_loader = DataLoader(to_ds(X_raw, y), batch_size=batch_size, shuffle=True, num_workers=2)\n",
291
+ " test_loader = DataLoader(to_ds(X_test, y_test), batch_size=batch_size, shuffle=False, num_workers=2)\n",
292
+ " return train_loader, test_loader\n",
293
+ "\n",
294
+ "def evaluate_model(model, loader):\n",
295
+ " model.eval()\n",
296
+ " all_preds, all_labels = [], []\n",
297
+ " with torch.no_grad():\n",
298
+ " for imgs, labels in loader:\n",
299
+ " preds = model(imgs.to(device)).argmax(dim=1).cpu().numpy()\n",
300
+ " all_preds.append(preds)\n",
301
+ " all_labels.append(labels.numpy())\n",
302
+ " preds = np.concatenate(all_preds)\n",
303
+ " labels = np.concatenate(all_labels)\n",
304
+ " return accuracy_score(labels, preds), preds\n",
305
+ "\n",
306
+ "train_loader, test_loader = make_loaders(X_train_raw, y_train, X_test_raw, y_test, DEEP_BATCH_SIZE)\n",
307
+ "\n",
308
+ "model = ScribblNet().to(device)\n",
309
+ "optimizer = torch.optim.Adam(model.parameters(), lr=DEEP_LR, weight_decay=DEEP_WEIGHT_DECAY)\n",
310
+ "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=DEEP_EPOCHS)\n",
311
+ "criterion = nn.CrossEntropyLoss()\n",
312
+ "\n",
313
+ "best_acc = 0.0\n",
314
+ "history = {'loss': [], 'val_acc': []}\n",
315
+ "\n",
316
+ "print(f'Training for {DEEP_EPOCHS} epochs...')\n",
317
+ "for epoch in range(1, DEEP_EPOCHS + 1):\n",
318
+ " model.train()\n",
319
+ " total_loss = 0.0\n",
320
+ " for imgs, labels in train_loader:\n",
321
+ " imgs, labels = imgs.to(device), labels.to(device)\n",
322
+ " optimizer.zero_grad()\n",
323
+ " loss = criterion(model(imgs), labels)\n",
324
+ " loss.backward()\n",
325
+ " optimizer.step()\n",
326
+ " total_loss += loss.item()\n",
327
+ " avg_loss = total_loss / len(train_loader)\n",
328
+ " val_acc, _ = evaluate_model(model, test_loader)\n",
329
+ " scheduler.step()\n",
330
+ " history['loss'].append(avg_loss)\n",
331
+ " history['val_acc'].append(val_acc)\n",
332
+ " print(f' epoch {epoch:02d}/{DEEP_EPOCHS} loss={avg_loss:.4f} val_acc={val_acc:.4f}')\n",
333
+ " if val_acc > best_acc:\n",
334
+ " best_acc = val_acc\n",
335
+ " torch.save(model.state_dict(), MODELS_DIR / 'deep_model.pth')\n",
336
+ "\n",
337
+ "print(f'\\nBest test accuracy: {best_acc:.4f}')"
338
+ ]
339
+ },
340
+ {
341
+ "cell_type": "code",
342
+ "execution_count": null,
343
+ "metadata": {},
344
+ "outputs": [],
345
+ "source": [
346
+ "# 10. Training curves\n",
347
+ "import matplotlib.pyplot as plt\n",
348
+ "\n",
349
+ "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4))\n",
350
+ "\n",
351
+ "ax1.plot(range(1, len(history['loss']) + 1), history['loss'],\n",
352
+ " color='steelblue', marker='o', linestyle='solid', markersize=5)\n",
353
+ "ax1.set_xlabel('Epoch')\n",
354
+ "ax1.set_ylabel('Training Loss')\n",
355
+ "ax1.set_title('ScribblNet Training Loss')\n",
356
+ "ax1.grid(True, alpha=0.3)\n",
357
+ "\n",
358
+ "ax2.plot(range(1, len(history['val_acc']) + 1), history['val_acc'],\n",
359
+ " color='seagreen', marker='o', linestyle='solid', markersize=5)\n",
360
+ "ax2.set_xlabel('Epoch')\n",
361
+ "ax2.set_ylabel('Validation Accuracy')\n",
362
+ "ax2.set_title('ScribblNet Validation Accuracy')\n",
363
+ "ax2.grid(True, alpha=0.3)\n",
364
+ "\n",
365
+ "plt.tight_layout()\n",
366
+ "plt.savefig(OUTPUTS_DIR / 'deep_training_curves.png', dpi=150)\n",
367
+ "plt.show()"
368
+ ]
369
+ },
370
+ {
371
+ "cell_type": "code",
372
+ "execution_count": null,
373
+ "metadata": {},
374
+ "outputs": [],
375
+ "source": [
376
+ "# 11. Confusion matrices\n",
377
+ "import seaborn as sns\n",
378
+ "\n",
379
+ "model.load_state_dict(torch.load(MODELS_DIR / 'deep_model.pth', map_location=device))\n",
380
+ "_, deep_preds = evaluate_model(model, test_loader)\n",
381
+ "\n",
382
+ "def plot_cm(y_true, y_pred, title, filename):\n",
383
+ " cm = confusion_matrix(y_true, y_pred, normalize='true')\n",
384
+ " fig, ax = plt.subplots(figsize=(10, 8))\n",
385
+ " sns.heatmap(cm, annot=True, fmt='.2f', xticklabels=CLASSES,\n",
386
+ " yticklabels=CLASSES, cmap='Blues', ax=ax, linewidths=0.5)\n",
387
+ " ax.set_xlabel('Predicted')\n",
388
+ " ax.set_ylabel('True')\n",
389
+ " ax.set_title(title)\n",
390
+ " plt.xticks(rotation=45, ha='right')\n",
391
+ " plt.tight_layout()\n",
392
+ " plt.savefig(OUTPUTS_DIR / filename, dpi=150)\n",
393
+ " plt.show()\n",
394
+ "\n",
395
+ "plot_cm(y_test, deep_preds, 'ScribblNet Confusion Matrix', 'deep_confusion_matrix.png')\n",
396
+ "plot_cm(y_test, rf_preds, 'Random Forest Confusion Matrix', 'classical_confusion_matrix.png')"
397
+ ]
398
+ },
399
+ {
400
+ "cell_type": "code",
401
+ "execution_count": null,
402
+ "metadata": {},
403
+ "outputs": [],
404
+ "source": [
405
+ "# 12. Model comparison bar chart\n",
406
+ "names = ['Naive Baseline', 'Random Forest', 'ScribblNet']\n",
407
+ "accs = [naive_acc, rf_acc, best_acc]\n",
408
+ "\n",
409
+ "fig, ax = plt.subplots(figsize=(7, 4))\n",
410
+ "bars = ax.bar(names, accs, color=['#94a3b8', '#60a5fa', '#34d399'], width=0.5)\n",
411
+ "ax.set_ylim(0, 1)\n",
412
+ "ax.set_ylabel('Test Accuracy')\n",
413
+ "ax.set_title('Model Comparison')\n",
414
+ "for bar, acc in zip(bars, accs):\n",
415
+ " ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,\n",
416
+ " f'{acc:.3f}', ha='center', fontsize=12)\n",
417
+ "ax.grid(True, axis='y', alpha=0.3)\n",
418
+ "plt.tight_layout()\n",
419
+ "plt.savefig(OUTPUTS_DIR / 'model_comparison.png', dpi=150)\n",
420
+ "plt.show()"
421
+ ]
422
+ },
423
+ {
424
+ "cell_type": "code",
425
+ "execution_count": null,
426
+ "metadata": {},
427
+ "outputs": [],
428
+ "source": [
429
+ "# 13. Experiment: training set size sensitivity\n",
430
+ "# Sweeps over fractions of training data for both models.\n",
431
+ "# Shows how data volume affects each approach.\n",
432
+ "print('Running experiment: training size sensitivity sweep...')\n",
433
+ "\n",
434
+ "deep_accs_exp = []\n",
435
+ "rf_accs_exp = []\n",
436
+ "n_samples_list = []\n",
437
+ "\n",
438
+ "for frac in EXPERIMENT_FRACTIONS:\n",
439
+ " n = int(len(X_train_raw) * frac)\n",
440
+ " n_samples_list.append(n)\n",
441
+ " print(f'\\nFraction {frac:.0%} (n={n})')\n",
442
+ "\n",
443
+ " # Deep model\n",
444
+ " tr_loader, te_loader = make_loaders(X_train_raw, y_train, X_test_raw, y_test,\n",
445
+ " DEEP_BATCH_SIZE, fraction=frac)\n",
446
+ " exp_model = ScribblNet().to(device)\n",
447
+ " exp_opt = torch.optim.Adam(exp_model.parameters(), lr=DEEP_LR, weight_decay=DEEP_WEIGHT_DECAY)\n",
448
+ " exp_sched = torch.optim.lr_scheduler.CosineAnnealingLR(exp_opt, T_max=EXPERIMENT_EPOCHS)\n",
449
+ " exp_model.train()\n",
450
+ " for ep in range(EXPERIMENT_EPOCHS):\n",
451
+ " for imgs, labels in tr_loader:\n",
452
+ " imgs, labels = imgs.to(device), labels.to(device)\n",
453
+ " exp_opt.zero_grad()\n",
454
+ " criterion(exp_model(imgs), labels).backward()\n",
455
+ " exp_opt.step()\n",
456
+ " exp_sched.step()\n",
457
+ " acc_deep, _ = evaluate_model(exp_model, te_loader)\n",
458
+ " deep_accs_exp.append(acc_deep)\n",
459
+ " print(f' ScribblNet acc: {acc_deep:.4f}')\n",
460
+ "\n",
461
+ " # Random Forest\n",
462
+ " idx = np.random.default_rng(seed=42).permutation(len(X_train_hog))[:n]\n",
463
+ " sc = StandardScaler()\n",
464
+ " X_tr_exp = sc.fit_transform(X_train_hog[idx])\n",
465
+ " X_te_exp = sc.transform(X_test_hog)\n",
466
+ " rf_exp = RandomForestClassifier(n_estimators=100, n_jobs=-1, random_state=42)\n",
467
+ " rf_exp.fit(X_tr_exp, y_train[idx])\n",
468
+ " acc_rf = accuracy_score(y_test, rf_exp.predict(X_te_exp))\n",
469
+ " rf_accs_exp.append(acc_rf)\n",
470
+ " print(f' Random Forest acc: {acc_rf:.4f}')\n",
471
+ "\n",
472
+ "# Plot\n",
473
+ "fig, ax = plt.subplots(figsize=(8, 5))\n",
474
+ "ax.plot(n_samples_list, deep_accs_exp, marker='o', linestyle='solid',\n",
475
+ " label='ScribblNet (CNN)', linewidth=2, markersize=7)\n",
476
+ "ax.plot(n_samples_list, rf_accs_exp, marker='s', linestyle='dashed',\n",
477
+ " label='Random Forest (HOG)', linewidth=2, markersize=7)\n",
478
+ "ax.set_xlabel('Training samples')\n",
479
+ "ax.set_ylabel('Test accuracy')\n",
480
+ "ax.set_title('Training Set Size Sensitivity')\n",
481
+ "ax.legend()\n",
482
+ "ax.grid(True, alpha=0.3)\n",
483
+ "ax.set_ylim(0, 1)\n",
484
+ "plt.tight_layout()\n",
485
+ "plt.savefig(OUTPUTS_DIR / 'experiment_sensitivity.png', dpi=150)\n",
486
+ "plt.show()\n",
487
+ "\n",
488
+ "import json\n",
489
+ "with open(OUTPUTS_DIR / 'experiment_results.json', 'w') as f:\n",
490
+ " json.dump({'fractions': EXPERIMENT_FRACTIONS, 'n_samples': n_samples_list,\n",
491
+ " 'deep_accs': deep_accs_exp, 'rf_accs': rf_accs_exp}, f, indent=2)\n",
492
+ "\n",
493
+ "print('Experiment complete.')"
494
+ ]
495
+ },
496
+ {
497
+ "cell_type": "code",
498
+ "execution_count": null,
499
+ "metadata": {},
500
+ "outputs": [],
501
+ "source": [
502
+ "# 14. Save everything to Google Drive\n",
503
+ "import shutil\n",
504
+ "\n",
505
+ "files_to_save = [\n",
506
+ " (MODELS_DIR / 'deep_model.pth', 'deep_model.pth'),\n",
507
+ " (MODELS_DIR / 'classical_model.pkl', 'classical_model.pkl'),\n",
508
+ " (MODELS_DIR / 'naive_model.pkl', 'naive_model.pkl'),\n",
509
+ " (OUTPUTS_DIR / 'deep_training_curves.png', 'deep_training_curves.png'),\n",
510
+ " (OUTPUTS_DIR / 'deep_confusion_matrix.png', 'deep_confusion_matrix.png'),\n",
511
+ " (OUTPUTS_DIR / 'classical_confusion_matrix.png', 'classical_confusion_matrix.png'),\n",
512
+ " (OUTPUTS_DIR / 'model_comparison.png', 'model_comparison.png'),\n",
513
+ " (OUTPUTS_DIR / 'experiment_sensitivity.png', 'experiment_sensitivity.png'),\n",
514
+ " (OUTPUTS_DIR / 'experiment_results.json', 'experiment_results.json'),\n",
515
+ "]\n",
516
+ "\n",
517
+ "for src, name in files_to_save:\n",
518
+ " dest = Path(SAVE_DIR) / name\n",
519
+ " shutil.copy(src, dest)\n",
520
+ " print(f'Saved {name}')\n",
521
+ "\n",
522
+ "print(f'\\nAll files saved to Google Drive at: {SAVE_DIR}')\n",
523
+ "print('Download deep_model.pth and put it in your local models/ folder.')"
524
+ ]
525
+ },
526
+ {
527
+ "cell_type": "code",
528
+ "execution_count": null,
529
+ "metadata": {},
530
+ "outputs": [],
531
+ "source": [
532
+ "# 15. Results summary\n",
533
+ "print('Results Summary')\n",
534
+ "print(f' Naive baseline: {naive_acc:.4f}')\n",
535
+ "print(f' Random Forest: {rf_acc:.4f}')\n",
536
+ "print(f' ScribblNet: {best_acc:.4f}')\n",
537
+ "\n",
538
+ "with open(OUTPUTS_DIR / 'results_summary.json', 'w') as f:\n",
539
+ " json.dump({'naive_accuracy': naive_acc, 'classical_accuracy': rf_acc,\n",
540
+ " 'deep_accuracy': best_acc}, f, indent=2)\n",
541
+ "shutil.copy(OUTPUTS_DIR / 'results_summary.json', Path(SAVE_DIR) / 'results_summary.json')\n",
542
+ "print('results_summary.json saved to Drive.')"
543
+ ]
544
+ }
545
+ ],
546
+ "metadata": {
547
+ "accelerator": "GPU",
548
+ "colab": {
549
+ "gpuType": "T4",
550
+ "provenance": []
551
+ },
552
+ "kernelspec": {
553
+ "display_name": "Python 3",
554
+ "language": "python",
555
+ "name": "python3"
556
+ },
557
+ "language_info": {
558
+ "name": "python",
559
+ "version": "3.10.0"
560
+ }
561
+ },
562
+ "nbformat": 4,
563
+ "nbformat_minor": 5
564
+ }
notebooks/exploration.ipynb ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# ScribblBot — Data Exploration"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "import sys\n",
17
+ "sys.path.insert(0, '..')\n",
18
+ "\n",
19
+ "import numpy as np\n",
20
+ "import matplotlib.pyplot as plt\n",
21
+ "from config import CLASSES, RAW_DIR\n",
22
+ "\n",
23
+ "cls = 'cat'\n",
24
+ "data = np.load(RAW_DIR / f'{cls}.npy')\n",
25
+ "print(f'{cls}: {data.shape}')\n",
26
+ "\n",
27
+ "fig, axes = plt.subplots(2, 8, figsize=(16, 4))\n",
28
+ "for i, ax in enumerate(axes.flat):\n",
29
+ " ax.imshow(data[i].reshape(28, 28), cmap='gray_r')\n",
30
+ " ax.axis('off')\n",
31
+ "plt.suptitle(cls)\n",
32
+ "plt.tight_layout()\n",
33
+ "plt.show()"
34
+ ]
35
+ }
36
+ ],
37
+ "metadata": {
38
+ "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" },
39
+ "language_info": { "name": "python", "version": "3.10.0" }
40
+ },
41
+ "nbformat": 4,
42
+ "nbformat_minor": 5
43
+ }
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.1.0
2
+ torchvision>=0.16.0
3
+ scikit-learn>=1.3.0
4
+ scikit-image>=0.21.0
5
+ numpy>=1.24.0
6
+ pandas>=2.0.0
7
+ matplotlib>=3.7.0
8
+ seaborn>=0.12.0
9
+ Pillow>=9.5.0
10
+ tqdm>=4.65.0
11
+ joblib>=1.3.0
12
+ gradio>=5.0.0
13
+ huggingface_hub>=0.20
14
+ jinja2>=3.1.4
15
+ requests>=2.31.0
scripts/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # scripts package
scripts/build_features.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ build_features.py – Load raw Quick Draw .npy files, split into train/test,
3
+ and extract HOG features for the classical ML pipeline.
4
+
5
+ Saved artefacts (under data/processed/):
6
+ X_train_raw.npy, y_train.npy -> pixel arrays for deep model
7
+ X_test_raw.npy, y_test.npy -> pixel arrays for evaluation
8
+ X_train_hog.npy -> HOG feature matrix for Random Forest
9
+ X_test_hog.npy
10
+
11
+ Usage:
12
+ python scripts/build_features.py
13
+ """
14
+
15
+ import sys
16
+ from pathlib import Path
17
+
18
+ import numpy as np
19
+ from skimage.feature import hog
20
+
21
+ sys.path.insert(0, str(Path(__file__).parent.parent))
22
+
23
+ from config import (
24
+ CLASSES,
25
+ RAW_DIR,
26
+ PROCESSED_DIR,
27
+ TRAIN_SAMPLES_PER_CLASS,
28
+ TEST_SAMPLES_PER_CLASS,
29
+ IMG_SIZE,
30
+ HOG_ORIENTATIONS,
31
+ HOG_PIXELS_PER_CELL,
32
+ HOG_CELLS_PER_BLOCK,
33
+ )
34
+
35
+
36
+ def load_class_data(cls: str, n_train: int, n_test: int) -> tuple[np.ndarray, np.ndarray]:
37
+ """Load and slice pixel data for a single class.
38
+
39
+ The Quick Draw .npy files contain rows of 784-element uint8 vectors
40
+ (28×28 flattened, pixel values 0–255, white stroke on black background).
41
+
42
+ Args:
43
+ cls: Class name.
44
+ n_train: Number of training samples to keep.
45
+ n_test: Number of test samples to keep.
46
+
47
+ Returns:
48
+ Tuple of (train_pixels, test_pixels) each shaped (n, 784).
49
+ """
50
+ path = RAW_DIR / f"{cls}.npy"
51
+ if not path.exists():
52
+ raise FileNotFoundError(
53
+ f"Missing {path}. Run scripts/make_dataset.py first."
54
+ )
55
+ data = np.load(path, mmap_mode="r") # memory mapped for large files
56
+
57
+ # Shuffle deterministically so splits are reproducible
58
+ rng = np.random.default_rng(seed=42)
59
+ indices = rng.permutation(len(data))[: n_train + n_test]
60
+ data = data[indices]
61
+
62
+ return data[:n_train], data[n_train : n_train + n_test]
63
+
64
+
65
+ def extract_hog_features(pixel_matrix: np.ndarray) -> np.ndarray:
66
+ """Compute HOG descriptors for a batch of flat pixel vectors.
67
+
68
+ Args:
69
+ pixel_matrix: Array of shape (N, 784), dtype uint8.
70
+
71
+ Returns:
72
+ Feature matrix of shape (N, D) where D is the HOG descriptor length.
73
+ """
74
+ features = []
75
+ for row in pixel_matrix:
76
+ img = row.reshape(IMG_SIZE, IMG_SIZE)
77
+ desc = hog(
78
+ img,
79
+ orientations=HOG_ORIENTATIONS,
80
+ pixels_per_cell=HOG_PIXELS_PER_CELL,
81
+ cells_per_block=HOG_CELLS_PER_BLOCK,
82
+ visualize=False,
83
+ channel_axis=None,
84
+ )
85
+ features.append(desc)
86
+ return np.array(features, dtype=np.float32)
87
+
88
+
89
+ def build_splits() -> None:
90
+ """Assemble train/test raw arrays and labels from all classes."""
91
+ train_raws, test_raws = [], []
92
+ train_labels, test_labels = [], []
93
+
94
+ print("Loading raw data …")
95
+ for label_idx, cls in enumerate(CLASSES):
96
+ print(f" {cls} ({label_idx + 1}/{len(CLASSES)})")
97
+ tr, te = load_class_data(cls, TRAIN_SAMPLES_PER_CLASS, TEST_SAMPLES_PER_CLASS)
98
+ train_raws.append(tr)
99
+ test_raws.append(te)
100
+ train_labels.append(np.full(len(tr), label_idx, dtype=np.int64))
101
+ test_labels.append(np.full(len(te), label_idx, dtype=np.int64))
102
+
103
+ X_train_raw = np.concatenate(train_raws)
104
+ X_test_raw = np.concatenate(test_raws)
105
+ y_train = np.concatenate(train_labels)
106
+ y_test = np.concatenate(test_labels)
107
+
108
+ # Shuffle training set
109
+ rng = np.random.default_rng(seed=0)
110
+ perm = rng.permutation(len(X_train_raw))
111
+ X_train_raw = X_train_raw[perm]
112
+ y_train = y_train[perm]
113
+
114
+ PROCESSED_DIR.mkdir(parents=True, exist_ok=True)
115
+ np.save(PROCESSED_DIR / "X_train_raw.npy", X_train_raw)
116
+ np.save(PROCESSED_DIR / "X_test_raw.npy", X_test_raw)
117
+ np.save(PROCESSED_DIR / "y_train.npy", y_train)
118
+ np.save(PROCESSED_DIR / "y_test.npy", y_test)
119
+ print(f"\nSaved raw splits → train {X_train_raw.shape}, test {X_test_raw.shape}")
120
+
121
+
122
+ def build_hog_features() -> None:
123
+ """Extract HOG features from saved raw arrays."""
124
+ X_train_raw = np.load(PROCESSED_DIR / "X_train_raw.npy")
125
+ X_test_raw = np.load(PROCESSED_DIR / "X_test_raw.npy")
126
+
127
+ print("Extracting HOG features (train) …")
128
+ X_train_hog = extract_hog_features(X_train_raw)
129
+ print("Extracting HOG features (test) …")
130
+ X_test_hog = extract_hog_features(X_test_raw)
131
+
132
+ np.save(PROCESSED_DIR / "X_train_hog.npy", X_train_hog)
133
+ np.save(PROCESSED_DIR / "X_test_hog.npy", X_test_hog)
134
+ print(f"Saved HOG features → train {X_train_hog.shape}, test {X_test_hog.shape}")
135
+
136
+
137
+ def build_all() -> None:
138
+ """Run the complete feature building pipeline."""
139
+ build_splits()
140
+ build_hog_features()
141
+ print("\nFeature pipeline complete.")
142
+
143
+
144
+ if __name__ == "__main__":
145
+ build_all()
scripts/make_dataset.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ make_dataset.py – Download Quick Draw .npy files for all configured classes.
3
+
4
+ Usage:
5
+ python scripts/make_dataset.py
6
+ """
7
+
8
+ import sys
9
+ from pathlib import Path
10
+ import urllib.request
11
+
12
+ sys.path.insert(0, str(Path(__file__).parent.parent))
13
+
14
+ from config import CLASSES, RAW_DIR, QUICKDRAW_URL
15
+
16
+
17
+ def download_class(cls: str, dest_dir: Path, force: bool = False) -> Path:
18
+ """Download the numpy bitmap file for a single Quick Draw class.
19
+
20
+ Args:
21
+ cls: Class name matching a Quick Draw category (e.g. 'cat').
22
+ dest_dir: Directory to write the .npy file into.
23
+ force: Redownload even if the file already exists.
24
+
25
+ Returns:
26
+ Path to the downloaded file.
27
+ """
28
+ url = QUICKDRAW_URL.format(cls=cls.replace(" ", "%20"))
29
+ dest = dest_dir / f"{cls}.npy"
30
+ if dest.exists() and not force:
31
+ print(f" [skip] {cls}.npy already exists")
32
+ return dest
33
+
34
+ print(f" [down] {cls}.npy -> {url}")
35
+
36
+ def _reporthook(block_num: int, block_size: int, total_size: int) -> None:
37
+ downloaded = block_num * block_size
38
+ pct = min(100, downloaded * 100 // total_size) if total_size > 0 else 0
39
+ print(f"\r {pct:3d}%", end="", flush=True)
40
+
41
+ urllib.request.urlretrieve(url, dest, reporthook=_reporthook)
42
+ print()
43
+ return dest
44
+
45
+
46
+ def download_all(force: bool = False) -> None:
47
+ """Download .npy files for every class listed in config.CLASSES.
48
+
49
+ Args:
50
+ force: Redownload files that already exist on disk.
51
+ """
52
+ RAW_DIR.mkdir(parents=True, exist_ok=True)
53
+ print(f"Downloading {len(CLASSES)} classes to {RAW_DIR} …\n")
54
+ for cls in CLASSES:
55
+ download_class(cls, RAW_DIR, force=force)
56
+ print("\nAll downloads complete.")
57
+
58
+
59
+ if __name__ == "__main__":
60
+ import argparse
61
+
62
+ parser = argparse.ArgumentParser(description="Download Quick Draw dataset")
63
+ parser.add_argument("--force", action="store_true", help="Redownload existing files")
64
+ args = parser.parse_args()
65
+ download_all(force=args.force)
scripts/model.py ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ model.py – Define, train, and evaluate all three models:
3
+ 1. Naive baseline (majority class classifier)
4
+ 2. Classical ML (Random Forest on HOG features)
5
+ 3. Deep learning (ScribblNet CNN)
6
+
7
+ Also runs the training size sensitivity experiment and saves results/plots.
8
+
9
+ Usage:
10
+ python scripts/model.py
11
+ """
12
+
13
+ import json
14
+ import sys
15
+ import time
16
+ from pathlib import Path
17
+ from typing import Any
18
+
19
+ import joblib
20
+ import matplotlib
21
+ matplotlib.use("Agg") # headless backend
22
+ import matplotlib.pyplot as plt
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn as nn
26
+ from sklearn.ensemble import RandomForestClassifier
27
+ from sklearn.metrics import (
28
+ accuracy_score,
29
+ classification_report,
30
+ confusion_matrix,
31
+ )
32
+ from sklearn.preprocessing import StandardScaler
33
+ from torch.utils.data import DataLoader, TensorDataset
34
+ import seaborn as sns
35
+
36
+ sys.path.insert(0, str(Path(__file__).parent.parent))
37
+
38
+ from config import (
39
+ CLASSES,
40
+ MODELS_DIR,
41
+ OUTPUTS_DIR,
42
+ PROCESSED_DIR,
43
+ NUM_CLASSES,
44
+ RF_MAX_DEPTH,
45
+ RF_N_ESTIMATORS,
46
+ DEEP_BATCH_SIZE,
47
+ DEEP_EPOCHS,
48
+ DEEP_LR,
49
+ DEEP_WEIGHT_DECAY,
50
+ IMG_SIZE,
51
+ EXPERIMENT_FRACTIONS,
52
+ EXPERIMENT_EPOCHS,
53
+ )
54
+
55
+
56
+ # Utility
57
+
58
+ def get_device() -> torch.device:
59
+ """Return the best available torch device (MPS > CUDA > CPU)."""
60
+ if torch.backends.mps.is_available():
61
+ return torch.device("mps")
62
+ if torch.cuda.is_available():
63
+ return torch.device("cuda")
64
+ return torch.device("cpu")
65
+
66
+
67
+ def load_processed_data() -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
68
+ """Load all processed arrays from disk.
69
+
70
+ Returns:
71
+ X_train_raw, X_test_raw, y_train, y_test, X_train_hog, X_test_hog
72
+ """
73
+ X_train_raw = np.load(PROCESSED_DIR / "X_train_raw.npy")
74
+ X_test_raw = np.load(PROCESSED_DIR / "X_test_raw.npy")
75
+ y_train = np.load(PROCESSED_DIR / "y_train.npy")
76
+ y_test = np.load(PROCESSED_DIR / "y_test.npy")
77
+ X_train_hog = np.load(PROCESSED_DIR / "X_train_hog.npy")
78
+ X_test_hog = np.load(PROCESSED_DIR / "X_test_hog.npy")
79
+ return X_train_raw, X_test_raw, y_train, y_test, X_train_hog, X_test_hog
80
+
81
+
82
+ # 1. Naive Baseline
83
+
84
+ class MajorityClassifier:
85
+ """Naive baseline: always predicts the most frequent class in training."""
86
+
87
+ def __init__(self) -> None:
88
+ self.majority_class: int = 0
89
+
90
+ def fit(self, y: np.ndarray) -> "MajorityClassifier":
91
+ """Fit by finding the majority class label.
92
+
93
+ Args:
94
+ y: 1-D array of integer class labels.
95
+
96
+ Returns:
97
+ self
98
+ """
99
+ counts = np.bincount(y)
100
+ self.majority_class = int(np.argmax(counts))
101
+ return self
102
+
103
+ def predict(self, n_samples: int) -> np.ndarray:
104
+ """Return the majority class repeated n_samples times.
105
+
106
+ Args:
107
+ n_samples: Number of predictions to generate.
108
+
109
+ Returns:
110
+ Array of length n_samples, all equal to majority_class.
111
+ """
112
+ return np.full(n_samples, self.majority_class, dtype=np.int64)
113
+
114
+
115
+ def train_naive(y_train: np.ndarray, y_test: np.ndarray) -> dict[str, Any]:
116
+ """Train and evaluate the majority class baseline.
117
+
118
+ Args:
119
+ y_train: Training labels.
120
+ y_test: Test labels.
121
+
122
+ Returns:
123
+ Dictionary of evaluation metrics.
124
+ """
125
+ print(f"\nNaive Baseline")
126
+ clf = MajorityClassifier().fit(y_train)
127
+ preds = clf.predict(len(y_test))
128
+ acc = accuracy_score(y_test, preds)
129
+ print(f" Majority class: {CLASSES[clf.majority_class]}")
130
+ print(f" Test accuracy: {acc:.4f}")
131
+
132
+ model_data = {"majority_class": clf.majority_class, "accuracy": acc}
133
+ joblib.dump(model_data, MODELS_DIR / "naive_model.pkl")
134
+
135
+ return {"model": "naive", "accuracy": acc}
136
+
137
+
138
+ # 2. Classical ML
139
+
140
+ def train_classical(
141
+ X_train_hog: np.ndarray,
142
+ X_test_hog: np.ndarray,
143
+ y_train: np.ndarray,
144
+ y_test: np.ndarray,
145
+ ) -> dict[str, Any]:
146
+ """Train Random Forest on HOG features and evaluate.
147
+
148
+ Args:
149
+ X_train_hog: Training HOG feature matrix.
150
+ X_test_hog: Test HOG feature matrix.
151
+ y_train: Training labels.
152
+ y_test: Test labels.
153
+
154
+ Returns:
155
+ Dictionary of evaluation metrics.
156
+ """
157
+ print(f"\nClassical ML (Random Forest on HOG)")
158
+
159
+ # Standardise features
160
+ scaler = StandardScaler()
161
+ X_tr = scaler.fit_transform(X_train_hog)
162
+ X_te = scaler.transform(X_test_hog)
163
+
164
+ clf = RandomForestClassifier(
165
+ n_estimators=RF_N_ESTIMATORS,
166
+ max_depth=RF_MAX_DEPTH,
167
+ n_jobs=-1,
168
+ random_state=42,
169
+ )
170
+ t0 = time.time()
171
+ clf.fit(X_tr, y_train)
172
+ elapsed = time.time() - t0
173
+
174
+ preds = clf.predict(X_te)
175
+ acc = accuracy_score(y_test, preds)
176
+ report = classification_report(y_test, preds, target_names=CLASSES)
177
+
178
+ print(f" Training time: {elapsed:.1f}s")
179
+ print(f" Test accuracy: {acc:.4f}")
180
+ print(f"\n{report}")
181
+
182
+ joblib.dump({"clf": clf, "scaler": scaler}, MODELS_DIR / "classical_model.pkl")
183
+ _save_confusion_matrix(y_test, preds, "classical_confusion_matrix.png")
184
+
185
+ return {"model": "classical", "accuracy": acc, "training_time_s": elapsed}
186
+
187
+
188
+ # 3. Deep Model
189
+
190
+ class ScribblNet(nn.Module):
191
+ """Lightweight CNN for 28×28 grayscale sketch classification.
192
+
193
+ Architecture:
194
+ 3 × (Conv2d → BatchNorm → ReLU → MaxPool)
195
+ Dropout → FC(1152→256) → ReLU → Dropout → FC(256→num_classes)
196
+ """
197
+
198
+ def __init__(self, num_classes: int = NUM_CLASSES) -> None:
199
+ super().__init__()
200
+ self.features = nn.Sequential(
201
+ nn.Conv2d(1, 32, kernel_size=3, padding=1),
202
+ nn.BatchNorm2d(32),
203
+ nn.ReLU(inplace=True),
204
+ nn.MaxPool2d(2),
205
+
206
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
207
+ nn.BatchNorm2d(64),
208
+ nn.ReLU(inplace=True),
209
+ nn.MaxPool2d(2),
210
+
211
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
212
+ nn.BatchNorm2d(128),
213
+ nn.ReLU(inplace=True),
214
+ nn.MaxPool2d(2),
215
+ )
216
+ # 28→14→7→3 ∴ feature map is 128×3×3 = 1152
217
+ self.classifier = nn.Sequential(
218
+ nn.Dropout(0.5),
219
+ nn.Linear(128 * 3 * 3, 256),
220
+ nn.ReLU(inplace=True),
221
+ nn.Dropout(0.3),
222
+ nn.Linear(256, num_classes),
223
+ )
224
+
225
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
226
+ """Forward pass.
227
+
228
+ Args:
229
+ x: Tensor of shape (B, 1, 28, 28), values in [0, 1].
230
+
231
+ Returns:
232
+ Logits tensor of shape (B, num_classes).
233
+ """
234
+ x = self.features(x)
235
+ x = x.view(x.size(0), -1)
236
+ return self.classifier(x)
237
+
238
+
239
+ def make_dataloaders(
240
+ X_raw: np.ndarray,
241
+ y: np.ndarray,
242
+ X_test_raw: np.ndarray,
243
+ y_test: np.ndarray,
244
+ batch_size: int = DEEP_BATCH_SIZE,
245
+ train_fraction: float = 1.0,
246
+ ) -> tuple[DataLoader, DataLoader]:
247
+ """Build PyTorch DataLoaders from raw pixel arrays.
248
+
249
+ Pixel values are normalised to [0, 1]. Training set can be subsampled
250
+ via train_fraction for the sensitivity experiment.
251
+
252
+ Args:
253
+ X_raw: Training pixel array (N, 784), uint8.
254
+ y: Training labels.
255
+ X_test_raw: Test pixel array.
256
+ y_test: Test labels.
257
+ batch_size: Minibatch size.
258
+ train_fraction: Fraction of training samples to use (0 < f ≤ 1).
259
+
260
+ Returns:
261
+ (train_loader, test_loader)
262
+ """
263
+ if train_fraction < 1.0:
264
+ n = max(1, int(len(X_raw) * train_fraction))
265
+ idx = np.random.default_rng(seed=7).permutation(len(X_raw))[:n]
266
+ X_raw = X_raw[idx]
267
+ y = y[idx]
268
+
269
+ def _to_tensor(X: np.ndarray, labels: np.ndarray) -> TensorDataset:
270
+ imgs = torch.from_numpy(X.astype(np.float32) / 255.0)
271
+ imgs = imgs.view(-1, 1, IMG_SIZE, IMG_SIZE)
272
+ return TensorDataset(imgs, torch.from_numpy(labels))
273
+
274
+ train_ds = _to_tensor(X_raw, y)
275
+ test_ds = _to_tensor(X_test_raw, y_test)
276
+
277
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
278
+ test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0)
279
+ return train_loader, test_loader
280
+
281
+
282
+ def train_one_epoch(
283
+ model: nn.Module,
284
+ loader: DataLoader,
285
+ optimizer: torch.optim.Optimizer,
286
+ criterion: nn.Module,
287
+ device: torch.device,
288
+ ) -> float:
289
+ """Run one training epoch and return average loss.
290
+
291
+ Args:
292
+ model: ScribblNet instance.
293
+ loader: Training DataLoader.
294
+ optimizer: Optimiser (Adam).
295
+ criterion: Loss function (CrossEntropyLoss).
296
+ device: Torch device.
297
+
298
+ Returns:
299
+ Mean loss over all minibatches.
300
+ """
301
+ model.train()
302
+ total_loss = 0.0
303
+ for imgs, labels in loader:
304
+ imgs, labels = imgs.to(device), labels.to(device)
305
+ optimizer.zero_grad()
306
+ loss = criterion(model(imgs), labels)
307
+ loss.backward()
308
+ optimizer.step()
309
+ total_loss += loss.item()
310
+ return total_loss / len(loader)
311
+
312
+
313
+ def evaluate(
314
+ model: nn.Module,
315
+ loader: DataLoader,
316
+ device: torch.device,
317
+ ) -> tuple[float, np.ndarray]:
318
+ """Evaluate model on a DataLoader.
319
+
320
+ Args:
321
+ model: ScribblNet instance.
322
+ loader: Evaluation DataLoader.
323
+ device: Torch device.
324
+
325
+ Returns:
326
+ (accuracy, predictions_array)
327
+ """
328
+ model.eval()
329
+ all_preds, all_labels = [], []
330
+ with torch.no_grad():
331
+ for imgs, labels in loader:
332
+ imgs = imgs.to(device)
333
+ preds = model(imgs).argmax(dim=1).cpu().numpy()
334
+ all_preds.append(preds)
335
+ all_labels.append(labels.numpy())
336
+ preds = np.concatenate(all_preds)
337
+ labels = np.concatenate(all_labels)
338
+ return accuracy_score(labels, preds), preds
339
+
340
+
341
+ def train_deep(
342
+ X_train_raw: np.ndarray,
343
+ X_test_raw: np.ndarray,
344
+ y_train: np.ndarray,
345
+ y_test: np.ndarray,
346
+ epochs: int = DEEP_EPOCHS,
347
+ train_fraction: float = 1.0,
348
+ save_model: bool = True,
349
+ ) -> dict[str, Any]:
350
+ """Train ScribblNet and evaluate on test set.
351
+
352
+ Args:
353
+ X_train_raw: Raw training pixel array.
354
+ X_test_raw: Raw test pixel array.
355
+ y_train: Training labels.
356
+ y_test: Test labels.
357
+ epochs: Number of training epochs.
358
+ train_fraction: Fraction of training data to use.
359
+ save_model: Whether to save weights to disk.
360
+
361
+ Returns:
362
+ Dictionary of evaluation metrics and training history.
363
+ """
364
+ print(f"\nDeep Model (ScribblNet, fraction={train_fraction:.0%})")
365
+ device = get_device()
366
+ print(f" Device: {device}")
367
+
368
+ train_loader, test_loader = make_dataloaders(
369
+ X_train_raw, y_train, X_test_raw, y_test, train_fraction=train_fraction
370
+ )
371
+
372
+ model = ScribblNet(num_classes=NUM_CLASSES).to(device)
373
+ optimizer = torch.optim.Adam(
374
+ model.parameters(), lr=DEEP_LR, weight_decay=DEEP_WEIGHT_DECAY
375
+ )
376
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
377
+ criterion = nn.CrossEntropyLoss()
378
+
379
+ history = {"loss": [], "val_acc": []}
380
+ best_acc = 0.0
381
+
382
+ for epoch in range(1, epochs + 1):
383
+ loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
384
+ acc, _ = evaluate(model, test_loader, device)
385
+ scheduler.step()
386
+ history["loss"].append(loss)
387
+ history["val_acc"].append(acc)
388
+ print(f" epoch {epoch:02d}/{epochs} loss={loss:.4f} val_acc={acc:.4f}")
389
+
390
+ if acc > best_acc:
391
+ best_acc = acc
392
+ if save_model:
393
+ torch.save(model.state_dict(), MODELS_DIR / "deep_model.pth")
394
+
395
+ # Final evaluation with best weights
396
+ if save_model:
397
+ model.load_state_dict(torch.load(MODELS_DIR / "deep_model.pth", map_location=device))
398
+
399
+ final_acc, final_preds = evaluate(model, test_loader, device)
400
+ print(f"\n Best test accuracy: {best_acc:.4f}")
401
+
402
+ if save_model:
403
+ report = classification_report(y_test, final_preds, target_names=CLASSES)
404
+ print(f"\n{report}")
405
+ _save_confusion_matrix(y_test, final_preds, "deep_confusion_matrix.png")
406
+ _save_training_curves(history)
407
+
408
+ return {"model": "deep", "accuracy": best_acc, "history": history}
409
+
410
+
411
+ # Experiment: Training Size Sensitivity
412
+
413
+ def run_experiment(
414
+ X_train_raw: np.ndarray,
415
+ X_test_raw: np.ndarray,
416
+ y_train: np.ndarray,
417
+ y_test: np.ndarray,
418
+ X_train_hog: np.ndarray,
419
+ X_test_hog: np.ndarray,
420
+ ) -> None:
421
+ """Training set size sensitivity analysis.
422
+
423
+ Sweeps over EXPERIMENT_FRACTIONS, training both the deep model and Random
424
+ Forest at each fraction, then plots accuracy vs number of training samples.
425
+
426
+ Motivation: Understanding how each model scales with data volume helps
427
+ justify architectural choices and highlights when more data is beneficial.
428
+
429
+ Args:
430
+ X_train_raw: Raw training pixels.
431
+ X_test_raw: Raw test pixels.
432
+ y_train: Training labels.
433
+ y_test: Test labels.
434
+ X_train_hog: HOG training features.
435
+ X_test_hog: HOG test features.
436
+ """
437
+ print(f"\nExperiment: Training Size Sensitivity")
438
+ deep_accs, rf_accs, n_samples = [], [], []
439
+ scaler = StandardScaler()
440
+ X_test_scaled = scaler.fit_transform(X_test_hog)
441
+
442
+ for frac in EXPERIMENT_FRACTIONS:
443
+ n = int(len(X_train_raw) * frac)
444
+ n_samples.append(n)
445
+ print(f"\n Fraction={frac:.0%} (n={n})")
446
+
447
+ # Deep model
448
+ result = train_deep(
449
+ X_train_raw, X_test_raw, y_train, y_test,
450
+ epochs=EXPERIMENT_EPOCHS, train_fraction=frac, save_model=False,
451
+ )
452
+ deep_accs.append(result["accuracy"])
453
+
454
+ # Random Forest
455
+ idx = np.random.default_rng(seed=42).permutation(len(X_train_hog))[:n]
456
+ X_tr = scaler.fit_transform(X_train_hog[idx])
457
+ rf = RandomForestClassifier(
458
+ n_estimators=100, n_jobs=-1, random_state=42
459
+ )
460
+ rf.fit(X_tr, y_train[idx])
461
+ rf_pred = rf.predict(X_test_scaled)
462
+ rf_accs.append(accuracy_score(y_test, rf_pred))
463
+ print(f" RF acc={rf_accs[-1]:.4f}")
464
+
465
+ # Plot
466
+ fig, ax = plt.subplots(figsize=(8, 5))
467
+ ax.plot(n_samples, deep_accs, marker="o", linestyle="solid", label="ScribblNet (CNN)", linewidth=2, markersize=7)
468
+ ax.plot(n_samples, rf_accs, marker="s", linestyle="dashed", label="Random Forest (HOG)", linewidth=2, markersize=7)
469
+ ax.set_xlabel("Training samples", fontsize=12)
470
+ ax.set_ylabel("Test accuracy", fontsize=12)
471
+ ax.set_title("Training Set Size Sensitivity", fontsize=14)
472
+ ax.legend(fontsize=11)
473
+ ax.grid(True, alpha=0.3)
474
+ ax.set_ylim(0, 1)
475
+ plt.tight_layout()
476
+ out_path = OUTPUTS_DIR / "experiment_sensitivity.png"
477
+ fig.savefig(out_path, dpi=150)
478
+ plt.close(fig)
479
+ print(f"\n Saved experiment plot → {out_path}")
480
+
481
+ results = {
482
+ "fractions": EXPERIMENT_FRACTIONS,
483
+ "n_samples": n_samples,
484
+ "deep_accs": deep_accs,
485
+ "rf_accs": rf_accs,
486
+ }
487
+ with open(OUTPUTS_DIR / "experiment_results.json", "w") as f:
488
+ json.dump(results, f, indent=2)
489
+ print(" Saved experiment_results.json")
490
+
491
+
492
+ # Plotting Helpers
493
+
494
+ def _save_confusion_matrix(
495
+ y_true: np.ndarray,
496
+ y_pred: np.ndarray,
497
+ filename: str,
498
+ ) -> None:
499
+ """Save a normalised confusion matrix heatmap.
500
+
501
+ Args:
502
+ y_true: Ground truth labels.
503
+ y_pred: Predicted labels.
504
+ filename: Output filename (saved under OUTPUTS_DIR).
505
+ """
506
+ cm = confusion_matrix(y_true, y_pred, normalize="true")
507
+ fig, ax = plt.subplots(figsize=(10, 8))
508
+ sns.heatmap(
509
+ cm,
510
+ annot=True,
511
+ fmt=".2f",
512
+ xticklabels=CLASSES,
513
+ yticklabels=CLASSES,
514
+ cmap="Blues",
515
+ ax=ax,
516
+ linewidths=0.5,
517
+ )
518
+ ax.set_xlabel("Predicted", fontsize=11)
519
+ ax.set_ylabel("True", fontsize=11)
520
+ ax.set_title(filename.replace("_", " ").replace(".png", "").title(), fontsize=13)
521
+ plt.xticks(rotation=45, ha="right")
522
+ plt.tight_layout()
523
+ fig.savefig(OUTPUTS_DIR / filename, dpi=150)
524
+ plt.close(fig)
525
+ print(f" Saved {filename}")
526
+
527
+
528
+ def _save_training_curves(history: dict[str, list[float]]) -> None:
529
+ """Save loss and validation accuracy curves for the deep model.
530
+
531
+ Args:
532
+ history: Dict with keys 'loss' and 'val_acc', each a list of per epoch values.
533
+ """
534
+ epochs = range(1, len(history["loss"]) + 1)
535
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4))
536
+
537
+ ax1.plot(epochs, history["loss"], color="steelblue", marker="o", linestyle="solid", markersize=5)
538
+ ax1.set_xlabel("Epoch")
539
+ ax1.set_ylabel("Training Loss")
540
+ ax1.set_title("ScribblNet Training Loss")
541
+ ax1.grid(True, alpha=0.3)
542
+
543
+ ax2.plot(epochs, history["val_acc"], color="seagreen", marker="o", linestyle="solid", markersize=5)
544
+ ax2.set_xlabel("Epoch")
545
+ ax2.set_ylabel("Validation Accuracy")
546
+ ax2.set_title("ScribblNet Validation Accuracy")
547
+ ax2.grid(True, alpha=0.3)
548
+
549
+ plt.tight_layout()
550
+ fig.savefig(OUTPUTS_DIR / "deep_training_curves.png", dpi=150)
551
+ plt.close(fig)
552
+ print(" Saved deep_training_curves.png")
553
+
554
+
555
+ def _save_model_comparison(results: list[dict[str, Any]]) -> None:
556
+ """Bar chart comparing test accuracy across all three models.
557
+
558
+ Args:
559
+ results: List of result dicts each containing 'model' and 'accuracy'.
560
+ """
561
+ names = [r["model"].capitalize() for r in results]
562
+ accs = [r["accuracy"] for r in results]
563
+
564
+ fig, ax = plt.subplots(figsize=(7, 4))
565
+ bars = ax.bar(names, accs, color=["#94a3b8", "#60a5fa", "#34d399"], width=0.5)
566
+ ax.set_ylim(0, 1)
567
+ ax.set_ylabel("Test Accuracy")
568
+ ax.set_title("Model Comparison")
569
+ for bar, acc in zip(bars, accs):
570
+ ax.text(
571
+ bar.get_x() + bar.get_width() / 2,
572
+ bar.get_height() + 0.01,
573
+ f"{acc:.3f}",
574
+ ha="center",
575
+ fontsize=12,
576
+ )
577
+ ax.grid(True, axis="y", alpha=0.3)
578
+ plt.tight_layout()
579
+ fig.savefig(OUTPUTS_DIR / "model_comparison.png", dpi=150)
580
+ plt.close(fig)
581
+ print(" Saved model_comparison.png")
582
+
583
+
584
+ # Orchestrator
585
+
586
+ def train_all() -> None:
587
+ """Train all three models, run the experiment, and save all artefacts."""
588
+ X_train_raw, X_test_raw, y_train, y_test, X_train_hog, X_test_hog = (
589
+ load_processed_data()
590
+ )
591
+
592
+ r_naive = train_naive(y_train, y_test)
593
+ r_classical = train_classical(X_train_hog, X_test_hog, y_train, y_test)
594
+ r_deep = train_deep(X_train_raw, X_test_raw, y_train, y_test)
595
+
596
+ _save_model_comparison([r_naive, r_classical, r_deep])
597
+
598
+ run_experiment(
599
+ X_train_raw, X_test_raw, y_train, y_test, X_train_hog, X_test_hog
600
+ )
601
+
602
+ summary = {
603
+ "naive_accuracy": r_naive["accuracy"],
604
+ "classical_accuracy": r_classical["accuracy"],
605
+ "deep_accuracy": r_deep["accuracy"],
606
+ }
607
+ with open(OUTPUTS_DIR / "results_summary.json", "w") as f:
608
+ json.dump(summary, f, indent=2)
609
+
610
+ print("\nTraining complete. Summary:")
611
+ for k, v in summary.items():
612
+ print(f" {k}: {v:.4f}")
613
+
614
+
615
+ if __name__ == "__main__":
616
+ train_all()
setup.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ setup.py – Orchestrates the full ScribblBot pipeline:
3
+ 1. Download Quick Draw data (make_dataset.py)
4
+ 2. Build features (build_features.py)
5
+ 3. Train all models (model.py)
6
+
7
+ Usage:
8
+ python setup.py # run full pipeline
9
+ python setup.py --skip_download # skip if data already downloaded
10
+ """
11
+
12
+ import argparse
13
+ import sys
14
+ from pathlib import Path
15
+
16
+ sys.path.insert(0, str(Path(__file__).parent))
17
+
18
+ from scripts.make_dataset import download_all
19
+ from scripts.build_features import build_all
20
+ from scripts.model import train_all
21
+
22
+
23
+ def run(skip_download: bool = False) -> None:
24
+ """Execute the complete data and training pipeline.
25
+
26
+ Args:
27
+ skip_download: If True, skip the dataset download step.
28
+ Useful when raw .npy files are already present.
29
+ """
30
+ print("ScribblBot setup pipeline starting")
31
+
32
+ if not skip_download:
33
+ print("\n[1/3] Downloading dataset ...")
34
+ download_all()
35
+ else:
36
+ print("\n[1/3] Skipping download (--skip_download)")
37
+
38
+ print("\n[2/3] Building features ...")
39
+ build_all()
40
+
41
+ print("\n[3/3] Training models ...")
42
+ train_all()
43
+
44
+ print("\nSetup complete. Run the app with:")
45
+ print(" python app.py")
46
+
47
+
48
+ if __name__ == "__main__":
49
+ parser = argparse.ArgumentParser(description="ScribblBot full pipeline setup")
50
+ parser.add_argument(
51
+ "--skip_download",
52
+ action="store_true",
53
+ help="Skip dataset download (use if .npy files already exist in data/raw/)",
54
+ )
55
+ args = parser.parse_args()
56
+ run(skip_download=args.skip_download)