File size: 16,562 Bytes
f787576
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2faff30
f787576
 
2faff30
f787576
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
import hashlib
import io
import os
import urllib
import warnings
from typing import List, Optional, Union

import torch
from tqdm import tqdm

from .audio import load_audio, pad_or_trim, log_mel_spectrogram
from .model import ModelDimensions, Whisper
from .streaming_model import StreamingWhisper
from .version import __version__

_MODELS = {
    "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
    "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
    "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
    "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
    "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
    "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
    "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
    "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
    "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
    "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
    "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
    "large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
    "large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
    "turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
}

_STREAMING_MODELS = {
    "base": {
        "300": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_base_LIBRI-960-ALIGNED_32_full_streaming_eot_fixed_timings_LR-1e-05_r32_g15_eg1_top5_full-streamTrue_random-orderFalse_fraction0.25/checkpoint/checkpoint-epoch=0009.pt",
        "200": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_base_LIBRI-960-ALIGNED_32_full_streaming_eot_fixed_timings_LR-1e-05_r32_g10_eg2_top5_full-streamTrue_random-orderFalse_fraction0.1/checkpoint/checkpoint-epoch=0009.pt",
        "100": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_base_LIBRI-960-ALIGNED_32_full_streaming_eot_fixed_timings_LR-1e-05_r32_g5_eg5_top5_full-streamTrue_random-orderFalse_fraction0.05/checkpoint/checkpoint-epoch=0009.pt",
        "40": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_base_LIBRI-960-ALIGNED_32_full_streaming_eot_fixed_timings_LR-1e-05_r32_g2_eg14_top5_full-streamTrue_random-orderFalse_fraction0.02/checkpoint/checkpoint-epoch=0006.pt",
    },
    "small": {
        "1000": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_small_LIBRI-960-ALIGNED_16_full_streaming_eot_fixed_timings_LR-1e-05_r32_g50_eg0_top5_full-streamTrue_random-orderFalse_fraction0.4/checkpoint/checkpoint-epoch=0009.pt",
        "300": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_small_LIBRI-960-ALIGNED_16_full_streaming_eot_fixed_timings_LR-1e-05_r32_g15_eg1_top5_full-streamTrue_random-orderFalse_fraction0.25/checkpoint/checkpoint-epoch=0009.pt",
        "200": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_small_LIBRI-960-ALIGNED_16_full_streaming_eot_fixed_timings_LR-1e-05_r32_g10_eg2_top5_full-streamTrue_random-orderFalse_fraction0.1/checkpoint/checkpoint-epoch=0009.pt",
        "100": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_small_LIBRI-960-ALIGNED_16_full_streaming_eot_fixed_timings_LR-1e-05_r32_g5_eg5_top5_full-streamTrue_random-orderFalse_fraction0.05/checkpoint/checkpoint-epoch=0009.pt",
        "40": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_small_LIBRI-960-ALIGNED_16_full_streaming_eot_fixed_timings_LR-1e-05_r32_g2_eg14_top5_full-streamTrue_random-orderFalse_fraction0.02/checkpoint/checkpoint-epoch=0009.pt",
    },
    "large-v2": {
        "1000": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_large-v2_LIBRI-960-ALIGNED_4_full_streaming_eot_fixed_timings_LR-1e-05_r4_g50_eg0_top5_full-streamTrue_random-orderFalse_fraction0.3/checkpoint/checkpoint-epoch=0002.pt",
        "300": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_large-v2_LIBRI-960-ALIGNED_4_full_streaming_eot_fixed_timings_LR-1e-05_r4_g15_eg1_top5_full-streamTrue_random-orderFalse_fraction0.1/checkpoint/checkpoint-epoch=0002.pt",
        "200": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_large-v2_LIBRI-960-ALIGNED_4_full_streaming_eot_fixed_timings_LR-1e-05_r4_g10_eg2_top5_full-streamTrue_random-orderFalse_fraction0.07/checkpoint/checkpoint-epoch=0002.pt",
        "100": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_large-v2_LIBRI-960-ALIGNED_4_full_streaming_eot_fixed_timings_LR-1e-05_r4_g5_eg5_top5_full-streamTrue_random-orderFalse_fraction0.03/checkpoint/checkpoint-epoch=0002.pt",
        "40": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_large-v2_LIBRI-960-ALIGNED_4_full_streaming_eot_fixed_timings_LR-1e-05_r4_g2_eg14_top5_full-streamTrue_random-orderFalse_fraction0.01/checkpoint/checkpoint-epoch=0002.pt",
        "300-multi": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_large-v2_LIBRI-BLEND-ALIGNED_4_full_streaming_eot_fixed_timings_LR-1e-05_r4_g15_eg1_top5_full-streamTrue_random-orderFalse_fraction0.05/checkpoint/checkpoint-epoch=0001.pt",
    }
}

_STREAMING_MODELS_HF = {
    "base": {
        "300": "base_300.pt",
        "200": "base_200.pt",
        "100": "base_100.pt",
        "40": "base_40.pt",
    },
    "small": {
        "1000": "small_1000.pt",
        "300": "small_300.pt",
        "200": "small_200.pt",
        "100": "small_100.pt",
        "40": "small_40.pt",
    },
    "large-v2": {
        "1000": "large-v2_1000.pt",
        "300": "large-v2_300.pt",
        "200": "large-v2_200.pt",
        "100": "large-v2_100.pt",
        "40": "large-v2_40.pt",
        "300-multi": "large-v2_300_multi.pt",
    }
}

# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
_ALIGNMENT_HEADS = {
    "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
    "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
    "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
    "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
    "small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
    "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
    "medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
    "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
    "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
    "large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
    "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
    "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
    "large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
    "turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
}


def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
    os.makedirs(root, exist_ok=True)

    expected_sha256 = url.split("/")[-2]
    download_target = os.path.join(root, os.path.basename(url))

    if os.path.exists(download_target) and not os.path.isfile(download_target):
        raise RuntimeError(f"{download_target} exists and is not a regular file")

    if os.path.isfile(download_target):
        with open(download_target, "rb") as f:
            model_bytes = f.read()
        if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
            return model_bytes if in_memory else download_target
        else:
            warnings.warn(
                f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
            )

    with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
        with tqdm(
            total=int(source.info().get("Content-Length")),
            ncols=80,
            unit="iB",
            unit_scale=True,
            unit_divisor=1024,
        ) as loop:
            while True:
                buffer = source.read(8192)
                if not buffer:
                    break

                output.write(buffer)
                loop.update(len(buffer))

    model_bytes = open(download_target, "rb").read()
    if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
        raise RuntimeError(
            "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
        )

    return model_bytes if in_memory else download_target


def available_models() -> List[str]:
    """Returns the names of available models"""
    return list(_MODELS.keys())


def load_model(
    name: str,
    device: Optional[Union[str, torch.device]] = None,
    download_root: str = None,
    in_memory: bool = False,
) -> Whisper:
    """
    Load a Whisper ASR model

    Parameters
    ----------
    name : str
        one of the official model names listed by `whisper.available_models()`, or
        path to a model checkpoint containing the model dimensions and the model state_dict.
    device : Union[str, torch.device]
        the PyTorch device to put the model into
    download_root: str
        path to download the model files; by default, it uses "~/.cache/whisper"
    in_memory: bool
        whether to preload the model weights into host memory

    Returns
    -------
    model : Whisper
        The Whisper ASR model instance
    """

    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    if download_root is None:
        default = os.path.join(os.path.expanduser("~"), ".cache")
        download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")

    if name in _MODELS:
        checkpoint_file = _download(_MODELS[name], download_root, in_memory)
        alignment_heads = _ALIGNMENT_HEADS[name]
    elif os.path.isfile(name):
        checkpoint_file = open(name, "rb").read() if in_memory else name
        alignment_heads = None
    else:
        raise RuntimeError(
            f"Model {name} not found; available models = {available_models()}"
        )

    with (
        io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
    ) as fp:
        checkpoint = torch.load(fp, map_location=device)
    del checkpoint_file

    dims = ModelDimensions(**checkpoint["dims"])
    model = Whisper(dims)
    model.load_state_dict(checkpoint["model_state_dict"])

    if alignment_heads is not None:
        model.set_alignment_heads(alignment_heads)

    return model.to(device)


def load_streaming_model(
    name: str,
    advisor_ckpt_path: str = None,
    ft_model_ckpt_path: str = None,
    device: Optional[Union[str, torch.device]] = None,
    download_root: str = None,
    in_memory: bool = False,
    cache_gran: bool = True,
    gran: int = 15, 
    rank: int = 8,
    extra_gran_blocks: int = 0,
    n_advisor_class: int = 4,
    **kwargs: any
) -> StreamingWhisper:
    """
    Load a StreamingWhisper ASR model

    Parameters
    ----------
    name : str
        one of the official model names listed by `whisper.available_models()`, or
        path to a model checkpoint containing the model dimensions and the model state_dict.
    device : Union[str, torch.device]
        the PyTorch device to put the model into
    download_root: str
        path to download the model files; by default, it uses "~/.cache/whisper"
    in_memory: bool
        whether to preload the model weights into host memory

    Returns
    -------
    model : Whisper
        The Whisper ASR model instance
    """
    if ft_model_ckpt_path is None:
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        if download_root is None:
            default = os.path.join(os.path.expanduser("~"), ".cache")
            download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")

        if name in _MODELS:
            checkpoint_file = _download(_MODELS[name], download_root, in_memory)
            alignment_heads = _ALIGNMENT_HEADS[name]
        elif os.path.isfile(name):
            checkpoint_file = open(name, "rb").read() if in_memory else name
            alignment_heads = None
        else:
            raise RuntimeError(
                f"Model {name} not found; available models = {available_models()}"
            )

        with (
            io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
        ) as fp:
            checkpoint = torch.load(fp, map_location=device)
        del checkpoint_file
    else:
        checkpoint = torch.load(ft_model_ckpt_path, weights_only=False)

    decoder_advisor_chkpt = torch.load(advisor_ckpt_path, weights_only=False) if advisor_ckpt_path is not None else {"state_dict": {}}
    advisor_state_dict = {k: v for k, v in decoder_advisor_chkpt["state_dict"].items() if "decoder_advisor" in k}

    whisper_dict = checkpoint["model_state_dict"] if "model_state_dict" in checkpoint.keys() else checkpoint["state_dict"]

    whisper_dict = {k.replace("weight", "base_layer.weight") if "attn." in k and "weight" in k
                    else k.replace("bias", "base_layer.bias") if "attn." in k and "bias" in k
                    else k: v for k, v in whisper_dict.items()}
    
    streaming_whisper_state_dict = {**advisor_state_dict, **whisper_dict}
    
    dims = ModelDimensions(**checkpoint["dims"])
    
    model = StreamingWhisper(dims, 
                             cache_gran=cache_gran, 
                             gran=gran, 
                             rank=rank, 
                             extra_gran_blocks=extra_gran_blocks,
                             n_advisor_class=n_advisor_class,
                             **kwargs)

    model.load_state_dict(streaming_whisper_state_dict, strict=False)

    # for n, p in model.named_parameters():
    #     print(n, p)

    if ft_model_ckpt_path is None and alignment_heads is not None:
        model.set_alignment_heads(alignment_heads)

    return model.to(device)


def load_streaming_model_correct(
    name: str,
    gran: int = 300, 
    multilingual: bool = False,
    device: Optional[Union[str, torch.device]] = None,
    download_root: str = None,
    in_memory: bool = False,
) -> StreamingWhisper:   
    
    subname = (str(gran) + '-multi') if multilingual else str(gran)
    
    from huggingface_hub import hf_hub_download
    hf_token = os.environ.get("HF_TOKEN")

    try:
        ckpt_path = hf_hub_download(repo_id="MLSpeech/causal-whisper", filename=_STREAMING_MODELS_HF[name][subname], repo_type="model", token=hf_token)
    except KeyError as e:
        print(f"Streaming model with the next configs: size {name}, multilingual: {multilingual} and chunk size: {gran} is not available.")
        
    checkpoint = torch.load(ckpt_path, weights_only=False)

    dims = ModelDimensions(**checkpoint["dims"])

    model = StreamingWhisper(dims, 
                             gran=checkpoint['cfg']['gran'], 
                             rank=checkpoint['cfg']['rank'], 
                             extra_gran_blocks=checkpoint['cfg']['extra_gran_blocks'])

    model.load_state_dict(checkpoint['state_dict'], strict=False)

    return model.to(device)