| import os |
| import time |
| import numpy as np |
| from typing import BinaryIO, Union, Tuple, List |
| import torch |
| from transformers import pipeline |
| from transformers.utils import is_flash_attn_2_available |
| import gradio as gr |
| from huggingface_hub import hf_hub_download |
| import whisper |
| from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn |
| from argparse import Namespace |
|
|
| from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR) |
| from modules.whisper.whisper_parameter import * |
| from modules.whisper.whisper_base import WhisperBase |
|
|
|
|
| class InsanelyFastWhisperInference(WhisperBase): |
| def __init__(self, |
| model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR, |
| diarization_model_dir: str = DIARIZATION_MODELS_DIR, |
| uvr_model_dir: str = UVR_MODELS_DIR, |
| output_dir: str = OUTPUT_DIR, |
| ): |
| super().__init__( |
| model_dir=model_dir, |
| output_dir=output_dir, |
| diarization_model_dir=diarization_model_dir, |
| uvr_model_dir=uvr_model_dir |
| ) |
| self.model_dir = model_dir |
| os.makedirs(self.model_dir, exist_ok=True) |
|
|
| openai_models = whisper.available_models() |
| distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"] |
| self.available_models = openai_models + distil_models |
| self.available_compute_types = ["float16"] |
|
|
| def transcribe(self, |
| audio: Union[str, np.ndarray, torch.Tensor], |
| progress: gr.Progress = gr.Progress(), |
| *whisper_params, |
| ) -> Tuple[List[dict], float]: |
| """ |
| transcribe method for faster-whisper. |
| |
| Parameters |
| ---------- |
| audio: Union[str, BinaryIO, np.ndarray] |
| Audio path or file binary or Audio numpy array |
| progress: gr.Progress |
| Indicator to show progress directly in gradio. |
| *whisper_params: tuple |
| Parameters related with whisper. This will be dealt with "WhisperParameters" data class |
| |
| Returns |
| ---------- |
| segments_result: List[dict] |
| list of dicts that includes start, end timestamps and transcribed text |
| elapsed_time: float |
| elapsed time for transcription |
| """ |
| start_time = time.time() |
| params = WhisperParameters.as_value(*whisper_params) |
|
|
| if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: |
| self.update_model(params.model_size, params.compute_type, progress) |
|
|
| progress(0, desc="Transcribing...Progress is not shown in insanely-fast-whisper.") |
| with Progress( |
| TextColumn("[progress.description]{task.description}"), |
| BarColumn(style="yellow1", pulse_style="white"), |
| TimeElapsedColumn(), |
| ) as progress: |
| progress.add_task("[yellow]Transcribing...", total=None) |
|
|
| kwargs = { |
| "no_speech_threshold": params.no_speech_threshold, |
| "temperature": params.temperature, |
| "compression_ratio_threshold": params.compression_ratio_threshold, |
| "logprob_threshold": params.log_prob_threshold, |
| } |
|
|
| if self.current_model_size.endswith(".en"): |
| pass |
| else: |
| kwargs["language"] = params.lang |
| kwargs["task"] = "translate" if params.is_translate else "transcribe" |
|
|
| segments = self.model( |
| inputs=audio, |
| return_timestamps=True, |
| chunk_length_s=params.chunk_length, |
| batch_size=params.batch_size, |
| generate_kwargs=kwargs |
| ) |
|
|
| segments_result = self.format_result( |
| transcribed_result=segments, |
| ) |
| elapsed_time = time.time() - start_time |
| return segments_result, elapsed_time |
|
|
| def update_model(self, |
| model_size: str, |
| compute_type: str, |
| progress: gr.Progress = gr.Progress(), |
| ): |
| """ |
| Update current model setting |
| |
| Parameters |
| ---------- |
| model_size: str |
| Size of whisper model |
| compute_type: str |
| Compute type for transcription. |
| see more info : https://opennmt.net/CTranslate2/quantization.html |
| progress: gr.Progress |
| Indicator to show progress directly in gradio. |
| """ |
| progress(0, desc="Initializing Model...") |
| model_path = os.path.join(self.model_dir, model_size) |
| if not os.path.isdir(model_path) or not os.listdir(model_path): |
| self.download_model( |
| model_size=model_size, |
| download_root=model_path, |
| progress=progress |
| ) |
|
|
| self.current_compute_type = compute_type |
| self.current_model_size = model_size |
| self.model = pipeline( |
| "automatic-speech-recognition", |
| model=os.path.join(self.model_dir, model_size), |
| torch_dtype=self.current_compute_type, |
| device=self.device, |
| model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"}, |
| ) |
|
|
| @staticmethod |
| def format_result( |
| transcribed_result: dict |
| ) -> List[dict]: |
| """ |
| Format the transcription result of insanely_fast_whisper as the same with other implementation. |
| |
| Parameters |
| ---------- |
| transcribed_result: dict |
| Transcription result of the insanely_fast_whisper |
| |
| Returns |
| ---------- |
| result: List[dict] |
| Formatted result as the same with other implementation |
| """ |
| result = transcribed_result["chunks"] |
| for item in result: |
| start, end = item["timestamp"][0], item["timestamp"][1] |
| if end is None: |
| end = start |
| item["start"] = start |
| item["end"] = end |
| return result |
|
|
| @staticmethod |
| def download_model( |
| model_size: str, |
| download_root: str, |
| progress: gr.Progress |
| ): |
| progress(0, 'Initializing model..') |
| print(f'Downloading {model_size} to "{download_root}"....') |
|
|
| os.makedirs(download_root, exist_ok=True) |
| download_list = [ |
| "model.safetensors", |
| "config.json", |
| "generation_config.json", |
| "preprocessor_config.json", |
| "tokenizer.json", |
| "tokenizer_config.json", |
| "added_tokens.json", |
| "special_tokens_map.json", |
| "vocab.json", |
| ] |
|
|
| if model_size.startswith("distil"): |
| repo_id = f"distil-whisper/{model_size}" |
| else: |
| repo_id = f"openai/whisper-{model_size}" |
| for item in download_list: |
| hf_hub_download(repo_id=repo_id, filename=item, local_dir=download_root) |
|
|