Audio Classification
PyTorch
Safetensors
whisper
biology
File size: 3,631 Bytes
071b3bf
 
 
 
 
 
 
 
 
3f85019
071b3bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f85019
 
 
071b3bf
 
 
 
e3b42e3
 
071b3bf
 
e3b42e3
071b3bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3b42e3
 
071b3bf
 
e3b42e3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
handler.py
Set up the possibility for an inference endpoint on huggingface.
"""
from typing import Dict, Any
import torch
import torchaudio
from transformers import WhisperForAudioClassification, WhisperFeatureExtractor
import numpy as np
import base64

class EndpointHandler():
    """
    This is a wrapper for huggingface models so that they return json objects and consider the same configs as other implementations
    """
    def __init__(self, threshold=0.5):

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
        model_id = 'DORI-SRKW/whisper-base-mm'

        # Load the model
        try:
            self.model = WhisperForAudioClassification.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
        except:
            self.model = WhisperForAudioClassification.from_pretrained(model_id, torch_dtype=torch_dtype)
        self.feature_extractor = WhisperFeatureExtractor.from_pretrained(model_id)

        self.model.eval()
        self.model.to(self.device)
        self.threshold = threshold


    def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
        """
        Args:
            data (:obj:):
                includes the input data and the parameters for the inference.
        Return:
            A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
                - "label": A string representing what the label/class is. There can be multiple labels.
                - "score": A score between 0 and 1 describing how confident the model is for this label/class.
        """

        # step one, get the sampling rate of the audio
        audio = data['audio']
        # we encoded using base64.b64encode(filebytes).decode('utf-8') to pass to api url
        audio = base64.b64decode(audio.encode('utf-8'))

        
        fs = data['sampling_rate']

        # split into 15 second intervals
        audio = np.frombuffer(audio, dtype=np.float32)
        audio = torch.tensor(audio)
        audio = audio.reshape(1, -1)


        # torchaudio resamples the audio to 32000
        audio = torchaudio.functional.resample(audio, orig_freq=fs, new_freq=32000)

        # highpass filter 1000 hz
        audio = torchaudio.functional.highpass_biquad(audio, 32000, 1000, 0.707)

        audio3 = []
        for i in range(0, len(audio[-1]), 32000*15):
            audio3.append(audio[:,i:i+32000*15].squeeze().cpu().data.numpy())
        
        data = self.feature_extractor(audio3, sampling_rate = 16000, padding='max_length', max_length=32000*15, return_tensors='pt')

        try:
            data['input_values'] = data['input_values'].squeeze(0)
        except:
            # it is called input_features for whisper
            data['input_features'] = data['input_features'].squeeze(0)

        data = {k: v.to(self.device) for k, v in data.items()}
        with torch.amp.autocast(device_type=self.device):
            outputs = []
            for segment in range(data['input_features'].shape[0]):
                # iterate through 15 second segments
                output = self.model(data['input_features'][segment].unsqueeze(0))


                outputs.append({'logit': torch.softmax(output.logits, dim=1)[0][1].float().cpu().data.numpy().max(), 'start_time_s': segment*15})

        outputs = {'logit': max([x['logit'] for x in outputs]), 'classification': 'present' if max([x['logit'] for x in outputs]) >= self.threshold else 'absent'}
        return outputs