|
|
import hashlib |
|
|
import io |
|
|
import os |
|
|
import urllib |
|
|
import warnings |
|
|
from typing import List, Optional, Union |
|
|
|
|
|
import torch |
|
|
from tqdm import tqdm |
|
|
|
|
|
from .audio import load_audio, pad_or_trim, log_mel_spectrogram |
|
|
from .model import ModelDimensions, Whisper |
|
|
from .streaming_model import StreamingWhisper |
|
|
from .version import __version__ |
|
|
|
|
|
_MODELS = { |
|
|
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", |
|
|
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", |
|
|
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", |
|
|
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", |
|
|
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", |
|
|
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", |
|
|
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", |
|
|
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", |
|
|
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", |
|
|
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", |
|
|
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", |
|
|
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", |
|
|
"large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt", |
|
|
"turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt", |
|
|
} |
|
|
|
|
|
_STREAMING_MODELS = { |
|
|
"base": { |
|
|
"300": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_base_LIBRI-960-ALIGNED_32_full_streaming_eot_fixed_timings_LR-1e-05_r32_g15_eg1_top5_full-streamTrue_random-orderFalse_fraction0.25/checkpoint/checkpoint-epoch=0009.pt", |
|
|
"200": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_base_LIBRI-960-ALIGNED_32_full_streaming_eot_fixed_timings_LR-1e-05_r32_g10_eg2_top5_full-streamTrue_random-orderFalse_fraction0.1/checkpoint/checkpoint-epoch=0009.pt", |
|
|
"100": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_base_LIBRI-960-ALIGNED_32_full_streaming_eot_fixed_timings_LR-1e-05_r32_g5_eg5_top5_full-streamTrue_random-orderFalse_fraction0.05/checkpoint/checkpoint-epoch=0009.pt", |
|
|
"40": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_base_LIBRI-960-ALIGNED_32_full_streaming_eot_fixed_timings_LR-1e-05_r32_g2_eg14_top5_full-streamTrue_random-orderFalse_fraction0.02/checkpoint/checkpoint-epoch=0006.pt", |
|
|
}, |
|
|
"small": { |
|
|
"1000": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_small_LIBRI-960-ALIGNED_16_full_streaming_eot_fixed_timings_LR-1e-05_r32_g50_eg0_top5_full-streamTrue_random-orderFalse_fraction0.4/checkpoint/checkpoint-epoch=0009.pt", |
|
|
"300": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_small_LIBRI-960-ALIGNED_16_full_streaming_eot_fixed_timings_LR-1e-05_r32_g15_eg1_top5_full-streamTrue_random-orderFalse_fraction0.25/checkpoint/checkpoint-epoch=0009.pt", |
|
|
"200": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_small_LIBRI-960-ALIGNED_16_full_streaming_eot_fixed_timings_LR-1e-05_r32_g10_eg2_top5_full-streamTrue_random-orderFalse_fraction0.1/checkpoint/checkpoint-epoch=0009.pt", |
|
|
"100": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_small_LIBRI-960-ALIGNED_16_full_streaming_eot_fixed_timings_LR-1e-05_r32_g5_eg5_top5_full-streamTrue_random-orderFalse_fraction0.05/checkpoint/checkpoint-epoch=0009.pt", |
|
|
"40": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_small_LIBRI-960-ALIGNED_16_full_streaming_eot_fixed_timings_LR-1e-05_r32_g2_eg14_top5_full-streamTrue_random-orderFalse_fraction0.02/checkpoint/checkpoint-epoch=0009.pt", |
|
|
}, |
|
|
"large-v2": { |
|
|
"1000": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_large-v2_LIBRI-960-ALIGNED_4_full_streaming_eot_fixed_timings_LR-1e-05_r4_g50_eg0_top5_full-streamTrue_random-orderFalse_fraction0.3/checkpoint/checkpoint-epoch=0002.pt", |
|
|
"300": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_large-v2_LIBRI-960-ALIGNED_4_full_streaming_eot_fixed_timings_LR-1e-05_r4_g15_eg1_top5_full-streamTrue_random-orderFalse_fraction0.1/checkpoint/checkpoint-epoch=0002.pt", |
|
|
"200": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_large-v2_LIBRI-960-ALIGNED_4_full_streaming_eot_fixed_timings_LR-1e-05_r4_g10_eg2_top5_full-streamTrue_random-orderFalse_fraction0.07/checkpoint/checkpoint-epoch=0002.pt", |
|
|
"100": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_large-v2_LIBRI-960-ALIGNED_4_full_streaming_eot_fixed_timings_LR-1e-05_r4_g5_eg5_top5_full-streamTrue_random-orderFalse_fraction0.03/checkpoint/checkpoint-epoch=0002.pt", |
|
|
"40": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_large-v2_LIBRI-960-ALIGNED_4_full_streaming_eot_fixed_timings_LR-1e-05_r4_g2_eg14_top5_full-streamTrue_random-orderFalse_fraction0.01/checkpoint/checkpoint-epoch=0002.pt", |
|
|
"300-multi": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_large-v2_LIBRI-BLEND-ALIGNED_4_full_streaming_eot_fixed_timings_LR-1e-05_r4_g15_eg1_top5_full-streamTrue_random-orderFalse_fraction0.05/checkpoint/checkpoint-epoch=0001.pt", |
|
|
} |
|
|
} |
|
|
|
|
|
_STREAMING_MODELS_HF = { |
|
|
"base": { |
|
|
"300": "base_300.pt", |
|
|
"200": "base_200.pt", |
|
|
"100": "base_100.pt", |
|
|
"40": "base_40.pt", |
|
|
}, |
|
|
"small": { |
|
|
"1000": "small_1000.pt", |
|
|
"300": "small_300.pt", |
|
|
"200": "small_200.pt", |
|
|
"100": "small_100.pt", |
|
|
"40": "small_40.pt", |
|
|
}, |
|
|
"large-v2": { |
|
|
"1000": "large-v2_1000.pt", |
|
|
"300": "large-v2_300.pt", |
|
|
"200": "large-v2_200.pt", |
|
|
"100": "large-v2_100.pt", |
|
|
"40": "large-v2_40.pt", |
|
|
"300-multi": "large-v2_300_multi.pt", |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
_ALIGNMENT_HEADS = { |
|
|
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00", |
|
|
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO", |
|
|
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00", |
|
|
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m", |
|
|
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00", |
|
|
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000", |
|
|
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00", |
|
|
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9", |
|
|
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj", |
|
|
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", |
|
|
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", |
|
|
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", |
|
|
"large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`", |
|
|
"turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`", |
|
|
} |
|
|
|
|
|
|
|
|
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: |
|
|
os.makedirs(root, exist_ok=True) |
|
|
|
|
|
expected_sha256 = url.split("/")[-2] |
|
|
download_target = os.path.join(root, os.path.basename(url)) |
|
|
|
|
|
if os.path.exists(download_target) and not os.path.isfile(download_target): |
|
|
raise RuntimeError(f"{download_target} exists and is not a regular file") |
|
|
|
|
|
if os.path.isfile(download_target): |
|
|
with open(download_target, "rb") as f: |
|
|
model_bytes = f.read() |
|
|
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: |
|
|
return model_bytes if in_memory else download_target |
|
|
else: |
|
|
warnings.warn( |
|
|
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" |
|
|
) |
|
|
|
|
|
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: |
|
|
with tqdm( |
|
|
total=int(source.info().get("Content-Length")), |
|
|
ncols=80, |
|
|
unit="iB", |
|
|
unit_scale=True, |
|
|
unit_divisor=1024, |
|
|
) as loop: |
|
|
while True: |
|
|
buffer = source.read(8192) |
|
|
if not buffer: |
|
|
break |
|
|
|
|
|
output.write(buffer) |
|
|
loop.update(len(buffer)) |
|
|
|
|
|
model_bytes = open(download_target, "rb").read() |
|
|
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: |
|
|
raise RuntimeError( |
|
|
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." |
|
|
) |
|
|
|
|
|
return model_bytes if in_memory else download_target |
|
|
|
|
|
|
|
|
def available_models() -> List[str]: |
|
|
"""Returns the names of available models""" |
|
|
return list(_MODELS.keys()) |
|
|
|
|
|
|
|
|
def load_model( |
|
|
name: str, |
|
|
device: Optional[Union[str, torch.device]] = None, |
|
|
download_root: str = None, |
|
|
in_memory: bool = False, |
|
|
) -> Whisper: |
|
|
""" |
|
|
Load a Whisper ASR model |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
name : str |
|
|
one of the official model names listed by `whisper.available_models()`, or |
|
|
path to a model checkpoint containing the model dimensions and the model state_dict. |
|
|
device : Union[str, torch.device] |
|
|
the PyTorch device to put the model into |
|
|
download_root: str |
|
|
path to download the model files; by default, it uses "~/.cache/whisper" |
|
|
in_memory: bool |
|
|
whether to preload the model weights into host memory |
|
|
|
|
|
Returns |
|
|
------- |
|
|
model : Whisper |
|
|
The Whisper ASR model instance |
|
|
""" |
|
|
|
|
|
if device is None: |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
if download_root is None: |
|
|
default = os.path.join(os.path.expanduser("~"), ".cache") |
|
|
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") |
|
|
|
|
|
if name in _MODELS: |
|
|
checkpoint_file = _download(_MODELS[name], download_root, in_memory) |
|
|
alignment_heads = _ALIGNMENT_HEADS[name] |
|
|
elif os.path.isfile(name): |
|
|
checkpoint_file = open(name, "rb").read() if in_memory else name |
|
|
alignment_heads = None |
|
|
else: |
|
|
raise RuntimeError( |
|
|
f"Model {name} not found; available models = {available_models()}" |
|
|
) |
|
|
|
|
|
with ( |
|
|
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") |
|
|
) as fp: |
|
|
checkpoint = torch.load(fp, map_location=device) |
|
|
del checkpoint_file |
|
|
|
|
|
dims = ModelDimensions(**checkpoint["dims"]) |
|
|
model = Whisper(dims) |
|
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
|
|
|
if alignment_heads is not None: |
|
|
model.set_alignment_heads(alignment_heads) |
|
|
|
|
|
return model.to(device) |
|
|
|
|
|
|
|
|
def load_streaming_model( |
|
|
name: str, |
|
|
advisor_ckpt_path: str = None, |
|
|
ft_model_ckpt_path: str = None, |
|
|
device: Optional[Union[str, torch.device]] = None, |
|
|
download_root: str = None, |
|
|
in_memory: bool = False, |
|
|
cache_gran: bool = True, |
|
|
gran: int = 15, |
|
|
rank: int = 8, |
|
|
extra_gran_blocks: int = 0, |
|
|
n_advisor_class: int = 4, |
|
|
**kwargs: any |
|
|
) -> StreamingWhisper: |
|
|
""" |
|
|
Load a StreamingWhisper ASR model |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
name : str |
|
|
one of the official model names listed by `whisper.available_models()`, or |
|
|
path to a model checkpoint containing the model dimensions and the model state_dict. |
|
|
device : Union[str, torch.device] |
|
|
the PyTorch device to put the model into |
|
|
download_root: str |
|
|
path to download the model files; by default, it uses "~/.cache/whisper" |
|
|
in_memory: bool |
|
|
whether to preload the model weights into host memory |
|
|
|
|
|
Returns |
|
|
------- |
|
|
model : Whisper |
|
|
The Whisper ASR model instance |
|
|
""" |
|
|
if ft_model_ckpt_path is None: |
|
|
if device is None: |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
if download_root is None: |
|
|
default = os.path.join(os.path.expanduser("~"), ".cache") |
|
|
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") |
|
|
|
|
|
if name in _MODELS: |
|
|
checkpoint_file = _download(_MODELS[name], download_root, in_memory) |
|
|
alignment_heads = _ALIGNMENT_HEADS[name] |
|
|
elif os.path.isfile(name): |
|
|
checkpoint_file = open(name, "rb").read() if in_memory else name |
|
|
alignment_heads = None |
|
|
else: |
|
|
raise RuntimeError( |
|
|
f"Model {name} not found; available models = {available_models()}" |
|
|
) |
|
|
|
|
|
with ( |
|
|
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") |
|
|
) as fp: |
|
|
checkpoint = torch.load(fp, map_location=device) |
|
|
del checkpoint_file |
|
|
else: |
|
|
checkpoint = torch.load(ft_model_ckpt_path, weights_only=False) |
|
|
|
|
|
decoder_advisor_chkpt = torch.load(advisor_ckpt_path, weights_only=False) if advisor_ckpt_path is not None else {"state_dict": {}} |
|
|
advisor_state_dict = {k: v for k, v in decoder_advisor_chkpt["state_dict"].items() if "decoder_advisor" in k} |
|
|
|
|
|
whisper_dict = checkpoint["model_state_dict"] if "model_state_dict" in checkpoint.keys() else checkpoint["state_dict"] |
|
|
|
|
|
whisper_dict = {k.replace("weight", "base_layer.weight") if "attn." in k and "weight" in k |
|
|
else k.replace("bias", "base_layer.bias") if "attn." in k and "bias" in k |
|
|
else k: v for k, v in whisper_dict.items()} |
|
|
|
|
|
streaming_whisper_state_dict = {**advisor_state_dict, **whisper_dict} |
|
|
|
|
|
dims = ModelDimensions(**checkpoint["dims"]) |
|
|
|
|
|
model = StreamingWhisper(dims, |
|
|
cache_gran=cache_gran, |
|
|
gran=gran, |
|
|
rank=rank, |
|
|
extra_gran_blocks=extra_gran_blocks, |
|
|
n_advisor_class=n_advisor_class, |
|
|
**kwargs) |
|
|
|
|
|
model.load_state_dict(streaming_whisper_state_dict, strict=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ft_model_ckpt_path is None and alignment_heads is not None: |
|
|
model.set_alignment_heads(alignment_heads) |
|
|
|
|
|
return model.to(device) |
|
|
|
|
|
|
|
|
def load_streaming_model_correct( |
|
|
name: str, |
|
|
gran: int = 300, |
|
|
multilingual: bool = False, |
|
|
device: Optional[Union[str, torch.device]] = None, |
|
|
download_root: str = None, |
|
|
in_memory: bool = False, |
|
|
) -> StreamingWhisper: |
|
|
|
|
|
subname = (str(gran) + '-multi') if multilingual else str(gran) |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
|
|
|
|
try: |
|
|
ckpt_path = hf_hub_download(repo_id="MLSpeech/causal-whisper", filename=_STREAMING_MODELS_HF[name][subname], repo_type="model", token=hf_token) |
|
|
except KeyError as e: |
|
|
print(f"Streaming model with the next configs: size {name}, multilingual: {multilingual} and chunk size: {gran} is not available.") |
|
|
|
|
|
checkpoint = torch.load(ckpt_path, weights_only=False) |
|
|
|
|
|
dims = ModelDimensions(**checkpoint["dims"]) |
|
|
|
|
|
model = StreamingWhisper(dims, |
|
|
gran=checkpoint['cfg']['gran'], |
|
|
rank=checkpoint['cfg']['rank'], |
|
|
extra_gran_blocks=checkpoint['cfg']['extra_gran_blocks']) |
|
|
|
|
|
model.load_state_dict(checkpoint['state_dict'], strict=False) |
|
|
|
|
|
return model.to(device) |
|
|
|