|
|
import math |
|
|
import numpy as np |
|
|
import torchaudio |
|
|
import traceback |
|
|
import triton_python_backend_utils as pb_utils |
|
|
|
|
|
|
|
|
class TritonPythonModel: |
|
|
def initialize(self, args): |
|
|
self.sample_rate = 16000 |
|
|
self.feature_dim = 80 |
|
|
self.vad_enabled = True |
|
|
self.min_duration = 0.1 |
|
|
|
|
|
|
|
|
self.speaker_model_name = "speaker_model" |
|
|
|
|
|
def execute(self, requests): |
|
|
responses = [] |
|
|
for request in requests: |
|
|
|
|
|
audio1 = pb_utils.get_input_tensor_by_name(request, "AUDIO1").as_numpy()[0].decode('utf-8') |
|
|
audio2 = pb_utils.get_input_tensor_by_name(request, "AUDIO2").as_numpy()[0].decode('utf-8') |
|
|
|
|
|
|
|
|
feats1 = self.preprocess(audio1) |
|
|
feats2 = self.preprocess(audio2) |
|
|
|
|
|
|
|
|
similarity = self.compute_similarity(feats1, feats2) |
|
|
|
|
|
|
|
|
output_tensor = pb_utils.Tensor("SIMILARITY", np.array([similarity]), dtype=np.float32) |
|
|
response = pb_utils.InferenceResponse(output_tensors=[output_tensor]) |
|
|
responses.append(response) |
|
|
|
|
|
return responses |
|
|
|
|
|
def preprocess(self, audio_path): |
|
|
""" |
|
|
处理音频文件,如果过短则复制到满足最小长度要求 |
|
|
返回处理后的音频路径和是否为临时文件的标志 |
|
|
""" |
|
|
try: |
|
|
waveform, sample_rate = torchaudio.load(audio_path) |
|
|
duration = waveform.shape[1] / sample_rate |
|
|
|
|
|
if duration >= self.min_duration: |
|
|
|
|
|
return waveform |
|
|
|
|
|
|
|
|
repeat_times = math.ceil(self.min_duration / duration) |
|
|
|
|
|
|
|
|
return waveform.repeat(1, repeat_times) |
|
|
|
|
|
except Exception: |
|
|
traceback.format_exc() |
|
|
return None |
|
|
|
|
|
def compute_similarity(self, feats1, feats2): |
|
|
|
|
|
e1 = self.call_speaker_model(feats1) |
|
|
e2 = self.call_speaker_model(feats2) |
|
|
|
|
|
|
|
|
dot_product = np.dot(e1, e2) |
|
|
norm_e1 = np.linalg.norm(e1) |
|
|
norm_e2 = np.linalg.norm(e2) |
|
|
similarity = dot_product / (norm_e1 * norm_e2) |
|
|
|
|
|
|
|
|
return (similarity + 1) / 2 |
|
|
|
|
|
def call_speaker_model(self, features): |
|
|
"""调用speaker_model获取嵌入向量""" |
|
|
|
|
|
input_tensor = pb_utils.Tensor("feats", features.astype(np.float32)) |
|
|
|
|
|
|
|
|
inference_request = pb_utils.InferenceRequest( |
|
|
model_name=self.speaker_model_name, |
|
|
requested_output_names=["embs"], |
|
|
inputs=[input_tensor] |
|
|
) |
|
|
|
|
|
|
|
|
inference_response = inference_request.exec() |
|
|
|
|
|
|
|
|
if inference_response.has_error(): |
|
|
raise pb_utils.TritonModelException(inference_response.error().message()) |
|
|
|
|
|
|
|
|
output_tensor = pb_utils.get_output_tensor_by_name(inference_response, "embs") |
|
|
return output_tensor.as_numpy() |
|
|
|