Spaces:
Sleeping
Sleeping
primepake
commited on
Commit
·
eb584bd
1
Parent(s):
631dfe2
inv code
Browse files- speech/tools/create_data_list.py +37 -0
- speech/tools/extract_embedding.py +78 -43
- speech/tools/extract_speech_token.py +61 -34
- speech/tools/inv_file_processor.py +109 -0
speech/tools/create_data_list.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Create data list files for training with individual files"""
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
def create_data_lists(src_dir, output_dir):
|
| 9 |
+
"""Create data list files pointing to directories or index files
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
src_dir: Directory containing processed audio files
|
| 13 |
+
output_dir: Directory to save list files
|
| 14 |
+
"""
|
| 15 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 16 |
+
|
| 17 |
+
# Option 1: Create a list pointing to the source directory
|
| 18 |
+
with open(os.path.join(output_dir, 'data.list'), 'w') as f:
|
| 19 |
+
f.write(src_dir + '\n')
|
| 20 |
+
|
| 21 |
+
# Option 2: If you have an index file, point to it
|
| 22 |
+
index_file = os.path.join(src_dir, 'data_index.json')
|
| 23 |
+
if os.path.exists(index_file):
|
| 24 |
+
with open(os.path.join(output_dir, 'data_index.list'), 'w') as f:
|
| 25 |
+
f.write(index_file + '\n')
|
| 26 |
+
|
| 27 |
+
print(f"Created data lists in {output_dir}")
|
| 28 |
+
|
| 29 |
+
if __name__ == "__main__":
|
| 30 |
+
parser = argparse.ArgumentParser()
|
| 31 |
+
parser.add_argument('--src_dir', type=str, required=True,
|
| 32 |
+
help='Source directory with processed files')
|
| 33 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
| 34 |
+
help='Output directory for list files')
|
| 35 |
+
args = parser.parse_args()
|
| 36 |
+
|
| 37 |
+
create_data_lists(args.src_dir, args.output_dir)
|
speech/tools/extract_embedding.py
CHANGED
|
@@ -1,17 +1,4 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
-
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
import argparse
|
| 16 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 17 |
import onnxruntime
|
|
@@ -19,54 +6,102 @@ import torch
|
|
| 19 |
import torchaudio
|
| 20 |
import torchaudio.compliance.kaldi as kaldi
|
| 21 |
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
| 22 |
|
|
|
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
if sample_rate != 16000:
|
| 27 |
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
|
|
|
|
| 28 |
feat = kaldi.fbank(audio,
|
| 29 |
num_mel_bins=80,
|
| 30 |
dither=0,
|
| 31 |
sample_frequency=16000)
|
| 32 |
feat = feat - feat.mean(dim=0, keepdim=True)
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
def main(args):
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
|
| 53 |
if __name__ == "__main__":
|
| 54 |
parser = argparse.ArgumentParser()
|
| 55 |
-
parser.add_argument("--
|
| 56 |
-
parser.add_argument("--onnx_path", type=str)
|
| 57 |
parser.add_argument("--num_thread", type=int, default=8)
|
| 58 |
args = parser.parse_args()
|
| 59 |
|
| 60 |
-
utt2wav, utt2spk = {}, {}
|
| 61 |
-
with open('{}/wav.scp'.format(args.dir)) as f:
|
| 62 |
-
for l in f:
|
| 63 |
-
l = l.replace('\n', '').split()
|
| 64 |
-
utt2wav[l[0]] = l[1]
|
| 65 |
-
with open('{}/utt2spk'.format(args.dir)) as f:
|
| 66 |
-
for l in f:
|
| 67 |
-
l = l.replace('\n', '').split()
|
| 68 |
-
utt2spk[l[0]] = l[1]
|
| 69 |
-
|
| 70 |
option = onnxruntime.SessionOptions()
|
| 71 |
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 72 |
option.intra_op_num_threads = 1
|
|
@@ -74,4 +109,4 @@ if __name__ == "__main__":
|
|
| 74 |
ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
|
| 75 |
executor = ThreadPoolExecutor(max_workers=args.num_thread)
|
| 76 |
|
| 77 |
-
main(args)
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import argparse
|
| 3 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 4 |
import onnxruntime
|
|
|
|
| 6 |
import torchaudio
|
| 7 |
import torchaudio.compliance.kaldi as kaldi
|
| 8 |
from tqdm import tqdm
|
| 9 |
+
import os
|
| 10 |
+
import glob
|
| 11 |
+
import logging
|
| 12 |
|
| 13 |
+
logger = logging.getLogger()
|
| 14 |
|
| 15 |
+
|
| 16 |
+
def process_single_audio(wav_path):
|
| 17 |
+
# Extract utterance ID and speaker ID from filename
|
| 18 |
+
utt = os.path.basename(wav_path).replace('.wav', '')
|
| 19 |
+
spk = utt.split('_')[0]
|
| 20 |
+
|
| 21 |
+
# Check if text file exists
|
| 22 |
+
txt_path = wav_path.replace('.wav', '.normalized.txt')
|
| 23 |
+
if not os.path.exists(txt_path):
|
| 24 |
+
logger.warning(f'{txt_path} does not exist, skipping {wav_path}')
|
| 25 |
+
return None
|
| 26 |
+
|
| 27 |
+
# Process audio
|
| 28 |
+
audio, sample_rate = torchaudio.load(wav_path)
|
| 29 |
if sample_rate != 16000:
|
| 30 |
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
|
| 31 |
+
|
| 32 |
feat = kaldi.fbank(audio,
|
| 33 |
num_mel_bins=80,
|
| 34 |
dither=0,
|
| 35 |
sample_frequency=16000)
|
| 36 |
feat = feat - feat.mean(dim=0, keepdim=True)
|
| 37 |
+
|
| 38 |
+
# Generate embedding
|
| 39 |
+
embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten()
|
| 40 |
+
|
| 41 |
+
# Save individual embedding file
|
| 42 |
+
embedding_path = wav_path.replace('.wav', '_embedding.pt')
|
| 43 |
+
torch.save(embedding, embedding_path)
|
| 44 |
+
|
| 45 |
+
return {
|
| 46 |
+
'wav_path': wav_path,
|
| 47 |
+
'utt': utt,
|
| 48 |
+
'spk': spk,
|
| 49 |
+
'embedding': embedding,
|
| 50 |
+
'embedding_path': embedding_path
|
| 51 |
+
}
|
| 52 |
|
| 53 |
|
| 54 |
def main(args):
|
| 55 |
+
# Find all wav files
|
| 56 |
+
wav_files = list(glob.glob('{}/*/*/*wav'.format(args.src_dir)))
|
| 57 |
+
print(f"Found {len(wav_files)} wav files")
|
| 58 |
+
|
| 59 |
+
# Process all audio files
|
| 60 |
+
all_tasks = [executor.submit(process_single_audio, wav_path) for wav_path in wav_files]
|
| 61 |
+
|
| 62 |
+
# Collect results
|
| 63 |
+
spk2embeddings = {}
|
| 64 |
+
successful_files = []
|
| 65 |
+
|
| 66 |
+
for future in tqdm(as_completed(all_tasks), total=len(all_tasks)):
|
| 67 |
+
result = future.result()
|
| 68 |
+
if result is None:
|
| 69 |
+
continue
|
| 70 |
+
|
| 71 |
+
successful_files.append(result)
|
| 72 |
+
|
| 73 |
+
# Collect embeddings by speaker
|
| 74 |
+
spk = result['spk']
|
| 75 |
+
if spk not in spk2embeddings:
|
| 76 |
+
spk2embeddings[spk] = []
|
| 77 |
+
spk2embeddings[spk].append(result['embedding'])
|
| 78 |
+
|
| 79 |
+
# Calculate and save speaker embeddings
|
| 80 |
+
spk_embed_dir = os.path.join(args.src_dir, "spk_embeddings")
|
| 81 |
+
os.makedirs(spk_embed_dir, exist_ok=True)
|
| 82 |
+
|
| 83 |
+
for spk, embeddings in spk2embeddings.items():
|
| 84 |
+
spk_embedding = torch.stack([torch.tensor(e) for e in embeddings]).mean(dim=0)
|
| 85 |
+
spk_embedding_path = os.path.join(spk_embed_dir, f"{spk}_embedding.pt")
|
| 86 |
+
torch.save(spk_embedding, spk_embedding_path)
|
| 87 |
+
print(f"Saved speaker embedding for {spk} with {len(embeddings)} utterances")
|
| 88 |
+
|
| 89 |
+
# Save a summary file for reference
|
| 90 |
+
summary_path = os.path.join(args.src_dir, "embedding_summary.txt")
|
| 91 |
+
with open(summary_path, 'w') as f:
|
| 92 |
+
f.write(f"Processed {len(successful_files)} files successfully\n")
|
| 93 |
+
f.write(f"Found {len(spk2embeddings)} speakers\n")
|
| 94 |
+
for result in successful_files:
|
| 95 |
+
f.write(f"{result['utt']} {result['wav_path']} {result['embedding_path']}\n")
|
| 96 |
|
| 97 |
|
| 98 |
if __name__ == "__main__":
|
| 99 |
parser = argparse.ArgumentParser()
|
| 100 |
+
parser.add_argument("--src_dir", type=str, help="Source directory containing audio files")
|
| 101 |
+
parser.add_argument("--onnx_path", type=str, help="Path to campplus.onnx model")
|
| 102 |
parser.add_argument("--num_thread", type=int, default=8)
|
| 103 |
args = parser.parse_args()
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
option = onnxruntime.SessionOptions()
|
| 106 |
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 107 |
option.intra_op_num_threads = 1
|
|
|
|
| 109 |
ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
|
| 110 |
executor = ThreadPoolExecutor(max_workers=args.num_thread)
|
| 111 |
|
| 112 |
+
main(args)
|
speech/tools/extract_speech_token.py
CHANGED
|
@@ -1,17 +1,4 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
-
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
import argparse
|
| 16 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 17 |
import logging
|
|
@@ -21,47 +8,87 @@ import onnxruntime
|
|
| 21 |
import numpy as np
|
| 22 |
import torchaudio
|
| 23 |
import whisper
|
|
|
|
|
|
|
| 24 |
|
|
|
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
if sample_rate != 16000:
|
| 29 |
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
|
|
|
|
| 30 |
# Convert audio to mono
|
| 31 |
if audio.shape[0] > 1:
|
| 32 |
audio = audio.mean(dim=0, keepdim=True)
|
|
|
|
| 33 |
if audio.shape[1] / 16000 > 30:
|
| 34 |
-
logging.warning('
|
| 35 |
speech_token = []
|
| 36 |
else:
|
| 37 |
feat = whisper.log_mel_spectrogram(audio, n_mels=128)
|
| 38 |
-
speech_token = ort_session.run(None, {
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
|
| 43 |
def main(args):
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
if __name__ == "__main__":
|
| 53 |
parser = argparse.ArgumentParser()
|
| 54 |
-
parser.add_argument("--
|
| 55 |
-
parser.add_argument("--onnx_path", type=str)
|
| 56 |
parser.add_argument("--num_thread", type=int, default=8)
|
| 57 |
args = parser.parse_args()
|
| 58 |
|
| 59 |
-
utt2wav = {}
|
| 60 |
-
with open('{}/wav.scp'.format(args.dir)) as f:
|
| 61 |
-
for l in f:
|
| 62 |
-
l = l.replace('\n', '').split()
|
| 63 |
-
utt2wav[l[0]] = l[1]
|
| 64 |
-
|
| 65 |
option = onnxruntime.SessionOptions()
|
| 66 |
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 67 |
option.intra_op_num_threads = 1
|
|
@@ -69,4 +96,4 @@ if __name__ == "__main__":
|
|
| 69 |
ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
|
| 70 |
executor = ThreadPoolExecutor(max_workers=args.num_thread)
|
| 71 |
|
| 72 |
-
main(args)
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import argparse
|
| 3 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 4 |
import logging
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
import torchaudio
|
| 10 |
import whisper
|
| 11 |
+
import glob
|
| 12 |
+
import os
|
| 13 |
|
| 14 |
+
logger = logging.getLogger()
|
| 15 |
|
| 16 |
+
|
| 17 |
+
def process_single_audio(wav_path):
|
| 18 |
+
# Check if text file exists
|
| 19 |
+
txt_path = wav_path.replace('.wav', '.normalized.txt')
|
| 20 |
+
if not os.path.exists(txt_path):
|
| 21 |
+
logger.warning(f'{txt_path} does not exist, skipping {wav_path}')
|
| 22 |
+
return None
|
| 23 |
+
|
| 24 |
+
# Extract utterance ID
|
| 25 |
+
utt = os.path.basename(wav_path).replace('.wav', '')
|
| 26 |
+
|
| 27 |
+
# Process audio
|
| 28 |
+
audio, sample_rate = torchaudio.load(wav_path, backend='soundfile')
|
| 29 |
if sample_rate != 16000:
|
| 30 |
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
|
| 31 |
+
|
| 32 |
# Convert audio to mono
|
| 33 |
if audio.shape[0] > 1:
|
| 34 |
audio = audio.mean(dim=0, keepdim=True)
|
| 35 |
+
|
| 36 |
if audio.shape[1] / 16000 > 30:
|
| 37 |
+
logging.warning(f'Audio longer than 30s, skipping tokenization for {wav_path}')
|
| 38 |
speech_token = []
|
| 39 |
else:
|
| 40 |
feat = whisper.log_mel_spectrogram(audio, n_mels=128)
|
| 41 |
+
speech_token = ort_session.run(None, {
|
| 42 |
+
ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
|
| 43 |
+
ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)
|
| 44 |
+
})[0].flatten().tolist()
|
| 45 |
+
|
| 46 |
+
# Save individual token file
|
| 47 |
+
token_path = wav_path.replace('.wav', '_tokens.pt')
|
| 48 |
+
torch.save(speech_token, token_path)
|
| 49 |
+
|
| 50 |
+
return {
|
| 51 |
+
'wav_path': wav_path,
|
| 52 |
+
'utt': utt,
|
| 53 |
+
'token_path': token_path,
|
| 54 |
+
'num_tokens': len(speech_token)
|
| 55 |
+
}
|
| 56 |
|
| 57 |
|
| 58 |
def main(args):
|
| 59 |
+
# Find all wav files
|
| 60 |
+
wav_files = list(glob.glob('{}/*/*/*wav'.format(args.src_dir)))
|
| 61 |
+
print(f"Found {len(wav_files)} wav files")
|
| 62 |
+
|
| 63 |
+
# Process all audio files
|
| 64 |
+
all_tasks = [executor.submit(process_single_audio, wav_path) for wav_path in wav_files]
|
| 65 |
+
|
| 66 |
+
# Collect results
|
| 67 |
+
successful_files = []
|
| 68 |
+
|
| 69 |
+
for future in tqdm(as_completed(all_tasks), total=len(all_tasks)):
|
| 70 |
+
result = future.result()
|
| 71 |
+
if result is None:
|
| 72 |
+
continue
|
| 73 |
+
successful_files.append(result)
|
| 74 |
+
|
| 75 |
+
# Save a summary file for reference
|
| 76 |
+
summary_path = os.path.join(args.src_dir, "token_summary.txt")
|
| 77 |
+
with open(summary_path, 'w') as f:
|
| 78 |
+
f.write(f"Processed {len(successful_files)} files successfully\n")
|
| 79 |
+
total_tokens = sum(r['num_tokens'] for r in successful_files)
|
| 80 |
+
f.write(f"Total tokens generated: {total_tokens}\n")
|
| 81 |
+
for result in successful_files:
|
| 82 |
+
f.write(f"{result['utt']} {result['wav_path']} {result['token_path']} {result['num_tokens']}\n")
|
| 83 |
|
| 84 |
|
| 85 |
if __name__ == "__main__":
|
| 86 |
parser = argparse.ArgumentParser()
|
| 87 |
+
parser.add_argument("--src_dir", type=str, help="Source directory containing audio files")
|
| 88 |
+
parser.add_argument("--onnx_path", type=str, help="Path to speech_tokenizer_v2.onnx model")
|
| 89 |
parser.add_argument("--num_thread", type=int, default=8)
|
| 90 |
args = parser.parse_args()
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
option = onnxruntime.SessionOptions()
|
| 93 |
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 94 |
option.intra_op_num_threads = 1
|
|
|
|
| 96 |
ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
|
| 97 |
executor = ThreadPoolExecutor(max_workers=args.num_thread)
|
| 98 |
|
| 99 |
+
main(args)
|
speech/tools/inv_file_processor.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Add this to your processor.py file or create a new file
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import json
|
| 5 |
+
import torch
|
| 6 |
+
import glob
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
def individual_file_opener(data, mode='train', tts_data={}):
|
| 11 |
+
"""Load data from individual files instead of parquet
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
data: Iterable[{src}] where src is either:
|
| 15 |
+
- Path to index JSON file
|
| 16 |
+
- Directory path containing wav files
|
| 17 |
+
mode: 'train' or 'test'
|
| 18 |
+
tts_data: Dict for TTS mode
|
| 19 |
+
|
| 20 |
+
Yields:
|
| 21 |
+
Dict with all required fields for training
|
| 22 |
+
"""
|
| 23 |
+
for sample in data:
|
| 24 |
+
src = sample['src']
|
| 25 |
+
|
| 26 |
+
# Check if src is a JSON index file or a directory
|
| 27 |
+
if src.endswith('.json'):
|
| 28 |
+
# Load from index file
|
| 29 |
+
with open(src, 'r') as f:
|
| 30 |
+
index_data = json.load(f)
|
| 31 |
+
file_list = index_data.get('data', [])
|
| 32 |
+
else:
|
| 33 |
+
# Find all wav files in directory
|
| 34 |
+
wav_files = glob.glob(os.path.join(src, '*/*/*wav'))
|
| 35 |
+
file_list = []
|
| 36 |
+
for wav_path in wav_files:
|
| 37 |
+
# Check if all required files exist
|
| 38 |
+
txt_path = wav_path.replace('.wav', '.normalized.txt')
|
| 39 |
+
embedding_path = wav_path.replace('.wav', '_embedding.pt')
|
| 40 |
+
token_path = wav_path.replace('.wav', '_tokens.pt')
|
| 41 |
+
|
| 42 |
+
if not all(os.path.exists(p) for p in [txt_path, embedding_path, token_path]):
|
| 43 |
+
logging.warning(f'Missing files for {wav_path}, skipping')
|
| 44 |
+
continue
|
| 45 |
+
|
| 46 |
+
# Extract metadata
|
| 47 |
+
utt = os.path.basename(wav_path).replace('.wav', '')
|
| 48 |
+
spk = utt.split('_')[0]
|
| 49 |
+
|
| 50 |
+
file_list.append({
|
| 51 |
+
'utt': utt,
|
| 52 |
+
'spk': spk,
|
| 53 |
+
'wav': wav_path,
|
| 54 |
+
'text_path': txt_path,
|
| 55 |
+
'embedding_path': embedding_path,
|
| 56 |
+
'token_path': token_path,
|
| 57 |
+
'spk_embedding_path': os.path.join(os.path.dirname(src), f"spk_embeddings/{spk}_embedding.pt")
|
| 58 |
+
})
|
| 59 |
+
|
| 60 |
+
# Process each file
|
| 61 |
+
for file_info in file_list:
|
| 62 |
+
try:
|
| 63 |
+
# Read audio data
|
| 64 |
+
with open(file_info['wav'], 'rb') as f:
|
| 65 |
+
audio_data = f.read()
|
| 66 |
+
|
| 67 |
+
# Read text
|
| 68 |
+
with open(file_info['text_path'], 'r') as f:
|
| 69 |
+
text = ''.join(l.strip() for l in f.readlines())
|
| 70 |
+
|
| 71 |
+
# Load embeddings
|
| 72 |
+
utt_embedding = torch.load(file_info['embedding_path']).tolist()
|
| 73 |
+
speech_token = torch.load(file_info['token_path'])
|
| 74 |
+
|
| 75 |
+
# Load speaker embedding
|
| 76 |
+
if os.path.exists(file_info['spk_embedding_path']):
|
| 77 |
+
spk_embedding = torch.load(file_info['spk_embedding_path']).tolist()
|
| 78 |
+
else:
|
| 79 |
+
logging.warning(f"Speaker embedding not found: {file_info['spk_embedding_path']}")
|
| 80 |
+
spk_embedding = utt_embedding # Fallback to utterance embedding
|
| 81 |
+
|
| 82 |
+
# Build sample dict
|
| 83 |
+
sample_dict = {
|
| 84 |
+
'utt': file_info['utt'],
|
| 85 |
+
'spk': file_info['spk'],
|
| 86 |
+
'audio_data': audio_data,
|
| 87 |
+
'text': text,
|
| 88 |
+
'text_token': [], # Will be filled by tokenize processor
|
| 89 |
+
'utt_embedding': utt_embedding,
|
| 90 |
+
'spk_embedding': spk_embedding,
|
| 91 |
+
'speech_token': speech_token,
|
| 92 |
+
'wav': file_info['wav'], # Keep original path for reference
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
# Merge with original sample data
|
| 96 |
+
sample_dict.update(sample)
|
| 97 |
+
|
| 98 |
+
if mode == 'train':
|
| 99 |
+
yield sample_dict
|
| 100 |
+
else:
|
| 101 |
+
# For TTS mode
|
| 102 |
+
if file_info['utt'] in tts_data:
|
| 103 |
+
for index, tts_text in enumerate(tts_data[file_info['utt']]):
|
| 104 |
+
yield {**sample_dict, 'tts_index': index, 'tts_text': tts_text}
|
| 105 |
+
else:
|
| 106 |
+
yield sample_dict
|
| 107 |
+
|
| 108 |
+
except Exception as ex:
|
| 109 |
+
logging.warning(f'Failed to process {file_info["wav"]}: {ex}')
|