File size: 11,717 Bytes
1bd6269
 
 
 
 
 
 
 
 
 
 
 
4be4d85
 
1bd6269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4be4d85
1bd6269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4be4d85
1bd6269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4be4d85
1bd6269
4be4d85
1bd6269
4be4d85
1bd6269
4be4d85
1bd6269
 
 
 
 
 
 
 
 
 
4be4d85
1bd6269
4be4d85
1bd6269
4be4d85
1bd6269
4be4d85
1bd6269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4be4d85
1bd6269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4be4d85
1bd6269
 
 
 
 
 
 
 
4be4d85
1bd6269
 
 
4be4d85
1bd6269
 
 
 
 
 
 
 
 
 
4be4d85
1bd6269
 
 
 
4be4d85
1bd6269
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
from __future__ import annotations

import json
import os
import threading
import time
from pathlib import Path
from typing import Any

import requests
from dotenv import load_dotenv

from dod_logging import log_error, log_info


load_dotenv(override=True)

EndpointConfig = dict[str, Any]
PLACEHOLDER_SECRET_VALUES = {
    "your_token",
    "your_huggingface_token",
    "your_hf_token",
    "hf_token",
    "token",
}

MAPPER_DATASET_REPO_ID = os.getenv("DOD_INFERENCE_MAPPER_DATASET_REPO_ID", "elismasilva/dod-inference-mapper")
MAPPER_DATASET_REVISION = os.getenv("DOD_INFERENCE_MAPPER_DATASET_REVISION", "main")
MAPPER_DATASET_PATH = "inference_map.json"
MAPPER_URL = os.getenv(
    "DOD_INFERENCE_MAPPER_URL",
    f"https://huggingface.co/datasets/{MAPPER_DATASET_REPO_ID}/raw/{MAPPER_DATASET_REVISION}/{MAPPER_DATASET_PATH}",
)
MAPPER_CACHE_TTL_SECONDS = float(os.getenv("DOD_INFERENCE_MAPPER_TTL_SECONDS", "60"))
ENDPOINT_FAILURE_COOLDOWN_SECONDS = float(os.getenv("DOD_ENDPOINT_FAILURE_COOLDOWN_SECONDS", "180"))
ENDPOINT_WARMUP_TIMEOUT_SECONDS = float(os.getenv("DOD_ENDPOINT_WARMUP_TIMEOUT_SECONDS", "75"))
_mapper_lock = threading.Lock()
_cached_mapper: dict[str, Any] | None = None
_last_mapper_update = 0.0
_endpoint_cooldowns: dict[tuple[str, str], float] = {}


def _refresh_env() -> None:
    """Reload local .env values so development flags override stale shell values."""
    load_dotenv(override=True)


def _env_enabled(name: str, fallback_name: str | None = None) -> bool:
    """Return whether an environment flag is truthy."""
    _refresh_env()
    value = os.getenv(name)
    if value is None and fallback_name:
        value = os.getenv(fallback_name, "")
    return str(value or "").lower() in {"1", "true", "yes", "on"}


def _optional_env_secret(name: str) -> str:
    """Return an environment secret while ignoring blank or placeholder values."""
    _refresh_env()
    value = os.getenv(name, "").strip().strip("\"'")
    if not value or value.lower() in PLACEHOLDER_SECRET_VALUES:
        return ""
    return value


def _local_data_dir() -> Path:
    """Return the local data directory used when local data mode is enabled."""
    _refresh_env()
    return Path(os.getenv("DOD_LOCAL_DATA_DIR", Path.home() / ".dod")).expanduser()


def _local_mapper_path() -> Path:
    """Return the local inference mapper JSON path."""
    _refresh_env()
    default_path = _local_data_dir() / Path(MAPPER_DATASET_PATH).name
    return Path(os.getenv("DOD_LOCAL_INFERENCE_MAPPER_PATH", default_path)).expanduser()


def _service_priority(service: str) -> str:
    """Return the configured endpoint priority for one service."""
    _refresh_env()
    env_name = "LLM_URL_PRIORITY" if service == "llm" else "TTS_URL_PRIORITY"
    priority = os.getenv(env_name, "primary").strip().lower()
    if priority not in {"primary", "fallback"}:
        log_error(f"[Mapper] Ignored invalid {env_name}={priority!r}. Using primary.", flush=True)
        return "primary"
    return priority


def _default_endpoint(service: str) -> EndpointConfig:
    """Return the local environment fallback endpoint for a service."""
    _refresh_env()
    if service == "llm":
        return {
            "name": "env-llm",
            "url": os.getenv("LLM_URL", "https://elismasilva-voxcpm2-nanovllm-service.hf.space"),
            "mode": "gradio",
            "api_name": "/generate_inference",
        }
    if service == "tts":
        tts_url = os.getenv("TTS_API_URL", "http://127.0.0.1:8000/generate_api")
        tts_mode = os.getenv("TTS_API_MODE", "rest")
        if tts_mode == "gradio" and tts_url.rstrip("/").endswith("/generate_api"):
            tts_url = tts_url.rstrip("/")[: -len("/generate_api")]
        return {
            "name": "env-tts",
            "url": tts_url,
            "mode": tts_mode,
            "api_name": "/generate_api",
        }
    return {"name": f"env-{service}", "url": "", "mode": "rest"}


def _normalize_endpoint(raw_endpoint: Any, service: str, role: str) -> EndpointConfig | None:
    """Normalize one mapper entry into a consistent endpoint dictionary."""
    if isinstance(raw_endpoint, str):
        raw_endpoint = {"url": raw_endpoint}
    if not isinstance(raw_endpoint, dict):
        return None

    default = _default_endpoint(service)
    mode = str(raw_endpoint.get("mode", default.get("mode", "rest"))).strip().lower()
    url = str(raw_endpoint.get("url") or raw_endpoint.get("space") or raw_endpoint.get("src") or "").strip()
    is_http_url = url.startswith("http")
    is_gradio_space_id = mode == "gradio" and "/" in url and " " not in url
    if not is_http_url and not is_gradio_space_id:
        return None

    api_name = str(raw_endpoint.get("api_name", default.get("api_name", ""))).strip()
    if api_name and not api_name.startswith("/"):
        api_name = f"/{api_name}"
    if mode == "gradio" and is_http_url and api_name and url.rstrip("/").endswith(api_name):
        url = url.rstrip("/")[: -len(api_name)]

    timeout = float(raw_endpoint.get("timeout", 120.0))
    warmup_timeout = float(raw_endpoint.get("warmup_timeout", max(timeout, ENDPOINT_WARMUP_TIMEOUT_SECONDS)))

    return {
        "name": str(raw_endpoint.get("name", role)),
        "url": url,
        "mode": mode,
        "api_name": api_name,
        "timeout": timeout,
        "warmup_timeout": warmup_timeout,
        "cooldown_seconds": float(raw_endpoint.get("cooldown_seconds", ENDPOINT_FAILURE_COOLDOWN_SECONDS)),
    }


def _extract_service_endpoints(mapper: dict[str, Any], service: str) -> list[EndpointConfig]:
    """Extract primary and fallback endpoints from mapper JSON."""
    service_config = mapper.get(service, {})
    endpoints: list[EndpointConfig] = []

    if isinstance(service_config, list):
        raw_entries = service_config
    elif isinstance(service_config, dict):
        raw_entries = []
        if "primary" in service_config:
            raw_entries.append(service_config["primary"])
        if "fallback" in service_config:
            raw_entries.append(service_config["fallback"])
        raw_entries.extend(service_config.get("fallbacks", []))
        if "url" in service_config:
            raw_entries.insert(0, service_config)
    else:
        raw_entries = [service_config]

    seen_urls = set()
    for idx, raw_entry in enumerate(raw_entries):
        endpoint = _normalize_endpoint(raw_entry, service, "primary" if idx == 0 else f"fallback-{idx}")
        if not endpoint:
            log_error(f"[Mapper] Ignored invalid {service} endpoint entry: {raw_entry}", flush=True)
            continue
        if endpoint["url"] in seen_urls:
            continue
        seen_urls.add(endpoint["url"])
        endpoints.append(endpoint)

    return endpoints


def _fetch_mapper() -> dict[str, Any]:
    """Fetch the mapper JSON from local disk or the remote dataset."""
    if _env_enabled("DOD_USE_LOCAL_DATA"):
        local_path = _local_mapper_path()
        try:
            with local_path.open(mode="r", encoding="utf-8") as mapper_file:
                mapper = json.load(mapper_file)
            if isinstance(mapper, dict):
                log_info(f"[Mapper] Loaded local inference mapper from {local_path}", flush=True)
                return mapper
            log_error(f"[Mapper] Local mapper at {local_path} is not a JSON object. Using environment defaults.", flush=True)
        except FileNotFoundError:
            log_error(f"[Mapper] Local mapper not found at {local_path}. Using environment defaults.", flush=True)
        except Exception as exc:
            log_error(f"[Mapper] Failed loading local inference mapper at {local_path}: {exc}", flush=True)
        return {}

    try:
        _refresh_env()
        hf_token = _optional_env_secret("HF_TOKEN_DATASET")
        headers = {"Authorization": f"Bearer {hf_token}"} if hf_token else {}
        response = requests.get(MAPPER_URL, headers=headers, timeout=3.0)
        if response.status_code == 200:
            mapper = response.json()
            if isinstance(mapper, dict):
                log_info(f"[Mapper] Loaded inference mapper from {MAPPER_URL}", flush=True)
                return mapper
            log_error("[Mapper] Remote mapper is not a JSON object. Using environment defaults.", flush=True)
        else:
            log_error(f"[Mapper] Remote mapper failed with status {response.status_code}.", flush=True)
    except Exception as exc:
        log_error(f"[Mapper] Failed fetching inference mapper, using defaults: {exc}", flush=True)
    return {}


def get_inference_mapper() -> dict[str, Any]:
    """Return cached mapper JSON, refreshing it after the configured TTL."""
    global _cached_mapper, _last_mapper_update

    if _env_enabled("DOD_USE_LOCAL_API"):
        return {}

    now = time.time()
    with _mapper_lock:
        if _cached_mapper is not None and now - _last_mapper_update < MAPPER_CACHE_TTL_SECONDS:
            return _cached_mapper

        _cached_mapper = _fetch_mapper()
        _last_mapper_update = now
        return _cached_mapper


def mark_endpoint_failed(service: str, endpoint: EndpointConfig, reason: str) -> None:
    """Temporarily skip an endpoint after a runtime failure.

    Args:
        service: Service name, such as llm or tts.
        endpoint: Endpoint configuration that failed.
        reason: Short failure reason for logs.
    """
    url = endpoint.get("url", "")
    if not url:
        return

    cooldown = float(endpoint.get("cooldown_seconds", ENDPOINT_FAILURE_COOLDOWN_SECONDS))
    retry_at = time.time() + cooldown
    with _mapper_lock:
        _endpoint_cooldowns[(service, url)] = retry_at
    log_info(f"[Mapper] Disabled {service} endpoint for {cooldown:.0f}s after failure: {url} ({reason})", flush=True)


def mark_endpoint_success(service: str, endpoint: EndpointConfig) -> None:
    """Clear a previously marked endpoint failure after a successful call."""
    url = endpoint.get("url", "")
    if not url:
        return

    with _mapper_lock:
        _endpoint_cooldowns.pop((service, url), None)


def get_endpoint_chain(service: str) -> list[EndpointConfig]:
    """Return available endpoints for a service."""
    if _env_enabled("DOD_USE_LOCAL_API"):
        endpoint = _default_endpoint(service)
        if endpoint.get("url"):
            log_info(f"[Mapper] DOD_USE_LOCAL_API=True. Using local {service} endpoint: {endpoint['url']}", flush=True)
            return [endpoint]
        return []

    mapper = get_inference_mapper()
    endpoints = _extract_service_endpoints(mapper, service) if mapper else []

    if _service_priority(service) == "fallback" and len(endpoints) > 1:
        priority_env = "LLM_URL_PRIORITY" if service == "llm" else "TTS_URL_PRIORITY"
        log_info(f"[Mapper] {priority_env}=fallback. Trying mapped fallback before primary for {service}.", flush=True)
        endpoints = endpoints[1:] + endpoints[:1]

    if not endpoints:
        log_error(f"[Mapper] No mapped {service} endpoints found. Set DOD_USE_LOCAL_API=True to use local environment URLs.", flush=True)
        return []

    now = time.time()
    available = [
        endpoint
        for endpoint in endpoints
        if now >= _endpoint_cooldowns.get((service, endpoint["url"]), 0.0)
    ]
    skipped_count = len(endpoints) - len(available)
    if skipped_count:
        log_info(f"[Mapper] Skipping {skipped_count} cooling-down {service} endpoint(s).", flush=True)

    selected = available or endpoints
    if selected:
        names = ", ".join(f"{endpoint.get('name', 'endpoint')}={endpoint['url']}" for endpoint in selected)
        log_info(f"[Mapper] Active {service} endpoint chain: {names}", flush=True)
    return selected