| import os |
| import time |
| import numpy as np |
| import torch |
| from typing import BinaryIO, Union, Tuple, List |
| import faster_whisper |
| from faster_whisper.vad import VadOptions |
| import ast |
| import ctranslate2 |
| import whisper |
| import gradio as gr |
| from argparse import Namespace |
|
|
| from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR) |
| from modules.whisper.whisper_parameter import * |
| from modules.whisper.whisper_base import WhisperBase |
|
|
|
|
| class FasterWhisperInference(WhisperBase): |
| def __init__(self, |
| model_dir: str = FASTER_WHISPER_MODELS_DIR, |
| diarization_model_dir: str = DIARIZATION_MODELS_DIR, |
| uvr_model_dir: str = UVR_MODELS_DIR, |
| output_dir: str = OUTPUT_DIR, |
| ): |
| super().__init__( |
| model_dir=model_dir, |
| diarization_model_dir=diarization_model_dir, |
| uvr_model_dir=uvr_model_dir, |
| output_dir=output_dir |
| ) |
| self.model_dir = model_dir |
| os.makedirs(self.model_dir, exist_ok=True) |
|
|
| self.model_paths = self.get_model_paths() |
| self.device = self.get_device() |
| self.available_models = self.model_paths.keys() |
| self.available_compute_types = ctranslate2.get_supported_compute_types( |
| "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu") |
|
|
| def transcribe(self, |
| audio: Union[str, BinaryIO, np.ndarray], |
| progress: gr.Progress = gr.Progress(), |
| *whisper_params, |
| ) -> Tuple[List[dict], float]: |
| """ |
| transcribe method for faster-whisper. |
| |
| Parameters |
| ---------- |
| audio: Union[str, BinaryIO, np.ndarray] |
| Audio path or file binary or Audio numpy array |
| progress: gr.Progress |
| Indicator to show progress directly in gradio. |
| *whisper_params: tuple |
| Parameters related with whisper. This will be dealt with "WhisperParameters" data class |
| |
| Returns |
| ---------- |
| segments_result: List[dict] |
| list of dicts that includes start, end timestamps and transcribed text |
| elapsed_time: float |
| elapsed time for transcription |
| """ |
| start_time = time.time() |
|
|
| params = WhisperParameters.as_value(*whisper_params) |
|
|
| if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: |
| self.update_model(params.model_size, params.compute_type, progress) |
|
|
| |
| if not params.initial_prompt: |
| params.initial_prompt = None |
| if not params.prefix: |
| params.prefix = None |
| if not params.hotwords: |
| params.hotwords = None |
|
|
| params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens) |
|
|
| segments, info = self.model.transcribe( |
| audio=audio, |
| language=params.lang, |
| task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe", |
| beam_size=params.beam_size, |
| log_prob_threshold=params.log_prob_threshold, |
| no_speech_threshold=params.no_speech_threshold, |
| best_of=params.best_of, |
| patience=params.patience, |
| temperature=params.temperature, |
| initial_prompt=params.initial_prompt, |
| compression_ratio_threshold=params.compression_ratio_threshold, |
| length_penalty=params.length_penalty, |
| repetition_penalty=params.repetition_penalty, |
| no_repeat_ngram_size=params.no_repeat_ngram_size, |
| prefix=params.prefix, |
| suppress_blank=params.suppress_blank, |
| suppress_tokens=params.suppress_tokens, |
| max_initial_timestamp=params.max_initial_timestamp, |
| word_timestamps=params.word_timestamps, |
| prepend_punctuations=params.prepend_punctuations, |
| append_punctuations=params.append_punctuations, |
| max_new_tokens=params.max_new_tokens, |
| chunk_length=params.chunk_length, |
| hallucination_silence_threshold=params.hallucination_silence_threshold, |
| hotwords=params.hotwords, |
| language_detection_threshold=params.language_detection_threshold, |
| language_detection_segments=params.language_detection_segments, |
| prompt_reset_on_temperature=params.prompt_reset_on_temperature, |
| ) |
| progress(0, desc="Loading audio...") |
|
|
| segments_result = [] |
| for segment in segments: |
| progress(segment.start / info.duration, desc="Transcribing...") |
| segments_result.append({ |
| "start": segment.start, |
| "end": segment.end, |
| "text": segment.text |
| }) |
|
|
| elapsed_time = time.time() - start_time |
| return segments_result, elapsed_time |
|
|
| def update_model(self, |
| model_size: str, |
| compute_type: str, |
| progress: gr.Progress = gr.Progress() |
| ): |
| """ |
| Update current model setting |
| |
| Parameters |
| ---------- |
| model_size: str |
| Size of whisper model |
| compute_type: str |
| Compute type for transcription. |
| see more info : https://opennmt.net/CTranslate2/quantization.html |
| progress: gr.Progress |
| Indicator to show progress directly in gradio. |
| """ |
| progress(0, desc="Initializing Model...") |
| self.current_model_size = self.model_paths[model_size] |
| self.current_compute_type = compute_type |
| self.model = faster_whisper.WhisperModel( |
| device=self.device, |
| model_size_or_path=self.current_model_size, |
| download_root=self.model_dir, |
| compute_type=self.current_compute_type |
| ) |
|
|
| def get_model_paths(self): |
| """ |
| Get available models from models path including fine-tuned model. |
| |
| Returns |
| ---------- |
| Name list of models |
| """ |
| model_paths = {model:model for model in faster_whisper.available_models()} |
| faster_whisper_prefix = "models--Systran--faster-whisper-" |
|
|
| existing_models = os.listdir(self.model_dir) |
| wrong_dirs = [".locks"] |
| existing_models = list(set(existing_models) - set(wrong_dirs)) |
|
|
| for model_name in existing_models: |
| if faster_whisper_prefix in model_name: |
| model_name = model_name[len(faster_whisper_prefix):] |
|
|
| if model_name not in whisper.available_models(): |
| model_paths[model_name] = os.path.join(self.model_dir, model_name) |
| return model_paths |
|
|
| @staticmethod |
| def get_device(): |
| if torch.cuda.is_available(): |
| return "cuda" |
| else: |
| return "auto" |
|
|
| @staticmethod |
| def format_suppress_tokens_str(suppress_tokens_str: str) -> List[int]: |
| try: |
| suppress_tokens = ast.literal_eval(suppress_tokens_str) |
| if not isinstance(suppress_tokens, list) or not all(isinstance(item, int) for item in suppress_tokens): |
| raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]") |
| return suppress_tokens |
| except Exception as e: |
| raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]") |
|
|