dahyedahye commited on
Commit
ae3884d
·
1 Parent(s): 9271e46

Add application file

Browse files
Files changed (38) hide show
  1. models/models will be saved here.txt +0 -0
  2. modules/__init__.py +0 -0
  3. modules/__pycache__/__init__.cpython-310.pyc +0 -0
  4. modules/diarize/__init__.py +0 -0
  5. modules/diarize/__pycache__/__init__.cpython-310.pyc +0 -0
  6. modules/diarize/__pycache__/diarize_pipeline.cpython-310.pyc +0 -0
  7. modules/diarize/__pycache__/diarizer.cpython-310.pyc +0 -0
  8. modules/diarize/audio_loader.py +179 -0
  9. modules/diarize/diarize_pipeline.py +94 -0
  10. modules/diarize/diarizer.py +132 -0
  11. modules/translation/__init__.py +0 -0
  12. modules/translation/deepl_api.py +201 -0
  13. modules/translation/nllb_inference.py +276 -0
  14. modules/translation/translation_base.py +151 -0
  15. modules/utils/__init__.py +0 -0
  16. modules/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  17. modules/utils/__pycache__/files_manager.cpython-310.pyc +0 -0
  18. modules/utils/__pycache__/subtitle_manager.cpython-310.pyc +0 -0
  19. modules/utils/__pycache__/youtube_manager.cpython-310.pyc +0 -0
  20. modules/utils/files_manager.py +39 -0
  21. modules/utils/subtitle_manager.py +135 -0
  22. modules/utils/youtube_manager.py +15 -0
  23. modules/vad/__init__.py +0 -0
  24. modules/vad/silero_vad.py +264 -0
  25. modules/whisper/__init__.py +0 -0
  26. modules/whisper/__pycache__/__init__.cpython-310.pyc +0 -0
  27. modules/whisper/__pycache__/faster_whisper_inference.cpython-310.pyc +0 -0
  28. modules/whisper/__pycache__/whisper_base.cpython-310.pyc +0 -0
  29. modules/whisper/__pycache__/whisper_factory.cpython-310.pyc +0 -0
  30. modules/whisper/__pycache__/whisper_parameter.cpython-310.pyc +0 -0
  31. modules/whisper/faster_whisper_inference.py +191 -0
  32. modules/whisper/insanely_fast_whisper_inference.py +185 -0
  33. modules/whisper/whisper_Inference.py +101 -0
  34. modules/whisper/whisper_base.py +436 -0
  35. modules/whisper/whisper_factory.py +81 -0
  36. modules/whisper/whisper_parameter.py +277 -0
  37. outputs/outputs are saved here.txt +0 -0
  38. outputs/translations/outputs for translation are saved here.txt +0 -0
models/models will be saved here.txt ADDED
File without changes
modules/__init__.py ADDED
File without changes
modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (167 Bytes). View file
 
modules/diarize/__init__.py ADDED
File without changes
modules/diarize/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (175 Bytes). View file
 
modules/diarize/__pycache__/diarize_pipeline.cpython-310.pyc ADDED
Binary file (3.06 kB). View file
 
modules/diarize/__pycache__/diarizer.cpython-310.pyc ADDED
Binary file (4.14 kB). View file
 
modules/diarize/audio_loader.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/m-bain/whisperX/blob/main/whisperx/audio.py
2
+
3
+ import os
4
+ import subprocess
5
+ from functools import lru_cache
6
+ from typing import Optional, Union
7
+ from scipy.io.wavfile import write
8
+ import tempfile
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+ def exact_div(x, y):
15
+ assert x % y == 0
16
+ return x // y
17
+
18
+ # hard-coded audio hyperparameters
19
+ SAMPLE_RATE = 16000
20
+ N_FFT = 400
21
+ HOP_LENGTH = 160
22
+ CHUNK_LENGTH = 30
23
+ N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
24
+ N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
25
+
26
+ N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
27
+ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
28
+ TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
29
+
30
+
31
+ def load_audio(file: Union[str, np.ndarray], sr: int = SAMPLE_RATE) -> np.ndarray:
32
+ """
33
+ Open an audio file or process a numpy array containing audio data as mono waveform, resampling as necessary.
34
+
35
+ Parameters
36
+ ----------
37
+ file: Union[str, np.ndarray]
38
+ The audio file to open or a numpy array containing the audio data.
39
+
40
+ sr: int
41
+ The sample rate to resample the audio if necessary.
42
+
43
+ Returns
44
+ -------
45
+ A NumPy array containing the audio waveform, in float32 dtype.
46
+ """
47
+ if isinstance(file, np.ndarray):
48
+ if file.dtype != np.float32:
49
+ file = file.astype(np.float32)
50
+ if file.ndim > 1:
51
+ file = np.mean(file, axis=1)
52
+
53
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
54
+ write(temp_file.name, SAMPLE_RATE, (file * 32768).astype(np.int16))
55
+ temp_file_path = temp_file.name
56
+ temp_file.close()
57
+ else:
58
+ temp_file_path = file
59
+
60
+ try:
61
+ cmd = [
62
+ "ffmpeg",
63
+ "-nostdin",
64
+ "-threads",
65
+ "0",
66
+ "-i",
67
+ temp_file_path,
68
+ "-f",
69
+ "s16le",
70
+ "-ac",
71
+ "1",
72
+ "-acodec",
73
+ "pcm_s16le",
74
+ "-ar",
75
+ str(sr),
76
+ "-",
77
+ ]
78
+ out = subprocess.run(cmd, capture_output=True, check=True).stdout
79
+ except subprocess.CalledProcessError as e:
80
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
81
+ finally:
82
+ if isinstance(file, np.ndarray):
83
+ os.remove(temp_file_path)
84
+
85
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
86
+
87
+
88
+ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
89
+ """
90
+ Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
91
+ """
92
+ if torch.is_tensor(array):
93
+ if array.shape[axis] > length:
94
+ array = array.index_select(
95
+ dim=axis, index=torch.arange(length, device=array.device)
96
+ )
97
+
98
+ if array.shape[axis] < length:
99
+ pad_widths = [(0, 0)] * array.ndim
100
+ pad_widths[axis] = (0, length - array.shape[axis])
101
+ array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
102
+ else:
103
+ if array.shape[axis] > length:
104
+ array = array.take(indices=range(length), axis=axis)
105
+
106
+ if array.shape[axis] < length:
107
+ pad_widths = [(0, 0)] * array.ndim
108
+ pad_widths[axis] = (0, length - array.shape[axis])
109
+ array = np.pad(array, pad_widths)
110
+
111
+ return array
112
+
113
+
114
+ @lru_cache(maxsize=None)
115
+ def mel_filters(device, n_mels: int) -> torch.Tensor:
116
+ """
117
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
118
+ Allows decoupling librosa dependency; saved using:
119
+
120
+ np.savez_compressed(
121
+ "mel_filters.npz",
122
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
123
+ )
124
+ """
125
+ assert n_mels in [80, 128], f"Unsupported n_mels: {n_mels}"
126
+ with np.load(
127
+ os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
128
+ ) as f:
129
+ return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
130
+
131
+
132
+ def log_mel_spectrogram(
133
+ audio: Union[str, np.ndarray, torch.Tensor],
134
+ n_mels: int,
135
+ padding: int = 0,
136
+ device: Optional[Union[str, torch.device]] = None,
137
+ ):
138
+ """
139
+ Compute the log-Mel spectrogram of
140
+
141
+ Parameters
142
+ ----------
143
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
144
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
145
+
146
+ n_mels: int
147
+ The number of Mel-frequency filters, only 80 is supported
148
+
149
+ padding: int
150
+ Number of zero samples to pad to the right
151
+
152
+ device: Optional[Union[str, torch.device]]
153
+ If given, the audio tensor is moved to this device before STFT
154
+
155
+ Returns
156
+ -------
157
+ torch.Tensor, shape = (80, n_frames)
158
+ A Tensor that contains the Mel spectrogram
159
+ """
160
+ if not torch.is_tensor(audio):
161
+ if isinstance(audio, str):
162
+ audio = load_audio(audio)
163
+ audio = torch.from_numpy(audio)
164
+
165
+ if device is not None:
166
+ audio = audio.to(device)
167
+ if padding > 0:
168
+ audio = F.pad(audio, (0, padding))
169
+ window = torch.hann_window(N_FFT).to(audio.device)
170
+ stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
171
+ magnitudes = stft[..., :-1].abs() ** 2
172
+
173
+ filters = mel_filters(audio.device, n_mels)
174
+ mel_spec = filters @ magnitudes
175
+
176
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
177
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
178
+ log_spec = (log_spec + 4.0) / 4.0
179
+ return log_spec
modules/diarize/diarize_pipeline.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/m-bain/whisperX/blob/main/whisperx/diarize.py
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import os
6
+ from pyannote.audio import Pipeline
7
+ from typing import Optional, Union
8
+ import torch
9
+
10
+ from modules.diarize.audio_loader import load_audio, SAMPLE_RATE
11
+
12
+
13
+ class DiarizationPipeline:
14
+ def __init__(
15
+ self,
16
+ model_name="pyannote/speaker-diarization-3.1",
17
+ cache_dir: str = os.path.join("models", "Diarization"),
18
+ use_auth_token=None,
19
+ device: Optional[Union[str, torch.device]] = "cpu",
20
+ ):
21
+ if isinstance(device, str):
22
+ device = torch.device(device)
23
+ self.model = Pipeline.from_pretrained(
24
+ model_name,
25
+ use_auth_token=use_auth_token,
26
+ cache_dir=cache_dir
27
+ ).to(device)
28
+
29
+ def __call__(self, audio: Union[str, np.ndarray], min_speakers=None, max_speakers=None):
30
+ if isinstance(audio, str):
31
+ audio = load_audio(audio)
32
+ audio_data = {
33
+ 'waveform': torch.from_numpy(audio[None, :]),
34
+ 'sample_rate': SAMPLE_RATE
35
+ }
36
+ segments = self.model(audio_data, min_speakers=min_speakers, max_speakers=max_speakers)
37
+ diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
38
+ diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start)
39
+ diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end)
40
+ return diarize_df
41
+
42
+
43
+ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
44
+ transcript_segments = transcript_result["segments"]
45
+ for seg in transcript_segments:
46
+ # assign speaker to segment (if any)
47
+ diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'],
48
+ seg['start'])
49
+ diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start'])
50
+
51
+ intersected = diarize_df[diarize_df["intersection"] > 0]
52
+
53
+ speaker = None
54
+ if len(intersected) > 0:
55
+ # Choosing most strong intersection
56
+ speaker = intersected.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
57
+ elif fill_nearest:
58
+ # Otherwise choosing closest
59
+ speaker = diarize_df.sort_values(by=["intersection"], ascending=False)["speaker"].values[0]
60
+
61
+ if speaker is not None:
62
+ seg["speaker"] = speaker
63
+
64
+ # assign speaker to words
65
+ if 'words' in seg:
66
+ for word in seg['words']:
67
+ if 'start' in word:
68
+ diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(
69
+ diarize_df['start'], word['start'])
70
+ diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'],
71
+ word['start'])
72
+
73
+ intersected = diarize_df[diarize_df["intersection"] > 0]
74
+
75
+ word_speaker = None
76
+ if len(intersected) > 0:
77
+ # Choosing most strong intersection
78
+ word_speaker = \
79
+ intersected.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
80
+ elif fill_nearest:
81
+ # Otherwise choosing closest
82
+ word_speaker = diarize_df.sort_values(by=["intersection"], ascending=False)["speaker"].values[0]
83
+
84
+ if word_speaker is not None:
85
+ word["speaker"] = word_speaker
86
+
87
+ return transcript_result
88
+
89
+
90
+ class Segment:
91
+ def __init__(self, start, end, speaker=None):
92
+ self.start = start
93
+ self.end = end
94
+ self.speaker = speaker
modules/diarize/diarizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from typing import List, Union, BinaryIO, Optional
4
+ import numpy as np
5
+ import time
6
+ import logging
7
+
8
+ from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
9
+ from modules.diarize.audio_loader import load_audio
10
+
11
+
12
+ class Diarizer:
13
+ def __init__(self,
14
+ model_dir: str = os.path.join("models", "Diarization")
15
+ ):
16
+ self.device = self.get_device()
17
+ self.available_device = self.get_available_device()
18
+ self.compute_type = "float16"
19
+ self.model_dir = model_dir
20
+ os.makedirs(self.model_dir, exist_ok=True)
21
+ self.pipe = None
22
+
23
+ def run(self,
24
+ audio: Union[str, BinaryIO, np.ndarray],
25
+ transcribed_result: List[dict],
26
+ use_auth_token: str,
27
+ device: Optional[str] = None
28
+ ):
29
+ """
30
+ Diarize transcribed result as a post-processing
31
+
32
+ Parameters
33
+ ----------
34
+ audio: Union[str, BinaryIO, np.ndarray]
35
+ Audio input. This can be file path or binary type.
36
+ transcribed_result: List[dict]
37
+ transcribed result through whisper.
38
+ use_auth_token: str
39
+ Huggingface token with READ permission. This is only needed the first time you download the model.
40
+ You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model.
41
+ device: Optional[str]
42
+ Device for diarization.
43
+
44
+ Returns
45
+ ----------
46
+ segments_result: List[dict]
47
+ list of dicts that includes start, end timestamps and transcribed text
48
+ elapsed_time: float
49
+ elapsed time for running
50
+ """
51
+ start_time = time.time()
52
+
53
+ if device is None:
54
+ device = self.device
55
+
56
+ if device != self.device or self.pipe is None:
57
+ self.update_pipe(
58
+ device=device,
59
+ use_auth_token=use_auth_token
60
+ )
61
+
62
+ audio = load_audio(audio)
63
+
64
+ diarization_segments = self.pipe(audio)
65
+ diarized_result = assign_word_speakers(
66
+ diarization_segments,
67
+ {"segments": transcribed_result}
68
+ )
69
+
70
+ for segment in diarized_result["segments"]:
71
+ speaker = "None"
72
+ if "speaker" in segment:
73
+ speaker = segment["speaker"]
74
+ segment["text"] = speaker + "|" + segment["text"].strip()
75
+
76
+ elapsed_time = time.time() - start_time
77
+ return diarized_result["segments"], elapsed_time
78
+
79
+ def update_pipe(self,
80
+ use_auth_token: str,
81
+ device: str
82
+ ):
83
+ """
84
+ Set pipeline for diarization
85
+
86
+ Parameters
87
+ ----------
88
+ use_auth_token: str
89
+ Huggingface token with READ permission. This is only needed the first time you download the model.
90
+ You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model.
91
+ device: str
92
+ Device for diarization.
93
+ """
94
+ self.device = device
95
+
96
+ os.makedirs(self.model_dir, exist_ok=True)
97
+
98
+ if (not os.listdir(self.model_dir) and
99
+ not use_auth_token):
100
+ print(
101
+ "\nFailed to diarize. You need huggingface token and agree to their requirements to download the diarization model.\n"
102
+ "Go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and follow their instructions to download the model.\n"
103
+ )
104
+ return
105
+
106
+ logger = logging.getLogger("speechbrain.utils.train_logger")
107
+ # Disable redundant torchvision warning message
108
+ logger.disabled = True
109
+ self.pipe = DiarizationPipeline(
110
+ use_auth_token=use_auth_token,
111
+ device=device,
112
+ cache_dir=self.model_dir
113
+ )
114
+ logger.disabled = False
115
+
116
+ @staticmethod
117
+ def get_device():
118
+ if torch.cuda.is_available():
119
+ return "cuda"
120
+ elif torch.backends.mps.is_available():
121
+ return "mps"
122
+ else:
123
+ return "cpu"
124
+
125
+ @staticmethod
126
+ def get_available_device():
127
+ devices = ["cpu"]
128
+ if torch.cuda.is_available():
129
+ devices.append("cuda")
130
+ elif torch.backends.mps.is_available():
131
+ devices.append("mps")
132
+ return devices
modules/translation/__init__.py ADDED
File without changes
modules/translation/deepl_api.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import time
3
+ import os
4
+ from datetime import datetime
5
+ import gradio as gr
6
+
7
+ from modules.utils.subtitle_manager import *
8
+
9
+ """
10
+ This is written with reference to the DeepL API documentation.
11
+ If you want to know the information of the DeepL API, see here: https://www.deepl.com/docs-api/documents
12
+ """
13
+
14
+ DEEPL_AVAILABLE_TARGET_LANGS = {
15
+ 'Bulgarian': 'BG',
16
+ 'Czech': 'CS',
17
+ 'Danish': 'DA',
18
+ 'German': 'DE',
19
+ 'Greek': 'EL',
20
+ 'English': 'EN',
21
+ 'English (British)': 'EN-GB',
22
+ 'English (American)': 'EN-US',
23
+ 'Spanish': 'ES',
24
+ 'Estonian': 'ET',
25
+ 'Finnish': 'FI',
26
+ 'French': 'FR',
27
+ 'Hungarian': 'HU',
28
+ 'Indonesian': 'ID',
29
+ 'Italian': 'IT',
30
+ 'Japanese': 'JA',
31
+ 'Korean': 'KO',
32
+ 'Lithuanian': 'LT',
33
+ 'Latvian': 'LV',
34
+ 'Norwegian (Bokmål)': 'NB',
35
+ 'Dutch': 'NL',
36
+ 'Polish': 'PL',
37
+ 'Portuguese': 'PT',
38
+ 'Portuguese (Brazilian)': 'PT-BR',
39
+ 'Portuguese (all Portuguese varieties excluding Brazilian Portuguese)': 'PT-PT',
40
+ 'Romanian': 'RO',
41
+ 'Russian': 'RU',
42
+ 'Slovak': 'SK',
43
+ 'Slovenian': 'SL',
44
+ 'Swedish': 'SV',
45
+ 'Turkish': 'TR',
46
+ 'Ukrainian': 'UK',
47
+ 'Chinese (simplified)': 'ZH'
48
+ }
49
+
50
+ DEEPL_AVAILABLE_SOURCE_LANGS = {
51
+ 'Automatic Detection': None,
52
+ 'Bulgarian': 'BG',
53
+ 'Czech': 'CS',
54
+ 'Danish': 'DA',
55
+ 'German': 'DE',
56
+ 'Greek': 'EL',
57
+ 'English': 'EN',
58
+ 'Spanish': 'ES',
59
+ 'Estonian': 'ET',
60
+ 'Finnish': 'FI',
61
+ 'French': 'FR',
62
+ 'Hungarian': 'HU',
63
+ 'Indonesian': 'ID',
64
+ 'Italian': 'IT',
65
+ 'Japanese': 'JA',
66
+ 'Korean': 'KO',
67
+ 'Lithuanian': 'LT',
68
+ 'Latvian': 'LV',
69
+ 'Norwegian (Bokmål)': 'NB',
70
+ 'Dutch': 'NL',
71
+ 'Polish': 'PL',
72
+ 'Portuguese (all Portuguese varieties mixed)': 'PT',
73
+ 'Romanian': 'RO',
74
+ 'Russian': 'RU',
75
+ 'Slovak': 'SK',
76
+ 'Slovenian': 'SL',
77
+ 'Swedish': 'SV',
78
+ 'Turkish': 'TR',
79
+ 'Ukrainian': 'UK',
80
+ 'Chinese': 'ZH'
81
+ }
82
+
83
+
84
+ class DeepLAPI:
85
+ def __init__(self,
86
+ output_dir: str = os.path.join("outputs", "translations")
87
+ ):
88
+ self.api_interval = 1
89
+ self.max_text_batch_size = 50
90
+ self.available_target_langs = DEEPL_AVAILABLE_TARGET_LANGS
91
+ self.available_source_langs = DEEPL_AVAILABLE_SOURCE_LANGS
92
+ self.output_dir = output_dir
93
+
94
+ def translate_deepl(self,
95
+ auth_key: str,
96
+ fileobjs: list,
97
+ source_lang: str,
98
+ target_lang: str,
99
+ is_pro: bool,
100
+ add_timestamp: bool,
101
+ progress=gr.Progress()) -> list:
102
+ """
103
+ Translate subtitle files using DeepL API
104
+ Parameters
105
+ ----------
106
+ auth_key: str
107
+ API Key for DeepL from gr.Textbox()
108
+ fileobjs: list
109
+ List of files to transcribe from gr.Files()
110
+ source_lang: str
111
+ Source language of the file to transcribe from gr.Dropdown()
112
+ target_lang: str
113
+ Target language of the file to transcribe from gr.Dropdown()
114
+ is_pro: str
115
+ Boolean value that is about pro user or not from gr.Checkbox().
116
+ add_timestamp: bool
117
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
118
+ progress: gr.Progress
119
+ Indicator to show progress directly in gradio.
120
+
121
+ Returns
122
+ ----------
123
+ A List of
124
+ String to return to gr.Textbox()
125
+ Files to return to gr.Files()
126
+ """
127
+
128
+ files_info = {}
129
+ for fileobj in fileobjs:
130
+ file_path = fileobj.name
131
+ file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
132
+
133
+ if file_ext == ".srt":
134
+ parsed_dicts = parse_srt(file_path=file_path)
135
+
136
+ batch_size = self.max_text_batch_size
137
+ for batch_start in range(0, len(parsed_dicts), batch_size):
138
+ batch_end = min(batch_start + batch_size, len(parsed_dicts))
139
+ sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
140
+ translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
141
+ target_lang, is_pro)
142
+ for i, translated_text in enumerate(translated_texts):
143
+ parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
144
+ progress(batch_end / len(parsed_dicts), desc="Translating..")
145
+
146
+ subtitle = get_serialized_srt(parsed_dicts)
147
+
148
+ elif file_ext == ".vtt":
149
+ parsed_dicts = parse_vtt(file_path=file_path)
150
+
151
+ batch_size = self.max_text_batch_size
152
+ for batch_start in range(0, len(parsed_dicts), batch_size):
153
+ batch_end = min(batch_start + batch_size, len(parsed_dicts))
154
+ sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
155
+ translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
156
+ target_lang, is_pro)
157
+ for i, translated_text in enumerate(translated_texts):
158
+ parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
159
+ progress(batch_end / len(parsed_dicts), desc="Translating..")
160
+
161
+ subtitle = get_serialized_vtt(parsed_dicts)
162
+
163
+ if add_timestamp:
164
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
165
+ file_name += f"-{timestamp}"
166
+
167
+ output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
168
+ write_file(subtitle, output_path)
169
+
170
+ files_info[file_name] = {"subtitle": subtitle, "path": output_path}
171
+
172
+ total_result = ''
173
+ for file_name, info in files_info.items():
174
+ total_result += '------------------------------------\n'
175
+ total_result += f'{file_name}\n\n'
176
+ total_result += f'{info["subtitle"]}'
177
+ gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
178
+
179
+ output_file_paths = [item["path"] for key, item in files_info.items()]
180
+ return [gr_str, output_file_paths]
181
+
182
+ def request_deepl_translate(self,
183
+ auth_key: str,
184
+ text: list,
185
+ source_lang: str,
186
+ target_lang: str,
187
+ is_pro: bool):
188
+ """Request API response to DeepL server"""
189
+
190
+ url = 'https://api.deepl.com/v2/translate' if is_pro else 'https://api-free.deepl.com/v2/translate'
191
+ headers = {
192
+ 'Authorization': f'DeepL-Auth-Key {auth_key}'
193
+ }
194
+ data = {
195
+ 'text': text,
196
+ 'source_lang': DEEPL_AVAILABLE_SOURCE_LANGS[source_lang],
197
+ 'target_lang': DEEPL_AVAILABLE_TARGET_LANGS[target_lang]
198
+ }
199
+ response = requests.post(url, headers=headers, data=data).json()
200
+ time.sleep(self.api_interval)
201
+ return response["translations"]
modules/translation/nllb_inference.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
2
+ import gradio as gr
3
+ import os
4
+
5
+ from modules.translation.translation_base import TranslationBase
6
+
7
+
8
+ class NLLBInference(TranslationBase):
9
+ def __init__(self,
10
+ model_dir: str = os.path.join("models", "NLLB"),
11
+ output_dir: str = os.path.join("outputs", "translations")
12
+ ):
13
+ super().__init__(
14
+ model_dir=model_dir,
15
+ output_dir=output_dir
16
+ )
17
+ self.tokenizer = None
18
+ self.available_models = ["facebook/nllb-200-3.3B", "facebook/nllb-200-1.3B", "facebook/nllb-200-distilled-600M"]
19
+ self.available_source_langs = list(NLLB_AVAILABLE_LANGS.keys())
20
+ self.available_target_langs = list(NLLB_AVAILABLE_LANGS.keys())
21
+ self.pipeline = None
22
+
23
+ def translate(self,
24
+ text: str,
25
+ max_length: int
26
+ ):
27
+ result = self.pipeline(
28
+ text,
29
+ max_length=max_length
30
+ )
31
+ return result[0]['translation_text']
32
+
33
+ def update_model(self,
34
+ model_size: str,
35
+ src_lang: str,
36
+ tgt_lang: str,
37
+ progress: gr.Progress
38
+ ):
39
+ if model_size != self.current_model_size or self.model is None:
40
+ print("\nInitializing NLLB Model..\n")
41
+ progress(0, desc="Initializing NLLB Model..")
42
+ self.current_model_size = model_size
43
+ local_files_only = self.is_model_exists(self.current_model_size)
44
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size,
45
+ cache_dir=self.model_dir,
46
+ local_files_only=local_files_only)
47
+ self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
48
+ cache_dir=os.path.join(self.model_dir, "tokenizers"),
49
+ local_files_only=local_files_only)
50
+ src_lang = NLLB_AVAILABLE_LANGS[src_lang]
51
+ tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
52
+ self.pipeline = pipeline("translation",
53
+ model=self.model,
54
+ tokenizer=self.tokenizer,
55
+ src_lang=src_lang,
56
+ tgt_lang=tgt_lang,
57
+ device=self.device)
58
+
59
+ def is_model_exists(self,
60
+ model_size: str):
61
+ """Check if model exists or not (Only facebook model)"""
62
+ prefix = "models--facebook--"
63
+ _id, model_size_name = model_size.split("/")
64
+ model_dir_name = prefix + model_size_name
65
+ model_dir_path = os.path.join(self.model_dir, model_dir_name)
66
+ if os.path.exists(model_dir_path) and os.listdir(model_dir_path):
67
+ return True
68
+ return False
69
+
70
+
71
+ NLLB_AVAILABLE_LANGS = {
72
+ "Acehnese (Arabic script)": "ace_Arab",
73
+ "Acehnese (Latin script)": "ace_Latn",
74
+ "Mesopotamian Arabic": "acm_Arab",
75
+ "Ta’izzi-Adeni Arabic": "acq_Arab",
76
+ "Tunisian Arabic": "aeb_Arab",
77
+ "Afrikaans": "afr_Latn",
78
+ "South Levantine Arabic": "ajp_Arab",
79
+ "Akan": "aka_Latn",
80
+ "Amharic": "amh_Ethi",
81
+ "North Levantine Arabic": "apc_Arab",
82
+ "Modern Standard Arabic": "arb_Arab",
83
+ "Modern Standard Arabic (Romanized)": "arb_Latn",
84
+ "Najdi Arabic": "ars_Arab",
85
+ "Moroccan Arabic": "ary_Arab",
86
+ "Egyptian Arabic": "arz_Arab",
87
+ "Assamese": "asm_Beng",
88
+ "Asturian": "ast_Latn",
89
+ "Awadhi": "awa_Deva",
90
+ "Central Aymara": "ayr_Latn",
91
+ "South Azerbaijani": "azb_Arab",
92
+ "North Azerbaijani": "azj_Latn",
93
+ "Bashkir": "bak_Cyrl",
94
+ "Bambara": "bam_Latn",
95
+ "Balinese": "ban_Latn",
96
+ "Belarusian": "bel_Cyrl",
97
+ "Bemba": "bem_Latn",
98
+ "Bengali": "ben_Beng",
99
+ "Bhojpuri": "bho_Deva",
100
+ "Banjar (Arabic script)": "bjn_Arab",
101
+ "Banjar (Latin script)": "bjn_Latn",
102
+ "Standard Tibetan": "bod_Tibt",
103
+ "Bosnian": "bos_Latn",
104
+ "Buginese": "bug_Latn",
105
+ "Bulgarian": "bul_Cyrl",
106
+ "Catalan": "cat_Latn",
107
+ "Cebuano": "ceb_Latn",
108
+ "Czech": "ces_Latn",
109
+ "Chokwe": "cjk_Latn",
110
+ "Central Kurdish": "ckb_Arab",
111
+ "Crimean Tatar": "crh_Latn",
112
+ "Welsh": "cym_Latn",
113
+ "Danish": "dan_Latn",
114
+ "German": "deu_Latn",
115
+ "Southwestern Dinka": "dik_Latn",
116
+ "Dyula": "dyu_Latn",
117
+ "Dzongkha": "dzo_Tibt",
118
+ "Greek": "ell_Grek",
119
+ "English": "eng_Latn",
120
+ "Esperanto": "epo_Latn",
121
+ "Estonian": "est_Latn",
122
+ "Basque": "eus_Latn",
123
+ "Ewe": "ewe_Latn",
124
+ "Faroese": "fao_Latn",
125
+ "Fijian": "fij_Latn",
126
+ "Finnish": "fin_Latn",
127
+ "Fon": "fon_Latn",
128
+ "French": "fra_Latn",
129
+ "Friulian": "fur_Latn",
130
+ "Nigerian Fulfulde": "fuv_Latn",
131
+ "Scottish Gaelic": "gla_Latn",
132
+ "Irish": "gle_Latn",
133
+ "Galician": "glg_Latn",
134
+ "Guarani": "grn_Latn",
135
+ "Gujarati": "guj_Gujr",
136
+ "Haitian Creole": "hat_Latn",
137
+ "Hausa": "hau_Latn",
138
+ "Hebrew": "heb_Hebr",
139
+ "Hindi": "hin_Deva",
140
+ "Chhattisgarhi": "hne_Deva",
141
+ "Croatian": "hrv_Latn",
142
+ "Hungarian": "hun_Latn",
143
+ "Armenian": "hye_Armn",
144
+ "Igbo": "ibo_Latn",
145
+ "Ilocano": "ilo_Latn",
146
+ "Indonesian": "ind_Latn",
147
+ "Icelandic": "isl_Latn",
148
+ "Italian": "ita_Latn",
149
+ "Javanese": "jav_Latn",
150
+ "Japanese": "jpn_Jpan",
151
+ "Kabyle": "kab_Latn",
152
+ "Jingpho": "kac_Latn",
153
+ "Kamba": "kam_Latn",
154
+ "Kannada": "kan_Knda",
155
+ "Kashmiri (Arabic script)": "kas_Arab",
156
+ "Kashmiri (Devanagari script)": "kas_Deva",
157
+ "Georgian": "kat_Geor",
158
+ "Central Kanuri (Arabic script)": "knc_Arab",
159
+ "Central Kanuri (Latin script)": "knc_Latn",
160
+ "Kazakh": "kaz_Cyrl",
161
+ "Kabiyè": "kbp_Latn",
162
+ "Kabuverdianu": "kea_Latn",
163
+ "Khmer": "khm_Khmr",
164
+ "Kikuyu": "kik_Latn",
165
+ "Kinyarwanda": "kin_Latn",
166
+ "Kyrgyz": "kir_Cyrl",
167
+ "Kimbundu": "kmb_Latn",
168
+ "Northern Kurdish": "kmr_Latn",
169
+ "Kikongo": "kon_Latn",
170
+ "Korean": "kor_Hang",
171
+ "Lao": "lao_Laoo",
172
+ "Ligurian": "lij_Latn",
173
+ "Limburgish": "lim_Latn",
174
+ "Lingala": "lin_Latn",
175
+ "Lithuanian": "lit_Latn",
176
+ "Lombard": "lmo_Latn",
177
+ "Latgalian": "ltg_Latn",
178
+ "Luxembourgish": "ltz_Latn",
179
+ "Luba-Kasai": "lua_Latn",
180
+ "Ganda": "lug_Latn",
181
+ "Luo": "luo_Latn",
182
+ "Mizo": "lus_Latn",
183
+ "Standard Latvian": "lvs_Latn",
184
+ "Magahi": "mag_Deva",
185
+ "Maithili": "mai_Deva",
186
+ "Malayalam": "mal_Mlym",
187
+ "Marathi": "mar_Deva",
188
+ "Minangkabau (Arabic script)": "min_Arab",
189
+ "Minangkabau (Latin script)": "min_Latn",
190
+ "Macedonian": "mkd_Cyrl",
191
+ "Plateau Malagasy": "plt_Latn",
192
+ "Maltese": "mlt_Latn",
193
+ "Meitei (Bengali script)": "mni_Beng",
194
+ "Halh Mongolian": "khk_Cyrl",
195
+ "Mossi": "mos_Latn",
196
+ "Maori": "mri_Latn",
197
+ "Burmese": "mya_Mymr",
198
+ "Dutch": "nld_Latn",
199
+ "Norwegian Nynorsk": "nno_Latn",
200
+ "Norwegian Bokmål": "nob_Latn",
201
+ "Nepali": "npi_Deva",
202
+ "Northern Sotho": "nso_Latn",
203
+ "Nuer": "nus_Latn",
204
+ "Nyanja": "nya_Latn",
205
+ "Occitan": "oci_Latn",
206
+ "West Central Oromo": "gaz_Latn",
207
+ "Odia": "ory_Orya",
208
+ "Pangasinan": "pag_Latn",
209
+ "Eastern Panjabi": "pan_Guru",
210
+ "Papiamento": "pap_Latn",
211
+ "Western Persian": "pes_Arab",
212
+ "Polish": "pol_Latn",
213
+ "Portuguese": "por_Latn",
214
+ "Dari": "prs_Arab",
215
+ "Southern Pashto": "pbt_Arab",
216
+ "Ayacucho Quechua": "quy_Latn",
217
+ "Romanian": "ron_Latn",
218
+ "Rundi": "run_Latn",
219
+ "Russian": "rus_Cyrl",
220
+ "Sango": "sag_Latn",
221
+ "Sanskrit": "san_Deva",
222
+ "Santali": "sat_Olck",
223
+ "Sicilian": "scn_Latn",
224
+ "Shan": "shn_Mymr",
225
+ "Sinhala": "sin_Sinh",
226
+ "Slovak": "slk_Latn",
227
+ "Slovenian": "slv_Latn",
228
+ "Samoan": "smo_Latn",
229
+ "Shona": "sna_Latn",
230
+ "Sindhi": "snd_Arab",
231
+ "Somali": "som_Latn",
232
+ "Southern Sotho": "sot_Latn",
233
+ "Spanish": "spa_Latn",
234
+ "Tosk Albanian": "als_Latn",
235
+ "Sardinian": "srd_Latn",
236
+ "Serbian": "srp_Cyrl",
237
+ "Swati": "ssw_Latn",
238
+ "Sundanese": "sun_Latn",
239
+ "Swedish": "swe_Latn",
240
+ "Swahili": "swh_Latn",
241
+ "Silesian": "szl_Latn",
242
+ "Tamil": "tam_Taml",
243
+ "Tatar": "tat_Cyrl",
244
+ "Telugu": "tel_Telu",
245
+ "Tajik": "tgk_Cyrl",
246
+ "Tagalog": "tgl_Latn",
247
+ "Thai": "tha_Thai",
248
+ "Tigrinya": "tir_Ethi",
249
+ "Tamasheq (Latin script)": "taq_Latn",
250
+ "Tamasheq (Tifinagh script)": "taq_Tfng",
251
+ "Tok Pisin": "tpi_Latn",
252
+ "Tswana": "tsn_Latn",
253
+ "Tsonga": "tso_Latn",
254
+ "Turkmen": "tuk_Latn",
255
+ "Tumbuka": "tum_Latn",
256
+ "Turkish": "tur_Latn",
257
+ "Twi": "twi_Latn",
258
+ "Central Atlas Tamazight": "tzm_Tfng",
259
+ "Uyghur": "uig_Arab",
260
+ "Ukrainian": "ukr_Cyrl",
261
+ "Umbundu": "umb_Latn",
262
+ "Urdu": "urd_Arab",
263
+ "Northern Uzbek": "uzn_Latn",
264
+ "Venetian": "vec_Latn",
265
+ "Vietnamese": "vie_Latn",
266
+ "Waray": "war_Latn",
267
+ "Wolof": "wol_Latn",
268
+ "Xhosa": "xho_Latn",
269
+ "Eastern Yiddish": "ydd_Hebr",
270
+ "Yoruba": "yor_Latn",
271
+ "Yue Chinese": "yue_Hant",
272
+ "Chinese (Simplified)": "zho_Hans",
273
+ "Chinese (Traditional)": "zho_Hant",
274
+ "Standard Malay": "zsm_Latn",
275
+ "Zulu": "zul_Latn",
276
+ }
modules/translation/translation_base.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from abc import ABC, abstractmethod
5
+ from typing import List
6
+ from datetime import datetime
7
+
8
+ from modules.whisper.whisper_parameter import *
9
+ from modules.utils.subtitle_manager import *
10
+
11
+
12
+ class TranslationBase(ABC):
13
+ def __init__(self,
14
+ model_dir: str = os.path.join("models", "NLLB"),
15
+ output_dir: str = os.path.join("outputs", "translations")
16
+ ):
17
+ super().__init__()
18
+ self.model = None
19
+ self.model_dir = model_dir
20
+ self.output_dir = output_dir
21
+ os.makedirs(self.model_dir, exist_ok=True)
22
+ os.makedirs(self.output_dir, exist_ok=True)
23
+ self.current_model_size = None
24
+ self.device = self.get_device()
25
+
26
+ @abstractmethod
27
+ def translate(self,
28
+ text: str,
29
+ max_length: int
30
+ ):
31
+ pass
32
+
33
+ @abstractmethod
34
+ def update_model(self,
35
+ model_size: str,
36
+ src_lang: str,
37
+ tgt_lang: str,
38
+ progress: gr.Progress
39
+ ):
40
+ pass
41
+
42
+ def translate_file(self,
43
+ fileobjs: list,
44
+ model_size: str,
45
+ src_lang: str,
46
+ tgt_lang: str,
47
+ max_length: int,
48
+ add_timestamp: bool,
49
+ progress=gr.Progress()) -> list:
50
+ """
51
+ Translate subtitle file from source language to target language
52
+
53
+ Parameters
54
+ ----------
55
+ fileobjs: list
56
+ List of files to transcribe from gr.Files()
57
+ model_size: str
58
+ Whisper model size from gr.Dropdown()
59
+ src_lang: str
60
+ Source language of the file to translate from gr.Dropdown()
61
+ tgt_lang: str
62
+ Target language of the file to translate from gr.Dropdown()
63
+ max_length: int
64
+ Max length per line to translate
65
+ add_timestamp: bool
66
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
67
+ progress: gr.Progress
68
+ Indicator to show progress directly in gradio.
69
+ I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
70
+
71
+ Returns
72
+ ----------
73
+ A List of
74
+ String to return to gr.Textbox()
75
+ Files to return to gr.Files()
76
+ """
77
+ try:
78
+ self.update_model(model_size=model_size,
79
+ src_lang=src_lang,
80
+ tgt_lang=tgt_lang,
81
+ progress=progress)
82
+
83
+ files_info = {}
84
+ for fileobj in fileobjs:
85
+ file_path = fileobj.name
86
+ file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
87
+ if file_ext == ".srt":
88
+ parsed_dicts = parse_srt(file_path=file_path)
89
+ total_progress = len(parsed_dicts)
90
+ for index, dic in enumerate(parsed_dicts):
91
+ progress(index / total_progress, desc="Translating..")
92
+ translated_text = self.translate(dic["sentence"], max_length=max_length)
93
+ dic["sentence"] = translated_text
94
+ subtitle = get_serialized_srt(parsed_dicts)
95
+
96
+ elif file_ext == ".vtt":
97
+ parsed_dicts = parse_vtt(file_path=file_path)
98
+ total_progress = len(parsed_dicts)
99
+ for index, dic in enumerate(parsed_dicts):
100
+ progress(index / total_progress, desc="Translating..")
101
+ translated_text = self.translate(dic["sentence"], max_length=max_length)
102
+ dic["sentence"] = translated_text
103
+ subtitle = get_serialized_vtt(parsed_dicts)
104
+
105
+ if add_timestamp:
106
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
107
+ file_name += f"-{timestamp}"
108
+
109
+ output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
110
+ write_file(subtitle, output_path)
111
+
112
+ files_info[file_name] = {"subtitle": subtitle, "path": output_path}
113
+
114
+ total_result = ''
115
+ for file_name, info in files_info.items():
116
+ total_result += '------------------------------------\n'
117
+ total_result += f'{file_name}\n\n'
118
+ total_result += f'{info["subtitle"]}'
119
+ gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
120
+
121
+ output_file_paths = [item["path"] for key, item in files_info.items()]
122
+ return [gr_str, output_file_paths]
123
+
124
+ except Exception as e:
125
+ print(f"Error: {str(e)}")
126
+ finally:
127
+ self.release_cuda_memory()
128
+
129
+ @staticmethod
130
+ def get_device():
131
+ if torch.cuda.is_available():
132
+ return "cuda"
133
+ elif torch.backends.mps.is_available():
134
+ return "mps"
135
+ else:
136
+ return "cpu"
137
+
138
+ @staticmethod
139
+ def release_cuda_memory():
140
+ if torch.cuda.is_available():
141
+ torch.cuda.empty_cache()
142
+ torch.cuda.reset_max_memory_allocated()
143
+
144
+ @staticmethod
145
+ def remove_input_files(file_paths: List[str]):
146
+ if not file_paths:
147
+ return
148
+
149
+ for file_path in file_paths:
150
+ if file_path and os.path.exists(file_path):
151
+ os.remove(file_path)
modules/utils/__init__.py ADDED
File without changes
modules/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (173 Bytes). View file
 
modules/utils/__pycache__/files_manager.cpython-310.pyc ADDED
Binary file (1.43 kB). View file
 
modules/utils/__pycache__/subtitle_manager.cpython-310.pyc ADDED
Binary file (3.38 kB). View file
 
modules/utils/__pycache__/youtube_manager.cpython-310.pyc ADDED
Binary file (748 Bytes). View file
 
modules/utils/files_manager.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import fnmatch
3
+
4
+ from gradio.utils import NamedString
5
+
6
+
7
+ def get_media_files(folder_path, include_sub_directory=False):
8
+ video_extensions = ['*.mp4', '*.mkv', '*.flv', '*.avi', '*.mov', '*.wmv']
9
+ audio_extensions = ['*.mp3', '*.wav', '*.aac', '*.flac', '*.ogg', '*.m4a']
10
+ media_extensions = video_extensions + audio_extensions
11
+
12
+ media_files = []
13
+
14
+ if include_sub_directory:
15
+ for root, _, files in os.walk(folder_path):
16
+ for extension in media_extensions:
17
+ media_files.extend(
18
+ os.path.join(root, file) for file in fnmatch.filter(files, extension)
19
+ if os.path.exists(os.path.join(root, file))
20
+ )
21
+ else:
22
+ for extension in media_extensions:
23
+ media_files.extend(
24
+ os.path.join(folder_path, file) for file in fnmatch.filter(os.listdir(folder_path), extension)
25
+ if os.path.isfile(os.path.join(folder_path, file)) and os.path.exists(os.path.join(folder_path, file))
26
+ )
27
+
28
+ return media_files
29
+
30
+
31
+ def format_gradio_files(files: list):
32
+ if not files:
33
+ return files
34
+
35
+ gradio_files = []
36
+ for file in files:
37
+ gradio_files.append(NamedString(file))
38
+ return gradio_files
39
+
modules/utils/subtitle_manager.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ def timeformat_srt(time):
5
+ hours = time // 3600
6
+ minutes = (time - hours * 3600) // 60
7
+ seconds = time - hours * 3600 - minutes * 60
8
+ milliseconds = (time - int(time)) * 1000
9
+ return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}"
10
+
11
+
12
+ def timeformat_vtt(time):
13
+ hours = time // 3600
14
+ minutes = (time - hours * 3600) // 60
15
+ seconds = time - hours * 3600 - minutes * 60
16
+ milliseconds = (time - int(time)) * 1000
17
+ return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}"
18
+
19
+
20
+ def write_file(subtitle, output_file):
21
+ with open(output_file, 'w', encoding='utf-8') as f:
22
+ f.write(subtitle)
23
+
24
+
25
+ def get_srt(segments):
26
+ output = ""
27
+ for i, segment in enumerate(segments):
28
+ output += f"{i + 1}\n"
29
+ output += f"{timeformat_srt(segment['start'])} --> {timeformat_srt(segment['end'])}\n"
30
+ if segment['text'].startswith(' '):
31
+ segment['text'] = segment['text'][1:]
32
+ output += f"{segment['text']}\n\n"
33
+ return output
34
+
35
+
36
+ def get_vtt(segments):
37
+ output = "WebVTT\n\n"
38
+ for i, segment in enumerate(segments):
39
+ output += f"{i + 1}\n"
40
+ output += f"{timeformat_vtt(segment['start'])} --> {timeformat_vtt(segment['end'])}\n"
41
+ if segment['text'].startswith(' '):
42
+ segment['text'] = segment['text'][1:]
43
+ output += f"{segment['text']}\n\n"
44
+ return output
45
+
46
+
47
+ def get_txt(segments):
48
+ output = ""
49
+ for i, segment in enumerate(segments):
50
+ if segment['text'].startswith(' '):
51
+ segment['text'] = segment['text'][1:]
52
+ output += f"{segment['text']}\n"
53
+ return output
54
+
55
+
56
+ def parse_srt(file_path):
57
+ """Reads SRT file and returns as dict"""
58
+ with open(file_path, 'r', encoding='utf-8') as file:
59
+ srt_data = file.read()
60
+
61
+ data = []
62
+ blocks = srt_data.split('\n\n')
63
+
64
+ for block in blocks:
65
+ if block.strip() != '':
66
+ lines = block.strip().split('\n')
67
+ index = lines[0]
68
+ timestamp = lines[1]
69
+ sentence = ' '.join(lines[2:])
70
+
71
+ data.append({
72
+ "index": index,
73
+ "timestamp": timestamp,
74
+ "sentence": sentence
75
+ })
76
+ return data
77
+
78
+
79
+ def parse_vtt(file_path):
80
+ """Reads WebVTT file and returns as dict"""
81
+ with open(file_path, 'r', encoding='utf-8') as file:
82
+ webvtt_data = file.read()
83
+
84
+ data = []
85
+ blocks = webvtt_data.split('\n\n')
86
+
87
+ for block in blocks:
88
+ if block.strip() != '' and not block.strip().startswith("WebVTT"):
89
+ lines = block.strip().split('\n')
90
+ index = lines[0]
91
+ timestamp = lines[1]
92
+ sentence = ' '.join(lines[2:])
93
+
94
+ data.append({
95
+ "index": index,
96
+ "timestamp": timestamp,
97
+ "sentence": sentence
98
+ })
99
+
100
+ return data
101
+
102
+
103
+ def get_serialized_srt(dicts):
104
+ output = ""
105
+ for dic in dicts:
106
+ output += f'{dic["index"]}\n'
107
+ output += f'{dic["timestamp"]}\n'
108
+ output += f'{dic["sentence"]}\n\n'
109
+ return output
110
+
111
+
112
+ def get_serialized_vtt(dicts):
113
+ output = "WebVTT\n\n"
114
+ for dic in dicts:
115
+ output += f'{dic["index"]}\n'
116
+ output += f'{dic["timestamp"]}\n'
117
+ output += f'{dic["sentence"]}\n\n'
118
+ return output
119
+
120
+
121
+ def safe_filename(name):
122
+ from app import _args
123
+ INVALID_FILENAME_CHARS = r'[<>:"/\\|?*\x00-\x1f]'
124
+ safe_name = re.sub(INVALID_FILENAME_CHARS, '_', name)
125
+ if not _args.colab:
126
+ return safe_name
127
+ # Truncate the filename if it exceeds the max_length (20)
128
+ if len(safe_name) > 20:
129
+ file_extension = safe_name.split('.')[-1]
130
+ if len(file_extension) + 1 < 20:
131
+ truncated_name = safe_name[:20 - len(file_extension) - 1]
132
+ safe_name = truncated_name + '.' + file_extension
133
+ else:
134
+ safe_name = safe_name[:20]
135
+ return safe_name
modules/utils/youtube_manager.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytubefix import YouTube
2
+ import os
3
+
4
+
5
+ def get_ytdata(link):
6
+ return YouTube(link)
7
+
8
+
9
+ def get_ytmetas(link):
10
+ yt = YouTube(link)
11
+ return yt.thumbnail_url, yt.title, yt.description
12
+
13
+
14
+ def get_ytaudio(ytdata: YouTube):
15
+ return ytdata.streams.get_audio_only().download(filename=os.path.join("modules", "yt_tmp.wav"))
modules/vad/__init__.py ADDED
File without changes
modules/vad/silero_vad.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py
2
+
3
+ from faster_whisper.vad import VadOptions, get_vad_model
4
+ import numpy as np
5
+ from typing import BinaryIO, Union, List, Optional, Tuple
6
+ import warnings
7
+ import faster_whisper
8
+ from faster_whisper.transcribe import SpeechTimestampsMap, Segment
9
+ import gradio as gr
10
+
11
+
12
+ class SileroVAD:
13
+ def __init__(self):
14
+ self.sampling_rate = 16000
15
+ self.window_size_samples = 512
16
+ self.model = None
17
+
18
+ def run(self,
19
+ audio: Union[str, BinaryIO, np.ndarray],
20
+ vad_parameters: VadOptions,
21
+ progress: gr.Progress = gr.Progress()
22
+ ) -> Tuple[np.ndarray, List[dict]]:
23
+ """
24
+ Run VAD
25
+
26
+ Parameters
27
+ ----------
28
+ audio: Union[str, BinaryIO, np.ndarray]
29
+ Audio path or file binary or Audio numpy array
30
+ vad_parameters:
31
+ Options for VAD processing.
32
+ progress: gr.Progress
33
+ Indicator to show progress directly in gradio.
34
+
35
+ Returns
36
+ ----------
37
+ np.ndarray
38
+ Pre-processed audio with VAD
39
+ List[dict]
40
+ Chunks of speeches to be used to restore the timestamps later
41
+ """
42
+
43
+ sampling_rate = self.sampling_rate
44
+
45
+ if not isinstance(audio, np.ndarray):
46
+ audio = faster_whisper.decode_audio(audio, sampling_rate=sampling_rate)
47
+
48
+ duration = audio.shape[0] / sampling_rate
49
+ duration_after_vad = duration
50
+
51
+ if vad_parameters is None:
52
+ vad_parameters = VadOptions()
53
+ elif isinstance(vad_parameters, dict):
54
+ vad_parameters = VadOptions(**vad_parameters)
55
+ speech_chunks = self.get_speech_timestamps(
56
+ audio=audio,
57
+ vad_options=vad_parameters,
58
+ progress=progress
59
+ )
60
+ audio = self.collect_chunks(audio, speech_chunks)
61
+ duration_after_vad = audio.shape[0] / sampling_rate
62
+
63
+ return audio, speech_chunks
64
+
65
+ def get_speech_timestamps(
66
+ self,
67
+ audio: np.ndarray,
68
+ vad_options: Optional[VadOptions] = None,
69
+ progress: gr.Progress = gr.Progress(),
70
+ **kwargs,
71
+ ) -> List[dict]:
72
+ """This method is used for splitting long audios into speech chunks using silero VAD.
73
+
74
+ Args:
75
+ audio: One dimensional float array.
76
+ vad_options: Options for VAD processing.
77
+ kwargs: VAD options passed as keyword arguments for backward compatibility.
78
+ progress: Gradio progress to indicate progress.
79
+
80
+ Returns:
81
+ List of dicts containing begin and end samples of each speech chunk.
82
+ """
83
+
84
+ if self.model is None:
85
+ self.update_model()
86
+
87
+ if vad_options is None:
88
+ vad_options = VadOptions(**kwargs)
89
+
90
+ threshold = vad_options.threshold
91
+ min_speech_duration_ms = vad_options.min_speech_duration_ms
92
+ max_speech_duration_s = vad_options.max_speech_duration_s
93
+ min_silence_duration_ms = vad_options.min_silence_duration_ms
94
+ window_size_samples = self.window_size_samples
95
+ speech_pad_ms = vad_options.speech_pad_ms
96
+ sampling_rate = 16000
97
+ min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
98
+ speech_pad_samples = sampling_rate * speech_pad_ms / 1000
99
+ max_speech_samples = (
100
+ sampling_rate * max_speech_duration_s
101
+ - window_size_samples
102
+ - 2 * speech_pad_samples
103
+ )
104
+ min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
105
+ min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
106
+
107
+ audio_length_samples = len(audio)
108
+
109
+ state, context = self.model.get_initial_states(batch_size=1)
110
+
111
+ speech_probs = []
112
+ for current_start_sample in range(0, audio_length_samples, window_size_samples):
113
+ progress(current_start_sample/audio_length_samples, desc="Detecting speeches only using VAD...")
114
+
115
+ chunk = audio[current_start_sample: current_start_sample + window_size_samples]
116
+ if len(chunk) < window_size_samples:
117
+ chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
118
+ speech_prob, state, context = self.model(chunk, state, context, sampling_rate)
119
+ speech_probs.append(speech_prob)
120
+
121
+ triggered = False
122
+ speeches = []
123
+ current_speech = {}
124
+ neg_threshold = threshold - 0.15
125
+
126
+ # to save potential segment end (and tolerate some silence)
127
+ temp_end = 0
128
+ # to save potential segment limits in case of maximum segment size reached
129
+ prev_end = next_start = 0
130
+
131
+ for i, speech_prob in enumerate(speech_probs):
132
+ if (speech_prob >= threshold) and temp_end:
133
+ temp_end = 0
134
+ if next_start < prev_end:
135
+ next_start = window_size_samples * i
136
+
137
+ if (speech_prob >= threshold) and not triggered:
138
+ triggered = True
139
+ current_speech["start"] = window_size_samples * i
140
+ continue
141
+
142
+ if (
143
+ triggered
144
+ and (window_size_samples * i) - current_speech["start"] > max_speech_samples
145
+ ):
146
+ if prev_end:
147
+ current_speech["end"] = prev_end
148
+ speeches.append(current_speech)
149
+ current_speech = {}
150
+ # previously reached silence (< neg_thres) and is still not speech (< thres)
151
+ if next_start < prev_end:
152
+ triggered = False
153
+ else:
154
+ current_speech["start"] = next_start
155
+ prev_end = next_start = temp_end = 0
156
+ else:
157
+ current_speech["end"] = window_size_samples * i
158
+ speeches.append(current_speech)
159
+ current_speech = {}
160
+ prev_end = next_start = temp_end = 0
161
+ triggered = False
162
+ continue
163
+
164
+ if (speech_prob < neg_threshold) and triggered:
165
+ if not temp_end:
166
+ temp_end = window_size_samples * i
167
+ # condition to avoid cutting in very short silence
168
+ if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech:
169
+ prev_end = temp_end
170
+ if (window_size_samples * i) - temp_end < min_silence_samples:
171
+ continue
172
+ else:
173
+ current_speech["end"] = temp_end
174
+ if (
175
+ current_speech["end"] - current_speech["start"]
176
+ ) > min_speech_samples:
177
+ speeches.append(current_speech)
178
+ current_speech = {}
179
+ prev_end = next_start = temp_end = 0
180
+ triggered = False
181
+ continue
182
+
183
+ if (
184
+ current_speech
185
+ and (audio_length_samples - current_speech["start"]) > min_speech_samples
186
+ ):
187
+ current_speech["end"] = audio_length_samples
188
+ speeches.append(current_speech)
189
+
190
+ for i, speech in enumerate(speeches):
191
+ if i == 0:
192
+ speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
193
+ if i != len(speeches) - 1:
194
+ silence_duration = speeches[i + 1]["start"] - speech["end"]
195
+ if silence_duration < 2 * speech_pad_samples:
196
+ speech["end"] += int(silence_duration // 2)
197
+ speeches[i + 1]["start"] = int(
198
+ max(0, speeches[i + 1]["start"] - silence_duration // 2)
199
+ )
200
+ else:
201
+ speech["end"] = int(
202
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
203
+ )
204
+ speeches[i + 1]["start"] = int(
205
+ max(0, speeches[i + 1]["start"] - speech_pad_samples)
206
+ )
207
+ else:
208
+ speech["end"] = int(
209
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
210
+ )
211
+
212
+ return speeches
213
+
214
+ def update_model(self):
215
+ self.model = get_vad_model()
216
+
217
+ @staticmethod
218
+ def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
219
+ """Collects and concatenates audio chunks."""
220
+ if not chunks:
221
+ return np.array([], dtype=np.float32)
222
+
223
+ return np.concatenate([audio[chunk["start"]: chunk["end"]] for chunk in chunks])
224
+
225
+ @staticmethod
226
+ def format_timestamp(
227
+ seconds: float,
228
+ always_include_hours: bool = False,
229
+ decimal_marker: str = ".",
230
+ ) -> str:
231
+ assert seconds >= 0, "non-negative timestamp expected"
232
+ milliseconds = round(seconds * 1000.0)
233
+
234
+ hours = milliseconds // 3_600_000
235
+ milliseconds -= hours * 3_600_000
236
+
237
+ minutes = milliseconds // 60_000
238
+ milliseconds -= minutes * 60_000
239
+
240
+ seconds = milliseconds // 1_000
241
+ milliseconds -= seconds * 1_000
242
+
243
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
244
+ return (
245
+ f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
246
+ )
247
+
248
+ def restore_speech_timestamps(
249
+ self,
250
+ segments: List[dict],
251
+ speech_chunks: List[dict],
252
+ sampling_rate: Optional[int] = None,
253
+ ) -> List[dict]:
254
+ if sampling_rate is None:
255
+ sampling_rate = self.sampling_rate
256
+
257
+ ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
258
+
259
+ for segment in segments:
260
+ segment["start"] = ts_map.get_original_time(segment["start"])
261
+ segment["end"] = ts_map.get_original_time(segment["end"])
262
+
263
+ return segments
264
+
modules/whisper/__init__.py ADDED
File without changes
modules/whisper/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (175 Bytes). View file
 
modules/whisper/__pycache__/faster_whisper_inference.cpython-310.pyc ADDED
Binary file (6.51 kB). View file
 
modules/whisper/__pycache__/whisper_base.cpython-310.pyc ADDED
Binary file (12.9 kB). View file
 
modules/whisper/__pycache__/whisper_factory.cpython-310.pyc ADDED
Binary file (2.87 kB). View file
 
modules/whisper/__pycache__/whisper_parameter.cpython-310.pyc ADDED
Binary file (3.68 kB). View file
 
modules/whisper/faster_whisper_inference.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import numpy as np
4
+ import torch
5
+ from typing import BinaryIO, Union, Tuple, List
6
+ import faster_whisper
7
+ from faster_whisper.vad import VadOptions
8
+ import ast
9
+ import ctranslate2
10
+ import whisper
11
+ import gradio as gr
12
+ from argparse import Namespace
13
+
14
+ from modules.whisper.whisper_parameter import *
15
+ from modules.whisper.whisper_base import WhisperBase
16
+
17
+
18
+ class FasterWhisperInference(WhisperBase):
19
+ def __init__(self,
20
+ model_dir: str = os.path.join("models", "Whisper", "faster-whisper"),
21
+ diarization_model_dir: str = os.path.join("models", "Diarization"),
22
+ output_dir: str = os.path.join("outputs"),
23
+ ):
24
+ super().__init__(
25
+ model_dir=model_dir,
26
+ diarization_model_dir=diarization_model_dir,
27
+ output_dir=output_dir
28
+ )
29
+ self.model_dir = model_dir
30
+ os.makedirs(self.model_dir, exist_ok=True)
31
+
32
+ self.model_paths = self.get_model_paths()
33
+ self.device = self.get_device()
34
+ self.available_models = self.model_paths.keys()
35
+ self.available_compute_types = ctranslate2.get_supported_compute_types(
36
+ "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
37
+
38
+ def transcribe(self,
39
+ audio: Union[str, BinaryIO, np.ndarray],
40
+ progress: gr.Progress,
41
+ *whisper_params,
42
+ ) -> Tuple[List[dict], float]:
43
+ """
44
+ transcribe method for faster-whisper.
45
+
46
+ Parameters
47
+ ----------
48
+ audio: Union[str, BinaryIO, np.ndarray]
49
+ Audio path or file binary or Audio numpy array
50
+ progress: gr.Progress
51
+ Indicator to show progress directly in gradio.
52
+ *whisper_params: tuple
53
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
54
+
55
+ Returns
56
+ ----------
57
+ segments_result: List[dict]
58
+ list of dicts that includes start, end timestamps and transcribed text
59
+ elapsed_time: float
60
+ elapsed time for transcription
61
+ """
62
+ start_time = time.time()
63
+
64
+ params = WhisperParameters.as_value(*whisper_params)
65
+
66
+ if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
67
+ self.update_model(params.model_size, params.compute_type, progress)
68
+
69
+ # None parameters with Textboxes: https://github.com/gradio-app/gradio/issues/8723
70
+ if not params.initial_prompt:
71
+ params.initial_prompt = None
72
+ if not params.prefix:
73
+ params.prefix = None
74
+ if not params.hotwords:
75
+ params.hotwords = None
76
+
77
+ params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens)
78
+
79
+ segments, info = self.model.transcribe(
80
+ audio=audio,
81
+ language=params.lang,
82
+ task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
83
+ beam_size=params.beam_size,
84
+ log_prob_threshold=params.log_prob_threshold,
85
+ no_speech_threshold=params.no_speech_threshold,
86
+ best_of=params.best_of,
87
+ patience=params.patience,
88
+ temperature=params.temperature,
89
+ initial_prompt=params.initial_prompt,
90
+ compression_ratio_threshold=params.compression_ratio_threshold,
91
+ length_penalty=params.length_penalty,
92
+ repetition_penalty=params.repetition_penalty,
93
+ no_repeat_ngram_size=params.no_repeat_ngram_size,
94
+ prefix=params.prefix,
95
+ suppress_blank=params.suppress_blank,
96
+ suppress_tokens=params.suppress_tokens,
97
+ max_initial_timestamp=params.max_initial_timestamp,
98
+ word_timestamps=params.word_timestamps,
99
+ prepend_punctuations=params.prepend_punctuations,
100
+ append_punctuations=params.append_punctuations,
101
+ max_new_tokens=params.max_new_tokens,
102
+ chunk_length=params.chunk_length,
103
+ hallucination_silence_threshold=params.hallucination_silence_threshold,
104
+ hotwords=params.hotwords,
105
+ language_detection_threshold=params.language_detection_threshold,
106
+ language_detection_segments=params.language_detection_segments,
107
+ prompt_reset_on_temperature=params.prompt_reset_on_temperature,
108
+ )
109
+ progress(0, desc="Loading audio..")
110
+
111
+ segments_result = []
112
+ for segment in segments:
113
+ progress(segment.start / info.duration, desc="Transcribing..")
114
+ segments_result.append({
115
+ "start": segment.start,
116
+ "end": segment.end,
117
+ "text": segment.text
118
+ })
119
+
120
+ elapsed_time = time.time() - start_time
121
+ return segments_result, elapsed_time
122
+
123
+ def update_model(self,
124
+ model_size: str,
125
+ compute_type: str,
126
+ progress: gr.Progress
127
+ ):
128
+ """
129
+ Update current model setting
130
+
131
+ Parameters
132
+ ----------
133
+ model_size: str
134
+ Size of whisper model
135
+ compute_type: str
136
+ Compute type for transcription.
137
+ see more info : https://opennmt.net/CTranslate2/quantization.html
138
+ progress: gr.Progress
139
+ Indicator to show progress directly in gradio.
140
+ """
141
+ progress(0, desc="Initializing Model..")
142
+ self.current_model_size = self.model_paths[model_size]
143
+ self.current_compute_type = compute_type
144
+ self.model = faster_whisper.WhisperModel(
145
+ device=self.device,
146
+ model_size_or_path=self.current_model_size,
147
+ download_root=self.model_dir,
148
+ compute_type=self.current_compute_type
149
+ )
150
+
151
+ def get_model_paths(self):
152
+ """
153
+ Get available models from models path including fine-tuned model.
154
+
155
+ Returns
156
+ ----------
157
+ Name list of models
158
+ """
159
+ model_paths = {model:model for model in whisper.available_models()}
160
+ faster_whisper_prefix = "models--Systran--faster-whisper-"
161
+
162
+ existing_models = os.listdir(self.model_dir)
163
+ wrong_dirs = [".locks"]
164
+ existing_models = list(set(existing_models) - set(wrong_dirs))
165
+
166
+ webui_dir = os.getcwd()
167
+
168
+ for model_name in existing_models:
169
+ if faster_whisper_prefix in model_name:
170
+ model_name = model_name[len(faster_whisper_prefix):]
171
+
172
+ if model_name not in whisper.available_models():
173
+ model_paths[model_name] = os.path.join(webui_dir, self.model_dir, model_name)
174
+ return model_paths
175
+
176
+ @staticmethod
177
+ def get_device():
178
+ if torch.cuda.is_available():
179
+ return "cuda"
180
+ else:
181
+ return "auto"
182
+
183
+ @staticmethod
184
+ def format_suppress_tokens_str(suppress_tokens_str: str) -> List[int]:
185
+ try:
186
+ suppress_tokens = ast.literal_eval(suppress_tokens_str)
187
+ if not isinstance(suppress_tokens, list) or not all(isinstance(item, int) for item in suppress_tokens):
188
+ raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
189
+ return suppress_tokens
190
+ except Exception as e:
191
+ raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
modules/whisper/insanely_fast_whisper_inference.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import numpy as np
4
+ from typing import BinaryIO, Union, Tuple, List
5
+ import torch
6
+ from transformers import pipeline
7
+ from transformers.utils import is_flash_attn_2_available
8
+ import gradio as gr
9
+ from huggingface_hub import hf_hub_download
10
+ import whisper
11
+ from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
12
+ from argparse import Namespace
13
+
14
+ from modules.whisper.whisper_parameter import *
15
+ from modules.whisper.whisper_base import WhisperBase
16
+
17
+
18
+ class InsanelyFastWhisperInference(WhisperBase):
19
+ def __init__(self,
20
+ model_dir: str = os.path.join("models", "Whisper", "insanely-fast-whisper"),
21
+ diarization_model_dir: str = os.path.join("models", "Diarization"),
22
+ output_dir: str = os.path.join("outputs"),
23
+ ):
24
+ super().__init__(
25
+ model_dir=model_dir,
26
+ output_dir=output_dir,
27
+ diarization_model_dir=diarization_model_dir
28
+ )
29
+ self.model_dir = model_dir
30
+ os.makedirs(self.model_dir, exist_ok=True)
31
+
32
+ openai_models = whisper.available_models()
33
+ distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
34
+ self.available_models = openai_models + distil_models
35
+ self.available_compute_types = ["float16"]
36
+
37
+ def transcribe(self,
38
+ audio: Union[str, np.ndarray, torch.Tensor],
39
+ progress: gr.Progress,
40
+ *whisper_params,
41
+ ) -> Tuple[List[dict], float]:
42
+ """
43
+ transcribe method for faster-whisper.
44
+
45
+ Parameters
46
+ ----------
47
+ audio: Union[str, BinaryIO, np.ndarray]
48
+ Audio path or file binary or Audio numpy array
49
+ progress: gr.Progress
50
+ Indicator to show progress directly in gradio.
51
+ *whisper_params: tuple
52
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
53
+
54
+ Returns
55
+ ----------
56
+ segments_result: List[dict]
57
+ list of dicts that includes start, end timestamps and transcribed text
58
+ elapsed_time: float
59
+ elapsed time for transcription
60
+ """
61
+ start_time = time.time()
62
+ params = WhisperParameters.as_value(*whisper_params)
63
+
64
+ if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
65
+ self.update_model(params.model_size, params.compute_type, progress)
66
+
67
+ progress(0, desc="Transcribing...Progress is not shown in insanely-fast-whisper.")
68
+ with Progress(
69
+ TextColumn("[progress.description]{task.description}"),
70
+ BarColumn(style="yellow1", pulse_style="white"),
71
+ TimeElapsedColumn(),
72
+ ) as progress:
73
+ progress.add_task("[yellow]Transcribing...", total=None)
74
+
75
+ segments = self.model(
76
+ inputs=audio,
77
+ return_timestamps=True,
78
+ chunk_length_s=params.chunk_length_s,
79
+ batch_size=params.batch_size,
80
+ generate_kwargs={
81
+ "language": params.lang,
82
+ "task": "translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
83
+ "no_speech_threshold": params.no_speech_threshold,
84
+ "temperature": params.temperature,
85
+ "compression_ratio_threshold": params.compression_ratio_threshold
86
+ }
87
+ )
88
+
89
+ segments_result = self.format_result(
90
+ transcribed_result=segments,
91
+ )
92
+ elapsed_time = time.time() - start_time
93
+ return segments_result, elapsed_time
94
+
95
+ def update_model(self,
96
+ model_size: str,
97
+ compute_type: str,
98
+ progress: gr.Progress,
99
+ ):
100
+ """
101
+ Update current model setting
102
+
103
+ Parameters
104
+ ----------
105
+ model_size: str
106
+ Size of whisper model
107
+ compute_type: str
108
+ Compute type for transcription.
109
+ see more info : https://opennmt.net/CTranslate2/quantization.html
110
+ progress: gr.Progress
111
+ Indicator to show progress directly in gradio.
112
+ """
113
+ progress(0, desc="Initializing Model..")
114
+ model_path = os.path.join(self.model_dir, model_size)
115
+ if not os.path.isdir(model_path) or not os.listdir(model_path):
116
+ self.download_model(
117
+ model_size=model_size,
118
+ download_root=model_path,
119
+ progress=progress
120
+ )
121
+
122
+ self.current_compute_type = compute_type
123
+ self.current_model_size = model_size
124
+ self.model = pipeline(
125
+ "automatic-speech-recognition",
126
+ model=os.path.join(self.model_dir, model_size),
127
+ torch_dtype=self.current_compute_type,
128
+ device=self.device,
129
+ model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"},
130
+ )
131
+
132
+ @staticmethod
133
+ def format_result(
134
+ transcribed_result: dict
135
+ ) -> List[dict]:
136
+ """
137
+ Format the transcription result of insanely_fast_whisper as the same with other implementation.
138
+
139
+ Parameters
140
+ ----------
141
+ transcribed_result: dict
142
+ Transcription result of the insanely_fast_whisper
143
+
144
+ Returns
145
+ ----------
146
+ result: List[dict]
147
+ Formatted result as the same with other implementation
148
+ """
149
+ result = transcribed_result["chunks"]
150
+ for item in result:
151
+ start, end = item["timestamp"][0], item["timestamp"][1]
152
+ if end is None:
153
+ end = start
154
+ item["start"] = start
155
+ item["end"] = end
156
+ return result
157
+
158
+ @staticmethod
159
+ def download_model(
160
+ model_size: str,
161
+ download_root: str,
162
+ progress: gr.Progress
163
+ ):
164
+ progress(0, 'Initializing model..')
165
+ print(f'Downloading {model_size} to "{download_root}"....')
166
+
167
+ os.makedirs(download_root, exist_ok=True)
168
+ download_list = [
169
+ "model.safetensors",
170
+ "config.json",
171
+ "generation_config.json",
172
+ "preprocessor_config.json",
173
+ "tokenizer.json",
174
+ "tokenizer_config.json",
175
+ "added_tokens.json",
176
+ "special_tokens_map.json",
177
+ "vocab.json",
178
+ ]
179
+
180
+ if model_size.startswith("distil"):
181
+ repo_id = f"distil-whisper/{model_size}"
182
+ else:
183
+ repo_id = f"openai/whisper-{model_size}"
184
+ for item in download_list:
185
+ hf_hub_download(repo_id=repo_id, filename=item, local_dir=download_root)
modules/whisper/whisper_Inference.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import whisper
2
+ import gradio as gr
3
+ import time
4
+ from typing import BinaryIO, Union, Tuple, List
5
+ import numpy as np
6
+ import torch
7
+ import os
8
+ from argparse import Namespace
9
+
10
+ from modules.whisper.whisper_base import WhisperBase
11
+ from modules.whisper.whisper_parameter import *
12
+
13
+
14
+ class WhisperInference(WhisperBase):
15
+ def __init__(self,
16
+ model_dir: str = os.path.join("models", "Whisper"),
17
+ diarization_model_dir: str = os.path.join("models", "Diarization"),
18
+ output_dir: str = os.path.join("outputs"),
19
+ ):
20
+ super().__init__(
21
+ model_dir=model_dir,
22
+ output_dir=output_dir,
23
+ diarization_model_dir=diarization_model_dir
24
+ )
25
+
26
+ def transcribe(self,
27
+ audio: Union[str, np.ndarray, torch.Tensor],
28
+ progress: gr.Progress,
29
+ *whisper_params,
30
+ ) -> Tuple[List[dict], float]:
31
+ """
32
+ transcribe method for faster-whisper.
33
+
34
+ Parameters
35
+ ----------
36
+ audio: Union[str, BinaryIO, np.ndarray]
37
+ Audio path or file binary or Audio numpy array
38
+ progress: gr.Progress
39
+ Indicator to show progress directly in gradio.
40
+ *whisper_params: tuple
41
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
42
+
43
+ Returns
44
+ ----------
45
+ segments_result: List[dict]
46
+ list of dicts that includes start, end timestamps and transcribed text
47
+ elapsed_time: float
48
+ elapsed time for transcription
49
+ """
50
+ start_time = time.time()
51
+ params = WhisperParameters.as_value(*whisper_params)
52
+
53
+ if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
54
+ self.update_model(params.model_size, params.compute_type, progress)
55
+
56
+ def progress_callback(progress_value):
57
+ progress(progress_value, desc="Transcribing..")
58
+
59
+ segments_result = self.model.transcribe(audio=audio,
60
+ language=params.lang,
61
+ verbose=False,
62
+ beam_size=params.beam_size,
63
+ logprob_threshold=params.log_prob_threshold,
64
+ no_speech_threshold=params.no_speech_threshold,
65
+ task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
66
+ fp16=True if params.compute_type == "float16" else False,
67
+ best_of=params.best_of,
68
+ patience=params.patience,
69
+ temperature=params.temperature,
70
+ compression_ratio_threshold=params.compression_ratio_threshold,
71
+ progress_callback=progress_callback,)["segments"]
72
+ elapsed_time = time.time() - start_time
73
+
74
+ return segments_result, elapsed_time
75
+
76
+ def update_model(self,
77
+ model_size: str,
78
+ compute_type: str,
79
+ progress: gr.Progress,
80
+ ):
81
+ """
82
+ Update current model setting
83
+
84
+ Parameters
85
+ ----------
86
+ model_size: str
87
+ Size of whisper model
88
+ compute_type: str
89
+ Compute type for transcription.
90
+ see more info : https://opennmt.net/CTranslate2/quantization.html
91
+ progress: gr.Progress
92
+ Indicator to show progress directly in gradio.
93
+ """
94
+ progress(0, desc="Initializing Model..")
95
+ self.current_compute_type = compute_type
96
+ self.current_model_size = model_size
97
+ self.model = whisper.load_model(
98
+ name=model_size,
99
+ device=self.device,
100
+ download_root=self.model_dir
101
+ )
modules/whisper/whisper_base.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import whisper
4
+ import gradio as gr
5
+ from abc import ABC, abstractmethod
6
+ from typing import BinaryIO, Union, Tuple, List
7
+ import numpy as np
8
+ from datetime import datetime
9
+ from faster_whisper.vad import VadOptions
10
+ from dataclasses import astuple
11
+
12
+ from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
13
+ from modules.utils.youtube_manager import get_ytdata, get_ytaudio
14
+ from modules.utils.files_manager import get_media_files, format_gradio_files
15
+ from modules.whisper.whisper_parameter import *
16
+ from modules.diarize.diarizer import Diarizer
17
+ from modules.vad.silero_vad import SileroVAD
18
+
19
+
20
+ class WhisperBase(ABC):
21
+ def __init__(self,
22
+ model_dir: str = os.path.join("models", "Whisper"),
23
+ diarization_model_dir: str = os.path.join("models", "Diarization"),
24
+ output_dir: str = os.path.join("outputs"),
25
+ ):
26
+ self.model_dir = model_dir
27
+ self.output_dir = output_dir
28
+ os.makedirs(self.output_dir, exist_ok=True)
29
+ os.makedirs(self.model_dir, exist_ok=True)
30
+ self.diarizer = Diarizer(
31
+ model_dir=diarization_model_dir
32
+ )
33
+ self.vad = SileroVAD()
34
+
35
+ self.model = None
36
+ self.current_model_size = None
37
+ self.available_models = whisper.available_models()
38
+ self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
39
+ self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
40
+ self.device = self.get_device()
41
+ self.available_compute_types = ["float16", "float32"]
42
+ self.current_compute_type = "float16" if self.device == "cuda" else "float32"
43
+
44
+ @abstractmethod
45
+ def transcribe(self,
46
+ audio: Union[str, BinaryIO, np.ndarray],
47
+ progress: gr.Progress,
48
+ *whisper_params,
49
+ ):
50
+ """Inference whisper model to transcribe"""
51
+ pass
52
+
53
+ @abstractmethod
54
+ def update_model(self,
55
+ model_size: str,
56
+ compute_type: str,
57
+ progress: gr.Progress
58
+ ):
59
+ """Initialize whisper model"""
60
+ pass
61
+
62
+ def run(self,
63
+ audio: Union[str, BinaryIO, np.ndarray],
64
+ progress: gr.Progress,
65
+ *whisper_params,
66
+ ) -> Tuple[List[dict], float]:
67
+ """
68
+ Run transcription with conditional pre-processing and post-processing.
69
+ The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
70
+ The diarization will be performed in post-processing, if enabled.
71
+
72
+ Parameters
73
+ ----------
74
+ audio: Union[str, BinaryIO, np.ndarray]
75
+ Audio input. This can be file path or binary type.
76
+ progress: gr.Progress
77
+ Indicator to show progress directly in gradio.
78
+ *whisper_params: tuple
79
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
80
+
81
+ Returns
82
+ ----------
83
+ segments_result: List[dict]
84
+ list of dicts that includes start, end timestamps and transcribed text
85
+ elapsed_time: float
86
+ elapsed time for running
87
+ """
88
+ params = WhisperParameters.as_value(*whisper_params)
89
+
90
+ if params.lang == "Automatic Detection":
91
+ params.lang = None
92
+ else:
93
+ language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
94
+ params.lang = language_code_dict[params.lang]
95
+
96
+ speech_chunks = None
97
+ if params.vad_filter:
98
+ # Explicit value set for float('inf') from gr.Number()
99
+ if params.max_speech_duration_s >= 9999:
100
+ params.max_speech_duration_s = float('inf')
101
+
102
+ vad_options = VadOptions(
103
+ threshold=params.threshold,
104
+ min_speech_duration_ms=params.min_speech_duration_ms,
105
+ max_speech_duration_s=params.max_speech_duration_s,
106
+ min_silence_duration_ms=params.min_silence_duration_ms,
107
+ speech_pad_ms=params.speech_pad_ms
108
+ )
109
+
110
+ audio, speech_chunks = self.vad.run(
111
+ audio=audio,
112
+ vad_parameters=vad_options,
113
+ progress=progress
114
+ )
115
+
116
+ result, elapsed_time = self.transcribe(
117
+ audio,
118
+ progress,
119
+ *astuple(params)
120
+ )
121
+
122
+ if params.vad_filter:
123
+ result = self.vad.restore_speech_timestamps(
124
+ segments=result,
125
+ speech_chunks=speech_chunks,
126
+ )
127
+
128
+ if params.is_diarize:
129
+ result, elapsed_time_diarization = self.diarizer.run(
130
+ audio=audio,
131
+ use_auth_token=params.hf_token,
132
+ transcribed_result=result,
133
+ )
134
+ elapsed_time += elapsed_time_diarization
135
+ return result, elapsed_time
136
+
137
+ def transcribe_file(self,
138
+ files: list,
139
+ input_folder_path: str,
140
+ file_format: str,
141
+ add_timestamp: bool,
142
+ progress=gr.Progress(),
143
+ *whisper_params,
144
+ ) -> list:
145
+ """
146
+ Write subtitle file from Files
147
+
148
+ Parameters
149
+ ----------
150
+ files: list
151
+ List of files to transcribe from gr.Files()
152
+ input_folder_path: str
153
+ Input folder path to transcribe from gr.Textbox(). If this is provided, `files` will be ignored and
154
+ this will be used instead.
155
+ file_format: str
156
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
157
+ add_timestamp: bool
158
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
159
+ progress: gr.Progress
160
+ Indicator to show progress directly in gradio.
161
+ *whisper_params: tuple
162
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
163
+
164
+ Returns
165
+ ----------
166
+ result_str:
167
+ Result of transcription to return to gr.Textbox()
168
+ result_file_path:
169
+ Output file path to return to gr.Files()
170
+ """
171
+ try:
172
+ if input_folder_path:
173
+ files = get_media_files(input_folder_path)
174
+ files = format_gradio_files(files)
175
+
176
+ files_info = {}
177
+ for file in files:
178
+ transcribed_segments, time_for_task = self.run(
179
+ file.name,
180
+ progress,
181
+ *whisper_params,
182
+ )
183
+
184
+ file_name, file_ext = os.path.splitext(os.path.basename(file.name))
185
+ subtitle, file_path = self.generate_and_write_file(
186
+ file_name=file_name,
187
+ transcribed_segments=transcribed_segments,
188
+ add_timestamp=add_timestamp,
189
+ file_format=file_format,
190
+ output_dir=self.output_dir
191
+ )
192
+ files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
193
+
194
+ total_result = ''
195
+ total_time = 0
196
+ for file_name, info in files_info.items():
197
+ total_result += '------------------------------------\n'
198
+ total_result += f'{file_name}\n\n'
199
+ total_result += f'{info["subtitle"]}'
200
+ total_time += info["time_for_task"]
201
+
202
+ result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
203
+ result_file_path = [info['path'] for info in files_info.values()]
204
+
205
+ return [result_str, result_file_path]
206
+
207
+ except Exception as e:
208
+ print(f"Error transcribing file: {e}")
209
+ finally:
210
+ self.release_cuda_memory()
211
+ if not files:
212
+ self.remove_input_files([file.name for file in files])
213
+
214
+ def transcribe_mic(self,
215
+ mic_audio: str,
216
+ file_format: str,
217
+ progress=gr.Progress(),
218
+ *whisper_params,
219
+ ) -> list:
220
+ """
221
+ Write subtitle file from microphone
222
+
223
+ Parameters
224
+ ----------
225
+ mic_audio: str
226
+ Audio file path from gr.Microphone()
227
+ file_format: str
228
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
229
+ progress: gr.Progress
230
+ Indicator to show progress directly in gradio.
231
+ *whisper_params: tuple
232
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
233
+
234
+ Returns
235
+ ----------
236
+ result_str:
237
+ Result of transcription to return to gr.Textbox()
238
+ result_file_path:
239
+ Output file path to return to gr.Files()
240
+ """
241
+ try:
242
+ progress(0, desc="Loading Audio..")
243
+ transcribed_segments, time_for_task = self.run(
244
+ mic_audio,
245
+ progress,
246
+ *whisper_params,
247
+ )
248
+ progress(1, desc="Completed!")
249
+
250
+ subtitle, result_file_path = self.generate_and_write_file(
251
+ file_name="Mic",
252
+ transcribed_segments=transcribed_segments,
253
+ add_timestamp=True,
254
+ file_format=file_format,
255
+ output_dir=self.output_dir
256
+ )
257
+
258
+ result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
259
+ return [result_str, result_file_path]
260
+ except Exception as e:
261
+ print(f"Error transcribing file: {e}")
262
+ finally:
263
+ self.release_cuda_memory()
264
+ self.remove_input_files([mic_audio])
265
+
266
+ def transcribe_youtube(self,
267
+ youtube_link: str,
268
+ file_format: str,
269
+ add_timestamp: bool,
270
+ progress=gr.Progress(),
271
+ *whisper_params,
272
+ ) -> list:
273
+ """
274
+ Write subtitle file from Youtube
275
+
276
+ Parameters
277
+ ----------
278
+ youtube_link: str
279
+ URL of the Youtube video to transcribe from gr.Textbox()
280
+ file_format: str
281
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
282
+ add_timestamp: bool
283
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
284
+ progress: gr.Progress
285
+ Indicator to show progress directly in gradio.
286
+ *whisper_params: tuple
287
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
288
+
289
+ Returns
290
+ ----------
291
+ result_str:
292
+ Result of transcription to return to gr.Textbox()
293
+ result_file_path:
294
+ Output file path to return to gr.Files()
295
+ """
296
+ try:
297
+ progress(0, desc="Loading Audio from Youtube..")
298
+ yt = get_ytdata(youtube_link)
299
+ audio = get_ytaudio(yt)
300
+
301
+ transcribed_segments, time_for_task = self.run(
302
+ audio,
303
+ progress,
304
+ *whisper_params,
305
+ )
306
+
307
+ progress(1, desc="Completed!")
308
+
309
+ file_name = safe_filename(yt.title)
310
+ subtitle, result_file_path = self.generate_and_write_file(
311
+ file_name=file_name,
312
+ transcribed_segments=transcribed_segments,
313
+ add_timestamp=add_timestamp,
314
+ file_format=file_format,
315
+ output_dir=self.output_dir
316
+ )
317
+ result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
318
+
319
+ return [result_str, result_file_path]
320
+
321
+ except Exception as e:
322
+ print(f"Error transcribing file: {e}")
323
+ finally:
324
+ try:
325
+ if 'yt' not in locals():
326
+ yt = get_ytdata(youtube_link)
327
+ file_path = get_ytaudio(yt)
328
+ else:
329
+ file_path = get_ytaudio(yt)
330
+
331
+ self.release_cuda_memory()
332
+ self.remove_input_files([file_path])
333
+ except Exception as cleanup_error:
334
+ pass
335
+
336
+ @staticmethod
337
+ def generate_and_write_file(file_name: str,
338
+ transcribed_segments: list,
339
+ add_timestamp: bool,
340
+ file_format: str,
341
+ output_dir: str
342
+ ) -> str:
343
+ """
344
+ Writes subtitle file
345
+
346
+ Parameters
347
+ ----------
348
+ file_name: str
349
+ Output file name
350
+ transcribed_segments: list
351
+ Text segments transcribed from audio
352
+ add_timestamp: bool
353
+ Determines whether to add a timestamp to the end of the filename.
354
+ file_format: str
355
+ File format to write. Supported formats: [SRT, WebVTT, txt]
356
+ output_dir: str
357
+ Directory path of the output
358
+
359
+ Returns
360
+ ----------
361
+ content: str
362
+ Result of the transcription
363
+ output_path: str
364
+ output file path
365
+ """
366
+ if add_timestamp:
367
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
368
+ output_path = os.path.join(output_dir, f"{file_name}-{timestamp}")
369
+ else:
370
+ output_path = os.path.join(output_dir, f"{file_name}")
371
+
372
+ if file_format == "SRT":
373
+ content = get_srt(transcribed_segments)
374
+ output_path += '.srt'
375
+
376
+ elif file_format == "WebVTT":
377
+ content = get_vtt(transcribed_segments)
378
+ output_path += '.vtt'
379
+
380
+ elif file_format == "txt":
381
+ content = get_txt(transcribed_segments)
382
+ output_path += '.txt'
383
+
384
+ write_file(content, output_path)
385
+ return content, output_path
386
+
387
+ @staticmethod
388
+ def format_time(elapsed_time: float) -> str:
389
+ """
390
+ Get {hours} {minutes} {seconds} time format string
391
+
392
+ Parameters
393
+ ----------
394
+ elapsed_time: str
395
+ Elapsed time for transcription
396
+
397
+ Returns
398
+ ----------
399
+ Time format string
400
+ """
401
+ hours, rem = divmod(elapsed_time, 3600)
402
+ minutes, seconds = divmod(rem, 60)
403
+
404
+ time_str = ""
405
+ if hours:
406
+ time_str += f"{hours} hours "
407
+ if minutes:
408
+ time_str += f"{minutes} minutes "
409
+ seconds = round(seconds)
410
+ time_str += f"{seconds} seconds"
411
+
412
+ return time_str.strip()
413
+
414
+ @staticmethod
415
+ def get_device():
416
+ if torch.cuda.is_available():
417
+ return "cuda"
418
+ elif torch.backends.mps.is_available():
419
+ return "mps"
420
+ else:
421
+ return "cpu"
422
+
423
+ @staticmethod
424
+ def release_cuda_memory():
425
+ if torch.cuda.is_available():
426
+ torch.cuda.empty_cache()
427
+ torch.cuda.reset_max_memory_allocated()
428
+
429
+ @staticmethod
430
+ def remove_input_files(file_paths: List[str]):
431
+ if not file_paths:
432
+ return
433
+
434
+ for file_path in file_paths:
435
+ if file_path and os.path.exists(file_path):
436
+ os.remove(file_path)
modules/whisper/whisper_factory.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import os
3
+
4
+ from modules.whisper.faster_whisper_inference import FasterWhisperInference
5
+ from modules.whisper.whisper_Inference import WhisperInference
6
+ from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
7
+ from modules.whisper.whisper_base import WhisperBase
8
+
9
+
10
+ class WhisperFactory:
11
+ @staticmethod
12
+ def create_whisper_inference(
13
+ whisper_type: str,
14
+ whisper_model_dir: str = os.path.join("models", "Whisper"),
15
+ faster_whisper_model_dir: str = os.path.join("models", "Whisper", "faster-whisper"),
16
+ insanely_fast_whisper_model_dir: str = os.path.join("models", "Whisper", "insanely-fast-whisper"),
17
+ diarization_model_dir: str = os.path.join("models", "Diarization"),
18
+ output_dir: str = os.path.join("outputs"),
19
+ ) -> "WhisperBase":
20
+ """
21
+ Create a whisper inference class based on the provided whisper_type.
22
+
23
+ Parameters
24
+ ----------
25
+ whisper_type : str
26
+ The type of Whisper implementation to use. Supported values (case-insensitive):
27
+ - "faster-whisper": https://github.com/openai/whisper
28
+ - "whisper": https://github.com/openai/whisper
29
+ - "insanely-fast-whisper": https://github.com/Vaibhavs10/insanely-fast-whisper
30
+ whisper_model_dir : str
31
+ Directory path for the Whisper model.
32
+ faster_whisper_model_dir : str
33
+ Directory path for the Faster Whisper model.
34
+ insanely_fast_whisper_model_dir : str
35
+ Directory path for the Insanely Fast Whisper model.
36
+ diarization_model_dir : str
37
+ Directory path for the diarization model.
38
+ output_dir : str
39
+ Directory path where output files will be saved.
40
+
41
+ Returns
42
+ -------
43
+ WhisperBase
44
+ An instance of the appropriate whisper inference class based on the whisper_type.
45
+ """
46
+ # Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144
47
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
48
+
49
+ whisper_type = whisper_type.lower().strip()
50
+
51
+ faster_whisper_typos = ["faster_whisper", "faster-whisper", "fasterwhisper"]
52
+ whisper_typos = ["whisper"]
53
+ insanely_fast_whisper_typos = [
54
+ "insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
55
+ "insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"
56
+ ]
57
+
58
+ if whisper_type in faster_whisper_typos:
59
+ return FasterWhisperInference(
60
+ model_dir=faster_whisper_model_dir,
61
+ output_dir=output_dir,
62
+ diarization_model_dir=diarization_model_dir
63
+ )
64
+ elif whisper_type in whisper_typos:
65
+ return WhisperInference(
66
+ model_dir=whisper_model_dir,
67
+ output_dir=output_dir,
68
+ diarization_model_dir=diarization_model_dir
69
+ )
70
+ elif whisper_type in insanely_fast_whisper_typos:
71
+ return InsanelyFastWhisperInference(
72
+ model_dir=insanely_fast_whisper_model_dir,
73
+ output_dir=output_dir,
74
+ diarization_model_dir=diarization_model_dir
75
+ )
76
+ else:
77
+ return FasterWhisperInference(
78
+ model_dir=faster_whisper_model_dir,
79
+ output_dir=output_dir,
80
+ diarization_model_dir=diarization_model_dir
81
+ )
modules/whisper/whisper_parameter.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, fields
2
+ import gradio as gr
3
+ from typing import Optional
4
+
5
+
6
+ @dataclass
7
+ class WhisperParameters:
8
+ model_size: gr.Dropdown
9
+ lang: gr.Dropdown
10
+ is_translate: gr.Checkbox
11
+ beam_size: gr.Number
12
+ log_prob_threshold: gr.Number
13
+ no_speech_threshold: gr.Number
14
+ compute_type: gr.Dropdown
15
+ best_of: gr.Number
16
+ patience: gr.Number
17
+ condition_on_previous_text: gr.Checkbox
18
+ prompt_reset_on_temperature: gr.Slider
19
+ initial_prompt: gr.Textbox
20
+ temperature: gr.Slider
21
+ compression_ratio_threshold: gr.Number
22
+ vad_filter: gr.Checkbox
23
+ threshold: gr.Slider
24
+ min_speech_duration_ms: gr.Number
25
+ max_speech_duration_s: gr.Number
26
+ min_silence_duration_ms: gr.Number
27
+ speech_pad_ms: gr.Number
28
+ chunk_length_s: gr.Number
29
+ batch_size: gr.Number
30
+ is_diarize: gr.Checkbox
31
+ hf_token: gr.Textbox
32
+ diarization_device: gr.Dropdown
33
+ length_penalty: gr.Number
34
+ repetition_penalty: gr.Number
35
+ no_repeat_ngram_size: gr.Number
36
+ prefix: gr.Textbox
37
+ suppress_blank: gr.Checkbox
38
+ suppress_tokens: gr.Textbox
39
+ max_initial_timestamp: gr.Number
40
+ word_timestamps: gr.Checkbox
41
+ prepend_punctuations: gr.Textbox
42
+ append_punctuations: gr.Textbox
43
+ max_new_tokens: gr.Number
44
+ chunk_length: gr.Number
45
+ hallucination_silence_threshold: gr.Number
46
+ hotwords: gr.Textbox
47
+ language_detection_threshold: gr.Number
48
+ language_detection_segments: gr.Number
49
+ """
50
+ A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
51
+ This data class is used to mitigate the key-value problem between Gradio components and function parameters.
52
+ Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471
53
+ See more about Gradio pre-processing: https://www.gradio.app/docs/components
54
+
55
+ Attributes
56
+ ----------
57
+ model_size: gr.Dropdown
58
+ Whisper model size.
59
+
60
+ lang: gr.Dropdown
61
+ Source language of the file to transcribe.
62
+
63
+ is_translate: gr.Checkbox
64
+ Boolean value that determines whether to translate to English.
65
+ It's Whisper's feature to translate speech from another language directly into English end-to-end.
66
+
67
+ beam_size: gr.Number
68
+ Int value that is used for decoding option.
69
+
70
+ log_prob_threshold: gr.Number
71
+ If the average log probability over sampled tokens is below this value, treat as failed.
72
+
73
+ no_speech_threshold: gr.Number
74
+ If the no_speech probability is higher than this value AND
75
+ the average log probability over sampled tokens is below `log_prob_threshold`,
76
+ consider the segment as silent.
77
+
78
+ compute_type: gr.Dropdown
79
+ compute type for transcription.
80
+ see more info : https://opennmt.net/CTranslate2/quantization.html
81
+
82
+ best_of: gr.Number
83
+ Number of candidates when sampling with non-zero temperature.
84
+
85
+ patience: gr.Number
86
+ Beam search patience factor.
87
+
88
+ condition_on_previous_text: gr.Checkbox
89
+ if True, the previous output of the model is provided as a prompt for the next window;
90
+ disabling may make the text inconsistent across windows, but the model becomes less prone to
91
+ getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
92
+
93
+ initial_prompt: gr.Textbox
94
+ Optional text to provide as a prompt for the first window. This can be used to provide, or
95
+ "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
96
+ to make it more likely to predict those word correctly.
97
+
98
+ temperature: gr.Slider
99
+ Temperature for sampling. It can be a tuple of temperatures,
100
+ which will be successively used upon failures according to either
101
+ `compression_ratio_threshold` or `log_prob_threshold`.
102
+
103
+ compression_ratio_threshold: gr.Number
104
+ If the gzip compression ratio is above this value, treat as failed
105
+
106
+ vad_filter: gr.Checkbox
107
+ Enable the voice activity detection (VAD) to filter out parts of the audio
108
+ without speech. This step is using the Silero VAD model
109
+ https://github.com/snakers4/silero-vad.
110
+
111
+ threshold: gr.Slider
112
+ This parameter is related with Silero VAD. Speech threshold.
113
+ Silero VAD outputs speech probabilities for each audio chunk,
114
+ probabilities ABOVE this value are considered as SPEECH. It is better to tune this
115
+ parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
116
+
117
+ min_speech_duration_ms: gr.Number
118
+ This parameter is related with Silero VAD. Final speech chunks shorter min_speech_duration_ms are thrown out.
119
+
120
+ max_speech_duration_s: gr.Number
121
+ This parameter is related with Silero VAD. Maximum duration of speech chunks in seconds. Chunks longer
122
+ than max_speech_duration_s will be split at the timestamp of the last silence that
123
+ lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
124
+ split aggressively just before max_speech_duration_s.
125
+
126
+ min_silence_duration_ms: gr.Number
127
+ This parameter is related with Silero VAD. In the end of each speech chunk wait for min_silence_duration_ms
128
+ before separating it
129
+
130
+ speech_pad_ms: gr.Number
131
+ This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side
132
+
133
+ chunk_length_s: gr.Number
134
+ This parameter is related with insanely-fast-whisper pipe.
135
+ Maximum length of each chunk
136
+
137
+ batch_size: gr.Number
138
+ This parameter is related with insanely-fast-whisper pipe. Batch size to pass to the pipe
139
+
140
+ is_diarize: gr.Checkbox
141
+ This parameter is related with whisperx. Boolean value that determines whether to diarize or not.
142
+
143
+ hf_token: gr.Textbox
144
+ This parameter is related with whisperx. Huggingface token is needed to download diarization models.
145
+ Read more about : https://huggingface.co/pyannote/speaker-diarization-3.1#requirements
146
+
147
+ diarization_device: gr.Dropdown
148
+ This parameter is related with whisperx. Device to run diarization model
149
+
150
+ length_penalty:
151
+ This parameter is related to faster-whisper. Exponential length penalty constant.
152
+
153
+ repetition_penalty:
154
+ This parameter is related to faster-whisper. Penalty applied to the score of previously generated tokens
155
+ (set > 1 to penalize).
156
+
157
+ no_repeat_ngram_size:
158
+ This parameter is related to faster-whisper. Prevent repetitions of n-grams with this size (set 0 to disable).
159
+
160
+ prefix:
161
+ This parameter is related to faster-whisper. Optional text to provide as a prefix for the first window.
162
+
163
+ suppress_blank:
164
+ This parameter is related to faster-whisper. Suppress blank outputs at the beginning of the sampling.
165
+
166
+ suppress_tokens:
167
+ This parameter is related to faster-whisper. List of token IDs to suppress. -1 will suppress a default set
168
+ of symbols as defined in the model config.json file.
169
+
170
+ max_initial_timestamp:
171
+ This parameter is related to faster-whisper. The initial timestamp cannot be later than this.
172
+
173
+ word_timestamps:
174
+ This parameter is related to faster-whisper. Extract word-level timestamps using the cross-attention pattern
175
+ and dynamic time warping, and include the timestamps for each word in each segment.
176
+
177
+ prepend_punctuations:
178
+ This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols
179
+ with the next word.
180
+
181
+ append_punctuations:
182
+ This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols
183
+ with the previous word.
184
+
185
+ max_new_tokens:
186
+ This parameter is related to faster-whisper. Maximum number of new tokens to generate per-chunk. If not set,
187
+ the maximum will be set by the default max_length.
188
+
189
+ chunk_length:
190
+ This parameter is related to faster-whisper. The length of audio segments. If it is not None, it will overwrite the
191
+ default chunk_length of the FeatureExtractor.
192
+
193
+ hallucination_silence_threshold:
194
+ This parameter is related to faster-whisper. When word_timestamps is True, skip silent periods longer than this threshold
195
+ (in seconds) when a possible hallucination is detected.
196
+
197
+ hotwords:
198
+ This parameter is related to faster-whisper. Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.
199
+
200
+ language_detection_threshold:
201
+ This parameter is related to faster-whisper. If the maximum probability of the language tokens is higher than this value, the language is detected.
202
+
203
+ language_detection_segments:
204
+ This parameter is related to faster-whisper. Number of segments to consider for the language detection.
205
+ """
206
+
207
+ def as_list(self) -> list:
208
+ """
209
+ Converts the data class attributes into a list, Use in Gradio UI before Gradio pre-processing.
210
+ See more about Gradio pre-processing: : https://www.gradio.app/docs/components
211
+
212
+ Returns
213
+ ----------
214
+ A list of Gradio components
215
+ """
216
+ return [getattr(self, f.name) for f in fields(self)]
217
+
218
+ @staticmethod
219
+ def as_value(*args) -> 'WhisperValues':
220
+ """
221
+ To use Whisper parameters in function after Gradio post-processing.
222
+ See more about Gradio post-processing: : https://www.gradio.app/docs/components
223
+
224
+ Returns
225
+ ----------
226
+ WhisperValues
227
+ Data class that has values of parameters
228
+ """
229
+ return WhisperValues(*args)
230
+
231
+
232
+ @dataclass
233
+ class WhisperValues:
234
+ model_size: str
235
+ lang: str
236
+ is_translate: bool
237
+ beam_size: int
238
+ log_prob_threshold: float
239
+ no_speech_threshold: float
240
+ compute_type: str
241
+ best_of: int
242
+ patience: float
243
+ condition_on_previous_text: bool
244
+ prompt_reset_on_temperature: float
245
+ initial_prompt: Optional[str]
246
+ temperature: float
247
+ compression_ratio_threshold: float
248
+ vad_filter: bool
249
+ threshold: float
250
+ min_speech_duration_ms: int
251
+ max_speech_duration_s: float
252
+ min_silence_duration_ms: int
253
+ speech_pad_ms: int
254
+ chunk_length_s: int
255
+ batch_size: int
256
+ is_diarize: bool
257
+ hf_token: str
258
+ diarization_device: str
259
+ length_penalty: float
260
+ repetition_penalty: float
261
+ no_repeat_ngram_size: int
262
+ prefix: Optional[str]
263
+ suppress_blank: bool
264
+ suppress_tokens: Optional[str]
265
+ max_initial_timestamp: float
266
+ word_timestamps: bool
267
+ prepend_punctuations: Optional[str]
268
+ append_punctuations: Optional[str]
269
+ max_new_tokens: Optional[int]
270
+ chunk_length: Optional[int]
271
+ hallucination_silence_threshold: Optional[float]
272
+ hotwords: Optional[str]
273
+ language_detection_threshold: Optional[float]
274
+ language_detection_segments: int
275
+ """
276
+ A data class to use Whisper parameters.
277
+ """
outputs/outputs are saved here.txt ADDED
File without changes
outputs/translations/outputs for translation are saved here.txt ADDED
File without changes