Update modules/whisper/whisper_base.py
Browse files
modules/whisper/whisper_base.py
CHANGED
|
@@ -671,16 +671,33 @@ class WhisperBase(ABC):
|
|
| 671 |
|
| 672 |
@staticmethod
|
| 673 |
def cache_parameters(
|
| 674 |
-
|
| 675 |
-
|
|
|
|
| 676 |
):
|
| 677 |
-
"""
|
| 678 |
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
| 679 |
-
|
| 680 |
-
|
|
|
|
| 681 |
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 682 |
|
| 683 |
-
|
|
|
|
| 684 |
|
| 685 |
@staticmethod
|
| 686 |
def resample_audio(audio: Union[str, np.ndarray],
|
|
|
|
| 671 |
|
| 672 |
@staticmethod
|
| 673 |
def cache_parameters(
|
| 674 |
+
params: WhisperValues,
|
| 675 |
+
file_format: str = "SRT",
|
| 676 |
+
add_timestamp: bool = True
|
| 677 |
):
|
| 678 |
+
"""Cache parameters to the yaml file"""
|
| 679 |
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
| 680 |
+
param_to_cache = params.to_dict()
|
| 681 |
+
|
| 682 |
+
cached_yaml = {**cached_params, **param_to_cache}
|
| 683 |
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
|
| 684 |
+
cached_yaml["whisper"]["file_format"] = file_format
|
| 685 |
+
|
| 686 |
+
suppress_token = cached_yaml["whisper"].get("suppress_tokens", None)
|
| 687 |
+
if suppress_token and isinstance(suppress_token, list):
|
| 688 |
+
cached_yaml["whisper"]["suppress_tokens"] = str(suppress_token)
|
| 689 |
+
|
| 690 |
+
if cached_yaml["whisper"].get("lang", None) is None:
|
| 691 |
+
cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION.unwrap()
|
| 692 |
+
else:
|
| 693 |
+
language_dict = whisper.tokenizer.LANGUAGES
|
| 694 |
+
cached_yaml["whisper"]["lang"] = language_dict[cached_yaml["whisper"]["lang"]]
|
| 695 |
+
|
| 696 |
+
if cached_yaml["vad"].get("max_speech_duration_s", float('inf')) == float('inf'):
|
| 697 |
+
cached_yaml["vad"]["max_speech_duration_s"] = GRADIO_NONE_NUMBER_MAX
|
| 698 |
|
| 699 |
+
if cached_yaml is not None and cached_yaml:
|
| 700 |
+
save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
|
| 701 |
|
| 702 |
@staticmethod
|
| 703 |
def resample_audio(audio: Union[str, np.ndarray],
|