| import io |
| import os |
| import torch |
| from transformers import ( |
| AutomaticSpeechRecognitionPipeline, |
| WhisperForConditionalGeneration, |
| WhisperTokenizer, |
| WhisperProcessor, |
| ) |
| from peft import PeftModel, PeftConfig |
| import speech_recognition as sr |
| from datetime import datetime, timedelta |
| from queue import Queue |
| from tempfile import NamedTemporaryFile |
| from time import sleep |
| from sys import platform |
|
|
|
|
|
|
| def main(): |
| |
| peft_model_id = "DuyTa/Vietnamese_ASR" |
| language = "Vietnamese" |
| task = "transcribe" |
| default_energy_threshold = 900 |
| default_record_timeout = 0.6 |
| default_phrase_timeout = 3 |
|
|
| |
| phrase_time = None |
| |
| last_sample = bytes() |
| |
| data_queue = Queue() |
| |
| recorder = sr.Recognizer() |
| recorder.energy_threshold = default_energy_threshold |
| |
| recorder.dynamic_energy_threshold = False |
| |
| source = sr.Microphone(sample_rate=16000) |
|
|
| |
| peft_config = PeftConfig.from_pretrained(peft_model_id) |
| model = WhisperForConditionalGeneration.from_pretrained( |
| peft_config.base_model_name_or_path |
| ) |
| model = PeftModel.from_pretrained(model, peft_model_id) |
|
|
| model.to("cuda:0") |
| processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task) |
| pipe = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, batch_size=8, torch_dtype=torch.float32, device="cuda:0") |
|
|
| |
| |
|
|
| record_timeout = default_record_timeout |
| phrase_timeout = default_phrase_timeout |
|
|
| temp_file = NamedTemporaryFile().name |
| transcription = [''] |
|
|
| with source: |
| recorder.adjust_for_ambient_noise(source) |
|
|
| def record_callback(_, audio:sr.AudioData) -> None: |
| """ |
| Threaded callback function to receive audio data when recordings finish. |
| audio: An AudioData containing the recorded bytes. |
| """ |
| |
| data = audio.get_raw_data() |
| data_queue.put(data) |
|
|
| |
| |
| recorder.listen_in_background(source, record_callback, phrase_time_limit=record_timeout) |
|
|
| print("Model loaded.\n") |
|
|
| while True: |
| try: |
| now = datetime.utcnow() |
| |
| if not data_queue.empty(): |
| phrase_complete = False |
| |
| |
| if phrase_time and now - phrase_time > timedelta(seconds=phrase_timeout): |
| last_sample = bytes() |
| phrase_complete = True |
| |
| phrase_time = now |
|
|
| |
| while not data_queue.empty(): |
| data = data_queue.get() |
| last_sample += data |
|
|
| |
| audio_data = sr.AudioData(last_sample, source.SAMPLE_RATE, source.SAMPLE_WIDTH) |
| wav_data = io.BytesIO(audio_data.get_wav_data()) |
|
|
| |
| with open(temp_file, 'w+b') as f: |
| f.write(wav_data.read()) |
|
|
| |
| text = pipe(temp_file, chunk_length_s=30, return_timestamps=False, generate_kwargs={"language": language, "task": task})["text"] |
| |
|
|
| |
| |
| if phrase_complete: |
| transcription.append(text) |
| else: |
| transcription[-1] = text |
|
|
| |
| os.system('cls' if os.name == 'nt' else 'clear') |
| for line in transcription: |
| print(line) |
| |
| print('', end='', flush=True) |
|
|
| |
| sleep(0.25) |
| except KeyboardInterrupt: |
| break |
|
|
| print("\n\nTranscription:") |
| for line in transcription: |
| print(line) |
|
|
| if __name__ == "__main__": |
| main() |
|
|