primepake
update training
1c43d7b
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import pyarrow.parquet as pq
from io import BytesIO
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import pyworld as pw
import glob
import os
import json
import traceback
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
def individual_file_opener(data, mode='train', tts_data={}, token_latent_ratio=3):
"""Load data from individual files listed in files.txt
Args:
data: Iterable[{src}] where src is path to files.txt containing audio paths
mode: 'train' or 'test'
tts_data: Dict for TTS mode
Yields:
Dict with all required fields for training
"""
for sample in data:
assert 'src' in sample
src = sample['src']
# Load file list from files.txt
file_list = []
# Check if src is a files.txt file
if src.endswith('.txt'):
with open(src, 'r') as f:
wav_files = [line.strip() for line in f if line.strip()]
for wav_path in wav_files:
# Skip empty lines or comments
if not wav_path or wav_path.startswith('#'):
continue
# Verify wav file exists
if not os.path.exists(wav_path):
logging.warning(f'Audio file not found: {wav_path}, skipping')
continue
# Check if all required files exist
txt_path = wav_path.replace('.wav', '.txt')
token_path = wav_path.replace('.wav', '_fsq.pt')
latent_path = wav_path.replace('.wav', '_latent2x.pt')
if not os.path.exists(txt_path):
logging.warning(f'Text file not found for {wav_path}, skipping')
continue
if not os.path.exists(token_path):
logging.warning(f'Token file not found for {wav_path}, skipping')
continue
if not os.path.exists(latent_path):
logging.warning(f'Latent file not found for {wav_path}, skipping')
continue
# Extract metadata
utt = os.path.basename(wav_path).replace('.wav', '')
# Try to extract speaker from filename (assuming format: spk_*.wav)
spk = utt.split('_')[0] if '_' in utt else 'default'
file_info = {
'utt': utt,
'spk': spk,
'wav': wav_path,
'text_path': txt_path,
'token_path': token_path,
'latent_path': latent_path,
}
logging.info(f'file_info {file_info}')
file_list.append(file_info)
elif src.endswith('.json'):
# Keep backward compatibility with JSON index files
with open(src, 'r') as f:
index_data = json.load(f)
file_list = index_data.get('data', [])
else:
# Assume it's a directory for backward compatibility
wav_files = glob.glob(os.path.join(src, '*/*/*wav'))
if not wav_files:
wav_files = glob.glob(os.path.join(src, '**/*.wav'), recursive=True)
for wav_path in wav_files:
txt_path = wav_path.replace('.wav', '.txt')
token_path = wav_path.replace('.wav', '_fsq.pt')
latent_path = wav_path.replace('.wav', '_latent2x.pt')
if not os.path.exists(txt_path):
logging.warning(f'Text file not found for {wav_path}, skipping')
continue
utt = os.path.basename(wav_path).replace('.wav', '')
spk = utt.split('_')[0]
file_info = {
'utt': utt,
'spk': spk,
'wav': wav_path,
'text_path': txt_path,
'token_path': token_path,
'latent_path': latent_path,
}
file_list.append(file_info)
logging.info(f'Found {len(file_list)} valid audio files from {src}')
# Process each file
for file_info in file_list:
try:
# Read audio data
with open(file_info['wav'], 'rb') as f:
audio_data = f.read()
# Read text
with open(file_info['text_path'], 'r', encoding='utf-8') as f:
text = ''.join(l.strip() for l in f.readlines())
# Load speech token
speech_token = torch.load(file_info['token_path'], map_location='cpu', weights_only=False)
if isinstance(speech_token, torch.Tensor):
speech_token = speech_token.tolist()
# Load speech latent
speech_latent = torch.load(file_info['latent_path'], map_location='cpu', weights_only=False)
speech_latent = speech_latent['z'].transpose(0, 1)
if token_latent_ratio != 0:
# trim to align speech_token and speech_feat
print('before algin speech_latent: ', speech_latent.shape)
token_len = int(min(speech_latent.shape[0] / token_latent_ratio, len(speech_token)))
speech_latent = speech_latent[:token_latent_ratio * token_len]
speech_token = speech_token[:token_len]
print('after algin speech_latent: ', speech_latent.shape)
# Build sample dict
sample_dict = {
'utt': file_info['utt'],
'spk': file_info['spk'],
'audio_data': audio_data,
'text': text,
'text_token': [], # Will be filled by tokenize processor
'speech_token': speech_token,
'wav': file_info['wav'], # Keep original path for reference
'speech_latent': speech_latent,
}
# Copy over any additional fields from the original sample
for key, value in sample.items():
if key not in sample_dict:
sample_dict[key] = value
if mode == 'train':
yield sample_dict
else:
# For TTS mode
if file_info['utt'] in tts_data:
for index, tts_text in enumerate(tts_data[file_info['utt']]):
yield {**sample_dict, 'tts_index': index, 'tts_text': tts_text}
else:
yield sample_dict
except Exception as ex:
logging.warning(f'Failed to process {file_info["wav"]}: {ex}')
def parquet_opener(data, mode='train', tts_data={}):
""" Give url or local file, return file descriptor
Inplace operation.
Args:
data(Iterable[str]): url or local file list
Returns:
Iterable[{src, stream}]
"""
for sample in data:
assert 'src' in sample
url = sample['src']
try:
for df in pq.ParquetFile(url).iter_batches(batch_size=64):
df = df.to_pandas()
for i in range(len(df)):
sample.update(dict(df.loc[i]))
if mode == 'train':
# NOTE do not return sample directly, must initialize a new dict
yield {**sample}
else:
for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
yield {**sample, 'tts_index': index, 'tts_text': text}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
def filter(data,
max_length=10240,
min_length=10,
token_max_length=200,
token_min_length=1,
min_output_input_ratio=0.0005,
max_output_input_ratio=1,
mode='train'):
""" Filter sample according to feature and label length
Inplace operation.
Args::
data: Iterable[{key, wav, label, sample_rate}]
max_length: drop utterance which is greater than max_length(10ms)
min_length: drop utterance which is less than min_length(10ms)
token_max_length: drop utterance which is greater than
token_max_length, especially when use char unit for
english modeling
token_min_length: drop utterance which is
less than token_max_length
min_output_input_ratio: minimal ration of
token_length / feats_length(10ms)
max_output_input_ratio: maximum ration of
token_length / feats_length(10ms)
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
del sample['audio_data']
# sample['wav'] is torch.Tensor, we have 100 frames every second
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
if num_frames < min_length:
continue
if num_frames > max_length:
continue
if len(sample['text_token']) < token_min_length:
continue
if len(sample['text_token']) > token_max_length:
continue
if len(sample['speech_token']) == 0:
continue
if 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
continue
if num_frames != 0:
if len(sample['text_token']) / num_frames < min_output_input_ratio:
print('continue text_token')
continue
if len(sample['text_token']) / num_frames > max_output_input_ratio:
print('continue text_token')
continue
yield sample
def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
""" Resample data.
Inplace operation.
Args:
data: Iterable[{key, wav, label, sample_rate}]
resample_rate: target resample rate
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample
sample_rate = sample['sample_rate']
waveform = sample['speech']
if sample_rate != resample_rate:
if sample_rate < min_sample_rate:
print('continue sample_rate')
continue
sample['sample_rate'] = resample_rate
sample['speech'] = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
max_val = sample['speech'].abs().max()
if max_val > 1:
sample['speech'] /= max_val
yield sample
def truncate(data, truncate_length=24576, mode='train'):
""" Truncate data.
Args:
data: Iterable[{key, wav, label, sample_rate}]
truncate_length: truncate length
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
waveform = sample['speech']
if waveform.shape[1] > truncate_length:
start = random.randint(0, waveform.shape[1] - truncate_length)
waveform = waveform[:, start: start + truncate_length]
else:
waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
sample['speech'] = waveform
yield sample
def extract_reference_mel_from_speech(
data,
feat_extractor,
min_length=2.0,
max_length=6.0,
num_crops=2, # Multiple random crops from same utterance
training=True,
sample_rate=24000,
mode='train'
):
"""
Extract mel spectrograms from current speech waveform with random cropping.
This creates multiple random crops from the same utterance for training diversity.
"""
for sample in data:
# Use the current speech waveform
waveform = sample['speech'] # [1, T]
speech_length = waveform.shape[1]
# Convert time to samples
min_samples = int(min_length * sample_rate)
max_samples = int(max_length * sample_rate)
reference_mels = []
reference_mel_lengths = []
# Skip if utterance is too short
if speech_length < min_samples:
logging.warning(f"Speech for {sample['utt']} is too short ({speech_length/sample_rate:.2f}s)")
sample['reference_mels'] = []
sample['reference_mel_lengths'] = []
sample['num_references'] = 0
print('continue num_references')
yield sample
continue
# Generate multiple crops from the same utterance
crops_to_generate = num_crops if training else 1
for i in range(crops_to_generate):
if training and speech_length > max_samples:
# Random crop during training
crop_length = random.randint(min_samples, min(max_samples, speech_length))
start_idx = random.randint(0, speech_length - crop_length)
audio_segment = waveform[:, start_idx:start_idx + crop_length]
elif speech_length > max_samples:
# Center crop during inference
start_idx = (speech_length - max_samples) // 2
audio_segment = waveform[:, start_idx:start_idx + max_samples]
else:
# Use full audio if shorter than max_length
audio_segment = waveform
# For training, if we need multiple crops but audio is short,
# we can add slight variations
if training and i > 0:
# Add very slight noise for variation
noise = torch.randn_like(audio_segment) * 0.001
audio_segment = audio_segment + noise
# Normalize audio segment
max_val = audio_segment.abs().max()
if max_val > 0:
audio_segment = audio_segment / max_val
# Extract mel spectrogram
mel = feat_extractor(audio_segment).squeeze(0) # Remove batch dim [C, T]
reference_mels.append(mel)
reference_mel_lengths.append(mel.shape[1])
sample['reference_mels'] = reference_mels
sample['reference_mel_lengths'] = reference_mel_lengths
sample['num_references'] = len(reference_mels)
yield sample
def compute_fbank(data,
feat_extractor,
token_mel_ratio=0,
mode='train'):
""" Extract fbank
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample
assert 'utt' in sample
assert 'text_token' in sample
waveform = sample['speech']
feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
# if token_mel_ratio != 0:
# pass
# trim to align speech_token and speech_feat
# token_len = int(min(feat.shape[0] / token_mel_ratio, sample["speech_token"].shape[0]))
# feat = feat[:token_mel_ratio * token_len]
# sample["speech_token"] = sample["speech_token"][:token_len]
sample['speech_mel'] = feat
# print('feat shape, ', feat.shape)
yield sample
def tokenize(data, get_tokenizer, allowed_special, mode='train'):
""" Decode text to chars or BPE
Inplace operation
Args:
data: Iterable[{key, wav, txt, sample_rate}]
Returns:
Iterable[{key, wav, txt, tokens, label, sample_rate}]
"""
tokenizer = get_tokenizer()
for sample in data:
assert 'text' in sample
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
yield sample
def shuffle(data, shuffle_size=10000, mode='train'):
""" Local shuffle the data
Args:
data: Iterable[{key, feat, label}]
shuffle_size: buffer size for shuffle
Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= shuffle_size:
random.shuffle(buf)
for x in buf:
yield x
buf = []
# The sample left over
random.shuffle(buf)
for x in buf:
yield x
def sort(data, sort_size=500, mode='train'):
""" Sort the data by feature length.
Sort is used after shuffle and before batch, so we can group
utts with similar lengths into a batch, and `sort_size` should
be less than `shuffle_size`
Args:
data: Iterable[{key, feat, label}]
sort_size: buffer size for sort
Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= sort_size:
buf.sort(key=lambda x: x['speech_latent'].size(0))
for x in buf:
yield x
buf = []
# The sample left over
buf.sort(key=lambda x: x['speech_latent'].size(0))
for x in buf:
yield x
def static_batch(data, batch_size=16):
""" Static batch the data by `batch_size`
Args:
data: Iterable[{key, feat, label}]
batch_size: batch size
Returns:
Iterable[List[{key, feat, label}]]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= batch_size:
yield buf
buf = []
if len(buf) > 0:
yield buf
def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
""" Dynamic batch the data until the total frames in batch
reach `max_frames_in_batch`
Args:
data: Iterable[{key, feat, label}]
max_frames_in_batch: max_frames in one batch
Returns:
Iterable[List[{key, feat, label}]]
"""
buf = []
longest_frames = 0
for sample in data:
assert 'speech_latent' in sample
assert isinstance(sample['speech_latent'], torch.Tensor)
new_sample_frames = sample['speech_latent'].size(0)
longest_frames = max(longest_frames, new_sample_frames)
frames_after_padding = longest_frames * (len(buf) + 1)
if frames_after_padding > max_frames_in_batch:
yield buf
buf = [sample]
longest_frames = new_sample_frames
else:
buf.append(sample)
if len(buf) > 0:
yield buf
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
""" Wrapper for static/dynamic batch
"""
if batch_type == 'static':
return static_batch(data, batch_size)
elif batch_type == 'dynamic':
return dynamic_batch(data, max_frames_in_batch)
else:
logging.fatal('Unsupported batch type {}'.format(batch_type))
def padding(data, mode='train', gan=False, dpo=False, use_speaker_encoder=False):
""" Padding the data into training data
Args:
data: Iterable[List[{key, feat, label}]]
use_speaker_encoder: Whether to prepare reference mels for speaker encoder
Returns:
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
"""
for sample in data:
assert isinstance(sample, list)
speech_latent_len = torch.tensor([x['speech_latent'].size(0) for x in sample], # Changed from size(1) to size(0)
dtype=torch.int32)
order = torch.argsort(speech_latent_len, descending=True)
utts = [sample[i]['utt'] for i in order]
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
speech = pad_sequence(speech, batch_first=True, padding_value=0)
# Handle speech_token - check if it's already a tensor
speech_token = []
for i in order:
if isinstance(sample[i]['speech_token'], torch.Tensor):
speech_token.append(sample[i]['speech_token'])
else:
speech_token.append(torch.tensor(sample[i]['speech_token']))
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
speech_token = pad_sequence(speech_token,
batch_first=True,
padding_value=0)
speech_latent = [sample[i]['speech_latent'] for i in order]
speech_latent = pad_sequence(speech_latent,
batch_first=True,
padding_value=0)
speech_mel = [sample[i]['speech_mel'] for i in order]
speech_mel_len = torch.tensor([i.size(0) for i in speech_mel], dtype=torch.int32)
speech_mel = pad_sequence(speech_mel,
batch_first=True,
padding_value=0)
text = [sample[i]['text'] for i in order]
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
batch = {
"utts": utts,
"speech": speech,
"speech_len": speech_len,
"speech_token": speech_token,
"speech_token_len": speech_token_len,
"speech_mel": speech_mel,
"speech_mel_len": speech_mel_len,
"speech_latent": speech_latent,
"speech_latent_len": speech_latent_len,
"text": text,
"text_token": text_token,
"text_token_len": text_token_len,
}
# Handle reference mels for speaker encoder
if use_speaker_encoder:
# Collect all reference mels
all_reference_mels = []
all_reference_mel_lengths = []
all_num_references = []
for i in order:
ref_mels = sample[i].get('reference_mels', [])
ref_lengths = sample[i].get('reference_mel_lengths', [])
num_refs = sample[i].get('num_references', 0)
all_reference_mels.append(ref_mels)
all_reference_mel_lengths.append(ref_lengths)
all_num_references.append(num_refs)
# Determine max number of references in batch
max_num_refs = max(all_num_references) if all_num_references else 0
if max_num_refs > 0:
# Find dimensions
batch_size = len(order)
max_mel_length = 0
mel_dim = 80 # default
# Find max mel length and mel dimension
for ref_mels in all_reference_mels:
for mel in ref_mels:
if isinstance(mel, torch.Tensor) and mel.numel() > 0:
max_mel_length = max(max_mel_length, mel.shape[1])
mel_dim = mel.shape[0]
if max_mel_length > 0:
# Create padded tensor [B, N, C, T]
padded_reference_mels = torch.zeros(batch_size, max_num_refs, mel_dim, max_mel_length)
padded_reference_mel_lengths = torch.zeros(batch_size, max_num_refs, dtype=torch.int32)
reference_mel_masks = torch.zeros(batch_size, max_num_refs, max_mel_length)
for b_idx, (ref_mels, ref_lengths) in enumerate(zip(all_reference_mels, all_reference_mel_lengths)):
for r_idx in range(min(len(ref_mels), max_num_refs)):
if r_idx < len(ref_mels) and isinstance(ref_mels[r_idx], torch.Tensor):
mel = ref_mels[r_idx]
length = ref_lengths[r_idx] if r_idx < len(ref_lengths) else mel.shape[1]
actual_length = min(length, mel.shape[1], max_mel_length)
padded_reference_mels[b_idx, r_idx, :, :actual_length] = mel[:, :actual_length]
padded_reference_mel_lengths[b_idx, r_idx] = actual_length
reference_mel_masks[b_idx, r_idx, :actual_length] = 1.0
batch['reference_mels'] = padded_reference_mels
batch['reference_mel_lengths'] = padded_reference_mel_lengths
batch['reference_mel_masks'] = reference_mel_masks
if gan is True:
# in gan train, we need pitch_feat
pitch_feat = [sample[i]['pitch_feat'] for i in order]
pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
pitch_feat = pad_sequence(pitch_feat,
batch_first=True,
padding_value=0)
batch["pitch_feat"] = pitch_feat
batch["pitch_feat_len"] = pitch_feat_len
else:
# only gan train needs speech, delete it to save memory
del batch["speech"]
del batch["speech_len"]
if dpo is True:
reject_speech_token = []
for i in order:
if isinstance(sample[i]['reject_speech_token'], torch.Tensor):
reject_speech_token.append(sample[i]['reject_speech_token'])
else:
reject_speech_token.append(torch.tensor(sample[i]['reject_speech_token']))
reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
reject_speech_token = pad_sequence(reject_speech_token,
batch_first=True,
padding_value=0)
batch['reject_speech_token'] = reject_speech_token
batch['reject_speech_token_len'] = reject_speech_token_len
yield batch