Spaces:
Running
Running
| import hashlib | |
| import os | |
| import tarfile | |
| import urllib.request | |
| from tqdm import tqdm | |
| def print_arguments(args): | |
| print("----------- Configuration Arguments -----------") | |
| for arg, value in vars(args).items(): | |
| print("%s: %s" % (arg, value)) | |
| print("------------------------------------------------") | |
| def strtobool(val): | |
| val = val.lower() | |
| if val in ('y', 'yes', 't', 'true', 'on', '1'): | |
| return True | |
| elif val in ('n', 'no', 'f', 'false', 'off', '0'): | |
| return False | |
| else: | |
| raise ValueError("invalid truth value %r" % (val,)) | |
| def str_none(val): | |
| if val == 'None': | |
| return None | |
| else: | |
| return val | |
| def add_arguments(argname, type, default, help, argparser, **kwargs): | |
| type = strtobool if type == bool else type | |
| type = str_none if type == str else type | |
| argparser.add_argument("--" + argname, | |
| default=default, | |
| type=type, | |
| help=help + ' Default: %(default)s.', | |
| **kwargs) | |
| def md5file(fname): | |
| hash_md5 = hashlib.md5() | |
| f = open(fname, "rb") | |
| for chunk in iter(lambda: f.read(4096), b""): | |
| hash_md5.update(chunk) | |
| f.close() | |
| return hash_md5.hexdigest() | |
| def download(url, md5sum, target_dir): | |
| """Download file from url to target_dir, and check md5sum.""" | |
| if not os.path.exists(target_dir): os.makedirs(target_dir) | |
| filepath = os.path.join(target_dir, url.split("/")[-1]) | |
| if not (os.path.exists(filepath) and md5file(filepath) == md5sum): | |
| print(f"Downloading {url} to {filepath} ...") | |
| with urllib.request.urlopen(url) as source, open(filepath, "wb") as output: | |
| with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, | |
| unit_divisor=1024) as loop: | |
| while True: | |
| buffer = source.read(8192) | |
| if not buffer: | |
| break | |
| output.write(buffer) | |
| loop.update(len(buffer)) | |
| print(f"\nMD5 Chesksum {filepath} ...") | |
| if not md5file(filepath) == md5sum: | |
| raise RuntimeError("MD5 checksum failed.") | |
| else: | |
| print(f"File exists, skip downloading. ({filepath})") | |
| return filepath | |
| def unpack(filepath, target_dir, rm_tar=False): | |
| """Unpack the file to the target_dir.""" | |
| print("Unpacking %s ..." % filepath) | |
| tar = tarfile.open(filepath) | |
| tar.extractall(target_dir) | |
| tar.close() | |
| if rm_tar: | |
| os.remove(filepath) | |
| def make_inputs_require_grad(module, input, output): | |
| output.requires_grad_(True) | |