import io import math import numpy as np import torch import torchaudio import torchaudio.compliance.kaldi as kaldi import traceback from torch.utils.dlpack import from_dlpack import triton_python_backend_utils as pb_utils class TritonPythonModel: def initialize(self, args): self.sample_rate = 16000 self.feature_dim = 80 self.vad_enabled = True # This variable is declared but not used. self.min_duration = 0.1 # This seems correct for BLS (Business Logic Scripting) self.speaker_model_name = "speaker_model" def execute(self, requests): responses = [] for request in requests: try: # 1. Get the input audio BYTES, not a file path string. # The input tensor is of type TYPE_STRING, which holds bytes. # .as_numpy()[0] gives you the raw bytes object. audio1_bytes = pb_utils.get_input_tensor_by_name(request, "AUDIO_BYTES_1").as_numpy()[0][0] audio2_bytes = pb_utils.get_input_tensor_by_name(request, "AUDIO_BYTES_2").as_numpy()[0][0] # 2. Preprocess audio from bytes feats1 = self.preprocess(audio1_bytes) feats2 = self.preprocess(audio2_bytes) # 3. Call the speaker_model to compute similarity similarity = self.compute_similarity(feats1, feats2) # Prepare output output_tensor = pb_utils.Tensor("SIMILARITY", np.array([similarity], dtype=np.float32)) response = pb_utils.InferenceResponse(output_tensors=[output_tensor]) responses.append(response) except pb_utils.TritonModelException as e: # If a Triton-specific error occurs, create an error response error_response = pb_utils.InferenceResponse(error=pb_utils.TritonError(str(e))) pb_utils.Logger.log_error(error_response) responses.append(error_response) except Exception as e: # For any other unexpected error, log it and return an error response error_message = f"Unexpected error: {e}\n{traceback.format_exc()}" pb_utils.Logger.log_error(error_message) error_response = pb_utils.InferenceResponse(error=pb_utils.TritonError(error_message)) responses.append(error_response) return responses def preprocess(self, audio_bytes: bytes): """ Processes audio data from an in-memory byte buffer. If the audio is too short, it's padded by repetition to meet the minimum length. """ try: # Wrap the raw bytes in a file-like object for torchaudio # buffer = io.BytesIO(audio_bytes) buffer = audio_bytes.decode('utf-8') waveform, sample_rate = torchaudio.load(buffer) # You might want to resample if the client's sample rate differs if sample_rate != self.sample_rate: # Note: This requires the 'torchaudio.transforms' module. # Make sure torchaudio is fully installed in your Triton environment. resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate) waveform = resampler(waveform) duration = waveform.shape[1] / self.sample_rate if duration < self.min_duration: # Audio is too short, repeat it to meet the minimum duration repeat_times = math.ceil(self.min_duration / duration) waveform = waveform.repeat(1, repeat_times) # --- THIS IS THE NEW, CRITICAL PART --- # Calculate 80-dimensional Fbank features, which is what the speaker_model expects. # The waveform needs to be shape [batch, time], so we squeeze it. features = kaldi.fbank( waveform.squeeze(0).unsqueeze(0), # Needs shape [1, T] num_mel_bins=self.feature_dim, # This is 80 sample_frequency=self.sample_rate, frame_length=25, frame_shift=10 ) # The output of fbank is [1, num_frames, num_bins], e.g., [1, 150, 80] # We need [num_frames, num_bins] for the speaker model return features.squeeze(0) # Returns shape [num_frames, 80] except Exception as e: # Raise a specific exception that can be caught in execute() raise pb_utils.TritonModelException(f"Failed during audio preprocessing: {e}") def compute_similarity(self, waveform1, waveform2): # Call speaker_model to get embeddings # Assuming speaker_model takes a waveform and outputs an embedding e1 = torch.from_numpy(self.call_speaker_model(waveform1)).to("cuda") e2 = torch.from_numpy(self.call_speaker_model(waveform2)).to("cuda") # Flatten the tensors e1 = e1.flatten() e2 = e2.flatten() # Calculate cosine similarity dot_product = torch.dot(e1, e2) norm_e1 = torch.norm(e1) norm_e2 = torch.norm(e2) # Handle zero norms if norm_e1 == 0 or norm_e2 == 0: return 0.0 similarity = (dot_product / (norm_e1 * norm_e2)).item() # Normalize from [-1, 1] to [0, 1] return (similarity + 1) / 2 def call_speaker_model(self, waveform): """Calls the speaker_model to get an embedding vector.""" # Create the input tensor for the speaker_model. # The name 'feats' here must match the input name in speaker_model's config.pbtxt if waveform.dim() == 2: waveform = waveform.unsqueeze(0) input_tensor = pb_utils.Tensor("feats", waveform.cpu().numpy().astype(np.float32)) inference_request = pb_utils.InferenceRequest( model_name=self.speaker_model_name, requested_output_names=["embs"], # Must match output name in speaker_model's config inputs=[input_tensor] ) inference_response = inference_request.exec() if inference_response.has_error(): raise pb_utils.TritonModelException(f"Error from speaker_model: {inference_response.error().message()}") output_tensor = pb_utils.get_output_tensor_by_name(inference_response, "embs") if output_tensor.is_cpu(): output_tensor = output_tensor.as_numpy() else: output_tensor = from_dlpack(output_tensor.to_dlpack()).detach().cpu().numpy() return output_tensor