RWKV-TTS / data /utils /create_embeddings_from_raw.py
kevinwang676's picture
Add files using upload-large-folder tool
fd82c69 verified
import os
from re import A
import whisper
from librosa import resample
import multiprocessing
from tqdm import tqdm
import onnxruntime
from onnxruntime import InferenceSession
import torch
import pyarrow.parquet as pq
import numpy as np
import json
import io
import soundfile as sf
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import mmap
import os
import pyarrow.parquet as pq
import io
import soundfile as sf
import torchaudio.compliance.kaldi as kaldi
import torch
import numpy as np
import onnxruntime
def process_file(file_info):
"""处理单个parquet文件的函数,每个进程调用一次"""
parquet_file, output_path, speaker_extractor, device = file_info
# 为每个进程创建独立的speech_tokenizer_session
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
ort_session = onnxruntime.InferenceSession(speaker_extractor, sess_options=option,
providers=["CPUExecutionProvider"])
results = {}
try:
# 创建目标文件名
base_filename = os.path.splitext(os.path.basename(parquet_file))[0]
output_file = os.path.join(output_path, f"{base_filename}_tokens.jsonl")
# 使用PyArrow读取parquet文件的元数据,获取总行数
parquet_metadata = pq.read_metadata(parquet_file)
total_rows = parquet_metadata.num_rows
batch_size = 100
# 使用 mmap 读取 parquet 文件
with open(parquet_file, 'rb') as f:
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
# 使用 io.BytesIO 将 mmap 对象包装成文件对象
buffer = io.BytesIO(mm)
pf = pq.ParquetFile(buffer) # 使用 mmap 包装的 buffer
progress = tqdm(total=total_rows,
desc=f"Processing {os.path.basename(parquet_file)}",
position=multiprocessing.current_process()._identity[0] % 10)
current_row = 0
idx = 0
for batch in pf.iter_batches(batch_size=batch_size):
df_batch = batch.to_pandas()
# 处理当前批次中的每一行
for _, row in df_batch.iterrows():
current_row += 1
audio_obj = row['audio']
audio_data = audio_obj['bytes']
transcription = row['transcription']
language = row['language']
speaker = row['speaker']
if speaker not in results:
results[speaker] = {}
if language not in results[speaker]:
results[speaker][language] = []
if len(results[speaker][language]) >= 10:
progress.update(1)
continue
with io.BytesIO(audio_data) as audio_buffer:
prompt_data, sample_rate = sf.read(audio_buffer)
# 确保是单声道,并转换为float32
if len(prompt_data.shape) > 1:
prompt_data = prompt_data[:, 0]
prompt_data = prompt_data.astype(np.float32)
# 重采样到16kHz (如果需要)
if sample_rate != 16000:
prompt_data = resample(prompt_data, orig_sr=sample_rate, target_sr=16000)
prompt_speech_16k = torch.tensor(prompt_data).unsqueeze(0)
feat = kaldi.fbank(prompt_speech_16k,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
feat = feat - feat.mean(dim=0,keepdim=True)
embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
results[speaker][language].append(embedding)
progress.update(1)
# 关闭 mmap 对象
mm.close()
print(f'All speakers {results.keys()}')
for speaker in results:
print(f'{speaker} : All languages {results[speaker].keys()} in {os.getpid()}')
return results
except Exception as e:
import traceback
traceback.print_exc()
return f"Error processing {parquet_file}: {str(e)}"
def process_file_x(file_info):
"""处理单个parquet文件的函数,每个进程调用一次"""
parquet_file, output_path, speaker_extractor, device = file_info
# 为每个进程创建独立的speech_tokenizer_session
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
ort_session = InferenceSession(speaker_extractor, sess_options=option,
providers=["CPUExecutionProvider"])
results = {}
try:
# 创建目标文件名
base_filename = os.path.splitext(os.path.basename(parquet_file))[0]
output_file = os.path.join(output_path, f"{base_filename}_tokens.jsonl")
# 使用PyArrow读取parquet文件的元数据,获取总行数
parquet_metadata = pq.read_metadata(parquet_file)
total_rows = parquet_metadata.num_rows
batch_size = 100
pf = pq.ParquetFile(parquet_file)
progress = tqdm(total=total_rows,
desc=f"Processing {os.path.basename(parquet_file)}",
position=multiprocessing.current_process()._identity[0] % 10)
current_row = 0
idx = 0
for batch in pf.iter_batches(batch_size=batch_size):
df_batch = batch.to_pandas()
# 处理当前批次中的每一行
for _, row in df_batch.iterrows():
current_row += 1
audio_obj = row['audio']
audio_data = audio_obj['bytes']
transcription = row['transcription']
language = row['language']
speaker = row['speaker']
if speaker not in results:
results[speaker] = {}
if language not in results[speaker]:
results[speaker][language] = []
if len(results[speaker][language]) >= 10:
progress.update(1)
continue
with io.BytesIO(audio_data) as buffer:
prompt_data, sample_rate = sf.read(buffer)
# 确保是单声道,并转换为float32
if len(prompt_data.shape) > 1:
prompt_data = prompt_data[:, 0]
prompt_data = prompt_data.astype(np.float32)
# 重采样到16kHz (如果需要)
if sample_rate != 16000:
prompt_data = resample(prompt_data, orig_sr=sample_rate, target_sr=16000)
prompt_speech_16k = torch.tensor(prompt_data).unsqueeze(0)
feat = kaldi.fbank(prompt_speech_16k,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
feat = feat - feat.mean(dim=0,keepdim=True)
embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
results[speaker][language].append(embedding)
progress.update(1)
print(f'All speakers {results.keys()}')
for speaker in results:
print(f'{speaker} : All languages {results[speaker].keys()} in {os.getpid()}')
return results
except Exception as e:
import traceback
traceback.print_exc()
return f"Error processing {parquet_file}: {str(e)}"
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, default='/external_data/yueyudata/starrail-voice')
parser.add_argument('--output_path',type=str,default='/external_data/yueyudata/starrail-voice-speaker-embeddings')
parser.add_argument('--speaker_extractor',type=str,default='/external_data/models/CosyVoice2-0.5B_RWKV_1.5B/campplus.onnx')
parser.add_argument('--device',type=str,default='cuda:0')
parser.add_argument('--num_processes',type=int,default=4)
args = parser.parse_args()
print(args)
data_path = args.data_path
output_path = args.output_path
device = args.device
speaker_extractor = args.speaker_extractor
num_processes = args.num_processes
# 确保输出目录存在
os.makedirs(output_path, exist_ok=True)
# 找到所有parquet文件
parquet_files = []
for root, dirs, files in os.walk(data_path):
for file in files:
if file.endswith('.parquet'):
parquet_files.append(os.path.join(root, file))
print(f'Found {len(parquet_files)} parquet files in {data_path}')
# 准备多进程参数
file_info_list = [(file, output_path, speaker_extractor, device) for file in parquet_files]
# 使用进程池处理文件
print(f"Starting processing with {num_processes} processes")
# 使用进程池处理文件
print(f"Starting processing with {num_processes} processes")
with multiprocessing.Pool(processes=num_processes) as pool:
results = pool.map(process_file, file_info_list)
# 输出处理结果
print('Processing complete,merge results')
final_results = {}
for result in results:
if isinstance(result, dict):
for speaker in result:
if speaker not in final_results:
final_results[speaker] = {}
for language in result[speaker]:
if language not in final_results[speaker]:
final_results[speaker][language] = []
final_results[speaker][language].extend(result[speaker][language])
else:
print(result)
# 输出结果
for speaker in final_results:
for language in final_results[speaker]:
output_file = os.path.join(output_path, f"{speaker}_{language}_embeddings.json")
print(f"Writing embeddings for {speaker} ({language}) to {output_file}")
with open(output_file, 'w', encoding='utf-8') as f_out:
json.dump(final_results[speaker][language], f_out)