diff --git "a/Final_Inference.ipynb" "b/Final_Inference.ipynb" new file mode 100644--- /dev/null +++ "b/Final_Inference.ipynb" @@ -0,0 +1,2402 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "bSu3U-SNu8v6", + "outputId": "494f98f4-d490-4010-8e79-9ba10fb7a06b" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: huggingface_hub in /usr/local/lib/python3.12/dist-packages (0.36.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from huggingface_hub) (3.20.0)\n", + "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.12/dist-packages (from huggingface_hub) (2025.3.0)\n", + "Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.12/dist-packages (from huggingface_hub) (25.0)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from huggingface_hub) (6.0.3)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from huggingface_hub) (2.32.4)\n", + "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.12/dist-packages (from huggingface_hub) (4.67.1)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.12/dist-packages (from huggingface_hub) (4.15.0)\n", + "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface_hub) (1.2.0)\n", + "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface_hub) (3.4.4)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface_hub) (3.11)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface_hub) (2.5.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface_hub) (2025.11.12)\n", + "Requirement already satisfied: opencv-python in /usr/local/lib/python3.12/dist-packages (4.12.0.88)\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.12/dist-packages (2.9.0+cu126)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (2.0.2)\n", + "Requirement already satisfied: torchvision in /usr/local/lib/python3.12/dist-packages (0.24.0+cu126)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (4.67.1)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (2.2.2)\n", + "Requirement already satisfied: ipywidgets in /usr/local/lib/python3.12/dist-packages (7.7.1)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch) (3.20.0)\n", + "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.12/dist-packages (from torch) (4.15.0)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch) (75.2.0)\n", + "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch) (1.14.0)\n", + "Requirement already satisfied: networkx>=2.5.1 in /usr/local/lib/python3.12/dist-packages (from torch) (3.6.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch) (3.1.6)\n", + "Requirement already satisfied: fsspec>=0.8.5 in /usr/local/lib/python3.12/dist-packages (from torch) (2025.3.0)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.77)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.77)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.80)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch) (9.10.2.21)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.4.1)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch) (11.3.0.4)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch) (10.3.7.77)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch) (11.7.1.2)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch) (12.5.4.2)\n", + "Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch) (0.7.1)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /usr/local/lib/python3.12/dist-packages (from torch) (2.27.5)\n", + "Requirement already satisfied: nvidia-nvshmem-cu12==3.3.20 in /usr/local/lib/python3.12/dist-packages (from torch) (3.3.20)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.77)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.85)\n", + "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch) (1.11.1.6)\n", + "Requirement already satisfied: triton==3.5.0 in /usr/local/lib/python3.12/dist-packages (from torch) (3.5.0)\n", + "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.12/dist-packages (from torchvision) (11.3.0)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas) (2025.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas) (2025.3)\n", + "Requirement already satisfied: ipykernel>=4.5.1 in /usr/local/lib/python3.12/dist-packages (from ipywidgets) (6.17.1)\n", + "Requirement already satisfied: ipython-genutils~=0.2.0 in /usr/local/lib/python3.12/dist-packages (from ipywidgets) (0.2.0)\n", + "Requirement already satisfied: traitlets>=4.3.1 in /usr/local/lib/python3.12/dist-packages (from ipywidgets) (5.7.1)\n", + "Requirement already satisfied: widgetsnbextension~=3.6.0 in /usr/local/lib/python3.12/dist-packages (from ipywidgets) (3.6.10)\n", + "Requirement already satisfied: ipython>=4.0.0 in /usr/local/lib/python3.12/dist-packages (from ipywidgets) (7.34.0)\n", + "Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /usr/local/lib/python3.12/dist-packages (from ipywidgets) (3.0.16)\n", + "Requirement already satisfied: debugpy>=1.0 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets) (1.8.15)\n", + "Requirement already satisfied: jupyter-client>=6.1.12 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets) (7.4.9)\n", + "Requirement already satisfied: matplotlib-inline>=0.1 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets) (0.2.1)\n", + "Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets) (1.6.0)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets) (25.0)\n", + "Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets) (5.9.5)\n", + "Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets) (26.2.1)\n", + "Requirement already satisfied: tornado>=6.1 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets) (6.5.1)\n", + "Requirement already satisfied: jedi>=0.16 in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets) (0.19.2)\n", + "Requirement already satisfied: decorator in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets) (4.4.2)\n", + "Requirement already satisfied: pickleshare in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets) (0.7.5)\n", + "Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets) (3.0.52)\n", + "Requirement already satisfied: pygments in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets) (2.19.2)\n", + "Requirement already satisfied: backcall in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets) (0.2.0)\n", + "Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets) (4.9.0)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.8.2->pandas) (1.17.0)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch) (1.3.0)\n", + "Requirement already satisfied: notebook>=4.4.1 in /usr/local/lib/python3.12/dist-packages (from widgetsnbextension~=3.6.0->ipywidgets) (6.5.7)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch) (3.0.3)\n", + "Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.12/dist-packages (from jedi>=0.16->ipython>=4.0.0->ipywidgets) (0.8.5)\n", + "Requirement already satisfied: entrypoints in /usr/local/lib/python3.12/dist-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets) (0.4)\n", + "Requirement already satisfied: jupyter-core>=4.9.2 in /usr/local/lib/python3.12/dist-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets) (5.9.1)\n", + "Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (25.1.0)\n", + "Requirement already satisfied: nbformat in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (5.10.4)\n", + "Requirement already satisfied: nbconvert>=5 in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (7.16.6)\n", + "Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (1.8.3)\n", + "Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.18.1)\n", + "Requirement already satisfied: prometheus-client in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.23.1)\n", + "Requirement already satisfied: nbclassic>=0.4.7 in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (1.3.3)\n", + "Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.12/dist-packages (from pexpect>4.3->ipython>=4.0.0->ipywidgets) (0.7.0)\n", + "Requirement already satisfied: wcwidth in /usr/local/lib/python3.12/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=4.0.0->ipywidgets) (0.2.14)\n", + "Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.12/dist-packages (from jupyter-core>=4.9.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets) (4.5.1)\n", + "Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.12/dist-packages (from nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.2.4)\n", + "Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (4.13.5)\n", + "Requirement already satisfied: bleach!=5.0.0 in /usr/local/lib/python3.12/dist-packages (from bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (6.3.0)\n", + "Requirement already satisfied: defusedxml in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.7.1)\n", + "Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.3.0)\n", + "Requirement already satisfied: mistune<4,>=2.0.3 in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (3.1.4)\n", + "Requirement already satisfied: nbclient>=0.5.0 in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.10.2)\n", + "Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (1.5.1)\n", + "Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.12/dist-packages (from nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (2.21.2)\n", + "Requirement already satisfied: jsonschema>=2.6 in /usr/local/lib/python3.12/dist-packages (from nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (4.25.1)\n", + "Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.12/dist-packages (from argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (25.1.0)\n", + "Requirement already satisfied: webencodings in /usr/local/lib/python3.12/dist-packages (from bleach!=5.0.0->bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.5.1)\n", + "Requirement already satisfied: tinycss2<1.5,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (1.4.0)\n", + "Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (25.4.0)\n", + "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (2025.9.1)\n", + "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.37.0)\n", + "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.30.0)\n", + "Requirement already satisfied: jupyter-server<3,>=1.8 in /usr/local/lib/python3.12/dist-packages (from notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (2.14.0)\n", + "Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (2.0.0)\n", + "Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.12/dist-packages (from beautifulsoup4->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (2.8)\n", + "Requirement already satisfied: pycparser in /usr/local/lib/python3.12/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (2.23)\n", + "Requirement already satisfied: anyio>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (4.12.0)\n", + "Requirement already satisfied: jupyter-events>=0.9.0 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.12.0)\n", + "Requirement already satisfied: jupyter-server-terminals>=0.4.4 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.5.3)\n", + "Requirement already satisfied: overrides>=5.0 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (7.7.0)\n", + "Requirement already satisfied: websocket-client>=1.7 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (1.9.0)\n", + "Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.12/dist-packages (from anyio>=3.1.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (3.11)\n", + "Requirement already satisfied: python-json-logger>=2.0.4 in /usr/local/lib/python3.12/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (4.0.0)\n", + "Requirement already satisfied: pyyaml>=5.3 in /usr/local/lib/python3.12/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (6.0.3)\n", + "Requirement already satisfied: rfc3339-validator in /usr/local/lib/python3.12/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.1.4)\n", + "Requirement already satisfied: rfc3986-validator>=0.1.1 in /usr/local/lib/python3.12/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.1.1)\n", + "Requirement already satisfied: fqdn in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (1.5.1)\n", + "Requirement already satisfied: isoduration in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (20.11.0)\n", + "Requirement already satisfied: jsonpointer>1.13 in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (3.0.0)\n", + "Requirement already satisfied: rfc3987-syntax>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (1.1.0)\n", + "Requirement already satisfied: uri-template in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (1.3.0)\n", + "Requirement already satisfied: webcolors>=24.6.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (25.10.0)\n", + "Requirement already satisfied: lark>=1.2.2 in /usr/local/lib/python3.12/dist-packages (from rfc3987-syntax>=1.1.0->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (1.3.1)\n", + "Requirement already satisfied: arrow>=0.15.0 in /usr/local/lib/python3.12/dist-packages (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (1.4.0)\n", + "Requirement already satisfied: lap in /usr/local/lib/python3.12/dist-packages (0.5.12)\n", + "Requirement already satisfied: numpy>=1.21.6 in /usr/local/lib/python3.12/dist-packages (from lap) (2.0.2)\n", + "Requirement already satisfied: ultralytics in /usr/local/lib/python3.12/dist-packages (8.3.248)\n", + "Requirement already satisfied: numpy>=1.23.0 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (2.0.2)\n", + "Requirement already satisfied: matplotlib>=3.3.0 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (3.10.0)\n", + "Requirement already satisfied: opencv-python>=4.6.0 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (4.12.0.88)\n", + "Requirement already satisfied: pillow>=7.1.2 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (11.3.0)\n", + "Requirement already satisfied: pyyaml>=5.3.1 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (6.0.3)\n", + "Requirement already satisfied: requests>=2.23.0 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (2.32.4)\n", + "Requirement already satisfied: scipy>=1.4.1 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (1.16.3)\n", + "Requirement already satisfied: torch>=1.8.0 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (2.9.0+cu126)\n", + "Requirement already satisfied: torchvision>=0.9.0 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (0.24.0+cu126)\n", + "Requirement already satisfied: psutil>=5.8.0 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (5.9.5)\n", + "Requirement already satisfied: polars>=0.20.0 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (1.31.0)\n", + "Requirement already satisfied: ultralytics-thop>=2.0.18 in /usr/local/lib/python3.12/dist-packages (from ultralytics) (2.0.18)\n", + "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.3.0->ultralytics) (1.3.3)\n", + "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.3.0->ultralytics) (0.12.1)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.3.0->ultralytics) (4.61.1)\n", + "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.3.0->ultralytics) (1.4.9)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.3.0->ultralytics) (25.0)\n", + "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.3.0->ultralytics) (3.2.5)\n", + "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.3.0->ultralytics) (2.9.0.post0)\n", + "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.23.0->ultralytics) (3.4.4)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests>=2.23.0->ultralytics) (3.11)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.23.0->ultralytics) (2.5.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests>=2.23.0->ultralytics) (2025.11.12)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (3.20.0)\n", + "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (4.15.0)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (75.2.0)\n", + "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (1.14.0)\n", + "Requirement already satisfied: networkx>=2.5.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (3.6.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (3.1.6)\n", + "Requirement already satisfied: fsspec>=0.8.5 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (2025.3.0)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (12.6.77)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (12.6.77)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (12.6.80)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (9.10.2.21)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (12.6.4.1)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (11.3.0.4)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (10.3.7.77)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (11.7.1.2)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (12.5.4.2)\n", + "Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (0.7.1)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (2.27.5)\n", + "Requirement already satisfied: nvidia-nvshmem-cu12==3.3.20 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (3.3.20)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (12.6.77)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (12.6.85)\n", + "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (1.11.1.6)\n", + "Requirement already satisfied: triton==3.5.0 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8.0->ultralytics) (3.5.0)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.7->matplotlib>=3.3.0->ultralytics) (1.17.0)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=1.8.0->ultralytics) (1.3.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=1.8.0->ultralytics) (3.0.3)\n" + ] + } + ], + "source": [ + "# ===== INSTALL DEPENDENCIES =====\n", + "!pip install huggingface_hub\n", + "!pip install boto3 -q\n", + "!pip install opencv-python torch numpy torchvision tqdm pandas ipywidgets\n", + "!pip install lap\n", + "!pip install ultralytics" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PMr99Yo7x8N1" + }, + "source": [ + "# Please double, triple, quadruple check that the below code runs without errors before submitting." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4YhoE1nF2Pee" + }, + "source": [ + "## TODO 1 - Enter your HuggingFace username below:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "id": "ENyfncieqs6i" + }, + "outputs": [], + "source": [ + "hf_username = \"maatt4face\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0lC0jUquq_06" + }, + "source": [ + "## TODO 2 - Define your model EXACTLY as you did in your training code (otherwise there will be errors, and, possibly, tears).\n", + "\n", + "Note below the classname is 'YourModelArchitecture'. That's because it literally needs to be YOUR MODEL ARCHITECTURE. This class definition is later referred to below in the 'load_model_from_hub' method. The architecture must match here, or it will not be able to instantiate the model weights correctly once it downloads them from HuggingFace. Pay very close attention to getting this right, please.\n", + "\n", + "Replace the below code, and replace the corresponding line in the 'load_model_from_hub' method." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gX0WITdlisN1" + }, + "source": [ + "### Parameters and Global Variables" + ] + }, + { + "cell_type": "code", + "source": [ + "# Import the required libraries\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from huggingface_hub import hf_hub_download\n", + "import boto3\n", + "from botocore import UNSIGNED\n", + "from botocore.config import Config\n", + "import os\n", + "import cv2\n", + "import numpy as np\n", + "from tqdm import tqdm\n", + "import time" + ], + "metadata": { + "id": "sjIQLKZAj2fz" + }, + "execution_count": 17, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "jtZ65UWcisN1", + "outputId": "7ce01e4d-e6f3-4d67-e9c1-54695927f2f2" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Device: cuda\n", + "Training Start Time (UNIX): 1767566954\n" + ] + } + ], + "source": [ + "# Admin =================================================================================\n", + "import os\n", + "DOWNLOAD_DIR = 'test-data'\n", + "CACHE_DIR = 'cache-data'\n", + "WEIGHTS_DIR = 'model-weights'\n", + "os.makedirs(DOWNLOAD_DIR, exist_ok=True)\n", + "os.makedirs(CACHE_DIR, exist_ok=True)\n", + "os.makedirs(WEIGHTS_DIR, exist_ok=True)\n", + "\n", + "# To GPU or not to GPU =================================================================\n", + "import torch\n", + "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Device: {DEVICE}\")\n", + "\n", + "# Landmark Extraction ==================================================================\n", + "from ultralytics import YOLO\n", + "YOLO_DIR = './yolo-weights'\n", + "os.makedirs(YOLO_DIR, exist_ok=True)\n", + "YOLO_MODEL = \"yolov8s\"\n", + "YOLO_MODEL_FILE = os.path.join(YOLO_DIR, f\"{YOLO_MODEL}-pose.pt\")\n", + "#assert os.path.exists(YOLO_MODEL_FILE)\n", + "YOLO_TRACKER = \"bytetrack.yaml\"\n", + "YOLO_TRACKER_FILE = os.path.join(YOLO_DIR, YOLO_TRACKER)\n", + "#assert os.path.exists(YOLO_TRACKER_FILE)\n", + "\n", + "# Model Stuff ===========================================================================\n", + "DEBUG = False\n", + "RAND_SEED = 4227\n", + "VID_STRIDE = 2\n", + "MAX_FRAMES = 1000\n", + "MAX_EPOCHS = 150\n", + "MAX_FRAME_LENGTH = 640\n", + "HUGGING_FACE_REPO_ID = f\"{hf_username}/mv-final-assignment\"\n", + "\n", + "# Mandatory timestamps for the 5-hour limit\n", + "print(f\"Training Start Time (UNIX): {int(time.time())}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2z-u_-hmisN2" + }, + "source": [ + "### Model" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "id": "cm-y1pPnOGkK" + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n", + "import torch.nn.init as init\n", + "\n", + "class PushupCounterSingleOutput(nn.Module):\n", + " def __init__(self, input_size=1, hidden_size=64):\n", + " super().__init__()\n", + " self.bn = nn.BatchNorm1d(input_size)\n", + " self.lstm = nn.LSTM(\n", + " input_size,\n", + " hidden_size,\n", + " bidirectional=True,\n", + " batch_first=True\n", + " )\n", + "\n", + " # Regression head to map the 'essence' of the video to a count\n", + " self.regressor = nn.Sequential(\n", + " nn.Linear(hidden_size * 2, 64),\n", + " nn.ReLU(),\n", + " nn.Linear(64, 1) # Output: Single scalar\n", + " )\n", + "\n", + " # Apply custom initialization\n", + " self.apply(self._init_weights)\n", + "\n", + " def _init_weights(self, module):\n", + " if isinstance(module, nn.Linear):\n", + " # He initialization for Linear layers\n", + " init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')\n", + " if module.bias is not None:\n", + " init.constant_(module.bias, 0)\n", + "\n", + " elif isinstance(module, nn.LSTM):\n", + " # Orthogonal initialization for recurrent weights\n", + " for name, param in module.named_parameters():\n", + " if 'weight_ih' in name:\n", + " init.xavier_uniform_(param.data)\n", + " elif 'weight_hh' in name:\n", + " init.orthogonal_(param.data)\n", + " elif 'bias' in name:\n", + " init.constant_(param.data, 0)\n", + " # Forget gate bias initialization (standard trick for LSTMs)\n", + " n = param.size(0)\n", + " param.data[n//4:n//2].fill_(1.0)\n", + "\n", + " elif isinstance(module, nn.BatchNorm1d):\n", + " # Standard BN initialization\n", + " init.constant_(module.weight, 1)\n", + " init.constant_(module.bias, 0)\n", + "\n", + " def forward(self, data, lengths):\n", + " # If input is (Batch, SeqLen), turn it into (Batch, SeqLen, 1)\n", + " if data.dim() == 2:\n", + " data = data.unsqueeze(-1)\n", + "\n", + " # 1. Normalize (ignoring camera angle effects)\n", + " data = data.transpose(1, 2)\n", + " data = self.bn(data)\n", + " data = data.transpose(1, 2)\n", + "\n", + " # 2. Pack the padded sequence\n", + " packed_x = pack_padded_sequence(data, lengths.cpu(), batch_first=True, enforce_sorted=False)\n", + " packed_out, (hn, cn) = self.lstm(packed_x)\n", + "\n", + " # 3. Use the Final Hidden State\n", + " # In a Bi-LSTM, the final context is the concatenation of:\n", + " # - The last hidden state of the forward pass\n", + " # - The first hidden state of the backward pass\n", + " # This represents the 'summary' of the entire erratic video.\n", + "\n", + " # hn shape: (num_layers * num_directions, batch, hidden_size)\n", + " # Extract the last layer's forward and backward hidden states\n", + " h_forward = hn[-2, :, :]\n", + " h_backward = hn[-1, :, :]\n", + " combined = torch.cat((h_forward, h_backward), dim=1) # (Batch, hidden_size * 2)\n", + "\n", + " # 4. Predict the final count\n", + " count = self.regressor(combined)\n", + " return count\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "R6OEci-2isN2" + }, + "source": [ + "### Secret Sauce" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zhXVPMi6isN2" + }, + "source": [ + "#### Transformation" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "id": "FSFBSh2BisN2" + }, + "outputs": [], + "source": [ + "import torchvision.transforms.functional as TF\n", + "import random\n", + "import torch # Needed for TF.to_tensor which uses torch internally\n", + "\n", + "\n", + "class VideoResizerSingle:\n", + " def __init__(self, largest_dim):\n", + " # resize such that aspect ratio is retained\n", + " self.largest_dim = largest_dim\n", + "\n", + " def __call__(self, single_frame):\n", + " original_width, original_height = single_frame.size\n", + " scale = max(self.largest_dim/original_width, self.largest_dim/original_height)\n", + " new_width = int(original_width * scale)\n", + " new_height = int(original_height * scale)\n", + " single_frame = TF.resize(single_frame, (new_height, new_width))\n", + " return single_frame\n", + "\n", + "\n", + "class VideoResizer:\n", + " def __init__(self, largest_dim):\n", + " self.largest_dim = largest_dim\n", + "\n", + " def __call__(self, frames_list_pil):\n", + " if not frames_list_pil:\n", + " return []\n", + "\n", + " transformed_frames_pil = []\n", + " helper = VideoResizerSingle(self.largest_dim)\n", + " for img_pil in frames_list_pil:\n", + " img_pil = helper(img_pil)\n", + " transformed_frames_pil.append(img_pil)\n", + " return transformed_frames_pil\n", + "\n", + "\n", + "class VideoTransform:\n", + " def __init__(self, rotation_degrees, hflip_p, vflip_p, color_jitter_params, resize_dims=None):\n", + " self.rotation_degrees = rotation_degrees\n", + " self.hflip_p = hflip_p\n", + " self.vflip_p = vflip_p\n", + " self.color_jitter_params = color_jitter_params # (brightness, contrast, saturation, hue)\n", + " self.resize_dims = resize_dims # (height, width)\n", + "\n", + " def __call__(self, frames_list_pil):\n", + " if not frames_list_pil:\n", + " return []\n", + "\n", + " # Sample parameters once per video\n", + " angle = random.uniform(-self.rotation_degrees, self.rotation_degrees)\n", + " h_flip = random.random() < self.hflip_p\n", + " v_flip = random.random() < self.vflip_p\n", + "\n", + " # Manually compute ColorJitter factors for consistency across frames\n", + " brightness_param = self.color_jitter_params[0]\n", + " contrast_param = self.color_jitter_params[1]\n", + " saturation_param = self.color_jitter_params[2]\n", + " hue_param = self.color_jitter_params[3]\n", + "\n", + " brightness_factor = random.uniform(max(0, 1 - brightness_param), 1 + brightness_param)\n", + " contrast_factor = random.uniform(max(0, 1 - contrast_param), 1 + contrast_param)\n", + " saturation_factor = random.uniform(max(0, 1 - saturation_param), 1 + saturation_param)\n", + " hue_factor = random.uniform(-hue_param, hue_param)\n", + "\n", + "\n", + " transformed_frames_pil = [] # Changed variable name to reflect return type\n", + " for img_pil in frames_list_pil:\n", + " # Apply same random parameters to each frame\n", + " img_pil = TF.rotate(img_pil, angle)\n", + " if h_flip:\n", + " img_pil = TF.hflip(img_pil)\n", + " if v_flip:\n", + " img_pil = TF.vflip(img_pil)\n", + " img_pil = TF.adjust_brightness(img_pil, brightness_factor)\n", + " img_pil = TF.adjust_contrast(img_pil, contrast_factor)\n", + " img_pil = TF.adjust_saturation(img_pil, saturation_factor)\n", + " img_pil = TF.adjust_hue(img_pil, hue_factor)\n", + "\n", + " # Resize while maintaining aspect ratio\n", + " if self.resize_dims:\n", + " original_width, original_height = img_pil.size\n", + " target_height, target_width = self.resize_dims\n", + "\n", + " # Calculate new dimensions to fit within target while maintaining aspect ratio\n", + " scale_w = target_width / original_width\n", + " scale_h = target_height / original_height\n", + " scale = min(scale_w, scale_h)\n", + "\n", + " new_width = int(original_width * scale)\n", + " new_height = int(original_height * scale)\n", + "\n", + " img_pil = TF.resize(img_pil, (new_height, new_width))\n", + "\n", + " # Pad if necessary to reach target_dims\n", + " pad_left = (target_width - new_width) // 2\n", + " pad_right = target_width - new_width - pad_left\n", + " pad_top = (target_height - new_height) // 2\n", + " pad_bottom = target_height - new_height - pad_top\n", + "\n", + " img_pil = TF.pad(img_pil, (pad_left, pad_top, pad_right, pad_bottom))\n", + "\n", + " transformed_frames_pil.append(img_pil) # Append PIL image directly, remove TF.to_tensor\n", + "\n", + " return transformed_frames_pil # Return list of PIL images\n", + "\n", + "TRANSFORMATIONS = VideoTransform(\n", + " rotation_degrees=180,\n", + " hflip_p=0.5,\n", + " vflip_p=0.5,\n", + " color_jitter_params=(0.2, 0.2, 0.2, 0.2),\n", + " resize_dims=(640, 640) # Example: resize to 224x224\n", + ")\n", + "\n", + "RESIZER = VideoResizerSingle(MAX_FRAME_LENGTH)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GeoiEjgGisN3" + }, + "source": [ + "#### Signal Calculators and Utils" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "id": "Ppmdr9qxisN3" + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import torch\n", + "import torch.nn as nn\n", + "import numpy as np\n", + "\n", + "class SignalPerturbator:\n", + " @staticmethod\n", + " def __call__(self, signals):\n", + " \"\"\"Applies dynamic 1D augmentation to the signal sequence.\"\"\"\n", + " # Add Gaussian noise\n", + " noise = torch.randn_like(signals) * 0.01\n", + " # Random scaling (simulates different range of motion)\n", + " scale = random.uniform(0.9, 1.1)\n", + " # Random baseline shift\n", + " shift = random.uniform(-0.05, 0.05)\n", + " return torch.clamp(signals * scale + noise + shift, 0, 1)\n", + "\n", + "\n", + "class SignalAverager:\n", + " @staticmethod\n", + " def __call__(signal_tensor, window_size=3):\n", + " \"\"\"\n", + " Applies a moving average to smooth signals.\n", + " \"\"\"\n", + " # 1. Ensure the tensor is at least 2D (Batch, Length)\n", + " if signal_tensor.dim() == 1:\n", + " signal_tensor = signal_tensor.unsqueeze(0)\n", + "\n", + " # 2. Reshape to (Batch, Channels, Length) for avg_pool1d\n", + " # We treat the elbow angle as a single channel\n", + " x = signal_tensor.unsqueeze(1)\n", + "\n", + " # 3. Apply average pooling\n", + " # stride=1 keeps the resolution the same\n", + " # padding=1 ensures the output length matches the input length\n", + " smoothed = nn.functional.avg_pool1d(x, kernel_size=window_size, stride=1, padding=window_size//2)\n", + "\n", + " # 4. Remove the extra channel dimension and return\n", + " return smoothed.squeeze(1)\n", + "\n", + "class SignalMedianator:\n", + " @staticmethod\n", + " def __call__(signal_tensor, window_size=3):\n", + " \"\"\"\n", + " Applies a moving media to smooth the signals.\n", + " Good for bring out local signals dwarfed by global normalization\n", + " \"\"\"\n", + " def moving_median(data, window):\n", + " return np.array([np.median(data[max(0, i-window//2):min(len(data), i+window//2+1)])\n", + " for i in range(len(data))])\n", + " smoothed = moving_median(signal_tensor, window_size)\n", + " mean = np.mean(smoothed)\n", + " std = np.std(smoothed) + 1e-6\n", + " smoothed = (smoothed - mean) / std\n", + " return torch.tensor(smoothed, dtype=torch.float32)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rppHgwaxisN3" + }, + "source": [ + "#### Elbow Calculation" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "id": "MiihAqg9isN3" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "from enum import Enum\n", + "\n", + "\n", + "class COCO(Enum):\n", + " R_SHOULDER = 5\n", + " L_SHOULDER = 6\n", + " R_ELBOW = 7\n", + " L_ELBOW = 8\n", + " R_WRIST = 9\n", + " L_WRIST = 10\n", + " L_HIP = 11\n", + " R_HIP = 12\n", + " L_KNEE = 13\n", + " R_KNEE = 14\n", + "\n", + "\n", + "class AngleCalculator:\n", + " \"\"\"Calculates the elbow angle of a person from landmarks.\"\"\"\n", + " def __init__(self, landmarks) -> None:\n", + " # landmarks as returned by Ultralytics: np.ndarray[np._AnyShape, np.dtype[np.Any]] | np.Any\n", + " self.landmarks = landmarks\n", + "\n", + " def _ref_len(self):\n", + " # Determine Scale Reference (Hips or Fallback to Shoulder Width)\n", + " if self.landmarks[COCO.L_HIP.value].any() and self.landmarks[COCO.R_HIP.value].any():\n", + " ref_len = np.linalg.norm(self.landmarks[COCO.L_SHOULDER.value] - self.landmarks[COCO.L_HIP.value])\n", + " else:\n", + " ref_len = np.linalg.norm(self.landmarks[COCO.L_SHOULDER.value] - self.landmarks[COCO.R_SHOULDER.value]) / 0.75\n", + " return ref_len\n", + "\n", + " def _uplift_arm_to_3d(self, sh_idx, el_idx, wr_idx, ref_len):\n", + " \"\"\"Uplifts 2D keypoints to 3D using Da Vinci's ratios\"\"\"\n", + " # Ratios (relative to torso length)\n", + " R_UPPER_ARM = 0.45\n", + " R_FOREARM = 0.42\n", + "\n", + " # 1. Shoulder is the root (z=0)\n", + " sh_3d = np.array([self.landmarks[sh_idx][0], self.landmarks[sh_idx][1], 0.0])\n", + "\n", + " # 2. Lift Elbow (Relative to Shoulder)\n", + " L_upper = R_UPPER_ARM * ref_len\n", + " dx1 = self.landmarks[el_idx][0] - self.landmarks[sh_idx][0]\n", + " dy1 = self.landmarks[el_idx][1] - self.landmarks[sh_idx][1]\n", + " dz1 = np.sqrt(max(0, L_upper**2 - (dx1**2 + dy1**2)))\n", + " el_3d = np.array([self.landmarks[el_idx][0], self.landmarks[el_idx][1], dz1])\n", + "\n", + " # 3. Lift Wrist (Relative to Elbow)\n", + " L_fore = R_FOREARM * ref_len\n", + " dx2 = self.landmarks[wr_idx][0] - self.landmarks[el_idx][0]\n", + " dy2 = self.landmarks[wr_idx][1] - self.landmarks[el_idx][1]\n", + " dz2 = np.sqrt(max(0, L_fore**2 - (dx2**2 + dy2**2)))\n", + " # Z-coordinate is cumulative\n", + " wr_3d = np.array([self.landmarks[wr_idx][0], self.landmarks[wr_idx][1], dz1 + dz2])\n", + "\n", + " return sh_3d, el_3d, wr_3d\n", + "\n", + " def __call__(self) -> float:\n", + " def get_angle(a, b, c) -> float:\n", + " ba, bc = a - b, c - b\n", + " cosine = np.dot(ba, bc) / (np.linalg.norm(ba) * np.linalg.norm(bc))\n", + " return np.degrees(np.arccos(np.clip(cosine, -1.0, 1.0)))\n", + "\n", + " ref_len = self._ref_len()\n", + "\n", + " # Calculate elbow angle for both arms\n", + " l_sh, l_el, l_wr = self._uplift_arm_to_3d(COCO.L_SHOULDER.value, COCO.L_ELBOW.value, COCO.L_WRIST.value, ref_len)\n", + " r_sh, r_el, r_wr = self._uplift_arm_to_3d(COCO.R_SHOULDER.value, COCO.R_ELBOW.value, COCO.R_WRIST.value, ref_len)\n", + "\n", + " l_angle = get_angle(l_sh, l_el, l_wr)\n", + " r_angle = get_angle(r_sh, r_el, r_wr)\n", + "\n", + " # Use average angle for robustness (handles side-on views better)\n", + " avg_angle = (l_angle + r_angle) / 2\n", + " return avg_angle" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qahq0xG2rs4h" + }, + "source": [ + "## Download the test data from s3, and create the corresponding dataset + dataloader.\n", + "\n", + "There's no TODO for you here. This text is just here to explain to you what this code does.\n", + "\n", + "In this instance, the test data IS the training data you were provided in the Model Training notebook. This is by design. You do not have access to the test data. This is a simple check to make sure the mechanics of this notebook work.\n", + "\n", + "You should achieve the same accuracy here in this notebook, as you did in your previous notebook (random seed notwithstanding)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_Fw6Rn1JisN3" + }, + "source": [ + "### Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "id": "XBukVn9qrnFZ", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "outputId": "e00098db-5f0a-4e92-b698-62930b543f2c" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading test data:\n", + "\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 77/77 [00:00<00:00, 73068.19it/s]" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "File already exists, skipping: 1_dksksjfwijf.mp4\n", + "File already exists, skipping: 2_dfsaeklnvvalkej.mp4\n", + "File already exists, skipping: 2_difficult_2.mp4\n", + "File already exists, skipping: 2_difficult_sdafkljsalkfj.mp4\n", + "File already exists, skipping: 2_dkdjwkndkfw.mp4\n", + "File already exists, skipping: 2_dkdmkejkeimdh.mp4\n", + "File already exists, skipping: 2_dkjd823kjf.mp4\n", + "File already exists, skipping: 2_dsalkfjalwkenlke.mp4\n", + "File already exists, skipping: 2_kling_20251205_Text_to_Video_On_a_sandy_4976_0.mp4\n", + "File already exists, skipping: 2_kling_20251206_Text_to_Video_Generate_a_71_1.mp4\n", + "File already exists, skipping: 2_sadfasjldkfjaseifj.mp4\n", + "File already exists, skipping: 2_sdafkjaslkclaksdjkas.mp4\n", + "File already exists, skipping: 2_sdfkjsaleijflaskdjf.mp4\n", + "File already exists, skipping: 2_sdjfhafsldkjhjk.mp4\n", + "File already exists, skipping: 2_sdkjdsflkjfwa.mp4\n", + "File already exists, skipping: 2_sdlfjlewlkjkj.mp4\n", + "File already exists, skipping: 2_sdlkjsaelijfksdjf.mp4\n", + "File already exists, skipping: 3_asldkfjalwieaskdfaskdf.mp4\n", + "File already exists, skipping: 3_dkk873lkjlksajdf.mp4\n", + "File already exists, skipping: 3_dsjlaeijlksjdfie.mp4\n", + "File already exists, skipping: 3_dsksdfjbvsdkj.mp4\n", + "File already exists, skipping: 3_dslkaldskjflakjs.mp4\n", + "File already exists, skipping: 3_ewdfkjwaeoihjlkasdjf.mp4\n", + "File already exists, skipping: 3_kling_20251205_Text_to_Video_In_a_grass_4697_0.mp4\n", + "File already exists, skipping: 3_kling_20251205_Text_to_Video_On_a_playg_5028_0.mp4\n", + "File already exists, skipping: 3_kling_20251205_Text_to_Video_On_a_playg_5064_0.mp4\n", + "File already exists, skipping: 3_kling_20251206_Text_to_Video_Generate_a_17_0.mp4\n", + "File already exists, skipping: 3_kling_20251206_Text_to_Video_Generate_a_315_0.mp4\n", + "File already exists, skipping: 3_kling_20251206_Text_to_Video_Generate_a_315_2.mp4\n", + "File already exists, skipping: 3_kling_20251206_Text_to_Video_Generate_a_712_3.mp4\n", + "File already exists, skipping: 3_kling_20251206_Text_to_Video_Generate_a_71_0.mp4\n", + "File already exists, skipping: 3_kling_20251206_Text_to_Video_Generate_a_71_2.mp4\n", + "File already exists, skipping: 3_kling_20251206_Text_to_Video_Generate_a_71_3.mp4\n", + "File already exists, skipping: 3_kling_20251209_Image_to_Video_Generate_a_613_1.mp4\n", + "File already exists, skipping: 3_kling_20251209_Image_to_Video_Generate_a_635_0.mp4\n", + "File already exists, skipping: 3_kling_20251209_Text_to_Video_Generate_a_190_1.mp4\n", + "File already exists, skipping: 3_kling_20251209_Text_to_Video_Generate_a_403_1.mp4\n", + "File already exists, skipping: 3_kling_20251209_Text_to_Video_Generate_a_491_0.mp4\n", + "File already exists, skipping: 3_kling_20251209_Text_to_Video_Generate_a_491_1.mp4\n", + "File already exists, skipping: 3_kling_20251209_Text_to_Video_Generate_a_491_2.mp4\n", + "File already exists, skipping: 3_kling_dskfseu.mp4\n", + "File already exists, skipping: 3_kling_kdjflaskdjf.mp4\n", + "File already exists, skipping: 3_sadklfjasbnlkjlfkj.mp4\n", + "File already exists, skipping: 3_sadlfkjasldkfjasleijlkjfd.mp4\n", + "File already exists, skipping: 3_sadlfkjawelnflksdjf.mp4\n", + "File already exists, skipping: 3_sdfjwaiejflkasjdf.mp4\n", + "File already exists, skipping: 3_sdflkjliejkjdf.mp4\n", + "File already exists, skipping: 3_sdlkfjaleknaksej.mp4\n", + "File already exists, skipping: 3_sdlkfjalkjejafe.mp4\n", + "File already exists, skipping: 3_sdlkjfaslkjfalskjdf.mp4\n", + "File already exists, skipping: 3_sdlkjslndflkseijlkjef.mp4\n", + "File already exists, skipping: 4_20251209_Text_to_Video_Generate_a_561_0.mp4\n", + "File already exists, skipping: 4_asdlkfjalsflnekj.mp4\n", + "File already exists, skipping: 4_aslkcasckmwlejk.mp4\n", + "File already exists, skipping: 4_aslkjasmcalkewjlkje.mp4\n", + "File already exists, skipping: 4_dssalsdkfjweijf.mp4\n", + "File already exists, skipping: 4_kling_20251206_Text_to_Video_Generate_a_28_0.mp4\n", + "File already exists, skipping: 4_kling_20251206_Text_to_Video_Generate_a_315_3.mp4\n", + "File already exists, skipping: 4_kling_20251206_Text_to_Video_Generate_a_58_0.mp4\n", + "File already exists, skipping: 4_kling_20251207_Text_to_Video_Generate_a_521_1.mp4\n", + "File already exists, skipping: 4_kling_20251209_Image_to_Video_Generate_a_635_1.mp4\n", + "File already exists, skipping: 4_kling_20251209_Text_to_Video_Generate_a_190_0.mp4\n", + "File already exists, skipping: 4_kling_20251209_Text_to_Video_Generate_a_218_0.mp4\n", + "File already exists, skipping: 4_kling_20251209_Text_to_Video_Generate_a_263_1.mp4\n", + "File already exists, skipping: 4_kling_20251209_Text_to_Video_Generate_a_377_1.mp4\n", + "File already exists, skipping: 4_kling_20251209_Text_to_Video_Generate_a_452_0.mp4\n", + "File already exists, skipping: 4_kling_20251209_Text_to_Video_Generate_a_452_1.mp4\n", + "File already exists, skipping: 4_kling_20251209_Text_to_Video_Generate_a_561_1.mp4\n", + "File already exists, skipping: 4_kling_20251209_Text_to_Video_Generate_a_588_2.mp4\n", + "File already exists, skipping: 4_pushup_1f2da596-7619-4d55-9376-069e15a42a1a_h264.mp4\n", + "File already exists, skipping: 4_sadflkjasldkjfalseij.mp4\n", + "File already exists, skipping: 4_sadlfkjlknewkjejk.mp4\n", + "File already exists, skipping: 5_sadfjhaslfkjasdlkfjsa.mp4\n", + "File already exists, skipping: 5_sdfkljweoijlkjdsflkjweaij.mp4\n", + "File already exists, skipping: 6_dfjewaijsldkjfsaef.mp4\n", + "File already exists, skipping: 6_kling_20251209_Text_to_Video_Generate_a_218_1.mp4\n", + "File already exists, skipping: 7_sadkjfkljekj.mp4\n", + "\n", + "Downloaded 77 test videos\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "'test-data'" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + } + }, + "metadata": {}, + "execution_count": 23 + } + ], + "source": [ + "# =============================================================================\n", + "# DOWNLOAD TEST DATA FROM S3\n", + "# =============================================================================\n", + "def download_test_data(bucket_name='training-and-validation-data', download_dir=DOWNLOAD_DIR):\n", + " s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))\n", + "\n", + " bucket_name = 'prism-mvta'\n", + " prefix = 'training-and-validation-data/'\n", + "\n", + " os.makedirs(download_dir, exist_ok=True)\n", + "\n", + " paginator = s3.get_paginator('list_objects_v2')\n", + " pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix)\n", + "\n", + " video_names = []\n", + "\n", + " for page in pages:\n", + " if 'Contents' not in page:\n", + " print(\"No files found at the specified path!\")\n", + " break\n", + "\n", + " print(\"Downloading test data:\\n\")\n", + " for obj in tqdm(page['Contents']):\n", + " key = obj['Key']\n", + " filename = os.path.basename(key)\n", + "\n", + " if not filename:\n", + " continue\n", + "\n", + " video_names.append(filename)\n", + " local_path = os.path.join(download_dir, filename)\n", + " if os.path.exists(local_path):\n", + " print(f\"File already exists, skipping: {filename}\")\n", + " continue\n", + " # print(f\"Downloading: {filename}\")\n", + " s3.download_file(bucket_name, key, local_path)\n", + "\n", + " print(f\"\\nDownloaded {len(video_names)} test videos\")\n", + " return download_dir\n", + "\n", + "download_test_data()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5sVIgQPSisN4" + }, + "source": [ + "### Dataloader" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KMBBOeBZisN4" + }, + "source": [ + "#### Clear Cache and Hug a Face" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 105, + "referenced_widgets": [ + "96b1904645b74a9186787e146019e4da", + "79e68ccbe7e14868bc1b621d9c2c3bcb", + "b0dd851eb3cd48c29f2eaf0eab7c3aa8", + "950ea68d2f8e492890a4e901ff38293d", + "29f2223b7d2b40a8a07ff058d0b40b58", + "05bec595f2db4ec8b9481ad7b4623b58", + "bfd5b83934a74596a02cb18c7ec46451", + "61bf2f24595b4619852d5f2286b1f67e", + "cf1a3dc687e8454a89ffd73519ab5f5d", + "a0ce61edaee2404ba6c19818f7956e7b", + "bc18c6adf87e487ab48345ce2a89fe9f" + ] + }, + "id": "qJsxm9qTisN4", + "outputId": "c729b60a-4b60-472d-8d01-30c0d24fb26b" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "./yolo-weights/yolov8s-pose.pt: 0%| | 0.00/23.5M [00:00 1:\n", + " # Removed print statement for verbosity during dataloading\n", + " pass # Or uncomment: print(f\"Multiple Humans in Frame {frame_idx}: {num_persons_in_frame} detected. Selecting most confident.\")\n", + "\n", + " # Access keypoint data directly from the returned human (which is a Keypoints object for one person)\n", + " if most_confident_person_keypoints_obj.xy.shape[0] > 0: # Check if keypoints exist for the human\n", + " landmarks_tensor = most_confident_person_keypoints_obj.xy[0] # Already (num_keypoints, 3)\n", + " landmarks_np = landmarks_tensor.cpu().numpy()\n", + "\n", + " if landmarks_np.shape[0] > max(COCO.R_WRIST.value, COCO.L_WRIST.value):\n", + " calculator = AngleCalculator(landmarks_np)\n", + " try:\n", + " angle = calculator()\n", + " if 0 <= angle <= 180:\n", + " current_frame_angle = angle\n", + " except Exception as e:\n", + " pass\n", + " angles.append(current_frame_angle)\n", + " del results # Explicitly clear results from GPU/memory\n", + "\n", + " # Now, after processing all frames, perform normalization\n", + " if not angles: # If no angles were calculated (e.g., empty video or no detections)\n", + " return []\n", + "\n", + " angles_np = np.array(angles, dtype=np.float32)\n", + " min_angle, max_angle = np.min(angles_np), np.max(angles_np)\n", + " if max_angle > 0:\n", + " angles_np = (angles_np - min_angle) / (max_angle - min_angle) # Normalize to [0, 1]\n", + "\n", + " return angles_np.tolist() # Return as list or keep as np array based on preference.\n", + "\n", + "\n", + "def collate_fn(batch):\n", + " \"\"\"Pad all sequences of angles to a target length.\"\"\"\n", + " angles_list, labels, lengths = zip(*batch)\n", + "\n", + " padded_angles = []\n", + " for angle_tensor, current_length in zip(angles_list, lengths):\n", + " angle_tensor = angle_tensor.flatten()\n", + " if current_length < MAX_FRAMES:\n", + " # Pad with zeros at the end\n", + " padding = torch.zeros(MAX_FRAMES - current_length, dtype=torch.float32)\n", + " angles = torch.cat([angle_tensor.flatten(), padding], dim=0)\n", + " elif current_length > MAX_FRAMES:\n", + " # Truncate if longer\n", + " angles = angle_tensor[:MAX_FRAMES]\n", + " else:\n", + " angles = angle_tensor\n", + " padded_angles.append(angles)\n", + "\n", + " angles_batch = torch.stack(padded_angles, dim=0)\n", + " labels_batch = torch.tensor(labels)\n", + " lengths_batch = torch.tensor(lengths)\n", + "\n", + " return angles_batch, labels_batch, lengths_batch\n", + "\n", + "\n", + "def get_dataloaders(video_dir, batch_size=4, val_split=0.2, transform=None, feature_dir=None, signal_transform=None):\n", + " \"\"\"Create train and validation dataloaders.\"\"\"\n", + "\n", + " full_dataset = VideoDataset(video_dir, transform=transform, feature_dir=feature_dir, signal_transform=signal_transform)\n", + "\n", + " val_size = int(len(full_dataset) * val_split)\n", + " train_size = len(full_dataset) - val_size\n", + "\n", + " train_dataset, val_dataset = random_split(\n", + " full_dataset,\n", + " [train_size, val_size],\n", + " generator=torch.Generator().manual_seed(42)\n", + " )\n", + "\n", + " train_loader = DataLoader(\n", + " train_dataset,\n", + " batch_size=batch_size,\n", + " shuffle=True,\n", + " num_workers=0, # Changed from 2 to 0 to avoid multiprocessing issues with YOLO model\n", + " collate_fn=collate_fn\n", + " )\n", + "\n", + " val_loader = DataLoader(\n", + " val_dataset,\n", + " batch_size=batch_size,\n", + " shuffle=False,\n", + " num_workers=0, # Changed from 2 to 0\n", + " collate_fn=collate_fn\n", + " )\n", + "\n", + " print(f\"Train: {len(train_dataset)} videos, Val: {len(val_dataset)} videos\\n\")\n", + "\n", + " return train_loader, val_loader\n", + "\n", + "\n", + "def get_balanced_dataloaders(video_dir, batch_size=4, val_split=0.2, transform=None, feature_dir=None, signal_transform=None):\n", + " \"\"\"Create balanced train and validation dataloaders using WeightedRandomSampler.\"\"\"\n", + "\n", + " # 1. Initialize the full dataset\n", + " full_dataset = VideoDataset(video_dir, transform=transform, feature_dir=feature_dir, signal_transform=signal_transform)\n", + "\n", + " # 2. Manual Index Split (to keep track of labels for the sampler)\n", + " dataset_size = len(full_dataset)\n", + " indices = list(range(dataset_size))\n", + " split = int(np.floor(val_split * dataset_size))\n", + "\n", + " # Shuffle indices to ensure random distribution before splitting\n", + " np.random.seed(42)\n", + " np.random.shuffle(indices)\n", + "\n", + " train_indices, val_indices = indices[split:], indices[:split]\n", + "\n", + " # 3. Calculate Weights for the Training Set ONLY\n", + " # We pull the labels corresponding to our training indices\n", + " train_labels = [full_dataset.labels[i] for i in train_indices]\n", + " train_labels = np.array(train_labels).astype(int)\n", + "\n", + " # Count occurrences of each push-up count (class)\n", + " class_sample_count = np.bincount(train_labels)\n", + " # Avoid division by zero for classes that might not exist in the subset\n", + " class_sample_count[class_sample_count == 0] = 1\n", + "\n", + " weight = 1. / class_sample_count\n", + "\n", + " # Assign a weight to every sample in the training set\n", + " samples_weight = torch.from_numpy(weight[train_labels]).double()\n", + "\n", + " # 4. Create the Sampler\n", + " # replacement=True is required to oversample rare classes (like 1s, 6s, 7s)\n", + " sampler = WeightedRandomSampler(samples_weight, len(samples_weight), replacement=True)\n", + "\n", + " # 5. Create Subsets and DataLoaders\n", + " train_dataset = Subset(full_dataset, train_indices)\n", + " val_dataset = Subset(full_dataset, val_indices)\n", + "\n", + " train_loader = DataLoader(\n", + " train_dataset,\n", + " batch_size=batch_size,\n", + " sampler=sampler, # SHUFFLE must be False when using a Sampler\n", + " num_workers=0,\n", + " collate_fn=collate_fn\n", + " )\n", + "\n", + " val_loader = DataLoader(\n", + " val_dataset,\n", + " batch_size=batch_size,\n", + " shuffle=False,\n", + " num_workers=0,\n", + " collate_fn=collate_fn\n", + " )\n", + "\n", + " print(f\"Dataset Balanced: Oversampling rare labels to match common labels.\")\n", + " print(f\"Train: {len(train_dataset)} videos, Val: {len(val_dataset)} videos\\n\")\n", + "\n", + " return train_loader, val_loader\n", + "\n", + "\n", + "# # Run pre-computation once. This will take time but saves hours during training.\n", + "# def precompute_features(video_dir, output_dir, transform=None):\n", + "# os.makedirs(output_dir, exist_ok=True)\n", + "# # Create a temporary dataset without transforms for extraction\n", + "# ds = VideoDataset(video_dir, transform=transform)\n", + "# for i in range(len(ds)):\n", + "# filename = ds.video_files[i]\n", + "# save_path = os.path.join(output_dir, filename + \".pt\")\n", + "# if os.path.exists(save_path):\n", + "# continue\n", + "# print(f\"Extracting features for {filename} ({i+1}/{len(ds)})...\")\n", + "# angles, label, length = ds[i]\n", + "# torch.save({'angles': angles, 'length': length, 'label': label}, save_path)\n", + "\n", + "# precompute_features(DOWNLOAD_DIR, CACHE_DIR, RESIZER)\n", + "\n", + "\n", + "train_loader, val_loader = get_dataloaders( # get_balanced_dataloaders\n", + " DOWNLOAD_DIR,\n", + " batch_size=4,\n", + " val_split=0.2,\n", + " transform=RESIZER,\n", + " feature_dir=CACHE_DIR,\n", + " signal_transform=None\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "B9PVSdWKsP94" + }, + "source": [ + "## TODO 3 - Download your model from HuggingFace and instantiate it\n", + "\n", + "Replace line 8 of the below code. Line 8 is where you instantiate YOUR MODEL ARCHITECTURE (which you re-defined above) with the weights you download from HuggingFace. Make sure you get the class name, and the arguments to the __init__ method correct.\n", + "\n", + "\n", + "This code just downloads the same model which you uploaded in the last notebook." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 95, + "referenced_widgets": [ + "0ff1e7ba711c49328b27f605a7db8976", + "cbb9bdc774454aa595fb566a0fb5eaf6", + "8107c233e5254eac8b6b2be9d4430682", + "35445b91cb07407aa2146a5f4b45ab75", + "1241c4cfb97c4a93a7ca506f8b293dd6", + "16a558aee1094a40a3a2a0c979aa42d7", + "a2cb6489aa3c4dcebbfa0296f5b11da6", + "424d904e40274016a4b75dad1d0a6c37", + "01b7c22856774189bd94ecbca7a4b9cd", + "72ac9afc25814e79927704caa9c8ef12", + "aa0aa9d015ee414d81fe78b1b3c75b75" + ] + }, + "id": "LWuMOqY_sOdg", + "outputId": "a9c2a3fe-230c-46a5-b5c2-d29acd4e8a94" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "model-weights/final_model_weights_Pushup(…): 0%| | 0.00/530k [00:007.1f}ms | {video_name}\")\n", + " correct += pred == true_label\n", + "\n", + " # correct += preds.eq(labels).sum().item()\n", + " total += labels.size(0)\n", + "\n", + " all_preds.extend(preds.cpu().numpy())\n", + " all_labels.extend(labels.cpu().numpy())\n", + "\n", + " accuracy = correct / total\n", + " return accuracy, all_preds, all_labels, all_times\n", + "\n", + "\n", + "# =============================================================================\n", + "# RUN INFERENCE\n", + "# =============================================================================\n", + "\n", + "def run_inference(model):\n", + " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + " print(f\"Using device: {device}\")\n", + "\n", + " # Download test data\n", + " test_dir = DOWNLOAD_DIR\n", + "\n", + " model = model.to(device)\n", + "\n", + " # Create dataloader\n", + " test_dataset = VideoDataset(test_dir)\n", + " test_loader = DataLoader(\n", + " test_dataset,\n", + " batch_size=1,\n", + " shuffle=False,\n", + " num_workers=0,\n", + " collate_fn=collate_fn\n", + " )\n", + "\n", + " print(f\"\\nRunning inference on {len(test_dataset)} test videos...\")\n", + "\n", + " # Warmup (optional, helps get consistent GPU timings)\n", + " if device.type == 'cuda':\n", + " dummy_data = torch.randn(1, MAX_FRAMES, device=device) # (Batch, SeqLen)\n", + " dummy_lengths = torch.tensor([MAX_FRAMES], dtype=torch.long, device=device) # (Batch,)\n", + " with torch.no_grad():\n", + " _ = model(dummy_data, dummy_lengths)\n", + " torch.cuda.synchronize()\n", + "\n", + " total_start = time.time()\n", + " accuracy, preds, labels, times = evaluate(model, test_loader, test_dataset, device)\n", + " total_end = time.time()\n", + "\n", + " # Summary\n", + " preds = np.round(preds).astype(int)\n", + " num_correct = sum(p == l for p, l in zip(preds.astype(int), labels))\n", + " num_wrong = len(preds) - num_correct\n", + "\n", + " print(\"\\n\" + \"=\"*50)\n", + " print(\"SUMMARY\")\n", + " print(\"=\"*50)\n", + " print(f\"Total videos: {len(preds)}\")\n", + " print(f\"Correct: {num_correct}\")\n", + " print(f\"Incorrect: {num_wrong}\")\n", + " print(f\"\")\n", + " print(f\"ACCURACY: {accuracy*100:.2f}%\")\n", + " print(f\"\")\n", + " print(f\"Total time: {total_end - total_start:.2f}s\")\n", + " print(f\"Avg per video: {sum(times) / len(times):.1f}ms\")\n", + " print(f\"Min latency: {min(times):.1f}ms\")\n", + " print(f\"Max latency: {max(times):.1f}ms\")\n", + " print(\"=\"*50)\n", + " return accuracy, preds, labels\n", + "\n", + "_, _, _ = run_inference(model)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "96b1904645b74a9186787e146019e4da": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_79e68ccbe7e14868bc1b621d9c2c3bcb", + "IPY_MODEL_b0dd851eb3cd48c29f2eaf0eab7c3aa8", + "IPY_MODEL_950ea68d2f8e492890a4e901ff38293d" + ], + "layout": "IPY_MODEL_29f2223b7d2b40a8a07ff058d0b40b58" + } + }, + "79e68ccbe7e14868bc1b621d9c2c3bcb": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_05bec595f2db4ec8b9481ad7b4623b58", + "placeholder": "​", + "style": "IPY_MODEL_bfd5b83934a74596a02cb18c7ec46451", + "value": "./yolo-weights/yolov8s-pose.pt: 100%" + } + }, + "b0dd851eb3cd48c29f2eaf0eab7c3aa8": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_61bf2f24595b4619852d5f2286b1f67e", + "max": 23513657, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_cf1a3dc687e8454a89ffd73519ab5f5d", + "value": 23513657 + } + }, + "950ea68d2f8e492890a4e901ff38293d": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a0ce61edaee2404ba6c19818f7956e7b", + "placeholder": "​", + "style": "IPY_MODEL_bc18c6adf87e487ab48345ce2a89fe9f", + "value": " 23.5M/23.5M [00:05<00:00, 5.34MB/s]" + } + }, + "29f2223b7d2b40a8a07ff058d0b40b58": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "05bec595f2db4ec8b9481ad7b4623b58": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "bfd5b83934a74596a02cb18c7ec46451": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "61bf2f24595b4619852d5f2286b1f67e": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "cf1a3dc687e8454a89ffd73519ab5f5d": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "a0ce61edaee2404ba6c19818f7956e7b": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "bc18c6adf87e487ab48345ce2a89fe9f": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "0ff1e7ba711c49328b27f605a7db8976": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_cbb9bdc774454aa595fb566a0fb5eaf6", + "IPY_MODEL_8107c233e5254eac8b6b2be9d4430682", + "IPY_MODEL_35445b91cb07407aa2146a5f4b45ab75" + ], + "layout": "IPY_MODEL_1241c4cfb97c4a93a7ca506f8b293dd6" + } + }, + "cbb9bdc774454aa595fb566a0fb5eaf6": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_16a558aee1094a40a3a2a0c979aa42d7", + "placeholder": "​", + "style": "IPY_MODEL_a2cb6489aa3c4dcebbfa0296f5b11da6", + "value": "model-weights/final_model_weights_Pushup(…): 100%" + } + }, + "8107c233e5254eac8b6b2be9d4430682": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_424d904e40274016a4b75dad1d0a6c37", + "max": 530345, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_01b7c22856774189bd94ecbca7a4b9cd", + "value": 530345 + } + }, + "35445b91cb07407aa2146a5f4b45ab75": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_72ac9afc25814e79927704caa9c8ef12", + "placeholder": "​", + "style": "IPY_MODEL_aa0aa9d015ee414d81fe78b1b3c75b75", + "value": " 530k/530k [00:02<00:00, 232kB/s]" + } + }, + "1241c4cfb97c4a93a7ca506f8b293dd6": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "16a558aee1094a40a3a2a0c979aa42d7": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a2cb6489aa3c4dcebbfa0296f5b11da6": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "424d904e40274016a4b75dad1d0a6c37": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "01b7c22856774189bd94ecbca7a4b9cd": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "72ac9afc25814e79927704caa9c8ef12": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "aa0aa9d015ee414d81fe78b1b3c75b75": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file