|
|
import os |
|
|
from re import A |
|
|
import whisper |
|
|
from librosa import resample |
|
|
import multiprocessing |
|
|
from tqdm import tqdm |
|
|
import onnxruntime |
|
|
from onnxruntime import InferenceSession |
|
|
import torch |
|
|
import pyarrow.parquet as pq |
|
|
import numpy as np |
|
|
import json |
|
|
import io |
|
|
import soundfile as sf |
|
|
import torchaudio |
|
|
import torchaudio.compliance.kaldi as kaldi |
|
|
import mmap |
|
|
import os |
|
|
import pyarrow.parquet as pq |
|
|
import io |
|
|
import soundfile as sf |
|
|
import torchaudio.compliance.kaldi as kaldi |
|
|
import torch |
|
|
import numpy as np |
|
|
import onnxruntime |
|
|
|
|
|
def process_file(file_info): |
|
|
"""处理单个parquet文件的函数,每个进程调用一次""" |
|
|
parquet_file, output_path, speaker_extractor, device = file_info |
|
|
|
|
|
|
|
|
option = onnxruntime.SessionOptions() |
|
|
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL |
|
|
option.intra_op_num_threads = 1 |
|
|
ort_session = onnxruntime.InferenceSession(speaker_extractor, sess_options=option, |
|
|
providers=["CPUExecutionProvider"]) |
|
|
results = {} |
|
|
try: |
|
|
|
|
|
base_filename = os.path.splitext(os.path.basename(parquet_file))[0] |
|
|
output_file = os.path.join(output_path, f"{base_filename}_tokens.jsonl") |
|
|
|
|
|
|
|
|
parquet_metadata = pq.read_metadata(parquet_file) |
|
|
total_rows = parquet_metadata.num_rows |
|
|
batch_size = 100 |
|
|
|
|
|
|
|
|
with open(parquet_file, 'rb') as f: |
|
|
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) |
|
|
|
|
|
|
|
|
buffer = io.BytesIO(mm) |
|
|
|
|
|
pf = pq.ParquetFile(buffer) |
|
|
|
|
|
progress = tqdm(total=total_rows, |
|
|
desc=f"Processing {os.path.basename(parquet_file)}", |
|
|
position=multiprocessing.current_process()._identity[0] % 10) |
|
|
|
|
|
current_row = 0 |
|
|
idx = 0 |
|
|
for batch in pf.iter_batches(batch_size=batch_size): |
|
|
df_batch = batch.to_pandas() |
|
|
|
|
|
|
|
|
for _, row in df_batch.iterrows(): |
|
|
current_row += 1 |
|
|
audio_obj = row['audio'] |
|
|
audio_data = audio_obj['bytes'] |
|
|
transcription = row['transcription'] |
|
|
language = row['language'] |
|
|
speaker = row['speaker'] |
|
|
if speaker not in results: |
|
|
results[speaker] = {} |
|
|
if language not in results[speaker]: |
|
|
results[speaker][language] = [] |
|
|
if len(results[speaker][language]) >= 10: |
|
|
progress.update(1) |
|
|
continue |
|
|
|
|
|
with io.BytesIO(audio_data) as audio_buffer: |
|
|
prompt_data, sample_rate = sf.read(audio_buffer) |
|
|
|
|
|
if len(prompt_data.shape) > 1: |
|
|
prompt_data = prompt_data[:, 0] |
|
|
prompt_data = prompt_data.astype(np.float32) |
|
|
|
|
|
|
|
|
if sample_rate != 16000: |
|
|
prompt_data = resample(prompt_data, orig_sr=sample_rate, target_sr=16000) |
|
|
|
|
|
prompt_speech_16k = torch.tensor(prompt_data).unsqueeze(0) |
|
|
|
|
|
feat = kaldi.fbank(prompt_speech_16k, |
|
|
num_mel_bins=80, |
|
|
dither=0, |
|
|
sample_frequency=16000) |
|
|
feat = feat - feat.mean(dim=0,keepdim=True) |
|
|
embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() |
|
|
|
|
|
results[speaker][language].append(embedding) |
|
|
|
|
|
progress.update(1) |
|
|
|
|
|
|
|
|
mm.close() |
|
|
|
|
|
|
|
|
|
|
|
print(f'All speakers {results.keys()}') |
|
|
for speaker in results: |
|
|
print(f'{speaker} : All languages {results[speaker].keys()} in {os.getpid()}') |
|
|
return results |
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return f"Error processing {parquet_file}: {str(e)}" |
|
|
def process_file_x(file_info): |
|
|
"""处理单个parquet文件的函数,每个进程调用一次""" |
|
|
parquet_file, output_path, speaker_extractor, device = file_info |
|
|
|
|
|
|
|
|
option = onnxruntime.SessionOptions() |
|
|
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL |
|
|
option.intra_op_num_threads = 1 |
|
|
ort_session = InferenceSession(speaker_extractor, sess_options=option, |
|
|
providers=["CPUExecutionProvider"]) |
|
|
results = {} |
|
|
try: |
|
|
|
|
|
base_filename = os.path.splitext(os.path.basename(parquet_file))[0] |
|
|
output_file = os.path.join(output_path, f"{base_filename}_tokens.jsonl") |
|
|
|
|
|
|
|
|
parquet_metadata = pq.read_metadata(parquet_file) |
|
|
total_rows = parquet_metadata.num_rows |
|
|
batch_size = 100 |
|
|
|
|
|
pf = pq.ParquetFile(parquet_file) |
|
|
|
|
|
progress = tqdm(total=total_rows, |
|
|
desc=f"Processing {os.path.basename(parquet_file)}", |
|
|
position=multiprocessing.current_process()._identity[0] % 10) |
|
|
|
|
|
current_row = 0 |
|
|
idx = 0 |
|
|
for batch in pf.iter_batches(batch_size=batch_size): |
|
|
df_batch = batch.to_pandas() |
|
|
|
|
|
|
|
|
for _, row in df_batch.iterrows(): |
|
|
current_row += 1 |
|
|
audio_obj = row['audio'] |
|
|
audio_data = audio_obj['bytes'] |
|
|
transcription = row['transcription'] |
|
|
language = row['language'] |
|
|
speaker = row['speaker'] |
|
|
if speaker not in results: |
|
|
results[speaker] = {} |
|
|
if language not in results[speaker]: |
|
|
results[speaker][language] = [] |
|
|
if len(results[speaker][language]) >= 10: |
|
|
progress.update(1) |
|
|
continue |
|
|
|
|
|
with io.BytesIO(audio_data) as buffer: |
|
|
prompt_data, sample_rate = sf.read(buffer) |
|
|
|
|
|
if len(prompt_data.shape) > 1: |
|
|
prompt_data = prompt_data[:, 0] |
|
|
prompt_data = prompt_data.astype(np.float32) |
|
|
|
|
|
|
|
|
if sample_rate != 16000: |
|
|
prompt_data = resample(prompt_data, orig_sr=sample_rate, target_sr=16000) |
|
|
|
|
|
prompt_speech_16k = torch.tensor(prompt_data).unsqueeze(0) |
|
|
|
|
|
feat = kaldi.fbank(prompt_speech_16k, |
|
|
num_mel_bins=80, |
|
|
dither=0, |
|
|
sample_frequency=16000) |
|
|
feat = feat - feat.mean(dim=0,keepdim=True) |
|
|
embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() |
|
|
|
|
|
results[speaker][language].append(embedding) |
|
|
|
|
|
progress.update(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f'All speakers {results.keys()}') |
|
|
for speaker in results: |
|
|
print(f'{speaker} : All languages {results[speaker].keys()} in {os.getpid()}') |
|
|
return results |
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return f"Error processing {parquet_file}: {str(e)}" |
|
|
if __name__ == '__main__': |
|
|
import argparse |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--data_path', type=str, default='/external_data/yueyudata/starrail-voice') |
|
|
parser.add_argument('--output_path',type=str,default='/external_data/yueyudata/starrail-voice-speaker-embeddings') |
|
|
parser.add_argument('--speaker_extractor',type=str,default='/external_data/models/CosyVoice2-0.5B_RWKV_1.5B/campplus.onnx') |
|
|
parser.add_argument('--device',type=str,default='cuda:0') |
|
|
parser.add_argument('--num_processes',type=int,default=4) |
|
|
args = parser.parse_args() |
|
|
|
|
|
print(args) |
|
|
data_path = args.data_path |
|
|
output_path = args.output_path |
|
|
device = args.device |
|
|
speaker_extractor = args.speaker_extractor |
|
|
num_processes = args.num_processes |
|
|
|
|
|
|
|
|
os.makedirs(output_path, exist_ok=True) |
|
|
|
|
|
|
|
|
parquet_files = [] |
|
|
for root, dirs, files in os.walk(data_path): |
|
|
for file in files: |
|
|
if file.endswith('.parquet'): |
|
|
parquet_files.append(os.path.join(root, file)) |
|
|
print(f'Found {len(parquet_files)} parquet files in {data_path}') |
|
|
|
|
|
|
|
|
file_info_list = [(file, output_path, speaker_extractor, device) for file in parquet_files] |
|
|
|
|
|
|
|
|
print(f"Starting processing with {num_processes} processes") |
|
|
|
|
|
|
|
|
print(f"Starting processing with {num_processes} processes") |
|
|
with multiprocessing.Pool(processes=num_processes) as pool: |
|
|
results = pool.map(process_file, file_info_list) |
|
|
|
|
|
|
|
|
print('Processing complete,merge results') |
|
|
final_results = {} |
|
|
for result in results: |
|
|
if isinstance(result, dict): |
|
|
for speaker in result: |
|
|
if speaker not in final_results: |
|
|
final_results[speaker] = {} |
|
|
for language in result[speaker]: |
|
|
if language not in final_results[speaker]: |
|
|
final_results[speaker][language] = [] |
|
|
final_results[speaker][language].extend(result[speaker][language]) |
|
|
else: |
|
|
print(result) |
|
|
|
|
|
|
|
|
for speaker in final_results: |
|
|
for language in final_results[speaker]: |
|
|
output_file = os.path.join(output_path, f"{speaker}_{language}_embeddings.json") |
|
|
print(f"Writing embeddings for {speaker} ({language}) to {output_file}") |
|
|
with open(output_file, 'w', encoding='utf-8') as f_out: |
|
|
json.dump(final_results[speaker][language], f_out) |