Spaces:
Sleeping
Sleeping
| # SOURCE: https://github.com/Sulam-Group/IBYDMT/blob/main/ibydmt/utils/config.py | |
| import os | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from itertools import product | |
| from typing import Any, Iterable, Mapping, Optional, Union | |
| import torch | |
| from ml_collections import ConfigDict | |
| from numpy import ndarray | |
| Array = Union[ndarray, torch.Tensor] | |
| class TestType(Enum): | |
| GLOBAL = "global" | |
| GLOBAL_COND = "global_cond" | |
| LOCAL_COND = "local_cond" | |
| class ConceptType(Enum): | |
| DATASET = "dataset" | |
| CLASS = "class" | |
| IMAGE = "image" | |
| class Constants: | |
| WORKDIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| class DataConfig(ConfigDict): | |
| def __init__(self, config_dict: Optional[Mapping[str, Any]] = None): | |
| super().__init__() | |
| if config_dict is None: | |
| config_dict = {} | |
| self.dataset: str = config_dict.get("dataset", None) | |
| self.backbone: str = config_dict.get("backbone", None) | |
| self.bottleneck: str = config_dict.get("bottleneck", None) | |
| self.classifier: str = config_dict.get("classifier", None) | |
| self.sampler: str = config_dict.get("sampler", None) | |
| self.num_concepts: int = config_dict.get("num_concepts", None) | |
| class SpliceConfig(ConfigDict): | |
| def __init__(self, config_dict: Optional[Mapping[str, Any]] = None): | |
| super().__init__() | |
| if config_dict is None: | |
| config_dict = {} | |
| self.vocab: str = config_dict.get("vocab", None) | |
| self.vocab_size: int = config_dict.get("vocab_size", None) | |
| self.l1_penalty: float = config_dict.get("l1_penalty", None) | |
| class PCBMConfig(ConfigDict): | |
| def __init__(self, config_dict: Optional[Mapping[str, Any]] = None): | |
| super().__init__() | |
| if config_dict is None: | |
| config_dict = {} | |
| self.alpha: float = config_dict.get("alpha", None) | |
| self.l1_ratio: float = config_dict.get("l1_ratio", None) | |
| class cKDEConfig(ConfigDict): | |
| def __init__(self, config_dict: Optional[Mapping[str, Any]] = None): | |
| super().__init__() | |
| if config_dict is None: | |
| config_dict = {} | |
| self.metric: str = config_dict.get("metric", None) | |
| self.scale_method: str = config_dict.get("scale_method", None) | |
| self.scale: float = config_dict.get("scale", None) | |
| class TestingConfig(ConfigDict): | |
| def __init__(self, config_dict: Optional[Mapping[str, Any]] = None): | |
| super().__init__() | |
| if config_dict is None: | |
| config_dict = {} | |
| self.significance_level: float = config_dict.get("significance_level", None) | |
| self.wealth: str = config_dict.get("wealth", None) | |
| self.bet: str = config_dict.get("bet", None) | |
| self.kernel: str = config_dict.get("kernel", None) | |
| self.kernel_scale_method: str = config_dict.get("kernel_scale_method", None) | |
| self.kernel_scale: float = config_dict.get("kernel_scale", None) | |
| self.tau_max: int = config_dict.get("tau_max", None) | |
| self.images_per_class: int = config_dict.get("images_per_class", None) | |
| self.r: int = config_dict.get("r", None) | |
| self.cardinality: Iterable[int] = config_dict.get("cardinality", None) | |
| class Config(ConfigDict): | |
| def __init__(self, config_dict: Optional[Mapping[str, Any]] = None): | |
| super().__init__() | |
| if config_dict is None: | |
| config_dict = {} | |
| self.name: str = config_dict.get("name", None) | |
| self.data = DataConfig(config_dict.get("data", None)) | |
| self.splice = SpliceConfig(config_dict.get("splice", None)) | |
| self.pcbm = PCBMConfig(config_dict.get("pcbm", None)) | |
| self.ckde = cKDEConfig(config_dict.get("ckde", None)) | |
| self.testing = TestingConfig(config_dict.get("testing", None)) | |
| def backbone_name(self): | |
| backbone = self.data.backbone.strip().lower() | |
| return backbone.replace("/", "_").replace(":", "_") | |
| def sweep(self, keys: Iterable[str]): | |
| def _get(dict, key): | |
| keys = key.split(".") | |
| if len(keys) == 1: | |
| return dict[keys[0]] | |
| else: | |
| return _get(dict[keys[0]], ".".join(keys[1:])) | |
| def _set(dict, key, value): | |
| keys = key.split(".") | |
| if len(keys) == 1: | |
| dict[keys[0]] = value | |
| else: | |
| _set(dict[keys[0]], ".".join(keys[1:]), value) | |
| to_iterable = lambda v: v if isinstance(v, list) else [v] | |
| config_dict = self.to_dict() | |
| sweep_values = [_get(config_dict, key) for key in keys] | |
| sweep = list(product(*map(to_iterable, sweep_values))) | |
| configs: Iterable[Config] = [] | |
| for _sweep in sweep: | |
| _config_dict = config_dict.copy() | |
| for key, value in zip(keys, _sweep): | |
| _set(_config_dict, key, value) | |
| configs.append(Config(_config_dict)) | |
| return configs | |
| def register_config(name: str): | |
| def register(cls: Config): | |
| if name in configs: | |
| raise ValueError(f"Config {name} is already registered") | |
| configs[name] = cls | |
| return register | |
| def get_config(name: str) -> Config: | |
| return configs[name]() | |
| configs: Mapping[str, Config] = {} | |