from omegaconf import OmegaConf, DictConfig, ListConfig from typing import Set, Any, Optional, Union import json from pathlib import Path class AccessTrackedConfig: """ Wrapper for OmegaConf to track accessed parameters. Only saves configuration items that were actually accessed during execution. """ _original_cfg_snapshot: Optional[OmegaConf] = None def __init__(self, cfg: Union[DictConfig, ListConfig], parent: 'AccessTrackedConfig' = None, key_path: str = ""): object.__setattr__(self, '_cfg', cfg) object.__setattr__(self, '_parent', parent) object.__setattr__(self, '_key_path', key_path) object.__setattr__(self, '_local_accessed', set()) object.__setattr__(self, '_children', {}) if parent is None: AccessTrackedConfig._original_cfg_snapshot = OmegaConf.create( OmegaConf.to_container(cfg, resolve=True) ) def _is_list_config(self) -> bool: """Check if underlying config is a ListConfig""" return isinstance(self._cfg, ListConfig) def _is_dict_config(self) -> bool: """Check if underlying config is a DictConfig""" return isinstance(self._cfg, DictConfig) def __getattr__(self, name: str) -> Any: if name.startswith('_'): return object.__getattribute__(self, name) self._local_accessed.add(name) # Use safe access: for hasattr() semantics, raise AttributeError on missing keys try: value = self._cfg[name] except Exception: raise AttributeError(f"Config has no attribute '{name}'") if OmegaConf.is_config(value): new_path = f"{self._key_path}.{name}" if self._key_path else name if name not in self._children: self._children[name] = AccessTrackedConfig(value, parent=self, key_path=new_path) return self._children[name] return value def __getitem__(self, key) -> Any: """Support both dict-style and list-style access""" if isinstance(key, int): # List-style access self._local_accessed.add(f"[{key}]") value = self._cfg[key] if OmegaConf.is_config(value): new_path = f"{self._key_path}[{key}]" if self._key_path else f"[{key}]" cache_key = f"[{key}]" if cache_key not in self._children: self._children[cache_key] = AccessTrackedConfig(value, parent=self, key_path=new_path) return self._children[cache_key] return value else: # Dict-style access return self.__getattr__(key) def __setattr__(self, name: str, value: Any): if name.startswith('_'): object.__setattr__(self, name, value) else: self._local_accessed.add(name) self._cfg[name] = value # Invalidate child cache if exists if name in self._children: del self._children[name] def __setitem__(self, key, value: Any): """Support both dict-style and list-style setting""" if isinstance(key, int): self._local_accessed.add(f"[{key}]") self._cfg[key] = value cache_key = f"[{key}]" if cache_key in self._children: del self._children[cache_key] else: self._local_accessed.add(key) self._cfg[key] = value if key in self._children: del self._children[key] def __contains__(self, key) -> bool: """Support 'in' operator - tracks the key check as an access""" if isinstance(key, int): self._local_accessed.add(f"[{key}]") else: self._local_accessed.add(key) return key in self._cfg def __len__(self) -> int: """Return number of keys/items""" return len(self._cfg) def __iter__(self): """Support iteration for both DictConfig and ListConfig""" if self._is_list_config(): # For ListConfig, iterate over indices and return values for i in range(len(self._cfg)): self._local_accessed.add(f"[{i}]") return iter(self._cfg) else: # For DictConfig, iterate over keys for key in self._cfg.keys(): self._local_accessed.add(key) return iter(self._cfg) def __repr__(self) -> str: """String representation""" if self._is_list_config(): return f"AccessTrackedConfig({self._key_path or 'root'}, list_len={len(self._cfg)})" return f"AccessTrackedConfig({self._key_path or 'root'}, keys={list(self._cfg.keys())})" def __str__(self) -> str: """String representation""" return OmegaConf.to_yaml(self._cfg) def __bool__(self) -> bool: """Boolean evaluation - True if config has any keys/items""" return len(self._cfg) > 0 def __eq__(self, other) -> bool: """Equality comparison""" if isinstance(other, AccessTrackedConfig): return self._cfg == other._cfg elif OmegaConf.is_config(other): return self._cfg == other elif isinstance(other, (dict, list)): return OmegaConf.to_container(self._cfg, resolve=True) == other return False def keys(self): """Return config keys (required for dict unpacking) Tracks all keys as accessed. Only works for DictConfig. """ if self._is_list_config(): raise TypeError("ListConfig does not support keys()") for key in self._cfg.keys(): self._local_accessed.add(key) return self._cfg.keys() def values(self): """Return config values (tracks all keys as accessed)""" if self._is_list_config(): for i in range(len(self._cfg)): self._local_accessed.add(f"[{i}]") yield self[i] else: for key in self._cfg.keys(): self._local_accessed.add(key) yield self.get(key) def items(self): """Return config items (tracks all keys as accessed)""" if self._is_list_config(): raise TypeError("ListConfig does not support items()") for key in self._cfg.keys(): self._local_accessed.add(key) yield key, self.get(key) def get(self, key: str, default: Any = None) -> Any: """Get value with default fallback""" self._local_accessed.add(key) value = self._cfg.get(key, default) if value is not default and OmegaConf.is_config(value): new_path = f"{self._key_path}.{key}" if self._key_path else key if key not in self._children: self._children[key] = AccessTrackedConfig(value, parent=self, key_path=new_path) return self._children[key] return value def update(self, other: Any = None, **kwargs): """Update config with values from another dict/config""" if self._is_list_config(): raise TypeError("ListConfig does not support update()") if other is not None: # Handle different input types if isinstance(other, AccessTrackedConfig): other = OmegaConf.to_container(other._cfg, resolve=True) elif OmegaConf.is_config(other): other = OmegaConf.to_container(other, resolve=True) elif hasattr(other, 'items'): # Dict-like object other = dict(other.items()) elif hasattr(other, '__iter__'): # Iterable of key-value pairs other = dict(other) else: raise TypeError(f"Cannot update from {type(other)}") for key, value in other.items(): self._local_accessed.add(key) self._cfg[key] = value # Invalidate child cache if exists if key in self._children: del self._children[key] for key, value in kwargs.items(): self._local_accessed.add(key) self._cfg[key] = value if key in self._children: del self._children[key] def pop(self, key, *args): """Remove and return a value""" if isinstance(key, int): self._local_accessed.add(f"[{key}]") cache_key = f"[{key}]" else: self._local_accessed.add(key) cache_key = key if cache_key in self._children: del self._children[cache_key] if args: return self._cfg.pop(key, args[0]) return self._cfg.pop(key) def append(self, value: Any): """Append value to list (only for ListConfig)""" if not self._is_list_config(): raise TypeError("append() only supported for ListConfig") self._cfg.append(value) idx = len(self._cfg) - 1 self._local_accessed.add(f"[{idx}]") def extend(self, values): """Extend list with values (only for ListConfig)""" if not self._is_list_config(): raise TypeError("extend() only supported for ListConfig") start_idx = len(self._cfg) self._cfg.extend(values) for i in range(start_idx, len(self._cfg)): self._local_accessed.add(f"[{i}]") def setdefault(self, key: str, default: Any = None) -> Any: """Set default value if key doesn't exist""" if self._is_list_config(): raise TypeError("ListConfig does not support setdefault()") self._local_accessed.add(key) if key not in self._cfg: self._cfg[key] = default return self.get(key) def copy(self) -> 'AccessTrackedConfig': """Return a shallow copy (does not copy access tracking state)""" new_cfg = OmegaConf.create(OmegaConf.to_container(self._cfg, resolve=True)) return AccessTrackedConfig(new_cfg) def deepcopy(self) -> 'AccessTrackedConfig': """Return a deep copy (does not copy access tracking state)""" new_cfg = OmegaConf.create(OmegaConf.to_container(self._cfg, resolve=True)) return AccessTrackedConfig(new_cfg) def merge_with(self, *others) -> 'AccessTrackedConfig': """Merge with other configs and return new tracked config""" configs = [self._cfg] for other in others: if isinstance(other, AccessTrackedConfig): configs.append(other._cfg) elif OmegaConf.is_config(other): configs.append(other) else: configs.append(OmegaConf.create(other)) merged = OmegaConf.merge(*configs) return AccessTrackedConfig(merged) def to_dict(self, resolve: bool = True) -> dict: """Convert to plain dictionary or list""" return OmegaConf.to_container(self._cfg, resolve=resolve) def to_yaml(self, resolve: bool = False) -> str: """Convert to YAML string""" return OmegaConf.to_yaml(self._cfg, resolve=resolve) def unwrap(self) -> Union[DictConfig, ListConfig]: """Get the underlying OmegaConf object""" return self._cfg def get_root(self) -> 'AccessTrackedConfig': """Get root config object""" current = self while current._parent is not None: current = current._parent return current def _collect_all_paths(self, node: 'AccessTrackedConfig' = None, prefix: str = "") -> Set[str]: """Recursively collect all accessed paths""" if node is None: node = self.get_root() paths = set() for key in node._local_accessed: current_path = f"{prefix}.{key}" if prefix and not key.startswith('[') else f"{prefix}{key}" if prefix else key paths.add(current_path) if key in node._children: paths.update(self._collect_all_paths(node._children[key], current_path)) return paths def _filter_leaf_paths(self, paths: Set[str]) -> Set[str]: """Filter to only leaf paths (no sub-paths)""" if not paths: return set() leaf_paths = set() for path in paths: # Check if any other path starts with this path followed by . or [ is_leaf = True for other in paths: if other != path: if other.startswith(f"{path}.") or other.startswith(f"{path}["): is_leaf = False break if is_leaf: leaf_paths.add(path) return leaf_paths @staticmethod def _get_nested_value(cfg, path: str) -> Any: """Get nested value through dot-separated path with bracket notation support""" import re value = cfg # Split by . but keep bracket notation together parts = re.split(r'\.(?![^\[]*\])', path) for part in parts: # Handle bracket notation like [0] bracket_match = re.match(r'\[(\d+)\]', part) if bracket_match: idx = int(bracket_match.group(1)) value = value[idx] elif '[' in part: # Handle cases like "key[0]" key_part, rest = part.split('[', 1) if key_part: value = value[key_part] indices = re.findall(r'\[(\d+)\]', '[' + rest) for idx_str in indices: value = value[int(idx_str)] else: value = value[part] return OmegaConf.to_container(value, resolve=True) if OmegaConf.is_config(value) else value @staticmethod def _set_nested_value(d: dict, path: str, value: Any): """Set nested value through dot-separated path""" import re parts = re.split(r'\.(?![^\[]*\])', path) for i, part in enumerate(parts[:-1]): bracket_match = re.match(r'\[(\d+)\]', part) if bracket_match: idx = int(bracket_match.group(1)) while len(d) <= idx: d.append({}) d = d[idx] elif '[' in part: key_part, rest = part.split('[', 1) if key_part: d = d.setdefault(key_part, {}) indices = re.findall(r'\[(\d+)\]', '[' + rest) for idx_str in indices: idx = int(idx_str) if isinstance(d, list): while len(d) <= idx: d.append({}) d = d[idx] else: d = d.setdefault(idx, {}) else: d = d.setdefault(part, {}) # Set final value last_part = parts[-1] bracket_match = re.match(r'\[(\d+)\]', last_part) if bracket_match: idx = int(bracket_match.group(1)) while len(d) <= idx: d.append(None) d[idx] = value elif '[' in last_part: key_part, rest = last_part.split('[', 1) if key_part: d = d.setdefault(key_part, []) indices = re.findall(r'\[(\d+)\]', '[' + rest) for idx_str in indices[:-1]: idx = int(idx_str) while len(d) <= idx: d.append([]) d = d[idx] final_idx = int(indices[-1]) while len(d) <= final_idx: d.append(None) d[final_idx] = value else: d[last_part] = value def export_accessed_config(self, use_original_values: bool = True) -> dict: """Export accessed configuration as dictionary (only leaf values)""" all_paths = self._collect_all_paths() leaf_paths = self._filter_leaf_paths(all_paths) source_cfg = AccessTrackedConfig._original_cfg_snapshot if use_original_values else self.get_root()._cfg result = {} for path in sorted(leaf_paths): try: value = self._get_nested_value(source_cfg, path) self._set_nested_value(result, path, value) except Exception: if use_original_values: try: value = self._get_nested_value(self.get_root()._cfg, path) self._set_nested_value(result, path, value) except Exception: pass return result def save_accessed_config(self, filepath: Path, use_original_values: bool = True): """Save accessed configuration to file""" accessed_config = self.export_accessed_config(use_original_values=use_original_values) filepath = Path(filepath) filepath.parent.mkdir(parents=True, exist_ok=True) with open(filepath, 'w') as f: if filepath.suffix == '.json': json.dump(accessed_config, f, indent=2) elif filepath.suffix in ('.yaml', '.yml'): OmegaConf.save(OmegaConf.create(accessed_config), f) else: raise ValueError(f"Unsupported file format: {filepath.suffix}") def get_access_summary(self) -> dict: """Get summary of accessed configuration""" all_paths = self._collect_all_paths() leaf_paths = self._filter_leaf_paths(all_paths) return { "total_accessed_keys": len(all_paths), "leaf_accessed_keys": len(leaf_paths), "leaf_accessed_paths": sorted(leaf_paths), "top_level_keys": sorted(self.get_root()._local_accessed) } def print_access_summary(self): """Print a formatted summary of accessed configuration""" summary = self.get_access_summary() print(f"\n{'='*60}") print("Configuration Access Summary") print(f"{'='*60}") print(f"Total accessed keys: {summary['total_accessed_keys']}") print(f"Leaf accessed keys: {summary['leaf_accessed_keys']}") print(f"\nTop-level keys accessed: {summary['top_level_keys']}") print(f"\nLeaf paths accessed:") for path in summary['leaf_accessed_paths']: print(f" - {path}") print(f"{'='*60}\n") def wrap_config(cfg: OmegaConf) -> AccessTrackedConfig: """Wrap OmegaConf configuration to enable access tracking""" return AccessTrackedConfig(cfg) def unwrap_config(cfg) -> OmegaConf: """Unwrap AccessTrackedConfig to get underlying OmegaConf object""" return cfg.unwrap() if isinstance(cfg, AccessTrackedConfig) else cfg # ========== Monkey Patch OmegaConf for Compatibility ========== _original_to_container = OmegaConf.to_container _original_save = OmegaConf.save _original_to_yaml = OmegaConf.to_yaml _original_is_config = OmegaConf.is_config _original_merge = OmegaConf.merge def _patched_to_container(cfg, resolve=True, enum_to_str=False, structured_config_mode=None): """Patched OmegaConf.to_container that handles AccessTrackedConfig""" if isinstance(cfg, AccessTrackedConfig): cfg = cfg.unwrap() try: if structured_config_mode is not None: return _original_to_container(cfg, resolve=resolve, enum_to_str=enum_to_str, structured_config_mode=structured_config_mode) else: return _original_to_container(cfg, resolve=resolve, enum_to_str=enum_to_str) except TypeError: return _original_to_container(cfg, resolve=resolve) def _patched_save(config, f, resolve=False): """Patched OmegaConf.save that handles AccessTrackedConfig""" if isinstance(config, AccessTrackedConfig): config = config.unwrap() return _original_save(config, f, resolve=resolve) def _patched_to_yaml(cfg, resolve=False, sort_keys=False): """Patched OmegaConf.to_yaml that handles AccessTrackedConfig""" if isinstance(cfg, AccessTrackedConfig): cfg = cfg.unwrap() try: return _original_to_yaml(cfg, resolve=resolve, sort_keys=sort_keys) except TypeError: return _original_to_yaml(cfg, resolve=resolve) def _patched_is_config(obj): """Patched OmegaConf.is_config that handles AccessTrackedConfig""" return True if isinstance(obj, AccessTrackedConfig) else _original_is_config(obj) def _patched_merge(*configs): """Patched OmegaConf.merge that handles AccessTrackedConfig""" unwrapped_configs = [] for cfg in configs: if isinstance(cfg, AccessTrackedConfig): unwrapped_configs.append(cfg.unwrap()) else: unwrapped_configs.append(cfg) return _original_merge(*unwrapped_configs) # Apply patches OmegaConf.to_container = _patched_to_container OmegaConf.save = _patched_save OmegaConf.to_yaml = _patched_to_yaml OmegaConf.is_config = _patched_is_config OmegaConf.merge = _patched_merge