Spaces:
Runtime error
Runtime error
add sample_rate and n_fft params
Browse files- README.md +3 -1
- scripts/audio_to_images.py +3 -1
- scripts/train_unconditional.py +5 -1
- scripts/train_vae.py +15 -3
README.md
CHANGED
|
@@ -71,7 +71,9 @@ python scripts/audio_to_images.py \
|
|
| 71 |
--output_dir data/audio-diffusion-256 \
|
| 72 |
--push_to_hub teticio/audio-diffusion-256
|
| 73 |
```
|
| 74 |
-
|
|
|
|
|
|
|
| 75 |
## Train model
|
| 76 |
#### Run training on local machine.
|
| 77 |
```bash
|
|
|
|
| 71 |
--output_dir data/audio-diffusion-256 \
|
| 72 |
--push_to_hub teticio/audio-diffusion-256
|
| 73 |
```
|
| 74 |
+
|
| 75 |
+
Note that the default `sample_rate` is 22050 and audios will be resampled if they are at a different rate. If you change this value, you may find that the results in the `test_mel.ipynb` notebook are not good (for example, if `sample_rate` is 48000) and that it is necessary to adjust `n_fft` (for example, to 2000 instead of the default value of 2048; alternatively, you can resample to a `sample_rate` of 44100). Make sure you use the same parameters for training and inference. You should also bear in mind that not all resolutions work with the neural network architecture as currently configured - you should be safe if you stick to powers of 2.
|
| 76 |
+
|
| 77 |
## Train model
|
| 78 |
#### Run training on local machine.
|
| 79 |
```bash
|
scripts/audio_to_images.py
CHANGED
|
@@ -19,7 +19,8 @@ def main(args):
|
|
| 19 |
mel = Mel(x_res=args.resolution[0],
|
| 20 |
y_res=args.resolution[1],
|
| 21 |
hop_length=args.hop_length,
|
| 22 |
-
sample_rate=args.sample_rate
|
|
|
|
| 23 |
os.makedirs(args.output_dir, exist_ok=True)
|
| 24 |
audio_files = [
|
| 25 |
os.path.join(root, file) for root, _, files in os.walk(args.input_dir)
|
|
@@ -86,6 +87,7 @@ if __name__ == "__main__":
|
|
| 86 |
parser.add_argument("--hop_length", type=int, default=512)
|
| 87 |
parser.add_argument("--push_to_hub", type=str, default=None)
|
| 88 |
parser.add_argument("--sample_rate", type=int, default=22050)
|
|
|
|
| 89 |
args = parser.parse_args()
|
| 90 |
|
| 91 |
if args.input_dir is None:
|
|
|
|
| 19 |
mel = Mel(x_res=args.resolution[0],
|
| 20 |
y_res=args.resolution[1],
|
| 21 |
hop_length=args.hop_length,
|
| 22 |
+
sample_rate=args.sample_rate,
|
| 23 |
+
n_fft=args.n_fft)
|
| 24 |
os.makedirs(args.output_dir, exist_ok=True)
|
| 25 |
audio_files = [
|
| 26 |
os.path.join(root, file) for root, _, files in os.walk(args.input_dir)
|
|
|
|
| 87 |
parser.add_argument("--hop_length", type=int, default=512)
|
| 88 |
parser.add_argument("--push_to_hub", type=str, default=None)
|
| 89 |
parser.add_argument("--sample_rate", type=int, default=22050)
|
| 90 |
+
parser.add_argument("--n_fft", type=int, default=2048)
|
| 91 |
args = parser.parse_args()
|
| 92 |
|
| 93 |
if args.input_dir is None:
|
scripts/train_unconditional.py
CHANGED
|
@@ -173,7 +173,9 @@ def main(args):
|
|
| 173 |
|
| 174 |
mel = Mel(x_res=resolution[1],
|
| 175 |
y_res=resolution[0],
|
| 176 |
-
hop_length=args.hop_length
|
|
|
|
|
|
|
| 177 |
|
| 178 |
global_step = 0
|
| 179 |
for epoch in range(args.num_epochs):
|
|
@@ -362,6 +364,8 @@ if __name__ == "__main__":
|
|
| 362 |
"and an Nvidia Ampere GPU."),
|
| 363 |
)
|
| 364 |
parser.add_argument("--hop_length", type=int, default=512)
|
|
|
|
|
|
|
| 365 |
parser.add_argument("--from_pretrained", type=str, default=None)
|
| 366 |
parser.add_argument("--start_epoch", type=int, default=0)
|
| 367 |
parser.add_argument("--num_train_steps", type=int, default=1000)
|
|
|
|
| 173 |
|
| 174 |
mel = Mel(x_res=resolution[1],
|
| 175 |
y_res=resolution[0],
|
| 176 |
+
hop_length=args.hop_length,
|
| 177 |
+
sample_rate=args.sample_rate,
|
| 178 |
+
n_fft=args.n_fft)
|
| 179 |
|
| 180 |
global_step = 0
|
| 181 |
for epoch in range(args.num_epochs):
|
|
|
|
| 364 |
"and an Nvidia Ampere GPU."),
|
| 365 |
)
|
| 366 |
parser.add_argument("--hop_length", type=int, default=512)
|
| 367 |
+
parser.add_argument("--sample_rate", type=int, default=22050)
|
| 368 |
+
parser.add_argument("--n_fft", type=int, default=2048)
|
| 369 |
parser.add_argument("--from_pretrained", type=str, default=None)
|
| 370 |
parser.add_argument("--start_epoch", type=int, default=0)
|
| 371 |
parser.add_argument("--num_train_steps", type=int, default=1000)
|
scripts/train_vae.py
CHANGED
|
@@ -60,10 +60,16 @@ class AudioDiffusionDataModule(pl.LightningDataModule):
|
|
| 60 |
|
| 61 |
class ImageLogger(Callback):
|
| 62 |
|
| 63 |
-
def __init__(self,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
super().__init__()
|
| 65 |
self.every = every
|
| 66 |
self.hop_length = hop_length
|
|
|
|
|
|
|
| 67 |
|
| 68 |
@rank_zero_only
|
| 69 |
def log_images_and_audios(self, pl_module, batch):
|
|
@@ -76,7 +82,9 @@ class ImageLogger(Callback):
|
|
| 76 |
channels = image_shape[1]
|
| 77 |
mel = Mel(x_res=image_shape[2],
|
| 78 |
y_res=image_shape[3],
|
| 79 |
-
hop_length=self.hop_length
|
|
|
|
|
|
|
| 80 |
|
| 81 |
for k in images:
|
| 82 |
images[k] = images[k].detach().cpu()
|
|
@@ -145,6 +153,8 @@ if __name__ == "__main__":
|
|
| 145 |
type=int,
|
| 146 |
default=1)
|
| 147 |
parser.add_argument("--hop_length", type=int, default=512)
|
|
|
|
|
|
|
| 148 |
parser.add_argument("--save_images_batches", type=int, default=1000)
|
| 149 |
parser.add_argument("--max_epochs", type=int, default=100)
|
| 150 |
args = parser.parse_args()
|
|
@@ -166,7 +176,9 @@ if __name__ == "__main__":
|
|
| 166 |
resume_from_checkpoint=args.resume_from_checkpoint,
|
| 167 |
callbacks=[
|
| 168 |
ImageLogger(every=args.save_images_batches,
|
| 169 |
-
hop_length=args.hop_length
|
|
|
|
|
|
|
| 170 |
HFModelCheckpoint(ldm_config=config,
|
| 171 |
hf_checkpoint=args.hf_checkpoint_dir,
|
| 172 |
dirpath=args.ldm_checkpoint_dir,
|
|
|
|
| 60 |
|
| 61 |
class ImageLogger(Callback):
|
| 62 |
|
| 63 |
+
def __init__(self,
|
| 64 |
+
every=1000,
|
| 65 |
+
hop_length=512,
|
| 66 |
+
sample_rate=22050,
|
| 67 |
+
n_fft=2048):
|
| 68 |
super().__init__()
|
| 69 |
self.every = every
|
| 70 |
self.hop_length = hop_length
|
| 71 |
+
self.sample_rate = sample_rate
|
| 72 |
+
self.n_fft = n_fft
|
| 73 |
|
| 74 |
@rank_zero_only
|
| 75 |
def log_images_and_audios(self, pl_module, batch):
|
|
|
|
| 82 |
channels = image_shape[1]
|
| 83 |
mel = Mel(x_res=image_shape[2],
|
| 84 |
y_res=image_shape[3],
|
| 85 |
+
hop_length=self.hop_length,
|
| 86 |
+
sample_rate=self.sample_rate,
|
| 87 |
+
n_fft=self.n_fft)
|
| 88 |
|
| 89 |
for k in images:
|
| 90 |
images[k] = images[k].detach().cpu()
|
|
|
|
| 153 |
type=int,
|
| 154 |
default=1)
|
| 155 |
parser.add_argument("--hop_length", type=int, default=512)
|
| 156 |
+
parser.add_argument("--sample_rate", type=int, default=22050)
|
| 157 |
+
parser.add_argument("--n_fft", type=int, default=2048)
|
| 158 |
parser.add_argument("--save_images_batches", type=int, default=1000)
|
| 159 |
parser.add_argument("--max_epochs", type=int, default=100)
|
| 160 |
args = parser.parse_args()
|
|
|
|
| 176 |
resume_from_checkpoint=args.resume_from_checkpoint,
|
| 177 |
callbacks=[
|
| 178 |
ImageLogger(every=args.save_images_batches,
|
| 179 |
+
hop_length=args.hop_length,
|
| 180 |
+
sample_rate=args.sample_rate,
|
| 181 |
+
n_fft=args.n_fft),
|
| 182 |
HFModelCheckpoint(ldm_config=config,
|
| 183 |
hf_checkpoint=args.hf_checkpoint_dir,
|
| 184 |
dirpath=args.ldm_checkpoint_dir,
|