Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| import argparse | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| import onnxruntime | |
| import torch | |
| import torchaudio | |
| import torchaudio.compliance.kaldi as kaldi | |
| from tqdm import tqdm | |
| import os | |
| import glob | |
| import logging | |
| logger = logging.getLogger() | |
| def process_single_audio(wav_path): | |
| # Extract utterance ID and speaker ID from filename | |
| utt = os.path.basename(wav_path).replace('.wav', '') | |
| spk = utt.split('_')[0] | |
| # Check if text file exists | |
| txt_path = wav_path.replace('.wav', '.normalized.txt') | |
| if not os.path.exists(txt_path): | |
| logger.warning(f'{txt_path} does not exist, skipping {wav_path}') | |
| return None | |
| # Process audio | |
| audio, sample_rate = torchaudio.load(wav_path) | |
| if sample_rate != 16000: | |
| audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio) | |
| feat = kaldi.fbank(audio, | |
| num_mel_bins=80, | |
| dither=0, | |
| sample_frequency=16000) | |
| feat = feat - feat.mean(dim=0, keepdim=True) | |
| # Generate embedding | |
| embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten() | |
| # Save individual embedding file | |
| embedding_path = wav_path.replace('.wav', '_embedding.pt') | |
| torch.save(embedding, embedding_path) | |
| return { | |
| 'wav_path': wav_path, | |
| 'utt': utt, | |
| 'spk': spk, | |
| 'embedding': embedding, | |
| 'embedding_path': embedding_path | |
| } | |
| def main(args): | |
| # Find all wav files | |
| wav_files = list(glob.glob('{}/*/*/*wav'.format(args.src_dir))) | |
| print(f"Found {len(wav_files)} wav files") | |
| # Process all audio files | |
| all_tasks = [executor.submit(process_single_audio, wav_path) for wav_path in wav_files] | |
| # Collect results | |
| spk2embeddings = {} | |
| successful_files = [] | |
| for future in tqdm(as_completed(all_tasks), total=len(all_tasks)): | |
| result = future.result() | |
| if result is None: | |
| continue | |
| successful_files.append(result) | |
| # Collect embeddings by speaker | |
| spk = result['spk'] | |
| if spk not in spk2embeddings: | |
| spk2embeddings[spk] = [] | |
| spk2embeddings[spk].append(result['embedding']) | |
| # Calculate and save speaker embeddings | |
| spk_embed_dir = os.path.join(args.src_dir, "spk_embeddings") | |
| os.makedirs(spk_embed_dir, exist_ok=True) | |
| for spk, embeddings in spk2embeddings.items(): | |
| spk_embedding = torch.stack([torch.tensor(e) for e in embeddings]).mean(dim=0) | |
| spk_embedding_path = os.path.join(spk_embed_dir, f"{spk}_embedding.pt") | |
| torch.save(spk_embedding, spk_embedding_path) | |
| print(f"Saved speaker embedding for {spk} with {len(embeddings)} utterances") | |
| # Save a summary file for reference | |
| summary_path = os.path.join(args.src_dir, "embedding_summary.txt") | |
| with open(summary_path, 'w') as f: | |
| f.write(f"Processed {len(successful_files)} files successfully\n") | |
| f.write(f"Found {len(spk2embeddings)} speakers\n") | |
| for result in successful_files: | |
| f.write(f"{result['utt']} {result['wav_path']} {result['embedding_path']}\n") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--src_dir", type=str, help="Source directory containing audio files") | |
| parser.add_argument("--onnx_path", type=str, help="Path to campplus.onnx model") | |
| parser.add_argument("--num_thread", type=int, default=8) | |
| args = parser.parse_args() | |
| option = onnxruntime.SessionOptions() | |
| option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| option.intra_op_num_threads = 1 | |
| providers = ["CPUExecutionProvider"] | |
| ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers) | |
| executor = ThreadPoolExecutor(max_workers=args.num_thread) | |
| main(args) |