File size: 4,312 Bytes
77169b4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | """
统一的 YAML 配置加载。
优先级:
1. WEB2API_CONFIG_PATH 指定的路径
2. 项目根目录下的 config.local.yaml
3. 项目根目录下的 config.yaml
同时支持通过环境变量覆盖单个配置项:
- 通用规则:WEB2API_<SECTION>_<KEY>
- 额外兼容:server.host -> HOST,server.port -> PORT
"""
from __future__ import annotations
import os
from pathlib import Path
from typing import Any
import yaml
_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
_CONFIG_ENV_KEY = "WEB2API_CONFIG_PATH"
_LOCAL_CONFIG_NAME = "config.local.yaml"
_DEFAULT_CONFIG_NAME = "config.yaml"
_ENV_MISSING = object()
_ENV_OVERRIDE_ALIASES: dict[tuple[str, str], tuple[str, ...]] = {
("server", "host"): ("HOST",),
("server", "port"): ("PORT",),
}
_DATABASE_URL_ENV_NAMES = ("WEB2API_DATABASE_URL", "DATABASE_URL")
_BOOL_TRUE_VALUES = {"1", "true", "yes", "on"}
_BOOL_FALSE_VALUES = {"0", "false", "no", "off"}
def _resolve_config_path() -> Path:
configured = os.environ.get(_CONFIG_ENV_KEY, "").strip()
if configured:
return Path(configured).expanduser()
local_config = _PROJECT_ROOT / _LOCAL_CONFIG_NAME
if local_config.exists():
return local_config
return _PROJECT_ROOT / _DEFAULT_CONFIG_NAME
_CONFIG_PATH = _resolve_config_path()
_config_cache: dict[str, Any] | None = None
def _env_override_names(section: str, key: str) -> tuple[str, ...]:
generic = f"WEB2API_{section}_{key}".upper().replace("-", "_")
aliases = _ENV_OVERRIDE_ALIASES.get((section, key), ())
ordered = [generic]
ordered.extend(alias for alias in aliases if alias != generic)
return tuple(ordered)
def _get_env_override(section: str, key: str) -> Any:
for env_name in _env_override_names(section, key):
if env_name in os.environ:
return os.environ[env_name]
return _ENV_MISSING
def has_env_override(section: str, key: str) -> bool:
return _get_env_override(section, key) is not _ENV_MISSING
def get_config_path() -> Path:
return _CONFIG_PATH
def reset_cache() -> None:
global _config_cache
_config_cache = None
def load_config() -> dict[str, Any]:
"""按优先级加载配置文件,不存在时返回空 dict。"""
global _config_cache
if _config_cache is not None:
return _config_cache
if not _CONFIG_PATH.exists():
_config_cache = {}
return {}
try:
with _CONFIG_PATH.open("r", encoding="utf-8") as f:
data = yaml.safe_load(f) or {}
if not isinstance(data, dict):
_config_cache = {}
else:
_config_cache = dict(data)
except Exception:
_config_cache = {}
return _config_cache
def get(section: str, key: str, default: Any = None) -> Any:
"""从配置中读取 section.key,环境变量优先,其次 YAML,最后返回 default。"""
env_override = _get_env_override(section, key)
if env_override is not _ENV_MISSING:
return env_override
cfg = load_config().get(section) or {}
if not isinstance(cfg, dict):
return default
val = cfg.get(key)
return val if val is not None else default
def coerce_bool(value: Any, default: bool = False) -> bool:
if isinstance(value, bool):
return value
if isinstance(value, (int, float)):
return bool(value)
if isinstance(value, str):
normalized = value.strip().lower()
if normalized in _BOOL_TRUE_VALUES:
return True
if normalized in _BOOL_FALSE_VALUES:
return False
return bool(default)
def get_bool(section: str, key: str, default: bool = False) -> bool:
"""从配置读取布尔值,兼容 true/false、1/0、yes/no、on/off。"""
return coerce_bool(get(section, key, default), default)
def get_server_host(default: str = "127.0.0.1") -> str:
return str(get("server", "host") or default).strip() or default
def get_server_port(default: int = 8001) -> int:
try:
return int(str(get("server", "port") or default).strip())
except Exception:
return default
def get_database_url(default: str = "") -> str:
for env_name in _DATABASE_URL_ENV_NAMES:
value = os.environ.get(env_name, "").strip()
if value:
return value
return default
|