File size: 4,312 Bytes
77169b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""
统一的 YAML 配置加载。

优先级:
1. WEB2API_CONFIG_PATH 指定的路径
2. 项目根目录下的 config.local.yaml
3. 项目根目录下的 config.yaml

同时支持通过环境变量覆盖单个配置项:
- 通用规则:WEB2API_<SECTION>_<KEY>
- 额外兼容:server.host -> HOST,server.port -> PORT
"""

from __future__ import annotations

import os
from pathlib import Path
from typing import Any

import yaml


_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
_CONFIG_ENV_KEY = "WEB2API_CONFIG_PATH"
_LOCAL_CONFIG_NAME = "config.local.yaml"
_DEFAULT_CONFIG_NAME = "config.yaml"
_ENV_MISSING = object()
_ENV_OVERRIDE_ALIASES: dict[tuple[str, str], tuple[str, ...]] = {
    ("server", "host"): ("HOST",),
    ("server", "port"): ("PORT",),
}
_DATABASE_URL_ENV_NAMES = ("WEB2API_DATABASE_URL", "DATABASE_URL")
_BOOL_TRUE_VALUES = {"1", "true", "yes", "on"}
_BOOL_FALSE_VALUES = {"0", "false", "no", "off"}


def _resolve_config_path() -> Path:
    configured = os.environ.get(_CONFIG_ENV_KEY, "").strip()
    if configured:
        return Path(configured).expanduser()
    local_config = _PROJECT_ROOT / _LOCAL_CONFIG_NAME
    if local_config.exists():
        return local_config
    return _PROJECT_ROOT / _DEFAULT_CONFIG_NAME


_CONFIG_PATH = _resolve_config_path()

_config_cache: dict[str, Any] | None = None


def _env_override_names(section: str, key: str) -> tuple[str, ...]:
    generic = f"WEB2API_{section}_{key}".upper().replace("-", "_")
    aliases = _ENV_OVERRIDE_ALIASES.get((section, key), ())
    ordered = [generic]
    ordered.extend(alias for alias in aliases if alias != generic)
    return tuple(ordered)


def _get_env_override(section: str, key: str) -> Any:
    for env_name in _env_override_names(section, key):
        if env_name in os.environ:
            return os.environ[env_name]
    return _ENV_MISSING


def has_env_override(section: str, key: str) -> bool:
    return _get_env_override(section, key) is not _ENV_MISSING


def get_config_path() -> Path:
    return _CONFIG_PATH


def reset_cache() -> None:
    global _config_cache
    _config_cache = None


def load_config() -> dict[str, Any]:
    """按优先级加载配置文件,不存在时返回空 dict。"""
    global _config_cache
    if _config_cache is not None:
        return _config_cache
    if not _CONFIG_PATH.exists():
        _config_cache = {}
        return {}
    try:
        with _CONFIG_PATH.open("r", encoding="utf-8") as f:
            data = yaml.safe_load(f) or {}
        if not isinstance(data, dict):
            _config_cache = {}
        else:
            _config_cache = dict(data)
    except Exception:
        _config_cache = {}
    return _config_cache


def get(section: str, key: str, default: Any = None) -> Any:
    """从配置中读取 section.key,环境变量优先,其次 YAML,最后返回 default。"""
    env_override = _get_env_override(section, key)
    if env_override is not _ENV_MISSING:
        return env_override
    cfg = load_config().get(section) or {}
    if not isinstance(cfg, dict):
        return default
    val = cfg.get(key)
    return val if val is not None else default


def coerce_bool(value: Any, default: bool = False) -> bool:
    if isinstance(value, bool):
        return value
    if isinstance(value, (int, float)):
        return bool(value)
    if isinstance(value, str):
        normalized = value.strip().lower()
        if normalized in _BOOL_TRUE_VALUES:
            return True
        if normalized in _BOOL_FALSE_VALUES:
            return False
    return bool(default)


def get_bool(section: str, key: str, default: bool = False) -> bool:
    """从配置读取布尔值,兼容 true/false、1/0、yes/no、on/off。"""
    return coerce_bool(get(section, key, default), default)


def get_server_host(default: str = "127.0.0.1") -> str:
    return str(get("server", "host") or default).strip() or default


def get_server_port(default: int = 8001) -> int:
    try:
        return int(str(get("server", "port") or default).strip())
    except Exception:
        return default


def get_database_url(default: str = "") -> str:
    for env_name in _DATABASE_URL_ENV_NAMES:
        value = os.environ.get(env_name, "").strip()
        if value:
            return value
    return default