pytorch initial commit

#1
by BobbyDUVA - opened
keypoint.csv ADDED
The diff for this file is too large to render. See raw diff
 
keypoint_classification_EN_pytorch.ipynb ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {
7
+ "colab": {
8
+ "base_uri": "https://localhost:8080/"
9
+ },
10
+ "id": "ypqky9tc9hE1",
11
+ "outputId": "5db082bb-30e3-4110-bf63-a1ee777ecd46"
12
+ },
13
+ "outputs": [
14
+ {
15
+ "name": "stdout",
16
+ "output_type": "stream",
17
+ "text": [
18
+ "Collecting torch\n",
19
+ " Using cached torch-2.9.1-cp312-cp312-win_amd64.whl.metadata (30 kB)\n",
20
+ "Collecting torchvision\n",
21
+ " Using cached torchvision-0.24.1-cp312-cp312-win_amd64.whl.metadata (5.9 kB)\n",
22
+ "Collecting filelock (from torch)\n",
23
+ " Using cached filelock-3.20.0-py3-none-any.whl.metadata (2.1 kB)\n",
24
+ "Requirement already satisfied: typing-extensions>=4.10.0 in c:\\users\\rfd\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch) (4.15.0)\n",
25
+ "Requirement already satisfied: sympy>=1.13.3 in c:\\users\\rfd\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch) (1.14.0)\n",
26
+ "Collecting networkx>=2.5.1 (from torch)\n",
27
+ " Using cached networkx-3.5-py3-none-any.whl.metadata (6.3 kB)\n",
28
+ "Collecting jinja2 (from torch)\n",
29
+ " Using cached jinja2-3.1.6-py3-none-any.whl.metadata (2.9 kB)\n",
30
+ "Collecting fsspec>=0.8.5 (from torch)\n",
31
+ " Using cached fsspec-2025.10.0-py3-none-any.whl.metadata (10 kB)\n",
32
+ "Collecting setuptools (from torch)\n",
33
+ " Using cached setuptools-80.9.0-py3-none-any.whl.metadata (6.6 kB)\n",
34
+ "Collecting numpy (from torchvision)\n",
35
+ " Using cached numpy-2.3.4-cp312-cp312-win_amd64.whl.metadata (60 kB)\n",
36
+ "Collecting pillow!=8.3.*,>=5.3.0 (from torchvision)\n",
37
+ " Using cached pillow-12.0.0-cp312-cp312-win_amd64.whl.metadata (9.0 kB)\n",
38
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in c:\\users\\rfd\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from sympy>=1.13.3->torch) (1.3.0)\n",
39
+ "Collecting MarkupSafe>=2.0 (from jinja2->torch)\n",
40
+ " Using cached markupsafe-3.0.3-cp312-cp312-win_amd64.whl.metadata (2.8 kB)\n",
41
+ "Using cached torch-2.9.1-cp312-cp312-win_amd64.whl (110.9 MB)\n",
42
+ "Using cached torchvision-0.24.1-cp312-cp312-win_amd64.whl (4.3 MB)\n",
43
+ "Using cached fsspec-2025.10.0-py3-none-any.whl (200 kB)\n",
44
+ "Using cached networkx-3.5-py3-none-any.whl (2.0 MB)\n",
45
+ "Using cached pillow-12.0.0-cp312-cp312-win_amd64.whl (7.0 MB)\n",
46
+ "Using cached filelock-3.20.0-py3-none-any.whl (16 kB)\n",
47
+ "Using cached jinja2-3.1.6-py3-none-any.whl (134 kB)\n",
48
+ "Using cached numpy-2.3.4-cp312-cp312-win_amd64.whl (12.8 MB)\n",
49
+ "Using cached setuptools-80.9.0-py3-none-any.whl (1.2 MB)\n",
50
+ "Using cached markupsafe-3.0.3-cp312-cp312-win_amd64.whl (15 kB)\n",
51
+ "Installing collected packages: setuptools, pillow, numpy, networkx, MarkupSafe, fsspec, filelock, jinja2, torch, torchvision\n",
52
+ "Successfully installed MarkupSafe-3.0.3 filelock-3.20.0 fsspec-2025.10.0 jinja2-3.1.6 networkx-3.5 numpy-2.3.4 pillow-12.0.0 setuptools-80.9.0 torch-2.9.1 torchvision-0.24.1\n"
53
+ ]
54
+ },
55
+ {
56
+ "name": "stderr",
57
+ "output_type": "stream",
58
+ "text": [
59
+ "\n",
60
+ "[notice] A new release of pip is available: 24.2 -> 25.3\n",
61
+ "[notice] To update, run: python.exe -m pip install --upgrade pip\n"
62
+ ]
63
+ }
64
+ ],
65
+ "source": [
66
+ "!pip install torch torchvision\n",
67
+ "import torch\n",
68
+ "import torch.nn as nn\n",
69
+ "import torch.nn.functional as F\n",
70
+ "\n",
71
+ "class KeypointClassifier(nn.Module):\n",
72
+ " def __init__(self, num_classes=12):\n",
73
+ " super().__init__()\n",
74
+ " self.dropout1 = nn.Dropout(0.2)\n",
75
+ " self.fc1 = nn.Linear(42, 20)\n",
76
+ " self.dropout2 = nn.Dropout(0.4)\n",
77
+ " self.fc2 = nn.Linear(20, 10)\n",
78
+ " self.fc3 = nn.Linear(10, num_classes)\n",
79
+ "\n",
80
+ " def forward(self, x):\n",
81
+ " x = self.dropout1(x)\n",
82
+ " x = F.relu(self.fc1(x))\n",
83
+ " x = self.dropout2(x)\n",
84
+ " x = F.relu(self.fc2(x))\n",
85
+ " x = self.fc3(x) # NO softmax here (PyTorch loss expects raw logits)\n",
86
+ " return x\n"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": null,
92
+ "metadata": {
93
+ "id": "MbMjOflQ9hE1"
94
+ },
95
+ "outputs": [
96
+ {
97
+ "name": "stdout",
98
+ "output_type": "stream",
99
+ "text": [
100
+ "Collecting sklearn\n",
101
+ " Downloading sklearn-0.0.post12.tar.gz (2.6 kB)\n",
102
+ " Installing build dependencies: started\n",
103
+ " Installing build dependencies: finished with status 'done'\n",
104
+ " Getting requirements to build wheel: started\n",
105
+ " Getting requirements to build wheel: finished with status 'error'\n"
106
+ ]
107
+ },
108
+ {
109
+ "name": "stderr",
110
+ "output_type": "stream",
111
+ "text": [
112
+ " error: subprocess-exited-with-error\n",
113
+ " \n",
114
+ " × Getting requirements to build wheel did not run successfully.\n",
115
+ " │ exit code: 1\n",
116
+ " ╰─> [15 lines of output]\n",
117
+ " The 'sklearn' PyPI package is deprecated, use 'scikit-learn'\n",
118
+ " rather than 'sklearn' for pip commands.\n",
119
+ " \n",
120
+ " Here is how to fix this error in the main use cases:\n",
121
+ " - use 'pip install scikit-learn' rather than 'pip install sklearn'\n",
122
+ " - replace 'sklearn' by 'scikit-learn' in your pip requirements files\n",
123
+ " (requirements.txt, setup.py, setup.cfg, Pipfile, etc ...)\n",
124
+ " - if the 'sklearn' package is used by one of your dependencies,\n",
125
+ " it would be great if you take some time to track which package uses\n",
126
+ " 'sklearn' instead of 'scikit-learn' and report it to their issue tracker\n",
127
+ " - as a last resort, set the environment variable\n",
128
+ " SKLEARN_ALLOW_DEPRECATED_SKLEARN_PACKAGE_INSTALL=True to avoid this error\n",
129
+ " \n",
130
+ " More information is available at\n",
131
+ " https://github.com/scikit-learn/sklearn-pypi-package\n",
132
+ " [end of output]\n",
133
+ " \n",
134
+ " note: This error originates from a subprocess, and is likely not a problem with pip.\n",
135
+ "\n",
136
+ "[notice] A new release of pip is available: 24.2 -> 25.3\n",
137
+ "[notice] To update, run: python.exe -m pip install --upgrade pip\n",
138
+ "error: subprocess-exited-with-error\n",
139
+ "\n",
140
+ "× Getting requirements to build wheel did not run successfully.\n",
141
+ "│ exit code: 1\n",
142
+ "╰─> See above for output.\n",
143
+ "\n",
144
+ "note: This error originates from a subprocess, and is likely not a problem with pip.\n"
145
+ ]
146
+ },
147
+ {
148
+ "ename": "ModuleNotFoundError",
149
+ "evalue": "No module named 'sklearn'",
150
+ "output_type": "error",
151
+ "traceback": [
152
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
153
+ "\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)",
154
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 4\u001b[39m\n\u001b[32m 1\u001b[39m get_ipython().system(\u001b[33m'\u001b[39m\u001b[33mpip install sklearn\u001b[39m\u001b[33m'\u001b[39m)\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnp\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m4\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01msklearn\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mmodel_selection\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m train_test_split\n\u001b[32m 5\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorch\u001b[39;00m\n\u001b[32m 6\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorch\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mutils\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mdata\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m TensorDataset, DataLoader\n",
155
+ "\u001b[31mModuleNotFoundError\u001b[39m: No module named 'sklearn'"
156
+ ]
157
+ }
158
+ ],
159
+ "source": [
160
+ "!pip install scikit-learn\n",
161
+ "\n",
162
+ "import numpy as np\n",
163
+ "from sklearn.model_selection import train_test_split\n",
164
+ "import torch\n",
165
+ "from torch.utils.data import TensorDataset, DataLoader\n",
166
+ "\n",
167
+ "dataset = 'model/keypoint_classifier/keypoint.csv'\n",
168
+ "\n",
169
+ "X = np.loadtxt(dataset, delimiter=',', dtype='float32', usecols=list(range(1, 43)))\n",
170
+ "y = np.loadtxt(dataset, delimiter=',', dtype='int64', usecols=(0))\n",
171
+ "\n",
172
+ "X_train, X_test, y_train, y_test = train_test_split(\n",
173
+ " X, y, train_size=0.75, random_state=42\n",
174
+ ")\n",
175
+ "\n",
176
+ "train_ds = TensorDataset(\n",
177
+ " torch.tensor(X_train, dtype=torch.float32),\n",
178
+ " torch.tensor(y_train, dtype=torch.long)\n",
179
+ ")\n",
180
+ "test_ds = TensorDataset(\n",
181
+ " torch.tensor(X_test, dtype=torch.float32),\n",
182
+ " torch.tensor(y_test, dtype=torch.long)\n",
183
+ ")\n",
184
+ "\n",
185
+ "train_dl = DataLoader(train_ds, batch_size=128, shuffle=True)\n",
186
+ "test_dl = DataLoader(test_ds, batch_size=128)\n"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": null,
192
+ "metadata": {
193
+ "id": "c3Dac0M_9hE2"
194
+ },
195
+ "outputs": [],
196
+ "source": [
197
+ "model = KeypointClassifier(num_classes=12)\n",
198
+ "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
199
+ "criterion = nn.CrossEntropyLoss()\n",
200
+ "\n",
201
+ "EPOCHS = 100\n",
202
+ "\n",
203
+ "for epoch in range(EPOCHS):\n",
204
+ " model.train()\n",
205
+ " total_loss = 0\n",
206
+ "\n",
207
+ " for xb, yb in train_dl:\n",
208
+ " optimizer.zero_grad()\n",
209
+ " logits = model(xb)\n",
210
+ " loss = criterion(logits, yb)\n",
211
+ " loss.backward()\n",
212
+ " optimizer.step()\n",
213
+ " total_loss += loss.item()\n",
214
+ "\n",
215
+ " print(f\"Epoch {epoch+1}/{EPOCHS} - Loss: {total_loss:.4f}\")\n"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "execution_count": null,
221
+ "metadata": {
222
+ "colab": {
223
+ "base_uri": "https://localhost:8080/"
224
+ },
225
+ "id": "WirBl-JE9hE3",
226
+ "outputId": "71b30ca2-8294-4d9d-8aa2-800d90d399de",
227
+ "scrolled": true
228
+ },
229
+ "outputs": [],
230
+ "source": [
231
+ "model.eval()\n",
232
+ "correct = 0\n",
233
+ "total = 0\n",
234
+ "\n",
235
+ "with torch.no_grad():\n",
236
+ " for xb, yb in test_dl:\n",
237
+ " preds = model(xb).argmax(dim=1)\n",
238
+ " correct += (preds == yb).sum().item()\n",
239
+ " total += yb.size(0)\n",
240
+ "\n",
241
+ "print(\"Accuracy:\", correct / total)\n"
242
+ ]
243
+ },
244
+ {
245
+ "cell_type": "code",
246
+ "execution_count": null,
247
+ "metadata": {
248
+ "colab": {
249
+ "base_uri": "https://localhost:8080/"
250
+ },
251
+ "id": "pxvb2Y299hE3",
252
+ "outputId": "59eb3185-2e37-4b9e-bc9d-ab1b8ac29b7f"
253
+ },
254
+ "outputs": [],
255
+ "source": [
256
+ "torch.save(model.state_dict(), \"keypoint_classifier_pytorch.pth\")\n"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": null,
262
+ "metadata": {
263
+ "id": "RBkmDeUW9hE4"
264
+ },
265
+ "outputs": [],
266
+ "source": [
267
+ "model_quant = torch.quantization.quantize_dynamic(\n",
268
+ " model, \n",
269
+ " {nn.Linear}, # Quantize only linear layers\n",
270
+ " dtype=torch.qint8\n",
271
+ ")\n",
272
+ "\n",
273
+ "torch.save(model_quant.state_dict(), \"keypoint_classifier_quantized.pth\")\n"
274
+ ]
275
+ },
276
+ {
277
+ "cell_type": "code",
278
+ "execution_count": null,
279
+ "metadata": {
280
+ "colab": {
281
+ "base_uri": "https://localhost:8080/"
282
+ },
283
+ "id": "tFz9Tb0I9hE4",
284
+ "outputId": "1c3b3528-54ae-4ee2-ab04-77429211cbef"
285
+ },
286
+ "outputs": [],
287
+ "source": []
288
+ },
289
+ {
290
+ "cell_type": "code",
291
+ "execution_count": null,
292
+ "metadata": {
293
+ "colab": {
294
+ "base_uri": "https://localhost:8080/",
295
+ "height": 582
296
+ },
297
+ "id": "AP1V6SCk9hE5",
298
+ "outputId": "08e41a80-7a4a-4619-8125-ecc371368d19"
299
+ },
300
+ "outputs": [],
301
+ "source": []
302
+ },
303
+ {
304
+ "cell_type": "code",
305
+ "execution_count": null,
306
+ "metadata": {
307
+ "id": "ODjnYyld9hE6"
308
+ },
309
+ "outputs": [],
310
+ "source": []
311
+ },
312
+ {
313
+ "cell_type": "code",
314
+ "execution_count": null,
315
+ "metadata": {
316
+ "colab": {
317
+ "base_uri": "https://localhost:8080/"
318
+ },
319
+ "id": "zRfuK8Y59hE6",
320
+ "outputId": "a4ca585c-b5d5-4244-8291-8674063209bb"
321
+ },
322
+ "outputs": [],
323
+ "source": []
324
+ },
325
+ {
326
+ "cell_type": "code",
327
+ "execution_count": null,
328
+ "metadata": {
329
+ "colab": {
330
+ "base_uri": "https://localhost:8080/"
331
+ },
332
+ "id": "s4FoAnuc9hE7",
333
+ "outputId": "91f18257-8d8b-4ef3-c558-e9b5f94fabbf",
334
+ "scrolled": true
335
+ },
336
+ "outputs": [],
337
+ "source": [
338
+ "\n"
339
+ ]
340
+ },
341
+ {
342
+ "cell_type": "code",
343
+ "execution_count": null,
344
+ "metadata": {
345
+ "colab": {
346
+ "base_uri": "https://localhost:8080/"
347
+ },
348
+ "id": "vONjp19J9hE8",
349
+ "outputId": "77205e24-fd00-42c4-f7b6-e06e527c2cba"
350
+ },
351
+ "outputs": [],
352
+ "source": []
353
+ }
354
+ ],
355
+ "metadata": {
356
+ "accelerator": "GPU",
357
+ "colab": {
358
+ "collapsed_sections": [],
359
+ "name": "keypoint_classification_EN.ipynb",
360
+ "provenance": [],
361
+ "toc_visible": true
362
+ },
363
+ "kernelspec": {
364
+ "display_name": "Python 3",
365
+ "language": "python",
366
+ "name": "python3"
367
+ },
368
+ "language_info": {
369
+ "codemirror_mode": {
370
+ "name": "ipython",
371
+ "version": 3
372
+ },
373
+ "file_extension": ".py",
374
+ "mimetype": "text/x-python",
375
+ "name": "python",
376
+ "nbconvert_exporter": "python",
377
+ "pygments_lexer": "ipython3",
378
+ "version": "3.12.7"
379
+ }
380
+ },
381
+ "nbformat": 4,
382
+ "nbformat_minor": 0
383
+ }
keypoint_classifier.hdf5 ADDED
Binary file (23.2 kB). View file
 
keypoint_classifier_label.csv ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Open
2
+ Close
3
+ Pointer
4
+ Pinch
5
+ Thumbs Up
6
+ Thumbs Down
7
+ Thumbs Sideways
8
+ Pinch Pinky
9
+ L
10
+ Click Up
11
+ Click Down
12
+ Yolo
keypoint_classifier_pytorch.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2e2939ce6244459e743d8d098585051961be0c01761be81bd606fcb6731086b
3
+ size 8005
keypoint_classifier_pytorch.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ # -------------------------------
5
+ # PyTorch model definition
6
+ # -------------------------------
7
+ class KeyPointClassifierModel(nn.Module):
8
+ def __init__(self):
9
+ super().__init__()
10
+ self.fc1 = nn.Linear(42, 20)
11
+ self.relu1 = nn.ReLU()
12
+ self.fc2 = nn.Linear(20, 10)
13
+ self.relu2 = nn.ReLU()
14
+ self.fc3 = nn.Linear(10, 12) # match checkpoint output classes
15
+
16
+ def forward(self, x):
17
+ x = self.fc1(x)
18
+ x = self.relu1(x)
19
+ x = self.fc2(x)
20
+ x = self.relu2(x)
21
+ x = self.fc3(x)
22
+ return x
23
+
24
+ # -------------------------------
25
+ # Wrapper class for easy usage
26
+ # -------------------------------
27
+ class KeyPointClassifier:
28
+ def __init__(self, model_path="keypoint_classifier_pytorch.pth", device='cpu'):
29
+ self.device = device
30
+ self.model = KeyPointClassifierModel()
31
+ # Load the checkpoint
32
+ self.model.load_state_dict(torch.load(model_path, map_location=self.device))
33
+ self.model.to(self.device)
34
+ self.model.eval()
35
+
36
+ def __call__(self, landmark_list):
37
+ with torch.no_grad():
38
+ x = torch.tensor([landmark_list], dtype=torch.float32).to(self.device)
39
+ output = self.model(x)
40
+ prob = torch.softmax(output, dim=1)
41
+ conf, pred = torch.max(prob, dim=1)
42
+ return pred.item(), conf.item()
keypoint_classifier_quantized.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d06732f994e873e78791a9785ef120deb2e8ab30fdc066741ea4900bae92443
3
+ size 6731