DylanJHJ/APRIL / src /autollmrerank /config_manager.py
DylanJHJ's picture
download
raw
3.38 kB
import argparse
import yaml
import os
from typing import Any, Dict
from pprint import pprint
import importlib.resources as pkg_resources
DEFAULT_CONFIG_PATH = pkg_resources.files("autollmrerank.configs").joinpath("default.yaml")
class ConfigManager:
def __init__(self, path=None, **kwargs):
self.config = self.load_yaml(path or DEFAULT_CONFIG_PATH)
self.parse_and_override()
self.apply_overrides(self.config, kwargs)
def load_yaml(self, path: str) -> Dict[str, Any]:
path = os.path.abspath(path)
if not os.path.isfile(path):
raise FileNotFoundError(f"Config file not found: {path}")
with open(path, "r") as f:
return yaml.safe_load(f)
def parse_and_override(self):
parser = argparse.ArgumentParser()
# NOTE: this is not necessary as `rerank_mode` has already used
# Checj the existed config
parser.add_argument("--config", type=str, default=None, help="Path to YAML config file")
# Capture all --key=value args
args, unknown = parser.parse_known_args()
if args.config:
self.config.update(self.load_yaml(args.config))
overrides = self.parse_unknown_args(unknown)
self.apply_overrides(self.config, overrides)
def parse_unknown_args(self, args: list) -> Dict[str, Any]:
result = {}
for arg in args:
if not arg.startswith("--") or "=" not in arg:
continue
key, value = arg[2:].split("=", 1)
self.insert_nested_key(result, key, self._infer_type(value))
return result
def insert_nested_key(self, d: Dict[str, Any], key: str, value: Any):
"""Insert value into nested dictionary using dot notation."""
keys = key.split(".")
current = d
for k in keys[:-1]:
if k not in current or not isinstance(current[k], dict):
current[k] = {}
current = current[k]
current[keys[-1]] = value
def apply_overrides(self, base: Dict[str, Any], overrides: Dict[str, Any]):
"""Recursively apply overrides to the base config."""
for key, value in overrides.items():
if isinstance(value, dict) and isinstance(base.get(key), dict):
self.apply_overrides(base[key], value)
else:
base[key] = value
def _infer_type(self, value: str) -> Any:
if value.lower() in {"true", "false"}:
print(f"Converting '{value}' to boolean")
return value.lower() == "true"
if value.lower() in {"none", "null"}:
print(f"Converting '{value}' to None")
return None
try:
return int(value)
except ValueError:
pass
try:
return float(value)
except ValueError:
return value
def get_config(self, return_dict=False) -> Dict[str, Any]:
if return_dict:
return self.config
def dict_to_namespace(d):
namespace = argparse.Namespace()
for key, value in d.items():
if isinstance(value, dict):
setattr(namespace, key, dict_to_namespace(value))
else:
setattr(namespace, key, value)
return namespace
return dict_to_namespace(self.config)

Xet Storage Details

Size:
3.38 kB
·
Xet hash:
1dcf007f098bb4caac927f9c8156c23d90c17a7c2a17b598ae8d0076a352a60f

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.