Spaces:
Sleeping
Sleeping
fixed model
Browse files- Dockerfile +0 -14
- README.md +6 -0
- requirements-dev.txt +0 -162
- requirements.txt +41 -1
- src/toy_duration_predictor/__init__.py +3 -2
- src/toy_duration_predictor/preprocess/__init__.py +1 -0
- src/toy_duration_predictor/preprocess/mssv.py +1 -1
- src/toy_duration_predictor/train.py +4 -4
- test.ipynb +47 -485
Dockerfile
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
FROM python:3.12-slim
|
| 2 |
-
# Set the working directory
|
| 3 |
-
# All subsequent commands will run from here.
|
| 4 |
-
WORKDIR /app
|
| 5 |
-
RUN apt-get update && apt-get upgrade -y && rm -rf /var/lib/apt/lists/*
|
| 6 |
-
# Copy project files into the container.
|
| 7 |
-
# This copies everything from your local directory into the container's /app directory.
|
| 8 |
-
COPY . .
|
| 9 |
-
# Install all the Python packages listed in your requirements.txt.
|
| 10 |
-
# The --no-cache-dir flag keeps the final image size smaller.
|
| 11 |
-
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 12 |
-
# Define the default command to run when the Space starts.
|
| 13 |
-
# This launches a JupyterLab server that is accessible from the web.
|
| 14 |
-
CMD ["jupyter", "lab", "--ip=0.0.0.0", "--port=7860", "--allow-root", "--NotebookApp.token=''"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -14,3 +14,9 @@ pip install git+https://github.com/ccss17/toy-duration-predictor.git
|
|
| 14 |
```python
|
| 15 |
import toy_duration_predictor as tdp
|
| 16 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
```python
|
| 15 |
import toy_duration_predictor as tdp
|
| 16 |
```
|
| 17 |
+
|
| 18 |
+
## Hugging Face
|
| 19 |
+
|
| 20 |
+
- Model: https://huggingface.co/ccss17/toy-duration-predictor
|
| 21 |
+
- Dataset: https://huggingface.co/datasets/ccss17/note-duration-dataset
|
| 22 |
+
- Spaces: https://huggingface.co/spaces/ccss17/toy-duration-predictor
|
requirements-dev.txt
DELETED
|
@@ -1,162 +0,0 @@
|
|
| 1 |
-
-e .
|
| 2 |
-
aiohappyeyeballs==2.6.1
|
| 3 |
-
aiohttp==3.12.13
|
| 4 |
-
aiohttp-cors==0.8.1
|
| 5 |
-
aiosignal==1.3.2
|
| 6 |
-
annotated-types==0.7.0
|
| 7 |
-
anyio==4.9.0
|
| 8 |
-
appnope==0.1.4 ; sys_platform == 'darwin'
|
| 9 |
-
asttokens==3.0.0
|
| 10 |
-
async-timeout==5.0.1 ; python_full_version < '3.11'
|
| 11 |
-
attrs==25.3.0
|
| 12 |
-
black==25.1.0
|
| 13 |
-
cachetools==5.5.2
|
| 14 |
-
certifi==2025.6.15
|
| 15 |
-
cffi==1.17.1 ; implementation_name == 'pypy'
|
| 16 |
-
charset-normalizer==3.4.2
|
| 17 |
-
click==8.1.8 ; python_full_version < '3.10'
|
| 18 |
-
click==8.2.1 ; python_full_version >= '3.10'
|
| 19 |
-
colorama==0.4.6 ; sys_platform == 'win32'
|
| 20 |
-
colorful==0.5.6
|
| 21 |
-
comm==0.2.2
|
| 22 |
-
datasets==3.6.0
|
| 23 |
-
debugpy==1.8.14
|
| 24 |
-
decorator==5.2.1
|
| 25 |
-
dill==0.3.8
|
| 26 |
-
distlib==0.3.9
|
| 27 |
-
exceptiongroup==1.3.0 ; python_full_version < '3.11'
|
| 28 |
-
executing==2.2.0
|
| 29 |
-
fastapi==0.115.13
|
| 30 |
-
filelock==3.18.0
|
| 31 |
-
frozenlist==1.7.0
|
| 32 |
-
fsspec==2025.3.0
|
| 33 |
-
google-api-core==2.25.1
|
| 34 |
-
google-auth==2.40.3
|
| 35 |
-
googleapis-common-protos==1.70.0
|
| 36 |
-
grpcio==1.73.0
|
| 37 |
-
h11==0.16.0
|
| 38 |
-
hf-xet==1.1.5 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
|
| 39 |
-
httptools==0.6.4
|
| 40 |
-
huggingface-hub==0.33.0
|
| 41 |
-
idna==3.10
|
| 42 |
-
importlib-metadata==8.7.0
|
| 43 |
-
ipykernel==6.29.5
|
| 44 |
-
ipython==8.18.1 ; python_full_version < '3.10'
|
| 45 |
-
ipython==8.37.0 ; python_full_version == '3.10.*'
|
| 46 |
-
ipython==9.3.0 ; python_full_version >= '3.11'
|
| 47 |
-
ipython-pygments-lexers==1.1.1 ; python_full_version >= '3.11'
|
| 48 |
-
ipywidgets==8.1.7
|
| 49 |
-
jedi==0.19.2
|
| 50 |
-
jinja2==3.1.6
|
| 51 |
-
jsonschema==4.24.0
|
| 52 |
-
jsonschema-specifications==2025.4.1
|
| 53 |
-
jupyter-client==8.6.3
|
| 54 |
-
jupyter-core==5.8.1
|
| 55 |
-
jupyterlab-widgets==3.0.15
|
| 56 |
-
llvmlite==0.44.0 ; python_full_version >= '3.10'
|
| 57 |
-
markdown-it-py==3.0.0
|
| 58 |
-
markupsafe==2.1.5 ; python_full_version < '3.10'
|
| 59 |
-
markupsafe==3.0.2 ; python_full_version >= '3.10'
|
| 60 |
-
matplotlib-inline==0.1.7
|
| 61 |
-
mdurl==0.1.2
|
| 62 |
-
midii==0.1.19 ; python_full_version < '3.10'
|
| 63 |
-
midii==0.1.41 ; python_full_version >= '3.10'
|
| 64 |
-
mido==1.3.3
|
| 65 |
-
mpmath==1.3.0
|
| 66 |
-
msgpack==1.1.1
|
| 67 |
-
multidict==6.5.0
|
| 68 |
-
multiprocess==0.70.16
|
| 69 |
-
mypy-extensions==1.1.0
|
| 70 |
-
nest-asyncio==1.6.0
|
| 71 |
-
networkx==3.2.1 ; python_full_version < '3.10'
|
| 72 |
-
networkx==3.4.2 ; python_full_version == '3.10.*'
|
| 73 |
-
networkx==3.5 ; python_full_version >= '3.11'
|
| 74 |
-
numba==0.61.2 ; python_full_version >= '3.10'
|
| 75 |
-
numpy==2.0.2 ; python_full_version < '3.10'
|
| 76 |
-
numpy==2.2.6 ; python_full_version >= '3.10'
|
| 77 |
-
nvidia-cublas-cu12==12.6.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 78 |
-
nvidia-cuda-cupti-cu12==12.6.80 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 79 |
-
nvidia-cuda-nvrtc-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 80 |
-
nvidia-cuda-runtime-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 81 |
-
nvidia-cudnn-cu12==9.5.1.17 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 82 |
-
nvidia-cufft-cu12==11.3.0.4 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 83 |
-
nvidia-cufile-cu12==1.11.1.6 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 84 |
-
nvidia-curand-cu12==10.3.7.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 85 |
-
nvidia-cusolver-cu12==11.7.1.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 86 |
-
nvidia-cusparse-cu12==12.5.4.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 87 |
-
nvidia-cusparselt-cu12==0.6.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 88 |
-
nvidia-nccl-cu12==2.26.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 89 |
-
nvidia-nvjitlink-cu12==12.6.85 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 90 |
-
nvidia-nvtx-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 91 |
-
opencensus==0.11.4
|
| 92 |
-
opencensus-context==0.1.3
|
| 93 |
-
opentelemetry-api==1.34.1
|
| 94 |
-
opentelemetry-exporter-prometheus==0.55b1
|
| 95 |
-
opentelemetry-proto==1.34.1
|
| 96 |
-
opentelemetry-sdk==1.34.1
|
| 97 |
-
opentelemetry-semantic-conventions==0.55b1
|
| 98 |
-
packaging==25.0
|
| 99 |
-
pandas==2.3.0
|
| 100 |
-
parso==0.8.4
|
| 101 |
-
pathspec==0.12.1
|
| 102 |
-
pexpect==4.9.0 ; (python_full_version < '3.10' and sys_platform == 'emscripten') or (sys_platform != 'emscripten' and sys_platform != 'win32')
|
| 103 |
-
platformdirs==4.3.8
|
| 104 |
-
prometheus-client==0.22.1
|
| 105 |
-
prompt-toolkit==3.0.51
|
| 106 |
-
propcache==0.3.2
|
| 107 |
-
proto-plus==1.26.1
|
| 108 |
-
protobuf==5.29.5
|
| 109 |
-
psutil==7.0.0
|
| 110 |
-
ptyprocess==0.7.0 ; (python_full_version < '3.10' and sys_platform == 'emscripten') or (sys_platform != 'emscripten' and sys_platform != 'win32')
|
| 111 |
-
pure-eval==0.2.3
|
| 112 |
-
py-spy==0.4.0
|
| 113 |
-
pyarrow==17.0.0
|
| 114 |
-
pyasn1==0.6.1
|
| 115 |
-
pyasn1-modules==0.4.2
|
| 116 |
-
pycparser==2.22 ; implementation_name == 'pypy'
|
| 117 |
-
pydantic==2.11.7
|
| 118 |
-
pydantic-core==2.33.2
|
| 119 |
-
pygments==2.19.2
|
| 120 |
-
python-dateutil==2.9.0.post0
|
| 121 |
-
python-dotenv==1.1.1
|
| 122 |
-
pytz==2025.2
|
| 123 |
-
pywin32==310 ; platform_python_implementation != 'PyPy' and sys_platform == 'win32'
|
| 124 |
-
pyyaml==6.0.2
|
| 125 |
-
pyzmq==27.0.0
|
| 126 |
-
ray==2.47.1
|
| 127 |
-
referencing==0.36.2
|
| 128 |
-
requests==2.32.4
|
| 129 |
-
rich==14.0.0
|
| 130 |
-
rpds-py==0.25.1
|
| 131 |
-
rsa==4.9.1
|
| 132 |
-
ruff==0.12.0
|
| 133 |
-
setuptools==80.9.0 ; (python_full_version >= '3.12' and platform_machine != 'x86_64') or (python_full_version >= '3.12' and sys_platform != 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')
|
| 134 |
-
six==1.17.0
|
| 135 |
-
smart-open==7.1.0
|
| 136 |
-
sniffio==1.3.1
|
| 137 |
-
stack-data==0.6.3
|
| 138 |
-
starlette==0.46.2
|
| 139 |
-
sympy==1.14.0
|
| 140 |
-
tensorboardx==2.6.4
|
| 141 |
-
tomli==2.2.1 ; python_full_version < '3.11'
|
| 142 |
-
torch==2.7.1
|
| 143 |
-
tornado==6.5.1
|
| 144 |
-
tqdm==4.67.1
|
| 145 |
-
traitlets==5.14.3
|
| 146 |
-
triton==3.3.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 147 |
-
typing-extensions==4.14.0
|
| 148 |
-
typing-inspection==0.4.1
|
| 149 |
-
tzdata==2025.2
|
| 150 |
-
urllib3==2.5.0
|
| 151 |
-
uvicorn==0.34.3
|
| 152 |
-
uvloop==0.21.0 ; platform_python_implementation != 'PyPy' and sys_platform != 'cygwin' and sys_platform != 'win32'
|
| 153 |
-
virtualenv==20.31.2
|
| 154 |
-
watchfiles==1.1.0
|
| 155 |
-
wcwidth==0.2.13
|
| 156 |
-
websockets==12.0 ; python_full_version < '3.10'
|
| 157 |
-
websockets==15.0.1 ; python_full_version >= '3.10'
|
| 158 |
-
widgetsnbextension==4.0.14
|
| 159 |
-
wrapt==1.17.2
|
| 160 |
-
xxhash==3.5.0
|
| 161 |
-
yarl==1.20.1
|
| 162 |
-
zipp==3.23.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,23 +1,30 @@
|
|
| 1 |
-
-e .
|
| 2 |
aiohappyeyeballs==2.6.1
|
| 3 |
aiohttp==3.12.13
|
| 4 |
aiohttp-cors==0.8.1
|
| 5 |
aiosignal==1.3.2
|
| 6 |
annotated-types==0.7.0
|
| 7 |
anyio==4.9.0
|
|
|
|
|
|
|
| 8 |
async-timeout==5.0.1 ; python_full_version < '3.11'
|
| 9 |
attrs==25.3.0
|
|
|
|
| 10 |
cachetools==5.5.2
|
| 11 |
certifi==2025.6.15
|
|
|
|
| 12 |
charset-normalizer==3.4.2
|
| 13 |
click==8.1.8 ; python_full_version < '3.10'
|
| 14 |
click==8.2.1 ; python_full_version >= '3.10'
|
| 15 |
colorama==0.4.6 ; sys_platform == 'win32'
|
| 16 |
colorful==0.5.6
|
|
|
|
| 17 |
datasets==3.6.0
|
|
|
|
|
|
|
| 18 |
dill==0.3.8
|
| 19 |
distlib==0.3.9
|
| 20 |
exceptiongroup==1.3.0 ; python_full_version < '3.11'
|
|
|
|
| 21 |
fastapi==0.115.13
|
| 22 |
filelock==3.18.0
|
| 23 |
frozenlist==1.7.0
|
|
@@ -32,13 +39,24 @@ httptools==0.6.4
|
|
| 32 |
huggingface-hub==0.33.0
|
| 33 |
idna==3.10
|
| 34 |
importlib-metadata==8.7.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
jinja2==3.1.6
|
| 36 |
jsonschema==4.24.0
|
| 37 |
jsonschema-specifications==2025.4.1
|
|
|
|
|
|
|
|
|
|
| 38 |
llvmlite==0.44.0 ; python_full_version >= '3.10'
|
| 39 |
markdown-it-py==3.0.0
|
| 40 |
markupsafe==2.1.5 ; python_full_version < '3.10'
|
| 41 |
markupsafe==3.0.2 ; python_full_version >= '3.10'
|
|
|
|
| 42 |
mdurl==0.1.2
|
| 43 |
midii==0.1.19 ; python_full_version < '3.10'
|
| 44 |
midii==0.1.41 ; python_full_version >= '3.10'
|
|
@@ -47,6 +65,8 @@ mpmath==1.3.0
|
|
| 47 |
msgpack==1.1.1
|
| 48 |
multidict==6.5.0
|
| 49 |
multiprocess==0.70.16
|
|
|
|
|
|
|
| 50 |
networkx==3.2.1 ; python_full_version < '3.10'
|
| 51 |
networkx==3.4.2 ; python_full_version == '3.10.*'
|
| 52 |
networkx==3.5 ; python_full_version >= '3.11'
|
|
@@ -76,37 +96,52 @@ opentelemetry-sdk==1.34.1
|
|
| 76 |
opentelemetry-semantic-conventions==0.55b1
|
| 77 |
packaging==25.0
|
| 78 |
pandas==2.3.0
|
|
|
|
|
|
|
|
|
|
| 79 |
platformdirs==4.3.8
|
| 80 |
prometheus-client==0.22.1
|
|
|
|
| 81 |
propcache==0.3.2
|
| 82 |
proto-plus==1.26.1
|
| 83 |
protobuf==5.29.5
|
|
|
|
|
|
|
|
|
|
| 84 |
py-spy==0.4.0
|
| 85 |
pyarrow==17.0.0
|
| 86 |
pyasn1==0.6.1
|
| 87 |
pyasn1-modules==0.4.2
|
|
|
|
| 88 |
pydantic==2.11.7
|
| 89 |
pydantic-core==2.33.2
|
| 90 |
pygments==2.19.2
|
| 91 |
python-dateutil==2.9.0.post0
|
| 92 |
python-dotenv==1.1.1
|
| 93 |
pytz==2025.2
|
|
|
|
| 94 |
pyyaml==6.0.2
|
|
|
|
| 95 |
ray==2.47.1
|
| 96 |
referencing==0.36.2
|
| 97 |
requests==2.32.4
|
| 98 |
rich==14.0.0
|
| 99 |
rpds-py==0.25.1
|
| 100 |
rsa==4.9.1
|
|
|
|
| 101 |
setuptools==80.9.0 ; (python_full_version >= '3.12' and platform_machine != 'x86_64') or (python_full_version >= '3.12' and sys_platform != 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')
|
| 102 |
six==1.17.0
|
| 103 |
smart-open==7.1.0
|
| 104 |
sniffio==1.3.1
|
|
|
|
| 105 |
starlette==0.46.2
|
| 106 |
sympy==1.14.0
|
| 107 |
tensorboardx==2.6.4
|
|
|
|
| 108 |
torch==2.7.1
|
|
|
|
| 109 |
tqdm==4.67.1
|
|
|
|
| 110 |
triton==3.3.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 111 |
typing-extensions==4.14.0
|
| 112 |
typing-inspection==0.4.1
|
|
@@ -116,9 +151,14 @@ uvicorn==0.34.3
|
|
| 116 |
uvloop==0.21.0 ; platform_python_implementation != 'PyPy' and sys_platform != 'cygwin' and sys_platform != 'win32'
|
| 117 |
virtualenv==20.31.2
|
| 118 |
watchfiles==1.1.0
|
|
|
|
| 119 |
websockets==12.0 ; python_full_version < '3.10'
|
| 120 |
websockets==15.0.1 ; python_full_version >= '3.10'
|
|
|
|
| 121 |
wrapt==1.17.2
|
| 122 |
xxhash==3.5.0
|
| 123 |
yarl==1.20.1
|
| 124 |
zipp==3.23.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
aiohappyeyeballs==2.6.1
|
| 2 |
aiohttp==3.12.13
|
| 3 |
aiohttp-cors==0.8.1
|
| 4 |
aiosignal==1.3.2
|
| 5 |
annotated-types==0.7.0
|
| 6 |
anyio==4.9.0
|
| 7 |
+
appnope==0.1.4 ; sys_platform == 'darwin'
|
| 8 |
+
asttokens==3.0.0
|
| 9 |
async-timeout==5.0.1 ; python_full_version < '3.11'
|
| 10 |
attrs==25.3.0
|
| 11 |
+
black==25.1.0
|
| 12 |
cachetools==5.5.2
|
| 13 |
certifi==2025.6.15
|
| 14 |
+
cffi==1.17.1 ; implementation_name == 'pypy'
|
| 15 |
charset-normalizer==3.4.2
|
| 16 |
click==8.1.8 ; python_full_version < '3.10'
|
| 17 |
click==8.2.1 ; python_full_version >= '3.10'
|
| 18 |
colorama==0.4.6 ; sys_platform == 'win32'
|
| 19 |
colorful==0.5.6
|
| 20 |
+
comm==0.2.2
|
| 21 |
datasets==3.6.0
|
| 22 |
+
debugpy==1.8.14
|
| 23 |
+
decorator==5.2.1
|
| 24 |
dill==0.3.8
|
| 25 |
distlib==0.3.9
|
| 26 |
exceptiongroup==1.3.0 ; python_full_version < '3.11'
|
| 27 |
+
executing==2.2.0
|
| 28 |
fastapi==0.115.13
|
| 29 |
filelock==3.18.0
|
| 30 |
frozenlist==1.7.0
|
|
|
|
| 39 |
huggingface-hub==0.33.0
|
| 40 |
idna==3.10
|
| 41 |
importlib-metadata==8.7.0
|
| 42 |
+
ipykernel==6.29.5
|
| 43 |
+
ipython==8.18.1 ; python_full_version < '3.10'
|
| 44 |
+
ipython==8.37.0 ; python_full_version == '3.10.*'
|
| 45 |
+
ipython==9.3.0 ; python_full_version >= '3.11'
|
| 46 |
+
ipython-pygments-lexers==1.1.1 ; python_full_version >= '3.11'
|
| 47 |
+
ipywidgets==8.1.7
|
| 48 |
+
jedi==0.19.2
|
| 49 |
jinja2==3.1.6
|
| 50 |
jsonschema==4.24.0
|
| 51 |
jsonschema-specifications==2025.4.1
|
| 52 |
+
jupyter-client==8.6.3
|
| 53 |
+
jupyter-core==5.8.1
|
| 54 |
+
jupyterlab-widgets==3.0.15
|
| 55 |
llvmlite==0.44.0 ; python_full_version >= '3.10'
|
| 56 |
markdown-it-py==3.0.0
|
| 57 |
markupsafe==2.1.5 ; python_full_version < '3.10'
|
| 58 |
markupsafe==3.0.2 ; python_full_version >= '3.10'
|
| 59 |
+
matplotlib-inline==0.1.7
|
| 60 |
mdurl==0.1.2
|
| 61 |
midii==0.1.19 ; python_full_version < '3.10'
|
| 62 |
midii==0.1.41 ; python_full_version >= '3.10'
|
|
|
|
| 65 |
msgpack==1.1.1
|
| 66 |
multidict==6.5.0
|
| 67 |
multiprocess==0.70.16
|
| 68 |
+
mypy-extensions==1.1.0
|
| 69 |
+
nest-asyncio==1.6.0
|
| 70 |
networkx==3.2.1 ; python_full_version < '3.10'
|
| 71 |
networkx==3.4.2 ; python_full_version == '3.10.*'
|
| 72 |
networkx==3.5 ; python_full_version >= '3.11'
|
|
|
|
| 96 |
opentelemetry-semantic-conventions==0.55b1
|
| 97 |
packaging==25.0
|
| 98 |
pandas==2.3.0
|
| 99 |
+
parso==0.8.4
|
| 100 |
+
pathspec==0.12.1
|
| 101 |
+
pexpect==4.9.0 ; (python_full_version < '3.10' and sys_platform == 'emscripten') or (sys_platform != 'emscripten' and sys_platform != 'win32')
|
| 102 |
platformdirs==4.3.8
|
| 103 |
prometheus-client==0.22.1
|
| 104 |
+
prompt-toolkit==3.0.51
|
| 105 |
propcache==0.3.2
|
| 106 |
proto-plus==1.26.1
|
| 107 |
protobuf==5.29.5
|
| 108 |
+
psutil==7.0.0
|
| 109 |
+
ptyprocess==0.7.0 ; (python_full_version < '3.10' and sys_platform == 'emscripten') or (sys_platform != 'emscripten' and sys_platform != 'win32')
|
| 110 |
+
pure-eval==0.2.3
|
| 111 |
py-spy==0.4.0
|
| 112 |
pyarrow==17.0.0
|
| 113 |
pyasn1==0.6.1
|
| 114 |
pyasn1-modules==0.4.2
|
| 115 |
+
pycparser==2.22 ; implementation_name == 'pypy'
|
| 116 |
pydantic==2.11.7
|
| 117 |
pydantic-core==2.33.2
|
| 118 |
pygments==2.19.2
|
| 119 |
python-dateutil==2.9.0.post0
|
| 120 |
python-dotenv==1.1.1
|
| 121 |
pytz==2025.2
|
| 122 |
+
pywin32==310 ; platform_python_implementation != 'PyPy' and sys_platform == 'win32'
|
| 123 |
pyyaml==6.0.2
|
| 124 |
+
pyzmq==27.0.0
|
| 125 |
ray==2.47.1
|
| 126 |
referencing==0.36.2
|
| 127 |
requests==2.32.4
|
| 128 |
rich==14.0.0
|
| 129 |
rpds-py==0.25.1
|
| 130 |
rsa==4.9.1
|
| 131 |
+
ruff==0.12.0
|
| 132 |
setuptools==80.9.0 ; (python_full_version >= '3.12' and platform_machine != 'x86_64') or (python_full_version >= '3.12' and sys_platform != 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')
|
| 133 |
six==1.17.0
|
| 134 |
smart-open==7.1.0
|
| 135 |
sniffio==1.3.1
|
| 136 |
+
stack-data==0.6.3
|
| 137 |
starlette==0.46.2
|
| 138 |
sympy==1.14.0
|
| 139 |
tensorboardx==2.6.4
|
| 140 |
+
tomli==2.2.1 ; python_full_version < '3.11'
|
| 141 |
torch==2.7.1
|
| 142 |
+
tornado==6.5.1
|
| 143 |
tqdm==4.67.1
|
| 144 |
+
traitlets==5.14.3
|
| 145 |
triton==3.3.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 146 |
typing-extensions==4.14.0
|
| 147 |
typing-inspection==0.4.1
|
|
|
|
| 151 |
uvloop==0.21.0 ; platform_python_implementation != 'PyPy' and sys_platform != 'cygwin' and sys_platform != 'win32'
|
| 152 |
virtualenv==20.31.2
|
| 153 |
watchfiles==1.1.0
|
| 154 |
+
wcwidth==0.2.13
|
| 155 |
websockets==12.0 ; python_full_version < '3.10'
|
| 156 |
websockets==15.0.1 ; python_full_version >= '3.10'
|
| 157 |
+
widgetsnbextension==4.0.14
|
| 158 |
wrapt==1.17.2
|
| 159 |
xxhash==3.5.0
|
| 160 |
yarl==1.20.1
|
| 161 |
zipp==3.23.0
|
| 162 |
+
|
| 163 |
+
jupyterlab
|
| 164 |
+
git+https://github.com/ccss17/toy-duration-predictor.git
|
src/toy_duration_predictor/__init__.py
CHANGED
|
@@ -1,2 +1,3 @@
|
|
| 1 |
-
|
| 2 |
-
|
|
|
|
|
|
| 1 |
+
from .train import *
|
| 2 |
+
from .upload import *
|
| 3 |
+
from .load import *
|
src/toy_duration_predictor/preprocess/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .mssv import *
|
src/toy_duration_predictor/preprocess/mssv.py
CHANGED
|
@@ -137,7 +137,7 @@ def process_midi_flat_map(row: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
| 137 |
|
| 138 |
|
| 139 |
def preprocess_dataset(midi_file_directory, output_parquet_path):
|
| 140 |
-
context = ray.init()
|
| 141 |
print(context.dashboard_url)
|
| 142 |
|
| 143 |
all_midi_paths = Path(midi_file_directory).rglob("*.mid")
|
|
|
|
| 137 |
|
| 138 |
|
| 139 |
def preprocess_dataset(midi_file_directory, output_parquet_path):
|
| 140 |
+
context = ray.init(ignore_reinit_error=True)
|
| 141 |
print(context.dashboard_url)
|
| 142 |
|
| 143 |
all_midi_paths = Path(midi_file_directory).rglob("*.mid")
|
src/toy_duration_predictor/train.py
CHANGED
|
@@ -12,10 +12,10 @@ BATCH_SIZE = 64
|
|
| 12 |
|
| 13 |
MODEL_CONFIG = {
|
| 14 |
"num_singers": 18,
|
| 15 |
-
"singer_embedding_dim":
|
| 16 |
-
"hidden_size":
|
| 17 |
-
"num_layers":
|
| 18 |
-
"dropout": 0.
|
| 19 |
}
|
| 20 |
|
| 21 |
LEARNING_RATE = 1e-4
|
|
|
|
| 12 |
|
| 13 |
MODEL_CONFIG = {
|
| 14 |
"num_singers": 18,
|
| 15 |
+
"singer_embedding_dim": 32,
|
| 16 |
+
"hidden_size": 256,
|
| 17 |
+
"num_layers": 3,
|
| 18 |
+
"dropout": 0.4,
|
| 19 |
}
|
| 20 |
|
| 21 |
LEARNING_RATE = 1e-4
|
test.ipynb
CHANGED
|
@@ -10,17 +10,17 @@
|
|
| 10 |
},
|
| 11 |
{
|
| 12 |
"cell_type": "code",
|
| 13 |
-
"execution_count":
|
| 14 |
"id": "d05b06c2",
|
| 15 |
"metadata": {},
|
| 16 |
"outputs": [],
|
| 17 |
"source": [
|
| 18 |
-
"from toy_duration_predictor
|
| 19 |
]
|
| 20 |
},
|
| 21 |
{
|
| 22 |
"cell_type": "code",
|
| 23 |
-
"execution_count":
|
| 24 |
"id": "77c0576d",
|
| 25 |
"metadata": {},
|
| 26 |
"outputs": [],
|
|
@@ -37,7 +37,7 @@
|
|
| 37 |
},
|
| 38 |
{
|
| 39 |
"cell_type": "code",
|
| 40 |
-
"execution_count":
|
| 41 |
"id": "ce7e5029",
|
| 42 |
"metadata": {},
|
| 43 |
"outputs": [
|
|
@@ -45,138 +45,27 @@
|
|
| 45 |
"name": "stderr",
|
| 46 |
"output_type": "stream",
|
| 47 |
"text": [
|
| 48 |
-
"2025-06-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
"
|
| 56 |
-
"
|
| 57 |
-
"
|
| 58 |
-
"\u001b[
|
| 59 |
-
"\u001b[36m(process_midi_to_dict pid=30235)\u001b[0m Error processing /mnt/d/dataset/004.다화자 가창 데이터/01.데이터/1.Training/라벨링데이터/01.발라드R&B/A. 남성/02. 30대/가창자_s09/ba_14343_-1_a_s09_m_03.mid: 'quantized duration'\u001b[32m [repeated 208x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)\u001b[0m\n"
|
| 60 |
-
]
|
| 61 |
-
},
|
| 62 |
-
{
|
| 63 |
-
"name": "stderr",
|
| 64 |
-
"output_type": "stream",
|
| 65 |
-
"text": [
|
| 66 |
-
"2025-06-23 15:44:51,225\tWARNING arrow.py:165 -- Failed to convert column 'item' into pyarrow array due to: Error converting data to Arrow: [ObjectRef(c8ef45ccd0112571ffffffffffffffffffffffff0100000001000000), ObjectRef(16310a0f0a45af5cffffffffffffffffffffffff0100000001000000), ObjectRef(c2668a65bda616c1ffffffffffffffffffffffff01000000010...; falling back to serialize as pickled python objects\n",
|
| 67 |
-
"Traceback (most recent call last):\n",
|
| 68 |
-
" File \"/home/ccsss/repo/toy-duration-predictor/.venv/lib/python3.12/site-packages/ray/air/util/tensor_extensions/arrow.py\", line 204, in _convert_to_pyarrow_native_array\n",
|
| 69 |
-
" pa_type = _infer_pyarrow_type(column_values)\n",
|
| 70 |
-
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
| 71 |
-
" File \"/home/ccsss/repo/toy-duration-predictor/.venv/lib/python3.12/site-packages/ray/air/util/tensor_extensions/arrow.py\", line 317, in _infer_pyarrow_type\n",
|
| 72 |
-
" inferred_pa_dtype = pa.infer_type(column_values)\n",
|
| 73 |
-
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
| 74 |
-
" File \"pyarrow/array.pxi\", line 564, in pyarrow.lib.infer_type\n",
|
| 75 |
-
" File \"pyarrow/error.pxi\", line 155, in pyarrow.lib.pyarrow_internal_check_status\n",
|
| 76 |
-
" File \"pyarrow/error.pxi\", line 92, in pyarrow.lib.check_status\n",
|
| 77 |
-
"pyarrow.lib.ArrowInvalid: Could not convert ObjectRef(c8ef45ccd0112571ffffffffffffffffffffffff0100000001000000) with type ray._raylet.ObjectRef: did not recognize Python value type when inferring an Arrow data type\n",
|
| 78 |
-
"\n",
|
| 79 |
-
"The above exception was the direct cause of the following exception:\n",
|
| 80 |
-
"\n",
|
| 81 |
-
"Traceback (most recent call last):\n",
|
| 82 |
-
" File \"/home/ccsss/repo/toy-duration-predictor/.venv/lib/python3.12/site-packages/ray/air/util/tensor_extensions/arrow.py\", line 117, in convert_to_pyarrow_array\n",
|
| 83 |
-
" return _convert_to_pyarrow_native_array(column_values, column_name)\n",
|
| 84 |
-
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
| 85 |
-
" File \"/home/ccsss/repo/toy-duration-predictor/.venv/lib/python3.12/site-packages/ray/air/util/tensor_extensions/arrow.py\", line 231, in _convert_to_pyarrow_native_array\n",
|
| 86 |
-
" raise ArrowConversionError(str(column_values)) from e\n",
|
| 87 |
-
"ray.air.util.tensor_extensions.arrow.ArrowConversionError: Error converting data to Arrow: [ObjectRef(c8ef45ccd0112571ffffffffffffffffffffffff0100000001000000), ObjectRef(16310a0f0a45af5cffffffffffffffffffffffff0100000001000000), ObjectRef(c2668a65bda616c1ffffffffffffffffffffffff01000000010...\n"
|
| 88 |
-
]
|
| 89 |
-
},
|
| 90 |
-
{
|
| 91 |
-
"name": "stdout",
|
| 92 |
-
"output_type": "stream",
|
| 93 |
-
"text": [
|
| 94 |
-
"Creating a Ray Dataset from results...\n"
|
| 95 |
-
]
|
| 96 |
-
},
|
| 97 |
-
{
|
| 98 |
-
"name": "stderr",
|
| 99 |
-
"output_type": "stream",
|
| 100 |
-
"text": [
|
| 101 |
-
"2025-06-23 15:44:53,181\tINFO dataset.py:3046 -- Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.\n",
|
| 102 |
-
"2025-06-23 15:44:53,357\tINFO logging.py:295 -- Registered dataset logger for dataset dataset_1_0\n"
|
| 103 |
-
]
|
| 104 |
-
},
|
| 105 |
-
{
|
| 106 |
-
"name": "stdout",
|
| 107 |
-
"output_type": "stream",
|
| 108 |
-
"text": [
|
| 109 |
-
"\n",
|
| 110 |
-
"Dataset schema:\n",
|
| 111 |
-
"\n",
|
| 112 |
-
"First 5 rows:\n"
|
| 113 |
-
]
|
| 114 |
-
},
|
| 115 |
-
{
|
| 116 |
-
"name": "stderr",
|
| 117 |
-
"output_type": "stream",
|
| 118 |
-
"text": [
|
| 119 |
-
"2025-06-23 15:44:53,393\tINFO streaming_executor.py:117 -- Starting execution of Dataset dataset_1_0. Full logs are in /tmp/ray/session_2025-06-23_15-44-44_311677_29614/logs/ray-data\n",
|
| 120 |
-
"2025-06-23 15:44:53,395\tINFO streaming_executor.py:118 -- Execution plan of Dataset dataset_1_0: InputDataBuffer[Input] -> LimitOperator[limit=5]\n"
|
| 121 |
-
]
|
| 122 |
-
},
|
| 123 |
-
{
|
| 124 |
-
"name": "stdout",
|
| 125 |
-
"output_type": "stream",
|
| 126 |
-
"text": [
|
| 127 |
-
"[dataset]: Run `pip install tqdm` to enable progress reporting.\n"
|
| 128 |
-
]
|
| 129 |
-
},
|
| 130 |
-
{
|
| 131 |
-
"name": "stderr",
|
| 132 |
-
"output_type": "stream",
|
| 133 |
-
"text": [
|
| 134 |
-
"2025-06-23 15:44:53,647\tINFO streaming_executor.py:227 -- ✔️ Dataset dataset_1_0 execution finished in 0.25 seconds\n",
|
| 135 |
-
"2025-06-23 15:44:53,798\tWARNING plan.py:472 -- Warning: The Ray cluster currently does not have any available CPUs. The Dataset job will hang unless more CPUs are freed up. A common reason is that cluster resources are used by Actors or Tune trials; see the following link for more details: https://docs.ray.io/en/latest/data/data-internals.html#ray-data-and-tune\n",
|
| 136 |
-
"2025-06-23 15:44:53,799\tINFO logging.py:295 -- Registered dataset logger for dataset dataset_4_0\n",
|
| 137 |
-
"2025-06-23 15:44:53,806\tINFO streaming_executor.py:117 -- Starting execution of Dataset dataset_4_0. Full logs are in /tmp/ray/session_2025-06-23_15-44-44_311677_29614/logs/ray-data\n",
|
| 138 |
-
"2025-06-23 15:44:53,809\tINFO streaming_executor.py:118 -- Execution plan of Dataset dataset_4_0: InputDataBuffer[Input] -> AllToAllOperator[Repartition] -> TaskPoolMapOperator[Write]\n"
|
| 139 |
-
]
|
| 140 |
-
},
|
| 141 |
-
{
|
| 142 |
-
"name": "stdout",
|
| 143 |
-
"output_type": "stream",
|
| 144 |
-
"text": [
|
| 145 |
-
"{'item': ObjectRef(c8ef45ccd0112571ffffffffffffffffffffffff0100000001000000)}\n",
|
| 146 |
-
"{'item': ObjectRef(16310a0f0a45af5cffffffffffffffffffffffff0100000001000000)}\n",
|
| 147 |
-
"{'item': ObjectRef(c2668a65bda616c1ffffffffffffffffffffffff0100000001000000)}\n",
|
| 148 |
-
"{'item': ObjectRef(32d950ec0ccf9d2affffffffffffffffffffffff0100000001000000)}\n",
|
| 149 |
-
"{'item': ObjectRef(e0dc174c83599034ffffffffffffffffffffffff0100000001000000)}\n",
|
| 150 |
-
"Repartitioning dataset to control output file size...\n",
|
| 151 |
-
"\n",
|
| 152 |
-
"Writing dataset to Parquet format at: /mnt/d/dataset/mssv_preprocessed_duration\n"
|
| 153 |
-
]
|
| 154 |
-
},
|
| 155 |
-
{
|
| 156 |
-
"name": "stderr",
|
| 157 |
-
"output_type": "stream",
|
| 158 |
-
"text": [
|
| 159 |
-
"2025-06-23 15:44:55,806\tINFO streaming_executor.py:227 -- ✔️ Dataset dataset_4_0 execution finished in 2.00 seconds\n",
|
| 160 |
-
"2025-06-23 15:45:01,968\tINFO dataset.py:4601 -- Data sink Parquet finished. 4205 rows and 839.9KB data written.\n"
|
| 161 |
-
]
|
| 162 |
-
},
|
| 163 |
-
{
|
| 164 |
-
"name": "stdout",
|
| 165 |
-
"output_type": "stream",
|
| 166 |
-
"text": [
|
| 167 |
-
"\n",
|
| 168 |
-
"Processing complete! Your dataset is ready.\n",
|
| 169 |
-
"\u001b[36m(process_midi_to_dict pid=30232)\u001b[0m Error processing /mnt/d/dataset/004.다화자 가창 데이터/01.데이터/1.Training/라벨링데이터/01.발라드R&B/A. 남성/03. 40대 이상/가창자_s13/ba_06120_-2_a_s13_m_04.mid: 'quantized duration'\u001b[32m [repeated 344x across cluster]\u001b[0m\n"
|
| 170 |
]
|
| 171 |
}
|
| 172 |
],
|
| 173 |
"source": [
|
| 174 |
-
"
|
| 175 |
]
|
| 176 |
},
|
| 177 |
{
|
| 178 |
"cell_type": "code",
|
| 179 |
-
"execution_count":
|
| 180 |
"id": "2d4c561c",
|
| 181 |
"metadata": {},
|
| 182 |
"outputs": [
|
|
@@ -3140,19 +3029,19 @@
|
|
| 3140 |
" 'duration': 480}]"
|
| 3141 |
]
|
| 3142 |
},
|
| 3143 |
-
"execution_count":
|
| 3144 |
"metadata": {},
|
| 3145 |
"output_type": "execute_result"
|
| 3146 |
}
|
| 3147 |
],
|
| 3148 |
"source": [
|
| 3149 |
-
"mssv_sample_list, ticks_per_beat =
|
| 3150 |
"mssv_sample_list "
|
| 3151 |
]
|
| 3152 |
},
|
| 3153 |
{
|
| 3154 |
"cell_type": "code",
|
| 3155 |
-
"execution_count":
|
| 3156 |
"id": "0bac9307",
|
| 3157 |
"metadata": {},
|
| 3158 |
"outputs": [
|
|
@@ -3313,13 +3202,13 @@
|
|
| 3313 |
}
|
| 3314 |
],
|
| 3315 |
"source": [
|
| 3316 |
-
"df =
|
| 3317 |
"df"
|
| 3318 |
]
|
| 3319 |
},
|
| 3320 |
{
|
| 3321 |
"cell_type": "code",
|
| 3322 |
-
"execution_count":
|
| 3323 |
"id": "c0b7965b",
|
| 3324 |
"metadata": {},
|
| 3325 |
"outputs": [
|
|
@@ -3335,7 +3224,7 @@
|
|
| 3335 |
}
|
| 3336 |
],
|
| 3337 |
"source": [
|
| 3338 |
-
"singer_id =
|
| 3339 |
"singer_id "
|
| 3340 |
]
|
| 3341 |
},
|
|
@@ -3349,7 +3238,7 @@
|
|
| 3349 |
},
|
| 3350 |
{
|
| 3351 |
"cell_type": "code",
|
| 3352 |
-
"execution_count":
|
| 3353 |
"id": "0e0ef1e4",
|
| 3354 |
"metadata": {},
|
| 3355 |
"outputs": [],
|
|
@@ -3360,7 +3249,7 @@
|
|
| 3360 |
},
|
| 3361 |
{
|
| 3362 |
"cell_type": "code",
|
| 3363 |
-
"execution_count":
|
| 3364 |
"id": "c519f235",
|
| 3365 |
"metadata": {},
|
| 3366 |
"outputs": [
|
|
@@ -3373,7 +3262,7 @@
|
|
| 3373 |
"})"
|
| 3374 |
]
|
| 3375 |
},
|
| 3376 |
-
"execution_count":
|
| 3377 |
"metadata": {},
|
| 3378 |
"output_type": "execute_result"
|
| 3379 |
}
|
|
@@ -3387,7 +3276,7 @@
|
|
| 3387 |
},
|
| 3388 |
{
|
| 3389 |
"cell_type": "code",
|
| 3390 |
-
"execution_count":
|
| 3391 |
"id": "6f62d4d5",
|
| 3392 |
"metadata": {},
|
| 3393 |
"outputs": [
|
|
@@ -3397,7 +3286,7 @@
|
|
| 3397 |
"(4204, 4204)"
|
| 3398 |
]
|
| 3399 |
},
|
| 3400 |
-
"execution_count":
|
| 3401 |
"metadata": {},
|
| 3402 |
"output_type": "execute_result"
|
| 3403 |
}
|
|
@@ -3408,7 +3297,7 @@
|
|
| 3408 |
},
|
| 3409 |
{
|
| 3410 |
"cell_type": "code",
|
| 3411 |
-
"execution_count":
|
| 3412 |
"id": "ff7ec906",
|
| 3413 |
"metadata": {},
|
| 3414 |
"outputs": [
|
|
@@ -3418,7 +3307,7 @@
|
|
| 3418 |
"(4204, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, 18)"
|
| 3419 |
]
|
| 3420 |
},
|
| 3421 |
-
"execution_count":
|
| 3422 |
"metadata": {},
|
| 3423 |
"output_type": "execute_result"
|
| 3424 |
}
|
|
@@ -3429,7 +3318,7 @@
|
|
| 3429 |
},
|
| 3430 |
{
|
| 3431 |
"cell_type": "code",
|
| 3432 |
-
"execution_count":
|
| 3433 |
"id": "a625ec7c",
|
| 3434 |
"metadata": {},
|
| 3435 |
"outputs": [
|
|
@@ -3449,7 +3338,7 @@
|
|
| 3449 |
},
|
| 3450 |
{
|
| 3451 |
"cell_type": "code",
|
| 3452 |
-
"execution_count":
|
| 3453 |
"id": "e22dc42e",
|
| 3454 |
"metadata": {},
|
| 3455 |
"outputs": [
|
|
@@ -3459,7 +3348,7 @@
|
|
| 3459 |
"True"
|
| 3460 |
]
|
| 3461 |
},
|
| 3462 |
-
"execution_count":
|
| 3463 |
"metadata": {},
|
| 3464 |
"output_type": "execute_result"
|
| 3465 |
}
|
|
@@ -3470,7 +3359,7 @@
|
|
| 3470 |
},
|
| 3471 |
{
|
| 3472 |
"cell_type": "code",
|
| 3473 |
-
"execution_count":
|
| 3474 |
"id": "aa41450d",
|
| 3475 |
"metadata": {},
|
| 3476 |
"outputs": [
|
|
@@ -3480,7 +3369,7 @@
|
|
| 3480 |
"(817, 817, 2)"
|
| 3481 |
]
|
| 3482 |
},
|
| 3483 |
-
"execution_count":
|
| 3484 |
"metadata": {},
|
| 3485 |
"output_type": "execute_result"
|
| 3486 |
}
|
|
@@ -3489,265 +3378,6 @@
|
|
| 3489 |
"len(train_dataset[1]['durations']), len(train_dataset[1]['quantized_durations']), train_dataset[1]['singer_id']"
|
| 3490 |
]
|
| 3491 |
},
|
| 3492 |
-
{
|
| 3493 |
-
"cell_type": "code",
|
| 3494 |
-
"execution_count": 24,
|
| 3495 |
-
"id": "cbcad2ca",
|
| 3496 |
-
"metadata": {},
|
| 3497 |
-
"outputs": [
|
| 3498 |
-
{
|
| 3499 |
-
"name": "stdout",
|
| 3500 |
-
"output_type": "stream",
|
| 3501 |
-
"text": [
|
| 3502 |
-
"Chunking dataset into fixed-length sequences...\n"
|
| 3503 |
-
]
|
| 3504 |
-
}
|
| 3505 |
-
],
|
| 3506 |
-
"source": [
|
| 3507 |
-
"# --- Step 2: Define the Preprocessing Function ---\n",
|
| 3508 |
-
"# This function will take a batch of songs and chop their long duration lists\n",
|
| 3509 |
-
"# into fixed-length chunks that the model can handle.\n",
|
| 3510 |
-
"\n",
|
| 3511 |
-
"SEQUENCE_LENGTH = 128 \n",
|
| 3512 |
-
"\n",
|
| 3513 |
-
"def chunk_examples(examples):\n",
|
| 3514 |
-
" # Dictionaries to hold the new, chunked data\n",
|
| 3515 |
-
" chunked_inputs = {\n",
|
| 3516 |
-
" 'original_durations': [],\n",
|
| 3517 |
-
" 'quantized_durations': [],\n",
|
| 3518 |
-
" 'singer_id': [],\n",
|
| 3519 |
-
" }\n",
|
| 3520 |
-
" \n",
|
| 3521 |
-
" # Iterate through each song in the batch\n",
|
| 3522 |
-
" for i in range(len(examples[\"durations\"])):\n",
|
| 3523 |
-
" durs = examples[\"durations\"][i]\n",
|
| 3524 |
-
" q_durs = examples[\"quantized_durations\"][i]\n",
|
| 3525 |
-
" sid = examples[\"singer_id\"][i]\n",
|
| 3526 |
-
" \n",
|
| 3527 |
-
" # Chop the long lists into smaller, fixed-length chunks\n",
|
| 3528 |
-
" for j in range(0, len(durs) - SEQUENCE_LENGTH, SEQUENCE_LENGTH):\n",
|
| 3529 |
-
" chunked_inputs['original_durations'].append(durs[j : j + SEQUENCE_LENGTH])\n",
|
| 3530 |
-
" chunked_inputs['quantized_durations'].append(q_durs[j : j + SEQUENCE_LENGTH])\n",
|
| 3531 |
-
" chunked_inputs['singer_id'].append(sid)\n",
|
| 3532 |
-
" \n",
|
| 3533 |
-
" return chunked_inputs\n",
|
| 3534 |
-
"\n",
|
| 3535 |
-
"\n",
|
| 3536 |
-
"# --- Step 3: Apply the Preprocessing ---\n",
|
| 3537 |
-
"# We use .map() to apply our function to the entire dataset.\n",
|
| 3538 |
-
"# `batched=True` sends multiple rows at a time to our function for efficiency.\n",
|
| 3539 |
-
"# `remove_columns` gets rid of the old, variable-length columns.\n",
|
| 3540 |
-
"print(\"Chunking dataset into fixed-length sequences...\")\n",
|
| 3541 |
-
"processed_dataset = train_dataset.map(\n",
|
| 3542 |
-
" chunk_examples,\n",
|
| 3543 |
-
" batched=True,\n",
|
| 3544 |
-
" remove_columns=train_dataset.column_names\n",
|
| 3545 |
-
")"
|
| 3546 |
-
]
|
| 3547 |
-
},
|
| 3548 |
-
{
|
| 3549 |
-
"cell_type": "code",
|
| 3550 |
-
"execution_count": 9,
|
| 3551 |
-
"id": "93b51c05",
|
| 3552 |
-
"metadata": {},
|
| 3553 |
-
"outputs": [
|
| 3554 |
-
{
|
| 3555 |
-
"data": {
|
| 3556 |
-
"text/plain": [
|
| 3557 |
-
"Dataset({\n",
|
| 3558 |
-
" features: ['quantized_durations', 'singer_id', 'original_durations'],\n",
|
| 3559 |
-
" num_rows: 17941\n",
|
| 3560 |
-
"})"
|
| 3561 |
-
]
|
| 3562 |
-
},
|
| 3563 |
-
"execution_count": 9,
|
| 3564 |
-
"metadata": {},
|
| 3565 |
-
"output_type": "execute_result"
|
| 3566 |
-
}
|
| 3567 |
-
],
|
| 3568 |
-
"source": [
|
| 3569 |
-
"processed_dataset"
|
| 3570 |
-
]
|
| 3571 |
-
},
|
| 3572 |
-
{
|
| 3573 |
-
"cell_type": "code",
|
| 3574 |
-
"execution_count": 10,
|
| 3575 |
-
"id": "8d312061",
|
| 3576 |
-
"metadata": {},
|
| 3577 |
-
"outputs": [
|
| 3578 |
-
{
|
| 3579 |
-
"data": {
|
| 3580 |
-
"text/plain": [
|
| 3581 |
-
"17941"
|
| 3582 |
-
]
|
| 3583 |
-
},
|
| 3584 |
-
"execution_count": 10,
|
| 3585 |
-
"metadata": {},
|
| 3586 |
-
"output_type": "execute_result"
|
| 3587 |
-
}
|
| 3588 |
-
],
|
| 3589 |
-
"source": [
|
| 3590 |
-
"len(processed_dataset['quantized_durations'])"
|
| 3591 |
-
]
|
| 3592 |
-
},
|
| 3593 |
-
{
|
| 3594 |
-
"cell_type": "code",
|
| 3595 |
-
"execution_count": 11,
|
| 3596 |
-
"id": "a38b02c3",
|
| 3597 |
-
"metadata": {},
|
| 3598 |
-
"outputs": [
|
| 3599 |
-
{
|
| 3600 |
-
"data": {
|
| 3601 |
-
"text/plain": [
|
| 3602 |
-
"(128, 128, 2)"
|
| 3603 |
-
]
|
| 3604 |
-
},
|
| 3605 |
-
"execution_count": 11,
|
| 3606 |
-
"metadata": {},
|
| 3607 |
-
"output_type": "execute_result"
|
| 3608 |
-
}
|
| 3609 |
-
],
|
| 3610 |
-
"source": [
|
| 3611 |
-
"len(processed_dataset['original_durations'][0]), len(processed_dataset['quantized_durations'][0]), processed_dataset['singer_id'][0]"
|
| 3612 |
-
]
|
| 3613 |
-
},
|
| 3614 |
-
{
|
| 3615 |
-
"cell_type": "code",
|
| 3616 |
-
"execution_count": 12,
|
| 3617 |
-
"id": "07583462",
|
| 3618 |
-
"metadata": {},
|
| 3619 |
-
"outputs": [
|
| 3620 |
-
{
|
| 3621 |
-
"data": {
|
| 3622 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 3623 |
-
"model_id": "45ae421c3aa0414386598a1213673229",
|
| 3624 |
-
"version_major": 2,
|
| 3625 |
-
"version_minor": 0
|
| 3626 |
-
},
|
| 3627 |
-
"text/plain": [
|
| 3628 |
-
"Map: 0%| | 0/17941 [00:00<?, ? examples/s]"
|
| 3629 |
-
]
|
| 3630 |
-
},
|
| 3631 |
-
"metadata": {},
|
| 3632 |
-
"output_type": "display_data"
|
| 3633 |
-
},
|
| 3634 |
-
{
|
| 3635 |
-
"name": "stdout",
|
| 3636 |
-
"output_type": "stream",
|
| 3637 |
-
"text": [
|
| 3638 |
-
"DataLoader for Control Model (original -> original) created.\n",
|
| 3639 |
-
"DataLoader for Your Method's Model (quantized -> original) created.\n"
|
| 3640 |
-
]
|
| 3641 |
-
}
|
| 3642 |
-
],
|
| 3643 |
-
"source": [
|
| 3644 |
-
"# --- Step 4: Set Format and Create DataLoaders ---\n",
|
| 3645 |
-
"# Now we format the dataset to output PyTorch tensors and create the DataLoaders\n",
|
| 3646 |
-
"# which will handle batching and shuffling for us during training.\n",
|
| 3647 |
-
"\n",
|
| 3648 |
-
"BATCH_SIZE = 32\n",
|
| 3649 |
-
"\n",
|
| 3650 |
-
"# --- Dataloader for Model A (Control) ---\n",
|
| 3651 |
-
"# Goal: original -> original\n",
|
| 3652 |
-
"# We need to duplicate the 'original_durations' column to use it for both input and labels.\n",
|
| 3653 |
-
"model_A_dataset = processed_dataset.map(\n",
|
| 3654 |
-
" lambda batch: {\"labels\": batch[\"original_durations\"]},\n",
|
| 3655 |
-
" batched=True\n",
|
| 3656 |
-
")\n",
|
| 3657 |
-
"# Now we have: 'original_durations', 'quantized_durations', 'singer_id', 'labels'\n",
|
| 3658 |
-
"\n",
|
| 3659 |
-
"# Rename 'original_durations' to 'input_ids' for the model's input\n",
|
| 3660 |
-
"model_A_dataset = model_A_dataset.rename_column(\"original_durations\", \"input_ids\")\n",
|
| 3661 |
-
"model_A_dataset.set_format(type='torch', columns=['input_ids', 'labels', 'singer_id'])\n",
|
| 3662 |
-
"model_A_dataloader = DataLoader(model_A_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
|
| 3663 |
-
"print(\"DataLoader for Control Model (original -> original) created.\")\n",
|
| 3664 |
-
"\n",
|
| 3665 |
-
"\n",
|
| 3666 |
-
"# --- Dataloader for Model B (Your Method) ---\n",
|
| 3667 |
-
"# Goal: quantized -> original\n",
|
| 3668 |
-
"model_B_dataset = processed_dataset.map(\n",
|
| 3669 |
-
" lambda batch: {\"labels\": batch[\"original_durations\"]},\n",
|
| 3670 |
-
" batched=True\n",
|
| 3671 |
-
")\n",
|
| 3672 |
-
"\n",
|
| 3673 |
-
"# Rename 'quantized_durations' to 'input_ids' for the model's input\n",
|
| 3674 |
-
"model_B_dataset = model_B_dataset.rename_column(\"quantized_durations\", \"input_ids\")\n",
|
| 3675 |
-
"model_B_dataset.set_format(type='torch', columns=['input_ids', 'labels', 'singer_id'])\n",
|
| 3676 |
-
"model_B_dataloader = DataLoader(model_B_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
|
| 3677 |
-
"print(\"DataLoader for Your Method's Model (quantized -> original) created.\")"
|
| 3678 |
-
]
|
| 3679 |
-
},
|
| 3680 |
-
{
|
| 3681 |
-
"cell_type": "code",
|
| 3682 |
-
"execution_count": 13,
|
| 3683 |
-
"id": "65c78eb3",
|
| 3684 |
-
"metadata": {},
|
| 3685 |
-
"outputs": [
|
| 3686 |
-
{
|
| 3687 |
-
"name": "stdout",
|
| 3688 |
-
"output_type": "stream",
|
| 3689 |
-
"text": [
|
| 3690 |
-
"\n",
|
| 3691 |
-
"--- Verifying one batch for Model B (Your Method) ---\n",
|
| 3692 |
-
"Shape of the input tensor: torch.Size([32, 128])\n",
|
| 3693 |
-
"Shape of the label tensor: torch.Size([32, 128])\n",
|
| 3694 |
-
"Shape of the singer_id tensor: torch.Size([32])\n",
|
| 3695 |
-
"\n",
|
| 3696 |
-
"First input sequence in the batch (quantized):\n",
|
| 3697 |
-
"tensor([ 240, 60, 360, 60, 300, 0, 780, 1560, 300, 0, 480, 0,\n",
|
| 3698 |
-
" 540, 240, 240, 60, 420, 0, 420, 0, 420, 300, 300, 0,\n",
|
| 3699 |
-
" 180, 60, 300, 60, 540, 0, 360, 0, 180, 60, 1080, 720,\n",
|
| 3700 |
-
" 540, 60, 300, 0, 300, 0, 180, 0, 180, 0, 360, 60,\n",
|
| 3701 |
-
" 540, 0, 300, 0, 420, 0, 300, 360, 540, 60, 360, 0,\n",
|
| 3702 |
-
" 300, 0, 180, 60, 180, 60, 240, 0, 1680, 180, 480, 0,\n",
|
| 3703 |
-
" 420, 60, 360, 0, 480, 0, 1860, 300, 360, 60, 240, 0,\n",
|
| 3704 |
-
" 480, 60, 300, 0, 300, 60, 1260, 1020, 240, 120, 360, 60,\n",
|
| 3705 |
-
" 240, 0, 240, 0, 240, 0, 300, 0, 300, 0, 240, 0,\n",
|
| 3706 |
-
" 480, 120, 300, 0, 240, 60, 2160, 1440, 360, 0, 240, 60,\n",
|
| 3707 |
-
" 360, 60, 480, 0, 300, 0, 1140, 780])\n",
|
| 3708 |
-
"\n",
|
| 3709 |
-
"Corresponding label sequence in the batch (original):\n",
|
| 3710 |
-
"tensor([ 263, 40, 380, 33, 292, 14, 773, 1553, 278, 15, 483, 19,\n",
|
| 3711 |
-
" 511, 253, 264, 40, 419, 13, 398, 15, 399, 278, 279, 13,\n",
|
| 3712 |
-
" 198, 33, 277, 34, 539, 13, 347, 1, 205, 39, 1097, 712,\n",
|
| 3713 |
-
" 564, 40, 293, 11, 271, 12, 178, 15, 177, 12, 375, 73,\n",
|
| 3714 |
-
" 539, 13, 307, 24, 397, 16, 318, 353, 560, 34, 371, 24,\n",
|
| 3715 |
-
" 289, 4, 199, 40, 205, 40, 237, 13, 1680, 172, 466, 3,\n",
|
| 3716 |
-
" 433, 33, 358, 13, 458, 15, 1880, 312, 384, 42, 226, 0,\n",
|
| 3717 |
-
" 457, 53, 280, 15, 325, 37, 1234, 1012, 265, 122, 385, 39,\n",
|
| 3718 |
-
" 221, 12, 224, 15, 250, 25, 277, 16, 270, 14, 219, 14,\n",
|
| 3719 |
-
" 473, 92, 278, 16, 256, 33, 2135, 1413, 337, 16, 267, 43,\n",
|
| 3720 |
-
" 346, 74, 472, 7, 286, 4, 1123, 779])\n",
|
| 3721 |
-
"\n",
|
| 3722 |
-
"Corresponding singer_id:\n",
|
| 3723 |
-
"tensor(4)\n"
|
| 3724 |
-
]
|
| 3725 |
-
}
|
| 3726 |
-
],
|
| 3727 |
-
"source": [
|
| 3728 |
-
"# --- Step 5: Verify One Batch ---\n",
|
| 3729 |
-
"# Let's pull one batch from the dataloader for Model B to see what it looks like.\n",
|
| 3730 |
-
"# This is the data in its final form, \"right before being input into model\".\n",
|
| 3731 |
-
"\n",
|
| 3732 |
-
"print(\"\\n--- Verifying one batch for Model B (Your Method) ---\")\n",
|
| 3733 |
-
"one_batch = next(iter(model_B_dataloader))\n",
|
| 3734 |
-
"\n",
|
| 3735 |
-
"# The dataloader gives us a dictionary of tensors\n",
|
| 3736 |
-
"input_tensor = one_batch['input_ids']\n",
|
| 3737 |
-
"label_tensor = one_batch['labels']\n",
|
| 3738 |
-
"sid_tensor = one_batch['singer_id']\n",
|
| 3739 |
-
"\n",
|
| 3740 |
-
"print(f\"Shape of the input tensor: {input_tensor.shape}\")\n",
|
| 3741 |
-
"print(f\"Shape of the label tensor: {label_tensor.shape}\")\n",
|
| 3742 |
-
"print(f\"Shape of the singer_id tensor: {sid_tensor.shape}\")\n",
|
| 3743 |
-
"print(\"\\nFirst input sequence in the batch (quantized):\")\n",
|
| 3744 |
-
"print(input_tensor[0])\n",
|
| 3745 |
-
"print(\"\\nCorresponding label sequence in the batch (original):\")\n",
|
| 3746 |
-
"print(label_tensor[0])\n",
|
| 3747 |
-
"print(\"\\nCorresponding singer_id:\")\n",
|
| 3748 |
-
"print(sid_tensor[0])"
|
| 3749 |
-
]
|
| 3750 |
-
},
|
| 3751 |
{
|
| 3752 |
"cell_type": "markdown",
|
| 3753 |
"id": "3828f8dd",
|
|
@@ -3763,7 +3393,7 @@
|
|
| 3763 |
"metadata": {},
|
| 3764 |
"outputs": [],
|
| 3765 |
"source": [
|
| 3766 |
-
"import toy_duration_predictor
|
| 3767 |
]
|
| 3768 |
},
|
| 3769 |
{
|
|
@@ -3779,83 +3409,15 @@
|
|
| 3779 |
"\n",
|
| 3780 |
"==================== FINAL MODEL COMPARISON ====================\n",
|
| 3781 |
"Using device for evaluation: cuda:0\n",
|
| 3782 |
-
"--- Preparing data
|
| 3783 |
-
|
| 3784 |
-
},
|
| 3785 |
-
{
|
| 3786 |
-
"data": {
|
| 3787 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 3788 |
-
"model_id": "3007639f434040c8aa28cef9039d10c1",
|
| 3789 |
-
"version_major": 2,
|
| 3790 |
-
"version_minor": 0
|
| 3791 |
-
},
|
| 3792 |
-
"text/plain": [
|
| 3793 |
-
"Map: 0%| | 0/4204 [00:00<?, ? examples/s]"
|
| 3794 |
-
]
|
| 3795 |
-
},
|
| 3796 |
-
"metadata": {},
|
| 3797 |
-
"output_type": "display_data"
|
| 3798 |
-
},
|
| 3799 |
-
{
|
| 3800 |
-
"name": "stdout",
|
| 3801 |
-
"output_type": "stream",
|
| 3802 |
-
"text": [
|
| 3803 |
-
"Calculated training set stats: Mean=128.66, Std=201.76\n"
|
| 3804 |
-
]
|
| 3805 |
-
},
|
| 3806 |
-
{
|
| 3807 |
-
"data": {
|
| 3808 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 3809 |
-
"model_id": "036eb8eb8a8f42ca9fc4223114c47566",
|
| 3810 |
-
"version_major": 2,
|
| 3811 |
-
"version_minor": 0
|
| 3812 |
-
},
|
| 3813 |
-
"text/plain": [
|
| 3814 |
-
"Map: 0%| | 0/3363 [00:00<?, ? examples/s]"
|
| 3815 |
-
]
|
| 3816 |
-
},
|
| 3817 |
-
"metadata": {},
|
| 3818 |
-
"output_type": "display_data"
|
| 3819 |
-
},
|
| 3820 |
-
{
|
| 3821 |
-
"data": {
|
| 3822 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 3823 |
-
"model_id": "418d0c402c41422db24d5f18dea4ecc4",
|
| 3824 |
-
"version_major": 2,
|
| 3825 |
-
"version_minor": 0
|
| 3826 |
-
},
|
| 3827 |
-
"text/plain": [
|
| 3828 |
-
"Map: 0%| | 0/420 [00:00<?, ? examples/s]"
|
| 3829 |
-
]
|
| 3830 |
-
},
|
| 3831 |
-
"metadata": {},
|
| 3832 |
-
"output_type": "display_data"
|
| 3833 |
-
},
|
| 3834 |
-
{
|
| 3835 |
-
"data": {
|
| 3836 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 3837 |
-
"model_id": "44c6f6c9620543f3baed8f09f66514d8",
|
| 3838 |
-
"version_major": 2,
|
| 3839 |
-
"version_minor": 0
|
| 3840 |
-
},
|
| 3841 |
-
"text/plain": [
|
| 3842 |
-
"Map: 0%| | 0/421 [00:00<?, ? examples/s]"
|
| 3843 |
-
]
|
| 3844 |
-
},
|
| 3845 |
-
"metadata": {},
|
| 3846 |
-
"output_type": "display_data"
|
| 3847 |
-
},
|
| 3848 |
-
{
|
| 3849 |
-
"name": "stdout",
|
| 3850 |
-
"output_type": "stream",
|
| 3851 |
-
"text": [
|
| 3852 |
"Evaluating models on the test set...\n"
|
| 3853 |
]
|
| 3854 |
},
|
| 3855 |
{
|
| 3856 |
"data": {
|
| 3857 |
"application/vnd.jupyter.widget-view+json": {
|
| 3858 |
-
"model_id": "
|
| 3859 |
"version_major": 2,
|
| 3860 |
"version_minor": 0
|
| 3861 |
},
|
|
@@ -3879,7 +3441,7 @@
|
|
| 3879 |
"source": [
|
| 3880 |
"model_a_path = 'model_stable/model_A.pth'\n",
|
| 3881 |
"model_b_path = 'model_stable/model_B.pth'\n",
|
| 3882 |
-
"
|
| 3883 |
]
|
| 3884 |
},
|
| 3885 |
{
|
|
@@ -3892,17 +3454,17 @@
|
|
| 3892 |
},
|
| 3893 |
{
|
| 3894 |
"cell_type": "code",
|
| 3895 |
-
"execution_count":
|
| 3896 |
"id": "7a54b60d",
|
| 3897 |
"metadata": {},
|
| 3898 |
"outputs": [],
|
| 3899 |
"source": [
|
| 3900 |
-
"import toy_duration_predictor
|
| 3901 |
]
|
| 3902 |
},
|
| 3903 |
{
|
| 3904 |
"cell_type": "code",
|
| 3905 |
-
"execution_count":
|
| 3906 |
"id": "fa70b5bf",
|
| 3907 |
"metadata": {},
|
| 3908 |
"outputs": [
|
|
@@ -3989,7 +3551,7 @@
|
|
| 3989 |
}
|
| 3990 |
],
|
| 3991 |
"source": [
|
| 3992 |
-
"
|
| 3993 |
]
|
| 3994 |
},
|
| 3995 |
{
|
|
@@ -4007,7 +3569,7 @@
|
|
| 4007 |
"metadata": {},
|
| 4008 |
"outputs": [],
|
| 4009 |
"source": [
|
| 4010 |
-
"
|
| 4011 |
]
|
| 4012 |
},
|
| 4013 |
{
|
|
@@ -4036,7 +3598,7 @@
|
|
| 4036 |
{
|
| 4037 |
"data": {
|
| 4038 |
"application/vnd.jupyter.widget-view+json": {
|
| 4039 |
-
"model_id": "
|
| 4040 |
"version_major": 2,
|
| 4041 |
"version_minor": 0
|
| 4042 |
},
|
|
@@ -4058,7 +3620,7 @@
|
|
| 4058 |
}
|
| 4059 |
],
|
| 4060 |
"source": [
|
| 4061 |
-
"load_and_test()"
|
| 4062 |
]
|
| 4063 |
},
|
| 4064 |
{
|
|
|
|
| 10 |
},
|
| 11 |
{
|
| 12 |
"cell_type": "code",
|
| 13 |
+
"execution_count": 1,
|
| 14 |
"id": "d05b06c2",
|
| 15 |
"metadata": {},
|
| 16 |
"outputs": [],
|
| 17 |
"source": [
|
| 18 |
+
"from toy_duration_predictor import preprocess"
|
| 19 |
]
|
| 20 |
},
|
| 21 |
{
|
| 22 |
"cell_type": "code",
|
| 23 |
+
"execution_count": 2,
|
| 24 |
"id": "77c0576d",
|
| 25 |
"metadata": {},
|
| 26 |
"outputs": [],
|
|
|
|
| 37 |
},
|
| 38 |
{
|
| 39 |
"cell_type": "code",
|
| 40 |
+
"execution_count": null,
|
| 41 |
"id": "ce7e5029",
|
| 42 |
"metadata": {},
|
| 43 |
"outputs": [
|
|
|
|
| 45 |
"name": "stderr",
|
| 46 |
"output_type": "stream",
|
| 47 |
"text": [
|
| 48 |
+
"\u001b[33m(raylet)\u001b[0m [2025-06-26 14:32:29,958 E 2007289 2007315] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2025-06-26_14-31-47_408952_2006798 is over 95% full, available space: 108.749 GB; capacity: 3651.19 GB. Object creation will fail if spilling is required.\n",
|
| 49 |
+
"\u001b[33m(raylet)\u001b[0m [2025-06-26 14:32:39,971 E 2007289 2007315] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2025-06-26_14-31-47_408952_2006798 is over 95% full, available space: 108.749 GB; capacity: 3651.19 GB. Object creation will fail if spilling is required.\n",
|
| 50 |
+
"\u001b[33m(raylet)\u001b[0m [2025-06-26 14:32:49,985 E 2007289 2007315] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2025-06-26_14-31-47_408952_2006798 is over 95% full, available space: 108.749 GB; capacity: 3651.19 GB. Object creation will fail if spilling is required.\n",
|
| 51 |
+
"\u001b[33m(raylet)\u001b[0m [2025-06-26 14:32:59,997 E 2007289 2007315] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2025-06-26_14-31-47_408952_2006798 is over 95% full, available space: 108.748 GB; capacity: 3651.19 GB. Object creation will fail if spilling is required.\n",
|
| 52 |
+
"\u001b[33m(raylet)\u001b[0m [2025-06-26 14:33:10,008 E 2007289 2007315] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2025-06-26_14-31-47_408952_2006798 is over 95% full, available space: 108.748 GB; capacity: 3651.19 GB. Object creation will fail if spilling is required.\n",
|
| 53 |
+
"\u001b[33m(raylet)\u001b[0m [2025-06-26 14:33:20,020 E 2007289 2007315] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2025-06-26_14-31-47_408952_2006798 is over 95% full, available space: 108.748 GB; capacity: 3651.19 GB. Object creation will fail if spilling is required.\n",
|
| 54 |
+
"\u001b[33m(raylet)\u001b[0m [2025-06-26 14:33:30,034 E 2007289 2007315] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2025-06-26_14-31-47_408952_2006798 is over 95% full, available space: 108.748 GB; capacity: 3651.19 GB. Object creation will fail if spilling is required.\n",
|
| 55 |
+
"\u001b[33m(raylet)\u001b[0m [2025-06-26 14:33:40,048 E 2007289 2007315] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2025-06-26_14-31-47_408952_2006798 is over 95% full, available space: 108.748 GB; capacity: 3651.19 GB. Object creation will fail if spilling is required.\n",
|
| 56 |
+
"\u001b[33m(raylet)\u001b[0m [2025-06-26 14:33:50,061 E 2007289 2007315] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2025-06-26_14-31-47_408952_2006798 is over 95% full, available space: 108.748 GB; capacity: 3651.19 GB. Object creation will fail if spilling is required.\n",
|
| 57 |
+
"\u001b[33m(raylet)\u001b[0m [2025-06-26 14:34:00,074 E 2007289 2007315] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2025-06-26_14-31-47_408952_2006798 is over 95% full, available space: 108.748 GB; capacity: 3651.19 GB. Object creation will fail if spilling is required.\n",
|
| 58 |
+
"\u001b[33m(raylet)\u001b[0m [2025-06-26 14:34:10,087 E 2007289 2007315] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2025-06-26_14-31-47_408952_2006798 is over 95% full, available space: 108.748 GB; capacity: 3651.19 GB. Object creation will fail if spilling is required.\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
]
|
| 60 |
}
|
| 61 |
],
|
| 62 |
"source": [
|
| 63 |
+
"preprocess.preprocess_dataset(mssv_path, mssv_preprocessed_path)"
|
| 64 |
]
|
| 65 |
},
|
| 66 |
{
|
| 67 |
"cell_type": "code",
|
| 68 |
+
"execution_count": 4,
|
| 69 |
"id": "2d4c561c",
|
| 70 |
"metadata": {},
|
| 71 |
"outputs": [
|
|
|
|
| 3029 |
" 'duration': 480}]"
|
| 3030 |
]
|
| 3031 |
},
|
| 3032 |
+
"execution_count": 4,
|
| 3033 |
"metadata": {},
|
| 3034 |
"output_type": "execute_result"
|
| 3035 |
}
|
| 3036 |
],
|
| 3037 |
"source": [
|
| 3038 |
+
"mssv_sample_list, ticks_per_beat = preprocess.midi_to_note_list(mssv_sample_midi)\n",
|
| 3039 |
"mssv_sample_list "
|
| 3040 |
]
|
| 3041 |
},
|
| 3042 |
{
|
| 3043 |
"cell_type": "code",
|
| 3044 |
+
"execution_count": null,
|
| 3045 |
"id": "0bac9307",
|
| 3046 |
"metadata": {},
|
| 3047 |
"outputs": [
|
|
|
|
| 3202 |
}
|
| 3203 |
],
|
| 3204 |
"source": [
|
| 3205 |
+
"df = preprocess.preprocess_notes(mssv_sample_list, ticks_per_beat=ticks_per_beat)\n",
|
| 3206 |
"df"
|
| 3207 |
]
|
| 3208 |
},
|
| 3209 |
{
|
| 3210 |
"cell_type": "code",
|
| 3211 |
+
"execution_count": 5,
|
| 3212 |
"id": "c0b7965b",
|
| 3213 |
"metadata": {},
|
| 3214 |
"outputs": [
|
|
|
|
| 3224 |
}
|
| 3225 |
],
|
| 3226 |
"source": [
|
| 3227 |
+
"singer_id = preprocess.singer_id_from_filepath(mssv_sample_midi)\n",
|
| 3228 |
"singer_id "
|
| 3229 |
]
|
| 3230 |
},
|
|
|
|
| 3238 |
},
|
| 3239 |
{
|
| 3240 |
"cell_type": "code",
|
| 3241 |
+
"execution_count": 6,
|
| 3242 |
"id": "0e0ef1e4",
|
| 3243 |
"metadata": {},
|
| 3244 |
"outputs": [],
|
|
|
|
| 3249 |
},
|
| 3250 |
{
|
| 3251 |
"cell_type": "code",
|
| 3252 |
+
"execution_count": 7,
|
| 3253 |
"id": "c519f235",
|
| 3254 |
"metadata": {},
|
| 3255 |
"outputs": [
|
|
|
|
| 3262 |
"})"
|
| 3263 |
]
|
| 3264 |
},
|
| 3265 |
+
"execution_count": 7,
|
| 3266 |
"metadata": {},
|
| 3267 |
"output_type": "execute_result"
|
| 3268 |
}
|
|
|
|
| 3276 |
},
|
| 3277 |
{
|
| 3278 |
"cell_type": "code",
|
| 3279 |
+
"execution_count": 8,
|
| 3280 |
"id": "6f62d4d5",
|
| 3281 |
"metadata": {},
|
| 3282 |
"outputs": [
|
|
|
|
| 3286 |
"(4204, 4204)"
|
| 3287 |
]
|
| 3288 |
},
|
| 3289 |
+
"execution_count": 8,
|
| 3290 |
"metadata": {},
|
| 3291 |
"output_type": "execute_result"
|
| 3292 |
}
|
|
|
|
| 3297 |
},
|
| 3298 |
{
|
| 3299 |
"cell_type": "code",
|
| 3300 |
+
"execution_count": 9,
|
| 3301 |
"id": "ff7ec906",
|
| 3302 |
"metadata": {},
|
| 3303 |
"outputs": [
|
|
|
|
| 3307 |
"(4204, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, 18)"
|
| 3308 |
]
|
| 3309 |
},
|
| 3310 |
+
"execution_count": 9,
|
| 3311 |
"metadata": {},
|
| 3312 |
"output_type": "execute_result"
|
| 3313 |
}
|
|
|
|
| 3318 |
},
|
| 3319 |
{
|
| 3320 |
"cell_type": "code",
|
| 3321 |
+
"execution_count": 10,
|
| 3322 |
"id": "a625ec7c",
|
| 3323 |
"metadata": {},
|
| 3324 |
"outputs": [
|
|
|
|
| 3338 |
},
|
| 3339 |
{
|
| 3340 |
"cell_type": "code",
|
| 3341 |
+
"execution_count": 11,
|
| 3342 |
"id": "e22dc42e",
|
| 3343 |
"metadata": {},
|
| 3344 |
"outputs": [
|
|
|
|
| 3348 |
"True"
|
| 3349 |
]
|
| 3350 |
},
|
| 3351 |
+
"execution_count": 11,
|
| 3352 |
"metadata": {},
|
| 3353 |
"output_type": "execute_result"
|
| 3354 |
}
|
|
|
|
| 3359 |
},
|
| 3360 |
{
|
| 3361 |
"cell_type": "code",
|
| 3362 |
+
"execution_count": 12,
|
| 3363 |
"id": "aa41450d",
|
| 3364 |
"metadata": {},
|
| 3365 |
"outputs": [
|
|
|
|
| 3369 |
"(817, 817, 2)"
|
| 3370 |
]
|
| 3371 |
},
|
| 3372 |
+
"execution_count": 12,
|
| 3373 |
"metadata": {},
|
| 3374 |
"output_type": "execute_result"
|
| 3375 |
}
|
|
|
|
| 3378 |
"len(train_dataset[1]['durations']), len(train_dataset[1]['quantized_durations']), train_dataset[1]['singer_id']"
|
| 3379 |
]
|
| 3380 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3381 |
{
|
| 3382 |
"cell_type": "markdown",
|
| 3383 |
"id": "3828f8dd",
|
|
|
|
| 3393 |
"metadata": {},
|
| 3394 |
"outputs": [],
|
| 3395 |
"source": [
|
| 3396 |
+
"import toy_duration_predictor as tdp"
|
| 3397 |
]
|
| 3398 |
},
|
| 3399 |
{
|
|
|
|
| 3409 |
"\n",
|
| 3410 |
"==================== FINAL MODEL COMPARISON ====================\n",
|
| 3411 |
"Using device for evaluation: cuda:0\n",
|
| 3412 |
+
"--- Preparing test data ---\n",
|
| 3413 |
+
"Calculated training set stats: Mean=128.66, Std=201.76\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3414 |
"Evaluating models on the test set...\n"
|
| 3415 |
]
|
| 3416 |
},
|
| 3417 |
{
|
| 3418 |
"data": {
|
| 3419 |
"application/vnd.jupyter.widget-view+json": {
|
| 3420 |
+
"model_id": "de7de013364a40fa8d9af567b9140063",
|
| 3421 |
"version_major": 2,
|
| 3422 |
"version_minor": 0
|
| 3423 |
},
|
|
|
|
| 3441 |
"source": [
|
| 3442 |
"model_a_path = 'model_stable/model_A.pth'\n",
|
| 3443 |
"model_b_path = 'model_stable/model_B.pth'\n",
|
| 3444 |
+
"tdp.evaluate_and_compare(model_a_path, model_b_path, gpu_id=0)"
|
| 3445 |
]
|
| 3446 |
},
|
| 3447 |
{
|
|
|
|
| 3454 |
},
|
| 3455 |
{
|
| 3456 |
"cell_type": "code",
|
| 3457 |
+
"execution_count": null,
|
| 3458 |
"id": "7a54b60d",
|
| 3459 |
"metadata": {},
|
| 3460 |
"outputs": [],
|
| 3461 |
"source": [
|
| 3462 |
+
"import toy_duration_predictor as tdp"
|
| 3463 |
]
|
| 3464 |
},
|
| 3465 |
{
|
| 3466 |
"cell_type": "code",
|
| 3467 |
+
"execution_count": null,
|
| 3468 |
"id": "fa70b5bf",
|
| 3469 |
"metadata": {},
|
| 3470 |
"outputs": [
|
|
|
|
| 3551 |
}
|
| 3552 |
],
|
| 3553 |
"source": [
|
| 3554 |
+
"tdp.upload_models_to_hub()"
|
| 3555 |
]
|
| 3556 |
},
|
| 3557 |
{
|
|
|
|
| 3569 |
"metadata": {},
|
| 3570 |
"outputs": [],
|
| 3571 |
"source": [
|
| 3572 |
+
"import toy_duration_predictor as tdp"
|
| 3573 |
]
|
| 3574 |
},
|
| 3575 |
{
|
|
|
|
| 3598 |
{
|
| 3599 |
"data": {
|
| 3600 |
"application/vnd.jupyter.widget-view+json": {
|
| 3601 |
+
"model_id": "65b38f6fa7724d73b91fc332eda22893",
|
| 3602 |
"version_major": 2,
|
| 3603 |
"version_minor": 0
|
| 3604 |
},
|
|
|
|
| 3620 |
}
|
| 3621 |
],
|
| 3622 |
"source": [
|
| 3623 |
+
"tdp.load_and_test()"
|
| 3624 |
]
|
| 3625 |
},
|
| 3626 |
{
|