primepake commited on
Commit
bfafefe
·
1 Parent(s): 1c33894
Files changed (3) hide show
  1. README.md +17 -2
  2. speech/config.yaml +1 -1
  3. speech/train.py +8 -0
README.md CHANGED
@@ -79,16 +79,31 @@ pip install -r requirements.txt
79
  --model speech_tokenizer_v2_25hz \
80
  --device "cuda" \
81
  --batch_size 64 \
82
- --file_list /data/learnable-speech/speech/files_test.txt \
83
  --skip_existing
84
  ```
85
 
86
  2. **Extracting DAC-VAE latent**
87
  ```bash
88
  cd dac-vae
89
- python inference.py --checkpoint checkpoint.pt --config config.yml
90
  ```
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  3. **Stage 1: Auto Regressive Transformer**
93
  ```bash
94
  #!/bin/bash
 
79
  --model speech_tokenizer_v2_25hz \
80
  --device "cuda" \
81
  --batch_size 64 \
82
+ --file_list /speech/files_test.txt \
83
  --skip_existing
84
  ```
85
 
86
  2. **Extracting DAC-VAE latent**
87
  ```bash
88
  cd dac-vae
89
+ python extract_dac_latents.py --checkpoint checkpoint.pt --config config.yml --root_path dataset --output_dir dataset/dac
90
  ```
91
 
92
+ After processing you should have root folder with following files:
93
+
94
+ ```
95
+ dataset_root/
96
+ ├── audio_name.wav
97
+ ├── audio_name.txt
98
+ ├── audio_name_fsq.pt
99
+ ├── audio_name_latent.pt
100
+ ├── another_audio.wav
101
+ ├── another_audio.txt
102
+ ├── another_audio_fsq.pt
103
+ ├── another_audio_latent.pt
104
+ └── ...
105
+ ```
106
+
107
  3. **Stage 1: Auto Regressive Transformer**
108
  ```bash
109
  #!/bin/bash
speech/config.yaml CHANGED
@@ -221,7 +221,7 @@ train_conf:
221
  scheduler_conf:
222
  warmup_steps: 500
223
  max_epoch: 2000
224
- grad_clip: 5
225
  accum_grad: 1
226
  log_interval: 5
227
  save_per_step: 2000
 
221
  scheduler_conf:
222
  warmup_steps: 500
223
  max_epoch: 2000
224
+ grad_clip: 1
225
  accum_grad: 1
226
  log_interval: 5
227
  save_per_step: 2000
speech/train.py CHANGED
@@ -54,6 +54,7 @@ def get_args():
54
  "--qwen_pretrain_path", required=False, help="qwen pretrain path"
55
  )
56
  parser.add_argument("--checkpoint", help="checkpoint model")
 
57
  parser.add_argument("--model_dir", required=True, help="save model dir")
58
  parser.add_argument(
59
  "--tensorboard_dir", default="tensorboard", help="tensorboard log dir"
@@ -209,6 +210,13 @@ def main():
209
 
210
  model = configs[args.model]
211
  start_step, start_epoch = 0, -1
 
 
 
 
 
 
 
212
  if args.checkpoint is not None:
213
  if os.path.exists(args.checkpoint):
214
  logger.info(f"Load checkpoint from {args.checkpoint}")
 
54
  "--qwen_pretrain_path", required=False, help="qwen pretrain path"
55
  )
56
  parser.add_argument("--checkpoint", help="checkpoint model")
57
+ parser.add_argument("--pretrained_model", help="pretrained model")
58
  parser.add_argument("--model_dir", required=True, help="save model dir")
59
  parser.add_argument(
60
  "--tensorboard_dir", default="tensorboard", help="tensorboard log dir"
 
210
 
211
  model = configs[args.model]
212
  start_step, start_epoch = 0, -1
213
+
214
+ if args.pretrained_model is not None:
215
+ # load the pretrained model with some weights is ignore
216
+ logger.info(f"Load pretrained model from {args.pretrained_model}")
217
+ state_dict = torch.load(args.pretrained_model, map_location="cpu")
218
+ model.load_state_dict(state_dict, strict=False)
219
+
220
  if args.checkpoint is not None:
221
  if os.path.exists(args.checkpoint):
222
  logger.info(f"Load checkpoint from {args.checkpoint}")