| | """ |
| | Routines for loading DeepSpeech model. |
| | """ |
| |
|
| | __all__ = ['get_deepspeech_model_file'] |
| |
|
| | import os |
| | import zipfile |
| | import logging |
| | import hashlib |
| |
|
| |
|
| | deepspeech_features_repo_url = 'https://github.com/osmr/deepspeech_features' |
| |
|
| |
|
| | def get_deepspeech_model_file(local_model_store_dir_path=os.path.join("~", ".tensorflow", "models")): |
| | """ |
| | Return location for the pretrained on local file system. This function will download from online model zoo when |
| | model cannot be found or has mismatch. The root directory will be created if it doesn't exist. |
| | |
| | Parameters |
| | ---------- |
| | local_model_store_dir_path : str, default $TENSORFLOW_HOME/models |
| | Location for keeping the model parameters. |
| | |
| | Returns |
| | ------- |
| | file_path |
| | Path to the requested pretrained model file. |
| | """ |
| | sha1_hash = "b90017e816572ddce84f5843f1fa21e6a377975e" |
| | file_name = "deepspeech-0_1_0-b90017e8.pb" |
| | local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path) |
| | file_path = os.path.join(local_model_store_dir_path, file_name) |
| | if os.path.exists(file_path): |
| | if _check_sha1(file_path, sha1_hash): |
| | return file_path |
| | else: |
| | logging.warning("Mismatch in the content of model file detected. Downloading again.") |
| | else: |
| | logging.info("Model file not found. Downloading to {}.".format(file_path)) |
| |
|
| | if not os.path.exists(local_model_store_dir_path): |
| | os.makedirs(local_model_store_dir_path) |
| |
|
| | zip_file_path = file_path + ".zip" |
| | _download( |
| | url="{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip".format( |
| | repo_url=deepspeech_features_repo_url, |
| | repo_release_tag="v0.0.1", |
| | file_name=file_name), |
| | path=zip_file_path, |
| | overwrite=True) |
| | with zipfile.ZipFile(zip_file_path) as zf: |
| | zf.extractall(local_model_store_dir_path) |
| | os.remove(zip_file_path) |
| |
|
| | if _check_sha1(file_path, sha1_hash): |
| | return file_path |
| | else: |
| | raise ValueError("Downloaded file has different hash. Please try again.") |
| |
|
| |
|
| | def _download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True): |
| | """ |
| | Download an given URL |
| | |
| | Parameters |
| | ---------- |
| | url : str |
| | URL to download |
| | path : str, optional |
| | Destination path to store downloaded file. By default stores to the |
| | current directory with same name as in url. |
| | overwrite : bool, optional |
| | Whether to overwrite destination file if already exists. |
| | sha1_hash : str, optional |
| | Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified |
| | but doesn't match. |
| | retries : integer, default 5 |
| | The number of times to attempt the download in case of failure or non 200 return codes |
| | verify_ssl : bool, default True |
| | Verify SSL certificates. |
| | |
| | Returns |
| | ------- |
| | str |
| | The file path of the downloaded file. |
| | """ |
| | import warnings |
| | try: |
| | import requests |
| | except ImportError: |
| | class requests_failed_to_import(object): |
| | pass |
| | requests = requests_failed_to_import |
| |
|
| | if path is None: |
| | fname = url.split("/")[-1] |
| | |
| | assert fname, "Can't construct file-name from this URL. Please set the `path` option manually." |
| | else: |
| | path = os.path.expanduser(path) |
| | if os.path.isdir(path): |
| | fname = os.path.join(path, url.split("/")[-1]) |
| | else: |
| | fname = path |
| | assert retries >= 0, "Number of retries should be at least 0" |
| |
|
| | if not verify_ssl: |
| | warnings.warn( |
| | "Unverified HTTPS request is being made (verify_ssl=False). " |
| | "Adding certificate verification is strongly advised.") |
| |
|
| | if overwrite or not os.path.exists(fname) or (sha1_hash and not _check_sha1(fname, sha1_hash)): |
| | dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) |
| | if not os.path.exists(dirname): |
| | os.makedirs(dirname) |
| | while retries + 1 > 0: |
| | |
| | |
| | try: |
| | print("Downloading {} from {}...".format(fname, url)) |
| | r = requests.get(url, stream=True, verify=verify_ssl) |
| | if r.status_code != 200: |
| | raise RuntimeError("Failed downloading url {}".format(url)) |
| | with open(fname, "wb") as f: |
| | for chunk in r.iter_content(chunk_size=1024): |
| | if chunk: |
| | f.write(chunk) |
| | if sha1_hash and not _check_sha1(fname, sha1_hash): |
| | raise UserWarning("File {} is downloaded but the content hash does not match." |
| | " The repo may be outdated or download may be incomplete. " |
| | "If the `repo_url` is overridden, consider switching to " |
| | "the default repo.".format(fname)) |
| | break |
| | except Exception as e: |
| | retries -= 1 |
| | if retries <= 0: |
| | raise e |
| | else: |
| | print("download failed, retrying, {} attempt{} left" |
| | .format(retries, "s" if retries > 1 else "")) |
| |
|
| | return fname |
| |
|
| |
|
| | def _check_sha1(filename, sha1_hash): |
| | """ |
| | Check whether the sha1 hash of the file content matches the expected hash. |
| | |
| | Parameters |
| | ---------- |
| | filename : str |
| | Path to the file. |
| | sha1_hash : str |
| | Expected sha1 hash in hexadecimal digits. |
| | |
| | Returns |
| | ------- |
| | bool |
| | Whether the file content matches the expected hash. |
| | """ |
| | sha1 = hashlib.sha1() |
| | with open(filename, "rb") as f: |
| | while True: |
| | data = f.read(1048576) |
| | if not data: |
| | break |
| | sha1.update(data) |
| |
|
| | return sha1.hexdigest() == sha1_hash |
| |
|