ccss17 commited on
Commit
3e7ca5f
·
1 Parent(s): 89760e4

fixed model

Browse files
Dockerfile DELETED
@@ -1,14 +0,0 @@
1
- FROM python:3.12-slim
2
- # Set the working directory
3
- # All subsequent commands will run from here.
4
- WORKDIR /app
5
- RUN apt-get update && apt-get upgrade -y && rm -rf /var/lib/apt/lists/*
6
- # Copy project files into the container.
7
- # This copies everything from your local directory into the container's /app directory.
8
- COPY . .
9
- # Install all the Python packages listed in your requirements.txt.
10
- # The --no-cache-dir flag keeps the final image size smaller.
11
- RUN pip install --no-cache-dir --upgrade -r requirements.txt
12
- # Define the default command to run when the Space starts.
13
- # This launches a JupyterLab server that is accessible from the web.
14
- CMD ["jupyter", "lab", "--ip=0.0.0.0", "--port=7860", "--allow-root", "--NotebookApp.token=''"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -14,3 +14,9 @@ pip install git+https://github.com/ccss17/toy-duration-predictor.git
14
  ```python
15
  import toy_duration_predictor as tdp
16
  ```
 
 
 
 
 
 
 
14
  ```python
15
  import toy_duration_predictor as tdp
16
  ```
17
+
18
+ ## Hugging Face
19
+
20
+ - Model: https://huggingface.co/ccss17/toy-duration-predictor
21
+ - Dataset: https://huggingface.co/datasets/ccss17/note-duration-dataset
22
+ - Spaces: https://huggingface.co/spaces/ccss17/toy-duration-predictor
requirements-dev.txt DELETED
@@ -1,162 +0,0 @@
1
- -e .
2
- aiohappyeyeballs==2.6.1
3
- aiohttp==3.12.13
4
- aiohttp-cors==0.8.1
5
- aiosignal==1.3.2
6
- annotated-types==0.7.0
7
- anyio==4.9.0
8
- appnope==0.1.4 ; sys_platform == 'darwin'
9
- asttokens==3.0.0
10
- async-timeout==5.0.1 ; python_full_version < '3.11'
11
- attrs==25.3.0
12
- black==25.1.0
13
- cachetools==5.5.2
14
- certifi==2025.6.15
15
- cffi==1.17.1 ; implementation_name == 'pypy'
16
- charset-normalizer==3.4.2
17
- click==8.1.8 ; python_full_version < '3.10'
18
- click==8.2.1 ; python_full_version >= '3.10'
19
- colorama==0.4.6 ; sys_platform == 'win32'
20
- colorful==0.5.6
21
- comm==0.2.2
22
- datasets==3.6.0
23
- debugpy==1.8.14
24
- decorator==5.2.1
25
- dill==0.3.8
26
- distlib==0.3.9
27
- exceptiongroup==1.3.0 ; python_full_version < '3.11'
28
- executing==2.2.0
29
- fastapi==0.115.13
30
- filelock==3.18.0
31
- frozenlist==1.7.0
32
- fsspec==2025.3.0
33
- google-api-core==2.25.1
34
- google-auth==2.40.3
35
- googleapis-common-protos==1.70.0
36
- grpcio==1.73.0
37
- h11==0.16.0
38
- hf-xet==1.1.5 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
39
- httptools==0.6.4
40
- huggingface-hub==0.33.0
41
- idna==3.10
42
- importlib-metadata==8.7.0
43
- ipykernel==6.29.5
44
- ipython==8.18.1 ; python_full_version < '3.10'
45
- ipython==8.37.0 ; python_full_version == '3.10.*'
46
- ipython==9.3.0 ; python_full_version >= '3.11'
47
- ipython-pygments-lexers==1.1.1 ; python_full_version >= '3.11'
48
- ipywidgets==8.1.7
49
- jedi==0.19.2
50
- jinja2==3.1.6
51
- jsonschema==4.24.0
52
- jsonschema-specifications==2025.4.1
53
- jupyter-client==8.6.3
54
- jupyter-core==5.8.1
55
- jupyterlab-widgets==3.0.15
56
- llvmlite==0.44.0 ; python_full_version >= '3.10'
57
- markdown-it-py==3.0.0
58
- markupsafe==2.1.5 ; python_full_version < '3.10'
59
- markupsafe==3.0.2 ; python_full_version >= '3.10'
60
- matplotlib-inline==0.1.7
61
- mdurl==0.1.2
62
- midii==0.1.19 ; python_full_version < '3.10'
63
- midii==0.1.41 ; python_full_version >= '3.10'
64
- mido==1.3.3
65
- mpmath==1.3.0
66
- msgpack==1.1.1
67
- multidict==6.5.0
68
- multiprocess==0.70.16
69
- mypy-extensions==1.1.0
70
- nest-asyncio==1.6.0
71
- networkx==3.2.1 ; python_full_version < '3.10'
72
- networkx==3.4.2 ; python_full_version == '3.10.*'
73
- networkx==3.5 ; python_full_version >= '3.11'
74
- numba==0.61.2 ; python_full_version >= '3.10'
75
- numpy==2.0.2 ; python_full_version < '3.10'
76
- numpy==2.2.6 ; python_full_version >= '3.10'
77
- nvidia-cublas-cu12==12.6.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
78
- nvidia-cuda-cupti-cu12==12.6.80 ; platform_machine == 'x86_64' and sys_platform == 'linux'
79
- nvidia-cuda-nvrtc-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
80
- nvidia-cuda-runtime-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
81
- nvidia-cudnn-cu12==9.5.1.17 ; platform_machine == 'x86_64' and sys_platform == 'linux'
82
- nvidia-cufft-cu12==11.3.0.4 ; platform_machine == 'x86_64' and sys_platform == 'linux'
83
- nvidia-cufile-cu12==1.11.1.6 ; platform_machine == 'x86_64' and sys_platform == 'linux'
84
- nvidia-curand-cu12==10.3.7.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
85
- nvidia-cusolver-cu12==11.7.1.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
86
- nvidia-cusparse-cu12==12.5.4.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
87
- nvidia-cusparselt-cu12==0.6.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
88
- nvidia-nccl-cu12==2.26.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
89
- nvidia-nvjitlink-cu12==12.6.85 ; platform_machine == 'x86_64' and sys_platform == 'linux'
90
- nvidia-nvtx-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
91
- opencensus==0.11.4
92
- opencensus-context==0.1.3
93
- opentelemetry-api==1.34.1
94
- opentelemetry-exporter-prometheus==0.55b1
95
- opentelemetry-proto==1.34.1
96
- opentelemetry-sdk==1.34.1
97
- opentelemetry-semantic-conventions==0.55b1
98
- packaging==25.0
99
- pandas==2.3.0
100
- parso==0.8.4
101
- pathspec==0.12.1
102
- pexpect==4.9.0 ; (python_full_version < '3.10' and sys_platform == 'emscripten') or (sys_platform != 'emscripten' and sys_platform != 'win32')
103
- platformdirs==4.3.8
104
- prometheus-client==0.22.1
105
- prompt-toolkit==3.0.51
106
- propcache==0.3.2
107
- proto-plus==1.26.1
108
- protobuf==5.29.5
109
- psutil==7.0.0
110
- ptyprocess==0.7.0 ; (python_full_version < '3.10' and sys_platform == 'emscripten') or (sys_platform != 'emscripten' and sys_platform != 'win32')
111
- pure-eval==0.2.3
112
- py-spy==0.4.0
113
- pyarrow==17.0.0
114
- pyasn1==0.6.1
115
- pyasn1-modules==0.4.2
116
- pycparser==2.22 ; implementation_name == 'pypy'
117
- pydantic==2.11.7
118
- pydantic-core==2.33.2
119
- pygments==2.19.2
120
- python-dateutil==2.9.0.post0
121
- python-dotenv==1.1.1
122
- pytz==2025.2
123
- pywin32==310 ; platform_python_implementation != 'PyPy' and sys_platform == 'win32'
124
- pyyaml==6.0.2
125
- pyzmq==27.0.0
126
- ray==2.47.1
127
- referencing==0.36.2
128
- requests==2.32.4
129
- rich==14.0.0
130
- rpds-py==0.25.1
131
- rsa==4.9.1
132
- ruff==0.12.0
133
- setuptools==80.9.0 ; (python_full_version >= '3.12' and platform_machine != 'x86_64') or (python_full_version >= '3.12' and sys_platform != 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')
134
- six==1.17.0
135
- smart-open==7.1.0
136
- sniffio==1.3.1
137
- stack-data==0.6.3
138
- starlette==0.46.2
139
- sympy==1.14.0
140
- tensorboardx==2.6.4
141
- tomli==2.2.1 ; python_full_version < '3.11'
142
- torch==2.7.1
143
- tornado==6.5.1
144
- tqdm==4.67.1
145
- traitlets==5.14.3
146
- triton==3.3.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
147
- typing-extensions==4.14.0
148
- typing-inspection==0.4.1
149
- tzdata==2025.2
150
- urllib3==2.5.0
151
- uvicorn==0.34.3
152
- uvloop==0.21.0 ; platform_python_implementation != 'PyPy' and sys_platform != 'cygwin' and sys_platform != 'win32'
153
- virtualenv==20.31.2
154
- watchfiles==1.1.0
155
- wcwidth==0.2.13
156
- websockets==12.0 ; python_full_version < '3.10'
157
- websockets==15.0.1 ; python_full_version >= '3.10'
158
- widgetsnbextension==4.0.14
159
- wrapt==1.17.2
160
- xxhash==3.5.0
161
- yarl==1.20.1
162
- zipp==3.23.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,23 +1,30 @@
1
- -e .
2
  aiohappyeyeballs==2.6.1
3
  aiohttp==3.12.13
4
  aiohttp-cors==0.8.1
5
  aiosignal==1.3.2
6
  annotated-types==0.7.0
7
  anyio==4.9.0
 
 
8
  async-timeout==5.0.1 ; python_full_version < '3.11'
9
  attrs==25.3.0
 
10
  cachetools==5.5.2
11
  certifi==2025.6.15
 
12
  charset-normalizer==3.4.2
13
  click==8.1.8 ; python_full_version < '3.10'
14
  click==8.2.1 ; python_full_version >= '3.10'
15
  colorama==0.4.6 ; sys_platform == 'win32'
16
  colorful==0.5.6
 
17
  datasets==3.6.0
 
 
18
  dill==0.3.8
19
  distlib==0.3.9
20
  exceptiongroup==1.3.0 ; python_full_version < '3.11'
 
21
  fastapi==0.115.13
22
  filelock==3.18.0
23
  frozenlist==1.7.0
@@ -32,13 +39,24 @@ httptools==0.6.4
32
  huggingface-hub==0.33.0
33
  idna==3.10
34
  importlib-metadata==8.7.0
 
 
 
 
 
 
 
35
  jinja2==3.1.6
36
  jsonschema==4.24.0
37
  jsonschema-specifications==2025.4.1
 
 
 
38
  llvmlite==0.44.0 ; python_full_version >= '3.10'
39
  markdown-it-py==3.0.0
40
  markupsafe==2.1.5 ; python_full_version < '3.10'
41
  markupsafe==3.0.2 ; python_full_version >= '3.10'
 
42
  mdurl==0.1.2
43
  midii==0.1.19 ; python_full_version < '3.10'
44
  midii==0.1.41 ; python_full_version >= '3.10'
@@ -47,6 +65,8 @@ mpmath==1.3.0
47
  msgpack==1.1.1
48
  multidict==6.5.0
49
  multiprocess==0.70.16
 
 
50
  networkx==3.2.1 ; python_full_version < '3.10'
51
  networkx==3.4.2 ; python_full_version == '3.10.*'
52
  networkx==3.5 ; python_full_version >= '3.11'
@@ -76,37 +96,52 @@ opentelemetry-sdk==1.34.1
76
  opentelemetry-semantic-conventions==0.55b1
77
  packaging==25.0
78
  pandas==2.3.0
 
 
 
79
  platformdirs==4.3.8
80
  prometheus-client==0.22.1
 
81
  propcache==0.3.2
82
  proto-plus==1.26.1
83
  protobuf==5.29.5
 
 
 
84
  py-spy==0.4.0
85
  pyarrow==17.0.0
86
  pyasn1==0.6.1
87
  pyasn1-modules==0.4.2
 
88
  pydantic==2.11.7
89
  pydantic-core==2.33.2
90
  pygments==2.19.2
91
  python-dateutil==2.9.0.post0
92
  python-dotenv==1.1.1
93
  pytz==2025.2
 
94
  pyyaml==6.0.2
 
95
  ray==2.47.1
96
  referencing==0.36.2
97
  requests==2.32.4
98
  rich==14.0.0
99
  rpds-py==0.25.1
100
  rsa==4.9.1
 
101
  setuptools==80.9.0 ; (python_full_version >= '3.12' and platform_machine != 'x86_64') or (python_full_version >= '3.12' and sys_platform != 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')
102
  six==1.17.0
103
  smart-open==7.1.0
104
  sniffio==1.3.1
 
105
  starlette==0.46.2
106
  sympy==1.14.0
107
  tensorboardx==2.6.4
 
108
  torch==2.7.1
 
109
  tqdm==4.67.1
 
110
  triton==3.3.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
111
  typing-extensions==4.14.0
112
  typing-inspection==0.4.1
@@ -116,9 +151,14 @@ uvicorn==0.34.3
116
  uvloop==0.21.0 ; platform_python_implementation != 'PyPy' and sys_platform != 'cygwin' and sys_platform != 'win32'
117
  virtualenv==20.31.2
118
  watchfiles==1.1.0
 
119
  websockets==12.0 ; python_full_version < '3.10'
120
  websockets==15.0.1 ; python_full_version >= '3.10'
 
121
  wrapt==1.17.2
122
  xxhash==3.5.0
123
  yarl==1.20.1
124
  zipp==3.23.0
 
 
 
 
 
1
  aiohappyeyeballs==2.6.1
2
  aiohttp==3.12.13
3
  aiohttp-cors==0.8.1
4
  aiosignal==1.3.2
5
  annotated-types==0.7.0
6
  anyio==4.9.0
7
+ appnope==0.1.4 ; sys_platform == 'darwin'
8
+ asttokens==3.0.0
9
  async-timeout==5.0.1 ; python_full_version < '3.11'
10
  attrs==25.3.0
11
+ black==25.1.0
12
  cachetools==5.5.2
13
  certifi==2025.6.15
14
+ cffi==1.17.1 ; implementation_name == 'pypy'
15
  charset-normalizer==3.4.2
16
  click==8.1.8 ; python_full_version < '3.10'
17
  click==8.2.1 ; python_full_version >= '3.10'
18
  colorama==0.4.6 ; sys_platform == 'win32'
19
  colorful==0.5.6
20
+ comm==0.2.2
21
  datasets==3.6.0
22
+ debugpy==1.8.14
23
+ decorator==5.2.1
24
  dill==0.3.8
25
  distlib==0.3.9
26
  exceptiongroup==1.3.0 ; python_full_version < '3.11'
27
+ executing==2.2.0
28
  fastapi==0.115.13
29
  filelock==3.18.0
30
  frozenlist==1.7.0
 
39
  huggingface-hub==0.33.0
40
  idna==3.10
41
  importlib-metadata==8.7.0
42
+ ipykernel==6.29.5
43
+ ipython==8.18.1 ; python_full_version < '3.10'
44
+ ipython==8.37.0 ; python_full_version == '3.10.*'
45
+ ipython==9.3.0 ; python_full_version >= '3.11'
46
+ ipython-pygments-lexers==1.1.1 ; python_full_version >= '3.11'
47
+ ipywidgets==8.1.7
48
+ jedi==0.19.2
49
  jinja2==3.1.6
50
  jsonschema==4.24.0
51
  jsonschema-specifications==2025.4.1
52
+ jupyter-client==8.6.3
53
+ jupyter-core==5.8.1
54
+ jupyterlab-widgets==3.0.15
55
  llvmlite==0.44.0 ; python_full_version >= '3.10'
56
  markdown-it-py==3.0.0
57
  markupsafe==2.1.5 ; python_full_version < '3.10'
58
  markupsafe==3.0.2 ; python_full_version >= '3.10'
59
+ matplotlib-inline==0.1.7
60
  mdurl==0.1.2
61
  midii==0.1.19 ; python_full_version < '3.10'
62
  midii==0.1.41 ; python_full_version >= '3.10'
 
65
  msgpack==1.1.1
66
  multidict==6.5.0
67
  multiprocess==0.70.16
68
+ mypy-extensions==1.1.0
69
+ nest-asyncio==1.6.0
70
  networkx==3.2.1 ; python_full_version < '3.10'
71
  networkx==3.4.2 ; python_full_version == '3.10.*'
72
  networkx==3.5 ; python_full_version >= '3.11'
 
96
  opentelemetry-semantic-conventions==0.55b1
97
  packaging==25.0
98
  pandas==2.3.0
99
+ parso==0.8.4
100
+ pathspec==0.12.1
101
+ pexpect==4.9.0 ; (python_full_version < '3.10' and sys_platform == 'emscripten') or (sys_platform != 'emscripten' and sys_platform != 'win32')
102
  platformdirs==4.3.8
103
  prometheus-client==0.22.1
104
+ prompt-toolkit==3.0.51
105
  propcache==0.3.2
106
  proto-plus==1.26.1
107
  protobuf==5.29.5
108
+ psutil==7.0.0
109
+ ptyprocess==0.7.0 ; (python_full_version < '3.10' and sys_platform == 'emscripten') or (sys_platform != 'emscripten' and sys_platform != 'win32')
110
+ pure-eval==0.2.3
111
  py-spy==0.4.0
112
  pyarrow==17.0.0
113
  pyasn1==0.6.1
114
  pyasn1-modules==0.4.2
115
+ pycparser==2.22 ; implementation_name == 'pypy'
116
  pydantic==2.11.7
117
  pydantic-core==2.33.2
118
  pygments==2.19.2
119
  python-dateutil==2.9.0.post0
120
  python-dotenv==1.1.1
121
  pytz==2025.2
122
+ pywin32==310 ; platform_python_implementation != 'PyPy' and sys_platform == 'win32'
123
  pyyaml==6.0.2
124
+ pyzmq==27.0.0
125
  ray==2.47.1
126
  referencing==0.36.2
127
  requests==2.32.4
128
  rich==14.0.0
129
  rpds-py==0.25.1
130
  rsa==4.9.1
131
+ ruff==0.12.0
132
  setuptools==80.9.0 ; (python_full_version >= '3.12' and platform_machine != 'x86_64') or (python_full_version >= '3.12' and sys_platform != 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')
133
  six==1.17.0
134
  smart-open==7.1.0
135
  sniffio==1.3.1
136
+ stack-data==0.6.3
137
  starlette==0.46.2
138
  sympy==1.14.0
139
  tensorboardx==2.6.4
140
+ tomli==2.2.1 ; python_full_version < '3.11'
141
  torch==2.7.1
142
+ tornado==6.5.1
143
  tqdm==4.67.1
144
+ traitlets==5.14.3
145
  triton==3.3.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
146
  typing-extensions==4.14.0
147
  typing-inspection==0.4.1
 
151
  uvloop==0.21.0 ; platform_python_implementation != 'PyPy' and sys_platform != 'cygwin' and sys_platform != 'win32'
152
  virtualenv==20.31.2
153
  watchfiles==1.1.0
154
+ wcwidth==0.2.13
155
  websockets==12.0 ; python_full_version < '3.10'
156
  websockets==15.0.1 ; python_full_version >= '3.10'
157
+ widgetsnbextension==4.0.14
158
  wrapt==1.17.2
159
  xxhash==3.5.0
160
  yarl==1.20.1
161
  zipp==3.23.0
162
+
163
+ jupyterlab
164
+ git+https://github.com/ccss17/toy-duration-predictor.git
src/toy_duration_predictor/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
- # from .preprocess import mssv
2
- # from .preprocess import utils
 
 
1
+ from .train import *
2
+ from .upload import *
3
+ from .load import *
src/toy_duration_predictor/preprocess/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .mssv import *
src/toy_duration_predictor/preprocess/mssv.py CHANGED
@@ -137,7 +137,7 @@ def process_midi_flat_map(row: Dict[str, Any]) -> List[Dict[str, Any]]:
137
 
138
 
139
  def preprocess_dataset(midi_file_directory, output_parquet_path):
140
- context = ray.init()
141
  print(context.dashboard_url)
142
 
143
  all_midi_paths = Path(midi_file_directory).rglob("*.mid")
 
137
 
138
 
139
  def preprocess_dataset(midi_file_directory, output_parquet_path):
140
+ context = ray.init(ignore_reinit_error=True)
141
  print(context.dashboard_url)
142
 
143
  all_midi_paths = Path(midi_file_directory).rglob("*.mid")
src/toy_duration_predictor/train.py CHANGED
@@ -12,10 +12,10 @@ BATCH_SIZE = 64
12
 
13
  MODEL_CONFIG = {
14
  "num_singers": 18,
15
- "singer_embedding_dim": 64,
16
- "hidden_size": 1024,
17
- "num_layers": 4,
18
- "dropout": 0.5,
19
  }
20
 
21
  LEARNING_RATE = 1e-4
 
12
 
13
  MODEL_CONFIG = {
14
  "num_singers": 18,
15
+ "singer_embedding_dim": 32,
16
+ "hidden_size": 256,
17
+ "num_layers": 3,
18
+ "dropout": 0.4,
19
  }
20
 
21
  LEARNING_RATE = 1e-4
test.ipynb CHANGED
@@ -10,17 +10,17 @@
10
  },
11
  {
12
  "cell_type": "code",
13
- "execution_count": null,
14
  "id": "d05b06c2",
15
  "metadata": {},
16
  "outputs": [],
17
  "source": [
18
- "from toy_duration_predictor.preprocess import mssv"
19
  ]
20
  },
21
  {
22
  "cell_type": "code",
23
- "execution_count": 6,
24
  "id": "77c0576d",
25
  "metadata": {},
26
  "outputs": [],
@@ -37,7 +37,7 @@
37
  },
38
  {
39
  "cell_type": "code",
40
- "execution_count": 3,
41
  "id": "ce7e5029",
42
  "metadata": {},
43
  "outputs": [
@@ -45,138 +45,27 @@
45
  "name": "stderr",
46
  "output_type": "stream",
47
  "text": [
48
- "2025-06-23 15:44:45,964\tINFO worker.py:1908 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265 \u001b[39m\u001b[22m\n"
49
- ]
50
- },
51
- {
52
- "name": "stdout",
53
- "output_type": "stream",
54
- "text": [
55
- "127.0.0.1:8265\n",
56
- "Ray cluster started: {'CPU': 12.0, 'node:__internal_head__': 1.0, 'object_store_memory': 918235545.0, 'node:172.20.119.165': 1.0, 'memory': 2142549607.0}\n",
57
- "Launching parallel processing tasks...\n",
58
- "\u001b[36m(process_midi_to_dict pid=30235)\u001b[0m Error processing /mnt/d/dataset/004.다화자 가창 데이터/01.데이터/1.Training/라벨링데이터/01.발라드R&B/A. 남성/01. 20대/가창자_s02/ba_05688_-4_a_s02_m_02.mid: 'quantized duration'\n",
59
- "\u001b[36m(process_midi_to_dict pid=30235)\u001b[0m Error processing /mnt/d/dataset/004.다화자 가창 데이터/01.데이터/1.Training/라벨링데이터/01.발라드R&B/A. 남성/02. 30대/가창자_s09/ba_14343_-1_a_s09_m_03.mid: 'quantized duration'\u001b[32m [repeated 208x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)\u001b[0m\n"
60
- ]
61
- },
62
- {
63
- "name": "stderr",
64
- "output_type": "stream",
65
- "text": [
66
- "2025-06-23 15:44:51,225\tWARNING arrow.py:165 -- Failed to convert column 'item' into pyarrow array due to: Error converting data to Arrow: [ObjectRef(c8ef45ccd0112571ffffffffffffffffffffffff0100000001000000), ObjectRef(16310a0f0a45af5cffffffffffffffffffffffff0100000001000000), ObjectRef(c2668a65bda616c1ffffffffffffffffffffffff01000000010...; falling back to serialize as pickled python objects\n",
67
- "Traceback (most recent call last):\n",
68
- " File \"/home/ccsss/repo/toy-duration-predictor/.venv/lib/python3.12/site-packages/ray/air/util/tensor_extensions/arrow.py\", line 204, in _convert_to_pyarrow_native_array\n",
69
- " pa_type = _infer_pyarrow_type(column_values)\n",
70
- " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
71
- " File \"/home/ccsss/repo/toy-duration-predictor/.venv/lib/python3.12/site-packages/ray/air/util/tensor_extensions/arrow.py\", line 317, in _infer_pyarrow_type\n",
72
- " inferred_pa_dtype = pa.infer_type(column_values)\n",
73
- " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
74
- " File \"pyarrow/array.pxi\", line 564, in pyarrow.lib.infer_type\n",
75
- " File \"pyarrow/error.pxi\", line 155, in pyarrow.lib.pyarrow_internal_check_status\n",
76
- " File \"pyarrow/error.pxi\", line 92, in pyarrow.lib.check_status\n",
77
- "pyarrow.lib.ArrowInvalid: Could not convert ObjectRef(c8ef45ccd0112571ffffffffffffffffffffffff0100000001000000) with type ray._raylet.ObjectRef: did not recognize Python value type when inferring an Arrow data type\n",
78
- "\n",
79
- "The above exception was the direct cause of the following exception:\n",
80
- "\n",
81
- "Traceback (most recent call last):\n",
82
- " File \"/home/ccsss/repo/toy-duration-predictor/.venv/lib/python3.12/site-packages/ray/air/util/tensor_extensions/arrow.py\", line 117, in convert_to_pyarrow_array\n",
83
- " return _convert_to_pyarrow_native_array(column_values, column_name)\n",
84
- " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
85
- " File \"/home/ccsss/repo/toy-duration-predictor/.venv/lib/python3.12/site-packages/ray/air/util/tensor_extensions/arrow.py\", line 231, in _convert_to_pyarrow_native_array\n",
86
- " raise ArrowConversionError(str(column_values)) from e\n",
87
- "ray.air.util.tensor_extensions.arrow.ArrowConversionError: Error converting data to Arrow: [ObjectRef(c8ef45ccd0112571ffffffffffffffffffffffff0100000001000000), ObjectRef(16310a0f0a45af5cffffffffffffffffffffffff0100000001000000), ObjectRef(c2668a65bda616c1ffffffffffffffffffffffff01000000010...\n"
88
- ]
89
- },
90
- {
91
- "name": "stdout",
92
- "output_type": "stream",
93
- "text": [
94
- "Creating a Ray Dataset from results...\n"
95
- ]
96
- },
97
- {
98
- "name": "stderr",
99
- "output_type": "stream",
100
- "text": [
101
- "2025-06-23 15:44:53,181\tINFO dataset.py:3046 -- Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.\n",
102
- "2025-06-23 15:44:53,357\tINFO logging.py:295 -- Registered dataset logger for dataset dataset_1_0\n"
103
- ]
104
- },
105
- {
106
- "name": "stdout",
107
- "output_type": "stream",
108
- "text": [
109
- "\n",
110
- "Dataset schema:\n",
111
- "\n",
112
- "First 5 rows:\n"
113
- ]
114
- },
115
- {
116
- "name": "stderr",
117
- "output_type": "stream",
118
- "text": [
119
- "2025-06-23 15:44:53,393\tINFO streaming_executor.py:117 -- Starting execution of Dataset dataset_1_0. Full logs are in /tmp/ray/session_2025-06-23_15-44-44_311677_29614/logs/ray-data\n",
120
- "2025-06-23 15:44:53,395\tINFO streaming_executor.py:118 -- Execution plan of Dataset dataset_1_0: InputDataBuffer[Input] -> LimitOperator[limit=5]\n"
121
- ]
122
- },
123
- {
124
- "name": "stdout",
125
- "output_type": "stream",
126
- "text": [
127
- "[dataset]: Run `pip install tqdm` to enable progress reporting.\n"
128
- ]
129
- },
130
- {
131
- "name": "stderr",
132
- "output_type": "stream",
133
- "text": [
134
- "2025-06-23 15:44:53,647\tINFO streaming_executor.py:227 -- ✔️ Dataset dataset_1_0 execution finished in 0.25 seconds\n",
135
- "2025-06-23 15:44:53,798\tWARNING plan.py:472 -- Warning: The Ray cluster currently does not have any available CPUs. The Dataset job will hang unless more CPUs are freed up. A common reason is that cluster resources are used by Actors or Tune trials; see the following link for more details: https://docs.ray.io/en/latest/data/data-internals.html#ray-data-and-tune\n",
136
- "2025-06-23 15:44:53,799\tINFO logging.py:295 -- Registered dataset logger for dataset dataset_4_0\n",
137
- "2025-06-23 15:44:53,806\tINFO streaming_executor.py:117 -- Starting execution of Dataset dataset_4_0. Full logs are in /tmp/ray/session_2025-06-23_15-44-44_311677_29614/logs/ray-data\n",
138
- "2025-06-23 15:44:53,809\tINFO streaming_executor.py:118 -- Execution plan of Dataset dataset_4_0: InputDataBuffer[Input] -> AllToAllOperator[Repartition] -> TaskPoolMapOperator[Write]\n"
139
- ]
140
- },
141
- {
142
- "name": "stdout",
143
- "output_type": "stream",
144
- "text": [
145
- "{'item': ObjectRef(c8ef45ccd0112571ffffffffffffffffffffffff0100000001000000)}\n",
146
- "{'item': ObjectRef(16310a0f0a45af5cffffffffffffffffffffffff0100000001000000)}\n",
147
- "{'item': ObjectRef(c2668a65bda616c1ffffffffffffffffffffffff0100000001000000)}\n",
148
- "{'item': ObjectRef(32d950ec0ccf9d2affffffffffffffffffffffff0100000001000000)}\n",
149
- "{'item': ObjectRef(e0dc174c83599034ffffffffffffffffffffffff0100000001000000)}\n",
150
- "Repartitioning dataset to control output file size...\n",
151
- "\n",
152
- "Writing dataset to Parquet format at: /mnt/d/dataset/mssv_preprocessed_duration\n"
153
- ]
154
- },
155
- {
156
- "name": "stderr",
157
- "output_type": "stream",
158
- "text": [
159
- "2025-06-23 15:44:55,806\tINFO streaming_executor.py:227 -- ✔️ Dataset dataset_4_0 execution finished in 2.00 seconds\n",
160
- "2025-06-23 15:45:01,968\tINFO dataset.py:4601 -- Data sink Parquet finished. 4205 rows and 839.9KB data written.\n"
161
- ]
162
- },
163
- {
164
- "name": "stdout",
165
- "output_type": "stream",
166
- "text": [
167
- "\n",
168
- "Processing complete! Your dataset is ready.\n",
169
- "\u001b[36m(process_midi_to_dict pid=30232)\u001b[0m Error processing /mnt/d/dataset/004.다화자 가창 데이터/01.데이터/1.Training/라벨링데이터/01.발라드R&B/A. 남성/03. 40대 이상/가창자_s13/ba_06120_-2_a_s13_m_04.mid: 'quantized duration'\u001b[32m [repeated 344x across cluster]\u001b[0m\n"
170
  ]
171
  }
172
  ],
173
  "source": [
174
- "mssv.preprocess_dataset(mssv_path, mssv_preprocessed_path)"
175
  ]
176
  },
177
  {
178
  "cell_type": "code",
179
- "execution_count": 3,
180
  "id": "2d4c561c",
181
  "metadata": {},
182
  "outputs": [
@@ -3140,19 +3029,19 @@
3140
  " 'duration': 480}]"
3141
  ]
3142
  },
3143
- "execution_count": 3,
3144
  "metadata": {},
3145
  "output_type": "execute_result"
3146
  }
3147
  ],
3148
  "source": [
3149
- "mssv_sample_list, ticks_per_beat = mssv.midi_to_note_list(mssv_sample_midi)\n",
3150
  "mssv_sample_list "
3151
  ]
3152
  },
3153
  {
3154
  "cell_type": "code",
3155
- "execution_count": 5,
3156
  "id": "0bac9307",
3157
  "metadata": {},
3158
  "outputs": [
@@ -3313,13 +3202,13 @@
3313
  }
3314
  ],
3315
  "source": [
3316
- "df = mssv.preprocess_notes(mssv_sample_list, ticks_per_beat=ticks_per_beat)\n",
3317
  "df"
3318
  ]
3319
  },
3320
  {
3321
  "cell_type": "code",
3322
- "execution_count": null,
3323
  "id": "c0b7965b",
3324
  "metadata": {},
3325
  "outputs": [
@@ -3335,7 +3224,7 @@
3335
  }
3336
  ],
3337
  "source": [
3338
- "singer_id = mssv.singer_id_from_filepath(mssv_sample_midi)\n",
3339
  "singer_id "
3340
  ]
3341
  },
@@ -3349,7 +3238,7 @@
3349
  },
3350
  {
3351
  "cell_type": "code",
3352
- "execution_count": 1,
3353
  "id": "0e0ef1e4",
3354
  "metadata": {},
3355
  "outputs": [],
@@ -3360,7 +3249,7 @@
3360
  },
3361
  {
3362
  "cell_type": "code",
3363
- "execution_count": 2,
3364
  "id": "c519f235",
3365
  "metadata": {},
3366
  "outputs": [
@@ -3373,7 +3262,7 @@
3373
  "})"
3374
  ]
3375
  },
3376
- "execution_count": 2,
3377
  "metadata": {},
3378
  "output_type": "execute_result"
3379
  }
@@ -3387,7 +3276,7 @@
3387
  },
3388
  {
3389
  "cell_type": "code",
3390
- "execution_count": 3,
3391
  "id": "6f62d4d5",
3392
  "metadata": {},
3393
  "outputs": [
@@ -3397,7 +3286,7 @@
3397
  "(4204, 4204)"
3398
  ]
3399
  },
3400
- "execution_count": 3,
3401
  "metadata": {},
3402
  "output_type": "execute_result"
3403
  }
@@ -3408,7 +3297,7 @@
3408
  },
3409
  {
3410
  "cell_type": "code",
3411
- "execution_count": 4,
3412
  "id": "ff7ec906",
3413
  "metadata": {},
3414
  "outputs": [
@@ -3418,7 +3307,7 @@
3418
  "(4204, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, 18)"
3419
  ]
3420
  },
3421
- "execution_count": 4,
3422
  "metadata": {},
3423
  "output_type": "execute_result"
3424
  }
@@ -3429,7 +3318,7 @@
3429
  },
3430
  {
3431
  "cell_type": "code",
3432
- "execution_count": 5,
3433
  "id": "a625ec7c",
3434
  "metadata": {},
3435
  "outputs": [
@@ -3449,7 +3338,7 @@
3449
  },
3450
  {
3451
  "cell_type": "code",
3452
- "execution_count": 6,
3453
  "id": "e22dc42e",
3454
  "metadata": {},
3455
  "outputs": [
@@ -3459,7 +3348,7 @@
3459
  "True"
3460
  ]
3461
  },
3462
- "execution_count": 6,
3463
  "metadata": {},
3464
  "output_type": "execute_result"
3465
  }
@@ -3470,7 +3359,7 @@
3470
  },
3471
  {
3472
  "cell_type": "code",
3473
- "execution_count": 7,
3474
  "id": "aa41450d",
3475
  "metadata": {},
3476
  "outputs": [
@@ -3480,7 +3369,7 @@
3480
  "(817, 817, 2)"
3481
  ]
3482
  },
3483
- "execution_count": 7,
3484
  "metadata": {},
3485
  "output_type": "execute_result"
3486
  }
@@ -3489,265 +3378,6 @@
3489
  "len(train_dataset[1]['durations']), len(train_dataset[1]['quantized_durations']), train_dataset[1]['singer_id']"
3490
  ]
3491
  },
3492
- {
3493
- "cell_type": "code",
3494
- "execution_count": 24,
3495
- "id": "cbcad2ca",
3496
- "metadata": {},
3497
- "outputs": [
3498
- {
3499
- "name": "stdout",
3500
- "output_type": "stream",
3501
- "text": [
3502
- "Chunking dataset into fixed-length sequences...\n"
3503
- ]
3504
- }
3505
- ],
3506
- "source": [
3507
- "# --- Step 2: Define the Preprocessing Function ---\n",
3508
- "# This function will take a batch of songs and chop their long duration lists\n",
3509
- "# into fixed-length chunks that the model can handle.\n",
3510
- "\n",
3511
- "SEQUENCE_LENGTH = 128 \n",
3512
- "\n",
3513
- "def chunk_examples(examples):\n",
3514
- " # Dictionaries to hold the new, chunked data\n",
3515
- " chunked_inputs = {\n",
3516
- " 'original_durations': [],\n",
3517
- " 'quantized_durations': [],\n",
3518
- " 'singer_id': [],\n",
3519
- " }\n",
3520
- " \n",
3521
- " # Iterate through each song in the batch\n",
3522
- " for i in range(len(examples[\"durations\"])):\n",
3523
- " durs = examples[\"durations\"][i]\n",
3524
- " q_durs = examples[\"quantized_durations\"][i]\n",
3525
- " sid = examples[\"singer_id\"][i]\n",
3526
- " \n",
3527
- " # Chop the long lists into smaller, fixed-length chunks\n",
3528
- " for j in range(0, len(durs) - SEQUENCE_LENGTH, SEQUENCE_LENGTH):\n",
3529
- " chunked_inputs['original_durations'].append(durs[j : j + SEQUENCE_LENGTH])\n",
3530
- " chunked_inputs['quantized_durations'].append(q_durs[j : j + SEQUENCE_LENGTH])\n",
3531
- " chunked_inputs['singer_id'].append(sid)\n",
3532
- " \n",
3533
- " return chunked_inputs\n",
3534
- "\n",
3535
- "\n",
3536
- "# --- Step 3: Apply the Preprocessing ---\n",
3537
- "# We use .map() to apply our function to the entire dataset.\n",
3538
- "# `batched=True` sends multiple rows at a time to our function for efficiency.\n",
3539
- "# `remove_columns` gets rid of the old, variable-length columns.\n",
3540
- "print(\"Chunking dataset into fixed-length sequences...\")\n",
3541
- "processed_dataset = train_dataset.map(\n",
3542
- " chunk_examples,\n",
3543
- " batched=True,\n",
3544
- " remove_columns=train_dataset.column_names\n",
3545
- ")"
3546
- ]
3547
- },
3548
- {
3549
- "cell_type": "code",
3550
- "execution_count": 9,
3551
- "id": "93b51c05",
3552
- "metadata": {},
3553
- "outputs": [
3554
- {
3555
- "data": {
3556
- "text/plain": [
3557
- "Dataset({\n",
3558
- " features: ['quantized_durations', 'singer_id', 'original_durations'],\n",
3559
- " num_rows: 17941\n",
3560
- "})"
3561
- ]
3562
- },
3563
- "execution_count": 9,
3564
- "metadata": {},
3565
- "output_type": "execute_result"
3566
- }
3567
- ],
3568
- "source": [
3569
- "processed_dataset"
3570
- ]
3571
- },
3572
- {
3573
- "cell_type": "code",
3574
- "execution_count": 10,
3575
- "id": "8d312061",
3576
- "metadata": {},
3577
- "outputs": [
3578
- {
3579
- "data": {
3580
- "text/plain": [
3581
- "17941"
3582
- ]
3583
- },
3584
- "execution_count": 10,
3585
- "metadata": {},
3586
- "output_type": "execute_result"
3587
- }
3588
- ],
3589
- "source": [
3590
- "len(processed_dataset['quantized_durations'])"
3591
- ]
3592
- },
3593
- {
3594
- "cell_type": "code",
3595
- "execution_count": 11,
3596
- "id": "a38b02c3",
3597
- "metadata": {},
3598
- "outputs": [
3599
- {
3600
- "data": {
3601
- "text/plain": [
3602
- "(128, 128, 2)"
3603
- ]
3604
- },
3605
- "execution_count": 11,
3606
- "metadata": {},
3607
- "output_type": "execute_result"
3608
- }
3609
- ],
3610
- "source": [
3611
- "len(processed_dataset['original_durations'][0]), len(processed_dataset['quantized_durations'][0]), processed_dataset['singer_id'][0]"
3612
- ]
3613
- },
3614
- {
3615
- "cell_type": "code",
3616
- "execution_count": 12,
3617
- "id": "07583462",
3618
- "metadata": {},
3619
- "outputs": [
3620
- {
3621
- "data": {
3622
- "application/vnd.jupyter.widget-view+json": {
3623
- "model_id": "45ae421c3aa0414386598a1213673229",
3624
- "version_major": 2,
3625
- "version_minor": 0
3626
- },
3627
- "text/plain": [
3628
- "Map: 0%| | 0/17941 [00:00<?, ? examples/s]"
3629
- ]
3630
- },
3631
- "metadata": {},
3632
- "output_type": "display_data"
3633
- },
3634
- {
3635
- "name": "stdout",
3636
- "output_type": "stream",
3637
- "text": [
3638
- "DataLoader for Control Model (original -> original) created.\n",
3639
- "DataLoader for Your Method's Model (quantized -> original) created.\n"
3640
- ]
3641
- }
3642
- ],
3643
- "source": [
3644
- "# --- Step 4: Set Format and Create DataLoaders ---\n",
3645
- "# Now we format the dataset to output PyTorch tensors and create the DataLoaders\n",
3646
- "# which will handle batching and shuffling for us during training.\n",
3647
- "\n",
3648
- "BATCH_SIZE = 32\n",
3649
- "\n",
3650
- "# --- Dataloader for Model A (Control) ---\n",
3651
- "# Goal: original -> original\n",
3652
- "# We need to duplicate the 'original_durations' column to use it for both input and labels.\n",
3653
- "model_A_dataset = processed_dataset.map(\n",
3654
- " lambda batch: {\"labels\": batch[\"original_durations\"]},\n",
3655
- " batched=True\n",
3656
- ")\n",
3657
- "# Now we have: 'original_durations', 'quantized_durations', 'singer_id', 'labels'\n",
3658
- "\n",
3659
- "# Rename 'original_durations' to 'input_ids' for the model's input\n",
3660
- "model_A_dataset = model_A_dataset.rename_column(\"original_durations\", \"input_ids\")\n",
3661
- "model_A_dataset.set_format(type='torch', columns=['input_ids', 'labels', 'singer_id'])\n",
3662
- "model_A_dataloader = DataLoader(model_A_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
3663
- "print(\"DataLoader for Control Model (original -> original) created.\")\n",
3664
- "\n",
3665
- "\n",
3666
- "# --- Dataloader for Model B (Your Method) ---\n",
3667
- "# Goal: quantized -> original\n",
3668
- "model_B_dataset = processed_dataset.map(\n",
3669
- " lambda batch: {\"labels\": batch[\"original_durations\"]},\n",
3670
- " batched=True\n",
3671
- ")\n",
3672
- "\n",
3673
- "# Rename 'quantized_durations' to 'input_ids' for the model's input\n",
3674
- "model_B_dataset = model_B_dataset.rename_column(\"quantized_durations\", \"input_ids\")\n",
3675
- "model_B_dataset.set_format(type='torch', columns=['input_ids', 'labels', 'singer_id'])\n",
3676
- "model_B_dataloader = DataLoader(model_B_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
3677
- "print(\"DataLoader for Your Method's Model (quantized -> original) created.\")"
3678
- ]
3679
- },
3680
- {
3681
- "cell_type": "code",
3682
- "execution_count": 13,
3683
- "id": "65c78eb3",
3684
- "metadata": {},
3685
- "outputs": [
3686
- {
3687
- "name": "stdout",
3688
- "output_type": "stream",
3689
- "text": [
3690
- "\n",
3691
- "--- Verifying one batch for Model B (Your Method) ---\n",
3692
- "Shape of the input tensor: torch.Size([32, 128])\n",
3693
- "Shape of the label tensor: torch.Size([32, 128])\n",
3694
- "Shape of the singer_id tensor: torch.Size([32])\n",
3695
- "\n",
3696
- "First input sequence in the batch (quantized):\n",
3697
- "tensor([ 240, 60, 360, 60, 300, 0, 780, 1560, 300, 0, 480, 0,\n",
3698
- " 540, 240, 240, 60, 420, 0, 420, 0, 420, 300, 300, 0,\n",
3699
- " 180, 60, 300, 60, 540, 0, 360, 0, 180, 60, 1080, 720,\n",
3700
- " 540, 60, 300, 0, 300, 0, 180, 0, 180, 0, 360, 60,\n",
3701
- " 540, 0, 300, 0, 420, 0, 300, 360, 540, 60, 360, 0,\n",
3702
- " 300, 0, 180, 60, 180, 60, 240, 0, 1680, 180, 480, 0,\n",
3703
- " 420, 60, 360, 0, 480, 0, 1860, 300, 360, 60, 240, 0,\n",
3704
- " 480, 60, 300, 0, 300, 60, 1260, 1020, 240, 120, 360, 60,\n",
3705
- " 240, 0, 240, 0, 240, 0, 300, 0, 300, 0, 240, 0,\n",
3706
- " 480, 120, 300, 0, 240, 60, 2160, 1440, 360, 0, 240, 60,\n",
3707
- " 360, 60, 480, 0, 300, 0, 1140, 780])\n",
3708
- "\n",
3709
- "Corresponding label sequence in the batch (original):\n",
3710
- "tensor([ 263, 40, 380, 33, 292, 14, 773, 1553, 278, 15, 483, 19,\n",
3711
- " 511, 253, 264, 40, 419, 13, 398, 15, 399, 278, 279, 13,\n",
3712
- " 198, 33, 277, 34, 539, 13, 347, 1, 205, 39, 1097, 712,\n",
3713
- " 564, 40, 293, 11, 271, 12, 178, 15, 177, 12, 375, 73,\n",
3714
- " 539, 13, 307, 24, 397, 16, 318, 353, 560, 34, 371, 24,\n",
3715
- " 289, 4, 199, 40, 205, 40, 237, 13, 1680, 172, 466, 3,\n",
3716
- " 433, 33, 358, 13, 458, 15, 1880, 312, 384, 42, 226, 0,\n",
3717
- " 457, 53, 280, 15, 325, 37, 1234, 1012, 265, 122, 385, 39,\n",
3718
- " 221, 12, 224, 15, 250, 25, 277, 16, 270, 14, 219, 14,\n",
3719
- " 473, 92, 278, 16, 256, 33, 2135, 1413, 337, 16, 267, 43,\n",
3720
- " 346, 74, 472, 7, 286, 4, 1123, 779])\n",
3721
- "\n",
3722
- "Corresponding singer_id:\n",
3723
- "tensor(4)\n"
3724
- ]
3725
- }
3726
- ],
3727
- "source": [
3728
- "# --- Step 5: Verify One Batch ---\n",
3729
- "# Let's pull one batch from the dataloader for Model B to see what it looks like.\n",
3730
- "# This is the data in its final form, \"right before being input into model\".\n",
3731
- "\n",
3732
- "print(\"\\n--- Verifying one batch for Model B (Your Method) ---\")\n",
3733
- "one_batch = next(iter(model_B_dataloader))\n",
3734
- "\n",
3735
- "# The dataloader gives us a dictionary of tensors\n",
3736
- "input_tensor = one_batch['input_ids']\n",
3737
- "label_tensor = one_batch['labels']\n",
3738
- "sid_tensor = one_batch['singer_id']\n",
3739
- "\n",
3740
- "print(f\"Shape of the input tensor: {input_tensor.shape}\")\n",
3741
- "print(f\"Shape of the label tensor: {label_tensor.shape}\")\n",
3742
- "print(f\"Shape of the singer_id tensor: {sid_tensor.shape}\")\n",
3743
- "print(\"\\nFirst input sequence in the batch (quantized):\")\n",
3744
- "print(input_tensor[0])\n",
3745
- "print(\"\\nCorresponding label sequence in the batch (original):\")\n",
3746
- "print(label_tensor[0])\n",
3747
- "print(\"\\nCorresponding singer_id:\")\n",
3748
- "print(sid_tensor[0])"
3749
- ]
3750
- },
3751
  {
3752
  "cell_type": "markdown",
3753
  "id": "3828f8dd",
@@ -3763,7 +3393,7 @@
3763
  "metadata": {},
3764
  "outputs": [],
3765
  "source": [
3766
- "import toy_duration_predictor.train as tt"
3767
  ]
3768
  },
3769
  {
@@ -3779,83 +3409,15 @@
3779
  "\n",
3780
  "==================== FINAL MODEL COMPARISON ====================\n",
3781
  "Using device for evaluation: cuda:0\n",
3782
- "--- Preparing data for Model B ---\n"
3783
- ]
3784
- },
3785
- {
3786
- "data": {
3787
- "application/vnd.jupyter.widget-view+json": {
3788
- "model_id": "3007639f434040c8aa28cef9039d10c1",
3789
- "version_major": 2,
3790
- "version_minor": 0
3791
- },
3792
- "text/plain": [
3793
- "Map: 0%| | 0/4204 [00:00<?, ? examples/s]"
3794
- ]
3795
- },
3796
- "metadata": {},
3797
- "output_type": "display_data"
3798
- },
3799
- {
3800
- "name": "stdout",
3801
- "output_type": "stream",
3802
- "text": [
3803
- "Calculated training set stats: Mean=128.66, Std=201.76\n"
3804
- ]
3805
- },
3806
- {
3807
- "data": {
3808
- "application/vnd.jupyter.widget-view+json": {
3809
- "model_id": "036eb8eb8a8f42ca9fc4223114c47566",
3810
- "version_major": 2,
3811
- "version_minor": 0
3812
- },
3813
- "text/plain": [
3814
- "Map: 0%| | 0/3363 [00:00<?, ? examples/s]"
3815
- ]
3816
- },
3817
- "metadata": {},
3818
- "output_type": "display_data"
3819
- },
3820
- {
3821
- "data": {
3822
- "application/vnd.jupyter.widget-view+json": {
3823
- "model_id": "418d0c402c41422db24d5f18dea4ecc4",
3824
- "version_major": 2,
3825
- "version_minor": 0
3826
- },
3827
- "text/plain": [
3828
- "Map: 0%| | 0/420 [00:00<?, ? examples/s]"
3829
- ]
3830
- },
3831
- "metadata": {},
3832
- "output_type": "display_data"
3833
- },
3834
- {
3835
- "data": {
3836
- "application/vnd.jupyter.widget-view+json": {
3837
- "model_id": "44c6f6c9620543f3baed8f09f66514d8",
3838
- "version_major": 2,
3839
- "version_minor": 0
3840
- },
3841
- "text/plain": [
3842
- "Map: 0%| | 0/421 [00:00<?, ? examples/s]"
3843
- ]
3844
- },
3845
- "metadata": {},
3846
- "output_type": "display_data"
3847
- },
3848
- {
3849
- "name": "stdout",
3850
- "output_type": "stream",
3851
- "text": [
3852
  "Evaluating models on the test set...\n"
3853
  ]
3854
  },
3855
  {
3856
  "data": {
3857
  "application/vnd.jupyter.widget-view+json": {
3858
- "model_id": "f989de925f684cefbd2e1bd6af6cc554",
3859
  "version_major": 2,
3860
  "version_minor": 0
3861
  },
@@ -3879,7 +3441,7 @@
3879
  "source": [
3880
  "model_a_path = 'model_stable/model_A.pth'\n",
3881
  "model_b_path = 'model_stable/model_B.pth'\n",
3882
- "tt.evaluate_and_compare(model_a_path, model_b_path, gpu_id=0)"
3883
  ]
3884
  },
3885
  {
@@ -3892,17 +3454,17 @@
3892
  },
3893
  {
3894
  "cell_type": "code",
3895
- "execution_count": 1,
3896
  "id": "7a54b60d",
3897
  "metadata": {},
3898
  "outputs": [],
3899
  "source": [
3900
- "import toy_duration_predictor.upload as up"
3901
  ]
3902
  },
3903
  {
3904
  "cell_type": "code",
3905
- "execution_count": 2,
3906
  "id": "fa70b5bf",
3907
  "metadata": {},
3908
  "outputs": [
@@ -3989,7 +3551,7 @@
3989
  }
3990
  ],
3991
  "source": [
3992
- "up.upload_models_to_hub()"
3993
  ]
3994
  },
3995
  {
@@ -4007,7 +3569,7 @@
4007
  "metadata": {},
4008
  "outputs": [],
4009
  "source": [
4010
- "from toy_duration_predictor.load import load_and_test"
4011
  ]
4012
  },
4013
  {
@@ -4036,7 +3598,7 @@
4036
  {
4037
  "data": {
4038
  "application/vnd.jupyter.widget-view+json": {
4039
- "model_id": "71f3f4ce46d343768862a8daf9aa6413",
4040
  "version_major": 2,
4041
  "version_minor": 0
4042
  },
@@ -4058,7 +3620,7 @@
4058
  }
4059
  ],
4060
  "source": [
4061
- "load_and_test()"
4062
  ]
4063
  },
4064
  {
 
10
  },
11
  {
12
  "cell_type": "code",
13
+ "execution_count": 1,
14
  "id": "d05b06c2",
15
  "metadata": {},
16
  "outputs": [],
17
  "source": [
18
+ "from toy_duration_predictor import preprocess"
19
  ]
20
  },
21
  {
22
  "cell_type": "code",
23
+ "execution_count": 2,
24
  "id": "77c0576d",
25
  "metadata": {},
26
  "outputs": [],
 
37
  },
38
  {
39
  "cell_type": "code",
40
+ "execution_count": null,
41
  "id": "ce7e5029",
42
  "metadata": {},
43
  "outputs": [
 
45
  "name": "stderr",
46
  "output_type": "stream",
47
  "text": [
48
+ "\u001b[33m(raylet)\u001b[0m [2025-06-26 14:32:29,958 E 2007289 2007315] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2025-06-26_14-31-47_408952_2006798 is over 95% full, available space: 108.749 GB; capacity: 3651.19 GB. Object creation will fail if spilling is required.\n",
49
+ "\u001b[33m(raylet)\u001b[0m [2025-06-26 14:32:39,971 E 2007289 2007315] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2025-06-26_14-31-47_408952_2006798 is over 95% full, available space: 108.749 GB; capacity: 3651.19 GB. Object creation will fail if spilling is required.\n",
50
+ "\u001b[33m(raylet)\u001b[0m [2025-06-26 14:32:49,985 E 2007289 2007315] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2025-06-26_14-31-47_408952_2006798 is over 95% full, available space: 108.749 GB; capacity: 3651.19 GB. Object creation will fail if spilling is required.\n",
51
+ "\u001b[33m(raylet)\u001b[0m [2025-06-26 14:32:59,997 E 2007289 2007315] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2025-06-26_14-31-47_408952_2006798 is over 95% full, available space: 108.748 GB; capacity: 3651.19 GB. Object creation will fail if spilling is required.\n",
52
+ "\u001b[33m(raylet)\u001b[0m [2025-06-26 14:33:10,008 E 2007289 2007315] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2025-06-26_14-31-47_408952_2006798 is over 95% full, available space: 108.748 GB; capacity: 3651.19 GB. Object creation will fail if spilling is required.\n",
53
+ "\u001b[33m(raylet)\u001b[0m [2025-06-26 14:33:20,020 E 2007289 2007315] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2025-06-26_14-31-47_408952_2006798 is over 95% full, available space: 108.748 GB; capacity: 3651.19 GB. Object creation will fail if spilling is required.\n",
54
+ "\u001b[33m(raylet)\u001b[0m [2025-06-26 14:33:30,034 E 2007289 2007315] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2025-06-26_14-31-47_408952_2006798 is over 95% full, available space: 108.748 GB; capacity: 3651.19 GB. Object creation will fail if spilling is required.\n",
55
+ "\u001b[33m(raylet)\u001b[0m [2025-06-26 14:33:40,048 E 2007289 2007315] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2025-06-26_14-31-47_408952_2006798 is over 95% full, available space: 108.748 GB; capacity: 3651.19 GB. Object creation will fail if spilling is required.\n",
56
+ "\u001b[33m(raylet)\u001b[0m [2025-06-26 14:33:50,061 E 2007289 2007315] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2025-06-26_14-31-47_408952_2006798 is over 95% full, available space: 108.748 GB; capacity: 3651.19 GB. Object creation will fail if spilling is required.\n",
57
+ "\u001b[33m(raylet)\u001b[0m [2025-06-26 14:34:00,074 E 2007289 2007315] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2025-06-26_14-31-47_408952_2006798 is over 95% full, available space: 108.748 GB; capacity: 3651.19 GB. Object creation will fail if spilling is required.\n",
58
+ "\u001b[33m(raylet)\u001b[0m [2025-06-26 14:34:10,087 E 2007289 2007315] (raylet) file_system_monitor.cc:116: /tmp/ray/session_2025-06-26_14-31-47_408952_2006798 is over 95% full, available space: 108.748 GB; capacity: 3651.19 GB. Object creation will fail if spilling is required.\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  ]
60
  }
61
  ],
62
  "source": [
63
+ "preprocess.preprocess_dataset(mssv_path, mssv_preprocessed_path)"
64
  ]
65
  },
66
  {
67
  "cell_type": "code",
68
+ "execution_count": 4,
69
  "id": "2d4c561c",
70
  "metadata": {},
71
  "outputs": [
 
3029
  " 'duration': 480}]"
3030
  ]
3031
  },
3032
+ "execution_count": 4,
3033
  "metadata": {},
3034
  "output_type": "execute_result"
3035
  }
3036
  ],
3037
  "source": [
3038
+ "mssv_sample_list, ticks_per_beat = preprocess.midi_to_note_list(mssv_sample_midi)\n",
3039
  "mssv_sample_list "
3040
  ]
3041
  },
3042
  {
3043
  "cell_type": "code",
3044
+ "execution_count": null,
3045
  "id": "0bac9307",
3046
  "metadata": {},
3047
  "outputs": [
 
3202
  }
3203
  ],
3204
  "source": [
3205
+ "df = preprocess.preprocess_notes(mssv_sample_list, ticks_per_beat=ticks_per_beat)\n",
3206
  "df"
3207
  ]
3208
  },
3209
  {
3210
  "cell_type": "code",
3211
+ "execution_count": 5,
3212
  "id": "c0b7965b",
3213
  "metadata": {},
3214
  "outputs": [
 
3224
  }
3225
  ],
3226
  "source": [
3227
+ "singer_id = preprocess.singer_id_from_filepath(mssv_sample_midi)\n",
3228
  "singer_id "
3229
  ]
3230
  },
 
3238
  },
3239
  {
3240
  "cell_type": "code",
3241
+ "execution_count": 6,
3242
  "id": "0e0ef1e4",
3243
  "metadata": {},
3244
  "outputs": [],
 
3249
  },
3250
  {
3251
  "cell_type": "code",
3252
+ "execution_count": 7,
3253
  "id": "c519f235",
3254
  "metadata": {},
3255
  "outputs": [
 
3262
  "})"
3263
  ]
3264
  },
3265
+ "execution_count": 7,
3266
  "metadata": {},
3267
  "output_type": "execute_result"
3268
  }
 
3276
  },
3277
  {
3278
  "cell_type": "code",
3279
+ "execution_count": 8,
3280
  "id": "6f62d4d5",
3281
  "metadata": {},
3282
  "outputs": [
 
3286
  "(4204, 4204)"
3287
  ]
3288
  },
3289
+ "execution_count": 8,
3290
  "metadata": {},
3291
  "output_type": "execute_result"
3292
  }
 
3297
  },
3298
  {
3299
  "cell_type": "code",
3300
+ "execution_count": 9,
3301
  "id": "ff7ec906",
3302
  "metadata": {},
3303
  "outputs": [
 
3307
  "(4204, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, 18)"
3308
  ]
3309
  },
3310
+ "execution_count": 9,
3311
  "metadata": {},
3312
  "output_type": "execute_result"
3313
  }
 
3318
  },
3319
  {
3320
  "cell_type": "code",
3321
+ "execution_count": 10,
3322
  "id": "a625ec7c",
3323
  "metadata": {},
3324
  "outputs": [
 
3338
  },
3339
  {
3340
  "cell_type": "code",
3341
+ "execution_count": 11,
3342
  "id": "e22dc42e",
3343
  "metadata": {},
3344
  "outputs": [
 
3348
  "True"
3349
  ]
3350
  },
3351
+ "execution_count": 11,
3352
  "metadata": {},
3353
  "output_type": "execute_result"
3354
  }
 
3359
  },
3360
  {
3361
  "cell_type": "code",
3362
+ "execution_count": 12,
3363
  "id": "aa41450d",
3364
  "metadata": {},
3365
  "outputs": [
 
3369
  "(817, 817, 2)"
3370
  ]
3371
  },
3372
+ "execution_count": 12,
3373
  "metadata": {},
3374
  "output_type": "execute_result"
3375
  }
 
3378
  "len(train_dataset[1]['durations']), len(train_dataset[1]['quantized_durations']), train_dataset[1]['singer_id']"
3379
  ]
3380
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3381
  {
3382
  "cell_type": "markdown",
3383
  "id": "3828f8dd",
 
3393
  "metadata": {},
3394
  "outputs": [],
3395
  "source": [
3396
+ "import toy_duration_predictor as tdp"
3397
  ]
3398
  },
3399
  {
 
3409
  "\n",
3410
  "==================== FINAL MODEL COMPARISON ====================\n",
3411
  "Using device for evaluation: cuda:0\n",
3412
+ "--- Preparing test data ---\n",
3413
+ "Calculated training set stats: Mean=128.66, Std=201.76\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3414
  "Evaluating models on the test set...\n"
3415
  ]
3416
  },
3417
  {
3418
  "data": {
3419
  "application/vnd.jupyter.widget-view+json": {
3420
+ "model_id": "de7de013364a40fa8d9af567b9140063",
3421
  "version_major": 2,
3422
  "version_minor": 0
3423
  },
 
3441
  "source": [
3442
  "model_a_path = 'model_stable/model_A.pth'\n",
3443
  "model_b_path = 'model_stable/model_B.pth'\n",
3444
+ "tdp.evaluate_and_compare(model_a_path, model_b_path, gpu_id=0)"
3445
  ]
3446
  },
3447
  {
 
3454
  },
3455
  {
3456
  "cell_type": "code",
3457
+ "execution_count": null,
3458
  "id": "7a54b60d",
3459
  "metadata": {},
3460
  "outputs": [],
3461
  "source": [
3462
+ "import toy_duration_predictor as tdp"
3463
  ]
3464
  },
3465
  {
3466
  "cell_type": "code",
3467
+ "execution_count": null,
3468
  "id": "fa70b5bf",
3469
  "metadata": {},
3470
  "outputs": [
 
3551
  }
3552
  ],
3553
  "source": [
3554
+ "tdp.upload_models_to_hub()"
3555
  ]
3556
  },
3557
  {
 
3569
  "metadata": {},
3570
  "outputs": [],
3571
  "source": [
3572
+ "import toy_duration_predictor as tdp"
3573
  ]
3574
  },
3575
  {
 
3598
  {
3599
  "data": {
3600
  "application/vnd.jupyter.widget-view+json": {
3601
+ "model_id": "65b38f6fa7724d73b91fc332eda22893",
3602
  "version_major": 2,
3603
  "version_minor": 0
3604
  },
 
3620
  }
3621
  ],
3622
  "source": [
3623
+ "tdp.load_and_test()"
3624
  ]
3625
  },
3626
  {