web2api / core /config /settings.py
ohmyapi's picture
feat: align hosted Space deployment with latest upstream
77169b4
"""
统一的 YAML 配置加载。
优先级:
1. WEB2API_CONFIG_PATH 指定的路径
2. 项目根目录下的 config.local.yaml
3. 项目根目录下的 config.yaml
同时支持通过环境变量覆盖单个配置项:
- 通用规则:WEB2API_<SECTION>_<KEY>
- 额外兼容:server.host -> HOST,server.port -> PORT
"""
from __future__ import annotations
import os
from pathlib import Path
from typing import Any
import yaml
_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
_CONFIG_ENV_KEY = "WEB2API_CONFIG_PATH"
_LOCAL_CONFIG_NAME = "config.local.yaml"
_DEFAULT_CONFIG_NAME = "config.yaml"
_ENV_MISSING = object()
_ENV_OVERRIDE_ALIASES: dict[tuple[str, str], tuple[str, ...]] = {
("server", "host"): ("HOST",),
("server", "port"): ("PORT",),
}
_DATABASE_URL_ENV_NAMES = ("WEB2API_DATABASE_URL", "DATABASE_URL")
_BOOL_TRUE_VALUES = {"1", "true", "yes", "on"}
_BOOL_FALSE_VALUES = {"0", "false", "no", "off"}
def _resolve_config_path() -> Path:
configured = os.environ.get(_CONFIG_ENV_KEY, "").strip()
if configured:
return Path(configured).expanduser()
local_config = _PROJECT_ROOT / _LOCAL_CONFIG_NAME
if local_config.exists():
return local_config
return _PROJECT_ROOT / _DEFAULT_CONFIG_NAME
_CONFIG_PATH = _resolve_config_path()
_config_cache: dict[str, Any] | None = None
def _env_override_names(section: str, key: str) -> tuple[str, ...]:
generic = f"WEB2API_{section}_{key}".upper().replace("-", "_")
aliases = _ENV_OVERRIDE_ALIASES.get((section, key), ())
ordered = [generic]
ordered.extend(alias for alias in aliases if alias != generic)
return tuple(ordered)
def _get_env_override(section: str, key: str) -> Any:
for env_name in _env_override_names(section, key):
if env_name in os.environ:
return os.environ[env_name]
return _ENV_MISSING
def has_env_override(section: str, key: str) -> bool:
return _get_env_override(section, key) is not _ENV_MISSING
def get_config_path() -> Path:
return _CONFIG_PATH
def reset_cache() -> None:
global _config_cache
_config_cache = None
def load_config() -> dict[str, Any]:
"""按优先级加载配置文件,不存在时返回空 dict。"""
global _config_cache
if _config_cache is not None:
return _config_cache
if not _CONFIG_PATH.exists():
_config_cache = {}
return {}
try:
with _CONFIG_PATH.open("r", encoding="utf-8") as f:
data = yaml.safe_load(f) or {}
if not isinstance(data, dict):
_config_cache = {}
else:
_config_cache = dict(data)
except Exception:
_config_cache = {}
return _config_cache
def get(section: str, key: str, default: Any = None) -> Any:
"""从配置中读取 section.key,环境变量优先,其次 YAML,最后返回 default。"""
env_override = _get_env_override(section, key)
if env_override is not _ENV_MISSING:
return env_override
cfg = load_config().get(section) or {}
if not isinstance(cfg, dict):
return default
val = cfg.get(key)
return val if val is not None else default
def coerce_bool(value: Any, default: bool = False) -> bool:
if isinstance(value, bool):
return value
if isinstance(value, (int, float)):
return bool(value)
if isinstance(value, str):
normalized = value.strip().lower()
if normalized in _BOOL_TRUE_VALUES:
return True
if normalized in _BOOL_FALSE_VALUES:
return False
return bool(default)
def get_bool(section: str, key: str, default: bool = False) -> bool:
"""从配置读取布尔值,兼容 true/false、1/0、yes/no、on/off。"""
return coerce_bool(get(section, key, default), default)
def get_server_host(default: str = "127.0.0.1") -> str:
return str(get("server", "host") or default).strip() or default
def get_server_port(default: int = 8001) -> int:
try:
return int(str(get("server", "port") or default).strip())
except Exception:
return default
def get_database_url(default: str = "") -> str:
for env_name in _DATABASE_URL_ENV_NAMES:
value = os.environ.get(env_name, "").strip()
if value:
return value
return default