| import whisper |
| import gradio as gr |
| import time |
| from typing import BinaryIO, Union, Tuple, List |
| import numpy as np |
| import torch |
| import os |
| from argparse import Namespace |
|
|
| from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, UVR_MODELS_DIR) |
| from modules.whisper.whisper_base import WhisperBase |
| from modules.whisper.whisper_parameter import * |
|
|
|
|
| class WhisperInference(WhisperBase): |
| def __init__(self, |
| model_dir: str = WHISPER_MODELS_DIR, |
| diarization_model_dir: str = DIARIZATION_MODELS_DIR, |
| uvr_model_dir: str = UVR_MODELS_DIR, |
| output_dir: str = OUTPUT_DIR, |
| ): |
| super().__init__( |
| model_dir=model_dir, |
| output_dir=output_dir, |
| diarization_model_dir=diarization_model_dir, |
| uvr_model_dir=uvr_model_dir |
| ) |
|
|
| def transcribe(self, |
| audio: Union[str, np.ndarray, torch.Tensor], |
| progress: gr.Progress = gr.Progress(), |
| *whisper_params, |
| ) -> Tuple[List[dict], float]: |
| """ |
| transcribe method for faster-whisper. |
| |
| Parameters |
| ---------- |
| audio: Union[str, BinaryIO, np.ndarray] |
| Audio path or file binary or Audio numpy array |
| progress: gr.Progress |
| Indicator to show progress directly in gradio. |
| *whisper_params: tuple |
| Parameters related with whisper. This will be dealt with "WhisperParameters" data class |
| |
| Returns |
| ---------- |
| segments_result: List[dict] |
| list of dicts that includes start, end timestamps and transcribed text |
| elapsed_time: float |
| elapsed time for transcription |
| """ |
| start_time = time.time() |
| params = WhisperParameters.as_value(*whisper_params) |
|
|
| if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: |
| self.update_model(params.model_size, params.compute_type, progress) |
|
|
| def progress_callback(progress_value): |
| progress(progress_value, desc="Transcribing...") |
|
|
| segments_result = self.model.transcribe(audio=audio, |
| language=params.lang, |
| verbose=False, |
| beam_size=params.beam_size, |
| logprob_threshold=params.log_prob_threshold, |
| no_speech_threshold=params.no_speech_threshold, |
| task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe", |
| fp16=True if params.compute_type == "float16" else False, |
| best_of=params.best_of, |
| patience=params.patience, |
| temperature=params.temperature, |
| compression_ratio_threshold=params.compression_ratio_threshold, |
| progress_callback=progress_callback,)["segments"] |
| elapsed_time = time.time() - start_time |
|
|
| return segments_result, elapsed_time |
|
|
| def update_model(self, |
| model_size: str, |
| compute_type: str, |
| progress: gr.Progress = gr.Progress(), |
| ): |
| """ |
| Update current model setting |
| |
| Parameters |
| ---------- |
| model_size: str |
| Size of whisper model |
| compute_type: str |
| Compute type for transcription. |
| see more info : https://opennmt.net/CTranslate2/quantization.html |
| progress: gr.Progress |
| Indicator to show progress directly in gradio. |
| """ |
| progress(0, desc="Initializing Model...") |
| self.current_compute_type = compute_type |
| self.current_model_size = model_size |
| self.model = whisper.load_model( |
| name=model_size, |
| device=self.device, |
| download_root=self.model_dir |
| ) |