| |
| |
| |
| """ |
| new_lr = 1e-7 # The new learning rate you want to use. |
| |
| print(f"Old learning rate: {optimizer.param_groups[0]['lr']}") |
| |
| for param_group in optimizer.param_groups: |
| param_group['lr'] = new_lr |
| |
| print(f"New learning rate: {optimizer.param_groups[0]['lr']}") |
| """ |
|
|
| |
|
|
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import ExponentialLR |
| import torchaudio |
| import torchcrepe |
| from tqdm import tqdm |
| import os |
| import glob |
| import signal |
|
|
| from beltout import BeltOutTTM |
|
|
| from arro3_dataset_streamer import Arro3StreamingDataset |
|
|
| |
| |
| interrupted = False |
| def signal_handler(sig, frame): |
| global interrupted |
| print("\nCtrl+C detected! Finishing current step and saving checkpoint...") |
| interrupted = True |
|
|
| |
| signal.signal(signal.SIGINT, signal_handler) |
|
|
| |
| COMPONENTS = ["english1-1", "english1-2", "english1-3", "english1-4", "english2-1", "english2-2", "english2-3", "other-1", "other-2", "other-3", "musical-1", "musical-2", "musical-3", "musical-4", "musical-5", "musical-6", "musical-7", "musical-8", "japanese-1", "japanese-2", "japanese-3", "persian-1", "greek-1", "esd-1", "esd-2", "esd-3", "chinese-1", "chinese-2", "romance-1", "romance-2", "romance-3", "romance-4"] |
| class AudioDataProcessor: |
| def __init__(self, segment_len_s=4): |
| self.target_sr = 24000 |
| self.segment_len_24k = segment_len_s * self.target_sr |
|
|
| def transform_example(self, example): |
| try: |
| waveforms = [] |
|
|
| for component in COMPONENTS: |
| audio_data = example[component] |
| waveform = torch.from_numpy(audio_data['array']).unsqueeze(0).float() |
| sr = audio_data['sampling_rate'] |
|
|
| |
| if sr != self.target_sr: |
| resampler_24k = torchaudio.transforms.Resample(sr, self.target_sr) |
| waveform_24k = resampler_24k(waveform) |
| else: |
| waveform_24k = waveform |
|
|
| |
| current_len = waveform_24k.shape[1] |
| if current_len > self.segment_len_24k: |
| |
| start = torch.randint(0, current_len - self.segment_len_24k, (1,)).item() |
| final_waveform = waveform_24k[:, start:start + self.segment_len_24k] |
| else: |
| |
| final_waveform = waveform_24k |
| |
| |
| waveforms.append(final_waveform.squeeze(0)) |
| |
| |
| return {"waveform_24k_batch": torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True, padding_value=0.0)} |
| except Exception as e: |
| |
| print(f"Skipping a problematic sample. Error: {e}") |
| return None |
|
|
| |
| def save_checkpoint(step, loss, decoder, pitchmvmt, optimizer, scheduler, checkpoint_dir="checkpoints_training"): |
| if not os.path.exists(checkpoint_dir): |
| os.makedirs(checkpoint_dir) |
| |
| |
| for old_ckpt in glob.glob(os.path.join(checkpoint_dir, "*.pth")): |
| print(f"Removing old checkpoint: {old_ckpt}") |
| os.remove(old_ckpt) |
|
|
| save_path = os.path.join(checkpoint_dir, f"training_step_{step}.pth") |
| |
| torch.save({ |
| 'step': step, |
| 'loss': loss, |
| 'decoder_state_dict': decoder.state_dict(), |
| 'pitchmvmt_state_dict': pitchmvmt.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'scheduler_state_dict': scheduler.state_dict(), |
| }, save_path) |
| print(f"Saved checkpoint to {save_path} with loss {loss:.4f}") |
|
|
| def load_latest_checkpoint(checkpoint_dir, decoder, pitchmvmt, optimizer, scheduler): |
| if not os.path.exists(checkpoint_dir): |
| print("No checkpoint directory found. Starting from scratch.") |
| return 0, None |
|
|
| checkpoints = glob.glob(os.path.join(checkpoint_dir, "*.pth")) |
| if not checkpoints: |
| print("No checkpoints found in directory. Starting from scratch.") |
| return 0, None |
|
|
| latest_ckpt_path = max(checkpoints, key=os.path.getctime) |
| print(f"Loading latest checkpoint by last modified time: {latest_ckpt_path}") |
| |
| checkpoint = torch.load(latest_ckpt_path, map_location=torch.device('cpu')) |
| |
| decoder.load_state_dict(checkpoint['decoder_state_dict']) |
| pitchmvmt.load_state_dict(checkpoint['pitchmvmt_state_dict']) |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
| |
| decoder.to(DEVICE) |
| pitchmvmt.to(DEVICE) |
| for state in optimizer.state.values(): |
| for k, v in state.items(): |
| if isinstance(v, torch.Tensor): |
| state[k] = v.to(DEVICE) |
|
|
| start_step = checkpoint.get('step', 0) + 1 |
| last_loss = checkpoint.get('loss', None) |
| |
| |
| if last_loss is not None: |
| print(f"Resuming training from step {start_step}. Last saved loss: {last_loss:.4f}") |
| else: |
| print(f"Resuming training from step {start_step}.") |
| |
| return start_step, last_loss |
|
|
| import questionary |
|
|
| def select_file_from_menu(folder_path: str, prefix: str) -> str | None: |
| """ |
| Scans a folder for files with a specific prefix, displays an interactive |
| menu, and returns the user's selection. |
| |
| Args: |
| folder_path: The path to the folder to search. |
| prefix: The file prefix to match (e.g., 'cfm_step_'). |
| |
| Returns: |
| The full path of the selected file, or None if no file was selected |
| or no matching files were found. |
| """ |
| try: |
| |
| all_files = os.listdir(folder_path) |
| matching_files = [ |
| f for f in all_files |
| if f.startswith(prefix) and os.path.isfile(os.path.join(folder_path, f)) |
| ] |
|
|
| if not matching_files: |
| print(f"No checkpoints found in '{folder_path}' with prefix '{prefix}'! Make sure to have at least one checkpoint downloaded for each model.") |
| return None |
|
|
| |
| |
| def key(name): |
| try: |
| return int(os.path.splitext(name)[0].replace(prefix, "").split("_")[-1]) |
| except ValueError: |
| return 0 |
| matching_files.sort(reverse=True, key=key) |
|
|
| |
| selected_file_name = questionary.select( |
| "Checkpoints:", |
| choices=matching_files, |
| use_indicator=True |
| ).ask() |
|
|
| if selected_file_name: |
| return os.path.join(folder_path, selected_file_name) |
| else: |
| |
| return None |
|
|
| except FileNotFoundError: |
| print(f"Error: The folder '{folder_path}' does not exist.") |
| return None |
|
|
| def select_checkpoint_from_menu(model_name: str) -> str: |
| print("Searching inside path './checkpoints' for available checkpoints...") |
| ckpt_path = select_file_from_menu("./checkpoints", model_name + "_") |
| if ckpt_path is None: |
| print("No checkpoints found for model '" + model_name + "'!") |
| exit() |
| return ckpt_path |
|
|
|
|
| |
| def main(): |
| global interrupted |
| |
| |
| global DEVICE |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
| |
| |
| LEARNING_RATE = 1e-4 |
| MAX_STEPS = 1000000 |
| SAVE_EVERY_N_STEPS = 1000 |
| CHECKPOINT_DIR = "checkpoints_training" |
| |
| F0_DROPOUT_PROB = 0.2 |
|
|
| |
| print("Pick starting checkpoints.") |
| decoder_ckpt_path = select_checkpoint_from_menu("cfm") |
| pitchmvmt_ckpt_path = select_checkpoint_from_menu("pitchmvmt") |
| encoder_ckpt_path = select_checkpoint_from_menu("encoder") |
| flow_ckpt_path = select_checkpoint_from_menu("flow") |
| mel2wav_ckpt_path = select_checkpoint_from_menu("mel2wav") |
| speaker_encoder_ckpt_path = select_checkpoint_from_menu("speaker_encoder") |
| tokenizer_ckpt_path = select_checkpoint_from_menu("tokenizer") |
| model = BeltOutTTM.from_local(decoder_ckpt_path, |
| pitchmvmt_ckpt_path, |
| encoder_ckpt_path, |
| flow_ckpt_path, |
| mel2wav_ckpt_path, |
| speaker_encoder_ckpt_path, |
| tokenizer_ckpt_path, device=DEVICE, eval=True) |
| model.decoder.train() |
| model.pitchmvmt.train() |
| print(f"Model loaded on {DEVICE}.") |
|
|
| |
| params_to_train = list(model.decoder.parameters()) + list(model.pitchmvmt.parameters()) |
| optimizer = AdamW(params_to_train, lr=LEARNING_RATE) |
| scheduler = ExponentialLR(optimizer, gamma=0.9999) |
| |
| |
| start_step, last_loss = load_latest_checkpoint(CHECKPOINT_DIR, model.decoder, model.pitchmvmt, optimizer, scheduler) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| data_processor = AudioDataProcessor(segment_len_s=4) |
| print("Loading dataset...") |
| dataset = Arro3StreamingDataset( |
| "Bill13579/combined-set-1", |
| transform_fn=data_processor.transform_example, |
| shuffle=True, |
| shuffle_buffer_size=32, |
| prefetch_files=2 |
| ) |
| dataloader = DataLoader(dataset, batch_size=None, num_workers=2, prefetch_factor=32) |
|
|
| |
| resampler_16k = torchaudio.transforms.Resample(24000, 16000).to(DEVICE) |
| crepe_sr = 16000 |
|
|
| |
| print("Starting...") |
| step = start_step |
| postfix_dict = {"loss": f"{last_loss:.4f}" if last_loss is not None else "N/A", "lr": f"{scheduler.get_last_lr()[0]:.1e}"} |
| training_progress = tqdm(total=MAX_STEPS, initial=start_step) |
| training_progress.set_postfix(postfix_dict) |
| |
| data_iterator = iter(dataloader) |
|
|
| epoch = 0 |
|
|
| while step < MAX_STEPS: |
| if interrupted: |
| break |
| |
| try: |
| batch = next(data_iterator) |
| except StopIteration: |
| print(f"Epoch {epoch} completed. Restarting dataloader!") |
| epoch += 1 |
| |
| data_iterator = iter(dataloader) |
| batch = next(data_iterator) |
|
|
| if batch is None: continue |
| |
| waveform_24k_batch = batch['waveform_24k_batch'].to(DEVICE) |
| |
| |
| with torch.no_grad(): |
| gt_mel = model.mel_extractor(waveform_24k_batch) |
| mel_len = gt_mel.shape[2] |
|
|
| waveform_16k_batch = resampler_16k(waveform_24k_batch) |
| s3_tokens, _ = model.tokenizer(waveform_16k_batch) |
| |
| x_vectors = torch.cat([ |
| model.embed_ref_x_vector(wf_24k.unsqueeze(0), 24000, device=DEVICE) |
| for wf_24k in waveform_24k_batch |
| ], dim=0) |
| speaker_embedding = model.flow.spk_embed_affine_layer(x_vectors) |
|
|
| |
| mel_frames_per_second = 50 |
| crepe_frames_per_second = 100 |
| crepe_hop_length = int(crepe_sr / float(crepe_frames_per_second)) |
| n_crepe_frame_in_mel_frame = int(crepe_frames_per_second / mel_frames_per_second) |
|
|
| |
| crepe_samples_needed = mel_len * n_crepe_frame_in_mel_frame * crepe_hop_length |
| padded_waveform_16k = waveform_16k_batch |
| pad_amount = crepe_samples_needed - padded_waveform_16k.shape[1] |
| if pad_amount > 0: |
| padded_waveform_16k = F.pad(torch.clone(padded_waveform_16k), (0, pad_amount)) |
|
|
| |
| crepe_embeddings_list = [] |
| for i in range(padded_waveform_16k.shape[0]): |
| |
| single_waveform = padded_waveform_16k[i:i+1, :] |
| |
| embedding = torchcrepe.embed( |
| single_waveform, |
| crepe_sr, |
| hop_length=crepe_hop_length, |
| model='tiny', |
| device=DEVICE, |
| ) |
| |
| crepe_embeddings_list.append(embedding[:, :mel_len*2, :, :]) |
|
|
| |
| crepe_embedding = torch.cat(crepe_embeddings_list, dim=0) |
|
|
| mel_lengths = torch.tensor([mel_len] * gt_mel.shape[0], device=DEVICE) |
| mask = (torch.arange(mel_len, device=DEVICE).unsqueeze(0) < mel_lengths.unsqueeze(1)).unsqueeze(1) |
|
|
| token_embeddings = model.flow.input_embedding(s3_tokens) |
| token_len = torch.tensor([token_embeddings.shape[1]] * token_embeddings.shape[0], device=DEVICE) |
| h, _ = model.encoder(token_embeddings, token_len) |
| encoded_tokens = model.flow.encoder_proj(h) |
| mu = encoded_tokens.transpose(1, 2) |
|
|
| B = crepe_embedding.shape[0] |
| projector_input = crepe_embedding.view(-1, n_crepe_frame_in_mel_frame, 256) |
| pitch_mvmt_encode_flat = model.pitchmvmt(projector_input) |
| pitch_mvmt_encode = pitch_mvmt_encode_flat.view(B, -1, 80).transpose(1, 2) |
| |
| optimizer.zero_grad() |
|
|
| |
| current_batch_size = pitch_mvmt_encode.shape[0] |
|
|
| |
| |
| num_to_drop = int(current_batch_size * F0_DROPOUT_PROB) |
| |
| |
| base_mask_false = torch.zeros(current_batch_size - num_to_drop, dtype=torch.bool) |
| base_mask_true = torch.ones(num_to_drop, dtype=torch.bool) |
| base_mask = torch.cat([base_mask_false, base_mask_true], dim=0) |
| |
| |
| |
| shuffled_indices = torch.randperm(current_batch_size) |
| dropout_mask = base_mask[shuffled_indices].to(DEVICE) |
| |
| |
| expanded_mask = dropout_mask.view(-1, 1, 1) |
| final_pitch_mvmt_encode = pitch_mvmt_encode.masked_fill(expanded_mask, 0.0) |
| |
|
|
| loss, _ = model.decoder.compute_loss( |
| x1=gt_mel, mask=mask, mu=mu, |
| spks=speaker_embedding, cond=final_pitch_mvmt_encode |
| ) |
| loss.backward() |
| optimizer.step() |
| scheduler.step() |
|
|
| |
| current_loss = loss.item() |
| postfix_dict["loss"] = f"{current_loss:.4f}" |
| postfix_dict["lr"] = f"{scheduler.get_last_lr()[0]:.1e}" |
| training_progress.set_postfix(postfix_dict) |
| training_progress.update(1) |
|
|
| if (step + 1) % SAVE_EVERY_N_STEPS == 0: |
| save_checkpoint(step, current_loss, model.decoder, model.pitchmvmt, optimizer, scheduler, CHECKPOINT_DIR) |
|
|
| step += 1 |
| |
| training_progress.close() |
| |
| |
| if start_step < step: |
| print("Performing final save...") |
| |
| final_loss = loss.item() if 'loss' in locals() else last_loss |
| save_checkpoint(step - 1, final_loss, model.decoder, model.pitchmvmt, optimizer, scheduler, CHECKPOINT_DIR) |
|
|
| print("Training finished.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|