primepake commited on
Commit
eb584bd
·
1 Parent(s): 631dfe2
speech/tools/create_data_list.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Create data list files for training with individual files"""
3
+
4
+ import argparse
5
+ import os
6
+ import json
7
+
8
+ def create_data_lists(src_dir, output_dir):
9
+ """Create data list files pointing to directories or index files
10
+
11
+ Args:
12
+ src_dir: Directory containing processed audio files
13
+ output_dir: Directory to save list files
14
+ """
15
+ os.makedirs(output_dir, exist_ok=True)
16
+
17
+ # Option 1: Create a list pointing to the source directory
18
+ with open(os.path.join(output_dir, 'data.list'), 'w') as f:
19
+ f.write(src_dir + '\n')
20
+
21
+ # Option 2: If you have an index file, point to it
22
+ index_file = os.path.join(src_dir, 'data_index.json')
23
+ if os.path.exists(index_file):
24
+ with open(os.path.join(output_dir, 'data_index.list'), 'w') as f:
25
+ f.write(index_file + '\n')
26
+
27
+ print(f"Created data lists in {output_dir}")
28
+
29
+ if __name__ == "__main__":
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument('--src_dir', type=str, required=True,
32
+ help='Source directory with processed files')
33
+ parser.add_argument('--output_dir', type=str, required=True,
34
+ help='Output directory for list files')
35
+ args = parser.parse_args()
36
+
37
+ create_data_lists(args.src_dir, args.output_dir)
speech/tools/extract_embedding.py CHANGED
@@ -1,17 +1,4 @@
1
  #!/usr/bin/env python3
2
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
  import argparse
16
  from concurrent.futures import ThreadPoolExecutor, as_completed
17
  import onnxruntime
@@ -19,54 +6,102 @@ import torch
19
  import torchaudio
20
  import torchaudio.compliance.kaldi as kaldi
21
  from tqdm import tqdm
 
 
 
22
 
 
23
 
24
- def single_job(utt):
25
- audio, sample_rate = torchaudio.load(utt2wav[utt])
 
 
 
 
 
 
 
 
 
 
 
 
26
  if sample_rate != 16000:
27
  audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
 
28
  feat = kaldi.fbank(audio,
29
  num_mel_bins=80,
30
  dither=0,
31
  sample_frequency=16000)
32
  feat = feat - feat.mean(dim=0, keepdim=True)
33
- embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
34
- return utt, embedding
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
  def main(args):
38
- all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()]
39
- utt2embedding, spk2embedding = {}, {}
40
- for future in tqdm(as_completed(all_task)):
41
- utt, embedding = future.result()
42
- utt2embedding[utt] = embedding
43
- spk = utt2spk[utt]
44
- if spk not in spk2embedding:
45
- spk2embedding[spk] = []
46
- spk2embedding[spk].append(embedding)
47
- for k, v in spk2embedding.items():
48
- spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
49
- torch.save(utt2embedding, "{}/utt2embedding.pt".format(args.dir))
50
- torch.save(spk2embedding, "{}/spk2embedding.pt".format(args.dir))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
 
53
  if __name__ == "__main__":
54
  parser = argparse.ArgumentParser()
55
- parser.add_argument("--dir", type=str)
56
- parser.add_argument("--onnx_path", type=str)
57
  parser.add_argument("--num_thread", type=int, default=8)
58
  args = parser.parse_args()
59
 
60
- utt2wav, utt2spk = {}, {}
61
- with open('{}/wav.scp'.format(args.dir)) as f:
62
- for l in f:
63
- l = l.replace('\n', '').split()
64
- utt2wav[l[0]] = l[1]
65
- with open('{}/utt2spk'.format(args.dir)) as f:
66
- for l in f:
67
- l = l.replace('\n', '').split()
68
- utt2spk[l[0]] = l[1]
69
-
70
  option = onnxruntime.SessionOptions()
71
  option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
72
  option.intra_op_num_threads = 1
@@ -74,4 +109,4 @@ if __name__ == "__main__":
74
  ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
75
  executor = ThreadPoolExecutor(max_workers=args.num_thread)
76
 
77
- main(args)
 
1
  #!/usr/bin/env python3
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import argparse
3
  from concurrent.futures import ThreadPoolExecutor, as_completed
4
  import onnxruntime
 
6
  import torchaudio
7
  import torchaudio.compliance.kaldi as kaldi
8
  from tqdm import tqdm
9
+ import os
10
+ import glob
11
+ import logging
12
 
13
+ logger = logging.getLogger()
14
 
15
+
16
+ def process_single_audio(wav_path):
17
+ # Extract utterance ID and speaker ID from filename
18
+ utt = os.path.basename(wav_path).replace('.wav', '')
19
+ spk = utt.split('_')[0]
20
+
21
+ # Check if text file exists
22
+ txt_path = wav_path.replace('.wav', '.normalized.txt')
23
+ if not os.path.exists(txt_path):
24
+ logger.warning(f'{txt_path} does not exist, skipping {wav_path}')
25
+ return None
26
+
27
+ # Process audio
28
+ audio, sample_rate = torchaudio.load(wav_path)
29
  if sample_rate != 16000:
30
  audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
31
+
32
  feat = kaldi.fbank(audio,
33
  num_mel_bins=80,
34
  dither=0,
35
  sample_frequency=16000)
36
  feat = feat - feat.mean(dim=0, keepdim=True)
37
+
38
+ # Generate embedding
39
+ embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten()
40
+
41
+ # Save individual embedding file
42
+ embedding_path = wav_path.replace('.wav', '_embedding.pt')
43
+ torch.save(embedding, embedding_path)
44
+
45
+ return {
46
+ 'wav_path': wav_path,
47
+ 'utt': utt,
48
+ 'spk': spk,
49
+ 'embedding': embedding,
50
+ 'embedding_path': embedding_path
51
+ }
52
 
53
 
54
  def main(args):
55
+ # Find all wav files
56
+ wav_files = list(glob.glob('{}/*/*/*wav'.format(args.src_dir)))
57
+ print(f"Found {len(wav_files)} wav files")
58
+
59
+ # Process all audio files
60
+ all_tasks = [executor.submit(process_single_audio, wav_path) for wav_path in wav_files]
61
+
62
+ # Collect results
63
+ spk2embeddings = {}
64
+ successful_files = []
65
+
66
+ for future in tqdm(as_completed(all_tasks), total=len(all_tasks)):
67
+ result = future.result()
68
+ if result is None:
69
+ continue
70
+
71
+ successful_files.append(result)
72
+
73
+ # Collect embeddings by speaker
74
+ spk = result['spk']
75
+ if spk not in spk2embeddings:
76
+ spk2embeddings[spk] = []
77
+ spk2embeddings[spk].append(result['embedding'])
78
+
79
+ # Calculate and save speaker embeddings
80
+ spk_embed_dir = os.path.join(args.src_dir, "spk_embeddings")
81
+ os.makedirs(spk_embed_dir, exist_ok=True)
82
+
83
+ for spk, embeddings in spk2embeddings.items():
84
+ spk_embedding = torch.stack([torch.tensor(e) for e in embeddings]).mean(dim=0)
85
+ spk_embedding_path = os.path.join(spk_embed_dir, f"{spk}_embedding.pt")
86
+ torch.save(spk_embedding, spk_embedding_path)
87
+ print(f"Saved speaker embedding for {spk} with {len(embeddings)} utterances")
88
+
89
+ # Save a summary file for reference
90
+ summary_path = os.path.join(args.src_dir, "embedding_summary.txt")
91
+ with open(summary_path, 'w') as f:
92
+ f.write(f"Processed {len(successful_files)} files successfully\n")
93
+ f.write(f"Found {len(spk2embeddings)} speakers\n")
94
+ for result in successful_files:
95
+ f.write(f"{result['utt']} {result['wav_path']} {result['embedding_path']}\n")
96
 
97
 
98
  if __name__ == "__main__":
99
  parser = argparse.ArgumentParser()
100
+ parser.add_argument("--src_dir", type=str, help="Source directory containing audio files")
101
+ parser.add_argument("--onnx_path", type=str, help="Path to campplus.onnx model")
102
  parser.add_argument("--num_thread", type=int, default=8)
103
  args = parser.parse_args()
104
 
 
 
 
 
 
 
 
 
 
 
105
  option = onnxruntime.SessionOptions()
106
  option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
107
  option.intra_op_num_threads = 1
 
109
  ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
110
  executor = ThreadPoolExecutor(max_workers=args.num_thread)
111
 
112
+ main(args)
speech/tools/extract_speech_token.py CHANGED
@@ -1,17 +1,4 @@
1
  #!/usr/bin/env python3
2
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
  import argparse
16
  from concurrent.futures import ThreadPoolExecutor, as_completed
17
  import logging
@@ -21,47 +8,87 @@ import onnxruntime
21
  import numpy as np
22
  import torchaudio
23
  import whisper
 
 
24
 
 
25
 
26
- def single_job(utt):
27
- audio, sample_rate = torchaudio.load(utt2wav[utt], backend='soundfile')
 
 
 
 
 
 
 
 
 
 
 
28
  if sample_rate != 16000:
29
  audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
 
30
  # Convert audio to mono
31
  if audio.shape[0] > 1:
32
  audio = audio.mean(dim=0, keepdim=True)
 
33
  if audio.shape[1] / 16000 > 30:
34
- logging.warning('do not support extract speech token for audio longer than 30s')
35
  speech_token = []
36
  else:
37
  feat = whisper.log_mel_spectrogram(audio, n_mels=128)
38
- speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
39
- ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
40
- return utt, speech_token
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
  def main(args):
44
- all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()]
45
- utt2speech_token = {}
46
- for future in tqdm(as_completed(all_task)):
47
- utt, speech_token = future.result()
48
- utt2speech_token[utt] = speech_token
49
- torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
 
52
  if __name__ == "__main__":
53
  parser = argparse.ArgumentParser()
54
- parser.add_argument("--dir", type=str)
55
- parser.add_argument("--onnx_path", type=str)
56
  parser.add_argument("--num_thread", type=int, default=8)
57
  args = parser.parse_args()
58
 
59
- utt2wav = {}
60
- with open('{}/wav.scp'.format(args.dir)) as f:
61
- for l in f:
62
- l = l.replace('\n', '').split()
63
- utt2wav[l[0]] = l[1]
64
-
65
  option = onnxruntime.SessionOptions()
66
  option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
67
  option.intra_op_num_threads = 1
@@ -69,4 +96,4 @@ if __name__ == "__main__":
69
  ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
70
  executor = ThreadPoolExecutor(max_workers=args.num_thread)
71
 
72
- main(args)
 
1
  #!/usr/bin/env python3
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import argparse
3
  from concurrent.futures import ThreadPoolExecutor, as_completed
4
  import logging
 
8
  import numpy as np
9
  import torchaudio
10
  import whisper
11
+ import glob
12
+ import os
13
 
14
+ logger = logging.getLogger()
15
 
16
+
17
+ def process_single_audio(wav_path):
18
+ # Check if text file exists
19
+ txt_path = wav_path.replace('.wav', '.normalized.txt')
20
+ if not os.path.exists(txt_path):
21
+ logger.warning(f'{txt_path} does not exist, skipping {wav_path}')
22
+ return None
23
+
24
+ # Extract utterance ID
25
+ utt = os.path.basename(wav_path).replace('.wav', '')
26
+
27
+ # Process audio
28
+ audio, sample_rate = torchaudio.load(wav_path, backend='soundfile')
29
  if sample_rate != 16000:
30
  audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
31
+
32
  # Convert audio to mono
33
  if audio.shape[0] > 1:
34
  audio = audio.mean(dim=0, keepdim=True)
35
+
36
  if audio.shape[1] / 16000 > 30:
37
+ logging.warning(f'Audio longer than 30s, skipping tokenization for {wav_path}')
38
  speech_token = []
39
  else:
40
  feat = whisper.log_mel_spectrogram(audio, n_mels=128)
41
+ speech_token = ort_session.run(None, {
42
+ ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
43
+ ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)
44
+ })[0].flatten().tolist()
45
+
46
+ # Save individual token file
47
+ token_path = wav_path.replace('.wav', '_tokens.pt')
48
+ torch.save(speech_token, token_path)
49
+
50
+ return {
51
+ 'wav_path': wav_path,
52
+ 'utt': utt,
53
+ 'token_path': token_path,
54
+ 'num_tokens': len(speech_token)
55
+ }
56
 
57
 
58
  def main(args):
59
+ # Find all wav files
60
+ wav_files = list(glob.glob('{}/*/*/*wav'.format(args.src_dir)))
61
+ print(f"Found {len(wav_files)} wav files")
62
+
63
+ # Process all audio files
64
+ all_tasks = [executor.submit(process_single_audio, wav_path) for wav_path in wav_files]
65
+
66
+ # Collect results
67
+ successful_files = []
68
+
69
+ for future in tqdm(as_completed(all_tasks), total=len(all_tasks)):
70
+ result = future.result()
71
+ if result is None:
72
+ continue
73
+ successful_files.append(result)
74
+
75
+ # Save a summary file for reference
76
+ summary_path = os.path.join(args.src_dir, "token_summary.txt")
77
+ with open(summary_path, 'w') as f:
78
+ f.write(f"Processed {len(successful_files)} files successfully\n")
79
+ total_tokens = sum(r['num_tokens'] for r in successful_files)
80
+ f.write(f"Total tokens generated: {total_tokens}\n")
81
+ for result in successful_files:
82
+ f.write(f"{result['utt']} {result['wav_path']} {result['token_path']} {result['num_tokens']}\n")
83
 
84
 
85
  if __name__ == "__main__":
86
  parser = argparse.ArgumentParser()
87
+ parser.add_argument("--src_dir", type=str, help="Source directory containing audio files")
88
+ parser.add_argument("--onnx_path", type=str, help="Path to speech_tokenizer_v2.onnx model")
89
  parser.add_argument("--num_thread", type=int, default=8)
90
  args = parser.parse_args()
91
 
 
 
 
 
 
 
92
  option = onnxruntime.SessionOptions()
93
  option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
94
  option.intra_op_num_threads = 1
 
96
  ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
97
  executor = ThreadPoolExecutor(max_workers=args.num_thread)
98
 
99
+ main(args)
speech/tools/inv_file_processor.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Add this to your processor.py file or create a new file
2
+
3
+ import logging
4
+ import json
5
+ import torch
6
+ import glob
7
+ import os
8
+ from pathlib import Path
9
+
10
+ def individual_file_opener(data, mode='train', tts_data={}):
11
+ """Load data from individual files instead of parquet
12
+
13
+ Args:
14
+ data: Iterable[{src}] where src is either:
15
+ - Path to index JSON file
16
+ - Directory path containing wav files
17
+ mode: 'train' or 'test'
18
+ tts_data: Dict for TTS mode
19
+
20
+ Yields:
21
+ Dict with all required fields for training
22
+ """
23
+ for sample in data:
24
+ src = sample['src']
25
+
26
+ # Check if src is a JSON index file or a directory
27
+ if src.endswith('.json'):
28
+ # Load from index file
29
+ with open(src, 'r') as f:
30
+ index_data = json.load(f)
31
+ file_list = index_data.get('data', [])
32
+ else:
33
+ # Find all wav files in directory
34
+ wav_files = glob.glob(os.path.join(src, '*/*/*wav'))
35
+ file_list = []
36
+ for wav_path in wav_files:
37
+ # Check if all required files exist
38
+ txt_path = wav_path.replace('.wav', '.normalized.txt')
39
+ embedding_path = wav_path.replace('.wav', '_embedding.pt')
40
+ token_path = wav_path.replace('.wav', '_tokens.pt')
41
+
42
+ if not all(os.path.exists(p) for p in [txt_path, embedding_path, token_path]):
43
+ logging.warning(f'Missing files for {wav_path}, skipping')
44
+ continue
45
+
46
+ # Extract metadata
47
+ utt = os.path.basename(wav_path).replace('.wav', '')
48
+ spk = utt.split('_')[0]
49
+
50
+ file_list.append({
51
+ 'utt': utt,
52
+ 'spk': spk,
53
+ 'wav': wav_path,
54
+ 'text_path': txt_path,
55
+ 'embedding_path': embedding_path,
56
+ 'token_path': token_path,
57
+ 'spk_embedding_path': os.path.join(os.path.dirname(src), f"spk_embeddings/{spk}_embedding.pt")
58
+ })
59
+
60
+ # Process each file
61
+ for file_info in file_list:
62
+ try:
63
+ # Read audio data
64
+ with open(file_info['wav'], 'rb') as f:
65
+ audio_data = f.read()
66
+
67
+ # Read text
68
+ with open(file_info['text_path'], 'r') as f:
69
+ text = ''.join(l.strip() for l in f.readlines())
70
+
71
+ # Load embeddings
72
+ utt_embedding = torch.load(file_info['embedding_path']).tolist()
73
+ speech_token = torch.load(file_info['token_path'])
74
+
75
+ # Load speaker embedding
76
+ if os.path.exists(file_info['spk_embedding_path']):
77
+ spk_embedding = torch.load(file_info['spk_embedding_path']).tolist()
78
+ else:
79
+ logging.warning(f"Speaker embedding not found: {file_info['spk_embedding_path']}")
80
+ spk_embedding = utt_embedding # Fallback to utterance embedding
81
+
82
+ # Build sample dict
83
+ sample_dict = {
84
+ 'utt': file_info['utt'],
85
+ 'spk': file_info['spk'],
86
+ 'audio_data': audio_data,
87
+ 'text': text,
88
+ 'text_token': [], # Will be filled by tokenize processor
89
+ 'utt_embedding': utt_embedding,
90
+ 'spk_embedding': spk_embedding,
91
+ 'speech_token': speech_token,
92
+ 'wav': file_info['wav'], # Keep original path for reference
93
+ }
94
+
95
+ # Merge with original sample data
96
+ sample_dict.update(sample)
97
+
98
+ if mode == 'train':
99
+ yield sample_dict
100
+ else:
101
+ # For TTS mode
102
+ if file_info['utt'] in tts_data:
103
+ for index, tts_text in enumerate(tts_data[file_info['utt']]):
104
+ yield {**sample_dict, 'tts_index': index, 'tts_text': tts_text}
105
+ else:
106
+ yield sample_dict
107
+
108
+ except Exception as ex:
109
+ logging.warning(f'Failed to process {file_info["wav"]}: {ex}')