File size: 1,134 Bytes
5db02c6
7b9e393
 
dd7709d
7f02c6d
 
955b296
7f02c6d
7b9e393
 
7f02c6d
dd7709d
7f02c6d
7b9e393
 
 
 
 
7f02c6d
7b9e393
7f02c6d
7b9e393
 
 
7f02c6d
 
7b9e393
7f02c6d
 
 
ad29e59
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from models import StudentForAudioClassification
import torch
import torchaudio

class EndpointHandler:
    def __init__(self, model_dir, *args, **kwargs):
        self.model = StudentForAudioClassification.from_pretrained(model_dir, trust_remote_code=True)
        self.model.eval()
        bundle = torchaudio.pipelines.WAV2VEC2_BASE
        self.w2v_model = bundle.get_model()
        self.w2v_model.eval()

    def __call__(self, data):
        import io
        waveform, orig_sr = torchaudio.load(io.BytesIO(data["inputs"]))
        waveform = waveform.mean(dim=0, keepdim=True)
        if orig_sr != 16000:
            resampler = torchaudio.transforms.Resample(orig_sr, 16000)
            waveform = resampler(waveform)
        
        with torch.no_grad():
            features = self.w2v_model(waveform)[0]
            x_w2v = features.mean(dim=1)
            x_w2v = x_w2v[:, :512]
            outputs = self.model(x_w2v)
            probs = torch.softmax(outputs.logits, dim=-1)
        
        return {
            "probabilities": probs.squeeze(0).tolist(),
            "label": int(probs.argmax(dim=-1)[0])
        }