Spaces:
Sleeping
Sleeping
| # Copyright (c) 2023 OpenAI. (authors: Whisper Team) | |
| # 2024 Tsinghua Univ. (authors: Xingchen Song) | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Modified from | |
| https://github.com/openai/whisper/blob/main/whisper/__init__.py | |
| """ | |
| import hashlib | |
| import os | |
| import urllib | |
| import warnings | |
| from typing import List, Union | |
| from tqdm import tqdm | |
| from s3tokenizer.model_v2 import S3TokenizerV2 | |
| from .model import S3Tokenizer | |
| from .utils import (load_audio, log_mel_spectrogram, make_non_pad_mask, | |
| mask_to_bias, onnx2torch, padding, merge_tokenized_segments) | |
| __all__ = [ | |
| 'load_audio', 'log_mel_spectrogram', 'make_non_pad_mask', 'mask_to_bias', | |
| 'onnx2torch', 'padding', 'merge_tokenized_segments' | |
| ] | |
| _MODELS = { | |
| "speech_tokenizer_v1": | |
| "https://www.modelscope.cn/models/iic/cosyvoice-300m/" | |
| "resolve/master/speech_tokenizer_v1.onnx", | |
| "speech_tokenizer_v1_25hz": | |
| "https://www.modelscope.cn/models/iic/CosyVoice-300M-25Hz/" | |
| "resolve/master/speech_tokenizer_v1.onnx", | |
| "speech_tokenizer_v2_25hz": | |
| "https://www.modelscope.cn/models/iic/CosyVoice2-0.5B/" | |
| "resolve/master/speech_tokenizer_v2.onnx", | |
| } | |
| _SHA256S = { | |
| "speech_tokenizer_v1": | |
| "23b5a723ed9143aebfd9ffda14ac4c21231f31c35ef837b6a13bb9e5488abb1e", | |
| "speech_tokenizer_v1_25hz": | |
| "56285ddd4a83e883ee0cb9f8d69c1089b53a94b1f78ff7e4a0224a27eb4cb486", | |
| "speech_tokenizer_v2_25hz": | |
| "d43342aa12163a80bf07bffb94c9de2e120a8df2f9917cd2f642e7f4219c6f71", | |
| } | |
| def _download(name: str, root: str) -> Union[bytes, str]: | |
| os.makedirs(root, exist_ok=True) | |
| expected_sha256 = _SHA256S[name] | |
| url = _MODELS[name] | |
| download_target = os.path.join(root, f"{name}.onnx") | |
| if os.path.exists(download_target) and not os.path.isfile(download_target): | |
| raise RuntimeError( | |
| f"{download_target} exists and is not a regular file") | |
| if os.path.isfile(download_target): | |
| with open(download_target, "rb") as f: | |
| model_bytes = f.read() | |
| if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: | |
| return download_target | |
| else: | |
| warnings.warn( | |
| f"{download_target} exists, but the SHA256 checksum does not" | |
| " match; re-downloading the file") | |
| with urllib.request.urlopen(url) as source, open(download_target, | |
| "wb") as output: | |
| with tqdm( | |
| total=int(source.info().get("Content-Length")), | |
| ncols=80, | |
| unit="iB", | |
| unit_scale=True, | |
| unit_divisor=1024, | |
| desc="Downloading onnx checkpoint", | |
| ) as loop: | |
| while True: | |
| buffer = source.read(8192) | |
| if not buffer: | |
| break | |
| output.write(buffer) | |
| loop.update(len(buffer)) | |
| model_bytes = open(download_target, "rb").read() | |
| if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: | |
| raise RuntimeError( | |
| "Model has been downloaded but the SHA256 checksum does not not" | |
| " match. Please retry loading the model.") | |
| return download_target | |
| def available_models() -> List[str]: | |
| """Returns the names of available models""" | |
| return list(_MODELS.keys()) | |
| def load_model( | |
| name: str, | |
| download_root: str = None, | |
| ) -> S3Tokenizer: | |
| """ | |
| Load a S3Tokenizer ASR model | |
| Parameters | |
| ---------- | |
| name : str | |
| one of the official model names listed by | |
| `s3tokenizer.available_models()`, or path to a model checkpoint | |
| containing the model dimensions and the model state_dict. | |
| download_root: str | |
| path to download the model files; by default, | |
| it uses "~/.cache/s3tokenizer" | |
| Returns | |
| ------- | |
| model : S3Tokenizer | |
| The S3Tokenizer model instance | |
| """ | |
| if download_root is None: | |
| default = os.path.join(os.path.expanduser("~"), ".cache") | |
| download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), | |
| "s3tokenizer") | |
| if name in _MODELS: | |
| checkpoint_file = _download(name, download_root) | |
| elif os.path.isfile(name): | |
| checkpoint_file = name | |
| else: | |
| raise RuntimeError( | |
| f"Model {name} not found; available models = {available_models()}") | |
| if 'v2' in name: | |
| model = S3TokenizerV2(name) | |
| else: | |
| model = S3Tokenizer(name) | |
| model.init_from_onnx(checkpoint_file) | |
| return model | |