""" 统一的 YAML 配置加载。 优先级: 1. WEB2API_CONFIG_PATH 指定的路径 2. 项目根目录下的 config.local.yaml 3. 项目根目录下的 config.yaml 同时支持通过环境变量覆盖单个配置项: - 通用规则:WEB2API_
_ - 额外兼容:server.host -> HOST,server.port -> PORT """ from __future__ import annotations import os from pathlib import Path from typing import Any import yaml _PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent _CONFIG_ENV_KEY = "WEB2API_CONFIG_PATH" _LOCAL_CONFIG_NAME = "config.local.yaml" _DEFAULT_CONFIG_NAME = "config.yaml" _ENV_MISSING = object() _ENV_OVERRIDE_ALIASES: dict[tuple[str, str], tuple[str, ...]] = { ("server", "host"): ("HOST",), ("server", "port"): ("PORT",), } _DATABASE_URL_ENV_NAMES = ("WEB2API_DATABASE_URL", "DATABASE_URL") _BOOL_TRUE_VALUES = {"1", "true", "yes", "on"} _BOOL_FALSE_VALUES = {"0", "false", "no", "off"} def _resolve_config_path() -> Path: configured = os.environ.get(_CONFIG_ENV_KEY, "").strip() if configured: return Path(configured).expanduser() local_config = _PROJECT_ROOT / _LOCAL_CONFIG_NAME if local_config.exists(): return local_config return _PROJECT_ROOT / _DEFAULT_CONFIG_NAME _CONFIG_PATH = _resolve_config_path() _config_cache: dict[str, Any] | None = None def _env_override_names(section: str, key: str) -> tuple[str, ...]: generic = f"WEB2API_{section}_{key}".upper().replace("-", "_") aliases = _ENV_OVERRIDE_ALIASES.get((section, key), ()) ordered = [generic] ordered.extend(alias for alias in aliases if alias != generic) return tuple(ordered) def _get_env_override(section: str, key: str) -> Any: for env_name in _env_override_names(section, key): if env_name in os.environ: return os.environ[env_name] return _ENV_MISSING def has_env_override(section: str, key: str) -> bool: return _get_env_override(section, key) is not _ENV_MISSING def get_config_path() -> Path: return _CONFIG_PATH def reset_cache() -> None: global _config_cache _config_cache = None def load_config() -> dict[str, Any]: """按优先级加载配置文件,不存在时返回空 dict。""" global _config_cache if _config_cache is not None: return _config_cache if not _CONFIG_PATH.exists(): _config_cache = {} return {} try: with _CONFIG_PATH.open("r", encoding="utf-8") as f: data = yaml.safe_load(f) or {} if not isinstance(data, dict): _config_cache = {} else: _config_cache = dict(data) except Exception: _config_cache = {} return _config_cache def get(section: str, key: str, default: Any = None) -> Any: """从配置中读取 section.key,环境变量优先,其次 YAML,最后返回 default。""" env_override = _get_env_override(section, key) if env_override is not _ENV_MISSING: return env_override cfg = load_config().get(section) or {} if not isinstance(cfg, dict): return default val = cfg.get(key) return val if val is not None else default def coerce_bool(value: Any, default: bool = False) -> bool: if isinstance(value, bool): return value if isinstance(value, (int, float)): return bool(value) if isinstance(value, str): normalized = value.strip().lower() if normalized in _BOOL_TRUE_VALUES: return True if normalized in _BOOL_FALSE_VALUES: return False return bool(default) def get_bool(section: str, key: str, default: bool = False) -> bool: """从配置读取布尔值,兼容 true/false、1/0、yes/no、on/off。""" return coerce_bool(get(section, key, default), default) def get_server_host(default: str = "127.0.0.1") -> str: return str(get("server", "host") or default).strip() or default def get_server_port(default: int = 8001) -> int: try: return int(str(get("server", "port") or default).strip()) except Exception: return default def get_database_url(default: str = "") -> str: for env_name in _DATABASE_URL_ENV_NAMES: value = os.environ.get(env_name, "").strip() if value: return value return default