Spaces:
Running
Running
jhj0517
commited on
Commit
·
201b316
1
Parent(s):
595b5f3
add `--diarization_model_dir` cli arg
Browse files- app.py +36 -16
- modules/diarize_pipeline.py +1 -1
- modules/diarizer.py +1 -1
- modules/faster_whisper_inference.py +5 -2
- modules/insanely_fast_whisper_inference.py +5 -2
- modules/whisper_Inference.py +5 -2
- modules/whisper_base.py +6 -2
- modules/whisper_parameter.py +6 -0
app.py
CHANGED
|
@@ -36,23 +36,27 @@ class App:
|
|
| 36 |
if whisper_type in ["faster_whisper", "faster-whisper", "fasterwhisper"]:
|
| 37 |
whisper_inf = FasterWhisperInference(
|
| 38 |
model_dir=self.args.faster_whisper_model_dir,
|
| 39 |
-
output_dir=self.args.output_dir
|
|
|
|
| 40 |
)
|
| 41 |
elif whisper_type in ["whisper"]:
|
| 42 |
whisper_inf = WhisperInference(
|
| 43 |
model_dir=self.args.whisper_model_dir,
|
| 44 |
-
output_dir=self.args.output_dir
|
|
|
|
| 45 |
)
|
| 46 |
elif whisper_type in ["insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
|
| 47 |
"insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"]:
|
| 48 |
whisper_inf = InsanelyFastWhisperInference(
|
| 49 |
model_dir=self.args.insanely_fast_whisper_model_dir,
|
| 50 |
-
output_dir=self.args.output_dir
|
|
|
|
| 51 |
)
|
| 52 |
else:
|
| 53 |
whisper_inf = FasterWhisperInference(
|
| 54 |
model_dir=self.args.faster_whisper_model_dir,
|
| 55 |
-
output_dir=self.args.output_dir
|
|
|
|
| 56 |
)
|
| 57 |
return whisper_inf
|
| 58 |
|
|
@@ -90,7 +94,7 @@ class App:
|
|
| 90 |
cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
|
| 91 |
with gr.Row():
|
| 92 |
cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename", interactive=True)
|
| 93 |
-
with gr.Accordion("
|
| 94 |
nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
|
| 95 |
nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
|
| 96 |
nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
|
|
@@ -101,7 +105,7 @@ class App:
|
|
| 101 |
tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
|
| 102 |
sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
|
| 103 |
nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True)
|
| 104 |
-
with gr.Accordion("VAD
|
| 105 |
cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
|
| 106 |
sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
|
| 107 |
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
|
|
@@ -109,12 +113,14 @@ class App:
|
|
| 109 |
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
|
| 110 |
nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
|
| 111 |
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
with gr.Accordion("Insanely Fast Whisper Parameters", open=False, visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
|
| 113 |
nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
|
| 114 |
nb_batch_size = gr.Number(label="Batch Size", value=24, precision=0)
|
| 115 |
-
with gr.Accordion("Diarization Parameters", open=False):
|
| 116 |
-
cb_diarize = gr.Checkbox(label="Enable Diarization")
|
| 117 |
-
tb_hf_token = gr.Text(label="HuggingFace Token", value="")
|
| 118 |
with gr.Row():
|
| 119 |
btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
|
| 120 |
with gr.Row():
|
|
@@ -146,7 +152,8 @@ class App:
|
|
| 146 |
chunk_length_s=nb_chunk_length_s,
|
| 147 |
batch_size=nb_batch_size,
|
| 148 |
is_diarize=cb_diarize,
|
| 149 |
-
hf_token=tb_hf_token
|
|
|
|
| 150 |
|
| 151 |
btn_run.click(fn=self.whisper_inf.transcribe_file,
|
| 152 |
inputs=params + whisper_params.as_list(),
|
|
@@ -174,7 +181,7 @@ class App:
|
|
| 174 |
with gr.Row():
|
| 175 |
cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
|
| 176 |
interactive=True)
|
| 177 |
-
with gr.Accordion("
|
| 178 |
nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
|
| 179 |
nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
|
| 180 |
nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
|
|
@@ -185,7 +192,7 @@ class App:
|
|
| 185 |
tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
|
| 186 |
sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
|
| 187 |
nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True)
|
| 188 |
-
with gr.Accordion("VAD
|
| 189 |
cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
|
| 190 |
sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
|
| 191 |
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
|
|
@@ -193,6 +200,11 @@ class App:
|
|
| 193 |
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
|
| 194 |
nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
|
| 195 |
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
with gr.Accordion("Insanely Fast Whisper Parameters", open=False,
|
| 197 |
visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
|
| 198 |
nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
|
|
@@ -228,7 +240,8 @@ class App:
|
|
| 228 |
chunk_length_s=nb_chunk_length_s,
|
| 229 |
batch_size=nb_batch_size,
|
| 230 |
is_diarize=cb_diarize,
|
| 231 |
-
hf_token=tb_hf_token
|
|
|
|
| 232 |
|
| 233 |
btn_run.click(fn=self.whisper_inf.transcribe_youtube,
|
| 234 |
inputs=params + whisper_params.as_list(),
|
|
@@ -249,7 +262,7 @@ class App:
|
|
| 249 |
dd_file_format = gr.Dropdown(["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
|
| 250 |
with gr.Row():
|
| 251 |
cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
|
| 252 |
-
with gr.Accordion("
|
| 253 |
nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
|
| 254 |
nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
|
| 255 |
nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
|
|
@@ -259,7 +272,7 @@ class App:
|
|
| 259 |
cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=True, interactive=True)
|
| 260 |
tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
|
| 261 |
sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
|
| 262 |
-
with gr.Accordion("VAD
|
| 263 |
cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
|
| 264 |
sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
|
| 265 |
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
|
|
@@ -267,6 +280,11 @@ class App:
|
|
| 267 |
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
|
| 268 |
nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
|
| 269 |
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
with gr.Accordion("Insanely Fast Whisper Parameters", open=False,
|
| 271 |
visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
|
| 272 |
nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
|
|
@@ -302,7 +320,8 @@ class App:
|
|
| 302 |
chunk_length_s=nb_chunk_length_s,
|
| 303 |
batch_size=nb_batch_size,
|
| 304 |
is_diarize=cb_diarize,
|
| 305 |
-
hf_token=tb_hf_token
|
|
|
|
| 306 |
|
| 307 |
btn_run.click(fn=self.whisper_inf.transcribe_mic,
|
| 308 |
inputs=params + whisper_params.as_list(),
|
|
@@ -404,6 +423,7 @@ parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=Tru
|
|
| 404 |
parser.add_argument('--whisper_model_dir', type=str, default=os.path.join("models", "Whisper"), help='Directory path of the whisper model')
|
| 405 |
parser.add_argument('--faster_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "faster-whisper"), help='Directory path of the faster-whisper model')
|
| 406 |
parser.add_argument('--insanely_fast_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "insanely-fast-whisper"), help='Directory path of the insanely-fast-whisper model')
|
|
|
|
| 407 |
parser.add_argument('--nllb_model_dir', type=str, default=os.path.join("models", "NLLB"), help='Directory path of the Facebook NLLB model')
|
| 408 |
parser.add_argument('--output_dir', type=str, default=os.path.join("outputs"), help='Directory path of the outputs')
|
| 409 |
_args = parser.parse_args()
|
|
|
|
| 36 |
if whisper_type in ["faster_whisper", "faster-whisper", "fasterwhisper"]:
|
| 37 |
whisper_inf = FasterWhisperInference(
|
| 38 |
model_dir=self.args.faster_whisper_model_dir,
|
| 39 |
+
output_dir=self.args.output_dir,
|
| 40 |
+
args=self.args
|
| 41 |
)
|
| 42 |
elif whisper_type in ["whisper"]:
|
| 43 |
whisper_inf = WhisperInference(
|
| 44 |
model_dir=self.args.whisper_model_dir,
|
| 45 |
+
output_dir=self.args.output_dir,
|
| 46 |
+
args=self.args
|
| 47 |
)
|
| 48 |
elif whisper_type in ["insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
|
| 49 |
"insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"]:
|
| 50 |
whisper_inf = InsanelyFastWhisperInference(
|
| 51 |
model_dir=self.args.insanely_fast_whisper_model_dir,
|
| 52 |
+
output_dir=self.args.output_dir,
|
| 53 |
+
args=self.args
|
| 54 |
)
|
| 55 |
else:
|
| 56 |
whisper_inf = FasterWhisperInference(
|
| 57 |
model_dir=self.args.faster_whisper_model_dir,
|
| 58 |
+
output_dir=self.args.output_dir,
|
| 59 |
+
args=self.args
|
| 60 |
)
|
| 61 |
return whisper_inf
|
| 62 |
|
|
|
|
| 94 |
cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
|
| 95 |
with gr.Row():
|
| 96 |
cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename", interactive=True)
|
| 97 |
+
with gr.Accordion("Advanced Parameters", open=False):
|
| 98 |
nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
|
| 99 |
nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
|
| 100 |
nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
|
|
|
|
| 105 |
tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
|
| 106 |
sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
|
| 107 |
nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True)
|
| 108 |
+
with gr.Accordion("VAD", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
|
| 109 |
cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
|
| 110 |
sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
|
| 111 |
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
|
|
|
|
| 113 |
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
|
| 114 |
nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
|
| 115 |
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
|
| 116 |
+
with gr.Accordion("Diarization", open=False):
|
| 117 |
+
cb_diarize = gr.Checkbox(label="Enable Diarization")
|
| 118 |
+
tb_hf_token = gr.Text(label="HuggingFace Token", value="",
|
| 119 |
+
info="This is only needed the first time you download the model. If you already have models, you don't need to enter.")
|
| 120 |
+
dd_diarization_device = gr.Dropdown(label="Device", choices=self.whisper_inf.diarizer.get_available_device(), value=self.whisper_inf.diarizer.get_device())
|
| 121 |
with gr.Accordion("Insanely Fast Whisper Parameters", open=False, visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
|
| 122 |
nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
|
| 123 |
nb_batch_size = gr.Number(label="Batch Size", value=24, precision=0)
|
|
|
|
|
|
|
|
|
|
| 124 |
with gr.Row():
|
| 125 |
btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
|
| 126 |
with gr.Row():
|
|
|
|
| 152 |
chunk_length_s=nb_chunk_length_s,
|
| 153 |
batch_size=nb_batch_size,
|
| 154 |
is_diarize=cb_diarize,
|
| 155 |
+
hf_token=tb_hf_token,
|
| 156 |
+
diarization_device=dd_diarization_device)
|
| 157 |
|
| 158 |
btn_run.click(fn=self.whisper_inf.transcribe_file,
|
| 159 |
inputs=params + whisper_params.as_list(),
|
|
|
|
| 181 |
with gr.Row():
|
| 182 |
cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
|
| 183 |
interactive=True)
|
| 184 |
+
with gr.Accordion("Advanced Parameters", open=False):
|
| 185 |
nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
|
| 186 |
nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
|
| 187 |
nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
|
|
|
|
| 192 |
tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
|
| 193 |
sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
|
| 194 |
nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True)
|
| 195 |
+
with gr.Accordion("VAD", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
|
| 196 |
cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
|
| 197 |
sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
|
| 198 |
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
|
|
|
|
| 200 |
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
|
| 201 |
nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
|
| 202 |
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
|
| 203 |
+
with gr.Accordion("Diarization", open=False):
|
| 204 |
+
cb_diarize = gr.Checkbox(label="Enable Diarization")
|
| 205 |
+
tb_hf_token = gr.Text(label="HuggingFace Token", value="",
|
| 206 |
+
info="This is only needed the first time you download the model. If you already have models, you don't need to enter.")
|
| 207 |
+
dd_diarization_device = gr.Dropdown(label="Device", choices=self.whisper_inf.diarizer.get_available_device(), value=self.whisper_inf.diarizer.get_device())
|
| 208 |
with gr.Accordion("Insanely Fast Whisper Parameters", open=False,
|
| 209 |
visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
|
| 210 |
nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
|
|
|
|
| 240 |
chunk_length_s=nb_chunk_length_s,
|
| 241 |
batch_size=nb_batch_size,
|
| 242 |
is_diarize=cb_diarize,
|
| 243 |
+
hf_token=tb_hf_token,
|
| 244 |
+
diarization_device=dd_diarization_device)
|
| 245 |
|
| 246 |
btn_run.click(fn=self.whisper_inf.transcribe_youtube,
|
| 247 |
inputs=params + whisper_params.as_list(),
|
|
|
|
| 262 |
dd_file_format = gr.Dropdown(["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
|
| 263 |
with gr.Row():
|
| 264 |
cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
|
| 265 |
+
with gr.Accordion("Advanced Parameters", open=False):
|
| 266 |
nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
|
| 267 |
nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
|
| 268 |
nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
|
|
|
|
| 272 |
cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=True, interactive=True)
|
| 273 |
tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
|
| 274 |
sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
|
| 275 |
+
with gr.Accordion("VAD", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
|
| 276 |
cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
|
| 277 |
sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
|
| 278 |
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
|
|
|
|
| 280 |
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
|
| 281 |
nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
|
| 282 |
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
|
| 283 |
+
with gr.Accordion("Diarization", open=False):
|
| 284 |
+
cb_diarize = gr.Checkbox(label="Enable Diarization")
|
| 285 |
+
tb_hf_token = gr.Text(label="HuggingFace Token", value="",
|
| 286 |
+
info="This is only needed the first time you download the model. If you already have models, you don't need to enter.")
|
| 287 |
+
dd_diarization_device = gr.Dropdown(label="Device", choices=self.whisper_inf.diarizer.get_available_device(), value=self.whisper_inf.diarizer.get_device())
|
| 288 |
with gr.Accordion("Insanely Fast Whisper Parameters", open=False,
|
| 289 |
visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
|
| 290 |
nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
|
|
|
|
| 320 |
chunk_length_s=nb_chunk_length_s,
|
| 321 |
batch_size=nb_batch_size,
|
| 322 |
is_diarize=cb_diarize,
|
| 323 |
+
hf_token=tb_hf_token,
|
| 324 |
+
diarization_device=dd_diarization_device)
|
| 325 |
|
| 326 |
btn_run.click(fn=self.whisper_inf.transcribe_mic,
|
| 327 |
inputs=params + whisper_params.as_list(),
|
|
|
|
| 423 |
parser.add_argument('--whisper_model_dir', type=str, default=os.path.join("models", "Whisper"), help='Directory path of the whisper model')
|
| 424 |
parser.add_argument('--faster_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "faster-whisper"), help='Directory path of the faster-whisper model')
|
| 425 |
parser.add_argument('--insanely_fast_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "insanely-fast-whisper"), help='Directory path of the insanely-fast-whisper model')
|
| 426 |
+
parser.add_argument('--diarization_model_dir', type=str, default=os.path.join("models", "Diarization"), help='Directory path of the diarization model')
|
| 427 |
parser.add_argument('--nllb_model_dir', type=str, default=os.path.join("models", "NLLB"), help='Directory path of the Facebook NLLB model')
|
| 428 |
parser.add_argument('--output_dir', type=str, default=os.path.join("outputs"), help='Directory path of the outputs')
|
| 429 |
_args = parser.parse_args()
|
modules/diarize_pipeline.py
CHANGED
|
@@ -11,7 +11,7 @@ class DiarizationPipeline:
|
|
| 11 |
def __init__(
|
| 12 |
self,
|
| 13 |
model_name="pyannote/speaker-diarization-3.1",
|
| 14 |
-
cache_dir: str = os.path.join("models", "
|
| 15 |
use_auth_token=None,
|
| 16 |
device: Optional[Union[str, torch.device]] = "cpu",
|
| 17 |
):
|
|
|
|
| 11 |
def __init__(
|
| 12 |
self,
|
| 13 |
model_name="pyannote/speaker-diarization-3.1",
|
| 14 |
+
cache_dir: str = os.path.join("models", "Diarization"),
|
| 15 |
use_auth_token=None,
|
| 16 |
device: Optional[Union[str, torch.device]] = "cpu",
|
| 17 |
):
|
modules/diarizer.py
CHANGED
|
@@ -9,7 +9,7 @@ from modules.diarize_pipeline import DiarizationPipeline
|
|
| 9 |
|
| 10 |
class Diarizer:
|
| 11 |
def __init__(self,
|
| 12 |
-
model_dir: str = os.path.join("models", "
|
| 13 |
):
|
| 14 |
self.device = self.get_device()
|
| 15 |
self.available_device = self.get_available_device()
|
|
|
|
| 9 |
|
| 10 |
class Diarizer:
|
| 11 |
def __init__(self,
|
| 12 |
+
model_dir: str = os.path.join("models", "Diarization")
|
| 13 |
):
|
| 14 |
self.device = self.get_device()
|
| 15 |
self.available_device = self.get_available_device()
|
modules/faster_whisper_inference.py
CHANGED
|
@@ -7,6 +7,7 @@ from faster_whisper.vad import VadOptions
|
|
| 7 |
import ctranslate2
|
| 8 |
import whisper
|
| 9 |
import gradio as gr
|
|
|
|
| 10 |
|
| 11 |
from modules.whisper_parameter import *
|
| 12 |
from modules.whisper_base import WhisperBase
|
|
@@ -15,11 +16,13 @@ from modules.whisper_base import WhisperBase
|
|
| 15 |
class FasterWhisperInference(WhisperBase):
|
| 16 |
def __init__(self,
|
| 17 |
model_dir: str,
|
| 18 |
-
output_dir: str
|
|
|
|
| 19 |
):
|
| 20 |
super().__init__(
|
| 21 |
model_dir=model_dir,
|
| 22 |
-
output_dir=output_dir
|
|
|
|
| 23 |
)
|
| 24 |
self.model_paths = self.get_model_paths()
|
| 25 |
self.available_models = self.model_paths.keys()
|
|
|
|
| 7 |
import ctranslate2
|
| 8 |
import whisper
|
| 9 |
import gradio as gr
|
| 10 |
+
from argparse import Namespace
|
| 11 |
|
| 12 |
from modules.whisper_parameter import *
|
| 13 |
from modules.whisper_base import WhisperBase
|
|
|
|
| 16 |
class FasterWhisperInference(WhisperBase):
|
| 17 |
def __init__(self,
|
| 18 |
model_dir: str,
|
| 19 |
+
output_dir: str,
|
| 20 |
+
args: Namespace
|
| 21 |
):
|
| 22 |
super().__init__(
|
| 23 |
model_dir=model_dir,
|
| 24 |
+
output_dir=output_dir,
|
| 25 |
+
args=args
|
| 26 |
)
|
| 27 |
self.model_paths = self.get_model_paths()
|
| 28 |
self.available_models = self.model_paths.keys()
|
modules/insanely_fast_whisper_inference.py
CHANGED
|
@@ -9,6 +9,7 @@ import gradio as gr
|
|
| 9 |
from huggingface_hub import hf_hub_download
|
| 10 |
import whisper
|
| 11 |
from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
|
|
|
|
| 12 |
|
| 13 |
from modules.whisper_parameter import *
|
| 14 |
from modules.whisper_base import WhisperBase
|
|
@@ -17,11 +18,13 @@ from modules.whisper_base import WhisperBase
|
|
| 17 |
class InsanelyFastWhisperInference(WhisperBase):
|
| 18 |
def __init__(self,
|
| 19 |
model_dir: str,
|
| 20 |
-
output_dir: str
|
|
|
|
| 21 |
):
|
| 22 |
super().__init__(
|
| 23 |
model_dir=model_dir,
|
| 24 |
-
output_dir=output_dir
|
|
|
|
| 25 |
)
|
| 26 |
openai_models = whisper.available_models()
|
| 27 |
distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
|
|
|
|
| 9 |
from huggingface_hub import hf_hub_download
|
| 10 |
import whisper
|
| 11 |
from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
|
| 12 |
+
from argparse import Namespace
|
| 13 |
|
| 14 |
from modules.whisper_parameter import *
|
| 15 |
from modules.whisper_base import WhisperBase
|
|
|
|
| 18 |
class InsanelyFastWhisperInference(WhisperBase):
|
| 19 |
def __init__(self,
|
| 20 |
model_dir: str,
|
| 21 |
+
output_dir: str,
|
| 22 |
+
args: Namespace
|
| 23 |
):
|
| 24 |
super().__init__(
|
| 25 |
model_dir=model_dir,
|
| 26 |
+
output_dir=output_dir,
|
| 27 |
+
args=args
|
| 28 |
)
|
| 29 |
openai_models = whisper.available_models()
|
| 30 |
distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
|
modules/whisper_Inference.py
CHANGED
|
@@ -5,6 +5,7 @@ import os
|
|
| 5 |
from typing import BinaryIO, Union, Tuple, List
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
|
|
|
| 8 |
|
| 9 |
from modules.whisper_base import WhisperBase
|
| 10 |
from modules.whisper_parameter import *
|
|
@@ -13,11 +14,13 @@ from modules.whisper_parameter import *
|
|
| 13 |
class WhisperInference(WhisperBase):
|
| 14 |
def __init__(self,
|
| 15 |
model_dir: str,
|
| 16 |
-
output_dir: str
|
|
|
|
| 17 |
):
|
| 18 |
super().__init__(
|
| 19 |
model_dir=model_dir,
|
| 20 |
-
output_dir=output_dir
|
|
|
|
| 21 |
)
|
| 22 |
|
| 23 |
def transcribe(self,
|
|
|
|
| 5 |
from typing import BinaryIO, Union, Tuple, List
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
| 8 |
+
from argparse import Namespace
|
| 9 |
|
| 10 |
from modules.whisper_base import WhisperBase
|
| 11 |
from modules.whisper_parameter import *
|
|
|
|
| 14 |
class WhisperInference(WhisperBase):
|
| 15 |
def __init__(self,
|
| 16 |
model_dir: str,
|
| 17 |
+
output_dir: str,
|
| 18 |
+
args: Namespace
|
| 19 |
):
|
| 20 |
super().__init__(
|
| 21 |
model_dir=model_dir,
|
| 22 |
+
output_dir=output_dir,
|
| 23 |
+
args=args
|
| 24 |
)
|
| 25 |
|
| 26 |
def transcribe(self,
|
modules/whisper_base.py
CHANGED
|
@@ -7,6 +7,7 @@ from abc import ABC, abstractmethod
|
|
| 7 |
from typing import BinaryIO, Union, Tuple, List
|
| 8 |
import numpy as np
|
| 9 |
from datetime import datetime
|
|
|
|
| 10 |
import time
|
| 11 |
|
| 12 |
from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
|
|
@@ -18,7 +19,8 @@ from modules.diarizer import Diarizer
|
|
| 18 |
class WhisperBase(ABC):
|
| 19 |
def __init__(self,
|
| 20 |
model_dir: str,
|
| 21 |
-
output_dir: str
|
|
|
|
| 22 |
):
|
| 23 |
self.model = None
|
| 24 |
self.current_model_size = None
|
|
@@ -32,7 +34,9 @@ class WhisperBase(ABC):
|
|
| 32 |
self.device = self.get_device()
|
| 33 |
self.available_compute_types = ["float16", "float32"]
|
| 34 |
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
|
| 35 |
-
self.diarizer = Diarizer(
|
|
|
|
|
|
|
| 36 |
|
| 37 |
@abstractmethod
|
| 38 |
def transcribe(self,
|
|
|
|
| 7 |
from typing import BinaryIO, Union, Tuple, List
|
| 8 |
import numpy as np
|
| 9 |
from datetime import datetime
|
| 10 |
+
from argparse import Namespace
|
| 11 |
import time
|
| 12 |
|
| 13 |
from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
|
|
|
|
| 19 |
class WhisperBase(ABC):
|
| 20 |
def __init__(self,
|
| 21 |
model_dir: str,
|
| 22 |
+
output_dir: str,
|
| 23 |
+
args: Namespace
|
| 24 |
):
|
| 25 |
self.model = None
|
| 26 |
self.current_model_size = None
|
|
|
|
| 34 |
self.device = self.get_device()
|
| 35 |
self.available_compute_types = ["float16", "float32"]
|
| 36 |
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
|
| 37 |
+
self.diarizer = Diarizer(
|
| 38 |
+
model_dir=args.diarization_model_dir
|
| 39 |
+
)
|
| 40 |
|
| 41 |
@abstractmethod
|
| 42 |
def transcribe(self,
|
modules/whisper_parameter.py
CHANGED
|
@@ -29,6 +29,7 @@ class WhisperParameters:
|
|
| 29 |
batch_size: gr.Number
|
| 30 |
is_diarize: gr.Checkbox
|
| 31 |
hf_token: gr.Textbox
|
|
|
|
| 32 |
"""
|
| 33 |
A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
|
| 34 |
This data class is used to mitigate the key-value problem between Gradio components and function parameters.
|
|
@@ -131,6 +132,9 @@ class WhisperParameters:
|
|
| 131 |
hf_token: gr.Textbox
|
| 132 |
This parameter is related with whisperx. Huggingface token is needed to download diarization models.
|
| 133 |
Read more about : https://huggingface.co/pyannote/speaker-diarization-3.1#requirements
|
|
|
|
|
|
|
|
|
|
| 134 |
"""
|
| 135 |
|
| 136 |
def as_list(self) -> list:
|
|
@@ -180,6 +184,7 @@ class WhisperParameters:
|
|
| 180 |
batch_size=args[21],
|
| 181 |
is_diarize=args[22],
|
| 182 |
hf_token=args[23],
|
|
|
|
| 183 |
)
|
| 184 |
|
| 185 |
|
|
@@ -209,6 +214,7 @@ class WhisperValues:
|
|
| 209 |
batch_size: int
|
| 210 |
is_diarize: bool
|
| 211 |
hf_token: str
|
|
|
|
| 212 |
"""
|
| 213 |
A data class to use Whisper parameters.
|
| 214 |
"""
|
|
|
|
| 29 |
batch_size: gr.Number
|
| 30 |
is_diarize: gr.Checkbox
|
| 31 |
hf_token: gr.Textbox
|
| 32 |
+
diarization_device: gr.Dropdown
|
| 33 |
"""
|
| 34 |
A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
|
| 35 |
This data class is used to mitigate the key-value problem between Gradio components and function parameters.
|
|
|
|
| 132 |
hf_token: gr.Textbox
|
| 133 |
This parameter is related with whisperx. Huggingface token is needed to download diarization models.
|
| 134 |
Read more about : https://huggingface.co/pyannote/speaker-diarization-3.1#requirements
|
| 135 |
+
|
| 136 |
+
diarization_device: gr.Dropdown
|
| 137 |
+
This parameter is related with whisperx. Device to run diarization model
|
| 138 |
"""
|
| 139 |
|
| 140 |
def as_list(self) -> list:
|
|
|
|
| 184 |
batch_size=args[21],
|
| 185 |
is_diarize=args[22],
|
| 186 |
hf_token=args[23],
|
| 187 |
+
diarization_device=args[24]
|
| 188 |
)
|
| 189 |
|
| 190 |
|
|
|
|
| 214 |
batch_size: int
|
| 215 |
is_diarize: bool
|
| 216 |
hf_token: str
|
| 217 |
+
diarization_device: str
|
| 218 |
"""
|
| 219 |
A data class to use Whisper parameters.
|
| 220 |
"""
|