File size: 6,441 Bytes
0c354cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)

import os.path
from typing import List,  Tuple

import numpy as np

from utils.utils.utils import  read_yaml
from utils.utils.frontend import WavFrontend
from utils.utils.e2e_vad import E2EVadModel
import axengine as axe

class AX_Fsmn_vad:
    def __init__(self, model_dir, batch_size=1, max_end_sil=None):
            """Initialize VAD model for inference"""
            
            # Export model if needed
            model_file = os.path.join(model_dir, "vad.axmodel")
            
            # Load config and frontend
            config_file = os.path.join(model_dir, "vad/config.yaml")
            cmvn_file = os.path.join(model_dir, "vad/am.mvn")
            self.config = read_yaml(config_file)
            self.frontend = WavFrontend(cmvn_file=cmvn_file, **self.config["frontend_conf"])
            #self.session = axe.InferenceSession(model_file, providers='AxEngineExecutionProvider')
            self.session = axe.InferenceSession(model_file)
            self.batch_size = batch_size
            self.vad_scorer = E2EVadModel(self.config["model_conf"])
            self.max_end_sil = max_end_sil if max_end_sil is not None else self.config["model_conf"]["max_end_silence_time"]

    def extract_feat(self, waveform_list):
        """Extract features from waveform"""
        feats, feats_len = [], []
        for waveform in waveform_list:
            speech, _ = self.frontend.fbank(waveform)
            feat, feat_len = self.frontend.lfr_cmvn(speech)
            feats.append(feat)
            feats_len.append(feat_len)

        max_len = max(feats_len)
        padded_feats = [np.pad(f, ((0, max_len - f.shape[0]), (0, 0)), 'constant') for f in feats]
        feats = np.array(padded_feats).astype(np.float32)
        feats_len = np.array(feats_len).astype(np.int32)
        return feats, feats_len

    def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]:
        """Run inference with ONNX Runtime"""
        # Get all input names from the model
        input_names = [input.name for input in self.session.get_inputs()]
        output_names = [x.name for x in self.session.get_outputs()]
        
        # Create input dictionary for all inputs
        input_dict = {}
        for i, (name, tensor) in enumerate(zip(input_names, feats)):
            input_dict[name] = tensor
            
        # Run inference with all inputs
        outputs = self.session.run(output_names, input_dict)
        scores, out_caches = outputs[0], outputs[1:]
        return scores, out_caches

    def __call__(self, wav_file, **kwargs):
        """Process audio file with sliding window approach"""
        # Load audio and prepare data
        # waveform = self.load_wav(wav_file)
        # waveform, _ = librosa.load(wav_file, sr=16000)
        waveform_list = [wav_file]
        waveform_nums = len(waveform_list)
        is_final = kwargs.get("kwargs", False)
        segments = [[]] * self.batch_size

        for beg_idx in range(0, waveform_nums, self.batch_size):
            vad_scorer = E2EVadModel(self.config["model_conf"])
            end_idx = min(waveform_nums, beg_idx + self.batch_size)
            waveform = waveform_list[beg_idx:end_idx]
            feats, feats_len = self.extract_feat(waveform)
            waveform = np.array(waveform)
            param_dict = kwargs.get("param_dict", dict())
            in_cache = param_dict.get("in_cache", list())
            in_cache = self.prepare_cache(in_cache)

            t_offset = 0
            step = int(min(feats_len.max(), 6000))
            for t_offset in range(0, int(feats_len), min(step, feats_len - t_offset)):
                if t_offset + step >= feats_len - 1:
                    step = feats_len - t_offset
                    is_final = True
                else:
                    is_final = False

                # Extract feature segment
                feats_package = feats[:, t_offset:int(t_offset + step), :]

                # Pad if it's the final segment
                if is_final:
                    pad_length = 6000 - int(step)
                    feats_package = np.pad(
                        feats_package,
                        ((0, 0), (0, pad_length), (0, 0)),
                        mode='constant',
                        constant_values=0
                    )

                # Extract corresponding waveform segment
                waveform_package = waveform[
                    :,
                    t_offset * 160:min(waveform.shape[-1], (int(t_offset + step) - 1) * 160 + 400),
                ]

                # Pad waveform if it's the final segment
                if is_final:
                    expected_wave_length = 6000 * 160 + 240
                    current_wave_length = waveform_package.shape[-1]
                    pad_wave_length = expected_wave_length - current_wave_length
                    if pad_wave_length > 0:
                        waveform_package = np.pad(
                            waveform_package,
                            ((0, 0), (0, pad_wave_length)),
                            mode='constant',
                            constant_values=0
                        )

                # Run inference
                inputs = [feats_package]
                inputs.extend(in_cache)
                scores, out_caches = self.infer(inputs)
                in_cache = out_caches
                
                # Get VAD segments for this chunk
                segments_part = vad_scorer(
                    scores,
                    waveform_package,
                    is_final=is_final,
                    max_end_sil=self.max_end_sil,
                    online=False,
                )

                # Accumulate segments
                if segments_part:
                    for batch_num in range(0, self.batch_size):
                        segments[batch_num] += segments_part[batch_num]

        return segments

    def prepare_cache(self, in_cache: list = []):
        if len(in_cache) > 0:
            return in_cache
        fsmn_layers = 4
        proj_dim = 128
        lorder = 20
        for i in range(fsmn_layers):
            cache = np.zeros((1, proj_dim, lorder - 1, 1)).astype(np.float32)
            in_cache.append(cache)
        return in_cache