|
|
""" |
|
|
This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/utils/download.py |
|
|
""" |
|
|
import os |
|
|
import hashlib |
|
|
import requests |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
def check_sha1(filename, sha1_hash): |
|
|
"""Check whether the sha1 hash of the file content matches the expected hash. |
|
|
Parameters |
|
|
---------- |
|
|
filename : str |
|
|
Path to the file. |
|
|
sha1_hash : str |
|
|
Expected sha1 hash in hexadecimal digits. |
|
|
Returns |
|
|
------- |
|
|
bool |
|
|
Whether the file content matches the expected hash. |
|
|
""" |
|
|
sha1 = hashlib.sha1() |
|
|
with open(filename, 'rb') as f: |
|
|
while True: |
|
|
data = f.read(1048576) |
|
|
if not data: |
|
|
break |
|
|
sha1.update(data) |
|
|
|
|
|
sha1_file = sha1.hexdigest() |
|
|
l = min(len(sha1_file), len(sha1_hash)) |
|
|
return sha1.hexdigest()[0:l] == sha1_hash[0:l] |
|
|
|
|
|
|
|
|
def download_file(url, path=None, overwrite=False, sha1_hash=None): |
|
|
"""Download an given URL |
|
|
Parameters |
|
|
---------- |
|
|
url : str |
|
|
URL to download |
|
|
path : str, optional |
|
|
Destination path to store downloaded file. By default stores to the |
|
|
current directory with same name as in url. |
|
|
overwrite : bool, optional |
|
|
Whether to overwrite destination file if already exists. |
|
|
sha1_hash : str, optional |
|
|
Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified |
|
|
but doesn't match. |
|
|
Returns |
|
|
------- |
|
|
str |
|
|
The file path of the downloaded file. |
|
|
""" |
|
|
if path is None: |
|
|
fname = url.split('/')[-1] |
|
|
else: |
|
|
path = os.path.expanduser(path) |
|
|
if os.path.isdir(path): |
|
|
fname = os.path.join(path, url.split('/')[-1]) |
|
|
else: |
|
|
fname = path |
|
|
|
|
|
if overwrite or not os.path.exists(fname) or ( |
|
|
sha1_hash and not check_sha1(fname, sha1_hash)): |
|
|
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) |
|
|
if not os.path.exists(dirname): |
|
|
os.makedirs(dirname) |
|
|
|
|
|
print('Downloading %s from %s...' % (fname, url)) |
|
|
r = requests.get(url, stream=True) |
|
|
if r.status_code != 200: |
|
|
raise RuntimeError("Failed downloading url %s" % url) |
|
|
total_length = r.headers.get('content-length') |
|
|
with open(fname, 'wb') as f: |
|
|
if total_length is None: |
|
|
for chunk in r.iter_content(chunk_size=1024): |
|
|
if chunk: |
|
|
f.write(chunk) |
|
|
else: |
|
|
total_length = int(total_length) |
|
|
for chunk in tqdm(r.iter_content(chunk_size=1024), |
|
|
total=int(total_length / 1024. + 0.5), |
|
|
unit='KB', |
|
|
unit_scale=False, |
|
|
dynamic_ncols=True): |
|
|
f.write(chunk) |
|
|
|
|
|
if sha1_hash and not check_sha1(fname, sha1_hash): |
|
|
raise UserWarning('File {} is downloaded but the content hash does not match. ' \ |
|
|
'The repo may be outdated or download may be incomplete. ' \ |
|
|
'If the "repo_url" is overridden, consider switching to ' \ |
|
|
'the default repo.'.format(fname)) |
|
|
|
|
|
return fname |
|
|
|