learnable-speech / speech /tools /extract_speech_token.py
primepake
inv code
eb584bd
raw
history blame
3.49 kB
#!/usr/bin/env python3
import argparse
from concurrent.futures import ThreadPoolExecutor, as_completed
import logging
import torch
from tqdm import tqdm
import onnxruntime
import numpy as np
import torchaudio
import whisper
import glob
import os
logger = logging.getLogger()
def process_single_audio(wav_path):
# Check if text file exists
txt_path = wav_path.replace('.wav', '.normalized.txt')
if not os.path.exists(txt_path):
logger.warning(f'{txt_path} does not exist, skipping {wav_path}')
return None
# Extract utterance ID
utt = os.path.basename(wav_path).replace('.wav', '')
# Process audio
audio, sample_rate = torchaudio.load(wav_path, backend='soundfile')
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
# Convert audio to mono
if audio.shape[0] > 1:
audio = audio.mean(dim=0, keepdim=True)
if audio.shape[1] / 16000 > 30:
logging.warning(f'Audio longer than 30s, skipping tokenization for {wav_path}')
speech_token = []
else:
feat = whisper.log_mel_spectrogram(audio, n_mels=128)
speech_token = ort_session.run(None, {
ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)
})[0].flatten().tolist()
# Save individual token file
token_path = wav_path.replace('.wav', '_tokens.pt')
torch.save(speech_token, token_path)
return {
'wav_path': wav_path,
'utt': utt,
'token_path': token_path,
'num_tokens': len(speech_token)
}
def main(args):
# Find all wav files
wav_files = list(glob.glob('{}/*/*/*wav'.format(args.src_dir)))
print(f"Found {len(wav_files)} wav files")
# Process all audio files
all_tasks = [executor.submit(process_single_audio, wav_path) for wav_path in wav_files]
# Collect results
successful_files = []
for future in tqdm(as_completed(all_tasks), total=len(all_tasks)):
result = future.result()
if result is None:
continue
successful_files.append(result)
# Save a summary file for reference
summary_path = os.path.join(args.src_dir, "token_summary.txt")
with open(summary_path, 'w') as f:
f.write(f"Processed {len(successful_files)} files successfully\n")
total_tokens = sum(r['num_tokens'] for r in successful_files)
f.write(f"Total tokens generated: {total_tokens}\n")
for result in successful_files:
f.write(f"{result['utt']} {result['wav_path']} {result['token_path']} {result['num_tokens']}\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--src_dir", type=str, help="Source directory containing audio files")
parser.add_argument("--onnx_path", type=str, help="Path to speech_tokenizer_v2.onnx model")
parser.add_argument("--num_thread", type=int, default=8)
args = parser.parse_args()
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
providers = ["CUDAExecutionProvider"]
ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
executor = ThreadPoolExecutor(max_workers=args.num_thread)
main(args)