File size: 6,766 Bytes
29c0409
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import io
import math
import numpy as np
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import traceback
from torch.utils.dlpack import from_dlpack
import triton_python_backend_utils as pb_utils


class TritonPythonModel:
    def initialize(self, args):
        self.sample_rate = 16000
        self.feature_dim = 80
        self.vad_enabled = True # This variable is declared but not used.
        self.min_duration = 0.1
        
        # This seems correct for BLS (Business Logic Scripting)
        self.speaker_model_name = "speaker_model"

    def execute(self, requests):
        responses = []
        for request in requests:
            try:
                # 1. Get the input audio BYTES, not a file path string.
                # The input tensor is of type TYPE_STRING, which holds bytes.
                # .as_numpy()[0] gives you the raw bytes object.
                audio1_bytes = pb_utils.get_input_tensor_by_name(request, "AUDIO_BYTES_1").as_numpy()[0][0]
                audio2_bytes = pb_utils.get_input_tensor_by_name(request, "AUDIO_BYTES_2").as_numpy()[0][0]
                
                # 2. Preprocess audio from bytes
                feats1 = self.preprocess(audio1_bytes)
                feats2 = self.preprocess(audio2_bytes)
                
                # 3. Call the speaker_model to compute similarity
                similarity = self.compute_similarity(feats1, feats2)

                # Prepare output
                output_tensor = pb_utils.Tensor("SIMILARITY", np.array([similarity], dtype=np.float32))
                response = pb_utils.InferenceResponse(output_tensors=[output_tensor])
                responses.append(response)

            except pb_utils.TritonModelException as e:
                # If a Triton-specific error occurs, create an error response
                error_response = pb_utils.InferenceResponse(error=pb_utils.TritonError(str(e)))
                pb_utils.Logger.log_error(error_response)
                responses.append(error_response)
            except Exception as e:
                # For any other unexpected error, log it and return an error response
                error_message = f"Unexpected error: {e}\n{traceback.format_exc()}"
                pb_utils.Logger.log_error(error_message)
                error_response = pb_utils.InferenceResponse(error=pb_utils.TritonError(error_message))
                responses.append(error_response)

        return responses

    def preprocess(self, audio_bytes: bytes):
        """
        Processes audio data from an in-memory byte buffer.
        If the audio is too short, it's padded by repetition to meet the minimum length.
        """
        try:
            # Wrap the raw bytes in a file-like object for torchaudio
            # buffer = io.BytesIO(audio_bytes)
            buffer = audio_bytes.decode('utf-8')
            waveform, sample_rate = torchaudio.load(buffer)
            
            # You might want to resample if the client's sample rate differs
            if sample_rate != self.sample_rate:
                # Note: This requires the 'torchaudio.transforms' module.
                # Make sure torchaudio is fully installed in your Triton environment.
                resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)
                waveform = resampler(waveform)

            duration = waveform.shape[1] / self.sample_rate
            
            if duration < self.min_duration:
                # Audio is too short, repeat it to meet the minimum duration
                repeat_times = math.ceil(self.min_duration / duration)
                waveform = waveform.repeat(1, repeat_times)
            
            # --- THIS IS THE NEW, CRITICAL PART ---
            # Calculate 80-dimensional Fbank features, which is what the speaker_model expects.
            # The waveform needs to be shape [batch, time], so we squeeze it.
            features = kaldi.fbank(
                waveform.squeeze(0).unsqueeze(0), # Needs shape [1, T]
                num_mel_bins=self.feature_dim,     # This is 80
                sample_frequency=self.sample_rate,
                frame_length=25,
                frame_shift=10
            )
            # The output of fbank is [1, num_frames, num_bins], e.g., [1, 150, 80]
            # We need [num_frames, num_bins] for the speaker model
            return features.squeeze(0) # Returns shape [num_frames, 80]
            
        except Exception as e:
            # Raise a specific exception that can be caught in execute()
            raise pb_utils.TritonModelException(f"Failed during audio preprocessing: {e}")
    
    def compute_similarity(self, waveform1, waveform2):
        # Call speaker_model to get embeddings
        # Assuming speaker_model takes a waveform and outputs an embedding
        e1 = torch.from_numpy(self.call_speaker_model(waveform1)).to("cuda")
        e2 = torch.from_numpy(self.call_speaker_model(waveform2)).to("cuda")
        
        # Flatten the tensors
        e1 = e1.flatten()
        e2 = e2.flatten()
        
        # Calculate cosine similarity
        dot_product = torch.dot(e1, e2)
        norm_e1 = torch.norm(e1)
        norm_e2 = torch.norm(e2)
        
        # Handle zero norms
        if norm_e1 == 0 or norm_e2 == 0:
            return 0.0
        
        similarity = (dot_product / (norm_e1 * norm_e2)).item()
        
        # Normalize from [-1, 1] to [0, 1]
        return (similarity + 1) / 2
    
    def call_speaker_model(self, waveform):
        """Calls the speaker_model to get an embedding vector."""
        # Create the input tensor for the speaker_model.
        # The name 'feats' here must match the input name in speaker_model's config.pbtxt
        if waveform.dim() == 2:
            waveform = waveform.unsqueeze(0)
        input_tensor = pb_utils.Tensor("feats", waveform.cpu().numpy().astype(np.float32))
        
        inference_request = pb_utils.InferenceRequest(
            model_name=self.speaker_model_name,
            requested_output_names=["embs"], # Must match output name in speaker_model's config
            inputs=[input_tensor]
        )
        
        inference_response = inference_request.exec()
        
        if inference_response.has_error():
            raise pb_utils.TritonModelException(f"Error from speaker_model: {inference_response.error().message()}")
        
        output_tensor = pb_utils.get_output_tensor_by_name(inference_response, "embs")
        if output_tensor.is_cpu():
            output_tensor = output_tensor.as_numpy()
        else:
            output_tensor = from_dlpack(output_tensor.to_dlpack()).detach().cpu().numpy()
        return output_tensor