Spaces:
Sleeping
Sleeping
Commit
·
3e210b5
1
Parent(s):
02e32e0
Add application file
Browse files- Dockerfile +3 -1
- main.py → app.py +0 -0
- models/models will be saved here.txt +0 -0
- modules/__init__.py +0 -0
- modules/__pycache__/__init__.cpython-310.pyc +0 -0
- modules/diarize/__init__.py +0 -0
- modules/diarize/__pycache__/__init__.cpython-310.pyc +0 -0
- modules/diarize/__pycache__/diarize_pipeline.cpython-310.pyc +0 -0
- modules/diarize/__pycache__/diarizer.cpython-310.pyc +0 -0
- modules/diarize/audio_loader.py +0 -179
- modules/diarize/diarize_pipeline.py +0 -94
- modules/diarize/diarizer.py +0 -132
- modules/translation/__init__.py +0 -0
- modules/translation/deepl_api.py +0 -201
- modules/translation/nllb_inference.py +0 -276
- modules/translation/translation_base.py +0 -151
- modules/utils/__init__.py +0 -0
- modules/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- modules/utils/__pycache__/files_manager.cpython-310.pyc +0 -0
- modules/utils/__pycache__/subtitle_manager.cpython-310.pyc +0 -0
- modules/utils/__pycache__/youtube_manager.cpython-310.pyc +0 -0
- modules/utils/files_manager.py +0 -39
- modules/utils/subtitle_manager.py +0 -135
- modules/utils/youtube_manager.py +0 -15
- modules/vad/__init__.py +0 -0
- modules/vad/silero_vad.py +0 -264
- modules/whisper/__init__.py +0 -0
- modules/whisper/__pycache__/__init__.cpython-310.pyc +0 -0
- modules/whisper/__pycache__/faster_whisper_inference.cpython-310.pyc +0 -0
- modules/whisper/__pycache__/whisper_base.cpython-310.pyc +0 -0
- modules/whisper/__pycache__/whisper_factory.cpython-310.pyc +0 -0
- modules/whisper/__pycache__/whisper_parameter.cpython-310.pyc +0 -0
- modules/whisper/faster_whisper_inference.py +0 -191
- modules/whisper/insanely_fast_whisper_inference.py +0 -185
- modules/whisper/whisper_Inference.py +0 -101
- modules/whisper/whisper_base.py +0 -436
- modules/whisper/whisper_factory.py +0 -81
- modules/whisper/whisper_parameter.py +0 -277
- outputs/outputs are saved here.txt +0 -0
- outputs/translations/outputs for translation are saved here.txt +0 -0
Dockerfile
CHANGED
|
@@ -25,11 +25,13 @@ WORKDIR /Whisper-WebUI
|
|
| 25 |
COPY . .
|
| 26 |
COPY --from=builder /Whisper-WebUI/venv /Whisper-WebUI/venv
|
| 27 |
|
|
|
|
|
|
|
|
|
|
| 28 |
VOLUME [ "/Whisper-WebUI/models" ]
|
| 29 |
VOLUME [ "/Whisper-WebUI/outputs" ]
|
| 30 |
|
| 31 |
ENV PATH="/Whisper-WebUI/venv/bin:$PATH"
|
| 32 |
ENV LD_LIBRARY_PATH=/Whisper-WebUI/venv/lib64/python3.11/site-packages/nvidia/cublas/lib:/Whisper-WebUI/venv/lib64/python3.11/site-packages/nvidia/cudnn/lib
|
| 33 |
|
| 34 |
-
COPY --chown=user . /app
|
| 35 |
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
|
|
|
| 25 |
COPY . .
|
| 26 |
COPY --from=builder /Whisper-WebUI/venv /Whisper-WebUI/venv
|
| 27 |
|
| 28 |
+
# Set permissions
|
| 29 |
+
RUN chown -R 1000:1000 /Whisper-WebUI/models /Whisper-WebUI/outputs
|
| 30 |
+
|
| 31 |
VOLUME [ "/Whisper-WebUI/models" ]
|
| 32 |
VOLUME [ "/Whisper-WebUI/outputs" ]
|
| 33 |
|
| 34 |
ENV PATH="/Whisper-WebUI/venv/bin:$PATH"
|
| 35 |
ENV LD_LIBRARY_PATH=/Whisper-WebUI/venv/lib64/python3.11/site-packages/nvidia/cublas/lib:/Whisper-WebUI/venv/lib64/python3.11/site-packages/nvidia/cudnn/lib
|
| 36 |
|
|
|
|
| 37 |
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
main.py → app.py
RENAMED
|
File without changes
|
models/models will be saved here.txt
DELETED
|
File without changes
|
modules/__init__.py
DELETED
|
File without changes
|
modules/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (167 Bytes)
|
|
|
modules/diarize/__init__.py
DELETED
|
File without changes
|
modules/diarize/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (175 Bytes)
|
|
|
modules/diarize/__pycache__/diarize_pipeline.cpython-310.pyc
DELETED
|
Binary file (3.06 kB)
|
|
|
modules/diarize/__pycache__/diarizer.cpython-310.pyc
DELETED
|
Binary file (4.14 kB)
|
|
|
modules/diarize/audio_loader.py
DELETED
|
@@ -1,179 +0,0 @@
|
|
| 1 |
-
# Adapted from https://github.com/m-bain/whisperX/blob/main/whisperx/audio.py
|
| 2 |
-
|
| 3 |
-
import os
|
| 4 |
-
import subprocess
|
| 5 |
-
from functools import lru_cache
|
| 6 |
-
from typing import Optional, Union
|
| 7 |
-
from scipy.io.wavfile import write
|
| 8 |
-
import tempfile
|
| 9 |
-
|
| 10 |
-
import numpy as np
|
| 11 |
-
import torch
|
| 12 |
-
import torch.nn.functional as F
|
| 13 |
-
|
| 14 |
-
def exact_div(x, y):
|
| 15 |
-
assert x % y == 0
|
| 16 |
-
return x // y
|
| 17 |
-
|
| 18 |
-
# hard-coded audio hyperparameters
|
| 19 |
-
SAMPLE_RATE = 16000
|
| 20 |
-
N_FFT = 400
|
| 21 |
-
HOP_LENGTH = 160
|
| 22 |
-
CHUNK_LENGTH = 30
|
| 23 |
-
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
| 24 |
-
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
|
| 25 |
-
|
| 26 |
-
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
|
| 27 |
-
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
| 28 |
-
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def load_audio(file: Union[str, np.ndarray], sr: int = SAMPLE_RATE) -> np.ndarray:
|
| 32 |
-
"""
|
| 33 |
-
Open an audio file or process a numpy array containing audio data as mono waveform, resampling as necessary.
|
| 34 |
-
|
| 35 |
-
Parameters
|
| 36 |
-
----------
|
| 37 |
-
file: Union[str, np.ndarray]
|
| 38 |
-
The audio file to open or a numpy array containing the audio data.
|
| 39 |
-
|
| 40 |
-
sr: int
|
| 41 |
-
The sample rate to resample the audio if necessary.
|
| 42 |
-
|
| 43 |
-
Returns
|
| 44 |
-
-------
|
| 45 |
-
A NumPy array containing the audio waveform, in float32 dtype.
|
| 46 |
-
"""
|
| 47 |
-
if isinstance(file, np.ndarray):
|
| 48 |
-
if file.dtype != np.float32:
|
| 49 |
-
file = file.astype(np.float32)
|
| 50 |
-
if file.ndim > 1:
|
| 51 |
-
file = np.mean(file, axis=1)
|
| 52 |
-
|
| 53 |
-
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
| 54 |
-
write(temp_file.name, SAMPLE_RATE, (file * 32768).astype(np.int16))
|
| 55 |
-
temp_file_path = temp_file.name
|
| 56 |
-
temp_file.close()
|
| 57 |
-
else:
|
| 58 |
-
temp_file_path = file
|
| 59 |
-
|
| 60 |
-
try:
|
| 61 |
-
cmd = [
|
| 62 |
-
"ffmpeg",
|
| 63 |
-
"-nostdin",
|
| 64 |
-
"-threads",
|
| 65 |
-
"0",
|
| 66 |
-
"-i",
|
| 67 |
-
temp_file_path,
|
| 68 |
-
"-f",
|
| 69 |
-
"s16le",
|
| 70 |
-
"-ac",
|
| 71 |
-
"1",
|
| 72 |
-
"-acodec",
|
| 73 |
-
"pcm_s16le",
|
| 74 |
-
"-ar",
|
| 75 |
-
str(sr),
|
| 76 |
-
"-",
|
| 77 |
-
]
|
| 78 |
-
out = subprocess.run(cmd, capture_output=True, check=True).stdout
|
| 79 |
-
except subprocess.CalledProcessError as e:
|
| 80 |
-
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
| 81 |
-
finally:
|
| 82 |
-
if isinstance(file, np.ndarray):
|
| 83 |
-
os.remove(temp_file_path)
|
| 84 |
-
|
| 85 |
-
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
| 89 |
-
"""
|
| 90 |
-
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
| 91 |
-
"""
|
| 92 |
-
if torch.is_tensor(array):
|
| 93 |
-
if array.shape[axis] > length:
|
| 94 |
-
array = array.index_select(
|
| 95 |
-
dim=axis, index=torch.arange(length, device=array.device)
|
| 96 |
-
)
|
| 97 |
-
|
| 98 |
-
if array.shape[axis] < length:
|
| 99 |
-
pad_widths = [(0, 0)] * array.ndim
|
| 100 |
-
pad_widths[axis] = (0, length - array.shape[axis])
|
| 101 |
-
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
|
| 102 |
-
else:
|
| 103 |
-
if array.shape[axis] > length:
|
| 104 |
-
array = array.take(indices=range(length), axis=axis)
|
| 105 |
-
|
| 106 |
-
if array.shape[axis] < length:
|
| 107 |
-
pad_widths = [(0, 0)] * array.ndim
|
| 108 |
-
pad_widths[axis] = (0, length - array.shape[axis])
|
| 109 |
-
array = np.pad(array, pad_widths)
|
| 110 |
-
|
| 111 |
-
return array
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
@lru_cache(maxsize=None)
|
| 115 |
-
def mel_filters(device, n_mels: int) -> torch.Tensor:
|
| 116 |
-
"""
|
| 117 |
-
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
| 118 |
-
Allows decoupling librosa dependency; saved using:
|
| 119 |
-
|
| 120 |
-
np.savez_compressed(
|
| 121 |
-
"mel_filters.npz",
|
| 122 |
-
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
| 123 |
-
)
|
| 124 |
-
"""
|
| 125 |
-
assert n_mels in [80, 128], f"Unsupported n_mels: {n_mels}"
|
| 126 |
-
with np.load(
|
| 127 |
-
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
| 128 |
-
) as f:
|
| 129 |
-
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
def log_mel_spectrogram(
|
| 133 |
-
audio: Union[str, np.ndarray, torch.Tensor],
|
| 134 |
-
n_mels: int,
|
| 135 |
-
padding: int = 0,
|
| 136 |
-
device: Optional[Union[str, torch.device]] = None,
|
| 137 |
-
):
|
| 138 |
-
"""
|
| 139 |
-
Compute the log-Mel spectrogram of
|
| 140 |
-
|
| 141 |
-
Parameters
|
| 142 |
-
----------
|
| 143 |
-
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
| 144 |
-
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
| 145 |
-
|
| 146 |
-
n_mels: int
|
| 147 |
-
The number of Mel-frequency filters, only 80 is supported
|
| 148 |
-
|
| 149 |
-
padding: int
|
| 150 |
-
Number of zero samples to pad to the right
|
| 151 |
-
|
| 152 |
-
device: Optional[Union[str, torch.device]]
|
| 153 |
-
If given, the audio tensor is moved to this device before STFT
|
| 154 |
-
|
| 155 |
-
Returns
|
| 156 |
-
-------
|
| 157 |
-
torch.Tensor, shape = (80, n_frames)
|
| 158 |
-
A Tensor that contains the Mel spectrogram
|
| 159 |
-
"""
|
| 160 |
-
if not torch.is_tensor(audio):
|
| 161 |
-
if isinstance(audio, str):
|
| 162 |
-
audio = load_audio(audio)
|
| 163 |
-
audio = torch.from_numpy(audio)
|
| 164 |
-
|
| 165 |
-
if device is not None:
|
| 166 |
-
audio = audio.to(device)
|
| 167 |
-
if padding > 0:
|
| 168 |
-
audio = F.pad(audio, (0, padding))
|
| 169 |
-
window = torch.hann_window(N_FFT).to(audio.device)
|
| 170 |
-
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
| 171 |
-
magnitudes = stft[..., :-1].abs() ** 2
|
| 172 |
-
|
| 173 |
-
filters = mel_filters(audio.device, n_mels)
|
| 174 |
-
mel_spec = filters @ magnitudes
|
| 175 |
-
|
| 176 |
-
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
| 177 |
-
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
| 178 |
-
log_spec = (log_spec + 4.0) / 4.0
|
| 179 |
-
return log_spec
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/diarize/diarize_pipeline.py
DELETED
|
@@ -1,94 +0,0 @@
|
|
| 1 |
-
# Adapted from https://github.com/m-bain/whisperX/blob/main/whisperx/diarize.py
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
import pandas as pd
|
| 5 |
-
import os
|
| 6 |
-
from pyannote.audio import Pipeline
|
| 7 |
-
from typing import Optional, Union
|
| 8 |
-
import torch
|
| 9 |
-
|
| 10 |
-
from modules.diarize.audio_loader import load_audio, SAMPLE_RATE
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class DiarizationPipeline:
|
| 14 |
-
def __init__(
|
| 15 |
-
self,
|
| 16 |
-
model_name="pyannote/speaker-diarization-3.1",
|
| 17 |
-
cache_dir: str = os.path.join("models", "Diarization"),
|
| 18 |
-
use_auth_token=None,
|
| 19 |
-
device: Optional[Union[str, torch.device]] = "cpu",
|
| 20 |
-
):
|
| 21 |
-
if isinstance(device, str):
|
| 22 |
-
device = torch.device(device)
|
| 23 |
-
self.model = Pipeline.from_pretrained(
|
| 24 |
-
model_name,
|
| 25 |
-
use_auth_token=use_auth_token,
|
| 26 |
-
cache_dir=cache_dir
|
| 27 |
-
).to(device)
|
| 28 |
-
|
| 29 |
-
def __call__(self, audio: Union[str, np.ndarray], min_speakers=None, max_speakers=None):
|
| 30 |
-
if isinstance(audio, str):
|
| 31 |
-
audio = load_audio(audio)
|
| 32 |
-
audio_data = {
|
| 33 |
-
'waveform': torch.from_numpy(audio[None, :]),
|
| 34 |
-
'sample_rate': SAMPLE_RATE
|
| 35 |
-
}
|
| 36 |
-
segments = self.model(audio_data, min_speakers=min_speakers, max_speakers=max_speakers)
|
| 37 |
-
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
|
| 38 |
-
diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start)
|
| 39 |
-
diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end)
|
| 40 |
-
return diarize_df
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
|
| 44 |
-
transcript_segments = transcript_result["segments"]
|
| 45 |
-
for seg in transcript_segments:
|
| 46 |
-
# assign speaker to segment (if any)
|
| 47 |
-
diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'],
|
| 48 |
-
seg['start'])
|
| 49 |
-
diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start'])
|
| 50 |
-
|
| 51 |
-
intersected = diarize_df[diarize_df["intersection"] > 0]
|
| 52 |
-
|
| 53 |
-
speaker = None
|
| 54 |
-
if len(intersected) > 0:
|
| 55 |
-
# Choosing most strong intersection
|
| 56 |
-
speaker = intersected.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
|
| 57 |
-
elif fill_nearest:
|
| 58 |
-
# Otherwise choosing closest
|
| 59 |
-
speaker = diarize_df.sort_values(by=["intersection"], ascending=False)["speaker"].values[0]
|
| 60 |
-
|
| 61 |
-
if speaker is not None:
|
| 62 |
-
seg["speaker"] = speaker
|
| 63 |
-
|
| 64 |
-
# assign speaker to words
|
| 65 |
-
if 'words' in seg:
|
| 66 |
-
for word in seg['words']:
|
| 67 |
-
if 'start' in word:
|
| 68 |
-
diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(
|
| 69 |
-
diarize_df['start'], word['start'])
|
| 70 |
-
diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'],
|
| 71 |
-
word['start'])
|
| 72 |
-
|
| 73 |
-
intersected = diarize_df[diarize_df["intersection"] > 0]
|
| 74 |
-
|
| 75 |
-
word_speaker = None
|
| 76 |
-
if len(intersected) > 0:
|
| 77 |
-
# Choosing most strong intersection
|
| 78 |
-
word_speaker = \
|
| 79 |
-
intersected.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
|
| 80 |
-
elif fill_nearest:
|
| 81 |
-
# Otherwise choosing closest
|
| 82 |
-
word_speaker = diarize_df.sort_values(by=["intersection"], ascending=False)["speaker"].values[0]
|
| 83 |
-
|
| 84 |
-
if word_speaker is not None:
|
| 85 |
-
word["speaker"] = word_speaker
|
| 86 |
-
|
| 87 |
-
return transcript_result
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
class Segment:
|
| 91 |
-
def __init__(self, start, end, speaker=None):
|
| 92 |
-
self.start = start
|
| 93 |
-
self.end = end
|
| 94 |
-
self.speaker = speaker
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/diarize/diarizer.py
DELETED
|
@@ -1,132 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import torch
|
| 3 |
-
from typing import List, Union, BinaryIO, Optional
|
| 4 |
-
import numpy as np
|
| 5 |
-
import time
|
| 6 |
-
import logging
|
| 7 |
-
|
| 8 |
-
from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
|
| 9 |
-
from modules.diarize.audio_loader import load_audio
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class Diarizer:
|
| 13 |
-
def __init__(self,
|
| 14 |
-
model_dir: str = os.path.join("models", "Diarization")
|
| 15 |
-
):
|
| 16 |
-
self.device = self.get_device()
|
| 17 |
-
self.available_device = self.get_available_device()
|
| 18 |
-
self.compute_type = "float16"
|
| 19 |
-
self.model_dir = model_dir
|
| 20 |
-
os.makedirs(self.model_dir, exist_ok=True)
|
| 21 |
-
self.pipe = None
|
| 22 |
-
|
| 23 |
-
def run(self,
|
| 24 |
-
audio: Union[str, BinaryIO, np.ndarray],
|
| 25 |
-
transcribed_result: List[dict],
|
| 26 |
-
use_auth_token: str,
|
| 27 |
-
device: Optional[str] = None
|
| 28 |
-
):
|
| 29 |
-
"""
|
| 30 |
-
Diarize transcribed result as a post-processing
|
| 31 |
-
|
| 32 |
-
Parameters
|
| 33 |
-
----------
|
| 34 |
-
audio: Union[str, BinaryIO, np.ndarray]
|
| 35 |
-
Audio input. This can be file path or binary type.
|
| 36 |
-
transcribed_result: List[dict]
|
| 37 |
-
transcribed result through whisper.
|
| 38 |
-
use_auth_token: str
|
| 39 |
-
Huggingface token with READ permission. This is only needed the first time you download the model.
|
| 40 |
-
You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model.
|
| 41 |
-
device: Optional[str]
|
| 42 |
-
Device for diarization.
|
| 43 |
-
|
| 44 |
-
Returns
|
| 45 |
-
----------
|
| 46 |
-
segments_result: List[dict]
|
| 47 |
-
list of dicts that includes start, end timestamps and transcribed text
|
| 48 |
-
elapsed_time: float
|
| 49 |
-
elapsed time for running
|
| 50 |
-
"""
|
| 51 |
-
start_time = time.time()
|
| 52 |
-
|
| 53 |
-
if device is None:
|
| 54 |
-
device = self.device
|
| 55 |
-
|
| 56 |
-
if device != self.device or self.pipe is None:
|
| 57 |
-
self.update_pipe(
|
| 58 |
-
device=device,
|
| 59 |
-
use_auth_token=use_auth_token
|
| 60 |
-
)
|
| 61 |
-
|
| 62 |
-
audio = load_audio(audio)
|
| 63 |
-
|
| 64 |
-
diarization_segments = self.pipe(audio)
|
| 65 |
-
diarized_result = assign_word_speakers(
|
| 66 |
-
diarization_segments,
|
| 67 |
-
{"segments": transcribed_result}
|
| 68 |
-
)
|
| 69 |
-
|
| 70 |
-
for segment in diarized_result["segments"]:
|
| 71 |
-
speaker = "None"
|
| 72 |
-
if "speaker" in segment:
|
| 73 |
-
speaker = segment["speaker"]
|
| 74 |
-
segment["text"] = speaker + "|" + segment["text"].strip()
|
| 75 |
-
|
| 76 |
-
elapsed_time = time.time() - start_time
|
| 77 |
-
return diarized_result["segments"], elapsed_time
|
| 78 |
-
|
| 79 |
-
def update_pipe(self,
|
| 80 |
-
use_auth_token: str,
|
| 81 |
-
device: str
|
| 82 |
-
):
|
| 83 |
-
"""
|
| 84 |
-
Set pipeline for diarization
|
| 85 |
-
|
| 86 |
-
Parameters
|
| 87 |
-
----------
|
| 88 |
-
use_auth_token: str
|
| 89 |
-
Huggingface token with READ permission. This is only needed the first time you download the model.
|
| 90 |
-
You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model.
|
| 91 |
-
device: str
|
| 92 |
-
Device for diarization.
|
| 93 |
-
"""
|
| 94 |
-
self.device = device
|
| 95 |
-
|
| 96 |
-
os.makedirs(self.model_dir, exist_ok=True)
|
| 97 |
-
|
| 98 |
-
if (not os.listdir(self.model_dir) and
|
| 99 |
-
not use_auth_token):
|
| 100 |
-
print(
|
| 101 |
-
"\nFailed to diarize. You need huggingface token and agree to their requirements to download the diarization model.\n"
|
| 102 |
-
"Go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and follow their instructions to download the model.\n"
|
| 103 |
-
)
|
| 104 |
-
return
|
| 105 |
-
|
| 106 |
-
logger = logging.getLogger("speechbrain.utils.train_logger")
|
| 107 |
-
# Disable redundant torchvision warning message
|
| 108 |
-
logger.disabled = True
|
| 109 |
-
self.pipe = DiarizationPipeline(
|
| 110 |
-
use_auth_token=use_auth_token,
|
| 111 |
-
device=device,
|
| 112 |
-
cache_dir=self.model_dir
|
| 113 |
-
)
|
| 114 |
-
logger.disabled = False
|
| 115 |
-
|
| 116 |
-
@staticmethod
|
| 117 |
-
def get_device():
|
| 118 |
-
if torch.cuda.is_available():
|
| 119 |
-
return "cuda"
|
| 120 |
-
elif torch.backends.mps.is_available():
|
| 121 |
-
return "mps"
|
| 122 |
-
else:
|
| 123 |
-
return "cpu"
|
| 124 |
-
|
| 125 |
-
@staticmethod
|
| 126 |
-
def get_available_device():
|
| 127 |
-
devices = ["cpu"]
|
| 128 |
-
if torch.cuda.is_available():
|
| 129 |
-
devices.append("cuda")
|
| 130 |
-
elif torch.backends.mps.is_available():
|
| 131 |
-
devices.append("mps")
|
| 132 |
-
return devices
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/translation/__init__.py
DELETED
|
File without changes
|
modules/translation/deepl_api.py
DELETED
|
@@ -1,201 +0,0 @@
|
|
| 1 |
-
import requests
|
| 2 |
-
import time
|
| 3 |
-
import os
|
| 4 |
-
from datetime import datetime
|
| 5 |
-
import gradio as gr
|
| 6 |
-
|
| 7 |
-
from modules.utils.subtitle_manager import *
|
| 8 |
-
|
| 9 |
-
"""
|
| 10 |
-
This is written with reference to the DeepL API documentation.
|
| 11 |
-
If you want to know the information of the DeepL API, see here: https://www.deepl.com/docs-api/documents
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
DEEPL_AVAILABLE_TARGET_LANGS = {
|
| 15 |
-
'Bulgarian': 'BG',
|
| 16 |
-
'Czech': 'CS',
|
| 17 |
-
'Danish': 'DA',
|
| 18 |
-
'German': 'DE',
|
| 19 |
-
'Greek': 'EL',
|
| 20 |
-
'English': 'EN',
|
| 21 |
-
'English (British)': 'EN-GB',
|
| 22 |
-
'English (American)': 'EN-US',
|
| 23 |
-
'Spanish': 'ES',
|
| 24 |
-
'Estonian': 'ET',
|
| 25 |
-
'Finnish': 'FI',
|
| 26 |
-
'French': 'FR',
|
| 27 |
-
'Hungarian': 'HU',
|
| 28 |
-
'Indonesian': 'ID',
|
| 29 |
-
'Italian': 'IT',
|
| 30 |
-
'Japanese': 'JA',
|
| 31 |
-
'Korean': 'KO',
|
| 32 |
-
'Lithuanian': 'LT',
|
| 33 |
-
'Latvian': 'LV',
|
| 34 |
-
'Norwegian (Bokmål)': 'NB',
|
| 35 |
-
'Dutch': 'NL',
|
| 36 |
-
'Polish': 'PL',
|
| 37 |
-
'Portuguese': 'PT',
|
| 38 |
-
'Portuguese (Brazilian)': 'PT-BR',
|
| 39 |
-
'Portuguese (all Portuguese varieties excluding Brazilian Portuguese)': 'PT-PT',
|
| 40 |
-
'Romanian': 'RO',
|
| 41 |
-
'Russian': 'RU',
|
| 42 |
-
'Slovak': 'SK',
|
| 43 |
-
'Slovenian': 'SL',
|
| 44 |
-
'Swedish': 'SV',
|
| 45 |
-
'Turkish': 'TR',
|
| 46 |
-
'Ukrainian': 'UK',
|
| 47 |
-
'Chinese (simplified)': 'ZH'
|
| 48 |
-
}
|
| 49 |
-
|
| 50 |
-
DEEPL_AVAILABLE_SOURCE_LANGS = {
|
| 51 |
-
'Automatic Detection': None,
|
| 52 |
-
'Bulgarian': 'BG',
|
| 53 |
-
'Czech': 'CS',
|
| 54 |
-
'Danish': 'DA',
|
| 55 |
-
'German': 'DE',
|
| 56 |
-
'Greek': 'EL',
|
| 57 |
-
'English': 'EN',
|
| 58 |
-
'Spanish': 'ES',
|
| 59 |
-
'Estonian': 'ET',
|
| 60 |
-
'Finnish': 'FI',
|
| 61 |
-
'French': 'FR',
|
| 62 |
-
'Hungarian': 'HU',
|
| 63 |
-
'Indonesian': 'ID',
|
| 64 |
-
'Italian': 'IT',
|
| 65 |
-
'Japanese': 'JA',
|
| 66 |
-
'Korean': 'KO',
|
| 67 |
-
'Lithuanian': 'LT',
|
| 68 |
-
'Latvian': 'LV',
|
| 69 |
-
'Norwegian (Bokmål)': 'NB',
|
| 70 |
-
'Dutch': 'NL',
|
| 71 |
-
'Polish': 'PL',
|
| 72 |
-
'Portuguese (all Portuguese varieties mixed)': 'PT',
|
| 73 |
-
'Romanian': 'RO',
|
| 74 |
-
'Russian': 'RU',
|
| 75 |
-
'Slovak': 'SK',
|
| 76 |
-
'Slovenian': 'SL',
|
| 77 |
-
'Swedish': 'SV',
|
| 78 |
-
'Turkish': 'TR',
|
| 79 |
-
'Ukrainian': 'UK',
|
| 80 |
-
'Chinese': 'ZH'
|
| 81 |
-
}
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
class DeepLAPI:
|
| 85 |
-
def __init__(self,
|
| 86 |
-
output_dir: str = os.path.join("outputs", "translations")
|
| 87 |
-
):
|
| 88 |
-
self.api_interval = 1
|
| 89 |
-
self.max_text_batch_size = 50
|
| 90 |
-
self.available_target_langs = DEEPL_AVAILABLE_TARGET_LANGS
|
| 91 |
-
self.available_source_langs = DEEPL_AVAILABLE_SOURCE_LANGS
|
| 92 |
-
self.output_dir = output_dir
|
| 93 |
-
|
| 94 |
-
def translate_deepl(self,
|
| 95 |
-
auth_key: str,
|
| 96 |
-
fileobjs: list,
|
| 97 |
-
source_lang: str,
|
| 98 |
-
target_lang: str,
|
| 99 |
-
is_pro: bool,
|
| 100 |
-
add_timestamp: bool,
|
| 101 |
-
progress=gr.Progress()) -> list:
|
| 102 |
-
"""
|
| 103 |
-
Translate subtitle files using DeepL API
|
| 104 |
-
Parameters
|
| 105 |
-
----------
|
| 106 |
-
auth_key: str
|
| 107 |
-
API Key for DeepL from gr.Textbox()
|
| 108 |
-
fileobjs: list
|
| 109 |
-
List of files to transcribe from gr.Files()
|
| 110 |
-
source_lang: str
|
| 111 |
-
Source language of the file to transcribe from gr.Dropdown()
|
| 112 |
-
target_lang: str
|
| 113 |
-
Target language of the file to transcribe from gr.Dropdown()
|
| 114 |
-
is_pro: str
|
| 115 |
-
Boolean value that is about pro user or not from gr.Checkbox().
|
| 116 |
-
add_timestamp: bool
|
| 117 |
-
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
|
| 118 |
-
progress: gr.Progress
|
| 119 |
-
Indicator to show progress directly in gradio.
|
| 120 |
-
|
| 121 |
-
Returns
|
| 122 |
-
----------
|
| 123 |
-
A List of
|
| 124 |
-
String to return to gr.Textbox()
|
| 125 |
-
Files to return to gr.Files()
|
| 126 |
-
"""
|
| 127 |
-
|
| 128 |
-
files_info = {}
|
| 129 |
-
for fileobj in fileobjs:
|
| 130 |
-
file_path = fileobj.name
|
| 131 |
-
file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
|
| 132 |
-
|
| 133 |
-
if file_ext == ".srt":
|
| 134 |
-
parsed_dicts = parse_srt(file_path=file_path)
|
| 135 |
-
|
| 136 |
-
batch_size = self.max_text_batch_size
|
| 137 |
-
for batch_start in range(0, len(parsed_dicts), batch_size):
|
| 138 |
-
batch_end = min(batch_start + batch_size, len(parsed_dicts))
|
| 139 |
-
sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
|
| 140 |
-
translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
|
| 141 |
-
target_lang, is_pro)
|
| 142 |
-
for i, translated_text in enumerate(translated_texts):
|
| 143 |
-
parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
|
| 144 |
-
progress(batch_end / len(parsed_dicts), desc="Translating..")
|
| 145 |
-
|
| 146 |
-
subtitle = get_serialized_srt(parsed_dicts)
|
| 147 |
-
|
| 148 |
-
elif file_ext == ".vtt":
|
| 149 |
-
parsed_dicts = parse_vtt(file_path=file_path)
|
| 150 |
-
|
| 151 |
-
batch_size = self.max_text_batch_size
|
| 152 |
-
for batch_start in range(0, len(parsed_dicts), batch_size):
|
| 153 |
-
batch_end = min(batch_start + batch_size, len(parsed_dicts))
|
| 154 |
-
sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
|
| 155 |
-
translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
|
| 156 |
-
target_lang, is_pro)
|
| 157 |
-
for i, translated_text in enumerate(translated_texts):
|
| 158 |
-
parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
|
| 159 |
-
progress(batch_end / len(parsed_dicts), desc="Translating..")
|
| 160 |
-
|
| 161 |
-
subtitle = get_serialized_vtt(parsed_dicts)
|
| 162 |
-
|
| 163 |
-
if add_timestamp:
|
| 164 |
-
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
| 165 |
-
file_name += f"-{timestamp}"
|
| 166 |
-
|
| 167 |
-
output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
|
| 168 |
-
write_file(subtitle, output_path)
|
| 169 |
-
|
| 170 |
-
files_info[file_name] = {"subtitle": subtitle, "path": output_path}
|
| 171 |
-
|
| 172 |
-
total_result = ''
|
| 173 |
-
for file_name, info in files_info.items():
|
| 174 |
-
total_result += '------------------------------------\n'
|
| 175 |
-
total_result += f'{file_name}\n\n'
|
| 176 |
-
total_result += f'{info["subtitle"]}'
|
| 177 |
-
gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
|
| 178 |
-
|
| 179 |
-
output_file_paths = [item["path"] for key, item in files_info.items()]
|
| 180 |
-
return [gr_str, output_file_paths]
|
| 181 |
-
|
| 182 |
-
def request_deepl_translate(self,
|
| 183 |
-
auth_key: str,
|
| 184 |
-
text: list,
|
| 185 |
-
source_lang: str,
|
| 186 |
-
target_lang: str,
|
| 187 |
-
is_pro: bool):
|
| 188 |
-
"""Request API response to DeepL server"""
|
| 189 |
-
|
| 190 |
-
url = 'https://api.deepl.com/v2/translate' if is_pro else 'https://api-free.deepl.com/v2/translate'
|
| 191 |
-
headers = {
|
| 192 |
-
'Authorization': f'DeepL-Auth-Key {auth_key}'
|
| 193 |
-
}
|
| 194 |
-
data = {
|
| 195 |
-
'text': text,
|
| 196 |
-
'source_lang': DEEPL_AVAILABLE_SOURCE_LANGS[source_lang],
|
| 197 |
-
'target_lang': DEEPL_AVAILABLE_TARGET_LANGS[target_lang]
|
| 198 |
-
}
|
| 199 |
-
response = requests.post(url, headers=headers, data=data).json()
|
| 200 |
-
time.sleep(self.api_interval)
|
| 201 |
-
return response["translations"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/translation/nllb_inference.py
DELETED
|
@@ -1,276 +0,0 @@
|
|
| 1 |
-
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
| 2 |
-
import gradio as gr
|
| 3 |
-
import os
|
| 4 |
-
|
| 5 |
-
from modules.translation.translation_base import TranslationBase
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class NLLBInference(TranslationBase):
|
| 9 |
-
def __init__(self,
|
| 10 |
-
model_dir: str = os.path.join("models", "NLLB"),
|
| 11 |
-
output_dir: str = os.path.join("outputs", "translations")
|
| 12 |
-
):
|
| 13 |
-
super().__init__(
|
| 14 |
-
model_dir=model_dir,
|
| 15 |
-
output_dir=output_dir
|
| 16 |
-
)
|
| 17 |
-
self.tokenizer = None
|
| 18 |
-
self.available_models = ["facebook/nllb-200-3.3B", "facebook/nllb-200-1.3B", "facebook/nllb-200-distilled-600M"]
|
| 19 |
-
self.available_source_langs = list(NLLB_AVAILABLE_LANGS.keys())
|
| 20 |
-
self.available_target_langs = list(NLLB_AVAILABLE_LANGS.keys())
|
| 21 |
-
self.pipeline = None
|
| 22 |
-
|
| 23 |
-
def translate(self,
|
| 24 |
-
text: str,
|
| 25 |
-
max_length: int
|
| 26 |
-
):
|
| 27 |
-
result = self.pipeline(
|
| 28 |
-
text,
|
| 29 |
-
max_length=max_length
|
| 30 |
-
)
|
| 31 |
-
return result[0]['translation_text']
|
| 32 |
-
|
| 33 |
-
def update_model(self,
|
| 34 |
-
model_size: str,
|
| 35 |
-
src_lang: str,
|
| 36 |
-
tgt_lang: str,
|
| 37 |
-
progress: gr.Progress
|
| 38 |
-
):
|
| 39 |
-
if model_size != self.current_model_size or self.model is None:
|
| 40 |
-
print("\nInitializing NLLB Model..\n")
|
| 41 |
-
progress(0, desc="Initializing NLLB Model..")
|
| 42 |
-
self.current_model_size = model_size
|
| 43 |
-
local_files_only = self.is_model_exists(self.current_model_size)
|
| 44 |
-
self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size,
|
| 45 |
-
cache_dir=self.model_dir,
|
| 46 |
-
local_files_only=local_files_only)
|
| 47 |
-
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
|
| 48 |
-
cache_dir=os.path.join(self.model_dir, "tokenizers"),
|
| 49 |
-
local_files_only=local_files_only)
|
| 50 |
-
src_lang = NLLB_AVAILABLE_LANGS[src_lang]
|
| 51 |
-
tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
|
| 52 |
-
self.pipeline = pipeline("translation",
|
| 53 |
-
model=self.model,
|
| 54 |
-
tokenizer=self.tokenizer,
|
| 55 |
-
src_lang=src_lang,
|
| 56 |
-
tgt_lang=tgt_lang,
|
| 57 |
-
device=self.device)
|
| 58 |
-
|
| 59 |
-
def is_model_exists(self,
|
| 60 |
-
model_size: str):
|
| 61 |
-
"""Check if model exists or not (Only facebook model)"""
|
| 62 |
-
prefix = "models--facebook--"
|
| 63 |
-
_id, model_size_name = model_size.split("/")
|
| 64 |
-
model_dir_name = prefix + model_size_name
|
| 65 |
-
model_dir_path = os.path.join(self.model_dir, model_dir_name)
|
| 66 |
-
if os.path.exists(model_dir_path) and os.listdir(model_dir_path):
|
| 67 |
-
return True
|
| 68 |
-
return False
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
NLLB_AVAILABLE_LANGS = {
|
| 72 |
-
"Acehnese (Arabic script)": "ace_Arab",
|
| 73 |
-
"Acehnese (Latin script)": "ace_Latn",
|
| 74 |
-
"Mesopotamian Arabic": "acm_Arab",
|
| 75 |
-
"Ta’izzi-Adeni Arabic": "acq_Arab",
|
| 76 |
-
"Tunisian Arabic": "aeb_Arab",
|
| 77 |
-
"Afrikaans": "afr_Latn",
|
| 78 |
-
"South Levantine Arabic": "ajp_Arab",
|
| 79 |
-
"Akan": "aka_Latn",
|
| 80 |
-
"Amharic": "amh_Ethi",
|
| 81 |
-
"North Levantine Arabic": "apc_Arab",
|
| 82 |
-
"Modern Standard Arabic": "arb_Arab",
|
| 83 |
-
"Modern Standard Arabic (Romanized)": "arb_Latn",
|
| 84 |
-
"Najdi Arabic": "ars_Arab",
|
| 85 |
-
"Moroccan Arabic": "ary_Arab",
|
| 86 |
-
"Egyptian Arabic": "arz_Arab",
|
| 87 |
-
"Assamese": "asm_Beng",
|
| 88 |
-
"Asturian": "ast_Latn",
|
| 89 |
-
"Awadhi": "awa_Deva",
|
| 90 |
-
"Central Aymara": "ayr_Latn",
|
| 91 |
-
"South Azerbaijani": "azb_Arab",
|
| 92 |
-
"North Azerbaijani": "azj_Latn",
|
| 93 |
-
"Bashkir": "bak_Cyrl",
|
| 94 |
-
"Bambara": "bam_Latn",
|
| 95 |
-
"Balinese": "ban_Latn",
|
| 96 |
-
"Belarusian": "bel_Cyrl",
|
| 97 |
-
"Bemba": "bem_Latn",
|
| 98 |
-
"Bengali": "ben_Beng",
|
| 99 |
-
"Bhojpuri": "bho_Deva",
|
| 100 |
-
"Banjar (Arabic script)": "bjn_Arab",
|
| 101 |
-
"Banjar (Latin script)": "bjn_Latn",
|
| 102 |
-
"Standard Tibetan": "bod_Tibt",
|
| 103 |
-
"Bosnian": "bos_Latn",
|
| 104 |
-
"Buginese": "bug_Latn",
|
| 105 |
-
"Bulgarian": "bul_Cyrl",
|
| 106 |
-
"Catalan": "cat_Latn",
|
| 107 |
-
"Cebuano": "ceb_Latn",
|
| 108 |
-
"Czech": "ces_Latn",
|
| 109 |
-
"Chokwe": "cjk_Latn",
|
| 110 |
-
"Central Kurdish": "ckb_Arab",
|
| 111 |
-
"Crimean Tatar": "crh_Latn",
|
| 112 |
-
"Welsh": "cym_Latn",
|
| 113 |
-
"Danish": "dan_Latn",
|
| 114 |
-
"German": "deu_Latn",
|
| 115 |
-
"Southwestern Dinka": "dik_Latn",
|
| 116 |
-
"Dyula": "dyu_Latn",
|
| 117 |
-
"Dzongkha": "dzo_Tibt",
|
| 118 |
-
"Greek": "ell_Grek",
|
| 119 |
-
"English": "eng_Latn",
|
| 120 |
-
"Esperanto": "epo_Latn",
|
| 121 |
-
"Estonian": "est_Latn",
|
| 122 |
-
"Basque": "eus_Latn",
|
| 123 |
-
"Ewe": "ewe_Latn",
|
| 124 |
-
"Faroese": "fao_Latn",
|
| 125 |
-
"Fijian": "fij_Latn",
|
| 126 |
-
"Finnish": "fin_Latn",
|
| 127 |
-
"Fon": "fon_Latn",
|
| 128 |
-
"French": "fra_Latn",
|
| 129 |
-
"Friulian": "fur_Latn",
|
| 130 |
-
"Nigerian Fulfulde": "fuv_Latn",
|
| 131 |
-
"Scottish Gaelic": "gla_Latn",
|
| 132 |
-
"Irish": "gle_Latn",
|
| 133 |
-
"Galician": "glg_Latn",
|
| 134 |
-
"Guarani": "grn_Latn",
|
| 135 |
-
"Gujarati": "guj_Gujr",
|
| 136 |
-
"Haitian Creole": "hat_Latn",
|
| 137 |
-
"Hausa": "hau_Latn",
|
| 138 |
-
"Hebrew": "heb_Hebr",
|
| 139 |
-
"Hindi": "hin_Deva",
|
| 140 |
-
"Chhattisgarhi": "hne_Deva",
|
| 141 |
-
"Croatian": "hrv_Latn",
|
| 142 |
-
"Hungarian": "hun_Latn",
|
| 143 |
-
"Armenian": "hye_Armn",
|
| 144 |
-
"Igbo": "ibo_Latn",
|
| 145 |
-
"Ilocano": "ilo_Latn",
|
| 146 |
-
"Indonesian": "ind_Latn",
|
| 147 |
-
"Icelandic": "isl_Latn",
|
| 148 |
-
"Italian": "ita_Latn",
|
| 149 |
-
"Javanese": "jav_Latn",
|
| 150 |
-
"Japanese": "jpn_Jpan",
|
| 151 |
-
"Kabyle": "kab_Latn",
|
| 152 |
-
"Jingpho": "kac_Latn",
|
| 153 |
-
"Kamba": "kam_Latn",
|
| 154 |
-
"Kannada": "kan_Knda",
|
| 155 |
-
"Kashmiri (Arabic script)": "kas_Arab",
|
| 156 |
-
"Kashmiri (Devanagari script)": "kas_Deva",
|
| 157 |
-
"Georgian": "kat_Geor",
|
| 158 |
-
"Central Kanuri (Arabic script)": "knc_Arab",
|
| 159 |
-
"Central Kanuri (Latin script)": "knc_Latn",
|
| 160 |
-
"Kazakh": "kaz_Cyrl",
|
| 161 |
-
"Kabiyè": "kbp_Latn",
|
| 162 |
-
"Kabuverdianu": "kea_Latn",
|
| 163 |
-
"Khmer": "khm_Khmr",
|
| 164 |
-
"Kikuyu": "kik_Latn",
|
| 165 |
-
"Kinyarwanda": "kin_Latn",
|
| 166 |
-
"Kyrgyz": "kir_Cyrl",
|
| 167 |
-
"Kimbundu": "kmb_Latn",
|
| 168 |
-
"Northern Kurdish": "kmr_Latn",
|
| 169 |
-
"Kikongo": "kon_Latn",
|
| 170 |
-
"Korean": "kor_Hang",
|
| 171 |
-
"Lao": "lao_Laoo",
|
| 172 |
-
"Ligurian": "lij_Latn",
|
| 173 |
-
"Limburgish": "lim_Latn",
|
| 174 |
-
"Lingala": "lin_Latn",
|
| 175 |
-
"Lithuanian": "lit_Latn",
|
| 176 |
-
"Lombard": "lmo_Latn",
|
| 177 |
-
"Latgalian": "ltg_Latn",
|
| 178 |
-
"Luxembourgish": "ltz_Latn",
|
| 179 |
-
"Luba-Kasai": "lua_Latn",
|
| 180 |
-
"Ganda": "lug_Latn",
|
| 181 |
-
"Luo": "luo_Latn",
|
| 182 |
-
"Mizo": "lus_Latn",
|
| 183 |
-
"Standard Latvian": "lvs_Latn",
|
| 184 |
-
"Magahi": "mag_Deva",
|
| 185 |
-
"Maithili": "mai_Deva",
|
| 186 |
-
"Malayalam": "mal_Mlym",
|
| 187 |
-
"Marathi": "mar_Deva",
|
| 188 |
-
"Minangkabau (Arabic script)": "min_Arab",
|
| 189 |
-
"Minangkabau (Latin script)": "min_Latn",
|
| 190 |
-
"Macedonian": "mkd_Cyrl",
|
| 191 |
-
"Plateau Malagasy": "plt_Latn",
|
| 192 |
-
"Maltese": "mlt_Latn",
|
| 193 |
-
"Meitei (Bengali script)": "mni_Beng",
|
| 194 |
-
"Halh Mongolian": "khk_Cyrl",
|
| 195 |
-
"Mossi": "mos_Latn",
|
| 196 |
-
"Maori": "mri_Latn",
|
| 197 |
-
"Burmese": "mya_Mymr",
|
| 198 |
-
"Dutch": "nld_Latn",
|
| 199 |
-
"Norwegian Nynorsk": "nno_Latn",
|
| 200 |
-
"Norwegian Bokmål": "nob_Latn",
|
| 201 |
-
"Nepali": "npi_Deva",
|
| 202 |
-
"Northern Sotho": "nso_Latn",
|
| 203 |
-
"Nuer": "nus_Latn",
|
| 204 |
-
"Nyanja": "nya_Latn",
|
| 205 |
-
"Occitan": "oci_Latn",
|
| 206 |
-
"West Central Oromo": "gaz_Latn",
|
| 207 |
-
"Odia": "ory_Orya",
|
| 208 |
-
"Pangasinan": "pag_Latn",
|
| 209 |
-
"Eastern Panjabi": "pan_Guru",
|
| 210 |
-
"Papiamento": "pap_Latn",
|
| 211 |
-
"Western Persian": "pes_Arab",
|
| 212 |
-
"Polish": "pol_Latn",
|
| 213 |
-
"Portuguese": "por_Latn",
|
| 214 |
-
"Dari": "prs_Arab",
|
| 215 |
-
"Southern Pashto": "pbt_Arab",
|
| 216 |
-
"Ayacucho Quechua": "quy_Latn",
|
| 217 |
-
"Romanian": "ron_Latn",
|
| 218 |
-
"Rundi": "run_Latn",
|
| 219 |
-
"Russian": "rus_Cyrl",
|
| 220 |
-
"Sango": "sag_Latn",
|
| 221 |
-
"Sanskrit": "san_Deva",
|
| 222 |
-
"Santali": "sat_Olck",
|
| 223 |
-
"Sicilian": "scn_Latn",
|
| 224 |
-
"Shan": "shn_Mymr",
|
| 225 |
-
"Sinhala": "sin_Sinh",
|
| 226 |
-
"Slovak": "slk_Latn",
|
| 227 |
-
"Slovenian": "slv_Latn",
|
| 228 |
-
"Samoan": "smo_Latn",
|
| 229 |
-
"Shona": "sna_Latn",
|
| 230 |
-
"Sindhi": "snd_Arab",
|
| 231 |
-
"Somali": "som_Latn",
|
| 232 |
-
"Southern Sotho": "sot_Latn",
|
| 233 |
-
"Spanish": "spa_Latn",
|
| 234 |
-
"Tosk Albanian": "als_Latn",
|
| 235 |
-
"Sardinian": "srd_Latn",
|
| 236 |
-
"Serbian": "srp_Cyrl",
|
| 237 |
-
"Swati": "ssw_Latn",
|
| 238 |
-
"Sundanese": "sun_Latn",
|
| 239 |
-
"Swedish": "swe_Latn",
|
| 240 |
-
"Swahili": "swh_Latn",
|
| 241 |
-
"Silesian": "szl_Latn",
|
| 242 |
-
"Tamil": "tam_Taml",
|
| 243 |
-
"Tatar": "tat_Cyrl",
|
| 244 |
-
"Telugu": "tel_Telu",
|
| 245 |
-
"Tajik": "tgk_Cyrl",
|
| 246 |
-
"Tagalog": "tgl_Latn",
|
| 247 |
-
"Thai": "tha_Thai",
|
| 248 |
-
"Tigrinya": "tir_Ethi",
|
| 249 |
-
"Tamasheq (Latin script)": "taq_Latn",
|
| 250 |
-
"Tamasheq (Tifinagh script)": "taq_Tfng",
|
| 251 |
-
"Tok Pisin": "tpi_Latn",
|
| 252 |
-
"Tswana": "tsn_Latn",
|
| 253 |
-
"Tsonga": "tso_Latn",
|
| 254 |
-
"Turkmen": "tuk_Latn",
|
| 255 |
-
"Tumbuka": "tum_Latn",
|
| 256 |
-
"Turkish": "tur_Latn",
|
| 257 |
-
"Twi": "twi_Latn",
|
| 258 |
-
"Central Atlas Tamazight": "tzm_Tfng",
|
| 259 |
-
"Uyghur": "uig_Arab",
|
| 260 |
-
"Ukrainian": "ukr_Cyrl",
|
| 261 |
-
"Umbundu": "umb_Latn",
|
| 262 |
-
"Urdu": "urd_Arab",
|
| 263 |
-
"Northern Uzbek": "uzn_Latn",
|
| 264 |
-
"Venetian": "vec_Latn",
|
| 265 |
-
"Vietnamese": "vie_Latn",
|
| 266 |
-
"Waray": "war_Latn",
|
| 267 |
-
"Wolof": "wol_Latn",
|
| 268 |
-
"Xhosa": "xho_Latn",
|
| 269 |
-
"Eastern Yiddish": "ydd_Hebr",
|
| 270 |
-
"Yoruba": "yor_Latn",
|
| 271 |
-
"Yue Chinese": "yue_Hant",
|
| 272 |
-
"Chinese (Simplified)": "zho_Hans",
|
| 273 |
-
"Chinese (Traditional)": "zho_Hant",
|
| 274 |
-
"Standard Malay": "zsm_Latn",
|
| 275 |
-
"Zulu": "zul_Latn",
|
| 276 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/translation/translation_base.py
DELETED
|
@@ -1,151 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import torch
|
| 3 |
-
import gradio as gr
|
| 4 |
-
from abc import ABC, abstractmethod
|
| 5 |
-
from typing import List
|
| 6 |
-
from datetime import datetime
|
| 7 |
-
|
| 8 |
-
from modules.whisper.whisper_parameter import *
|
| 9 |
-
from modules.utils.subtitle_manager import *
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class TranslationBase(ABC):
|
| 13 |
-
def __init__(self,
|
| 14 |
-
model_dir: str = os.path.join("models", "NLLB"),
|
| 15 |
-
output_dir: str = os.path.join("outputs", "translations")
|
| 16 |
-
):
|
| 17 |
-
super().__init__()
|
| 18 |
-
self.model = None
|
| 19 |
-
self.model_dir = model_dir
|
| 20 |
-
self.output_dir = output_dir
|
| 21 |
-
os.makedirs(self.model_dir, exist_ok=True)
|
| 22 |
-
os.makedirs(self.output_dir, exist_ok=True)
|
| 23 |
-
self.current_model_size = None
|
| 24 |
-
self.device = self.get_device()
|
| 25 |
-
|
| 26 |
-
@abstractmethod
|
| 27 |
-
def translate(self,
|
| 28 |
-
text: str,
|
| 29 |
-
max_length: int
|
| 30 |
-
):
|
| 31 |
-
pass
|
| 32 |
-
|
| 33 |
-
@abstractmethod
|
| 34 |
-
def update_model(self,
|
| 35 |
-
model_size: str,
|
| 36 |
-
src_lang: str,
|
| 37 |
-
tgt_lang: str,
|
| 38 |
-
progress: gr.Progress
|
| 39 |
-
):
|
| 40 |
-
pass
|
| 41 |
-
|
| 42 |
-
def translate_file(self,
|
| 43 |
-
fileobjs: list,
|
| 44 |
-
model_size: str,
|
| 45 |
-
src_lang: str,
|
| 46 |
-
tgt_lang: str,
|
| 47 |
-
max_length: int,
|
| 48 |
-
add_timestamp: bool,
|
| 49 |
-
progress=gr.Progress()) -> list:
|
| 50 |
-
"""
|
| 51 |
-
Translate subtitle file from source language to target language
|
| 52 |
-
|
| 53 |
-
Parameters
|
| 54 |
-
----------
|
| 55 |
-
fileobjs: list
|
| 56 |
-
List of files to transcribe from gr.Files()
|
| 57 |
-
model_size: str
|
| 58 |
-
Whisper model size from gr.Dropdown()
|
| 59 |
-
src_lang: str
|
| 60 |
-
Source language of the file to translate from gr.Dropdown()
|
| 61 |
-
tgt_lang: str
|
| 62 |
-
Target language of the file to translate from gr.Dropdown()
|
| 63 |
-
max_length: int
|
| 64 |
-
Max length per line to translate
|
| 65 |
-
add_timestamp: bool
|
| 66 |
-
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
|
| 67 |
-
progress: gr.Progress
|
| 68 |
-
Indicator to show progress directly in gradio.
|
| 69 |
-
I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
|
| 70 |
-
|
| 71 |
-
Returns
|
| 72 |
-
----------
|
| 73 |
-
A List of
|
| 74 |
-
String to return to gr.Textbox()
|
| 75 |
-
Files to return to gr.Files()
|
| 76 |
-
"""
|
| 77 |
-
try:
|
| 78 |
-
self.update_model(model_size=model_size,
|
| 79 |
-
src_lang=src_lang,
|
| 80 |
-
tgt_lang=tgt_lang,
|
| 81 |
-
progress=progress)
|
| 82 |
-
|
| 83 |
-
files_info = {}
|
| 84 |
-
for fileobj in fileobjs:
|
| 85 |
-
file_path = fileobj.name
|
| 86 |
-
file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
|
| 87 |
-
if file_ext == ".srt":
|
| 88 |
-
parsed_dicts = parse_srt(file_path=file_path)
|
| 89 |
-
total_progress = len(parsed_dicts)
|
| 90 |
-
for index, dic in enumerate(parsed_dicts):
|
| 91 |
-
progress(index / total_progress, desc="Translating..")
|
| 92 |
-
translated_text = self.translate(dic["sentence"], max_length=max_length)
|
| 93 |
-
dic["sentence"] = translated_text
|
| 94 |
-
subtitle = get_serialized_srt(parsed_dicts)
|
| 95 |
-
|
| 96 |
-
elif file_ext == ".vtt":
|
| 97 |
-
parsed_dicts = parse_vtt(file_path=file_path)
|
| 98 |
-
total_progress = len(parsed_dicts)
|
| 99 |
-
for index, dic in enumerate(parsed_dicts):
|
| 100 |
-
progress(index / total_progress, desc="Translating..")
|
| 101 |
-
translated_text = self.translate(dic["sentence"], max_length=max_length)
|
| 102 |
-
dic["sentence"] = translated_text
|
| 103 |
-
subtitle = get_serialized_vtt(parsed_dicts)
|
| 104 |
-
|
| 105 |
-
if add_timestamp:
|
| 106 |
-
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
| 107 |
-
file_name += f"-{timestamp}"
|
| 108 |
-
|
| 109 |
-
output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
|
| 110 |
-
write_file(subtitle, output_path)
|
| 111 |
-
|
| 112 |
-
files_info[file_name] = {"subtitle": subtitle, "path": output_path}
|
| 113 |
-
|
| 114 |
-
total_result = ''
|
| 115 |
-
for file_name, info in files_info.items():
|
| 116 |
-
total_result += '------------------------------------\n'
|
| 117 |
-
total_result += f'{file_name}\n\n'
|
| 118 |
-
total_result += f'{info["subtitle"]}'
|
| 119 |
-
gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
|
| 120 |
-
|
| 121 |
-
output_file_paths = [item["path"] for key, item in files_info.items()]
|
| 122 |
-
return [gr_str, output_file_paths]
|
| 123 |
-
|
| 124 |
-
except Exception as e:
|
| 125 |
-
print(f"Error: {str(e)}")
|
| 126 |
-
finally:
|
| 127 |
-
self.release_cuda_memory()
|
| 128 |
-
|
| 129 |
-
@staticmethod
|
| 130 |
-
def get_device():
|
| 131 |
-
if torch.cuda.is_available():
|
| 132 |
-
return "cuda"
|
| 133 |
-
elif torch.backends.mps.is_available():
|
| 134 |
-
return "mps"
|
| 135 |
-
else:
|
| 136 |
-
return "cpu"
|
| 137 |
-
|
| 138 |
-
@staticmethod
|
| 139 |
-
def release_cuda_memory():
|
| 140 |
-
if torch.cuda.is_available():
|
| 141 |
-
torch.cuda.empty_cache()
|
| 142 |
-
torch.cuda.reset_max_memory_allocated()
|
| 143 |
-
|
| 144 |
-
@staticmethod
|
| 145 |
-
def remove_input_files(file_paths: List[str]):
|
| 146 |
-
if not file_paths:
|
| 147 |
-
return
|
| 148 |
-
|
| 149 |
-
for file_path in file_paths:
|
| 150 |
-
if file_path and os.path.exists(file_path):
|
| 151 |
-
os.remove(file_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/utils/__init__.py
DELETED
|
File without changes
|
modules/utils/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (173 Bytes)
|
|
|
modules/utils/__pycache__/files_manager.cpython-310.pyc
DELETED
|
Binary file (1.43 kB)
|
|
|
modules/utils/__pycache__/subtitle_manager.cpython-310.pyc
DELETED
|
Binary file (3.38 kB)
|
|
|
modules/utils/__pycache__/youtube_manager.cpython-310.pyc
DELETED
|
Binary file (748 Bytes)
|
|
|
modules/utils/files_manager.py
DELETED
|
@@ -1,39 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import fnmatch
|
| 3 |
-
|
| 4 |
-
from gradio.utils import NamedString
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
def get_media_files(folder_path, include_sub_directory=False):
|
| 8 |
-
video_extensions = ['*.mp4', '*.mkv', '*.flv', '*.avi', '*.mov', '*.wmv']
|
| 9 |
-
audio_extensions = ['*.mp3', '*.wav', '*.aac', '*.flac', '*.ogg', '*.m4a']
|
| 10 |
-
media_extensions = video_extensions + audio_extensions
|
| 11 |
-
|
| 12 |
-
media_files = []
|
| 13 |
-
|
| 14 |
-
if include_sub_directory:
|
| 15 |
-
for root, _, files in os.walk(folder_path):
|
| 16 |
-
for extension in media_extensions:
|
| 17 |
-
media_files.extend(
|
| 18 |
-
os.path.join(root, file) for file in fnmatch.filter(files, extension)
|
| 19 |
-
if os.path.exists(os.path.join(root, file))
|
| 20 |
-
)
|
| 21 |
-
else:
|
| 22 |
-
for extension in media_extensions:
|
| 23 |
-
media_files.extend(
|
| 24 |
-
os.path.join(folder_path, file) for file in fnmatch.filter(os.listdir(folder_path), extension)
|
| 25 |
-
if os.path.isfile(os.path.join(folder_path, file)) and os.path.exists(os.path.join(folder_path, file))
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
return media_files
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def format_gradio_files(files: list):
|
| 32 |
-
if not files:
|
| 33 |
-
return files
|
| 34 |
-
|
| 35 |
-
gradio_files = []
|
| 36 |
-
for file in files:
|
| 37 |
-
gradio_files.append(NamedString(file))
|
| 38 |
-
return gradio_files
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/utils/subtitle_manager.py
DELETED
|
@@ -1,135 +0,0 @@
|
|
| 1 |
-
import re
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
def timeformat_srt(time):
|
| 5 |
-
hours = time // 3600
|
| 6 |
-
minutes = (time - hours * 3600) // 60
|
| 7 |
-
seconds = time - hours * 3600 - minutes * 60
|
| 8 |
-
milliseconds = (time - int(time)) * 1000
|
| 9 |
-
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}"
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def timeformat_vtt(time):
|
| 13 |
-
hours = time // 3600
|
| 14 |
-
minutes = (time - hours * 3600) // 60
|
| 15 |
-
seconds = time - hours * 3600 - minutes * 60
|
| 16 |
-
milliseconds = (time - int(time)) * 1000
|
| 17 |
-
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}"
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def write_file(subtitle, output_file):
|
| 21 |
-
with open(output_file, 'w', encoding='utf-8') as f:
|
| 22 |
-
f.write(subtitle)
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def get_srt(segments):
|
| 26 |
-
output = ""
|
| 27 |
-
for i, segment in enumerate(segments):
|
| 28 |
-
output += f"{i + 1}\n"
|
| 29 |
-
output += f"{timeformat_srt(segment['start'])} --> {timeformat_srt(segment['end'])}\n"
|
| 30 |
-
if segment['text'].startswith(' '):
|
| 31 |
-
segment['text'] = segment['text'][1:]
|
| 32 |
-
output += f"{segment['text']}\n\n"
|
| 33 |
-
return output
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def get_vtt(segments):
|
| 37 |
-
output = "WebVTT\n\n"
|
| 38 |
-
for i, segment in enumerate(segments):
|
| 39 |
-
output += f"{i + 1}\n"
|
| 40 |
-
output += f"{timeformat_vtt(segment['start'])} --> {timeformat_vtt(segment['end'])}\n"
|
| 41 |
-
if segment['text'].startswith(' '):
|
| 42 |
-
segment['text'] = segment['text'][1:]
|
| 43 |
-
output += f"{segment['text']}\n\n"
|
| 44 |
-
return output
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def get_txt(segments):
|
| 48 |
-
output = ""
|
| 49 |
-
for i, segment in enumerate(segments):
|
| 50 |
-
if segment['text'].startswith(' '):
|
| 51 |
-
segment['text'] = segment['text'][1:]
|
| 52 |
-
output += f"{segment['text']}\n"
|
| 53 |
-
return output
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def parse_srt(file_path):
|
| 57 |
-
"""Reads SRT file and returns as dict"""
|
| 58 |
-
with open(file_path, 'r', encoding='utf-8') as file:
|
| 59 |
-
srt_data = file.read()
|
| 60 |
-
|
| 61 |
-
data = []
|
| 62 |
-
blocks = srt_data.split('\n\n')
|
| 63 |
-
|
| 64 |
-
for block in blocks:
|
| 65 |
-
if block.strip() != '':
|
| 66 |
-
lines = block.strip().split('\n')
|
| 67 |
-
index = lines[0]
|
| 68 |
-
timestamp = lines[1]
|
| 69 |
-
sentence = ' '.join(lines[2:])
|
| 70 |
-
|
| 71 |
-
data.append({
|
| 72 |
-
"index": index,
|
| 73 |
-
"timestamp": timestamp,
|
| 74 |
-
"sentence": sentence
|
| 75 |
-
})
|
| 76 |
-
return data
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def parse_vtt(file_path):
|
| 80 |
-
"""Reads WebVTT file and returns as dict"""
|
| 81 |
-
with open(file_path, 'r', encoding='utf-8') as file:
|
| 82 |
-
webvtt_data = file.read()
|
| 83 |
-
|
| 84 |
-
data = []
|
| 85 |
-
blocks = webvtt_data.split('\n\n')
|
| 86 |
-
|
| 87 |
-
for block in blocks:
|
| 88 |
-
if block.strip() != '' and not block.strip().startswith("WebVTT"):
|
| 89 |
-
lines = block.strip().split('\n')
|
| 90 |
-
index = lines[0]
|
| 91 |
-
timestamp = lines[1]
|
| 92 |
-
sentence = ' '.join(lines[2:])
|
| 93 |
-
|
| 94 |
-
data.append({
|
| 95 |
-
"index": index,
|
| 96 |
-
"timestamp": timestamp,
|
| 97 |
-
"sentence": sentence
|
| 98 |
-
})
|
| 99 |
-
|
| 100 |
-
return data
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
def get_serialized_srt(dicts):
|
| 104 |
-
output = ""
|
| 105 |
-
for dic in dicts:
|
| 106 |
-
output += f'{dic["index"]}\n'
|
| 107 |
-
output += f'{dic["timestamp"]}\n'
|
| 108 |
-
output += f'{dic["sentence"]}\n\n'
|
| 109 |
-
return output
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
def get_serialized_vtt(dicts):
|
| 113 |
-
output = "WebVTT\n\n"
|
| 114 |
-
for dic in dicts:
|
| 115 |
-
output += f'{dic["index"]}\n'
|
| 116 |
-
output += f'{dic["timestamp"]}\n'
|
| 117 |
-
output += f'{dic["sentence"]}\n\n'
|
| 118 |
-
return output
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
def safe_filename(name):
|
| 122 |
-
from app import _args
|
| 123 |
-
INVALID_FILENAME_CHARS = r'[<>:"/\\|?*\x00-\x1f]'
|
| 124 |
-
safe_name = re.sub(INVALID_FILENAME_CHARS, '_', name)
|
| 125 |
-
if not _args.colab:
|
| 126 |
-
return safe_name
|
| 127 |
-
# Truncate the filename if it exceeds the max_length (20)
|
| 128 |
-
if len(safe_name) > 20:
|
| 129 |
-
file_extension = safe_name.split('.')[-1]
|
| 130 |
-
if len(file_extension) + 1 < 20:
|
| 131 |
-
truncated_name = safe_name[:20 - len(file_extension) - 1]
|
| 132 |
-
safe_name = truncated_name + '.' + file_extension
|
| 133 |
-
else:
|
| 134 |
-
safe_name = safe_name[:20]
|
| 135 |
-
return safe_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/utils/youtube_manager.py
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
from pytubefix import YouTube
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
def get_ytdata(link):
|
| 6 |
-
return YouTube(link)
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def get_ytmetas(link):
|
| 10 |
-
yt = YouTube(link)
|
| 11 |
-
return yt.thumbnail_url, yt.title, yt.description
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def get_ytaudio(ytdata: YouTube):
|
| 15 |
-
return ytdata.streams.get_audio_only().download(filename=os.path.join("modules", "yt_tmp.wav"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/vad/__init__.py
DELETED
|
File without changes
|
modules/vad/silero_vad.py
DELETED
|
@@ -1,264 +0,0 @@
|
|
| 1 |
-
# Adapted from https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py
|
| 2 |
-
|
| 3 |
-
from faster_whisper.vad import VadOptions, get_vad_model
|
| 4 |
-
import numpy as np
|
| 5 |
-
from typing import BinaryIO, Union, List, Optional, Tuple
|
| 6 |
-
import warnings
|
| 7 |
-
import faster_whisper
|
| 8 |
-
from faster_whisper.transcribe import SpeechTimestampsMap, Segment
|
| 9 |
-
import gradio as gr
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class SileroVAD:
|
| 13 |
-
def __init__(self):
|
| 14 |
-
self.sampling_rate = 16000
|
| 15 |
-
self.window_size_samples = 512
|
| 16 |
-
self.model = None
|
| 17 |
-
|
| 18 |
-
def run(self,
|
| 19 |
-
audio: Union[str, BinaryIO, np.ndarray],
|
| 20 |
-
vad_parameters: VadOptions,
|
| 21 |
-
progress: gr.Progress = gr.Progress()
|
| 22 |
-
) -> Tuple[np.ndarray, List[dict]]:
|
| 23 |
-
"""
|
| 24 |
-
Run VAD
|
| 25 |
-
|
| 26 |
-
Parameters
|
| 27 |
-
----------
|
| 28 |
-
audio: Union[str, BinaryIO, np.ndarray]
|
| 29 |
-
Audio path or file binary or Audio numpy array
|
| 30 |
-
vad_parameters:
|
| 31 |
-
Options for VAD processing.
|
| 32 |
-
progress: gr.Progress
|
| 33 |
-
Indicator to show progress directly in gradio.
|
| 34 |
-
|
| 35 |
-
Returns
|
| 36 |
-
----------
|
| 37 |
-
np.ndarray
|
| 38 |
-
Pre-processed audio with VAD
|
| 39 |
-
List[dict]
|
| 40 |
-
Chunks of speeches to be used to restore the timestamps later
|
| 41 |
-
"""
|
| 42 |
-
|
| 43 |
-
sampling_rate = self.sampling_rate
|
| 44 |
-
|
| 45 |
-
if not isinstance(audio, np.ndarray):
|
| 46 |
-
audio = faster_whisper.decode_audio(audio, sampling_rate=sampling_rate)
|
| 47 |
-
|
| 48 |
-
duration = audio.shape[0] / sampling_rate
|
| 49 |
-
duration_after_vad = duration
|
| 50 |
-
|
| 51 |
-
if vad_parameters is None:
|
| 52 |
-
vad_parameters = VadOptions()
|
| 53 |
-
elif isinstance(vad_parameters, dict):
|
| 54 |
-
vad_parameters = VadOptions(**vad_parameters)
|
| 55 |
-
speech_chunks = self.get_speech_timestamps(
|
| 56 |
-
audio=audio,
|
| 57 |
-
vad_options=vad_parameters,
|
| 58 |
-
progress=progress
|
| 59 |
-
)
|
| 60 |
-
audio = self.collect_chunks(audio, speech_chunks)
|
| 61 |
-
duration_after_vad = audio.shape[0] / sampling_rate
|
| 62 |
-
|
| 63 |
-
return audio, speech_chunks
|
| 64 |
-
|
| 65 |
-
def get_speech_timestamps(
|
| 66 |
-
self,
|
| 67 |
-
audio: np.ndarray,
|
| 68 |
-
vad_options: Optional[VadOptions] = None,
|
| 69 |
-
progress: gr.Progress = gr.Progress(),
|
| 70 |
-
**kwargs,
|
| 71 |
-
) -> List[dict]:
|
| 72 |
-
"""This method is used for splitting long audios into speech chunks using silero VAD.
|
| 73 |
-
|
| 74 |
-
Args:
|
| 75 |
-
audio: One dimensional float array.
|
| 76 |
-
vad_options: Options for VAD processing.
|
| 77 |
-
kwargs: VAD options passed as keyword arguments for backward compatibility.
|
| 78 |
-
progress: Gradio progress to indicate progress.
|
| 79 |
-
|
| 80 |
-
Returns:
|
| 81 |
-
List of dicts containing begin and end samples of each speech chunk.
|
| 82 |
-
"""
|
| 83 |
-
|
| 84 |
-
if self.model is None:
|
| 85 |
-
self.update_model()
|
| 86 |
-
|
| 87 |
-
if vad_options is None:
|
| 88 |
-
vad_options = VadOptions(**kwargs)
|
| 89 |
-
|
| 90 |
-
threshold = vad_options.threshold
|
| 91 |
-
min_speech_duration_ms = vad_options.min_speech_duration_ms
|
| 92 |
-
max_speech_duration_s = vad_options.max_speech_duration_s
|
| 93 |
-
min_silence_duration_ms = vad_options.min_silence_duration_ms
|
| 94 |
-
window_size_samples = self.window_size_samples
|
| 95 |
-
speech_pad_ms = vad_options.speech_pad_ms
|
| 96 |
-
sampling_rate = 16000
|
| 97 |
-
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
|
| 98 |
-
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
| 99 |
-
max_speech_samples = (
|
| 100 |
-
sampling_rate * max_speech_duration_s
|
| 101 |
-
- window_size_samples
|
| 102 |
-
- 2 * speech_pad_samples
|
| 103 |
-
)
|
| 104 |
-
min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
| 105 |
-
min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
|
| 106 |
-
|
| 107 |
-
audio_length_samples = len(audio)
|
| 108 |
-
|
| 109 |
-
state, context = self.model.get_initial_states(batch_size=1)
|
| 110 |
-
|
| 111 |
-
speech_probs = []
|
| 112 |
-
for current_start_sample in range(0, audio_length_samples, window_size_samples):
|
| 113 |
-
progress(current_start_sample/audio_length_samples, desc="Detecting speeches only using VAD...")
|
| 114 |
-
|
| 115 |
-
chunk = audio[current_start_sample: current_start_sample + window_size_samples]
|
| 116 |
-
if len(chunk) < window_size_samples:
|
| 117 |
-
chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
|
| 118 |
-
speech_prob, state, context = self.model(chunk, state, context, sampling_rate)
|
| 119 |
-
speech_probs.append(speech_prob)
|
| 120 |
-
|
| 121 |
-
triggered = False
|
| 122 |
-
speeches = []
|
| 123 |
-
current_speech = {}
|
| 124 |
-
neg_threshold = threshold - 0.15
|
| 125 |
-
|
| 126 |
-
# to save potential segment end (and tolerate some silence)
|
| 127 |
-
temp_end = 0
|
| 128 |
-
# to save potential segment limits in case of maximum segment size reached
|
| 129 |
-
prev_end = next_start = 0
|
| 130 |
-
|
| 131 |
-
for i, speech_prob in enumerate(speech_probs):
|
| 132 |
-
if (speech_prob >= threshold) and temp_end:
|
| 133 |
-
temp_end = 0
|
| 134 |
-
if next_start < prev_end:
|
| 135 |
-
next_start = window_size_samples * i
|
| 136 |
-
|
| 137 |
-
if (speech_prob >= threshold) and not triggered:
|
| 138 |
-
triggered = True
|
| 139 |
-
current_speech["start"] = window_size_samples * i
|
| 140 |
-
continue
|
| 141 |
-
|
| 142 |
-
if (
|
| 143 |
-
triggered
|
| 144 |
-
and (window_size_samples * i) - current_speech["start"] > max_speech_samples
|
| 145 |
-
):
|
| 146 |
-
if prev_end:
|
| 147 |
-
current_speech["end"] = prev_end
|
| 148 |
-
speeches.append(current_speech)
|
| 149 |
-
current_speech = {}
|
| 150 |
-
# previously reached silence (< neg_thres) and is still not speech (< thres)
|
| 151 |
-
if next_start < prev_end:
|
| 152 |
-
triggered = False
|
| 153 |
-
else:
|
| 154 |
-
current_speech["start"] = next_start
|
| 155 |
-
prev_end = next_start = temp_end = 0
|
| 156 |
-
else:
|
| 157 |
-
current_speech["end"] = window_size_samples * i
|
| 158 |
-
speeches.append(current_speech)
|
| 159 |
-
current_speech = {}
|
| 160 |
-
prev_end = next_start = temp_end = 0
|
| 161 |
-
triggered = False
|
| 162 |
-
continue
|
| 163 |
-
|
| 164 |
-
if (speech_prob < neg_threshold) and triggered:
|
| 165 |
-
if not temp_end:
|
| 166 |
-
temp_end = window_size_samples * i
|
| 167 |
-
# condition to avoid cutting in very short silence
|
| 168 |
-
if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech:
|
| 169 |
-
prev_end = temp_end
|
| 170 |
-
if (window_size_samples * i) - temp_end < min_silence_samples:
|
| 171 |
-
continue
|
| 172 |
-
else:
|
| 173 |
-
current_speech["end"] = temp_end
|
| 174 |
-
if (
|
| 175 |
-
current_speech["end"] - current_speech["start"]
|
| 176 |
-
) > min_speech_samples:
|
| 177 |
-
speeches.append(current_speech)
|
| 178 |
-
current_speech = {}
|
| 179 |
-
prev_end = next_start = temp_end = 0
|
| 180 |
-
triggered = False
|
| 181 |
-
continue
|
| 182 |
-
|
| 183 |
-
if (
|
| 184 |
-
current_speech
|
| 185 |
-
and (audio_length_samples - current_speech["start"]) > min_speech_samples
|
| 186 |
-
):
|
| 187 |
-
current_speech["end"] = audio_length_samples
|
| 188 |
-
speeches.append(current_speech)
|
| 189 |
-
|
| 190 |
-
for i, speech in enumerate(speeches):
|
| 191 |
-
if i == 0:
|
| 192 |
-
speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
|
| 193 |
-
if i != len(speeches) - 1:
|
| 194 |
-
silence_duration = speeches[i + 1]["start"] - speech["end"]
|
| 195 |
-
if silence_duration < 2 * speech_pad_samples:
|
| 196 |
-
speech["end"] += int(silence_duration // 2)
|
| 197 |
-
speeches[i + 1]["start"] = int(
|
| 198 |
-
max(0, speeches[i + 1]["start"] - silence_duration // 2)
|
| 199 |
-
)
|
| 200 |
-
else:
|
| 201 |
-
speech["end"] = int(
|
| 202 |
-
min(audio_length_samples, speech["end"] + speech_pad_samples)
|
| 203 |
-
)
|
| 204 |
-
speeches[i + 1]["start"] = int(
|
| 205 |
-
max(0, speeches[i + 1]["start"] - speech_pad_samples)
|
| 206 |
-
)
|
| 207 |
-
else:
|
| 208 |
-
speech["end"] = int(
|
| 209 |
-
min(audio_length_samples, speech["end"] + speech_pad_samples)
|
| 210 |
-
)
|
| 211 |
-
|
| 212 |
-
return speeches
|
| 213 |
-
|
| 214 |
-
def update_model(self):
|
| 215 |
-
self.model = get_vad_model()
|
| 216 |
-
|
| 217 |
-
@staticmethod
|
| 218 |
-
def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
|
| 219 |
-
"""Collects and concatenates audio chunks."""
|
| 220 |
-
if not chunks:
|
| 221 |
-
return np.array([], dtype=np.float32)
|
| 222 |
-
|
| 223 |
-
return np.concatenate([audio[chunk["start"]: chunk["end"]] for chunk in chunks])
|
| 224 |
-
|
| 225 |
-
@staticmethod
|
| 226 |
-
def format_timestamp(
|
| 227 |
-
seconds: float,
|
| 228 |
-
always_include_hours: bool = False,
|
| 229 |
-
decimal_marker: str = ".",
|
| 230 |
-
) -> str:
|
| 231 |
-
assert seconds >= 0, "non-negative timestamp expected"
|
| 232 |
-
milliseconds = round(seconds * 1000.0)
|
| 233 |
-
|
| 234 |
-
hours = milliseconds // 3_600_000
|
| 235 |
-
milliseconds -= hours * 3_600_000
|
| 236 |
-
|
| 237 |
-
minutes = milliseconds // 60_000
|
| 238 |
-
milliseconds -= minutes * 60_000
|
| 239 |
-
|
| 240 |
-
seconds = milliseconds // 1_000
|
| 241 |
-
milliseconds -= seconds * 1_000
|
| 242 |
-
|
| 243 |
-
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
| 244 |
-
return (
|
| 245 |
-
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
| 246 |
-
)
|
| 247 |
-
|
| 248 |
-
def restore_speech_timestamps(
|
| 249 |
-
self,
|
| 250 |
-
segments: List[dict],
|
| 251 |
-
speech_chunks: List[dict],
|
| 252 |
-
sampling_rate: Optional[int] = None,
|
| 253 |
-
) -> List[dict]:
|
| 254 |
-
if sampling_rate is None:
|
| 255 |
-
sampling_rate = self.sampling_rate
|
| 256 |
-
|
| 257 |
-
ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
|
| 258 |
-
|
| 259 |
-
for segment in segments:
|
| 260 |
-
segment["start"] = ts_map.get_original_time(segment["start"])
|
| 261 |
-
segment["end"] = ts_map.get_original_time(segment["end"])
|
| 262 |
-
|
| 263 |
-
return segments
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/whisper/__init__.py
DELETED
|
File without changes
|
modules/whisper/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (175 Bytes)
|
|
|
modules/whisper/__pycache__/faster_whisper_inference.cpython-310.pyc
DELETED
|
Binary file (6.51 kB)
|
|
|
modules/whisper/__pycache__/whisper_base.cpython-310.pyc
DELETED
|
Binary file (12.9 kB)
|
|
|
modules/whisper/__pycache__/whisper_factory.cpython-310.pyc
DELETED
|
Binary file (2.87 kB)
|
|
|
modules/whisper/__pycache__/whisper_parameter.cpython-310.pyc
DELETED
|
Binary file (3.68 kB)
|
|
|
modules/whisper/faster_whisper_inference.py
DELETED
|
@@ -1,191 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import time
|
| 3 |
-
import numpy as np
|
| 4 |
-
import torch
|
| 5 |
-
from typing import BinaryIO, Union, Tuple, List
|
| 6 |
-
import faster_whisper
|
| 7 |
-
from faster_whisper.vad import VadOptions
|
| 8 |
-
import ast
|
| 9 |
-
import ctranslate2
|
| 10 |
-
import whisper
|
| 11 |
-
import gradio as gr
|
| 12 |
-
from argparse import Namespace
|
| 13 |
-
|
| 14 |
-
from modules.whisper.whisper_parameter import *
|
| 15 |
-
from modules.whisper.whisper_base import WhisperBase
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class FasterWhisperInference(WhisperBase):
|
| 19 |
-
def __init__(self,
|
| 20 |
-
model_dir: str = os.path.join("models", "Whisper", "faster-whisper"),
|
| 21 |
-
diarization_model_dir: str = os.path.join("models", "Diarization"),
|
| 22 |
-
output_dir: str = os.path.join("outputs"),
|
| 23 |
-
):
|
| 24 |
-
super().__init__(
|
| 25 |
-
model_dir=model_dir,
|
| 26 |
-
diarization_model_dir=diarization_model_dir,
|
| 27 |
-
output_dir=output_dir
|
| 28 |
-
)
|
| 29 |
-
self.model_dir = model_dir
|
| 30 |
-
os.makedirs(self.model_dir, exist_ok=True)
|
| 31 |
-
|
| 32 |
-
self.model_paths = self.get_model_paths()
|
| 33 |
-
self.device = self.get_device()
|
| 34 |
-
self.available_models = self.model_paths.keys()
|
| 35 |
-
self.available_compute_types = ctranslate2.get_supported_compute_types(
|
| 36 |
-
"cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
|
| 37 |
-
|
| 38 |
-
def transcribe(self,
|
| 39 |
-
audio: Union[str, BinaryIO, np.ndarray],
|
| 40 |
-
progress: gr.Progress,
|
| 41 |
-
*whisper_params,
|
| 42 |
-
) -> Tuple[List[dict], float]:
|
| 43 |
-
"""
|
| 44 |
-
transcribe method for faster-whisper.
|
| 45 |
-
|
| 46 |
-
Parameters
|
| 47 |
-
----------
|
| 48 |
-
audio: Union[str, BinaryIO, np.ndarray]
|
| 49 |
-
Audio path or file binary or Audio numpy array
|
| 50 |
-
progress: gr.Progress
|
| 51 |
-
Indicator to show progress directly in gradio.
|
| 52 |
-
*whisper_params: tuple
|
| 53 |
-
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
| 54 |
-
|
| 55 |
-
Returns
|
| 56 |
-
----------
|
| 57 |
-
segments_result: List[dict]
|
| 58 |
-
list of dicts that includes start, end timestamps and transcribed text
|
| 59 |
-
elapsed_time: float
|
| 60 |
-
elapsed time for transcription
|
| 61 |
-
"""
|
| 62 |
-
start_time = time.time()
|
| 63 |
-
|
| 64 |
-
params = WhisperParameters.as_value(*whisper_params)
|
| 65 |
-
|
| 66 |
-
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
| 67 |
-
self.update_model(params.model_size, params.compute_type, progress)
|
| 68 |
-
|
| 69 |
-
# None parameters with Textboxes: https://github.com/gradio-app/gradio/issues/8723
|
| 70 |
-
if not params.initial_prompt:
|
| 71 |
-
params.initial_prompt = None
|
| 72 |
-
if not params.prefix:
|
| 73 |
-
params.prefix = None
|
| 74 |
-
if not params.hotwords:
|
| 75 |
-
params.hotwords = None
|
| 76 |
-
|
| 77 |
-
params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens)
|
| 78 |
-
|
| 79 |
-
segments, info = self.model.transcribe(
|
| 80 |
-
audio=audio,
|
| 81 |
-
language=params.lang,
|
| 82 |
-
task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
|
| 83 |
-
beam_size=params.beam_size,
|
| 84 |
-
log_prob_threshold=params.log_prob_threshold,
|
| 85 |
-
no_speech_threshold=params.no_speech_threshold,
|
| 86 |
-
best_of=params.best_of,
|
| 87 |
-
patience=params.patience,
|
| 88 |
-
temperature=params.temperature,
|
| 89 |
-
initial_prompt=params.initial_prompt,
|
| 90 |
-
compression_ratio_threshold=params.compression_ratio_threshold,
|
| 91 |
-
length_penalty=params.length_penalty,
|
| 92 |
-
repetition_penalty=params.repetition_penalty,
|
| 93 |
-
no_repeat_ngram_size=params.no_repeat_ngram_size,
|
| 94 |
-
prefix=params.prefix,
|
| 95 |
-
suppress_blank=params.suppress_blank,
|
| 96 |
-
suppress_tokens=params.suppress_tokens,
|
| 97 |
-
max_initial_timestamp=params.max_initial_timestamp,
|
| 98 |
-
word_timestamps=params.word_timestamps,
|
| 99 |
-
prepend_punctuations=params.prepend_punctuations,
|
| 100 |
-
append_punctuations=params.append_punctuations,
|
| 101 |
-
max_new_tokens=params.max_new_tokens,
|
| 102 |
-
chunk_length=params.chunk_length,
|
| 103 |
-
hallucination_silence_threshold=params.hallucination_silence_threshold,
|
| 104 |
-
hotwords=params.hotwords,
|
| 105 |
-
language_detection_threshold=params.language_detection_threshold,
|
| 106 |
-
language_detection_segments=params.language_detection_segments,
|
| 107 |
-
prompt_reset_on_temperature=params.prompt_reset_on_temperature,
|
| 108 |
-
)
|
| 109 |
-
progress(0, desc="Loading audio..")
|
| 110 |
-
|
| 111 |
-
segments_result = []
|
| 112 |
-
for segment in segments:
|
| 113 |
-
progress(segment.start / info.duration, desc="Transcribing..")
|
| 114 |
-
segments_result.append({
|
| 115 |
-
"start": segment.start,
|
| 116 |
-
"end": segment.end,
|
| 117 |
-
"text": segment.text
|
| 118 |
-
})
|
| 119 |
-
|
| 120 |
-
elapsed_time = time.time() - start_time
|
| 121 |
-
return segments_result, elapsed_time
|
| 122 |
-
|
| 123 |
-
def update_model(self,
|
| 124 |
-
model_size: str,
|
| 125 |
-
compute_type: str,
|
| 126 |
-
progress: gr.Progress
|
| 127 |
-
):
|
| 128 |
-
"""
|
| 129 |
-
Update current model setting
|
| 130 |
-
|
| 131 |
-
Parameters
|
| 132 |
-
----------
|
| 133 |
-
model_size: str
|
| 134 |
-
Size of whisper model
|
| 135 |
-
compute_type: str
|
| 136 |
-
Compute type for transcription.
|
| 137 |
-
see more info : https://opennmt.net/CTranslate2/quantization.html
|
| 138 |
-
progress: gr.Progress
|
| 139 |
-
Indicator to show progress directly in gradio.
|
| 140 |
-
"""
|
| 141 |
-
progress(0, desc="Initializing Model..")
|
| 142 |
-
self.current_model_size = self.model_paths[model_size]
|
| 143 |
-
self.current_compute_type = compute_type
|
| 144 |
-
self.model = faster_whisper.WhisperModel(
|
| 145 |
-
device=self.device,
|
| 146 |
-
model_size_or_path=self.current_model_size,
|
| 147 |
-
download_root=self.model_dir,
|
| 148 |
-
compute_type=self.current_compute_type
|
| 149 |
-
)
|
| 150 |
-
|
| 151 |
-
def get_model_paths(self):
|
| 152 |
-
"""
|
| 153 |
-
Get available models from models path including fine-tuned model.
|
| 154 |
-
|
| 155 |
-
Returns
|
| 156 |
-
----------
|
| 157 |
-
Name list of models
|
| 158 |
-
"""
|
| 159 |
-
model_paths = {model:model for model in whisper.available_models()}
|
| 160 |
-
faster_whisper_prefix = "models--Systran--faster-whisper-"
|
| 161 |
-
|
| 162 |
-
existing_models = os.listdir(self.model_dir)
|
| 163 |
-
wrong_dirs = [".locks"]
|
| 164 |
-
existing_models = list(set(existing_models) - set(wrong_dirs))
|
| 165 |
-
|
| 166 |
-
webui_dir = os.getcwd()
|
| 167 |
-
|
| 168 |
-
for model_name in existing_models:
|
| 169 |
-
if faster_whisper_prefix in model_name:
|
| 170 |
-
model_name = model_name[len(faster_whisper_prefix):]
|
| 171 |
-
|
| 172 |
-
if model_name not in whisper.available_models():
|
| 173 |
-
model_paths[model_name] = os.path.join(webui_dir, self.model_dir, model_name)
|
| 174 |
-
return model_paths
|
| 175 |
-
|
| 176 |
-
@staticmethod
|
| 177 |
-
def get_device():
|
| 178 |
-
if torch.cuda.is_available():
|
| 179 |
-
return "cuda"
|
| 180 |
-
else:
|
| 181 |
-
return "auto"
|
| 182 |
-
|
| 183 |
-
@staticmethod
|
| 184 |
-
def format_suppress_tokens_str(suppress_tokens_str: str) -> List[int]:
|
| 185 |
-
try:
|
| 186 |
-
suppress_tokens = ast.literal_eval(suppress_tokens_str)
|
| 187 |
-
if not isinstance(suppress_tokens, list) or not all(isinstance(item, int) for item in suppress_tokens):
|
| 188 |
-
raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
|
| 189 |
-
return suppress_tokens
|
| 190 |
-
except Exception as e:
|
| 191 |
-
raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/whisper/insanely_fast_whisper_inference.py
DELETED
|
@@ -1,185 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import time
|
| 3 |
-
import numpy as np
|
| 4 |
-
from typing import BinaryIO, Union, Tuple, List
|
| 5 |
-
import torch
|
| 6 |
-
from transformers import pipeline
|
| 7 |
-
from transformers.utils import is_flash_attn_2_available
|
| 8 |
-
import gradio as gr
|
| 9 |
-
from huggingface_hub import hf_hub_download
|
| 10 |
-
import whisper
|
| 11 |
-
from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
|
| 12 |
-
from argparse import Namespace
|
| 13 |
-
|
| 14 |
-
from modules.whisper.whisper_parameter import *
|
| 15 |
-
from modules.whisper.whisper_base import WhisperBase
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class InsanelyFastWhisperInference(WhisperBase):
|
| 19 |
-
def __init__(self,
|
| 20 |
-
model_dir: str = os.path.join("models", "Whisper", "insanely-fast-whisper"),
|
| 21 |
-
diarization_model_dir: str = os.path.join("models", "Diarization"),
|
| 22 |
-
output_dir: str = os.path.join("outputs"),
|
| 23 |
-
):
|
| 24 |
-
super().__init__(
|
| 25 |
-
model_dir=model_dir,
|
| 26 |
-
output_dir=output_dir,
|
| 27 |
-
diarization_model_dir=diarization_model_dir
|
| 28 |
-
)
|
| 29 |
-
self.model_dir = model_dir
|
| 30 |
-
os.makedirs(self.model_dir, exist_ok=True)
|
| 31 |
-
|
| 32 |
-
openai_models = whisper.available_models()
|
| 33 |
-
distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
|
| 34 |
-
self.available_models = openai_models + distil_models
|
| 35 |
-
self.available_compute_types = ["float16"]
|
| 36 |
-
|
| 37 |
-
def transcribe(self,
|
| 38 |
-
audio: Union[str, np.ndarray, torch.Tensor],
|
| 39 |
-
progress: gr.Progress,
|
| 40 |
-
*whisper_params,
|
| 41 |
-
) -> Tuple[List[dict], float]:
|
| 42 |
-
"""
|
| 43 |
-
transcribe method for faster-whisper.
|
| 44 |
-
|
| 45 |
-
Parameters
|
| 46 |
-
----------
|
| 47 |
-
audio: Union[str, BinaryIO, np.ndarray]
|
| 48 |
-
Audio path or file binary or Audio numpy array
|
| 49 |
-
progress: gr.Progress
|
| 50 |
-
Indicator to show progress directly in gradio.
|
| 51 |
-
*whisper_params: tuple
|
| 52 |
-
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
| 53 |
-
|
| 54 |
-
Returns
|
| 55 |
-
----------
|
| 56 |
-
segments_result: List[dict]
|
| 57 |
-
list of dicts that includes start, end timestamps and transcribed text
|
| 58 |
-
elapsed_time: float
|
| 59 |
-
elapsed time for transcription
|
| 60 |
-
"""
|
| 61 |
-
start_time = time.time()
|
| 62 |
-
params = WhisperParameters.as_value(*whisper_params)
|
| 63 |
-
|
| 64 |
-
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
| 65 |
-
self.update_model(params.model_size, params.compute_type, progress)
|
| 66 |
-
|
| 67 |
-
progress(0, desc="Transcribing...Progress is not shown in insanely-fast-whisper.")
|
| 68 |
-
with Progress(
|
| 69 |
-
TextColumn("[progress.description]{task.description}"),
|
| 70 |
-
BarColumn(style="yellow1", pulse_style="white"),
|
| 71 |
-
TimeElapsedColumn(),
|
| 72 |
-
) as progress:
|
| 73 |
-
progress.add_task("[yellow]Transcribing...", total=None)
|
| 74 |
-
|
| 75 |
-
segments = self.model(
|
| 76 |
-
inputs=audio,
|
| 77 |
-
return_timestamps=True,
|
| 78 |
-
chunk_length_s=params.chunk_length_s,
|
| 79 |
-
batch_size=params.batch_size,
|
| 80 |
-
generate_kwargs={
|
| 81 |
-
"language": params.lang,
|
| 82 |
-
"task": "translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
|
| 83 |
-
"no_speech_threshold": params.no_speech_threshold,
|
| 84 |
-
"temperature": params.temperature,
|
| 85 |
-
"compression_ratio_threshold": params.compression_ratio_threshold
|
| 86 |
-
}
|
| 87 |
-
)
|
| 88 |
-
|
| 89 |
-
segments_result = self.format_result(
|
| 90 |
-
transcribed_result=segments,
|
| 91 |
-
)
|
| 92 |
-
elapsed_time = time.time() - start_time
|
| 93 |
-
return segments_result, elapsed_time
|
| 94 |
-
|
| 95 |
-
def update_model(self,
|
| 96 |
-
model_size: str,
|
| 97 |
-
compute_type: str,
|
| 98 |
-
progress: gr.Progress,
|
| 99 |
-
):
|
| 100 |
-
"""
|
| 101 |
-
Update current model setting
|
| 102 |
-
|
| 103 |
-
Parameters
|
| 104 |
-
----------
|
| 105 |
-
model_size: str
|
| 106 |
-
Size of whisper model
|
| 107 |
-
compute_type: str
|
| 108 |
-
Compute type for transcription.
|
| 109 |
-
see more info : https://opennmt.net/CTranslate2/quantization.html
|
| 110 |
-
progress: gr.Progress
|
| 111 |
-
Indicator to show progress directly in gradio.
|
| 112 |
-
"""
|
| 113 |
-
progress(0, desc="Initializing Model..")
|
| 114 |
-
model_path = os.path.join(self.model_dir, model_size)
|
| 115 |
-
if not os.path.isdir(model_path) or not os.listdir(model_path):
|
| 116 |
-
self.download_model(
|
| 117 |
-
model_size=model_size,
|
| 118 |
-
download_root=model_path,
|
| 119 |
-
progress=progress
|
| 120 |
-
)
|
| 121 |
-
|
| 122 |
-
self.current_compute_type = compute_type
|
| 123 |
-
self.current_model_size = model_size
|
| 124 |
-
self.model = pipeline(
|
| 125 |
-
"automatic-speech-recognition",
|
| 126 |
-
model=os.path.join(self.model_dir, model_size),
|
| 127 |
-
torch_dtype=self.current_compute_type,
|
| 128 |
-
device=self.device,
|
| 129 |
-
model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"},
|
| 130 |
-
)
|
| 131 |
-
|
| 132 |
-
@staticmethod
|
| 133 |
-
def format_result(
|
| 134 |
-
transcribed_result: dict
|
| 135 |
-
) -> List[dict]:
|
| 136 |
-
"""
|
| 137 |
-
Format the transcription result of insanely_fast_whisper as the same with other implementation.
|
| 138 |
-
|
| 139 |
-
Parameters
|
| 140 |
-
----------
|
| 141 |
-
transcribed_result: dict
|
| 142 |
-
Transcription result of the insanely_fast_whisper
|
| 143 |
-
|
| 144 |
-
Returns
|
| 145 |
-
----------
|
| 146 |
-
result: List[dict]
|
| 147 |
-
Formatted result as the same with other implementation
|
| 148 |
-
"""
|
| 149 |
-
result = transcribed_result["chunks"]
|
| 150 |
-
for item in result:
|
| 151 |
-
start, end = item["timestamp"][0], item["timestamp"][1]
|
| 152 |
-
if end is None:
|
| 153 |
-
end = start
|
| 154 |
-
item["start"] = start
|
| 155 |
-
item["end"] = end
|
| 156 |
-
return result
|
| 157 |
-
|
| 158 |
-
@staticmethod
|
| 159 |
-
def download_model(
|
| 160 |
-
model_size: str,
|
| 161 |
-
download_root: str,
|
| 162 |
-
progress: gr.Progress
|
| 163 |
-
):
|
| 164 |
-
progress(0, 'Initializing model..')
|
| 165 |
-
print(f'Downloading {model_size} to "{download_root}"....')
|
| 166 |
-
|
| 167 |
-
os.makedirs(download_root, exist_ok=True)
|
| 168 |
-
download_list = [
|
| 169 |
-
"model.safetensors",
|
| 170 |
-
"config.json",
|
| 171 |
-
"generation_config.json",
|
| 172 |
-
"preprocessor_config.json",
|
| 173 |
-
"tokenizer.json",
|
| 174 |
-
"tokenizer_config.json",
|
| 175 |
-
"added_tokens.json",
|
| 176 |
-
"special_tokens_map.json",
|
| 177 |
-
"vocab.json",
|
| 178 |
-
]
|
| 179 |
-
|
| 180 |
-
if model_size.startswith("distil"):
|
| 181 |
-
repo_id = f"distil-whisper/{model_size}"
|
| 182 |
-
else:
|
| 183 |
-
repo_id = f"openai/whisper-{model_size}"
|
| 184 |
-
for item in download_list:
|
| 185 |
-
hf_hub_download(repo_id=repo_id, filename=item, local_dir=download_root)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/whisper/whisper_Inference.py
DELETED
|
@@ -1,101 +0,0 @@
|
|
| 1 |
-
import whisper
|
| 2 |
-
import gradio as gr
|
| 3 |
-
import time
|
| 4 |
-
from typing import BinaryIO, Union, Tuple, List
|
| 5 |
-
import numpy as np
|
| 6 |
-
import torch
|
| 7 |
-
import os
|
| 8 |
-
from argparse import Namespace
|
| 9 |
-
|
| 10 |
-
from modules.whisper.whisper_base import WhisperBase
|
| 11 |
-
from modules.whisper.whisper_parameter import *
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class WhisperInference(WhisperBase):
|
| 15 |
-
def __init__(self,
|
| 16 |
-
model_dir: str = os.path.join("models", "Whisper"),
|
| 17 |
-
diarization_model_dir: str = os.path.join("models", "Diarization"),
|
| 18 |
-
output_dir: str = os.path.join("outputs"),
|
| 19 |
-
):
|
| 20 |
-
super().__init__(
|
| 21 |
-
model_dir=model_dir,
|
| 22 |
-
output_dir=output_dir,
|
| 23 |
-
diarization_model_dir=diarization_model_dir
|
| 24 |
-
)
|
| 25 |
-
|
| 26 |
-
def transcribe(self,
|
| 27 |
-
audio: Union[str, np.ndarray, torch.Tensor],
|
| 28 |
-
progress: gr.Progress,
|
| 29 |
-
*whisper_params,
|
| 30 |
-
) -> Tuple[List[dict], float]:
|
| 31 |
-
"""
|
| 32 |
-
transcribe method for faster-whisper.
|
| 33 |
-
|
| 34 |
-
Parameters
|
| 35 |
-
----------
|
| 36 |
-
audio: Union[str, BinaryIO, np.ndarray]
|
| 37 |
-
Audio path or file binary or Audio numpy array
|
| 38 |
-
progress: gr.Progress
|
| 39 |
-
Indicator to show progress directly in gradio.
|
| 40 |
-
*whisper_params: tuple
|
| 41 |
-
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
| 42 |
-
|
| 43 |
-
Returns
|
| 44 |
-
----------
|
| 45 |
-
segments_result: List[dict]
|
| 46 |
-
list of dicts that includes start, end timestamps and transcribed text
|
| 47 |
-
elapsed_time: float
|
| 48 |
-
elapsed time for transcription
|
| 49 |
-
"""
|
| 50 |
-
start_time = time.time()
|
| 51 |
-
params = WhisperParameters.as_value(*whisper_params)
|
| 52 |
-
|
| 53 |
-
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
| 54 |
-
self.update_model(params.model_size, params.compute_type, progress)
|
| 55 |
-
|
| 56 |
-
def progress_callback(progress_value):
|
| 57 |
-
progress(progress_value, desc="Transcribing..")
|
| 58 |
-
|
| 59 |
-
segments_result = self.model.transcribe(audio=audio,
|
| 60 |
-
language=params.lang,
|
| 61 |
-
verbose=False,
|
| 62 |
-
beam_size=params.beam_size,
|
| 63 |
-
logprob_threshold=params.log_prob_threshold,
|
| 64 |
-
no_speech_threshold=params.no_speech_threshold,
|
| 65 |
-
task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
|
| 66 |
-
fp16=True if params.compute_type == "float16" else False,
|
| 67 |
-
best_of=params.best_of,
|
| 68 |
-
patience=params.patience,
|
| 69 |
-
temperature=params.temperature,
|
| 70 |
-
compression_ratio_threshold=params.compression_ratio_threshold,
|
| 71 |
-
progress_callback=progress_callback,)["segments"]
|
| 72 |
-
elapsed_time = time.time() - start_time
|
| 73 |
-
|
| 74 |
-
return segments_result, elapsed_time
|
| 75 |
-
|
| 76 |
-
def update_model(self,
|
| 77 |
-
model_size: str,
|
| 78 |
-
compute_type: str,
|
| 79 |
-
progress: gr.Progress,
|
| 80 |
-
):
|
| 81 |
-
"""
|
| 82 |
-
Update current model setting
|
| 83 |
-
|
| 84 |
-
Parameters
|
| 85 |
-
----------
|
| 86 |
-
model_size: str
|
| 87 |
-
Size of whisper model
|
| 88 |
-
compute_type: str
|
| 89 |
-
Compute type for transcription.
|
| 90 |
-
see more info : https://opennmt.net/CTranslate2/quantization.html
|
| 91 |
-
progress: gr.Progress
|
| 92 |
-
Indicator to show progress directly in gradio.
|
| 93 |
-
"""
|
| 94 |
-
progress(0, desc="Initializing Model..")
|
| 95 |
-
self.current_compute_type = compute_type
|
| 96 |
-
self.current_model_size = model_size
|
| 97 |
-
self.model = whisper.load_model(
|
| 98 |
-
name=model_size,
|
| 99 |
-
device=self.device,
|
| 100 |
-
download_root=self.model_dir
|
| 101 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/whisper/whisper_base.py
DELETED
|
@@ -1,436 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import torch
|
| 3 |
-
import whisper
|
| 4 |
-
import gradio as gr
|
| 5 |
-
from abc import ABC, abstractmethod
|
| 6 |
-
from typing import BinaryIO, Union, Tuple, List
|
| 7 |
-
import numpy as np
|
| 8 |
-
from datetime import datetime
|
| 9 |
-
from faster_whisper.vad import VadOptions
|
| 10 |
-
from dataclasses import astuple
|
| 11 |
-
|
| 12 |
-
from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
|
| 13 |
-
from modules.utils.youtube_manager import get_ytdata, get_ytaudio
|
| 14 |
-
from modules.utils.files_manager import get_media_files, format_gradio_files
|
| 15 |
-
from modules.whisper.whisper_parameter import *
|
| 16 |
-
from modules.diarize.diarizer import Diarizer
|
| 17 |
-
from modules.vad.silero_vad import SileroVAD
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class WhisperBase(ABC):
|
| 21 |
-
def __init__(self,
|
| 22 |
-
model_dir: str = os.path.join("models", "Whisper"),
|
| 23 |
-
diarization_model_dir: str = os.path.join("models", "Diarization"),
|
| 24 |
-
output_dir: str = os.path.join("outputs"),
|
| 25 |
-
):
|
| 26 |
-
self.model_dir = model_dir
|
| 27 |
-
self.output_dir = output_dir
|
| 28 |
-
os.makedirs(self.output_dir, exist_ok=True)
|
| 29 |
-
os.makedirs(self.model_dir, exist_ok=True)
|
| 30 |
-
self.diarizer = Diarizer(
|
| 31 |
-
model_dir=diarization_model_dir
|
| 32 |
-
)
|
| 33 |
-
self.vad = SileroVAD()
|
| 34 |
-
|
| 35 |
-
self.model = None
|
| 36 |
-
self.current_model_size = None
|
| 37 |
-
self.available_models = whisper.available_models()
|
| 38 |
-
self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
|
| 39 |
-
self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
|
| 40 |
-
self.device = self.get_device()
|
| 41 |
-
self.available_compute_types = ["float16", "float32"]
|
| 42 |
-
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
|
| 43 |
-
|
| 44 |
-
@abstractmethod
|
| 45 |
-
def transcribe(self,
|
| 46 |
-
audio: Union[str, BinaryIO, np.ndarray],
|
| 47 |
-
progress: gr.Progress,
|
| 48 |
-
*whisper_params,
|
| 49 |
-
):
|
| 50 |
-
"""Inference whisper model to transcribe"""
|
| 51 |
-
pass
|
| 52 |
-
|
| 53 |
-
@abstractmethod
|
| 54 |
-
def update_model(self,
|
| 55 |
-
model_size: str,
|
| 56 |
-
compute_type: str,
|
| 57 |
-
progress: gr.Progress
|
| 58 |
-
):
|
| 59 |
-
"""Initialize whisper model"""
|
| 60 |
-
pass
|
| 61 |
-
|
| 62 |
-
def run(self,
|
| 63 |
-
audio: Union[str, BinaryIO, np.ndarray],
|
| 64 |
-
progress: gr.Progress,
|
| 65 |
-
*whisper_params,
|
| 66 |
-
) -> Tuple[List[dict], float]:
|
| 67 |
-
"""
|
| 68 |
-
Run transcription with conditional pre-processing and post-processing.
|
| 69 |
-
The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
|
| 70 |
-
The diarization will be performed in post-processing, if enabled.
|
| 71 |
-
|
| 72 |
-
Parameters
|
| 73 |
-
----------
|
| 74 |
-
audio: Union[str, BinaryIO, np.ndarray]
|
| 75 |
-
Audio input. This can be file path or binary type.
|
| 76 |
-
progress: gr.Progress
|
| 77 |
-
Indicator to show progress directly in gradio.
|
| 78 |
-
*whisper_params: tuple
|
| 79 |
-
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
| 80 |
-
|
| 81 |
-
Returns
|
| 82 |
-
----------
|
| 83 |
-
segments_result: List[dict]
|
| 84 |
-
list of dicts that includes start, end timestamps and transcribed text
|
| 85 |
-
elapsed_time: float
|
| 86 |
-
elapsed time for running
|
| 87 |
-
"""
|
| 88 |
-
params = WhisperParameters.as_value(*whisper_params)
|
| 89 |
-
|
| 90 |
-
if params.lang == "Automatic Detection":
|
| 91 |
-
params.lang = None
|
| 92 |
-
else:
|
| 93 |
-
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
|
| 94 |
-
params.lang = language_code_dict[params.lang]
|
| 95 |
-
|
| 96 |
-
speech_chunks = None
|
| 97 |
-
if params.vad_filter:
|
| 98 |
-
# Explicit value set for float('inf') from gr.Number()
|
| 99 |
-
if params.max_speech_duration_s >= 9999:
|
| 100 |
-
params.max_speech_duration_s = float('inf')
|
| 101 |
-
|
| 102 |
-
vad_options = VadOptions(
|
| 103 |
-
threshold=params.threshold,
|
| 104 |
-
min_speech_duration_ms=params.min_speech_duration_ms,
|
| 105 |
-
max_speech_duration_s=params.max_speech_duration_s,
|
| 106 |
-
min_silence_duration_ms=params.min_silence_duration_ms,
|
| 107 |
-
speech_pad_ms=params.speech_pad_ms
|
| 108 |
-
)
|
| 109 |
-
|
| 110 |
-
audio, speech_chunks = self.vad.run(
|
| 111 |
-
audio=audio,
|
| 112 |
-
vad_parameters=vad_options,
|
| 113 |
-
progress=progress
|
| 114 |
-
)
|
| 115 |
-
|
| 116 |
-
result, elapsed_time = self.transcribe(
|
| 117 |
-
audio,
|
| 118 |
-
progress,
|
| 119 |
-
*astuple(params)
|
| 120 |
-
)
|
| 121 |
-
|
| 122 |
-
if params.vad_filter:
|
| 123 |
-
result = self.vad.restore_speech_timestamps(
|
| 124 |
-
segments=result,
|
| 125 |
-
speech_chunks=speech_chunks,
|
| 126 |
-
)
|
| 127 |
-
|
| 128 |
-
if params.is_diarize:
|
| 129 |
-
result, elapsed_time_diarization = self.diarizer.run(
|
| 130 |
-
audio=audio,
|
| 131 |
-
use_auth_token=params.hf_token,
|
| 132 |
-
transcribed_result=result,
|
| 133 |
-
)
|
| 134 |
-
elapsed_time += elapsed_time_diarization
|
| 135 |
-
return result, elapsed_time
|
| 136 |
-
|
| 137 |
-
def transcribe_file(self,
|
| 138 |
-
files: list,
|
| 139 |
-
input_folder_path: str,
|
| 140 |
-
file_format: str,
|
| 141 |
-
add_timestamp: bool,
|
| 142 |
-
progress=gr.Progress(),
|
| 143 |
-
*whisper_params,
|
| 144 |
-
) -> list:
|
| 145 |
-
"""
|
| 146 |
-
Write subtitle file from Files
|
| 147 |
-
|
| 148 |
-
Parameters
|
| 149 |
-
----------
|
| 150 |
-
files: list
|
| 151 |
-
List of files to transcribe from gr.Files()
|
| 152 |
-
input_folder_path: str
|
| 153 |
-
Input folder path to transcribe from gr.Textbox(). If this is provided, `files` will be ignored and
|
| 154 |
-
this will be used instead.
|
| 155 |
-
file_format: str
|
| 156 |
-
Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
|
| 157 |
-
add_timestamp: bool
|
| 158 |
-
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
|
| 159 |
-
progress: gr.Progress
|
| 160 |
-
Indicator to show progress directly in gradio.
|
| 161 |
-
*whisper_params: tuple
|
| 162 |
-
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
| 163 |
-
|
| 164 |
-
Returns
|
| 165 |
-
----------
|
| 166 |
-
result_str:
|
| 167 |
-
Result of transcription to return to gr.Textbox()
|
| 168 |
-
result_file_path:
|
| 169 |
-
Output file path to return to gr.Files()
|
| 170 |
-
"""
|
| 171 |
-
try:
|
| 172 |
-
if input_folder_path:
|
| 173 |
-
files = get_media_files(input_folder_path)
|
| 174 |
-
files = format_gradio_files(files)
|
| 175 |
-
|
| 176 |
-
files_info = {}
|
| 177 |
-
for file in files:
|
| 178 |
-
transcribed_segments, time_for_task = self.run(
|
| 179 |
-
file.name,
|
| 180 |
-
progress,
|
| 181 |
-
*whisper_params,
|
| 182 |
-
)
|
| 183 |
-
|
| 184 |
-
file_name, file_ext = os.path.splitext(os.path.basename(file.name))
|
| 185 |
-
subtitle, file_path = self.generate_and_write_file(
|
| 186 |
-
file_name=file_name,
|
| 187 |
-
transcribed_segments=transcribed_segments,
|
| 188 |
-
add_timestamp=add_timestamp,
|
| 189 |
-
file_format=file_format,
|
| 190 |
-
output_dir=self.output_dir
|
| 191 |
-
)
|
| 192 |
-
files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
|
| 193 |
-
|
| 194 |
-
total_result = ''
|
| 195 |
-
total_time = 0
|
| 196 |
-
for file_name, info in files_info.items():
|
| 197 |
-
total_result += '------------------------------------\n'
|
| 198 |
-
total_result += f'{file_name}\n\n'
|
| 199 |
-
total_result += f'{info["subtitle"]}'
|
| 200 |
-
total_time += info["time_for_task"]
|
| 201 |
-
|
| 202 |
-
result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
|
| 203 |
-
result_file_path = [info['path'] for info in files_info.values()]
|
| 204 |
-
|
| 205 |
-
return [result_str, result_file_path]
|
| 206 |
-
|
| 207 |
-
except Exception as e:
|
| 208 |
-
print(f"Error transcribing file: {e}")
|
| 209 |
-
finally:
|
| 210 |
-
self.release_cuda_memory()
|
| 211 |
-
if not files:
|
| 212 |
-
self.remove_input_files([file.name for file in files])
|
| 213 |
-
|
| 214 |
-
def transcribe_mic(self,
|
| 215 |
-
mic_audio: str,
|
| 216 |
-
file_format: str,
|
| 217 |
-
progress=gr.Progress(),
|
| 218 |
-
*whisper_params,
|
| 219 |
-
) -> list:
|
| 220 |
-
"""
|
| 221 |
-
Write subtitle file from microphone
|
| 222 |
-
|
| 223 |
-
Parameters
|
| 224 |
-
----------
|
| 225 |
-
mic_audio: str
|
| 226 |
-
Audio file path from gr.Microphone()
|
| 227 |
-
file_format: str
|
| 228 |
-
Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
|
| 229 |
-
progress: gr.Progress
|
| 230 |
-
Indicator to show progress directly in gradio.
|
| 231 |
-
*whisper_params: tuple
|
| 232 |
-
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
| 233 |
-
|
| 234 |
-
Returns
|
| 235 |
-
----------
|
| 236 |
-
result_str:
|
| 237 |
-
Result of transcription to return to gr.Textbox()
|
| 238 |
-
result_file_path:
|
| 239 |
-
Output file path to return to gr.Files()
|
| 240 |
-
"""
|
| 241 |
-
try:
|
| 242 |
-
progress(0, desc="Loading Audio..")
|
| 243 |
-
transcribed_segments, time_for_task = self.run(
|
| 244 |
-
mic_audio,
|
| 245 |
-
progress,
|
| 246 |
-
*whisper_params,
|
| 247 |
-
)
|
| 248 |
-
progress(1, desc="Completed!")
|
| 249 |
-
|
| 250 |
-
subtitle, result_file_path = self.generate_and_write_file(
|
| 251 |
-
file_name="Mic",
|
| 252 |
-
transcribed_segments=transcribed_segments,
|
| 253 |
-
add_timestamp=True,
|
| 254 |
-
file_format=file_format,
|
| 255 |
-
output_dir=self.output_dir
|
| 256 |
-
)
|
| 257 |
-
|
| 258 |
-
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
|
| 259 |
-
return [result_str, result_file_path]
|
| 260 |
-
except Exception as e:
|
| 261 |
-
print(f"Error transcribing file: {e}")
|
| 262 |
-
finally:
|
| 263 |
-
self.release_cuda_memory()
|
| 264 |
-
self.remove_input_files([mic_audio])
|
| 265 |
-
|
| 266 |
-
def transcribe_youtube(self,
|
| 267 |
-
youtube_link: str,
|
| 268 |
-
file_format: str,
|
| 269 |
-
add_timestamp: bool,
|
| 270 |
-
progress=gr.Progress(),
|
| 271 |
-
*whisper_params,
|
| 272 |
-
) -> list:
|
| 273 |
-
"""
|
| 274 |
-
Write subtitle file from Youtube
|
| 275 |
-
|
| 276 |
-
Parameters
|
| 277 |
-
----------
|
| 278 |
-
youtube_link: str
|
| 279 |
-
URL of the Youtube video to transcribe from gr.Textbox()
|
| 280 |
-
file_format: str
|
| 281 |
-
Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
|
| 282 |
-
add_timestamp: bool
|
| 283 |
-
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
|
| 284 |
-
progress: gr.Progress
|
| 285 |
-
Indicator to show progress directly in gradio.
|
| 286 |
-
*whisper_params: tuple
|
| 287 |
-
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
| 288 |
-
|
| 289 |
-
Returns
|
| 290 |
-
----------
|
| 291 |
-
result_str:
|
| 292 |
-
Result of transcription to return to gr.Textbox()
|
| 293 |
-
result_file_path:
|
| 294 |
-
Output file path to return to gr.Files()
|
| 295 |
-
"""
|
| 296 |
-
try:
|
| 297 |
-
progress(0, desc="Loading Audio from Youtube..")
|
| 298 |
-
yt = get_ytdata(youtube_link)
|
| 299 |
-
audio = get_ytaudio(yt)
|
| 300 |
-
|
| 301 |
-
transcribed_segments, time_for_task = self.run(
|
| 302 |
-
audio,
|
| 303 |
-
progress,
|
| 304 |
-
*whisper_params,
|
| 305 |
-
)
|
| 306 |
-
|
| 307 |
-
progress(1, desc="Completed!")
|
| 308 |
-
|
| 309 |
-
file_name = safe_filename(yt.title)
|
| 310 |
-
subtitle, result_file_path = self.generate_and_write_file(
|
| 311 |
-
file_name=file_name,
|
| 312 |
-
transcribed_segments=transcribed_segments,
|
| 313 |
-
add_timestamp=add_timestamp,
|
| 314 |
-
file_format=file_format,
|
| 315 |
-
output_dir=self.output_dir
|
| 316 |
-
)
|
| 317 |
-
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
|
| 318 |
-
|
| 319 |
-
return [result_str, result_file_path]
|
| 320 |
-
|
| 321 |
-
except Exception as e:
|
| 322 |
-
print(f"Error transcribing file: {e}")
|
| 323 |
-
finally:
|
| 324 |
-
try:
|
| 325 |
-
if 'yt' not in locals():
|
| 326 |
-
yt = get_ytdata(youtube_link)
|
| 327 |
-
file_path = get_ytaudio(yt)
|
| 328 |
-
else:
|
| 329 |
-
file_path = get_ytaudio(yt)
|
| 330 |
-
|
| 331 |
-
self.release_cuda_memory()
|
| 332 |
-
self.remove_input_files([file_path])
|
| 333 |
-
except Exception as cleanup_error:
|
| 334 |
-
pass
|
| 335 |
-
|
| 336 |
-
@staticmethod
|
| 337 |
-
def generate_and_write_file(file_name: str,
|
| 338 |
-
transcribed_segments: list,
|
| 339 |
-
add_timestamp: bool,
|
| 340 |
-
file_format: str,
|
| 341 |
-
output_dir: str
|
| 342 |
-
) -> str:
|
| 343 |
-
"""
|
| 344 |
-
Writes subtitle file
|
| 345 |
-
|
| 346 |
-
Parameters
|
| 347 |
-
----------
|
| 348 |
-
file_name: str
|
| 349 |
-
Output file name
|
| 350 |
-
transcribed_segments: list
|
| 351 |
-
Text segments transcribed from audio
|
| 352 |
-
add_timestamp: bool
|
| 353 |
-
Determines whether to add a timestamp to the end of the filename.
|
| 354 |
-
file_format: str
|
| 355 |
-
File format to write. Supported formats: [SRT, WebVTT, txt]
|
| 356 |
-
output_dir: str
|
| 357 |
-
Directory path of the output
|
| 358 |
-
|
| 359 |
-
Returns
|
| 360 |
-
----------
|
| 361 |
-
content: str
|
| 362 |
-
Result of the transcription
|
| 363 |
-
output_path: str
|
| 364 |
-
output file path
|
| 365 |
-
"""
|
| 366 |
-
if add_timestamp:
|
| 367 |
-
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
| 368 |
-
output_path = os.path.join(output_dir, f"{file_name}-{timestamp}")
|
| 369 |
-
else:
|
| 370 |
-
output_path = os.path.join(output_dir, f"{file_name}")
|
| 371 |
-
|
| 372 |
-
if file_format == "SRT":
|
| 373 |
-
content = get_srt(transcribed_segments)
|
| 374 |
-
output_path += '.srt'
|
| 375 |
-
|
| 376 |
-
elif file_format == "WebVTT":
|
| 377 |
-
content = get_vtt(transcribed_segments)
|
| 378 |
-
output_path += '.vtt'
|
| 379 |
-
|
| 380 |
-
elif file_format == "txt":
|
| 381 |
-
content = get_txt(transcribed_segments)
|
| 382 |
-
output_path += '.txt'
|
| 383 |
-
|
| 384 |
-
write_file(content, output_path)
|
| 385 |
-
return content, output_path
|
| 386 |
-
|
| 387 |
-
@staticmethod
|
| 388 |
-
def format_time(elapsed_time: float) -> str:
|
| 389 |
-
"""
|
| 390 |
-
Get {hours} {minutes} {seconds} time format string
|
| 391 |
-
|
| 392 |
-
Parameters
|
| 393 |
-
----------
|
| 394 |
-
elapsed_time: str
|
| 395 |
-
Elapsed time for transcription
|
| 396 |
-
|
| 397 |
-
Returns
|
| 398 |
-
----------
|
| 399 |
-
Time format string
|
| 400 |
-
"""
|
| 401 |
-
hours, rem = divmod(elapsed_time, 3600)
|
| 402 |
-
minutes, seconds = divmod(rem, 60)
|
| 403 |
-
|
| 404 |
-
time_str = ""
|
| 405 |
-
if hours:
|
| 406 |
-
time_str += f"{hours} hours "
|
| 407 |
-
if minutes:
|
| 408 |
-
time_str += f"{minutes} minutes "
|
| 409 |
-
seconds = round(seconds)
|
| 410 |
-
time_str += f"{seconds} seconds"
|
| 411 |
-
|
| 412 |
-
return time_str.strip()
|
| 413 |
-
|
| 414 |
-
@staticmethod
|
| 415 |
-
def get_device():
|
| 416 |
-
if torch.cuda.is_available():
|
| 417 |
-
return "cuda"
|
| 418 |
-
elif torch.backends.mps.is_available():
|
| 419 |
-
return "mps"
|
| 420 |
-
else:
|
| 421 |
-
return "cpu"
|
| 422 |
-
|
| 423 |
-
@staticmethod
|
| 424 |
-
def release_cuda_memory():
|
| 425 |
-
if torch.cuda.is_available():
|
| 426 |
-
torch.cuda.empty_cache()
|
| 427 |
-
torch.cuda.reset_max_memory_allocated()
|
| 428 |
-
|
| 429 |
-
@staticmethod
|
| 430 |
-
def remove_input_files(file_paths: List[str]):
|
| 431 |
-
if not file_paths:
|
| 432 |
-
return
|
| 433 |
-
|
| 434 |
-
for file_path in file_paths:
|
| 435 |
-
if file_path and os.path.exists(file_path):
|
| 436 |
-
os.remove(file_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/whisper/whisper_factory.py
DELETED
|
@@ -1,81 +0,0 @@
|
|
| 1 |
-
from typing import Optional
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
from modules.whisper.faster_whisper_inference import FasterWhisperInference
|
| 5 |
-
from modules.whisper.whisper_Inference import WhisperInference
|
| 6 |
-
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
|
| 7 |
-
from modules.whisper.whisper_base import WhisperBase
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class WhisperFactory:
|
| 11 |
-
@staticmethod
|
| 12 |
-
def create_whisper_inference(
|
| 13 |
-
whisper_type: str,
|
| 14 |
-
whisper_model_dir: str = os.path.join("models", "Whisper"),
|
| 15 |
-
faster_whisper_model_dir: str = os.path.join("models", "Whisper", "faster-whisper"),
|
| 16 |
-
insanely_fast_whisper_model_dir: str = os.path.join("models", "Whisper", "insanely-fast-whisper"),
|
| 17 |
-
diarization_model_dir: str = os.path.join("models", "Diarization"),
|
| 18 |
-
output_dir: str = os.path.join("outputs"),
|
| 19 |
-
) -> "WhisperBase":
|
| 20 |
-
"""
|
| 21 |
-
Create a whisper inference class based on the provided whisper_type.
|
| 22 |
-
|
| 23 |
-
Parameters
|
| 24 |
-
----------
|
| 25 |
-
whisper_type : str
|
| 26 |
-
The type of Whisper implementation to use. Supported values (case-insensitive):
|
| 27 |
-
- "faster-whisper": https://github.com/openai/whisper
|
| 28 |
-
- "whisper": https://github.com/openai/whisper
|
| 29 |
-
- "insanely-fast-whisper": https://github.com/Vaibhavs10/insanely-fast-whisper
|
| 30 |
-
whisper_model_dir : str
|
| 31 |
-
Directory path for the Whisper model.
|
| 32 |
-
faster_whisper_model_dir : str
|
| 33 |
-
Directory path for the Faster Whisper model.
|
| 34 |
-
insanely_fast_whisper_model_dir : str
|
| 35 |
-
Directory path for the Insanely Fast Whisper model.
|
| 36 |
-
diarization_model_dir : str
|
| 37 |
-
Directory path for the diarization model.
|
| 38 |
-
output_dir : str
|
| 39 |
-
Directory path where output files will be saved.
|
| 40 |
-
|
| 41 |
-
Returns
|
| 42 |
-
-------
|
| 43 |
-
WhisperBase
|
| 44 |
-
An instance of the appropriate whisper inference class based on the whisper_type.
|
| 45 |
-
"""
|
| 46 |
-
# Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144
|
| 47 |
-
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
| 48 |
-
|
| 49 |
-
whisper_type = whisper_type.lower().strip()
|
| 50 |
-
|
| 51 |
-
faster_whisper_typos = ["faster_whisper", "faster-whisper", "fasterwhisper"]
|
| 52 |
-
whisper_typos = ["whisper"]
|
| 53 |
-
insanely_fast_whisper_typos = [
|
| 54 |
-
"insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
|
| 55 |
-
"insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"
|
| 56 |
-
]
|
| 57 |
-
|
| 58 |
-
if whisper_type in faster_whisper_typos:
|
| 59 |
-
return FasterWhisperInference(
|
| 60 |
-
model_dir=faster_whisper_model_dir,
|
| 61 |
-
output_dir=output_dir,
|
| 62 |
-
diarization_model_dir=diarization_model_dir
|
| 63 |
-
)
|
| 64 |
-
elif whisper_type in whisper_typos:
|
| 65 |
-
return WhisperInference(
|
| 66 |
-
model_dir=whisper_model_dir,
|
| 67 |
-
output_dir=output_dir,
|
| 68 |
-
diarization_model_dir=diarization_model_dir
|
| 69 |
-
)
|
| 70 |
-
elif whisper_type in insanely_fast_whisper_typos:
|
| 71 |
-
return InsanelyFastWhisperInference(
|
| 72 |
-
model_dir=insanely_fast_whisper_model_dir,
|
| 73 |
-
output_dir=output_dir,
|
| 74 |
-
diarization_model_dir=diarization_model_dir
|
| 75 |
-
)
|
| 76 |
-
else:
|
| 77 |
-
return FasterWhisperInference(
|
| 78 |
-
model_dir=faster_whisper_model_dir,
|
| 79 |
-
output_dir=output_dir,
|
| 80 |
-
diarization_model_dir=diarization_model_dir
|
| 81 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/whisper/whisper_parameter.py
DELETED
|
@@ -1,277 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass, fields
|
| 2 |
-
import gradio as gr
|
| 3 |
-
from typing import Optional
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
@dataclass
|
| 7 |
-
class WhisperParameters:
|
| 8 |
-
model_size: gr.Dropdown
|
| 9 |
-
lang: gr.Dropdown
|
| 10 |
-
is_translate: gr.Checkbox
|
| 11 |
-
beam_size: gr.Number
|
| 12 |
-
log_prob_threshold: gr.Number
|
| 13 |
-
no_speech_threshold: gr.Number
|
| 14 |
-
compute_type: gr.Dropdown
|
| 15 |
-
best_of: gr.Number
|
| 16 |
-
patience: gr.Number
|
| 17 |
-
condition_on_previous_text: gr.Checkbox
|
| 18 |
-
prompt_reset_on_temperature: gr.Slider
|
| 19 |
-
initial_prompt: gr.Textbox
|
| 20 |
-
temperature: gr.Slider
|
| 21 |
-
compression_ratio_threshold: gr.Number
|
| 22 |
-
vad_filter: gr.Checkbox
|
| 23 |
-
threshold: gr.Slider
|
| 24 |
-
min_speech_duration_ms: gr.Number
|
| 25 |
-
max_speech_duration_s: gr.Number
|
| 26 |
-
min_silence_duration_ms: gr.Number
|
| 27 |
-
speech_pad_ms: gr.Number
|
| 28 |
-
chunk_length_s: gr.Number
|
| 29 |
-
batch_size: gr.Number
|
| 30 |
-
is_diarize: gr.Checkbox
|
| 31 |
-
hf_token: gr.Textbox
|
| 32 |
-
diarization_device: gr.Dropdown
|
| 33 |
-
length_penalty: gr.Number
|
| 34 |
-
repetition_penalty: gr.Number
|
| 35 |
-
no_repeat_ngram_size: gr.Number
|
| 36 |
-
prefix: gr.Textbox
|
| 37 |
-
suppress_blank: gr.Checkbox
|
| 38 |
-
suppress_tokens: gr.Textbox
|
| 39 |
-
max_initial_timestamp: gr.Number
|
| 40 |
-
word_timestamps: gr.Checkbox
|
| 41 |
-
prepend_punctuations: gr.Textbox
|
| 42 |
-
append_punctuations: gr.Textbox
|
| 43 |
-
max_new_tokens: gr.Number
|
| 44 |
-
chunk_length: gr.Number
|
| 45 |
-
hallucination_silence_threshold: gr.Number
|
| 46 |
-
hotwords: gr.Textbox
|
| 47 |
-
language_detection_threshold: gr.Number
|
| 48 |
-
language_detection_segments: gr.Number
|
| 49 |
-
"""
|
| 50 |
-
A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
|
| 51 |
-
This data class is used to mitigate the key-value problem between Gradio components and function parameters.
|
| 52 |
-
Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471
|
| 53 |
-
See more about Gradio pre-processing: https://www.gradio.app/docs/components
|
| 54 |
-
|
| 55 |
-
Attributes
|
| 56 |
-
----------
|
| 57 |
-
model_size: gr.Dropdown
|
| 58 |
-
Whisper model size.
|
| 59 |
-
|
| 60 |
-
lang: gr.Dropdown
|
| 61 |
-
Source language of the file to transcribe.
|
| 62 |
-
|
| 63 |
-
is_translate: gr.Checkbox
|
| 64 |
-
Boolean value that determines whether to translate to English.
|
| 65 |
-
It's Whisper's feature to translate speech from another language directly into English end-to-end.
|
| 66 |
-
|
| 67 |
-
beam_size: gr.Number
|
| 68 |
-
Int value that is used for decoding option.
|
| 69 |
-
|
| 70 |
-
log_prob_threshold: gr.Number
|
| 71 |
-
If the average log probability over sampled tokens is below this value, treat as failed.
|
| 72 |
-
|
| 73 |
-
no_speech_threshold: gr.Number
|
| 74 |
-
If the no_speech probability is higher than this value AND
|
| 75 |
-
the average log probability over sampled tokens is below `log_prob_threshold`,
|
| 76 |
-
consider the segment as silent.
|
| 77 |
-
|
| 78 |
-
compute_type: gr.Dropdown
|
| 79 |
-
compute type for transcription.
|
| 80 |
-
see more info : https://opennmt.net/CTranslate2/quantization.html
|
| 81 |
-
|
| 82 |
-
best_of: gr.Number
|
| 83 |
-
Number of candidates when sampling with non-zero temperature.
|
| 84 |
-
|
| 85 |
-
patience: gr.Number
|
| 86 |
-
Beam search patience factor.
|
| 87 |
-
|
| 88 |
-
condition_on_previous_text: gr.Checkbox
|
| 89 |
-
if True, the previous output of the model is provided as a prompt for the next window;
|
| 90 |
-
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
| 91 |
-
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
| 92 |
-
|
| 93 |
-
initial_prompt: gr.Textbox
|
| 94 |
-
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
| 95 |
-
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
| 96 |
-
to make it more likely to predict those word correctly.
|
| 97 |
-
|
| 98 |
-
temperature: gr.Slider
|
| 99 |
-
Temperature for sampling. It can be a tuple of temperatures,
|
| 100 |
-
which will be successively used upon failures according to either
|
| 101 |
-
`compression_ratio_threshold` or `log_prob_threshold`.
|
| 102 |
-
|
| 103 |
-
compression_ratio_threshold: gr.Number
|
| 104 |
-
If the gzip compression ratio is above this value, treat as failed
|
| 105 |
-
|
| 106 |
-
vad_filter: gr.Checkbox
|
| 107 |
-
Enable the voice activity detection (VAD) to filter out parts of the audio
|
| 108 |
-
without speech. This step is using the Silero VAD model
|
| 109 |
-
https://github.com/snakers4/silero-vad.
|
| 110 |
-
|
| 111 |
-
threshold: gr.Slider
|
| 112 |
-
This parameter is related with Silero VAD. Speech threshold.
|
| 113 |
-
Silero VAD outputs speech probabilities for each audio chunk,
|
| 114 |
-
probabilities ABOVE this value are considered as SPEECH. It is better to tune this
|
| 115 |
-
parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
| 116 |
-
|
| 117 |
-
min_speech_duration_ms: gr.Number
|
| 118 |
-
This parameter is related with Silero VAD. Final speech chunks shorter min_speech_duration_ms are thrown out.
|
| 119 |
-
|
| 120 |
-
max_speech_duration_s: gr.Number
|
| 121 |
-
This parameter is related with Silero VAD. Maximum duration of speech chunks in seconds. Chunks longer
|
| 122 |
-
than max_speech_duration_s will be split at the timestamp of the last silence that
|
| 123 |
-
lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
|
| 124 |
-
split aggressively just before max_speech_duration_s.
|
| 125 |
-
|
| 126 |
-
min_silence_duration_ms: gr.Number
|
| 127 |
-
This parameter is related with Silero VAD. In the end of each speech chunk wait for min_silence_duration_ms
|
| 128 |
-
before separating it
|
| 129 |
-
|
| 130 |
-
speech_pad_ms: gr.Number
|
| 131 |
-
This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side
|
| 132 |
-
|
| 133 |
-
chunk_length_s: gr.Number
|
| 134 |
-
This parameter is related with insanely-fast-whisper pipe.
|
| 135 |
-
Maximum length of each chunk
|
| 136 |
-
|
| 137 |
-
batch_size: gr.Number
|
| 138 |
-
This parameter is related with insanely-fast-whisper pipe. Batch size to pass to the pipe
|
| 139 |
-
|
| 140 |
-
is_diarize: gr.Checkbox
|
| 141 |
-
This parameter is related with whisperx. Boolean value that determines whether to diarize or not.
|
| 142 |
-
|
| 143 |
-
hf_token: gr.Textbox
|
| 144 |
-
This parameter is related with whisperx. Huggingface token is needed to download diarization models.
|
| 145 |
-
Read more about : https://huggingface.co/pyannote/speaker-diarization-3.1#requirements
|
| 146 |
-
|
| 147 |
-
diarization_device: gr.Dropdown
|
| 148 |
-
This parameter is related with whisperx. Device to run diarization model
|
| 149 |
-
|
| 150 |
-
length_penalty:
|
| 151 |
-
This parameter is related to faster-whisper. Exponential length penalty constant.
|
| 152 |
-
|
| 153 |
-
repetition_penalty:
|
| 154 |
-
This parameter is related to faster-whisper. Penalty applied to the score of previously generated tokens
|
| 155 |
-
(set > 1 to penalize).
|
| 156 |
-
|
| 157 |
-
no_repeat_ngram_size:
|
| 158 |
-
This parameter is related to faster-whisper. Prevent repetitions of n-grams with this size (set 0 to disable).
|
| 159 |
-
|
| 160 |
-
prefix:
|
| 161 |
-
This parameter is related to faster-whisper. Optional text to provide as a prefix for the first window.
|
| 162 |
-
|
| 163 |
-
suppress_blank:
|
| 164 |
-
This parameter is related to faster-whisper. Suppress blank outputs at the beginning of the sampling.
|
| 165 |
-
|
| 166 |
-
suppress_tokens:
|
| 167 |
-
This parameter is related to faster-whisper. List of token IDs to suppress. -1 will suppress a default set
|
| 168 |
-
of symbols as defined in the model config.json file.
|
| 169 |
-
|
| 170 |
-
max_initial_timestamp:
|
| 171 |
-
This parameter is related to faster-whisper. The initial timestamp cannot be later than this.
|
| 172 |
-
|
| 173 |
-
word_timestamps:
|
| 174 |
-
This parameter is related to faster-whisper. Extract word-level timestamps using the cross-attention pattern
|
| 175 |
-
and dynamic time warping, and include the timestamps for each word in each segment.
|
| 176 |
-
|
| 177 |
-
prepend_punctuations:
|
| 178 |
-
This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols
|
| 179 |
-
with the next word.
|
| 180 |
-
|
| 181 |
-
append_punctuations:
|
| 182 |
-
This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols
|
| 183 |
-
with the previous word.
|
| 184 |
-
|
| 185 |
-
max_new_tokens:
|
| 186 |
-
This parameter is related to faster-whisper. Maximum number of new tokens to generate per-chunk. If not set,
|
| 187 |
-
the maximum will be set by the default max_length.
|
| 188 |
-
|
| 189 |
-
chunk_length:
|
| 190 |
-
This parameter is related to faster-whisper. The length of audio segments. If it is not None, it will overwrite the
|
| 191 |
-
default chunk_length of the FeatureExtractor.
|
| 192 |
-
|
| 193 |
-
hallucination_silence_threshold:
|
| 194 |
-
This parameter is related to faster-whisper. When word_timestamps is True, skip silent periods longer than this threshold
|
| 195 |
-
(in seconds) when a possible hallucination is detected.
|
| 196 |
-
|
| 197 |
-
hotwords:
|
| 198 |
-
This parameter is related to faster-whisper. Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.
|
| 199 |
-
|
| 200 |
-
language_detection_threshold:
|
| 201 |
-
This parameter is related to faster-whisper. If the maximum probability of the language tokens is higher than this value, the language is detected.
|
| 202 |
-
|
| 203 |
-
language_detection_segments:
|
| 204 |
-
This parameter is related to faster-whisper. Number of segments to consider for the language detection.
|
| 205 |
-
"""
|
| 206 |
-
|
| 207 |
-
def as_list(self) -> list:
|
| 208 |
-
"""
|
| 209 |
-
Converts the data class attributes into a list, Use in Gradio UI before Gradio pre-processing.
|
| 210 |
-
See more about Gradio pre-processing: : https://www.gradio.app/docs/components
|
| 211 |
-
|
| 212 |
-
Returns
|
| 213 |
-
----------
|
| 214 |
-
A list of Gradio components
|
| 215 |
-
"""
|
| 216 |
-
return [getattr(self, f.name) for f in fields(self)]
|
| 217 |
-
|
| 218 |
-
@staticmethod
|
| 219 |
-
def as_value(*args) -> 'WhisperValues':
|
| 220 |
-
"""
|
| 221 |
-
To use Whisper parameters in function after Gradio post-processing.
|
| 222 |
-
See more about Gradio post-processing: : https://www.gradio.app/docs/components
|
| 223 |
-
|
| 224 |
-
Returns
|
| 225 |
-
----------
|
| 226 |
-
WhisperValues
|
| 227 |
-
Data class that has values of parameters
|
| 228 |
-
"""
|
| 229 |
-
return WhisperValues(*args)
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
@dataclass
|
| 233 |
-
class WhisperValues:
|
| 234 |
-
model_size: str
|
| 235 |
-
lang: str
|
| 236 |
-
is_translate: bool
|
| 237 |
-
beam_size: int
|
| 238 |
-
log_prob_threshold: float
|
| 239 |
-
no_speech_threshold: float
|
| 240 |
-
compute_type: str
|
| 241 |
-
best_of: int
|
| 242 |
-
patience: float
|
| 243 |
-
condition_on_previous_text: bool
|
| 244 |
-
prompt_reset_on_temperature: float
|
| 245 |
-
initial_prompt: Optional[str]
|
| 246 |
-
temperature: float
|
| 247 |
-
compression_ratio_threshold: float
|
| 248 |
-
vad_filter: bool
|
| 249 |
-
threshold: float
|
| 250 |
-
min_speech_duration_ms: int
|
| 251 |
-
max_speech_duration_s: float
|
| 252 |
-
min_silence_duration_ms: int
|
| 253 |
-
speech_pad_ms: int
|
| 254 |
-
chunk_length_s: int
|
| 255 |
-
batch_size: int
|
| 256 |
-
is_diarize: bool
|
| 257 |
-
hf_token: str
|
| 258 |
-
diarization_device: str
|
| 259 |
-
length_penalty: float
|
| 260 |
-
repetition_penalty: float
|
| 261 |
-
no_repeat_ngram_size: int
|
| 262 |
-
prefix: Optional[str]
|
| 263 |
-
suppress_blank: bool
|
| 264 |
-
suppress_tokens: Optional[str]
|
| 265 |
-
max_initial_timestamp: float
|
| 266 |
-
word_timestamps: bool
|
| 267 |
-
prepend_punctuations: Optional[str]
|
| 268 |
-
append_punctuations: Optional[str]
|
| 269 |
-
max_new_tokens: Optional[int]
|
| 270 |
-
chunk_length: Optional[int]
|
| 271 |
-
hallucination_silence_threshold: Optional[float]
|
| 272 |
-
hotwords: Optional[str]
|
| 273 |
-
language_detection_threshold: Optional[float]
|
| 274 |
-
language_detection_segments: int
|
| 275 |
-
"""
|
| 276 |
-
A data class to use Whisper parameters.
|
| 277 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
outputs/outputs are saved here.txt
DELETED
|
File without changes
|
outputs/translations/outputs for translation are saved here.txt
DELETED
|
File without changes
|