Spaces:
Runtime error
Runtime error
Commit
·
d860e14
1
Parent(s):
7b7fdec
Add simulstreaming_whisper module, update requirements, improve Dockerfile and model handling
Browse files- .gitignore +57 -0
- Dockerfile +2 -3
- requirements.txt +3 -0
- server_wrapper.py +14 -1
- simul_whisper/__init__.py +0 -0
- simul_whisper/beam.py +17 -0
- simul_whisper/config.py +31 -0
- simul_whisper/eow_detection.py +68 -0
- simul_whisper/generation_progress.py +43 -0
- simul_whisper/simul_whisper.py +649 -0
- simul_whisper/whisper/__init__.py +160 -0
- simul_whisper/whisper/__main__.py +3 -0
- simul_whisper/whisper/assets/gpt2.tiktoken +0 -0
- simul_whisper/whisper/assets/mel_filters.npz +0 -0
- simul_whisper/whisper/assets/multilingual.tiktoken +0 -0
- simul_whisper/whisper/audio.py +157 -0
- simul_whisper/whisper/decoding.py +833 -0
- simul_whisper/whisper/model.py +382 -0
- simul_whisper/whisper/normalizers/__init__.py +2 -0
- simul_whisper/whisper/normalizers/basic.py +76 -0
- simul_whisper/whisper/normalizers/english.py +550 -0
- simul_whisper/whisper/timing.py +401 -0
- simul_whisper/whisper/tokenizer.py +395 -0
- simul_whisper/whisper/trans_nopad.py +501 -0
- simul_whisper/whisper/transcribe.py +467 -0
- simul_whisper/whisper/triton_ops.py +109 -0
- simul_whisper/whisper/utils.py +258 -0
- simul_whisper/whisper/version.py +1 -0
- simulstreaming_whisper.py +260 -0
- token_buffer.py +73 -0
- whisper_streaming/base.py +46 -0
- whisper_streaming/line_packet.py +93 -0
- whisper_streaming/silero_vad_iterator.py +150 -0
- whisper_streaming/vac_online_processor.py +111 -0
- whisper_streaming/whisper_online_main.py +224 -0
- whisper_streaming/whisper_server.py +177 -0
.gitignore
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
pip-wheel-metadata/
|
| 20 |
+
share/python-wheels/
|
| 21 |
+
*.egg-info/
|
| 22 |
+
.installed.cfg
|
| 23 |
+
*.egg
|
| 24 |
+
MANIFEST
|
| 25 |
+
|
| 26 |
+
# Virtual environments
|
| 27 |
+
venv/
|
| 28 |
+
ENV/
|
| 29 |
+
env/
|
| 30 |
+
|
| 31 |
+
# IDE
|
| 32 |
+
.vscode/
|
| 33 |
+
.idea/
|
| 34 |
+
*.swp
|
| 35 |
+
*.swo
|
| 36 |
+
*~
|
| 37 |
+
|
| 38 |
+
# Cache
|
| 39 |
+
.cache/
|
| 40 |
+
*.cache
|
| 41 |
+
|
| 42 |
+
# Models (large files)
|
| 43 |
+
*.pt
|
| 44 |
+
*.pth
|
| 45 |
+
*.bin
|
| 46 |
+
|
| 47 |
+
# Logs
|
| 48 |
+
logs/
|
| 49 |
+
*.log
|
| 50 |
+
|
| 51 |
+
# OS
|
| 52 |
+
.DS_Store
|
| 53 |
+
Thumbs.db
|
| 54 |
+
|
| 55 |
+
# Local development
|
| 56 |
+
.env
|
| 57 |
+
.env.local
|
Dockerfile
CHANGED
|
@@ -18,9 +18,8 @@ RUN pip install --no-cache-dir -r requirements.txt
|
|
| 18 |
# Copia el código de la aplicación
|
| 19 |
COPY . .
|
| 20 |
|
| 21 |
-
#
|
| 22 |
-
|
| 23 |
-
# RUN python -c "import torch; torch.hub.load('openai/whisper', 'large-v3')"
|
| 24 |
|
| 25 |
EXPOSE 7860
|
| 26 |
|
|
|
|
| 18 |
# Copia el código de la aplicación
|
| 19 |
COPY . .
|
| 20 |
|
| 21 |
+
# Pre-descargar el modelo durante la construcción
|
| 22 |
+
RUN python -c "import torch; torch.hub.load('openai/whisper', 'large-v3')"
|
|
|
|
| 23 |
|
| 24 |
EXPOSE 7860
|
| 25 |
|
requirements.txt
CHANGED
|
@@ -7,3 +7,6 @@ numpy>=1.24.0
|
|
| 7 |
torch>=2.0.0
|
| 8 |
transformers>=4.30.0
|
| 9 |
torchaudio>=2.0.0
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
torch>=2.0.0
|
| 8 |
transformers>=4.30.0
|
| 9 |
torchaudio>=2.0.0
|
| 10 |
+
tqdm
|
| 11 |
+
tiktoken
|
| 12 |
+
triton>=2.0.0,<3;platform_machine=="x86_64" and sys_platform=="linux" or sys_platform=="linux2"
|
server_wrapper.py
CHANGED
|
@@ -14,13 +14,26 @@ _asr = None
|
|
| 14 |
_online = None
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
def _make_args():
|
| 18 |
# Minimal args required by simul_asr_factory
|
| 19 |
return SimpleNamespace(
|
| 20 |
log_level='INFO',
|
| 21 |
decoder=None,
|
| 22 |
beams=1,
|
| 23 |
-
model_path=
|
| 24 |
cif_ckpt_path=None,
|
| 25 |
frame_threshold=25,
|
| 26 |
audio_min_len=0.0,
|
|
|
|
| 14 |
_online = None
|
| 15 |
|
| 16 |
|
| 17 |
+
def _get_model_path():
|
| 18 |
+
"""Get the path to the Whisper model. Download if needed."""
|
| 19 |
+
import os
|
| 20 |
+
model_dir = os.path.expanduser('~/.cache/torch/hub/checkpoints')
|
| 21 |
+
model_path = os.path.join(model_dir, 'large-v3.pt')
|
| 22 |
+
|
| 23 |
+
if not os.path.exists(model_path):
|
| 24 |
+
print(f"Model not found at {model_path}. Downloading...")
|
| 25 |
+
import torch
|
| 26 |
+
torch.hub.load('openai/whisper', 'large-v3')
|
| 27 |
+
|
| 28 |
+
return model_path
|
| 29 |
+
|
| 30 |
def _make_args():
|
| 31 |
# Minimal args required by simul_asr_factory
|
| 32 |
return SimpleNamespace(
|
| 33 |
log_level='INFO',
|
| 34 |
decoder=None,
|
| 35 |
beams=1,
|
| 36 |
+
model_path=_get_model_path(),
|
| 37 |
cif_ckpt_path=None,
|
| 38 |
frame_threshold=25,
|
| 39 |
audio_min_len=0.0,
|
simul_whisper/__init__.py
ADDED
|
File without changes
|
simul_whisper/beam.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .whisper.decoding import PyTorchInference
|
| 2 |
+
|
| 3 |
+
# extention of PyTorchInference for beam search
|
| 4 |
+
class BeamPyTorchInference(PyTorchInference):
|
| 5 |
+
|
| 6 |
+
def _kv_modules(self):
|
| 7 |
+
key_modules = [block.attn.key.cache_id for block in self.model.decoder.blocks]
|
| 8 |
+
value_modules = [block.attn.value.cache_id for block in self.model.decoder.blocks]
|
| 9 |
+
return key_modules + value_modules
|
| 10 |
+
|
| 11 |
+
def rearrange_kv_cache(self, source_indices):
|
| 12 |
+
if source_indices != list(range(len(source_indices))):
|
| 13 |
+
for module_cache_id in self._kv_modules():
|
| 14 |
+
self.kv_cache[module_cache_id] = self.kv_cache[module_cache_id][source_indices].detach()
|
| 15 |
+
from torch import Tensor
|
| 16 |
+
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
| 17 |
+
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
simul_whisper/config.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This code was originally in simul_whisper/transcriber/simul_whisper.py . It is adapted a lot for SimulStreaming.
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Literal
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class SimulWhisperConfig:
|
| 8 |
+
'''Options that are common for all simul policies that could be implemented in SimulWhisper.'''
|
| 9 |
+
model_path: str
|
| 10 |
+
language: str = field(default="zh")
|
| 11 |
+
nonspeech_prob: float = 1.0
|
| 12 |
+
audio_min_len: float = 1.0
|
| 13 |
+
decoder_type: Literal["greedy","beam"] = "greedy"
|
| 14 |
+
beam_size: int = 5
|
| 15 |
+
task: Literal["transcribe","translate"] = "transcribe"
|
| 16 |
+
init_prompt: str = field(default=None)
|
| 17 |
+
static_init_prompt: str = field(default=None)
|
| 18 |
+
max_context_tokens: int = field(default=None)
|
| 19 |
+
|
| 20 |
+
logdir: str = field(default="logdir", metadata={"help": "Directory to save audio segments and tokens for debugging purposes."})
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class AlignAttConfig(SimulWhisperConfig):
|
| 24 |
+
'''Options specific to the AlignAtt policy.'''
|
| 25 |
+
eval_data_path: str = "tmp"
|
| 26 |
+
segment_length: float = field(default=1.0, metadata = {"help": "in second"})
|
| 27 |
+
frame_threshold: int = 4
|
| 28 |
+
rewind_threshold: int = 200 # in frames. Max value is 1500. Higher value turns rewinds off.
|
| 29 |
+
audio_max_len: float = 30.0
|
| 30 |
+
cif_ckpt_path: str = ""
|
| 31 |
+
never_fire: bool = False
|
simul_whisper/eow_detection.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
# code for the end-of-word detection based on the CIF model proposed in Simul-Whisper
|
| 4 |
+
|
| 5 |
+
def load_cif(cfg, n_audio_state, device):
|
| 6 |
+
"""cfg: AlignAttConfig, n_audio_state: int, device: torch.device"""
|
| 7 |
+
cif_linear = torch.nn.Linear(n_audio_state, 1)
|
| 8 |
+
if cfg.cif_ckpt_path is None or not cfg.cif_ckpt_path:
|
| 9 |
+
if cfg.never_fire:
|
| 10 |
+
never_fire = True
|
| 11 |
+
always_fire = False
|
| 12 |
+
else:
|
| 13 |
+
always_fire = True
|
| 14 |
+
never_fire = False
|
| 15 |
+
else:
|
| 16 |
+
always_fire = False
|
| 17 |
+
never_fire = cfg.never_fire
|
| 18 |
+
map_location = None
|
| 19 |
+
if not torch.cuda.is_available():
|
| 20 |
+
map_location=torch.device('cpu')
|
| 21 |
+
checkpoint = torch.load(cfg.cif_ckpt_path, map_location=map_location)
|
| 22 |
+
cif_linear.load_state_dict(checkpoint)
|
| 23 |
+
cif_linear.to(device)
|
| 24 |
+
return cif_linear, always_fire, never_fire
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# from https://github.com/dqqcasia/mosst/blob/master/fairseq/models/speech_to_text/convtransformer_wav2vec_cif.py
|
| 28 |
+
def resize(alphas, target_lengths, threshold=0.999):
|
| 29 |
+
"""
|
| 30 |
+
alpha in thresh=1.0 | (0.0, +0.21)
|
| 31 |
+
target_lengths: if None, apply round and resize, else apply scaling
|
| 32 |
+
"""
|
| 33 |
+
# sum
|
| 34 |
+
_num = alphas.sum(-1)
|
| 35 |
+
num = target_lengths.float()
|
| 36 |
+
# scaling
|
| 37 |
+
_alphas = alphas * (num / _num)[:, None].repeat(1, alphas.size(1))
|
| 38 |
+
# rm attention value that exceeds threashold
|
| 39 |
+
count = 0
|
| 40 |
+
while len(torch.where(_alphas > threshold)[0]):
|
| 41 |
+
count += 1
|
| 42 |
+
if count > 10:
|
| 43 |
+
break
|
| 44 |
+
xs, ys = torch.where(_alphas > threshold)
|
| 45 |
+
for x, y in zip(xs, ys):
|
| 46 |
+
if _alphas[x][y] >= threshold:
|
| 47 |
+
mask = _alphas[x].ne(0).float()
|
| 48 |
+
mean = 0.5 * _alphas[x].sum() / mask.sum()
|
| 49 |
+
_alphas[x] = _alphas[x] * 0.5 + mean * mask
|
| 50 |
+
|
| 51 |
+
return _alphas, _num
|
| 52 |
+
|
| 53 |
+
def fire_at_boundary(chunked_encoder_feature: torch.Tensor, cif_linear):
|
| 54 |
+
content_mel_len = chunked_encoder_feature.shape[1] # B, T, D
|
| 55 |
+
alphas = cif_linear(chunked_encoder_feature).squeeze(dim=2) # B, T
|
| 56 |
+
alphas = torch.sigmoid(alphas)
|
| 57 |
+
decode_length = torch.round(alphas.sum(-1)).int()
|
| 58 |
+
alphas, _ = resize(alphas, decode_length)
|
| 59 |
+
alphas = alphas.squeeze(0) # (T, )
|
| 60 |
+
threshold = 0.999
|
| 61 |
+
integrate = torch.cumsum(alphas[:-1], dim=0) # ignore the peak value at the end of the content chunk
|
| 62 |
+
exceed_count = integrate[-1] // threshold
|
| 63 |
+
integrate = integrate - exceed_count*1.0 # minus 1 every time intergrate exceed the threshold
|
| 64 |
+
important_positions = (integrate >= 0).nonzero(as_tuple=True)[0]
|
| 65 |
+
if important_positions.numel() == 0:
|
| 66 |
+
return False
|
| 67 |
+
else:
|
| 68 |
+
return important_positions[0] >= content_mel_len-2
|
simul_whisper/generation_progress.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class Tokens:
|
| 2 |
+
def __init__(self, tokens):
|
| 3 |
+
self.tokens = tokens
|
| 4 |
+
|
| 5 |
+
# def clone(self):
|
| 6 |
+
# return Tokens(self.tokens.clone())
|
| 7 |
+
|
| 8 |
+
def __str__(self):
|
| 9 |
+
return str(self.tokens.tolist())
|
| 10 |
+
|
| 11 |
+
def __repr__(self):
|
| 12 |
+
return self.__str__()
|
| 13 |
+
|
| 14 |
+
class BeamTokens(Tokens):
|
| 15 |
+
def __init__(self, tokens, beam_size):
|
| 16 |
+
self.tokens = tokens
|
| 17 |
+
self.beam_size = beam_size
|
| 18 |
+
|
| 19 |
+
def clone(self):
|
| 20 |
+
return BeamTokens(self.tokens.clone())
|
| 21 |
+
|
| 22 |
+
def __str__(self):
|
| 23 |
+
return f"BeamTokens({self.tokens.tolist()}, beam_size={self.beam_size})"
|
| 24 |
+
|
| 25 |
+
def __repr__(self):
|
| 26 |
+
return self.__str__()
|
| 27 |
+
|
| 28 |
+
def as_text(self, tokenizer):
|
| 29 |
+
return tokenizer.decode(self.tokens)
|
| 30 |
+
|
| 31 |
+
class Logits(Tokens):
|
| 32 |
+
def __init__(self, logits):
|
| 33 |
+
super().__init__(logits)
|
| 34 |
+
|
| 35 |
+
# def clone(self):
|
| 36 |
+
# return Logits(self.tokens.clone(), self.beam_size)
|
| 37 |
+
|
| 38 |
+
def __str__(self):
|
| 39 |
+
# return "abc"
|
| 40 |
+
return f"Logits({self.tokens.shape})"
|
| 41 |
+
|
| 42 |
+
def __repr__(self):
|
| 43 |
+
return self.__str__()
|
simul_whisper/simul_whisper.py
ADDED
|
@@ -0,0 +1,649 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This code was originally in simul_whisper/transcriber/simul_whisper.py . It is adapted a lot for SimulStreaming.
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from .whisper import load_model, DecodingOptions, tokenizer
|
| 10 |
+
from .config import AlignAttConfig
|
| 11 |
+
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
|
| 12 |
+
from .whisper.timing import median_filter
|
| 13 |
+
from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language
|
| 14 |
+
from .beam import BeamPyTorchInference
|
| 15 |
+
from .eow_detection import fire_at_boundary, load_cif
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
from token_buffer import TokenBuffer
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
from .generation_progress import *
|
| 22 |
+
|
| 23 |
+
DEC_PAD = 50257
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
import sys
|
| 27 |
+
import wave
|
| 28 |
+
|
| 29 |
+
# New features added to the original version of Simul-Whisper:
|
| 30 |
+
# - large-v3 model support
|
| 31 |
+
# - translation support
|
| 32 |
+
# - beam search
|
| 33 |
+
# - prompt -- static vs. non-static
|
| 34 |
+
# - context
|
| 35 |
+
class PaddedAlignAttWhisper:
|
| 36 |
+
def __init__(self, cfg: AlignAttConfig) -> None:
|
| 37 |
+
self.logdir_i = 0
|
| 38 |
+
self.log_segments = 0
|
| 39 |
+
if cfg.logdir is not None and not os.path.exists(cfg.logdir):
|
| 40 |
+
os.makedirs(cfg.logdir)
|
| 41 |
+
model_name = os.path.basename(cfg.model_path).replace(".pt", "")
|
| 42 |
+
model_path = os.path.dirname(os.path.abspath(cfg.model_path))
|
| 43 |
+
self.model = load_model(name=model_name, download_root=model_path)
|
| 44 |
+
|
| 45 |
+
logger.info(f"Model dimensions: {self.model.dims}")
|
| 46 |
+
|
| 47 |
+
self.decode_options = DecodingOptions(
|
| 48 |
+
language = cfg.language,
|
| 49 |
+
without_timestamps = True,
|
| 50 |
+
task=cfg.task
|
| 51 |
+
)
|
| 52 |
+
self.tokenizer_is_multilingual = not model_name.endswith(".en")
|
| 53 |
+
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
| 54 |
+
self.detected_language = cfg.language if cfg.language != "auto" else None
|
| 55 |
+
|
| 56 |
+
self.max_text_len = self.model.dims.n_text_ctx
|
| 57 |
+
self.num_decoder_layers = len(self.model.decoder.blocks)
|
| 58 |
+
self.cfg = cfg
|
| 59 |
+
|
| 60 |
+
# model to detect end-of-word boundary at the end of the segment
|
| 61 |
+
self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg,
|
| 62 |
+
n_audio_state=self.model.dims.n_audio_state,
|
| 63 |
+
device=self.model.device)
|
| 64 |
+
|
| 65 |
+
# install hooks to access encoder-decoder attention
|
| 66 |
+
self.dec_attns = []
|
| 67 |
+
def layer_hook(module, net_input, net_output):
|
| 68 |
+
# net_output[1]: B*num_head*token_len*audio_len
|
| 69 |
+
t = F.softmax(net_output[1], dim=-1)
|
| 70 |
+
self.dec_attns.append(t.squeeze(0))
|
| 71 |
+
for b in self.model.decoder.blocks:
|
| 72 |
+
b.cross_attn.register_forward_hook(layer_hook)
|
| 73 |
+
|
| 74 |
+
self.kv_cache = {}
|
| 75 |
+
def kv_hook(module: torch.nn.Linear, _, net_output: torch.Tensor):
|
| 76 |
+
if module.cache_id not in self.kv_cache or net_output.shape[1] > self.max_text_len:
|
| 77 |
+
# save as-is, for the first token or cross attention
|
| 78 |
+
self.kv_cache[module.cache_id] = net_output
|
| 79 |
+
else:
|
| 80 |
+
x = self.kv_cache[module.cache_id]
|
| 81 |
+
self.kv_cache[module.cache_id] = torch.cat([x, net_output], dim=1).detach()
|
| 82 |
+
return self.kv_cache[module.cache_id]
|
| 83 |
+
|
| 84 |
+
for i,b in enumerate(self.model.decoder.blocks):
|
| 85 |
+
b.attn.key.register_forward_hook(kv_hook)
|
| 86 |
+
b.attn.value.register_forward_hook(kv_hook)
|
| 87 |
+
b.cross_attn.key.register_forward_hook(kv_hook)
|
| 88 |
+
b.cross_attn.value.register_forward_hook(kv_hook)
|
| 89 |
+
|
| 90 |
+
self.align_source = {}
|
| 91 |
+
self.num_align_heads = 0
|
| 92 |
+
for layer_rank, head_id in self.model.alignment_heads.indices().T:
|
| 93 |
+
layer_rank = layer_rank.item()
|
| 94 |
+
heads = self.align_source.get(layer_rank, [])
|
| 95 |
+
heads.append((self.num_align_heads, head_id.item()))
|
| 96 |
+
self.align_source[layer_rank] = heads
|
| 97 |
+
self.num_align_heads += 1
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# tokens to be suppressed from decoding, to prevent hallucinations
|
| 101 |
+
suppress_tokens = [
|
| 102 |
+
self.tokenizer.transcribe,
|
| 103 |
+
self.tokenizer.translate,
|
| 104 |
+
self.tokenizer.sot,
|
| 105 |
+
self.tokenizer.sot_prev,
|
| 106 |
+
self.tokenizer.sot_lm,
|
| 107 |
+
# self.tokenizer.eot
|
| 108 |
+
self.tokenizer.no_timestamps, # added by DM
|
| 109 |
+
] + list(self.tokenizer.all_language_tokens) # added by DM
|
| 110 |
+
if self.tokenizer.no_speech is not None:
|
| 111 |
+
suppress_tokens.append(self.tokenizer.no_speech)
|
| 112 |
+
suppress_tokens = tuple(sorted(set(suppress_tokens)))
|
| 113 |
+
logger.debug(f"Suppress tokens: {suppress_tokens}")
|
| 114 |
+
sup_tokens = SuppressTokens(suppress_tokens)
|
| 115 |
+
self.suppress_tokens = lambda logits: sup_tokens.apply(logits, None)
|
| 116 |
+
# blank tokens are suppresed for new segments near the line 334
|
| 117 |
+
|
| 118 |
+
# it's going to be regenerated after lang id
|
| 119 |
+
self.segments = []
|
| 120 |
+
self.init_tokens()
|
| 121 |
+
|
| 122 |
+
self.last_attend_frame = -self.cfg.rewind_threshold
|
| 123 |
+
|
| 124 |
+
if self.cfg.max_context_tokens is None:
|
| 125 |
+
self.max_context_tokens = self.max_text_len
|
| 126 |
+
else:
|
| 127 |
+
self.max_context_tokens = self.cfg.max_context_tokens
|
| 128 |
+
self.init_context()
|
| 129 |
+
|
| 130 |
+
# decoder type: greedy or beam
|
| 131 |
+
if cfg.decoder_type == "greedy":
|
| 132 |
+
logger.info("Using greedy decoder")
|
| 133 |
+
self.token_decoder = GreedyDecoder(0.0, self.tokenizer.eot)
|
| 134 |
+
self.decoder_type = "greedy"
|
| 135 |
+
|
| 136 |
+
elif cfg.decoder_type == "beam":
|
| 137 |
+
self.decoder_type = "beam"
|
| 138 |
+
self.inference = BeamPyTorchInference(self.model, self.initial_token_length)
|
| 139 |
+
self.inference.kv_cache = self.kv_cache
|
| 140 |
+
|
| 141 |
+
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
|
| 142 |
+
|
| 143 |
+
def create_tokenizer(self, language=None):
|
| 144 |
+
self.tokenizer = tokenizer.get_tokenizer(
|
| 145 |
+
multilingual=self.tokenizer_is_multilingual,
|
| 146 |
+
language=language,
|
| 147 |
+
num_languages=self.model.num_languages,
|
| 148 |
+
task=self.decode_options.task
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
def init_context(self):
|
| 152 |
+
kw = {'tokenizer': self.tokenizer,
|
| 153 |
+
'device': self.model.device,
|
| 154 |
+
'prefix_token_ids': [self.tokenizer.sot_prev]}
|
| 155 |
+
self.context = TokenBuffer.empty(**kw)
|
| 156 |
+
if self.cfg.static_init_prompt is not None:
|
| 157 |
+
self.context = TokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
|
| 158 |
+
if self.cfg.init_prompt is not None:
|
| 159 |
+
self.context.text += self.cfg.init_prompt
|
| 160 |
+
|
| 161 |
+
def init_tokens(self):
|
| 162 |
+
logger.debug(f"init tokens, {len(self.segments)}")
|
| 163 |
+
# init tokens (mandatory prompt)
|
| 164 |
+
self.initial_tokens = torch.tensor(
|
| 165 |
+
self.tokenizer.sot_sequence_including_notimestamps,
|
| 166 |
+
dtype=torch.long,
|
| 167 |
+
device=self.model.device).unsqueeze(0)
|
| 168 |
+
self.initial_token_length = self.initial_tokens.shape[1]
|
| 169 |
+
self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
|
| 170 |
+
# self.segments = []
|
| 171 |
+
logger.debug(f"init tokens after, {len(self.segments)}")
|
| 172 |
+
self.tokens = [self.initial_tokens]
|
| 173 |
+
|
| 174 |
+
def trim_context(self):
|
| 175 |
+
logger.info("Trimming context")
|
| 176 |
+
c = len(self.context.as_token_ids()) - len(self.context.prefix_token_ids)
|
| 177 |
+
# logger.debug(f"c= {len(self.context.as_token_ids())}, {len(self.context.prefix_token_ids)}")
|
| 178 |
+
logger.info(f"Context text: {self.context.as_text()}")
|
| 179 |
+
# logger.debug(f"Context tensor: {self.context.as_tensor()}")
|
| 180 |
+
l = sum(t.shape[1] for t in self.tokens) + c
|
| 181 |
+
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
| 182 |
+
if self.cfg.static_init_prompt is None:
|
| 183 |
+
after = 0
|
| 184 |
+
else:
|
| 185 |
+
after = len(self.cfg.static_init_prompt)
|
| 186 |
+
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
| 187 |
+
while c > self.max_context_tokens or l > self.max_text_len - 20:
|
| 188 |
+
t = self.context.trim_words(after=after)
|
| 189 |
+
l -= t
|
| 190 |
+
c -= t
|
| 191 |
+
logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
| 192 |
+
if t == 0:
|
| 193 |
+
break
|
| 194 |
+
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
| 195 |
+
logger.info(f"Context after trim: {self.context.text} (len: {l})")
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor) -> torch.Tensor:
|
| 199 |
+
if self.cfg.decoder_type == "greedy":
|
| 200 |
+
logit = self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
| 201 |
+
else:
|
| 202 |
+
logger.debug(f"Logits shape: {tokens.shape}")
|
| 203 |
+
logit = self.inference.logits(tokens, audio_features)
|
| 204 |
+
return logit
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def refresh_segment(self, complete=False):
|
| 208 |
+
|
| 209 |
+
logger.debug("Refreshing segment:")
|
| 210 |
+
self.init_tokens()
|
| 211 |
+
self.last_attend_frame = -self.cfg.rewind_threshold
|
| 212 |
+
self.detected_language = None
|
| 213 |
+
self.init_context()
|
| 214 |
+
logger.debug(f"Context: {self.context}")
|
| 215 |
+
if not complete and len(self.segments) > 2:
|
| 216 |
+
logger.debug("keeping last two segments because they are and it is not complete.")
|
| 217 |
+
self.segments = self.segments[-2:]
|
| 218 |
+
else:
|
| 219 |
+
logger.debug("removing all segments.")
|
| 220 |
+
self.segments = []
|
| 221 |
+
self.log_segments += 1
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
|
| 225 |
+
if self.always_fire: return True
|
| 226 |
+
if self.never_fire: return False
|
| 227 |
+
return fire_at_boundary(chunked_encoder_feature, self.CIFLinear)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def _current_tokens(self):
|
| 231 |
+
|
| 232 |
+
toks = self.tokens
|
| 233 |
+
# very first infer: duplicate start of seq to beam_size
|
| 234 |
+
if toks[0].shape[0] == 1:
|
| 235 |
+
toks[0] = toks[0].repeat_interleave(self.cfg.beam_size,dim=0)
|
| 236 |
+
|
| 237 |
+
if not self.context.is_empty():
|
| 238 |
+
context_toks = self.context.as_tensor_beam(self.cfg.beam_size, device=self.model.device)
|
| 239 |
+
toks = [context_toks] + toks
|
| 240 |
+
|
| 241 |
+
# make it one tensor
|
| 242 |
+
if len(toks) > 1:
|
| 243 |
+
current_tokens = torch.cat(toks, dim=1)
|
| 244 |
+
else:
|
| 245 |
+
current_tokens = toks[0]
|
| 246 |
+
logger.debug("debug print current_tokens:")
|
| 247 |
+
self.debug_print_tokens(current_tokens)
|
| 248 |
+
return current_tokens
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def debug_print_tokens(self, tokens):
|
| 252 |
+
for i in range(self.cfg.beam_size):
|
| 253 |
+
logger.debug(self.tokenizer.decode_with_timestamps(tokens[i].tolist()))
|
| 254 |
+
|
| 255 |
+
### audio buffer
|
| 256 |
+
|
| 257 |
+
def segments_len(self):
|
| 258 |
+
segments_len = sum(s.shape[0] for s in self.segments) / 16000
|
| 259 |
+
return segments_len
|
| 260 |
+
|
| 261 |
+
def _apply_minseglen(self):
|
| 262 |
+
segments_len = self.segments_len()
|
| 263 |
+
# wait for long enough audio to start
|
| 264 |
+
if segments_len < self.cfg.audio_min_len:
|
| 265 |
+
logger.debug("waiting for next segment")
|
| 266 |
+
return False
|
| 267 |
+
return True
|
| 268 |
+
|
| 269 |
+
def insert_audio(self, segment=None):
|
| 270 |
+
if segment is not None:
|
| 271 |
+
self.segments.append(segment)
|
| 272 |
+
|
| 273 |
+
removed_len = 0
|
| 274 |
+
# len of audio is bigger than buffer_len. Going to remove the first segment
|
| 275 |
+
segments_len = self.segments_len()
|
| 276 |
+
while len(self.segments) > 1 and segments_len > self.cfg.audio_max_len:
|
| 277 |
+
removed_len = self.segments[0].shape[0] / 16000
|
| 278 |
+
segments_len -= removed_len
|
| 279 |
+
self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len)
|
| 280 |
+
self.segments = self.segments[1:]
|
| 281 |
+
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}")
|
| 282 |
+
if len(self.tokens) > 1:
|
| 283 |
+
self.context.append_token_ids(self.tokens[1][0,:])
|
| 284 |
+
self.tokens = [self.initial_tokens] + self.tokens[2:]
|
| 285 |
+
return removed_len
|
| 286 |
+
|
| 287 |
+
def _clean_cache(self):
|
| 288 |
+
'''clean the cache that stores the attention matrices and kv_cache.
|
| 289 |
+
It must be called every time after generation with the model.'''
|
| 290 |
+
# cleaning cache
|
| 291 |
+
self.dec_attns = []
|
| 292 |
+
self.kv_cache = {}
|
| 293 |
+
if self.decoder_type == "beam":
|
| 294 |
+
self.inference.kv_cache = self.kv_cache
|
| 295 |
+
self.token_decoder.reset()
|
| 296 |
+
|
| 297 |
+
@torch.no_grad()
|
| 298 |
+
def lang_id(self, encoder_features):
|
| 299 |
+
"""Language detection from encoder features.
|
| 300 |
+
This code is trimmed and copy-pasted from whisper.decoding.detect_language .
|
| 301 |
+
"""
|
| 302 |
+
|
| 303 |
+
# forward pass using a single token, startoftranscript
|
| 304 |
+
n_audio = encoder_features.shape[0]
|
| 305 |
+
x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device) # [n_audio, 1]
|
| 306 |
+
logits = self.model.logits(x, encoder_features)[:, 0]
|
| 307 |
+
|
| 308 |
+
# collect detected languages; suppress all non-language tokens
|
| 309 |
+
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
| 310 |
+
mask[list(self.tokenizer.all_language_tokens)] = False
|
| 311 |
+
logits[:, mask] = -np.inf
|
| 312 |
+
language_tokens = logits.argmax(dim=-1)
|
| 313 |
+
language_token_probs = logits.softmax(dim=-1).cpu()
|
| 314 |
+
language_probs = [
|
| 315 |
+
{
|
| 316 |
+
c: language_token_probs[i, j].item()
|
| 317 |
+
for j, c in zip(self.tokenizer.all_language_tokens, self.tokenizer.all_language_codes)
|
| 318 |
+
}
|
| 319 |
+
for i in range(n_audio)
|
| 320 |
+
]
|
| 321 |
+
|
| 322 |
+
single = encoder_features.ndim == 2
|
| 323 |
+
if single:
|
| 324 |
+
language_tokens = language_tokens[0]
|
| 325 |
+
language_probs = language_probs[0]
|
| 326 |
+
|
| 327 |
+
self._clean_cache()
|
| 328 |
+
return language_tokens, language_probs
|
| 329 |
+
|
| 330 |
+
### transcription / translation
|
| 331 |
+
|
| 332 |
+
@torch.no_grad()
|
| 333 |
+
def infer(self, is_last=False):
|
| 334 |
+
new_segment = True
|
| 335 |
+
if len(self.segments) == 0:
|
| 336 |
+
logger.debug("No segments, nothing to do")
|
| 337 |
+
self.logdir_save([], [], {})
|
| 338 |
+
return [], {}
|
| 339 |
+
if not self._apply_minseglen():
|
| 340 |
+
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
| 341 |
+
input_segments = torch.cat(self.segments, dim=0)
|
| 342 |
+
self.logdir_save(input_segments, [], {})
|
| 343 |
+
return [], {}
|
| 344 |
+
|
| 345 |
+
# input_segments is concatenation of audio, it's one array
|
| 346 |
+
if len(self.segments) > 1:
|
| 347 |
+
input_segments = torch.cat(self.segments, dim=0)
|
| 348 |
+
else:
|
| 349 |
+
input_segments = self.segments[0]
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
# mel + padding to 30s
|
| 354 |
+
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
|
| 355 |
+
device=self.model.device).unsqueeze(0)
|
| 356 |
+
# trim to 3000
|
| 357 |
+
mel = pad_or_trim(mel_padded, N_FRAMES)
|
| 358 |
+
|
| 359 |
+
# the len of actual audio
|
| 360 |
+
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
|
| 361 |
+
|
| 362 |
+
# encode
|
| 363 |
+
encoder_feature = self.model.encoder(mel)
|
| 364 |
+
|
| 365 |
+
# logger.debug(f"Encoder feature shape: {encoder_feature.shape}")
|
| 366 |
+
# if mel.shape[-2:] != (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
|
| 367 |
+
# logger.debug("mel ")
|
| 368 |
+
if self.cfg.language == "auto" and self.detected_language is None:
|
| 369 |
+
language_tokens, language_probs = self.lang_id(encoder_feature)
|
| 370 |
+
logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}")
|
| 371 |
+
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
| 372 |
+
logger.info(f"Detected language: {top_lan} with p={p:.4f}")
|
| 373 |
+
#self.tokenizer.language = top_lan
|
| 374 |
+
#self.tokenizer.__post_init__()
|
| 375 |
+
self.create_tokenizer(top_lan)
|
| 376 |
+
self.detected_language = top_lan
|
| 377 |
+
self.init_tokens()
|
| 378 |
+
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
|
| 379 |
+
|
| 380 |
+
self.trim_context()
|
| 381 |
+
current_tokens = self._current_tokens()
|
| 382 |
+
#
|
| 383 |
+
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
####################### Decoding loop
|
| 387 |
+
logger.info("Decoding loop starts\n")
|
| 388 |
+
|
| 389 |
+
sum_logprobs = torch.zeros(self.cfg.beam_size, device=mel.device)
|
| 390 |
+
completed = False
|
| 391 |
+
|
| 392 |
+
attn_of_alignment_heads = None
|
| 393 |
+
most_attended_frame = None
|
| 394 |
+
|
| 395 |
+
token_len_before_decoding = current_tokens.shape[1]
|
| 396 |
+
|
| 397 |
+
generation_progress = []
|
| 398 |
+
generation = {
|
| 399 |
+
"starting_tokens": BeamTokens(current_tokens[0,:].clone(), self.cfg.beam_size),
|
| 400 |
+
"token_len_before_decoding": token_len_before_decoding,
|
| 401 |
+
#"fire_detected": fire_detected,
|
| 402 |
+
"frames_len": content_mel_len,
|
| 403 |
+
"frames_threshold": 4 if is_last else self.cfg.frame_threshold,
|
| 404 |
+
|
| 405 |
+
# to be filled later
|
| 406 |
+
"logits_starting": None,
|
| 407 |
+
|
| 408 |
+
# to be filled later
|
| 409 |
+
"no_speech_prob": None,
|
| 410 |
+
"no_speech": False,
|
| 411 |
+
|
| 412 |
+
# to be filled in the loop
|
| 413 |
+
"progress": generation_progress,
|
| 414 |
+
}
|
| 415 |
+
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
|
| 416 |
+
generation_progress_loop = []
|
| 417 |
+
|
| 418 |
+
if new_segment:
|
| 419 |
+
tokens_for_logits = current_tokens
|
| 420 |
+
else:
|
| 421 |
+
# only need to use the last token except in the first forward pass
|
| 422 |
+
tokens_for_logits = current_tokens[:,-1:]
|
| 423 |
+
|
| 424 |
+
logits = self.logits(tokens_for_logits, encoder_feature) # B, len(tokens), token dict size
|
| 425 |
+
if new_segment:
|
| 426 |
+
generation["logits_starting"] = Logits(logits[:,:,:])
|
| 427 |
+
|
| 428 |
+
if new_segment and self.tokenizer.no_speech is not None:
|
| 429 |
+
probs_at_sot = logits[:, self.sot_index, :].float().softmax(dim=-1)
|
| 430 |
+
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
| 431 |
+
generation["no_speech_prob"] = no_speech_probs[0]
|
| 432 |
+
if no_speech_probs[0] > self.cfg.nonspeech_prob:
|
| 433 |
+
generation["no_speech"] = True
|
| 434 |
+
logger.info("no speech, stop")
|
| 435 |
+
break
|
| 436 |
+
|
| 437 |
+
logits = logits[:, -1, :] # logits for the last token
|
| 438 |
+
generation_progress_loop.append(("logits_before_suppress",Logits(logits)))
|
| 439 |
+
|
| 440 |
+
# supress blank tokens only at the beginning of the segment
|
| 441 |
+
if new_segment:
|
| 442 |
+
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
| 443 |
+
new_segment = False
|
| 444 |
+
self.suppress_tokens(logits)
|
| 445 |
+
#generation_progress_loop.append(("logits_after_suppres",BeamLogits(logits[0,:].clone(), self.cfg.beam_size)))
|
| 446 |
+
generation_progress_loop.append(("logits_after_suppress",Logits(logits)))
|
| 447 |
+
|
| 448 |
+
current_tokens, completed = self.token_decoder.update(current_tokens, logits, sum_logprobs)
|
| 449 |
+
generation_progress_loop.append(("beam_tokens",Tokens(current_tokens[:,-1].clone())))
|
| 450 |
+
generation_progress_loop.append(("sum_logprobs",sum_logprobs.tolist()))
|
| 451 |
+
generation_progress_loop.append(("completed",completed))
|
| 452 |
+
|
| 453 |
+
logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ")
|
| 454 |
+
self.debug_print_tokens(current_tokens)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
# if self.decoder_type == "beam":
|
| 458 |
+
# logger.debug(f"Finished sequences: {self.token_decoder.finished_sequences}")
|
| 459 |
+
|
| 460 |
+
# logprobs = F.log_softmax(logits.float(), dim=-1)
|
| 461 |
+
# idx = 0
|
| 462 |
+
# logger.debug(f"Beam search topk: {logprobs[idx].topk(self.cfg.beam_size + 1)}")
|
| 463 |
+
# logger.debug(f"Greedy search argmax: {logits.argmax(dim=-1)}")
|
| 464 |
+
# if completed:
|
| 465 |
+
# self.debug_print_tokens(current_tokens)
|
| 466 |
+
|
| 467 |
+
# logger.debug("decode stopped because decoder completed")
|
| 468 |
+
|
| 469 |
+
attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)]
|
| 470 |
+
for i, attn_mat in enumerate(self.dec_attns):
|
| 471 |
+
layer_rank = int(i % len(self.model.decoder.blocks))
|
| 472 |
+
align_heads_in_layer = self.align_source.get(layer_rank, [])
|
| 473 |
+
if len(align_heads_in_layer) == 0:
|
| 474 |
+
continue
|
| 475 |
+
for align_head_rank, head_id in align_heads_in_layer:
|
| 476 |
+
if self.cfg.beam_size == 1:
|
| 477 |
+
a = attn_mat[head_id, :, :]
|
| 478 |
+
a = a.unsqueeze(0)
|
| 479 |
+
else:
|
| 480 |
+
a = attn_mat[:, head_id, :, :]
|
| 481 |
+
attn_of_alignment_heads[align_head_rank].append(a)
|
| 482 |
+
tmp = []
|
| 483 |
+
for mat in attn_of_alignment_heads:
|
| 484 |
+
t = torch.cat(mat, dim=1)
|
| 485 |
+
tmp.append(t)
|
| 486 |
+
attn_of_alignment_heads = torch.stack(tmp, dim=1)
|
| 487 |
+
# logger.debug(str(attn_of_alignment_heads.shape) + " tttady")
|
| 488 |
+
std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False)
|
| 489 |
+
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / std
|
| 490 |
+
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7) # from whisper.timing
|
| 491 |
+
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
|
| 492 |
+
# logger.debug(str(attn_of_alignment_heads.shape) + " po mean")
|
| 493 |
+
attn_of_alignment_heads = attn_of_alignment_heads[:,:, :content_mel_len]
|
| 494 |
+
# logger.debug(str(attn_of_alignment_heads.shape) + " pak ")
|
| 495 |
+
|
| 496 |
+
# for each beam, the most attended frame is:
|
| 497 |
+
most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1)
|
| 498 |
+
generation_progress_loop.append(("most_attended_frames",most_attended_frames.clone().tolist()))
|
| 499 |
+
logger.debug(str(most_attended_frames.tolist()) + " most att frames")
|
| 500 |
+
|
| 501 |
+
most_attended_frame = most_attended_frames[0].item()
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
generation_progress.append(dict(generation_progress_loop))
|
| 505 |
+
logger.debug("current tokens" + str(current_tokens.shape))
|
| 506 |
+
if completed:
|
| 507 |
+
# # stripping the last token, the eot
|
| 508 |
+
current_tokens = current_tokens[:, :-1]
|
| 509 |
+
break
|
| 510 |
+
|
| 511 |
+
# for some rare cases where the attention fails
|
| 512 |
+
if not is_last and self.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold:
|
| 513 |
+
# TODO: check this
|
| 514 |
+
if current_tokens.shape[1] > 1 and current_tokens[0, -2] >= DEC_PAD:
|
| 515 |
+
logger.debug("ommit rewinding from special tokens")
|
| 516 |
+
self.last_attend_frame = most_attended_frame
|
| 517 |
+
else:
|
| 518 |
+
logger.debug(
|
| 519 |
+
f"[rewind detected] current attention pos: {most_attended_frame}, "
|
| 520 |
+
f"last attention pos: {self.last_attend_frame}; omit this segment")
|
| 521 |
+
self.last_attend_frame = -self.cfg.rewind_threshold
|
| 522 |
+
current_tokens = torch.cat(self.tokens, dim=1) if len(self.tokens) > 0 else self.tokens[0]
|
| 523 |
+
break
|
| 524 |
+
else:
|
| 525 |
+
self.last_attend_frame = most_attended_frame
|
| 526 |
+
|
| 527 |
+
if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold):
|
| 528 |
+
logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}")
|
| 529 |
+
# stripping the last token, the one that is attended too close to the end
|
| 530 |
+
current_tokens = current_tokens[:, :-1]
|
| 531 |
+
break
|
| 532 |
+
|
| 533 |
+
# debug print
|
| 534 |
+
for i in range(self.cfg.beam_size):
|
| 535 |
+
logger.debug("attn: {}, current pos: {}, current token: {}({})".format(
|
| 536 |
+
attn_of_alignment_heads.shape if attn_of_alignment_heads is not None else None,
|
| 537 |
+
most_attended_frames[i],
|
| 538 |
+
current_tokens[i, -1].item(),
|
| 539 |
+
self.tokenizer.decode([current_tokens[i, -1].item()])
|
| 540 |
+
))
|
| 541 |
+
|
| 542 |
+
# for k,v in generation.items():
|
| 543 |
+
# print(k,v,file=sys.stderr)
|
| 544 |
+
# for x in generation_progress:
|
| 545 |
+
# for y in x.items():
|
| 546 |
+
# print("\t\t",*y,file=sys.stderr)
|
| 547 |
+
# print("\t","----", file=sys.stderr)
|
| 548 |
+
# print("\t", "end of generation_progress_loop", file=sys.stderr)
|
| 549 |
+
# sys.exit(1)
|
| 550 |
+
####################### End of decoding loop
|
| 551 |
+
|
| 552 |
+
logger.info("End of decoding loop")
|
| 553 |
+
|
| 554 |
+
# if attn_of_alignment_heads is not None:
|
| 555 |
+
# seg_len = int(segment.shape[0] / 16000 * TOKENS_PER_SECOND)
|
| 556 |
+
|
| 557 |
+
# # Lets' now consider only the top hypothesis in the beam search
|
| 558 |
+
# top_beam_attn_of_alignment_heads = attn_of_alignment_heads[0]
|
| 559 |
+
|
| 560 |
+
# # debug print: how is the new token attended?
|
| 561 |
+
# new_token_attn = top_beam_attn_of_alignment_heads[token_len_before_decoding:, -seg_len:]
|
| 562 |
+
# logger.debug(f"New token attention shape: {new_token_attn.shape}")
|
| 563 |
+
# if new_token_attn.shape[0] == 0: # it's not attended in the current audio segment
|
| 564 |
+
# logger.debug("no token generated")
|
| 565 |
+
# else: # it is, and the max attention is:
|
| 566 |
+
# new_token_max_attn, _ = new_token_attn.max(dim=-1)
|
| 567 |
+
# logger.debug(f"segment max attention: {new_token_max_attn.mean().item()/len(self.segments)}")
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
# let's now operate only with the top beam hypothesis
|
| 571 |
+
tokens_to_split = current_tokens[0, token_len_before_decoding:]
|
| 572 |
+
if fire_detected or is_last:
|
| 573 |
+
new_hypothesis = tokens_to_split.flatten().tolist()
|
| 574 |
+
else:
|
| 575 |
+
# going to truncate the tokens after the last space
|
| 576 |
+
split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split.tolist())
|
| 577 |
+
generation["result"] = {"split_words": split_words[:-1], "split_tokens": split_tokens[:-1]}
|
| 578 |
+
generation["result_truncated"] = {"split_words": split_words[-1:], "split_tokens": split_tokens[-1:]}
|
| 579 |
+
|
| 580 |
+
# text_to_split = self.tokenizer.decode(tokens_to_split)
|
| 581 |
+
# logger.debug(f"text_to_split: {text_to_split}")
|
| 582 |
+
# logger.debug("text at current step: {}".format(text_to_split.replace(" ", "<space>")))
|
| 583 |
+
# text_before_space = " ".join(text_to_split.split(" ")[:-1])
|
| 584 |
+
# logger.debug("before the last space: {}".format(text_before_space.replace(" ", "<space>")))
|
| 585 |
+
if len(split_words) > 1:
|
| 586 |
+
new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist]
|
| 587 |
+
else:
|
| 588 |
+
new_hypothesis = []
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
### new hypothesis
|
| 592 |
+
logger.debug(f"new_hypothesis: {new_hypothesis}")
|
| 593 |
+
new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to(
|
| 594 |
+
device=self.model.device,
|
| 595 |
+
)
|
| 596 |
+
self.tokens.append(new_tokens)
|
| 597 |
+
# TODO: test if this is redundant or not
|
| 598 |
+
# ret = ret[ret<DEC_PAD]
|
| 599 |
+
|
| 600 |
+
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
|
| 601 |
+
|
| 602 |
+
self._clean_cache()
|
| 603 |
+
|
| 604 |
+
self.logdir_save(input_segments, new_hypothesis, generation)
|
| 605 |
+
return new_hypothesis, generation
|
| 606 |
+
|
| 607 |
+
def logdir_save(self, input_segments, new_hypothesis, generation):
|
| 608 |
+
"""The audio and result from each iteration is saved to the logdir for debugging purposes"""
|
| 609 |
+
|
| 610 |
+
# only when the logdir arg is set
|
| 611 |
+
if self.cfg.logdir is None:
|
| 612 |
+
return
|
| 613 |
+
|
| 614 |
+
self.logdir_i += 1
|
| 615 |
+
|
| 616 |
+
# every VAD segment is in a separate directory
|
| 617 |
+
dir = os.path.join(self.cfg.logdir, f"seg_{self.log_segments:05d}")
|
| 618 |
+
if not os.path.exists(dir):
|
| 619 |
+
os.makedirs(dir)
|
| 620 |
+
|
| 621 |
+
logger.debug(f"Saving to {dir}, iteration {self.logdir_i:05d}")
|
| 622 |
+
|
| 623 |
+
# saving wav:
|
| 624 |
+
wav_path = os.path.join(dir, f"iter_{self.logdir_i:05d}_audio.wav")
|
| 625 |
+
audio_np = np.array(input_segments)
|
| 626 |
+
# Ensure audio is float32 in range [-1, 1], convert to int16 for wav
|
| 627 |
+
if audio_np.dtype != np.int16:
|
| 628 |
+
audio_int16 = np.clip(audio_np * 32767, -32768, 32767).astype(np.int16)
|
| 629 |
+
else:
|
| 630 |
+
audio_int16 = audio_np
|
| 631 |
+
|
| 632 |
+
with wave.open(wav_path, "wb") as wf:
|
| 633 |
+
wf.setnchannels(1)
|
| 634 |
+
wf.setsampwidth(2) # 2 bytes for int16
|
| 635 |
+
wf.setframerate(16000)
|
| 636 |
+
wf.writeframes(audio_int16.tobytes())
|
| 637 |
+
|
| 638 |
+
# saving readable text: context + hypothesis
|
| 639 |
+
text = self.tokenizer.decode(new_hypothesis)
|
| 640 |
+
with open(os.path.join(dir, f"iter_{self.logdir_i:05d}_hypothesis.txt"), "w") as f:
|
| 641 |
+
if generation:
|
| 642 |
+
context = generation["starting_tokens"].as_text(self.tokenizer)
|
| 643 |
+
else:
|
| 644 |
+
context = ""
|
| 645 |
+
print("CONTEXT+FORCED:",context,sep="\t",file=f)
|
| 646 |
+
print("HYPOTHESIS:", text, sep="\t", file=f)
|
| 647 |
+
|
| 648 |
+
# TODO: generation progress can be also saved in a readable format
|
| 649 |
+
#logger.debug(f"generation progress: {generation}")
|
simul_whisper/whisper/__init__.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import io
|
| 3 |
+
import os
|
| 4 |
+
import urllib
|
| 5 |
+
import warnings
|
| 6 |
+
from typing import List, Optional, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
| 12 |
+
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
| 13 |
+
from .model import ModelDimensions, Whisper
|
| 14 |
+
from .transcribe import transcribe
|
| 15 |
+
from .version import __version__
|
| 16 |
+
|
| 17 |
+
_MODELS = {
|
| 18 |
+
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
| 19 |
+
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
| 20 |
+
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
|
| 21 |
+
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
|
| 22 |
+
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
|
| 23 |
+
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
| 24 |
+
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
|
| 25 |
+
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
| 26 |
+
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
| 27 |
+
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
| 28 |
+
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
| 29 |
+
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
| 30 |
+
"large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
| 31 |
+
"turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
| 35 |
+
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
|
| 36 |
+
_ALIGNMENT_HEADS = {
|
| 37 |
+
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
|
| 38 |
+
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
|
| 39 |
+
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
|
| 40 |
+
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
|
| 41 |
+
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
|
| 42 |
+
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
|
| 43 |
+
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
|
| 44 |
+
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
|
| 45 |
+
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
| 46 |
+
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
| 47 |
+
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
| 48 |
+
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
| 49 |
+
"large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
| 50 |
+
"turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
| 55 |
+
os.makedirs(root, exist_ok=True)
|
| 56 |
+
|
| 57 |
+
expected_sha256 = url.split("/")[-2]
|
| 58 |
+
download_target = os.path.join(root, os.path.basename(url))
|
| 59 |
+
|
| 60 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
| 61 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
| 62 |
+
|
| 63 |
+
if os.path.isfile(download_target):
|
| 64 |
+
with open(download_target, "rb") as f:
|
| 65 |
+
model_bytes = f.read()
|
| 66 |
+
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
| 67 |
+
return model_bytes if in_memory else download_target
|
| 68 |
+
else:
|
| 69 |
+
warnings.warn(
|
| 70 |
+
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
| 74 |
+
with tqdm(
|
| 75 |
+
total=int(source.info().get("Content-Length")),
|
| 76 |
+
ncols=80,
|
| 77 |
+
unit="iB",
|
| 78 |
+
unit_scale=True,
|
| 79 |
+
unit_divisor=1024,
|
| 80 |
+
) as loop:
|
| 81 |
+
while True:
|
| 82 |
+
buffer = source.read(8192)
|
| 83 |
+
if not buffer:
|
| 84 |
+
break
|
| 85 |
+
|
| 86 |
+
output.write(buffer)
|
| 87 |
+
loop.update(len(buffer))
|
| 88 |
+
|
| 89 |
+
model_bytes = open(download_target, "rb").read()
|
| 90 |
+
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
| 91 |
+
raise RuntimeError(
|
| 92 |
+
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
return model_bytes if in_memory else download_target
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def available_models() -> List[str]:
|
| 99 |
+
"""Returns the names of available models"""
|
| 100 |
+
return list(_MODELS.keys())
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def load_model(
|
| 104 |
+
name: str,
|
| 105 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 106 |
+
download_root: str = None,
|
| 107 |
+
in_memory: bool = False,
|
| 108 |
+
) -> Whisper:
|
| 109 |
+
"""
|
| 110 |
+
Load a Whisper ASR model
|
| 111 |
+
|
| 112 |
+
Parameters
|
| 113 |
+
----------
|
| 114 |
+
name : str
|
| 115 |
+
one of the official model names listed by `whisper.available_models()`, or
|
| 116 |
+
path to a model checkpoint containing the model dimensions and the model state_dict.
|
| 117 |
+
device : Union[str, torch.device]
|
| 118 |
+
the PyTorch device to put the model into
|
| 119 |
+
download_root: str
|
| 120 |
+
path to download the model files; by default, it uses "~/.cache/whisper"
|
| 121 |
+
in_memory: bool
|
| 122 |
+
whether to preload the model weights into host memory
|
| 123 |
+
|
| 124 |
+
Returns
|
| 125 |
+
-------
|
| 126 |
+
model : Whisper
|
| 127 |
+
The Whisper ASR model instance
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
if device is None:
|
| 131 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 132 |
+
if download_root is None:
|
| 133 |
+
default = os.path.join(os.path.expanduser("~"), ".cache")
|
| 134 |
+
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
| 135 |
+
|
| 136 |
+
if name in _MODELS:
|
| 137 |
+
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
| 138 |
+
alignment_heads = _ALIGNMENT_HEADS[name]
|
| 139 |
+
elif os.path.isfile(name):
|
| 140 |
+
checkpoint_file = open(name, "rb").read() if in_memory else name
|
| 141 |
+
alignment_heads = None
|
| 142 |
+
else:
|
| 143 |
+
raise RuntimeError(
|
| 144 |
+
f"Model {name} not found; available models = {available_models()}"
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
with (
|
| 148 |
+
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
| 149 |
+
) as fp:
|
| 150 |
+
checkpoint = torch.load(fp, map_location=device)
|
| 151 |
+
del checkpoint_file
|
| 152 |
+
|
| 153 |
+
dims = ModelDimensions(**checkpoint["dims"])
|
| 154 |
+
model = Whisper(dims)
|
| 155 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 156 |
+
|
| 157 |
+
if alignment_heads is not None:
|
| 158 |
+
model.set_alignment_heads(alignment_heads)
|
| 159 |
+
|
| 160 |
+
return model.to(device)
|
simul_whisper/whisper/__main__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .transcribe import cli
|
| 2 |
+
|
| 3 |
+
cli()
|
simul_whisper/whisper/assets/gpt2.tiktoken
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
simul_whisper/whisper/assets/mel_filters.npz
ADDED
|
Binary file (4.27 kB). View file
|
|
|
simul_whisper/whisper/assets/multilingual.tiktoken
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
simul_whisper/whisper/audio.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from functools import lru_cache
|
| 3 |
+
from subprocess import CalledProcessError, run
|
| 4 |
+
from typing import Optional, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from .utils import exact_div
|
| 11 |
+
|
| 12 |
+
# hard-coded audio hyperparameters
|
| 13 |
+
SAMPLE_RATE = 16000
|
| 14 |
+
N_FFT = 400
|
| 15 |
+
HOP_LENGTH = 160
|
| 16 |
+
CHUNK_LENGTH = 30
|
| 17 |
+
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
| 18 |
+
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
|
| 19 |
+
|
| 20 |
+
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
|
| 21 |
+
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
| 22 |
+
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
| 26 |
+
"""
|
| 27 |
+
Open an audio file and read as mono waveform, resampling as necessary
|
| 28 |
+
|
| 29 |
+
Parameters
|
| 30 |
+
----------
|
| 31 |
+
file: str
|
| 32 |
+
The audio file to open
|
| 33 |
+
|
| 34 |
+
sr: int
|
| 35 |
+
The sample rate to resample the audio if necessary
|
| 36 |
+
|
| 37 |
+
Returns
|
| 38 |
+
-------
|
| 39 |
+
A NumPy array containing the audio waveform, in float32 dtype.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
# This launches a subprocess to decode audio while down-mixing
|
| 43 |
+
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
| 44 |
+
# fmt: off
|
| 45 |
+
cmd = [
|
| 46 |
+
"ffmpeg",
|
| 47 |
+
"-nostdin",
|
| 48 |
+
"-threads", "0",
|
| 49 |
+
"-i", file,
|
| 50 |
+
"-f", "s16le",
|
| 51 |
+
"-ac", "1",
|
| 52 |
+
"-acodec", "pcm_s16le",
|
| 53 |
+
"-ar", str(sr),
|
| 54 |
+
"-"
|
| 55 |
+
]
|
| 56 |
+
# fmt: on
|
| 57 |
+
try:
|
| 58 |
+
out = run(cmd, capture_output=True, check=True).stdout
|
| 59 |
+
except CalledProcessError as e:
|
| 60 |
+
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
| 61 |
+
|
| 62 |
+
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
| 66 |
+
"""
|
| 67 |
+
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
| 68 |
+
"""
|
| 69 |
+
if torch.is_tensor(array):
|
| 70 |
+
if array.shape[axis] > length:
|
| 71 |
+
array = array.index_select(
|
| 72 |
+
dim=axis, index=torch.arange(length, device=array.device)
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
if array.shape[axis] < length:
|
| 76 |
+
pad_widths = [(0, 0)] * array.ndim
|
| 77 |
+
pad_widths[axis] = (0, length - array.shape[axis])
|
| 78 |
+
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
|
| 79 |
+
else:
|
| 80 |
+
if array.shape[axis] > length:
|
| 81 |
+
array = array.take(indices=range(length), axis=axis)
|
| 82 |
+
|
| 83 |
+
if array.shape[axis] < length:
|
| 84 |
+
pad_widths = [(0, 0)] * array.ndim
|
| 85 |
+
pad_widths[axis] = (0, length - array.shape[axis])
|
| 86 |
+
array = np.pad(array, pad_widths)
|
| 87 |
+
|
| 88 |
+
return array
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@lru_cache(maxsize=None)
|
| 92 |
+
def mel_filters(device, n_mels: int) -> torch.Tensor:
|
| 93 |
+
"""
|
| 94 |
+
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
| 95 |
+
Allows decoupling librosa dependency; saved using:
|
| 96 |
+
|
| 97 |
+
np.savez_compressed(
|
| 98 |
+
"mel_filters.npz",
|
| 99 |
+
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
| 100 |
+
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
|
| 101 |
+
)
|
| 102 |
+
"""
|
| 103 |
+
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
| 104 |
+
|
| 105 |
+
filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
| 106 |
+
with np.load(filters_path, allow_pickle=False) as f:
|
| 107 |
+
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def log_mel_spectrogram(
|
| 111 |
+
audio: Union[str, np.ndarray, torch.Tensor],
|
| 112 |
+
n_mels: int = 80,
|
| 113 |
+
padding: int = 0,
|
| 114 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 115 |
+
):
|
| 116 |
+
"""
|
| 117 |
+
Compute the log-Mel spectrogram of
|
| 118 |
+
|
| 119 |
+
Parameters
|
| 120 |
+
----------
|
| 121 |
+
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
| 122 |
+
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
| 123 |
+
|
| 124 |
+
n_mels: int
|
| 125 |
+
The number of Mel-frequency filters, only 80 and 128 are supported
|
| 126 |
+
|
| 127 |
+
padding: int
|
| 128 |
+
Number of zero samples to pad to the right
|
| 129 |
+
|
| 130 |
+
device: Optional[Union[str, torch.device]]
|
| 131 |
+
If given, the audio tensor is moved to this device before STFT
|
| 132 |
+
|
| 133 |
+
Returns
|
| 134 |
+
-------
|
| 135 |
+
torch.Tensor, shape = (n_mels, n_frames)
|
| 136 |
+
A Tensor that contains the Mel spectrogram
|
| 137 |
+
"""
|
| 138 |
+
if not torch.is_tensor(audio):
|
| 139 |
+
if isinstance(audio, str):
|
| 140 |
+
audio = load_audio(audio)
|
| 141 |
+
audio = torch.from_numpy(audio)
|
| 142 |
+
|
| 143 |
+
if device is not None:
|
| 144 |
+
audio = audio.to(device)
|
| 145 |
+
if padding > 0:
|
| 146 |
+
audio = F.pad(audio, (0, padding))
|
| 147 |
+
window = torch.hann_window(N_FFT).to(audio.device)
|
| 148 |
+
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
| 149 |
+
magnitudes = stft[..., :-1].abs() ** 2
|
| 150 |
+
|
| 151 |
+
filters = mel_filters(audio.device, n_mels)
|
| 152 |
+
mel_spec = filters @ magnitudes
|
| 153 |
+
|
| 154 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
| 155 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
| 156 |
+
log_spec = (log_spec + 4.0) / 4.0
|
| 157 |
+
return log_spec
|
simul_whisper/whisper/decoding.py
ADDED
|
@@ -0,0 +1,833 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field, replace
|
| 2 |
+
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from torch.distributions import Categorical
|
| 9 |
+
|
| 10 |
+
from .audio import CHUNK_LENGTH
|
| 11 |
+
from .tokenizer import Tokenizer, get_tokenizer
|
| 12 |
+
from .utils import compression_ratio
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from .model import Whisper
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@torch.no_grad()
|
| 19 |
+
def detect_language(
|
| 20 |
+
model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
|
| 21 |
+
) -> Tuple[Tensor, List[dict]]:
|
| 22 |
+
"""
|
| 23 |
+
Detect the spoken language in the audio, and return them as list of strings, along with the ids
|
| 24 |
+
of the most probable language tokens and the probability distribution over all language tokens.
|
| 25 |
+
This is performed outside the main decode loop in order to not interfere with kv-caching.
|
| 26 |
+
|
| 27 |
+
Returns
|
| 28 |
+
-------
|
| 29 |
+
language_tokens : Tensor, shape = (n_audio,)
|
| 30 |
+
ids of the most probable language tokens, which appears after the startoftranscript token.
|
| 31 |
+
language_probs : List[Dict[str, float]], length = n_audio
|
| 32 |
+
list of dictionaries containing the probability distribution over all languages.
|
| 33 |
+
"""
|
| 34 |
+
if tokenizer is None:
|
| 35 |
+
tokenizer = get_tokenizer(model.is_multilingual)
|
| 36 |
+
if (
|
| 37 |
+
tokenizer.language is None
|
| 38 |
+
or tokenizer.language_token not in tokenizer.sot_sequence
|
| 39 |
+
):
|
| 40 |
+
raise ValueError(
|
| 41 |
+
"This model doesn't have language tokens so it can't perform lang id"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
single = mel.ndim == 2
|
| 45 |
+
if single:
|
| 46 |
+
mel = mel.unsqueeze(0)
|
| 47 |
+
|
| 48 |
+
# skip encoder forward pass if already-encoded audio features were given
|
| 49 |
+
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
|
| 50 |
+
mel = model.encoder(mel)
|
| 51 |
+
|
| 52 |
+
# forward pass using a single token, startoftranscript
|
| 53 |
+
n_audio = mel.shape[0]
|
| 54 |
+
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
|
| 55 |
+
logits = model.logits(x, mel)[:, 0]
|
| 56 |
+
|
| 57 |
+
# collect detected languages; suppress all non-language tokens
|
| 58 |
+
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
| 59 |
+
mask[list(tokenizer.all_language_tokens)] = False
|
| 60 |
+
logits[:, mask] = -np.inf
|
| 61 |
+
language_tokens = logits.argmax(dim=-1)
|
| 62 |
+
language_token_probs = logits.softmax(dim=-1).cpu()
|
| 63 |
+
language_probs = [
|
| 64 |
+
{
|
| 65 |
+
c: language_token_probs[i, j].item()
|
| 66 |
+
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
|
| 67 |
+
}
|
| 68 |
+
for i in range(n_audio)
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
if single:
|
| 72 |
+
language_tokens = language_tokens[0]
|
| 73 |
+
language_probs = language_probs[0]
|
| 74 |
+
|
| 75 |
+
return language_tokens, language_probs
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@dataclass(frozen=True)
|
| 79 |
+
class DecodingOptions:
|
| 80 |
+
# whether to perform X->X "transcribe" or X->English "translate"
|
| 81 |
+
task: str = "transcribe"
|
| 82 |
+
|
| 83 |
+
# language that the audio is in; uses detected language if None
|
| 84 |
+
language: Optional[str] = None
|
| 85 |
+
|
| 86 |
+
# sampling-related options
|
| 87 |
+
temperature: float = 0.0
|
| 88 |
+
sample_len: Optional[int] = None # maximum number of tokens to sample
|
| 89 |
+
best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
|
| 90 |
+
beam_size: Optional[int] = None # number of beams in beam search, if t == 0
|
| 91 |
+
patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
|
| 92 |
+
|
| 93 |
+
# "alpha" in Google NMT, or None for length norm, when ranking generations
|
| 94 |
+
# to select which to return among the beams or best-of-N samples
|
| 95 |
+
length_penalty: Optional[float] = None
|
| 96 |
+
|
| 97 |
+
# text or tokens to feed as the prompt or the prefix; for more info:
|
| 98 |
+
# https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
|
| 99 |
+
prompt: Optional[Union[str, List[int]]] = None # for the previous context
|
| 100 |
+
prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
|
| 101 |
+
|
| 102 |
+
# list of tokens ids (or comma-separated token ids) to suppress
|
| 103 |
+
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
|
| 104 |
+
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
|
| 105 |
+
suppress_blank: bool = True # this will suppress blank outputs
|
| 106 |
+
|
| 107 |
+
# timestamp sampling options
|
| 108 |
+
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
| 109 |
+
max_initial_timestamp: Optional[float] = 1.0
|
| 110 |
+
|
| 111 |
+
# implementation details
|
| 112 |
+
fp16: bool = True # use fp16 for most of the calculation
|
| 113 |
+
|
| 114 |
+
# streaming
|
| 115 |
+
add_sot: Optional[bool] = True
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@dataclass(frozen=True)
|
| 119 |
+
class DecodingResult:
|
| 120 |
+
audio_features: Tensor
|
| 121 |
+
language: str
|
| 122 |
+
language_probs: Optional[Dict[str, float]] = None
|
| 123 |
+
tokens: List[int] = field(default_factory=list)
|
| 124 |
+
text: str = ""
|
| 125 |
+
avg_logprob: float = np.nan
|
| 126 |
+
no_speech_prob: float = np.nan
|
| 127 |
+
temperature: float = np.nan
|
| 128 |
+
compression_ratio: float = np.nan
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class Inference:
|
| 132 |
+
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
| 133 |
+
"""Perform a forward pass on the decoder and return per-token logits"""
|
| 134 |
+
raise NotImplementedError
|
| 135 |
+
|
| 136 |
+
def rearrange_kv_cache(self, source_indices) -> None:
|
| 137 |
+
"""Update the key-value cache according to the updated beams"""
|
| 138 |
+
raise NotImplementedError
|
| 139 |
+
|
| 140 |
+
def cleanup_caching(self) -> None:
|
| 141 |
+
"""Clean up any resources or hooks after decoding is finished"""
|
| 142 |
+
pass
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class PyTorchInference(Inference):
|
| 146 |
+
def __init__(self, model: "Whisper", initial_token_length: int):
|
| 147 |
+
self.model: "Whisper" = model
|
| 148 |
+
self.initial_token_length = initial_token_length
|
| 149 |
+
self.kv_cache = {}
|
| 150 |
+
self.hooks = []
|
| 151 |
+
|
| 152 |
+
key_modules = [block.attn.key for block in self.model.decoder.blocks]
|
| 153 |
+
value_modules = [block.attn.value for block in self.model.decoder.blocks]
|
| 154 |
+
self.kv_modules = key_modules + value_modules
|
| 155 |
+
|
| 156 |
+
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
| 157 |
+
if not self.kv_cache:
|
| 158 |
+
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
| 159 |
+
|
| 160 |
+
if tokens.shape[-1] > self.initial_token_length:
|
| 161 |
+
# only need to use the last token except in the first forward pass
|
| 162 |
+
tokens = tokens[:, -1:]
|
| 163 |
+
|
| 164 |
+
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
| 165 |
+
|
| 166 |
+
def cleanup_caching(self):
|
| 167 |
+
for hook in self.hooks:
|
| 168 |
+
hook.remove()
|
| 169 |
+
|
| 170 |
+
self.kv_cache = {}
|
| 171 |
+
self.hooks = []
|
| 172 |
+
|
| 173 |
+
def rearrange_kv_cache(self, source_indices):
|
| 174 |
+
if source_indices != list(range(len(source_indices))):
|
| 175 |
+
for module in self.kv_modules:
|
| 176 |
+
# update the key/value cache to contain the selected sequences
|
| 177 |
+
self.kv_cache[module] = self.kv_cache[module][source_indices].detach()
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class SequenceRanker:
|
| 181 |
+
def rank(
|
| 182 |
+
self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
|
| 183 |
+
) -> List[int]:
|
| 184 |
+
"""
|
| 185 |
+
Given a list of groups of samples and their cumulative log probabilities,
|
| 186 |
+
return the indices of the samples in each group to select as the final result
|
| 187 |
+
"""
|
| 188 |
+
raise NotImplementedError
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class MaximumLikelihoodRanker(SequenceRanker):
|
| 192 |
+
"""
|
| 193 |
+
Select the sample with the highest log probabilities, penalized using either
|
| 194 |
+
a simple length normalization or Google NMT paper's length penalty
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
def __init__(self, length_penalty: Optional[float]):
|
| 198 |
+
self.length_penalty = length_penalty
|
| 199 |
+
|
| 200 |
+
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
|
| 201 |
+
def scores(logprobs, lengths):
|
| 202 |
+
result = []
|
| 203 |
+
for logprob, length in zip(logprobs, lengths):
|
| 204 |
+
if self.length_penalty is None:
|
| 205 |
+
penalty = length
|
| 206 |
+
else:
|
| 207 |
+
# from the Google NMT paper
|
| 208 |
+
penalty = ((5 + length) / 6) ** self.length_penalty
|
| 209 |
+
result.append(logprob / penalty)
|
| 210 |
+
return result
|
| 211 |
+
|
| 212 |
+
# get the sequence with the highest score
|
| 213 |
+
lengths = [[len(t) for t in s] for s in tokens]
|
| 214 |
+
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class TokenDecoder:
|
| 218 |
+
def reset(self):
|
| 219 |
+
"""Initialize any stateful variables for decoding a new sequence"""
|
| 220 |
+
|
| 221 |
+
def update(
|
| 222 |
+
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
| 223 |
+
) -> Tuple[Tensor, bool]:
|
| 224 |
+
"""Specify how to select the next token, based on the current trace and logits
|
| 225 |
+
|
| 226 |
+
Parameters
|
| 227 |
+
----------
|
| 228 |
+
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
| 229 |
+
all tokens in the context so far, including the prefix and sot_sequence tokens
|
| 230 |
+
|
| 231 |
+
logits : Tensor, shape = (n_batch, vocab_size)
|
| 232 |
+
per-token logits of the probability distribution at the current step
|
| 233 |
+
|
| 234 |
+
sum_logprobs : Tensor, shape = (n_batch)
|
| 235 |
+
cumulative log probabilities for each sequence
|
| 236 |
+
|
| 237 |
+
Returns
|
| 238 |
+
-------
|
| 239 |
+
tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
|
| 240 |
+
the tokens, appended with the selected next token
|
| 241 |
+
|
| 242 |
+
completed : bool
|
| 243 |
+
True if all sequences has reached the end of text
|
| 244 |
+
|
| 245 |
+
"""
|
| 246 |
+
raise NotImplementedError
|
| 247 |
+
|
| 248 |
+
def finalize(
|
| 249 |
+
self, tokens: Tensor, sum_logprobs: Tensor
|
| 250 |
+
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
|
| 251 |
+
"""Finalize search and return the final candidate sequences
|
| 252 |
+
|
| 253 |
+
Parameters
|
| 254 |
+
----------
|
| 255 |
+
tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
|
| 256 |
+
all tokens in the context so far, including the prefix and sot_sequence
|
| 257 |
+
|
| 258 |
+
sum_logprobs : Tensor, shape = (n_audio, n_group)
|
| 259 |
+
cumulative log probabilities for each sequence
|
| 260 |
+
|
| 261 |
+
Returns
|
| 262 |
+
-------
|
| 263 |
+
tokens : Sequence[Sequence[Tensor]], length = n_audio
|
| 264 |
+
sequence of Tensors containing candidate token sequences, for each audio input
|
| 265 |
+
|
| 266 |
+
sum_logprobs : List[List[float]], length = n_audio
|
| 267 |
+
sequence of cumulative log probabilities corresponding to the above
|
| 268 |
+
|
| 269 |
+
"""
|
| 270 |
+
raise NotImplementedError
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class GreedyDecoder(TokenDecoder):
|
| 274 |
+
def __init__(self, temperature: float, eot: int):
|
| 275 |
+
self.temperature = temperature
|
| 276 |
+
self.eot = eot
|
| 277 |
+
|
| 278 |
+
def update(
|
| 279 |
+
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
| 280 |
+
) -> Tuple[Tensor, bool]:
|
| 281 |
+
if self.temperature == 0:
|
| 282 |
+
next_tokens = logits.argmax(dim=-1)
|
| 283 |
+
else:
|
| 284 |
+
next_tokens = Categorical(logits=logits / self.temperature).sample()
|
| 285 |
+
|
| 286 |
+
logprobs = F.log_softmax(logits.float(), dim=-1)
|
| 287 |
+
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
|
| 288 |
+
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
|
| 289 |
+
|
| 290 |
+
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
| 291 |
+
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
|
| 292 |
+
|
| 293 |
+
completed = (tokens[:, -1] == self.eot).all()
|
| 294 |
+
return tokens, completed
|
| 295 |
+
|
| 296 |
+
def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
|
| 297 |
+
# make sure each sequence has at least one EOT token at the end
|
| 298 |
+
tokens = F.pad(tokens, (0, 1), value=self.eot)
|
| 299 |
+
return tokens, sum_logprobs.tolist()
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class BeamSearchDecoder(TokenDecoder):
|
| 303 |
+
def __init__(
|
| 304 |
+
self,
|
| 305 |
+
beam_size: int,
|
| 306 |
+
eot: int,
|
| 307 |
+
inference: Inference,
|
| 308 |
+
patience: Optional[float] = None,
|
| 309 |
+
):
|
| 310 |
+
self.beam_size = beam_size
|
| 311 |
+
self.eot = eot
|
| 312 |
+
self.inference = inference
|
| 313 |
+
self.patience = patience or 1.0
|
| 314 |
+
self.max_candidates: int = round(beam_size * self.patience)
|
| 315 |
+
self.finished_sequences = None
|
| 316 |
+
|
| 317 |
+
assert (
|
| 318 |
+
self.max_candidates > 0
|
| 319 |
+
), f"Invalid beam size ({beam_size}) or patience ({patience})"
|
| 320 |
+
|
| 321 |
+
def reset(self):
|
| 322 |
+
self.finished_sequences = None
|
| 323 |
+
|
| 324 |
+
def update(
|
| 325 |
+
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
| 326 |
+
) -> Tuple[Tensor, bool]:
|
| 327 |
+
if tokens.shape[0] % self.beam_size != 0:
|
| 328 |
+
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
| 329 |
+
|
| 330 |
+
n_audio = tokens.shape[0] // self.beam_size
|
| 331 |
+
if self.finished_sequences is None: # for the first update
|
| 332 |
+
self.finished_sequences = [{} for _ in range(n_audio)]
|
| 333 |
+
|
| 334 |
+
logprobs = F.log_softmax(logits.float(), dim=-1)
|
| 335 |
+
next_tokens, source_indices, finished_sequences = [], [], []
|
| 336 |
+
for i in range(n_audio):
|
| 337 |
+
scores, sources, finished = {}, {}, {}
|
| 338 |
+
|
| 339 |
+
# STEP 1: calculate the cumulative log probabilities for possible candidates
|
| 340 |
+
for j in range(self.beam_size):
|
| 341 |
+
idx = i * self.beam_size + j
|
| 342 |
+
prefix = tokens[idx].tolist()
|
| 343 |
+
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
|
| 344 |
+
new_logprob = (sum_logprobs[idx] + logprob).item()
|
| 345 |
+
sequence = tuple(prefix + [token.item()])
|
| 346 |
+
scores[sequence] = new_logprob
|
| 347 |
+
sources[sequence] = idx
|
| 348 |
+
|
| 349 |
+
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
|
| 350 |
+
saved = 0
|
| 351 |
+
for sequence in sorted(scores, key=scores.get, reverse=True):
|
| 352 |
+
if sequence[-1] == self.eot:
|
| 353 |
+
finished[sequence] = scores[sequence]
|
| 354 |
+
else:
|
| 355 |
+
sum_logprobs[len(next_tokens)] = scores[sequence]
|
| 356 |
+
next_tokens.append(sequence)
|
| 357 |
+
source_indices.append(sources[sequence])
|
| 358 |
+
|
| 359 |
+
saved += 1
|
| 360 |
+
if saved == self.beam_size:
|
| 361 |
+
break
|
| 362 |
+
|
| 363 |
+
finished_sequences.append(finished)
|
| 364 |
+
|
| 365 |
+
tokens = torch.tensor(next_tokens, device=tokens.device)
|
| 366 |
+
self.inference.rearrange_kv_cache(source_indices)
|
| 367 |
+
|
| 368 |
+
# add newly finished sequences to self.finished_sequences
|
| 369 |
+
assert len(self.finished_sequences) == len(finished_sequences)
|
| 370 |
+
for previously_finished, newly_finished in zip(
|
| 371 |
+
self.finished_sequences, finished_sequences
|
| 372 |
+
):
|
| 373 |
+
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
| 374 |
+
if len(previously_finished) >= self.max_candidates:
|
| 375 |
+
break # the candidate list is full
|
| 376 |
+
previously_finished[seq] = newly_finished[seq]
|
| 377 |
+
|
| 378 |
+
# mark as completed if all audio has enough number of samples
|
| 379 |
+
completed = all(
|
| 380 |
+
len(sequences) >= self.max_candidates
|
| 381 |
+
for sequences in self.finished_sequences
|
| 382 |
+
)
|
| 383 |
+
return tokens, completed
|
| 384 |
+
|
| 385 |
+
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
|
| 386 |
+
# collect all finished sequences, including patience, and add unfinished ones if not enough
|
| 387 |
+
sum_logprobs = sum_logprobs.cpu()
|
| 388 |
+
for i, sequences in enumerate(self.finished_sequences):
|
| 389 |
+
if (
|
| 390 |
+
len(sequences) < self.beam_size
|
| 391 |
+
): # when not enough sequences are finished
|
| 392 |
+
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
| 393 |
+
sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
| 394 |
+
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
|
| 395 |
+
if len(sequences) >= self.beam_size:
|
| 396 |
+
break
|
| 397 |
+
|
| 398 |
+
tokens: List[List[Tensor]] = [
|
| 399 |
+
[torch.tensor(seq) for seq in sequences.keys()]
|
| 400 |
+
for sequences in self.finished_sequences
|
| 401 |
+
]
|
| 402 |
+
sum_logprobs: List[List[float]] = [
|
| 403 |
+
list(sequences.values()) for sequences in self.finished_sequences
|
| 404 |
+
]
|
| 405 |
+
return tokens, sum_logprobs
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
class LogitFilter:
|
| 409 |
+
def apply(self, logits: Tensor, tokens: Tensor) -> None:
|
| 410 |
+
"""Apply any filtering or masking to logits in-place
|
| 411 |
+
|
| 412 |
+
Parameters
|
| 413 |
+
----------
|
| 414 |
+
logits : Tensor, shape = (n_batch, vocab_size)
|
| 415 |
+
per-token logits of the probability distribution at the current step
|
| 416 |
+
|
| 417 |
+
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
| 418 |
+
all tokens in the context so far, including the prefix and sot_sequence tokens
|
| 419 |
+
|
| 420 |
+
"""
|
| 421 |
+
raise NotImplementedError
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
class SuppressBlank(LogitFilter):
|
| 425 |
+
def __init__(self, tokenizer: Tokenizer, sample_begin: int):
|
| 426 |
+
self.tokenizer = tokenizer
|
| 427 |
+
self.sample_begin = sample_begin
|
| 428 |
+
|
| 429 |
+
def apply(self, logits: Tensor, tokens: Tensor):
|
| 430 |
+
if tokens.shape[1] == self.sample_begin:
|
| 431 |
+
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
class SuppressTokens(LogitFilter):
|
| 435 |
+
def __init__(self, suppress_tokens: Sequence[int]):
|
| 436 |
+
self.suppress_tokens = list(suppress_tokens)
|
| 437 |
+
|
| 438 |
+
def apply(self, logits: Tensor, tokens: Tensor):
|
| 439 |
+
logits[:, self.suppress_tokens] = -np.inf
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
class ApplyTimestampRules(LogitFilter):
|
| 443 |
+
def __init__(
|
| 444 |
+
self,
|
| 445 |
+
tokenizer: Tokenizer,
|
| 446 |
+
sample_begin: int,
|
| 447 |
+
max_initial_timestamp_index: Optional[int],
|
| 448 |
+
):
|
| 449 |
+
self.tokenizer = tokenizer
|
| 450 |
+
self.sample_begin = sample_begin
|
| 451 |
+
self.max_initial_timestamp_index = max_initial_timestamp_index
|
| 452 |
+
|
| 453 |
+
def apply(self, logits: Tensor, tokens: Tensor):
|
| 454 |
+
# suppress <|notimestamps|> which is handled by without_timestamps
|
| 455 |
+
if self.tokenizer.no_timestamps is not None:
|
| 456 |
+
logits[:, self.tokenizer.no_timestamps] = -np.inf
|
| 457 |
+
|
| 458 |
+
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
| 459 |
+
for k in range(tokens.shape[0]):
|
| 460 |
+
sampled_tokens = tokens[k, self.sample_begin :]
|
| 461 |
+
seq = [t for t in sampled_tokens.tolist()]
|
| 462 |
+
last_was_timestamp = (
|
| 463 |
+
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
| 464 |
+
)
|
| 465 |
+
penultimate_was_timestamp = (
|
| 466 |
+
len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
if last_was_timestamp:
|
| 470 |
+
if penultimate_was_timestamp: # has to be non-timestamp
|
| 471 |
+
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
|
| 472 |
+
else: # cannot be normal text tokens
|
| 473 |
+
logits[k, : self.tokenizer.eot] = -np.inf
|
| 474 |
+
|
| 475 |
+
timestamps = sampled_tokens[
|
| 476 |
+
sampled_tokens.ge(self.tokenizer.timestamp_begin)
|
| 477 |
+
]
|
| 478 |
+
if timestamps.numel() > 0:
|
| 479 |
+
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
|
| 480 |
+
# also force each segment to have a nonzero length, to prevent infinite looping
|
| 481 |
+
if last_was_timestamp and not penultimate_was_timestamp:
|
| 482 |
+
timestamp_last = timestamps[-1]
|
| 483 |
+
else:
|
| 484 |
+
timestamp_last = timestamps[-1] + 1
|
| 485 |
+
logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf
|
| 486 |
+
|
| 487 |
+
if tokens.shape[1] == self.sample_begin:
|
| 488 |
+
# suppress generating non-timestamp tokens at the beginning
|
| 489 |
+
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
|
| 490 |
+
|
| 491 |
+
# apply the `max_initial_timestamp` option
|
| 492 |
+
if self.max_initial_timestamp_index is not None:
|
| 493 |
+
last_allowed = (
|
| 494 |
+
self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
| 495 |
+
)
|
| 496 |
+
logits[:, last_allowed + 1 :] = -np.inf
|
| 497 |
+
|
| 498 |
+
# if sum of probability over timestamps is above any other token, sample timestamp
|
| 499 |
+
logprobs = F.log_softmax(logits.float(), dim=-1)
|
| 500 |
+
for k in range(tokens.shape[0]):
|
| 501 |
+
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
|
| 502 |
+
dim=-1
|
| 503 |
+
)
|
| 504 |
+
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
|
| 505 |
+
if timestamp_logprob > max_text_token_logprob:
|
| 506 |
+
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
class DecodingTask:
|
| 510 |
+
inference: Inference
|
| 511 |
+
sequence_ranker: SequenceRanker
|
| 512 |
+
decoder: TokenDecoder
|
| 513 |
+
logit_filters: List[LogitFilter]
|
| 514 |
+
|
| 515 |
+
def __init__(self, model: "Whisper", options: DecodingOptions):
|
| 516 |
+
self.options: DecodingOptions = self._verify_options(options)
|
| 517 |
+
if self.options.fp16:
|
| 518 |
+
self.model = model.half()
|
| 519 |
+
else:
|
| 520 |
+
self.model = model
|
| 521 |
+
|
| 522 |
+
language = options.language or "en"
|
| 523 |
+
tokenizer = get_tokenizer(
|
| 524 |
+
model.is_multilingual, language=language, task=options.task
|
| 525 |
+
)
|
| 526 |
+
self.tokenizer: Tokenizer = tokenizer
|
| 527 |
+
|
| 528 |
+
# print(self.options)
|
| 529 |
+
|
| 530 |
+
self.n_group: int = options.beam_size or options.best_of or 1
|
| 531 |
+
self.n_ctx: int = model.dims.n_text_ctx
|
| 532 |
+
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
|
| 533 |
+
|
| 534 |
+
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
|
| 535 |
+
if self.options.without_timestamps:
|
| 536 |
+
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
|
| 537 |
+
|
| 538 |
+
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
|
| 539 |
+
self.sample_begin: int = len(self.initial_tokens)
|
| 540 |
+
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
|
| 541 |
+
|
| 542 |
+
# inference: implements the forward pass through the decoder, including kv caching
|
| 543 |
+
self.inference = PyTorchInference(model, len(self.initial_tokens))
|
| 544 |
+
|
| 545 |
+
# sequence ranker: implements how to rank a group of sampled sequences
|
| 546 |
+
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
| 547 |
+
|
| 548 |
+
# decoder: implements how to select the next tokens, given the autoregressive distribution
|
| 549 |
+
if options.beam_size is not None:
|
| 550 |
+
self.decoder = BeamSearchDecoder(
|
| 551 |
+
options.beam_size, tokenizer.eot, self.inference, options.patience
|
| 552 |
+
)
|
| 553 |
+
else:
|
| 554 |
+
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
| 555 |
+
|
| 556 |
+
# logit filters: applies various rules to suppress or penalize certain tokens
|
| 557 |
+
self.logit_filters = []
|
| 558 |
+
if self.options.suppress_blank:
|
| 559 |
+
self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
|
| 560 |
+
if self.options.suppress_tokens:
|
| 561 |
+
self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
|
| 562 |
+
if not options.without_timestamps:
|
| 563 |
+
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
|
| 564 |
+
max_initial_timestamp_index = None
|
| 565 |
+
if options.max_initial_timestamp:
|
| 566 |
+
max_initial_timestamp_index = round(
|
| 567 |
+
self.options.max_initial_timestamp / precision
|
| 568 |
+
)
|
| 569 |
+
self.logit_filters.append(
|
| 570 |
+
ApplyTimestampRules(
|
| 571 |
+
tokenizer, self.sample_begin, max_initial_timestamp_index
|
| 572 |
+
)
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
|
| 576 |
+
if options.beam_size is not None and options.best_of is not None:
|
| 577 |
+
raise ValueError("beam_size and best_of can't be given together")
|
| 578 |
+
if options.temperature == 0:
|
| 579 |
+
if options.best_of is not None:
|
| 580 |
+
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
|
| 581 |
+
if options.patience is not None and options.beam_size is None:
|
| 582 |
+
raise ValueError("patience requires beam_size to be given")
|
| 583 |
+
if options.length_penalty is not None and not (
|
| 584 |
+
0 <= options.length_penalty <= 1
|
| 585 |
+
):
|
| 586 |
+
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
|
| 587 |
+
|
| 588 |
+
return options
|
| 589 |
+
|
| 590 |
+
def _get_initial_tokens(self) -> Tuple[int]:
|
| 591 |
+
tokens = list(self.sot_sequence)
|
| 592 |
+
# print("prefix", prefix)
|
| 593 |
+
if prefix := self.options.prefix:
|
| 594 |
+
prefix_tokens = (
|
| 595 |
+
self.tokenizer.encode(" " + prefix.strip())
|
| 596 |
+
if isinstance(prefix, str)
|
| 597 |
+
else prefix
|
| 598 |
+
)
|
| 599 |
+
if self.sample_len is not None:
|
| 600 |
+
max_prefix_len = self.n_ctx // 2 - self.sample_len
|
| 601 |
+
prefix_tokens = prefix_tokens[-max_prefix_len:]
|
| 602 |
+
tokens = tokens + prefix_tokens
|
| 603 |
+
|
| 604 |
+
if prompt := self.options.prompt:
|
| 605 |
+
prompt_tokens = (
|
| 606 |
+
self.tokenizer.encode(" " + prompt.strip())
|
| 607 |
+
if isinstance(prompt, str)
|
| 608 |
+
else prompt
|
| 609 |
+
)
|
| 610 |
+
# if self.options.add_sot:
|
| 611 |
+
tokens = (
|
| 612 |
+
[self.tokenizer.sot_prev]
|
| 613 |
+
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
|
| 614 |
+
+ tokens
|
| 615 |
+
)
|
| 616 |
+
#else:
|
| 617 |
+
# tokens = ([self.tokenizer.sot_prev] + tokens + prompt_tokens[-(self.n_ctx // 2 - 1) :])
|
| 618 |
+
# print("return", tokens)
|
| 619 |
+
return tuple(tokens)
|
| 620 |
+
|
| 621 |
+
def _get_suppress_tokens(self) -> Tuple[int]:
|
| 622 |
+
suppress_tokens = self.options.suppress_tokens
|
| 623 |
+
|
| 624 |
+
if isinstance(suppress_tokens, str):
|
| 625 |
+
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
|
| 626 |
+
|
| 627 |
+
if -1 in suppress_tokens:
|
| 628 |
+
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
| 629 |
+
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
|
| 630 |
+
elif suppress_tokens is None or len(suppress_tokens) == 0:
|
| 631 |
+
suppress_tokens = [] # interpret empty string as an empty list
|
| 632 |
+
else:
|
| 633 |
+
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
|
| 634 |
+
|
| 635 |
+
suppress_tokens.extend(
|
| 636 |
+
[
|
| 637 |
+
self.tokenizer.transcribe,
|
| 638 |
+
self.tokenizer.translate,
|
| 639 |
+
self.tokenizer.sot,
|
| 640 |
+
self.tokenizer.sot_prev,
|
| 641 |
+
self.tokenizer.sot_lm,
|
| 642 |
+
]
|
| 643 |
+
)
|
| 644 |
+
if self.tokenizer.no_speech is not None:
|
| 645 |
+
# no-speech probability is collected separately
|
| 646 |
+
suppress_tokens.append(self.tokenizer.no_speech)
|
| 647 |
+
|
| 648 |
+
return tuple(sorted(set(suppress_tokens)))
|
| 649 |
+
|
| 650 |
+
def _get_audio_features(self, mel: Tensor):
|
| 651 |
+
if self.options.fp16:
|
| 652 |
+
mel = mel.half()
|
| 653 |
+
|
| 654 |
+
if mel.shape[-2:] == (
|
| 655 |
+
self.model.dims.n_audio_ctx,
|
| 656 |
+
self.model.dims.n_audio_state,
|
| 657 |
+
):
|
| 658 |
+
# encoded audio features are given; skip audio encoding
|
| 659 |
+
audio_features = mel
|
| 660 |
+
else:
|
| 661 |
+
audio_features = self.model.encoder(mel)
|
| 662 |
+
|
| 663 |
+
if audio_features.dtype != (
|
| 664 |
+
torch.float16 if self.options.fp16 else torch.float32
|
| 665 |
+
):
|
| 666 |
+
raise TypeError(
|
| 667 |
+
f"audio_features has an incorrect dtype: {audio_features.dtype}"
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
return audio_features
|
| 671 |
+
|
| 672 |
+
def _detect_language(self, audio_features: Tensor, tokens: Tensor):
|
| 673 |
+
languages = [self.options.language] * audio_features.shape[0]
|
| 674 |
+
lang_probs = None
|
| 675 |
+
|
| 676 |
+
if self.options.language is None or self.options.task == "lang_id":
|
| 677 |
+
lang_tokens, lang_probs = self.model.detect_language(
|
| 678 |
+
audio_features, self.tokenizer
|
| 679 |
+
)
|
| 680 |
+
languages = [max(probs, key=probs.get) for probs in lang_probs]
|
| 681 |
+
if self.options.language is None:
|
| 682 |
+
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
|
| 683 |
+
|
| 684 |
+
return languages, lang_probs
|
| 685 |
+
|
| 686 |
+
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
|
| 687 |
+
n_batch = tokens.shape[0]
|
| 688 |
+
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
|
| 689 |
+
no_speech_probs = [np.nan] * n_batch
|
| 690 |
+
|
| 691 |
+
try:
|
| 692 |
+
for i in range(self.sample_len): # 最多循环448次
|
| 693 |
+
# print("in decode main loop", i , tokens[0].tolist())
|
| 694 |
+
logits = self.inference.logits(tokens, audio_features)
|
| 695 |
+
# print(logits)
|
| 696 |
+
if (
|
| 697 |
+
i == 0 and self.tokenizer.no_speech is not None
|
| 698 |
+
): # save no_speech_probs
|
| 699 |
+
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
|
| 700 |
+
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
| 701 |
+
|
| 702 |
+
# now we need to consider the logits at the last token only
|
| 703 |
+
logits = logits[:, -1]
|
| 704 |
+
|
| 705 |
+
# apply the logit filters, e.g. for suppressing or applying penalty to
|
| 706 |
+
for logit_filter in self.logit_filters:
|
| 707 |
+
logit_filter.apply(logits, tokens)
|
| 708 |
+
|
| 709 |
+
# expand the tokens tensor with the selected next tokens
|
| 710 |
+
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
|
| 711 |
+
|
| 712 |
+
if completed or tokens.shape[-1] > self.n_ctx:
|
| 713 |
+
break
|
| 714 |
+
finally:
|
| 715 |
+
self.inference.cleanup_caching()
|
| 716 |
+
|
| 717 |
+
return tokens, sum_logprobs, no_speech_probs
|
| 718 |
+
|
| 719 |
+
@torch.no_grad()
|
| 720 |
+
def run(self, mel: Tensor) -> List[DecodingResult]:
|
| 721 |
+
self.decoder.reset()
|
| 722 |
+
tokenizer: Tokenizer = self.tokenizer
|
| 723 |
+
n_audio: int = mel.shape[0]
|
| 724 |
+
|
| 725 |
+
audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
|
| 726 |
+
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
|
| 727 |
+
# print("initial_tokens", self.initial_tokens)
|
| 728 |
+
# detect language if requested, overwriting the language token
|
| 729 |
+
languages, language_probs = self._detect_language(audio_features, tokens)
|
| 730 |
+
if self.options.task == "lang_id":
|
| 731 |
+
return [
|
| 732 |
+
DecodingResult(
|
| 733 |
+
audio_features=features, language=language, language_probs=probs
|
| 734 |
+
)
|
| 735 |
+
for features, language, probs in zip(
|
| 736 |
+
audio_features, languages, language_probs
|
| 737 |
+
)
|
| 738 |
+
]
|
| 739 |
+
|
| 740 |
+
# repeat text tensors by the group size, for beam search or best-of-n sampling
|
| 741 |
+
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
| 742 |
+
|
| 743 |
+
# call the main sampling loop
|
| 744 |
+
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
|
| 745 |
+
|
| 746 |
+
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
| 747 |
+
audio_features = audio_features[:: self.n_group]
|
| 748 |
+
no_speech_probs = no_speech_probs[:: self.n_group]
|
| 749 |
+
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
|
| 750 |
+
|
| 751 |
+
tokens = tokens.reshape(n_audio, self.n_group, -1)
|
| 752 |
+
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
| 753 |
+
|
| 754 |
+
# get the final candidates for each group, and slice between the first sampled token and EOT
|
| 755 |
+
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
| 756 |
+
tokens: List[List[Tensor]] = [
|
| 757 |
+
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
|
| 758 |
+
for s in tokens
|
| 759 |
+
]
|
| 760 |
+
|
| 761 |
+
# select the top-ranked sample in each group
|
| 762 |
+
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
| 763 |
+
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
| 764 |
+
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
| 765 |
+
|
| 766 |
+
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
| 767 |
+
avg_logprobs: List[float] = [
|
| 768 |
+
lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
|
| 769 |
+
]
|
| 770 |
+
|
| 771 |
+
fields = (
|
| 772 |
+
texts,
|
| 773 |
+
languages,
|
| 774 |
+
tokens,
|
| 775 |
+
audio_features,
|
| 776 |
+
avg_logprobs,
|
| 777 |
+
no_speech_probs,
|
| 778 |
+
)
|
| 779 |
+
if len(set(map(len, fields))) != 1:
|
| 780 |
+
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
| 781 |
+
|
| 782 |
+
return [
|
| 783 |
+
DecodingResult(
|
| 784 |
+
audio_features=features,
|
| 785 |
+
language=language,
|
| 786 |
+
tokens=tokens,
|
| 787 |
+
text=text,
|
| 788 |
+
avg_logprob=avg_logprob,
|
| 789 |
+
no_speech_prob=no_speech_prob,
|
| 790 |
+
temperature=self.options.temperature,
|
| 791 |
+
compression_ratio=compression_ratio(text),
|
| 792 |
+
)
|
| 793 |
+
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
|
| 794 |
+
*fields
|
| 795 |
+
)
|
| 796 |
+
]
|
| 797 |
+
|
| 798 |
+
|
| 799 |
+
@torch.no_grad()
|
| 800 |
+
def decode(
|
| 801 |
+
model: "Whisper",
|
| 802 |
+
mel: Tensor,
|
| 803 |
+
options: DecodingOptions = DecodingOptions(),
|
| 804 |
+
**kwargs,
|
| 805 |
+
) -> Union[DecodingResult, List[DecodingResult]]:
|
| 806 |
+
"""
|
| 807 |
+
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
| 808 |
+
|
| 809 |
+
Parameters
|
| 810 |
+
----------
|
| 811 |
+
model: Whisper
|
| 812 |
+
the Whisper model instance
|
| 813 |
+
|
| 814 |
+
mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
|
| 815 |
+
A tensor containing the Mel spectrogram(s)
|
| 816 |
+
|
| 817 |
+
options: DecodingOptions
|
| 818 |
+
A dataclass that contains all necessary options for decoding 30-second segments
|
| 819 |
+
|
| 820 |
+
Returns
|
| 821 |
+
-------
|
| 822 |
+
result: Union[DecodingResult, List[DecodingResult]]
|
| 823 |
+
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
|
| 824 |
+
"""
|
| 825 |
+
if single := mel.ndim == 2:
|
| 826 |
+
mel = mel.unsqueeze(0)
|
| 827 |
+
|
| 828 |
+
if kwargs:
|
| 829 |
+
options = replace(options, **kwargs)
|
| 830 |
+
|
| 831 |
+
result = DecodingTask(model, options).run(mel)
|
| 832 |
+
|
| 833 |
+
return result[0] if single else result
|
simul_whisper/whisper/model.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import gzip
|
| 3 |
+
from contextlib import contextmanager
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Dict, Iterable, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torch import Tensor, nn
|
| 11 |
+
|
| 12 |
+
from .decoding import decode as decode_function
|
| 13 |
+
from .decoding import detect_language as detect_language_function
|
| 14 |
+
from .transcribe import transcribe as transcribe_function
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from torch.nn.functional import scaled_dot_product_attention
|
| 19 |
+
|
| 20 |
+
SDPA_AVAILABLE = True
|
| 21 |
+
except (ImportError, RuntimeError, OSError):
|
| 22 |
+
scaled_dot_product_attention = None
|
| 23 |
+
SDPA_AVAILABLE = False
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class ModelDimensions:
|
| 28 |
+
n_mels: int
|
| 29 |
+
n_audio_ctx: int
|
| 30 |
+
n_audio_state: int
|
| 31 |
+
n_audio_head: int
|
| 32 |
+
n_audio_layer: int
|
| 33 |
+
n_vocab: int
|
| 34 |
+
n_text_ctx: int
|
| 35 |
+
n_text_state: int
|
| 36 |
+
n_text_head: int
|
| 37 |
+
n_text_layer: int
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# class LayerNorm(nn.LayerNorm):
|
| 41 |
+
# def forward(self, x: Tensor) -> Tensor:
|
| 42 |
+
# return super().forward(x.float()).type(x.dtype)
|
| 43 |
+
|
| 44 |
+
# class Linear(nn.Linear):
|
| 45 |
+
# def forward(self, x: Tensor) -> Tensor:
|
| 46 |
+
# return F.linear(
|
| 47 |
+
# x,
|
| 48 |
+
# self.weight.to(x.dtype),
|
| 49 |
+
# None if self.bias is None else self.bias.to(x.dtype),
|
| 50 |
+
# )
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# class Conv1d(nn.Conv1d):
|
| 54 |
+
# def _conv_forward(
|
| 55 |
+
# self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
| 56 |
+
# ) -> Tensor:
|
| 57 |
+
# return super()._conv_forward(
|
| 58 |
+
# x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
| 59 |
+
# )
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def sinusoids(length, channels, max_timescale=10000):
|
| 63 |
+
"""Returns sinusoids for positional embedding"""
|
| 64 |
+
assert channels % 2 == 0
|
| 65 |
+
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
| 66 |
+
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
| 67 |
+
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
| 68 |
+
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
| 69 |
+
|
| 70 |
+
import sys ## this is mine, for debugging
|
| 71 |
+
class MultiHeadAttention(nn.Module):
|
| 72 |
+
|
| 73 |
+
use_sdpa = False # disabling: https://github.com/linto-ai/whisper-timestamped/issues/212
|
| 74 |
+
|
| 75 |
+
def __init__(self, n_state: int, n_head: int, cache_id: str):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.n_head = n_head
|
| 78 |
+
self.query = nn.Linear(n_state, n_state)
|
| 79 |
+
self.key = nn.Linear(n_state, n_state, bias=False)
|
| 80 |
+
self.key.cache_id = f"{cache_id}_key"
|
| 81 |
+
self.value = nn.Linear(n_state, n_state)
|
| 82 |
+
self.value.cache_id = f"{cache_id}_value"
|
| 83 |
+
self.out = nn.Linear(n_state, n_state)
|
| 84 |
+
self.cache_id = cache_id
|
| 85 |
+
|
| 86 |
+
def forward(
|
| 87 |
+
self,
|
| 88 |
+
x: Tensor,
|
| 89 |
+
xa: Optional[Tensor] = None,
|
| 90 |
+
mask: Optional[Tensor] = None,
|
| 91 |
+
kv_cache: Optional[dict] = None,
|
| 92 |
+
):
|
| 93 |
+
#print("MultiHeadAttention forward",file=sys.stderr)
|
| 94 |
+
q = self.query(x)
|
| 95 |
+
# print(q.shape, x is None, mask is None, list(kv_cache.keys()) if kv_cache is not None else None, file=sys.stderr)
|
| 96 |
+
# print(mask, kv_cache, xa, file=sys.stderr)
|
| 97 |
+
|
| 98 |
+
if kv_cache is None or xa is None or self.key.cache_id not in kv_cache:
|
| 99 |
+
k = self.key(x if xa is None else xa)
|
| 100 |
+
v = self.value(x if xa is None else xa)
|
| 101 |
+
# print(self.key.cache_id, "cache miss") # , kv_cache is None, xa is None, self.key.cache_id not in kv_cache if kv_cache is not None else None, k.shape, x.shape)
|
| 102 |
+
# if kv_cache is not None:
|
| 103 |
+
# print(kv_cache.keys())
|
| 104 |
+
else:
|
| 105 |
+
# print(self.key.cache_id, "cache hit") #, kv_cache is None, xa is None, self.key.cache_id not in kv_cache)
|
| 106 |
+
# if kv_cache is not None:
|
| 107 |
+
# print(kv_cache.keys())
|
| 108 |
+
k = kv_cache[self.key.cache_id]
|
| 109 |
+
v = kv_cache[self.value.cache_id]
|
| 110 |
+
# print(self.key.cache_id, "qkv attention", q.shape, k.shape, v.shape)
|
| 111 |
+
wv, qk = self.qkv_attention(q, k, v, mask)
|
| 112 |
+
return self.out(wv), qk
|
| 113 |
+
|
| 114 |
+
# def qkv_attention(
|
| 115 |
+
# self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
| 116 |
+
# ):
|
| 117 |
+
# n_batch, n_ctx, n_state = q.shape
|
| 118 |
+
# scale = (n_state // self.n_head) ** -0.25
|
| 119 |
+
# q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
| 120 |
+
# k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
| 121 |
+
# v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
| 122 |
+
|
| 123 |
+
# qk = q @ k
|
| 124 |
+
# if mask is not None:
|
| 125 |
+
# qk = qk + mask[:n_ctx, :n_ctx]
|
| 126 |
+
# # qk = qk.float()
|
| 127 |
+
|
| 128 |
+
# w = F.softmax(qk, dim=-1) # .to(q.dtype)
|
| 129 |
+
# return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def qkv_attention(
|
| 133 |
+
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
| 134 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 135 |
+
n_batch, n_ctx, n_state = q.shape
|
| 136 |
+
scale = (n_state // self.n_head) ** -0.25
|
| 137 |
+
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
| 138 |
+
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
| 139 |
+
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
| 140 |
+
|
| 141 |
+
if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
|
| 142 |
+
a = scaled_dot_product_attention(
|
| 143 |
+
q, k, v, is_causal=mask is not None and n_ctx > 1
|
| 144 |
+
)
|
| 145 |
+
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
|
| 146 |
+
qk = None
|
| 147 |
+
else:
|
| 148 |
+
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
| 149 |
+
if mask is not None:
|
| 150 |
+
qk = qk + mask[:n_ctx, :n_ctx]
|
| 151 |
+
qk = qk.float()
|
| 152 |
+
|
| 153 |
+
w = F.softmax(qk, dim=-1).to(q.dtype)
|
| 154 |
+
out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
| 155 |
+
qk = qk.detach()
|
| 156 |
+
|
| 157 |
+
return out, qk
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class ResidualAttentionBlock(nn.Module):
|
| 161 |
+
def __init__(self, n_state: int, n_head: int, cache_id: str="", cross_attention: bool = False):
|
| 162 |
+
super().__init__()
|
| 163 |
+
|
| 164 |
+
self.attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_self_attn")
|
| 165 |
+
self.attn_ln = nn.LayerNorm(n_state)
|
| 166 |
+
|
| 167 |
+
self.cross_attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_cross_attn") if cross_attention else None
|
| 168 |
+
|
| 169 |
+
self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None
|
| 170 |
+
|
| 171 |
+
n_mlp = n_state * 4
|
| 172 |
+
self.mlp = nn.Sequential(
|
| 173 |
+
nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
|
| 174 |
+
)
|
| 175 |
+
self.mlp_ln = nn.LayerNorm(n_state)
|
| 176 |
+
|
| 177 |
+
def forward(
|
| 178 |
+
self,
|
| 179 |
+
x: Tensor,
|
| 180 |
+
xa: Optional[Tensor] = None,
|
| 181 |
+
mask: Optional[Tensor] = None,
|
| 182 |
+
kv_cache: Optional[dict] = None,
|
| 183 |
+
):
|
| 184 |
+
# print("ResidualAttentionBlock forward",file=sys.stderr)
|
| 185 |
+
# print(x.shape, file=sys.stderr)
|
| 186 |
+
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
| 187 |
+
if self.cross_attn:
|
| 188 |
+
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
| 189 |
+
x = x + self.mlp(self.mlp_ln(x))
|
| 190 |
+
return x
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class AudioEncoder(nn.Module):
|
| 194 |
+
def __init__(
|
| 195 |
+
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
| 196 |
+
):
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
| 199 |
+
self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
| 200 |
+
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
| 201 |
+
|
| 202 |
+
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
| 203 |
+
[ResidualAttentionBlock(n_state, n_head, cache_id=f"enc_layer{i}") for i in range(n_layer)]
|
| 204 |
+
)
|
| 205 |
+
self.ln_post = nn.LayerNorm(n_state)
|
| 206 |
+
|
| 207 |
+
def forward(self, x: Tensor, return_layer_results: bool=False):
|
| 208 |
+
"""
|
| 209 |
+
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
| 210 |
+
the mel spectrogram of the audio
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
x = F.gelu(self.conv1(x))
|
| 214 |
+
x = F.gelu(self.conv2(x))
|
| 215 |
+
x = x.permute(0, 2, 1) # BDT -> BTD
|
| 216 |
+
|
| 217 |
+
# 两层卷积,2倍降采样
|
| 218 |
+
# 最终剩下1500帧
|
| 219 |
+
|
| 220 |
+
x = (x + self.positional_embedding[:x.shape[1], :]) #.to(x.dtype)
|
| 221 |
+
|
| 222 |
+
layer_results = []
|
| 223 |
+
i = 0
|
| 224 |
+
for block in self.blocks:
|
| 225 |
+
# print(f"encoder layer {i}")
|
| 226 |
+
x = block(x)
|
| 227 |
+
layer_results.append(x)
|
| 228 |
+
i += 1
|
| 229 |
+
|
| 230 |
+
x = self.ln_post(x)
|
| 231 |
+
|
| 232 |
+
if return_layer_results:
|
| 233 |
+
return x, layer_results
|
| 234 |
+
else:
|
| 235 |
+
return x
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class TextDecoder(nn.Module):
|
| 239 |
+
def __init__(
|
| 240 |
+
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
| 241 |
+
):
|
| 242 |
+
super().__init__()
|
| 243 |
+
|
| 244 |
+
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
| 245 |
+
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
| 246 |
+
|
| 247 |
+
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
| 248 |
+
[
|
| 249 |
+
ResidualAttentionBlock(n_state, n_head, cross_attention=True, cache_id=f"dec_layer{i}")
|
| 250 |
+
for i in range(n_layer)
|
| 251 |
+
]
|
| 252 |
+
)
|
| 253 |
+
self.ln = nn.LayerNorm(n_state)
|
| 254 |
+
|
| 255 |
+
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
| 256 |
+
self.register_buffer("mask", mask, persistent=False)
|
| 257 |
+
|
| 258 |
+
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
| 259 |
+
"""
|
| 260 |
+
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
| 261 |
+
the text tokens
|
| 262 |
+
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
| 263 |
+
the encoded audio features to be attended on
|
| 264 |
+
"""
|
| 265 |
+
|
| 266 |
+
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
| 267 |
+
x = (
|
| 268 |
+
self.token_embedding(x)
|
| 269 |
+
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
| 270 |
+
)
|
| 271 |
+
# x = x.to(xa.dtype)
|
| 272 |
+
|
| 273 |
+
i = 0
|
| 274 |
+
for block in self.blocks:
|
| 275 |
+
# print(f"decoder layer {i}")
|
| 276 |
+
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
| 277 |
+
i += 1
|
| 278 |
+
|
| 279 |
+
x = self.ln(x)
|
| 280 |
+
logits = x @ torch.transpose(self.token_embedding.weight, 0, 1)
|
| 281 |
+
|
| 282 |
+
return logits
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class Whisper(nn.Module):
|
| 286 |
+
def __init__(self, dims: ModelDimensions):
|
| 287 |
+
super().__init__()
|
| 288 |
+
self.dims = dims
|
| 289 |
+
self.encoder = AudioEncoder(
|
| 290 |
+
self.dims.n_mels,
|
| 291 |
+
self.dims.n_audio_ctx,
|
| 292 |
+
self.dims.n_audio_state,
|
| 293 |
+
self.dims.n_audio_head,
|
| 294 |
+
self.dims.n_audio_layer,
|
| 295 |
+
)
|
| 296 |
+
self.decoder = TextDecoder(
|
| 297 |
+
self.dims.n_vocab,
|
| 298 |
+
self.dims.n_text_ctx,
|
| 299 |
+
self.dims.n_text_state,
|
| 300 |
+
self.dims.n_text_head,
|
| 301 |
+
self.dims.n_text_layer,
|
| 302 |
+
)
|
| 303 |
+
# use the last half layers for alignment by default; see `set_alignment_heads()` below
|
| 304 |
+
all_heads = torch.zeros(
|
| 305 |
+
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
| 306 |
+
)
|
| 307 |
+
all_heads[self.dims.n_text_layer // 2 :] = True
|
| 308 |
+
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
| 309 |
+
|
| 310 |
+
def set_alignment_heads(self, dump: bytes):
|
| 311 |
+
array = np.frombuffer(
|
| 312 |
+
gzip.decompress(base64.b85decode(dump)), dtype=bool
|
| 313 |
+
).copy()
|
| 314 |
+
mask = torch.from_numpy(array).reshape(
|
| 315 |
+
self.dims.n_text_layer, self.dims.n_text_head
|
| 316 |
+
)
|
| 317 |
+
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
|
| 318 |
+
|
| 319 |
+
def embed_audio(self, mel: torch.Tensor):
|
| 320 |
+
return self.encoder(mel)
|
| 321 |
+
|
| 322 |
+
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
| 323 |
+
# tokens = tokens.to(self.decoder.ln.weight.dtype)
|
| 324 |
+
# audio_features = audio_features.to(self.decoder.ln.weight.dtype)
|
| 325 |
+
return self.decoder(tokens, audio_features)
|
| 326 |
+
|
| 327 |
+
def forward(
|
| 328 |
+
self, mel: torch.Tensor, tokens: torch.Tensor
|
| 329 |
+
) -> Dict[str, torch.Tensor]:
|
| 330 |
+
# mel = mel.to(self.decoder.ln.weight.dtype)
|
| 331 |
+
# tokens = tokens.to(self.decoder.ln.weight.dtype)
|
| 332 |
+
return self.decoder(tokens, self.encoder(mel))
|
| 333 |
+
|
| 334 |
+
@property
|
| 335 |
+
def device(self):
|
| 336 |
+
return next(self.parameters()).device
|
| 337 |
+
|
| 338 |
+
@property
|
| 339 |
+
def is_multilingual(self):
|
| 340 |
+
return self.dims.n_vocab >= 51865
|
| 341 |
+
|
| 342 |
+
@property
|
| 343 |
+
def num_languages(self):
|
| 344 |
+
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
| 345 |
+
|
| 346 |
+
# 为decoder加入缓存机制,每次推理时保存上次的k和v,下次推理无需重新计算
|
| 347 |
+
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
| 348 |
+
"""
|
| 349 |
+
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
| 350 |
+
tensors calculated for the previous positions. This method returns a dictionary that stores
|
| 351 |
+
all caches, and the necessary hooks for the key and value projection modules that save the
|
| 352 |
+
intermediate tensors to be reused during later calculations.
|
| 353 |
+
|
| 354 |
+
Returns
|
| 355 |
+
-------
|
| 356 |
+
cache : Dict[nn.Module, torch.Tensor]
|
| 357 |
+
A dictionary object mapping the key/value projection modules to its cache
|
| 358 |
+
hooks : List[RemovableHandle]
|
| 359 |
+
List of PyTorch RemovableHandle objects to stop the hooks to be called
|
| 360 |
+
"""
|
| 361 |
+
cache = {**cache} if cache is not None else {}
|
| 362 |
+
hooks = []
|
| 363 |
+
|
| 364 |
+
def save_to_cache(module, _, output):
|
| 365 |
+
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
|
| 366 |
+
# save as-is, for the first token or cross attention
|
| 367 |
+
cache[module] = output
|
| 368 |
+
else:
|
| 369 |
+
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
| 370 |
+
return cache[module]
|
| 371 |
+
|
| 372 |
+
def install_hooks(layer: nn.Module):
|
| 373 |
+
if isinstance(layer, MultiHeadAttention):
|
| 374 |
+
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
| 375 |
+
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
| 376 |
+
|
| 377 |
+
self.decoder.apply(install_hooks)
|
| 378 |
+
return cache, hooks
|
| 379 |
+
|
| 380 |
+
detect_language = detect_language_function
|
| 381 |
+
transcribe = transcribe_function
|
| 382 |
+
decode = decode_function
|
simul_whisper/whisper/normalizers/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .basic import BasicTextNormalizer as BasicTextNormalizer
|
| 2 |
+
from .english import EnglishTextNormalizer as EnglishTextNormalizer
|
simul_whisper/whisper/normalizers/basic.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import unicodedata
|
| 3 |
+
|
| 4 |
+
import regex
|
| 5 |
+
|
| 6 |
+
# non-ASCII letters that are not separated by "NFKD" normalization
|
| 7 |
+
ADDITIONAL_DIACRITICS = {
|
| 8 |
+
"œ": "oe",
|
| 9 |
+
"Œ": "OE",
|
| 10 |
+
"ø": "o",
|
| 11 |
+
"Ø": "O",
|
| 12 |
+
"æ": "ae",
|
| 13 |
+
"Æ": "AE",
|
| 14 |
+
"ß": "ss",
|
| 15 |
+
"ẞ": "SS",
|
| 16 |
+
"đ": "d",
|
| 17 |
+
"Đ": "D",
|
| 18 |
+
"ð": "d",
|
| 19 |
+
"Ð": "D",
|
| 20 |
+
"þ": "th",
|
| 21 |
+
"Þ": "th",
|
| 22 |
+
"ł": "l",
|
| 23 |
+
"Ł": "L",
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def remove_symbols_and_diacritics(s: str, keep=""):
|
| 28 |
+
"""
|
| 29 |
+
Replace any other markers, symbols, and punctuations with a space,
|
| 30 |
+
and drop any diacritics (category 'Mn' and some manual mappings)
|
| 31 |
+
"""
|
| 32 |
+
return "".join(
|
| 33 |
+
c
|
| 34 |
+
if c in keep
|
| 35 |
+
else ADDITIONAL_DIACRITICS[c]
|
| 36 |
+
if c in ADDITIONAL_DIACRITICS
|
| 37 |
+
else ""
|
| 38 |
+
if unicodedata.category(c) == "Mn"
|
| 39 |
+
else " "
|
| 40 |
+
if unicodedata.category(c)[0] in "MSP"
|
| 41 |
+
else c
|
| 42 |
+
for c in unicodedata.normalize("NFKD", s)
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def remove_symbols(s: str):
|
| 47 |
+
"""
|
| 48 |
+
Replace any other markers, symbols, punctuations with a space, keeping diacritics
|
| 49 |
+
"""
|
| 50 |
+
return "".join(
|
| 51 |
+
" " if unicodedata.category(c)[0] in "MSP" else c
|
| 52 |
+
for c in unicodedata.normalize("NFKC", s)
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class BasicTextNormalizer:
|
| 57 |
+
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
|
| 58 |
+
self.clean = (
|
| 59 |
+
remove_symbols_and_diacritics if remove_diacritics else remove_symbols
|
| 60 |
+
)
|
| 61 |
+
self.split_letters = split_letters
|
| 62 |
+
|
| 63 |
+
def __call__(self, s: str):
|
| 64 |
+
s = s.lower()
|
| 65 |
+
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
| 66 |
+
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
| 67 |
+
s = self.clean(s).lower()
|
| 68 |
+
|
| 69 |
+
if self.split_letters:
|
| 70 |
+
s = " ".join(regex.findall(r"\X", s, regex.U))
|
| 71 |
+
|
| 72 |
+
s = re.sub(
|
| 73 |
+
r"\s+", " ", s
|
| 74 |
+
) # replace any successive whitespace characters with a space
|
| 75 |
+
|
| 76 |
+
return s
|
simul_whisper/whisper/normalizers/english.py
ADDED
|
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
from fractions import Fraction
|
| 5 |
+
from typing import Iterator, List, Match, Optional, Union
|
| 6 |
+
|
| 7 |
+
from more_itertools import windowed
|
| 8 |
+
|
| 9 |
+
from .basic import remove_symbols_and_diacritics
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class EnglishNumberNormalizer:
|
| 13 |
+
"""
|
| 14 |
+
Convert any spelled-out numbers into arabic numbers, while handling:
|
| 15 |
+
|
| 16 |
+
- remove any commas
|
| 17 |
+
- keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
|
| 18 |
+
- spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
|
| 19 |
+
- spell out `one` and `ones`
|
| 20 |
+
- interpret successive single-digit numbers as nominal: `one oh one` -> `101`
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
self.zeros = {"o", "oh", "zero"}
|
| 27 |
+
self.ones = {
|
| 28 |
+
name: i
|
| 29 |
+
for i, name in enumerate(
|
| 30 |
+
[
|
| 31 |
+
"one",
|
| 32 |
+
"two",
|
| 33 |
+
"three",
|
| 34 |
+
"four",
|
| 35 |
+
"five",
|
| 36 |
+
"six",
|
| 37 |
+
"seven",
|
| 38 |
+
"eight",
|
| 39 |
+
"nine",
|
| 40 |
+
"ten",
|
| 41 |
+
"eleven",
|
| 42 |
+
"twelve",
|
| 43 |
+
"thirteen",
|
| 44 |
+
"fourteen",
|
| 45 |
+
"fifteen",
|
| 46 |
+
"sixteen",
|
| 47 |
+
"seventeen",
|
| 48 |
+
"eighteen",
|
| 49 |
+
"nineteen",
|
| 50 |
+
],
|
| 51 |
+
start=1,
|
| 52 |
+
)
|
| 53 |
+
}
|
| 54 |
+
self.ones_plural = {
|
| 55 |
+
"sixes" if name == "six" else name + "s": (value, "s")
|
| 56 |
+
for name, value in self.ones.items()
|
| 57 |
+
}
|
| 58 |
+
self.ones_ordinal = {
|
| 59 |
+
"zeroth": (0, "th"),
|
| 60 |
+
"first": (1, "st"),
|
| 61 |
+
"second": (2, "nd"),
|
| 62 |
+
"third": (3, "rd"),
|
| 63 |
+
"fifth": (5, "th"),
|
| 64 |
+
"twelfth": (12, "th"),
|
| 65 |
+
**{
|
| 66 |
+
name + ("h" if name.endswith("t") else "th"): (value, "th")
|
| 67 |
+
for name, value in self.ones.items()
|
| 68 |
+
if value > 3 and value != 5 and value != 12
|
| 69 |
+
},
|
| 70 |
+
}
|
| 71 |
+
self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
|
| 72 |
+
|
| 73 |
+
self.tens = {
|
| 74 |
+
"twenty": 20,
|
| 75 |
+
"thirty": 30,
|
| 76 |
+
"forty": 40,
|
| 77 |
+
"fifty": 50,
|
| 78 |
+
"sixty": 60,
|
| 79 |
+
"seventy": 70,
|
| 80 |
+
"eighty": 80,
|
| 81 |
+
"ninety": 90,
|
| 82 |
+
}
|
| 83 |
+
self.tens_plural = {
|
| 84 |
+
name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
|
| 85 |
+
}
|
| 86 |
+
self.tens_ordinal = {
|
| 87 |
+
name.replace("y", "ieth"): (value, "th")
|
| 88 |
+
for name, value in self.tens.items()
|
| 89 |
+
}
|
| 90 |
+
self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
|
| 91 |
+
|
| 92 |
+
self.multipliers = {
|
| 93 |
+
"hundred": 100,
|
| 94 |
+
"thousand": 1_000,
|
| 95 |
+
"million": 1_000_000,
|
| 96 |
+
"billion": 1_000_000_000,
|
| 97 |
+
"trillion": 1_000_000_000_000,
|
| 98 |
+
"quadrillion": 1_000_000_000_000_000,
|
| 99 |
+
"quintillion": 1_000_000_000_000_000_000,
|
| 100 |
+
"sextillion": 1_000_000_000_000_000_000_000,
|
| 101 |
+
"septillion": 1_000_000_000_000_000_000_000_000,
|
| 102 |
+
"octillion": 1_000_000_000_000_000_000_000_000_000,
|
| 103 |
+
"nonillion": 1_000_000_000_000_000_000_000_000_000_000,
|
| 104 |
+
"decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
|
| 105 |
+
}
|
| 106 |
+
self.multipliers_plural = {
|
| 107 |
+
name + "s": (value, "s") for name, value in self.multipliers.items()
|
| 108 |
+
}
|
| 109 |
+
self.multipliers_ordinal = {
|
| 110 |
+
name + "th": (value, "th") for name, value in self.multipliers.items()
|
| 111 |
+
}
|
| 112 |
+
self.multipliers_suffixed = {
|
| 113 |
+
**self.multipliers_plural,
|
| 114 |
+
**self.multipliers_ordinal,
|
| 115 |
+
}
|
| 116 |
+
self.decimals = {*self.ones, *self.tens, *self.zeros}
|
| 117 |
+
|
| 118 |
+
self.preceding_prefixers = {
|
| 119 |
+
"minus": "-",
|
| 120 |
+
"negative": "-",
|
| 121 |
+
"plus": "+",
|
| 122 |
+
"positive": "+",
|
| 123 |
+
}
|
| 124 |
+
self.following_prefixers = {
|
| 125 |
+
"pound": "£",
|
| 126 |
+
"pounds": "£",
|
| 127 |
+
"euro": "€",
|
| 128 |
+
"euros": "€",
|
| 129 |
+
"dollar": "$",
|
| 130 |
+
"dollars": "$",
|
| 131 |
+
"cent": "¢",
|
| 132 |
+
"cents": "¢",
|
| 133 |
+
}
|
| 134 |
+
self.prefixes = set(
|
| 135 |
+
list(self.preceding_prefixers.values())
|
| 136 |
+
+ list(self.following_prefixers.values())
|
| 137 |
+
)
|
| 138 |
+
self.suffixers = {
|
| 139 |
+
"per": {"cent": "%"},
|
| 140 |
+
"percent": "%",
|
| 141 |
+
}
|
| 142 |
+
self.specials = {"and", "double", "triple", "point"}
|
| 143 |
+
|
| 144 |
+
self.words = set(
|
| 145 |
+
[
|
| 146 |
+
key
|
| 147 |
+
for mapping in [
|
| 148 |
+
self.zeros,
|
| 149 |
+
self.ones,
|
| 150 |
+
self.ones_suffixed,
|
| 151 |
+
self.tens,
|
| 152 |
+
self.tens_suffixed,
|
| 153 |
+
self.multipliers,
|
| 154 |
+
self.multipliers_suffixed,
|
| 155 |
+
self.preceding_prefixers,
|
| 156 |
+
self.following_prefixers,
|
| 157 |
+
self.suffixers,
|
| 158 |
+
self.specials,
|
| 159 |
+
]
|
| 160 |
+
for key in mapping
|
| 161 |
+
]
|
| 162 |
+
)
|
| 163 |
+
self.literal_words = {"one", "ones"}
|
| 164 |
+
|
| 165 |
+
def process_words(self, words: List[str]) -> Iterator[str]:
|
| 166 |
+
prefix: Optional[str] = None
|
| 167 |
+
value: Optional[Union[str, int]] = None
|
| 168 |
+
skip = False
|
| 169 |
+
|
| 170 |
+
def to_fraction(s: str):
|
| 171 |
+
try:
|
| 172 |
+
return Fraction(s)
|
| 173 |
+
except ValueError:
|
| 174 |
+
return None
|
| 175 |
+
|
| 176 |
+
def output(result: Union[str, int]):
|
| 177 |
+
nonlocal prefix, value
|
| 178 |
+
result = str(result)
|
| 179 |
+
if prefix is not None:
|
| 180 |
+
result = prefix + result
|
| 181 |
+
value = None
|
| 182 |
+
prefix = None
|
| 183 |
+
return result
|
| 184 |
+
|
| 185 |
+
if len(words) == 0:
|
| 186 |
+
return
|
| 187 |
+
|
| 188 |
+
for prev, current, next in windowed([None] + words + [None], 3):
|
| 189 |
+
if skip:
|
| 190 |
+
skip = False
|
| 191 |
+
continue
|
| 192 |
+
|
| 193 |
+
next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
|
| 194 |
+
has_prefix = current[0] in self.prefixes
|
| 195 |
+
current_without_prefix = current[1:] if has_prefix else current
|
| 196 |
+
if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
|
| 197 |
+
# arabic numbers (potentially with signs and fractions)
|
| 198 |
+
f = to_fraction(current_without_prefix)
|
| 199 |
+
assert f is not None
|
| 200 |
+
if value is not None:
|
| 201 |
+
if isinstance(value, str) and value.endswith("."):
|
| 202 |
+
# concatenate decimals / ip address components
|
| 203 |
+
value = str(value) + str(current)
|
| 204 |
+
continue
|
| 205 |
+
else:
|
| 206 |
+
yield output(value)
|
| 207 |
+
|
| 208 |
+
prefix = current[0] if has_prefix else prefix
|
| 209 |
+
if f.denominator == 1:
|
| 210 |
+
value = f.numerator # store integers as int
|
| 211 |
+
else:
|
| 212 |
+
value = current_without_prefix
|
| 213 |
+
elif current not in self.words:
|
| 214 |
+
# non-numeric words
|
| 215 |
+
if value is not None:
|
| 216 |
+
yield output(value)
|
| 217 |
+
yield output(current)
|
| 218 |
+
elif current in self.zeros:
|
| 219 |
+
value = str(value or "") + "0"
|
| 220 |
+
elif current in self.ones:
|
| 221 |
+
ones = self.ones[current]
|
| 222 |
+
|
| 223 |
+
if value is None:
|
| 224 |
+
value = ones
|
| 225 |
+
elif isinstance(value, str) or prev in self.ones:
|
| 226 |
+
if (
|
| 227 |
+
prev in self.tens and ones < 10
|
| 228 |
+
): # replace the last zero with the digit
|
| 229 |
+
assert value[-1] == "0"
|
| 230 |
+
value = value[:-1] + str(ones)
|
| 231 |
+
else:
|
| 232 |
+
value = str(value) + str(ones)
|
| 233 |
+
elif ones < 10:
|
| 234 |
+
if value % 10 == 0:
|
| 235 |
+
value += ones
|
| 236 |
+
else:
|
| 237 |
+
value = str(value) + str(ones)
|
| 238 |
+
else: # eleven to nineteen
|
| 239 |
+
if value % 100 == 0:
|
| 240 |
+
value += ones
|
| 241 |
+
else:
|
| 242 |
+
value = str(value) + str(ones)
|
| 243 |
+
elif current in self.ones_suffixed:
|
| 244 |
+
# ordinal or cardinal; yield the number right away
|
| 245 |
+
ones, suffix = self.ones_suffixed[current]
|
| 246 |
+
if value is None:
|
| 247 |
+
yield output(str(ones) + suffix)
|
| 248 |
+
elif isinstance(value, str) or prev in self.ones:
|
| 249 |
+
if prev in self.tens and ones < 10:
|
| 250 |
+
assert value[-1] == "0"
|
| 251 |
+
yield output(value[:-1] + str(ones) + suffix)
|
| 252 |
+
else:
|
| 253 |
+
yield output(str(value) + str(ones) + suffix)
|
| 254 |
+
elif ones < 10:
|
| 255 |
+
if value % 10 == 0:
|
| 256 |
+
yield output(str(value + ones) + suffix)
|
| 257 |
+
else:
|
| 258 |
+
yield output(str(value) + str(ones) + suffix)
|
| 259 |
+
else: # eleven to nineteen
|
| 260 |
+
if value % 100 == 0:
|
| 261 |
+
yield output(str(value + ones) + suffix)
|
| 262 |
+
else:
|
| 263 |
+
yield output(str(value) + str(ones) + suffix)
|
| 264 |
+
value = None
|
| 265 |
+
elif current in self.tens:
|
| 266 |
+
tens = self.tens[current]
|
| 267 |
+
if value is None:
|
| 268 |
+
value = tens
|
| 269 |
+
elif isinstance(value, str):
|
| 270 |
+
value = str(value) + str(tens)
|
| 271 |
+
else:
|
| 272 |
+
if value % 100 == 0:
|
| 273 |
+
value += tens
|
| 274 |
+
else:
|
| 275 |
+
value = str(value) + str(tens)
|
| 276 |
+
elif current in self.tens_suffixed:
|
| 277 |
+
# ordinal or cardinal; yield the number right away
|
| 278 |
+
tens, suffix = self.tens_suffixed[current]
|
| 279 |
+
if value is None:
|
| 280 |
+
yield output(str(tens) + suffix)
|
| 281 |
+
elif isinstance(value, str):
|
| 282 |
+
yield output(str(value) + str(tens) + suffix)
|
| 283 |
+
else:
|
| 284 |
+
if value % 100 == 0:
|
| 285 |
+
yield output(str(value + tens) + suffix)
|
| 286 |
+
else:
|
| 287 |
+
yield output(str(value) + str(tens) + suffix)
|
| 288 |
+
elif current in self.multipliers:
|
| 289 |
+
multiplier = self.multipliers[current]
|
| 290 |
+
if value is None:
|
| 291 |
+
value = multiplier
|
| 292 |
+
elif isinstance(value, str) or value == 0:
|
| 293 |
+
f = to_fraction(value)
|
| 294 |
+
p = f * multiplier if f is not None else None
|
| 295 |
+
if f is not None and p.denominator == 1:
|
| 296 |
+
value = p.numerator
|
| 297 |
+
else:
|
| 298 |
+
yield output(value)
|
| 299 |
+
value = multiplier
|
| 300 |
+
else:
|
| 301 |
+
before = value // 1000 * 1000
|
| 302 |
+
residual = value % 1000
|
| 303 |
+
value = before + residual * multiplier
|
| 304 |
+
elif current in self.multipliers_suffixed:
|
| 305 |
+
multiplier, suffix = self.multipliers_suffixed[current]
|
| 306 |
+
if value is None:
|
| 307 |
+
yield output(str(multiplier) + suffix)
|
| 308 |
+
elif isinstance(value, str):
|
| 309 |
+
f = to_fraction(value)
|
| 310 |
+
p = f * multiplier if f is not None else None
|
| 311 |
+
if f is not None and p.denominator == 1:
|
| 312 |
+
yield output(str(p.numerator) + suffix)
|
| 313 |
+
else:
|
| 314 |
+
yield output(value)
|
| 315 |
+
yield output(str(multiplier) + suffix)
|
| 316 |
+
else: # int
|
| 317 |
+
before = value // 1000 * 1000
|
| 318 |
+
residual = value % 1000
|
| 319 |
+
value = before + residual * multiplier
|
| 320 |
+
yield output(str(value) + suffix)
|
| 321 |
+
value = None
|
| 322 |
+
elif current in self.preceding_prefixers:
|
| 323 |
+
# apply prefix (positive, minus, etc.) if it precedes a number
|
| 324 |
+
if value is not None:
|
| 325 |
+
yield output(value)
|
| 326 |
+
|
| 327 |
+
if next in self.words or next_is_numeric:
|
| 328 |
+
prefix = self.preceding_prefixers[current]
|
| 329 |
+
else:
|
| 330 |
+
yield output(current)
|
| 331 |
+
elif current in self.following_prefixers:
|
| 332 |
+
# apply prefix (dollars, cents, etc.) only after a number
|
| 333 |
+
if value is not None:
|
| 334 |
+
prefix = self.following_prefixers[current]
|
| 335 |
+
yield output(value)
|
| 336 |
+
else:
|
| 337 |
+
yield output(current)
|
| 338 |
+
elif current in self.suffixers:
|
| 339 |
+
# apply suffix symbols (percent -> '%')
|
| 340 |
+
if value is not None:
|
| 341 |
+
suffix = self.suffixers[current]
|
| 342 |
+
if isinstance(suffix, dict):
|
| 343 |
+
if next in suffix:
|
| 344 |
+
yield output(str(value) + suffix[next])
|
| 345 |
+
skip = True
|
| 346 |
+
else:
|
| 347 |
+
yield output(value)
|
| 348 |
+
yield output(current)
|
| 349 |
+
else:
|
| 350 |
+
yield output(str(value) + suffix)
|
| 351 |
+
else:
|
| 352 |
+
yield output(current)
|
| 353 |
+
elif current in self.specials:
|
| 354 |
+
if next not in self.words and not next_is_numeric:
|
| 355 |
+
# apply special handling only if the next word can be numeric
|
| 356 |
+
if value is not None:
|
| 357 |
+
yield output(value)
|
| 358 |
+
yield output(current)
|
| 359 |
+
elif current == "and":
|
| 360 |
+
# ignore "and" after hundreds, thousands, etc.
|
| 361 |
+
if prev not in self.multipliers:
|
| 362 |
+
if value is not None:
|
| 363 |
+
yield output(value)
|
| 364 |
+
yield output(current)
|
| 365 |
+
elif current == "double" or current == "triple":
|
| 366 |
+
if next in self.ones or next in self.zeros:
|
| 367 |
+
repeats = 2 if current == "double" else 3
|
| 368 |
+
ones = self.ones.get(next, 0)
|
| 369 |
+
value = str(value or "") + str(ones) * repeats
|
| 370 |
+
skip = True
|
| 371 |
+
else:
|
| 372 |
+
if value is not None:
|
| 373 |
+
yield output(value)
|
| 374 |
+
yield output(current)
|
| 375 |
+
elif current == "point":
|
| 376 |
+
if next in self.decimals or next_is_numeric:
|
| 377 |
+
value = str(value or "") + "."
|
| 378 |
+
else:
|
| 379 |
+
# should all have been covered at this point
|
| 380 |
+
raise ValueError(f"Unexpected token: {current}")
|
| 381 |
+
else:
|
| 382 |
+
# all should have been covered at this point
|
| 383 |
+
raise ValueError(f"Unexpected token: {current}")
|
| 384 |
+
|
| 385 |
+
if value is not None:
|
| 386 |
+
yield output(value)
|
| 387 |
+
|
| 388 |
+
def preprocess(self, s: str):
|
| 389 |
+
# replace "<number> and a half" with "<number> point five"
|
| 390 |
+
results = []
|
| 391 |
+
|
| 392 |
+
segments = re.split(r"\band\s+a\s+half\b", s)
|
| 393 |
+
for i, segment in enumerate(segments):
|
| 394 |
+
if len(segment.strip()) == 0:
|
| 395 |
+
continue
|
| 396 |
+
if i == len(segments) - 1:
|
| 397 |
+
results.append(segment)
|
| 398 |
+
else:
|
| 399 |
+
results.append(segment)
|
| 400 |
+
last_word = segment.rsplit(maxsplit=2)[-1]
|
| 401 |
+
if last_word in self.decimals or last_word in self.multipliers:
|
| 402 |
+
results.append("point five")
|
| 403 |
+
else:
|
| 404 |
+
results.append("and a half")
|
| 405 |
+
|
| 406 |
+
s = " ".join(results)
|
| 407 |
+
|
| 408 |
+
# put a space at number/letter boundary
|
| 409 |
+
s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
|
| 410 |
+
s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
|
| 411 |
+
|
| 412 |
+
# but remove spaces which could be a suffix
|
| 413 |
+
s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
|
| 414 |
+
|
| 415 |
+
return s
|
| 416 |
+
|
| 417 |
+
def postprocess(self, s: str):
|
| 418 |
+
def combine_cents(m: Match):
|
| 419 |
+
try:
|
| 420 |
+
currency = m.group(1)
|
| 421 |
+
integer = m.group(2)
|
| 422 |
+
cents = int(m.group(3))
|
| 423 |
+
return f"{currency}{integer}.{cents:02d}"
|
| 424 |
+
except ValueError:
|
| 425 |
+
return m.string
|
| 426 |
+
|
| 427 |
+
def extract_cents(m: Match):
|
| 428 |
+
try:
|
| 429 |
+
return f"¢{int(m.group(1))}"
|
| 430 |
+
except ValueError:
|
| 431 |
+
return m.string
|
| 432 |
+
|
| 433 |
+
# apply currency postprocessing; "$2 and ¢7" -> "$2.07"
|
| 434 |
+
s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s)
|
| 435 |
+
s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s)
|
| 436 |
+
|
| 437 |
+
# write "one(s)" instead of "1(s)", just for the readability
|
| 438 |
+
s = re.sub(r"\b1(s?)\b", r"one\1", s)
|
| 439 |
+
|
| 440 |
+
return s
|
| 441 |
+
|
| 442 |
+
def __call__(self, s: str):
|
| 443 |
+
s = self.preprocess(s)
|
| 444 |
+
s = " ".join(word for word in self.process_words(s.split()) if word is not None)
|
| 445 |
+
s = self.postprocess(s)
|
| 446 |
+
|
| 447 |
+
return s
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
class EnglishSpellingNormalizer:
|
| 451 |
+
"""
|
| 452 |
+
Applies British-American spelling mappings as listed in [1].
|
| 453 |
+
|
| 454 |
+
[1] https://www.tysto.com/uk-us-spelling-list.html
|
| 455 |
+
"""
|
| 456 |
+
|
| 457 |
+
def __init__(self):
|
| 458 |
+
mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
|
| 459 |
+
self.mapping = json.load(open(mapping_path))
|
| 460 |
+
|
| 461 |
+
def __call__(self, s: str):
|
| 462 |
+
return " ".join(self.mapping.get(word, word) for word in s.split())
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
class EnglishTextNormalizer:
|
| 466 |
+
def __init__(self):
|
| 467 |
+
self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
|
| 468 |
+
self.replacers = {
|
| 469 |
+
# common contractions
|
| 470 |
+
r"\bwon't\b": "will not",
|
| 471 |
+
r"\bcan't\b": "can not",
|
| 472 |
+
r"\blet's\b": "let us",
|
| 473 |
+
r"\bain't\b": "aint",
|
| 474 |
+
r"\by'all\b": "you all",
|
| 475 |
+
r"\bwanna\b": "want to",
|
| 476 |
+
r"\bgotta\b": "got to",
|
| 477 |
+
r"\bgonna\b": "going to",
|
| 478 |
+
r"\bi'ma\b": "i am going to",
|
| 479 |
+
r"\bimma\b": "i am going to",
|
| 480 |
+
r"\bwoulda\b": "would have",
|
| 481 |
+
r"\bcoulda\b": "could have",
|
| 482 |
+
r"\bshoulda\b": "should have",
|
| 483 |
+
r"\bma'am\b": "madam",
|
| 484 |
+
# contractions in titles/prefixes
|
| 485 |
+
r"\bmr\b": "mister ",
|
| 486 |
+
r"\bmrs\b": "missus ",
|
| 487 |
+
r"\bst\b": "saint ",
|
| 488 |
+
r"\bdr\b": "doctor ",
|
| 489 |
+
r"\bprof\b": "professor ",
|
| 490 |
+
r"\bcapt\b": "captain ",
|
| 491 |
+
r"\bgov\b": "governor ",
|
| 492 |
+
r"\bald\b": "alderman ",
|
| 493 |
+
r"\bgen\b": "general ",
|
| 494 |
+
r"\bsen\b": "senator ",
|
| 495 |
+
r"\brep\b": "representative ",
|
| 496 |
+
r"\bpres\b": "president ",
|
| 497 |
+
r"\brev\b": "reverend ",
|
| 498 |
+
r"\bhon\b": "honorable ",
|
| 499 |
+
r"\basst\b": "assistant ",
|
| 500 |
+
r"\bassoc\b": "associate ",
|
| 501 |
+
r"\blt\b": "lieutenant ",
|
| 502 |
+
r"\bcol\b": "colonel ",
|
| 503 |
+
r"\bjr\b": "junior ",
|
| 504 |
+
r"\bsr\b": "senior ",
|
| 505 |
+
r"\besq\b": "esquire ",
|
| 506 |
+
# prefect tenses, ideally it should be any past participles, but it's harder..
|
| 507 |
+
r"'d been\b": " had been",
|
| 508 |
+
r"'s been\b": " has been",
|
| 509 |
+
r"'d gone\b": " had gone",
|
| 510 |
+
r"'s gone\b": " has gone",
|
| 511 |
+
r"'d done\b": " had done", # "'s done" is ambiguous
|
| 512 |
+
r"'s got\b": " has got",
|
| 513 |
+
# general contractions
|
| 514 |
+
r"n't\b": " not",
|
| 515 |
+
r"'re\b": " are",
|
| 516 |
+
r"'s\b": " is",
|
| 517 |
+
r"'d\b": " would",
|
| 518 |
+
r"'ll\b": " will",
|
| 519 |
+
r"'t\b": " not",
|
| 520 |
+
r"'ve\b": " have",
|
| 521 |
+
r"'m\b": " am",
|
| 522 |
+
}
|
| 523 |
+
self.standardize_numbers = EnglishNumberNormalizer()
|
| 524 |
+
self.standardize_spellings = EnglishSpellingNormalizer()
|
| 525 |
+
|
| 526 |
+
def __call__(self, s: str):
|
| 527 |
+
s = s.lower()
|
| 528 |
+
|
| 529 |
+
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
| 530 |
+
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
| 531 |
+
s = re.sub(self.ignore_patterns, "", s)
|
| 532 |
+
s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe
|
| 533 |
+
|
| 534 |
+
for pattern, replacement in self.replacers.items():
|
| 535 |
+
s = re.sub(pattern, replacement, s)
|
| 536 |
+
|
| 537 |
+
s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
|
| 538 |
+
s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
|
| 539 |
+
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols
|
| 540 |
+
|
| 541 |
+
s = self.standardize_numbers(s)
|
| 542 |
+
s = self.standardize_spellings(s)
|
| 543 |
+
|
| 544 |
+
# now remove prefix/suffix symbols that are not preceded/followed by numbers
|
| 545 |
+
s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
|
| 546 |
+
s = re.sub(r"([^0-9])%", r"\1 ", s)
|
| 547 |
+
|
| 548 |
+
s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space
|
| 549 |
+
|
| 550 |
+
return s
|
simul_whisper/whisper/timing.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import subprocess
|
| 3 |
+
import warnings
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import TYPE_CHECKING, List
|
| 6 |
+
|
| 7 |
+
import numba
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
|
| 13 |
+
from .tokenizer import Tokenizer
|
| 14 |
+
|
| 15 |
+
if TYPE_CHECKING:
|
| 16 |
+
from .model import Whisper
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def median_filter(x: torch.Tensor, filter_width: int):
|
| 20 |
+
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
|
| 21 |
+
pad_width = filter_width // 2
|
| 22 |
+
if x.shape[-1] <= pad_width:
|
| 23 |
+
# F.pad requires the padding width to be smaller than the input dimension
|
| 24 |
+
return x
|
| 25 |
+
|
| 26 |
+
if (ndim := x.ndim) <= 2:
|
| 27 |
+
# `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
|
| 28 |
+
x = x[None, None, :]
|
| 29 |
+
|
| 30 |
+
assert (
|
| 31 |
+
filter_width > 0 and filter_width % 2 == 1
|
| 32 |
+
), "`filter_width` should be an odd number"
|
| 33 |
+
|
| 34 |
+
result = None
|
| 35 |
+
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
|
| 36 |
+
if x.is_cuda:
|
| 37 |
+
try:
|
| 38 |
+
from .triton_ops import median_filter_cuda
|
| 39 |
+
|
| 40 |
+
result = median_filter_cuda(x, filter_width)
|
| 41 |
+
except (RuntimeError, subprocess.CalledProcessError):
|
| 42 |
+
warnings.warn(
|
| 43 |
+
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
|
| 44 |
+
"falling back to a slower median kernel implementation..."
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
if result is None:
|
| 48 |
+
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
|
| 49 |
+
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
|
| 50 |
+
|
| 51 |
+
if ndim <= 2:
|
| 52 |
+
result = result[0, 0]
|
| 53 |
+
|
| 54 |
+
return result
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@numba.jit(nopython=True)
|
| 58 |
+
def backtrace(trace: np.ndarray):
|
| 59 |
+
i = trace.shape[0] - 1 # trace: (N+1, M+1), i=N
|
| 60 |
+
j = trace.shape[1] - 1 # j=M
|
| 61 |
+
# 边界点其实无意义?
|
| 62 |
+
trace[0, :] = 2
|
| 63 |
+
trace[:, 0] = 1
|
| 64 |
+
|
| 65 |
+
result = []
|
| 66 |
+
while i > 0 or j > 0:
|
| 67 |
+
result.append((i - 1, j - 1))
|
| 68 |
+
|
| 69 |
+
if trace[i, j] == 0:
|
| 70 |
+
i -= 1
|
| 71 |
+
j -= 1
|
| 72 |
+
elif trace[i, j] == 1:
|
| 73 |
+
i -= 1
|
| 74 |
+
elif trace[i, j] == 2:
|
| 75 |
+
j -= 1
|
| 76 |
+
else:
|
| 77 |
+
raise ValueError("Unexpected trace[i, j]")
|
| 78 |
+
|
| 79 |
+
result = np.array(result)
|
| 80 |
+
return result[::-1, :].T
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@numba.jit(nopython=True, parallel=True)
|
| 84 |
+
def dtw_cpu(x: np.ndarray):
|
| 85 |
+
N, M = x.shape
|
| 86 |
+
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf # cost: x[0, 0]到x[i-1, j-1]的最小代价
|
| 87 |
+
trace = -np.ones((N + 1, M + 1), dtype=np.float32) # trace:
|
| 88 |
+
|
| 89 |
+
cost[0, 0] = 0
|
| 90 |
+
for j in range(1, M + 1):
|
| 91 |
+
for i in range(1, N + 1):
|
| 92 |
+
c0 = cost[i - 1, j - 1]
|
| 93 |
+
c1 = cost[i - 1, j]
|
| 94 |
+
c2 = cost[i, j - 1]
|
| 95 |
+
|
| 96 |
+
if c0 < c1 and c0 < c2:
|
| 97 |
+
c, t = c0, 0
|
| 98 |
+
elif c1 < c0 and c1 < c2:
|
| 99 |
+
c, t = c1, 1
|
| 100 |
+
else:
|
| 101 |
+
c, t = c2, 2
|
| 102 |
+
|
| 103 |
+
cost[i, j] = x[i - 1, j - 1] + c
|
| 104 |
+
trace[i, j] = t
|
| 105 |
+
|
| 106 |
+
return backtrace(trace)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def dtw_cuda(x, BLOCK_SIZE=1024):
|
| 110 |
+
from .triton_ops import dtw_kernel
|
| 111 |
+
|
| 112 |
+
M, N = x.shape
|
| 113 |
+
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
|
| 114 |
+
|
| 115 |
+
x_skew = (
|
| 116 |
+
F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
|
| 117 |
+
)
|
| 118 |
+
x_skew = x_skew.T.contiguous()
|
| 119 |
+
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
| 120 |
+
cost[0, 0] = 0
|
| 121 |
+
cost = cost.cuda()
|
| 122 |
+
trace = torch.zeros_like(cost, dtype=torch.int32)
|
| 123 |
+
|
| 124 |
+
dtw_kernel[(1,)](
|
| 125 |
+
cost,
|
| 126 |
+
trace,
|
| 127 |
+
x_skew,
|
| 128 |
+
x_skew.stride(0),
|
| 129 |
+
cost.stride(0),
|
| 130 |
+
trace.stride(0),
|
| 131 |
+
N,
|
| 132 |
+
M,
|
| 133 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
|
| 137 |
+
:, : N + 1
|
| 138 |
+
]
|
| 139 |
+
return backtrace(trace.cpu().numpy())
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def dtw(x: torch.Tensor) -> np.ndarray:
|
| 143 |
+
if x.is_cuda:
|
| 144 |
+
try:
|
| 145 |
+
return dtw_cuda(x)
|
| 146 |
+
except (RuntimeError, subprocess.CalledProcessError):
|
| 147 |
+
warnings.warn(
|
| 148 |
+
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
|
| 149 |
+
"falling back to a slower DTW implementation..."
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
return dtw_cpu(x.double().cpu().numpy())
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@dataclass
|
| 156 |
+
class WordTiming:
|
| 157 |
+
word: str
|
| 158 |
+
tokens: List[int]
|
| 159 |
+
start: float
|
| 160 |
+
end: float
|
| 161 |
+
probability: float
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def find_alignment(
|
| 165 |
+
model: "Whisper",
|
| 166 |
+
tokenizer: Tokenizer,
|
| 167 |
+
text_tokens: List[int],
|
| 168 |
+
mel: torch.Tensor,
|
| 169 |
+
num_frames: int,
|
| 170 |
+
*,
|
| 171 |
+
medfilt_width: int = 7,
|
| 172 |
+
qk_scale: float = 1.0,
|
| 173 |
+
) -> List[WordTiming]:
|
| 174 |
+
if len(text_tokens) == 0:
|
| 175 |
+
return []
|
| 176 |
+
|
| 177 |
+
tokens = torch.tensor(
|
| 178 |
+
[
|
| 179 |
+
*tokenizer.sot_sequence,
|
| 180 |
+
tokenizer.no_timestamps,
|
| 181 |
+
*text_tokens,
|
| 182 |
+
tokenizer.eot,
|
| 183 |
+
]
|
| 184 |
+
).to(model.device)
|
| 185 |
+
|
| 186 |
+
# install hooks on the cross attention layers to retrieve the attention weights
|
| 187 |
+
QKs = [None] * model.dims.n_text_layer
|
| 188 |
+
hooks = [
|
| 189 |
+
block.cross_attn.register_forward_hook(
|
| 190 |
+
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
|
| 191 |
+
)
|
| 192 |
+
for i, block in enumerate(model.decoder.blocks)
|
| 193 |
+
]
|
| 194 |
+
|
| 195 |
+
# 进行前传,获得token概率
|
| 196 |
+
with torch.no_grad():
|
| 197 |
+
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
| 198 |
+
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
| 199 |
+
token_probs = sampled_logits.softmax(dim=-1)
|
| 200 |
+
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
|
| 201 |
+
text_token_probs = text_token_probs.tolist()
|
| 202 |
+
|
| 203 |
+
# 移除钩子
|
| 204 |
+
for hook in hooks:
|
| 205 |
+
hook.remove()
|
| 206 |
+
|
| 207 |
+
# heads * tokens * frames
|
| 208 |
+
# print(model.alignment_heads)
|
| 209 |
+
# exit(0)
|
| 210 |
+
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
|
| 211 |
+
weights = weights[:, :, : num_frames // 2]
|
| 212 |
+
weights = (weights * qk_scale).softmax(dim=-1)
|
| 213 |
+
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
| 214 |
+
weights = (weights - mean) / std
|
| 215 |
+
weights = median_filter(weights, medfilt_width)
|
| 216 |
+
|
| 217 |
+
matrix = weights.mean(axis=0)
|
| 218 |
+
print("attention", matrix.shape, matrix[:5, :5])
|
| 219 |
+
matrix = matrix[len(tokenizer.sot_sequence) : -1]
|
| 220 |
+
print("attention", matrix.shape, matrix[:5, :5])
|
| 221 |
+
text_indices, time_indices = dtw(-matrix)
|
| 222 |
+
|
| 223 |
+
print("num_frames", num_frames)
|
| 224 |
+
print("attention", matrix.shape, matrix[:5, :5])
|
| 225 |
+
print("text_indices", text_indices)
|
| 226 |
+
print("time", time_indices)
|
| 227 |
+
print("text_tokens", text_tokens, tokenizer.decode(text_tokens), len(text_tokens))
|
| 228 |
+
print("eot", tokenizer.eot)
|
| 229 |
+
|
| 230 |
+
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
|
| 231 |
+
if len(word_tokens) <= 1:
|
| 232 |
+
# return on eot only
|
| 233 |
+
# >>> np.pad([], (1, 0))
|
| 234 |
+
# array([0.])
|
| 235 |
+
# This results in crashes when we lookup jump_times with float, like
|
| 236 |
+
# IndexError: arrays used as indices must be of integer (or boolean) type
|
| 237 |
+
return []
|
| 238 |
+
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
| 239 |
+
|
| 240 |
+
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
| 241 |
+
# print("jumps", jumps, jumps.shape)
|
| 242 |
+
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
|
| 243 |
+
# print("jump_times", jump_times)
|
| 244 |
+
start_times = jump_times[word_boundaries[:-1]]
|
| 245 |
+
end_times = jump_times[word_boundaries[1:]]
|
| 246 |
+
word_probabilities = [
|
| 247 |
+
np.mean(text_token_probs[i:j])
|
| 248 |
+
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
|
| 249 |
+
]
|
| 250 |
+
|
| 251 |
+
return [
|
| 252 |
+
WordTiming(word, tokens, start, end, probability)
|
| 253 |
+
for word, tokens, start, end, probability in zip(
|
| 254 |
+
words, word_tokens, start_times, end_times, word_probabilities
|
| 255 |
+
)
|
| 256 |
+
]
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
|
| 260 |
+
# merge prepended punctuations
|
| 261 |
+
i = len(alignment) - 2
|
| 262 |
+
j = len(alignment) - 1
|
| 263 |
+
while i >= 0:
|
| 264 |
+
previous = alignment[i]
|
| 265 |
+
following = alignment[j]
|
| 266 |
+
if previous.word.startswith(" ") and previous.word.strip() in prepended:
|
| 267 |
+
# prepend it to the following word
|
| 268 |
+
following.word = previous.word + following.word
|
| 269 |
+
following.tokens = previous.tokens + following.tokens
|
| 270 |
+
previous.word = ""
|
| 271 |
+
previous.tokens = []
|
| 272 |
+
else:
|
| 273 |
+
j = i
|
| 274 |
+
i -= 1
|
| 275 |
+
|
| 276 |
+
# merge appended punctuations
|
| 277 |
+
i = 0
|
| 278 |
+
j = 1
|
| 279 |
+
while j < len(alignment):
|
| 280 |
+
previous = alignment[i]
|
| 281 |
+
following = alignment[j]
|
| 282 |
+
if not previous.word.endswith(" ") and following.word in appended:
|
| 283 |
+
# append it to the previous word
|
| 284 |
+
previous.word = previous.word + following.word
|
| 285 |
+
previous.tokens = previous.tokens + following.tokens
|
| 286 |
+
following.word = ""
|
| 287 |
+
following.tokens = []
|
| 288 |
+
else:
|
| 289 |
+
i = j
|
| 290 |
+
j += 1
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def add_word_timestamps(
|
| 294 |
+
*,
|
| 295 |
+
segments: List[dict],
|
| 296 |
+
model: "Whisper",
|
| 297 |
+
tokenizer: Tokenizer,
|
| 298 |
+
mel: torch.Tensor,
|
| 299 |
+
num_frames: int,
|
| 300 |
+
prepend_punctuations: str = "\"'“¿([{-",
|
| 301 |
+
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
| 302 |
+
last_speech_timestamp: float,
|
| 303 |
+
**kwargs,
|
| 304 |
+
):
|
| 305 |
+
if len(segments) == 0:
|
| 306 |
+
return
|
| 307 |
+
|
| 308 |
+
text_tokens_per_segment = [
|
| 309 |
+
[token for token in segment["tokens"] if token < tokenizer.eot]
|
| 310 |
+
for segment in segments
|
| 311 |
+
]
|
| 312 |
+
|
| 313 |
+
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
|
| 314 |
+
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
|
| 315 |
+
word_durations = np.array([t.end - t.start for t in alignment])
|
| 316 |
+
word_durations = word_durations[word_durations.nonzero()]
|
| 317 |
+
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
| 318 |
+
max_duration = median_duration * 2
|
| 319 |
+
|
| 320 |
+
# hack: truncate long words at sentence boundaries.
|
| 321 |
+
# a better segmentation algorithm based on VAD should be able to replace this.
|
| 322 |
+
if len(word_durations) > 0:
|
| 323 |
+
sentence_end_marks = ".。!!??"
|
| 324 |
+
# ensure words at sentence boundaries are not longer than twice the median word duration.
|
| 325 |
+
for i in range(1, len(alignment)):
|
| 326 |
+
if alignment[i].end - alignment[i].start > max_duration:
|
| 327 |
+
if alignment[i].word in sentence_end_marks:
|
| 328 |
+
alignment[i].end = alignment[i].start + max_duration
|
| 329 |
+
elif alignment[i - 1].word in sentence_end_marks:
|
| 330 |
+
alignment[i].start = alignment[i].end - max_duration
|
| 331 |
+
|
| 332 |
+
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
| 333 |
+
|
| 334 |
+
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
|
| 335 |
+
word_index = 0
|
| 336 |
+
|
| 337 |
+
for segment, text_tokens in zip(segments, text_tokens_per_segment):
|
| 338 |
+
saved_tokens = 0
|
| 339 |
+
words = []
|
| 340 |
+
|
| 341 |
+
while word_index < len(alignment) and saved_tokens < len(text_tokens):
|
| 342 |
+
timing = alignment[word_index]
|
| 343 |
+
|
| 344 |
+
if timing.word:
|
| 345 |
+
words.append(
|
| 346 |
+
dict(
|
| 347 |
+
word=timing.word,
|
| 348 |
+
start=round(time_offset + timing.start, 2),
|
| 349 |
+
end=round(time_offset + timing.end, 2),
|
| 350 |
+
probability=timing.probability,
|
| 351 |
+
)
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
saved_tokens += len(timing.tokens)
|
| 355 |
+
word_index += 1
|
| 356 |
+
|
| 357 |
+
# hack: truncate long words at segment boundaries.
|
| 358 |
+
# a better segmentation algorithm based on VAD should be able to replace this.
|
| 359 |
+
if len(words) > 0:
|
| 360 |
+
# ensure the first and second word after a pause is not longer than
|
| 361 |
+
# twice the median word duration.
|
| 362 |
+
if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
|
| 363 |
+
words[0]["end"] - words[0]["start"] > max_duration
|
| 364 |
+
or (
|
| 365 |
+
len(words) > 1
|
| 366 |
+
and words[1]["end"] - words[0]["start"] > max_duration * 2
|
| 367 |
+
)
|
| 368 |
+
):
|
| 369 |
+
if (
|
| 370 |
+
len(words) > 1
|
| 371 |
+
and words[1]["end"] - words[1]["start"] > max_duration
|
| 372 |
+
):
|
| 373 |
+
boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
|
| 374 |
+
words[0]["end"] = words[1]["start"] = boundary
|
| 375 |
+
words[0]["start"] = max(0, words[0]["end"] - max_duration)
|
| 376 |
+
|
| 377 |
+
# prefer the segment-level start timestamp if the first word is too long.
|
| 378 |
+
if (
|
| 379 |
+
segment["start"] < words[0]["end"]
|
| 380 |
+
and segment["start"] - 0.5 > words[0]["start"]
|
| 381 |
+
):
|
| 382 |
+
words[0]["start"] = max(
|
| 383 |
+
0, min(words[0]["end"] - median_duration, segment["start"])
|
| 384 |
+
)
|
| 385 |
+
else:
|
| 386 |
+
segment["start"] = words[0]["start"]
|
| 387 |
+
|
| 388 |
+
# prefer the segment-level end timestamp if the last word is too long.
|
| 389 |
+
if (
|
| 390 |
+
segment["end"] > words[-1]["start"]
|
| 391 |
+
and segment["end"] + 0.5 < words[-1]["end"]
|
| 392 |
+
):
|
| 393 |
+
words[-1]["end"] = max(
|
| 394 |
+
words[-1]["start"] + median_duration, segment["end"]
|
| 395 |
+
)
|
| 396 |
+
else:
|
| 397 |
+
segment["end"] = words[-1]["end"]
|
| 398 |
+
|
| 399 |
+
last_speech_timestamp = segment["end"]
|
| 400 |
+
|
| 401 |
+
segment["words"] = words
|
simul_whisper/whisper/tokenizer.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import os
|
| 3 |
+
import string
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from functools import cached_property, lru_cache
|
| 6 |
+
from typing import Dict, List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import tiktoken
|
| 9 |
+
|
| 10 |
+
LANGUAGES = {
|
| 11 |
+
"en": "english",
|
| 12 |
+
"zh": "chinese",
|
| 13 |
+
"de": "german",
|
| 14 |
+
"es": "spanish",
|
| 15 |
+
"ru": "russian",
|
| 16 |
+
"ko": "korean",
|
| 17 |
+
"fr": "french",
|
| 18 |
+
"ja": "japanese",
|
| 19 |
+
"pt": "portuguese",
|
| 20 |
+
"tr": "turkish",
|
| 21 |
+
"pl": "polish",
|
| 22 |
+
"ca": "catalan",
|
| 23 |
+
"nl": "dutch",
|
| 24 |
+
"ar": "arabic",
|
| 25 |
+
"sv": "swedish",
|
| 26 |
+
"it": "italian",
|
| 27 |
+
"id": "indonesian",
|
| 28 |
+
"hi": "hindi",
|
| 29 |
+
"fi": "finnish",
|
| 30 |
+
"vi": "vietnamese",
|
| 31 |
+
"he": "hebrew",
|
| 32 |
+
"uk": "ukrainian",
|
| 33 |
+
"el": "greek",
|
| 34 |
+
"ms": "malay",
|
| 35 |
+
"cs": "czech",
|
| 36 |
+
"ro": "romanian",
|
| 37 |
+
"da": "danish",
|
| 38 |
+
"hu": "hungarian",
|
| 39 |
+
"ta": "tamil",
|
| 40 |
+
"no": "norwegian",
|
| 41 |
+
"th": "thai",
|
| 42 |
+
"ur": "urdu",
|
| 43 |
+
"hr": "croatian",
|
| 44 |
+
"bg": "bulgarian",
|
| 45 |
+
"lt": "lithuanian",
|
| 46 |
+
"la": "latin",
|
| 47 |
+
"mi": "maori",
|
| 48 |
+
"ml": "malayalam",
|
| 49 |
+
"cy": "welsh",
|
| 50 |
+
"sk": "slovak",
|
| 51 |
+
"te": "telugu",
|
| 52 |
+
"fa": "persian",
|
| 53 |
+
"lv": "latvian",
|
| 54 |
+
"bn": "bengali",
|
| 55 |
+
"sr": "serbian",
|
| 56 |
+
"az": "azerbaijani",
|
| 57 |
+
"sl": "slovenian",
|
| 58 |
+
"kn": "kannada",
|
| 59 |
+
"et": "estonian",
|
| 60 |
+
"mk": "macedonian",
|
| 61 |
+
"br": "breton",
|
| 62 |
+
"eu": "basque",
|
| 63 |
+
"is": "icelandic",
|
| 64 |
+
"hy": "armenian",
|
| 65 |
+
"ne": "nepali",
|
| 66 |
+
"mn": "mongolian",
|
| 67 |
+
"bs": "bosnian",
|
| 68 |
+
"kk": "kazakh",
|
| 69 |
+
"sq": "albanian",
|
| 70 |
+
"sw": "swahili",
|
| 71 |
+
"gl": "galician",
|
| 72 |
+
"mr": "marathi",
|
| 73 |
+
"pa": "punjabi",
|
| 74 |
+
"si": "sinhala",
|
| 75 |
+
"km": "khmer",
|
| 76 |
+
"sn": "shona",
|
| 77 |
+
"yo": "yoruba",
|
| 78 |
+
"so": "somali",
|
| 79 |
+
"af": "afrikaans",
|
| 80 |
+
"oc": "occitan",
|
| 81 |
+
"ka": "georgian",
|
| 82 |
+
"be": "belarusian",
|
| 83 |
+
"tg": "tajik",
|
| 84 |
+
"sd": "sindhi",
|
| 85 |
+
"gu": "gujarati",
|
| 86 |
+
"am": "amharic",
|
| 87 |
+
"yi": "yiddish",
|
| 88 |
+
"lo": "lao",
|
| 89 |
+
"uz": "uzbek",
|
| 90 |
+
"fo": "faroese",
|
| 91 |
+
"ht": "haitian creole",
|
| 92 |
+
"ps": "pashto",
|
| 93 |
+
"tk": "turkmen",
|
| 94 |
+
"nn": "nynorsk",
|
| 95 |
+
"mt": "maltese",
|
| 96 |
+
"sa": "sanskrit",
|
| 97 |
+
"lb": "luxembourgish",
|
| 98 |
+
"my": "myanmar",
|
| 99 |
+
"bo": "tibetan",
|
| 100 |
+
"tl": "tagalog",
|
| 101 |
+
"mg": "malagasy",
|
| 102 |
+
"as": "assamese",
|
| 103 |
+
"tt": "tatar",
|
| 104 |
+
"haw": "hawaiian",
|
| 105 |
+
"ln": "lingala",
|
| 106 |
+
"ha": "hausa",
|
| 107 |
+
"ba": "bashkir",
|
| 108 |
+
"jw": "javanese",
|
| 109 |
+
"su": "sundanese",
|
| 110 |
+
"yue": "cantonese",
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
# language code lookup by name, with a few language aliases
|
| 114 |
+
TO_LANGUAGE_CODE = {
|
| 115 |
+
**{language: code for code, language in LANGUAGES.items()},
|
| 116 |
+
"burmese": "my",
|
| 117 |
+
"valencian": "ca",
|
| 118 |
+
"flemish": "nl",
|
| 119 |
+
"haitian": "ht",
|
| 120 |
+
"letzeburgesch": "lb",
|
| 121 |
+
"pushto": "ps",
|
| 122 |
+
"panjabi": "pa",
|
| 123 |
+
"moldavian": "ro",
|
| 124 |
+
"moldovan": "ro",
|
| 125 |
+
"sinhalese": "si",
|
| 126 |
+
"castilian": "es",
|
| 127 |
+
"mandarin": "zh",
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@dataclass
|
| 132 |
+
class Tokenizer:
|
| 133 |
+
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
|
| 134 |
+
|
| 135 |
+
encoding: tiktoken.Encoding
|
| 136 |
+
num_languages: int
|
| 137 |
+
language: Optional[str] = None
|
| 138 |
+
task: Optional[str] = None
|
| 139 |
+
sot_sequence: Tuple[int] = ()
|
| 140 |
+
special_tokens: Dict[str, int] = field(default_factory=dict)
|
| 141 |
+
|
| 142 |
+
def __post_init__(self):
|
| 143 |
+
for special in self.encoding.special_tokens_set:
|
| 144 |
+
special_token = self.encoding.encode_single_token(special)
|
| 145 |
+
self.special_tokens[special] = special_token
|
| 146 |
+
|
| 147 |
+
sot: int = self.special_tokens["<|startoftranscript|>"]
|
| 148 |
+
translate: int = self.special_tokens["<|translate|>"]
|
| 149 |
+
transcribe: int = self.special_tokens["<|transcribe|>"]
|
| 150 |
+
|
| 151 |
+
langs = tuple(LANGUAGES.keys())[: self.num_languages]
|
| 152 |
+
sot_sequence = [sot]
|
| 153 |
+
if self.language is not None:
|
| 154 |
+
sot_sequence.append(sot + 1 + langs.index(self.language))
|
| 155 |
+
if self.task is not None:
|
| 156 |
+
task_token: int = transcribe if self.task == "transcribe" else translate
|
| 157 |
+
sot_sequence.append(task_token)
|
| 158 |
+
|
| 159 |
+
self.sot_sequence = tuple(sot_sequence)
|
| 160 |
+
|
| 161 |
+
def encode(self, text, **kwargs):
|
| 162 |
+
return self.encoding.encode(text, **kwargs)
|
| 163 |
+
|
| 164 |
+
def decode(self, token_ids: List[int], **kwargs) -> str:
|
| 165 |
+
token_ids = [t for t in token_ids if t < self.timestamp_begin]
|
| 166 |
+
return self.encoding.decode(token_ids, **kwargs)
|
| 167 |
+
|
| 168 |
+
def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
|
| 169 |
+
"""
|
| 170 |
+
Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
|
| 171 |
+
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
| 172 |
+
"""
|
| 173 |
+
return self.encoding.decode(token_ids, **kwargs)
|
| 174 |
+
|
| 175 |
+
@cached_property
|
| 176 |
+
def eot(self) -> int:
|
| 177 |
+
return self.encoding.eot_token
|
| 178 |
+
|
| 179 |
+
@cached_property
|
| 180 |
+
def transcribe(self) -> int:
|
| 181 |
+
return self.special_tokens["<|transcribe|>"]
|
| 182 |
+
|
| 183 |
+
@cached_property
|
| 184 |
+
def translate(self) -> int:
|
| 185 |
+
return self.special_tokens["<|translate|>"]
|
| 186 |
+
|
| 187 |
+
@cached_property
|
| 188 |
+
def sot(self) -> int:
|
| 189 |
+
return self.special_tokens["<|startoftranscript|>"]
|
| 190 |
+
|
| 191 |
+
@cached_property
|
| 192 |
+
def sot_lm(self) -> int:
|
| 193 |
+
return self.special_tokens["<|startoflm|>"]
|
| 194 |
+
|
| 195 |
+
@cached_property
|
| 196 |
+
def sot_prev(self) -> int:
|
| 197 |
+
return self.special_tokens["<|startofprev|>"]
|
| 198 |
+
|
| 199 |
+
@cached_property
|
| 200 |
+
def no_speech(self) -> int:
|
| 201 |
+
return self.special_tokens["<|nospeech|>"]
|
| 202 |
+
|
| 203 |
+
@cached_property
|
| 204 |
+
def no_timestamps(self) -> int:
|
| 205 |
+
return self.special_tokens["<|notimestamps|>"]
|
| 206 |
+
|
| 207 |
+
@cached_property
|
| 208 |
+
def timestamp_begin(self) -> int:
|
| 209 |
+
return self.special_tokens["<|0.00|>"]
|
| 210 |
+
|
| 211 |
+
@cached_property
|
| 212 |
+
def language_token(self) -> int:
|
| 213 |
+
"""Returns the token id corresponding to the value of the `language` field"""
|
| 214 |
+
if self.language is None:
|
| 215 |
+
raise ValueError("This tokenizer does not have language token configured")
|
| 216 |
+
|
| 217 |
+
return self.to_language_token(self.language)
|
| 218 |
+
|
| 219 |
+
def to_language_token(self, language):
|
| 220 |
+
if token := self.special_tokens.get(f"<|{language}|>", None):
|
| 221 |
+
return token
|
| 222 |
+
|
| 223 |
+
raise KeyError(f"Language {language} not found in tokenizer.")
|
| 224 |
+
|
| 225 |
+
@cached_property
|
| 226 |
+
def all_language_tokens(self) -> Tuple[int]:
|
| 227 |
+
result = []
|
| 228 |
+
for token, token_id in self.special_tokens.items():
|
| 229 |
+
if token.strip("<|>") in LANGUAGES:
|
| 230 |
+
result.append(token_id)
|
| 231 |
+
return tuple(result)[: self.num_languages]
|
| 232 |
+
|
| 233 |
+
@cached_property
|
| 234 |
+
def all_language_codes(self) -> Tuple[str]:
|
| 235 |
+
return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
|
| 236 |
+
|
| 237 |
+
@cached_property
|
| 238 |
+
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
| 239 |
+
return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
| 240 |
+
|
| 241 |
+
@cached_property
|
| 242 |
+
def non_speech_tokens(self) -> Tuple[int]:
|
| 243 |
+
"""
|
| 244 |
+
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
| 245 |
+
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
|
| 246 |
+
|
| 247 |
+
- ♪♪♪
|
| 248 |
+
- ( SPEAKING FOREIGN LANGUAGE )
|
| 249 |
+
- [DAVID] Hey there,
|
| 250 |
+
|
| 251 |
+
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
| 252 |
+
"""
|
| 253 |
+
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
| 254 |
+
symbols += (
|
| 255 |
+
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
| 259 |
+
# In case they're multiple tokens, suppress the first token, which is safe because:
|
| 260 |
+
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
|
| 261 |
+
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
|
| 262 |
+
miscellaneous = set("♩♪♫♬♭♮♯")
|
| 263 |
+
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
| 264 |
+
|
| 265 |
+
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
| 266 |
+
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
|
| 267 |
+
for symbol in symbols + list(miscellaneous):
|
| 268 |
+
for tokens in [
|
| 269 |
+
self.encoding.encode(symbol),
|
| 270 |
+
self.encoding.encode(" " + symbol),
|
| 271 |
+
]:
|
| 272 |
+
if len(tokens) == 1 or symbol in miscellaneous:
|
| 273 |
+
result.add(tokens[0])
|
| 274 |
+
|
| 275 |
+
return tuple(sorted(result))
|
| 276 |
+
|
| 277 |
+
def split_to_word_tokens(self, tokens: List[int]):
|
| 278 |
+
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
|
| 279 |
+
# These languages don't typically use spaces, so it is difficult to split words
|
| 280 |
+
# without morpheme analysis. Here, we instead split words at any
|
| 281 |
+
# position where the tokens are decoded as valid unicode points
|
| 282 |
+
return self.split_tokens_on_unicode(tokens)
|
| 283 |
+
|
| 284 |
+
return self.split_tokens_on_spaces(tokens)
|
| 285 |
+
|
| 286 |
+
def split_tokens_on_unicode(self, tokens: List[int]):
|
| 287 |
+
decoded_full = self.decode_with_timestamps(tokens)
|
| 288 |
+
replacement_char = "\ufffd"
|
| 289 |
+
|
| 290 |
+
words = []
|
| 291 |
+
word_tokens = []
|
| 292 |
+
current_tokens = []
|
| 293 |
+
unicode_offset = 0
|
| 294 |
+
|
| 295 |
+
for token in tokens:
|
| 296 |
+
current_tokens.append(token)
|
| 297 |
+
decoded = self.decode_with_timestamps(current_tokens)
|
| 298 |
+
|
| 299 |
+
if (
|
| 300 |
+
replacement_char not in decoded
|
| 301 |
+
or decoded_full[unicode_offset + decoded.index(replacement_char)]
|
| 302 |
+
== replacement_char
|
| 303 |
+
):
|
| 304 |
+
words.append(decoded)
|
| 305 |
+
word_tokens.append(current_tokens)
|
| 306 |
+
current_tokens = []
|
| 307 |
+
unicode_offset += len(decoded)
|
| 308 |
+
|
| 309 |
+
return words, word_tokens
|
| 310 |
+
|
| 311 |
+
def split_tokens_on_spaces(self, tokens: List[int]):
|
| 312 |
+
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
|
| 313 |
+
words = []
|
| 314 |
+
word_tokens = []
|
| 315 |
+
|
| 316 |
+
for subword, subword_tokens in zip(subwords, subword_tokens_list):
|
| 317 |
+
special = subword_tokens[0] >= self.eot
|
| 318 |
+
with_space = subword.startswith(" ")
|
| 319 |
+
punctuation = subword.strip() in string.punctuation
|
| 320 |
+
if special or with_space or punctuation or len(words) == 0:
|
| 321 |
+
words.append(subword)
|
| 322 |
+
word_tokens.append(subword_tokens)
|
| 323 |
+
else:
|
| 324 |
+
words[-1] = words[-1] + subword
|
| 325 |
+
word_tokens[-1].extend(subword_tokens)
|
| 326 |
+
|
| 327 |
+
return words, word_tokens
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
@lru_cache(maxsize=None)
|
| 331 |
+
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
| 332 |
+
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
| 333 |
+
ranks = {
|
| 334 |
+
base64.b64decode(token): int(rank)
|
| 335 |
+
for token, rank in (line.split() for line in open(vocab_path) if line)
|
| 336 |
+
}
|
| 337 |
+
n_vocab = len(ranks)
|
| 338 |
+
special_tokens = {}
|
| 339 |
+
|
| 340 |
+
specials = [
|
| 341 |
+
"<|endoftext|>",
|
| 342 |
+
"<|startoftranscript|>",
|
| 343 |
+
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
| 344 |
+
"<|translate|>",
|
| 345 |
+
"<|transcribe|>",
|
| 346 |
+
"<|startoflm|>",
|
| 347 |
+
"<|startofprev|>",
|
| 348 |
+
"<|nospeech|>",
|
| 349 |
+
"<|notimestamps|>",
|
| 350 |
+
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
| 351 |
+
]
|
| 352 |
+
|
| 353 |
+
for token in specials:
|
| 354 |
+
special_tokens[token] = n_vocab
|
| 355 |
+
n_vocab += 1
|
| 356 |
+
|
| 357 |
+
return tiktoken.Encoding(
|
| 358 |
+
name=os.path.basename(vocab_path),
|
| 359 |
+
explicit_n_vocab=n_vocab,
|
| 360 |
+
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
| 361 |
+
mergeable_ranks=ranks,
|
| 362 |
+
special_tokens=special_tokens,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
@lru_cache(maxsize=None)
|
| 367 |
+
def get_tokenizer(
|
| 368 |
+
multilingual: bool,
|
| 369 |
+
*,
|
| 370 |
+
num_languages: int = 99,
|
| 371 |
+
language: Optional[str] = None,
|
| 372 |
+
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
| 373 |
+
) -> Tokenizer:
|
| 374 |
+
if language is not None:
|
| 375 |
+
language = language.lower()
|
| 376 |
+
if language not in LANGUAGES:
|
| 377 |
+
if language in TO_LANGUAGE_CODE:
|
| 378 |
+
language = TO_LANGUAGE_CODE[language]
|
| 379 |
+
else:
|
| 380 |
+
raise ValueError(f"Unsupported language: {language}")
|
| 381 |
+
|
| 382 |
+
if multilingual:
|
| 383 |
+
encoding_name = "multilingual"
|
| 384 |
+
language = language or "en"
|
| 385 |
+
task = task or "transcribe"
|
| 386 |
+
else:
|
| 387 |
+
encoding_name = "gpt2"
|
| 388 |
+
language = None
|
| 389 |
+
task = None
|
| 390 |
+
|
| 391 |
+
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
| 392 |
+
|
| 393 |
+
return Tokenizer(
|
| 394 |
+
encoding=encoding, num_languages=num_languages, language=language, task=task
|
| 395 |
+
)
|
simul_whisper/whisper/trans_nopad.py
ADDED
|
@@ -0,0 +1,501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import warnings
|
| 4 |
+
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import tqdm
|
| 9 |
+
|
| 10 |
+
from whisper.audio import (
|
| 11 |
+
FRAMES_PER_SECOND,
|
| 12 |
+
HOP_LENGTH,
|
| 13 |
+
N_FRAMES,
|
| 14 |
+
N_SAMPLES,
|
| 15 |
+
SAMPLE_RATE,
|
| 16 |
+
log_mel_spectrogram,
|
| 17 |
+
pad_or_trim,
|
| 18 |
+
)
|
| 19 |
+
from whisper.decoding import DecodingOptions, DecodingResult
|
| 20 |
+
from whisper.timing import add_word_timestamps
|
| 21 |
+
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
| 22 |
+
from whisper.utils import (
|
| 23 |
+
exact_div,
|
| 24 |
+
format_timestamp,
|
| 25 |
+
get_writer,
|
| 26 |
+
make_safe,
|
| 27 |
+
optional_float,
|
| 28 |
+
optional_int,
|
| 29 |
+
str2bool,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
if TYPE_CHECKING:
|
| 33 |
+
from whisper.model import Whisper
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def transcribe(
|
| 37 |
+
model: "Whisper",
|
| 38 |
+
audio: Union[str, np.ndarray, torch.Tensor],
|
| 39 |
+
*,
|
| 40 |
+
verbose: Optional[bool] = None,
|
| 41 |
+
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
| 42 |
+
compression_ratio_threshold: Optional[float] = 2.4,
|
| 43 |
+
logprob_threshold: Optional[float] = -1.0,
|
| 44 |
+
no_speech_threshold: Optional[float] = 0.6,
|
| 45 |
+
condition_on_previous_text: bool = True,
|
| 46 |
+
initial_prompt: Optional[str] = None,
|
| 47 |
+
word_timestamps: bool = False,
|
| 48 |
+
prepend_punctuations: str = "\"'“¿([{-",
|
| 49 |
+
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
| 50 |
+
**decode_options,
|
| 51 |
+
):
|
| 52 |
+
"""
|
| 53 |
+
Transcribe an audio file using Whisper
|
| 54 |
+
|
| 55 |
+
Parameters
|
| 56 |
+
----------
|
| 57 |
+
model: Whisper
|
| 58 |
+
The Whisper model instance
|
| 59 |
+
|
| 60 |
+
audio: Union[str, np.ndarray, torch.Tensor]
|
| 61 |
+
The path to the audio file to open, or the audio waveform
|
| 62 |
+
|
| 63 |
+
verbose: bool
|
| 64 |
+
Whether to display the text being decoded to the console. If True, displays all the details,
|
| 65 |
+
If False, displays minimal details. If None, does not display anything
|
| 66 |
+
|
| 67 |
+
temperature: Union[float, Tuple[float, ...]]
|
| 68 |
+
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
|
| 69 |
+
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
| 70 |
+
|
| 71 |
+
compression_ratio_threshold: float
|
| 72 |
+
If the gzip compression ratio is above this value, treat as failed
|
| 73 |
+
|
| 74 |
+
logprob_threshold: float
|
| 75 |
+
If the average log probability over sampled tokens is below this value, treat as failed
|
| 76 |
+
|
| 77 |
+
no_speech_threshold: float
|
| 78 |
+
If the no_speech probability is higher than this value AND the average log probability
|
| 79 |
+
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
| 80 |
+
|
| 81 |
+
condition_on_previous_text: bool
|
| 82 |
+
if True, the previous output of the model is provided as a prompt for the next window;
|
| 83 |
+
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
| 84 |
+
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
| 85 |
+
|
| 86 |
+
word_timestamps: bool
|
| 87 |
+
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
|
| 88 |
+
and include the timestamps for each word in each segment.
|
| 89 |
+
|
| 90 |
+
prepend_punctuations: str
|
| 91 |
+
If word_timestamps is True, merge these punctuation symbols with the next word
|
| 92 |
+
|
| 93 |
+
append_punctuations: str
|
| 94 |
+
If word_timestamps is True, merge these punctuation symbols with the previous word
|
| 95 |
+
|
| 96 |
+
initial_prompt: Optional[str]
|
| 97 |
+
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
| 98 |
+
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
| 99 |
+
to make it more likely to predict those word correctly.
|
| 100 |
+
|
| 101 |
+
decode_options: dict
|
| 102 |
+
Keyword arguments to construct `DecodingOptions` instances
|
| 103 |
+
|
| 104 |
+
Returns
|
| 105 |
+
-------
|
| 106 |
+
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
| 107 |
+
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
| 108 |
+
"""
|
| 109 |
+
# print("HACKED")
|
| 110 |
+
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
| 111 |
+
if model.device == torch.device("cpu"):
|
| 112 |
+
if torch.cuda.is_available():
|
| 113 |
+
warnings.warn("Performing inference on CPU when CUDA is available")
|
| 114 |
+
if dtype == torch.float16:
|
| 115 |
+
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
| 116 |
+
dtype = torch.float32
|
| 117 |
+
|
| 118 |
+
if dtype == torch.float32:
|
| 119 |
+
decode_options["fp16"] = False
|
| 120 |
+
|
| 121 |
+
# Pad 30-seconds of silence to the input audio, for slicing
|
| 122 |
+
mel = log_mel_spectrogram(audio, padding=0) # log_mel_spectrogram(audio, padding=N_SAMPLES) # 添加16000*30 = 480000个点
|
| 123 |
+
# mel = pad_or_trim(mel, 3000)
|
| 124 |
+
content_frames = mel.shape[-1] # - N_FRAMES # 对应3000帧;真正有内容的是去掉尾部3000的那些数据
|
| 125 |
+
|
| 126 |
+
# 判断语种
|
| 127 |
+
if decode_options.get("language", None) is None:
|
| 128 |
+
# 如果是单语种模型,直接设成英文
|
| 129 |
+
if not model.is_multilingual:
|
| 130 |
+
decode_options["language"] = "en"
|
| 131 |
+
# 否则需要前传一次
|
| 132 |
+
else:
|
| 133 |
+
if verbose:
|
| 134 |
+
print(
|
| 135 |
+
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
| 136 |
+
)
|
| 137 |
+
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
| 138 |
+
# print(mel_segment.shape)
|
| 139 |
+
_, probs = model.detect_language(mel_segment)
|
| 140 |
+
decode_options["language"] = max(probs, key=probs.get)
|
| 141 |
+
if verbose is not None:
|
| 142 |
+
print(
|
| 143 |
+
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
language: str = decode_options["language"]
|
| 147 |
+
task: str = decode_options.get("task", "transcribe")
|
| 148 |
+
# 输出编码器
|
| 149 |
+
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
| 150 |
+
|
| 151 |
+
# 词级别时间戳
|
| 152 |
+
if word_timestamps and task == "translate":
|
| 153 |
+
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
| 154 |
+
|
| 155 |
+
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
|
| 156 |
+
temperatures = (
|
| 157 |
+
[temperature] if isinstance(temperature, (int, float)) else temperature
|
| 158 |
+
)
|
| 159 |
+
decode_result = None
|
| 160 |
+
|
| 161 |
+
for t in temperatures:
|
| 162 |
+
kwargs = {**decode_options}
|
| 163 |
+
if t > 0:
|
| 164 |
+
# disable beam_size and patience when t > 0
|
| 165 |
+
kwargs.pop("beam_size", None)
|
| 166 |
+
kwargs.pop("patience", None)
|
| 167 |
+
else:
|
| 168 |
+
# disable best_of when t == 0
|
| 169 |
+
kwargs.pop("best_of", None)
|
| 170 |
+
|
| 171 |
+
options = DecodingOptions(**kwargs, temperature=t)
|
| 172 |
+
decode_result = model.decode(segment, options)
|
| 173 |
+
|
| 174 |
+
# 几种解码可能失败的情况。这些情况下会重复解码
|
| 175 |
+
# 感觉是一种KnowHow的东西 或许ChatGPT里有不少这种trick
|
| 176 |
+
needs_fallback = False
|
| 177 |
+
if (
|
| 178 |
+
compression_ratio_threshold is not None
|
| 179 |
+
and decode_result.compression_ratio > compression_ratio_threshold
|
| 180 |
+
):
|
| 181 |
+
needs_fallback = True # too repetitive
|
| 182 |
+
if (
|
| 183 |
+
logprob_threshold is not None
|
| 184 |
+
and decode_result.avg_logprob < logprob_threshold
|
| 185 |
+
):
|
| 186 |
+
needs_fallback = True # average log probability is too low
|
| 187 |
+
if (
|
| 188 |
+
no_speech_threshold is not None
|
| 189 |
+
and decode_result.no_speech_prob > no_speech_threshold
|
| 190 |
+
):
|
| 191 |
+
needs_fallback = False # silence
|
| 192 |
+
if not needs_fallback:
|
| 193 |
+
break
|
| 194 |
+
# print("decode with temperature {} compress rate {:.3f}/{:.3f}, log_prob {:.3f}/{:.3f}, {:.3f}/{:.3f}".format(
|
| 195 |
+
# t,
|
| 196 |
+
# decode_result.compression_ratio, compression_ratio_threshold,
|
| 197 |
+
# -decode_result.avg_logprob, -logprob_threshold,
|
| 198 |
+
# decode_result.no_speech_prob, no_speech_threshold
|
| 199 |
+
# ))
|
| 200 |
+
|
| 201 |
+
return decode_result
|
| 202 |
+
|
| 203 |
+
seek = 0
|
| 204 |
+
input_stride = exact_div(
|
| 205 |
+
N_FRAMES, model.dims.n_audio_ctx
|
| 206 |
+
) # mel frames per output token: 2
|
| 207 |
+
# 这里output token指的应该是CNN输出的那个东西
|
| 208 |
+
|
| 209 |
+
time_precision = (
|
| 210 |
+
input_stride * HOP_LENGTH / SAMPLE_RATE
|
| 211 |
+
) # time per output token: 0.02 (seconds)
|
| 212 |
+
all_tokens = []
|
| 213 |
+
all_segments = []
|
| 214 |
+
prompt_reset_since = 0
|
| 215 |
+
|
| 216 |
+
if initial_prompt is not None:
|
| 217 |
+
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
| 218 |
+
all_tokens.extend(initial_prompt_tokens)
|
| 219 |
+
else:
|
| 220 |
+
initial_prompt_tokens = []
|
| 221 |
+
|
| 222 |
+
def new_segment(
|
| 223 |
+
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
|
| 224 |
+
):
|
| 225 |
+
tokens = tokens.tolist()
|
| 226 |
+
text_tokens = [token for token in tokens if token < tokenizer.eot]
|
| 227 |
+
return {
|
| 228 |
+
"seek": seek,
|
| 229 |
+
"start": start,
|
| 230 |
+
"end": end,
|
| 231 |
+
"text": tokenizer.decode(text_tokens),
|
| 232 |
+
"tokens": tokens,
|
| 233 |
+
"temperature": result.temperature,
|
| 234 |
+
"avg_logprob": result.avg_logprob,
|
| 235 |
+
"compression_ratio": result.compression_ratio,
|
| 236 |
+
"no_speech_prob": result.no_speech_prob,
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
| 240 |
+
with tqdm.tqdm(
|
| 241 |
+
total=content_frames, unit="frames", disable=verbose is not False
|
| 242 |
+
) as pbar:
|
| 243 |
+
last_speech_timestamp = 0.0
|
| 244 |
+
while seek < content_frames: # seek:标记mel频谱当前帧的位置 直接跳过Padding上的部分
|
| 245 |
+
# print("seek segments", seek, content_frames)
|
| 246 |
+
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) # 本片段的开始时间
|
| 247 |
+
# mel_segment = mel[:, seek : seek + N_FRAMES] # 获得当前片段的数据
|
| 248 |
+
mel_segment = mel[:, seek:]
|
| 249 |
+
segment_size = min(N_FRAMES, content_frames - seek) # segment_size: 排除padding的真的长度。content_frames:有内容的段的真正长度 如果不够N_FRAMES的话就会截断
|
| 250 |
+
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE # 当前片段的时长
|
| 251 |
+
mel_segment = mel_segment.to(model.device).to(dtype) # pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) # 补到mel_segment帧
|
| 252 |
+
|
| 253 |
+
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
| 254 |
+
result: DecodingResult = decode_with_fallback(mel_segment)
|
| 255 |
+
tokens = torch.tensor(result.tokens)
|
| 256 |
+
|
| 257 |
+
# 跳过静音部分
|
| 258 |
+
if no_speech_threshold is not None:
|
| 259 |
+
# no voice activity check
|
| 260 |
+
should_skip = result.no_speech_prob > no_speech_threshold
|
| 261 |
+
if (
|
| 262 |
+
logprob_threshold is not None
|
| 263 |
+
and result.avg_logprob > logprob_threshold
|
| 264 |
+
):
|
| 265 |
+
# don't skip if the logprob is high enough, despite the no_speech_prob
|
| 266 |
+
should_skip = False
|
| 267 |
+
|
| 268 |
+
if should_skip:
|
| 269 |
+
seek += segment_size # fast-forward to the next segment boundary
|
| 270 |
+
continue
|
| 271 |
+
|
| 272 |
+
previous_seek = seek
|
| 273 |
+
current_segments = []
|
| 274 |
+
|
| 275 |
+
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) # timestamp begin是<|0.00|>的token;bos比文字token大,eos的值比bos还大,所以是ge
|
| 276 |
+
timestamp_tokens[-1] = False
|
| 277 |
+
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] # 如果最后是[False,True]:本段里一个句子结束了
|
| 278 |
+
|
| 279 |
+
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
| 280 |
+
# torch.where(condition) is identical to torch.nonzero(condition, as_tuple=True).
|
| 281 |
+
# timestamp_token就是个一维向量吧 那为啥不直接nonzero
|
| 282 |
+
# 如果有两个连续的时间戳 这个会是一个一维tensor 是这两个连续时间戳的结尾位置
|
| 283 |
+
# 多个的话指向第二个 那如果有三个怎么办?
|
| 284 |
+
# 否则是个0维tensor
|
| 285 |
+
|
| 286 |
+
consecutive.add_(1) # 0维tensor+1还是0维 哪儿找的这些edge cases js是吧
|
| 287 |
+
if len(consecutive) > 0:
|
| 288 |
+
# if the output contains two consecutive timestamp tokens
|
| 289 |
+
slices = consecutive.tolist()
|
| 290 |
+
if single_timestamp_ending:
|
| 291 |
+
slices.append(len(tokens)) # 把最后一段的结尾也加进去
|
| 292 |
+
# print("many sentenses", consecutive)
|
| 293 |
+
last_slice = 0
|
| 294 |
+
for current_slice in slices:
|
| 295 |
+
sliced_tokens = tokens[last_slice:current_slice]
|
| 296 |
+
# 看起来语音开始帧、语音结束帧的位置会被编码到start_timestamp中
|
| 297 |
+
start_timestamp_pos = (
|
| 298 |
+
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
| 299 |
+
)
|
| 300 |
+
end_timestamp_pos = (
|
| 301 |
+
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
| 302 |
+
)
|
| 303 |
+
# 获取一个新的语音段
|
| 304 |
+
current_segments.append(
|
| 305 |
+
new_segment(
|
| 306 |
+
start=time_offset + start_timestamp_pos * time_precision,
|
| 307 |
+
end=time_offset + end_timestamp_pos * time_precision,
|
| 308 |
+
tokens=sliced_tokens,
|
| 309 |
+
result=result,
|
| 310 |
+
)
|
| 311 |
+
)
|
| 312 |
+
last_slice = current_slice
|
| 313 |
+
|
| 314 |
+
if single_timestamp_ending:
|
| 315 |
+
# single timestamp at the end means no speech after the last timestamp.
|
| 316 |
+
seek += segment_size
|
| 317 |
+
else:
|
| 318 |
+
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
| 319 |
+
# 如果语音尚未结束,那么seek变为上一个结束的语段的位置
|
| 320 |
+
# 换句话说就是针对30s长的chunk的语音设计的
|
| 321 |
+
last_timestamp_pos = (
|
| 322 |
+
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
| 323 |
+
)
|
| 324 |
+
seek += last_timestamp_pos * input_stride
|
| 325 |
+
else:
|
| 326 |
+
duration = segment_duration
|
| 327 |
+
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
| 328 |
+
# print(timestamps)
|
| 329 |
+
if (
|
| 330 |
+
len(timestamps) > 0
|
| 331 |
+
and timestamps[-1].item() != tokenizer.timestamp_begin
|
| 332 |
+
):
|
| 333 |
+
# no consecutive timestamps but it has a timestamp; use the last one.
|
| 334 |
+
# 取最后一个;假设要么有一个结束的time stamp;要么有一对儿?
|
| 335 |
+
# 如果里面只有一个开始的timestamp 似乎后面的东西都会被丢掉?
|
| 336 |
+
last_timestamp_pos = (
|
| 337 |
+
timestamps[-1].item() - tokenizer.timestamp_begin
|
| 338 |
+
)
|
| 339 |
+
duration = last_timestamp_pos * time_precision
|
| 340 |
+
|
| 341 |
+
current_segments.append(
|
| 342 |
+
new_segment(
|
| 343 |
+
start=time_offset,
|
| 344 |
+
end=time_offset + duration,
|
| 345 |
+
tokens=tokens,
|
| 346 |
+
result=result,
|
| 347 |
+
)
|
| 348 |
+
)
|
| 349 |
+
seek += segment_size
|
| 350 |
+
|
| 351 |
+
# 每个token有自己的时间戳
|
| 352 |
+
if word_timestamps:
|
| 353 |
+
add_word_timestamps(
|
| 354 |
+
segments=current_segments,
|
| 355 |
+
model=model,
|
| 356 |
+
tokenizer=tokenizer,
|
| 357 |
+
mel=mel_segment,
|
| 358 |
+
num_frames=segment_size,
|
| 359 |
+
prepend_punctuations=prepend_punctuations,
|
| 360 |
+
append_punctuations=append_punctuations,
|
| 361 |
+
last_speech_timestamp=last_speech_timestamp,
|
| 362 |
+
)
|
| 363 |
+
word_end_timestamps = [
|
| 364 |
+
w["end"] for s in current_segments for w in s["words"]
|
| 365 |
+
]
|
| 366 |
+
if len(word_end_timestamps) > 0:
|
| 367 |
+
last_speech_timestamp = word_end_timestamps[-1]
|
| 368 |
+
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
| 369 |
+
seek_shift = round(
|
| 370 |
+
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
| 371 |
+
)
|
| 372 |
+
if seek_shift > 0:
|
| 373 |
+
seek = previous_seek + seek_shift
|
| 374 |
+
|
| 375 |
+
if verbose:
|
| 376 |
+
for segment in current_segments:
|
| 377 |
+
start, end, text = segment["start"], segment["end"], segment["text"]
|
| 378 |
+
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
|
| 379 |
+
print(make_safe(line))
|
| 380 |
+
|
| 381 |
+
# if a segment is instantaneous or does not contain text, clear it
|
| 382 |
+
for i, segment in enumerate(current_segments):
|
| 383 |
+
if segment["start"] == segment["end"] or segment["text"].strip() == "":
|
| 384 |
+
segment["text"] = ""
|
| 385 |
+
segment["tokens"] = []
|
| 386 |
+
segment["words"] = []
|
| 387 |
+
|
| 388 |
+
# 更新结果
|
| 389 |
+
all_segments.extend(
|
| 390 |
+
[
|
| 391 |
+
{"id": i, **segment}
|
| 392 |
+
for i, segment in enumerate(
|
| 393 |
+
current_segments, start=len(all_segments)
|
| 394 |
+
)
|
| 395 |
+
]
|
| 396 |
+
)
|
| 397 |
+
all_tokens.extend(
|
| 398 |
+
[token for segment in current_segments for token in segment["tokens"]]
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
if not condition_on_previous_text or result.temperature > 0.5:
|
| 402 |
+
# do not feed the prompt tokens if a high temperature was used
|
| 403 |
+
prompt_reset_since = len(all_tokens)
|
| 404 |
+
|
| 405 |
+
# update progress bar
|
| 406 |
+
pbar.update(min(content_frames, seek) - previous_seek)
|
| 407 |
+
|
| 408 |
+
# print("太长了")
|
| 409 |
+
# break
|
| 410 |
+
|
| 411 |
+
return dict(
|
| 412 |
+
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
| 413 |
+
segments=all_segments,
|
| 414 |
+
language=language,
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def cli():
|
| 419 |
+
from . import available_models
|
| 420 |
+
|
| 421 |
+
# fmt: off
|
| 422 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 423 |
+
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
| 424 |
+
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
| 425 |
+
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
| 426 |
+
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
| 427 |
+
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
| 428 |
+
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
|
| 429 |
+
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
| 430 |
+
|
| 431 |
+
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
| 432 |
+
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
| 433 |
+
|
| 434 |
+
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
| 435 |
+
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
| 436 |
+
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
| 437 |
+
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
| 438 |
+
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
| 439 |
+
|
| 440 |
+
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
| 441 |
+
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
| 442 |
+
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
| 443 |
+
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
| 444 |
+
|
| 445 |
+
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
| 446 |
+
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
| 447 |
+
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
| 448 |
+
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
| 449 |
+
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
|
| 450 |
+
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
| 451 |
+
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
| 452 |
+
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
| 453 |
+
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
|
| 454 |
+
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
|
| 455 |
+
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
| 456 |
+
# fmt: on
|
| 457 |
+
|
| 458 |
+
args = parser.parse_args().__dict__
|
| 459 |
+
model_name: str = args.pop("model")
|
| 460 |
+
model_dir: str = args.pop("model_dir")
|
| 461 |
+
output_dir: str = args.pop("output_dir")
|
| 462 |
+
output_format: str = args.pop("output_format")
|
| 463 |
+
device: str = args.pop("device")
|
| 464 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 465 |
+
|
| 466 |
+
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
| 467 |
+
if args["language"] is not None:
|
| 468 |
+
warnings.warn(
|
| 469 |
+
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
|
| 470 |
+
)
|
| 471 |
+
args["language"] = "en"
|
| 472 |
+
|
| 473 |
+
temperature = args.pop("temperature")
|
| 474 |
+
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
| 475 |
+
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
|
| 476 |
+
else:
|
| 477 |
+
temperature = [temperature]
|
| 478 |
+
|
| 479 |
+
if (threads := args.pop("threads")) > 0:
|
| 480 |
+
torch.set_num_threads(threads)
|
| 481 |
+
|
| 482 |
+
from . import load_model
|
| 483 |
+
|
| 484 |
+
model = load_model(model_name, device=device, download_root=model_dir)
|
| 485 |
+
|
| 486 |
+
writer = get_writer(output_format, output_dir)
|
| 487 |
+
word_options = ["highlight_words", "max_line_count", "max_line_width"]
|
| 488 |
+
if not args["word_timestamps"]:
|
| 489 |
+
for option in word_options:
|
| 490 |
+
if args[option]:
|
| 491 |
+
parser.error(f"--{option} requires --word_timestamps True")
|
| 492 |
+
if args["max_line_count"] and not args["max_line_width"]:
|
| 493 |
+
warnings.warn("--max_line_count has no effect without --max_line_width")
|
| 494 |
+
writer_args = {arg: args.pop(arg) for arg in word_options}
|
| 495 |
+
for audio_path in args.pop("audio"):
|
| 496 |
+
result = transcribe(model, audio_path, temperature=temperature, **args)
|
| 497 |
+
writer(result, audio_path, writer_args)
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
if __name__ == "__main__":
|
| 501 |
+
cli()
|
simul_whisper/whisper/transcribe.py
ADDED
|
@@ -0,0 +1,467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import warnings
|
| 4 |
+
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import tqdm
|
| 9 |
+
|
| 10 |
+
from .audio import (
|
| 11 |
+
FRAMES_PER_SECOND,
|
| 12 |
+
HOP_LENGTH,
|
| 13 |
+
N_FRAMES,
|
| 14 |
+
N_SAMPLES,
|
| 15 |
+
SAMPLE_RATE,
|
| 16 |
+
log_mel_spectrogram,
|
| 17 |
+
pad_or_trim,
|
| 18 |
+
)
|
| 19 |
+
from .decoding import DecodingOptions, DecodingResult
|
| 20 |
+
from .timing import add_word_timestamps
|
| 21 |
+
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
| 22 |
+
from .utils import (
|
| 23 |
+
exact_div,
|
| 24 |
+
format_timestamp,
|
| 25 |
+
get_writer,
|
| 26 |
+
make_safe,
|
| 27 |
+
optional_float,
|
| 28 |
+
optional_int,
|
| 29 |
+
str2bool,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
if TYPE_CHECKING:
|
| 33 |
+
from .model import Whisper
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def transcribe(
|
| 37 |
+
model: "Whisper",
|
| 38 |
+
audio: Union[str, np.ndarray, torch.Tensor],
|
| 39 |
+
*,
|
| 40 |
+
verbose: Optional[bool] = None,
|
| 41 |
+
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
| 42 |
+
compression_ratio_threshold: Optional[float] = 2.4,
|
| 43 |
+
logprob_threshold: Optional[float] = -1.0,
|
| 44 |
+
no_speech_threshold: Optional[float] = 0.6,
|
| 45 |
+
condition_on_previous_text: bool = True,
|
| 46 |
+
initial_prompt: Optional[str] = None,
|
| 47 |
+
word_timestamps: bool = False,
|
| 48 |
+
prepend_punctuations: str = "\"'“¿([{-",
|
| 49 |
+
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
| 50 |
+
**decode_options,
|
| 51 |
+
):
|
| 52 |
+
"""
|
| 53 |
+
Transcribe an audio file using Whisper
|
| 54 |
+
|
| 55 |
+
Parameters
|
| 56 |
+
----------
|
| 57 |
+
model: Whisper
|
| 58 |
+
The Whisper model instance
|
| 59 |
+
|
| 60 |
+
audio: Union[str, np.ndarray, torch.Tensor]
|
| 61 |
+
The path to the audio file to open, or the audio waveform
|
| 62 |
+
|
| 63 |
+
verbose: bool
|
| 64 |
+
Whether to display the text being decoded to the console. If True, displays all the details,
|
| 65 |
+
If False, displays minimal details. If None, does not display anything
|
| 66 |
+
|
| 67 |
+
temperature: Union[float, Tuple[float, ...]]
|
| 68 |
+
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
|
| 69 |
+
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
| 70 |
+
|
| 71 |
+
compression_ratio_threshold: float
|
| 72 |
+
If the gzip compression ratio is above this value, treat as failed
|
| 73 |
+
|
| 74 |
+
logprob_threshold: float
|
| 75 |
+
If the average log probability over sampled tokens is below this value, treat as failed
|
| 76 |
+
|
| 77 |
+
no_speech_threshold: float
|
| 78 |
+
If the no_speech probability is higher than this value AND the average log probability
|
| 79 |
+
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
| 80 |
+
|
| 81 |
+
condition_on_previous_text: bool
|
| 82 |
+
if True, the previous output of the model is provided as a prompt for the next window;
|
| 83 |
+
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
| 84 |
+
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
| 85 |
+
|
| 86 |
+
word_timestamps: bool
|
| 87 |
+
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
|
| 88 |
+
and include the timestamps for each word in each segment.
|
| 89 |
+
|
| 90 |
+
prepend_punctuations: str
|
| 91 |
+
If word_timestamps is True, merge these punctuation symbols with the next word
|
| 92 |
+
|
| 93 |
+
append_punctuations: str
|
| 94 |
+
If word_timestamps is True, merge these punctuation symbols with the previous word
|
| 95 |
+
|
| 96 |
+
initial_prompt: Optional[str]
|
| 97 |
+
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
| 98 |
+
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
| 99 |
+
to make it more likely to predict those word correctly.
|
| 100 |
+
|
| 101 |
+
decode_options: dict
|
| 102 |
+
Keyword arguments to construct `DecodingOptions` instances
|
| 103 |
+
|
| 104 |
+
Returns
|
| 105 |
+
-------
|
| 106 |
+
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
| 107 |
+
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
| 108 |
+
"""
|
| 109 |
+
# print("transcribe")
|
| 110 |
+
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
| 111 |
+
if model.device == torch.device("cpu"):
|
| 112 |
+
if torch.cuda.is_available():
|
| 113 |
+
warnings.warn("Performing inference on CPU when CUDA is available")
|
| 114 |
+
if dtype == torch.float16:
|
| 115 |
+
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
| 116 |
+
dtype = torch.float32
|
| 117 |
+
|
| 118 |
+
if dtype == torch.float32:
|
| 119 |
+
decode_options["fp16"] = False
|
| 120 |
+
|
| 121 |
+
# Pad 30-seconds of silence to the input audio, for slicing
|
| 122 |
+
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
|
| 123 |
+
content_frames = mel.shape[-1] - N_FRAMES
|
| 124 |
+
|
| 125 |
+
if decode_options.get("language", None) is None:
|
| 126 |
+
if not model.is_multilingual:
|
| 127 |
+
decode_options["language"] = "en"
|
| 128 |
+
else:
|
| 129 |
+
if verbose:
|
| 130 |
+
print(
|
| 131 |
+
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
| 132 |
+
)
|
| 133 |
+
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
| 134 |
+
# print(mel_segment.shape)
|
| 135 |
+
_, probs = model.detect_language(mel_segment)
|
| 136 |
+
decode_options["language"] = max(probs, key=probs.get)
|
| 137 |
+
if verbose is not None:
|
| 138 |
+
print(
|
| 139 |
+
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
language: str = decode_options["language"]
|
| 143 |
+
task: str = decode_options.get("task", "transcribe")
|
| 144 |
+
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
| 145 |
+
|
| 146 |
+
if word_timestamps and task == "translate":
|
| 147 |
+
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
| 148 |
+
|
| 149 |
+
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
|
| 150 |
+
temperatures = (
|
| 151 |
+
[temperature] if isinstance(temperature, (int, float)) else temperature
|
| 152 |
+
)
|
| 153 |
+
decode_result = None
|
| 154 |
+
|
| 155 |
+
for t in temperatures:
|
| 156 |
+
kwargs = {**decode_options}
|
| 157 |
+
if t > 0:
|
| 158 |
+
# disable beam_size and patience when t > 0
|
| 159 |
+
kwargs.pop("beam_size", None)
|
| 160 |
+
kwargs.pop("patience", None)
|
| 161 |
+
else:
|
| 162 |
+
# disable best_of when t == 0
|
| 163 |
+
kwargs.pop("best_of", None)
|
| 164 |
+
|
| 165 |
+
options = DecodingOptions(**kwargs, temperature=t)
|
| 166 |
+
decode_result = model.decode(segment, options)
|
| 167 |
+
|
| 168 |
+
needs_fallback = False
|
| 169 |
+
if (
|
| 170 |
+
compression_ratio_threshold is not None
|
| 171 |
+
and decode_result.compression_ratio > compression_ratio_threshold
|
| 172 |
+
):
|
| 173 |
+
needs_fallback = True # too repetitive
|
| 174 |
+
if (
|
| 175 |
+
logprob_threshold is not None
|
| 176 |
+
and decode_result.avg_logprob < logprob_threshold
|
| 177 |
+
):
|
| 178 |
+
needs_fallback = True # average log probability is too low
|
| 179 |
+
if (
|
| 180 |
+
no_speech_threshold is not None
|
| 181 |
+
and decode_result.no_speech_prob > no_speech_threshold
|
| 182 |
+
):
|
| 183 |
+
needs_fallback = False # silence
|
| 184 |
+
if not needs_fallback:
|
| 185 |
+
break
|
| 186 |
+
|
| 187 |
+
return decode_result
|
| 188 |
+
|
| 189 |
+
seek = 0
|
| 190 |
+
input_stride = exact_div(
|
| 191 |
+
N_FRAMES, model.dims.n_audio_ctx
|
| 192 |
+
) # mel frames per output token: 2
|
| 193 |
+
time_precision = (
|
| 194 |
+
input_stride * HOP_LENGTH / SAMPLE_RATE
|
| 195 |
+
) # time per output token: 0.02 (seconds)
|
| 196 |
+
all_tokens = []
|
| 197 |
+
all_segments = []
|
| 198 |
+
prompt_reset_since = 0
|
| 199 |
+
|
| 200 |
+
if initial_prompt is not None:
|
| 201 |
+
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
| 202 |
+
all_tokens.extend(initial_prompt_tokens)
|
| 203 |
+
else:
|
| 204 |
+
initial_prompt_tokens = []
|
| 205 |
+
|
| 206 |
+
def new_segment(
|
| 207 |
+
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
|
| 208 |
+
):
|
| 209 |
+
tokens = tokens.tolist()
|
| 210 |
+
text_tokens = [token for token in tokens if token < tokenizer.eot]
|
| 211 |
+
return {
|
| 212 |
+
"seek": seek,
|
| 213 |
+
"start": start,
|
| 214 |
+
"end": end,
|
| 215 |
+
"text": tokenizer.decode(text_tokens),
|
| 216 |
+
"tokens": tokens,
|
| 217 |
+
"temperature": result.temperature,
|
| 218 |
+
"avg_logprob": result.avg_logprob,
|
| 219 |
+
"compression_ratio": result.compression_ratio,
|
| 220 |
+
"no_speech_prob": result.no_speech_prob,
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
| 224 |
+
with tqdm.tqdm(
|
| 225 |
+
total=content_frames, unit="frames", disable=verbose is not False
|
| 226 |
+
) as pbar:
|
| 227 |
+
last_speech_timestamp = 0.0
|
| 228 |
+
while seek < content_frames:
|
| 229 |
+
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
| 230 |
+
mel_segment = mel[:, seek : seek + N_FRAMES]
|
| 231 |
+
segment_size = min(N_FRAMES, content_frames - seek)
|
| 232 |
+
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
| 233 |
+
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
| 234 |
+
|
| 235 |
+
# print("melshape", mel_segment.shape)
|
| 236 |
+
|
| 237 |
+
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
| 238 |
+
result: DecodingResult = decode_with_fallback(mel_segment)
|
| 239 |
+
tokens = torch.tensor(result.tokens)
|
| 240 |
+
|
| 241 |
+
if no_speech_threshold is not None:
|
| 242 |
+
# no voice activity check
|
| 243 |
+
should_skip = result.no_speech_prob > no_speech_threshold
|
| 244 |
+
if (
|
| 245 |
+
logprob_threshold is not None
|
| 246 |
+
and result.avg_logprob > logprob_threshold
|
| 247 |
+
):
|
| 248 |
+
# don't skip if the logprob is high enough, despite the no_speech_prob
|
| 249 |
+
should_skip = False
|
| 250 |
+
|
| 251 |
+
if should_skip:
|
| 252 |
+
seek += segment_size # fast-forward to the next segment boundary
|
| 253 |
+
continue
|
| 254 |
+
|
| 255 |
+
previous_seek = seek
|
| 256 |
+
current_segments = []
|
| 257 |
+
|
| 258 |
+
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
| 259 |
+
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
| 260 |
+
|
| 261 |
+
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
| 262 |
+
consecutive.add_(1)
|
| 263 |
+
if len(consecutive) > 0:
|
| 264 |
+
# if the output contains two consecutive timestamp tokens
|
| 265 |
+
slices = consecutive.tolist()
|
| 266 |
+
if single_timestamp_ending:
|
| 267 |
+
slices.append(len(tokens))
|
| 268 |
+
|
| 269 |
+
last_slice = 0
|
| 270 |
+
for current_slice in slices:
|
| 271 |
+
sliced_tokens = tokens[last_slice:current_slice]
|
| 272 |
+
start_timestamp_pos = (
|
| 273 |
+
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
| 274 |
+
)
|
| 275 |
+
end_timestamp_pos = (
|
| 276 |
+
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
| 277 |
+
)
|
| 278 |
+
current_segments.append(
|
| 279 |
+
new_segment(
|
| 280 |
+
start=time_offset + start_timestamp_pos * time_precision,
|
| 281 |
+
end=time_offset + end_timestamp_pos * time_precision,
|
| 282 |
+
tokens=sliced_tokens,
|
| 283 |
+
result=result,
|
| 284 |
+
)
|
| 285 |
+
)
|
| 286 |
+
last_slice = current_slice
|
| 287 |
+
|
| 288 |
+
if single_timestamp_ending:
|
| 289 |
+
# single timestamp at the end means no speech after the last timestamp.
|
| 290 |
+
seek += segment_size
|
| 291 |
+
else:
|
| 292 |
+
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
| 293 |
+
last_timestamp_pos = (
|
| 294 |
+
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
| 295 |
+
)
|
| 296 |
+
seek += last_timestamp_pos * input_stride
|
| 297 |
+
else:
|
| 298 |
+
duration = segment_duration
|
| 299 |
+
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
| 300 |
+
if (
|
| 301 |
+
len(timestamps) > 0
|
| 302 |
+
and timestamps[-1].item() != tokenizer.timestamp_begin
|
| 303 |
+
):
|
| 304 |
+
# no consecutive timestamps but it has a timestamp; use the last one.
|
| 305 |
+
last_timestamp_pos = (
|
| 306 |
+
timestamps[-1].item() - tokenizer.timestamp_begin
|
| 307 |
+
)
|
| 308 |
+
duration = last_timestamp_pos * time_precision
|
| 309 |
+
|
| 310 |
+
current_segments.append(
|
| 311 |
+
new_segment(
|
| 312 |
+
start=time_offset,
|
| 313 |
+
end=time_offset + duration,
|
| 314 |
+
tokens=tokens,
|
| 315 |
+
result=result,
|
| 316 |
+
)
|
| 317 |
+
)
|
| 318 |
+
seek += segment_size
|
| 319 |
+
|
| 320 |
+
# print("word_timestamps, ", word_timestamps)
|
| 321 |
+
if word_timestamps:
|
| 322 |
+
# print("=========run timestamps here=========")
|
| 323 |
+
add_word_timestamps(
|
| 324 |
+
segments=current_segments,
|
| 325 |
+
model=model,
|
| 326 |
+
tokenizer=tokenizer,
|
| 327 |
+
mel=mel_segment,
|
| 328 |
+
num_frames=segment_size,
|
| 329 |
+
prepend_punctuations=prepend_punctuations,
|
| 330 |
+
append_punctuations=append_punctuations,
|
| 331 |
+
last_speech_timestamp=last_speech_timestamp,
|
| 332 |
+
)
|
| 333 |
+
word_end_timestamps = [
|
| 334 |
+
w["end"] for s in current_segments for w in s["words"]
|
| 335 |
+
]
|
| 336 |
+
if len(word_end_timestamps) > 0:
|
| 337 |
+
last_speech_timestamp = word_end_timestamps[-1]
|
| 338 |
+
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
| 339 |
+
seek_shift = round(
|
| 340 |
+
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
| 341 |
+
)
|
| 342 |
+
if seek_shift > 0:
|
| 343 |
+
seek = previous_seek + seek_shift
|
| 344 |
+
|
| 345 |
+
if verbose:
|
| 346 |
+
for segment in current_segments:
|
| 347 |
+
start, end, text = segment["start"], segment["end"], segment["text"]
|
| 348 |
+
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
|
| 349 |
+
print(make_safe(line))
|
| 350 |
+
|
| 351 |
+
# if a segment is instantaneous or does not contain text, clear it
|
| 352 |
+
for i, segment in enumerate(current_segments):
|
| 353 |
+
if segment["start"] == segment["end"] or segment["text"].strip() == "":
|
| 354 |
+
segment["text"] = ""
|
| 355 |
+
segment["tokens"] = []
|
| 356 |
+
segment["words"] = []
|
| 357 |
+
|
| 358 |
+
all_segments.extend(
|
| 359 |
+
[
|
| 360 |
+
{"id": i, **segment}
|
| 361 |
+
for i, segment in enumerate(
|
| 362 |
+
current_segments, start=len(all_segments)
|
| 363 |
+
)
|
| 364 |
+
]
|
| 365 |
+
)
|
| 366 |
+
all_tokens.extend(
|
| 367 |
+
[token for segment in current_segments for token in segment["tokens"]]
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
if not condition_on_previous_text or result.temperature > 0.5:
|
| 371 |
+
# do not feed the prompt tokens if a high temperature was used
|
| 372 |
+
prompt_reset_since = len(all_tokens)
|
| 373 |
+
|
| 374 |
+
# update progress bar
|
| 375 |
+
pbar.update(min(content_frames, seek) - previous_seek)
|
| 376 |
+
|
| 377 |
+
return dict(
|
| 378 |
+
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
| 379 |
+
segments=all_segments,
|
| 380 |
+
language=language,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def cli():
|
| 385 |
+
from . import available_models
|
| 386 |
+
|
| 387 |
+
# fmt: off
|
| 388 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 389 |
+
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
| 390 |
+
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
| 391 |
+
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
| 392 |
+
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
| 393 |
+
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
| 394 |
+
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
|
| 395 |
+
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
| 396 |
+
|
| 397 |
+
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
| 398 |
+
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
| 399 |
+
|
| 400 |
+
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
| 401 |
+
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
| 402 |
+
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
| 403 |
+
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
| 404 |
+
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
| 405 |
+
|
| 406 |
+
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
| 407 |
+
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
| 408 |
+
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
| 409 |
+
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
| 410 |
+
|
| 411 |
+
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
| 412 |
+
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
| 413 |
+
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
| 414 |
+
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
| 415 |
+
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
|
| 416 |
+
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
| 417 |
+
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
| 418 |
+
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
| 419 |
+
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
|
| 420 |
+
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
|
| 421 |
+
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
| 422 |
+
# fmt: on
|
| 423 |
+
|
| 424 |
+
args = parser.parse_args().__dict__
|
| 425 |
+
model_name: str = args.pop("model")
|
| 426 |
+
model_dir: str = args.pop("model_dir")
|
| 427 |
+
output_dir: str = args.pop("output_dir")
|
| 428 |
+
output_format: str = args.pop("output_format")
|
| 429 |
+
device: str = args.pop("device")
|
| 430 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 431 |
+
|
| 432 |
+
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
| 433 |
+
if args["language"] is not None:
|
| 434 |
+
warnings.warn(
|
| 435 |
+
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
|
| 436 |
+
)
|
| 437 |
+
args["language"] = "en"
|
| 438 |
+
|
| 439 |
+
temperature = args.pop("temperature")
|
| 440 |
+
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
| 441 |
+
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
|
| 442 |
+
else:
|
| 443 |
+
temperature = [temperature]
|
| 444 |
+
|
| 445 |
+
if (threads := args.pop("threads")) > 0:
|
| 446 |
+
torch.set_num_threads(threads)
|
| 447 |
+
|
| 448 |
+
from . import load_model
|
| 449 |
+
|
| 450 |
+
model = load_model(model_name, device=device, download_root=model_dir)
|
| 451 |
+
|
| 452 |
+
writer = get_writer(output_format, output_dir)
|
| 453 |
+
word_options = ["highlight_words", "max_line_count", "max_line_width"]
|
| 454 |
+
if not args["word_timestamps"]:
|
| 455 |
+
for option in word_options:
|
| 456 |
+
if args[option]:
|
| 457 |
+
parser.error(f"--{option} requires --word_timestamps True")
|
| 458 |
+
if args["max_line_count"] and not args["max_line_width"]:
|
| 459 |
+
warnings.warn("--max_line_count has no effect without --max_line_width")
|
| 460 |
+
writer_args = {arg: args.pop(arg) for arg in word_options}
|
| 461 |
+
for audio_path in args.pop("audio"):
|
| 462 |
+
result = transcribe(model, audio_path, temperature=temperature, **args)
|
| 463 |
+
writer(result, audio_path, writer_args)
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
if __name__ == "__main__":
|
| 467 |
+
cli()
|
simul_whisper/whisper/triton_ops.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import lru_cache
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
except ImportError:
|
| 10 |
+
raise RuntimeError("triton import failed; try `pip install --pre triton`")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@triton.jit
|
| 14 |
+
def dtw_kernel(
|
| 15 |
+
cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr
|
| 16 |
+
):
|
| 17 |
+
offsets = tl.arange(0, BLOCK_SIZE)
|
| 18 |
+
mask = offsets < M
|
| 19 |
+
|
| 20 |
+
for k in range(1, N + M + 1): # k = i + j
|
| 21 |
+
tl.debug_barrier()
|
| 22 |
+
|
| 23 |
+
p0 = cost + (k - 1) * cost_stride
|
| 24 |
+
p1 = cost + k * cost_stride
|
| 25 |
+
p2 = cost + k * cost_stride + 1
|
| 26 |
+
|
| 27 |
+
c0 = tl.load(p0 + offsets, mask=mask)
|
| 28 |
+
c1 = tl.load(p1 + offsets, mask=mask)
|
| 29 |
+
c2 = tl.load(p2 + offsets, mask=mask)
|
| 30 |
+
|
| 31 |
+
x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0)
|
| 32 |
+
cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)
|
| 33 |
+
|
| 34 |
+
cost_ptr = cost + (k + 1) * cost_stride + 1
|
| 35 |
+
tl.store(cost_ptr + offsets, cost_row, mask=mask)
|
| 36 |
+
|
| 37 |
+
trace_ptr = trace + (k + 1) * trace_stride + 1
|
| 38 |
+
tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1))
|
| 39 |
+
tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2))
|
| 40 |
+
tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2))
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@lru_cache(maxsize=None)
|
| 44 |
+
def median_kernel(filter_width: int):
|
| 45 |
+
@triton.jit
|
| 46 |
+
def kernel(
|
| 47 |
+
y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr
|
| 48 |
+
): # x.shape[-1] == filter_width
|
| 49 |
+
row_idx = tl.program_id(0)
|
| 50 |
+
offsets = tl.arange(0, BLOCK_SIZE)
|
| 51 |
+
mask = offsets < y_stride
|
| 52 |
+
|
| 53 |
+
x_ptr = x + row_idx * x_stride # noqa: F841
|
| 54 |
+
y_ptr = y + row_idx * y_stride
|
| 55 |
+
|
| 56 |
+
LOAD_ALL_ROWS_HERE # noqa: F821
|
| 57 |
+
|
| 58 |
+
BUBBLESORT_HERE # noqa: F821
|
| 59 |
+
|
| 60 |
+
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
|
| 61 |
+
|
| 62 |
+
kernel = triton.JITFunction(kernel.fn)
|
| 63 |
+
kernel.src = kernel.src.replace(
|
| 64 |
+
" LOAD_ALL_ROWS_HERE",
|
| 65 |
+
"\n".join(
|
| 66 |
+
[
|
| 67 |
+
f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
|
| 68 |
+
for i in range(filter_width)
|
| 69 |
+
]
|
| 70 |
+
),
|
| 71 |
+
)
|
| 72 |
+
kernel.src = kernel.src.replace(
|
| 73 |
+
" BUBBLESORT_HERE",
|
| 74 |
+
"\n\n".join(
|
| 75 |
+
[
|
| 76 |
+
"\n\n".join(
|
| 77 |
+
[
|
| 78 |
+
"\n".join(
|
| 79 |
+
[
|
| 80 |
+
f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
|
| 81 |
+
f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
|
| 82 |
+
f" row{j} = smaller",
|
| 83 |
+
f" row{j + 1} = larger",
|
| 84 |
+
]
|
| 85 |
+
)
|
| 86 |
+
for j in range(filter_width - i - 1)
|
| 87 |
+
]
|
| 88 |
+
)
|
| 89 |
+
for i in range(filter_width // 2 + 1)
|
| 90 |
+
]
|
| 91 |
+
),
|
| 92 |
+
)
|
| 93 |
+
kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
|
| 94 |
+
|
| 95 |
+
return kernel
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def median_filter_cuda(x: torch.Tensor, filter_width: int):
|
| 99 |
+
"""Apply a median filter of given width along the last dimension of x"""
|
| 100 |
+
slices = x.contiguous().unfold(-1, filter_width, 1)
|
| 101 |
+
grid = np.prod(slices.shape[:-2])
|
| 102 |
+
|
| 103 |
+
kernel = median_kernel(filter_width)
|
| 104 |
+
y = torch.empty_like(slices[..., 0])
|
| 105 |
+
|
| 106 |
+
BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length()
|
| 107 |
+
kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE)
|
| 108 |
+
|
| 109 |
+
return y
|
simul_whisper/whisper/utils.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import sys
|
| 5 |
+
import zlib
|
| 6 |
+
from typing import Callable, Optional, TextIO
|
| 7 |
+
|
| 8 |
+
system_encoding = sys.getdefaultencoding()
|
| 9 |
+
|
| 10 |
+
if system_encoding != "utf-8":
|
| 11 |
+
|
| 12 |
+
def make_safe(string):
|
| 13 |
+
# replaces any character not representable using the system default encoding with an '?',
|
| 14 |
+
# avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
|
| 15 |
+
return string.encode(system_encoding, errors="replace").decode(system_encoding)
|
| 16 |
+
|
| 17 |
+
else:
|
| 18 |
+
|
| 19 |
+
def make_safe(string):
|
| 20 |
+
# utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
|
| 21 |
+
return string
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def exact_div(x, y):
|
| 25 |
+
assert x % y == 0
|
| 26 |
+
return x // y
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def str2bool(string):
|
| 30 |
+
str2val = {"True": True, "False": False}
|
| 31 |
+
if string in str2val:
|
| 32 |
+
return str2val[string]
|
| 33 |
+
else:
|
| 34 |
+
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def optional_int(string):
|
| 38 |
+
return None if string == "None" else int(string)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def optional_float(string):
|
| 42 |
+
return None if string == "None" else float(string)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def compression_ratio(text) -> float:
|
| 46 |
+
text_bytes = text.encode("utf-8")
|
| 47 |
+
return len(text_bytes) / len(zlib.compress(text_bytes))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def format_timestamp(
|
| 51 |
+
seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
|
| 52 |
+
):
|
| 53 |
+
assert seconds >= 0, "non-negative timestamp expected"
|
| 54 |
+
milliseconds = round(seconds * 1000.0)
|
| 55 |
+
|
| 56 |
+
hours = milliseconds // 3_600_000
|
| 57 |
+
milliseconds -= hours * 3_600_000
|
| 58 |
+
|
| 59 |
+
minutes = milliseconds // 60_000
|
| 60 |
+
milliseconds -= minutes * 60_000
|
| 61 |
+
|
| 62 |
+
seconds = milliseconds // 1_000
|
| 63 |
+
milliseconds -= seconds * 1_000
|
| 64 |
+
|
| 65 |
+
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
| 66 |
+
return (
|
| 67 |
+
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class ResultWriter:
|
| 72 |
+
extension: str
|
| 73 |
+
|
| 74 |
+
def __init__(self, output_dir: str):
|
| 75 |
+
self.output_dir = output_dir
|
| 76 |
+
|
| 77 |
+
def __call__(self, result: dict, audio_path: str, options: dict):
|
| 78 |
+
audio_basename = os.path.basename(audio_path)
|
| 79 |
+
audio_basename = os.path.splitext(audio_basename)[0]
|
| 80 |
+
output_path = os.path.join(
|
| 81 |
+
self.output_dir, audio_basename + "." + self.extension
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
| 85 |
+
self.write_result(result, file=f, options=options)
|
| 86 |
+
|
| 87 |
+
def write_result(self, result: dict, file: TextIO, options: dict):
|
| 88 |
+
raise NotImplementedError
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class WriteTXT(ResultWriter):
|
| 92 |
+
extension: str = "txt"
|
| 93 |
+
|
| 94 |
+
def write_result(self, result: dict, file: TextIO, options: dict):
|
| 95 |
+
for segment in result["segments"]:
|
| 96 |
+
print(segment["text"].strip(), file=file, flush=True)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class SubtitlesWriter(ResultWriter):
|
| 100 |
+
always_include_hours: bool
|
| 101 |
+
decimal_marker: str
|
| 102 |
+
|
| 103 |
+
def iterate_result(self, result: dict, options: dict):
|
| 104 |
+
raw_max_line_width: Optional[int] = options["max_line_width"]
|
| 105 |
+
max_line_count: Optional[int] = options["max_line_count"]
|
| 106 |
+
highlight_words: bool = options["highlight_words"]
|
| 107 |
+
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
|
| 108 |
+
preserve_segments = max_line_count is None or raw_max_line_width is None
|
| 109 |
+
|
| 110 |
+
def iterate_subtitles():
|
| 111 |
+
line_len = 0
|
| 112 |
+
line_count = 1
|
| 113 |
+
# the next subtitle to yield (a list of word timings with whitespace)
|
| 114 |
+
subtitle: list[dict] = []
|
| 115 |
+
last = result["segments"][0]["words"][0]["start"]
|
| 116 |
+
for segment in result["segments"]:
|
| 117 |
+
for i, original_timing in enumerate(segment["words"]):
|
| 118 |
+
timing = original_timing.copy()
|
| 119 |
+
long_pause = not preserve_segments and timing["start"] - last > 3.0
|
| 120 |
+
has_room = line_len + len(timing["word"]) <= max_line_width
|
| 121 |
+
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
| 122 |
+
if line_len > 0 and has_room and not long_pause and not seg_break:
|
| 123 |
+
# line continuation
|
| 124 |
+
line_len += len(timing["word"])
|
| 125 |
+
else:
|
| 126 |
+
# new line
|
| 127 |
+
timing["word"] = timing["word"].strip()
|
| 128 |
+
if (
|
| 129 |
+
len(subtitle) > 0
|
| 130 |
+
and max_line_count is not None
|
| 131 |
+
and (long_pause or line_count >= max_line_count)
|
| 132 |
+
or seg_break
|
| 133 |
+
):
|
| 134 |
+
# subtitle break
|
| 135 |
+
yield subtitle
|
| 136 |
+
subtitle = []
|
| 137 |
+
line_count = 1
|
| 138 |
+
elif line_len > 0:
|
| 139 |
+
# line break
|
| 140 |
+
line_count += 1
|
| 141 |
+
timing["word"] = "\n" + timing["word"]
|
| 142 |
+
line_len = len(timing["word"].strip())
|
| 143 |
+
subtitle.append(timing)
|
| 144 |
+
last = timing["start"]
|
| 145 |
+
if len(subtitle) > 0:
|
| 146 |
+
yield subtitle
|
| 147 |
+
|
| 148 |
+
if len(result["segments"]) > 0 and "words" in result["segments"][0]:
|
| 149 |
+
for subtitle in iterate_subtitles():
|
| 150 |
+
subtitle_start = self.format_timestamp(subtitle[0]["start"])
|
| 151 |
+
subtitle_end = self.format_timestamp(subtitle[-1]["end"])
|
| 152 |
+
subtitle_text = "".join([word["word"] for word in subtitle])
|
| 153 |
+
if highlight_words:
|
| 154 |
+
last = subtitle_start
|
| 155 |
+
all_words = [timing["word"] for timing in subtitle]
|
| 156 |
+
for i, this_word in enumerate(subtitle):
|
| 157 |
+
start = self.format_timestamp(this_word["start"])
|
| 158 |
+
end = self.format_timestamp(this_word["end"])
|
| 159 |
+
if last != start:
|
| 160 |
+
yield last, start, subtitle_text
|
| 161 |
+
|
| 162 |
+
yield start, end, "".join(
|
| 163 |
+
[
|
| 164 |
+
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
| 165 |
+
if j == i
|
| 166 |
+
else word
|
| 167 |
+
for j, word in enumerate(all_words)
|
| 168 |
+
]
|
| 169 |
+
)
|
| 170 |
+
last = end
|
| 171 |
+
else:
|
| 172 |
+
yield subtitle_start, subtitle_end, subtitle_text
|
| 173 |
+
else:
|
| 174 |
+
for segment in result["segments"]:
|
| 175 |
+
segment_start = self.format_timestamp(segment["start"])
|
| 176 |
+
segment_end = self.format_timestamp(segment["end"])
|
| 177 |
+
segment_text = segment["text"].strip().replace("-->", "->")
|
| 178 |
+
yield segment_start, segment_end, segment_text
|
| 179 |
+
|
| 180 |
+
def format_timestamp(self, seconds: float):
|
| 181 |
+
return format_timestamp(
|
| 182 |
+
seconds=seconds,
|
| 183 |
+
always_include_hours=self.always_include_hours,
|
| 184 |
+
decimal_marker=self.decimal_marker,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class WriteVTT(SubtitlesWriter):
|
| 189 |
+
extension: str = "vtt"
|
| 190 |
+
always_include_hours: bool = False
|
| 191 |
+
decimal_marker: str = "."
|
| 192 |
+
|
| 193 |
+
def write_result(self, result: dict, file: TextIO, options: dict):
|
| 194 |
+
print("WEBVTT\n", file=file)
|
| 195 |
+
for start, end, text in self.iterate_result(result, options):
|
| 196 |
+
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class WriteSRT(SubtitlesWriter):
|
| 200 |
+
extension: str = "srt"
|
| 201 |
+
always_include_hours: bool = True
|
| 202 |
+
decimal_marker: str = ","
|
| 203 |
+
|
| 204 |
+
def write_result(self, result: dict, file: TextIO, options: dict):
|
| 205 |
+
for i, (start, end, text) in enumerate(
|
| 206 |
+
self.iterate_result(result, options), start=1
|
| 207 |
+
):
|
| 208 |
+
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class WriteTSV(ResultWriter):
|
| 212 |
+
"""
|
| 213 |
+
Write a transcript to a file in TSV (tab-separated values) format containing lines like:
|
| 214 |
+
<start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
|
| 215 |
+
|
| 216 |
+
Using integer milliseconds as start and end times means there's no chance of interference from
|
| 217 |
+
an environment setting a language encoding that causes the decimal in a floating point number
|
| 218 |
+
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
|
| 219 |
+
"""
|
| 220 |
+
|
| 221 |
+
extension: str = "tsv"
|
| 222 |
+
|
| 223 |
+
def write_result(self, result: dict, file: TextIO, options: dict):
|
| 224 |
+
print("start", "end", "text", sep="\t", file=file)
|
| 225 |
+
for segment in result["segments"]:
|
| 226 |
+
print(round(1000 * segment["start"]), file=file, end="\t")
|
| 227 |
+
print(round(1000 * segment["end"]), file=file, end="\t")
|
| 228 |
+
print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class WriteJSON(ResultWriter):
|
| 232 |
+
extension: str = "json"
|
| 233 |
+
|
| 234 |
+
def write_result(self, result: dict, file: TextIO, options: dict):
|
| 235 |
+
json.dump(result, file)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def get_writer(
|
| 239 |
+
output_format: str, output_dir: str
|
| 240 |
+
) -> Callable[[dict, TextIO, dict], None]:
|
| 241 |
+
writers = {
|
| 242 |
+
"txt": WriteTXT,
|
| 243 |
+
"vtt": WriteVTT,
|
| 244 |
+
"srt": WriteSRT,
|
| 245 |
+
"tsv": WriteTSV,
|
| 246 |
+
"json": WriteJSON,
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
if output_format == "all":
|
| 250 |
+
all_writers = [writer(output_dir) for writer in writers.values()]
|
| 251 |
+
|
| 252 |
+
def write_all(result: dict, file: TextIO, options: dict):
|
| 253 |
+
for writer in all_writers:
|
| 254 |
+
writer(result, file, options)
|
| 255 |
+
|
| 256 |
+
return write_all
|
| 257 |
+
|
| 258 |
+
return writers[output_format](output_dir)
|
simul_whisper/whisper/version.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__version__ = "20230918"
|
simulstreaming_whisper.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from whisper_streaming.base import OnlineProcessorInterface, ASRBase
|
| 2 |
+
import argparse
|
| 3 |
+
|
| 4 |
+
import sys
|
| 5 |
+
import logging
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from simul_whisper.config import AlignAttConfig
|
| 9 |
+
from simul_whisper.simul_whisper import PaddedAlignAttWhisper
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
def simulwhisper_args(parser):
|
| 14 |
+
group = parser.add_argument_group('Whisper arguments')
|
| 15 |
+
group.add_argument('--model_path', type=str, default='./large-v3.pt',
|
| 16 |
+
help='The file path to the Whisper .pt model. If not present on the filesystem, the model is downloaded automatically.')
|
| 17 |
+
group.add_argument("--beams","-b", type=int, default=1, help="Number of beams for beam search decoding. If 1, GreedyDecoder is used.")
|
| 18 |
+
group.add_argument("--decoder",type=str, default=None, help="Override automatic selection of beam or greedy decoder. "
|
| 19 |
+
"If beams > 1 and greedy: invalid.")
|
| 20 |
+
|
| 21 |
+
group = parser.add_argument_group('Audio buffer')
|
| 22 |
+
group.add_argument('--audio_max_len', type=float, default=30.0,
|
| 23 |
+
help='Max length of the audio buffer, in seconds.')
|
| 24 |
+
group.add_argument('--audio_min_len', type=float, default=0.0,
|
| 25 |
+
help='Skip processing if the audio buffer is shorter than this length, in seconds. Useful when the --min-chunk-size is small.')
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
group = parser.add_argument_group('AlignAtt argument')
|
| 29 |
+
group.add_argument('--frame_threshold', type=int, default=25,
|
| 30 |
+
help='Threshold for the attention-guided decoding. The AlignAtt policy will decode only ' \
|
| 31 |
+
'until this number of frames from the end of audio. In frames: one frame is 0.02 seconds for large-v3 model. ')
|
| 32 |
+
|
| 33 |
+
group = parser.add_argument_group('Truncation of the last decoded word (from Simul-Whisper)')
|
| 34 |
+
group.add_argument('--cif_ckpt_path', type=str, default=None,
|
| 35 |
+
help='The file path to the Simul-Whisper\'s CIF model checkpoint that detects whether there is' \
|
| 36 |
+
'end of word at the end of the chunk. If not, the last decoded space-separated word is truncated ' \
|
| 37 |
+
'because it is often wrong -- transcribing a word in the middle.' \
|
| 38 |
+
'The CIF model adapted for the Whisper model version should be used. ' \
|
| 39 |
+
'Find the models in https://github.com/backspacetg/simul_whisper/tree/main/cif_models . ' \
|
| 40 |
+
'Note that there is no model for large-v3.')
|
| 41 |
+
group.add_argument("--never_fire", action=argparse.BooleanOptionalAction, default=False,
|
| 42 |
+
help="Override the CIF model. If True, the last word is NEVER truncated, no matter what the CIF model detects. " \
|
| 43 |
+
". If False: if CIF model path is set, the last word is SOMETIMES truncated, depending on the CIF detection. " \
|
| 44 |
+
"Otherwise, if the CIF model path is not set, the last word is ALWAYS trimmed.")
|
| 45 |
+
|
| 46 |
+
group = parser.add_argument_group("Prompt and context")
|
| 47 |
+
group.add_argument("--init_prompt",type=str, default=None, help="Init prompt for the model. It should be in the target language.")
|
| 48 |
+
group.add_argument("--static_init_prompt",type=str, default=None, help="Do not scroll over this text. It can contain terminology that should be relevant over all document.")
|
| 49 |
+
group.add_argument("--max_context_tokens",type=int, default=None, help="Max context tokens for the model. Default is 0.")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def simul_asr_factory(args):
|
| 53 |
+
logger.setLevel(args.log_level)
|
| 54 |
+
decoder = args.decoder
|
| 55 |
+
if args.beams > 1:
|
| 56 |
+
if decoder == "greedy":
|
| 57 |
+
raise ValueError("Invalid 'greedy' decoder type for beams > 1. Use 'beam'.")
|
| 58 |
+
elif decoder is None or decoder == "beam":
|
| 59 |
+
decoder = "beam"
|
| 60 |
+
else:
|
| 61 |
+
raise ValueError("Invalid decoder type. Use 'beam' or 'greedy'.")
|
| 62 |
+
else:
|
| 63 |
+
if decoder is None:
|
| 64 |
+
decoder = "greedy"
|
| 65 |
+
elif decoder not in ("beam","greedy"):
|
| 66 |
+
raise ValueError("Invalid decoder type. Use 'beam' or 'greedy'.")
|
| 67 |
+
# else: it is greedy or beam, that's ok
|
| 68 |
+
|
| 69 |
+
a = { v:getattr(args, v) for v in ["model_path", "cif_ckpt_path", "frame_threshold", "audio_min_len", "audio_max_len", "beams", "task",
|
| 70 |
+
"never_fire", 'init_prompt', 'static_init_prompt', 'max_context_tokens', "logdir"
|
| 71 |
+
]}
|
| 72 |
+
a["language"] = args.lan
|
| 73 |
+
a["segment_length"] = args.min_chunk_size
|
| 74 |
+
a["decoder_type"] = decoder
|
| 75 |
+
|
| 76 |
+
if args.min_chunk_size >= args.audio_max_len:
|
| 77 |
+
raise ValueError("min_chunk_size must be smaller than audio_max_len")
|
| 78 |
+
if args.audio_min_len > args.audio_max_len:
|
| 79 |
+
raise ValueError("audio_min_len must be smaller than audio_max_len")
|
| 80 |
+
logger.info(f"Arguments: {a}")
|
| 81 |
+
asr = SimulWhisperASR(**a)
|
| 82 |
+
return asr, SimulWhisperOnline(asr)
|
| 83 |
+
|
| 84 |
+
class SimulWhisperASR(ASRBase):
|
| 85 |
+
|
| 86 |
+
sep = " "
|
| 87 |
+
|
| 88 |
+
def __init__(self, language, model_path, cif_ckpt_path, frame_threshold, audio_max_len, audio_min_len, segment_length, beams, task,
|
| 89 |
+
decoder_type, never_fire, init_prompt, static_init_prompt, max_context_tokens, logdir):
|
| 90 |
+
cfg = AlignAttConfig(
|
| 91 |
+
model_path=model_path,
|
| 92 |
+
segment_length=segment_length,
|
| 93 |
+
frame_threshold=frame_threshold,
|
| 94 |
+
language=language,
|
| 95 |
+
audio_max_len=audio_max_len,
|
| 96 |
+
audio_min_len=audio_min_len,
|
| 97 |
+
cif_ckpt_path=cif_ckpt_path,
|
| 98 |
+
decoder_type=decoder_type, #"greedy" if beams==1 else "beam",
|
| 99 |
+
beam_size=beams,
|
| 100 |
+
task=task,
|
| 101 |
+
never_fire=never_fire,
|
| 102 |
+
init_prompt=init_prompt,
|
| 103 |
+
max_context_tokens=max_context_tokens,
|
| 104 |
+
static_init_prompt=static_init_prompt,
|
| 105 |
+
logdir=logdir,
|
| 106 |
+
)
|
| 107 |
+
logger.info(f"Language: {language}")
|
| 108 |
+
self.model = PaddedAlignAttWhisper(cfg)
|
| 109 |
+
|
| 110 |
+
def transcribe(self, audio, init_prompt=""):
|
| 111 |
+
logger.info("SimulWhisperASR's transcribe() should not be used. It's here only temporarily." \
|
| 112 |
+
"Instead, use SimulWhisperOnline.process_iter().")
|
| 113 |
+
raise NotImplementedError("Use SimulWhisperOnline.process_iter() instead of transcribe().")
|
| 114 |
+
|
| 115 |
+
def warmup(self, audio, init_prompt=""):
|
| 116 |
+
self.model.insert_audio(audio)
|
| 117 |
+
self.model.infer(True)
|
| 118 |
+
self.model.refresh_segment(complete=True)
|
| 119 |
+
|
| 120 |
+
def use_vad(self):
|
| 121 |
+
print("VAD not implemented",file=sys.stderr)
|
| 122 |
+
|
| 123 |
+
def set_translate_task(self):
|
| 124 |
+
# this is not used. Translate task is set another way.
|
| 125 |
+
pass
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class SimulWhisperOnline(OnlineProcessorInterface):
|
| 129 |
+
|
| 130 |
+
def __init__(self, asr):
|
| 131 |
+
self.model = asr.model
|
| 132 |
+
self.file = None
|
| 133 |
+
self.init()
|
| 134 |
+
|
| 135 |
+
def init(self, offset=None):
|
| 136 |
+
self.audio_chunks = []
|
| 137 |
+
if offset is not None:
|
| 138 |
+
self.offset = offset
|
| 139 |
+
else:
|
| 140 |
+
self.offset = 0
|
| 141 |
+
self.is_last = False
|
| 142 |
+
self.beg = self.offset
|
| 143 |
+
self.end = self.offset
|
| 144 |
+
|
| 145 |
+
self.audio_bufer_offset = self.offset
|
| 146 |
+
self.last_ts = -1
|
| 147 |
+
self.model.refresh_segment(complete=True)
|
| 148 |
+
|
| 149 |
+
self.unicode_buffer = [] # hide incomplete unicode character for the next iteration
|
| 150 |
+
|
| 151 |
+
def insert_audio_chunk(self, audio):
|
| 152 |
+
self.audio_chunks.append(torch.from_numpy(audio))
|
| 153 |
+
|
| 154 |
+
def timestamped_text(self, tokens, generation):
|
| 155 |
+
if not generation:
|
| 156 |
+
return []
|
| 157 |
+
|
| 158 |
+
pr = generation["progress"]
|
| 159 |
+
if "result" not in generation or self.unicode_buffer != []:
|
| 160 |
+
split_words, split_tokens = self.model.tokenizer.split_to_word_tokens(tokens)
|
| 161 |
+
else:
|
| 162 |
+
split_words, split_tokens = generation["result"]["split_words"], generation["result"]["split_tokens"]
|
| 163 |
+
|
| 164 |
+
frames = [p["most_attended_frames"][0] for p in pr]
|
| 165 |
+
if frames and self.unicode_buffer != []:
|
| 166 |
+
a = [frames[0]] * len(self.unicode_buffer)
|
| 167 |
+
frames = a + frames
|
| 168 |
+
|
| 169 |
+
tokens = tokens.copy()
|
| 170 |
+
ret = []
|
| 171 |
+
for sw,st in zip(split_words,split_tokens):
|
| 172 |
+
b = None
|
| 173 |
+
for stt in st:
|
| 174 |
+
t,f = tokens.pop(0), frames.pop(0)
|
| 175 |
+
if t != stt:
|
| 176 |
+
raise ValueError(f"Token mismatch: {t} != {stt} at frame {f}.")
|
| 177 |
+
if b is None:
|
| 178 |
+
b = f
|
| 179 |
+
e = f
|
| 180 |
+
out = {
|
| 181 |
+
'start': b * 0.02 + self.audio_bufer_offset,
|
| 182 |
+
'end': e * 0.02 + self.audio_bufer_offset,
|
| 183 |
+
'text': sw,
|
| 184 |
+
'tokens': st
|
| 185 |
+
}
|
| 186 |
+
ret.append(out)
|
| 187 |
+
logger.debug(f"TS-WORD-INFO: {out}")
|
| 188 |
+
return ret
|
| 189 |
+
|
| 190 |
+
def hide_incomplete_unicode(self, tokens):
|
| 191 |
+
"""Sometimes, the last token is an imcomplete unicode character, e.g. a part of "ň" or "ř".
|
| 192 |
+
Without this, the outputs can end with '�' = Unicode Replacement Character, and the next output also
|
| 193 |
+
starts with '�'.
|
| 194 |
+
This function hides the last incomplete unicode character and adds it in the next iteration.
|
| 195 |
+
"""
|
| 196 |
+
if self.unicode_buffer != []:
|
| 197 |
+
logger.debug(f"Hiding incomplete unicode character: {self.unicode_buffer}")
|
| 198 |
+
tokens = self.unicode_buffer + tokens
|
| 199 |
+
self.unicode_buffer = [] # clear the buffer after processing
|
| 200 |
+
chars, _ = self.model.tokenizer.split_tokens_on_unicode(tokens)
|
| 201 |
+
if len(chars) > 0 and chars[-1].endswith('�'):
|
| 202 |
+
self.unicode_buffer = tokens[-1:] # keep the last incomplete unicode character
|
| 203 |
+
logger.debug(f"Hiding incomplete unicode character: {tokens[-1:]}")
|
| 204 |
+
return tokens[:-1] # remove the last token, which is incomplete unicode character
|
| 205 |
+
return tokens
|
| 206 |
+
|
| 207 |
+
def process_iter(self):
|
| 208 |
+
if len(self.audio_chunks) == 0:
|
| 209 |
+
audio = None
|
| 210 |
+
else:
|
| 211 |
+
audio = torch.cat(self.audio_chunks, dim=0)
|
| 212 |
+
if audio.shape[0] == 0:
|
| 213 |
+
audio = None
|
| 214 |
+
else:
|
| 215 |
+
self.end += audio.shape[0] / self.SAMPLING_RATE
|
| 216 |
+
self.audio_chunks = []
|
| 217 |
+
self.audio_bufer_offset += self.model.insert_audio(audio)
|
| 218 |
+
tokens, generation_progress = self.model.infer(is_last=self.is_last)
|
| 219 |
+
|
| 220 |
+
tokens = self.hide_incomplete_unicode(tokens)
|
| 221 |
+
|
| 222 |
+
text = self.model.tokenizer.decode(tokens)
|
| 223 |
+
if len(text) == 0:
|
| 224 |
+
return {}
|
| 225 |
+
|
| 226 |
+
# word-level timestamps
|
| 227 |
+
ts_words = self.timestamped_text(tokens, generation_progress)
|
| 228 |
+
|
| 229 |
+
self.beg = min(word['start'] for word in ts_words) # it should be this
|
| 230 |
+
self.beg = max(self.beg, self.last_ts + 0.001) # but let's create the timestamps non-decreasing -- at least last beg + 1
|
| 231 |
+
if self.is_last:
|
| 232 |
+
e = self.end
|
| 233 |
+
else:
|
| 234 |
+
e = max(word['end'] for word in ts_words)
|
| 235 |
+
e = max(e, self.beg + 0.001)
|
| 236 |
+
|
| 237 |
+
self.last_ts = e
|
| 238 |
+
|
| 239 |
+
# return (self.beg,e,text)
|
| 240 |
+
return {
|
| 241 |
+
'start': self.beg,
|
| 242 |
+
'end': e,
|
| 243 |
+
'text': text,
|
| 244 |
+
'tokens': tokens,
|
| 245 |
+
'words': ts_words
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
def finish(self):
|
| 249 |
+
logger.info("Finish")
|
| 250 |
+
self.is_last = True
|
| 251 |
+
o = self.process_iter()
|
| 252 |
+
self.is_last = False
|
| 253 |
+
self.model.refresh_segment(complete=True)
|
| 254 |
+
return o
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
if __name__ == "__main__":
|
| 258 |
+
|
| 259 |
+
from whisper_streaming.whisper_online_main import main_simulation_from_file
|
| 260 |
+
main_simulation_from_file(simul_asr_factory, add_args=simulwhisper_args)
|
token_buffer.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import sys
|
| 3 |
+
class TokenBuffer:
|
| 4 |
+
|
| 5 |
+
def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]):
|
| 6 |
+
self.text = text
|
| 7 |
+
self.prefix_token_ids = prefix_token_ids
|
| 8 |
+
self.tokenizer = tokenizer
|
| 9 |
+
self.device = device
|
| 10 |
+
|
| 11 |
+
def as_token_ids(self, tokenizer=None):
|
| 12 |
+
|
| 13 |
+
if tokenizer is None:
|
| 14 |
+
tokenizer = self.tokenizer
|
| 15 |
+
if tokenizer is None:
|
| 16 |
+
raise ValueError("Tokenizer is not set.")
|
| 17 |
+
return self.prefix_token_ids + tokenizer.encode(self.text)
|
| 18 |
+
|
| 19 |
+
def as_tensor(self, device=None):
|
| 20 |
+
if device is None:
|
| 21 |
+
device = self.device
|
| 22 |
+
if device is None:
|
| 23 |
+
raise ValueError("Device is not set.")
|
| 24 |
+
tok_ids = self.as_token_ids()
|
| 25 |
+
return torch.tensor(tok_ids,
|
| 26 |
+
dtype=torch.long, device=device).unsqueeze(0)
|
| 27 |
+
|
| 28 |
+
def as_tensor_beam(self, beam, device=None):
|
| 29 |
+
t = self.as_tensor(device=device)
|
| 30 |
+
return t.repeat_interleave(beam, dim=0)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def as_text(self):
|
| 34 |
+
return self.text
|
| 35 |
+
|
| 36 |
+
@staticmethod
|
| 37 |
+
def empty(*a, **kw):
|
| 38 |
+
return TokenBuffer(*a,**kw)
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def from_text(text, *a, **kw):
|
| 42 |
+
return TokenBuffer(*a, text=text, **kw)
|
| 43 |
+
|
| 44 |
+
def is_empty(self):
|
| 45 |
+
return self.text is None or self.text == ""
|
| 46 |
+
|
| 47 |
+
def trim_words(self, num=1, after=0):
|
| 48 |
+
'''
|
| 49 |
+
num: how many words to trim from the beginning
|
| 50 |
+
after: how many characters to skip (length of the static prompt)
|
| 51 |
+
'''
|
| 52 |
+
tokenizer = self.tokenizer
|
| 53 |
+
assert tokenizer is not None, "Tokenizer is not set."
|
| 54 |
+
|
| 55 |
+
ids = tokenizer.encode(self.text[after:])
|
| 56 |
+
words, wids = self.tokenizer.split_to_word_tokens(ids)
|
| 57 |
+
# print(words, file=sys.stderr)
|
| 58 |
+
# print(wids, file=sys.stderr)
|
| 59 |
+
if not words:
|
| 60 |
+
return 0
|
| 61 |
+
self.text = self.text[:after] + "".join(words[num:])
|
| 62 |
+
return sum(len(wi) for wi in wids[:num])
|
| 63 |
+
|
| 64 |
+
def append_token_ids(self, token_ids):
|
| 65 |
+
tokenizer = self.tokenizer
|
| 66 |
+
assert tokenizer is not None, "Tokenizer is not set."
|
| 67 |
+
self.text += self.tokenizer.decode(token_ids)
|
| 68 |
+
|
| 69 |
+
def as_split_word_tokens(self):
|
| 70 |
+
tokenizer = self.tokenizer
|
| 71 |
+
assert tokenizer is not None, "Tokenizer is not set."
|
| 72 |
+
ids = tokenizer.encode(self.text)
|
| 73 |
+
return tokenizer.split_to_word_tokens(ids)
|
whisper_streaming/base.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
class ASRBase:
|
| 3 |
+
|
| 4 |
+
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
| 5 |
+
# "" for faster-whisper because it emits the spaces when neeeded)
|
| 6 |
+
|
| 7 |
+
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
|
| 8 |
+
self.logfile = logfile
|
| 9 |
+
|
| 10 |
+
self.transcribe_kargs = {}
|
| 11 |
+
if lan == "auto":
|
| 12 |
+
self.original_language = None
|
| 13 |
+
else:
|
| 14 |
+
self.original_language = lan
|
| 15 |
+
|
| 16 |
+
self.model = self.load_model(modelsize, cache_dir, model_dir)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def load_model(self, modelsize, cache_dir):
|
| 20 |
+
raise NotImplemented("must be implemented in the child class")
|
| 21 |
+
|
| 22 |
+
def transcribe(self, audio, init_prompt=""):
|
| 23 |
+
raise NotImplemented("must be implemented in the child class")
|
| 24 |
+
|
| 25 |
+
def warmup(self, audio, init_prompt=""):
|
| 26 |
+
return self.transcribe(audio, init_prompt)
|
| 27 |
+
|
| 28 |
+
def use_vad(self):
|
| 29 |
+
raise NotImplemented("must be implemented in the child class")
|
| 30 |
+
|
| 31 |
+
def set_translate_task(self):
|
| 32 |
+
raise NotImplemented("must be implemented in the child class")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class OnlineProcessorInterface:
|
| 36 |
+
|
| 37 |
+
SAMPLING_RATE = 16000
|
| 38 |
+
|
| 39 |
+
def insert_audio_chunk(self, audio):
|
| 40 |
+
raise NotImplementedError("must be implemented in child class")
|
| 41 |
+
|
| 42 |
+
def process_iter(self):
|
| 43 |
+
raise NotImplementedError("must be implemented in child class")
|
| 44 |
+
|
| 45 |
+
def finish(self):
|
| 46 |
+
raise NotImplementedError("must be implemented in child class")
|
whisper_streaming/line_packet.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
"""Functions for sending and receiving individual lines of text over a socket.
|
| 4 |
+
|
| 5 |
+
A line is transmitted using one or more fixed-size packets of UTF-8 bytes
|
| 6 |
+
containing:
|
| 7 |
+
|
| 8 |
+
- Zero or more bytes of UTF-8, excluding \n and \0, followed by
|
| 9 |
+
|
| 10 |
+
- Zero or more \0 bytes as required to pad the packet to PACKET_SIZE
|
| 11 |
+
|
| 12 |
+
Originally from the UEDIN team of the ELITR project.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
PACKET_SIZE = 65536
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def send_one_line(socket, text, pad_zeros=False):
|
| 19 |
+
"""Sends a line of text over the given socket.
|
| 20 |
+
|
| 21 |
+
The 'text' argument should contain a single line of text (line break
|
| 22 |
+
characters are optional). Line boundaries are determined by Python's
|
| 23 |
+
str.splitlines() function [1]. We also count '\0' as a line terminator.
|
| 24 |
+
If 'text' contains multiple lines then only the first will be sent.
|
| 25 |
+
|
| 26 |
+
If the send fails then an exception will be raised.
|
| 27 |
+
|
| 28 |
+
[1] https://docs.python.org/3.5/library/stdtypes.html#str.splitlines
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
socket: a socket object.
|
| 32 |
+
text: string containing a line of text for transmission.
|
| 33 |
+
"""
|
| 34 |
+
text.replace('\0', '\n')
|
| 35 |
+
lines = text.splitlines()
|
| 36 |
+
first_line = '' if len(lines) == 0 else lines[0]
|
| 37 |
+
# TODO Is there a better way of handling bad input than 'replace'?
|
| 38 |
+
data = first_line.encode('utf-8', errors='replace') + b'\n' + (b'\0' if pad_zeros else b'')
|
| 39 |
+
for offset in range(0, len(data), PACKET_SIZE):
|
| 40 |
+
bytes_remaining = len(data) - offset
|
| 41 |
+
if bytes_remaining < PACKET_SIZE:
|
| 42 |
+
padding_length = PACKET_SIZE - bytes_remaining
|
| 43 |
+
packet = data[offset:] + (b'\0' * padding_length if pad_zeros else b'')
|
| 44 |
+
else:
|
| 45 |
+
packet = data[offset:offset+PACKET_SIZE]
|
| 46 |
+
socket.sendall(packet)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def receive_one_line(socket):
|
| 50 |
+
"""Receives a line of text from the given socket.
|
| 51 |
+
|
| 52 |
+
This function will (attempt to) receive a single line of text. If data is
|
| 53 |
+
currently unavailable then it will block until data becomes available or
|
| 54 |
+
the sender has closed the connection (in which case it will return an
|
| 55 |
+
empty string).
|
| 56 |
+
|
| 57 |
+
The string should not contain any newline characters, but if it does then
|
| 58 |
+
only the first line will be returned.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
socket: a socket object.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
A string representing a single line with a terminating newline or
|
| 65 |
+
None if the connection has been closed.
|
| 66 |
+
"""
|
| 67 |
+
data = b''
|
| 68 |
+
while True:
|
| 69 |
+
packet = socket.recv(PACKET_SIZE)
|
| 70 |
+
if not packet: # Connection has been closed.
|
| 71 |
+
return None
|
| 72 |
+
data += packet
|
| 73 |
+
if b'\0' in packet:
|
| 74 |
+
break
|
| 75 |
+
# TODO Is there a better way of handling bad input than 'replace'?
|
| 76 |
+
text = data.decode('utf-8', errors='replace').strip('\0')
|
| 77 |
+
lines = text.split('\n')
|
| 78 |
+
return lines[0] + '\n'
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def receive_lines(socket):
|
| 82 |
+
try:
|
| 83 |
+
data = socket.recv(PACKET_SIZE)
|
| 84 |
+
except BlockingIOError:
|
| 85 |
+
return []
|
| 86 |
+
if data is None: # Connection has been closed.
|
| 87 |
+
return None
|
| 88 |
+
# TODO Is there a better way of handling bad input than 'replace'?
|
| 89 |
+
text = data.decode('utf-8', errors='replace').strip('\0')
|
| 90 |
+
lines = text.split('\n')
|
| 91 |
+
if len(lines)==1 and not lines[0]:
|
| 92 |
+
return None
|
| 93 |
+
return lines
|
whisper_streaming/silero_vad_iterator.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
# This is copied from silero-vad's vad_utils.py:
|
| 4 |
+
# https://github.com/snakers4/silero-vad/blob/94811cbe1207ec24bc0f5370b895364b8934936f/src/silero_vad/utils_vad.py#L398C1-L489C20
|
| 5 |
+
# (except changed defaults)
|
| 6 |
+
|
| 7 |
+
# Their licence is MIT, same as ours: https://github.com/snakers4/silero-vad/blob/94811cbe1207ec24bc0f5370b895364b8934936f/LICENSE
|
| 8 |
+
|
| 9 |
+
class VADIterator:
|
| 10 |
+
def __init__(self,
|
| 11 |
+
model,
|
| 12 |
+
threshold: float = 0.5,
|
| 13 |
+
sampling_rate: int = 16000,
|
| 14 |
+
min_silence_duration_ms: int = 500, # makes sense on one recording that I checked
|
| 15 |
+
speech_pad_ms: int = 100 # same
|
| 16 |
+
):
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
Class for stream imitation
|
| 20 |
+
|
| 21 |
+
Parameters
|
| 22 |
+
----------
|
| 23 |
+
model: preloaded .jit/.onnx silero VAD model
|
| 24 |
+
|
| 25 |
+
threshold: float (default - 0.5)
|
| 26 |
+
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
|
| 27 |
+
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
| 28 |
+
|
| 29 |
+
sampling_rate: int (default - 16000)
|
| 30 |
+
Currently silero VAD models support 8000 and 16000 sample rates
|
| 31 |
+
|
| 32 |
+
min_silence_duration_ms: int (default - 100 milliseconds)
|
| 33 |
+
In the end of each speech chunk wait for min_silence_duration_ms before separating it
|
| 34 |
+
|
| 35 |
+
speech_pad_ms: int (default - 30 milliseconds)
|
| 36 |
+
Final speech chunks are padded by speech_pad_ms each side
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
self.model = model
|
| 40 |
+
self.threshold = threshold
|
| 41 |
+
self.sampling_rate = sampling_rate
|
| 42 |
+
|
| 43 |
+
if sampling_rate not in [8000, 16000]:
|
| 44 |
+
raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
|
| 45 |
+
|
| 46 |
+
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
| 47 |
+
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
| 48 |
+
self.reset_states()
|
| 49 |
+
|
| 50 |
+
def reset_states(self):
|
| 51 |
+
|
| 52 |
+
self.model.reset_states()
|
| 53 |
+
self.triggered = False
|
| 54 |
+
self.temp_end = 0
|
| 55 |
+
self.current_sample = 0
|
| 56 |
+
|
| 57 |
+
@torch.no_grad()
|
| 58 |
+
def __call__(self, x, return_seconds=False, time_resolution: int = 1):
|
| 59 |
+
"""
|
| 60 |
+
x: torch.Tensor
|
| 61 |
+
audio chunk (see examples in repo)
|
| 62 |
+
|
| 63 |
+
return_seconds: bool (default - False)
|
| 64 |
+
whether return timestamps in seconds (default - samples)
|
| 65 |
+
|
| 66 |
+
time_resolution: int (default - 1)
|
| 67 |
+
time resolution of speech coordinates when requested as seconds
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
if not torch.is_tensor(x):
|
| 71 |
+
try:
|
| 72 |
+
x = torch.Tensor(x)
|
| 73 |
+
except:
|
| 74 |
+
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
|
| 75 |
+
|
| 76 |
+
window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
|
| 77 |
+
self.current_sample += window_size_samples
|
| 78 |
+
|
| 79 |
+
speech_prob = self.model(x, self.sampling_rate).item()
|
| 80 |
+
|
| 81 |
+
if (speech_prob >= self.threshold) and self.temp_end:
|
| 82 |
+
self.temp_end = 0
|
| 83 |
+
|
| 84 |
+
if (speech_prob >= self.threshold) and not self.triggered:
|
| 85 |
+
self.triggered = True
|
| 86 |
+
speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
|
| 87 |
+
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, time_resolution)}
|
| 88 |
+
|
| 89 |
+
if (speech_prob < self.threshold - 0.15) and self.triggered:
|
| 90 |
+
if not self.temp_end:
|
| 91 |
+
self.temp_end = self.current_sample
|
| 92 |
+
if self.current_sample - self.temp_end < self.min_silence_samples:
|
| 93 |
+
return None
|
| 94 |
+
else:
|
| 95 |
+
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
|
| 96 |
+
self.temp_end = 0
|
| 97 |
+
self.triggered = False
|
| 98 |
+
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, time_resolution)}
|
| 99 |
+
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
#######################
|
| 103 |
+
# because Silero now requires exactly 512-sized audio chunks
|
| 104 |
+
|
| 105 |
+
import numpy as np
|
| 106 |
+
class FixedVADIterator(VADIterator):
|
| 107 |
+
'''It fixes VADIterator by allowing to process any audio length, not only exactly 512 frames at once.
|
| 108 |
+
If audio to be processed at once is long and multiple voiced segments detected,
|
| 109 |
+
then __call__ returns the start of the first segment, and end (or middle, which means no end) of the last segment.
|
| 110 |
+
'''
|
| 111 |
+
|
| 112 |
+
def reset_states(self):
|
| 113 |
+
super().reset_states()
|
| 114 |
+
self.buffer = np.array([],dtype=np.float32)
|
| 115 |
+
|
| 116 |
+
def __call__(self, x, return_seconds=False):
|
| 117 |
+
self.buffer = np.append(self.buffer, x)
|
| 118 |
+
ret = None
|
| 119 |
+
while len(self.buffer) >= 512:
|
| 120 |
+
r = super().__call__(self.buffer[:512], return_seconds=return_seconds)
|
| 121 |
+
self.buffer = self.buffer[512:]
|
| 122 |
+
if ret is None:
|
| 123 |
+
ret = r
|
| 124 |
+
elif r is not None:
|
| 125 |
+
if 'end' in r:
|
| 126 |
+
ret['end'] = r['end'] # the latter end
|
| 127 |
+
if 'start' in r and 'end' in ret: # there is an earlier start.
|
| 128 |
+
# Remove end, merging this segment with the previous one.
|
| 129 |
+
del ret['end']
|
| 130 |
+
return ret if ret != {} else None
|
| 131 |
+
|
| 132 |
+
if __name__ == "__main__":
|
| 133 |
+
# test/demonstrate the need for FixedVADIterator:
|
| 134 |
+
|
| 135 |
+
import torch
|
| 136 |
+
model, _ = torch.hub.load(
|
| 137 |
+
repo_or_dir='snakers4/silero-vad',
|
| 138 |
+
model='silero_vad'
|
| 139 |
+
)
|
| 140 |
+
vac = FixedVADIterator(model)
|
| 141 |
+
# vac = VADIterator(model) # the second case crashes with this
|
| 142 |
+
|
| 143 |
+
# this works: for both
|
| 144 |
+
audio_buffer = np.array([0]*(512),dtype=np.float32)
|
| 145 |
+
vac(audio_buffer)
|
| 146 |
+
|
| 147 |
+
# this crashes on the non FixedVADIterator with
|
| 148 |
+
# ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError")
|
| 149 |
+
audio_buffer = np.array([0]*(512-1),dtype=np.float32)
|
| 150 |
+
vac(audio_buffer)
|
whisper_streaming/vac_online_processor.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from whisper_streaming.base import OnlineProcessorInterface
|
| 2 |
+
from whisper_streaming.silero_vad_iterator import FixedVADIterator
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
class VACOnlineASRProcessor(OnlineProcessorInterface):
|
| 10 |
+
'''Wraps OnlineASRProcessor with VAC (Voice Activity Controller).
|
| 11 |
+
|
| 12 |
+
It works the same way as OnlineASRProcessor: it receives chunks of audio (e.g. 0.04 seconds),
|
| 13 |
+
it runs VAD and continuously detects whether there is speech or not.
|
| 14 |
+
When it detects end of speech (non-voice for 500ms), it makes OnlineASRProcessor to end the utterance immediately.
|
| 15 |
+
'''
|
| 16 |
+
|
| 17 |
+
def __init__(self, online_chunk_size, online, min_buffered_length=1):
|
| 18 |
+
self.online_chunk_size = online_chunk_size
|
| 19 |
+
self.online = online
|
| 20 |
+
|
| 21 |
+
self.min_buffered_frames = int(min_buffered_length * self.SAMPLING_RATE)
|
| 22 |
+
|
| 23 |
+
# VAC:
|
| 24 |
+
import torch
|
| 25 |
+
model, _ = torch.hub.load(
|
| 26 |
+
repo_or_dir='snakers4/silero-vad',
|
| 27 |
+
model='silero_vad'
|
| 28 |
+
)
|
| 29 |
+
self.vac = FixedVADIterator(model) # we use the default options there: 500ms silence, 100ms padding, etc.
|
| 30 |
+
|
| 31 |
+
self.init()
|
| 32 |
+
|
| 33 |
+
def init(self):
|
| 34 |
+
self.online.init()
|
| 35 |
+
self.vac.reset_states()
|
| 36 |
+
self.current_online_chunk_buffer_size = 0
|
| 37 |
+
|
| 38 |
+
self.is_currently_final = False
|
| 39 |
+
|
| 40 |
+
self.status = None # or "voice" or "nonvoice"
|
| 41 |
+
self.audio_buffer = np.array([],dtype=np.float32)
|
| 42 |
+
self.buffer_offset = 0 # in frames
|
| 43 |
+
|
| 44 |
+
def clear_buffer(self):
|
| 45 |
+
self.audio_buffer = np.array([],dtype=np.float32)
|
| 46 |
+
|
| 47 |
+
def insert_audio_chunk(self, audio):
|
| 48 |
+
res = self.vac(audio)
|
| 49 |
+
self.audio_buffer = np.append(self.audio_buffer, audio)
|
| 50 |
+
if res is not None:
|
| 51 |
+
frame = list(res.values())[0] - self.buffer_offset
|
| 52 |
+
frame = max(0, frame)
|
| 53 |
+
if 'start' in res and 'end' not in res:
|
| 54 |
+
self.status = 'voice'
|
| 55 |
+
send_audio = self.audio_buffer[frame:]
|
| 56 |
+
self.online.init(offset=(frame + self.buffer_offset)/self.SAMPLING_RATE)
|
| 57 |
+
self.online.insert_audio_chunk(send_audio)
|
| 58 |
+
self.current_online_chunk_buffer_size += len(send_audio)
|
| 59 |
+
self.buffer_offset += len(self.audio_buffer)
|
| 60 |
+
self.clear_buffer()
|
| 61 |
+
elif 'end' in res and 'start' not in res:
|
| 62 |
+
self.status = 'nonvoice'
|
| 63 |
+
if frame > 0:
|
| 64 |
+
send_audio = self.audio_buffer[:frame]
|
| 65 |
+
self.online.insert_audio_chunk(send_audio)
|
| 66 |
+
self.current_online_chunk_buffer_size += len(send_audio)
|
| 67 |
+
self.is_currently_final = True
|
| 68 |
+
keep_frames = min(len(self.audio_buffer) - frame, self.min_buffered_frames)
|
| 69 |
+
self.buffer_offset += len(self.audio_buffer) - keep_frames
|
| 70 |
+
self.audio_buffer = self.audio_buffer[-keep_frames:]
|
| 71 |
+
else:
|
| 72 |
+
beg = max(0, res["start"] - self.buffer_offset)
|
| 73 |
+
end = max(0, res["end"] - self.buffer_offset)
|
| 74 |
+
self.status = 'nonvoice'
|
| 75 |
+
if beg < end:
|
| 76 |
+
send_audio = self.audio_buffer[beg:end]
|
| 77 |
+
self.online.init(offset=((beg + self.buffer_offset)/self.SAMPLING_RATE))
|
| 78 |
+
self.online.insert_audio_chunk(send_audio)
|
| 79 |
+
self.current_online_chunk_buffer_size += len(send_audio)
|
| 80 |
+
self.is_currently_final = True
|
| 81 |
+
keep_frames = min(len(self.audio_buffer) - end, self.min_buffered_frames)
|
| 82 |
+
self.buffer_offset += len(self.audio_buffer) - keep_frames
|
| 83 |
+
self.audio_buffer = self.audio_buffer[-keep_frames:]
|
| 84 |
+
else:
|
| 85 |
+
if self.status == 'voice':
|
| 86 |
+
self.online.insert_audio_chunk(self.audio_buffer)
|
| 87 |
+
self.current_online_chunk_buffer_size += len(self.audio_buffer)
|
| 88 |
+
self.buffer_offset += len(self.audio_buffer)
|
| 89 |
+
self.clear_buffer()
|
| 90 |
+
else:
|
| 91 |
+
# We keep 1 second because VAD may later find start of voice in it.
|
| 92 |
+
# But we trim it to prevent OOM.
|
| 93 |
+
self.buffer_offset += max(0, len(self.audio_buffer) - self.min_buffered_frames)
|
| 94 |
+
self.audio_buffer = self.audio_buffer[-self.min_buffered_frames:]
|
| 95 |
+
|
| 96 |
+
def process_iter(self):
|
| 97 |
+
if self.is_currently_final:
|
| 98 |
+
return self.finish()
|
| 99 |
+
elif self.current_online_chunk_buffer_size > self.SAMPLING_RATE*self.online_chunk_size:
|
| 100 |
+
self.current_online_chunk_buffer_size = 0
|
| 101 |
+
ret = self.online.process_iter()
|
| 102 |
+
return ret
|
| 103 |
+
else:
|
| 104 |
+
logger.info(f"no online update, only VAD. {self.status}")
|
| 105 |
+
return {}
|
| 106 |
+
|
| 107 |
+
def finish(self):
|
| 108 |
+
ret = self.online.finish()
|
| 109 |
+
self.current_online_chunk_buffer_size = 0
|
| 110 |
+
self.is_currently_final = False
|
| 111 |
+
return ret
|
whisper_streaming/whisper_online_main.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
# This code is retrieved from the original WhisperStreaming whisper_online.py .
|
| 4 |
+
# It is refactored and simplified. Only the code that is needed for the
|
| 5 |
+
# SimulWhisper backend is kept.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import sys
|
| 9 |
+
import numpy as np
|
| 10 |
+
import librosa
|
| 11 |
+
from functools import lru_cache
|
| 12 |
+
import time
|
| 13 |
+
import logging
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
@lru_cache(10**6)
|
| 19 |
+
def load_audio(fname):
|
| 20 |
+
a, _ = librosa.load(fname, sr=16000, dtype=np.float32)
|
| 21 |
+
return a
|
| 22 |
+
|
| 23 |
+
def load_audio_chunk(fname, beg, end):
|
| 24 |
+
audio = load_audio(fname)
|
| 25 |
+
beg_s = int(beg*16000)
|
| 26 |
+
end_s = int(end*16000)
|
| 27 |
+
return audio[beg_s:end_s]
|
| 28 |
+
|
| 29 |
+
def processor_args(parser):
|
| 30 |
+
"""shared args for the online processors
|
| 31 |
+
parser: argparse.ArgumentParser object
|
| 32 |
+
"""
|
| 33 |
+
group = parser.add_argument_group("WhisperStreaming processor arguments (shared for simulation from file and for the server)")
|
| 34 |
+
group.add_argument('--min-chunk-size', type=float, default=1.2,
|
| 35 |
+
help='Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter '
|
| 36 |
+
'time, it waits, otherwise it processes the whole segment that was received by this time.')
|
| 37 |
+
|
| 38 |
+
group.add_argument('--lan', '--language', type=str, default="en",
|
| 39 |
+
help="Source language code, e.g. en, de, cs, or auto for automatic language detection from speech.")
|
| 40 |
+
group.add_argument('--task', type=str, default='transcribe',
|
| 41 |
+
choices=["transcribe","translate"],
|
| 42 |
+
help="Transcribe or translate.")
|
| 43 |
+
|
| 44 |
+
group.add_argument('--vac', action="store_true", default=False,
|
| 45 |
+
help='Use VAC = voice activity controller. Recommended. Requires torch.')
|
| 46 |
+
group.add_argument('--vac-chunk-size', type=float, default=0.04,
|
| 47 |
+
help='VAC sample size in seconds.')
|
| 48 |
+
|
| 49 |
+
parser.add_argument("-l", "--log-level", dest="log_level",
|
| 50 |
+
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
| 51 |
+
help="Set the log level", default='DEBUG')
|
| 52 |
+
|
| 53 |
+
parser.add_argument("--logdir", help="Directory to save audio segments and generated texts for debugging.",
|
| 54 |
+
default=None)
|
| 55 |
+
|
| 56 |
+
def asr_factory(args, factory=None):
|
| 57 |
+
"""
|
| 58 |
+
Creates and configures an asr and online processor object through factory that is implemented in the backend.
|
| 59 |
+
"""
|
| 60 |
+
# if backend is None:
|
| 61 |
+
# backend = args.backend
|
| 62 |
+
# if backend == "simul-whisper":
|
| 63 |
+
# from simul_whisper_backend import simul_asr_factory
|
| 64 |
+
asr, online = factory(args)
|
| 65 |
+
|
| 66 |
+
# Create the OnlineASRProcessor
|
| 67 |
+
if args.vac:
|
| 68 |
+
from whisper_streaming.vac_online_processor import VACOnlineASRProcessor
|
| 69 |
+
online = VACOnlineASRProcessor(args.min_chunk_size, online)
|
| 70 |
+
|
| 71 |
+
if args.task == "translate":
|
| 72 |
+
if args.model_path.endswith(".en.pt"):
|
| 73 |
+
logger.error(f"The model {args.model_path} is English only. Translation is not available. Terminating.")
|
| 74 |
+
sys.exit(1)
|
| 75 |
+
asr.set_translate_task()
|
| 76 |
+
|
| 77 |
+
return asr, online
|
| 78 |
+
|
| 79 |
+
def set_logging(args,logger):
|
| 80 |
+
logging.basicConfig(
|
| 81 |
+
# this format would include module name:
|
| 82 |
+
# format='%(levelname)s\t%(name)s\t%(message)s')
|
| 83 |
+
format='%(levelname)s\t%(message)s')
|
| 84 |
+
logger.setLevel(args.log_level)
|
| 85 |
+
logging.getLogger("simul_whisper").setLevel(args.log_level)
|
| 86 |
+
logging.getLogger("whisper_streaming").setLevel(args.log_level)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def simulation_args(parser):
|
| 90 |
+
simulation_group = parser.add_argument_group("Arguments for simulation from file")
|
| 91 |
+
simulation_group.add_argument('audio_path', type=str, help="Filename of 16kHz mono channel wav, on which live streaming is simulated.")
|
| 92 |
+
simulation_group.add_argument('--start_at', type=float, default=0.0, help='Start processing audio at this time.')
|
| 93 |
+
# TODO: offline mode is not implemented in SimulStreaming yet
|
| 94 |
+
# simulation_group.add_argument('--offline', action="store_true", default=False, help='Offline mode.')
|
| 95 |
+
simulation_group.add_argument('--comp_unaware', action="store_true", default=False, help='Computationally unaware simulation.')
|
| 96 |
+
|
| 97 |
+
def main_simulation_from_file(factory, add_args=None):
|
| 98 |
+
'''
|
| 99 |
+
factory: function that creates the ASR and online processor object from args and logger.
|
| 100 |
+
or in the default WhisperStreaming local agreement backends (not implemented but could be).
|
| 101 |
+
add_args: add specific args for the backend
|
| 102 |
+
'''
|
| 103 |
+
|
| 104 |
+
import argparse
|
| 105 |
+
parser = argparse.ArgumentParser()
|
| 106 |
+
|
| 107 |
+
processor_args(parser)
|
| 108 |
+
if add_args is not None:
|
| 109 |
+
add_args(parser)
|
| 110 |
+
|
| 111 |
+
simulation_args(parser)
|
| 112 |
+
|
| 113 |
+
args = parser.parse_args()
|
| 114 |
+
args.offline = False # TODO: offline mode is not implemented in SimulStreaming yet
|
| 115 |
+
|
| 116 |
+
if args.offline and args.comp_unaware:
|
| 117 |
+
logger.error("No or one option from --offline and --comp_unaware are available, not both. Exiting.")
|
| 118 |
+
sys.exit(1)
|
| 119 |
+
|
| 120 |
+
set_logging(args,logger)
|
| 121 |
+
|
| 122 |
+
audio_path = args.audio_path
|
| 123 |
+
|
| 124 |
+
SAMPLING_RATE = 16000
|
| 125 |
+
duration = len(load_audio(audio_path))/SAMPLING_RATE
|
| 126 |
+
logger.info("Audio duration is: %2.2f seconds" % duration)
|
| 127 |
+
|
| 128 |
+
asr, online = asr_factory(args, factory)
|
| 129 |
+
if args.vac:
|
| 130 |
+
min_chunk = args.vac_chunk_size
|
| 131 |
+
else:
|
| 132 |
+
min_chunk = args.min_chunk_size
|
| 133 |
+
|
| 134 |
+
# load the audio into the LRU cache before we start the timer
|
| 135 |
+
a = load_audio_chunk(audio_path,0,1)
|
| 136 |
+
|
| 137 |
+
# warm up the ASR because the very first transcribe takes much more time than the other
|
| 138 |
+
asr.warmup(a)
|
| 139 |
+
|
| 140 |
+
beg = args.start_at
|
| 141 |
+
start = time.time()-beg
|
| 142 |
+
|
| 143 |
+
def output_transcript(iteration_output, now=None):
|
| 144 |
+
# output format in stdout is like:
|
| 145 |
+
# 4186.3606 0 1720 Takhle to je
|
| 146 |
+
# - the first three words are:
|
| 147 |
+
# - emission time from beginning of processing, in milliseconds
|
| 148 |
+
# - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway
|
| 149 |
+
# - the next words: segment transcript
|
| 150 |
+
if now is None:
|
| 151 |
+
now = time.time() - start
|
| 152 |
+
|
| 153 |
+
if iteration_output:
|
| 154 |
+
start_ts = iteration_output['start']
|
| 155 |
+
end_ts = iteration_output['end']
|
| 156 |
+
text = iteration_output['text']
|
| 157 |
+
logger.debug(f"{now * 1000:.4f} {start_ts * 1000:.0f} {end_ts * 1000:.0f} {text}")
|
| 158 |
+
print(f"{now * 1000:.4f} {start_ts * 1000:.0f} {end_ts * 1000:.0f} {text}", flush=True)
|
| 159 |
+
else:
|
| 160 |
+
logger.debug("No text in this segment")
|
| 161 |
+
|
| 162 |
+
if args.offline: ## offline mode processing (for testing/debugging)
|
| 163 |
+
a = load_audio(audio_path)
|
| 164 |
+
online.insert_audio_chunk(a)
|
| 165 |
+
try:
|
| 166 |
+
o = online.process_iter()
|
| 167 |
+
except AssertionError as e:
|
| 168 |
+
logger.error(f"assertion error: {repr(e)}")
|
| 169 |
+
else:
|
| 170 |
+
output_transcript(o)
|
| 171 |
+
now = None
|
| 172 |
+
elif args.comp_unaware: # computational unaware mode
|
| 173 |
+
end = beg + min_chunk
|
| 174 |
+
while True:
|
| 175 |
+
a = load_audio_chunk(audio_path,beg,end)
|
| 176 |
+
online.insert_audio_chunk(a)
|
| 177 |
+
try:
|
| 178 |
+
o = online.process_iter()
|
| 179 |
+
except AssertionError as e:
|
| 180 |
+
logger.error(f"assertion error: {repr(e)}")
|
| 181 |
+
pass
|
| 182 |
+
else:
|
| 183 |
+
output_transcript(o, now=end)
|
| 184 |
+
|
| 185 |
+
logger.info(f"## last processed {end:.2f}s")
|
| 186 |
+
|
| 187 |
+
if end >= duration:
|
| 188 |
+
break
|
| 189 |
+
|
| 190 |
+
beg = end
|
| 191 |
+
|
| 192 |
+
if end + min_chunk > duration:
|
| 193 |
+
end = duration
|
| 194 |
+
else:
|
| 195 |
+
end += min_chunk
|
| 196 |
+
now = duration
|
| 197 |
+
|
| 198 |
+
else: # online = simultaneous mode
|
| 199 |
+
end = 0
|
| 200 |
+
while True:
|
| 201 |
+
now = time.time() - start
|
| 202 |
+
if now < end+min_chunk:
|
| 203 |
+
time.sleep(min_chunk+end-now)
|
| 204 |
+
end = time.time() - start
|
| 205 |
+
a = load_audio_chunk(audio_path,beg,end)
|
| 206 |
+
beg = end
|
| 207 |
+
online.insert_audio_chunk(a)
|
| 208 |
+
|
| 209 |
+
try:
|
| 210 |
+
o = online.process_iter()
|
| 211 |
+
except AssertionError as e:
|
| 212 |
+
logger.error(f"assertion error: {e}")
|
| 213 |
+
pass
|
| 214 |
+
else:
|
| 215 |
+
output_transcript(o)
|
| 216 |
+
now = time.time() - start
|
| 217 |
+
logger.info(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}")
|
| 218 |
+
|
| 219 |
+
if end >= duration:
|
| 220 |
+
break
|
| 221 |
+
now = None
|
| 222 |
+
|
| 223 |
+
o = online.finish()
|
| 224 |
+
output_transcript(o, now=now)
|
whisper_streaming/whisper_server.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from whisper_streaming.whisper_online_main import *
|
| 3 |
+
|
| 4 |
+
import sys
|
| 5 |
+
import argparse
|
| 6 |
+
import os
|
| 7 |
+
import logging
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
SAMPLING_RATE = 16000
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
######### Server objects
|
| 16 |
+
|
| 17 |
+
import whisper_streaming.line_packet as line_packet
|
| 18 |
+
import socket
|
| 19 |
+
|
| 20 |
+
class Connection:
|
| 21 |
+
'''it wraps conn object'''
|
| 22 |
+
PACKET_SIZE = 32000*5*60 # 5 minutes # was: 65536
|
| 23 |
+
|
| 24 |
+
def __init__(self, conn):
|
| 25 |
+
self.conn = conn
|
| 26 |
+
self.last_line = ""
|
| 27 |
+
|
| 28 |
+
self.conn.setblocking(True)
|
| 29 |
+
|
| 30 |
+
def send(self, line):
|
| 31 |
+
'''it doesn't send the same line twice, because it was problematic in online-text-flow-events'''
|
| 32 |
+
if line == self.last_line:
|
| 33 |
+
return
|
| 34 |
+
line_packet.send_one_line(self.conn, line)
|
| 35 |
+
self.last_line = line
|
| 36 |
+
|
| 37 |
+
def receive_lines(self):
|
| 38 |
+
in_line = line_packet.receive_lines(self.conn)
|
| 39 |
+
return in_line
|
| 40 |
+
|
| 41 |
+
def non_blocking_receive_audio(self):
|
| 42 |
+
try:
|
| 43 |
+
r = self.conn.recv(self.PACKET_SIZE)
|
| 44 |
+
return r
|
| 45 |
+
except ConnectionResetError:
|
| 46 |
+
return None
|
| 47 |
+
|
| 48 |
+
import io
|
| 49 |
+
import soundfile
|
| 50 |
+
|
| 51 |
+
# wraps socket and ASR object, and serves one client connection.
|
| 52 |
+
# next client should be served by a new instance of this object
|
| 53 |
+
class ServerProcessor:
|
| 54 |
+
|
| 55 |
+
def __init__(self, c, online_asr_proc, min_chunk):
|
| 56 |
+
self.connection = c
|
| 57 |
+
self.online_asr_proc = online_asr_proc
|
| 58 |
+
self.min_chunk = min_chunk
|
| 59 |
+
|
| 60 |
+
self.is_first = True
|
| 61 |
+
|
| 62 |
+
def receive_audio_chunk(self):
|
| 63 |
+
# receive all audio that is available by this time
|
| 64 |
+
# blocks operation if less than self.min_chunk seconds is available
|
| 65 |
+
# unblocks if connection is closed or a chunk is available
|
| 66 |
+
out = []
|
| 67 |
+
minlimit = self.min_chunk*SAMPLING_RATE
|
| 68 |
+
while sum(len(x) for x in out) < minlimit:
|
| 69 |
+
raw_bytes = self.connection.non_blocking_receive_audio()
|
| 70 |
+
if not raw_bytes:
|
| 71 |
+
break
|
| 72 |
+
# print("received audio:",len(raw_bytes), "bytes", raw_bytes[:10])
|
| 73 |
+
sf = soundfile.SoundFile(io.BytesIO(raw_bytes), channels=1,endian="LITTLE",samplerate=SAMPLING_RATE, subtype="PCM_16",format="RAW")
|
| 74 |
+
audio, _ = librosa.load(sf,sr=SAMPLING_RATE,dtype=np.float32)
|
| 75 |
+
out.append(audio)
|
| 76 |
+
if not out:
|
| 77 |
+
return None
|
| 78 |
+
conc = np.concatenate(out)
|
| 79 |
+
if self.is_first and len(conc) < minlimit:
|
| 80 |
+
return None
|
| 81 |
+
self.is_first = False
|
| 82 |
+
return np.concatenate(out)
|
| 83 |
+
|
| 84 |
+
def send_result(self, iteration_output):
|
| 85 |
+
# output format in stdout is like:
|
| 86 |
+
# 0 1720 Takhle to je
|
| 87 |
+
# - the first two words are:
|
| 88 |
+
# - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway
|
| 89 |
+
# - the next words: segment transcript
|
| 90 |
+
if iteration_output:
|
| 91 |
+
message = "%1.0f %1.0f %s" % (iteration_output['start'] * 1000, iteration_output['end'] * 1000, iteration_output['text'])
|
| 92 |
+
print(message, flush=True, file=sys.stderr)
|
| 93 |
+
self.connection.send(message)
|
| 94 |
+
else:
|
| 95 |
+
logger.debug("No text in this segment")
|
| 96 |
+
|
| 97 |
+
def process(self):
|
| 98 |
+
# handle one client connection
|
| 99 |
+
self.online_asr_proc.init()
|
| 100 |
+
while True:
|
| 101 |
+
a = self.receive_audio_chunk()
|
| 102 |
+
if a is None:
|
| 103 |
+
break
|
| 104 |
+
self.online_asr_proc.insert_audio_chunk(a)
|
| 105 |
+
o = self.online_asr_proc.process_iter()
|
| 106 |
+
try:
|
| 107 |
+
self.send_result(o)
|
| 108 |
+
except BrokenPipeError:
|
| 109 |
+
logger.info("broken pipe -- connection closed?")
|
| 110 |
+
break
|
| 111 |
+
|
| 112 |
+
# o = online.finish() # this should be working
|
| 113 |
+
# self.send_result(o)
|
| 114 |
+
|
| 115 |
+
def main_server(factory, add_args):
|
| 116 |
+
'''
|
| 117 |
+
factory: function that creates the ASR and online processor object from args and logger.
|
| 118 |
+
or in the default WhisperStreaming local agreement backends (not implemented but could be).
|
| 119 |
+
add_args: add specific args for the backend
|
| 120 |
+
'''
|
| 121 |
+
logger = logging.getLogger(__name__)
|
| 122 |
+
parser = argparse.ArgumentParser()
|
| 123 |
+
|
| 124 |
+
# server options
|
| 125 |
+
parser.add_argument("--host", type=str, default='localhost')
|
| 126 |
+
parser.add_argument("--port", type=int, default=43007)
|
| 127 |
+
parser.add_argument("--warmup-file", type=str, dest="warmup_file",
|
| 128 |
+
help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. "
|
| 129 |
+
"https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .")
|
| 130 |
+
|
| 131 |
+
# options from whisper_online
|
| 132 |
+
processor_args(parser)
|
| 133 |
+
|
| 134 |
+
add_args(parser)
|
| 135 |
+
|
| 136 |
+
args = parser.parse_args()
|
| 137 |
+
|
| 138 |
+
set_logging(args,logger)
|
| 139 |
+
|
| 140 |
+
# setting whisper object by args
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
asr, online = asr_factory(args, factory)
|
| 144 |
+
if args.vac:
|
| 145 |
+
min_chunk = args.vac_chunk_size
|
| 146 |
+
else:
|
| 147 |
+
min_chunk = args.min_chunk_size
|
| 148 |
+
|
| 149 |
+
# warm up the ASR because the very first transcribe takes more time than the others.
|
| 150 |
+
# Test results in https://github.com/ufal/whisper_streaming/pull/81
|
| 151 |
+
msg = "Whisper is not warmed up. The first chunk processing may take longer."
|
| 152 |
+
if args.warmup_file:
|
| 153 |
+
if os.path.isfile(args.warmup_file):
|
| 154 |
+
a = load_audio_chunk(args.warmup_file,0,1)
|
| 155 |
+
asr.warmup(a)
|
| 156 |
+
logger.info("Whisper is warmed up.")
|
| 157 |
+
else:
|
| 158 |
+
logger.critical("The warm up file is not available. "+msg)
|
| 159 |
+
sys.exit(1)
|
| 160 |
+
else:
|
| 161 |
+
logger.warning(msg)
|
| 162 |
+
|
| 163 |
+
# server loop
|
| 164 |
+
|
| 165 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 166 |
+
s.bind((args.host, args.port))
|
| 167 |
+
s.listen(1)
|
| 168 |
+
logger.info('Listening on'+str((args.host, args.port)))
|
| 169 |
+
while True:
|
| 170 |
+
conn, addr = s.accept()
|
| 171 |
+
logger.info('Connected to client on {}'.format(addr))
|
| 172 |
+
connection = Connection(conn)
|
| 173 |
+
proc = ServerProcessor(connection, online, min_chunk)
|
| 174 |
+
proc.process()
|
| 175 |
+
conn.close()
|
| 176 |
+
logger.info('Connection to client closed')
|
| 177 |
+
logger.info('Connection closed, terminating.')
|