import math import numpy as np import torchaudio import traceback import triton_python_backend_utils as pb_utils class TritonPythonModel: def initialize(self, args): self.sample_rate = 16000 self.feature_dim = 80 self.vad_enabled = True self.min_duration = 0.1 # 创建与speaker_model通信的客户端 self.speaker_model_name = "speaker_model" def execute(self, requests): responses = [] for request in requests: # 获取输入音频 audio1 = pb_utils.get_input_tensor_by_name(request, "AUDIO1").as_numpy()[0].decode('utf-8') audio2 = pb_utils.get_input_tensor_by_name(request, "AUDIO2").as_numpy()[0].decode('utf-8') # 预处理音频 feats1 = self.preprocess(audio1) feats2 = self.preprocess(audio2) # 调用speaker_model计算相似度 similarity = self.compute_similarity(feats1, feats2) # 准备输出 output_tensor = pb_utils.Tensor("SIMILARITY", np.array([similarity]), dtype=np.float32) response = pb_utils.InferenceResponse(output_tensors=[output_tensor]) responses.append(response) return responses def preprocess(self, audio_path): """ 处理音频文件,如果过短则复制到满足最小长度要求 返回处理后的音频路径和是否为临时文件的标志 """ try: waveform, sample_rate = torchaudio.load(audio_path) duration = waveform.shape[1] / sample_rate if duration >= self.min_duration: # 音频长度足够,直接返回原路径 return waveform # 音频过短,需要复制 repeat_times = math.ceil(self.min_duration / duration) # 复制音频 return waveform.repeat(1, repeat_times) except Exception: traceback.format_exc() return None def compute_similarity(self, feats1, feats2): # 调用speaker_model获取嵌入向量 e1 = self.call_speaker_model(feats1) e2 = self.call_speaker_model(feats2) # 计算余弦相似度 dot_product = np.dot(e1, e2) norm_e1 = np.linalg.norm(e1) norm_e2 = np.linalg.norm(e2) similarity = dot_product / (norm_e1 * norm_e2) # 归一化到[0, 1] return (similarity + 1) / 2 def call_speaker_model(self, features): """调用speaker_model获取嵌入向量""" # 创建输入张量 input_tensor = pb_utils.Tensor("feats", features.astype(np.float32)) # 创建推理请求 inference_request = pb_utils.InferenceRequest( model_name=self.speaker_model_name, requested_output_names=["embs"], inputs=[input_tensor] ) # 发送请求 inference_response = inference_request.exec() # 处理响应 if inference_response.has_error(): raise pb_utils.TritonModelException(inference_response.error().message()) # 获取嵌入向量 output_tensor = pb_utils.get_output_tensor_by_name(inference_response, "embs") return output_tensor.as_numpy()