File size: 5,048 Bytes
f768eb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# Copyright (c) 2023 OpenAI. (authors: Whisper Team)
#               2024 Tsinghua Univ. (authors: Xingchen Song)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modified from
    https://github.com/openai/whisper/blob/main/whisper/__init__.py
"""

import hashlib
import os
import urllib
import warnings
from typing import List, Union

from tqdm import tqdm

from s3tokenizer.model_v2 import S3TokenizerV2

from .model import S3Tokenizer
from .utils import (load_audio, log_mel_spectrogram, make_non_pad_mask,
                    mask_to_bias, onnx2torch, padding, merge_tokenized_segments)

__all__ = [
    'load_audio', 'log_mel_spectrogram', 'make_non_pad_mask', 'mask_to_bias',
    'onnx2torch', 'padding', 'merge_tokenized_segments'
]
_MODELS = {
    "speech_tokenizer_v1":
    "https://www.modelscope.cn/models/iic/cosyvoice-300m/"
    "resolve/master/speech_tokenizer_v1.onnx",
    "speech_tokenizer_v1_25hz":
    "https://www.modelscope.cn/models/iic/CosyVoice-300M-25Hz/"
    "resolve/master/speech_tokenizer_v1.onnx",
    "speech_tokenizer_v2_25hz":
    "https://www.modelscope.cn/models/iic/CosyVoice2-0.5B/"
    "resolve/master/speech_tokenizer_v2.onnx",
}

_SHA256S = {
    "speech_tokenizer_v1":
    "23b5a723ed9143aebfd9ffda14ac4c21231f31c35ef837b6a13bb9e5488abb1e",
    "speech_tokenizer_v1_25hz":
    "56285ddd4a83e883ee0cb9f8d69c1089b53a94b1f78ff7e4a0224a27eb4cb486",
    "speech_tokenizer_v2_25hz":
    "d43342aa12163a80bf07bffb94c9de2e120a8df2f9917cd2f642e7f4219c6f71",
}


def _download(name: str, root: str) -> Union[bytes, str]:
    os.makedirs(root, exist_ok=True)

    expected_sha256 = _SHA256S[name]
    url = _MODELS[name]
    download_target = os.path.join(root, f"{name}.onnx")

    if os.path.exists(download_target) and not os.path.isfile(download_target):
        raise RuntimeError(
            f"{download_target} exists and is not a regular file")

    if os.path.isfile(download_target):
        with open(download_target, "rb") as f:
            model_bytes = f.read()
        if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
            return download_target
        else:
            warnings.warn(
                f"{download_target} exists, but the SHA256 checksum does not"
                " match; re-downloading the file")

    with urllib.request.urlopen(url) as source, open(download_target,
                                                     "wb") as output:
        with tqdm(
                total=int(source.info().get("Content-Length")),
                ncols=80,
                unit="iB",
                unit_scale=True,
                unit_divisor=1024,
                desc="Downloading onnx checkpoint",
        ) as loop:
            while True:
                buffer = source.read(8192)
                if not buffer:
                    break

                output.write(buffer)
                loop.update(len(buffer))

    model_bytes = open(download_target, "rb").read()
    if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
        raise RuntimeError(
            "Model has been downloaded but the SHA256 checksum does not not"
            " match. Please retry loading the model.")

    return download_target


def available_models() -> List[str]:
    """Returns the names of available models"""
    return list(_MODELS.keys())


def load_model(
    name: str,
    download_root: str = None,
) -> S3Tokenizer:
    """
    Load a S3Tokenizer ASR model

    Parameters
    ----------
    name : str
        one of the official model names listed by
        `s3tokenizer.available_models()`, or path to a model checkpoint
         containing the model dimensions and the model state_dict.
    download_root: str
        path to download the model files; by default,
        it uses "~/.cache/s3tokenizer"

    Returns
    -------
    model : S3Tokenizer
        The S3Tokenizer model instance
    """

    if download_root is None:
        default = os.path.join(os.path.expanduser("~"), ".cache")
        download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default),
                                     "s3tokenizer")

    if name in _MODELS:
        checkpoint_file = _download(name, download_root)
    elif os.path.isfile(name):
        checkpoint_file = name
    else:
        raise RuntimeError(
            f"Model {name} not found; available models = {available_models()}")
    if 'v2' in name:
        model = S3TokenizerV2(name)
    else:
        model = S3Tokenizer(name)
    model.init_from_onnx(checkpoint_file)

    return model