| | import os, hashlib |
| | import requests |
| | from tqdm import tqdm |
| | import importlib |
| |
|
| | URL_MAP = { |
| | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" |
| | } |
| |
|
| | CKPT_MAP = { |
| | "vgg_lpips": "vgg.pth" |
| | } |
| |
|
| | MD5_MAP = { |
| | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" |
| | } |
| |
|
| |
|
| | def get_obj_from_str(string, reload=False): |
| | module, cls = string.rsplit(".", 1) |
| | if reload: |
| | module_imp = importlib.import_module(module) |
| | importlib.reload(module_imp) |
| | return getattr(importlib.import_module(module, package=None), cls) |
| |
|
| |
|
| | def instantiate_from_config(config): |
| | if not "target" in config: |
| | raise KeyError("Expected key `target` to instantiate.") |
| | return get_obj_from_str(config["target"])(**config.get("params", dict())) |
| |
|
| |
|
| | def download(url, local_path, chunk_size=1024): |
| | os.makedirs(os.path.split(local_path)[0], exist_ok=True) |
| | with requests.get(url, stream=True) as r: |
| | total_size = int(r.headers.get("content-length", 0)) |
| | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: |
| | with open(local_path, "wb") as f: |
| | for data in r.iter_content(chunk_size=chunk_size): |
| | if data: |
| | f.write(data) |
| | pbar.update(chunk_size) |
| |
|
| |
|
| | def md5_hash(path): |
| | with open(path, "rb") as f: |
| | content = f.read() |
| | return hashlib.md5(content).hexdigest() |
| |
|
| |
|
| | def get_ckpt_path(name, root, check=False): |
| | assert name in URL_MAP |
| | path = os.path.join(root, CKPT_MAP[name]) |
| | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): |
| | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) |
| | download(URL_MAP[name], path) |
| | md5 = md5_hash(path) |
| | assert md5 == MD5_MAP[name], md5 |
| | return path |
| |
|
| |
|
| | class KeyNotFoundError(Exception): |
| | def __init__(self, cause, keys=None, visited=None): |
| | self.cause = cause |
| | self.keys = keys |
| | self.visited = visited |
| | messages = list() |
| | if keys is not None: |
| | messages.append("Key not found: {}".format(keys)) |
| | if visited is not None: |
| | messages.append("Visited: {}".format(visited)) |
| | messages.append("Cause:\n{}".format(cause)) |
| | message = "\n".join(messages) |
| | super().__init__(message) |
| |
|
| |
|
| | def retrieve( |
| | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False |
| | ): |
| | """Given a nested list or dict return the desired value at key expanding |
| | callable nodes if necessary and :attr:`expand` is ``True``. The expansion |
| | is done in-place. |
| | |
| | Parameters |
| | ---------- |
| | list_or_dict : list or dict |
| | Possibly nested list or dictionary. |
| | key : str |
| | key/to/value, path like string describing all keys necessary to |
| | consider to get to the desired value. List indices can also be |
| | passed here. |
| | splitval : str |
| | String that defines the delimiter between keys of the |
| | different depth levels in `key`. |
| | default : obj |
| | Value returned if :attr:`key` is not found. |
| | expand : bool |
| | Whether to expand callable nodes on the path or not. |
| | |
| | Returns |
| | ------- |
| | The desired value or if :attr:`default` is not ``None`` and the |
| | :attr:`key` is not found returns ``default``. |
| | |
| | Raises |
| | ------ |
| | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is |
| | ``None``. |
| | """ |
| |
|
| | keys = key.split(splitval) |
| |
|
| | success = True |
| | try: |
| | visited = [] |
| | parent = None |
| | last_key = None |
| | for key in keys: |
| | if callable(list_or_dict): |
| | if not expand: |
| | raise KeyNotFoundError( |
| | ValueError( |
| | "Trying to get past callable node with expand=False." |
| | ), |
| | keys=keys, |
| | visited=visited, |
| | ) |
| | list_or_dict = list_or_dict() |
| | parent[last_key] = list_or_dict |
| |
|
| | last_key = key |
| | parent = list_or_dict |
| |
|
| | try: |
| | if isinstance(list_or_dict, dict): |
| | list_or_dict = list_or_dict[key] |
| | else: |
| | list_or_dict = list_or_dict[int(key)] |
| | except (KeyError, IndexError, ValueError) as e: |
| | raise KeyNotFoundError(e, keys=keys, visited=visited) |
| |
|
| | visited += [key] |
| | |
| | if expand and callable(list_or_dict): |
| | list_or_dict = list_or_dict() |
| | parent[last_key] = list_or_dict |
| | except KeyNotFoundError as e: |
| | if default is None: |
| | raise e |
| | else: |
| | list_or_dict = default |
| | success = False |
| |
|
| | if not pass_success: |
| | return list_or_dict |
| | else: |
| | return list_or_dict, success |
| |
|
| |
|
| | if __name__ == "__main__": |
| | config = {"keya": "a", |
| | "keyb": "b", |
| | "keyc": |
| | {"cc1": 1, |
| | "cc2": 2, |
| | } |
| | } |
| | from omegaconf import OmegaConf |
| |
|
| | config = OmegaConf.create(config) |
| | print(config) |
| | retrieve(config, "keya") |
| |
|