Spaces:
Sleeping
Sleeping
Commit ·
d3782ca
0
Parent(s):
Initial
Browse files- .gitignore +144 -0
- .python-version +1 -0
- README.md +15 -0
- pyproject.toml +34 -0
- sample/gv/SINGER_16_10TO29_CLEAR_FEMALE_BALLAD_C0632.json +0 -0
- sample/gv/SINGER_66_30TO49_HUSKY_MALE_DANCE_C2835.json +0 -0
- sample/mssv/ba_05688_-4_a_s02_m_02.mid +0 -0
- sample/mssv/ba_09303_+0_a_s02_m_02.mid +0 -0
- src/toy_duration_predictor/__init__.py +2 -0
- src/toy_duration_predictor/_legacy/train_fastai.py +169 -0
- src/toy_duration_predictor/_legacy/train_jax.py +173 -0
- src/toy_duration_predictor/_legacy/train_tensorflow.py +183 -0
- src/toy_duration_predictor/_legacy/train_torch_mlops.py +219 -0
- src/toy_duration_predictor/_legacy/train_torch_vanilla.py +172 -0
- src/toy_duration_predictor/preprocess/mssv.py +188 -0
- src/toy_duration_predictor/preprocess/utils.py +117 -0
- src/toy_duration_predictor/py.typed +0 -0
- src/toy_duration_predictor/train_fastai.py +297 -0
- src/toy_duration_predictor/train_lightning.py +350 -0
- src/toy_duration_predictor/upload.py +142 -0
- test.ipynb +0 -0
- test.py +17 -0
- test_wandb.py +31 -0
.gitignore
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
uv.lock
|
| 2 |
+
|
| 3 |
+
wandb/
|
| 4 |
+
toy-duration-predictor-lightning/
|
| 5 |
+
|
| 6 |
+
sample/gv/json_filled_time_gaps
|
| 7 |
+
sample/gv/json_time_adjusted
|
| 8 |
+
sample/gv/json_preprocessed
|
| 9 |
+
sample/gv/split_json
|
| 10 |
+
sample/mssv/preprocessed
|
| 11 |
+
preprocessed_gv/
|
| 12 |
+
preprocessed_mssv/
|
| 13 |
+
|
| 14 |
+
# Python-generated files
|
| 15 |
+
__pycache__/
|
| 16 |
+
*.py[oc]
|
| 17 |
+
build/
|
| 18 |
+
dist/
|
| 19 |
+
wheels/
|
| 20 |
+
*.egg-info
|
| 21 |
+
|
| 22 |
+
# Virtual environments
|
| 23 |
+
.venv
|
| 24 |
+
|
| 25 |
+
# Byte-compiled / optimized / DLL files
|
| 26 |
+
*.py[cod]
|
| 27 |
+
*$py.class
|
| 28 |
+
|
| 29 |
+
# C extensions
|
| 30 |
+
*.so
|
| 31 |
+
|
| 32 |
+
# Distribution / packaging
|
| 33 |
+
.Python
|
| 34 |
+
build/
|
| 35 |
+
develop-eggs/
|
| 36 |
+
dist/
|
| 37 |
+
downloads/
|
| 38 |
+
eggs/
|
| 39 |
+
.eggs/
|
| 40 |
+
lib/
|
| 41 |
+
lib64/
|
| 42 |
+
parts/
|
| 43 |
+
sdist/
|
| 44 |
+
var/
|
| 45 |
+
wheels/
|
| 46 |
+
pip-wheel-metadata/
|
| 47 |
+
share/python-wheels/
|
| 48 |
+
*.egg-info/
|
| 49 |
+
.installed.cfg
|
| 50 |
+
*.egg
|
| 51 |
+
MANIFEST
|
| 52 |
+
|
| 53 |
+
*.spec
|
| 54 |
+
|
| 55 |
+
# Installer logs
|
| 56 |
+
pip-log.txt
|
| 57 |
+
pip-delete-this-directory.txt
|
| 58 |
+
|
| 59 |
+
# Unit test / coverage reports
|
| 60 |
+
htmlcov/
|
| 61 |
+
.tox/
|
| 62 |
+
.nox/
|
| 63 |
+
.coverage
|
| 64 |
+
.coverage.*
|
| 65 |
+
.cache
|
| 66 |
+
nosetests.xml
|
| 67 |
+
coverage.xml
|
| 68 |
+
*.cover
|
| 69 |
+
.hypothesis/
|
| 70 |
+
.pytest_cache/
|
| 71 |
+
|
| 72 |
+
# Translations
|
| 73 |
+
*.mo
|
| 74 |
+
*.pot
|
| 75 |
+
|
| 76 |
+
# Django stuff:
|
| 77 |
+
*.log
|
| 78 |
+
local_settings.py
|
| 79 |
+
db.sqlite3
|
| 80 |
+
db.sqlite3-journal
|
| 81 |
+
|
| 82 |
+
# Flask stuff:
|
| 83 |
+
instance/
|
| 84 |
+
.webassets-cache
|
| 85 |
+
|
| 86 |
+
# Jupyter Notebook
|
| 87 |
+
.ipynb_checkpoints
|
| 88 |
+
|
| 89 |
+
# IPython
|
| 90 |
+
profile_default/
|
| 91 |
+
ipython_config.py
|
| 92 |
+
|
| 93 |
+
# Environments
|
| 94 |
+
.env
|
| 95 |
+
.venv
|
| 96 |
+
env/
|
| 97 |
+
venv/
|
| 98 |
+
ENV/
|
| 99 |
+
env.bak/
|
| 100 |
+
venv.bak/
|
| 101 |
+
|
| 102 |
+
# Spyder project settings
|
| 103 |
+
.spyderproject
|
| 104 |
+
.spyproject
|
| 105 |
+
|
| 106 |
+
# Rope project settings
|
| 107 |
+
.ropeproject
|
| 108 |
+
|
| 109 |
+
# PyDev project settings
|
| 110 |
+
.pydevproject
|
| 111 |
+
|
| 112 |
+
# PyCharm specific files
|
| 113 |
+
.idea/
|
| 114 |
+
*.iml
|
| 115 |
+
|
| 116 |
+
# VS Code specific files
|
| 117 |
+
.vscode/
|
| 118 |
+
|
| 119 |
+
# Sass cache files
|
| 120 |
+
.sass-cache/
|
| 121 |
+
|
| 122 |
+
# Conda environments
|
| 123 |
+
.conda/
|
| 124 |
+
|
| 125 |
+
# Mypy cache
|
| 126 |
+
.mypy_cache/
|
| 127 |
+
|
| 128 |
+
# pytype
|
| 129 |
+
.pytype/
|
| 130 |
+
|
| 131 |
+
# Cache and log files created by tools
|
| 132 |
+
.cache/
|
| 133 |
+
*.log
|
| 134 |
+
log/
|
| 135 |
+
logs/
|
| 136 |
+
|
| 137 |
+
# OS generated files
|
| 138 |
+
.DS_Store
|
| 139 |
+
.DS_Store?
|
| 140 |
+
._*
|
| 141 |
+
.Spotlight-V100
|
| 142 |
+
.Trashes
|
| 143 |
+
ehthumbs.db
|
| 144 |
+
Thumbs.db
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.12
|
README.md
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# a
|
| 2 |
+
|
| 3 |
+
aaaaaaaaaa [midii](https://github.com/ccss17/midii)
|
| 4 |
+
|
| 5 |
+
## Installation
|
| 6 |
+
|
| 7 |
+
```shell
|
| 8 |
+
pip install git+https://github.com/ccss17/toy-duration-predictor.git
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
## Usage
|
| 12 |
+
|
| 13 |
+
```python
|
| 14 |
+
import toy_duration_predictor as tdp
|
| 15 |
+
```
|
pyproject.toml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "toy-duration-predictor"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
authors = [
|
| 7 |
+
{ name = "ccsss", email = "chansol0505@naver.com" }
|
| 8 |
+
]
|
| 9 |
+
requires-python = ">=3.9"
|
| 10 |
+
dependencies = [
|
| 11 |
+
"datasets>=3.6.0",
|
| 12 |
+
"fastai>=2.1.10",
|
| 13 |
+
"gradio>=4.44.1",
|
| 14 |
+
"lightning>=2.5.2",
|
| 15 |
+
"midii>=0.1.19",
|
| 16 |
+
"numpy>=2.0.2",
|
| 17 |
+
"pandas>=2.3.0",
|
| 18 |
+
"ray[data,serve,train,tune]>=2.47.1",
|
| 19 |
+
"torch>=2.7.1",
|
| 20 |
+
"wandb>=0.20.1",
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
[build-system]
|
| 24 |
+
requires = ["hatchling"]
|
| 25 |
+
build-backend = "hatchling.build"
|
| 26 |
+
|
| 27 |
+
[dependency-groups]
|
| 28 |
+
dev = [
|
| 29 |
+
"black>=25.1.0",
|
| 30 |
+
"ipykernel>=6.29.5",
|
| 31 |
+
"ipywidgets>=8.1.7",
|
| 32 |
+
"rich>=14.0.0",
|
| 33 |
+
"ruff>=0.12.0",
|
| 34 |
+
]
|
sample/gv/SINGER_16_10TO29_CLEAR_FEMALE_BALLAD_C0632.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample/gv/SINGER_66_30TO49_HUSKY_MALE_DANCE_C2835.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample/mssv/ba_05688_-4_a_s02_m_02.mid
ADDED
|
Binary file (4.81 kB). View file
|
|
|
sample/mssv/ba_09303_+0_a_s02_m_02.mid
ADDED
|
Binary file (4.5 kB). View file
|
|
|
src/toy_duration_predictor/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from .preprocess import mssv
|
| 2 |
+
# from .preprocess import utils
|
src/toy_duration_predictor/_legacy/train_fastai.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- 0. 필요 라이브러리 설치 ---
|
| 2 |
+
# 이 스크립트를 실행하기 전에 먼저 터미널에서 아래 명령어를 실행해주세요.
|
| 3 |
+
# pip install torch fastai wandb gradio
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.utils.data import TensorDataset
|
| 8 |
+
from fastai.data.core import DataLoaders
|
| 9 |
+
from fastai.learner import Learner, pérdida_Calculada
|
| 10 |
+
from fastai.callback.wandb import WandbCallback
|
| 11 |
+
from fastai.callback.schedule import lr_find
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pandas as pd
|
| 14 |
+
import gradio as gr
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
# --- 1. 하이퍼파라미터 및 상수 정의 ---
|
| 18 |
+
MAX_SEQ_LENGTH = 32
|
| 19 |
+
NUM_SINGERS = 100
|
| 20 |
+
NUM_SAMPLES = 100000
|
| 21 |
+
BATCH_SIZE = 256
|
| 22 |
+
|
| 23 |
+
# 모델 구조 관련 파라미터 (fastai Learner에 전달)
|
| 24 |
+
SID_EMBEDDING_DIM = 16
|
| 25 |
+
GRU_UNITS = 128
|
| 26 |
+
NUM_GRU_LAYERS = 2
|
| 27 |
+
|
| 28 |
+
# --- 2. PyTorch 모델 아키텍처 정의 (fastai는 순수 PyTorch 모델을 그대로 사용) ---
|
| 29 |
+
class DurationPredictorGRU(nn.Module):
|
| 30 |
+
"""
|
| 31 |
+
fastai의 Learner가 래핑할 순수 PyTorch 모델.
|
| 32 |
+
"""
|
| 33 |
+
def __init__(self, num_singers, sid_embedding_dim, gru_units, num_gru_layers):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.sid_embedding = nn.Embedding(num_singers, sid_embedding_dim)
|
| 36 |
+
gru_input_dim = 1 + sid_embedding_dim
|
| 37 |
+
self.gru = nn.GRU(
|
| 38 |
+
gru_input_dim, gru_units, num_gru_layers,
|
| 39 |
+
batch_first=True, bidirectional=True
|
| 40 |
+
)
|
| 41 |
+
self.fc_out = nn.Linear(gru_units * 2, 1)
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
# fastai는 입력을 튜플로 묶어서 전달합니다.
|
| 45 |
+
duration_input, sid_input = x
|
| 46 |
+
sid_embedded = self.sid_embedding(sid_input)
|
| 47 |
+
duration_reshaped = duration_input.unsqueeze(-1)
|
| 48 |
+
|
| 49 |
+
features = torch.cat([duration_reshaped, sid_embedded], dim=-1)
|
| 50 |
+
gru_output, _ = self.gru(features)
|
| 51 |
+
predictions = self.fc_out(gru_output)
|
| 52 |
+
return predictions
|
| 53 |
+
|
| 54 |
+
# --- 3. 데이터 준비 ---
|
| 55 |
+
print("--- 데이터셋 준비 중... ---")
|
| 56 |
+
|
| 57 |
+
# 가상의 전체 데이터셋 생성 (DataFrame으로 관리하면 편리)
|
| 58 |
+
data = {
|
| 59 |
+
'durations': [torch.rand(MAX_SEQ_LENGTH) for _ in range(NUM_SAMPLES)],
|
| 60 |
+
'sids': [torch.randint(0, NUM_SINGERS, (MAX_SEQ_LENGTH,)) for _ in range(NUM_SAMPLES)],
|
| 61 |
+
'labels': [d * torch.rand_like(d) * 2 for d in [d['durations'] for d in [{'durations': data} for data in [{'durations': torch.rand(MAX_SEQ_LENGTH)}] * NUM_SAMPLES]]]
|
| 62 |
+
}
|
| 63 |
+
df = pd.DataFrame(data)
|
| 64 |
+
|
| 65 |
+
# 훈련(80%), 검증(10%), 테스트(10%) 인덱스 생성
|
| 66 |
+
np.random.seed(42)
|
| 67 |
+
indices = np.random.permutation(len(df))
|
| 68 |
+
test_split_idx = int(len(df) * 0.1)
|
| 69 |
+
val_split_idx = int(len(df) * 0.2)
|
| 70 |
+
|
| 71 |
+
test_indices = indices[:test_split_idx]
|
| 72 |
+
val_indices = indices[test_split_idx:val_split_idx]
|
| 73 |
+
train_indices = indices[val_split_idx:]
|
| 74 |
+
|
| 75 |
+
# fastai의 DataLoaders 객체 생성
|
| 76 |
+
# 입력(x)은 튜플, 출력(y)은 단일 텐서로 구성
|
| 77 |
+
train_ds = TensorDataset(torch.stack(df.loc[train_indices, 'durations'].tolist()),
|
| 78 |
+
torch.stack(df.loc[train_indices, 'sids'].tolist()),
|
| 79 |
+
torch.stack(df.loc[train_indices, 'labels'].tolist()).unsqueeze(-1))
|
| 80 |
+
|
| 81 |
+
val_ds = TensorDataset(torch.stack(df.loc[val_indices, 'durations'].tolist()),
|
| 82 |
+
torch.stack(df.loc[val_indices, 'sids'].tolist()),
|
| 83 |
+
torch.stack(df.loc[val_indices, 'labels'].tolist()).unsqueeze(-1))
|
| 84 |
+
|
| 85 |
+
test_ds = TensorDataset(torch.stack(df.loc[test_indices, 'durations'].tolist()),
|
| 86 |
+
torch.stack(df.loc[test_indices, 'sids'].tolist()),
|
| 87 |
+
torch.stack(df.loc[test_indices, 'labels'].tolist()).unsqueeze(-1))
|
| 88 |
+
|
| 89 |
+
# fastai의 DataLoaders로 래핑
|
| 90 |
+
# 입력(x)을 튜플로 묶기 위해 x_cat=2
|
| 91 |
+
dls = DataLoaders.from_dsets(train_ds, val_ds, bs=BATCH_SIZE, device='cuda' if torch.cuda.is_available() else 'cpu')
|
| 92 |
+
test_dl = dls.test_dl(test_ds, with_labels=True)
|
| 93 |
+
|
| 94 |
+
print(f"훈련 데이터 샘플 수: {len(train_ds)}")
|
| 95 |
+
print(f"검증 데이터 샘플 수: {len(val_ds)}")
|
| 96 |
+
print(f"테스트 데이터 샘플 수: {len(test_ds)}")
|
| 97 |
+
|
| 98 |
+
# --- 4. fastai Learner 생성 및 훈련 ---
|
| 99 |
+
|
| 100 |
+
# 모델 인스턴스화
|
| 101 |
+
model = DurationPredictorGRU(NUM_SINGERS, SID_EMBEDDING_DIM, GRU_UNITS, NUM_GRU_LAYERS)
|
| 102 |
+
|
| 103 |
+
# Learner 생성 (모델, 데이터, 손실 함수, 콜백 등을 모두 묶음)
|
| 104 |
+
learn = Learner(dls, model, loss_func=nn.MSELoss(), cbs=WandbCallback(log_preds=False))
|
| 105 |
+
|
| 106 |
+
# --- 4a. 최적의 학습률 탐색 (Optuna 대신 사용) ---
|
| 107 |
+
print("\n--- 1. 최적의 학습률 탐색 시작 (fastai lr_find) ---")
|
| 108 |
+
# lr_find() 실행 후, 가장 가파른 기울기를 가진 지점의 학습률을 사용하는 것이 일반적
|
| 109 |
+
suggested_lr = learn.lr_find(suggest_funcs=(lr_find.valley, lr_find.slide))
|
| 110 |
+
print(f"fastai가 제안하는 최적 학습률: {suggested_lr.valley:.2e}")
|
| 111 |
+
|
| 112 |
+
# --- 4b. 모델 훈련 ---
|
| 113 |
+
print("\n--- 2. 제안된 학습률로 모델 훈련 시작 ---")
|
| 114 |
+
# fine_tune은 헤드는 제안된 학습률로, 몸통은 더 낮은 학습률��� 훈련하는 등
|
| 115 |
+
# 여러 best practice가 적용된 강력한 훈련 메소드
|
| 116 |
+
learn.fine_tune(10, base_lr=suggested_lr.valley)
|
| 117 |
+
|
| 118 |
+
print("모델 훈련 완료!")
|
| 119 |
+
|
| 120 |
+
# --- 5. 최종 성능 평가 (테스트셋) ---
|
| 121 |
+
print("\n--- 3. 최종 모델 평가 시작 (테스트 데이터셋 사용) ---")
|
| 122 |
+
# get_preds를 사용하여 테스트셋에 대한 예측 및 손실 계산
|
| 123 |
+
preds, targs, test_loss = learn.get_preds(dl=test_dl, with_loss=True)
|
| 124 |
+
print(f"최종 테스트 손실 (MSE): {test_loss.item():.6f}")
|
| 125 |
+
|
| 126 |
+
# --- 6. Gradio 데모 실행 ---
|
| 127 |
+
print("\n--- 4. Gradio 데모 인터페이스 실행 ---")
|
| 128 |
+
learn.model.eval() # 추론을 위해 모델을 평가 모드로 전환
|
| 129 |
+
|
| 130 |
+
def predict_duration_fastai(singer_id_str, duration_sequence_str):
|
| 131 |
+
try:
|
| 132 |
+
# 입력 파싱 및 텐서화
|
| 133 |
+
singer_id = int(singer_id_str)
|
| 134 |
+
durations = [float(d.strip()) for d in duration_sequence_str.split(',')]
|
| 135 |
+
|
| 136 |
+
if len(durations) > MAX_SEQ_LENGTH:
|
| 137 |
+
durations = durations[:MAX_SEQ_LENGTH]
|
| 138 |
+
else:
|
| 139 |
+
durations += [0] * (MAX_SEQ_LENGTH - len(durations))
|
| 140 |
+
|
| 141 |
+
duration_tensor = torch.tensor(durations, dtype=torch.float32).unsqueeze(0)
|
| 142 |
+
sid_tensor = torch.full_like(duration_tensor, singer_id, dtype=torch.long)
|
| 143 |
+
|
| 144 |
+
# fastai Learner를 사용한 예측
|
| 145 |
+
# learn.predict는 단일 아이템에 대한 예측과 디코딩을 수행
|
| 146 |
+
# 여기서는 모델 직접 호출이 더 간단
|
| 147 |
+
with torch.no_grad():
|
| 148 |
+
prediction = learn.model((duration_tensor.to(learn.dls.device), sid_tensor.to(learn.dls.device)))
|
| 149 |
+
|
| 150 |
+
output_sequence = prediction.squeeze().cpu().tolist()
|
| 151 |
+
return ", ".join([f"{x:.4f}" for x in output_sequence])
|
| 152 |
+
|
| 153 |
+
except Exception as e:
|
| 154 |
+
return f"오류 발생: {e}"
|
| 155 |
+
|
| 156 |
+
# Gradio 인터페이스 생성 및 실행
|
| 157 |
+
iface = gr.Interface(
|
| 158 |
+
fn=predict_duration_fastai,
|
| 159 |
+
inputs=[
|
| 160 |
+
gr.Textbox(label="가수 ID (Singer ID)", value="10"),
|
| 161 |
+
gr.Textbox(label="음표 길이 시퀀스 (쉼표로 구분)",
|
| 162 |
+
value="0.1, 0.2, 0.15, 0.5, 0.4, 0.12, 0.1, 0.25")
|
| 163 |
+
],
|
| 164 |
+
outputs=gr.Textbox(label="예측된 음표 길이 시퀀스"),
|
| 165 |
+
title="🎵 Duration Predictor (fastai + MLOps)",
|
| 166 |
+
description="fastai로 훈련된 모델입니다. 가수 ID와 정규 음표 길이 시퀀스를 입력하면, 해당 가수의 고유한 리듬 표현이 적용된 음표 길이를 예측합니다."
|
| 167 |
+
)
|
| 168 |
+
iface.launch()
|
| 169 |
+
|
src/toy_duration_predictor/_legacy/train_jax.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- 0. 필요 라이브러리 설치 ---
|
| 2 |
+
# 이 스크립트를 실행하기 전에 먼저 터미널에서 아래 명령어를 실행해주세요.
|
| 3 |
+
# CPU 버전: pip install jax flax optax elegy
|
| 4 |
+
# GPU 버전: pip install jax[cuda11_pip] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
| 5 |
+
# pip install flax optax elegy
|
| 6 |
+
|
| 7 |
+
import jax
|
| 8 |
+
import jax.numpy as jnp
|
| 9 |
+
import flax.linen as nn
|
| 10 |
+
import optax
|
| 11 |
+
import elegy
|
| 12 |
+
import numpy as np
|
| 13 |
+
from sklearn.model_selection import train_test_split
|
| 14 |
+
|
| 15 |
+
# --- 1. 하이퍼파라미터 및 상수 정의 ---
|
| 16 |
+
MAX_SEQ_LENGTH = 32
|
| 17 |
+
NUM_SINGERS = 100
|
| 18 |
+
NUM_SAMPLES = 100000
|
| 19 |
+
BATCH_SIZE = 256
|
| 20 |
+
|
| 21 |
+
# 모델 구조 관련 파라미터
|
| 22 |
+
SID_EMBEDDING_DIM = 16
|
| 23 |
+
GRU_UNITS = 128
|
| 24 |
+
NUM_GRU_LAYERS = 2 # Flax의 GRUCell은 num_layers를 직접 지원하지 않으므로, 루프로 구현해야 합니다.
|
| 25 |
+
# 이 예제에서는 간결성을 위해 1개 층으로 구현합니다.
|
| 26 |
+
LEARNING_RATE = 0.001
|
| 27 |
+
NUM_EPOCHS = 20
|
| 28 |
+
|
| 29 |
+
# --- 2. Flax를 사용한 모델 아키텍처 정의 ---
|
| 30 |
+
class DurationPredictorGRU(nn.Module):
|
| 31 |
+
"""
|
| 32 |
+
Flax를 사용하여 정의한 Duration Predictor 모델.
|
| 33 |
+
JAX의 순수 함수 철학을 따릅니다.
|
| 34 |
+
"""
|
| 35 |
+
num_singers: int
|
| 36 |
+
sid_embedding_dim: int
|
| 37 |
+
gru_units: int
|
| 38 |
+
|
| 39 |
+
@nn.compact
|
| 40 |
+
def __call__(self, x):
|
| 41 |
+
# Elegy는 입력을 튜플/리스트 대신 딕셔너리로 받는 것을 선호합니다.
|
| 42 |
+
duration_input = x['duration_input']
|
| 43 |
+
sid_input = x['sid_input']
|
| 44 |
+
|
| 45 |
+
# 1. SID 임베딩
|
| 46 |
+
sid_embedded = nn.Embed(
|
| 47 |
+
num_embeddings=self.num_singers,
|
| 48 |
+
features=self.sid_embedding_dim
|
| 49 |
+
)(sid_input)
|
| 50 |
+
|
| 51 |
+
# 2. 음표 길이 차원 확장
|
| 52 |
+
duration_reshaped = jnp.expand_dims(duration_input, axis=-1)
|
| 53 |
+
|
| 54 |
+
# 3. 피처 연결
|
| 55 |
+
features = jnp.concatenate([duration_reshaped, sid_embedded], axis=-1)
|
| 56 |
+
|
| 57 |
+
# 4. 양방향 GRU
|
| 58 |
+
# Flax의 Bidirectional 래퍼는 RNNCell을 감싸서 양방향으로 만듭니다.
|
| 59 |
+
gru_cell = nn.GRUCell(features=self.gru_units)
|
| 60 |
+
gru_output = nn.Bidirectional(gru_cell)(features)
|
| 61 |
+
|
| 62 |
+
# 5. 출력층
|
| 63 |
+
# Flax의 Dense는 시퀀스 입력에 대해 자동으로 Time-Distributed처럼 작동합니다.
|
| 64 |
+
predictions = nn.Dense(features=1)(gru_output)
|
| 65 |
+
|
| 66 |
+
return predictions
|
| 67 |
+
|
| 68 |
+
# --- 3. 데이터 준비 ---
|
| 69 |
+
print("--- 데이터셋 준비 중... ---")
|
| 70 |
+
|
| 71 |
+
# 가상의 Numpy 데이터셋 생성
|
| 72 |
+
durations = np.random.rand(NUM_SAMPLES, MAX_SEQ_LENGTH).astype(np.float32)
|
| 73 |
+
sids = np.random.randint(0, NUM_SINGERS, (NUM_SAMPLES, MAX_SEQ_LENGTH)).astype(np.int32)
|
| 74 |
+
labels = (durations * np.random.rand(NUM_SAMPLES, MAX_SEQ_LENGTH) * 2).astype(np.float32)
|
| 75 |
+
|
| 76 |
+
# 훈련(80%), 검증/테스트(20%)로 먼저 분할
|
| 77 |
+
dur_train, dur_rem, sids_train, sids_rem, y_train, y_rem = train_test_split(
|
| 78 |
+
durations, sids, labels, test_size=0.2, random_state=42)
|
| 79 |
+
|
| 80 |
+
# 검증(10%), 테스트(10%)로 분할
|
| 81 |
+
dur_val, dur_test, sids_val, sids_test, y_val, y_test = train_test_split(
|
| 82 |
+
dur_rem, sids_rem, y_rem, test_size=0.5, random_state=42)
|
| 83 |
+
|
| 84 |
+
# Elegy가 사용할 수 있도록 입력 데이터를 딕셔너리 형태로 묶습니다.
|
| 85 |
+
X_train = {'duration_input': dur_train, 'sid_input': sids_train}
|
| 86 |
+
X_val = {'duration_input': dur_val, 'sid_input': sids_val}
|
| 87 |
+
X_test = {'duration_input': dur_test, 'sid_input': sids_test}
|
| 88 |
+
|
| 89 |
+
print(f"훈련 데이터 샘플 수: {len(y_train)}")
|
| 90 |
+
print(f"검증 데이터 샘플 수: {len(y_val)}")
|
| 91 |
+
print(f"테스트 데이터 샘플 수: {len(y_test)}")
|
| 92 |
+
|
| 93 |
+
# --- 4. Elegy를 사용한 모델 훈련 및 평가 ---
|
| 94 |
+
|
| 95 |
+
# Elegy 모델 생성
|
| 96 |
+
# Keras와 매우 유사하게, 모듈, 손실함수, 옵티마이저, 메트릭을 정의합니다.
|
| 97 |
+
model = elegy.Model(
|
| 98 |
+
module=DurationPredictorGRU(
|
| 99 |
+
num_singers=NUM_SINGERS,
|
| 100 |
+
sid_embedding_dim=SID_EMBEDDING_DIM,
|
| 101 |
+
gru_units=GRU_UNITS
|
| 102 |
+
),
|
| 103 |
+
loss=elegy.losses.MeanSquaredError(),
|
| 104 |
+
optimizer=optax.adam(learning_rate=LEARNING_RATE),
|
| 105 |
+
metrics=[elegy.metrics.MeanAbsoluteError()]
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
print("\n--- 모델 훈련 시작 (Elegy)... ---")
|
| 109 |
+
# Keras와 거의 동일한 .fit() API를 사용
|
| 110 |
+
# (참고: WandB, Gradio 등은 Elegy의 콜백 시스템을 통해 연동 가능합니다)
|
| 111 |
+
history = model.fit(
|
| 112 |
+
x=X_train,
|
| 113 |
+
y=y_train,
|
| 114 |
+
epochs=NUM_EPOCHS,
|
| 115 |
+
batch_size=BATCH_SIZE,
|
| 116 |
+
validation_data=(X_val, y_val),
|
| 117 |
+
callbacks=[elegy.callbacks.EarlyStopping(monitor='val_loss', patience=5)],
|
| 118 |
+
shuffle=True
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
print("모델 훈련 완료!")
|
| 122 |
+
|
| 123 |
+
# --- 5. 최종 성능 평가 (테스트셋) ---
|
| 124 |
+
print("\n--- 최종 모델 평가 시작 (테스트 데이터셋 사용)... ---")
|
| 125 |
+
|
| 126 |
+
# .evaluate() API를 사용하여 최종 성능 측정
|
| 127 |
+
test_metrics = model.evaluate(X_test, y_test)
|
| 128 |
+
print("-" * 50)
|
| 129 |
+
print(f"최종 테스트 결과: {test_metrics}")
|
| 130 |
+
print("-" * 50)
|
| 131 |
+
|
| 132 |
+
# --- 6. Gradio 데모를 위한 예측 함수 (예시) ---
|
| 133 |
+
# Elegy 모델은 내부적으로 JAX의 JIT 컴파일을 사용하여 예측 속도가 매우 빠릅니다.
|
| 134 |
+
@jax.jit
|
| 135 |
+
def predict_fn(params, x):
|
| 136 |
+
return model.module.apply({'params': params}, x)
|
| 137 |
+
|
| 138 |
+
def gradio_predict(singer_id_str, duration_sequence_str):
|
| 139 |
+
try:
|
| 140 |
+
singer_id = int(singer_id_str)
|
| 141 |
+
durations = [float(d.strip()) for d in duration_sequence_str.split(',')]
|
| 142 |
+
|
| 143 |
+
if len(durations) > MAX_SEQ_LENGTH:
|
| 144 |
+
durations = durations[:MAX_SEQ_LENGTH]
|
| 145 |
+
else:
|
| 146 |
+
durations += [0] * (MAX_SEQ_LENGTH - len(durations))
|
| 147 |
+
|
| 148 |
+
duration_np = np.array(durations, dtype=np.float32).reshape(1, -1)
|
| 149 |
+
sid_np = np.full_like(duration_np, singer_id, dtype=np.int32)
|
| 150 |
+
|
| 151 |
+
input_dict = {'duration_input': duration_np, 'sid_input': sid_np}
|
| 152 |
+
|
| 153 |
+
# JIT 컴파일된 함수로 예측 실행
|
| 154 |
+
prediction = predict_fn(model.states.params, input_dict)
|
| 155 |
+
|
| 156 |
+
output_sequence = np.asarray(prediction).flatten().tolist()
|
| 157 |
+
return ", ".join([f"{x:.4f}" for x in output_sequence])
|
| 158 |
+
except Exception as e:
|
| 159 |
+
return f"오류 발생: {e}"
|
| 160 |
+
|
| 161 |
+
# (Gradio 실행 부분은 주석 처리. 필요시 주석 해제하여 사용)
|
| 162 |
+
# print("\n--- Gradio 데모 인터페이스 실행 ---")
|
| 163 |
+
# iface = gr.Interface(
|
| 164 |
+
# fn=gradio_predict,
|
| 165 |
+
# inputs=[
|
| 166 |
+
# gr.Textbox(label="가수 ID (Singer ID)", value="10"),
|
| 167 |
+
# gr.Textbox(label="음표 길이 시퀀스 (쉼표로 구분)", value="0.1, 0.2, 0.15, 0.5")
|
| 168 |
+
# ],
|
| 169 |
+
# outputs=gr.Textbox(label="예측된 음표 길이 시퀀스"),
|
| 170 |
+
# title="🎵 Duration Predictor (JAX/Flax/Elegy)",
|
| 171 |
+
# description="JAX 생태계로 훈련된 모델입니다. 가수 ID와 정규 음표 길이 시퀀스를 입력하면, 해당 가수의 고유한 리듬 표현이 적용된 음표 길이를 예측합니다."
|
| 172 |
+
# )
|
| 173 |
+
# iface.launch()
|
src/toy_duration_predictor/_legacy/train_tensorflow.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- 0. 필요 라이브러리 설치 ---
|
| 2 |
+
# 이 스크립트를 실행하기 전에 먼저 터미널에서 아래 명령어를 실행해주세요.
|
| 3 |
+
# pip install tensorflow numpy wandb keras-tuner gradio
|
| 4 |
+
|
| 5 |
+
import tensorflow as tf
|
| 6 |
+
from tensorflow import keras
|
| 7 |
+
from tensorflow.keras import layers, Model
|
| 8 |
+
import numpy as np
|
| 9 |
+
import keras_tuner as kt
|
| 10 |
+
from wandb.keras import WandbCallback
|
| 11 |
+
import gradio as gr
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
# --- 1. 하이퍼파라미터 및 상수 정의 ---
|
| 15 |
+
MAX_SEQ_LENGTH = 32
|
| 16 |
+
NUM_SINGERS = 100
|
| 17 |
+
NUM_SAMPLES = 100000
|
| 18 |
+
BATCH_SIZE = 256
|
| 19 |
+
BUFFER_SIZE = 10000 # tf.data.Dataset 셔플을 위한 버퍼 크기
|
| 20 |
+
|
| 21 |
+
# --- 2. 데이터 준비 (tf.data.Dataset 사용) ---
|
| 22 |
+
print("--- 데이터셋 준비 중... ---")
|
| 23 |
+
|
| 24 |
+
def generate_dummy_data():
|
| 25 |
+
"""가상의 데이터셋을 생성하는 제너레이터 함수"""
|
| 26 |
+
for _ in range(NUM_SAMPLES):
|
| 27 |
+
duration = np.random.rand(MAX_SEQ_LENGTH).astype(np.float32)
|
| 28 |
+
sid = np.random.randint(0, NUM_SINGERS, (MAX_SEQ_LENGTH,)).astype(np.int32)
|
| 29 |
+
label = (duration * np.random.rand(MAX_SEQ_LENGTH) * 2).astype(np.float32)
|
| 30 |
+
# Keras 모델은 입력과 출력을 딕셔너리 형태로 받는 것이 편리합니다.
|
| 31 |
+
yield {'duration_input': duration, 'sid_input': sid}, label
|
| 32 |
+
|
| 33 |
+
# tf.data.Dataset 객체 생성
|
| 34 |
+
full_dataset = tf.data.Dataset.from_generator(
|
| 35 |
+
generate_dummy_data,
|
| 36 |
+
output_signature=(
|
| 37 |
+
{'duration_input': tf.TensorSpec(shape=(MAX_SEQ_LENGTH,), dtype=tf.float32),
|
| 38 |
+
'sid_input': tf.TensorSpec(shape=(MAX_SEQ_LENGTH,), dtype=tf.int32)},
|
| 39 |
+
tf.TensorSpec(shape=(MAX_SEQ_LENGTH,), dtype=tf.float32)
|
| 40 |
+
)
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# 훈련(80%), 검증(10%), 테스트(10%) 데이터셋으로 분할
|
| 44 |
+
full_dataset = full_dataset.shuffle(BUFFER_SIZE, seed=42) # 분할 전 전체 셔플
|
| 45 |
+
train_size = int(0.8 * NUM_SAMPLES)
|
| 46 |
+
val_size = int(0.1 * NUM_SAMPLES)
|
| 47 |
+
|
| 48 |
+
train_dataset = full_dataset.take(train_size)
|
| 49 |
+
val_and_test_dataset = full_dataset.skip(train_size)
|
| 50 |
+
val_dataset = val_and_test_dataset.take(val_size)
|
| 51 |
+
test_dataset = val_and_test_dataset.skip(val_size)
|
| 52 |
+
|
| 53 |
+
# 데이터로더 생성 (배치, 프리페치 등 최적화)
|
| 54 |
+
train_loader = train_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
|
| 55 |
+
val_loader = val_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
|
| 56 |
+
test_loader = test_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
|
| 57 |
+
|
| 58 |
+
print(f"훈련 데이터 샘플 수: {train_size}")
|
| 59 |
+
print(f"검증 데이터 샘플 수: {val_size}")
|
| 60 |
+
print(f"테스트 데이터 샘플 수: {NUM_SAMPLES - train_size - val_size}")
|
| 61 |
+
|
| 62 |
+
# --- 3. KerasTuner를 사용한 하이퍼파라미터 최적화 ---
|
| 63 |
+
|
| 64 |
+
def build_model(hp: kt.HyperParameters):
|
| 65 |
+
"""KerasTuner가 하이퍼파라미터를 탐색하기 위한 모델 빌드 함수"""
|
| 66 |
+
|
| 67 |
+
# 입력층 정의
|
| 68 |
+
duration_input = layers.Input(shape=(MAX_SEQ_LENGTH,), name='duration_input')
|
| 69 |
+
sid_input = layers.Input(shape=(MAX_SEQ_LENGTH,), name='sid_input')
|
| 70 |
+
|
| 71 |
+
# 하이퍼파라미터 탐색 공간 정의
|
| 72 |
+
sid_embedding_dim = hp.Choice('sid_embedding_dim', values=[8, 16, 32])
|
| 73 |
+
gru_units = hp.Choice('gru_units', values=[64, 128])
|
| 74 |
+
learning_rate = hp.Float('learning_rate', min_value=1e-4, max_value=1e-2, sampling='log')
|
| 75 |
+
|
| 76 |
+
# 모델 레이어
|
| 77 |
+
sid_embedding = layers.Embedding(input_dim=NUM_SINGERS, output_dim=sid_embedding_dim)(sid_input)
|
| 78 |
+
duration_reshaped = layers.Reshape((MAX_SEQ_LENGTH, 1))(duration_input)
|
| 79 |
+
|
| 80 |
+
x = layers.Concatenate()([duration_reshaped, sid_embedding])
|
| 81 |
+
|
| 82 |
+
# Keras의 GRU는 기본적으로 dropout 인자를 가짐
|
| 83 |
+
x = layers.Bidirectional(layers.GRU(gru_units, return_sequences=True))(x)
|
| 84 |
+
x = layers.Bidirectional(layers.GRU(gru_units, return_sequences=True))(x)
|
| 85 |
+
|
| 86 |
+
outputs = layers.TimeDistributed(layers.Dense(1, activation='linear'))(x)
|
| 87 |
+
|
| 88 |
+
model = Model(inputs=[duration_input, sid_input], outputs=outputs)
|
| 89 |
+
|
| 90 |
+
model.compile(optimizer=keras.optimizers.Adam(learning_rate), loss='mean_squared_error')
|
| 91 |
+
|
| 92 |
+
return model
|
| 93 |
+
|
| 94 |
+
print("\n--- 1. 하이퍼파라미터 최적화 시작 (KerasTuner) ---")
|
| 95 |
+
tuner = kt.Hyperband(
|
| 96 |
+
build_model,
|
| 97 |
+
objective='val_loss',
|
| 98 |
+
max_epochs=10,
|
| 99 |
+
factor=3,
|
| 100 |
+
directory='keras_tuner_dir',
|
| 101 |
+
project_name='duration_predictor'
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# KerasTuner 실행
|
| 105 |
+
tuner.search(train_loader, epochs=10, validation_data=val_loader, callbacks=[keras.callbacks.EarlyStopping(patience=3)])
|
| 106 |
+
|
| 107 |
+
# 최적 하이퍼파라미터 추출
|
| 108 |
+
best_hps = tuner.get_best_hyperparameters(num_trials=1)[0]
|
| 109 |
+
print("최적화 완료!")
|
| 110 |
+
print(f"최적의 학습률: {best_hps.get('learning_rate')}")
|
| 111 |
+
print(f"최적의 임베딩 차원: {best_hps.get('sid_embedding_dim')}")
|
| 112 |
+
print(f"최적의 GRU 유닛 수: {best_hps.get('gru_units')}")
|
| 113 |
+
|
| 114 |
+
# --- 4. 최종 모델 훈련 및 평가 ---
|
| 115 |
+
print("\n--- 2. 최적의 하이퍼파라미터로 최종 모델 학습 및 평가 시작 ---")
|
| 116 |
+
|
| 117 |
+
# WandB 초기화
|
| 118 |
+
import wandb
|
| 119 |
+
wandb.init(project="duration_predictor_tf_keras", config=best_hps.values)
|
| 120 |
+
|
| 121 |
+
# 최적의 하이퍼파라미터로 최종 모델 빌드
|
| 122 |
+
final_model = tuner.hypermodel.build(best_hps)
|
| 123 |
+
|
| 124 |
+
# 체크포인트 및 조기 종료 콜백 설정
|
| 125 |
+
checkpoint_cb = keras.callbacks.ModelCheckpoint("best_model.keras", save_best_only=True)
|
| 126 |
+
early_stopping_cb = keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True)
|
| 127 |
+
|
| 128 |
+
# 모델 훈련
|
| 129 |
+
final_model.fit(
|
| 130 |
+
train_loader,
|
| 131 |
+
epochs=50,
|
| 132 |
+
validation_data=val_loader,
|
| 133 |
+
callbacks=[WandbCallback(), checkpoint_cb, early_stopping_cb]
|
| 134 |
+
)
|
| 135 |
+
print("최종 모델 학습 완료!")
|
| 136 |
+
|
| 137 |
+
# --- 5. 최종 성능 평가 (테스트셋) ---
|
| 138 |
+
print("\n--- 3. 최종 모델 평가 시작 (테스트 데이터셋 사용) ---")
|
| 139 |
+
best_model = keras.models.load_model("best_model.keras")
|
| 140 |
+
test_loss = best_model.evaluate(test_loader)
|
| 141 |
+
print("-" * 50)
|
| 142 |
+
print(f"최종 테스트 손실 (MSE): {test_loss:.6f}")
|
| 143 |
+
wandb.log({"test_loss": test_loss})
|
| 144 |
+
wandb.finish()
|
| 145 |
+
|
| 146 |
+
# --- 6. Gradio 데모 실행 ---
|
| 147 |
+
print("\n--- 4. Gradio 데모 인터페이스 실행 ---")
|
| 148 |
+
|
| 149 |
+
def predict_duration_keras(singer_id_str, duration_sequence_str):
|
| 150 |
+
try:
|
| 151 |
+
singer_id = int(singer_id_str)
|
| 152 |
+
durations = [float(d.strip()) for d in duration_sequence_str.split(',')]
|
| 153 |
+
|
| 154 |
+
if len(durations) > MAX_SEQ_LENGTH:
|
| 155 |
+
durations = durations[:MAX_SEQ_LENGTH]
|
| 156 |
+
else:
|
| 157 |
+
durations += [0] * (MAX_SEQ_LENGTH - len(durations))
|
| 158 |
+
|
| 159 |
+
# Keras 모델 입력 형태로 변환 (Numpy 배열)
|
| 160 |
+
duration_np = np.array(durations, dtype=np.float32).reshape(1, -1)
|
| 161 |
+
sid_np = np.full_like(duration_np, singer_id, dtype=np.int32)
|
| 162 |
+
|
| 163 |
+
# 예측 실행
|
| 164 |
+
prediction = best_model.predict({'duration_input': duration_np, 'sid_input': sid_np})
|
| 165 |
+
|
| 166 |
+
output_sequence = prediction.flatten().tolist()
|
| 167 |
+
return ", ".join([f"{x:.4f}" for x in output_sequence])
|
| 168 |
+
except Exception as e:
|
| 169 |
+
return f"오류 발생: {e}"
|
| 170 |
+
|
| 171 |
+
iface = gr.Interface(
|
| 172 |
+
fn=predict_duration_keras,
|
| 173 |
+
inputs=[
|
| 174 |
+
gr.Textbox(label="가수 ID (Singer ID)", value="10"),
|
| 175 |
+
gr.Textbox(label="음표 길이 시퀀스 (쉼표로 구분)",
|
| 176 |
+
value="0.1, 0.2, 0.15, 0.5, 0.4, 0.12, 0.1, 0.25")
|
| 177 |
+
],
|
| 178 |
+
outputs=gr.Textbox(label="예측된 음표 길이 시퀀스"),
|
| 179 |
+
title="🎵 Duration Predictor (Keras + MLOps)",
|
| 180 |
+
description="Keras로 훈련된 모델입니다. 가수 ID와 정규 음표 길이 시퀀스를 입력하면, 해당 가수의 고유한 리듬 표현이 적용된 음표 길이를 예측합니다."
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
iface.launch()
|
src/toy_duration_predictor/_legacy/train_torch_mlops.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- 0. 필요 라이브러리 설치 ---
|
| 2 |
+
# 이 스크립트를 실행하기 전에 먼저 터미널에서 아래 명령어를 실행해주세요.
|
| 3 |
+
# pip install torch numpy pytorch-lightning wandb optuna gradio
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.utils.data import DataLoader, TensorDataset, random_split
|
| 8 |
+
import pytorch_lightning as pl
|
| 9 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 10 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
| 11 |
+
import numpy as np
|
| 12 |
+
import optuna
|
| 13 |
+
import gradio as gr
|
| 14 |
+
import os
|
| 15 |
+
|
| 16 |
+
# --- 1. 가상 데이터 생성 및 PyTorch Lightning 데이터 모듈 (업데이트) ---
|
| 17 |
+
# 실제로는 이 부분에 MIDI 데이터를 전처리하고 로드하는 코드가 들어갑니다.
|
| 18 |
+
|
| 19 |
+
# 데이터 관련 상수
|
| 20 |
+
MAX_SEQ_LENGTH = 32
|
| 21 |
+
NUM_SINGERS = 100
|
| 22 |
+
NUM_SAMPLES = 100000 # 전체 샘플 수 (10만개로 증가)
|
| 23 |
+
BATCH_SIZE = 256 # 배치 크기
|
| 24 |
+
|
| 25 |
+
class DurationDataModule(pl.LightningDataModule):
|
| 26 |
+
def __init__(self, batch_size):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.batch_size = batch_size
|
| 29 |
+
self.full_dataset = None
|
| 30 |
+
self.train_dataset = None
|
| 31 |
+
self.val_dataset = None
|
| 32 |
+
self.test_dataset = None
|
| 33 |
+
|
| 34 |
+
def prepare_data(self):
|
| 35 |
+
# 이 메소드는 단일 프로세스에서만 실행됩니다.
|
| 36 |
+
# 데이터를 다운로드하거나 생성하는 로직을 여기에 넣습니다.
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
def setup(self, stage=None):
|
| 40 |
+
# 모든 GPU/TPU에서 실행됩니다. 데이터를 분할하고 할당합니다.
|
| 41 |
+
if not self.full_dataset:
|
| 42 |
+
# 가상의 전체 데이터셋 생성
|
| 43 |
+
durations = torch.rand(NUM_SAMPLES, MAX_SEQ_LENGTH)
|
| 44 |
+
sids = torch.randint(0, NUM_SINGERS, (NUM_SAMPLES, MAX_SEQ_LENGTH))
|
| 45 |
+
labels = durations * torch.rand_like(durations) * 2
|
| 46 |
+
self.full_dataset = TensorDataset(durations, sids, labels)
|
| 47 |
+
|
| 48 |
+
# 훈련(80%), 검증(10%), 테스트(10%) 데이터셋으로 분할
|
| 49 |
+
train_size = int(0.8 * len(self.full_dataset))
|
| 50 |
+
val_size = int(0.1 * len(self.full_dataset))
|
| 51 |
+
test_size = len(self.full_dataset) - train_size - val_size
|
| 52 |
+
|
| 53 |
+
# random_split을 사용하여 데이터를 나눔 (매번 동일한 분할을 위해 시드 고정)
|
| 54 |
+
self.train_dataset, self.val_dataset, self.test_dataset = random_split(
|
| 55 |
+
self.full_dataset, [train_size, val_size, test_size],
|
| 56 |
+
generator=torch.Generator().manual_seed(42)
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def train_dataloader(self):
|
| 60 |
+
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=os.cpu_count()//2)
|
| 61 |
+
|
| 62 |
+
def val_dataloader(self):
|
| 63 |
+
return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=os.cpu_count()//2)
|
| 64 |
+
|
| 65 |
+
def test_dataloader(self):
|
| 66 |
+
# 테스트 데이터로더 추가
|
| 67 |
+
return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=os.cpu_count()//2)
|
| 68 |
+
|
| 69 |
+
# --- 2. PyTorch Lightning 모델 (업데이트) ---
|
| 70 |
+
# 테스트 스텝 추가
|
| 71 |
+
|
| 72 |
+
class DurationPredictor(pl.LightningModule):
|
| 73 |
+
def __init__(self, hparams):
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.save_hyperparameters(hparams)
|
| 76 |
+
self.sid_embedding = nn.Embedding(self.hparams.num_singers, self.hparams.sid_embedding_dim)
|
| 77 |
+
gru_input_dim = 1 + self.hparams.sid_embedding_dim
|
| 78 |
+
self.gru = nn.GRU(gru_input_dim, self.hparams.gru_units, self.hparams.num_gru_layers,
|
| 79 |
+
batch_first=True, bidirectional=True,
|
| 80 |
+
dropout=self.hparams.dropout_rate if self.hparams.num_gru_layers > 1 else 0)
|
| 81 |
+
self.fc_out = nn.Linear(self.hparams.gru_units * 2, 1)
|
| 82 |
+
self.loss_fn = nn.MSELoss()
|
| 83 |
+
|
| 84 |
+
def forward(self, duration_input, sid_input):
|
| 85 |
+
sid_embedded = self.sid_embedding(sid_input)
|
| 86 |
+
duration_reshaped = duration_input.unsqueeze(-1)
|
| 87 |
+
x = torch.cat([duration_reshaped, sid_embedded], dim=-1)
|
| 88 |
+
gru_output, _ = self.gru(x)
|
| 89 |
+
predictions = self.fc_out(gru_output)
|
| 90 |
+
return predictions
|
| 91 |
+
|
| 92 |
+
def _common_step(self, batch, batch_idx):
|
| 93 |
+
durations, sids, labels = batch
|
| 94 |
+
predictions = self.forward(durations, sids)
|
| 95 |
+
loss = self.loss_fn(predictions.squeeze(-1), labels)
|
| 96 |
+
return loss
|
| 97 |
+
|
| 98 |
+
def training_step(self, batch, batch_idx):
|
| 99 |
+
loss = self._common_step(batch, batch_idx)
|
| 100 |
+
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
|
| 101 |
+
return loss
|
| 102 |
+
|
| 103 |
+
def validation_step(self, batch, batch_idx):
|
| 104 |
+
loss = self._common_step(batch, batch_idx)
|
| 105 |
+
self.log('val_loss', loss, on_epoch=True, prog_bar=True)
|
| 106 |
+
return loss
|
| 107 |
+
|
| 108 |
+
def test_step(self, batch, batch_idx):
|
| 109 |
+
loss = self._common_step(batch, batch_idx)
|
| 110 |
+
self.log('test_loss', loss, on_epoch=True, prog_bar=True)
|
| 111 |
+
return loss
|
| 112 |
+
|
| 113 |
+
def configure_optimizers(self):
|
| 114 |
+
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
|
| 115 |
+
|
| 116 |
+
# --- 3. Optuna를 사용한 하이퍼파라미터 최적화 ---
|
| 117 |
+
def objective(trial: optuna.Trial):
|
| 118 |
+
hparams = {
|
| 119 |
+
'learning_rate': trial.suggest_float('learning_rate', 1e-4, 1e-2, log=True),
|
| 120 |
+
'sid_embedding_dim': trial.suggest_categorical('sid_embedding_dim', [8, 16, 32]),
|
| 121 |
+
'gru_units': trial.suggest_categorical('gru_units', [64, 128]),
|
| 122 |
+
'num_gru_layers': trial.suggest_int('num_gru_layers', 1, 2),
|
| 123 |
+
'dropout_rate': trial.suggest_float('dropout_rate', 0.1, 0.4),
|
| 124 |
+
'num_singers': NUM_SINGERS
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
wandb_logger = WandbLogger(project="duration_predictor_optuna", name=f"trial-{trial.number}", group="optuna-study")
|
| 128 |
+
wandb_logger.log_hyperparams(hparams)
|
| 129 |
+
|
| 130 |
+
model = DurationPredictor(hparams)
|
| 131 |
+
datamodule = DurationDataModule(batch_size=BATCH_SIZE)
|
| 132 |
+
|
| 133 |
+
trainer = pl.Trainer(
|
| 134 |
+
logger=wandb_logger, max_epochs=5, accelerator="auto", devices=1,
|
| 135 |
+
enable_checkpointing=False, callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=2)]
|
| 136 |
+
)
|
| 137 |
+
trainer.fit(model, datamodule)
|
| 138 |
+
return trainer.callback_metrics["val_loss"].item()
|
| 139 |
+
|
| 140 |
+
# --- 4. 메인 실행 블록 (업데이트) ---
|
| 141 |
+
if __name__ == '__main__':
|
| 142 |
+
# --- 1. 하이퍼파라미터 최적화 ---
|
| 143 |
+
print("--- 1. 하이퍼파라미터 최적화 시작 (Optuna) ---")
|
| 144 |
+
study = optuna.create_study(direction='minimize')
|
| 145 |
+
study.optimize(objective, n_trials=10) # 실제로는 50~100회 이상 권장
|
| 146 |
+
|
| 147 |
+
print("최적화 완료!")
|
| 148 |
+
best_hparams = study.best_params
|
| 149 |
+
print(f"최고의 val_loss: {study.best_value}\n최적 하이퍼파라미터: {best_hparams}")
|
| 150 |
+
|
| 151 |
+
# --- 2. 최적 모델 훈련 및 저장 ---
|
| 152 |
+
print("\n--- 2. 최적의 하이퍼파라미터로 최종 모델 학습 및 평가 시작 ---")
|
| 153 |
+
final_hparams = best_hparams
|
| 154 |
+
final_hparams['num_singers'] = NUM_SINGERS
|
| 155 |
+
|
| 156 |
+
datamodule = DurationDataModule(batch_size=BATCH_SIZE)
|
| 157 |
+
model = DurationPredictor(final_hparams)
|
| 158 |
+
|
| 159 |
+
wandb_logger = WandbLogger(project="duration_predictor_final", name="final_best_model")
|
| 160 |
+
wandb_logger.log_hyperparams(final_hparams)
|
| 161 |
+
|
| 162 |
+
checkpoint_callback = ModelCheckpoint(
|
| 163 |
+
dirpath='checkpoints', filename='best-model', save_top_k=1, monitor='val_loss', mode='min'
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
trainer = pl.Trainer(
|
| 167 |
+
logger=wandb_logger, max_epochs=20, accelerator="auto", devices=1,
|
| 168 |
+
callbacks=[checkpoint_callback, EarlyStopping(monitor="val_loss", mode="min", patience=4)]
|
| 169 |
+
)
|
| 170 |
+
trainer.fit(model, datamodule)
|
| 171 |
+
print("최종 모델 학습 완료!")
|
| 172 |
+
|
| 173 |
+
# --- 3. 최종 성능 평가 (테스트) ---
|
| 174 |
+
print(f"저장된 최고 성능 모델 경로: {checkpoint_callback.best_model_path}")
|
| 175 |
+
# trainer.test()는 최고의 체크포인트를 자동으로 불러와 평가를 진행합니다.
|
| 176 |
+
test_results = trainer.test(ckpt_path='best', datamodule=datamodule)
|
| 177 |
+
print("최종 테스트 결과:", test_results)
|
| 178 |
+
|
| 179 |
+
# --- 4. Gradio 데모 실행 ---
|
| 180 |
+
print("\n--- 4. Gradio 데모 인터페이스 실행 ---")
|
| 181 |
+
|
| 182 |
+
best_model = DurationPredictor.load_from_checkpoint(checkpoint_callback.best_model_path)
|
| 183 |
+
best_model.eval()
|
| 184 |
+
|
| 185 |
+
def predict_duration(singer_id_str, duration_sequence_str):
|
| 186 |
+
try:
|
| 187 |
+
singer_id = int(singer_id_str)
|
| 188 |
+
durations = [float(d.strip()) for d in duration_sequence_str.split(',')]
|
| 189 |
+
|
| 190 |
+
if len(durations) > MAX_SEQ_LENGTH:
|
| 191 |
+
durations = durations[:MAX_SEQ_LENGTH]
|
| 192 |
+
else:
|
| 193 |
+
durations += [0] * (MAX_SEQ_LENGTH - len(durations))
|
| 194 |
+
|
| 195 |
+
duration_tensor = torch.tensor(durations, dtype=torch.float32).unsqueeze(0)
|
| 196 |
+
sid_tensor = torch.full_like(duration_tensor, singer_id, dtype=torch.long)
|
| 197 |
+
|
| 198 |
+
with torch.no_grad():
|
| 199 |
+
prediction = best_model(duration_tensor, sid_tensor)
|
| 200 |
+
|
| 201 |
+
output_sequence = prediction.squeeze().tolist()
|
| 202 |
+
return ", ".join([f"{x:.4f}" for x in output_sequence])
|
| 203 |
+
|
| 204 |
+
except Exception as e:
|
| 205 |
+
return f"오류 발생: {e}"
|
| 206 |
+
|
| 207 |
+
iface = gr.Interface(
|
| 208 |
+
fn=predict_duration,
|
| 209 |
+
inputs=[
|
| 210 |
+
gr.Textbox(label="가수 ID (Singer ID)", value="10"),
|
| 211 |
+
gr.Textbox(label="음표 길이 시퀀스 (쉼표로 구분)",
|
| 212 |
+
value="0.1, 0.2, 0.15, 0.5, 0.4, 0.12, 0.1, 0.25")
|
| 213 |
+
],
|
| 214 |
+
outputs=gr.Textbox(label="예측된 음표 길이 시퀀스"),
|
| 215 |
+
title="🎵 Duration Predictor (리듬 표현 예측기)",
|
| 216 |
+
description="가수 ID와 정규 음표 길이 시퀀스를 입력하면, 해당 가수의 고유한 리듬 표현이 적용된 음표 길이를 예측합니다."
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
iface.launch()
|
src/toy_duration_predictor/_legacy/train_torch_vanilla.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.utils.data import DataLoader, TensorDataset, random_split
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
import copy
|
| 7 |
+
|
| 8 |
+
# --- 1. 모델 아키텍처 및 하이퍼파라미터 정의 ---
|
| 9 |
+
# 이 값들은 실제 데이터셋과 실험 목적에 맞게 조정될 수 있습니다.
|
| 10 |
+
|
| 11 |
+
# 데이터 관련 파라미터
|
| 12 |
+
MAX_SEQ_LENGTH = 32
|
| 13 |
+
NUM_SINGERS = 100
|
| 14 |
+
NUM_SAMPLES = 100000 # 10만개 샘플
|
| 15 |
+
BATCH_SIZE = 256
|
| 16 |
+
|
| 17 |
+
# 모델 구조 관련 파라미터
|
| 18 |
+
SID_EMBEDDING_DIM = 16
|
| 19 |
+
GRU_UNITS = 128
|
| 20 |
+
NUM_GRU_LAYERS = 2
|
| 21 |
+
DROPOUT_RATE = 0.3
|
| 22 |
+
|
| 23 |
+
# 훈련 관련 파라미터
|
| 24 |
+
LEARNING_RATE = 0.001
|
| 25 |
+
NUM_EPOCHS = 50 # 최대 훈련 에포크 수
|
| 26 |
+
EARLY_STOPPING_PATIENCE = 5 # 검증 성능이 5 에포크 동안 개선되지 않으면 조기 종료
|
| 27 |
+
|
| 28 |
+
class DurationPredictorGRU(nn.Module):
|
| 29 |
+
"""
|
| 30 |
+
가수 ID(SID)와 음표 길이 시퀀스를 입력받아,
|
| 31 |
+
표현력 있는(expressive) 음표 길이 시퀀스를 예측하는 양방향 GRU 모델입니다.
|
| 32 |
+
"""
|
| 33 |
+
def __init__(self):
|
| 34 |
+
super(DurationPredictorGRU, self).__init__()
|
| 35 |
+
self.sid_embedding = nn.Embedding(NUM_SINGERS, SID_EMBEDDING_DIM)
|
| 36 |
+
gru_input_dim = 1 + SID_EMBEDDING_DIM
|
| 37 |
+
self.gru = nn.GRU(
|
| 38 |
+
input_size=gru_input_dim,
|
| 39 |
+
hidden_size=GRU_UNITS,
|
| 40 |
+
num_layers=NUM_GRU_LAYERS,
|
| 41 |
+
batch_first=True,
|
| 42 |
+
bidirectional=True,
|
| 43 |
+
dropout=DROPOUT_RATE if NUM_GRU_LAYERS > 1 else 0
|
| 44 |
+
)
|
| 45 |
+
self.fc_out = nn.Linear(GRU_UNITS * 2, 1)
|
| 46 |
+
|
| 47 |
+
def forward(self, duration_input, sid_input):
|
| 48 |
+
sid_embedded = self.sid_embedding(sid_input)
|
| 49 |
+
duration_reshaped = duration_input.unsqueeze(-1)
|
| 50 |
+
x = torch.cat([duration_reshaped, sid_embedded], dim=-1)
|
| 51 |
+
gru_output, _ = self.gru(x)
|
| 52 |
+
predictions = self.fc_out(gru_output)
|
| 53 |
+
return predictions
|
| 54 |
+
|
| 55 |
+
# --- 2. 데이터 준비 ---
|
| 56 |
+
print("--- 데이터셋 준비 중... ---")
|
| 57 |
+
# 가상의 전체 데이터셋 생성
|
| 58 |
+
durations = torch.rand(NUM_SAMPLES, MAX_SEQ_LENGTH)
|
| 59 |
+
sids = torch.randint(0, NUM_SINGERS, (NUM_SAMPLES, MAX_SEQ_LENGTH))
|
| 60 |
+
labels = (durations * torch.rand_like(durations) * 2).unsqueeze(-1)
|
| 61 |
+
full_dataset = TensorDataset(durations, sids, labels)
|
| 62 |
+
|
| 63 |
+
# 훈련(80%), 검증(10%), 테스트(10%) 데이터셋으로 분할
|
| 64 |
+
train_size = int(0.8 * NUM_SAMPLES)
|
| 65 |
+
val_size = int(0.1 * NUM_SAMPLES)
|
| 66 |
+
test_size = NUM_SAMPLES - train_size - val_size
|
| 67 |
+
train_dataset, val_dataset, test_dataset = random_split(
|
| 68 |
+
full_dataset, [train_size, val_size, test_size],
|
| 69 |
+
generator=torch.Generator().manual_seed(42)
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# 각 데이터셋을 위한 DataLoader 생성
|
| 73 |
+
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
| 74 |
+
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
|
| 75 |
+
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)
|
| 76 |
+
|
| 77 |
+
print(f"훈련 데이터 샘플 수: {len(train_dataset)}")
|
| 78 |
+
print(f"검증 데이터 샘플 수: {len(val_dataset)}")
|
| 79 |
+
print(f"테스트 데이터 샘플 수: {len(test_dataset)}")
|
| 80 |
+
|
| 81 |
+
# --- 3. 훈련 및 검증 루프 ---
|
| 82 |
+
print("\n--- 모델 훈련 시작... ---")
|
| 83 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 84 |
+
model = DurationPredictorGRU().to(device)
|
| 85 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
| 86 |
+
loss_fn = nn.MSELoss()
|
| 87 |
+
|
| 88 |
+
# 최적 모델 저장을 위한 변수 초기화
|
| 89 |
+
best_val_loss = float('inf')
|
| 90 |
+
patience_counter = 0
|
| 91 |
+
best_model_state = None
|
| 92 |
+
|
| 93 |
+
for epoch in range(NUM_EPOCHS):
|
| 94 |
+
# --- 훈련 단계 ---
|
| 95 |
+
model.train() # 모델을 훈련 모드로 설정
|
| 96 |
+
total_train_loss = 0
|
| 97 |
+
for batch_idx, (duration, sid, label) in enumerate(train_loader):
|
| 98 |
+
duration, sid, label = duration.to(device), sid.to(device), label.to(device)
|
| 99 |
+
|
| 100 |
+
# 순전파
|
| 101 |
+
predictions = model(duration, sid)
|
| 102 |
+
loss = loss_fn(predictions, label)
|
| 103 |
+
|
| 104 |
+
# 역전파
|
| 105 |
+
optimizer.zero_grad()
|
| 106 |
+
loss.backward()
|
| 107 |
+
optimizer.step()
|
| 108 |
+
|
| 109 |
+
total_train_loss += loss.item()
|
| 110 |
+
|
| 111 |
+
avg_train_loss = total_train_loss / len(train_loader)
|
| 112 |
+
|
| 113 |
+
# --- 검증 단계 ---
|
| 114 |
+
model.eval() # 모델을 평가 모드로 설정
|
| 115 |
+
total_val_loss = 0
|
| 116 |
+
with torch.no_grad(): # 그래디언트 계산 비활성화
|
| 117 |
+
for duration, sid, label in val_loader:
|
| 118 |
+
duration, sid, label = duration.to(device), sid.to(device), label.to(device)
|
| 119 |
+
predictions = model(duration, sid)
|
| 120 |
+
loss = loss_fn(predictions, label)
|
| 121 |
+
total_val_loss += loss.item()
|
| 122 |
+
|
| 123 |
+
avg_val_loss = total_val_loss / len(val_loader)
|
| 124 |
+
|
| 125 |
+
print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}")
|
| 126 |
+
|
| 127 |
+
# --- 체크포인팅 및 조기 종료 로직 ---
|
| 128 |
+
# 검증 손실이 개선되었는지 확인
|
| 129 |
+
if avg_val_loss < best_val_loss:
|
| 130 |
+
best_val_loss = avg_val_loss
|
| 131 |
+
# 가장 좋은 모델의 가중치를 deepcopy로 저장
|
| 132 |
+
best_model_state = copy.deepcopy(model.state_dict())
|
| 133 |
+
patience_counter = 0 # 인내심 카운터 초기화
|
| 134 |
+
print(f" -> 검증 성능 개선! 최적 모델 저장됨. (Val Loss: {best_val_loss:.6f})")
|
| 135 |
+
else:
|
| 136 |
+
patience_counter += 1
|
| 137 |
+
print(f" -> 검증 성능 개선 없음. (Patience: {patience_counter}/{EARLY_STOPPING_PATIENCE})")
|
| 138 |
+
|
| 139 |
+
# 조기 종료 조건 확인
|
| 140 |
+
if patience_counter >= EARLY_STOPPING_PATIENCE:
|
| 141 |
+
print(f"\n조기 종료: {EARLY_STOPPING_PATIENCE} 에포크 동안 검증 성능 개선이 없어 훈련을 중단합니다.")
|
| 142 |
+
break
|
| 143 |
+
|
| 144 |
+
# --- 4. 최종 모델 평가 (테스트 단계) ---
|
| 145 |
+
print("\n--- 최종 모델 평가 시작 (테스트 데이터셋 사용)... ---")
|
| 146 |
+
|
| 147 |
+
# 저장된 최적의 모델 가중치를 불러오기
|
| 148 |
+
if best_model_state:
|
| 149 |
+
model.load_state_dict(best_model_state)
|
| 150 |
+
else:
|
| 151 |
+
print("경고: 저장된 최적 모델이 없습니다. 마지막 에포크 모델로 평가합니다.")
|
| 152 |
+
|
| 153 |
+
model.eval()
|
| 154 |
+
total_test_loss = 0
|
| 155 |
+
with torch.no_grad():
|
| 156 |
+
for duration, sid, label in test_loader:
|
| 157 |
+
duration, sid, label = duration.to(device), sid.to(device), label.to(device)
|
| 158 |
+
predictions = model(duration, sid)
|
| 159 |
+
loss = loss_fn(predictions, label)
|
| 160 |
+
total_test_loss += loss.item()
|
| 161 |
+
|
| 162 |
+
avg_test_loss = total_test_loss / len(test_loader)
|
| 163 |
+
|
| 164 |
+
print("-" * 50)
|
| 165 |
+
print(f"최종 테스트 손실 (MSE): {avg_test_loss:.6f}")
|
| 166 |
+
print("이것이 논문에 보고할 최종 모델의 일반화 성능입니다.")
|
| 167 |
+
print("-" * 50)
|
| 168 |
+
|
| 169 |
+
# (선택) 최적 모델 가중치 파일로 저장
|
| 170 |
+
if best_model_state:
|
| 171 |
+
torch.save(best_model_state, 'best_duration_predictor.pth')
|
| 172 |
+
print("최적 모델 가중치가 'best_duration_predictor.pth' 파일로 저장되었습니다.")
|
src/toy_duration_predictor/preprocess/mssv.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import re
|
| 3 |
+
from typing import List, Dict, Any
|
| 4 |
+
|
| 5 |
+
import ray
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import midii
|
| 8 |
+
|
| 9 |
+
from .utils import (
|
| 10 |
+
_preprocess_sort_by_start_time,
|
| 11 |
+
_preprocess_remove_front_back_silence,
|
| 12 |
+
_preprocess_silence_pitch_zero,
|
| 13 |
+
_preprocess_merge_silence,
|
| 14 |
+
_preprocess_remove_short_silence,
|
| 15 |
+
_preprocess_add_quantized_duration_col,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def singer_id_from_filepath(filepath):
|
| 20 |
+
return int(re.findall(r"s\d\d", filepath)[0][1:])
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def midi_to_note_list(midi_filepath, quantize=False):
|
| 24 |
+
try:
|
| 25 |
+
mid = midii.MidiFile(
|
| 26 |
+
midi_filepath, convert_1_to_0=True, lyric_encoding="utf-8"
|
| 27 |
+
)
|
| 28 |
+
mid.lyrics
|
| 29 |
+
except: # noqa: E722
|
| 30 |
+
mid = midii.MidiFile(
|
| 31 |
+
midi_filepath, convert_1_to_0=True, lyric_encoding="cp949"
|
| 32 |
+
)
|
| 33 |
+
if quantize:
|
| 34 |
+
mid.quantize(unit="32")
|
| 35 |
+
data = []
|
| 36 |
+
total_duration = 0
|
| 37 |
+
residual_duration = 0
|
| 38 |
+
active_note = {}
|
| 39 |
+
silence_note = {}
|
| 40 |
+
for msg in mid.tracks[0]:
|
| 41 |
+
msg_end_time = total_duration + msg.time
|
| 42 |
+
if msg.type == "note_on" and msg.velocity > 0:
|
| 43 |
+
residual_duration += msg.time
|
| 44 |
+
if residual_duration > 0:
|
| 45 |
+
if not silence_note:
|
| 46 |
+
silence_note = {
|
| 47 |
+
"start_time": total_duration,
|
| 48 |
+
"pitch": 0,
|
| 49 |
+
"lyric": " ",
|
| 50 |
+
}
|
| 51 |
+
silence_note["end_time"] = msg_end_time
|
| 52 |
+
silence_note["duration"] = (
|
| 53 |
+
msg_end_time - silence_note["start_time"]
|
| 54 |
+
)
|
| 55 |
+
data.append(silence_note.copy())
|
| 56 |
+
silence_note.clear()
|
| 57 |
+
residual_duration = 0
|
| 58 |
+
active_note = {
|
| 59 |
+
"start_time": msg_end_time,
|
| 60 |
+
"pitch": msg.note,
|
| 61 |
+
}
|
| 62 |
+
elif msg.type == "lyrics":
|
| 63 |
+
active_note["lyric"] = midii.MessageAnalyzer_lyrics(
|
| 64 |
+
msg=msg, encoding=mid.lyric_encoding
|
| 65 |
+
).lyric
|
| 66 |
+
elif msg.type == "note_off" or (
|
| 67 |
+
msg.type == "note_on" and msg.velocity == 0
|
| 68 |
+
):
|
| 69 |
+
active_note["end_time"] = msg_end_time
|
| 70 |
+
active_note["duration"] = msg_end_time - active_note["start_time"]
|
| 71 |
+
data.append(active_note.copy())
|
| 72 |
+
active_note.clear()
|
| 73 |
+
else:
|
| 74 |
+
if not active_note and not silence_note:
|
| 75 |
+
silence_note = {
|
| 76 |
+
"start_time": total_duration,
|
| 77 |
+
"pitch": 0,
|
| 78 |
+
"lyric": " ",
|
| 79 |
+
}
|
| 80 |
+
if not active_note:
|
| 81 |
+
residual_duration += msg.time
|
| 82 |
+
total_duration = msg_end_time
|
| 83 |
+
|
| 84 |
+
return data, mid.ticks_per_beat
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _preprocess_slice_actual_lyric(df):
|
| 88 |
+
j_indices = df.index[df["lyric"] == "J"].tolist()
|
| 89 |
+
idx_j = j_indices[0]
|
| 90 |
+
h_indices = df.index[df["lyric"] == "H"].tolist()
|
| 91 |
+
idx_h = h_indices[0]
|
| 92 |
+
slice_start_index = idx_j + 1
|
| 93 |
+
slice_end_index = idx_h
|
| 94 |
+
df = df.iloc[slice_start_index:slice_end_index].reset_index(drop=True)
|
| 95 |
+
return df
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def preprocess_notes(notes, ticks_per_beat, unit="32"):
|
| 99 |
+
df = pd.DataFrame(notes)
|
| 100 |
+
|
| 101 |
+
# ["J":"H"]
|
| 102 |
+
df = _preprocess_slice_actual_lyric(df)
|
| 103 |
+
# sort by time
|
| 104 |
+
df = _preprocess_sort_by_start_time(df)
|
| 105 |
+
# remove front & back silence
|
| 106 |
+
df = _preprocess_remove_front_back_silence(df)
|
| 107 |
+
# lyric=" " --> pitch=0
|
| 108 |
+
df = _preprocess_silence_pitch_zero(df)
|
| 109 |
+
# merge lyric=" " items
|
| 110 |
+
df = _preprocess_merge_silence(df)
|
| 111 |
+
# remove silence < 0.3
|
| 112 |
+
df = _preprocess_remove_short_silence(df, 0.3)
|
| 113 |
+
#
|
| 114 |
+
df = _preprocess_add_quantized_duration_col(df, ticks_per_beat, unit=unit)
|
| 115 |
+
return df
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def process_midi_flat_map(row: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 119 |
+
"""
|
| 120 |
+
Processes a single file path. Returns a list containing one dictionary on success,
|
| 121 |
+
or an empty list on failure. Designed for use with flat_map().
|
| 122 |
+
"""
|
| 123 |
+
file_path_str = row["path"]
|
| 124 |
+
try:
|
| 125 |
+
mssv_sample_list, ticks_per_beat = midi_to_note_list(file_path_str)
|
| 126 |
+
df = preprocess_notes(mssv_sample_list, ticks_per_beat=ticks_per_beat)
|
| 127 |
+
singer_id = singer_id_from_filepath(file_path_str)
|
| 128 |
+
|
| 129 |
+
durations = df["duration"].tolist()
|
| 130 |
+
quantized_durations = df["quantized_duration"].tolist()
|
| 131 |
+
|
| 132 |
+
return [
|
| 133 |
+
{
|
| 134 |
+
"durations": durations,
|
| 135 |
+
"quantized_durations": quantized_durations,
|
| 136 |
+
"singer_id": singer_id,
|
| 137 |
+
}
|
| 138 |
+
]
|
| 139 |
+
except Exception as e:
|
| 140 |
+
print(f"CRITICAL ERROR processing {file_path_str}: {e}")
|
| 141 |
+
return []
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def preprocess_dataset(midi_file_directory, output_parquet_path):
|
| 145 |
+
context = ray.init()
|
| 146 |
+
print(context.dashboard_url)
|
| 147 |
+
print(f"Ray cluster started: {ray.cluster_resources()}")
|
| 148 |
+
|
| 149 |
+
all_midi_paths = list(Path(midi_file_directory).rglob("*.mid"))
|
| 150 |
+
print(f"Found {len(all_midi_paths)} MIDI files to process.")
|
| 151 |
+
|
| 152 |
+
# --- Step 2: Create a dataset of file paths ---
|
| 153 |
+
# This is the idiomatic way. We create a dataset where each row is a path.
|
| 154 |
+
# We pass a list of dictionaries to give the column a name: "path".
|
| 155 |
+
print("Creating initial dataset of file paths...")
|
| 156 |
+
ds = ray.data.from_items([{"path": str(p)} for p in all_midi_paths])
|
| 157 |
+
|
| 158 |
+
# --- Step 3: Use .map() to process files in parallel ---
|
| 159 |
+
# .map() applies the function to each row of the dataset in parallel.
|
| 160 |
+
# Ray Data manages the tasks, memory, and scheduling for you.
|
| 161 |
+
print("Applying parallel processing function using .map()...")
|
| 162 |
+
processed_ds = ds.flat_map(process_midi_flat_map)
|
| 163 |
+
|
| 164 |
+
# --- Step 4: Filter out any rows that failed ---
|
| 165 |
+
# Our function returns an empty dict on failure, so we filter those out.
|
| 166 |
+
# processed_ds = processed_ds.filter(lambda row: row)
|
| 167 |
+
|
| 168 |
+
# Now, let's inspect the schema of the PROCESSED dataset
|
| 169 |
+
print("\nProcessed Dataset schema:")
|
| 170 |
+
print(processed_ds.schema()) # Corrected print statement
|
| 171 |
+
|
| 172 |
+
print("\nFirst 1 rows of processed data:")
|
| 173 |
+
processed_ds.show(1)
|
| 174 |
+
|
| 175 |
+
# You chose to repartition to 1, which is fine if the total dataset is < 1GB.
|
| 176 |
+
# This will create a single Parquet file in the output directory.
|
| 177 |
+
print("Repartitioning dataset...")
|
| 178 |
+
processed_ds = processed_ds.repartition(num_blocks=1)
|
| 179 |
+
|
| 180 |
+
print(f"\nWriting dataset to Parquet format at: {output_parquet_path}")
|
| 181 |
+
processed_ds.write_parquet(output_parquet_path)
|
| 182 |
+
|
| 183 |
+
# This will now show the correct number of processed files (rows in the dataset)
|
| 184 |
+
print(
|
| 185 |
+
f"\nProcessing complete! {processed_ds.count()} files successfully processed."
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
ray.shutdown()
|
src/toy_duration_predictor/preprocess/utils.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import midii
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_files(dir_path, type, sort=False):
|
| 8 |
+
paths = Path(dir_path).rglob(f"*.{type}")
|
| 9 |
+
if sort:
|
| 10 |
+
return sorted(paths, key=lambda p: p.stem)
|
| 11 |
+
else:
|
| 12 |
+
return paths
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _preprocess_remove_front_back_silence(df):
|
| 16 |
+
is_valid_lyric = df["lyric"] != " "
|
| 17 |
+
valid_indices = df.index[is_valid_lyric].tolist()
|
| 18 |
+
first_valid_idx = valid_indices[0]
|
| 19 |
+
last_valid_idx = valid_indices[-1]
|
| 20 |
+
df = df.iloc[first_valid_idx : last_valid_idx + 1].reset_index(drop=True)
|
| 21 |
+
return df
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _preprocess_sort_by_start_time(df):
|
| 25 |
+
df = df.sort_values(by="start_time").reset_index(drop=True)
|
| 26 |
+
return df
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _preprocess_remove_front_back_silence(df):
|
| 30 |
+
is_valid_lyric = df["lyric"] != " "
|
| 31 |
+
valid_indices = df.index[is_valid_lyric].tolist()
|
| 32 |
+
first_valid_idx = valid_indices[0]
|
| 33 |
+
last_valid_idx = valid_indices[-1]
|
| 34 |
+
df = df.iloc[first_valid_idx : last_valid_idx + 1].reset_index(drop=True)
|
| 35 |
+
return df
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _preprocess_silence_pitch_zero(df):
|
| 39 |
+
df.loc[df["lyric"] == " ", "pitch"] = 0
|
| 40 |
+
return df
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _preprocess_merge_silence(df):
|
| 44 |
+
output_notes = []
|
| 45 |
+
i = 0
|
| 46 |
+
n = len(df)
|
| 47 |
+
while i < n:
|
| 48 |
+
current_row = df.iloc[i] # Pandas Series
|
| 49 |
+
if current_row["lyric"] == " ":
|
| 50 |
+
merged_start_time = current_row["start_time"]
|
| 51 |
+
merged_end_time = current_row["end_time"]
|
| 52 |
+
|
| 53 |
+
j = i + 1
|
| 54 |
+
while j < n and df.iloc[j]["lyric"] == " ":
|
| 55 |
+
merged_end_time = df.iloc[j][
|
| 56 |
+
"end_time"
|
| 57 |
+
] # 마지막 공백의 end_time으로 업데이트
|
| 58 |
+
j += 1
|
| 59 |
+
|
| 60 |
+
merged_item = {
|
| 61 |
+
"start_time": merged_start_time,
|
| 62 |
+
"end_time": merged_end_time,
|
| 63 |
+
"pitch": 0,
|
| 64 |
+
"lyric": " ",
|
| 65 |
+
"duration": merged_end_time - merged_start_time,
|
| 66 |
+
}
|
| 67 |
+
output_notes.append(merged_item)
|
| 68 |
+
i = j # 병합된 블록 다음으로 인덱스 이동
|
| 69 |
+
else:
|
| 70 |
+
non_space_item = {
|
| 71 |
+
"start_time": current_row["start_time"],
|
| 72 |
+
"end_time": current_row["end_time"],
|
| 73 |
+
"pitch": current_row["pitch"],
|
| 74 |
+
"lyric": current_row["lyric"],
|
| 75 |
+
"duration": current_row["duration"],
|
| 76 |
+
}
|
| 77 |
+
output_notes.append(non_space_item)
|
| 78 |
+
i += 1
|
| 79 |
+
df = pd.DataFrame(output_notes)
|
| 80 |
+
return df
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _preprocess_remove_short_silence(df, threshold=0.3):
|
| 84 |
+
processed_notes = []
|
| 85 |
+
absorbed_time = 0.0
|
| 86 |
+
|
| 87 |
+
for i in range(len(df)):
|
| 88 |
+
current_note_s = df.iloc[i]
|
| 89 |
+
if (
|
| 90 |
+
current_note_s["lyric"] == " "
|
| 91 |
+
and current_note_s["duration"] < threshold
|
| 92 |
+
):
|
| 93 |
+
absorbed_time += current_note_s["duration"]
|
| 94 |
+
continue
|
| 95 |
+
else:
|
| 96 |
+
note_to_add = current_note_s.to_dict()
|
| 97 |
+
if absorbed_time > 0:
|
| 98 |
+
note_to_add["start_time"] -= absorbed_time
|
| 99 |
+
note_to_add["duration"] = (
|
| 100 |
+
note_to_add["end_time"] - note_to_add["start_time"]
|
| 101 |
+
)
|
| 102 |
+
absorbed_time = 0.0
|
| 103 |
+
processed_notes.append(note_to_add)
|
| 104 |
+
|
| 105 |
+
df = pd.DataFrame(processed_notes)
|
| 106 |
+
return df
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _preprocess_add_quantized_duration_col(df, ticks_per_beat, unit="32"):
|
| 110 |
+
unit_tick = midii.beat2tick(
|
| 111 |
+
midii.NOTE[f"n/{unit}"].beat, ticks_per_beat=ticks_per_beat
|
| 112 |
+
)
|
| 113 |
+
df["quantized_duration"], _ = midii.quantize(
|
| 114 |
+
df["duration"].values, unit=unit_tick
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
return df
|
src/toy_duration_predictor/py.typed
ADDED
|
File without changes
|
src/toy_duration_predictor/train_fastai.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from datasets import load_dataset, DatasetDict
|
| 4 |
+
from fastai.vision.all import * # Import a more general base for DataBlock
|
| 5 |
+
|
| 6 |
+
# Import Ray and the correct modern modules
|
| 7 |
+
import ray
|
| 8 |
+
from ray import tune
|
| 9 |
+
from ray.air import session # The new way to report metrics
|
| 10 |
+
from ray.tune.search.optuna import OptunaSearch
|
| 11 |
+
|
| 12 |
+
# Import tools for logging and demos
|
| 13 |
+
import wandb
|
| 14 |
+
from ray.air.integrations.wandb import WandbLoggerCallback
|
| 15 |
+
import gradio as gr
|
| 16 |
+
|
| 17 |
+
# --- 1. Configuration & Hyperparameters ---
|
| 18 |
+
# You can adjust these values for your experiments
|
| 19 |
+
|
| 20 |
+
# Data Parameters
|
| 21 |
+
REPO_ID = "ccss17/note-duration-dataset"
|
| 22 |
+
SEQUENCE_LENGTH = 128
|
| 23 |
+
PAD_TOKEN = 0
|
| 24 |
+
BATCH_SIZE = 64
|
| 25 |
+
|
| 26 |
+
# Model Parameters
|
| 27 |
+
NUM_SINGERS = 18 # IMPORTANT: Set this to the total number of unique singers in your dataset
|
| 28 |
+
SINGER_EMBEDDING_DIM = 16
|
| 29 |
+
HIDDEN_SIZE = 256
|
| 30 |
+
NUM_LAYERS = 2
|
| 31 |
+
DROPOUT = 0.3
|
| 32 |
+
|
| 33 |
+
# --- NEW: Configuration for wandb ---
|
| 34 |
+
# Set this to False if you want to run the script without logging to wandb
|
| 35 |
+
WANDB_ENABLED = True
|
| 36 |
+
WANDB_ENTITY = "ccss17" # Your wandb username
|
| 37 |
+
WANDB_PROJECT = "toy-duration-predictor"
|
| 38 |
+
|
| 39 |
+
# --- 2. Model Architecture Definition (PyTorch) ---
|
| 40 |
+
# A bi-GRU model that takes a sequence and a singer ID as input.
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ToyDurationPredictor(nn.Module):
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
vocab_size,
|
| 47 |
+
embedding_dim,
|
| 48 |
+
hidden_size,
|
| 49 |
+
num_layers,
|
| 50 |
+
dropout,
|
| 51 |
+
num_singers,
|
| 52 |
+
singer_embedding_dim,
|
| 53 |
+
):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.num_singers = num_singers
|
| 56 |
+
|
| 57 |
+
# Embedding layer for the singer ID
|
| 58 |
+
self.singer_embedding = nn.Embedding(num_singers, singer_embedding_dim)
|
| 59 |
+
|
| 60 |
+
# Bi-directional GRU layers
|
| 61 |
+
self.rnn = nn.GRU(
|
| 62 |
+
input_size=1 + singer_embedding_dim,
|
| 63 |
+
hidden_size=hidden_size,
|
| 64 |
+
num_layers=num_layers,
|
| 65 |
+
bidirectional=True,
|
| 66 |
+
dropout=dropout,
|
| 67 |
+
batch_first=True,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
self.fc = nn.Linear(hidden_size * 2, 1)
|
| 71 |
+
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
# fastai's DataBlock will pass the input as a tuple (inp, sid)
|
| 74 |
+
x_seq, x_sid = x
|
| 75 |
+
|
| 76 |
+
# Add a feature dimension to the input sequence
|
| 77 |
+
x_seq = x_seq.unsqueeze(-1).float()
|
| 78 |
+
|
| 79 |
+
# Get singer embedding
|
| 80 |
+
sid_emb = self.singer_embedding(x_sid)
|
| 81 |
+
|
| 82 |
+
# Repeat the singer embedding for each step in the sequence
|
| 83 |
+
sid_emb_expanded = sid_emb.unsqueeze(1).expand(-1, x_seq.size(1), -1)
|
| 84 |
+
|
| 85 |
+
# Concatenate the duration sequence with the singer embedding
|
| 86 |
+
combined_input = torch.cat([x_seq, sid_emb_expanded], dim=-1)
|
| 87 |
+
|
| 88 |
+
# Pass through the GRU
|
| 89 |
+
outputs, _ = self.rnn(combined_input)
|
| 90 |
+
|
| 91 |
+
# Pass through the final fully connected layer
|
| 92 |
+
prediction = self.fc(outputs)
|
| 93 |
+
|
| 94 |
+
return prediction.squeeze(-1)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# --- 3. Data Loading and Preparation ---
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def get_dataloaders(model_type="B"):
|
| 101 |
+
"""
|
| 102 |
+
Loads data from the Hub, splits it, processes it, and returns DataLoaders.
|
| 103 |
+
"""
|
| 104 |
+
print(f"--- Preparing DataLoaders for Model {model_type} ---")
|
| 105 |
+
dataset = load_dataset(REPO_ID, split="train")
|
| 106 |
+
|
| 107 |
+
# Perform the 80/10/10 split
|
| 108 |
+
train_test_split = dataset.train_test_split(test_size=0.2, seed=42)
|
| 109 |
+
test_valid_split = train_test_split["test"].train_test_split(
|
| 110 |
+
test_size=0.5, seed=42
|
| 111 |
+
)
|
| 112 |
+
split_dataset = DatasetDict(
|
| 113 |
+
{
|
| 114 |
+
"train": train_test_split["train"],
|
| 115 |
+
"valid": test_valid_split["train"],
|
| 116 |
+
"test": test_valid_split["test"],
|
| 117 |
+
}
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# The chunking function with padding remains the same
|
| 121 |
+
def chunk_examples_with_padding(examples):
|
| 122 |
+
chunked = {"durations": [], "quantized_durations": [], "singer_id": []}
|
| 123 |
+
for i in range(len(examples["durations"])):
|
| 124 |
+
durs, q_durs, sid = (
|
| 125 |
+
examples["durations"][i],
|
| 126 |
+
examples["quantized_durations"][i],
|
| 127 |
+
examples["singer_id"][i],
|
| 128 |
+
)
|
| 129 |
+
for j in range(0, len(durs), SEQUENCE_LENGTH):
|
| 130 |
+
d_chunk = durs[j : j + SEQUENCE_LENGTH]
|
| 131 |
+
q_chunk = q_durs[j : j + SEQUENCE_LENGTH]
|
| 132 |
+
if len(d_chunk) < SEQUENCE_LENGTH:
|
| 133 |
+
padding_needed = SEQUENCE_LENGTH - len(d_chunk)
|
| 134 |
+
d_chunk.extend([PAD_TOKEN] * padding_needed)
|
| 135 |
+
q_chunk.extend([PAD_TOKEN] * padding_needed)
|
| 136 |
+
chunked["durations"].append(d_chunk)
|
| 137 |
+
chunked["quantized_durations"].append(q_chunk)
|
| 138 |
+
chunked["singer_id"].append(sid)
|
| 139 |
+
return chunked
|
| 140 |
+
|
| 141 |
+
processed_splits = split_dataset.map(
|
| 142 |
+
chunk_examples_with_padding,
|
| 143 |
+
batched=True,
|
| 144 |
+
remove_columns=split_dataset["train"].column_names,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# --- NEW: Simpler, more robust DataBlock setup ---
|
| 148 |
+
|
| 149 |
+
# Define functions to get the inputs (x) and target (y) from a row
|
| 150 |
+
def get_x(row):
|
| 151 |
+
# The input is a tuple of the sequence and the singer id
|
| 152 |
+
if model_type == "A":
|
| 153 |
+
seq = torch.tensor(row["durations"], dtype=torch.long)
|
| 154 |
+
else:
|
| 155 |
+
seq = torch.tensor(row["quantized_durations"], dtype=torch.long)
|
| 156 |
+
sid = torch.tensor(row["singer_id"], dtype=torch.long)
|
| 157 |
+
return (seq, sid)
|
| 158 |
+
|
| 159 |
+
def get_y(row):
|
| 160 |
+
# The target is always the original durations, as a float for regression
|
| 161 |
+
return torch.tensor(row["durations"], dtype=torch.float32)
|
| 162 |
+
|
| 163 |
+
# Create the DataBlock
|
| 164 |
+
dblock = DataBlock(
|
| 165 |
+
blocks=(
|
| 166 |
+
TransformBlock,
|
| 167 |
+
RegressionBlock,
|
| 168 |
+
), # A generic transform block and a regression block
|
| 169 |
+
get_x=get_x,
|
| 170 |
+
get_y=get_y,
|
| 171 |
+
splitter=IndexSplitter(
|
| 172 |
+
split_dataset["valid"]._indices
|
| 173 |
+
), # Use indices for splitting
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# Create the DataLoaders from the processed training set
|
| 177 |
+
dls = dblock.dataloaders(processed_splits["train"], bs=BATCH_SIZE)
|
| 178 |
+
|
| 179 |
+
return dls, processed_splits["test"]
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# --- 4. Custom Callback for Ray Tune + fastai Integration ---
|
| 183 |
+
class TuneReportCallbackForFastAI(Callback):
|
| 184 |
+
def after_epoch(self):
|
| 185 |
+
train_loss = self.learn.recorder.smooth_loss.item()
|
| 186 |
+
valid_loss = self.learn.recorder.val_loss.item()
|
| 187 |
+
mae_metric = self.learn.recorder.metrics[0].value.item()
|
| 188 |
+
session.report(
|
| 189 |
+
{
|
| 190 |
+
"train_loss": train_loss,
|
| 191 |
+
"valid_loss": valid_loss,
|
| 192 |
+
"mae": mae_metric,
|
| 193 |
+
}
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# --- 5. Training Function for Ray Tune ---
|
| 198 |
+
def train_tdp(config):
|
| 199 |
+
model_type = config.pop("model_type", "B")
|
| 200 |
+
dls, _ = get_dataloaders(model_type=model_type)
|
| 201 |
+
|
| 202 |
+
model = ToyDurationPredictor(
|
| 203 |
+
vocab_size=0,
|
| 204 |
+
embedding_dim=0,
|
| 205 |
+
hidden_size=config["hidden_size"],
|
| 206 |
+
num_layers=config["num_layers"],
|
| 207 |
+
dropout=config["dropout"],
|
| 208 |
+
num_singers=NUM_SINGERS,
|
| 209 |
+
singer_embedding_dim=SINGER_EMBEDDING_DIM,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
learn = Learner(dls, model, loss_func=MSELossFlat(), metrics=mae).to_fp16()
|
| 213 |
+
callbacks = [TuneReportCallbackForFastAI()]
|
| 214 |
+
learn.fit_one_cycle(config["epochs"], lr_max=config["lr"], cbs=callbacks)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# --- 6. Main Execution Block ---
|
| 218 |
+
if __name__ == "__main__":
|
| 219 |
+
# --- Option 1: Run a single training for quick testing ---
|
| 220 |
+
print("--- Starting Single Training Run for Model B (Your Method) ---")
|
| 221 |
+
|
| 222 |
+
if WANDB_ENABLED:
|
| 223 |
+
try:
|
| 224 |
+
wandb.init(
|
| 225 |
+
project=WANDB_PROJECT,
|
| 226 |
+
entity=WANDB_ENTITY,
|
| 227 |
+
name="single_run_model_b",
|
| 228 |
+
)
|
| 229 |
+
except Exception as e:
|
| 230 |
+
print(f"Could not initialize wandb: {e}. Disabling for this run.")
|
| 231 |
+
WANDB_ENABLED = False
|
| 232 |
+
|
| 233 |
+
dls_B, test_ds_B = get_dataloaders(model_type="B")
|
| 234 |
+
|
| 235 |
+
model = ToyDurationPredictor(
|
| 236 |
+
vocab_size=0,
|
| 237 |
+
embedding_dim=0,
|
| 238 |
+
hidden_size=HIDDEN_SIZE,
|
| 239 |
+
num_layers=NUM_LAYERS,
|
| 240 |
+
dropout=DROPOUT,
|
| 241 |
+
num_singers=NUM_SINGERS,
|
| 242 |
+
singer_embedding_dim=SINGER_EMBEDDING_DIM,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
callbacks = [WandbCallback()] if WANDB_ENABLED else []
|
| 246 |
+
learn = Learner(
|
| 247 |
+
dls_B, model, loss_func=MSELossFlat(), metrics=mae, cbs=callbacks
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
print("Training the model...")
|
| 251 |
+
learn.fit_one_cycle(5, 1e-3)
|
| 252 |
+
|
| 253 |
+
print("\n--- Evaluating on the held-out test set ---")
|
| 254 |
+
test_dl = dls_B.test_dl(test_ds_B)
|
| 255 |
+
loss, mae_val = learn.validate(dl=test_dl)
|
| 256 |
+
print(
|
| 257 |
+
f"\nFinal Test Set Performance: Loss (MSE)={loss:.4f}, MAE={mae_val:.4f} ticks"
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
if WANDB_ENABLED:
|
| 261 |
+
wandb.finish()
|
| 262 |
+
|
| 263 |
+
# --- Gradio Demo Section ---
|
| 264 |
+
print("\n--- Launching Gradio Demo ---")
|
| 265 |
+
|
| 266 |
+
def predict_durations(quantized_durations_str, singer_id):
|
| 267 |
+
try:
|
| 268 |
+
durs = [int(x.strip()) for x in quantized_durations_str.split(",")]
|
| 269 |
+
inp_tensor = torch.tensor(durs, dtype=torch.long)
|
| 270 |
+
sid_tensor = torch.tensor([int(singer_id)], dtype=torch.long)
|
| 271 |
+
|
| 272 |
+
# The input to the learner's test_dl is a list of items
|
| 273 |
+
# Each item should match what get_x would produce
|
| 274 |
+
dl = learn.dls.test_dl([(inp_tensor, sid_tensor)])
|
| 275 |
+
preds, _ = learn.get_preds(dl=dl)
|
| 276 |
+
|
| 277 |
+
return ", ".join([str(int(p)) for p in preds[0]])
|
| 278 |
+
except Exception as e:
|
| 279 |
+
return f"Error: {e}"
|
| 280 |
+
|
| 281 |
+
iface = gr.Interface(
|
| 282 |
+
fn=predict_durations,
|
| 283 |
+
inputs=[
|
| 284 |
+
gr.Textbox(
|
| 285 |
+
label="Quantized Durations (comma-separated)",
|
| 286 |
+
placeholder="30, 0, 75, 0, 45, 15, ...",
|
| 287 |
+
),
|
| 288 |
+
gr.Number(label="Singer ID", value=2),
|
| 289 |
+
],
|
| 290 |
+
outputs=gr.Textbox(label="Predicted Original Durations"),
|
| 291 |
+
title="Toy Duration Predictor",
|
| 292 |
+
description="Enter a sequence of quantized durations and a singer ID to see the model predict the original, stylistic performance.",
|
| 293 |
+
)
|
| 294 |
+
iface.launch()
|
| 295 |
+
|
| 296 |
+
# --- Option 2: Run Ray Tune ---
|
| 297 |
+
# ... (Ray Tune code remains the same) ...
|
src/toy_duration_predictor/train_lightning.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.utils.data import DataLoader, Dataset
|
| 4 |
+
from datasets import load_dataset, DatasetDict
|
| 5 |
+
import pytorch_lightning as pl
|
| 6 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 7 |
+
from pytorch_lightning.callbacks import TQDMProgressBar
|
| 8 |
+
|
| 9 |
+
# Import Ray and Tune for hyperparameter search
|
| 10 |
+
from ray import tune
|
| 11 |
+
from ray.tune.integration.pytorch_lightning import TuneReportCallback
|
| 12 |
+
|
| 13 |
+
# Import tools for demos
|
| 14 |
+
import wandb
|
| 15 |
+
import gradio as gr
|
| 16 |
+
import pandas as pd
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
# --- 1. Configuration & Hyperparameters ---
|
| 20 |
+
# Data Parameters
|
| 21 |
+
REPO_ID = "ccss17/note-duration-dataset"
|
| 22 |
+
SEQUENCE_LENGTH = 128
|
| 23 |
+
PAD_TOKEN = 0
|
| 24 |
+
BATCH_SIZE = 64
|
| 25 |
+
|
| 26 |
+
# Model Parameters are now set dynamically in the DataModule
|
| 27 |
+
SINGER_EMBEDDING_DIM = 16
|
| 28 |
+
HIDDEN_SIZE = 256
|
| 29 |
+
NUM_LAYERS = 2
|
| 30 |
+
DROPOUT = 0.3
|
| 31 |
+
LEARNING_RATE = 1e-3
|
| 32 |
+
|
| 33 |
+
# Configuration for wandb
|
| 34 |
+
WANDB_ENABLED = True
|
| 35 |
+
WANDB_ENTITY = "ccss17" # Your wandb username
|
| 36 |
+
WANDB_PROJECT = "toy-duration-predictor-lightning"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# --- 2. Model Architecture Definition (Vanilla PyTorch) ---
|
| 40 |
+
class ToyDurationPredictor(nn.Module):
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
hidden_size,
|
| 44 |
+
num_layers,
|
| 45 |
+
dropout,
|
| 46 |
+
num_singers,
|
| 47 |
+
singer_embedding_dim,
|
| 48 |
+
):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.singer_embedding = nn.Embedding(num_singers, singer_embedding_dim)
|
| 51 |
+
self.rnn = nn.GRU(
|
| 52 |
+
input_size=1 + singer_embedding_dim,
|
| 53 |
+
hidden_size=hidden_size,
|
| 54 |
+
num_layers=num_layers,
|
| 55 |
+
bidirectional=True,
|
| 56 |
+
dropout=dropout,
|
| 57 |
+
batch_first=True,
|
| 58 |
+
)
|
| 59 |
+
self.fc = nn.Linear(hidden_size * 2, 1)
|
| 60 |
+
|
| 61 |
+
def forward(self, x_seq, x_sid):
|
| 62 |
+
x_seq = x_seq.unsqueeze(-1).float()
|
| 63 |
+
sid_emb = self.singer_embedding(x_sid)
|
| 64 |
+
sid_emb_expanded = sid_emb.unsqueeze(1).expand(-1, x_seq.size(1), -1)
|
| 65 |
+
combined_input = torch.cat([x_seq, sid_emb_expanded], dim=-1)
|
| 66 |
+
outputs, _ = self.rnn(combined_input)
|
| 67 |
+
prediction = self.fc(outputs)
|
| 68 |
+
return prediction.squeeze(-1)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# --- 3. Data Preparation (PyTorch Dataset & Lightning DataModule) ---
|
| 72 |
+
class DurationDataset(Dataset):
|
| 73 |
+
"""A standard PyTorch Dataset for our chunked data."""
|
| 74 |
+
|
| 75 |
+
def __init__(self, processed_hf_dataset, model_type="B"):
|
| 76 |
+
self.data = processed_hf_dataset
|
| 77 |
+
self.model_type = model_type
|
| 78 |
+
|
| 79 |
+
def __len__(self):
|
| 80 |
+
return len(self.data)
|
| 81 |
+
|
| 82 |
+
def __getitem__(self, idx):
|
| 83 |
+
item = self.data[idx]
|
| 84 |
+
label = torch.tensor(item["durations"], dtype=torch.float32)
|
| 85 |
+
|
| 86 |
+
if self.model_type == "A":
|
| 87 |
+
input_seq = torch.tensor(item["durations"], dtype=torch.long)
|
| 88 |
+
else:
|
| 89 |
+
input_seq = torch.tensor(
|
| 90 |
+
item["quantized_durations"], dtype=torch.long
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# Use the NEW re-indexed singer_idx column
|
| 94 |
+
singer_idx = torch.tensor(item["singer_idx"], dtype=torch.long)
|
| 95 |
+
|
| 96 |
+
return {"input_seq": input_seq, "singer_id": singer_idx}, label
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class DurationDataModule(pl.LightningDataModule):
|
| 100 |
+
"""A LightningDataModule to handle loading, splitting, and batching."""
|
| 101 |
+
|
| 102 |
+
def __init__(self, model_type="B", batch_size=32):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.model_type = model_type
|
| 105 |
+
self.batch_size = batch_size
|
| 106 |
+
self.singer_id_map = {}
|
| 107 |
+
self.num_singers = 0
|
| 108 |
+
|
| 109 |
+
def setup(self, stage=None):
|
| 110 |
+
dataset = load_dataset(REPO_ID, split="train")
|
| 111 |
+
|
| 112 |
+
# --- FIX: Create a mapping for singer IDs ---
|
| 113 |
+
unique_singer_ids = sorted(dataset.unique("singer_id"))
|
| 114 |
+
self.num_singers = len(unique_singer_ids)
|
| 115 |
+
self.singer_id_map = {
|
| 116 |
+
sid: i for i, sid in enumerate(unique_singer_ids)
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
print(
|
| 120 |
+
f"Found {self.num_singers} unique singers. Mapping IDs to [0, {self.num_singers - 1}]"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
def map_singer_ids(example):
|
| 124 |
+
example["singer_idx"] = self.singer_id_map[example["singer_id"]]
|
| 125 |
+
return example
|
| 126 |
+
|
| 127 |
+
dataset = dataset.map(map_singer_ids)
|
| 128 |
+
# --- END FIX ---
|
| 129 |
+
|
| 130 |
+
train_test_split = dataset.train_test_split(test_size=0.2, seed=42)
|
| 131 |
+
test_valid_split = train_test_split["test"].train_test_split(
|
| 132 |
+
test_size=0.5, seed=42
|
| 133 |
+
)
|
| 134 |
+
split_dataset = DatasetDict(
|
| 135 |
+
{
|
| 136 |
+
"train": train_test_split["train"],
|
| 137 |
+
"valid": test_valid_split["train"],
|
| 138 |
+
"test": test_valid_split["test"],
|
| 139 |
+
}
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
def chunk_examples_with_padding(examples):
|
| 143 |
+
chunked = {
|
| 144 |
+
"durations": [],
|
| 145 |
+
"quantized_durations": [],
|
| 146 |
+
"singer_idx": [],
|
| 147 |
+
}
|
| 148 |
+
for i in range(len(examples["durations"])):
|
| 149 |
+
durs, q_durs, s_idx = (
|
| 150 |
+
examples["durations"][i],
|
| 151 |
+
examples["quantized_durations"][i],
|
| 152 |
+
examples["singer_idx"][i],
|
| 153 |
+
)
|
| 154 |
+
for j in range(0, len(durs), SEQUENCE_LENGTH):
|
| 155 |
+
d_chunk, q_chunk = (
|
| 156 |
+
durs[j : j + SEQUENCE_LENGTH],
|
| 157 |
+
q_durs[j : j + SEQUENCE_LENGTH],
|
| 158 |
+
)
|
| 159 |
+
if len(d_chunk) < SEQUENCE_LENGTH:
|
| 160 |
+
padding = [PAD_TOKEN] * (
|
| 161 |
+
SEQUENCE_LENGTH - len(d_chunk)
|
| 162 |
+
)
|
| 163 |
+
d_chunk.extend(padding)
|
| 164 |
+
q_chunk.extend(padding)
|
| 165 |
+
chunked["durations"].append(d_chunk)
|
| 166 |
+
chunked["quantized_durations"].append(q_chunk)
|
| 167 |
+
chunked["singer_idx"].append(s_idx)
|
| 168 |
+
return chunked
|
| 169 |
+
|
| 170 |
+
processed_splits = split_dataset.map(
|
| 171 |
+
chunk_examples_with_padding,
|
| 172 |
+
batched=True,
|
| 173 |
+
remove_columns=dataset.column_names,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
self.train_ds = DurationDataset(
|
| 177 |
+
processed_splits["train"], self.model_type
|
| 178 |
+
)
|
| 179 |
+
self.val_ds = DurationDataset(
|
| 180 |
+
processed_splits["valid"], self.model_type
|
| 181 |
+
)
|
| 182 |
+
self.test_ds = DurationDataset(
|
| 183 |
+
processed_splits["test"], self.model_type
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
def train_dataloader(self):
|
| 187 |
+
return DataLoader(
|
| 188 |
+
self.train_ds,
|
| 189 |
+
batch_size=self.batch_size,
|
| 190 |
+
shuffle=True,
|
| 191 |
+
num_workers=4,
|
| 192 |
+
persistent_workers=True,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
def val_dataloader(self):
|
| 196 |
+
return DataLoader(
|
| 197 |
+
self.val_ds,
|
| 198 |
+
batch_size=self.batch_size,
|
| 199 |
+
num_workers=4,
|
| 200 |
+
persistent_workers=True,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
def test_dataloader(self):
|
| 204 |
+
return DataLoader(
|
| 205 |
+
self.test_ds,
|
| 206 |
+
batch_size=self.batch_size,
|
| 207 |
+
num_workers=4,
|
| 208 |
+
persistent_workers=True,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
# --- 4. The LightningModule ---
|
| 213 |
+
class LightningTDP(pl.LightningModule):
|
| 214 |
+
def __init__(
|
| 215 |
+
self,
|
| 216 |
+
num_singers,
|
| 217 |
+
model_type="B",
|
| 218 |
+
hidden_size=HIDDEN_SIZE,
|
| 219 |
+
num_layers=NUM_LAYERS,
|
| 220 |
+
dropout=DROPOUT,
|
| 221 |
+
learning_rate=LEARNING_RATE,
|
| 222 |
+
):
|
| 223 |
+
super().__init__()
|
| 224 |
+
self.save_hyperparameters()
|
| 225 |
+
self.model = ToyDurationPredictor(
|
| 226 |
+
hidden_size=hidden_size,
|
| 227 |
+
num_layers=num_layers,
|
| 228 |
+
dropout=dropout,
|
| 229 |
+
num_singers=num_singers,
|
| 230 |
+
singer_embedding_dim=SINGER_EMBEDDING_DIM,
|
| 231 |
+
)
|
| 232 |
+
self.loss_fn = nn.MSELoss()
|
| 233 |
+
|
| 234 |
+
def forward(self, batch):
|
| 235 |
+
return self.model(batch["input_seq"], batch["singer_id"])
|
| 236 |
+
|
| 237 |
+
def training_step(self, batch, batch_idx):
|
| 238 |
+
inputs, labels = batch
|
| 239 |
+
preds = self(inputs)
|
| 240 |
+
loss = self.loss_fn(preds, labels)
|
| 241 |
+
self.log(
|
| 242 |
+
"train_loss",
|
| 243 |
+
loss,
|
| 244 |
+
on_step=True,
|
| 245 |
+
on_epoch=True,
|
| 246 |
+
prog_bar=True,
|
| 247 |
+
logger=True,
|
| 248 |
+
)
|
| 249 |
+
return loss
|
| 250 |
+
|
| 251 |
+
def validation_step(self, batch, batch_idx):
|
| 252 |
+
inputs, labels = batch
|
| 253 |
+
preds = self(inputs)
|
| 254 |
+
loss = self.loss_fn(preds, labels)
|
| 255 |
+
self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True)
|
| 256 |
+
|
| 257 |
+
def test_step(self, batch, batch_idx):
|
| 258 |
+
inputs, labels = batch
|
| 259 |
+
preds = self(inputs)
|
| 260 |
+
loss = self.loss_fn(preds, labels)
|
| 261 |
+
mae = nn.functional.l1_loss(preds, labels)
|
| 262 |
+
self.log("test_loss", loss, prog_bar=True)
|
| 263 |
+
self.log("test_mae", mae, prog_bar=True)
|
| 264 |
+
|
| 265 |
+
def configure_optimizers(self):
|
| 266 |
+
return torch.optim.Adam(
|
| 267 |
+
self.parameters(), lr=self.hparams.learning_rate
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
# --- 5. Main Execution Block ---
|
| 272 |
+
if __name__ == "__main__":
|
| 273 |
+
print("--- Starting Single Training Run with PyTorch Lightning ---")
|
| 274 |
+
data_module = DurationDataModule(model_type="B", batch_size=BATCH_SIZE)
|
| 275 |
+
|
| 276 |
+
# Must run setup() to access the number of singers
|
| 277 |
+
data_module.setup()
|
| 278 |
+
|
| 279 |
+
lightning_model = LightningTDP(
|
| 280 |
+
num_singers=data_module.num_singers, model_type="B"
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
wandb_logger = None
|
| 284 |
+
if WANDB_ENABLED:
|
| 285 |
+
try:
|
| 286 |
+
wandb_logger = WandbLogger(
|
| 287 |
+
project=WANDB_PROJECT,
|
| 288 |
+
# entity=WANDB_ENTITY,
|
| 289 |
+
name="lightning_single_run",
|
| 290 |
+
)
|
| 291 |
+
except Exception as e:
|
| 292 |
+
print(f"Could not initialize wandb: {e}. Disabling for this run.")
|
| 293 |
+
|
| 294 |
+
trainer = pl.Trainer(
|
| 295 |
+
max_epochs=5,
|
| 296 |
+
logger=wandb_logger,
|
| 297 |
+
callbacks=[TQDMProgressBar(refresh_rate=10)],
|
| 298 |
+
accelerator="auto",
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
print("Training Model B...")
|
| 302 |
+
trainer.fit(lightning_model, datamodule=data_module)
|
| 303 |
+
|
| 304 |
+
print("\n--- Evaluating on the held-out test set ---")
|
| 305 |
+
trainer.test(lightning_model, datamodule=data_module)
|
| 306 |
+
|
| 307 |
+
if WANDB_ENABLED and wandb.run:
|
| 308 |
+
wandb.finish()
|
| 309 |
+
|
| 310 |
+
print("\n--- Launching Gradio Demo ---")
|
| 311 |
+
model_for_demo = lightning_model.model.cpu()
|
| 312 |
+
|
| 313 |
+
# Need the mapping for the demo
|
| 314 |
+
singer_id_reverse_map = {
|
| 315 |
+
v: k for k, v in data_module.singer_id_map.items()
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
def predict_durations(quantized_durations_str, singer_id_from_user):
|
| 319 |
+
try:
|
| 320 |
+
# Map the user-provided singer ID to the model's internal index
|
| 321 |
+
if singer_id_from_user not in data_module.singer_id_map:
|
| 322 |
+
return f"Error: Singer ID {singer_id_from_user} not found in the dataset."
|
| 323 |
+
singer_idx = data_module.singer_id_map[singer_id_from_user]
|
| 324 |
+
|
| 325 |
+
durs = [int(x.strip()) for x in quantized_durations_str.split(",")]
|
| 326 |
+
inp_tensor = torch.tensor(durs, dtype=torch.long)
|
| 327 |
+
sid_tensor = torch.tensor([singer_idx], dtype=torch.long)
|
| 328 |
+
|
| 329 |
+
with torch.no_grad():
|
| 330 |
+
model_for_demo.eval()
|
| 331 |
+
preds = model_for_demo(inp_tensor.unsqueeze(0), sid_tensor)
|
| 332 |
+
|
| 333 |
+
return ", ".join([str(int(p)) for p in preds[0]])
|
| 334 |
+
except Exception as e:
|
| 335 |
+
return f"Error: {e}"
|
| 336 |
+
|
| 337 |
+
iface = gr.Interface(
|
| 338 |
+
fn=predict_durations,
|
| 339 |
+
inputs=[
|
| 340 |
+
gr.Textbox(
|
| 341 |
+
label="Quantized Durations (comma-separated)",
|
| 342 |
+
placeholder="30, 0, 75, 0, 45, 15, ...",
|
| 343 |
+
),
|
| 344 |
+
gr.Number(label="Singer ID (Original)", value=2),
|
| 345 |
+
],
|
| 346 |
+
outputs=gr.Textbox(label="Predicted Original Durations"),
|
| 347 |
+
title="Toy Duration Predictor (Lightning)",
|
| 348 |
+
description="Enter a sequence of quantized durations and an original singer ID to see the model predict the original, stylistic performance.",
|
| 349 |
+
)
|
| 350 |
+
iface.launch()
|
src/toy_duration_predictor/upload.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import torch
|
| 3 |
+
from huggingface_hub import HfApi
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
# --- IMPORTANT ---
|
| 7 |
+
# This script assumes you have already run the main training script and have
|
| 8 |
+
# the necessary files and class definitions available.
|
| 9 |
+
|
| 10 |
+
# Import your model class from your training script.
|
| 11 |
+
# Make sure the path is correct. For example, if your training script is in 'src/train.py':
|
| 12 |
+
# from src.train import LightningTDP, ToyDurationPredictor
|
| 13 |
+
#
|
| 14 |
+
# For this example, we will define the classes again to make the script standalone.
|
| 15 |
+
# In your real project, you should import them.
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import pytorch_lightning as pl
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ToyDurationPredictor(nn.Module):
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
hidden_size,
|
| 24 |
+
num_layers,
|
| 25 |
+
dropout,
|
| 26 |
+
num_singers,
|
| 27 |
+
singer_embedding_dim,
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.singer_embedding = nn.Embedding(num_singers, singer_embedding_dim)
|
| 31 |
+
self.rnn = nn.GRU(
|
| 32 |
+
input_size=1 + singer_embedding_dim,
|
| 33 |
+
hidden_size=hidden_size,
|
| 34 |
+
num_layers=num_layers,
|
| 35 |
+
bidirectional=True,
|
| 36 |
+
dropout=dropout,
|
| 37 |
+
batch_first=True,
|
| 38 |
+
)
|
| 39 |
+
self.fc = nn.Linear(hidden_size * 2, 1)
|
| 40 |
+
|
| 41 |
+
def forward(self, x_seq, x_sid):
|
| 42 |
+
x_seq = x_seq.unsqueeze(-1).float()
|
| 43 |
+
sid_emb = self.singer_embedding(x_sid)
|
| 44 |
+
sid_emb_expanded = sid_emb.unsqueeze(1).expand(-1, x_seq.size(1), -1)
|
| 45 |
+
combined_input = torch.cat([x_seq, sid_emb_expanded], dim=-1)
|
| 46 |
+
outputs, _ = self.rnn(combined_input)
|
| 47 |
+
prediction = self.fc(outputs)
|
| 48 |
+
return prediction.squeeze(-1)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class LightningTDP(pl.LightningModule):
|
| 52 |
+
def __init__(
|
| 53 |
+
self, num_singers, learning_rate, hidden_size, num_layers, dropout
|
| 54 |
+
):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.save_hyperparameters()
|
| 57 |
+
self.model = ToyDurationPredictor(
|
| 58 |
+
hidden_size=hidden_size,
|
| 59 |
+
num_layers=num_layers,
|
| 60 |
+
dropout=dropout,
|
| 61 |
+
num_singers=num_singers,
|
| 62 |
+
singer_embedding_dim=16, # Assuming a fixed value
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# --- 1. Configuration ---
|
| 67 |
+
# The path to the best checkpoint saved by PyTorch Lightning
|
| 68 |
+
BEST_MODEL_PATH = "./checkpoints/best-model-B.ckpt"
|
| 69 |
+
# Your Hugging Face username and the desired model repo name
|
| 70 |
+
REPO_ID = "ccss17/toy-duration-predictor"
|
| 71 |
+
# A local directory to stage your files before uploading
|
| 72 |
+
STAGING_DIR = Path("./hf_upload_staging")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def upload_model_to_hub():
|
| 76 |
+
"""
|
| 77 |
+
Loads a model from a checkpoint, prepares all necessary files,
|
| 78 |
+
and uploads them to the Hugging Face Hub.
|
| 79 |
+
"""
|
| 80 |
+
# Create the staging directory if it doesn't exist
|
| 81 |
+
STAGING_DIR.mkdir(exist_ok=True)
|
| 82 |
+
|
| 83 |
+
# --- 2. Load the final model and extract its state ---
|
| 84 |
+
print(f"Loading best model from: {BEST_MODEL_PATH}")
|
| 85 |
+
try:
|
| 86 |
+
lightning_model = LightningTDP.load_from_checkpoint(BEST_MODEL_PATH)
|
| 87 |
+
except FileNotFoundError:
|
| 88 |
+
print(f"ERROR: Checkpoint file not found at {BEST_MODEL_PATH}")
|
| 89 |
+
print("Please make sure you have run the training script first.")
|
| 90 |
+
return
|
| 91 |
+
|
| 92 |
+
# Extract the underlying PyTorch model (the weights)
|
| 93 |
+
final_pytorch_model = lightning_model.model
|
| 94 |
+
|
| 95 |
+
# Save the model's weights in the standard Hugging Face format
|
| 96 |
+
weights_path = STAGING_DIR / "pytorch_model.bin"
|
| 97 |
+
torch.save(final_pytorch_model.state_dict(), weights_path)
|
| 98 |
+
print(f"Model weights saved to {weights_path}")
|
| 99 |
+
|
| 100 |
+
# --- 3. Create the configuration file ---
|
| 101 |
+
# This saves all the hyperparameters needed to recreate the model architecture
|
| 102 |
+
config = {
|
| 103 |
+
"hidden_size": lightning_model.hparams.hidden_size,
|
| 104 |
+
"num_layers": lightning_model.hparams.num_layers,
|
| 105 |
+
"dropout": lightning_model.hparams.dropout,
|
| 106 |
+
"num_singers": lightning_model.hparams.num_singers,
|
| 107 |
+
"singer_embedding_dim": 16, # Assuming a fixed value
|
| 108 |
+
"architectures": [
|
| 109 |
+
"ToyDurationPredictor"
|
| 110 |
+
], # Link to the model class name
|
| 111 |
+
}
|
| 112 |
+
config_path = STAGING_DIR / "config.json"
|
| 113 |
+
with open(config_path, "w") as f:
|
| 114 |
+
json.dump(config, f, indent=4)
|
| 115 |
+
print(f"Model configuration saved to {config_path}")
|
| 116 |
+
|
| 117 |
+
# --- 4. Upload all files to the Hub ---
|
| 118 |
+
print(f"\nUploading files to repository: {REPO_ID}")
|
| 119 |
+
|
| 120 |
+
# Ensure you are logged in
|
| 121 |
+
# In your terminal run: huggingface-cli login
|
| 122 |
+
api = HfApi()
|
| 123 |
+
|
| 124 |
+
# Create the repository on the Hub (if it doesn't exist)
|
| 125 |
+
api.create_repo(repo_id=REPO_ID, repo_type="model", exist_ok=True)
|
| 126 |
+
|
| 127 |
+
# Upload the entire staging folder
|
| 128 |
+
api.upload_folder(
|
| 129 |
+
folder_path=STAGING_DIR,
|
| 130 |
+
repo_id=REPO_ID,
|
| 131 |
+
repo_type="model",
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
print("\nUpload complete! Your model is now on the Hugging Face Hub.")
|
| 135 |
+
print(
|
| 136 |
+
f"You can load it elsewhere using: AutoModel.from_pretrained('{REPO_ID}', trust_remote_code=True)"
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
if __name__ == "__main__":
|
| 141 |
+
# Run the upload process
|
| 142 |
+
upload_model_to_hub()
|
test.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
test.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from toy_duration_predictor.preprocess import mssv
|
| 2 |
+
import toy_duration_predictor.train_fastai as train_fastai
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def preprocessing():
|
| 6 |
+
mssv_path = "/mnt/d/dataset/004.다화자 가창 데이터"
|
| 7 |
+
mssv_preprocessed_path = "/mnt/d/dataset/mssv_preprocessed_duration"
|
| 8 |
+
mssv.preprocess_dataset(mssv_path, mssv_preprocessed_path)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def test_train():
|
| 12 |
+
train_fastai.test_train()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
if __name__ == "__main__":
|
| 16 |
+
# preprocessing()
|
| 17 |
+
test_train()
|
test_wandb.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import wandb
|
| 4 |
+
|
| 5 |
+
# Start a new wandb run to track this script.
|
| 6 |
+
run = wandb.init(
|
| 7 |
+
# Set the wandb entity where your project will be logged (generally your team name).
|
| 8 |
+
# entity="cccsss17",
|
| 9 |
+
# Set the wandb project where this run will be logged.
|
| 10 |
+
project="my-awesome-project",
|
| 11 |
+
# Track hyperparameters and run metadata.
|
| 12 |
+
config={
|
| 13 |
+
"learning_rate": 0.02,
|
| 14 |
+
"architecture": "CNN",
|
| 15 |
+
"dataset": "CIFAR-100",
|
| 16 |
+
"epochs": 10,
|
| 17 |
+
},
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
# Simulate training.
|
| 21 |
+
epochs = 10
|
| 22 |
+
offset = random.random() / 5
|
| 23 |
+
for epoch in range(2, epochs):
|
| 24 |
+
acc = 1 - 2**-epoch - random.random() / epoch - offset
|
| 25 |
+
loss = 2**-epoch + random.random() / epoch + offset
|
| 26 |
+
|
| 27 |
+
# Log metrics to wandb.
|
| 28 |
+
run.log({"acc": acc, "loss": loss})
|
| 29 |
+
|
| 30 |
+
# Finish the run and upload any remaining data.
|
| 31 |
+
run.finish()
|