Spaces:
Sleeping
Sleeping
| # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import logging | |
| import random | |
| import pyarrow.parquet as pq | |
| from io import BytesIO | |
| import torch | |
| import torchaudio | |
| from torch.nn.utils.rnn import pad_sequence | |
| import torch.nn.functional as F | |
| import pyworld as pw | |
| import glob | |
| import os | |
| import json | |
| import traceback | |
| AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'} | |
| def individual_file_opener(data, mode='train', tts_data={}, token_latent_ratio=3): | |
| """Load data from individual files listed in files.txt | |
| Args: | |
| data: Iterable[{src}] where src is path to files.txt containing audio paths | |
| mode: 'train' or 'test' | |
| tts_data: Dict for TTS mode | |
| Yields: | |
| Dict with all required fields for training | |
| """ | |
| for sample in data: | |
| assert 'src' in sample | |
| src = sample['src'] | |
| # Load file list from files.txt | |
| file_list = [] | |
| # Check if src is a files.txt file | |
| if src.endswith('.txt'): | |
| with open(src, 'r') as f: | |
| wav_files = [line.strip() for line in f if line.strip()] | |
| for wav_path in wav_files: | |
| # Skip empty lines or comments | |
| if not wav_path or wav_path.startswith('#'): | |
| continue | |
| # Verify wav file exists | |
| if not os.path.exists(wav_path): | |
| logging.warning(f'Audio file not found: {wav_path}, skipping') | |
| continue | |
| # Check if all required files exist | |
| txt_path = wav_path.replace('.wav', '.txt') | |
| token_path = wav_path.replace('.wav', '_fsq.pt') | |
| latent_path = wav_path.replace('.wav', '_latent2x.pt') | |
| if not os.path.exists(txt_path): | |
| logging.warning(f'Text file not found for {wav_path}, skipping') | |
| continue | |
| if not os.path.exists(token_path): | |
| logging.warning(f'Token file not found for {wav_path}, skipping') | |
| continue | |
| if not os.path.exists(latent_path): | |
| logging.warning(f'Latent file not found for {wav_path}, skipping') | |
| continue | |
| # Extract metadata | |
| utt = os.path.basename(wav_path).replace('.wav', '') | |
| # Try to extract speaker from filename (assuming format: spk_*.wav) | |
| spk = utt.split('_')[0] if '_' in utt else 'default' | |
| file_info = { | |
| 'utt': utt, | |
| 'spk': spk, | |
| 'wav': wav_path, | |
| 'text_path': txt_path, | |
| 'token_path': token_path, | |
| 'latent_path': latent_path, | |
| } | |
| logging.info(f'file_info {file_info}') | |
| file_list.append(file_info) | |
| elif src.endswith('.json'): | |
| # Keep backward compatibility with JSON index files | |
| with open(src, 'r') as f: | |
| index_data = json.load(f) | |
| file_list = index_data.get('data', []) | |
| else: | |
| # Assume it's a directory for backward compatibility | |
| wav_files = glob.glob(os.path.join(src, '*/*/*wav')) | |
| if not wav_files: | |
| wav_files = glob.glob(os.path.join(src, '**/*.wav'), recursive=True) | |
| for wav_path in wav_files: | |
| txt_path = wav_path.replace('.wav', '.txt') | |
| token_path = wav_path.replace('.wav', '_fsq.pt') | |
| latent_path = wav_path.replace('.wav', '_latent2x.pt') | |
| if not os.path.exists(txt_path): | |
| logging.warning(f'Text file not found for {wav_path}, skipping') | |
| continue | |
| utt = os.path.basename(wav_path).replace('.wav', '') | |
| spk = utt.split('_')[0] | |
| file_info = { | |
| 'utt': utt, | |
| 'spk': spk, | |
| 'wav': wav_path, | |
| 'text_path': txt_path, | |
| 'token_path': token_path, | |
| 'latent_path': latent_path, | |
| } | |
| file_list.append(file_info) | |
| logging.info(f'Found {len(file_list)} valid audio files from {src}') | |
| # Process each file | |
| for file_info in file_list: | |
| try: | |
| # Read audio data | |
| with open(file_info['wav'], 'rb') as f: | |
| audio_data = f.read() | |
| # Read text | |
| with open(file_info['text_path'], 'r', encoding='utf-8') as f: | |
| text = ''.join(l.strip() for l in f.readlines()) | |
| # Load speech token | |
| speech_token = torch.load(file_info['token_path'], map_location='cpu', weights_only=False) | |
| if isinstance(speech_token, torch.Tensor): | |
| speech_token = speech_token.tolist() | |
| # Load speech latent | |
| speech_latent = torch.load(file_info['latent_path'], map_location='cpu', weights_only=False) | |
| speech_latent = speech_latent['z'].transpose(0, 1) | |
| if token_latent_ratio != 0: | |
| # trim to align speech_token and speech_feat | |
| print('before algin speech_latent: ', speech_latent.shape) | |
| token_len = int(min(speech_latent.shape[0] / token_latent_ratio, len(speech_token))) | |
| speech_latent = speech_latent[:token_latent_ratio * token_len] | |
| speech_token = speech_token[:token_len] | |
| print('after algin speech_latent: ', speech_latent.shape) | |
| # Build sample dict | |
| sample_dict = { | |
| 'utt': file_info['utt'], | |
| 'spk': file_info['spk'], | |
| 'audio_data': audio_data, | |
| 'text': text, | |
| 'text_token': [], # Will be filled by tokenize processor | |
| 'speech_token': speech_token, | |
| 'wav': file_info['wav'], # Keep original path for reference | |
| 'speech_latent': speech_latent, | |
| } | |
| # Copy over any additional fields from the original sample | |
| for key, value in sample.items(): | |
| if key not in sample_dict: | |
| sample_dict[key] = value | |
| if mode == 'train': | |
| yield sample_dict | |
| else: | |
| # For TTS mode | |
| if file_info['utt'] in tts_data: | |
| for index, tts_text in enumerate(tts_data[file_info['utt']]): | |
| yield {**sample_dict, 'tts_index': index, 'tts_text': tts_text} | |
| else: | |
| yield sample_dict | |
| except Exception as ex: | |
| logging.warning(f'Failed to process {file_info["wav"]}: {ex}') | |
| def parquet_opener(data, mode='train', tts_data={}): | |
| """ Give url or local file, return file descriptor | |
| Inplace operation. | |
| Args: | |
| data(Iterable[str]): url or local file list | |
| Returns: | |
| Iterable[{src, stream}] | |
| """ | |
| for sample in data: | |
| assert 'src' in sample | |
| url = sample['src'] | |
| try: | |
| for df in pq.ParquetFile(url).iter_batches(batch_size=64): | |
| df = df.to_pandas() | |
| for i in range(len(df)): | |
| sample.update(dict(df.loc[i])) | |
| if mode == 'train': | |
| # NOTE do not return sample directly, must initialize a new dict | |
| yield {**sample} | |
| else: | |
| for index, text in enumerate(tts_data[df.loc[i, 'utt']]): | |
| yield {**sample, 'tts_index': index, 'tts_text': text} | |
| except Exception as ex: | |
| logging.warning('Failed to open {}, ex info {}'.format(url, ex)) | |
| def filter(data, | |
| max_length=10240, | |
| min_length=10, | |
| token_max_length=200, | |
| token_min_length=1, | |
| min_output_input_ratio=0.0005, | |
| max_output_input_ratio=1, | |
| mode='train'): | |
| """ Filter sample according to feature and label length | |
| Inplace operation. | |
| Args:: | |
| data: Iterable[{key, wav, label, sample_rate}] | |
| max_length: drop utterance which is greater than max_length(10ms) | |
| min_length: drop utterance which is less than min_length(10ms) | |
| token_max_length: drop utterance which is greater than | |
| token_max_length, especially when use char unit for | |
| english modeling | |
| token_min_length: drop utterance which is | |
| less than token_max_length | |
| min_output_input_ratio: minimal ration of | |
| token_length / feats_length(10ms) | |
| max_output_input_ratio: maximum ration of | |
| token_length / feats_length(10ms) | |
| Returns: | |
| Iterable[{key, wav, label, sample_rate}] | |
| """ | |
| for sample in data: | |
| sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data'])) | |
| sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) | |
| del sample['audio_data'] | |
| # sample['wav'] is torch.Tensor, we have 100 frames every second | |
| num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100 | |
| if num_frames < min_length: | |
| continue | |
| if num_frames > max_length: | |
| continue | |
| if len(sample['text_token']) < token_min_length: | |
| continue | |
| if len(sample['text_token']) > token_max_length: | |
| continue | |
| if len(sample['speech_token']) == 0: | |
| continue | |
| if 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0: | |
| continue | |
| if num_frames != 0: | |
| if len(sample['text_token']) / num_frames < min_output_input_ratio: | |
| print('continue text_token') | |
| continue | |
| if len(sample['text_token']) / num_frames > max_output_input_ratio: | |
| print('continue text_token') | |
| continue | |
| yield sample | |
| def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'): | |
| """ Resample data. | |
| Inplace operation. | |
| Args: | |
| data: Iterable[{key, wav, label, sample_rate}] | |
| resample_rate: target resample rate | |
| Returns: | |
| Iterable[{key, wav, label, sample_rate}] | |
| """ | |
| for sample in data: | |
| assert 'sample_rate' in sample | |
| assert 'speech' in sample | |
| sample_rate = sample['sample_rate'] | |
| waveform = sample['speech'] | |
| if sample_rate != resample_rate: | |
| if sample_rate < min_sample_rate: | |
| print('continue sample_rate') | |
| continue | |
| sample['sample_rate'] = resample_rate | |
| sample['speech'] = torchaudio.transforms.Resample( | |
| orig_freq=sample_rate, new_freq=resample_rate)(waveform) | |
| max_val = sample['speech'].abs().max() | |
| if max_val > 1: | |
| sample['speech'] /= max_val | |
| yield sample | |
| def truncate(data, truncate_length=24576, mode='train'): | |
| """ Truncate data. | |
| Args: | |
| data: Iterable[{key, wav, label, sample_rate}] | |
| truncate_length: truncate length | |
| Returns: | |
| Iterable[{key, wav, label, sample_rate}] | |
| """ | |
| for sample in data: | |
| waveform = sample['speech'] | |
| if waveform.shape[1] > truncate_length: | |
| start = random.randint(0, waveform.shape[1] - truncate_length) | |
| waveform = waveform[:, start: start + truncate_length] | |
| else: | |
| waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1) | |
| sample['speech'] = waveform | |
| yield sample | |
| def extract_reference_mel_from_speech( | |
| data, | |
| feat_extractor, | |
| min_length=2.0, | |
| max_length=6.0, | |
| num_crops=2, # Multiple random crops from same utterance | |
| training=True, | |
| sample_rate=24000, | |
| mode='train' | |
| ): | |
| """ | |
| Extract mel spectrograms from current speech waveform with random cropping. | |
| This creates multiple random crops from the same utterance for training diversity. | |
| """ | |
| for sample in data: | |
| # Use the current speech waveform | |
| waveform = sample['speech'] # [1, T] | |
| speech_length = waveform.shape[1] | |
| # Convert time to samples | |
| min_samples = int(min_length * sample_rate) | |
| max_samples = int(max_length * sample_rate) | |
| reference_mels = [] | |
| reference_mel_lengths = [] | |
| # Skip if utterance is too short | |
| if speech_length < min_samples: | |
| logging.warning(f"Speech for {sample['utt']} is too short ({speech_length/sample_rate:.2f}s)") | |
| sample['reference_mels'] = [] | |
| sample['reference_mel_lengths'] = [] | |
| sample['num_references'] = 0 | |
| print('continue num_references') | |
| yield sample | |
| continue | |
| # Generate multiple crops from the same utterance | |
| crops_to_generate = num_crops if training else 1 | |
| for i in range(crops_to_generate): | |
| if training and speech_length > max_samples: | |
| # Random crop during training | |
| crop_length = random.randint(min_samples, min(max_samples, speech_length)) | |
| start_idx = random.randint(0, speech_length - crop_length) | |
| audio_segment = waveform[:, start_idx:start_idx + crop_length] | |
| elif speech_length > max_samples: | |
| # Center crop during inference | |
| start_idx = (speech_length - max_samples) // 2 | |
| audio_segment = waveform[:, start_idx:start_idx + max_samples] | |
| else: | |
| # Use full audio if shorter than max_length | |
| audio_segment = waveform | |
| # For training, if we need multiple crops but audio is short, | |
| # we can add slight variations | |
| if training and i > 0: | |
| # Add very slight noise for variation | |
| noise = torch.randn_like(audio_segment) * 0.001 | |
| audio_segment = audio_segment + noise | |
| # Normalize audio segment | |
| max_val = audio_segment.abs().max() | |
| if max_val > 0: | |
| audio_segment = audio_segment / max_val | |
| # Extract mel spectrogram | |
| mel = feat_extractor(audio_segment).squeeze(0) # Remove batch dim [C, T] | |
| reference_mels.append(mel) | |
| reference_mel_lengths.append(mel.shape[1]) | |
| sample['reference_mels'] = reference_mels | |
| sample['reference_mel_lengths'] = reference_mel_lengths | |
| sample['num_references'] = len(reference_mels) | |
| yield sample | |
| def compute_fbank(data, | |
| feat_extractor, | |
| token_mel_ratio=0, | |
| mode='train'): | |
| """ Extract fbank | |
| Args: | |
| data: Iterable[{key, wav, label, sample_rate}] | |
| Returns: | |
| Iterable[{key, feat, label}] | |
| """ | |
| for sample in data: | |
| assert 'sample_rate' in sample | |
| assert 'speech' in sample | |
| assert 'utt' in sample | |
| assert 'text_token' in sample | |
| waveform = sample['speech'] | |
| feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1) | |
| # if token_mel_ratio != 0: | |
| # pass | |
| # trim to align speech_token and speech_feat | |
| # token_len = int(min(feat.shape[0] / token_mel_ratio, sample["speech_token"].shape[0])) | |
| # feat = feat[:token_mel_ratio * token_len] | |
| # sample["speech_token"] = sample["speech_token"][:token_len] | |
| sample['speech_mel'] = feat | |
| # print('feat shape, ', feat.shape) | |
| yield sample | |
| def tokenize(data, get_tokenizer, allowed_special, mode='train'): | |
| """ Decode text to chars or BPE | |
| Inplace operation | |
| Args: | |
| data: Iterable[{key, wav, txt, sample_rate}] | |
| Returns: | |
| Iterable[{key, wav, txt, tokens, label, sample_rate}] | |
| """ | |
| tokenizer = get_tokenizer() | |
| for sample in data: | |
| assert 'text' in sample | |
| sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special) | |
| yield sample | |
| def shuffle(data, shuffle_size=10000, mode='train'): | |
| """ Local shuffle the data | |
| Args: | |
| data: Iterable[{key, feat, label}] | |
| shuffle_size: buffer size for shuffle | |
| Returns: | |
| Iterable[{key, feat, label}] | |
| """ | |
| buf = [] | |
| for sample in data: | |
| buf.append(sample) | |
| if len(buf) >= shuffle_size: | |
| random.shuffle(buf) | |
| for x in buf: | |
| yield x | |
| buf = [] | |
| # The sample left over | |
| random.shuffle(buf) | |
| for x in buf: | |
| yield x | |
| def sort(data, sort_size=500, mode='train'): | |
| """ Sort the data by feature length. | |
| Sort is used after shuffle and before batch, so we can group | |
| utts with similar lengths into a batch, and `sort_size` should | |
| be less than `shuffle_size` | |
| Args: | |
| data: Iterable[{key, feat, label}] | |
| sort_size: buffer size for sort | |
| Returns: | |
| Iterable[{key, feat, label}] | |
| """ | |
| buf = [] | |
| for sample in data: | |
| buf.append(sample) | |
| if len(buf) >= sort_size: | |
| buf.sort(key=lambda x: x['speech_latent'].size(0)) | |
| for x in buf: | |
| yield x | |
| buf = [] | |
| # The sample left over | |
| buf.sort(key=lambda x: x['speech_latent'].size(0)) | |
| for x in buf: | |
| yield x | |
| def static_batch(data, batch_size=16): | |
| """ Static batch the data by `batch_size` | |
| Args: | |
| data: Iterable[{key, feat, label}] | |
| batch_size: batch size | |
| Returns: | |
| Iterable[List[{key, feat, label}]] | |
| """ | |
| buf = [] | |
| for sample in data: | |
| buf.append(sample) | |
| if len(buf) >= batch_size: | |
| yield buf | |
| buf = [] | |
| if len(buf) > 0: | |
| yield buf | |
| def dynamic_batch(data, max_frames_in_batch=12000, mode='train'): | |
| """ Dynamic batch the data until the total frames in batch | |
| reach `max_frames_in_batch` | |
| Args: | |
| data: Iterable[{key, feat, label}] | |
| max_frames_in_batch: max_frames in one batch | |
| Returns: | |
| Iterable[List[{key, feat, label}]] | |
| """ | |
| buf = [] | |
| longest_frames = 0 | |
| for sample in data: | |
| assert 'speech_latent' in sample | |
| assert isinstance(sample['speech_latent'], torch.Tensor) | |
| new_sample_frames = sample['speech_latent'].size(0) | |
| longest_frames = max(longest_frames, new_sample_frames) | |
| frames_after_padding = longest_frames * (len(buf) + 1) | |
| if frames_after_padding > max_frames_in_batch: | |
| yield buf | |
| buf = [sample] | |
| longest_frames = new_sample_frames | |
| else: | |
| buf.append(sample) | |
| if len(buf) > 0: | |
| yield buf | |
| def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'): | |
| """ Wrapper for static/dynamic batch | |
| """ | |
| if batch_type == 'static': | |
| return static_batch(data, batch_size) | |
| elif batch_type == 'dynamic': | |
| return dynamic_batch(data, max_frames_in_batch) | |
| else: | |
| logging.fatal('Unsupported batch type {}'.format(batch_type)) | |
| def padding(data, mode='train', gan=False, dpo=False, use_speaker_encoder=False): | |
| """ Padding the data into training data | |
| Args: | |
| data: Iterable[List[{key, feat, label}]] | |
| use_speaker_encoder: Whether to prepare reference mels for speaker encoder | |
| Returns: | |
| Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] | |
| """ | |
| for sample in data: | |
| assert isinstance(sample, list) | |
| speech_latent_len = torch.tensor([x['speech_latent'].size(0) for x in sample], # Changed from size(1) to size(0) | |
| dtype=torch.int32) | |
| order = torch.argsort(speech_latent_len, descending=True) | |
| utts = [sample[i]['utt'] for i in order] | |
| speech = [sample[i]['speech'].squeeze(dim=0) for i in order] | |
| speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32) | |
| speech = pad_sequence(speech, batch_first=True, padding_value=0) | |
| # Handle speech_token - check if it's already a tensor | |
| speech_token = [] | |
| for i in order: | |
| if isinstance(sample[i]['speech_token'], torch.Tensor): | |
| speech_token.append(sample[i]['speech_token']) | |
| else: | |
| speech_token.append(torch.tensor(sample[i]['speech_token'])) | |
| speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32) | |
| speech_token = pad_sequence(speech_token, | |
| batch_first=True, | |
| padding_value=0) | |
| speech_latent = [sample[i]['speech_latent'] for i in order] | |
| speech_latent = pad_sequence(speech_latent, | |
| batch_first=True, | |
| padding_value=0) | |
| speech_mel = [sample[i]['speech_mel'] for i in order] | |
| speech_mel_len = torch.tensor([i.size(0) for i in speech_mel], dtype=torch.int32) | |
| speech_mel = pad_sequence(speech_mel, | |
| batch_first=True, | |
| padding_value=0) | |
| text = [sample[i]['text'] for i in order] | |
| text_token = [torch.tensor(sample[i]['text_token']) for i in order] | |
| text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32) | |
| text_token = pad_sequence(text_token, batch_first=True, padding_value=0) | |
| batch = { | |
| "utts": utts, | |
| "speech": speech, | |
| "speech_len": speech_len, | |
| "speech_token": speech_token, | |
| "speech_token_len": speech_token_len, | |
| "speech_mel": speech_mel, | |
| "speech_mel_len": speech_mel_len, | |
| "speech_latent": speech_latent, | |
| "speech_latent_len": speech_latent_len, | |
| "text": text, | |
| "text_token": text_token, | |
| "text_token_len": text_token_len, | |
| } | |
| # Handle reference mels for speaker encoder | |
| if use_speaker_encoder: | |
| # Collect all reference mels | |
| all_reference_mels = [] | |
| all_reference_mel_lengths = [] | |
| all_num_references = [] | |
| for i in order: | |
| ref_mels = sample[i].get('reference_mels', []) | |
| ref_lengths = sample[i].get('reference_mel_lengths', []) | |
| num_refs = sample[i].get('num_references', 0) | |
| all_reference_mels.append(ref_mels) | |
| all_reference_mel_lengths.append(ref_lengths) | |
| all_num_references.append(num_refs) | |
| # Determine max number of references in batch | |
| max_num_refs = max(all_num_references) if all_num_references else 0 | |
| if max_num_refs > 0: | |
| # Find dimensions | |
| batch_size = len(order) | |
| max_mel_length = 0 | |
| mel_dim = 80 # default | |
| # Find max mel length and mel dimension | |
| for ref_mels in all_reference_mels: | |
| for mel in ref_mels: | |
| if isinstance(mel, torch.Tensor) and mel.numel() > 0: | |
| max_mel_length = max(max_mel_length, mel.shape[1]) | |
| mel_dim = mel.shape[0] | |
| if max_mel_length > 0: | |
| # Create padded tensor [B, N, C, T] | |
| padded_reference_mels = torch.zeros(batch_size, max_num_refs, mel_dim, max_mel_length) | |
| padded_reference_mel_lengths = torch.zeros(batch_size, max_num_refs, dtype=torch.int32) | |
| reference_mel_masks = torch.zeros(batch_size, max_num_refs, max_mel_length) | |
| for b_idx, (ref_mels, ref_lengths) in enumerate(zip(all_reference_mels, all_reference_mel_lengths)): | |
| for r_idx in range(min(len(ref_mels), max_num_refs)): | |
| if r_idx < len(ref_mels) and isinstance(ref_mels[r_idx], torch.Tensor): | |
| mel = ref_mels[r_idx] | |
| length = ref_lengths[r_idx] if r_idx < len(ref_lengths) else mel.shape[1] | |
| actual_length = min(length, mel.shape[1], max_mel_length) | |
| padded_reference_mels[b_idx, r_idx, :, :actual_length] = mel[:, :actual_length] | |
| padded_reference_mel_lengths[b_idx, r_idx] = actual_length | |
| reference_mel_masks[b_idx, r_idx, :actual_length] = 1.0 | |
| batch['reference_mels'] = padded_reference_mels | |
| batch['reference_mel_lengths'] = padded_reference_mel_lengths | |
| batch['reference_mel_masks'] = reference_mel_masks | |
| if gan is True: | |
| # in gan train, we need pitch_feat | |
| pitch_feat = [sample[i]['pitch_feat'] for i in order] | |
| pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32) | |
| pitch_feat = pad_sequence(pitch_feat, | |
| batch_first=True, | |
| padding_value=0) | |
| batch["pitch_feat"] = pitch_feat | |
| batch["pitch_feat_len"] = pitch_feat_len | |
| else: | |
| # only gan train needs speech, delete it to save memory | |
| del batch["speech"] | |
| del batch["speech_len"] | |
| if dpo is True: | |
| reject_speech_token = [] | |
| for i in order: | |
| if isinstance(sample[i]['reject_speech_token'], torch.Tensor): | |
| reject_speech_token.append(sample[i]['reject_speech_token']) | |
| else: | |
| reject_speech_token.append(torch.tensor(sample[i]['reject_speech_token'])) | |
| reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32) | |
| reject_speech_token = pad_sequence(reject_speech_token, | |
| batch_first=True, | |
| padding_value=0) | |
| batch['reject_speech_token'] = reject_speech_token | |
| batch['reject_speech_token_len'] = reject_speech_token_len | |
| yield batch |