MCplayer's picture
speech similarity model
29c0409
import math
import numpy as np
import torchaudio
import traceback
import triton_python_backend_utils as pb_utils
class TritonPythonModel:
def initialize(self, args):
self.sample_rate = 16000
self.feature_dim = 80
self.vad_enabled = True
self.min_duration = 0.1
# 创建与speaker_model通信的客户端
self.speaker_model_name = "speaker_model"
def execute(self, requests):
responses = []
for request in requests:
# 获取输入音频
audio1 = pb_utils.get_input_tensor_by_name(request, "AUDIO1").as_numpy()[0].decode('utf-8')
audio2 = pb_utils.get_input_tensor_by_name(request, "AUDIO2").as_numpy()[0].decode('utf-8')
# 预处理音频
feats1 = self.preprocess(audio1)
feats2 = self.preprocess(audio2)
# 调用speaker_model计算相似度
similarity = self.compute_similarity(feats1, feats2)
# 准备输出
output_tensor = pb_utils.Tensor("SIMILARITY", np.array([similarity]), dtype=np.float32)
response = pb_utils.InferenceResponse(output_tensors=[output_tensor])
responses.append(response)
return responses
def preprocess(self, audio_path):
"""
处理音频文件,如果过短则复制到满足最小长度要求
返回处理后的音频路径和是否为临时文件的标志
"""
try:
waveform, sample_rate = torchaudio.load(audio_path)
duration = waveform.shape[1] / sample_rate
if duration >= self.min_duration:
# 音频长度足够,直接返回原路径
return waveform
# 音频过短,需要复制
repeat_times = math.ceil(self.min_duration / duration)
# 复制音频
return waveform.repeat(1, repeat_times)
except Exception:
traceback.format_exc()
return None
def compute_similarity(self, feats1, feats2):
# 调用speaker_model获取嵌入向量
e1 = self.call_speaker_model(feats1)
e2 = self.call_speaker_model(feats2)
# 计算余弦相似度
dot_product = np.dot(e1, e2)
norm_e1 = np.linalg.norm(e1)
norm_e2 = np.linalg.norm(e2)
similarity = dot_product / (norm_e1 * norm_e2)
# 归一化到[0, 1]
return (similarity + 1) / 2
def call_speaker_model(self, features):
"""调用speaker_model获取嵌入向量"""
# 创建输入张量
input_tensor = pb_utils.Tensor("feats", features.astype(np.float32))
# 创建推理请求
inference_request = pb_utils.InferenceRequest(
model_name=self.speaker_model_name,
requested_output_names=["embs"],
inputs=[input_tensor]
)
# 发送请求
inference_response = inference_request.exec()
# 处理响应
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# 获取嵌入向量
output_tensor = pb_utils.get_output_tensor_by_name(inference_response, "embs")
return output_tensor.as_numpy()