Spaces:
Running
Running
Merge pull request #181 from jhj0517/feature/integrate-whisperx
Browse files- app.py +63 -27
- modules/diarize/__init__.py +0 -0
- modules/diarize/audio_loader.py +161 -0
- modules/diarize/diarize_pipeline.py +92 -0
- modules/diarize/diarizer.py +127 -0
- modules/translation/__init__.py +0 -0
- modules/{deepl_api.py β translation/deepl_api.py} +3 -3
- modules/{nllb_inference.py β translation/nllb_inference.py} +1 -1
- modules/{translation_base.py β translation/translation_base.py} +6 -6
- modules/utils/__init__.py +0 -0
- modules/{subtitle_manager.py β utils/subtitle_manager.py} +0 -0
- modules/{youtube_manager.py β utils/youtube_manager.py} +0 -0
- modules/whisper/__init__.py +0 -0
- modules/{faster_whisper_inference.py β whisper/faster_whisper_inference.py} +9 -10
- modules/{insanely_fast_whisper_inference.py β whisper/insanely_fast_whisper_inference.py} +9 -6
- modules/{whisper_Inference.py β whisper/whisper_Inference.py} +9 -7
- modules/{whisper_base.py β whisper/whisper_base.py} +65 -11
- modules/{whisper_parameter.py β whisper/whisper_parameter.py} +22 -3
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -1,15 +1,14 @@
|
|
| 1 |
-
import gradio as gr
|
| 2 |
import os
|
| 3 |
import argparse
|
| 4 |
|
| 5 |
-
from modules.whisper_Inference import WhisperInference
|
| 6 |
-
from modules.faster_whisper_inference import FasterWhisperInference
|
| 7 |
-
from modules.insanely_fast_whisper_inference import InsanelyFastWhisperInference
|
| 8 |
-
from modules.nllb_inference import NLLBInference
|
| 9 |
from ui.htmls import *
|
| 10 |
-
from modules.youtube_manager import get_ytmetas
|
| 11 |
-
from modules.deepl_api import DeepLAPI
|
| 12 |
-
from modules.whisper_parameter import *
|
| 13 |
|
| 14 |
|
| 15 |
class App:
|
|
@@ -28,28 +27,35 @@ class App:
|
|
| 28 |
)
|
| 29 |
|
| 30 |
def init_whisper(self):
|
|
|
|
|
|
|
|
|
|
| 31 |
whisper_type = self.args.whisper_type.lower().strip()
|
| 32 |
|
| 33 |
if whisper_type in ["faster_whisper", "faster-whisper", "fasterwhisper"]:
|
| 34 |
whisper_inf = FasterWhisperInference(
|
| 35 |
model_dir=self.args.faster_whisper_model_dir,
|
| 36 |
-
output_dir=self.args.output_dir
|
|
|
|
| 37 |
)
|
| 38 |
elif whisper_type in ["whisper"]:
|
| 39 |
whisper_inf = WhisperInference(
|
| 40 |
model_dir=self.args.whisper_model_dir,
|
| 41 |
-
output_dir=self.args.output_dir
|
|
|
|
| 42 |
)
|
| 43 |
elif whisper_type in ["insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
|
| 44 |
"insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"]:
|
| 45 |
whisper_inf = InsanelyFastWhisperInference(
|
| 46 |
model_dir=self.args.insanely_fast_whisper_model_dir,
|
| 47 |
-
output_dir=self.args.output_dir
|
|
|
|
| 48 |
)
|
| 49 |
else:
|
| 50 |
whisper_inf = FasterWhisperInference(
|
| 51 |
model_dir=self.args.faster_whisper_model_dir,
|
| 52 |
-
output_dir=self.args.output_dir
|
|
|
|
| 53 |
)
|
| 54 |
return whisper_inf
|
| 55 |
|
|
@@ -87,7 +93,7 @@ class App:
|
|
| 87 |
cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
|
| 88 |
with gr.Row():
|
| 89 |
cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename", interactive=True)
|
| 90 |
-
with gr.Accordion("
|
| 91 |
nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
|
| 92 |
nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
|
| 93 |
nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
|
|
@@ -98,14 +104,20 @@ class App:
|
|
| 98 |
tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
|
| 99 |
sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
|
| 100 |
nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True)
|
| 101 |
-
with gr.Accordion("VAD
|
| 102 |
cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
|
| 103 |
-
sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
|
| 104 |
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
|
| 105 |
nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999)
|
| 106 |
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
|
| 107 |
nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
|
| 108 |
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
with gr.Accordion("Insanely Fast Whisper Parameters", open=False, visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
|
| 110 |
nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
|
| 111 |
nb_batch_size = gr.Number(label="Batch Size", value=24, precision=0)
|
|
@@ -138,10 +150,13 @@ class App:
|
|
| 138 |
window_size_sample=nb_window_size_sample,
|
| 139 |
speech_pad_ms=nb_speech_pad_ms,
|
| 140 |
chunk_length_s=nb_chunk_length_s,
|
| 141 |
-
batch_size=nb_batch_size
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
btn_run.click(fn=self.whisper_inf.transcribe_file,
|
| 144 |
-
inputs=params + whisper_params.
|
| 145 |
outputs=[tb_indicator, files_subtitles])
|
| 146 |
btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
|
| 147 |
dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
|
|
@@ -166,7 +181,7 @@ class App:
|
|
| 166 |
with gr.Row():
|
| 167 |
cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
|
| 168 |
interactive=True)
|
| 169 |
-
with gr.Accordion("
|
| 170 |
nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
|
| 171 |
nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
|
| 172 |
nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
|
|
@@ -177,14 +192,20 @@ class App:
|
|
| 177 |
tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
|
| 178 |
sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
|
| 179 |
nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True)
|
| 180 |
-
with gr.Accordion("VAD
|
| 181 |
cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
|
| 182 |
-
sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
|
| 183 |
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
|
| 184 |
nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999)
|
| 185 |
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
|
| 186 |
nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
|
| 187 |
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
with gr.Accordion("Insanely Fast Whisper Parameters", open=False,
|
| 189 |
visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
|
| 190 |
nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
|
|
@@ -218,10 +239,13 @@ class App:
|
|
| 218 |
window_size_sample=nb_window_size_sample,
|
| 219 |
speech_pad_ms=nb_speech_pad_ms,
|
| 220 |
chunk_length_s=nb_chunk_length_s,
|
| 221 |
-
batch_size=nb_batch_size
|
|
|
|
|
|
|
|
|
|
| 222 |
|
| 223 |
btn_run.click(fn=self.whisper_inf.transcribe_youtube,
|
| 224 |
-
inputs=params + whisper_params.
|
| 225 |
outputs=[tb_indicator, files_subtitles])
|
| 226 |
tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink],
|
| 227 |
outputs=[img_thumbnail, tb_title, tb_description])
|
|
@@ -239,7 +263,7 @@ class App:
|
|
| 239 |
dd_file_format = gr.Dropdown(["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
|
| 240 |
with gr.Row():
|
| 241 |
cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
|
| 242 |
-
with gr.Accordion("
|
| 243 |
nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
|
| 244 |
nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
|
| 245 |
nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
|
|
@@ -249,14 +273,22 @@ class App:
|
|
| 249 |
cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=True, interactive=True)
|
| 250 |
tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
|
| 251 |
sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
|
| 252 |
-
with gr.Accordion("VAD
|
| 253 |
cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
|
| 254 |
-
sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
|
| 255 |
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
|
| 256 |
nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999)
|
| 257 |
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
|
| 258 |
nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
|
| 259 |
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
with gr.Accordion("Insanely Fast Whisper Parameters", open=False,
|
| 261 |
visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
|
| 262 |
nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
|
|
@@ -290,10 +322,13 @@ class App:
|
|
| 290 |
window_size_sample=nb_window_size_sample,
|
| 291 |
speech_pad_ms=nb_speech_pad_ms,
|
| 292 |
chunk_length_s=nb_chunk_length_s,
|
| 293 |
-
batch_size=nb_batch_size
|
|
|
|
|
|
|
|
|
|
| 294 |
|
| 295 |
btn_run.click(fn=self.whisper_inf.transcribe_mic,
|
| 296 |
-
inputs=params + whisper_params.
|
| 297 |
outputs=[tb_indicator, files_subtitles])
|
| 298 |
btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
|
| 299 |
dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
|
|
@@ -392,6 +427,7 @@ parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=Tru
|
|
| 392 |
parser.add_argument('--whisper_model_dir', type=str, default=os.path.join("models", "Whisper"), help='Directory path of the whisper model')
|
| 393 |
parser.add_argument('--faster_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "faster-whisper"), help='Directory path of the faster-whisper model')
|
| 394 |
parser.add_argument('--insanely_fast_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "insanely-fast-whisper"), help='Directory path of the insanely-fast-whisper model')
|
|
|
|
| 395 |
parser.add_argument('--nllb_model_dir', type=str, default=os.path.join("models", "NLLB"), help='Directory path of the Facebook NLLB model')
|
| 396 |
parser.add_argument('--output_dir', type=str, default=os.path.join("outputs"), help='Directory path of the outputs')
|
| 397 |
_args = parser.parse_args()
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import argparse
|
| 3 |
|
| 4 |
+
from modules.whisper.whisper_Inference import WhisperInference
|
| 5 |
+
from modules.whisper.faster_whisper_inference import FasterWhisperInference
|
| 6 |
+
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
|
| 7 |
+
from modules.translation.nllb_inference import NLLBInference
|
| 8 |
from ui.htmls import *
|
| 9 |
+
from modules.utils.youtube_manager import get_ytmetas
|
| 10 |
+
from modules.translation.deepl_api import DeepLAPI
|
| 11 |
+
from modules.whisper.whisper_parameter import *
|
| 12 |
|
| 13 |
|
| 14 |
class App:
|
|
|
|
| 27 |
)
|
| 28 |
|
| 29 |
def init_whisper(self):
|
| 30 |
+
# Temporal fix of the issue : https://github.com/jhj0517/Whisper-WebUI/issues/144
|
| 31 |
+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
| 32 |
+
|
| 33 |
whisper_type = self.args.whisper_type.lower().strip()
|
| 34 |
|
| 35 |
if whisper_type in ["faster_whisper", "faster-whisper", "fasterwhisper"]:
|
| 36 |
whisper_inf = FasterWhisperInference(
|
| 37 |
model_dir=self.args.faster_whisper_model_dir,
|
| 38 |
+
output_dir=self.args.output_dir,
|
| 39 |
+
args=self.args
|
| 40 |
)
|
| 41 |
elif whisper_type in ["whisper"]:
|
| 42 |
whisper_inf = WhisperInference(
|
| 43 |
model_dir=self.args.whisper_model_dir,
|
| 44 |
+
output_dir=self.args.output_dir,
|
| 45 |
+
args=self.args
|
| 46 |
)
|
| 47 |
elif whisper_type in ["insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
|
| 48 |
"insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"]:
|
| 49 |
whisper_inf = InsanelyFastWhisperInference(
|
| 50 |
model_dir=self.args.insanely_fast_whisper_model_dir,
|
| 51 |
+
output_dir=self.args.output_dir,
|
| 52 |
+
args=self.args
|
| 53 |
)
|
| 54 |
else:
|
| 55 |
whisper_inf = FasterWhisperInference(
|
| 56 |
model_dir=self.args.faster_whisper_model_dir,
|
| 57 |
+
output_dir=self.args.output_dir,
|
| 58 |
+
args=self.args
|
| 59 |
)
|
| 60 |
return whisper_inf
|
| 61 |
|
|
|
|
| 93 |
cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
|
| 94 |
with gr.Row():
|
| 95 |
cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename", interactive=True)
|
| 96 |
+
with gr.Accordion("Advanced Parameters", open=False):
|
| 97 |
nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
|
| 98 |
nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
|
| 99 |
nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
|
|
|
|
| 104 |
tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
|
| 105 |
sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
|
| 106 |
nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True)
|
| 107 |
+
with gr.Accordion("VAD", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
|
| 108 |
cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
|
| 109 |
+
sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5, info="Lower it to be more sensitive to small sounds.")
|
| 110 |
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
|
| 111 |
nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999)
|
| 112 |
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
|
| 113 |
nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
|
| 114 |
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
|
| 115 |
+
with gr.Accordion("Diarization", open=False):
|
| 116 |
+
cb_diarize = gr.Checkbox(label="Enable Diarization")
|
| 117 |
+
tb_hf_token = gr.Text(label="HuggingFace Token", value="",
|
| 118 |
+
info="This is only needed the first time you download the model. If you already have models, you don't need to enter. "
|
| 119 |
+
"To download the model, you must manually go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and agree to their requirement.")
|
| 120 |
+
dd_diarization_device = gr.Dropdown(label="Device", choices=self.whisper_inf.diarizer.get_available_device(), value=self.whisper_inf.diarizer.get_device())
|
| 121 |
with gr.Accordion("Insanely Fast Whisper Parameters", open=False, visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
|
| 122 |
nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
|
| 123 |
nb_batch_size = gr.Number(label="Batch Size", value=24, precision=0)
|
|
|
|
| 150 |
window_size_sample=nb_window_size_sample,
|
| 151 |
speech_pad_ms=nb_speech_pad_ms,
|
| 152 |
chunk_length_s=nb_chunk_length_s,
|
| 153 |
+
batch_size=nb_batch_size,
|
| 154 |
+
is_diarize=cb_diarize,
|
| 155 |
+
hf_token=tb_hf_token,
|
| 156 |
+
diarization_device=dd_diarization_device)
|
| 157 |
|
| 158 |
btn_run.click(fn=self.whisper_inf.transcribe_file,
|
| 159 |
+
inputs=params + whisper_params.as_list(),
|
| 160 |
outputs=[tb_indicator, files_subtitles])
|
| 161 |
btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
|
| 162 |
dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
|
|
|
|
| 181 |
with gr.Row():
|
| 182 |
cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
|
| 183 |
interactive=True)
|
| 184 |
+
with gr.Accordion("Advanced Parameters", open=False):
|
| 185 |
nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
|
| 186 |
nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
|
| 187 |
nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
|
|
|
|
| 192 |
tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
|
| 193 |
sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
|
| 194 |
nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True)
|
| 195 |
+
with gr.Accordion("VAD", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
|
| 196 |
cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
|
| 197 |
+
sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5, info="Lower it to be more sensitive to small sounds.")
|
| 198 |
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
|
| 199 |
nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999)
|
| 200 |
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
|
| 201 |
nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
|
| 202 |
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
|
| 203 |
+
with gr.Accordion("Diarization", open=False):
|
| 204 |
+
cb_diarize = gr.Checkbox(label="Enable Diarization")
|
| 205 |
+
tb_hf_token = gr.Text(label="HuggingFace Token", value="",
|
| 206 |
+
info="This is only needed the first time you download the model. If you already have models, you don't need to enter. "
|
| 207 |
+
"To download the model, you must manually go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and agree to their requirement.")
|
| 208 |
+
dd_diarization_device = gr.Dropdown(label="Device", choices=self.whisper_inf.diarizer.get_available_device(), value=self.whisper_inf.diarizer.get_device())
|
| 209 |
with gr.Accordion("Insanely Fast Whisper Parameters", open=False,
|
| 210 |
visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
|
| 211 |
nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
|
|
|
|
| 239 |
window_size_sample=nb_window_size_sample,
|
| 240 |
speech_pad_ms=nb_speech_pad_ms,
|
| 241 |
chunk_length_s=nb_chunk_length_s,
|
| 242 |
+
batch_size=nb_batch_size,
|
| 243 |
+
is_diarize=cb_diarize,
|
| 244 |
+
hf_token=tb_hf_token,
|
| 245 |
+
diarization_device=dd_diarization_device)
|
| 246 |
|
| 247 |
btn_run.click(fn=self.whisper_inf.transcribe_youtube,
|
| 248 |
+
inputs=params + whisper_params.as_list(),
|
| 249 |
outputs=[tb_indicator, files_subtitles])
|
| 250 |
tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink],
|
| 251 |
outputs=[img_thumbnail, tb_title, tb_description])
|
|
|
|
| 263 |
dd_file_format = gr.Dropdown(["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
|
| 264 |
with gr.Row():
|
| 265 |
cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
|
| 266 |
+
with gr.Accordion("Advanced Parameters", open=False):
|
| 267 |
nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
|
| 268 |
nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
|
| 269 |
nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
|
|
|
|
| 273 |
cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=True, interactive=True)
|
| 274 |
tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
|
| 275 |
sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
|
| 276 |
+
with gr.Accordion("VAD", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
|
| 277 |
cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
|
| 278 |
+
sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5, info="Lower it to be more sensitive to small sounds.")
|
| 279 |
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
|
| 280 |
nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999)
|
| 281 |
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
|
| 282 |
nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
|
| 283 |
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
|
| 284 |
+
with gr.Accordion("Diarization", open=False):
|
| 285 |
+
cb_diarize = gr.Checkbox(label="Enable Diarization")
|
| 286 |
+
tb_hf_token = gr.Text(label="HuggingFace Token", value="",
|
| 287 |
+
info="This is only needed the first time you download the model. If you already have models, you don't need to enter. "
|
| 288 |
+
"To download the model, you must manually go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and agree to their requirement.")
|
| 289 |
+
dd_diarization_device = gr.Dropdown(label="Device",
|
| 290 |
+
choices=self.whisper_inf.diarizer.get_available_device(),
|
| 291 |
+
value=self.whisper_inf.diarizer.get_device())
|
| 292 |
with gr.Accordion("Insanely Fast Whisper Parameters", open=False,
|
| 293 |
visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
|
| 294 |
nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
|
|
|
|
| 322 |
window_size_sample=nb_window_size_sample,
|
| 323 |
speech_pad_ms=nb_speech_pad_ms,
|
| 324 |
chunk_length_s=nb_chunk_length_s,
|
| 325 |
+
batch_size=nb_batch_size,
|
| 326 |
+
is_diarize=cb_diarize,
|
| 327 |
+
hf_token=tb_hf_token,
|
| 328 |
+
diarization_device=dd_diarization_device)
|
| 329 |
|
| 330 |
btn_run.click(fn=self.whisper_inf.transcribe_mic,
|
| 331 |
+
inputs=params + whisper_params.as_list(),
|
| 332 |
outputs=[tb_indicator, files_subtitles])
|
| 333 |
btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
|
| 334 |
dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
|
|
|
|
| 427 |
parser.add_argument('--whisper_model_dir', type=str, default=os.path.join("models", "Whisper"), help='Directory path of the whisper model')
|
| 428 |
parser.add_argument('--faster_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "faster-whisper"), help='Directory path of the faster-whisper model')
|
| 429 |
parser.add_argument('--insanely_fast_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "insanely-fast-whisper"), help='Directory path of the insanely-fast-whisper model')
|
| 430 |
+
parser.add_argument('--diarization_model_dir', type=str, default=os.path.join("models", "Diarization"), help='Directory path of the diarization model')
|
| 431 |
parser.add_argument('--nllb_model_dir', type=str, default=os.path.join("models", "NLLB"), help='Directory path of the Facebook NLLB model')
|
| 432 |
parser.add_argument('--output_dir', type=str, default=os.path.join("outputs"), help='Directory path of the outputs')
|
| 433 |
_args = parser.parse_args()
|
modules/diarize/__init__.py
ADDED
|
File without changes
|
modules/diarize/audio_loader.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
from functools import lru_cache
|
| 4 |
+
from typing import Optional, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
def exact_div(x, y):
|
| 11 |
+
assert x % y == 0
|
| 12 |
+
return x // y
|
| 13 |
+
|
| 14 |
+
# hard-coded audio hyperparameters
|
| 15 |
+
SAMPLE_RATE = 16000
|
| 16 |
+
N_FFT = 400
|
| 17 |
+
HOP_LENGTH = 160
|
| 18 |
+
CHUNK_LENGTH = 30
|
| 19 |
+
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
| 20 |
+
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
|
| 21 |
+
|
| 22 |
+
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
|
| 23 |
+
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
| 24 |
+
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
| 28 |
+
"""
|
| 29 |
+
Open an audio file and read as mono waveform, resampling as necessary
|
| 30 |
+
|
| 31 |
+
Parameters
|
| 32 |
+
----------
|
| 33 |
+
file: str
|
| 34 |
+
The audio file to open
|
| 35 |
+
|
| 36 |
+
sr: int
|
| 37 |
+
The sample rate to resample the audio if necessary
|
| 38 |
+
|
| 39 |
+
Returns
|
| 40 |
+
-------
|
| 41 |
+
A NumPy array containing the audio waveform, in float32 dtype.
|
| 42 |
+
"""
|
| 43 |
+
try:
|
| 44 |
+
# Launches a subprocess to decode audio while down-mixing and resampling as necessary.
|
| 45 |
+
# Requires the ffmpeg CLI to be installed.
|
| 46 |
+
cmd = [
|
| 47 |
+
"ffmpeg",
|
| 48 |
+
"-nostdin",
|
| 49 |
+
"-threads",
|
| 50 |
+
"0",
|
| 51 |
+
"-i",
|
| 52 |
+
file,
|
| 53 |
+
"-f",
|
| 54 |
+
"s16le",
|
| 55 |
+
"-ac",
|
| 56 |
+
"1",
|
| 57 |
+
"-acodec",
|
| 58 |
+
"pcm_s16le",
|
| 59 |
+
"-ar",
|
| 60 |
+
str(sr),
|
| 61 |
+
"-",
|
| 62 |
+
]
|
| 63 |
+
out = subprocess.run(cmd, capture_output=True, check=True).stdout
|
| 64 |
+
except subprocess.CalledProcessError as e:
|
| 65 |
+
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
| 66 |
+
|
| 67 |
+
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
| 71 |
+
"""
|
| 72 |
+
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
| 73 |
+
"""
|
| 74 |
+
if torch.is_tensor(array):
|
| 75 |
+
if array.shape[axis] > length:
|
| 76 |
+
array = array.index_select(
|
| 77 |
+
dim=axis, index=torch.arange(length, device=array.device)
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
if array.shape[axis] < length:
|
| 81 |
+
pad_widths = [(0, 0)] * array.ndim
|
| 82 |
+
pad_widths[axis] = (0, length - array.shape[axis])
|
| 83 |
+
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
|
| 84 |
+
else:
|
| 85 |
+
if array.shape[axis] > length:
|
| 86 |
+
array = array.take(indices=range(length), axis=axis)
|
| 87 |
+
|
| 88 |
+
if array.shape[axis] < length:
|
| 89 |
+
pad_widths = [(0, 0)] * array.ndim
|
| 90 |
+
pad_widths[axis] = (0, length - array.shape[axis])
|
| 91 |
+
array = np.pad(array, pad_widths)
|
| 92 |
+
|
| 93 |
+
return array
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@lru_cache(maxsize=None)
|
| 97 |
+
def mel_filters(device, n_mels: int) -> torch.Tensor:
|
| 98 |
+
"""
|
| 99 |
+
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
| 100 |
+
Allows decoupling librosa dependency; saved using:
|
| 101 |
+
|
| 102 |
+
np.savez_compressed(
|
| 103 |
+
"mel_filters.npz",
|
| 104 |
+
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
| 105 |
+
)
|
| 106 |
+
"""
|
| 107 |
+
assert n_mels in [80, 128], f"Unsupported n_mels: {n_mels}"
|
| 108 |
+
with np.load(
|
| 109 |
+
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
| 110 |
+
) as f:
|
| 111 |
+
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def log_mel_spectrogram(
|
| 115 |
+
audio: Union[str, np.ndarray, torch.Tensor],
|
| 116 |
+
n_mels: int,
|
| 117 |
+
padding: int = 0,
|
| 118 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 119 |
+
):
|
| 120 |
+
"""
|
| 121 |
+
Compute the log-Mel spectrogram of
|
| 122 |
+
|
| 123 |
+
Parameters
|
| 124 |
+
----------
|
| 125 |
+
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
| 126 |
+
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
| 127 |
+
|
| 128 |
+
n_mels: int
|
| 129 |
+
The number of Mel-frequency filters, only 80 is supported
|
| 130 |
+
|
| 131 |
+
padding: int
|
| 132 |
+
Number of zero samples to pad to the right
|
| 133 |
+
|
| 134 |
+
device: Optional[Union[str, torch.device]]
|
| 135 |
+
If given, the audio tensor is moved to this device before STFT
|
| 136 |
+
|
| 137 |
+
Returns
|
| 138 |
+
-------
|
| 139 |
+
torch.Tensor, shape = (80, n_frames)
|
| 140 |
+
A Tensor that contains the Mel spectrogram
|
| 141 |
+
"""
|
| 142 |
+
if not torch.is_tensor(audio):
|
| 143 |
+
if isinstance(audio, str):
|
| 144 |
+
audio = load_audio(audio)
|
| 145 |
+
audio = torch.from_numpy(audio)
|
| 146 |
+
|
| 147 |
+
if device is not None:
|
| 148 |
+
audio = audio.to(device)
|
| 149 |
+
if padding > 0:
|
| 150 |
+
audio = F.pad(audio, (0, padding))
|
| 151 |
+
window = torch.hann_window(N_FFT).to(audio.device)
|
| 152 |
+
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
| 153 |
+
magnitudes = stft[..., :-1].abs() ** 2
|
| 154 |
+
|
| 155 |
+
filters = mel_filters(audio.device, n_mels)
|
| 156 |
+
mel_spec = filters @ magnitudes
|
| 157 |
+
|
| 158 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
| 159 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
| 160 |
+
log_spec = (log_spec + 4.0) / 4.0
|
| 161 |
+
return log_spec
|
modules/diarize/diarize_pipeline.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import os
|
| 4 |
+
from pyannote.audio import Pipeline
|
| 5 |
+
from typing import Optional, Union
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from modules.diarize.audio_loader import load_audio, SAMPLE_RATE
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DiarizationPipeline:
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
model_name="pyannote/speaker-diarization-3.1",
|
| 15 |
+
cache_dir: str = os.path.join("models", "Diarization"),
|
| 16 |
+
use_auth_token=None,
|
| 17 |
+
device: Optional[Union[str, torch.device]] = "cpu",
|
| 18 |
+
):
|
| 19 |
+
if isinstance(device, str):
|
| 20 |
+
device = torch.device(device)
|
| 21 |
+
self.model = Pipeline.from_pretrained(
|
| 22 |
+
model_name,
|
| 23 |
+
use_auth_token=use_auth_token,
|
| 24 |
+
cache_dir=cache_dir
|
| 25 |
+
).to(device)
|
| 26 |
+
|
| 27 |
+
def __call__(self, audio: Union[str, np.ndarray], min_speakers=None, max_speakers=None):
|
| 28 |
+
if isinstance(audio, str):
|
| 29 |
+
audio = load_audio(audio)
|
| 30 |
+
audio_data = {
|
| 31 |
+
'waveform': torch.from_numpy(audio[None, :]),
|
| 32 |
+
'sample_rate': SAMPLE_RATE
|
| 33 |
+
}
|
| 34 |
+
segments = self.model(audio_data, min_speakers=min_speakers, max_speakers=max_speakers)
|
| 35 |
+
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
|
| 36 |
+
diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start)
|
| 37 |
+
diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end)
|
| 38 |
+
return diarize_df
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
|
| 42 |
+
transcript_segments = transcript_result["segments"]
|
| 43 |
+
for seg in transcript_segments:
|
| 44 |
+
# assign speaker to segment (if any)
|
| 45 |
+
diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'],
|
| 46 |
+
seg['start'])
|
| 47 |
+
diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start'])
|
| 48 |
+
|
| 49 |
+
intersected = diarize_df[diarize_df["intersection"] > 0]
|
| 50 |
+
|
| 51 |
+
speaker = None
|
| 52 |
+
if len(intersected) > 0:
|
| 53 |
+
# Choosing most strong intersection
|
| 54 |
+
speaker = intersected.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
|
| 55 |
+
elif fill_nearest:
|
| 56 |
+
# Otherwise choosing closest
|
| 57 |
+
speaker = diarize_df.sort_values(by=["intersection"], ascending=False)["speaker"].values[0]
|
| 58 |
+
|
| 59 |
+
if speaker is not None:
|
| 60 |
+
seg["speaker"] = speaker
|
| 61 |
+
|
| 62 |
+
# assign speaker to words
|
| 63 |
+
if 'words' in seg:
|
| 64 |
+
for word in seg['words']:
|
| 65 |
+
if 'start' in word:
|
| 66 |
+
diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(
|
| 67 |
+
diarize_df['start'], word['start'])
|
| 68 |
+
diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'],
|
| 69 |
+
word['start'])
|
| 70 |
+
|
| 71 |
+
intersected = diarize_df[diarize_df["intersection"] > 0]
|
| 72 |
+
|
| 73 |
+
word_speaker = None
|
| 74 |
+
if len(intersected) > 0:
|
| 75 |
+
# Choosing most strong intersection
|
| 76 |
+
word_speaker = \
|
| 77 |
+
intersected.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
|
| 78 |
+
elif fill_nearest:
|
| 79 |
+
# Otherwise choosing closest
|
| 80 |
+
word_speaker = diarize_df.sort_values(by=["intersection"], ascending=False)["speaker"].values[0]
|
| 81 |
+
|
| 82 |
+
if word_speaker is not None:
|
| 83 |
+
word["speaker"] = word_speaker
|
| 84 |
+
|
| 85 |
+
return transcript_result
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class Segment:
|
| 89 |
+
def __init__(self, start, end, speaker=None):
|
| 90 |
+
self.start = start
|
| 91 |
+
self.end = end
|
| 92 |
+
self.speaker = speaker
|
modules/diarize/diarizer.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from typing import List
|
| 4 |
+
import time
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
|
| 8 |
+
from modules.diarize.audio_loader import load_audio
|
| 9 |
+
|
| 10 |
+
class Diarizer:
|
| 11 |
+
def __init__(self,
|
| 12 |
+
model_dir: str = os.path.join("models", "Diarization")
|
| 13 |
+
):
|
| 14 |
+
self.device = self.get_device()
|
| 15 |
+
self.available_device = self.get_available_device()
|
| 16 |
+
self.compute_type = "float16"
|
| 17 |
+
self.model_dir = model_dir
|
| 18 |
+
os.makedirs(self.model_dir, exist_ok=True)
|
| 19 |
+
self.pipe = None
|
| 20 |
+
|
| 21 |
+
def run(self,
|
| 22 |
+
audio: str,
|
| 23 |
+
transcribed_result: List[dict],
|
| 24 |
+
use_auth_token: str,
|
| 25 |
+
device: str
|
| 26 |
+
):
|
| 27 |
+
"""
|
| 28 |
+
Diarize transcribed result as a post-processing
|
| 29 |
+
|
| 30 |
+
Parameters
|
| 31 |
+
----------
|
| 32 |
+
audio: Union[str, BinaryIO, np.ndarray]
|
| 33 |
+
Audio input. This can be file path or binary type.
|
| 34 |
+
transcribed_result: List[dict]
|
| 35 |
+
transcribed result through whisper.
|
| 36 |
+
use_auth_token: str
|
| 37 |
+
Huggingface token with READ permission. This is only needed the first time you download the model.
|
| 38 |
+
You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model.
|
| 39 |
+
device: str
|
| 40 |
+
Device for diarization.
|
| 41 |
+
|
| 42 |
+
Returns
|
| 43 |
+
----------
|
| 44 |
+
segments_result: List[dict]
|
| 45 |
+
list of dicts that includes start, end timestamps and transcribed text
|
| 46 |
+
elapsed_time: float
|
| 47 |
+
elapsed time for running
|
| 48 |
+
"""
|
| 49 |
+
start_time = time.time()
|
| 50 |
+
|
| 51 |
+
if (device != self.device
|
| 52 |
+
or self.pipe is None):
|
| 53 |
+
self.update_pipe(
|
| 54 |
+
device=device,
|
| 55 |
+
use_auth_token=use_auth_token
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
audio = load_audio(audio)
|
| 59 |
+
|
| 60 |
+
diarization_segments = self.pipe(audio)
|
| 61 |
+
diarized_result = assign_word_speakers(
|
| 62 |
+
diarization_segments,
|
| 63 |
+
{"segments": transcribed_result}
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
for segment in diarized_result["segments"]:
|
| 67 |
+
speaker = "None"
|
| 68 |
+
if "speaker" in segment:
|
| 69 |
+
speaker = segment["speaker"]
|
| 70 |
+
segment["text"] = speaker + "|" + segment["text"][1:]
|
| 71 |
+
|
| 72 |
+
elapsed_time = time.time() - start_time
|
| 73 |
+
return diarized_result["segments"], elapsed_time
|
| 74 |
+
|
| 75 |
+
def update_pipe(self,
|
| 76 |
+
use_auth_token: str,
|
| 77 |
+
device: str
|
| 78 |
+
):
|
| 79 |
+
"""
|
| 80 |
+
Set pipeline for diarization
|
| 81 |
+
|
| 82 |
+
Parameters
|
| 83 |
+
----------
|
| 84 |
+
use_auth_token: str
|
| 85 |
+
Huggingface token with READ permission. This is only needed the first time you download the model.
|
| 86 |
+
You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model.
|
| 87 |
+
device: str
|
| 88 |
+
Device for diarization.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
os.makedirs(self.model_dir, exist_ok=True)
|
| 92 |
+
|
| 93 |
+
if (not os.listdir(self.model_dir) and
|
| 94 |
+
not use_auth_token):
|
| 95 |
+
print(
|
| 96 |
+
"\nFailed to diarize. You need huggingface token and agree to their requirements to download the diarization model.\n"
|
| 97 |
+
"Go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and follow their instructions to download the model.\n"
|
| 98 |
+
)
|
| 99 |
+
return
|
| 100 |
+
|
| 101 |
+
logger = logging.getLogger("speechbrain.utils.train_logger")
|
| 102 |
+
# Disable redundant torchvision warning message
|
| 103 |
+
logger.disabled = True
|
| 104 |
+
self.pipe = DiarizationPipeline(
|
| 105 |
+
use_auth_token=use_auth_token,
|
| 106 |
+
device=device,
|
| 107 |
+
cache_dir=self.model_dir
|
| 108 |
+
)
|
| 109 |
+
logger.disabled = False
|
| 110 |
+
|
| 111 |
+
@staticmethod
|
| 112 |
+
def get_device():
|
| 113 |
+
if torch.cuda.is_available():
|
| 114 |
+
return "cuda"
|
| 115 |
+
elif torch.backends.mps.is_available():
|
| 116 |
+
return "mps"
|
| 117 |
+
else:
|
| 118 |
+
return "cpu"
|
| 119 |
+
|
| 120 |
+
@staticmethod
|
| 121 |
+
def get_available_device():
|
| 122 |
+
devices = ["cpu"]
|
| 123 |
+
if torch.cuda.is_available():
|
| 124 |
+
devices.append("cuda")
|
| 125 |
+
elif torch.backends.mps.is_available():
|
| 126 |
+
devices.append("mps")
|
| 127 |
+
return devices
|
modules/translation/__init__.py
ADDED
|
File without changes
|
modules/{deepl_api.py β translation/deepl_api.py}
RENAMED
|
@@ -4,7 +4,7 @@ import os
|
|
| 4 |
from datetime import datetime
|
| 5 |
import gradio as gr
|
| 6 |
|
| 7 |
-
from modules.subtitle_manager import *
|
| 8 |
|
| 9 |
"""
|
| 10 |
This is written with reference to the DeepL API documentation.
|
|
@@ -144,7 +144,7 @@ class DeepLAPI:
|
|
| 144 |
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
| 145 |
|
| 146 |
file_name = file_name[:-9]
|
| 147 |
-
output_path = os.path.join(self.output_dir, "
|
| 148 |
write_file(subtitle, output_path)
|
| 149 |
|
| 150 |
elif file_ext == ".vtt":
|
|
@@ -164,7 +164,7 @@ class DeepLAPI:
|
|
| 164 |
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
| 165 |
|
| 166 |
file_name = file_name[:-9]
|
| 167 |
-
output_path = os.path.join(self.output_dir, "
|
| 168 |
|
| 169 |
write_file(subtitle, output_path)
|
| 170 |
|
|
|
|
| 4 |
from datetime import datetime
|
| 5 |
import gradio as gr
|
| 6 |
|
| 7 |
+
from modules.utils.subtitle_manager import *
|
| 8 |
|
| 9 |
"""
|
| 10 |
This is written with reference to the DeepL API documentation.
|
|
|
|
| 144 |
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
| 145 |
|
| 146 |
file_name = file_name[:-9]
|
| 147 |
+
output_path = os.path.join(self.output_dir, "", f"{file_name}-{timestamp}.srt")
|
| 148 |
write_file(subtitle, output_path)
|
| 149 |
|
| 150 |
elif file_ext == ".vtt":
|
|
|
|
| 164 |
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
| 165 |
|
| 166 |
file_name = file_name[:-9]
|
| 167 |
+
output_path = os.path.join(self.output_dir, "", f"{file_name}-{timestamp}.vtt")
|
| 168 |
|
| 169 |
write_file(subtitle, output_path)
|
| 170 |
|
modules/{nllb_inference.py β translation/nllb_inference.py}
RENAMED
|
@@ -2,7 +2,7 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
|
| 2 |
import gradio as gr
|
| 3 |
import os
|
| 4 |
|
| 5 |
-
from modules.translation_base import TranslationBase
|
| 6 |
|
| 7 |
|
| 8 |
class NLLBInference(TranslationBase):
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
import os
|
| 4 |
|
| 5 |
+
from modules.translation.translation_base import TranslationBase
|
| 6 |
|
| 7 |
|
| 8 |
class NLLBInference(TranslationBase):
|
modules/{translation_base.py β translation/translation_base.py}
RENAMED
|
@@ -5,8 +5,8 @@ from abc import ABC, abstractmethod
|
|
| 5 |
from typing import List
|
| 6 |
from datetime import datetime
|
| 7 |
|
| 8 |
-
from modules.whisper_parameter import *
|
| 9 |
-
from modules.subtitle_manager import *
|
| 10 |
|
| 11 |
|
| 12 |
class TranslationBase(ABC):
|
|
@@ -90,9 +90,9 @@ class TranslationBase(ABC):
|
|
| 90 |
|
| 91 |
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
| 92 |
if add_timestamp:
|
| 93 |
-
output_path = os.path.join("outputs", "
|
| 94 |
else:
|
| 95 |
-
output_path = os.path.join("outputs", "
|
| 96 |
|
| 97 |
elif file_ext == ".vtt":
|
| 98 |
parsed_dicts = parse_vtt(file_path=file_path)
|
|
@@ -105,9 +105,9 @@ class TranslationBase(ABC):
|
|
| 105 |
|
| 106 |
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
| 107 |
if add_timestamp:
|
| 108 |
-
output_path = os.path.join(self.output_dir, "
|
| 109 |
else:
|
| 110 |
-
output_path = os.path.join(self.output_dir, "
|
| 111 |
|
| 112 |
write_file(subtitle, output_path)
|
| 113 |
files_info[file_name] = subtitle
|
|
|
|
| 5 |
from typing import List
|
| 6 |
from datetime import datetime
|
| 7 |
|
| 8 |
+
from modules.whisper.whisper_parameter import *
|
| 9 |
+
from modules.utils.subtitle_manager import *
|
| 10 |
|
| 11 |
|
| 12 |
class TranslationBase(ABC):
|
|
|
|
| 90 |
|
| 91 |
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
| 92 |
if add_timestamp:
|
| 93 |
+
output_path = os.path.join("outputs", "", f"{file_name}-{timestamp}.srt")
|
| 94 |
else:
|
| 95 |
+
output_path = os.path.join("outputs", "", f"{file_name}.srt")
|
| 96 |
|
| 97 |
elif file_ext == ".vtt":
|
| 98 |
parsed_dicts = parse_vtt(file_path=file_path)
|
|
|
|
| 105 |
|
| 106 |
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
| 107 |
if add_timestamp:
|
| 108 |
+
output_path = os.path.join(self.output_dir, "", f"{file_name}-{timestamp}.vtt")
|
| 109 |
else:
|
| 110 |
+
output_path = os.path.join(self.output_dir, "", f"{file_name}.vtt")
|
| 111 |
|
| 112 |
write_file(subtitle, output_path)
|
| 113 |
files_info[file_name] = subtitle
|
modules/utils/__init__.py
ADDED
|
File without changes
|
modules/{subtitle_manager.py β utils/subtitle_manager.py}
RENAMED
|
File without changes
|
modules/{youtube_manager.py β utils/youtube_manager.py}
RENAMED
|
File without changes
|
modules/whisper/__init__.py
ADDED
|
File without changes
|
modules/{faster_whisper_inference.py β whisper/faster_whisper_inference.py}
RENAMED
|
@@ -2,28 +2,27 @@ import os
|
|
| 2 |
import time
|
| 3 |
import numpy as np
|
| 4 |
from typing import BinaryIO, Union, Tuple, List
|
| 5 |
-
|
| 6 |
import faster_whisper
|
| 7 |
from faster_whisper.vad import VadOptions
|
| 8 |
import ctranslate2
|
| 9 |
import whisper
|
| 10 |
import gradio as gr
|
|
|
|
| 11 |
|
| 12 |
-
from modules.whisper_parameter import *
|
| 13 |
-
from modules.whisper_base import WhisperBase
|
| 14 |
-
|
| 15 |
-
# Temporal fix of the issue : https://github.com/jhj0517/Whisper-WebUI/issues/144
|
| 16 |
-
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
| 17 |
|
| 18 |
|
| 19 |
class FasterWhisperInference(WhisperBase):
|
| 20 |
def __init__(self,
|
| 21 |
model_dir: str,
|
| 22 |
-
output_dir: str
|
|
|
|
| 23 |
):
|
| 24 |
super().__init__(
|
| 25 |
model_dir=model_dir,
|
| 26 |
-
output_dir=output_dir
|
|
|
|
| 27 |
)
|
| 28 |
self.model_paths = self.get_model_paths()
|
| 29 |
self.available_models = self.model_paths.keys()
|
|
@@ -45,7 +44,7 @@ class FasterWhisperInference(WhisperBase):
|
|
| 45 |
progress: gr.Progress
|
| 46 |
Indicator to show progress directly in gradio.
|
| 47 |
*whisper_params: tuple
|
| 48 |
-
|
| 49 |
|
| 50 |
Returns
|
| 51 |
----------
|
|
@@ -56,7 +55,7 @@ class FasterWhisperInference(WhisperBase):
|
|
| 56 |
"""
|
| 57 |
start_time = time.time()
|
| 58 |
|
| 59 |
-
params = WhisperParameters.
|
| 60 |
|
| 61 |
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
| 62 |
self.update_model(params.model_size, params.compute_type, progress)
|
|
|
|
| 2 |
import time
|
| 3 |
import numpy as np
|
| 4 |
from typing import BinaryIO, Union, Tuple, List
|
|
|
|
| 5 |
import faster_whisper
|
| 6 |
from faster_whisper.vad import VadOptions
|
| 7 |
import ctranslate2
|
| 8 |
import whisper
|
| 9 |
import gradio as gr
|
| 10 |
+
from argparse import Namespace
|
| 11 |
|
| 12 |
+
from modules.whisper.whisper_parameter import *
|
| 13 |
+
from modules.whisper.whisper_base import WhisperBase
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
class FasterWhisperInference(WhisperBase):
|
| 17 |
def __init__(self,
|
| 18 |
model_dir: str,
|
| 19 |
+
output_dir: str,
|
| 20 |
+
args: Namespace
|
| 21 |
):
|
| 22 |
super().__init__(
|
| 23 |
model_dir=model_dir,
|
| 24 |
+
output_dir=output_dir,
|
| 25 |
+
args=args
|
| 26 |
)
|
| 27 |
self.model_paths = self.get_model_paths()
|
| 28 |
self.available_models = self.model_paths.keys()
|
|
|
|
| 44 |
progress: gr.Progress
|
| 45 |
Indicator to show progress directly in gradio.
|
| 46 |
*whisper_params: tuple
|
| 47 |
+
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
| 48 |
|
| 49 |
Returns
|
| 50 |
----------
|
|
|
|
| 55 |
"""
|
| 56 |
start_time = time.time()
|
| 57 |
|
| 58 |
+
params = WhisperParameters.as_value(*whisper_params)
|
| 59 |
|
| 60 |
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
| 61 |
self.update_model(params.model_size, params.compute_type, progress)
|
modules/{insanely_fast_whisper_inference.py β whisper/insanely_fast_whisper_inference.py}
RENAMED
|
@@ -9,19 +9,22 @@ import gradio as gr
|
|
| 9 |
from huggingface_hub import hf_hub_download
|
| 10 |
import whisper
|
| 11 |
from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
|
|
|
|
| 12 |
|
| 13 |
-
from modules.whisper_parameter import *
|
| 14 |
-
from modules.whisper_base import WhisperBase
|
| 15 |
|
| 16 |
|
| 17 |
class InsanelyFastWhisperInference(WhisperBase):
|
| 18 |
def __init__(self,
|
| 19 |
model_dir: str,
|
| 20 |
-
output_dir: str
|
|
|
|
| 21 |
):
|
| 22 |
super().__init__(
|
| 23 |
model_dir=model_dir,
|
| 24 |
-
output_dir=output_dir
|
|
|
|
| 25 |
)
|
| 26 |
openai_models = whisper.available_models()
|
| 27 |
distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
|
|
@@ -43,7 +46,7 @@ class InsanelyFastWhisperInference(WhisperBase):
|
|
| 43 |
progress: gr.Progress
|
| 44 |
Indicator to show progress directly in gradio.
|
| 45 |
*whisper_params: tuple
|
| 46 |
-
|
| 47 |
|
| 48 |
Returns
|
| 49 |
----------
|
|
@@ -53,7 +56,7 @@ class InsanelyFastWhisperInference(WhisperBase):
|
|
| 53 |
elapsed time for transcription
|
| 54 |
"""
|
| 55 |
start_time = time.time()
|
| 56 |
-
params = WhisperParameters.
|
| 57 |
|
| 58 |
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
| 59 |
self.update_model(params.model_size, params.compute_type, progress)
|
|
|
|
| 9 |
from huggingface_hub import hf_hub_download
|
| 10 |
import whisper
|
| 11 |
from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
|
| 12 |
+
from argparse import Namespace
|
| 13 |
|
| 14 |
+
from modules.whisper.whisper_parameter import *
|
| 15 |
+
from modules.whisper.whisper_base import WhisperBase
|
| 16 |
|
| 17 |
|
| 18 |
class InsanelyFastWhisperInference(WhisperBase):
|
| 19 |
def __init__(self,
|
| 20 |
model_dir: str,
|
| 21 |
+
output_dir: str,
|
| 22 |
+
args: Namespace
|
| 23 |
):
|
| 24 |
super().__init__(
|
| 25 |
model_dir=model_dir,
|
| 26 |
+
output_dir=output_dir,
|
| 27 |
+
args=args
|
| 28 |
)
|
| 29 |
openai_models = whisper.available_models()
|
| 30 |
distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
|
|
|
|
| 46 |
progress: gr.Progress
|
| 47 |
Indicator to show progress directly in gradio.
|
| 48 |
*whisper_params: tuple
|
| 49 |
+
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
| 50 |
|
| 51 |
Returns
|
| 52 |
----------
|
|
|
|
| 56 |
elapsed time for transcription
|
| 57 |
"""
|
| 58 |
start_time = time.time()
|
| 59 |
+
params = WhisperParameters.as_value(*whisper_params)
|
| 60 |
|
| 61 |
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
| 62 |
self.update_model(params.model_size, params.compute_type, progress)
|
modules/{whisper_Inference.py β whisper/whisper_Inference.py}
RENAMED
|
@@ -1,23 +1,25 @@
|
|
| 1 |
import whisper
|
| 2 |
import gradio as gr
|
| 3 |
import time
|
| 4 |
-
import os
|
| 5 |
from typing import BinaryIO, Union, Tuple, List
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
|
|
|
| 8 |
|
| 9 |
-
from modules.whisper_base import WhisperBase
|
| 10 |
-
from modules.whisper_parameter import *
|
| 11 |
|
| 12 |
|
| 13 |
class WhisperInference(WhisperBase):
|
| 14 |
def __init__(self,
|
| 15 |
model_dir: str,
|
| 16 |
-
output_dir: str
|
|
|
|
| 17 |
):
|
| 18 |
super().__init__(
|
| 19 |
model_dir=model_dir,
|
| 20 |
-
output_dir=output_dir
|
|
|
|
| 21 |
)
|
| 22 |
|
| 23 |
def transcribe(self,
|
|
@@ -35,7 +37,7 @@ class WhisperInference(WhisperBase):
|
|
| 35 |
progress: gr.Progress
|
| 36 |
Indicator to show progress directly in gradio.
|
| 37 |
*whisper_params: tuple
|
| 38 |
-
|
| 39 |
|
| 40 |
Returns
|
| 41 |
----------
|
|
@@ -45,7 +47,7 @@ class WhisperInference(WhisperBase):
|
|
| 45 |
elapsed time for transcription
|
| 46 |
"""
|
| 47 |
start_time = time.time()
|
| 48 |
-
params = WhisperParameters.
|
| 49 |
|
| 50 |
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
| 51 |
self.update_model(params.model_size, params.compute_type, progress)
|
|
|
|
| 1 |
import whisper
|
| 2 |
import gradio as gr
|
| 3 |
import time
|
|
|
|
| 4 |
from typing import BinaryIO, Union, Tuple, List
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
| 7 |
+
from argparse import Namespace
|
| 8 |
|
| 9 |
+
from modules.whisper.whisper_base import WhisperBase
|
| 10 |
+
from modules.whisper.whisper_parameter import *
|
| 11 |
|
| 12 |
|
| 13 |
class WhisperInference(WhisperBase):
|
| 14 |
def __init__(self,
|
| 15 |
model_dir: str,
|
| 16 |
+
output_dir: str,
|
| 17 |
+
args: Namespace
|
| 18 |
):
|
| 19 |
super().__init__(
|
| 20 |
model_dir=model_dir,
|
| 21 |
+
output_dir=output_dir,
|
| 22 |
+
args=args
|
| 23 |
)
|
| 24 |
|
| 25 |
def transcribe(self,
|
|
|
|
| 37 |
progress: gr.Progress
|
| 38 |
Indicator to show progress directly in gradio.
|
| 39 |
*whisper_params: tuple
|
| 40 |
+
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
| 41 |
|
| 42 |
Returns
|
| 43 |
----------
|
|
|
|
| 47 |
elapsed time for transcription
|
| 48 |
"""
|
| 49 |
start_time = time.time()
|
| 50 |
+
params = WhisperParameters.as_value(*whisper_params)
|
| 51 |
|
| 52 |
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
| 53 |
self.update_model(params.model_size, params.compute_type, progress)
|
modules/{whisper_base.py β whisper/whisper_base.py}
RENAMED
|
@@ -1,22 +1,24 @@
|
|
| 1 |
import os
|
| 2 |
import torch
|
| 3 |
-
from typing import List
|
| 4 |
import whisper
|
| 5 |
import gradio as gr
|
| 6 |
from abc import ABC, abstractmethod
|
| 7 |
from typing import BinaryIO, Union, Tuple, List
|
| 8 |
import numpy as np
|
| 9 |
from datetime import datetime
|
|
|
|
| 10 |
|
| 11 |
-
from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
|
| 12 |
-
from modules.youtube_manager import get_ytdata, get_ytaudio
|
| 13 |
-
from modules.whisper_parameter import *
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
class WhisperBase(ABC):
|
| 17 |
def __init__(self,
|
| 18 |
model_dir: str,
|
| 19 |
-
output_dir: str
|
|
|
|
| 20 |
):
|
| 21 |
self.model = None
|
| 22 |
self.current_model_size = None
|
|
@@ -30,6 +32,9 @@ class WhisperBase(ABC):
|
|
| 30 |
self.device = self.get_device()
|
| 31 |
self.available_compute_types = ["float16", "float32"]
|
| 32 |
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
@abstractmethod
|
| 35 |
def transcribe(self,
|
|
@@ -47,6 +52,55 @@ class WhisperBase(ABC):
|
|
| 47 |
):
|
| 48 |
pass
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
def transcribe_file(self,
|
| 51 |
files: list,
|
| 52 |
file_format: str,
|
|
@@ -68,7 +122,7 @@ class WhisperBase(ABC):
|
|
| 68 |
progress: gr.Progress
|
| 69 |
Indicator to show progress directly in gradio.
|
| 70 |
*whisper_params: tuple
|
| 71 |
-
|
| 72 |
|
| 73 |
Returns
|
| 74 |
----------
|
|
@@ -80,7 +134,7 @@ class WhisperBase(ABC):
|
|
| 80 |
try:
|
| 81 |
files_info = {}
|
| 82 |
for file in files:
|
| 83 |
-
transcribed_segments, time_for_task = self.
|
| 84 |
file.name,
|
| 85 |
progress,
|
| 86 |
*whisper_params,
|
|
@@ -135,7 +189,7 @@ class WhisperBase(ABC):
|
|
| 135 |
progress: gr.Progress
|
| 136 |
Indicator to show progress directly in gradio.
|
| 137 |
*whisper_params: tuple
|
| 138 |
-
|
| 139 |
|
| 140 |
Returns
|
| 141 |
----------
|
|
@@ -146,7 +200,7 @@ class WhisperBase(ABC):
|
|
| 146 |
"""
|
| 147 |
try:
|
| 148 |
progress(0, desc="Loading Audio..")
|
| 149 |
-
transcribed_segments, time_for_task = self.
|
| 150 |
mic_audio,
|
| 151 |
progress,
|
| 152 |
*whisper_params,
|
|
@@ -190,7 +244,7 @@ class WhisperBase(ABC):
|
|
| 190 |
progress: gr.Progress
|
| 191 |
Indicator to show progress directly in gradio.
|
| 192 |
*whisper_params: tuple
|
| 193 |
-
|
| 194 |
|
| 195 |
Returns
|
| 196 |
----------
|
|
@@ -204,7 +258,7 @@ class WhisperBase(ABC):
|
|
| 204 |
yt = get_ytdata(youtube_link)
|
| 205 |
audio = get_ytaudio(yt)
|
| 206 |
|
| 207 |
-
transcribed_segments, time_for_task = self.
|
| 208 |
audio,
|
| 209 |
progress,
|
| 210 |
*whisper_params,
|
|
|
|
| 1 |
import os
|
| 2 |
import torch
|
|
|
|
| 3 |
import whisper
|
| 4 |
import gradio as gr
|
| 5 |
from abc import ABC, abstractmethod
|
| 6 |
from typing import BinaryIO, Union, Tuple, List
|
| 7 |
import numpy as np
|
| 8 |
from datetime import datetime
|
| 9 |
+
from argparse import Namespace
|
| 10 |
|
| 11 |
+
from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
|
| 12 |
+
from modules.utils.youtube_manager import get_ytdata, get_ytaudio
|
| 13 |
+
from modules.whisper.whisper_parameter import *
|
| 14 |
+
from modules.diarize.diarizer import Diarizer
|
| 15 |
|
| 16 |
|
| 17 |
class WhisperBase(ABC):
|
| 18 |
def __init__(self,
|
| 19 |
model_dir: str,
|
| 20 |
+
output_dir: str,
|
| 21 |
+
args: Namespace
|
| 22 |
):
|
| 23 |
self.model = None
|
| 24 |
self.current_model_size = None
|
|
|
|
| 32 |
self.device = self.get_device()
|
| 33 |
self.available_compute_types = ["float16", "float32"]
|
| 34 |
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
|
| 35 |
+
self.diarizer = Diarizer(
|
| 36 |
+
model_dir=args.diarization_model_dir
|
| 37 |
+
)
|
| 38 |
|
| 39 |
@abstractmethod
|
| 40 |
def transcribe(self,
|
|
|
|
| 52 |
):
|
| 53 |
pass
|
| 54 |
|
| 55 |
+
def run(self,
|
| 56 |
+
audio: Union[str, BinaryIO, np.ndarray],
|
| 57 |
+
progress: gr.Progress,
|
| 58 |
+
*whisper_params,
|
| 59 |
+
) -> Tuple[List[dict], float]:
|
| 60 |
+
"""
|
| 61 |
+
Run transcription with conditional post-processing.
|
| 62 |
+
The diarization will be performed in post-processing if enabled.
|
| 63 |
+
|
| 64 |
+
Parameters
|
| 65 |
+
----------
|
| 66 |
+
audio: Union[str, BinaryIO, np.ndarray]
|
| 67 |
+
Audio input. This can be file path or binary type.
|
| 68 |
+
progress: gr.Progress
|
| 69 |
+
Indicator to show progress directly in gradio.
|
| 70 |
+
*whisper_params: tuple
|
| 71 |
+
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
| 72 |
+
|
| 73 |
+
Returns
|
| 74 |
+
----------
|
| 75 |
+
segments_result: List[dict]
|
| 76 |
+
list of dicts that includes start, end timestamps and transcribed text
|
| 77 |
+
elapsed_time: float
|
| 78 |
+
elapsed time for running
|
| 79 |
+
"""
|
| 80 |
+
params = WhisperParameters.as_value(*whisper_params)
|
| 81 |
+
|
| 82 |
+
if params.lang == "Automatic Detection":
|
| 83 |
+
params.lang = None
|
| 84 |
+
else:
|
| 85 |
+
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
|
| 86 |
+
params.lang = language_code_dict[params.lang]
|
| 87 |
+
|
| 88 |
+
result, elapsed_time = self.transcribe(
|
| 89 |
+
audio,
|
| 90 |
+
progress,
|
| 91 |
+
*whisper_params
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
if params.is_diarize:
|
| 95 |
+
result, elapsed_time_diarization = self.diarizer.run(
|
| 96 |
+
audio=audio,
|
| 97 |
+
use_auth_token=params.hf_token,
|
| 98 |
+
transcribed_result=result,
|
| 99 |
+
device=self.device
|
| 100 |
+
)
|
| 101 |
+
elapsed_time += elapsed_time_diarization
|
| 102 |
+
return result, elapsed_time
|
| 103 |
+
|
| 104 |
def transcribe_file(self,
|
| 105 |
files: list,
|
| 106 |
file_format: str,
|
|
|
|
| 122 |
progress: gr.Progress
|
| 123 |
Indicator to show progress directly in gradio.
|
| 124 |
*whisper_params: tuple
|
| 125 |
+
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
| 126 |
|
| 127 |
Returns
|
| 128 |
----------
|
|
|
|
| 134 |
try:
|
| 135 |
files_info = {}
|
| 136 |
for file in files:
|
| 137 |
+
transcribed_segments, time_for_task = self.run(
|
| 138 |
file.name,
|
| 139 |
progress,
|
| 140 |
*whisper_params,
|
|
|
|
| 189 |
progress: gr.Progress
|
| 190 |
Indicator to show progress directly in gradio.
|
| 191 |
*whisper_params: tuple
|
| 192 |
+
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
| 193 |
|
| 194 |
Returns
|
| 195 |
----------
|
|
|
|
| 200 |
"""
|
| 201 |
try:
|
| 202 |
progress(0, desc="Loading Audio..")
|
| 203 |
+
transcribed_segments, time_for_task = self.run(
|
| 204 |
mic_audio,
|
| 205 |
progress,
|
| 206 |
*whisper_params,
|
|
|
|
| 244 |
progress: gr.Progress
|
| 245 |
Indicator to show progress directly in gradio.
|
| 246 |
*whisper_params: tuple
|
| 247 |
+
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
| 248 |
|
| 249 |
Returns
|
| 250 |
----------
|
|
|
|
| 258 |
yt = get_ytdata(youtube_link)
|
| 259 |
audio = get_ytaudio(yt)
|
| 260 |
|
| 261 |
+
transcribed_segments, time_for_task = self.run(
|
| 262 |
audio,
|
| 263 |
progress,
|
| 264 |
*whisper_params,
|
modules/{whisper_parameter.py β whisper/whisper_parameter.py}
RENAMED
|
@@ -27,6 +27,9 @@ class WhisperParameters:
|
|
| 27 |
speech_pad_ms: gr.Number
|
| 28 |
chunk_length_s: gr.Number
|
| 29 |
batch_size: gr.Number
|
|
|
|
|
|
|
|
|
|
| 30 |
"""
|
| 31 |
A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
|
| 32 |
This data class is used to mitigate the key-value problem between Gradio components and function parameters.
|
|
@@ -122,9 +125,19 @@ class WhisperParameters:
|
|
| 122 |
|
| 123 |
batch_size: gr.Number
|
| 124 |
This parameter is related with insanely-fast-whisper pipe. Batch size to pass to the pipe
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
"""
|
| 126 |
|
| 127 |
-
def
|
| 128 |
"""
|
| 129 |
Converts the data class attributes into a list, Use in Gradio UI before Gradio pre-processing.
|
| 130 |
See more about Gradio pre-processing: : https://www.gradio.app/docs/components
|
|
@@ -136,7 +149,7 @@ class WhisperParameters:
|
|
| 136 |
return [getattr(self, f.name) for f in fields(self)]
|
| 137 |
|
| 138 |
@staticmethod
|
| 139 |
-
def
|
| 140 |
"""
|
| 141 |
To use Whisper parameters in function after Gradio post-processing.
|
| 142 |
See more about Gradio post-processing: : https://www.gradio.app/docs/components
|
|
@@ -168,7 +181,10 @@ class WhisperParameters:
|
|
| 168 |
window_size_samples=args[18],
|
| 169 |
speech_pad_ms=args[19],
|
| 170 |
chunk_length_s=args[20],
|
| 171 |
-
batch_size=args[21]
|
|
|
|
|
|
|
|
|
|
| 172 |
)
|
| 173 |
|
| 174 |
|
|
@@ -196,6 +212,9 @@ class WhisperValues:
|
|
| 196 |
speech_pad_ms: int
|
| 197 |
chunk_length_s: int
|
| 198 |
batch_size: int
|
|
|
|
|
|
|
|
|
|
| 199 |
"""
|
| 200 |
A data class to use Whisper parameters.
|
| 201 |
"""
|
|
|
|
| 27 |
speech_pad_ms: gr.Number
|
| 28 |
chunk_length_s: gr.Number
|
| 29 |
batch_size: gr.Number
|
| 30 |
+
is_diarize: gr.Checkbox
|
| 31 |
+
hf_token: gr.Textbox
|
| 32 |
+
diarization_device: gr.Dropdown
|
| 33 |
"""
|
| 34 |
A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
|
| 35 |
This data class is used to mitigate the key-value problem between Gradio components and function parameters.
|
|
|
|
| 125 |
|
| 126 |
batch_size: gr.Number
|
| 127 |
This parameter is related with insanely-fast-whisper pipe. Batch size to pass to the pipe
|
| 128 |
+
|
| 129 |
+
is_diarize: gr.Checkbox
|
| 130 |
+
This parameter is related with whisperx. Boolean value that determines whether to diarize or not.
|
| 131 |
+
|
| 132 |
+
hf_token: gr.Textbox
|
| 133 |
+
This parameter is related with whisperx. Huggingface token is needed to download diarization models.
|
| 134 |
+
Read more about : https://huggingface.co/pyannote/speaker-diarization-3.1#requirements
|
| 135 |
+
|
| 136 |
+
diarization_device: gr.Dropdown
|
| 137 |
+
This parameter is related with whisperx. Device to run diarization model
|
| 138 |
"""
|
| 139 |
|
| 140 |
+
def as_list(self) -> list:
|
| 141 |
"""
|
| 142 |
Converts the data class attributes into a list, Use in Gradio UI before Gradio pre-processing.
|
| 143 |
See more about Gradio pre-processing: : https://www.gradio.app/docs/components
|
|
|
|
| 149 |
return [getattr(self, f.name) for f in fields(self)]
|
| 150 |
|
| 151 |
@staticmethod
|
| 152 |
+
def as_value(*args) -> 'WhisperValues':
|
| 153 |
"""
|
| 154 |
To use Whisper parameters in function after Gradio post-processing.
|
| 155 |
See more about Gradio post-processing: : https://www.gradio.app/docs/components
|
|
|
|
| 181 |
window_size_samples=args[18],
|
| 182 |
speech_pad_ms=args[19],
|
| 183 |
chunk_length_s=args[20],
|
| 184 |
+
batch_size=args[21],
|
| 185 |
+
is_diarize=args[22],
|
| 186 |
+
hf_token=args[23],
|
| 187 |
+
diarization_device=args[24]
|
| 188 |
)
|
| 189 |
|
| 190 |
|
|
|
|
| 212 |
speech_pad_ms: int
|
| 213 |
chunk_length_s: int
|
| 214 |
batch_size: int
|
| 215 |
+
is_diarize: bool
|
| 216 |
+
hf_token: str
|
| 217 |
+
diarization_device: str
|
| 218 |
"""
|
| 219 |
A data class to use Whisper parameters.
|
| 220 |
"""
|
requirements.txt
CHANGED
|
@@ -4,4 +4,5 @@ git+https://github.com/jhj0517/jhj0517-whisper.git
|
|
| 4 |
faster-whisper==1.0.2
|
| 5 |
transformers
|
| 6 |
gradio==4.29.0
|
| 7 |
-
pytube
|
|
|
|
|
|
| 4 |
faster-whisper==1.0.2
|
| 5 |
transformers
|
| 6 |
gradio==4.29.0
|
| 7 |
+
pytube
|
| 8 |
+
pyannote.audio==3.3.1
|