Rahul-Samedavar commited on
Commit
ff53531
·
1 Parent(s): 84bc6da
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.9
5
+
6
+ RUN useradd -m -u 1000 user
7
+ USER user
8
+ ENV PATH="/home/user/.local/bin:$PATH"
9
+
10
+ WORKDIR /app
11
+
12
+ COPY --chown=user ./requirements.txt requirements.txt
13
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
+
15
+ COPY --chown=user . /app
16
+ CMD ["gunicorn", "-b", "0.0.0.0:7860", "app:app"]
Weights/v_20.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2e2874d6534697f6938dcd11696c9ddd52ca8bb1e7b6f3bed49afe4aa921d2a
3
+ size 925328
Weights/v_42.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6bdd4ef8c5593d8af016b1535cf1c91fdca78699dd6c7195d5c1f5a3416d7701
3
+ size 925328
Weights/v_44.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9cc5f9c6ed3359f0be14b099518aa840c7f35e15ce5ab3b9f45e9e434fb8f4ab
3
+ size 925328
Weights/v_45.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0228934153db8266623ad6157ffbd60fb9ee0778a8c13fad19ade62b06bc81f9
3
+ size 925328
Weights/v_46.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b2a62d01d0bdd3c2d8fa398c8f48a7ba71c5408bd946fd7c4a4098424aea333
3
+ size 925328
Weights/v_48.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3324e3d20980ea57e3293b2c8787d920e30032b43a5670109e6d180b63bc3033
3
+ size 925328
Weights/v_50.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f8955b9091e76edb2b63a0e809674b823ffca05dbe5132e5db7c865f43f7368
3
+ size 925328
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, render_template
2
+ import numpy as np
3
+ from tensorflow.keras.models import load_model
4
+
5
+ from tensorflow.keras import Sequential
6
+ from tensorflow.keras.layers import Conv2D, MaxPooling2D, Conv2D, Flatten, Dense, Dropout
7
+ from tensorflow.keras.metrics import Precision, Recall, TopKCategoricalAccuracy
8
+ from tensorflow.keras.optimizers import Adamax
9
+
10
+
11
+
12
+ # Replace this with any version of interest: (Available: 20, 42, 44, 45, 46, 48, 50 )
13
+ version = 50
14
+ WEIGHTS_PATH = f"Weights/v_{version}.h5"
15
+
16
+
17
+ model = Sequential([
18
+ Conv2D(16, (3,3), activation='relu', input_shape=(28, 28, 1)),
19
+ MaxPooling2D(2,2),
20
+ Conv2D(64, (3,3), activation='relu'),
21
+ MaxPooling2D(2,2),
22
+ Flatten(),
23
+ Dropout(0.2),
24
+ Dense(128, activation='relu'),
25
+ Dropout(0.2),
26
+ Dense(64, activation='relu'),
27
+ Dropout(0.2),
28
+ Dense(35, activation='softmax')
29
+ ])
30
+ model.compile(
31
+ optimizer=Adamax(0.001),
32
+ loss='categorical_crossentropy',
33
+ metrics=['accuracy', TopKCategoricalAccuracy(3), Precision(), Recall()]
34
+ )
35
+
36
+ model.load_weights(WEIGHTS_PATH)
37
+
38
+ app = Flask(__name__)
39
+
40
+ classes = ['Airplane', 'Alarm Clock', 'Ant', 'Bear', 'Beard', 'Bird', 'Bus',
41
+ 'Cookie', 'Cow', 'Donut', 'Hand', 'Hat', 'Key', 'Moon',
42
+ 'Motorbike', 'Octagon', 'Pizza', 'Rabbit', 'School Bus', 'Shark',
43
+ 'Skull', 'Smiley Face', 'Snake', 'Spider', 'Square', 'Star', 'Sun',
44
+ 'Swing Set', 'Table', 'Tent', 'Tree', 'Triangle', 'Whale', 'Wheel',
45
+ 'Windmill']
46
+
47
+ def label(pred):
48
+ return {classes[i]: float(pred[0][i]) for i in range(len(classes))}
49
+
50
+ @app.route('/')
51
+ def home():
52
+ return render_template('index.html')
53
+
54
+ @app.route('/classify', methods=['POST'])
55
+ def classify():
56
+ doodle = request.get_json()['doodle']
57
+ doodle = np.array(doodle)
58
+ pred = model.predict(np.expand_dims(doodle, axis=0).astype(np.float16))[0].astype(np.float64)
59
+ return {classes[i]: pred[i] for i in range(35)}
60
+
61
+ if __name__ == '__main__':
62
+ app.run(debug=True)
static/script.js ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ document.addEventListener('DOMContentLoaded', function () {
2
+ const canvas = document.getElementById('drawingCanvas');
3
+ const ctx = canvas.getContext('2d');
4
+ const clearButton = document.getElementById('clearButton');
5
+ const resultsBody = document.getElementById('resultsBody');
6
+
7
+
8
+ const canvasSize = { width: 280, height: 280 };
9
+ canvas.width = canvasSize.width;
10
+ canvas.height = canvasSize.height;
11
+
12
+ const gridSize = 28;
13
+
14
+ let isDrawing = false;
15
+
16
+ ctx.lineJoin = 'round';
17
+ ctx.lineCap = 'round';
18
+ ctx.lineWidth = 15;
19
+ ctx.strokeStyle = 'black';
20
+
21
+ function clearCanvas() {
22
+ ctx.fillStyle = 'white';
23
+ ctx.fillRect(0, 0, canvasSize.width, canvasSize.height);
24
+ resultsBody.innerHTML = `
25
+ <tr class="empty-state">
26
+ <td colspan="2">Draw something to see classification results</td>
27
+ </tr>
28
+ `;
29
+ }
30
+
31
+ clearCanvas();
32
+
33
+ function canvasToArray() {
34
+ const imageData = ctx.getImageData(0, 0, canvasSize.width, canvasSize.height);
35
+ const data = imageData.data;
36
+
37
+ const result = Array(gridSize).fill(0).map(() => Array(gridSize).fill(0));
38
+
39
+ const cellWidth = canvasSize.width / gridSize;
40
+ const cellHeight = canvasSize.height / gridSize;
41
+
42
+ for (let i = 0; i < gridSize; i++) {
43
+ for (let j = 0; j < gridSize; j++) {
44
+ let sum = 0;
45
+ let count = 0;
46
+
47
+ for (let x = Math.floor(j * cellWidth); x < Math.floor((j + 1) * cellWidth); x++) {
48
+ for (let y = Math.floor(i * cellHeight); y < Math.floor((i + 1) * cellHeight); y++) {
49
+ const idx = (y * canvasSize.width + x) * 4;
50
+ const gray = 1 - (data[idx] + data[idx + 1] + data[idx + 2]) / (3 * 255);
51
+ sum += gray;
52
+ count++;
53
+ }
54
+ }
55
+
56
+ result[i][j] = count > 0 ? sum / count : 0;
57
+ }
58
+ }
59
+
60
+ return result;
61
+ }
62
+
63
+ function debounce(func, delay) {
64
+ let timeoutId;
65
+ return function (...args) {
66
+ clearTimeout(timeoutId);
67
+ timeoutId = setTimeout(() => func.apply(this, args), delay);
68
+ };
69
+ }
70
+
71
+ async function classifyDoodle() {
72
+ showLoadingState();
73
+
74
+ const doodleData = canvasToArray();
75
+
76
+ try {
77
+ const response = await fetch('/classify', {
78
+ method: 'POST',
79
+ headers: { 'Content-Type': 'application/json' },
80
+ body: JSON.stringify({ doodle: doodleData }),
81
+ });
82
+
83
+ if (!response.ok) {
84
+ throw new Error(`API error: ${response.status}`);
85
+ }
86
+
87
+ const results = await response.json();
88
+ displayResults(results);
89
+ } catch (error) {
90
+ console.error('Error classifying doodle:', error);
91
+ }
92
+ }
93
+
94
+ const debouncedClassify = debounce(classifyDoodle, 500);
95
+
96
+ function showLoadingState() {
97
+ let loadingRows = '';
98
+ for (let i = 0; i < 5; i++) {
99
+ loadingRows += `
100
+ <tr>
101
+ <td><div class="loading-placeholder">Loading...</div></td>
102
+ <td><div class="loading-placeholder"></div></td>
103
+ </tr>
104
+ `;
105
+ }
106
+ resultsBody.innerHTML = loadingRows;
107
+ }
108
+
109
+ function displayResults(results) {
110
+ const topResults = Object.entries(results)
111
+ .sort((a, b) => b[1] - a[1])
112
+ .slice(0, 5);
113
+
114
+ if (topResults.length === 0) {
115
+ resultsBody.innerHTML = `
116
+ <tr>
117
+ <td colspan="2">No results found</td>
118
+ </tr>
119
+ `;
120
+ return;
121
+ }
122
+
123
+ let tableRows = '';
124
+ topResults.forEach(([category, probability]) => {
125
+ const probabilityPercent = (probability * 100).toFixed(2);
126
+ tableRows += `
127
+ <tr>
128
+ <td>${category}</td>
129
+ <td>
130
+ ${probabilityPercent}%
131
+ <div class="probability-bar">
132
+ <div class="probability-fill" style="width: ${probabilityPercent}%;"></div>
133
+ </div>
134
+ </td>
135
+ </tr>
136
+ `;
137
+ });
138
+
139
+ resultsBody.innerHTML = tableRows;
140
+ }
141
+
142
+ function startDrawing(e) {
143
+ isDrawing = true;
144
+ draw(e);
145
+ }
146
+
147
+ function stopDrawing() {
148
+ if (isDrawing) {
149
+ isDrawing = false;
150
+ ctx.beginPath();
151
+ debouncedClassify();
152
+ }
153
+ }
154
+
155
+ function draw(e) {
156
+ if (!isDrawing) return;
157
+ e.preventDefault();
158
+
159
+ const rect = canvas.getBoundingClientRect();
160
+ const x = (e.clientX || e.touches[0].clientX) - rect.left;
161
+ const y = (e.clientY || e.touches[0].clientY) - rect.top;
162
+
163
+ ctx.lineTo(x, y);
164
+ ctx.stroke();
165
+ ctx.beginPath();
166
+ ctx.moveTo(x, y);
167
+ }
168
+
169
+ canvas.addEventListener('mousedown', startDrawing);
170
+ canvas.addEventListener('mousemove', draw);
171
+ canvas.addEventListener('mouseup', stopDrawing);
172
+ clearButton.addEventListener('click', clearCanvas);
173
+ });
static/style.css ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ body {
2
+ font-family: 'Poppins', sans-serif;
3
+ background-color: #121212;
4
+ color: #ffffff;
5
+ margin: 0;
6
+ padding: 0;
7
+ display: flex;
8
+ flex-direction: column;
9
+ justify-content: center;
10
+ align-items: center;
11
+ height: 100vh;
12
+ }
13
+
14
+ .main-container {
15
+ display: grid;
16
+ grid-template-columns: 1fr 1.5fr;
17
+ gap: 30px;
18
+ width: 80%;
19
+ max-width: 1000px;
20
+ background: rgba(255, 255, 255, 0.1);
21
+ padding: 20px;
22
+ border-radius: 15px;
23
+ box-shadow: 0 8px 16px rgba(0, 255, 255, 0.2);
24
+ backdrop-filter: blur(10px);
25
+ }
26
+
27
+ .left-section {
28
+ display: flex;
29
+ flex-direction: column;
30
+ align-items: center;
31
+ justify-content: center;
32
+ }
33
+
34
+ canvas {
35
+ border-radius: 15px;
36
+ border: 3px solid rgba(0, 255, 255, 0.4);
37
+ background-color: white;
38
+ box-shadow: 0 0 10px rgba(0, 255, 255, 0.5);
39
+ transition: transform 0.3s ease-in-out;
40
+ }
41
+
42
+ canvas:hover {
43
+ transform: scale(1.02);
44
+ box-shadow: 0 0 20px rgba(0, 255, 255, 0.8);
45
+ }
46
+
47
+ h1 {
48
+ font-size: 24px;
49
+ text-transform: uppercase;
50
+ letter-spacing: 2px;
51
+ color: #0ff;
52
+ text-shadow: 0 0 10px #0ff;
53
+ text-align: center;
54
+ }
55
+
56
+ table {
57
+ width: 100%;
58
+ margin-top: 15px;
59
+ border-collapse: collapse;
60
+ }
61
+
62
+ th, td {
63
+ padding: 10px;
64
+ text-align: left;
65
+ border-bottom: 2px solid rgba(255, 255, 255, 0.2);
66
+ }
67
+
68
+ th {
69
+ font-size: 16px;
70
+ text-transform: uppercase;
71
+ letter-spacing: 1px;
72
+ color: #0ff;
73
+ }
74
+
75
+ tr:hover {
76
+ background-color: rgba(255, 255, 255, 0.1);
77
+ }
78
+
79
+ .probability-bar {
80
+ height: 12px;
81
+ background-color: rgba(255, 255, 255, 0.2);
82
+ border-radius: 5px;
83
+ overflow: hidden;
84
+ position: relative;
85
+ }
86
+
87
+ .probability-fill {
88
+ height: 100%;
89
+ background: linear-gradient(90deg, #ff00ff, #0ff);
90
+ transition: width 0.4s ease-in-out;
91
+ }
92
+
93
+ button {
94
+ display: block;
95
+ margin: 20px auto 0;
96
+ padding: 12px 20px;
97
+ font-size: 16px;
98
+ font-weight: bold;
99
+ text-transform: uppercase;
100
+ color: #fff;
101
+ background: linear-gradient(90deg, #ff00ff, #0ff);
102
+ border: none;
103
+ border-radius: 25px;
104
+ cursor: pointer;
105
+ transition: 0.3s ease-in-out;
106
+ box-shadow: 0 4px 10px rgba(255, 0, 255, 0.5);
107
+ }
108
+
109
+ button:hover {
110
+ transform: scale(1.05);
111
+ box-shadow: 0 6px 15px rgba(255, 0, 255, 0.8);
112
+ }
113
+
114
+ button:active {
115
+ transform: scale(0.95);
116
+ }
117
+
118
+ @media (max-width: 768px) {
119
+ .main-container {
120
+ grid-template-columns: 1fr;
121
+ }
122
+ }
templates/index.html ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
6
+ <title>Doodle Classifier</title>
7
+ <link
8
+ rel="stylesheet"
9
+ href="{{ url_for('static', filename='style.css') }}"
10
+ />
11
+
12
+ </head>
13
+ <body>
14
+ <h1>Doodle Classifier</h1>
15
+ <div class="main-container">
16
+ <div class="left-section" >
17
+ <canvas id="drawingCanvas"></canvas>
18
+ <button id="clearButton">Clear</button>
19
+ </div>
20
+ <table >
21
+ <thead>
22
+ <tr>
23
+ <th>Prediction</th>
24
+ <th>Probability</th>
25
+ </tr>
26
+ </thead>
27
+ <tbody id="resultsBody">
28
+ <tr>
29
+ <td colspan="2">Draw something to classify</td>
30
+ </tr>
31
+ </tbody>
32
+ </table>
33
+ </div>
34
+ <script src="{{ url_for('static', filename='script.js') }}" defer></script>
35
+ </body>
36
+ </html>