| import os |
| import os.path as op |
| root = os.sep + os.sep.join(__file__.split(os.sep)[1:__file__.split(os.sep).index("Recurrent-Parameter-Generation")+1]) |
|
|
|
|
|
|
|
|
| assert op.exists(root), "Cannot find the executing root." |
| assert op.basename(root) == "Recurrent-Parameter-Generation", \ |
| f""" |
| You need to rename the repository folder to "Recurrent-Parameter-Generation" manually. |
| Because the whole project depends on this name. |
| The file structure is as follow: |
| └─Recurrent-Parameter-Generation |
| ├─dataset |
| │ ├─cifar10_cnnmedium |
| │ ├─...(total 21 folders) |
| │ ├─__init__.py |
| │ ├─config.json |
| │ ├─dataset.py |
| │ └─register.py |
| ├─model |
| │ ├─__init__.py |
| │ ├─denoiser.py |
| │ ├─diffusion.py |
| │ └─...(total 8 files) |
| ├─quick_start |
| │ ├─set_configs.py |
| │ └─auto_start.sh |
| ├─workspace |
| │ ├─main |
| │ ├─evaluate |
| │ ├─...(total 6 folders) |
| │ └─config.json |
| ├─README.md |
| └─requirements.txt |
| """ |
|
|
|
|
| print("\n1. Set an \033[91mABSOLUTE\033[0m path to download your small dataset, such as CIFAR10 and CIFAR100") |
| default_dataset_root = op.join(op.dirname(op.abspath(root)), 'Dataset') |
| dataset_root = input(f"[{default_dataset_root} (default & \033[32mrecommanded\033[0m)]: ") or default_dataset_root |
| print(f"\033[32mdataset_root is set to {dataset_root}\033[0m") |
|
|
|
|
| print("\n2. Set the \033[91mABSOLUTE\033[0m path to your ImageNet1k dataset. " |
| "\033[32m(Press ENTER if you don't want to use ImageNet1k)\033[0m") |
| print("""The ImageNet1k dataset should be organized as follow: |
| └─ImageNet1k |
| ├─train |
| │ ├─n01443537 |
| │ ├─n01484850 |
| │ ├─n######## |
| └─test |
| ├─n01443537 |
| ├─n01484850 |
| └─n########""") |
| imagenet_root = input(f"[None (default)]: ") |
| if imagenet_root == "": |
| print("\033[32mWe don't use ImageNet1k.\033[0m") |
| imagenet_root_train = None |
| imagenet_root_test = None |
| else: |
| print(f"\033[32mimagenet_root is set to {imagenet_root}\033[0m") |
| imagenet_root_train = op.join(imagenet_root, "train") |
| imagenet_root_test = op.join(imagenet_root, "test") |
| assert op.exists(imagenet_root_train), f"{imagenet_root_train} is not existed." |
| assert op.exists(imagenet_root_test), f"{imagenet_root_test} is not existed." |
|
|
|
|
| print("\n3. Do you want to use wandb?") |
| default_use_wandb = True |
| use_wandb = input("[True (default & \033[32mrecommanded\033[0m)) / False]: ") |
| use_wandb = default_use_wandb if use_wandb == "" else eval(use_wandb) |
| print(f"\033[32muse_wandb is set to {use_wandb}\033[0m") |
| if use_wandb: |
| wandb_api_key = input("Set your wandb api key: ") |
| assert wandb_api_key != "", "You need to set an API_KEY is you want to use wandb." |
|
|
|
|
|
|
|
|
| print() |
| import json |
| from pprint import pprint |
|
|
| |
| print() |
| with open(op.join(root, "dataset/config.json"), "r") as f: |
| dataset_config = json.load(f) |
| dataset_config.update({ |
| "dataset_root": dataset_root, |
| "imagenet_root": { |
| "train": imagenet_root_train, |
| "test": imagenet_root_test, |
| }, |
| }) |
| with open(op.join(root, "dataset/config.json"), "w") as f: |
| print("\033[32mUpdated dataset/config.json as follow:\033[0m") |
| pprint(dataset_config) |
| json.dump(dataset_config, f) |
|
|
| |
| print() |
| with open(op.join(root, "workspace/config.json"), "r") as f: |
| workspace_config = json.load(f) |
| workspace_config.update({ |
| "use_wandb": use_wandb, |
| "wandb_api_key": globals().get("wandb_api_key", None), |
| }) |
| with open(op.join(root, "workspace/config.json"), "w") as f: |
| print("\033[32mUpdated workspace/config.json as follow:\033[0m") |
| pprint(workspace_config) |
| json.dump(workspace_config, f) |
| print() |