| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import hashlib |
| | import io |
| | import os |
| | import urllib |
| | import warnings |
| | from typing import List, Optional, Union |
| |
|
| | import torch |
| | from tqdm import tqdm |
| |
|
| | from .audio import load_audio, log_mel_spectrogram, pad_or_trim |
| | from .decoding import DecodingOptions, DecodingResult, decode, detect_language |
| | from .model import Whisper, ModelDimensions |
| | from .transcribe import transcribe |
| | from .version import __version__ |
| |
|
| |
|
| | _MODELS = { |
| | "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", |
| | "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", |
| | "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", |
| | "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", |
| | "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", |
| | "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", |
| | "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", |
| | "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", |
| | "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", |
| | "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", |
| | "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", |
| | } |
| |
|
| |
|
| | def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: |
| | os.makedirs(root, exist_ok=True) |
| |
|
| | expected_sha256 = url.split("/")[-2] |
| | download_target = os.path.join(root, os.path.basename(url)) |
| |
|
| | if os.path.exists(download_target) and not os.path.isfile(download_target): |
| | raise RuntimeError(f"{download_target} exists and is not a regular file") |
| |
|
| | if os.path.isfile(download_target): |
| | with open(download_target, "rb") as f: |
| | model_bytes = f.read() |
| | if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: |
| | return model_bytes if in_memory else download_target |
| | else: |
| | warnings.warn( |
| | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" |
| | ) |
| |
|
| | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: |
| | with tqdm( |
| | total=int(source.info().get("Content-Length")), |
| | ncols=80, |
| | unit="iB", |
| | unit_scale=True, |
| | unit_divisor=1024, |
| | ) as loop: |
| | while True: |
| | buffer = source.read(8192) |
| | if not buffer: |
| | break |
| |
|
| | output.write(buffer) |
| | loop.update(len(buffer)) |
| |
|
| | model_bytes = open(download_target, "rb").read() |
| | if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: |
| | raise RuntimeError( |
| | "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." |
| | ) |
| |
|
| | return model_bytes if in_memory else download_target |
| |
|
| |
|
| | def available_models() -> List[str]: |
| | """Returns the names of available models""" |
| | return list(_MODELS.keys()) |
| |
|
| |
|
| | def load_model( |
| | name: str, |
| | device: Optional[Union[str, torch.device]] = None, |
| | download_root: str = None, |
| | in_memory: bool = False, |
| | checkpoint_file=None, |
| | ) -> Whisper: |
| | """ |
| | Load a Whisper ASR model |
| | |
| | Parameters |
| | ---------- |
| | name : str |
| | one of the official model names listed by `whisper.available_models()`, or |
| | path to a model checkpoint containing the model dimensions and the model state_dict. |
| | device : Union[str, torch.device] |
| | the PyTorch device to put the model into |
| | download_root: str |
| | path to download the model files; by default, it uses "~/.cache/whisper" |
| | in_memory: bool |
| | whether to preload the model weights into host memory |
| | |
| | Returns |
| | ------- |
| | model : Whisper |
| | The Whisper ASR model instance |
| | """ |
| |
|
| | if device is None: |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | if download_root is None: |
| | download_root = os.getenv( |
| | "XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache", "whisper") |
| | ) |
| |
|
| | if not os.path.exists(checkpoint_file): |
| | if name in _MODELS: |
| | checkpoint_file = _download(_MODELS[name], download_root, in_memory) |
| | elif os.path.isfile(name): |
| | checkpoint_file = open(name, "rb").read() if in_memory else name |
| | else: |
| | raise RuntimeError( |
| | f"Model {name} not found; available models = {available_models()}" |
| | ) |
| | else: |
| | checkpoint_file = ( |
| | open(checkpoint_file, "rb").read() if in_memory else checkpoint_file |
| | ) |
| |
|
| | with ( |
| | io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") |
| | ) as fp: |
| | checkpoint = torch.load(fp, map_location=device) |
| | del checkpoint_file |
| |
|
| | dims = ModelDimensions(**checkpoint["dims"]) |
| | model = Whisper(dims) |
| | model.load_state_dict(checkpoint["model_state_dict"]) |
| |
|
| | return model.to(device) |
| |
|