Spaces:
Sleeping
Sleeping
| """ | |
| Routines for loading DeepSpeech model. | |
| """ | |
| __all__ = ['get_deepspeech_model_file'] | |
| import os | |
| import zipfile | |
| import logging | |
| import hashlib | |
| deepspeech_features_repo_url = 'https://github.com/osmr/deepspeech_features' | |
| def get_deepspeech_model_file(local_model_store_dir_path=os.path.join("~", ".tensorflow", "models")): | |
| """ | |
| Return location for the pretrained on local file system. This function will download from online model zoo when | |
| model cannot be found or has mismatch. The root directory will be created if it doesn't exist. | |
| Parameters | |
| ---------- | |
| local_model_store_dir_path : str, default $TENSORFLOW_HOME/models | |
| Location for keeping the model parameters. | |
| Returns | |
| ------- | |
| file_path | |
| Path to the requested pretrained model file. | |
| """ | |
| sha1_hash = "b90017e816572ddce84f5843f1fa21e6a377975e" | |
| file_name = "deepspeech-0_1_0-b90017e8.pb" | |
| local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path) | |
| file_path = os.path.join(local_model_store_dir_path, file_name) | |
| if os.path.exists(file_path): | |
| if _check_sha1(file_path, sha1_hash): | |
| return file_path | |
| else: | |
| logging.warning("Mismatch in the content of model file detected. Downloading again.") | |
| else: | |
| logging.info("Model file not found. Downloading to {}.".format(file_path)) | |
| if not os.path.exists(local_model_store_dir_path): | |
| os.makedirs(local_model_store_dir_path) | |
| zip_file_path = file_path + ".zip" | |
| _download( | |
| url="{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip".format( | |
| repo_url=deepspeech_features_repo_url, | |
| repo_release_tag="v0.0.1", | |
| file_name=file_name), | |
| path=zip_file_path, | |
| overwrite=True) | |
| with zipfile.ZipFile(zip_file_path) as zf: | |
| zf.extractall(local_model_store_dir_path) | |
| os.remove(zip_file_path) | |
| if _check_sha1(file_path, sha1_hash): | |
| return file_path | |
| else: | |
| raise ValueError("Downloaded file has different hash. Please try again.") | |
| def _download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True): | |
| """ | |
| Download an given URL | |
| Parameters | |
| ---------- | |
| url : str | |
| URL to download | |
| path : str, optional | |
| Destination path to store downloaded file. By default stores to the | |
| current directory with same name as in url. | |
| overwrite : bool, optional | |
| Whether to overwrite destination file if already exists. | |
| sha1_hash : str, optional | |
| Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified | |
| but doesn't match. | |
| retries : integer, default 5 | |
| The number of times to attempt the download in case of failure or non 200 return codes | |
| verify_ssl : bool, default True | |
| Verify SSL certificates. | |
| Returns | |
| ------- | |
| str | |
| The file path of the downloaded file. | |
| """ | |
| import warnings | |
| try: | |
| import requests | |
| except ImportError: | |
| class requests_failed_to_import(object): | |
| pass | |
| requests = requests_failed_to_import | |
| if path is None: | |
| fname = url.split("/")[-1] | |
| # Empty filenames are invalid | |
| assert fname, "Can't construct file-name from this URL. Please set the `path` option manually." | |
| else: | |
| path = os.path.expanduser(path) | |
| if os.path.isdir(path): | |
| fname = os.path.join(path, url.split("/")[-1]) | |
| else: | |
| fname = path | |
| assert retries >= 0, "Number of retries should be at least 0" | |
| if not verify_ssl: | |
| warnings.warn( | |
| "Unverified HTTPS request is being made (verify_ssl=False). " | |
| "Adding certificate verification is strongly advised.") | |
| if overwrite or not os.path.exists(fname) or (sha1_hash and not _check_sha1(fname, sha1_hash)): | |
| dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) | |
| if not os.path.exists(dirname): | |
| os.makedirs(dirname) | |
| while retries + 1 > 0: | |
| # Disable pyling too broad Exception | |
| # pylint: disable=W0703 | |
| try: | |
| print("Downloading {} from {}...".format(fname, url)) | |
| r = requests.get(url, stream=True, verify=verify_ssl) | |
| if r.status_code != 200: | |
| raise RuntimeError("Failed downloading url {}".format(url)) | |
| with open(fname, "wb") as f: | |
| for chunk in r.iter_content(chunk_size=1024): | |
| if chunk: # filter out keep-alive new chunks | |
| f.write(chunk) | |
| if sha1_hash and not _check_sha1(fname, sha1_hash): | |
| raise UserWarning("File {} is downloaded but the content hash does not match." | |
| " The repo may be outdated or download may be incomplete. " | |
| "If the `repo_url` is overridden, consider switching to " | |
| "the default repo.".format(fname)) | |
| break | |
| except Exception as e: | |
| retries -= 1 | |
| if retries <= 0: | |
| raise e | |
| else: | |
| print("download failed, retrying, {} attempt{} left" | |
| .format(retries, "s" if retries > 1 else "")) | |
| return fname | |
| def _check_sha1(filename, sha1_hash): | |
| """ | |
| Check whether the sha1 hash of the file content matches the expected hash. | |
| Parameters | |
| ---------- | |
| filename : str | |
| Path to the file. | |
| sha1_hash : str | |
| Expected sha1 hash in hexadecimal digits. | |
| Returns | |
| ------- | |
| bool | |
| Whether the file content matches the expected hash. | |
| """ | |
| sha1 = hashlib.sha1() | |
| with open(filename, "rb") as f: | |
| while True: | |
| data = f.read(1048576) | |
| if not data: | |
| break | |
| sha1.update(data) | |
| return sha1.hexdigest() == sha1_hash | |