Spaces:
Paused
Paused
| """ | |
| 统一的 YAML 配置加载。 | |
| 优先级: | |
| 1. WEB2API_CONFIG_PATH 指定的路径 | |
| 2. 项目根目录下的 config.local.yaml | |
| 3. 项目根目录下的 config.yaml | |
| 同时支持通过环境变量覆盖单个配置项: | |
| - 通用规则:WEB2API_<SECTION>_<KEY> | |
| - 额外兼容:server.host -> HOST,server.port -> PORT | |
| """ | |
| from __future__ import annotations | |
| import os | |
| from pathlib import Path | |
| from typing import Any | |
| import yaml | |
| _PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent | |
| _CONFIG_ENV_KEY = "WEB2API_CONFIG_PATH" | |
| _LOCAL_CONFIG_NAME = "config.local.yaml" | |
| _DEFAULT_CONFIG_NAME = "config.yaml" | |
| _ENV_MISSING = object() | |
| _ENV_OVERRIDE_ALIASES: dict[tuple[str, str], tuple[str, ...]] = { | |
| ("server", "host"): ("HOST",), | |
| ("server", "port"): ("PORT",), | |
| } | |
| _DATABASE_URL_ENV_NAMES = ("WEB2API_DATABASE_URL", "DATABASE_URL") | |
| _BOOL_TRUE_VALUES = {"1", "true", "yes", "on"} | |
| _BOOL_FALSE_VALUES = {"0", "false", "no", "off"} | |
| def _resolve_config_path() -> Path: | |
| configured = os.environ.get(_CONFIG_ENV_KEY, "").strip() | |
| if configured: | |
| return Path(configured).expanduser() | |
| local_config = _PROJECT_ROOT / _LOCAL_CONFIG_NAME | |
| if local_config.exists(): | |
| return local_config | |
| return _PROJECT_ROOT / _DEFAULT_CONFIG_NAME | |
| _CONFIG_PATH = _resolve_config_path() | |
| _config_cache: dict[str, Any] | None = None | |
| def _env_override_names(section: str, key: str) -> tuple[str, ...]: | |
| generic = f"WEB2API_{section}_{key}".upper().replace("-", "_") | |
| aliases = _ENV_OVERRIDE_ALIASES.get((section, key), ()) | |
| ordered = [generic] | |
| ordered.extend(alias for alias in aliases if alias != generic) | |
| return tuple(ordered) | |
| def _get_env_override(section: str, key: str) -> Any: | |
| for env_name in _env_override_names(section, key): | |
| if env_name in os.environ: | |
| return os.environ[env_name] | |
| return _ENV_MISSING | |
| def has_env_override(section: str, key: str) -> bool: | |
| return _get_env_override(section, key) is not _ENV_MISSING | |
| def get_config_path() -> Path: | |
| return _CONFIG_PATH | |
| def reset_cache() -> None: | |
| global _config_cache | |
| _config_cache = None | |
| def load_config() -> dict[str, Any]: | |
| """按优先级加载配置文件,不存在时返回空 dict。""" | |
| global _config_cache | |
| if _config_cache is not None: | |
| return _config_cache | |
| if not _CONFIG_PATH.exists(): | |
| _config_cache = {} | |
| return {} | |
| try: | |
| with _CONFIG_PATH.open("r", encoding="utf-8") as f: | |
| data = yaml.safe_load(f) or {} | |
| if not isinstance(data, dict): | |
| _config_cache = {} | |
| else: | |
| _config_cache = dict(data) | |
| except Exception: | |
| _config_cache = {} | |
| return _config_cache | |
| def get(section: str, key: str, default: Any = None) -> Any: | |
| """从配置中读取 section.key,环境变量优先,其次 YAML,最后返回 default。""" | |
| env_override = _get_env_override(section, key) | |
| if env_override is not _ENV_MISSING: | |
| return env_override | |
| cfg = load_config().get(section) or {} | |
| if not isinstance(cfg, dict): | |
| return default | |
| val = cfg.get(key) | |
| return val if val is not None else default | |
| def coerce_bool(value: Any, default: bool = False) -> bool: | |
| if isinstance(value, bool): | |
| return value | |
| if isinstance(value, (int, float)): | |
| return bool(value) | |
| if isinstance(value, str): | |
| normalized = value.strip().lower() | |
| if normalized in _BOOL_TRUE_VALUES: | |
| return True | |
| if normalized in _BOOL_FALSE_VALUES: | |
| return False | |
| return bool(default) | |
| def get_bool(section: str, key: str, default: bool = False) -> bool: | |
| """从配置读取布尔值,兼容 true/false、1/0、yes/no、on/off。""" | |
| return coerce_bool(get(section, key, default), default) | |
| def get_server_host(default: str = "127.0.0.1") -> str: | |
| return str(get("server", "host") or default).strip() or default | |
| def get_server_port(default: int = 8001) -> int: | |
| try: | |
| return int(str(get("server", "port") or default).strip()) | |
| except Exception: | |
| return default | |
| def get_database_url(default: str = "") -> str: | |
| for env_name in _DATABASE_URL_ENV_NAMES: | |
| value = os.environ.get(env_name, "").strip() | |
| if value: | |
| return value | |
| return default | |