Spaces:
Running
Running
jhj0517
commited on
Commit
·
e862b08
1
Parent(s):
cee12df
Handle gradio none values
Browse files
modules/whisper/faster_whisper_inference.py
CHANGED
|
@@ -67,16 +67,6 @@ class FasterWhisperInference(WhisperBase):
|
|
| 67 |
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
| 68 |
self.update_model(params.model_size, params.compute_type, progress)
|
| 69 |
|
| 70 |
-
# None parameters with Textboxes: https://github.com/gradio-app/gradio/issues/8723
|
| 71 |
-
if not params.initial_prompt:
|
| 72 |
-
params.initial_prompt = None
|
| 73 |
-
if not params.prefix:
|
| 74 |
-
params.prefix = None
|
| 75 |
-
if not params.hotwords:
|
| 76 |
-
params.hotwords = None
|
| 77 |
-
|
| 78 |
-
params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens)
|
| 79 |
-
|
| 80 |
segments, info = self.model.transcribe(
|
| 81 |
audio=audio,
|
| 82 |
language=params.lang,
|
|
|
|
| 67 |
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
| 68 |
self.update_model(params.model_size, params.compute_type, progress)
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
segments, info = self.model.transcribe(
|
| 71 |
audio=audio,
|
| 72 |
language=params.lang,
|
modules/whisper/whisper_base.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
import torch
|
|
|
|
| 3 |
import whisper
|
| 4 |
import ctranslate2
|
| 5 |
import gradio as gr
|
|
@@ -14,7 +15,7 @@ from dataclasses import astuple
|
|
| 14 |
from modules.uvr.music_separator import MusicSeparator
|
| 15 |
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
|
| 16 |
UVR_MODELS_DIR)
|
| 17 |
-
from modules.utils.constants import AUTOMATIC_DETECTION
|
| 18 |
from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
|
| 19 |
from modules.utils.youtube_manager import get_ytdata, get_ytaudio
|
| 20 |
from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
|
|
@@ -101,16 +102,9 @@ class WhisperBase(ABC):
|
|
| 101 |
elapsed time for running
|
| 102 |
"""
|
| 103 |
params = TranscriptionPipelineParams.from_list(list(pipeline_params))
|
|
|
|
| 104 |
bgm_params, vad_params, whisper_params, diarization_params = params.bgm_separation, params.vad, params.whisper, params.diarization
|
| 105 |
|
| 106 |
-
if whisper_params.lang is None:
|
| 107 |
-
pass
|
| 108 |
-
elif whisper_params.lang == AUTOMATIC_DETECTION:
|
| 109 |
-
whisper_params.lang = None
|
| 110 |
-
else:
|
| 111 |
-
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
|
| 112 |
-
whisper_params.lang = language_code_dict[params.lang]
|
| 113 |
-
|
| 114 |
if bgm_params.is_separate_bgm:
|
| 115 |
music, audio, _ = self.music_separator.separate(
|
| 116 |
audio=audio,
|
|
@@ -515,25 +509,57 @@ class WhisperBase(ABC):
|
|
| 515 |
if file_path and os.path.exists(file_path):
|
| 516 |
os.remove(file_path)
|
| 517 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
@staticmethod
|
| 519 |
def cache_parameters(
|
| 520 |
params: TranscriptionPipelineParams,
|
| 521 |
add_timestamp: bool
|
| 522 |
):
|
| 523 |
-
"""
|
| 524 |
-
|
| 525 |
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
| 526 |
param_to_cache = params.to_dict()
|
| 527 |
|
| 528 |
-
print(param_to_cache)
|
| 529 |
-
|
| 530 |
cached_yaml = {**cached_params, **param_to_cache}
|
| 531 |
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
|
| 532 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 533 |
if cached_yaml["whisper"].get("lang", None) is None:
|
| 534 |
-
cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION
|
| 535 |
|
| 536 |
-
|
|
|
|
| 537 |
|
| 538 |
@staticmethod
|
| 539 |
def resample_audio(audio: Union[str, np.ndarray],
|
|
|
|
| 1 |
import os
|
| 2 |
import torch
|
| 3 |
+
import ast
|
| 4 |
import whisper
|
| 5 |
import ctranslate2
|
| 6 |
import gradio as gr
|
|
|
|
| 15 |
from modules.uvr.music_separator import MusicSeparator
|
| 16 |
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
|
| 17 |
UVR_MODELS_DIR)
|
| 18 |
+
from modules.utils.constants import AUTOMATIC_DETECTION, GRADIO_NONE_VALUES
|
| 19 |
from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
|
| 20 |
from modules.utils.youtube_manager import get_ytdata, get_ytaudio
|
| 21 |
from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
|
|
|
|
| 102 |
elapsed time for running
|
| 103 |
"""
|
| 104 |
params = TranscriptionPipelineParams.from_list(list(pipeline_params))
|
| 105 |
+
params = self.handle_gradio_values(params)
|
| 106 |
bgm_params, vad_params, whisper_params, diarization_params = params.bgm_separation, params.vad, params.whisper, params.diarization
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
if bgm_params.is_separate_bgm:
|
| 109 |
music, audio, _ = self.music_separator.separate(
|
| 110 |
audio=audio,
|
|
|
|
| 509 |
if file_path and os.path.exists(file_path):
|
| 510 |
os.remove(file_path)
|
| 511 |
|
| 512 |
+
@staticmethod
|
| 513 |
+
def handle_gradio_values(params: TranscriptionPipelineParams):
|
| 514 |
+
"""
|
| 515 |
+
Handle gradio specific values that can't be displayed as None in the UI.
|
| 516 |
+
Related issue : https://github.com/gradio-app/gradio/issues/8723
|
| 517 |
+
"""
|
| 518 |
+
if params.whisper.lang is None:
|
| 519 |
+
pass
|
| 520 |
+
elif params.whisper.lang == AUTOMATIC_DETECTION:
|
| 521 |
+
params.whisper.lang = None
|
| 522 |
+
else:
|
| 523 |
+
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
|
| 524 |
+
params.whisper.lang = language_code_dict[params.lang]
|
| 525 |
+
|
| 526 |
+
if not params.whisper.initial_prompt:
|
| 527 |
+
params.whisper.initial_prompt = None
|
| 528 |
+
if not params.whisper.prefix:
|
| 529 |
+
params.whisper.prefix = None
|
| 530 |
+
if not params.whisper.hotwords:
|
| 531 |
+
params.whisper.hotwords = None
|
| 532 |
+
if params.whisper.max_new_tokens == 0:
|
| 533 |
+
params.whisper.max_new_tokens = None
|
| 534 |
+
if params.whisper.hallucination_silence_threshold == 0:
|
| 535 |
+
params.whisper.hallucination_silence_threshold = None
|
| 536 |
+
if params.whisper.language_detection_threshold == 0:
|
| 537 |
+
params.whisper.language_detection_threshold = None
|
| 538 |
+
if params.whisper.max_speech_duration_s >= 9999:
|
| 539 |
+
params.whisper.max_speech_duration_s = float('inf')
|
| 540 |
+
return params
|
| 541 |
+
|
| 542 |
@staticmethod
|
| 543 |
def cache_parameters(
|
| 544 |
params: TranscriptionPipelineParams,
|
| 545 |
add_timestamp: bool
|
| 546 |
):
|
| 547 |
+
"""Cache parameters to the yaml file"""
|
|
|
|
| 548 |
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
| 549 |
param_to_cache = params.to_dict()
|
| 550 |
|
|
|
|
|
|
|
| 551 |
cached_yaml = {**cached_params, **param_to_cache}
|
| 552 |
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
|
| 553 |
|
| 554 |
+
supress_token = cached_yaml["whisper"].get("suppress_tokens", None)
|
| 555 |
+
if supress_token and isinstance(supress_token, list):
|
| 556 |
+
cached_yaml["whisper"]["suppress_tokens"] = str(supress_token)
|
| 557 |
+
|
| 558 |
if cached_yaml["whisper"].get("lang", None) is None:
|
| 559 |
+
cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION.unwrap()
|
| 560 |
|
| 561 |
+
if cached_yaml is not None and cached_yaml:
|
| 562 |
+
save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
|
| 563 |
|
| 564 |
@staticmethod
|
| 565 |
def resample_audio(audio: Union[str, np.ndarray],
|