| ```python | |
| """the interface to interact with wakeword model""" | |
| import pyaudio | |
| import threading | |
| import time | |
| import torchaudio | |
| import torch | |
| import numpy as np | |
| import queue | |
| from transformers import WavLMForSequenceClassification | |
| from transformers import AutoFeatureExtractor | |
| def int2float(sound): | |
| abs_max = np.abs(sound).max() | |
| sound = sound.astype('float32') | |
| if abs_max > 0: | |
| sound *= 1/abs_max | |
| sound = sound.squeeze() # depends on the use case | |
| return sound | |
| class RealtimeDecoder(): | |
| def __init__(self, | |
| model, | |
| ) -> None: | |
| self.model = model | |
| self.vad_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', | |
| model='silero_vad', | |
| force_reload=False, | |
| onnx=False) | |
| (self.get_speech_timestamps, _, _, _, _) = utils | |
| self.SAMPLE_RATE = 16000 | |
| self.cache_output = { | |
| "cache" : torch.zeros(0, 0, 0, dtype=torch.float), | |
| "wavchunks": [], | |
| } | |
| self.continue_recording = threading.Event() | |
| self.frame_duration_ms = 1000 | |
| self.audio_queue = queue.SimpleQueue() | |
| self.speech_queue = queue.SimpleQueue() | |
| def start_recording(self, wait_enter_to_stop=True): | |
| def stop(): | |
| input("Press Enter to stop the recording:\n\n") | |
| self.continue_recording.set() | |
| def record(): | |
| audio = pyaudio.PyAudio() | |
| stream = audio.open(format=pyaudio.paInt16, | |
| channels=1, | |
| rate=self.SAMPLE_RATE, | |
| input=True, | |
| frames_per_buffer=int(self.SAMPLE_RATE / 10)) | |
| while not self.continue_recording.is_set(): | |
| audio_chunk = stream.read(int(self.SAMPLE_RATE * self.frame_duration_ms / 1000.0), exception_on_overflow = False) | |
| audio_int16 = np.frombuffer(audio_chunk, np.int16) | |
| audio_float32 = int2float(audio_int16) | |
| waveform = torch.from_numpy(audio_float32) | |
| self.audio_queue.put(waveform) | |
| print("Finish record") | |
| stream.close() | |
| if wait_enter_to_stop: | |
| stop_listener_thread = threading.Thread(target=stop, daemon=False) | |
| else: | |
| stop_listener_thread = None | |
| recording_thread = threading.Thread(target=record, daemon=False) | |
| return stop_listener_thread, recording_thread | |
| def finish_realtime_decode(self): | |
| self.cache_output = { | |
| "cache" : torch.zeros(0, 0, 0, dtype=torch.float), | |
| "wavchunks": [], | |
| } | |
| def start_decoding(self): | |
| def decode(): | |
| while not self.continue_recording.is_set(): | |
| if self.audio_queue.qsize() > 0: | |
| currunt_wavform = self.audio_queue.get() | |
| if currunt_wavform is not None: | |
| self.cache_output['wavchunks'].append(currunt_wavform) | |
| self.cache_output['wavchunks'] = self.cache_output['wavchunks'][-4:] | |
| if len(self.cache_output['wavchunks']) > 1: | |
| wavform = torch.cat(self.cache_output['wavchunks'][-2:], dim=-1) | |
| speech_timestamps = self.get_speech_timestamps(wavform, self.vad_model, sampling_rate=self.SAMPLE_RATE) | |
| logits = [1, 0] | |
| if len(speech_timestamps) > 0: | |
| input_features = feature_extractor.pad([{"input_values": wavform}], padding=True, return_tensors="pt") | |
| logits = self.model(**input_features).logits.softmax(dim=-1).squeeze() | |
| if logits[1] > 0.6: | |
| print("hey armar", logits, wavform.size(-1) / self.SAMPLE_RATE) | |
| self.cache_output['wavchunks'] = [] | |
| else: | |
| print('.'+'.'*self.audio_queue.qsize()) | |
| else: | |
| time.sleep(0.01) | |
| print("KWS thread finish") | |
| kws_decode_thread = threading.Thread(target=decode, daemon=False) | |
| return kws_decode_thread | |
| if __name__ == "__main__": | |
| print("Model loading....") | |
| kws_model = WavLMForSequenceClassification.from_pretrained('nguyenvulebinh/heyarmar') | |
| feature_extractor = AutoFeatureExtractor.from_pretrained('nguyenvulebinh/heyarmar') | |
| print("Model loaded....") | |
| # file_wave = './99.wav' | |
| # wav, rate = torchaudio.load(file_wave) | |
| # input_features = feature_extractor.pad([{"input_values": item} for item in wav], padding=True, return_tensors="pt") | |
| # output = kws_model(**input_features) | |
| # print(output.logits.softmax(dim=-1)) | |
| obj_decode = RealtimeDecoder(kws_model) | |
| recording_threads = obj_decode.start_recording() | |
| kws_decode_thread = obj_decode.start_decoding() | |
| for thread in recording_threads: | |
| if thread is not None: | |
| thread.start() | |
| kws_decode_thread.start() | |
| for thread in recording_threads: | |
| if thread is not None: | |
| thread.join() | |
| kws_decode_thread.join() | |
| ``` |