rjzevallos commited on
Commit
d860e14
·
1 Parent(s): 7b7fdec

Add simulstreaming_whisper module, update requirements, improve Dockerfile and model handling

Browse files
.gitignore ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ pip-wheel-metadata/
20
+ share/python-wheels/
21
+ *.egg-info/
22
+ .installed.cfg
23
+ *.egg
24
+ MANIFEST
25
+
26
+ # Virtual environments
27
+ venv/
28
+ ENV/
29
+ env/
30
+
31
+ # IDE
32
+ .vscode/
33
+ .idea/
34
+ *.swp
35
+ *.swo
36
+ *~
37
+
38
+ # Cache
39
+ .cache/
40
+ *.cache
41
+
42
+ # Models (large files)
43
+ *.pt
44
+ *.pth
45
+ *.bin
46
+
47
+ # Logs
48
+ logs/
49
+ *.log
50
+
51
+ # OS
52
+ .DS_Store
53
+ Thumbs.db
54
+
55
+ # Local development
56
+ .env
57
+ .env.local
Dockerfile CHANGED
@@ -18,9 +18,8 @@ RUN pip install --no-cache-dir -r requirements.txt
18
  # Copia el código de la aplicación
19
  COPY . .
20
 
21
- # (Opcional) Pre-descargar el modelo Whisper si tienes conexión en build time
22
- # Descomentar si quieres que el modelo se baje durante la construcción del Docker
23
- # RUN python -c "import torch; torch.hub.load('openai/whisper', 'large-v3')"
24
 
25
  EXPOSE 7860
26
 
 
18
  # Copia el código de la aplicación
19
  COPY . .
20
 
21
+ # Pre-descargar el modelo durante la construcción
22
+ RUN python -c "import torch; torch.hub.load('openai/whisper', 'large-v3')"
 
23
 
24
  EXPOSE 7860
25
 
requirements.txt CHANGED
@@ -7,3 +7,6 @@ numpy>=1.24.0
7
  torch>=2.0.0
8
  transformers>=4.30.0
9
  torchaudio>=2.0.0
 
 
 
 
7
  torch>=2.0.0
8
  transformers>=4.30.0
9
  torchaudio>=2.0.0
10
+ tqdm
11
+ tiktoken
12
+ triton>=2.0.0,<3;platform_machine=="x86_64" and sys_platform=="linux" or sys_platform=="linux2"
server_wrapper.py CHANGED
@@ -14,13 +14,26 @@ _asr = None
14
  _online = None
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def _make_args():
18
  # Minimal args required by simul_asr_factory
19
  return SimpleNamespace(
20
  log_level='INFO',
21
  decoder=None,
22
  beams=1,
23
- model_path='./large-v3.pt',
24
  cif_ckpt_path=None,
25
  frame_threshold=25,
26
  audio_min_len=0.0,
 
14
  _online = None
15
 
16
 
17
+ def _get_model_path():
18
+ """Get the path to the Whisper model. Download if needed."""
19
+ import os
20
+ model_dir = os.path.expanduser('~/.cache/torch/hub/checkpoints')
21
+ model_path = os.path.join(model_dir, 'large-v3.pt')
22
+
23
+ if not os.path.exists(model_path):
24
+ print(f"Model not found at {model_path}. Downloading...")
25
+ import torch
26
+ torch.hub.load('openai/whisper', 'large-v3')
27
+
28
+ return model_path
29
+
30
  def _make_args():
31
  # Minimal args required by simul_asr_factory
32
  return SimpleNamespace(
33
  log_level='INFO',
34
  decoder=None,
35
  beams=1,
36
+ model_path=_get_model_path(),
37
  cif_ckpt_path=None,
38
  frame_threshold=25,
39
  audio_min_len=0.0,
simul_whisper/__init__.py ADDED
File without changes
simul_whisper/beam.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .whisper.decoding import PyTorchInference
2
+
3
+ # extention of PyTorchInference for beam search
4
+ class BeamPyTorchInference(PyTorchInference):
5
+
6
+ def _kv_modules(self):
7
+ key_modules = [block.attn.key.cache_id for block in self.model.decoder.blocks]
8
+ value_modules = [block.attn.value.cache_id for block in self.model.decoder.blocks]
9
+ return key_modules + value_modules
10
+
11
+ def rearrange_kv_cache(self, source_indices):
12
+ if source_indices != list(range(len(source_indices))):
13
+ for module_cache_id in self._kv_modules():
14
+ self.kv_cache[module_cache_id] = self.kv_cache[module_cache_id][source_indices].detach()
15
+ from torch import Tensor
16
+ def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
17
+ return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
simul_whisper/config.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code was originally in simul_whisper/transcriber/simul_whisper.py . It is adapted a lot for SimulStreaming.
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Literal
5
+
6
+ @dataclass
7
+ class SimulWhisperConfig:
8
+ '''Options that are common for all simul policies that could be implemented in SimulWhisper.'''
9
+ model_path: str
10
+ language: str = field(default="zh")
11
+ nonspeech_prob: float = 1.0
12
+ audio_min_len: float = 1.0
13
+ decoder_type: Literal["greedy","beam"] = "greedy"
14
+ beam_size: int = 5
15
+ task: Literal["transcribe","translate"] = "transcribe"
16
+ init_prompt: str = field(default=None)
17
+ static_init_prompt: str = field(default=None)
18
+ max_context_tokens: int = field(default=None)
19
+
20
+ logdir: str = field(default="logdir", metadata={"help": "Directory to save audio segments and tokens for debugging purposes."})
21
+
22
+ @dataclass
23
+ class AlignAttConfig(SimulWhisperConfig):
24
+ '''Options specific to the AlignAtt policy.'''
25
+ eval_data_path: str = "tmp"
26
+ segment_length: float = field(default=1.0, metadata = {"help": "in second"})
27
+ frame_threshold: int = 4
28
+ rewind_threshold: int = 200 # in frames. Max value is 1500. Higher value turns rewinds off.
29
+ audio_max_len: float = 30.0
30
+ cif_ckpt_path: str = ""
31
+ never_fire: bool = False
simul_whisper/eow_detection.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # code for the end-of-word detection based on the CIF model proposed in Simul-Whisper
4
+
5
+ def load_cif(cfg, n_audio_state, device):
6
+ """cfg: AlignAttConfig, n_audio_state: int, device: torch.device"""
7
+ cif_linear = torch.nn.Linear(n_audio_state, 1)
8
+ if cfg.cif_ckpt_path is None or not cfg.cif_ckpt_path:
9
+ if cfg.never_fire:
10
+ never_fire = True
11
+ always_fire = False
12
+ else:
13
+ always_fire = True
14
+ never_fire = False
15
+ else:
16
+ always_fire = False
17
+ never_fire = cfg.never_fire
18
+ map_location = None
19
+ if not torch.cuda.is_available():
20
+ map_location=torch.device('cpu')
21
+ checkpoint = torch.load(cfg.cif_ckpt_path, map_location=map_location)
22
+ cif_linear.load_state_dict(checkpoint)
23
+ cif_linear.to(device)
24
+ return cif_linear, always_fire, never_fire
25
+
26
+
27
+ # from https://github.com/dqqcasia/mosst/blob/master/fairseq/models/speech_to_text/convtransformer_wav2vec_cif.py
28
+ def resize(alphas, target_lengths, threshold=0.999):
29
+ """
30
+ alpha in thresh=1.0 | (0.0, +0.21)
31
+ target_lengths: if None, apply round and resize, else apply scaling
32
+ """
33
+ # sum
34
+ _num = alphas.sum(-1)
35
+ num = target_lengths.float()
36
+ # scaling
37
+ _alphas = alphas * (num / _num)[:, None].repeat(1, alphas.size(1))
38
+ # rm attention value that exceeds threashold
39
+ count = 0
40
+ while len(torch.where(_alphas > threshold)[0]):
41
+ count += 1
42
+ if count > 10:
43
+ break
44
+ xs, ys = torch.where(_alphas > threshold)
45
+ for x, y in zip(xs, ys):
46
+ if _alphas[x][y] >= threshold:
47
+ mask = _alphas[x].ne(0).float()
48
+ mean = 0.5 * _alphas[x].sum() / mask.sum()
49
+ _alphas[x] = _alphas[x] * 0.5 + mean * mask
50
+
51
+ return _alphas, _num
52
+
53
+ def fire_at_boundary(chunked_encoder_feature: torch.Tensor, cif_linear):
54
+ content_mel_len = chunked_encoder_feature.shape[1] # B, T, D
55
+ alphas = cif_linear(chunked_encoder_feature).squeeze(dim=2) # B, T
56
+ alphas = torch.sigmoid(alphas)
57
+ decode_length = torch.round(alphas.sum(-1)).int()
58
+ alphas, _ = resize(alphas, decode_length)
59
+ alphas = alphas.squeeze(0) # (T, )
60
+ threshold = 0.999
61
+ integrate = torch.cumsum(alphas[:-1], dim=0) # ignore the peak value at the end of the content chunk
62
+ exceed_count = integrate[-1] // threshold
63
+ integrate = integrate - exceed_count*1.0 # minus 1 every time intergrate exceed the threshold
64
+ important_positions = (integrate >= 0).nonzero(as_tuple=True)[0]
65
+ if important_positions.numel() == 0:
66
+ return False
67
+ else:
68
+ return important_positions[0] >= content_mel_len-2
simul_whisper/generation_progress.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Tokens:
2
+ def __init__(self, tokens):
3
+ self.tokens = tokens
4
+
5
+ # def clone(self):
6
+ # return Tokens(self.tokens.clone())
7
+
8
+ def __str__(self):
9
+ return str(self.tokens.tolist())
10
+
11
+ def __repr__(self):
12
+ return self.__str__()
13
+
14
+ class BeamTokens(Tokens):
15
+ def __init__(self, tokens, beam_size):
16
+ self.tokens = tokens
17
+ self.beam_size = beam_size
18
+
19
+ def clone(self):
20
+ return BeamTokens(self.tokens.clone())
21
+
22
+ def __str__(self):
23
+ return f"BeamTokens({self.tokens.tolist()}, beam_size={self.beam_size})"
24
+
25
+ def __repr__(self):
26
+ return self.__str__()
27
+
28
+ def as_text(self, tokenizer):
29
+ return tokenizer.decode(self.tokens)
30
+
31
+ class Logits(Tokens):
32
+ def __init__(self, logits):
33
+ super().__init__(logits)
34
+
35
+ # def clone(self):
36
+ # return Logits(self.tokens.clone(), self.beam_size)
37
+
38
+ def __str__(self):
39
+ # return "abc"
40
+ return f"Logits({self.tokens.shape})"
41
+
42
+ def __repr__(self):
43
+ return self.__str__()
simul_whisper/simul_whisper.py ADDED
@@ -0,0 +1,649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code was originally in simul_whisper/transcriber/simul_whisper.py . It is adapted a lot for SimulStreaming.
2
+
3
+ import os
4
+ import logging
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from .whisper import load_model, DecodingOptions, tokenizer
10
+ from .config import AlignAttConfig
11
+ from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
12
+ from .whisper.timing import median_filter
13
+ from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language
14
+ from .beam import BeamPyTorchInference
15
+ from .eow_detection import fire_at_boundary, load_cif
16
+ import os
17
+
18
+ from token_buffer import TokenBuffer
19
+
20
+ import numpy as np
21
+ from .generation_progress import *
22
+
23
+ DEC_PAD = 50257
24
+ logger = logging.getLogger(__name__)
25
+
26
+ import sys
27
+ import wave
28
+
29
+ # New features added to the original version of Simul-Whisper:
30
+ # - large-v3 model support
31
+ # - translation support
32
+ # - beam search
33
+ # - prompt -- static vs. non-static
34
+ # - context
35
+ class PaddedAlignAttWhisper:
36
+ def __init__(self, cfg: AlignAttConfig) -> None:
37
+ self.logdir_i = 0
38
+ self.log_segments = 0
39
+ if cfg.logdir is not None and not os.path.exists(cfg.logdir):
40
+ os.makedirs(cfg.logdir)
41
+ model_name = os.path.basename(cfg.model_path).replace(".pt", "")
42
+ model_path = os.path.dirname(os.path.abspath(cfg.model_path))
43
+ self.model = load_model(name=model_name, download_root=model_path)
44
+
45
+ logger.info(f"Model dimensions: {self.model.dims}")
46
+
47
+ self.decode_options = DecodingOptions(
48
+ language = cfg.language,
49
+ without_timestamps = True,
50
+ task=cfg.task
51
+ )
52
+ self.tokenizer_is_multilingual = not model_name.endswith(".en")
53
+ self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
54
+ self.detected_language = cfg.language if cfg.language != "auto" else None
55
+
56
+ self.max_text_len = self.model.dims.n_text_ctx
57
+ self.num_decoder_layers = len(self.model.decoder.blocks)
58
+ self.cfg = cfg
59
+
60
+ # model to detect end-of-word boundary at the end of the segment
61
+ self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg,
62
+ n_audio_state=self.model.dims.n_audio_state,
63
+ device=self.model.device)
64
+
65
+ # install hooks to access encoder-decoder attention
66
+ self.dec_attns = []
67
+ def layer_hook(module, net_input, net_output):
68
+ # net_output[1]: B*num_head*token_len*audio_len
69
+ t = F.softmax(net_output[1], dim=-1)
70
+ self.dec_attns.append(t.squeeze(0))
71
+ for b in self.model.decoder.blocks:
72
+ b.cross_attn.register_forward_hook(layer_hook)
73
+
74
+ self.kv_cache = {}
75
+ def kv_hook(module: torch.nn.Linear, _, net_output: torch.Tensor):
76
+ if module.cache_id not in self.kv_cache or net_output.shape[1] > self.max_text_len:
77
+ # save as-is, for the first token or cross attention
78
+ self.kv_cache[module.cache_id] = net_output
79
+ else:
80
+ x = self.kv_cache[module.cache_id]
81
+ self.kv_cache[module.cache_id] = torch.cat([x, net_output], dim=1).detach()
82
+ return self.kv_cache[module.cache_id]
83
+
84
+ for i,b in enumerate(self.model.decoder.blocks):
85
+ b.attn.key.register_forward_hook(kv_hook)
86
+ b.attn.value.register_forward_hook(kv_hook)
87
+ b.cross_attn.key.register_forward_hook(kv_hook)
88
+ b.cross_attn.value.register_forward_hook(kv_hook)
89
+
90
+ self.align_source = {}
91
+ self.num_align_heads = 0
92
+ for layer_rank, head_id in self.model.alignment_heads.indices().T:
93
+ layer_rank = layer_rank.item()
94
+ heads = self.align_source.get(layer_rank, [])
95
+ heads.append((self.num_align_heads, head_id.item()))
96
+ self.align_source[layer_rank] = heads
97
+ self.num_align_heads += 1
98
+
99
+
100
+ # tokens to be suppressed from decoding, to prevent hallucinations
101
+ suppress_tokens = [
102
+ self.tokenizer.transcribe,
103
+ self.tokenizer.translate,
104
+ self.tokenizer.sot,
105
+ self.tokenizer.sot_prev,
106
+ self.tokenizer.sot_lm,
107
+ # self.tokenizer.eot
108
+ self.tokenizer.no_timestamps, # added by DM
109
+ ] + list(self.tokenizer.all_language_tokens) # added by DM
110
+ if self.tokenizer.no_speech is not None:
111
+ suppress_tokens.append(self.tokenizer.no_speech)
112
+ suppress_tokens = tuple(sorted(set(suppress_tokens)))
113
+ logger.debug(f"Suppress tokens: {suppress_tokens}")
114
+ sup_tokens = SuppressTokens(suppress_tokens)
115
+ self.suppress_tokens = lambda logits: sup_tokens.apply(logits, None)
116
+ # blank tokens are suppresed for new segments near the line 334
117
+
118
+ # it's going to be regenerated after lang id
119
+ self.segments = []
120
+ self.init_tokens()
121
+
122
+ self.last_attend_frame = -self.cfg.rewind_threshold
123
+
124
+ if self.cfg.max_context_tokens is None:
125
+ self.max_context_tokens = self.max_text_len
126
+ else:
127
+ self.max_context_tokens = self.cfg.max_context_tokens
128
+ self.init_context()
129
+
130
+ # decoder type: greedy or beam
131
+ if cfg.decoder_type == "greedy":
132
+ logger.info("Using greedy decoder")
133
+ self.token_decoder = GreedyDecoder(0.0, self.tokenizer.eot)
134
+ self.decoder_type = "greedy"
135
+
136
+ elif cfg.decoder_type == "beam":
137
+ self.decoder_type = "beam"
138
+ self.inference = BeamPyTorchInference(self.model, self.initial_token_length)
139
+ self.inference.kv_cache = self.kv_cache
140
+
141
+ self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
142
+
143
+ def create_tokenizer(self, language=None):
144
+ self.tokenizer = tokenizer.get_tokenizer(
145
+ multilingual=self.tokenizer_is_multilingual,
146
+ language=language,
147
+ num_languages=self.model.num_languages,
148
+ task=self.decode_options.task
149
+ )
150
+
151
+ def init_context(self):
152
+ kw = {'tokenizer': self.tokenizer,
153
+ 'device': self.model.device,
154
+ 'prefix_token_ids': [self.tokenizer.sot_prev]}
155
+ self.context = TokenBuffer.empty(**kw)
156
+ if self.cfg.static_init_prompt is not None:
157
+ self.context = TokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
158
+ if self.cfg.init_prompt is not None:
159
+ self.context.text += self.cfg.init_prompt
160
+
161
+ def init_tokens(self):
162
+ logger.debug(f"init tokens, {len(self.segments)}")
163
+ # init tokens (mandatory prompt)
164
+ self.initial_tokens = torch.tensor(
165
+ self.tokenizer.sot_sequence_including_notimestamps,
166
+ dtype=torch.long,
167
+ device=self.model.device).unsqueeze(0)
168
+ self.initial_token_length = self.initial_tokens.shape[1]
169
+ self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
170
+ # self.segments = []
171
+ logger.debug(f"init tokens after, {len(self.segments)}")
172
+ self.tokens = [self.initial_tokens]
173
+
174
+ def trim_context(self):
175
+ logger.info("Trimming context")
176
+ c = len(self.context.as_token_ids()) - len(self.context.prefix_token_ids)
177
+ # logger.debug(f"c= {len(self.context.as_token_ids())}, {len(self.context.prefix_token_ids)}")
178
+ logger.info(f"Context text: {self.context.as_text()}")
179
+ # logger.debug(f"Context tensor: {self.context.as_tensor()}")
180
+ l = sum(t.shape[1] for t in self.tokens) + c
181
+ # logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
182
+ if self.cfg.static_init_prompt is None:
183
+ after = 0
184
+ else:
185
+ after = len(self.cfg.static_init_prompt)
186
+ # logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
187
+ while c > self.max_context_tokens or l > self.max_text_len - 20:
188
+ t = self.context.trim_words(after=after)
189
+ l -= t
190
+ c -= t
191
+ logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
192
+ if t == 0:
193
+ break
194
+ # logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
195
+ logger.info(f"Context after trim: {self.context.text} (len: {l})")
196
+
197
+
198
+ def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor) -> torch.Tensor:
199
+ if self.cfg.decoder_type == "greedy":
200
+ logit = self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
201
+ else:
202
+ logger.debug(f"Logits shape: {tokens.shape}")
203
+ logit = self.inference.logits(tokens, audio_features)
204
+ return logit
205
+
206
+
207
+ def refresh_segment(self, complete=False):
208
+
209
+ logger.debug("Refreshing segment:")
210
+ self.init_tokens()
211
+ self.last_attend_frame = -self.cfg.rewind_threshold
212
+ self.detected_language = None
213
+ self.init_context()
214
+ logger.debug(f"Context: {self.context}")
215
+ if not complete and len(self.segments) > 2:
216
+ logger.debug("keeping last two segments because they are and it is not complete.")
217
+ self.segments = self.segments[-2:]
218
+ else:
219
+ logger.debug("removing all segments.")
220
+ self.segments = []
221
+ self.log_segments += 1
222
+
223
+
224
+ def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
225
+ if self.always_fire: return True
226
+ if self.never_fire: return False
227
+ return fire_at_boundary(chunked_encoder_feature, self.CIFLinear)
228
+
229
+
230
+ def _current_tokens(self):
231
+
232
+ toks = self.tokens
233
+ # very first infer: duplicate start of seq to beam_size
234
+ if toks[0].shape[0] == 1:
235
+ toks[0] = toks[0].repeat_interleave(self.cfg.beam_size,dim=0)
236
+
237
+ if not self.context.is_empty():
238
+ context_toks = self.context.as_tensor_beam(self.cfg.beam_size, device=self.model.device)
239
+ toks = [context_toks] + toks
240
+
241
+ # make it one tensor
242
+ if len(toks) > 1:
243
+ current_tokens = torch.cat(toks, dim=1)
244
+ else:
245
+ current_tokens = toks[0]
246
+ logger.debug("debug print current_tokens:")
247
+ self.debug_print_tokens(current_tokens)
248
+ return current_tokens
249
+
250
+
251
+ def debug_print_tokens(self, tokens):
252
+ for i in range(self.cfg.beam_size):
253
+ logger.debug(self.tokenizer.decode_with_timestamps(tokens[i].tolist()))
254
+
255
+ ### audio buffer
256
+
257
+ def segments_len(self):
258
+ segments_len = sum(s.shape[0] for s in self.segments) / 16000
259
+ return segments_len
260
+
261
+ def _apply_minseglen(self):
262
+ segments_len = self.segments_len()
263
+ # wait for long enough audio to start
264
+ if segments_len < self.cfg.audio_min_len:
265
+ logger.debug("waiting for next segment")
266
+ return False
267
+ return True
268
+
269
+ def insert_audio(self, segment=None):
270
+ if segment is not None:
271
+ self.segments.append(segment)
272
+
273
+ removed_len = 0
274
+ # len of audio is bigger than buffer_len. Going to remove the first segment
275
+ segments_len = self.segments_len()
276
+ while len(self.segments) > 1 and segments_len > self.cfg.audio_max_len:
277
+ removed_len = self.segments[0].shape[0] / 16000
278
+ segments_len -= removed_len
279
+ self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len)
280
+ self.segments = self.segments[1:]
281
+ logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}")
282
+ if len(self.tokens) > 1:
283
+ self.context.append_token_ids(self.tokens[1][0,:])
284
+ self.tokens = [self.initial_tokens] + self.tokens[2:]
285
+ return removed_len
286
+
287
+ def _clean_cache(self):
288
+ '''clean the cache that stores the attention matrices and kv_cache.
289
+ It must be called every time after generation with the model.'''
290
+ # cleaning cache
291
+ self.dec_attns = []
292
+ self.kv_cache = {}
293
+ if self.decoder_type == "beam":
294
+ self.inference.kv_cache = self.kv_cache
295
+ self.token_decoder.reset()
296
+
297
+ @torch.no_grad()
298
+ def lang_id(self, encoder_features):
299
+ """Language detection from encoder features.
300
+ This code is trimmed and copy-pasted from whisper.decoding.detect_language .
301
+ """
302
+
303
+ # forward pass using a single token, startoftranscript
304
+ n_audio = encoder_features.shape[0]
305
+ x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device) # [n_audio, 1]
306
+ logits = self.model.logits(x, encoder_features)[:, 0]
307
+
308
+ # collect detected languages; suppress all non-language tokens
309
+ mask = torch.ones(logits.shape[-1], dtype=torch.bool)
310
+ mask[list(self.tokenizer.all_language_tokens)] = False
311
+ logits[:, mask] = -np.inf
312
+ language_tokens = logits.argmax(dim=-1)
313
+ language_token_probs = logits.softmax(dim=-1).cpu()
314
+ language_probs = [
315
+ {
316
+ c: language_token_probs[i, j].item()
317
+ for j, c in zip(self.tokenizer.all_language_tokens, self.tokenizer.all_language_codes)
318
+ }
319
+ for i in range(n_audio)
320
+ ]
321
+
322
+ single = encoder_features.ndim == 2
323
+ if single:
324
+ language_tokens = language_tokens[0]
325
+ language_probs = language_probs[0]
326
+
327
+ self._clean_cache()
328
+ return language_tokens, language_probs
329
+
330
+ ### transcription / translation
331
+
332
+ @torch.no_grad()
333
+ def infer(self, is_last=False):
334
+ new_segment = True
335
+ if len(self.segments) == 0:
336
+ logger.debug("No segments, nothing to do")
337
+ self.logdir_save([], [], {})
338
+ return [], {}
339
+ if not self._apply_minseglen():
340
+ logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
341
+ input_segments = torch.cat(self.segments, dim=0)
342
+ self.logdir_save(input_segments, [], {})
343
+ return [], {}
344
+
345
+ # input_segments is concatenation of audio, it's one array
346
+ if len(self.segments) > 1:
347
+ input_segments = torch.cat(self.segments, dim=0)
348
+ else:
349
+ input_segments = self.segments[0]
350
+
351
+
352
+
353
+ # mel + padding to 30s
354
+ mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
355
+ device=self.model.device).unsqueeze(0)
356
+ # trim to 3000
357
+ mel = pad_or_trim(mel_padded, N_FRAMES)
358
+
359
+ # the len of actual audio
360
+ content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
361
+
362
+ # encode
363
+ encoder_feature = self.model.encoder(mel)
364
+
365
+ # logger.debug(f"Encoder feature shape: {encoder_feature.shape}")
366
+ # if mel.shape[-2:] != (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
367
+ # logger.debug("mel ")
368
+ if self.cfg.language == "auto" and self.detected_language is None:
369
+ language_tokens, language_probs = self.lang_id(encoder_feature)
370
+ logger.debug(f"Language tokens: {language_tokens}, probs: {language_probs}")
371
+ top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
372
+ logger.info(f"Detected language: {top_lan} with p={p:.4f}")
373
+ #self.tokenizer.language = top_lan
374
+ #self.tokenizer.__post_init__()
375
+ self.create_tokenizer(top_lan)
376
+ self.detected_language = top_lan
377
+ self.init_tokens()
378
+ logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
379
+
380
+ self.trim_context()
381
+ current_tokens = self._current_tokens()
382
+ #
383
+ fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
384
+
385
+
386
+ ####################### Decoding loop
387
+ logger.info("Decoding loop starts\n")
388
+
389
+ sum_logprobs = torch.zeros(self.cfg.beam_size, device=mel.device)
390
+ completed = False
391
+
392
+ attn_of_alignment_heads = None
393
+ most_attended_frame = None
394
+
395
+ token_len_before_decoding = current_tokens.shape[1]
396
+
397
+ generation_progress = []
398
+ generation = {
399
+ "starting_tokens": BeamTokens(current_tokens[0,:].clone(), self.cfg.beam_size),
400
+ "token_len_before_decoding": token_len_before_decoding,
401
+ #"fire_detected": fire_detected,
402
+ "frames_len": content_mel_len,
403
+ "frames_threshold": 4 if is_last else self.cfg.frame_threshold,
404
+
405
+ # to be filled later
406
+ "logits_starting": None,
407
+
408
+ # to be filled later
409
+ "no_speech_prob": None,
410
+ "no_speech": False,
411
+
412
+ # to be filled in the loop
413
+ "progress": generation_progress,
414
+ }
415
+ while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
416
+ generation_progress_loop = []
417
+
418
+ if new_segment:
419
+ tokens_for_logits = current_tokens
420
+ else:
421
+ # only need to use the last token except in the first forward pass
422
+ tokens_for_logits = current_tokens[:,-1:]
423
+
424
+ logits = self.logits(tokens_for_logits, encoder_feature) # B, len(tokens), token dict size
425
+ if new_segment:
426
+ generation["logits_starting"] = Logits(logits[:,:,:])
427
+
428
+ if new_segment and self.tokenizer.no_speech is not None:
429
+ probs_at_sot = logits[:, self.sot_index, :].float().softmax(dim=-1)
430
+ no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
431
+ generation["no_speech_prob"] = no_speech_probs[0]
432
+ if no_speech_probs[0] > self.cfg.nonspeech_prob:
433
+ generation["no_speech"] = True
434
+ logger.info("no speech, stop")
435
+ break
436
+
437
+ logits = logits[:, -1, :] # logits for the last token
438
+ generation_progress_loop.append(("logits_before_suppress",Logits(logits)))
439
+
440
+ # supress blank tokens only at the beginning of the segment
441
+ if new_segment:
442
+ logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
443
+ new_segment = False
444
+ self.suppress_tokens(logits)
445
+ #generation_progress_loop.append(("logits_after_suppres",BeamLogits(logits[0,:].clone(), self.cfg.beam_size)))
446
+ generation_progress_loop.append(("logits_after_suppress",Logits(logits)))
447
+
448
+ current_tokens, completed = self.token_decoder.update(current_tokens, logits, sum_logprobs)
449
+ generation_progress_loop.append(("beam_tokens",Tokens(current_tokens[:,-1].clone())))
450
+ generation_progress_loop.append(("sum_logprobs",sum_logprobs.tolist()))
451
+ generation_progress_loop.append(("completed",completed))
452
+
453
+ logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ")
454
+ self.debug_print_tokens(current_tokens)
455
+
456
+
457
+ # if self.decoder_type == "beam":
458
+ # logger.debug(f"Finished sequences: {self.token_decoder.finished_sequences}")
459
+
460
+ # logprobs = F.log_softmax(logits.float(), dim=-1)
461
+ # idx = 0
462
+ # logger.debug(f"Beam search topk: {logprobs[idx].topk(self.cfg.beam_size + 1)}")
463
+ # logger.debug(f"Greedy search argmax: {logits.argmax(dim=-1)}")
464
+ # if completed:
465
+ # self.debug_print_tokens(current_tokens)
466
+
467
+ # logger.debug("decode stopped because decoder completed")
468
+
469
+ attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)]
470
+ for i, attn_mat in enumerate(self.dec_attns):
471
+ layer_rank = int(i % len(self.model.decoder.blocks))
472
+ align_heads_in_layer = self.align_source.get(layer_rank, [])
473
+ if len(align_heads_in_layer) == 0:
474
+ continue
475
+ for align_head_rank, head_id in align_heads_in_layer:
476
+ if self.cfg.beam_size == 1:
477
+ a = attn_mat[head_id, :, :]
478
+ a = a.unsqueeze(0)
479
+ else:
480
+ a = attn_mat[:, head_id, :, :]
481
+ attn_of_alignment_heads[align_head_rank].append(a)
482
+ tmp = []
483
+ for mat in attn_of_alignment_heads:
484
+ t = torch.cat(mat, dim=1)
485
+ tmp.append(t)
486
+ attn_of_alignment_heads = torch.stack(tmp, dim=1)
487
+ # logger.debug(str(attn_of_alignment_heads.shape) + " tttady")
488
+ std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False)
489
+ attn_of_alignment_heads = (attn_of_alignment_heads - mean) / std
490
+ attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7) # from whisper.timing
491
+ attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
492
+ # logger.debug(str(attn_of_alignment_heads.shape) + " po mean")
493
+ attn_of_alignment_heads = attn_of_alignment_heads[:,:, :content_mel_len]
494
+ # logger.debug(str(attn_of_alignment_heads.shape) + " pak ")
495
+
496
+ # for each beam, the most attended frame is:
497
+ most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1)
498
+ generation_progress_loop.append(("most_attended_frames",most_attended_frames.clone().tolist()))
499
+ logger.debug(str(most_attended_frames.tolist()) + " most att frames")
500
+
501
+ most_attended_frame = most_attended_frames[0].item()
502
+
503
+
504
+ generation_progress.append(dict(generation_progress_loop))
505
+ logger.debug("current tokens" + str(current_tokens.shape))
506
+ if completed:
507
+ # # stripping the last token, the eot
508
+ current_tokens = current_tokens[:, :-1]
509
+ break
510
+
511
+ # for some rare cases where the attention fails
512
+ if not is_last and self.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold:
513
+ # TODO: check this
514
+ if current_tokens.shape[1] > 1 and current_tokens[0, -2] >= DEC_PAD:
515
+ logger.debug("ommit rewinding from special tokens")
516
+ self.last_attend_frame = most_attended_frame
517
+ else:
518
+ logger.debug(
519
+ f"[rewind detected] current attention pos: {most_attended_frame}, "
520
+ f"last attention pos: {self.last_attend_frame}; omit this segment")
521
+ self.last_attend_frame = -self.cfg.rewind_threshold
522
+ current_tokens = torch.cat(self.tokens, dim=1) if len(self.tokens) > 0 else self.tokens[0]
523
+ break
524
+ else:
525
+ self.last_attend_frame = most_attended_frame
526
+
527
+ if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold):
528
+ logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}")
529
+ # stripping the last token, the one that is attended too close to the end
530
+ current_tokens = current_tokens[:, :-1]
531
+ break
532
+
533
+ # debug print
534
+ for i in range(self.cfg.beam_size):
535
+ logger.debug("attn: {}, current pos: {}, current token: {}({})".format(
536
+ attn_of_alignment_heads.shape if attn_of_alignment_heads is not None else None,
537
+ most_attended_frames[i],
538
+ current_tokens[i, -1].item(),
539
+ self.tokenizer.decode([current_tokens[i, -1].item()])
540
+ ))
541
+
542
+ # for k,v in generation.items():
543
+ # print(k,v,file=sys.stderr)
544
+ # for x in generation_progress:
545
+ # for y in x.items():
546
+ # print("\t\t",*y,file=sys.stderr)
547
+ # print("\t","----", file=sys.stderr)
548
+ # print("\t", "end of generation_progress_loop", file=sys.stderr)
549
+ # sys.exit(1)
550
+ ####################### End of decoding loop
551
+
552
+ logger.info("End of decoding loop")
553
+
554
+ # if attn_of_alignment_heads is not None:
555
+ # seg_len = int(segment.shape[0] / 16000 * TOKENS_PER_SECOND)
556
+
557
+ # # Lets' now consider only the top hypothesis in the beam search
558
+ # top_beam_attn_of_alignment_heads = attn_of_alignment_heads[0]
559
+
560
+ # # debug print: how is the new token attended?
561
+ # new_token_attn = top_beam_attn_of_alignment_heads[token_len_before_decoding:, -seg_len:]
562
+ # logger.debug(f"New token attention shape: {new_token_attn.shape}")
563
+ # if new_token_attn.shape[0] == 0: # it's not attended in the current audio segment
564
+ # logger.debug("no token generated")
565
+ # else: # it is, and the max attention is:
566
+ # new_token_max_attn, _ = new_token_attn.max(dim=-1)
567
+ # logger.debug(f"segment max attention: {new_token_max_attn.mean().item()/len(self.segments)}")
568
+
569
+
570
+ # let's now operate only with the top beam hypothesis
571
+ tokens_to_split = current_tokens[0, token_len_before_decoding:]
572
+ if fire_detected or is_last:
573
+ new_hypothesis = tokens_to_split.flatten().tolist()
574
+ else:
575
+ # going to truncate the tokens after the last space
576
+ split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split.tolist())
577
+ generation["result"] = {"split_words": split_words[:-1], "split_tokens": split_tokens[:-1]}
578
+ generation["result_truncated"] = {"split_words": split_words[-1:], "split_tokens": split_tokens[-1:]}
579
+
580
+ # text_to_split = self.tokenizer.decode(tokens_to_split)
581
+ # logger.debug(f"text_to_split: {text_to_split}")
582
+ # logger.debug("text at current step: {}".format(text_to_split.replace(" ", "<space>")))
583
+ # text_before_space = " ".join(text_to_split.split(" ")[:-1])
584
+ # logger.debug("before the last space: {}".format(text_before_space.replace(" ", "<space>")))
585
+ if len(split_words) > 1:
586
+ new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist]
587
+ else:
588
+ new_hypothesis = []
589
+
590
+
591
+ ### new hypothesis
592
+ logger.debug(f"new_hypothesis: {new_hypothesis}")
593
+ new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to(
594
+ device=self.model.device,
595
+ )
596
+ self.tokens.append(new_tokens)
597
+ # TODO: test if this is redundant or not
598
+ # ret = ret[ret<DEC_PAD]
599
+
600
+ logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
601
+
602
+ self._clean_cache()
603
+
604
+ self.logdir_save(input_segments, new_hypothesis, generation)
605
+ return new_hypothesis, generation
606
+
607
+ def logdir_save(self, input_segments, new_hypothesis, generation):
608
+ """The audio and result from each iteration is saved to the logdir for debugging purposes"""
609
+
610
+ # only when the logdir arg is set
611
+ if self.cfg.logdir is None:
612
+ return
613
+
614
+ self.logdir_i += 1
615
+
616
+ # every VAD segment is in a separate directory
617
+ dir = os.path.join(self.cfg.logdir, f"seg_{self.log_segments:05d}")
618
+ if not os.path.exists(dir):
619
+ os.makedirs(dir)
620
+
621
+ logger.debug(f"Saving to {dir}, iteration {self.logdir_i:05d}")
622
+
623
+ # saving wav:
624
+ wav_path = os.path.join(dir, f"iter_{self.logdir_i:05d}_audio.wav")
625
+ audio_np = np.array(input_segments)
626
+ # Ensure audio is float32 in range [-1, 1], convert to int16 for wav
627
+ if audio_np.dtype != np.int16:
628
+ audio_int16 = np.clip(audio_np * 32767, -32768, 32767).astype(np.int16)
629
+ else:
630
+ audio_int16 = audio_np
631
+
632
+ with wave.open(wav_path, "wb") as wf:
633
+ wf.setnchannels(1)
634
+ wf.setsampwidth(2) # 2 bytes for int16
635
+ wf.setframerate(16000)
636
+ wf.writeframes(audio_int16.tobytes())
637
+
638
+ # saving readable text: context + hypothesis
639
+ text = self.tokenizer.decode(new_hypothesis)
640
+ with open(os.path.join(dir, f"iter_{self.logdir_i:05d}_hypothesis.txt"), "w") as f:
641
+ if generation:
642
+ context = generation["starting_tokens"].as_text(self.tokenizer)
643
+ else:
644
+ context = ""
645
+ print("CONTEXT+FORCED:",context,sep="\t",file=f)
646
+ print("HYPOTHESIS:", text, sep="\t", file=f)
647
+
648
+ # TODO: generation progress can be also saved in a readable format
649
+ #logger.debug(f"generation progress: {generation}")
simul_whisper/whisper/__init__.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import io
3
+ import os
4
+ import urllib
5
+ import warnings
6
+ from typing import List, Optional, Union
7
+
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ from .audio import load_audio, log_mel_spectrogram, pad_or_trim
12
+ from .decoding import DecodingOptions, DecodingResult, decode, detect_language
13
+ from .model import ModelDimensions, Whisper
14
+ from .transcribe import transcribe
15
+ from .version import __version__
16
+
17
+ _MODELS = {
18
+ "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
19
+ "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
20
+ "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
21
+ "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
22
+ "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
23
+ "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
24
+ "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
25
+ "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
26
+ "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
27
+ "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
28
+ "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
29
+ "large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
30
+ "large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
31
+ "turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
32
+ }
33
+
34
+ # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
35
+ # highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
36
+ _ALIGNMENT_HEADS = {
37
+ "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
38
+ "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
39
+ "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
40
+ "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
41
+ "small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
42
+ "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
43
+ "medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
44
+ "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
45
+ "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
46
+ "large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
47
+ "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
48
+ "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
49
+ "large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
50
+ "turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
51
+ }
52
+
53
+
54
+ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
55
+ os.makedirs(root, exist_ok=True)
56
+
57
+ expected_sha256 = url.split("/")[-2]
58
+ download_target = os.path.join(root, os.path.basename(url))
59
+
60
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
61
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
62
+
63
+ if os.path.isfile(download_target):
64
+ with open(download_target, "rb") as f:
65
+ model_bytes = f.read()
66
+ if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
67
+ return model_bytes if in_memory else download_target
68
+ else:
69
+ warnings.warn(
70
+ f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
71
+ )
72
+
73
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
74
+ with tqdm(
75
+ total=int(source.info().get("Content-Length")),
76
+ ncols=80,
77
+ unit="iB",
78
+ unit_scale=True,
79
+ unit_divisor=1024,
80
+ ) as loop:
81
+ while True:
82
+ buffer = source.read(8192)
83
+ if not buffer:
84
+ break
85
+
86
+ output.write(buffer)
87
+ loop.update(len(buffer))
88
+
89
+ model_bytes = open(download_target, "rb").read()
90
+ if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
91
+ raise RuntimeError(
92
+ "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
93
+ )
94
+
95
+ return model_bytes if in_memory else download_target
96
+
97
+
98
+ def available_models() -> List[str]:
99
+ """Returns the names of available models"""
100
+ return list(_MODELS.keys())
101
+
102
+
103
+ def load_model(
104
+ name: str,
105
+ device: Optional[Union[str, torch.device]] = None,
106
+ download_root: str = None,
107
+ in_memory: bool = False,
108
+ ) -> Whisper:
109
+ """
110
+ Load a Whisper ASR model
111
+
112
+ Parameters
113
+ ----------
114
+ name : str
115
+ one of the official model names listed by `whisper.available_models()`, or
116
+ path to a model checkpoint containing the model dimensions and the model state_dict.
117
+ device : Union[str, torch.device]
118
+ the PyTorch device to put the model into
119
+ download_root: str
120
+ path to download the model files; by default, it uses "~/.cache/whisper"
121
+ in_memory: bool
122
+ whether to preload the model weights into host memory
123
+
124
+ Returns
125
+ -------
126
+ model : Whisper
127
+ The Whisper ASR model instance
128
+ """
129
+
130
+ if device is None:
131
+ device = "cuda" if torch.cuda.is_available() else "cpu"
132
+ if download_root is None:
133
+ default = os.path.join(os.path.expanduser("~"), ".cache")
134
+ download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
135
+
136
+ if name in _MODELS:
137
+ checkpoint_file = _download(_MODELS[name], download_root, in_memory)
138
+ alignment_heads = _ALIGNMENT_HEADS[name]
139
+ elif os.path.isfile(name):
140
+ checkpoint_file = open(name, "rb").read() if in_memory else name
141
+ alignment_heads = None
142
+ else:
143
+ raise RuntimeError(
144
+ f"Model {name} not found; available models = {available_models()}"
145
+ )
146
+
147
+ with (
148
+ io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
149
+ ) as fp:
150
+ checkpoint = torch.load(fp, map_location=device)
151
+ del checkpoint_file
152
+
153
+ dims = ModelDimensions(**checkpoint["dims"])
154
+ model = Whisper(dims)
155
+ model.load_state_dict(checkpoint["model_state_dict"])
156
+
157
+ if alignment_heads is not None:
158
+ model.set_alignment_heads(alignment_heads)
159
+
160
+ return model.to(device)
simul_whisper/whisper/__main__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .transcribe import cli
2
+
3
+ cli()
simul_whisper/whisper/assets/gpt2.tiktoken ADDED
The diff for this file is too large to render. See raw diff
 
simul_whisper/whisper/assets/mel_filters.npz ADDED
Binary file (4.27 kB). View file
 
simul_whisper/whisper/assets/multilingual.tiktoken ADDED
The diff for this file is too large to render. See raw diff
 
simul_whisper/whisper/audio.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import lru_cache
3
+ from subprocess import CalledProcessError, run
4
+ from typing import Optional, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from .utils import exact_div
11
+
12
+ # hard-coded audio hyperparameters
13
+ SAMPLE_RATE = 16000
14
+ N_FFT = 400
15
+ HOP_LENGTH = 160
16
+ CHUNK_LENGTH = 30
17
+ N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
18
+ N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
19
+
20
+ N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
21
+ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
22
+ TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
23
+
24
+
25
+ def load_audio(file: str, sr: int = SAMPLE_RATE):
26
+ """
27
+ Open an audio file and read as mono waveform, resampling as necessary
28
+
29
+ Parameters
30
+ ----------
31
+ file: str
32
+ The audio file to open
33
+
34
+ sr: int
35
+ The sample rate to resample the audio if necessary
36
+
37
+ Returns
38
+ -------
39
+ A NumPy array containing the audio waveform, in float32 dtype.
40
+ """
41
+
42
+ # This launches a subprocess to decode audio while down-mixing
43
+ # and resampling as necessary. Requires the ffmpeg CLI in PATH.
44
+ # fmt: off
45
+ cmd = [
46
+ "ffmpeg",
47
+ "-nostdin",
48
+ "-threads", "0",
49
+ "-i", file,
50
+ "-f", "s16le",
51
+ "-ac", "1",
52
+ "-acodec", "pcm_s16le",
53
+ "-ar", str(sr),
54
+ "-"
55
+ ]
56
+ # fmt: on
57
+ try:
58
+ out = run(cmd, capture_output=True, check=True).stdout
59
+ except CalledProcessError as e:
60
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
61
+
62
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
63
+
64
+
65
+ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
66
+ """
67
+ Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
68
+ """
69
+ if torch.is_tensor(array):
70
+ if array.shape[axis] > length:
71
+ array = array.index_select(
72
+ dim=axis, index=torch.arange(length, device=array.device)
73
+ )
74
+
75
+ if array.shape[axis] < length:
76
+ pad_widths = [(0, 0)] * array.ndim
77
+ pad_widths[axis] = (0, length - array.shape[axis])
78
+ array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
79
+ else:
80
+ if array.shape[axis] > length:
81
+ array = array.take(indices=range(length), axis=axis)
82
+
83
+ if array.shape[axis] < length:
84
+ pad_widths = [(0, 0)] * array.ndim
85
+ pad_widths[axis] = (0, length - array.shape[axis])
86
+ array = np.pad(array, pad_widths)
87
+
88
+ return array
89
+
90
+
91
+ @lru_cache(maxsize=None)
92
+ def mel_filters(device, n_mels: int) -> torch.Tensor:
93
+ """
94
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
95
+ Allows decoupling librosa dependency; saved using:
96
+
97
+ np.savez_compressed(
98
+ "mel_filters.npz",
99
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
100
+ mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
101
+ )
102
+ """
103
+ assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
104
+
105
+ filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
106
+ with np.load(filters_path, allow_pickle=False) as f:
107
+ return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
108
+
109
+
110
+ def log_mel_spectrogram(
111
+ audio: Union[str, np.ndarray, torch.Tensor],
112
+ n_mels: int = 80,
113
+ padding: int = 0,
114
+ device: Optional[Union[str, torch.device]] = None,
115
+ ):
116
+ """
117
+ Compute the log-Mel spectrogram of
118
+
119
+ Parameters
120
+ ----------
121
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
122
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
123
+
124
+ n_mels: int
125
+ The number of Mel-frequency filters, only 80 and 128 are supported
126
+
127
+ padding: int
128
+ Number of zero samples to pad to the right
129
+
130
+ device: Optional[Union[str, torch.device]]
131
+ If given, the audio tensor is moved to this device before STFT
132
+
133
+ Returns
134
+ -------
135
+ torch.Tensor, shape = (n_mels, n_frames)
136
+ A Tensor that contains the Mel spectrogram
137
+ """
138
+ if not torch.is_tensor(audio):
139
+ if isinstance(audio, str):
140
+ audio = load_audio(audio)
141
+ audio = torch.from_numpy(audio)
142
+
143
+ if device is not None:
144
+ audio = audio.to(device)
145
+ if padding > 0:
146
+ audio = F.pad(audio, (0, padding))
147
+ window = torch.hann_window(N_FFT).to(audio.device)
148
+ stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
149
+ magnitudes = stft[..., :-1].abs() ** 2
150
+
151
+ filters = mel_filters(audio.device, n_mels)
152
+ mel_spec = filters @ magnitudes
153
+
154
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
155
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
156
+ log_spec = (log_spec + 4.0) / 4.0
157
+ return log_spec
simul_whisper/whisper/decoding.py ADDED
@@ -0,0 +1,833 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field, replace
2
+ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+ from torch.distributions import Categorical
9
+
10
+ from .audio import CHUNK_LENGTH
11
+ from .tokenizer import Tokenizer, get_tokenizer
12
+ from .utils import compression_ratio
13
+
14
+ if TYPE_CHECKING:
15
+ from .model import Whisper
16
+
17
+
18
+ @torch.no_grad()
19
+ def detect_language(
20
+ model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
21
+ ) -> Tuple[Tensor, List[dict]]:
22
+ """
23
+ Detect the spoken language in the audio, and return them as list of strings, along with the ids
24
+ of the most probable language tokens and the probability distribution over all language tokens.
25
+ This is performed outside the main decode loop in order to not interfere with kv-caching.
26
+
27
+ Returns
28
+ -------
29
+ language_tokens : Tensor, shape = (n_audio,)
30
+ ids of the most probable language tokens, which appears after the startoftranscript token.
31
+ language_probs : List[Dict[str, float]], length = n_audio
32
+ list of dictionaries containing the probability distribution over all languages.
33
+ """
34
+ if tokenizer is None:
35
+ tokenizer = get_tokenizer(model.is_multilingual)
36
+ if (
37
+ tokenizer.language is None
38
+ or tokenizer.language_token not in tokenizer.sot_sequence
39
+ ):
40
+ raise ValueError(
41
+ "This model doesn't have language tokens so it can't perform lang id"
42
+ )
43
+
44
+ single = mel.ndim == 2
45
+ if single:
46
+ mel = mel.unsqueeze(0)
47
+
48
+ # skip encoder forward pass if already-encoded audio features were given
49
+ if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
50
+ mel = model.encoder(mel)
51
+
52
+ # forward pass using a single token, startoftranscript
53
+ n_audio = mel.shape[0]
54
+ x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
55
+ logits = model.logits(x, mel)[:, 0]
56
+
57
+ # collect detected languages; suppress all non-language tokens
58
+ mask = torch.ones(logits.shape[-1], dtype=torch.bool)
59
+ mask[list(tokenizer.all_language_tokens)] = False
60
+ logits[:, mask] = -np.inf
61
+ language_tokens = logits.argmax(dim=-1)
62
+ language_token_probs = logits.softmax(dim=-1).cpu()
63
+ language_probs = [
64
+ {
65
+ c: language_token_probs[i, j].item()
66
+ for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
67
+ }
68
+ for i in range(n_audio)
69
+ ]
70
+
71
+ if single:
72
+ language_tokens = language_tokens[0]
73
+ language_probs = language_probs[0]
74
+
75
+ return language_tokens, language_probs
76
+
77
+
78
+ @dataclass(frozen=True)
79
+ class DecodingOptions:
80
+ # whether to perform X->X "transcribe" or X->English "translate"
81
+ task: str = "transcribe"
82
+
83
+ # language that the audio is in; uses detected language if None
84
+ language: Optional[str] = None
85
+
86
+ # sampling-related options
87
+ temperature: float = 0.0
88
+ sample_len: Optional[int] = None # maximum number of tokens to sample
89
+ best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
90
+ beam_size: Optional[int] = None # number of beams in beam search, if t == 0
91
+ patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
92
+
93
+ # "alpha" in Google NMT, or None for length norm, when ranking generations
94
+ # to select which to return among the beams or best-of-N samples
95
+ length_penalty: Optional[float] = None
96
+
97
+ # text or tokens to feed as the prompt or the prefix; for more info:
98
+ # https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
99
+ prompt: Optional[Union[str, List[int]]] = None # for the previous context
100
+ prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
101
+
102
+ # list of tokens ids (or comma-separated token ids) to suppress
103
+ # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
104
+ suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
105
+ suppress_blank: bool = True # this will suppress blank outputs
106
+
107
+ # timestamp sampling options
108
+ without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
109
+ max_initial_timestamp: Optional[float] = 1.0
110
+
111
+ # implementation details
112
+ fp16: bool = True # use fp16 for most of the calculation
113
+
114
+ # streaming
115
+ add_sot: Optional[bool] = True
116
+
117
+
118
+ @dataclass(frozen=True)
119
+ class DecodingResult:
120
+ audio_features: Tensor
121
+ language: str
122
+ language_probs: Optional[Dict[str, float]] = None
123
+ tokens: List[int] = field(default_factory=list)
124
+ text: str = ""
125
+ avg_logprob: float = np.nan
126
+ no_speech_prob: float = np.nan
127
+ temperature: float = np.nan
128
+ compression_ratio: float = np.nan
129
+
130
+
131
+ class Inference:
132
+ def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
133
+ """Perform a forward pass on the decoder and return per-token logits"""
134
+ raise NotImplementedError
135
+
136
+ def rearrange_kv_cache(self, source_indices) -> None:
137
+ """Update the key-value cache according to the updated beams"""
138
+ raise NotImplementedError
139
+
140
+ def cleanup_caching(self) -> None:
141
+ """Clean up any resources or hooks after decoding is finished"""
142
+ pass
143
+
144
+
145
+ class PyTorchInference(Inference):
146
+ def __init__(self, model: "Whisper", initial_token_length: int):
147
+ self.model: "Whisper" = model
148
+ self.initial_token_length = initial_token_length
149
+ self.kv_cache = {}
150
+ self.hooks = []
151
+
152
+ key_modules = [block.attn.key for block in self.model.decoder.blocks]
153
+ value_modules = [block.attn.value for block in self.model.decoder.blocks]
154
+ self.kv_modules = key_modules + value_modules
155
+
156
+ def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
157
+ if not self.kv_cache:
158
+ self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
159
+
160
+ if tokens.shape[-1] > self.initial_token_length:
161
+ # only need to use the last token except in the first forward pass
162
+ tokens = tokens[:, -1:]
163
+
164
+ return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
165
+
166
+ def cleanup_caching(self):
167
+ for hook in self.hooks:
168
+ hook.remove()
169
+
170
+ self.kv_cache = {}
171
+ self.hooks = []
172
+
173
+ def rearrange_kv_cache(self, source_indices):
174
+ if source_indices != list(range(len(source_indices))):
175
+ for module in self.kv_modules:
176
+ # update the key/value cache to contain the selected sequences
177
+ self.kv_cache[module] = self.kv_cache[module][source_indices].detach()
178
+
179
+
180
+ class SequenceRanker:
181
+ def rank(
182
+ self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
183
+ ) -> List[int]:
184
+ """
185
+ Given a list of groups of samples and their cumulative log probabilities,
186
+ return the indices of the samples in each group to select as the final result
187
+ """
188
+ raise NotImplementedError
189
+
190
+
191
+ class MaximumLikelihoodRanker(SequenceRanker):
192
+ """
193
+ Select the sample with the highest log probabilities, penalized using either
194
+ a simple length normalization or Google NMT paper's length penalty
195
+ """
196
+
197
+ def __init__(self, length_penalty: Optional[float]):
198
+ self.length_penalty = length_penalty
199
+
200
+ def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
201
+ def scores(logprobs, lengths):
202
+ result = []
203
+ for logprob, length in zip(logprobs, lengths):
204
+ if self.length_penalty is None:
205
+ penalty = length
206
+ else:
207
+ # from the Google NMT paper
208
+ penalty = ((5 + length) / 6) ** self.length_penalty
209
+ result.append(logprob / penalty)
210
+ return result
211
+
212
+ # get the sequence with the highest score
213
+ lengths = [[len(t) for t in s] for s in tokens]
214
+ return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
215
+
216
+
217
+ class TokenDecoder:
218
+ def reset(self):
219
+ """Initialize any stateful variables for decoding a new sequence"""
220
+
221
+ def update(
222
+ self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
223
+ ) -> Tuple[Tensor, bool]:
224
+ """Specify how to select the next token, based on the current trace and logits
225
+
226
+ Parameters
227
+ ----------
228
+ tokens : Tensor, shape = (n_batch, current_sequence_length)
229
+ all tokens in the context so far, including the prefix and sot_sequence tokens
230
+
231
+ logits : Tensor, shape = (n_batch, vocab_size)
232
+ per-token logits of the probability distribution at the current step
233
+
234
+ sum_logprobs : Tensor, shape = (n_batch)
235
+ cumulative log probabilities for each sequence
236
+
237
+ Returns
238
+ -------
239
+ tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
240
+ the tokens, appended with the selected next token
241
+
242
+ completed : bool
243
+ True if all sequences has reached the end of text
244
+
245
+ """
246
+ raise NotImplementedError
247
+
248
+ def finalize(
249
+ self, tokens: Tensor, sum_logprobs: Tensor
250
+ ) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
251
+ """Finalize search and return the final candidate sequences
252
+
253
+ Parameters
254
+ ----------
255
+ tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
256
+ all tokens in the context so far, including the prefix and sot_sequence
257
+
258
+ sum_logprobs : Tensor, shape = (n_audio, n_group)
259
+ cumulative log probabilities for each sequence
260
+
261
+ Returns
262
+ -------
263
+ tokens : Sequence[Sequence[Tensor]], length = n_audio
264
+ sequence of Tensors containing candidate token sequences, for each audio input
265
+
266
+ sum_logprobs : List[List[float]], length = n_audio
267
+ sequence of cumulative log probabilities corresponding to the above
268
+
269
+ """
270
+ raise NotImplementedError
271
+
272
+
273
+ class GreedyDecoder(TokenDecoder):
274
+ def __init__(self, temperature: float, eot: int):
275
+ self.temperature = temperature
276
+ self.eot = eot
277
+
278
+ def update(
279
+ self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
280
+ ) -> Tuple[Tensor, bool]:
281
+ if self.temperature == 0:
282
+ next_tokens = logits.argmax(dim=-1)
283
+ else:
284
+ next_tokens = Categorical(logits=logits / self.temperature).sample()
285
+
286
+ logprobs = F.log_softmax(logits.float(), dim=-1)
287
+ current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
288
+ sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
289
+
290
+ next_tokens[tokens[:, -1] == self.eot] = self.eot
291
+ tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
292
+
293
+ completed = (tokens[:, -1] == self.eot).all()
294
+ return tokens, completed
295
+
296
+ def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
297
+ # make sure each sequence has at least one EOT token at the end
298
+ tokens = F.pad(tokens, (0, 1), value=self.eot)
299
+ return tokens, sum_logprobs.tolist()
300
+
301
+
302
+ class BeamSearchDecoder(TokenDecoder):
303
+ def __init__(
304
+ self,
305
+ beam_size: int,
306
+ eot: int,
307
+ inference: Inference,
308
+ patience: Optional[float] = None,
309
+ ):
310
+ self.beam_size = beam_size
311
+ self.eot = eot
312
+ self.inference = inference
313
+ self.patience = patience or 1.0
314
+ self.max_candidates: int = round(beam_size * self.patience)
315
+ self.finished_sequences = None
316
+
317
+ assert (
318
+ self.max_candidates > 0
319
+ ), f"Invalid beam size ({beam_size}) or patience ({patience})"
320
+
321
+ def reset(self):
322
+ self.finished_sequences = None
323
+
324
+ def update(
325
+ self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
326
+ ) -> Tuple[Tensor, bool]:
327
+ if tokens.shape[0] % self.beam_size != 0:
328
+ raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
329
+
330
+ n_audio = tokens.shape[0] // self.beam_size
331
+ if self.finished_sequences is None: # for the first update
332
+ self.finished_sequences = [{} for _ in range(n_audio)]
333
+
334
+ logprobs = F.log_softmax(logits.float(), dim=-1)
335
+ next_tokens, source_indices, finished_sequences = [], [], []
336
+ for i in range(n_audio):
337
+ scores, sources, finished = {}, {}, {}
338
+
339
+ # STEP 1: calculate the cumulative log probabilities for possible candidates
340
+ for j in range(self.beam_size):
341
+ idx = i * self.beam_size + j
342
+ prefix = tokens[idx].tolist()
343
+ for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
344
+ new_logprob = (sum_logprobs[idx] + logprob).item()
345
+ sequence = tuple(prefix + [token.item()])
346
+ scores[sequence] = new_logprob
347
+ sources[sequence] = idx
348
+
349
+ # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
350
+ saved = 0
351
+ for sequence in sorted(scores, key=scores.get, reverse=True):
352
+ if sequence[-1] == self.eot:
353
+ finished[sequence] = scores[sequence]
354
+ else:
355
+ sum_logprobs[len(next_tokens)] = scores[sequence]
356
+ next_tokens.append(sequence)
357
+ source_indices.append(sources[sequence])
358
+
359
+ saved += 1
360
+ if saved == self.beam_size:
361
+ break
362
+
363
+ finished_sequences.append(finished)
364
+
365
+ tokens = torch.tensor(next_tokens, device=tokens.device)
366
+ self.inference.rearrange_kv_cache(source_indices)
367
+
368
+ # add newly finished sequences to self.finished_sequences
369
+ assert len(self.finished_sequences) == len(finished_sequences)
370
+ for previously_finished, newly_finished in zip(
371
+ self.finished_sequences, finished_sequences
372
+ ):
373
+ for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
374
+ if len(previously_finished) >= self.max_candidates:
375
+ break # the candidate list is full
376
+ previously_finished[seq] = newly_finished[seq]
377
+
378
+ # mark as completed if all audio has enough number of samples
379
+ completed = all(
380
+ len(sequences) >= self.max_candidates
381
+ for sequences in self.finished_sequences
382
+ )
383
+ return tokens, completed
384
+
385
+ def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
386
+ # collect all finished sequences, including patience, and add unfinished ones if not enough
387
+ sum_logprobs = sum_logprobs.cpu()
388
+ for i, sequences in enumerate(self.finished_sequences):
389
+ if (
390
+ len(sequences) < self.beam_size
391
+ ): # when not enough sequences are finished
392
+ for j in list(np.argsort(sum_logprobs[i]))[::-1]:
393
+ sequence = preceding_tokens[i, j].tolist() + [self.eot]
394
+ sequences[tuple(sequence)] = sum_logprobs[i][j].item()
395
+ if len(sequences) >= self.beam_size:
396
+ break
397
+
398
+ tokens: List[List[Tensor]] = [
399
+ [torch.tensor(seq) for seq in sequences.keys()]
400
+ for sequences in self.finished_sequences
401
+ ]
402
+ sum_logprobs: List[List[float]] = [
403
+ list(sequences.values()) for sequences in self.finished_sequences
404
+ ]
405
+ return tokens, sum_logprobs
406
+
407
+
408
+ class LogitFilter:
409
+ def apply(self, logits: Tensor, tokens: Tensor) -> None:
410
+ """Apply any filtering or masking to logits in-place
411
+
412
+ Parameters
413
+ ----------
414
+ logits : Tensor, shape = (n_batch, vocab_size)
415
+ per-token logits of the probability distribution at the current step
416
+
417
+ tokens : Tensor, shape = (n_batch, current_sequence_length)
418
+ all tokens in the context so far, including the prefix and sot_sequence tokens
419
+
420
+ """
421
+ raise NotImplementedError
422
+
423
+
424
+ class SuppressBlank(LogitFilter):
425
+ def __init__(self, tokenizer: Tokenizer, sample_begin: int):
426
+ self.tokenizer = tokenizer
427
+ self.sample_begin = sample_begin
428
+
429
+ def apply(self, logits: Tensor, tokens: Tensor):
430
+ if tokens.shape[1] == self.sample_begin:
431
+ logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
432
+
433
+
434
+ class SuppressTokens(LogitFilter):
435
+ def __init__(self, suppress_tokens: Sequence[int]):
436
+ self.suppress_tokens = list(suppress_tokens)
437
+
438
+ def apply(self, logits: Tensor, tokens: Tensor):
439
+ logits[:, self.suppress_tokens] = -np.inf
440
+
441
+
442
+ class ApplyTimestampRules(LogitFilter):
443
+ def __init__(
444
+ self,
445
+ tokenizer: Tokenizer,
446
+ sample_begin: int,
447
+ max_initial_timestamp_index: Optional[int],
448
+ ):
449
+ self.tokenizer = tokenizer
450
+ self.sample_begin = sample_begin
451
+ self.max_initial_timestamp_index = max_initial_timestamp_index
452
+
453
+ def apply(self, logits: Tensor, tokens: Tensor):
454
+ # suppress <|notimestamps|> which is handled by without_timestamps
455
+ if self.tokenizer.no_timestamps is not None:
456
+ logits[:, self.tokenizer.no_timestamps] = -np.inf
457
+
458
+ # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
459
+ for k in range(tokens.shape[0]):
460
+ sampled_tokens = tokens[k, self.sample_begin :]
461
+ seq = [t for t in sampled_tokens.tolist()]
462
+ last_was_timestamp = (
463
+ len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
464
+ )
465
+ penultimate_was_timestamp = (
466
+ len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
467
+ )
468
+
469
+ if last_was_timestamp:
470
+ if penultimate_was_timestamp: # has to be non-timestamp
471
+ logits[k, self.tokenizer.timestamp_begin :] = -np.inf
472
+ else: # cannot be normal text tokens
473
+ logits[k, : self.tokenizer.eot] = -np.inf
474
+
475
+ timestamps = sampled_tokens[
476
+ sampled_tokens.ge(self.tokenizer.timestamp_begin)
477
+ ]
478
+ if timestamps.numel() > 0:
479
+ # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
480
+ # also force each segment to have a nonzero length, to prevent infinite looping
481
+ if last_was_timestamp and not penultimate_was_timestamp:
482
+ timestamp_last = timestamps[-1]
483
+ else:
484
+ timestamp_last = timestamps[-1] + 1
485
+ logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf
486
+
487
+ if tokens.shape[1] == self.sample_begin:
488
+ # suppress generating non-timestamp tokens at the beginning
489
+ logits[:, : self.tokenizer.timestamp_begin] = -np.inf
490
+
491
+ # apply the `max_initial_timestamp` option
492
+ if self.max_initial_timestamp_index is not None:
493
+ last_allowed = (
494
+ self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
495
+ )
496
+ logits[:, last_allowed + 1 :] = -np.inf
497
+
498
+ # if sum of probability over timestamps is above any other token, sample timestamp
499
+ logprobs = F.log_softmax(logits.float(), dim=-1)
500
+ for k in range(tokens.shape[0]):
501
+ timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
502
+ dim=-1
503
+ )
504
+ max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
505
+ if timestamp_logprob > max_text_token_logprob:
506
+ logits[k, : self.tokenizer.timestamp_begin] = -np.inf
507
+
508
+
509
+ class DecodingTask:
510
+ inference: Inference
511
+ sequence_ranker: SequenceRanker
512
+ decoder: TokenDecoder
513
+ logit_filters: List[LogitFilter]
514
+
515
+ def __init__(self, model: "Whisper", options: DecodingOptions):
516
+ self.options: DecodingOptions = self._verify_options(options)
517
+ if self.options.fp16:
518
+ self.model = model.half()
519
+ else:
520
+ self.model = model
521
+
522
+ language = options.language or "en"
523
+ tokenizer = get_tokenizer(
524
+ model.is_multilingual, language=language, task=options.task
525
+ )
526
+ self.tokenizer: Tokenizer = tokenizer
527
+
528
+ # print(self.options)
529
+
530
+ self.n_group: int = options.beam_size or options.best_of or 1
531
+ self.n_ctx: int = model.dims.n_text_ctx
532
+ self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
533
+
534
+ self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
535
+ if self.options.without_timestamps:
536
+ self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
537
+
538
+ self.initial_tokens: Tuple[int] = self._get_initial_tokens()
539
+ self.sample_begin: int = len(self.initial_tokens)
540
+ self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
541
+
542
+ # inference: implements the forward pass through the decoder, including kv caching
543
+ self.inference = PyTorchInference(model, len(self.initial_tokens))
544
+
545
+ # sequence ranker: implements how to rank a group of sampled sequences
546
+ self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
547
+
548
+ # decoder: implements how to select the next tokens, given the autoregressive distribution
549
+ if options.beam_size is not None:
550
+ self.decoder = BeamSearchDecoder(
551
+ options.beam_size, tokenizer.eot, self.inference, options.patience
552
+ )
553
+ else:
554
+ self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
555
+
556
+ # logit filters: applies various rules to suppress or penalize certain tokens
557
+ self.logit_filters = []
558
+ if self.options.suppress_blank:
559
+ self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
560
+ if self.options.suppress_tokens:
561
+ self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
562
+ if not options.without_timestamps:
563
+ precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
564
+ max_initial_timestamp_index = None
565
+ if options.max_initial_timestamp:
566
+ max_initial_timestamp_index = round(
567
+ self.options.max_initial_timestamp / precision
568
+ )
569
+ self.logit_filters.append(
570
+ ApplyTimestampRules(
571
+ tokenizer, self.sample_begin, max_initial_timestamp_index
572
+ )
573
+ )
574
+
575
+ def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
576
+ if options.beam_size is not None and options.best_of is not None:
577
+ raise ValueError("beam_size and best_of can't be given together")
578
+ if options.temperature == 0:
579
+ if options.best_of is not None:
580
+ raise ValueError("best_of with greedy sampling (T=0) is not compatible")
581
+ if options.patience is not None and options.beam_size is None:
582
+ raise ValueError("patience requires beam_size to be given")
583
+ if options.length_penalty is not None and not (
584
+ 0 <= options.length_penalty <= 1
585
+ ):
586
+ raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
587
+
588
+ return options
589
+
590
+ def _get_initial_tokens(self) -> Tuple[int]:
591
+ tokens = list(self.sot_sequence)
592
+ # print("prefix", prefix)
593
+ if prefix := self.options.prefix:
594
+ prefix_tokens = (
595
+ self.tokenizer.encode(" " + prefix.strip())
596
+ if isinstance(prefix, str)
597
+ else prefix
598
+ )
599
+ if self.sample_len is not None:
600
+ max_prefix_len = self.n_ctx // 2 - self.sample_len
601
+ prefix_tokens = prefix_tokens[-max_prefix_len:]
602
+ tokens = tokens + prefix_tokens
603
+
604
+ if prompt := self.options.prompt:
605
+ prompt_tokens = (
606
+ self.tokenizer.encode(" " + prompt.strip())
607
+ if isinstance(prompt, str)
608
+ else prompt
609
+ )
610
+ # if self.options.add_sot:
611
+ tokens = (
612
+ [self.tokenizer.sot_prev]
613
+ + prompt_tokens[-(self.n_ctx // 2 - 1) :]
614
+ + tokens
615
+ )
616
+ #else:
617
+ # tokens = ([self.tokenizer.sot_prev] + tokens + prompt_tokens[-(self.n_ctx // 2 - 1) :])
618
+ # print("return", tokens)
619
+ return tuple(tokens)
620
+
621
+ def _get_suppress_tokens(self) -> Tuple[int]:
622
+ suppress_tokens = self.options.suppress_tokens
623
+
624
+ if isinstance(suppress_tokens, str):
625
+ suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
626
+
627
+ if -1 in suppress_tokens:
628
+ suppress_tokens = [t for t in suppress_tokens if t >= 0]
629
+ suppress_tokens.extend(self.tokenizer.non_speech_tokens)
630
+ elif suppress_tokens is None or len(suppress_tokens) == 0:
631
+ suppress_tokens = [] # interpret empty string as an empty list
632
+ else:
633
+ assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
634
+
635
+ suppress_tokens.extend(
636
+ [
637
+ self.tokenizer.transcribe,
638
+ self.tokenizer.translate,
639
+ self.tokenizer.sot,
640
+ self.tokenizer.sot_prev,
641
+ self.tokenizer.sot_lm,
642
+ ]
643
+ )
644
+ if self.tokenizer.no_speech is not None:
645
+ # no-speech probability is collected separately
646
+ suppress_tokens.append(self.tokenizer.no_speech)
647
+
648
+ return tuple(sorted(set(suppress_tokens)))
649
+
650
+ def _get_audio_features(self, mel: Tensor):
651
+ if self.options.fp16:
652
+ mel = mel.half()
653
+
654
+ if mel.shape[-2:] == (
655
+ self.model.dims.n_audio_ctx,
656
+ self.model.dims.n_audio_state,
657
+ ):
658
+ # encoded audio features are given; skip audio encoding
659
+ audio_features = mel
660
+ else:
661
+ audio_features = self.model.encoder(mel)
662
+
663
+ if audio_features.dtype != (
664
+ torch.float16 if self.options.fp16 else torch.float32
665
+ ):
666
+ raise TypeError(
667
+ f"audio_features has an incorrect dtype: {audio_features.dtype}"
668
+ )
669
+
670
+ return audio_features
671
+
672
+ def _detect_language(self, audio_features: Tensor, tokens: Tensor):
673
+ languages = [self.options.language] * audio_features.shape[0]
674
+ lang_probs = None
675
+
676
+ if self.options.language is None or self.options.task == "lang_id":
677
+ lang_tokens, lang_probs = self.model.detect_language(
678
+ audio_features, self.tokenizer
679
+ )
680
+ languages = [max(probs, key=probs.get) for probs in lang_probs]
681
+ if self.options.language is None:
682
+ tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
683
+
684
+ return languages, lang_probs
685
+
686
+ def _main_loop(self, audio_features: Tensor, tokens: Tensor):
687
+ n_batch = tokens.shape[0]
688
+ sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
689
+ no_speech_probs = [np.nan] * n_batch
690
+
691
+ try:
692
+ for i in range(self.sample_len): # 最多循环448次
693
+ # print("in decode main loop", i , tokens[0].tolist())
694
+ logits = self.inference.logits(tokens, audio_features)
695
+ # print(logits)
696
+ if (
697
+ i == 0 and self.tokenizer.no_speech is not None
698
+ ): # save no_speech_probs
699
+ probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
700
+ no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
701
+
702
+ # now we need to consider the logits at the last token only
703
+ logits = logits[:, -1]
704
+
705
+ # apply the logit filters, e.g. for suppressing or applying penalty to
706
+ for logit_filter in self.logit_filters:
707
+ logit_filter.apply(logits, tokens)
708
+
709
+ # expand the tokens tensor with the selected next tokens
710
+ tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
711
+
712
+ if completed or tokens.shape[-1] > self.n_ctx:
713
+ break
714
+ finally:
715
+ self.inference.cleanup_caching()
716
+
717
+ return tokens, sum_logprobs, no_speech_probs
718
+
719
+ @torch.no_grad()
720
+ def run(self, mel: Tensor) -> List[DecodingResult]:
721
+ self.decoder.reset()
722
+ tokenizer: Tokenizer = self.tokenizer
723
+ n_audio: int = mel.shape[0]
724
+
725
+ audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
726
+ tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
727
+ # print("initial_tokens", self.initial_tokens)
728
+ # detect language if requested, overwriting the language token
729
+ languages, language_probs = self._detect_language(audio_features, tokens)
730
+ if self.options.task == "lang_id":
731
+ return [
732
+ DecodingResult(
733
+ audio_features=features, language=language, language_probs=probs
734
+ )
735
+ for features, language, probs in zip(
736
+ audio_features, languages, language_probs
737
+ )
738
+ ]
739
+
740
+ # repeat text tensors by the group size, for beam search or best-of-n sampling
741
+ tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
742
+
743
+ # call the main sampling loop
744
+ tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
745
+
746
+ # reshape the tensors to have (n_audio, n_group) as the first two dimensions
747
+ audio_features = audio_features[:: self.n_group]
748
+ no_speech_probs = no_speech_probs[:: self.n_group]
749
+ assert audio_features.shape[0] == len(no_speech_probs) == n_audio
750
+
751
+ tokens = tokens.reshape(n_audio, self.n_group, -1)
752
+ sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
753
+
754
+ # get the final candidates for each group, and slice between the first sampled token and EOT
755
+ tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
756
+ tokens: List[List[Tensor]] = [
757
+ [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
758
+ for s in tokens
759
+ ]
760
+
761
+ # select the top-ranked sample in each group
762
+ selected = self.sequence_ranker.rank(tokens, sum_logprobs)
763
+ tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
764
+ texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
765
+
766
+ sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
767
+ avg_logprobs: List[float] = [
768
+ lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
769
+ ]
770
+
771
+ fields = (
772
+ texts,
773
+ languages,
774
+ tokens,
775
+ audio_features,
776
+ avg_logprobs,
777
+ no_speech_probs,
778
+ )
779
+ if len(set(map(len, fields))) != 1:
780
+ raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
781
+
782
+ return [
783
+ DecodingResult(
784
+ audio_features=features,
785
+ language=language,
786
+ tokens=tokens,
787
+ text=text,
788
+ avg_logprob=avg_logprob,
789
+ no_speech_prob=no_speech_prob,
790
+ temperature=self.options.temperature,
791
+ compression_ratio=compression_ratio(text),
792
+ )
793
+ for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
794
+ *fields
795
+ )
796
+ ]
797
+
798
+
799
+ @torch.no_grad()
800
+ def decode(
801
+ model: "Whisper",
802
+ mel: Tensor,
803
+ options: DecodingOptions = DecodingOptions(),
804
+ **kwargs,
805
+ ) -> Union[DecodingResult, List[DecodingResult]]:
806
+ """
807
+ Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
808
+
809
+ Parameters
810
+ ----------
811
+ model: Whisper
812
+ the Whisper model instance
813
+
814
+ mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
815
+ A tensor containing the Mel spectrogram(s)
816
+
817
+ options: DecodingOptions
818
+ A dataclass that contains all necessary options for decoding 30-second segments
819
+
820
+ Returns
821
+ -------
822
+ result: Union[DecodingResult, List[DecodingResult]]
823
+ The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
824
+ """
825
+ if single := mel.ndim == 2:
826
+ mel = mel.unsqueeze(0)
827
+
828
+ if kwargs:
829
+ options = replace(options, **kwargs)
830
+
831
+ result = DecodingTask(model, options).run(mel)
832
+
833
+ return result[0] if single else result
simul_whisper/whisper/model.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import gzip
3
+ from contextlib import contextmanager
4
+ from dataclasses import dataclass
5
+ from typing import Dict, Iterable, Optional, Tuple
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import Tensor, nn
11
+
12
+ from .decoding import decode as decode_function
13
+ from .decoding import detect_language as detect_language_function
14
+ from .transcribe import transcribe as transcribe_function
15
+
16
+
17
+ try:
18
+ from torch.nn.functional import scaled_dot_product_attention
19
+
20
+ SDPA_AVAILABLE = True
21
+ except (ImportError, RuntimeError, OSError):
22
+ scaled_dot_product_attention = None
23
+ SDPA_AVAILABLE = False
24
+
25
+
26
+ @dataclass
27
+ class ModelDimensions:
28
+ n_mels: int
29
+ n_audio_ctx: int
30
+ n_audio_state: int
31
+ n_audio_head: int
32
+ n_audio_layer: int
33
+ n_vocab: int
34
+ n_text_ctx: int
35
+ n_text_state: int
36
+ n_text_head: int
37
+ n_text_layer: int
38
+
39
+
40
+ # class LayerNorm(nn.LayerNorm):
41
+ # def forward(self, x: Tensor) -> Tensor:
42
+ # return super().forward(x.float()).type(x.dtype)
43
+
44
+ # class Linear(nn.Linear):
45
+ # def forward(self, x: Tensor) -> Tensor:
46
+ # return F.linear(
47
+ # x,
48
+ # self.weight.to(x.dtype),
49
+ # None if self.bias is None else self.bias.to(x.dtype),
50
+ # )
51
+
52
+
53
+ # class Conv1d(nn.Conv1d):
54
+ # def _conv_forward(
55
+ # self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
56
+ # ) -> Tensor:
57
+ # return super()._conv_forward(
58
+ # x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
59
+ # )
60
+
61
+
62
+ def sinusoids(length, channels, max_timescale=10000):
63
+ """Returns sinusoids for positional embedding"""
64
+ assert channels % 2 == 0
65
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
66
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
67
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
68
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
69
+
70
+ import sys ## this is mine, for debugging
71
+ class MultiHeadAttention(nn.Module):
72
+
73
+ use_sdpa = False # disabling: https://github.com/linto-ai/whisper-timestamped/issues/212
74
+
75
+ def __init__(self, n_state: int, n_head: int, cache_id: str):
76
+ super().__init__()
77
+ self.n_head = n_head
78
+ self.query = nn.Linear(n_state, n_state)
79
+ self.key = nn.Linear(n_state, n_state, bias=False)
80
+ self.key.cache_id = f"{cache_id}_key"
81
+ self.value = nn.Linear(n_state, n_state)
82
+ self.value.cache_id = f"{cache_id}_value"
83
+ self.out = nn.Linear(n_state, n_state)
84
+ self.cache_id = cache_id
85
+
86
+ def forward(
87
+ self,
88
+ x: Tensor,
89
+ xa: Optional[Tensor] = None,
90
+ mask: Optional[Tensor] = None,
91
+ kv_cache: Optional[dict] = None,
92
+ ):
93
+ #print("MultiHeadAttention forward",file=sys.stderr)
94
+ q = self.query(x)
95
+ # print(q.shape, x is None, mask is None, list(kv_cache.keys()) if kv_cache is not None else None, file=sys.stderr)
96
+ # print(mask, kv_cache, xa, file=sys.stderr)
97
+
98
+ if kv_cache is None or xa is None or self.key.cache_id not in kv_cache:
99
+ k = self.key(x if xa is None else xa)
100
+ v = self.value(x if xa is None else xa)
101
+ # print(self.key.cache_id, "cache miss") # , kv_cache is None, xa is None, self.key.cache_id not in kv_cache if kv_cache is not None else None, k.shape, x.shape)
102
+ # if kv_cache is not None:
103
+ # print(kv_cache.keys())
104
+ else:
105
+ # print(self.key.cache_id, "cache hit") #, kv_cache is None, xa is None, self.key.cache_id not in kv_cache)
106
+ # if kv_cache is not None:
107
+ # print(kv_cache.keys())
108
+ k = kv_cache[self.key.cache_id]
109
+ v = kv_cache[self.value.cache_id]
110
+ # print(self.key.cache_id, "qkv attention", q.shape, k.shape, v.shape)
111
+ wv, qk = self.qkv_attention(q, k, v, mask)
112
+ return self.out(wv), qk
113
+
114
+ # def qkv_attention(
115
+ # self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
116
+ # ):
117
+ # n_batch, n_ctx, n_state = q.shape
118
+ # scale = (n_state // self.n_head) ** -0.25
119
+ # q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
120
+ # k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
121
+ # v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
122
+
123
+ # qk = q @ k
124
+ # if mask is not None:
125
+ # qk = qk + mask[:n_ctx, :n_ctx]
126
+ # # qk = qk.float()
127
+
128
+ # w = F.softmax(qk, dim=-1) # .to(q.dtype)
129
+ # return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
130
+
131
+
132
+ def qkv_attention(
133
+ self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
134
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
135
+ n_batch, n_ctx, n_state = q.shape
136
+ scale = (n_state // self.n_head) ** -0.25
137
+ q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
138
+ k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
139
+ v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
140
+
141
+ if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
142
+ a = scaled_dot_product_attention(
143
+ q, k, v, is_causal=mask is not None and n_ctx > 1
144
+ )
145
+ out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
146
+ qk = None
147
+ else:
148
+ qk = (q * scale) @ (k * scale).transpose(-1, -2)
149
+ if mask is not None:
150
+ qk = qk + mask[:n_ctx, :n_ctx]
151
+ qk = qk.float()
152
+
153
+ w = F.softmax(qk, dim=-1).to(q.dtype)
154
+ out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
155
+ qk = qk.detach()
156
+
157
+ return out, qk
158
+
159
+
160
+ class ResidualAttentionBlock(nn.Module):
161
+ def __init__(self, n_state: int, n_head: int, cache_id: str="", cross_attention: bool = False):
162
+ super().__init__()
163
+
164
+ self.attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_self_attn")
165
+ self.attn_ln = nn.LayerNorm(n_state)
166
+
167
+ self.cross_attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_cross_attn") if cross_attention else None
168
+
169
+ self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None
170
+
171
+ n_mlp = n_state * 4
172
+ self.mlp = nn.Sequential(
173
+ nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
174
+ )
175
+ self.mlp_ln = nn.LayerNorm(n_state)
176
+
177
+ def forward(
178
+ self,
179
+ x: Tensor,
180
+ xa: Optional[Tensor] = None,
181
+ mask: Optional[Tensor] = None,
182
+ kv_cache: Optional[dict] = None,
183
+ ):
184
+ # print("ResidualAttentionBlock forward",file=sys.stderr)
185
+ # print(x.shape, file=sys.stderr)
186
+ x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
187
+ if self.cross_attn:
188
+ x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
189
+ x = x + self.mlp(self.mlp_ln(x))
190
+ return x
191
+
192
+
193
+ class AudioEncoder(nn.Module):
194
+ def __init__(
195
+ self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
196
+ ):
197
+ super().__init__()
198
+ self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
199
+ self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
200
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
201
+
202
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
203
+ [ResidualAttentionBlock(n_state, n_head, cache_id=f"enc_layer{i}") for i in range(n_layer)]
204
+ )
205
+ self.ln_post = nn.LayerNorm(n_state)
206
+
207
+ def forward(self, x: Tensor, return_layer_results: bool=False):
208
+ """
209
+ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
210
+ the mel spectrogram of the audio
211
+ """
212
+
213
+ x = F.gelu(self.conv1(x))
214
+ x = F.gelu(self.conv2(x))
215
+ x = x.permute(0, 2, 1) # BDT -> BTD
216
+
217
+ # 两层卷积,2倍降采样
218
+ # 最终剩下1500帧
219
+
220
+ x = (x + self.positional_embedding[:x.shape[1], :]) #.to(x.dtype)
221
+
222
+ layer_results = []
223
+ i = 0
224
+ for block in self.blocks:
225
+ # print(f"encoder layer {i}")
226
+ x = block(x)
227
+ layer_results.append(x)
228
+ i += 1
229
+
230
+ x = self.ln_post(x)
231
+
232
+ if return_layer_results:
233
+ return x, layer_results
234
+ else:
235
+ return x
236
+
237
+
238
+ class TextDecoder(nn.Module):
239
+ def __init__(
240
+ self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
241
+ ):
242
+ super().__init__()
243
+
244
+ self.token_embedding = nn.Embedding(n_vocab, n_state)
245
+ self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
246
+
247
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
248
+ [
249
+ ResidualAttentionBlock(n_state, n_head, cross_attention=True, cache_id=f"dec_layer{i}")
250
+ for i in range(n_layer)
251
+ ]
252
+ )
253
+ self.ln = nn.LayerNorm(n_state)
254
+
255
+ mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
256
+ self.register_buffer("mask", mask, persistent=False)
257
+
258
+ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
259
+ """
260
+ x : torch.LongTensor, shape = (batch_size, <= n_ctx)
261
+ the text tokens
262
+ xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
263
+ the encoded audio features to be attended on
264
+ """
265
+
266
+ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
267
+ x = (
268
+ self.token_embedding(x)
269
+ + self.positional_embedding[offset : offset + x.shape[-1]]
270
+ )
271
+ # x = x.to(xa.dtype)
272
+
273
+ i = 0
274
+ for block in self.blocks:
275
+ # print(f"decoder layer {i}")
276
+ x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
277
+ i += 1
278
+
279
+ x = self.ln(x)
280
+ logits = x @ torch.transpose(self.token_embedding.weight, 0, 1)
281
+
282
+ return logits
283
+
284
+
285
+ class Whisper(nn.Module):
286
+ def __init__(self, dims: ModelDimensions):
287
+ super().__init__()
288
+ self.dims = dims
289
+ self.encoder = AudioEncoder(
290
+ self.dims.n_mels,
291
+ self.dims.n_audio_ctx,
292
+ self.dims.n_audio_state,
293
+ self.dims.n_audio_head,
294
+ self.dims.n_audio_layer,
295
+ )
296
+ self.decoder = TextDecoder(
297
+ self.dims.n_vocab,
298
+ self.dims.n_text_ctx,
299
+ self.dims.n_text_state,
300
+ self.dims.n_text_head,
301
+ self.dims.n_text_layer,
302
+ )
303
+ # use the last half layers for alignment by default; see `set_alignment_heads()` below
304
+ all_heads = torch.zeros(
305
+ self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
306
+ )
307
+ all_heads[self.dims.n_text_layer // 2 :] = True
308
+ self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
309
+
310
+ def set_alignment_heads(self, dump: bytes):
311
+ array = np.frombuffer(
312
+ gzip.decompress(base64.b85decode(dump)), dtype=bool
313
+ ).copy()
314
+ mask = torch.from_numpy(array).reshape(
315
+ self.dims.n_text_layer, self.dims.n_text_head
316
+ )
317
+ self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
318
+
319
+ def embed_audio(self, mel: torch.Tensor):
320
+ return self.encoder(mel)
321
+
322
+ def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
323
+ # tokens = tokens.to(self.decoder.ln.weight.dtype)
324
+ # audio_features = audio_features.to(self.decoder.ln.weight.dtype)
325
+ return self.decoder(tokens, audio_features)
326
+
327
+ def forward(
328
+ self, mel: torch.Tensor, tokens: torch.Tensor
329
+ ) -> Dict[str, torch.Tensor]:
330
+ # mel = mel.to(self.decoder.ln.weight.dtype)
331
+ # tokens = tokens.to(self.decoder.ln.weight.dtype)
332
+ return self.decoder(tokens, self.encoder(mel))
333
+
334
+ @property
335
+ def device(self):
336
+ return next(self.parameters()).device
337
+
338
+ @property
339
+ def is_multilingual(self):
340
+ return self.dims.n_vocab >= 51865
341
+
342
+ @property
343
+ def num_languages(self):
344
+ return self.dims.n_vocab - 51765 - int(self.is_multilingual)
345
+
346
+ # 为decoder加入缓存机制,每次推理时保存上次的k和v,下次推理无需重新计算
347
+ def install_kv_cache_hooks(self, cache: Optional[dict] = None):
348
+ """
349
+ The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
350
+ tensors calculated for the previous positions. This method returns a dictionary that stores
351
+ all caches, and the necessary hooks for the key and value projection modules that save the
352
+ intermediate tensors to be reused during later calculations.
353
+
354
+ Returns
355
+ -------
356
+ cache : Dict[nn.Module, torch.Tensor]
357
+ A dictionary object mapping the key/value projection modules to its cache
358
+ hooks : List[RemovableHandle]
359
+ List of PyTorch RemovableHandle objects to stop the hooks to be called
360
+ """
361
+ cache = {**cache} if cache is not None else {}
362
+ hooks = []
363
+
364
+ def save_to_cache(module, _, output):
365
+ if module not in cache or output.shape[1] > self.dims.n_text_ctx:
366
+ # save as-is, for the first token or cross attention
367
+ cache[module] = output
368
+ else:
369
+ cache[module] = torch.cat([cache[module], output], dim=1).detach()
370
+ return cache[module]
371
+
372
+ def install_hooks(layer: nn.Module):
373
+ if isinstance(layer, MultiHeadAttention):
374
+ hooks.append(layer.key.register_forward_hook(save_to_cache))
375
+ hooks.append(layer.value.register_forward_hook(save_to_cache))
376
+
377
+ self.decoder.apply(install_hooks)
378
+ return cache, hooks
379
+
380
+ detect_language = detect_language_function
381
+ transcribe = transcribe_function
382
+ decode = decode_function
simul_whisper/whisper/normalizers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .basic import BasicTextNormalizer as BasicTextNormalizer
2
+ from .english import EnglishTextNormalizer as EnglishTextNormalizer
simul_whisper/whisper/normalizers/basic.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import unicodedata
3
+
4
+ import regex
5
+
6
+ # non-ASCII letters that are not separated by "NFKD" normalization
7
+ ADDITIONAL_DIACRITICS = {
8
+ "œ": "oe",
9
+ "Œ": "OE",
10
+ "ø": "o",
11
+ "Ø": "O",
12
+ "æ": "ae",
13
+ "Æ": "AE",
14
+ "ß": "ss",
15
+ "ẞ": "SS",
16
+ "đ": "d",
17
+ "Đ": "D",
18
+ "ð": "d",
19
+ "Ð": "D",
20
+ "þ": "th",
21
+ "Þ": "th",
22
+ "ł": "l",
23
+ "Ł": "L",
24
+ }
25
+
26
+
27
+ def remove_symbols_and_diacritics(s: str, keep=""):
28
+ """
29
+ Replace any other markers, symbols, and punctuations with a space,
30
+ and drop any diacritics (category 'Mn' and some manual mappings)
31
+ """
32
+ return "".join(
33
+ c
34
+ if c in keep
35
+ else ADDITIONAL_DIACRITICS[c]
36
+ if c in ADDITIONAL_DIACRITICS
37
+ else ""
38
+ if unicodedata.category(c) == "Mn"
39
+ else " "
40
+ if unicodedata.category(c)[0] in "MSP"
41
+ else c
42
+ for c in unicodedata.normalize("NFKD", s)
43
+ )
44
+
45
+
46
+ def remove_symbols(s: str):
47
+ """
48
+ Replace any other markers, symbols, punctuations with a space, keeping diacritics
49
+ """
50
+ return "".join(
51
+ " " if unicodedata.category(c)[0] in "MSP" else c
52
+ for c in unicodedata.normalize("NFKC", s)
53
+ )
54
+
55
+
56
+ class BasicTextNormalizer:
57
+ def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
58
+ self.clean = (
59
+ remove_symbols_and_diacritics if remove_diacritics else remove_symbols
60
+ )
61
+ self.split_letters = split_letters
62
+
63
+ def __call__(self, s: str):
64
+ s = s.lower()
65
+ s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
66
+ s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
67
+ s = self.clean(s).lower()
68
+
69
+ if self.split_letters:
70
+ s = " ".join(regex.findall(r"\X", s, regex.U))
71
+
72
+ s = re.sub(
73
+ r"\s+", " ", s
74
+ ) # replace any successive whitespace characters with a space
75
+
76
+ return s
simul_whisper/whisper/normalizers/english.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ from fractions import Fraction
5
+ from typing import Iterator, List, Match, Optional, Union
6
+
7
+ from more_itertools import windowed
8
+
9
+ from .basic import remove_symbols_and_diacritics
10
+
11
+
12
+ class EnglishNumberNormalizer:
13
+ """
14
+ Convert any spelled-out numbers into arabic numbers, while handling:
15
+
16
+ - remove any commas
17
+ - keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
18
+ - spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
19
+ - spell out `one` and `ones`
20
+ - interpret successive single-digit numbers as nominal: `one oh one` -> `101`
21
+ """
22
+
23
+ def __init__(self):
24
+ super().__init__()
25
+
26
+ self.zeros = {"o", "oh", "zero"}
27
+ self.ones = {
28
+ name: i
29
+ for i, name in enumerate(
30
+ [
31
+ "one",
32
+ "two",
33
+ "three",
34
+ "four",
35
+ "five",
36
+ "six",
37
+ "seven",
38
+ "eight",
39
+ "nine",
40
+ "ten",
41
+ "eleven",
42
+ "twelve",
43
+ "thirteen",
44
+ "fourteen",
45
+ "fifteen",
46
+ "sixteen",
47
+ "seventeen",
48
+ "eighteen",
49
+ "nineteen",
50
+ ],
51
+ start=1,
52
+ )
53
+ }
54
+ self.ones_plural = {
55
+ "sixes" if name == "six" else name + "s": (value, "s")
56
+ for name, value in self.ones.items()
57
+ }
58
+ self.ones_ordinal = {
59
+ "zeroth": (0, "th"),
60
+ "first": (1, "st"),
61
+ "second": (2, "nd"),
62
+ "third": (3, "rd"),
63
+ "fifth": (5, "th"),
64
+ "twelfth": (12, "th"),
65
+ **{
66
+ name + ("h" if name.endswith("t") else "th"): (value, "th")
67
+ for name, value in self.ones.items()
68
+ if value > 3 and value != 5 and value != 12
69
+ },
70
+ }
71
+ self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
72
+
73
+ self.tens = {
74
+ "twenty": 20,
75
+ "thirty": 30,
76
+ "forty": 40,
77
+ "fifty": 50,
78
+ "sixty": 60,
79
+ "seventy": 70,
80
+ "eighty": 80,
81
+ "ninety": 90,
82
+ }
83
+ self.tens_plural = {
84
+ name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
85
+ }
86
+ self.tens_ordinal = {
87
+ name.replace("y", "ieth"): (value, "th")
88
+ for name, value in self.tens.items()
89
+ }
90
+ self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
91
+
92
+ self.multipliers = {
93
+ "hundred": 100,
94
+ "thousand": 1_000,
95
+ "million": 1_000_000,
96
+ "billion": 1_000_000_000,
97
+ "trillion": 1_000_000_000_000,
98
+ "quadrillion": 1_000_000_000_000_000,
99
+ "quintillion": 1_000_000_000_000_000_000,
100
+ "sextillion": 1_000_000_000_000_000_000_000,
101
+ "septillion": 1_000_000_000_000_000_000_000_000,
102
+ "octillion": 1_000_000_000_000_000_000_000_000_000,
103
+ "nonillion": 1_000_000_000_000_000_000_000_000_000_000,
104
+ "decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
105
+ }
106
+ self.multipliers_plural = {
107
+ name + "s": (value, "s") for name, value in self.multipliers.items()
108
+ }
109
+ self.multipliers_ordinal = {
110
+ name + "th": (value, "th") for name, value in self.multipliers.items()
111
+ }
112
+ self.multipliers_suffixed = {
113
+ **self.multipliers_plural,
114
+ **self.multipliers_ordinal,
115
+ }
116
+ self.decimals = {*self.ones, *self.tens, *self.zeros}
117
+
118
+ self.preceding_prefixers = {
119
+ "minus": "-",
120
+ "negative": "-",
121
+ "plus": "+",
122
+ "positive": "+",
123
+ }
124
+ self.following_prefixers = {
125
+ "pound": "£",
126
+ "pounds": "£",
127
+ "euro": "€",
128
+ "euros": "€",
129
+ "dollar": "$",
130
+ "dollars": "$",
131
+ "cent": "¢",
132
+ "cents": "¢",
133
+ }
134
+ self.prefixes = set(
135
+ list(self.preceding_prefixers.values())
136
+ + list(self.following_prefixers.values())
137
+ )
138
+ self.suffixers = {
139
+ "per": {"cent": "%"},
140
+ "percent": "%",
141
+ }
142
+ self.specials = {"and", "double", "triple", "point"}
143
+
144
+ self.words = set(
145
+ [
146
+ key
147
+ for mapping in [
148
+ self.zeros,
149
+ self.ones,
150
+ self.ones_suffixed,
151
+ self.tens,
152
+ self.tens_suffixed,
153
+ self.multipliers,
154
+ self.multipliers_suffixed,
155
+ self.preceding_prefixers,
156
+ self.following_prefixers,
157
+ self.suffixers,
158
+ self.specials,
159
+ ]
160
+ for key in mapping
161
+ ]
162
+ )
163
+ self.literal_words = {"one", "ones"}
164
+
165
+ def process_words(self, words: List[str]) -> Iterator[str]:
166
+ prefix: Optional[str] = None
167
+ value: Optional[Union[str, int]] = None
168
+ skip = False
169
+
170
+ def to_fraction(s: str):
171
+ try:
172
+ return Fraction(s)
173
+ except ValueError:
174
+ return None
175
+
176
+ def output(result: Union[str, int]):
177
+ nonlocal prefix, value
178
+ result = str(result)
179
+ if prefix is not None:
180
+ result = prefix + result
181
+ value = None
182
+ prefix = None
183
+ return result
184
+
185
+ if len(words) == 0:
186
+ return
187
+
188
+ for prev, current, next in windowed([None] + words + [None], 3):
189
+ if skip:
190
+ skip = False
191
+ continue
192
+
193
+ next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
194
+ has_prefix = current[0] in self.prefixes
195
+ current_without_prefix = current[1:] if has_prefix else current
196
+ if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
197
+ # arabic numbers (potentially with signs and fractions)
198
+ f = to_fraction(current_without_prefix)
199
+ assert f is not None
200
+ if value is not None:
201
+ if isinstance(value, str) and value.endswith("."):
202
+ # concatenate decimals / ip address components
203
+ value = str(value) + str(current)
204
+ continue
205
+ else:
206
+ yield output(value)
207
+
208
+ prefix = current[0] if has_prefix else prefix
209
+ if f.denominator == 1:
210
+ value = f.numerator # store integers as int
211
+ else:
212
+ value = current_without_prefix
213
+ elif current not in self.words:
214
+ # non-numeric words
215
+ if value is not None:
216
+ yield output(value)
217
+ yield output(current)
218
+ elif current in self.zeros:
219
+ value = str(value or "") + "0"
220
+ elif current in self.ones:
221
+ ones = self.ones[current]
222
+
223
+ if value is None:
224
+ value = ones
225
+ elif isinstance(value, str) or prev in self.ones:
226
+ if (
227
+ prev in self.tens and ones < 10
228
+ ): # replace the last zero with the digit
229
+ assert value[-1] == "0"
230
+ value = value[:-1] + str(ones)
231
+ else:
232
+ value = str(value) + str(ones)
233
+ elif ones < 10:
234
+ if value % 10 == 0:
235
+ value += ones
236
+ else:
237
+ value = str(value) + str(ones)
238
+ else: # eleven to nineteen
239
+ if value % 100 == 0:
240
+ value += ones
241
+ else:
242
+ value = str(value) + str(ones)
243
+ elif current in self.ones_suffixed:
244
+ # ordinal or cardinal; yield the number right away
245
+ ones, suffix = self.ones_suffixed[current]
246
+ if value is None:
247
+ yield output(str(ones) + suffix)
248
+ elif isinstance(value, str) or prev in self.ones:
249
+ if prev in self.tens and ones < 10:
250
+ assert value[-1] == "0"
251
+ yield output(value[:-1] + str(ones) + suffix)
252
+ else:
253
+ yield output(str(value) + str(ones) + suffix)
254
+ elif ones < 10:
255
+ if value % 10 == 0:
256
+ yield output(str(value + ones) + suffix)
257
+ else:
258
+ yield output(str(value) + str(ones) + suffix)
259
+ else: # eleven to nineteen
260
+ if value % 100 == 0:
261
+ yield output(str(value + ones) + suffix)
262
+ else:
263
+ yield output(str(value) + str(ones) + suffix)
264
+ value = None
265
+ elif current in self.tens:
266
+ tens = self.tens[current]
267
+ if value is None:
268
+ value = tens
269
+ elif isinstance(value, str):
270
+ value = str(value) + str(tens)
271
+ else:
272
+ if value % 100 == 0:
273
+ value += tens
274
+ else:
275
+ value = str(value) + str(tens)
276
+ elif current in self.tens_suffixed:
277
+ # ordinal or cardinal; yield the number right away
278
+ tens, suffix = self.tens_suffixed[current]
279
+ if value is None:
280
+ yield output(str(tens) + suffix)
281
+ elif isinstance(value, str):
282
+ yield output(str(value) + str(tens) + suffix)
283
+ else:
284
+ if value % 100 == 0:
285
+ yield output(str(value + tens) + suffix)
286
+ else:
287
+ yield output(str(value) + str(tens) + suffix)
288
+ elif current in self.multipliers:
289
+ multiplier = self.multipliers[current]
290
+ if value is None:
291
+ value = multiplier
292
+ elif isinstance(value, str) or value == 0:
293
+ f = to_fraction(value)
294
+ p = f * multiplier if f is not None else None
295
+ if f is not None and p.denominator == 1:
296
+ value = p.numerator
297
+ else:
298
+ yield output(value)
299
+ value = multiplier
300
+ else:
301
+ before = value // 1000 * 1000
302
+ residual = value % 1000
303
+ value = before + residual * multiplier
304
+ elif current in self.multipliers_suffixed:
305
+ multiplier, suffix = self.multipliers_suffixed[current]
306
+ if value is None:
307
+ yield output(str(multiplier) + suffix)
308
+ elif isinstance(value, str):
309
+ f = to_fraction(value)
310
+ p = f * multiplier if f is not None else None
311
+ if f is not None and p.denominator == 1:
312
+ yield output(str(p.numerator) + suffix)
313
+ else:
314
+ yield output(value)
315
+ yield output(str(multiplier) + suffix)
316
+ else: # int
317
+ before = value // 1000 * 1000
318
+ residual = value % 1000
319
+ value = before + residual * multiplier
320
+ yield output(str(value) + suffix)
321
+ value = None
322
+ elif current in self.preceding_prefixers:
323
+ # apply prefix (positive, minus, etc.) if it precedes a number
324
+ if value is not None:
325
+ yield output(value)
326
+
327
+ if next in self.words or next_is_numeric:
328
+ prefix = self.preceding_prefixers[current]
329
+ else:
330
+ yield output(current)
331
+ elif current in self.following_prefixers:
332
+ # apply prefix (dollars, cents, etc.) only after a number
333
+ if value is not None:
334
+ prefix = self.following_prefixers[current]
335
+ yield output(value)
336
+ else:
337
+ yield output(current)
338
+ elif current in self.suffixers:
339
+ # apply suffix symbols (percent -> '%')
340
+ if value is not None:
341
+ suffix = self.suffixers[current]
342
+ if isinstance(suffix, dict):
343
+ if next in suffix:
344
+ yield output(str(value) + suffix[next])
345
+ skip = True
346
+ else:
347
+ yield output(value)
348
+ yield output(current)
349
+ else:
350
+ yield output(str(value) + suffix)
351
+ else:
352
+ yield output(current)
353
+ elif current in self.specials:
354
+ if next not in self.words and not next_is_numeric:
355
+ # apply special handling only if the next word can be numeric
356
+ if value is not None:
357
+ yield output(value)
358
+ yield output(current)
359
+ elif current == "and":
360
+ # ignore "and" after hundreds, thousands, etc.
361
+ if prev not in self.multipliers:
362
+ if value is not None:
363
+ yield output(value)
364
+ yield output(current)
365
+ elif current == "double" or current == "triple":
366
+ if next in self.ones or next in self.zeros:
367
+ repeats = 2 if current == "double" else 3
368
+ ones = self.ones.get(next, 0)
369
+ value = str(value or "") + str(ones) * repeats
370
+ skip = True
371
+ else:
372
+ if value is not None:
373
+ yield output(value)
374
+ yield output(current)
375
+ elif current == "point":
376
+ if next in self.decimals or next_is_numeric:
377
+ value = str(value or "") + "."
378
+ else:
379
+ # should all have been covered at this point
380
+ raise ValueError(f"Unexpected token: {current}")
381
+ else:
382
+ # all should have been covered at this point
383
+ raise ValueError(f"Unexpected token: {current}")
384
+
385
+ if value is not None:
386
+ yield output(value)
387
+
388
+ def preprocess(self, s: str):
389
+ # replace "<number> and a half" with "<number> point five"
390
+ results = []
391
+
392
+ segments = re.split(r"\band\s+a\s+half\b", s)
393
+ for i, segment in enumerate(segments):
394
+ if len(segment.strip()) == 0:
395
+ continue
396
+ if i == len(segments) - 1:
397
+ results.append(segment)
398
+ else:
399
+ results.append(segment)
400
+ last_word = segment.rsplit(maxsplit=2)[-1]
401
+ if last_word in self.decimals or last_word in self.multipliers:
402
+ results.append("point five")
403
+ else:
404
+ results.append("and a half")
405
+
406
+ s = " ".join(results)
407
+
408
+ # put a space at number/letter boundary
409
+ s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
410
+ s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
411
+
412
+ # but remove spaces which could be a suffix
413
+ s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
414
+
415
+ return s
416
+
417
+ def postprocess(self, s: str):
418
+ def combine_cents(m: Match):
419
+ try:
420
+ currency = m.group(1)
421
+ integer = m.group(2)
422
+ cents = int(m.group(3))
423
+ return f"{currency}{integer}.{cents:02d}"
424
+ except ValueError:
425
+ return m.string
426
+
427
+ def extract_cents(m: Match):
428
+ try:
429
+ return f"¢{int(m.group(1))}"
430
+ except ValueError:
431
+ return m.string
432
+
433
+ # apply currency postprocessing; "$2 and ¢7" -> "$2.07"
434
+ s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s)
435
+ s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s)
436
+
437
+ # write "one(s)" instead of "1(s)", just for the readability
438
+ s = re.sub(r"\b1(s?)\b", r"one\1", s)
439
+
440
+ return s
441
+
442
+ def __call__(self, s: str):
443
+ s = self.preprocess(s)
444
+ s = " ".join(word for word in self.process_words(s.split()) if word is not None)
445
+ s = self.postprocess(s)
446
+
447
+ return s
448
+
449
+
450
+ class EnglishSpellingNormalizer:
451
+ """
452
+ Applies British-American spelling mappings as listed in [1].
453
+
454
+ [1] https://www.tysto.com/uk-us-spelling-list.html
455
+ """
456
+
457
+ def __init__(self):
458
+ mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
459
+ self.mapping = json.load(open(mapping_path))
460
+
461
+ def __call__(self, s: str):
462
+ return " ".join(self.mapping.get(word, word) for word in s.split())
463
+
464
+
465
+ class EnglishTextNormalizer:
466
+ def __init__(self):
467
+ self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
468
+ self.replacers = {
469
+ # common contractions
470
+ r"\bwon't\b": "will not",
471
+ r"\bcan't\b": "can not",
472
+ r"\blet's\b": "let us",
473
+ r"\bain't\b": "aint",
474
+ r"\by'all\b": "you all",
475
+ r"\bwanna\b": "want to",
476
+ r"\bgotta\b": "got to",
477
+ r"\bgonna\b": "going to",
478
+ r"\bi'ma\b": "i am going to",
479
+ r"\bimma\b": "i am going to",
480
+ r"\bwoulda\b": "would have",
481
+ r"\bcoulda\b": "could have",
482
+ r"\bshoulda\b": "should have",
483
+ r"\bma'am\b": "madam",
484
+ # contractions in titles/prefixes
485
+ r"\bmr\b": "mister ",
486
+ r"\bmrs\b": "missus ",
487
+ r"\bst\b": "saint ",
488
+ r"\bdr\b": "doctor ",
489
+ r"\bprof\b": "professor ",
490
+ r"\bcapt\b": "captain ",
491
+ r"\bgov\b": "governor ",
492
+ r"\bald\b": "alderman ",
493
+ r"\bgen\b": "general ",
494
+ r"\bsen\b": "senator ",
495
+ r"\brep\b": "representative ",
496
+ r"\bpres\b": "president ",
497
+ r"\brev\b": "reverend ",
498
+ r"\bhon\b": "honorable ",
499
+ r"\basst\b": "assistant ",
500
+ r"\bassoc\b": "associate ",
501
+ r"\blt\b": "lieutenant ",
502
+ r"\bcol\b": "colonel ",
503
+ r"\bjr\b": "junior ",
504
+ r"\bsr\b": "senior ",
505
+ r"\besq\b": "esquire ",
506
+ # prefect tenses, ideally it should be any past participles, but it's harder..
507
+ r"'d been\b": " had been",
508
+ r"'s been\b": " has been",
509
+ r"'d gone\b": " had gone",
510
+ r"'s gone\b": " has gone",
511
+ r"'d done\b": " had done", # "'s done" is ambiguous
512
+ r"'s got\b": " has got",
513
+ # general contractions
514
+ r"n't\b": " not",
515
+ r"'re\b": " are",
516
+ r"'s\b": " is",
517
+ r"'d\b": " would",
518
+ r"'ll\b": " will",
519
+ r"'t\b": " not",
520
+ r"'ve\b": " have",
521
+ r"'m\b": " am",
522
+ }
523
+ self.standardize_numbers = EnglishNumberNormalizer()
524
+ self.standardize_spellings = EnglishSpellingNormalizer()
525
+
526
+ def __call__(self, s: str):
527
+ s = s.lower()
528
+
529
+ s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
530
+ s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
531
+ s = re.sub(self.ignore_patterns, "", s)
532
+ s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe
533
+
534
+ for pattern, replacement in self.replacers.items():
535
+ s = re.sub(pattern, replacement, s)
536
+
537
+ s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
538
+ s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
539
+ s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols
540
+
541
+ s = self.standardize_numbers(s)
542
+ s = self.standardize_spellings(s)
543
+
544
+ # now remove prefix/suffix symbols that are not preceded/followed by numbers
545
+ s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
546
+ s = re.sub(r"([^0-9])%", r"\1 ", s)
547
+
548
+ s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space
549
+
550
+ return s
simul_whisper/whisper/timing.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import subprocess
3
+ import warnings
4
+ from dataclasses import dataclass
5
+ from typing import TYPE_CHECKING, List
6
+
7
+ import numba
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+ from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
13
+ from .tokenizer import Tokenizer
14
+
15
+ if TYPE_CHECKING:
16
+ from .model import Whisper
17
+
18
+
19
+ def median_filter(x: torch.Tensor, filter_width: int):
20
+ """Apply a median filter of width `filter_width` along the last dimension of `x`"""
21
+ pad_width = filter_width // 2
22
+ if x.shape[-1] <= pad_width:
23
+ # F.pad requires the padding width to be smaller than the input dimension
24
+ return x
25
+
26
+ if (ndim := x.ndim) <= 2:
27
+ # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
28
+ x = x[None, None, :]
29
+
30
+ assert (
31
+ filter_width > 0 and filter_width % 2 == 1
32
+ ), "`filter_width` should be an odd number"
33
+
34
+ result = None
35
+ x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
36
+ if x.is_cuda:
37
+ try:
38
+ from .triton_ops import median_filter_cuda
39
+
40
+ result = median_filter_cuda(x, filter_width)
41
+ except (RuntimeError, subprocess.CalledProcessError):
42
+ warnings.warn(
43
+ "Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
44
+ "falling back to a slower median kernel implementation..."
45
+ )
46
+
47
+ if result is None:
48
+ # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
49
+ result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
50
+
51
+ if ndim <= 2:
52
+ result = result[0, 0]
53
+
54
+ return result
55
+
56
+
57
+ @numba.jit(nopython=True)
58
+ def backtrace(trace: np.ndarray):
59
+ i = trace.shape[0] - 1 # trace: (N+1, M+1), i=N
60
+ j = trace.shape[1] - 1 # j=M
61
+ # 边界点其实无意义?
62
+ trace[0, :] = 2
63
+ trace[:, 0] = 1
64
+
65
+ result = []
66
+ while i > 0 or j > 0:
67
+ result.append((i - 1, j - 1))
68
+
69
+ if trace[i, j] == 0:
70
+ i -= 1
71
+ j -= 1
72
+ elif trace[i, j] == 1:
73
+ i -= 1
74
+ elif trace[i, j] == 2:
75
+ j -= 1
76
+ else:
77
+ raise ValueError("Unexpected trace[i, j]")
78
+
79
+ result = np.array(result)
80
+ return result[::-1, :].T
81
+
82
+
83
+ @numba.jit(nopython=True, parallel=True)
84
+ def dtw_cpu(x: np.ndarray):
85
+ N, M = x.shape
86
+ cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf # cost: x[0, 0]到x[i-1, j-1]的最小代价
87
+ trace = -np.ones((N + 1, M + 1), dtype=np.float32) # trace:
88
+
89
+ cost[0, 0] = 0
90
+ for j in range(1, M + 1):
91
+ for i in range(1, N + 1):
92
+ c0 = cost[i - 1, j - 1]
93
+ c1 = cost[i - 1, j]
94
+ c2 = cost[i, j - 1]
95
+
96
+ if c0 < c1 and c0 < c2:
97
+ c, t = c0, 0
98
+ elif c1 < c0 and c1 < c2:
99
+ c, t = c1, 1
100
+ else:
101
+ c, t = c2, 2
102
+
103
+ cost[i, j] = x[i - 1, j - 1] + c
104
+ trace[i, j] = t
105
+
106
+ return backtrace(trace)
107
+
108
+
109
+ def dtw_cuda(x, BLOCK_SIZE=1024):
110
+ from .triton_ops import dtw_kernel
111
+
112
+ M, N = x.shape
113
+ assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
114
+
115
+ x_skew = (
116
+ F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
117
+ )
118
+ x_skew = x_skew.T.contiguous()
119
+ cost = torch.ones(N + M + 2, M + 2) * np.inf
120
+ cost[0, 0] = 0
121
+ cost = cost.cuda()
122
+ trace = torch.zeros_like(cost, dtype=torch.int32)
123
+
124
+ dtw_kernel[(1,)](
125
+ cost,
126
+ trace,
127
+ x_skew,
128
+ x_skew.stride(0),
129
+ cost.stride(0),
130
+ trace.stride(0),
131
+ N,
132
+ M,
133
+ BLOCK_SIZE=BLOCK_SIZE,
134
+ )
135
+
136
+ trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
137
+ :, : N + 1
138
+ ]
139
+ return backtrace(trace.cpu().numpy())
140
+
141
+
142
+ def dtw(x: torch.Tensor) -> np.ndarray:
143
+ if x.is_cuda:
144
+ try:
145
+ return dtw_cuda(x)
146
+ except (RuntimeError, subprocess.CalledProcessError):
147
+ warnings.warn(
148
+ "Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
149
+ "falling back to a slower DTW implementation..."
150
+ )
151
+
152
+ return dtw_cpu(x.double().cpu().numpy())
153
+
154
+
155
+ @dataclass
156
+ class WordTiming:
157
+ word: str
158
+ tokens: List[int]
159
+ start: float
160
+ end: float
161
+ probability: float
162
+
163
+
164
+ def find_alignment(
165
+ model: "Whisper",
166
+ tokenizer: Tokenizer,
167
+ text_tokens: List[int],
168
+ mel: torch.Tensor,
169
+ num_frames: int,
170
+ *,
171
+ medfilt_width: int = 7,
172
+ qk_scale: float = 1.0,
173
+ ) -> List[WordTiming]:
174
+ if len(text_tokens) == 0:
175
+ return []
176
+
177
+ tokens = torch.tensor(
178
+ [
179
+ *tokenizer.sot_sequence,
180
+ tokenizer.no_timestamps,
181
+ *text_tokens,
182
+ tokenizer.eot,
183
+ ]
184
+ ).to(model.device)
185
+
186
+ # install hooks on the cross attention layers to retrieve the attention weights
187
+ QKs = [None] * model.dims.n_text_layer
188
+ hooks = [
189
+ block.cross_attn.register_forward_hook(
190
+ lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
191
+ )
192
+ for i, block in enumerate(model.decoder.blocks)
193
+ ]
194
+
195
+ # 进行前传,获得token概率
196
+ with torch.no_grad():
197
+ logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
198
+ sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
199
+ token_probs = sampled_logits.softmax(dim=-1)
200
+ text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
201
+ text_token_probs = text_token_probs.tolist()
202
+
203
+ # 移除钩子
204
+ for hook in hooks:
205
+ hook.remove()
206
+
207
+ # heads * tokens * frames
208
+ # print(model.alignment_heads)
209
+ # exit(0)
210
+ weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
211
+ weights = weights[:, :, : num_frames // 2]
212
+ weights = (weights * qk_scale).softmax(dim=-1)
213
+ std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
214
+ weights = (weights - mean) / std
215
+ weights = median_filter(weights, medfilt_width)
216
+
217
+ matrix = weights.mean(axis=0)
218
+ print("attention", matrix.shape, matrix[:5, :5])
219
+ matrix = matrix[len(tokenizer.sot_sequence) : -1]
220
+ print("attention", matrix.shape, matrix[:5, :5])
221
+ text_indices, time_indices = dtw(-matrix)
222
+
223
+ print("num_frames", num_frames)
224
+ print("attention", matrix.shape, matrix[:5, :5])
225
+ print("text_indices", text_indices)
226
+ print("time", time_indices)
227
+ print("text_tokens", text_tokens, tokenizer.decode(text_tokens), len(text_tokens))
228
+ print("eot", tokenizer.eot)
229
+
230
+ words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
231
+ if len(word_tokens) <= 1:
232
+ # return on eot only
233
+ # >>> np.pad([], (1, 0))
234
+ # array([0.])
235
+ # This results in crashes when we lookup jump_times with float, like
236
+ # IndexError: arrays used as indices must be of integer (or boolean) type
237
+ return []
238
+ word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
239
+
240
+ jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
241
+ # print("jumps", jumps, jumps.shape)
242
+ jump_times = time_indices[jumps] / TOKENS_PER_SECOND
243
+ # print("jump_times", jump_times)
244
+ start_times = jump_times[word_boundaries[:-1]]
245
+ end_times = jump_times[word_boundaries[1:]]
246
+ word_probabilities = [
247
+ np.mean(text_token_probs[i:j])
248
+ for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
249
+ ]
250
+
251
+ return [
252
+ WordTiming(word, tokens, start, end, probability)
253
+ for word, tokens, start, end, probability in zip(
254
+ words, word_tokens, start_times, end_times, word_probabilities
255
+ )
256
+ ]
257
+
258
+
259
+ def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
260
+ # merge prepended punctuations
261
+ i = len(alignment) - 2
262
+ j = len(alignment) - 1
263
+ while i >= 0:
264
+ previous = alignment[i]
265
+ following = alignment[j]
266
+ if previous.word.startswith(" ") and previous.word.strip() in prepended:
267
+ # prepend it to the following word
268
+ following.word = previous.word + following.word
269
+ following.tokens = previous.tokens + following.tokens
270
+ previous.word = ""
271
+ previous.tokens = []
272
+ else:
273
+ j = i
274
+ i -= 1
275
+
276
+ # merge appended punctuations
277
+ i = 0
278
+ j = 1
279
+ while j < len(alignment):
280
+ previous = alignment[i]
281
+ following = alignment[j]
282
+ if not previous.word.endswith(" ") and following.word in appended:
283
+ # append it to the previous word
284
+ previous.word = previous.word + following.word
285
+ previous.tokens = previous.tokens + following.tokens
286
+ following.word = ""
287
+ following.tokens = []
288
+ else:
289
+ i = j
290
+ j += 1
291
+
292
+
293
+ def add_word_timestamps(
294
+ *,
295
+ segments: List[dict],
296
+ model: "Whisper",
297
+ tokenizer: Tokenizer,
298
+ mel: torch.Tensor,
299
+ num_frames: int,
300
+ prepend_punctuations: str = "\"'“¿([{-",
301
+ append_punctuations: str = "\"'.。,,!!??::”)]}、",
302
+ last_speech_timestamp: float,
303
+ **kwargs,
304
+ ):
305
+ if len(segments) == 0:
306
+ return
307
+
308
+ text_tokens_per_segment = [
309
+ [token for token in segment["tokens"] if token < tokenizer.eot]
310
+ for segment in segments
311
+ ]
312
+
313
+ text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
314
+ alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
315
+ word_durations = np.array([t.end - t.start for t in alignment])
316
+ word_durations = word_durations[word_durations.nonzero()]
317
+ median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
318
+ max_duration = median_duration * 2
319
+
320
+ # hack: truncate long words at sentence boundaries.
321
+ # a better segmentation algorithm based on VAD should be able to replace this.
322
+ if len(word_durations) > 0:
323
+ sentence_end_marks = ".。!!??"
324
+ # ensure words at sentence boundaries are not longer than twice the median word duration.
325
+ for i in range(1, len(alignment)):
326
+ if alignment[i].end - alignment[i].start > max_duration:
327
+ if alignment[i].word in sentence_end_marks:
328
+ alignment[i].end = alignment[i].start + max_duration
329
+ elif alignment[i - 1].word in sentence_end_marks:
330
+ alignment[i].start = alignment[i].end - max_duration
331
+
332
+ merge_punctuations(alignment, prepend_punctuations, append_punctuations)
333
+
334
+ time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
335
+ word_index = 0
336
+
337
+ for segment, text_tokens in zip(segments, text_tokens_per_segment):
338
+ saved_tokens = 0
339
+ words = []
340
+
341
+ while word_index < len(alignment) and saved_tokens < len(text_tokens):
342
+ timing = alignment[word_index]
343
+
344
+ if timing.word:
345
+ words.append(
346
+ dict(
347
+ word=timing.word,
348
+ start=round(time_offset + timing.start, 2),
349
+ end=round(time_offset + timing.end, 2),
350
+ probability=timing.probability,
351
+ )
352
+ )
353
+
354
+ saved_tokens += len(timing.tokens)
355
+ word_index += 1
356
+
357
+ # hack: truncate long words at segment boundaries.
358
+ # a better segmentation algorithm based on VAD should be able to replace this.
359
+ if len(words) > 0:
360
+ # ensure the first and second word after a pause is not longer than
361
+ # twice the median word duration.
362
+ if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
363
+ words[0]["end"] - words[0]["start"] > max_duration
364
+ or (
365
+ len(words) > 1
366
+ and words[1]["end"] - words[0]["start"] > max_duration * 2
367
+ )
368
+ ):
369
+ if (
370
+ len(words) > 1
371
+ and words[1]["end"] - words[1]["start"] > max_duration
372
+ ):
373
+ boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
374
+ words[0]["end"] = words[1]["start"] = boundary
375
+ words[0]["start"] = max(0, words[0]["end"] - max_duration)
376
+
377
+ # prefer the segment-level start timestamp if the first word is too long.
378
+ if (
379
+ segment["start"] < words[0]["end"]
380
+ and segment["start"] - 0.5 > words[0]["start"]
381
+ ):
382
+ words[0]["start"] = max(
383
+ 0, min(words[0]["end"] - median_duration, segment["start"])
384
+ )
385
+ else:
386
+ segment["start"] = words[0]["start"]
387
+
388
+ # prefer the segment-level end timestamp if the last word is too long.
389
+ if (
390
+ segment["end"] > words[-1]["start"]
391
+ and segment["end"] + 0.5 < words[-1]["end"]
392
+ ):
393
+ words[-1]["end"] = max(
394
+ words[-1]["start"] + median_duration, segment["end"]
395
+ )
396
+ else:
397
+ segment["end"] = words[-1]["end"]
398
+
399
+ last_speech_timestamp = segment["end"]
400
+
401
+ segment["words"] = words
simul_whisper/whisper/tokenizer.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ import string
4
+ from dataclasses import dataclass, field
5
+ from functools import cached_property, lru_cache
6
+ from typing import Dict, List, Optional, Tuple
7
+
8
+ import tiktoken
9
+
10
+ LANGUAGES = {
11
+ "en": "english",
12
+ "zh": "chinese",
13
+ "de": "german",
14
+ "es": "spanish",
15
+ "ru": "russian",
16
+ "ko": "korean",
17
+ "fr": "french",
18
+ "ja": "japanese",
19
+ "pt": "portuguese",
20
+ "tr": "turkish",
21
+ "pl": "polish",
22
+ "ca": "catalan",
23
+ "nl": "dutch",
24
+ "ar": "arabic",
25
+ "sv": "swedish",
26
+ "it": "italian",
27
+ "id": "indonesian",
28
+ "hi": "hindi",
29
+ "fi": "finnish",
30
+ "vi": "vietnamese",
31
+ "he": "hebrew",
32
+ "uk": "ukrainian",
33
+ "el": "greek",
34
+ "ms": "malay",
35
+ "cs": "czech",
36
+ "ro": "romanian",
37
+ "da": "danish",
38
+ "hu": "hungarian",
39
+ "ta": "tamil",
40
+ "no": "norwegian",
41
+ "th": "thai",
42
+ "ur": "urdu",
43
+ "hr": "croatian",
44
+ "bg": "bulgarian",
45
+ "lt": "lithuanian",
46
+ "la": "latin",
47
+ "mi": "maori",
48
+ "ml": "malayalam",
49
+ "cy": "welsh",
50
+ "sk": "slovak",
51
+ "te": "telugu",
52
+ "fa": "persian",
53
+ "lv": "latvian",
54
+ "bn": "bengali",
55
+ "sr": "serbian",
56
+ "az": "azerbaijani",
57
+ "sl": "slovenian",
58
+ "kn": "kannada",
59
+ "et": "estonian",
60
+ "mk": "macedonian",
61
+ "br": "breton",
62
+ "eu": "basque",
63
+ "is": "icelandic",
64
+ "hy": "armenian",
65
+ "ne": "nepali",
66
+ "mn": "mongolian",
67
+ "bs": "bosnian",
68
+ "kk": "kazakh",
69
+ "sq": "albanian",
70
+ "sw": "swahili",
71
+ "gl": "galician",
72
+ "mr": "marathi",
73
+ "pa": "punjabi",
74
+ "si": "sinhala",
75
+ "km": "khmer",
76
+ "sn": "shona",
77
+ "yo": "yoruba",
78
+ "so": "somali",
79
+ "af": "afrikaans",
80
+ "oc": "occitan",
81
+ "ka": "georgian",
82
+ "be": "belarusian",
83
+ "tg": "tajik",
84
+ "sd": "sindhi",
85
+ "gu": "gujarati",
86
+ "am": "amharic",
87
+ "yi": "yiddish",
88
+ "lo": "lao",
89
+ "uz": "uzbek",
90
+ "fo": "faroese",
91
+ "ht": "haitian creole",
92
+ "ps": "pashto",
93
+ "tk": "turkmen",
94
+ "nn": "nynorsk",
95
+ "mt": "maltese",
96
+ "sa": "sanskrit",
97
+ "lb": "luxembourgish",
98
+ "my": "myanmar",
99
+ "bo": "tibetan",
100
+ "tl": "tagalog",
101
+ "mg": "malagasy",
102
+ "as": "assamese",
103
+ "tt": "tatar",
104
+ "haw": "hawaiian",
105
+ "ln": "lingala",
106
+ "ha": "hausa",
107
+ "ba": "bashkir",
108
+ "jw": "javanese",
109
+ "su": "sundanese",
110
+ "yue": "cantonese",
111
+ }
112
+
113
+ # language code lookup by name, with a few language aliases
114
+ TO_LANGUAGE_CODE = {
115
+ **{language: code for code, language in LANGUAGES.items()},
116
+ "burmese": "my",
117
+ "valencian": "ca",
118
+ "flemish": "nl",
119
+ "haitian": "ht",
120
+ "letzeburgesch": "lb",
121
+ "pushto": "ps",
122
+ "panjabi": "pa",
123
+ "moldavian": "ro",
124
+ "moldovan": "ro",
125
+ "sinhalese": "si",
126
+ "castilian": "es",
127
+ "mandarin": "zh",
128
+ }
129
+
130
+
131
+ @dataclass
132
+ class Tokenizer:
133
+ """A thin wrapper around `tiktoken` providing quick access to special tokens"""
134
+
135
+ encoding: tiktoken.Encoding
136
+ num_languages: int
137
+ language: Optional[str] = None
138
+ task: Optional[str] = None
139
+ sot_sequence: Tuple[int] = ()
140
+ special_tokens: Dict[str, int] = field(default_factory=dict)
141
+
142
+ def __post_init__(self):
143
+ for special in self.encoding.special_tokens_set:
144
+ special_token = self.encoding.encode_single_token(special)
145
+ self.special_tokens[special] = special_token
146
+
147
+ sot: int = self.special_tokens["<|startoftranscript|>"]
148
+ translate: int = self.special_tokens["<|translate|>"]
149
+ transcribe: int = self.special_tokens["<|transcribe|>"]
150
+
151
+ langs = tuple(LANGUAGES.keys())[: self.num_languages]
152
+ sot_sequence = [sot]
153
+ if self.language is not None:
154
+ sot_sequence.append(sot + 1 + langs.index(self.language))
155
+ if self.task is not None:
156
+ task_token: int = transcribe if self.task == "transcribe" else translate
157
+ sot_sequence.append(task_token)
158
+
159
+ self.sot_sequence = tuple(sot_sequence)
160
+
161
+ def encode(self, text, **kwargs):
162
+ return self.encoding.encode(text, **kwargs)
163
+
164
+ def decode(self, token_ids: List[int], **kwargs) -> str:
165
+ token_ids = [t for t in token_ids if t < self.timestamp_begin]
166
+ return self.encoding.decode(token_ids, **kwargs)
167
+
168
+ def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
169
+ """
170
+ Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
171
+ This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
172
+ """
173
+ return self.encoding.decode(token_ids, **kwargs)
174
+
175
+ @cached_property
176
+ def eot(self) -> int:
177
+ return self.encoding.eot_token
178
+
179
+ @cached_property
180
+ def transcribe(self) -> int:
181
+ return self.special_tokens["<|transcribe|>"]
182
+
183
+ @cached_property
184
+ def translate(self) -> int:
185
+ return self.special_tokens["<|translate|>"]
186
+
187
+ @cached_property
188
+ def sot(self) -> int:
189
+ return self.special_tokens["<|startoftranscript|>"]
190
+
191
+ @cached_property
192
+ def sot_lm(self) -> int:
193
+ return self.special_tokens["<|startoflm|>"]
194
+
195
+ @cached_property
196
+ def sot_prev(self) -> int:
197
+ return self.special_tokens["<|startofprev|>"]
198
+
199
+ @cached_property
200
+ def no_speech(self) -> int:
201
+ return self.special_tokens["<|nospeech|>"]
202
+
203
+ @cached_property
204
+ def no_timestamps(self) -> int:
205
+ return self.special_tokens["<|notimestamps|>"]
206
+
207
+ @cached_property
208
+ def timestamp_begin(self) -> int:
209
+ return self.special_tokens["<|0.00|>"]
210
+
211
+ @cached_property
212
+ def language_token(self) -> int:
213
+ """Returns the token id corresponding to the value of the `language` field"""
214
+ if self.language is None:
215
+ raise ValueError("This tokenizer does not have language token configured")
216
+
217
+ return self.to_language_token(self.language)
218
+
219
+ def to_language_token(self, language):
220
+ if token := self.special_tokens.get(f"<|{language}|>", None):
221
+ return token
222
+
223
+ raise KeyError(f"Language {language} not found in tokenizer.")
224
+
225
+ @cached_property
226
+ def all_language_tokens(self) -> Tuple[int]:
227
+ result = []
228
+ for token, token_id in self.special_tokens.items():
229
+ if token.strip("<|>") in LANGUAGES:
230
+ result.append(token_id)
231
+ return tuple(result)[: self.num_languages]
232
+
233
+ @cached_property
234
+ def all_language_codes(self) -> Tuple[str]:
235
+ return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
236
+
237
+ @cached_property
238
+ def sot_sequence_including_notimestamps(self) -> Tuple[int]:
239
+ return tuple(list(self.sot_sequence) + [self.no_timestamps])
240
+
241
+ @cached_property
242
+ def non_speech_tokens(self) -> Tuple[int]:
243
+ """
244
+ Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
245
+ annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
246
+
247
+ - ♪♪♪
248
+ - ( SPEAKING FOREIGN LANGUAGE )
249
+ - [DAVID] Hey there,
250
+
251
+ keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
252
+ """
253
+ symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
254
+ symbols += (
255
+ "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
256
+ )
257
+
258
+ # symbols that may be a single token or multiple tokens depending on the tokenizer.
259
+ # In case they're multiple tokens, suppress the first token, which is safe because:
260
+ # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
261
+ # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
262
+ miscellaneous = set("♩♪♫♬♭♮♯")
263
+ assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
264
+
265
+ # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
266
+ result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
267
+ for symbol in symbols + list(miscellaneous):
268
+ for tokens in [
269
+ self.encoding.encode(symbol),
270
+ self.encoding.encode(" " + symbol),
271
+ ]:
272
+ if len(tokens) == 1 or symbol in miscellaneous:
273
+ result.add(tokens[0])
274
+
275
+ return tuple(sorted(result))
276
+
277
+ def split_to_word_tokens(self, tokens: List[int]):
278
+ if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
279
+ # These languages don't typically use spaces, so it is difficult to split words
280
+ # without morpheme analysis. Here, we instead split words at any
281
+ # position where the tokens are decoded as valid unicode points
282
+ return self.split_tokens_on_unicode(tokens)
283
+
284
+ return self.split_tokens_on_spaces(tokens)
285
+
286
+ def split_tokens_on_unicode(self, tokens: List[int]):
287
+ decoded_full = self.decode_with_timestamps(tokens)
288
+ replacement_char = "\ufffd"
289
+
290
+ words = []
291
+ word_tokens = []
292
+ current_tokens = []
293
+ unicode_offset = 0
294
+
295
+ for token in tokens:
296
+ current_tokens.append(token)
297
+ decoded = self.decode_with_timestamps(current_tokens)
298
+
299
+ if (
300
+ replacement_char not in decoded
301
+ or decoded_full[unicode_offset + decoded.index(replacement_char)]
302
+ == replacement_char
303
+ ):
304
+ words.append(decoded)
305
+ word_tokens.append(current_tokens)
306
+ current_tokens = []
307
+ unicode_offset += len(decoded)
308
+
309
+ return words, word_tokens
310
+
311
+ def split_tokens_on_spaces(self, tokens: List[int]):
312
+ subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
313
+ words = []
314
+ word_tokens = []
315
+
316
+ for subword, subword_tokens in zip(subwords, subword_tokens_list):
317
+ special = subword_tokens[0] >= self.eot
318
+ with_space = subword.startswith(" ")
319
+ punctuation = subword.strip() in string.punctuation
320
+ if special or with_space or punctuation or len(words) == 0:
321
+ words.append(subword)
322
+ word_tokens.append(subword_tokens)
323
+ else:
324
+ words[-1] = words[-1] + subword
325
+ word_tokens[-1].extend(subword_tokens)
326
+
327
+ return words, word_tokens
328
+
329
+
330
+ @lru_cache(maxsize=None)
331
+ def get_encoding(name: str = "gpt2", num_languages: int = 99):
332
+ vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
333
+ ranks = {
334
+ base64.b64decode(token): int(rank)
335
+ for token, rank in (line.split() for line in open(vocab_path) if line)
336
+ }
337
+ n_vocab = len(ranks)
338
+ special_tokens = {}
339
+
340
+ specials = [
341
+ "<|endoftext|>",
342
+ "<|startoftranscript|>",
343
+ *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
344
+ "<|translate|>",
345
+ "<|transcribe|>",
346
+ "<|startoflm|>",
347
+ "<|startofprev|>",
348
+ "<|nospeech|>",
349
+ "<|notimestamps|>",
350
+ *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
351
+ ]
352
+
353
+ for token in specials:
354
+ special_tokens[token] = n_vocab
355
+ n_vocab += 1
356
+
357
+ return tiktoken.Encoding(
358
+ name=os.path.basename(vocab_path),
359
+ explicit_n_vocab=n_vocab,
360
+ pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
361
+ mergeable_ranks=ranks,
362
+ special_tokens=special_tokens,
363
+ )
364
+
365
+
366
+ @lru_cache(maxsize=None)
367
+ def get_tokenizer(
368
+ multilingual: bool,
369
+ *,
370
+ num_languages: int = 99,
371
+ language: Optional[str] = None,
372
+ task: Optional[str] = None, # Literal["transcribe", "translate", None]
373
+ ) -> Tokenizer:
374
+ if language is not None:
375
+ language = language.lower()
376
+ if language not in LANGUAGES:
377
+ if language in TO_LANGUAGE_CODE:
378
+ language = TO_LANGUAGE_CODE[language]
379
+ else:
380
+ raise ValueError(f"Unsupported language: {language}")
381
+
382
+ if multilingual:
383
+ encoding_name = "multilingual"
384
+ language = language or "en"
385
+ task = task or "transcribe"
386
+ else:
387
+ encoding_name = "gpt2"
388
+ language = None
389
+ task = None
390
+
391
+ encoding = get_encoding(name=encoding_name, num_languages=num_languages)
392
+
393
+ return Tokenizer(
394
+ encoding=encoding, num_languages=num_languages, language=language, task=task
395
+ )
simul_whisper/whisper/trans_nopad.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import warnings
4
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import tqdm
9
+
10
+ from whisper.audio import (
11
+ FRAMES_PER_SECOND,
12
+ HOP_LENGTH,
13
+ N_FRAMES,
14
+ N_SAMPLES,
15
+ SAMPLE_RATE,
16
+ log_mel_spectrogram,
17
+ pad_or_trim,
18
+ )
19
+ from whisper.decoding import DecodingOptions, DecodingResult
20
+ from whisper.timing import add_word_timestamps
21
+ from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
22
+ from whisper.utils import (
23
+ exact_div,
24
+ format_timestamp,
25
+ get_writer,
26
+ make_safe,
27
+ optional_float,
28
+ optional_int,
29
+ str2bool,
30
+ )
31
+
32
+ if TYPE_CHECKING:
33
+ from whisper.model import Whisper
34
+
35
+
36
+ def transcribe(
37
+ model: "Whisper",
38
+ audio: Union[str, np.ndarray, torch.Tensor],
39
+ *,
40
+ verbose: Optional[bool] = None,
41
+ temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
42
+ compression_ratio_threshold: Optional[float] = 2.4,
43
+ logprob_threshold: Optional[float] = -1.0,
44
+ no_speech_threshold: Optional[float] = 0.6,
45
+ condition_on_previous_text: bool = True,
46
+ initial_prompt: Optional[str] = None,
47
+ word_timestamps: bool = False,
48
+ prepend_punctuations: str = "\"'“¿([{-",
49
+ append_punctuations: str = "\"'.。,,!!??::”)]}、",
50
+ **decode_options,
51
+ ):
52
+ """
53
+ Transcribe an audio file using Whisper
54
+
55
+ Parameters
56
+ ----------
57
+ model: Whisper
58
+ The Whisper model instance
59
+
60
+ audio: Union[str, np.ndarray, torch.Tensor]
61
+ The path to the audio file to open, or the audio waveform
62
+
63
+ verbose: bool
64
+ Whether to display the text being decoded to the console. If True, displays all the details,
65
+ If False, displays minimal details. If None, does not display anything
66
+
67
+ temperature: Union[float, Tuple[float, ...]]
68
+ Temperature for sampling. It can be a tuple of temperatures, which will be successively used
69
+ upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
70
+
71
+ compression_ratio_threshold: float
72
+ If the gzip compression ratio is above this value, treat as failed
73
+
74
+ logprob_threshold: float
75
+ If the average log probability over sampled tokens is below this value, treat as failed
76
+
77
+ no_speech_threshold: float
78
+ If the no_speech probability is higher than this value AND the average log probability
79
+ over sampled tokens is below `logprob_threshold`, consider the segment as silent
80
+
81
+ condition_on_previous_text: bool
82
+ if True, the previous output of the model is provided as a prompt for the next window;
83
+ disabling may make the text inconsistent across windows, but the model becomes less prone to
84
+ getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
85
+
86
+ word_timestamps: bool
87
+ Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
88
+ and include the timestamps for each word in each segment.
89
+
90
+ prepend_punctuations: str
91
+ If word_timestamps is True, merge these punctuation symbols with the next word
92
+
93
+ append_punctuations: str
94
+ If word_timestamps is True, merge these punctuation symbols with the previous word
95
+
96
+ initial_prompt: Optional[str]
97
+ Optional text to provide as a prompt for the first window. This can be used to provide, or
98
+ "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
99
+ to make it more likely to predict those word correctly.
100
+
101
+ decode_options: dict
102
+ Keyword arguments to construct `DecodingOptions` instances
103
+
104
+ Returns
105
+ -------
106
+ A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
107
+ the spoken language ("language"), which is detected when `decode_options["language"]` is None.
108
+ """
109
+ # print("HACKED")
110
+ dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
111
+ if model.device == torch.device("cpu"):
112
+ if torch.cuda.is_available():
113
+ warnings.warn("Performing inference on CPU when CUDA is available")
114
+ if dtype == torch.float16:
115
+ warnings.warn("FP16 is not supported on CPU; using FP32 instead")
116
+ dtype = torch.float32
117
+
118
+ if dtype == torch.float32:
119
+ decode_options["fp16"] = False
120
+
121
+ # Pad 30-seconds of silence to the input audio, for slicing
122
+ mel = log_mel_spectrogram(audio, padding=0) # log_mel_spectrogram(audio, padding=N_SAMPLES) # 添加16000*30 = 480000个点
123
+ # mel = pad_or_trim(mel, 3000)
124
+ content_frames = mel.shape[-1] # - N_FRAMES # 对应3000帧;真正有内容的是去掉尾部3000的那些数据
125
+
126
+ # 判断语种
127
+ if decode_options.get("language", None) is None:
128
+ # 如果是单语种模型,直接设成英文
129
+ if not model.is_multilingual:
130
+ decode_options["language"] = "en"
131
+ # 否则需要前传一次
132
+ else:
133
+ if verbose:
134
+ print(
135
+ "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
136
+ )
137
+ mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
138
+ # print(mel_segment.shape)
139
+ _, probs = model.detect_language(mel_segment)
140
+ decode_options["language"] = max(probs, key=probs.get)
141
+ if verbose is not None:
142
+ print(
143
+ f"Detected language: {LANGUAGES[decode_options['language']].title()}"
144
+ )
145
+
146
+ language: str = decode_options["language"]
147
+ task: str = decode_options.get("task", "transcribe")
148
+ # 输出编码器
149
+ tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
150
+
151
+ # 词级别时间戳
152
+ if word_timestamps and task == "translate":
153
+ warnings.warn("Word-level timestamps on translations may not be reliable.")
154
+
155
+ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
156
+ temperatures = (
157
+ [temperature] if isinstance(temperature, (int, float)) else temperature
158
+ )
159
+ decode_result = None
160
+
161
+ for t in temperatures:
162
+ kwargs = {**decode_options}
163
+ if t > 0:
164
+ # disable beam_size and patience when t > 0
165
+ kwargs.pop("beam_size", None)
166
+ kwargs.pop("patience", None)
167
+ else:
168
+ # disable best_of when t == 0
169
+ kwargs.pop("best_of", None)
170
+
171
+ options = DecodingOptions(**kwargs, temperature=t)
172
+ decode_result = model.decode(segment, options)
173
+
174
+ # 几种解码可能失败的情况。这些情况下会重复解码
175
+ # 感觉是一种KnowHow的东西 或许ChatGPT里有不少这种trick
176
+ needs_fallback = False
177
+ if (
178
+ compression_ratio_threshold is not None
179
+ and decode_result.compression_ratio > compression_ratio_threshold
180
+ ):
181
+ needs_fallback = True # too repetitive
182
+ if (
183
+ logprob_threshold is not None
184
+ and decode_result.avg_logprob < logprob_threshold
185
+ ):
186
+ needs_fallback = True # average log probability is too low
187
+ if (
188
+ no_speech_threshold is not None
189
+ and decode_result.no_speech_prob > no_speech_threshold
190
+ ):
191
+ needs_fallback = False # silence
192
+ if not needs_fallback:
193
+ break
194
+ # print("decode with temperature {} compress rate {:.3f}/{:.3f}, log_prob {:.3f}/{:.3f}, {:.3f}/{:.3f}".format(
195
+ # t,
196
+ # decode_result.compression_ratio, compression_ratio_threshold,
197
+ # -decode_result.avg_logprob, -logprob_threshold,
198
+ # decode_result.no_speech_prob, no_speech_threshold
199
+ # ))
200
+
201
+ return decode_result
202
+
203
+ seek = 0
204
+ input_stride = exact_div(
205
+ N_FRAMES, model.dims.n_audio_ctx
206
+ ) # mel frames per output token: 2
207
+ # 这里output token指的应该是CNN输出的那个东西
208
+
209
+ time_precision = (
210
+ input_stride * HOP_LENGTH / SAMPLE_RATE
211
+ ) # time per output token: 0.02 (seconds)
212
+ all_tokens = []
213
+ all_segments = []
214
+ prompt_reset_since = 0
215
+
216
+ if initial_prompt is not None:
217
+ initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
218
+ all_tokens.extend(initial_prompt_tokens)
219
+ else:
220
+ initial_prompt_tokens = []
221
+
222
+ def new_segment(
223
+ *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
224
+ ):
225
+ tokens = tokens.tolist()
226
+ text_tokens = [token for token in tokens if token < tokenizer.eot]
227
+ return {
228
+ "seek": seek,
229
+ "start": start,
230
+ "end": end,
231
+ "text": tokenizer.decode(text_tokens),
232
+ "tokens": tokens,
233
+ "temperature": result.temperature,
234
+ "avg_logprob": result.avg_logprob,
235
+ "compression_ratio": result.compression_ratio,
236
+ "no_speech_prob": result.no_speech_prob,
237
+ }
238
+
239
+ # show the progress bar when verbose is False (if True, transcribed text will be printed)
240
+ with tqdm.tqdm(
241
+ total=content_frames, unit="frames", disable=verbose is not False
242
+ ) as pbar:
243
+ last_speech_timestamp = 0.0
244
+ while seek < content_frames: # seek:标记mel频谱当前帧的位置 直接跳过Padding上的部分
245
+ # print("seek segments", seek, content_frames)
246
+ time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) # 本片段的开始时间
247
+ # mel_segment = mel[:, seek : seek + N_FRAMES] # 获得当前片段的数据
248
+ mel_segment = mel[:, seek:]
249
+ segment_size = min(N_FRAMES, content_frames - seek) # segment_size: 排除padding的真的长度。content_frames:有内容的段的真正长度 如果不够N_FRAMES的话就会截断
250
+ segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE # 当前片段的时长
251
+ mel_segment = mel_segment.to(model.device).to(dtype) # pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) # 补到mel_segment帧
252
+
253
+ decode_options["prompt"] = all_tokens[prompt_reset_since:]
254
+ result: DecodingResult = decode_with_fallback(mel_segment)
255
+ tokens = torch.tensor(result.tokens)
256
+
257
+ # 跳过静音部分
258
+ if no_speech_threshold is not None:
259
+ # no voice activity check
260
+ should_skip = result.no_speech_prob > no_speech_threshold
261
+ if (
262
+ logprob_threshold is not None
263
+ and result.avg_logprob > logprob_threshold
264
+ ):
265
+ # don't skip if the logprob is high enough, despite the no_speech_prob
266
+ should_skip = False
267
+
268
+ if should_skip:
269
+ seek += segment_size # fast-forward to the next segment boundary
270
+ continue
271
+
272
+ previous_seek = seek
273
+ current_segments = []
274
+
275
+ timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) # timestamp begin是<|0.00|>的token;bos比文字token大,eos的值比bos还大,所以是ge
276
+ timestamp_tokens[-1] = False
277
+ single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] # 如果最后是[False,True]:本段里一个句子结束了
278
+
279
+ consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
280
+ # torch.where(condition) is identical to torch.nonzero(condition, as_tuple=True).
281
+ # timestamp_token就是个一维向量吧 那为啥不直接nonzero
282
+ # 如果有两个连续的时间戳 这个会是一个一维tensor 是这两个连续时间戳的结尾位置
283
+ # 多个的话指向第二个 那如果有三个怎么办?
284
+ # 否则是个0维tensor
285
+
286
+ consecutive.add_(1) # 0维tensor+1还是0维 哪儿找的这些edge cases js是吧
287
+ if len(consecutive) > 0:
288
+ # if the output contains two consecutive timestamp tokens
289
+ slices = consecutive.tolist()
290
+ if single_timestamp_ending:
291
+ slices.append(len(tokens)) # 把最后一段的结尾也加进去
292
+ # print("many sentenses", consecutive)
293
+ last_slice = 0
294
+ for current_slice in slices:
295
+ sliced_tokens = tokens[last_slice:current_slice]
296
+ # 看起来语音开始帧、语音结束帧的位置会被编码到start_timestamp中
297
+ start_timestamp_pos = (
298
+ sliced_tokens[0].item() - tokenizer.timestamp_begin
299
+ )
300
+ end_timestamp_pos = (
301
+ sliced_tokens[-1].item() - tokenizer.timestamp_begin
302
+ )
303
+ # 获取一个新的语音段
304
+ current_segments.append(
305
+ new_segment(
306
+ start=time_offset + start_timestamp_pos * time_precision,
307
+ end=time_offset + end_timestamp_pos * time_precision,
308
+ tokens=sliced_tokens,
309
+ result=result,
310
+ )
311
+ )
312
+ last_slice = current_slice
313
+
314
+ if single_timestamp_ending:
315
+ # single timestamp at the end means no speech after the last timestamp.
316
+ seek += segment_size
317
+ else:
318
+ # otherwise, ignore the unfinished segment and seek to the last timestamp
319
+ # 如果语音尚未结束,那么seek变为上一个结束的语段的位置
320
+ # 换句话说就是针对30s长的chunk的语音设计的
321
+ last_timestamp_pos = (
322
+ tokens[last_slice - 1].item() - tokenizer.timestamp_begin
323
+ )
324
+ seek += last_timestamp_pos * input_stride
325
+ else:
326
+ duration = segment_duration
327
+ timestamps = tokens[timestamp_tokens.nonzero().flatten()]
328
+ # print(timestamps)
329
+ if (
330
+ len(timestamps) > 0
331
+ and timestamps[-1].item() != tokenizer.timestamp_begin
332
+ ):
333
+ # no consecutive timestamps but it has a timestamp; use the last one.
334
+ # 取最后一个;假设要么有一个结束的time stamp;要么有一对儿?
335
+ # 如果里面只有一个开始的timestamp 似乎后面的东西都会被丢掉?
336
+ last_timestamp_pos = (
337
+ timestamps[-1].item() - tokenizer.timestamp_begin
338
+ )
339
+ duration = last_timestamp_pos * time_precision
340
+
341
+ current_segments.append(
342
+ new_segment(
343
+ start=time_offset,
344
+ end=time_offset + duration,
345
+ tokens=tokens,
346
+ result=result,
347
+ )
348
+ )
349
+ seek += segment_size
350
+
351
+ # 每个token有自己的时间戳
352
+ if word_timestamps:
353
+ add_word_timestamps(
354
+ segments=current_segments,
355
+ model=model,
356
+ tokenizer=tokenizer,
357
+ mel=mel_segment,
358
+ num_frames=segment_size,
359
+ prepend_punctuations=prepend_punctuations,
360
+ append_punctuations=append_punctuations,
361
+ last_speech_timestamp=last_speech_timestamp,
362
+ )
363
+ word_end_timestamps = [
364
+ w["end"] for s in current_segments for w in s["words"]
365
+ ]
366
+ if len(word_end_timestamps) > 0:
367
+ last_speech_timestamp = word_end_timestamps[-1]
368
+ if not single_timestamp_ending and len(word_end_timestamps) > 0:
369
+ seek_shift = round(
370
+ (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
371
+ )
372
+ if seek_shift > 0:
373
+ seek = previous_seek + seek_shift
374
+
375
+ if verbose:
376
+ for segment in current_segments:
377
+ start, end, text = segment["start"], segment["end"], segment["text"]
378
+ line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
379
+ print(make_safe(line))
380
+
381
+ # if a segment is instantaneous or does not contain text, clear it
382
+ for i, segment in enumerate(current_segments):
383
+ if segment["start"] == segment["end"] or segment["text"].strip() == "":
384
+ segment["text"] = ""
385
+ segment["tokens"] = []
386
+ segment["words"] = []
387
+
388
+ # 更新结果
389
+ all_segments.extend(
390
+ [
391
+ {"id": i, **segment}
392
+ for i, segment in enumerate(
393
+ current_segments, start=len(all_segments)
394
+ )
395
+ ]
396
+ )
397
+ all_tokens.extend(
398
+ [token for segment in current_segments for token in segment["tokens"]]
399
+ )
400
+
401
+ if not condition_on_previous_text or result.temperature > 0.5:
402
+ # do not feed the prompt tokens if a high temperature was used
403
+ prompt_reset_since = len(all_tokens)
404
+
405
+ # update progress bar
406
+ pbar.update(min(content_frames, seek) - previous_seek)
407
+
408
+ # print("太长了")
409
+ # break
410
+
411
+ return dict(
412
+ text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
413
+ segments=all_segments,
414
+ language=language,
415
+ )
416
+
417
+
418
+ def cli():
419
+ from . import available_models
420
+
421
+ # fmt: off
422
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
423
+ parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
424
+ parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
425
+ parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
426
+ parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
427
+ parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
428
+ parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
429
+ parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
430
+
431
+ parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
432
+ parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
433
+
434
+ parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
435
+ parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
436
+ parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
437
+ parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
438
+ parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
439
+
440
+ parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
441
+ parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
442
+ parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
443
+ parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
444
+
445
+ parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
446
+ parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
447
+ parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
448
+ parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
449
+ parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
450
+ parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
451
+ parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
452
+ parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
453
+ parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
454
+ parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
455
+ parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
456
+ # fmt: on
457
+
458
+ args = parser.parse_args().__dict__
459
+ model_name: str = args.pop("model")
460
+ model_dir: str = args.pop("model_dir")
461
+ output_dir: str = args.pop("output_dir")
462
+ output_format: str = args.pop("output_format")
463
+ device: str = args.pop("device")
464
+ os.makedirs(output_dir, exist_ok=True)
465
+
466
+ if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
467
+ if args["language"] is not None:
468
+ warnings.warn(
469
+ f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
470
+ )
471
+ args["language"] = "en"
472
+
473
+ temperature = args.pop("temperature")
474
+ if (increment := args.pop("temperature_increment_on_fallback")) is not None:
475
+ temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
476
+ else:
477
+ temperature = [temperature]
478
+
479
+ if (threads := args.pop("threads")) > 0:
480
+ torch.set_num_threads(threads)
481
+
482
+ from . import load_model
483
+
484
+ model = load_model(model_name, device=device, download_root=model_dir)
485
+
486
+ writer = get_writer(output_format, output_dir)
487
+ word_options = ["highlight_words", "max_line_count", "max_line_width"]
488
+ if not args["word_timestamps"]:
489
+ for option in word_options:
490
+ if args[option]:
491
+ parser.error(f"--{option} requires --word_timestamps True")
492
+ if args["max_line_count"] and not args["max_line_width"]:
493
+ warnings.warn("--max_line_count has no effect without --max_line_width")
494
+ writer_args = {arg: args.pop(arg) for arg in word_options}
495
+ for audio_path in args.pop("audio"):
496
+ result = transcribe(model, audio_path, temperature=temperature, **args)
497
+ writer(result, audio_path, writer_args)
498
+
499
+
500
+ if __name__ == "__main__":
501
+ cli()
simul_whisper/whisper/transcribe.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import warnings
4
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import tqdm
9
+
10
+ from .audio import (
11
+ FRAMES_PER_SECOND,
12
+ HOP_LENGTH,
13
+ N_FRAMES,
14
+ N_SAMPLES,
15
+ SAMPLE_RATE,
16
+ log_mel_spectrogram,
17
+ pad_or_trim,
18
+ )
19
+ from .decoding import DecodingOptions, DecodingResult
20
+ from .timing import add_word_timestamps
21
+ from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
22
+ from .utils import (
23
+ exact_div,
24
+ format_timestamp,
25
+ get_writer,
26
+ make_safe,
27
+ optional_float,
28
+ optional_int,
29
+ str2bool,
30
+ )
31
+
32
+ if TYPE_CHECKING:
33
+ from .model import Whisper
34
+
35
+
36
+ def transcribe(
37
+ model: "Whisper",
38
+ audio: Union[str, np.ndarray, torch.Tensor],
39
+ *,
40
+ verbose: Optional[bool] = None,
41
+ temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
42
+ compression_ratio_threshold: Optional[float] = 2.4,
43
+ logprob_threshold: Optional[float] = -1.0,
44
+ no_speech_threshold: Optional[float] = 0.6,
45
+ condition_on_previous_text: bool = True,
46
+ initial_prompt: Optional[str] = None,
47
+ word_timestamps: bool = False,
48
+ prepend_punctuations: str = "\"'“¿([{-",
49
+ append_punctuations: str = "\"'.。,,!!??::”)]}、",
50
+ **decode_options,
51
+ ):
52
+ """
53
+ Transcribe an audio file using Whisper
54
+
55
+ Parameters
56
+ ----------
57
+ model: Whisper
58
+ The Whisper model instance
59
+
60
+ audio: Union[str, np.ndarray, torch.Tensor]
61
+ The path to the audio file to open, or the audio waveform
62
+
63
+ verbose: bool
64
+ Whether to display the text being decoded to the console. If True, displays all the details,
65
+ If False, displays minimal details. If None, does not display anything
66
+
67
+ temperature: Union[float, Tuple[float, ...]]
68
+ Temperature for sampling. It can be a tuple of temperatures, which will be successively used
69
+ upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
70
+
71
+ compression_ratio_threshold: float
72
+ If the gzip compression ratio is above this value, treat as failed
73
+
74
+ logprob_threshold: float
75
+ If the average log probability over sampled tokens is below this value, treat as failed
76
+
77
+ no_speech_threshold: float
78
+ If the no_speech probability is higher than this value AND the average log probability
79
+ over sampled tokens is below `logprob_threshold`, consider the segment as silent
80
+
81
+ condition_on_previous_text: bool
82
+ if True, the previous output of the model is provided as a prompt for the next window;
83
+ disabling may make the text inconsistent across windows, but the model becomes less prone to
84
+ getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
85
+
86
+ word_timestamps: bool
87
+ Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
88
+ and include the timestamps for each word in each segment.
89
+
90
+ prepend_punctuations: str
91
+ If word_timestamps is True, merge these punctuation symbols with the next word
92
+
93
+ append_punctuations: str
94
+ If word_timestamps is True, merge these punctuation symbols with the previous word
95
+
96
+ initial_prompt: Optional[str]
97
+ Optional text to provide as a prompt for the first window. This can be used to provide, or
98
+ "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
99
+ to make it more likely to predict those word correctly.
100
+
101
+ decode_options: dict
102
+ Keyword arguments to construct `DecodingOptions` instances
103
+
104
+ Returns
105
+ -------
106
+ A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
107
+ the spoken language ("language"), which is detected when `decode_options["language"]` is None.
108
+ """
109
+ # print("transcribe")
110
+ dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
111
+ if model.device == torch.device("cpu"):
112
+ if torch.cuda.is_available():
113
+ warnings.warn("Performing inference on CPU when CUDA is available")
114
+ if dtype == torch.float16:
115
+ warnings.warn("FP16 is not supported on CPU; using FP32 instead")
116
+ dtype = torch.float32
117
+
118
+ if dtype == torch.float32:
119
+ decode_options["fp16"] = False
120
+
121
+ # Pad 30-seconds of silence to the input audio, for slicing
122
+ mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
123
+ content_frames = mel.shape[-1] - N_FRAMES
124
+
125
+ if decode_options.get("language", None) is None:
126
+ if not model.is_multilingual:
127
+ decode_options["language"] = "en"
128
+ else:
129
+ if verbose:
130
+ print(
131
+ "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
132
+ )
133
+ mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
134
+ # print(mel_segment.shape)
135
+ _, probs = model.detect_language(mel_segment)
136
+ decode_options["language"] = max(probs, key=probs.get)
137
+ if verbose is not None:
138
+ print(
139
+ f"Detected language: {LANGUAGES[decode_options['language']].title()}"
140
+ )
141
+
142
+ language: str = decode_options["language"]
143
+ task: str = decode_options.get("task", "transcribe")
144
+ tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
145
+
146
+ if word_timestamps and task == "translate":
147
+ warnings.warn("Word-level timestamps on translations may not be reliable.")
148
+
149
+ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
150
+ temperatures = (
151
+ [temperature] if isinstance(temperature, (int, float)) else temperature
152
+ )
153
+ decode_result = None
154
+
155
+ for t in temperatures:
156
+ kwargs = {**decode_options}
157
+ if t > 0:
158
+ # disable beam_size and patience when t > 0
159
+ kwargs.pop("beam_size", None)
160
+ kwargs.pop("patience", None)
161
+ else:
162
+ # disable best_of when t == 0
163
+ kwargs.pop("best_of", None)
164
+
165
+ options = DecodingOptions(**kwargs, temperature=t)
166
+ decode_result = model.decode(segment, options)
167
+
168
+ needs_fallback = False
169
+ if (
170
+ compression_ratio_threshold is not None
171
+ and decode_result.compression_ratio > compression_ratio_threshold
172
+ ):
173
+ needs_fallback = True # too repetitive
174
+ if (
175
+ logprob_threshold is not None
176
+ and decode_result.avg_logprob < logprob_threshold
177
+ ):
178
+ needs_fallback = True # average log probability is too low
179
+ if (
180
+ no_speech_threshold is not None
181
+ and decode_result.no_speech_prob > no_speech_threshold
182
+ ):
183
+ needs_fallback = False # silence
184
+ if not needs_fallback:
185
+ break
186
+
187
+ return decode_result
188
+
189
+ seek = 0
190
+ input_stride = exact_div(
191
+ N_FRAMES, model.dims.n_audio_ctx
192
+ ) # mel frames per output token: 2
193
+ time_precision = (
194
+ input_stride * HOP_LENGTH / SAMPLE_RATE
195
+ ) # time per output token: 0.02 (seconds)
196
+ all_tokens = []
197
+ all_segments = []
198
+ prompt_reset_since = 0
199
+
200
+ if initial_prompt is not None:
201
+ initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
202
+ all_tokens.extend(initial_prompt_tokens)
203
+ else:
204
+ initial_prompt_tokens = []
205
+
206
+ def new_segment(
207
+ *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
208
+ ):
209
+ tokens = tokens.tolist()
210
+ text_tokens = [token for token in tokens if token < tokenizer.eot]
211
+ return {
212
+ "seek": seek,
213
+ "start": start,
214
+ "end": end,
215
+ "text": tokenizer.decode(text_tokens),
216
+ "tokens": tokens,
217
+ "temperature": result.temperature,
218
+ "avg_logprob": result.avg_logprob,
219
+ "compression_ratio": result.compression_ratio,
220
+ "no_speech_prob": result.no_speech_prob,
221
+ }
222
+
223
+ # show the progress bar when verbose is False (if True, transcribed text will be printed)
224
+ with tqdm.tqdm(
225
+ total=content_frames, unit="frames", disable=verbose is not False
226
+ ) as pbar:
227
+ last_speech_timestamp = 0.0
228
+ while seek < content_frames:
229
+ time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
230
+ mel_segment = mel[:, seek : seek + N_FRAMES]
231
+ segment_size = min(N_FRAMES, content_frames - seek)
232
+ segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
233
+ mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
234
+
235
+ # print("melshape", mel_segment.shape)
236
+
237
+ decode_options["prompt"] = all_tokens[prompt_reset_since:]
238
+ result: DecodingResult = decode_with_fallback(mel_segment)
239
+ tokens = torch.tensor(result.tokens)
240
+
241
+ if no_speech_threshold is not None:
242
+ # no voice activity check
243
+ should_skip = result.no_speech_prob > no_speech_threshold
244
+ if (
245
+ logprob_threshold is not None
246
+ and result.avg_logprob > logprob_threshold
247
+ ):
248
+ # don't skip if the logprob is high enough, despite the no_speech_prob
249
+ should_skip = False
250
+
251
+ if should_skip:
252
+ seek += segment_size # fast-forward to the next segment boundary
253
+ continue
254
+
255
+ previous_seek = seek
256
+ current_segments = []
257
+
258
+ timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
259
+ single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
260
+
261
+ consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
262
+ consecutive.add_(1)
263
+ if len(consecutive) > 0:
264
+ # if the output contains two consecutive timestamp tokens
265
+ slices = consecutive.tolist()
266
+ if single_timestamp_ending:
267
+ slices.append(len(tokens))
268
+
269
+ last_slice = 0
270
+ for current_slice in slices:
271
+ sliced_tokens = tokens[last_slice:current_slice]
272
+ start_timestamp_pos = (
273
+ sliced_tokens[0].item() - tokenizer.timestamp_begin
274
+ )
275
+ end_timestamp_pos = (
276
+ sliced_tokens[-1].item() - tokenizer.timestamp_begin
277
+ )
278
+ current_segments.append(
279
+ new_segment(
280
+ start=time_offset + start_timestamp_pos * time_precision,
281
+ end=time_offset + end_timestamp_pos * time_precision,
282
+ tokens=sliced_tokens,
283
+ result=result,
284
+ )
285
+ )
286
+ last_slice = current_slice
287
+
288
+ if single_timestamp_ending:
289
+ # single timestamp at the end means no speech after the last timestamp.
290
+ seek += segment_size
291
+ else:
292
+ # otherwise, ignore the unfinished segment and seek to the last timestamp
293
+ last_timestamp_pos = (
294
+ tokens[last_slice - 1].item() - tokenizer.timestamp_begin
295
+ )
296
+ seek += last_timestamp_pos * input_stride
297
+ else:
298
+ duration = segment_duration
299
+ timestamps = tokens[timestamp_tokens.nonzero().flatten()]
300
+ if (
301
+ len(timestamps) > 0
302
+ and timestamps[-1].item() != tokenizer.timestamp_begin
303
+ ):
304
+ # no consecutive timestamps but it has a timestamp; use the last one.
305
+ last_timestamp_pos = (
306
+ timestamps[-1].item() - tokenizer.timestamp_begin
307
+ )
308
+ duration = last_timestamp_pos * time_precision
309
+
310
+ current_segments.append(
311
+ new_segment(
312
+ start=time_offset,
313
+ end=time_offset + duration,
314
+ tokens=tokens,
315
+ result=result,
316
+ )
317
+ )
318
+ seek += segment_size
319
+
320
+ # print("word_timestamps, ", word_timestamps)
321
+ if word_timestamps:
322
+ # print("=========run timestamps here=========")
323
+ add_word_timestamps(
324
+ segments=current_segments,
325
+ model=model,
326
+ tokenizer=tokenizer,
327
+ mel=mel_segment,
328
+ num_frames=segment_size,
329
+ prepend_punctuations=prepend_punctuations,
330
+ append_punctuations=append_punctuations,
331
+ last_speech_timestamp=last_speech_timestamp,
332
+ )
333
+ word_end_timestamps = [
334
+ w["end"] for s in current_segments for w in s["words"]
335
+ ]
336
+ if len(word_end_timestamps) > 0:
337
+ last_speech_timestamp = word_end_timestamps[-1]
338
+ if not single_timestamp_ending and len(word_end_timestamps) > 0:
339
+ seek_shift = round(
340
+ (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
341
+ )
342
+ if seek_shift > 0:
343
+ seek = previous_seek + seek_shift
344
+
345
+ if verbose:
346
+ for segment in current_segments:
347
+ start, end, text = segment["start"], segment["end"], segment["text"]
348
+ line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
349
+ print(make_safe(line))
350
+
351
+ # if a segment is instantaneous or does not contain text, clear it
352
+ for i, segment in enumerate(current_segments):
353
+ if segment["start"] == segment["end"] or segment["text"].strip() == "":
354
+ segment["text"] = ""
355
+ segment["tokens"] = []
356
+ segment["words"] = []
357
+
358
+ all_segments.extend(
359
+ [
360
+ {"id": i, **segment}
361
+ for i, segment in enumerate(
362
+ current_segments, start=len(all_segments)
363
+ )
364
+ ]
365
+ )
366
+ all_tokens.extend(
367
+ [token for segment in current_segments for token in segment["tokens"]]
368
+ )
369
+
370
+ if not condition_on_previous_text or result.temperature > 0.5:
371
+ # do not feed the prompt tokens if a high temperature was used
372
+ prompt_reset_since = len(all_tokens)
373
+
374
+ # update progress bar
375
+ pbar.update(min(content_frames, seek) - previous_seek)
376
+
377
+ return dict(
378
+ text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
379
+ segments=all_segments,
380
+ language=language,
381
+ )
382
+
383
+
384
+ def cli():
385
+ from . import available_models
386
+
387
+ # fmt: off
388
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
389
+ parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
390
+ parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
391
+ parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
392
+ parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
393
+ parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
394
+ parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
395
+ parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
396
+
397
+ parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
398
+ parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
399
+
400
+ parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
401
+ parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
402
+ parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
403
+ parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
404
+ parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
405
+
406
+ parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
407
+ parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
408
+ parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
409
+ parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
410
+
411
+ parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
412
+ parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
413
+ parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
414
+ parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
415
+ parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
416
+ parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
417
+ parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
418
+ parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
419
+ parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
420
+ parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
421
+ parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
422
+ # fmt: on
423
+
424
+ args = parser.parse_args().__dict__
425
+ model_name: str = args.pop("model")
426
+ model_dir: str = args.pop("model_dir")
427
+ output_dir: str = args.pop("output_dir")
428
+ output_format: str = args.pop("output_format")
429
+ device: str = args.pop("device")
430
+ os.makedirs(output_dir, exist_ok=True)
431
+
432
+ if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
433
+ if args["language"] is not None:
434
+ warnings.warn(
435
+ f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
436
+ )
437
+ args["language"] = "en"
438
+
439
+ temperature = args.pop("temperature")
440
+ if (increment := args.pop("temperature_increment_on_fallback")) is not None:
441
+ temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
442
+ else:
443
+ temperature = [temperature]
444
+
445
+ if (threads := args.pop("threads")) > 0:
446
+ torch.set_num_threads(threads)
447
+
448
+ from . import load_model
449
+
450
+ model = load_model(model_name, device=device, download_root=model_dir)
451
+
452
+ writer = get_writer(output_format, output_dir)
453
+ word_options = ["highlight_words", "max_line_count", "max_line_width"]
454
+ if not args["word_timestamps"]:
455
+ for option in word_options:
456
+ if args[option]:
457
+ parser.error(f"--{option} requires --word_timestamps True")
458
+ if args["max_line_count"] and not args["max_line_width"]:
459
+ warnings.warn("--max_line_count has no effect without --max_line_width")
460
+ writer_args = {arg: args.pop(arg) for arg in word_options}
461
+ for audio_path in args.pop("audio"):
462
+ result = transcribe(model, audio_path, temperature=temperature, **args)
463
+ writer(result, audio_path, writer_args)
464
+
465
+
466
+ if __name__ == "__main__":
467
+ cli()
simul_whisper/whisper/triton_ops.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ try:
7
+ import triton
8
+ import triton.language as tl
9
+ except ImportError:
10
+ raise RuntimeError("triton import failed; try `pip install --pre triton`")
11
+
12
+
13
+ @triton.jit
14
+ def dtw_kernel(
15
+ cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr
16
+ ):
17
+ offsets = tl.arange(0, BLOCK_SIZE)
18
+ mask = offsets < M
19
+
20
+ for k in range(1, N + M + 1): # k = i + j
21
+ tl.debug_barrier()
22
+
23
+ p0 = cost + (k - 1) * cost_stride
24
+ p1 = cost + k * cost_stride
25
+ p2 = cost + k * cost_stride + 1
26
+
27
+ c0 = tl.load(p0 + offsets, mask=mask)
28
+ c1 = tl.load(p1 + offsets, mask=mask)
29
+ c2 = tl.load(p2 + offsets, mask=mask)
30
+
31
+ x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0)
32
+ cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)
33
+
34
+ cost_ptr = cost + (k + 1) * cost_stride + 1
35
+ tl.store(cost_ptr + offsets, cost_row, mask=mask)
36
+
37
+ trace_ptr = trace + (k + 1) * trace_stride + 1
38
+ tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1))
39
+ tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2))
40
+ tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2))
41
+
42
+
43
+ @lru_cache(maxsize=None)
44
+ def median_kernel(filter_width: int):
45
+ @triton.jit
46
+ def kernel(
47
+ y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr
48
+ ): # x.shape[-1] == filter_width
49
+ row_idx = tl.program_id(0)
50
+ offsets = tl.arange(0, BLOCK_SIZE)
51
+ mask = offsets < y_stride
52
+
53
+ x_ptr = x + row_idx * x_stride # noqa: F841
54
+ y_ptr = y + row_idx * y_stride
55
+
56
+ LOAD_ALL_ROWS_HERE # noqa: F821
57
+
58
+ BUBBLESORT_HERE # noqa: F821
59
+
60
+ tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
61
+
62
+ kernel = triton.JITFunction(kernel.fn)
63
+ kernel.src = kernel.src.replace(
64
+ " LOAD_ALL_ROWS_HERE",
65
+ "\n".join(
66
+ [
67
+ f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
68
+ for i in range(filter_width)
69
+ ]
70
+ ),
71
+ )
72
+ kernel.src = kernel.src.replace(
73
+ " BUBBLESORT_HERE",
74
+ "\n\n".join(
75
+ [
76
+ "\n\n".join(
77
+ [
78
+ "\n".join(
79
+ [
80
+ f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
81
+ f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
82
+ f" row{j} = smaller",
83
+ f" row{j + 1} = larger",
84
+ ]
85
+ )
86
+ for j in range(filter_width - i - 1)
87
+ ]
88
+ )
89
+ for i in range(filter_width // 2 + 1)
90
+ ]
91
+ ),
92
+ )
93
+ kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
94
+
95
+ return kernel
96
+
97
+
98
+ def median_filter_cuda(x: torch.Tensor, filter_width: int):
99
+ """Apply a median filter of given width along the last dimension of x"""
100
+ slices = x.contiguous().unfold(-1, filter_width, 1)
101
+ grid = np.prod(slices.shape[:-2])
102
+
103
+ kernel = median_kernel(filter_width)
104
+ y = torch.empty_like(slices[..., 0])
105
+
106
+ BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length()
107
+ kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE)
108
+
109
+ return y
simul_whisper/whisper/utils.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ import sys
5
+ import zlib
6
+ from typing import Callable, Optional, TextIO
7
+
8
+ system_encoding = sys.getdefaultencoding()
9
+
10
+ if system_encoding != "utf-8":
11
+
12
+ def make_safe(string):
13
+ # replaces any character not representable using the system default encoding with an '?',
14
+ # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
15
+ return string.encode(system_encoding, errors="replace").decode(system_encoding)
16
+
17
+ else:
18
+
19
+ def make_safe(string):
20
+ # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
21
+ return string
22
+
23
+
24
+ def exact_div(x, y):
25
+ assert x % y == 0
26
+ return x // y
27
+
28
+
29
+ def str2bool(string):
30
+ str2val = {"True": True, "False": False}
31
+ if string in str2val:
32
+ return str2val[string]
33
+ else:
34
+ raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
35
+
36
+
37
+ def optional_int(string):
38
+ return None if string == "None" else int(string)
39
+
40
+
41
+ def optional_float(string):
42
+ return None if string == "None" else float(string)
43
+
44
+
45
+ def compression_ratio(text) -> float:
46
+ text_bytes = text.encode("utf-8")
47
+ return len(text_bytes) / len(zlib.compress(text_bytes))
48
+
49
+
50
+ def format_timestamp(
51
+ seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
52
+ ):
53
+ assert seconds >= 0, "non-negative timestamp expected"
54
+ milliseconds = round(seconds * 1000.0)
55
+
56
+ hours = milliseconds // 3_600_000
57
+ milliseconds -= hours * 3_600_000
58
+
59
+ minutes = milliseconds // 60_000
60
+ milliseconds -= minutes * 60_000
61
+
62
+ seconds = milliseconds // 1_000
63
+ milliseconds -= seconds * 1_000
64
+
65
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
66
+ return (
67
+ f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
68
+ )
69
+
70
+
71
+ class ResultWriter:
72
+ extension: str
73
+
74
+ def __init__(self, output_dir: str):
75
+ self.output_dir = output_dir
76
+
77
+ def __call__(self, result: dict, audio_path: str, options: dict):
78
+ audio_basename = os.path.basename(audio_path)
79
+ audio_basename = os.path.splitext(audio_basename)[0]
80
+ output_path = os.path.join(
81
+ self.output_dir, audio_basename + "." + self.extension
82
+ )
83
+
84
+ with open(output_path, "w", encoding="utf-8") as f:
85
+ self.write_result(result, file=f, options=options)
86
+
87
+ def write_result(self, result: dict, file: TextIO, options: dict):
88
+ raise NotImplementedError
89
+
90
+
91
+ class WriteTXT(ResultWriter):
92
+ extension: str = "txt"
93
+
94
+ def write_result(self, result: dict, file: TextIO, options: dict):
95
+ for segment in result["segments"]:
96
+ print(segment["text"].strip(), file=file, flush=True)
97
+
98
+
99
+ class SubtitlesWriter(ResultWriter):
100
+ always_include_hours: bool
101
+ decimal_marker: str
102
+
103
+ def iterate_result(self, result: dict, options: dict):
104
+ raw_max_line_width: Optional[int] = options["max_line_width"]
105
+ max_line_count: Optional[int] = options["max_line_count"]
106
+ highlight_words: bool = options["highlight_words"]
107
+ max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
108
+ preserve_segments = max_line_count is None or raw_max_line_width is None
109
+
110
+ def iterate_subtitles():
111
+ line_len = 0
112
+ line_count = 1
113
+ # the next subtitle to yield (a list of word timings with whitespace)
114
+ subtitle: list[dict] = []
115
+ last = result["segments"][0]["words"][0]["start"]
116
+ for segment in result["segments"]:
117
+ for i, original_timing in enumerate(segment["words"]):
118
+ timing = original_timing.copy()
119
+ long_pause = not preserve_segments and timing["start"] - last > 3.0
120
+ has_room = line_len + len(timing["word"]) <= max_line_width
121
+ seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
122
+ if line_len > 0 and has_room and not long_pause and not seg_break:
123
+ # line continuation
124
+ line_len += len(timing["word"])
125
+ else:
126
+ # new line
127
+ timing["word"] = timing["word"].strip()
128
+ if (
129
+ len(subtitle) > 0
130
+ and max_line_count is not None
131
+ and (long_pause or line_count >= max_line_count)
132
+ or seg_break
133
+ ):
134
+ # subtitle break
135
+ yield subtitle
136
+ subtitle = []
137
+ line_count = 1
138
+ elif line_len > 0:
139
+ # line break
140
+ line_count += 1
141
+ timing["word"] = "\n" + timing["word"]
142
+ line_len = len(timing["word"].strip())
143
+ subtitle.append(timing)
144
+ last = timing["start"]
145
+ if len(subtitle) > 0:
146
+ yield subtitle
147
+
148
+ if len(result["segments"]) > 0 and "words" in result["segments"][0]:
149
+ for subtitle in iterate_subtitles():
150
+ subtitle_start = self.format_timestamp(subtitle[0]["start"])
151
+ subtitle_end = self.format_timestamp(subtitle[-1]["end"])
152
+ subtitle_text = "".join([word["word"] for word in subtitle])
153
+ if highlight_words:
154
+ last = subtitle_start
155
+ all_words = [timing["word"] for timing in subtitle]
156
+ for i, this_word in enumerate(subtitle):
157
+ start = self.format_timestamp(this_word["start"])
158
+ end = self.format_timestamp(this_word["end"])
159
+ if last != start:
160
+ yield last, start, subtitle_text
161
+
162
+ yield start, end, "".join(
163
+ [
164
+ re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
165
+ if j == i
166
+ else word
167
+ for j, word in enumerate(all_words)
168
+ ]
169
+ )
170
+ last = end
171
+ else:
172
+ yield subtitle_start, subtitle_end, subtitle_text
173
+ else:
174
+ for segment in result["segments"]:
175
+ segment_start = self.format_timestamp(segment["start"])
176
+ segment_end = self.format_timestamp(segment["end"])
177
+ segment_text = segment["text"].strip().replace("-->", "->")
178
+ yield segment_start, segment_end, segment_text
179
+
180
+ def format_timestamp(self, seconds: float):
181
+ return format_timestamp(
182
+ seconds=seconds,
183
+ always_include_hours=self.always_include_hours,
184
+ decimal_marker=self.decimal_marker,
185
+ )
186
+
187
+
188
+ class WriteVTT(SubtitlesWriter):
189
+ extension: str = "vtt"
190
+ always_include_hours: bool = False
191
+ decimal_marker: str = "."
192
+
193
+ def write_result(self, result: dict, file: TextIO, options: dict):
194
+ print("WEBVTT\n", file=file)
195
+ for start, end, text in self.iterate_result(result, options):
196
+ print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
197
+
198
+
199
+ class WriteSRT(SubtitlesWriter):
200
+ extension: str = "srt"
201
+ always_include_hours: bool = True
202
+ decimal_marker: str = ","
203
+
204
+ def write_result(self, result: dict, file: TextIO, options: dict):
205
+ for i, (start, end, text) in enumerate(
206
+ self.iterate_result(result, options), start=1
207
+ ):
208
+ print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
209
+
210
+
211
+ class WriteTSV(ResultWriter):
212
+ """
213
+ Write a transcript to a file in TSV (tab-separated values) format containing lines like:
214
+ <start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
215
+
216
+ Using integer milliseconds as start and end times means there's no chance of interference from
217
+ an environment setting a language encoding that causes the decimal in a floating point number
218
+ to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
219
+ """
220
+
221
+ extension: str = "tsv"
222
+
223
+ def write_result(self, result: dict, file: TextIO, options: dict):
224
+ print("start", "end", "text", sep="\t", file=file)
225
+ for segment in result["segments"]:
226
+ print(round(1000 * segment["start"]), file=file, end="\t")
227
+ print(round(1000 * segment["end"]), file=file, end="\t")
228
+ print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
229
+
230
+
231
+ class WriteJSON(ResultWriter):
232
+ extension: str = "json"
233
+
234
+ def write_result(self, result: dict, file: TextIO, options: dict):
235
+ json.dump(result, file)
236
+
237
+
238
+ def get_writer(
239
+ output_format: str, output_dir: str
240
+ ) -> Callable[[dict, TextIO, dict], None]:
241
+ writers = {
242
+ "txt": WriteTXT,
243
+ "vtt": WriteVTT,
244
+ "srt": WriteSRT,
245
+ "tsv": WriteTSV,
246
+ "json": WriteJSON,
247
+ }
248
+
249
+ if output_format == "all":
250
+ all_writers = [writer(output_dir) for writer in writers.values()]
251
+
252
+ def write_all(result: dict, file: TextIO, options: dict):
253
+ for writer in all_writers:
254
+ writer(result, file, options)
255
+
256
+ return write_all
257
+
258
+ return writers[output_format](output_dir)
simul_whisper/whisper/version.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "20230918"
simulstreaming_whisper.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from whisper_streaming.base import OnlineProcessorInterface, ASRBase
2
+ import argparse
3
+
4
+ import sys
5
+ import logging
6
+ import torch
7
+
8
+ from simul_whisper.config import AlignAttConfig
9
+ from simul_whisper.simul_whisper import PaddedAlignAttWhisper
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ def simulwhisper_args(parser):
14
+ group = parser.add_argument_group('Whisper arguments')
15
+ group.add_argument('--model_path', type=str, default='./large-v3.pt',
16
+ help='The file path to the Whisper .pt model. If not present on the filesystem, the model is downloaded automatically.')
17
+ group.add_argument("--beams","-b", type=int, default=1, help="Number of beams for beam search decoding. If 1, GreedyDecoder is used.")
18
+ group.add_argument("--decoder",type=str, default=None, help="Override automatic selection of beam or greedy decoder. "
19
+ "If beams > 1 and greedy: invalid.")
20
+
21
+ group = parser.add_argument_group('Audio buffer')
22
+ group.add_argument('--audio_max_len', type=float, default=30.0,
23
+ help='Max length of the audio buffer, in seconds.')
24
+ group.add_argument('--audio_min_len', type=float, default=0.0,
25
+ help='Skip processing if the audio buffer is shorter than this length, in seconds. Useful when the --min-chunk-size is small.')
26
+
27
+
28
+ group = parser.add_argument_group('AlignAtt argument')
29
+ group.add_argument('--frame_threshold', type=int, default=25,
30
+ help='Threshold for the attention-guided decoding. The AlignAtt policy will decode only ' \
31
+ 'until this number of frames from the end of audio. In frames: one frame is 0.02 seconds for large-v3 model. ')
32
+
33
+ group = parser.add_argument_group('Truncation of the last decoded word (from Simul-Whisper)')
34
+ group.add_argument('--cif_ckpt_path', type=str, default=None,
35
+ help='The file path to the Simul-Whisper\'s CIF model checkpoint that detects whether there is' \
36
+ 'end of word at the end of the chunk. If not, the last decoded space-separated word is truncated ' \
37
+ 'because it is often wrong -- transcribing a word in the middle.' \
38
+ 'The CIF model adapted for the Whisper model version should be used. ' \
39
+ 'Find the models in https://github.com/backspacetg/simul_whisper/tree/main/cif_models . ' \
40
+ 'Note that there is no model for large-v3.')
41
+ group.add_argument("--never_fire", action=argparse.BooleanOptionalAction, default=False,
42
+ help="Override the CIF model. If True, the last word is NEVER truncated, no matter what the CIF model detects. " \
43
+ ". If False: if CIF model path is set, the last word is SOMETIMES truncated, depending on the CIF detection. " \
44
+ "Otherwise, if the CIF model path is not set, the last word is ALWAYS trimmed.")
45
+
46
+ group = parser.add_argument_group("Prompt and context")
47
+ group.add_argument("--init_prompt",type=str, default=None, help="Init prompt for the model. It should be in the target language.")
48
+ group.add_argument("--static_init_prompt",type=str, default=None, help="Do not scroll over this text. It can contain terminology that should be relevant over all document.")
49
+ group.add_argument("--max_context_tokens",type=int, default=None, help="Max context tokens for the model. Default is 0.")
50
+
51
+
52
+ def simul_asr_factory(args):
53
+ logger.setLevel(args.log_level)
54
+ decoder = args.decoder
55
+ if args.beams > 1:
56
+ if decoder == "greedy":
57
+ raise ValueError("Invalid 'greedy' decoder type for beams > 1. Use 'beam'.")
58
+ elif decoder is None or decoder == "beam":
59
+ decoder = "beam"
60
+ else:
61
+ raise ValueError("Invalid decoder type. Use 'beam' or 'greedy'.")
62
+ else:
63
+ if decoder is None:
64
+ decoder = "greedy"
65
+ elif decoder not in ("beam","greedy"):
66
+ raise ValueError("Invalid decoder type. Use 'beam' or 'greedy'.")
67
+ # else: it is greedy or beam, that's ok
68
+
69
+ a = { v:getattr(args, v) for v in ["model_path", "cif_ckpt_path", "frame_threshold", "audio_min_len", "audio_max_len", "beams", "task",
70
+ "never_fire", 'init_prompt', 'static_init_prompt', 'max_context_tokens', "logdir"
71
+ ]}
72
+ a["language"] = args.lan
73
+ a["segment_length"] = args.min_chunk_size
74
+ a["decoder_type"] = decoder
75
+
76
+ if args.min_chunk_size >= args.audio_max_len:
77
+ raise ValueError("min_chunk_size must be smaller than audio_max_len")
78
+ if args.audio_min_len > args.audio_max_len:
79
+ raise ValueError("audio_min_len must be smaller than audio_max_len")
80
+ logger.info(f"Arguments: {a}")
81
+ asr = SimulWhisperASR(**a)
82
+ return asr, SimulWhisperOnline(asr)
83
+
84
+ class SimulWhisperASR(ASRBase):
85
+
86
+ sep = " "
87
+
88
+ def __init__(self, language, model_path, cif_ckpt_path, frame_threshold, audio_max_len, audio_min_len, segment_length, beams, task,
89
+ decoder_type, never_fire, init_prompt, static_init_prompt, max_context_tokens, logdir):
90
+ cfg = AlignAttConfig(
91
+ model_path=model_path,
92
+ segment_length=segment_length,
93
+ frame_threshold=frame_threshold,
94
+ language=language,
95
+ audio_max_len=audio_max_len,
96
+ audio_min_len=audio_min_len,
97
+ cif_ckpt_path=cif_ckpt_path,
98
+ decoder_type=decoder_type, #"greedy" if beams==1 else "beam",
99
+ beam_size=beams,
100
+ task=task,
101
+ never_fire=never_fire,
102
+ init_prompt=init_prompt,
103
+ max_context_tokens=max_context_tokens,
104
+ static_init_prompt=static_init_prompt,
105
+ logdir=logdir,
106
+ )
107
+ logger.info(f"Language: {language}")
108
+ self.model = PaddedAlignAttWhisper(cfg)
109
+
110
+ def transcribe(self, audio, init_prompt=""):
111
+ logger.info("SimulWhisperASR's transcribe() should not be used. It's here only temporarily." \
112
+ "Instead, use SimulWhisperOnline.process_iter().")
113
+ raise NotImplementedError("Use SimulWhisperOnline.process_iter() instead of transcribe().")
114
+
115
+ def warmup(self, audio, init_prompt=""):
116
+ self.model.insert_audio(audio)
117
+ self.model.infer(True)
118
+ self.model.refresh_segment(complete=True)
119
+
120
+ def use_vad(self):
121
+ print("VAD not implemented",file=sys.stderr)
122
+
123
+ def set_translate_task(self):
124
+ # this is not used. Translate task is set another way.
125
+ pass
126
+
127
+
128
+ class SimulWhisperOnline(OnlineProcessorInterface):
129
+
130
+ def __init__(self, asr):
131
+ self.model = asr.model
132
+ self.file = None
133
+ self.init()
134
+
135
+ def init(self, offset=None):
136
+ self.audio_chunks = []
137
+ if offset is not None:
138
+ self.offset = offset
139
+ else:
140
+ self.offset = 0
141
+ self.is_last = False
142
+ self.beg = self.offset
143
+ self.end = self.offset
144
+
145
+ self.audio_bufer_offset = self.offset
146
+ self.last_ts = -1
147
+ self.model.refresh_segment(complete=True)
148
+
149
+ self.unicode_buffer = [] # hide incomplete unicode character for the next iteration
150
+
151
+ def insert_audio_chunk(self, audio):
152
+ self.audio_chunks.append(torch.from_numpy(audio))
153
+
154
+ def timestamped_text(self, tokens, generation):
155
+ if not generation:
156
+ return []
157
+
158
+ pr = generation["progress"]
159
+ if "result" not in generation or self.unicode_buffer != []:
160
+ split_words, split_tokens = self.model.tokenizer.split_to_word_tokens(tokens)
161
+ else:
162
+ split_words, split_tokens = generation["result"]["split_words"], generation["result"]["split_tokens"]
163
+
164
+ frames = [p["most_attended_frames"][0] for p in pr]
165
+ if frames and self.unicode_buffer != []:
166
+ a = [frames[0]] * len(self.unicode_buffer)
167
+ frames = a + frames
168
+
169
+ tokens = tokens.copy()
170
+ ret = []
171
+ for sw,st in zip(split_words,split_tokens):
172
+ b = None
173
+ for stt in st:
174
+ t,f = tokens.pop(0), frames.pop(0)
175
+ if t != stt:
176
+ raise ValueError(f"Token mismatch: {t} != {stt} at frame {f}.")
177
+ if b is None:
178
+ b = f
179
+ e = f
180
+ out = {
181
+ 'start': b * 0.02 + self.audio_bufer_offset,
182
+ 'end': e * 0.02 + self.audio_bufer_offset,
183
+ 'text': sw,
184
+ 'tokens': st
185
+ }
186
+ ret.append(out)
187
+ logger.debug(f"TS-WORD-INFO: {out}")
188
+ return ret
189
+
190
+ def hide_incomplete_unicode(self, tokens):
191
+ """Sometimes, the last token is an imcomplete unicode character, e.g. a part of "ň" or "ř".
192
+ Without this, the outputs can end with '�' = Unicode Replacement Character, and the next output also
193
+ starts with '�'.
194
+ This function hides the last incomplete unicode character and adds it in the next iteration.
195
+ """
196
+ if self.unicode_buffer != []:
197
+ logger.debug(f"Hiding incomplete unicode character: {self.unicode_buffer}")
198
+ tokens = self.unicode_buffer + tokens
199
+ self.unicode_buffer = [] # clear the buffer after processing
200
+ chars, _ = self.model.tokenizer.split_tokens_on_unicode(tokens)
201
+ if len(chars) > 0 and chars[-1].endswith('�'):
202
+ self.unicode_buffer = tokens[-1:] # keep the last incomplete unicode character
203
+ logger.debug(f"Hiding incomplete unicode character: {tokens[-1:]}")
204
+ return tokens[:-1] # remove the last token, which is incomplete unicode character
205
+ return tokens
206
+
207
+ def process_iter(self):
208
+ if len(self.audio_chunks) == 0:
209
+ audio = None
210
+ else:
211
+ audio = torch.cat(self.audio_chunks, dim=0)
212
+ if audio.shape[0] == 0:
213
+ audio = None
214
+ else:
215
+ self.end += audio.shape[0] / self.SAMPLING_RATE
216
+ self.audio_chunks = []
217
+ self.audio_bufer_offset += self.model.insert_audio(audio)
218
+ tokens, generation_progress = self.model.infer(is_last=self.is_last)
219
+
220
+ tokens = self.hide_incomplete_unicode(tokens)
221
+
222
+ text = self.model.tokenizer.decode(tokens)
223
+ if len(text) == 0:
224
+ return {}
225
+
226
+ # word-level timestamps
227
+ ts_words = self.timestamped_text(tokens, generation_progress)
228
+
229
+ self.beg = min(word['start'] for word in ts_words) # it should be this
230
+ self.beg = max(self.beg, self.last_ts + 0.001) # but let's create the timestamps non-decreasing -- at least last beg + 1
231
+ if self.is_last:
232
+ e = self.end
233
+ else:
234
+ e = max(word['end'] for word in ts_words)
235
+ e = max(e, self.beg + 0.001)
236
+
237
+ self.last_ts = e
238
+
239
+ # return (self.beg,e,text)
240
+ return {
241
+ 'start': self.beg,
242
+ 'end': e,
243
+ 'text': text,
244
+ 'tokens': tokens,
245
+ 'words': ts_words
246
+ }
247
+
248
+ def finish(self):
249
+ logger.info("Finish")
250
+ self.is_last = True
251
+ o = self.process_iter()
252
+ self.is_last = False
253
+ self.model.refresh_segment(complete=True)
254
+ return o
255
+
256
+
257
+ if __name__ == "__main__":
258
+
259
+ from whisper_streaming.whisper_online_main import main_simulation_from_file
260
+ main_simulation_from_file(simul_asr_factory, add_args=simulwhisper_args)
token_buffer.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ class TokenBuffer:
4
+
5
+ def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]):
6
+ self.text = text
7
+ self.prefix_token_ids = prefix_token_ids
8
+ self.tokenizer = tokenizer
9
+ self.device = device
10
+
11
+ def as_token_ids(self, tokenizer=None):
12
+
13
+ if tokenizer is None:
14
+ tokenizer = self.tokenizer
15
+ if tokenizer is None:
16
+ raise ValueError("Tokenizer is not set.")
17
+ return self.prefix_token_ids + tokenizer.encode(self.text)
18
+
19
+ def as_tensor(self, device=None):
20
+ if device is None:
21
+ device = self.device
22
+ if device is None:
23
+ raise ValueError("Device is not set.")
24
+ tok_ids = self.as_token_ids()
25
+ return torch.tensor(tok_ids,
26
+ dtype=torch.long, device=device).unsqueeze(0)
27
+
28
+ def as_tensor_beam(self, beam, device=None):
29
+ t = self.as_tensor(device=device)
30
+ return t.repeat_interleave(beam, dim=0)
31
+
32
+
33
+ def as_text(self):
34
+ return self.text
35
+
36
+ @staticmethod
37
+ def empty(*a, **kw):
38
+ return TokenBuffer(*a,**kw)
39
+
40
+ @staticmethod
41
+ def from_text(text, *a, **kw):
42
+ return TokenBuffer(*a, text=text, **kw)
43
+
44
+ def is_empty(self):
45
+ return self.text is None or self.text == ""
46
+
47
+ def trim_words(self, num=1, after=0):
48
+ '''
49
+ num: how many words to trim from the beginning
50
+ after: how many characters to skip (length of the static prompt)
51
+ '''
52
+ tokenizer = self.tokenizer
53
+ assert tokenizer is not None, "Tokenizer is not set."
54
+
55
+ ids = tokenizer.encode(self.text[after:])
56
+ words, wids = self.tokenizer.split_to_word_tokens(ids)
57
+ # print(words, file=sys.stderr)
58
+ # print(wids, file=sys.stderr)
59
+ if not words:
60
+ return 0
61
+ self.text = self.text[:after] + "".join(words[num:])
62
+ return sum(len(wi) for wi in wids[:num])
63
+
64
+ def append_token_ids(self, token_ids):
65
+ tokenizer = self.tokenizer
66
+ assert tokenizer is not None, "Tokenizer is not set."
67
+ self.text += self.tokenizer.decode(token_ids)
68
+
69
+ def as_split_word_tokens(self):
70
+ tokenizer = self.tokenizer
71
+ assert tokenizer is not None, "Tokenizer is not set."
72
+ ids = tokenizer.encode(self.text)
73
+ return tokenizer.split_to_word_tokens(ids)
whisper_streaming/base.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ class ASRBase:
3
+
4
+ sep = " " # join transcribe words with this character (" " for whisper_timestamped,
5
+ # "" for faster-whisper because it emits the spaces when neeeded)
6
+
7
+ def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
8
+ self.logfile = logfile
9
+
10
+ self.transcribe_kargs = {}
11
+ if lan == "auto":
12
+ self.original_language = None
13
+ else:
14
+ self.original_language = lan
15
+
16
+ self.model = self.load_model(modelsize, cache_dir, model_dir)
17
+
18
+
19
+ def load_model(self, modelsize, cache_dir):
20
+ raise NotImplemented("must be implemented in the child class")
21
+
22
+ def transcribe(self, audio, init_prompt=""):
23
+ raise NotImplemented("must be implemented in the child class")
24
+
25
+ def warmup(self, audio, init_prompt=""):
26
+ return self.transcribe(audio, init_prompt)
27
+
28
+ def use_vad(self):
29
+ raise NotImplemented("must be implemented in the child class")
30
+
31
+ def set_translate_task(self):
32
+ raise NotImplemented("must be implemented in the child class")
33
+
34
+
35
+ class OnlineProcessorInterface:
36
+
37
+ SAMPLING_RATE = 16000
38
+
39
+ def insert_audio_chunk(self, audio):
40
+ raise NotImplementedError("must be implemented in child class")
41
+
42
+ def process_iter(self):
43
+ raise NotImplementedError("must be implemented in child class")
44
+
45
+ def finish(self):
46
+ raise NotImplementedError("must be implemented in child class")
whisper_streaming/line_packet.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """Functions for sending and receiving individual lines of text over a socket.
4
+
5
+ A line is transmitted using one or more fixed-size packets of UTF-8 bytes
6
+ containing:
7
+
8
+ - Zero or more bytes of UTF-8, excluding \n and \0, followed by
9
+
10
+ - Zero or more \0 bytes as required to pad the packet to PACKET_SIZE
11
+
12
+ Originally from the UEDIN team of the ELITR project.
13
+ """
14
+
15
+ PACKET_SIZE = 65536
16
+
17
+
18
+ def send_one_line(socket, text, pad_zeros=False):
19
+ """Sends a line of text over the given socket.
20
+
21
+ The 'text' argument should contain a single line of text (line break
22
+ characters are optional). Line boundaries are determined by Python's
23
+ str.splitlines() function [1]. We also count '\0' as a line terminator.
24
+ If 'text' contains multiple lines then only the first will be sent.
25
+
26
+ If the send fails then an exception will be raised.
27
+
28
+ [1] https://docs.python.org/3.5/library/stdtypes.html#str.splitlines
29
+
30
+ Args:
31
+ socket: a socket object.
32
+ text: string containing a line of text for transmission.
33
+ """
34
+ text.replace('\0', '\n')
35
+ lines = text.splitlines()
36
+ first_line = '' if len(lines) == 0 else lines[0]
37
+ # TODO Is there a better way of handling bad input than 'replace'?
38
+ data = first_line.encode('utf-8', errors='replace') + b'\n' + (b'\0' if pad_zeros else b'')
39
+ for offset in range(0, len(data), PACKET_SIZE):
40
+ bytes_remaining = len(data) - offset
41
+ if bytes_remaining < PACKET_SIZE:
42
+ padding_length = PACKET_SIZE - bytes_remaining
43
+ packet = data[offset:] + (b'\0' * padding_length if pad_zeros else b'')
44
+ else:
45
+ packet = data[offset:offset+PACKET_SIZE]
46
+ socket.sendall(packet)
47
+
48
+
49
+ def receive_one_line(socket):
50
+ """Receives a line of text from the given socket.
51
+
52
+ This function will (attempt to) receive a single line of text. If data is
53
+ currently unavailable then it will block until data becomes available or
54
+ the sender has closed the connection (in which case it will return an
55
+ empty string).
56
+
57
+ The string should not contain any newline characters, but if it does then
58
+ only the first line will be returned.
59
+
60
+ Args:
61
+ socket: a socket object.
62
+
63
+ Returns:
64
+ A string representing a single line with a terminating newline or
65
+ None if the connection has been closed.
66
+ """
67
+ data = b''
68
+ while True:
69
+ packet = socket.recv(PACKET_SIZE)
70
+ if not packet: # Connection has been closed.
71
+ return None
72
+ data += packet
73
+ if b'\0' in packet:
74
+ break
75
+ # TODO Is there a better way of handling bad input than 'replace'?
76
+ text = data.decode('utf-8', errors='replace').strip('\0')
77
+ lines = text.split('\n')
78
+ return lines[0] + '\n'
79
+
80
+
81
+ def receive_lines(socket):
82
+ try:
83
+ data = socket.recv(PACKET_SIZE)
84
+ except BlockingIOError:
85
+ return []
86
+ if data is None: # Connection has been closed.
87
+ return None
88
+ # TODO Is there a better way of handling bad input than 'replace'?
89
+ text = data.decode('utf-8', errors='replace').strip('\0')
90
+ lines = text.split('\n')
91
+ if len(lines)==1 and not lines[0]:
92
+ return None
93
+ return lines
whisper_streaming/silero_vad_iterator.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # This is copied from silero-vad's vad_utils.py:
4
+ # https://github.com/snakers4/silero-vad/blob/94811cbe1207ec24bc0f5370b895364b8934936f/src/silero_vad/utils_vad.py#L398C1-L489C20
5
+ # (except changed defaults)
6
+
7
+ # Their licence is MIT, same as ours: https://github.com/snakers4/silero-vad/blob/94811cbe1207ec24bc0f5370b895364b8934936f/LICENSE
8
+
9
+ class VADIterator:
10
+ def __init__(self,
11
+ model,
12
+ threshold: float = 0.5,
13
+ sampling_rate: int = 16000,
14
+ min_silence_duration_ms: int = 500, # makes sense on one recording that I checked
15
+ speech_pad_ms: int = 100 # same
16
+ ):
17
+
18
+ """
19
+ Class for stream imitation
20
+
21
+ Parameters
22
+ ----------
23
+ model: preloaded .jit/.onnx silero VAD model
24
+
25
+ threshold: float (default - 0.5)
26
+ Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
27
+ It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
28
+
29
+ sampling_rate: int (default - 16000)
30
+ Currently silero VAD models support 8000 and 16000 sample rates
31
+
32
+ min_silence_duration_ms: int (default - 100 milliseconds)
33
+ In the end of each speech chunk wait for min_silence_duration_ms before separating it
34
+
35
+ speech_pad_ms: int (default - 30 milliseconds)
36
+ Final speech chunks are padded by speech_pad_ms each side
37
+ """
38
+
39
+ self.model = model
40
+ self.threshold = threshold
41
+ self.sampling_rate = sampling_rate
42
+
43
+ if sampling_rate not in [8000, 16000]:
44
+ raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
45
+
46
+ self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
47
+ self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
48
+ self.reset_states()
49
+
50
+ def reset_states(self):
51
+
52
+ self.model.reset_states()
53
+ self.triggered = False
54
+ self.temp_end = 0
55
+ self.current_sample = 0
56
+
57
+ @torch.no_grad()
58
+ def __call__(self, x, return_seconds=False, time_resolution: int = 1):
59
+ """
60
+ x: torch.Tensor
61
+ audio chunk (see examples in repo)
62
+
63
+ return_seconds: bool (default - False)
64
+ whether return timestamps in seconds (default - samples)
65
+
66
+ time_resolution: int (default - 1)
67
+ time resolution of speech coordinates when requested as seconds
68
+ """
69
+
70
+ if not torch.is_tensor(x):
71
+ try:
72
+ x = torch.Tensor(x)
73
+ except:
74
+ raise TypeError("Audio cannot be casted to tensor. Cast it manually")
75
+
76
+ window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
77
+ self.current_sample += window_size_samples
78
+
79
+ speech_prob = self.model(x, self.sampling_rate).item()
80
+
81
+ if (speech_prob >= self.threshold) and self.temp_end:
82
+ self.temp_end = 0
83
+
84
+ if (speech_prob >= self.threshold) and not self.triggered:
85
+ self.triggered = True
86
+ speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
87
+ return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, time_resolution)}
88
+
89
+ if (speech_prob < self.threshold - 0.15) and self.triggered:
90
+ if not self.temp_end:
91
+ self.temp_end = self.current_sample
92
+ if self.current_sample - self.temp_end < self.min_silence_samples:
93
+ return None
94
+ else:
95
+ speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
96
+ self.temp_end = 0
97
+ self.triggered = False
98
+ return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, time_resolution)}
99
+
100
+ return None
101
+
102
+ #######################
103
+ # because Silero now requires exactly 512-sized audio chunks
104
+
105
+ import numpy as np
106
+ class FixedVADIterator(VADIterator):
107
+ '''It fixes VADIterator by allowing to process any audio length, not only exactly 512 frames at once.
108
+ If audio to be processed at once is long and multiple voiced segments detected,
109
+ then __call__ returns the start of the first segment, and end (or middle, which means no end) of the last segment.
110
+ '''
111
+
112
+ def reset_states(self):
113
+ super().reset_states()
114
+ self.buffer = np.array([],dtype=np.float32)
115
+
116
+ def __call__(self, x, return_seconds=False):
117
+ self.buffer = np.append(self.buffer, x)
118
+ ret = None
119
+ while len(self.buffer) >= 512:
120
+ r = super().__call__(self.buffer[:512], return_seconds=return_seconds)
121
+ self.buffer = self.buffer[512:]
122
+ if ret is None:
123
+ ret = r
124
+ elif r is not None:
125
+ if 'end' in r:
126
+ ret['end'] = r['end'] # the latter end
127
+ if 'start' in r and 'end' in ret: # there is an earlier start.
128
+ # Remove end, merging this segment with the previous one.
129
+ del ret['end']
130
+ return ret if ret != {} else None
131
+
132
+ if __name__ == "__main__":
133
+ # test/demonstrate the need for FixedVADIterator:
134
+
135
+ import torch
136
+ model, _ = torch.hub.load(
137
+ repo_or_dir='snakers4/silero-vad',
138
+ model='silero_vad'
139
+ )
140
+ vac = FixedVADIterator(model)
141
+ # vac = VADIterator(model) # the second case crashes with this
142
+
143
+ # this works: for both
144
+ audio_buffer = np.array([0]*(512),dtype=np.float32)
145
+ vac(audio_buffer)
146
+
147
+ # this crashes on the non FixedVADIterator with
148
+ # ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError")
149
+ audio_buffer = np.array([0]*(512-1),dtype=np.float32)
150
+ vac(audio_buffer)
whisper_streaming/vac_online_processor.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from whisper_streaming.base import OnlineProcessorInterface
2
+ from whisper_streaming.silero_vad_iterator import FixedVADIterator
3
+ import numpy as np
4
+
5
+ import logging
6
+ logger = logging.getLogger(__name__)
7
+ import sys
8
+
9
+ class VACOnlineASRProcessor(OnlineProcessorInterface):
10
+ '''Wraps OnlineASRProcessor with VAC (Voice Activity Controller).
11
+
12
+ It works the same way as OnlineASRProcessor: it receives chunks of audio (e.g. 0.04 seconds),
13
+ it runs VAD and continuously detects whether there is speech or not.
14
+ When it detects end of speech (non-voice for 500ms), it makes OnlineASRProcessor to end the utterance immediately.
15
+ '''
16
+
17
+ def __init__(self, online_chunk_size, online, min_buffered_length=1):
18
+ self.online_chunk_size = online_chunk_size
19
+ self.online = online
20
+
21
+ self.min_buffered_frames = int(min_buffered_length * self.SAMPLING_RATE)
22
+
23
+ # VAC:
24
+ import torch
25
+ model, _ = torch.hub.load(
26
+ repo_or_dir='snakers4/silero-vad',
27
+ model='silero_vad'
28
+ )
29
+ self.vac = FixedVADIterator(model) # we use the default options there: 500ms silence, 100ms padding, etc.
30
+
31
+ self.init()
32
+
33
+ def init(self):
34
+ self.online.init()
35
+ self.vac.reset_states()
36
+ self.current_online_chunk_buffer_size = 0
37
+
38
+ self.is_currently_final = False
39
+
40
+ self.status = None # or "voice" or "nonvoice"
41
+ self.audio_buffer = np.array([],dtype=np.float32)
42
+ self.buffer_offset = 0 # in frames
43
+
44
+ def clear_buffer(self):
45
+ self.audio_buffer = np.array([],dtype=np.float32)
46
+
47
+ def insert_audio_chunk(self, audio):
48
+ res = self.vac(audio)
49
+ self.audio_buffer = np.append(self.audio_buffer, audio)
50
+ if res is not None:
51
+ frame = list(res.values())[0] - self.buffer_offset
52
+ frame = max(0, frame)
53
+ if 'start' in res and 'end' not in res:
54
+ self.status = 'voice'
55
+ send_audio = self.audio_buffer[frame:]
56
+ self.online.init(offset=(frame + self.buffer_offset)/self.SAMPLING_RATE)
57
+ self.online.insert_audio_chunk(send_audio)
58
+ self.current_online_chunk_buffer_size += len(send_audio)
59
+ self.buffer_offset += len(self.audio_buffer)
60
+ self.clear_buffer()
61
+ elif 'end' in res and 'start' not in res:
62
+ self.status = 'nonvoice'
63
+ if frame > 0:
64
+ send_audio = self.audio_buffer[:frame]
65
+ self.online.insert_audio_chunk(send_audio)
66
+ self.current_online_chunk_buffer_size += len(send_audio)
67
+ self.is_currently_final = True
68
+ keep_frames = min(len(self.audio_buffer) - frame, self.min_buffered_frames)
69
+ self.buffer_offset += len(self.audio_buffer) - keep_frames
70
+ self.audio_buffer = self.audio_buffer[-keep_frames:]
71
+ else:
72
+ beg = max(0, res["start"] - self.buffer_offset)
73
+ end = max(0, res["end"] - self.buffer_offset)
74
+ self.status = 'nonvoice'
75
+ if beg < end:
76
+ send_audio = self.audio_buffer[beg:end]
77
+ self.online.init(offset=((beg + self.buffer_offset)/self.SAMPLING_RATE))
78
+ self.online.insert_audio_chunk(send_audio)
79
+ self.current_online_chunk_buffer_size += len(send_audio)
80
+ self.is_currently_final = True
81
+ keep_frames = min(len(self.audio_buffer) - end, self.min_buffered_frames)
82
+ self.buffer_offset += len(self.audio_buffer) - keep_frames
83
+ self.audio_buffer = self.audio_buffer[-keep_frames:]
84
+ else:
85
+ if self.status == 'voice':
86
+ self.online.insert_audio_chunk(self.audio_buffer)
87
+ self.current_online_chunk_buffer_size += len(self.audio_buffer)
88
+ self.buffer_offset += len(self.audio_buffer)
89
+ self.clear_buffer()
90
+ else:
91
+ # We keep 1 second because VAD may later find start of voice in it.
92
+ # But we trim it to prevent OOM.
93
+ self.buffer_offset += max(0, len(self.audio_buffer) - self.min_buffered_frames)
94
+ self.audio_buffer = self.audio_buffer[-self.min_buffered_frames:]
95
+
96
+ def process_iter(self):
97
+ if self.is_currently_final:
98
+ return self.finish()
99
+ elif self.current_online_chunk_buffer_size > self.SAMPLING_RATE*self.online_chunk_size:
100
+ self.current_online_chunk_buffer_size = 0
101
+ ret = self.online.process_iter()
102
+ return ret
103
+ else:
104
+ logger.info(f"no online update, only VAD. {self.status}")
105
+ return {}
106
+
107
+ def finish(self):
108
+ ret = self.online.finish()
109
+ self.current_online_chunk_buffer_size = 0
110
+ self.is_currently_final = False
111
+ return ret
whisper_streaming/whisper_online_main.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # This code is retrieved from the original WhisperStreaming whisper_online.py .
4
+ # It is refactored and simplified. Only the code that is needed for the
5
+ # SimulWhisper backend is kept.
6
+
7
+
8
+ import sys
9
+ import numpy as np
10
+ import librosa
11
+ from functools import lru_cache
12
+ import time
13
+ import logging
14
+
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ @lru_cache(10**6)
19
+ def load_audio(fname):
20
+ a, _ = librosa.load(fname, sr=16000, dtype=np.float32)
21
+ return a
22
+
23
+ def load_audio_chunk(fname, beg, end):
24
+ audio = load_audio(fname)
25
+ beg_s = int(beg*16000)
26
+ end_s = int(end*16000)
27
+ return audio[beg_s:end_s]
28
+
29
+ def processor_args(parser):
30
+ """shared args for the online processors
31
+ parser: argparse.ArgumentParser object
32
+ """
33
+ group = parser.add_argument_group("WhisperStreaming processor arguments (shared for simulation from file and for the server)")
34
+ group.add_argument('--min-chunk-size', type=float, default=1.2,
35
+ help='Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter '
36
+ 'time, it waits, otherwise it processes the whole segment that was received by this time.')
37
+
38
+ group.add_argument('--lan', '--language', type=str, default="en",
39
+ help="Source language code, e.g. en, de, cs, or auto for automatic language detection from speech.")
40
+ group.add_argument('--task', type=str, default='transcribe',
41
+ choices=["transcribe","translate"],
42
+ help="Transcribe or translate.")
43
+
44
+ group.add_argument('--vac', action="store_true", default=False,
45
+ help='Use VAC = voice activity controller. Recommended. Requires torch.')
46
+ group.add_argument('--vac-chunk-size', type=float, default=0.04,
47
+ help='VAC sample size in seconds.')
48
+
49
+ parser.add_argument("-l", "--log-level", dest="log_level",
50
+ choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
51
+ help="Set the log level", default='DEBUG')
52
+
53
+ parser.add_argument("--logdir", help="Directory to save audio segments and generated texts for debugging.",
54
+ default=None)
55
+
56
+ def asr_factory(args, factory=None):
57
+ """
58
+ Creates and configures an asr and online processor object through factory that is implemented in the backend.
59
+ """
60
+ # if backend is None:
61
+ # backend = args.backend
62
+ # if backend == "simul-whisper":
63
+ # from simul_whisper_backend import simul_asr_factory
64
+ asr, online = factory(args)
65
+
66
+ # Create the OnlineASRProcessor
67
+ if args.vac:
68
+ from whisper_streaming.vac_online_processor import VACOnlineASRProcessor
69
+ online = VACOnlineASRProcessor(args.min_chunk_size, online)
70
+
71
+ if args.task == "translate":
72
+ if args.model_path.endswith(".en.pt"):
73
+ logger.error(f"The model {args.model_path} is English only. Translation is not available. Terminating.")
74
+ sys.exit(1)
75
+ asr.set_translate_task()
76
+
77
+ return asr, online
78
+
79
+ def set_logging(args,logger):
80
+ logging.basicConfig(
81
+ # this format would include module name:
82
+ # format='%(levelname)s\t%(name)s\t%(message)s')
83
+ format='%(levelname)s\t%(message)s')
84
+ logger.setLevel(args.log_level)
85
+ logging.getLogger("simul_whisper").setLevel(args.log_level)
86
+ logging.getLogger("whisper_streaming").setLevel(args.log_level)
87
+
88
+
89
+ def simulation_args(parser):
90
+ simulation_group = parser.add_argument_group("Arguments for simulation from file")
91
+ simulation_group.add_argument('audio_path', type=str, help="Filename of 16kHz mono channel wav, on which live streaming is simulated.")
92
+ simulation_group.add_argument('--start_at', type=float, default=0.0, help='Start processing audio at this time.')
93
+ # TODO: offline mode is not implemented in SimulStreaming yet
94
+ # simulation_group.add_argument('--offline', action="store_true", default=False, help='Offline mode.')
95
+ simulation_group.add_argument('--comp_unaware', action="store_true", default=False, help='Computationally unaware simulation.')
96
+
97
+ def main_simulation_from_file(factory, add_args=None):
98
+ '''
99
+ factory: function that creates the ASR and online processor object from args and logger.
100
+ or in the default WhisperStreaming local agreement backends (not implemented but could be).
101
+ add_args: add specific args for the backend
102
+ '''
103
+
104
+ import argparse
105
+ parser = argparse.ArgumentParser()
106
+
107
+ processor_args(parser)
108
+ if add_args is not None:
109
+ add_args(parser)
110
+
111
+ simulation_args(parser)
112
+
113
+ args = parser.parse_args()
114
+ args.offline = False # TODO: offline mode is not implemented in SimulStreaming yet
115
+
116
+ if args.offline and args.comp_unaware:
117
+ logger.error("No or one option from --offline and --comp_unaware are available, not both. Exiting.")
118
+ sys.exit(1)
119
+
120
+ set_logging(args,logger)
121
+
122
+ audio_path = args.audio_path
123
+
124
+ SAMPLING_RATE = 16000
125
+ duration = len(load_audio(audio_path))/SAMPLING_RATE
126
+ logger.info("Audio duration is: %2.2f seconds" % duration)
127
+
128
+ asr, online = asr_factory(args, factory)
129
+ if args.vac:
130
+ min_chunk = args.vac_chunk_size
131
+ else:
132
+ min_chunk = args.min_chunk_size
133
+
134
+ # load the audio into the LRU cache before we start the timer
135
+ a = load_audio_chunk(audio_path,0,1)
136
+
137
+ # warm up the ASR because the very first transcribe takes much more time than the other
138
+ asr.warmup(a)
139
+
140
+ beg = args.start_at
141
+ start = time.time()-beg
142
+
143
+ def output_transcript(iteration_output, now=None):
144
+ # output format in stdout is like:
145
+ # 4186.3606 0 1720 Takhle to je
146
+ # - the first three words are:
147
+ # - emission time from beginning of processing, in milliseconds
148
+ # - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway
149
+ # - the next words: segment transcript
150
+ if now is None:
151
+ now = time.time() - start
152
+
153
+ if iteration_output:
154
+ start_ts = iteration_output['start']
155
+ end_ts = iteration_output['end']
156
+ text = iteration_output['text']
157
+ logger.debug(f"{now * 1000:.4f} {start_ts * 1000:.0f} {end_ts * 1000:.0f} {text}")
158
+ print(f"{now * 1000:.4f} {start_ts * 1000:.0f} {end_ts * 1000:.0f} {text}", flush=True)
159
+ else:
160
+ logger.debug("No text in this segment")
161
+
162
+ if args.offline: ## offline mode processing (for testing/debugging)
163
+ a = load_audio(audio_path)
164
+ online.insert_audio_chunk(a)
165
+ try:
166
+ o = online.process_iter()
167
+ except AssertionError as e:
168
+ logger.error(f"assertion error: {repr(e)}")
169
+ else:
170
+ output_transcript(o)
171
+ now = None
172
+ elif args.comp_unaware: # computational unaware mode
173
+ end = beg + min_chunk
174
+ while True:
175
+ a = load_audio_chunk(audio_path,beg,end)
176
+ online.insert_audio_chunk(a)
177
+ try:
178
+ o = online.process_iter()
179
+ except AssertionError as e:
180
+ logger.error(f"assertion error: {repr(e)}")
181
+ pass
182
+ else:
183
+ output_transcript(o, now=end)
184
+
185
+ logger.info(f"## last processed {end:.2f}s")
186
+
187
+ if end >= duration:
188
+ break
189
+
190
+ beg = end
191
+
192
+ if end + min_chunk > duration:
193
+ end = duration
194
+ else:
195
+ end += min_chunk
196
+ now = duration
197
+
198
+ else: # online = simultaneous mode
199
+ end = 0
200
+ while True:
201
+ now = time.time() - start
202
+ if now < end+min_chunk:
203
+ time.sleep(min_chunk+end-now)
204
+ end = time.time() - start
205
+ a = load_audio_chunk(audio_path,beg,end)
206
+ beg = end
207
+ online.insert_audio_chunk(a)
208
+
209
+ try:
210
+ o = online.process_iter()
211
+ except AssertionError as e:
212
+ logger.error(f"assertion error: {e}")
213
+ pass
214
+ else:
215
+ output_transcript(o)
216
+ now = time.time() - start
217
+ logger.info(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}")
218
+
219
+ if end >= duration:
220
+ break
221
+ now = None
222
+
223
+ o = online.finish()
224
+ output_transcript(o, now=now)
whisper_streaming/whisper_server.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from whisper_streaming.whisper_online_main import *
3
+
4
+ import sys
5
+ import argparse
6
+ import os
7
+ import logging
8
+ import numpy as np
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ SAMPLING_RATE = 16000
13
+
14
+
15
+ ######### Server objects
16
+
17
+ import whisper_streaming.line_packet as line_packet
18
+ import socket
19
+
20
+ class Connection:
21
+ '''it wraps conn object'''
22
+ PACKET_SIZE = 32000*5*60 # 5 minutes # was: 65536
23
+
24
+ def __init__(self, conn):
25
+ self.conn = conn
26
+ self.last_line = ""
27
+
28
+ self.conn.setblocking(True)
29
+
30
+ def send(self, line):
31
+ '''it doesn't send the same line twice, because it was problematic in online-text-flow-events'''
32
+ if line == self.last_line:
33
+ return
34
+ line_packet.send_one_line(self.conn, line)
35
+ self.last_line = line
36
+
37
+ def receive_lines(self):
38
+ in_line = line_packet.receive_lines(self.conn)
39
+ return in_line
40
+
41
+ def non_blocking_receive_audio(self):
42
+ try:
43
+ r = self.conn.recv(self.PACKET_SIZE)
44
+ return r
45
+ except ConnectionResetError:
46
+ return None
47
+
48
+ import io
49
+ import soundfile
50
+
51
+ # wraps socket and ASR object, and serves one client connection.
52
+ # next client should be served by a new instance of this object
53
+ class ServerProcessor:
54
+
55
+ def __init__(self, c, online_asr_proc, min_chunk):
56
+ self.connection = c
57
+ self.online_asr_proc = online_asr_proc
58
+ self.min_chunk = min_chunk
59
+
60
+ self.is_first = True
61
+
62
+ def receive_audio_chunk(self):
63
+ # receive all audio that is available by this time
64
+ # blocks operation if less than self.min_chunk seconds is available
65
+ # unblocks if connection is closed or a chunk is available
66
+ out = []
67
+ minlimit = self.min_chunk*SAMPLING_RATE
68
+ while sum(len(x) for x in out) < minlimit:
69
+ raw_bytes = self.connection.non_blocking_receive_audio()
70
+ if not raw_bytes:
71
+ break
72
+ # print("received audio:",len(raw_bytes), "bytes", raw_bytes[:10])
73
+ sf = soundfile.SoundFile(io.BytesIO(raw_bytes), channels=1,endian="LITTLE",samplerate=SAMPLING_RATE, subtype="PCM_16",format="RAW")
74
+ audio, _ = librosa.load(sf,sr=SAMPLING_RATE,dtype=np.float32)
75
+ out.append(audio)
76
+ if not out:
77
+ return None
78
+ conc = np.concatenate(out)
79
+ if self.is_first and len(conc) < minlimit:
80
+ return None
81
+ self.is_first = False
82
+ return np.concatenate(out)
83
+
84
+ def send_result(self, iteration_output):
85
+ # output format in stdout is like:
86
+ # 0 1720 Takhle to je
87
+ # - the first two words are:
88
+ # - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway
89
+ # - the next words: segment transcript
90
+ if iteration_output:
91
+ message = "%1.0f %1.0f %s" % (iteration_output['start'] * 1000, iteration_output['end'] * 1000, iteration_output['text'])
92
+ print(message, flush=True, file=sys.stderr)
93
+ self.connection.send(message)
94
+ else:
95
+ logger.debug("No text in this segment")
96
+
97
+ def process(self):
98
+ # handle one client connection
99
+ self.online_asr_proc.init()
100
+ while True:
101
+ a = self.receive_audio_chunk()
102
+ if a is None:
103
+ break
104
+ self.online_asr_proc.insert_audio_chunk(a)
105
+ o = self.online_asr_proc.process_iter()
106
+ try:
107
+ self.send_result(o)
108
+ except BrokenPipeError:
109
+ logger.info("broken pipe -- connection closed?")
110
+ break
111
+
112
+ # o = online.finish() # this should be working
113
+ # self.send_result(o)
114
+
115
+ def main_server(factory, add_args):
116
+ '''
117
+ factory: function that creates the ASR and online processor object from args and logger.
118
+ or in the default WhisperStreaming local agreement backends (not implemented but could be).
119
+ add_args: add specific args for the backend
120
+ '''
121
+ logger = logging.getLogger(__name__)
122
+ parser = argparse.ArgumentParser()
123
+
124
+ # server options
125
+ parser.add_argument("--host", type=str, default='localhost')
126
+ parser.add_argument("--port", type=int, default=43007)
127
+ parser.add_argument("--warmup-file", type=str, dest="warmup_file",
128
+ help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. "
129
+ "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .")
130
+
131
+ # options from whisper_online
132
+ processor_args(parser)
133
+
134
+ add_args(parser)
135
+
136
+ args = parser.parse_args()
137
+
138
+ set_logging(args,logger)
139
+
140
+ # setting whisper object by args
141
+
142
+
143
+ asr, online = asr_factory(args, factory)
144
+ if args.vac:
145
+ min_chunk = args.vac_chunk_size
146
+ else:
147
+ min_chunk = args.min_chunk_size
148
+
149
+ # warm up the ASR because the very first transcribe takes more time than the others.
150
+ # Test results in https://github.com/ufal/whisper_streaming/pull/81
151
+ msg = "Whisper is not warmed up. The first chunk processing may take longer."
152
+ if args.warmup_file:
153
+ if os.path.isfile(args.warmup_file):
154
+ a = load_audio_chunk(args.warmup_file,0,1)
155
+ asr.warmup(a)
156
+ logger.info("Whisper is warmed up.")
157
+ else:
158
+ logger.critical("The warm up file is not available. "+msg)
159
+ sys.exit(1)
160
+ else:
161
+ logger.warning(msg)
162
+
163
+ # server loop
164
+
165
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
166
+ s.bind((args.host, args.port))
167
+ s.listen(1)
168
+ logger.info('Listening on'+str((args.host, args.port)))
169
+ while True:
170
+ conn, addr = s.accept()
171
+ logger.info('Connected to client on {}'.format(addr))
172
+ connection = Connection(conn)
173
+ proc = ServerProcessor(connection, online, min_chunk)
174
+ proc.process()
175
+ conn.close()
176
+ logger.info('Connection to client closed')
177
+ logger.info('Connection closed, terminating.')