File size: 6,766 Bytes
29c0409 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import io
import math
import numpy as np
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import traceback
from torch.utils.dlpack import from_dlpack
import triton_python_backend_utils as pb_utils
class TritonPythonModel:
def initialize(self, args):
self.sample_rate = 16000
self.feature_dim = 80
self.vad_enabled = True # This variable is declared but not used.
self.min_duration = 0.1
# This seems correct for BLS (Business Logic Scripting)
self.speaker_model_name = "speaker_model"
def execute(self, requests):
responses = []
for request in requests:
try:
# 1. Get the input audio BYTES, not a file path string.
# The input tensor is of type TYPE_STRING, which holds bytes.
# .as_numpy()[0] gives you the raw bytes object.
audio1_bytes = pb_utils.get_input_tensor_by_name(request, "AUDIO_BYTES_1").as_numpy()[0][0]
audio2_bytes = pb_utils.get_input_tensor_by_name(request, "AUDIO_BYTES_2").as_numpy()[0][0]
# 2. Preprocess audio from bytes
feats1 = self.preprocess(audio1_bytes)
feats2 = self.preprocess(audio2_bytes)
# 3. Call the speaker_model to compute similarity
similarity = self.compute_similarity(feats1, feats2)
# Prepare output
output_tensor = pb_utils.Tensor("SIMILARITY", np.array([similarity], dtype=np.float32))
response = pb_utils.InferenceResponse(output_tensors=[output_tensor])
responses.append(response)
except pb_utils.TritonModelException as e:
# If a Triton-specific error occurs, create an error response
error_response = pb_utils.InferenceResponse(error=pb_utils.TritonError(str(e)))
pb_utils.Logger.log_error(error_response)
responses.append(error_response)
except Exception as e:
# For any other unexpected error, log it and return an error response
error_message = f"Unexpected error: {e}\n{traceback.format_exc()}"
pb_utils.Logger.log_error(error_message)
error_response = pb_utils.InferenceResponse(error=pb_utils.TritonError(error_message))
responses.append(error_response)
return responses
def preprocess(self, audio_bytes: bytes):
"""
Processes audio data from an in-memory byte buffer.
If the audio is too short, it's padded by repetition to meet the minimum length.
"""
try:
# Wrap the raw bytes in a file-like object for torchaudio
# buffer = io.BytesIO(audio_bytes)
buffer = audio_bytes.decode('utf-8')
waveform, sample_rate = torchaudio.load(buffer)
# You might want to resample if the client's sample rate differs
if sample_rate != self.sample_rate:
# Note: This requires the 'torchaudio.transforms' module.
# Make sure torchaudio is fully installed in your Triton environment.
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)
waveform = resampler(waveform)
duration = waveform.shape[1] / self.sample_rate
if duration < self.min_duration:
# Audio is too short, repeat it to meet the minimum duration
repeat_times = math.ceil(self.min_duration / duration)
waveform = waveform.repeat(1, repeat_times)
# --- THIS IS THE NEW, CRITICAL PART ---
# Calculate 80-dimensional Fbank features, which is what the speaker_model expects.
# The waveform needs to be shape [batch, time], so we squeeze it.
features = kaldi.fbank(
waveform.squeeze(0).unsqueeze(0), # Needs shape [1, T]
num_mel_bins=self.feature_dim, # This is 80
sample_frequency=self.sample_rate,
frame_length=25,
frame_shift=10
)
# The output of fbank is [1, num_frames, num_bins], e.g., [1, 150, 80]
# We need [num_frames, num_bins] for the speaker model
return features.squeeze(0) # Returns shape [num_frames, 80]
except Exception as e:
# Raise a specific exception that can be caught in execute()
raise pb_utils.TritonModelException(f"Failed during audio preprocessing: {e}")
def compute_similarity(self, waveform1, waveform2):
# Call speaker_model to get embeddings
# Assuming speaker_model takes a waveform and outputs an embedding
e1 = torch.from_numpy(self.call_speaker_model(waveform1)).to("cuda")
e2 = torch.from_numpy(self.call_speaker_model(waveform2)).to("cuda")
# Flatten the tensors
e1 = e1.flatten()
e2 = e2.flatten()
# Calculate cosine similarity
dot_product = torch.dot(e1, e2)
norm_e1 = torch.norm(e1)
norm_e2 = torch.norm(e2)
# Handle zero norms
if norm_e1 == 0 or norm_e2 == 0:
return 0.0
similarity = (dot_product / (norm_e1 * norm_e2)).item()
# Normalize from [-1, 1] to [0, 1]
return (similarity + 1) / 2
def call_speaker_model(self, waveform):
"""Calls the speaker_model to get an embedding vector."""
# Create the input tensor for the speaker_model.
# The name 'feats' here must match the input name in speaker_model's config.pbtxt
if waveform.dim() == 2:
waveform = waveform.unsqueeze(0)
input_tensor = pb_utils.Tensor("feats", waveform.cpu().numpy().astype(np.float32))
inference_request = pb_utils.InferenceRequest(
model_name=self.speaker_model_name,
requested_output_names=["embs"], # Must match output name in speaker_model's config
inputs=[input_tensor]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(f"Error from speaker_model: {inference_response.error().message()}")
output_tensor = pb_utils.get_output_tensor_by_name(inference_response, "embs")
if output_tensor.is_cpu():
output_tensor = output_tensor.as_numpy()
else:
output_tensor = from_dlpack(output_tensor.to_dlpack()).detach().cpu().numpy()
return output_tensor |