| """Configurations for the project.""" | |
| from pathlib import Path | |
| from typing import Any, Dict | |
| import torch | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| repo_path = Path(__file__).parent.parent.absolute() | |
| output_path = repo_path / "models" | |
| config_dict = { | |
| "Ng": 32, | |
| "D": 256, | |
| "condition_dim": 100, | |
| "noise_dim": 100, | |
| "lr_config": { | |
| "disc_lr": 2e-4, | |
| "gen_lr": 2e-4, | |
| "img_encoder_lr": 3e-3, | |
| "text_encoder_lr": 3e-3, | |
| }, | |
| "batch_size": 64, | |
| "device": device, | |
| "epochs": 200, | |
| "output_dir": output_path, | |
| "snapshot": 5, | |
| "const_dict": { | |
| "smooth_val_gen": 0.999, | |
| "lambda1": 1, | |
| "lambda2": 1, | |
| "lambda3": 1, | |
| "lambda4": 1, | |
| "gamma1": 4, | |
| "gamma2": 5, | |
| "gamma3": 10, | |
| }, | |
| } | |
| def update_config(cfg_dict: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: | |
| """ | |
| Function to update the configuration dictionary. | |
| """ | |
| for key, value in kwargs.items(): | |
| cfg_dict[key] = value | |
| return cfg_dict | |