dod-uno / inference_mapper.py
elismasilva's picture
update flag disable logs
4be4d85
from __future__ import annotations
import json
import os
import threading
import time
from pathlib import Path
from typing import Any
import requests
from dotenv import load_dotenv
from dod_logging import log_error, log_info
load_dotenv(override=True)
EndpointConfig = dict[str, Any]
PLACEHOLDER_SECRET_VALUES = {
"your_token",
"your_huggingface_token",
"your_hf_token",
"hf_token",
"token",
}
MAPPER_DATASET_REPO_ID = os.getenv("DOD_INFERENCE_MAPPER_DATASET_REPO_ID", "elismasilva/dod-inference-mapper")
MAPPER_DATASET_REVISION = os.getenv("DOD_INFERENCE_MAPPER_DATASET_REVISION", "main")
MAPPER_DATASET_PATH = "inference_map.json"
MAPPER_URL = os.getenv(
"DOD_INFERENCE_MAPPER_URL",
f"https://huggingface.co/datasets/{MAPPER_DATASET_REPO_ID}/raw/{MAPPER_DATASET_REVISION}/{MAPPER_DATASET_PATH}",
)
MAPPER_CACHE_TTL_SECONDS = float(os.getenv("DOD_INFERENCE_MAPPER_TTL_SECONDS", "60"))
ENDPOINT_FAILURE_COOLDOWN_SECONDS = float(os.getenv("DOD_ENDPOINT_FAILURE_COOLDOWN_SECONDS", "180"))
ENDPOINT_WARMUP_TIMEOUT_SECONDS = float(os.getenv("DOD_ENDPOINT_WARMUP_TIMEOUT_SECONDS", "75"))
_mapper_lock = threading.Lock()
_cached_mapper: dict[str, Any] | None = None
_last_mapper_update = 0.0
_endpoint_cooldowns: dict[tuple[str, str], float] = {}
def _refresh_env() -> None:
"""Reload local .env values so development flags override stale shell values."""
load_dotenv(override=True)
def _env_enabled(name: str, fallback_name: str | None = None) -> bool:
"""Return whether an environment flag is truthy."""
_refresh_env()
value = os.getenv(name)
if value is None and fallback_name:
value = os.getenv(fallback_name, "")
return str(value or "").lower() in {"1", "true", "yes", "on"}
def _optional_env_secret(name: str) -> str:
"""Return an environment secret while ignoring blank or placeholder values."""
_refresh_env()
value = os.getenv(name, "").strip().strip("\"'")
if not value or value.lower() in PLACEHOLDER_SECRET_VALUES:
return ""
return value
def _local_data_dir() -> Path:
"""Return the local data directory used when local data mode is enabled."""
_refresh_env()
return Path(os.getenv("DOD_LOCAL_DATA_DIR", Path.home() / ".dod")).expanduser()
def _local_mapper_path() -> Path:
"""Return the local inference mapper JSON path."""
_refresh_env()
default_path = _local_data_dir() / Path(MAPPER_DATASET_PATH).name
return Path(os.getenv("DOD_LOCAL_INFERENCE_MAPPER_PATH", default_path)).expanduser()
def _service_priority(service: str) -> str:
"""Return the configured endpoint priority for one service."""
_refresh_env()
env_name = "LLM_URL_PRIORITY" if service == "llm" else "TTS_URL_PRIORITY"
priority = os.getenv(env_name, "primary").strip().lower()
if priority not in {"primary", "fallback"}:
log_error(f"[Mapper] Ignored invalid {env_name}={priority!r}. Using primary.", flush=True)
return "primary"
return priority
def _default_endpoint(service: str) -> EndpointConfig:
"""Return the local environment fallback endpoint for a service."""
_refresh_env()
if service == "llm":
return {
"name": "env-llm",
"url": os.getenv("LLM_URL", "https://elismasilva-voxcpm2-nanovllm-service.hf.space"),
"mode": "gradio",
"api_name": "/generate_inference",
}
if service == "tts":
tts_url = os.getenv("TTS_API_URL", "http://127.0.0.1:8000/generate_api")
tts_mode = os.getenv("TTS_API_MODE", "rest")
if tts_mode == "gradio" and tts_url.rstrip("/").endswith("/generate_api"):
tts_url = tts_url.rstrip("/")[: -len("/generate_api")]
return {
"name": "env-tts",
"url": tts_url,
"mode": tts_mode,
"api_name": "/generate_api",
}
return {"name": f"env-{service}", "url": "", "mode": "rest"}
def _normalize_endpoint(raw_endpoint: Any, service: str, role: str) -> EndpointConfig | None:
"""Normalize one mapper entry into a consistent endpoint dictionary."""
if isinstance(raw_endpoint, str):
raw_endpoint = {"url": raw_endpoint}
if not isinstance(raw_endpoint, dict):
return None
default = _default_endpoint(service)
mode = str(raw_endpoint.get("mode", default.get("mode", "rest"))).strip().lower()
url = str(raw_endpoint.get("url") or raw_endpoint.get("space") or raw_endpoint.get("src") or "").strip()
is_http_url = url.startswith("http")
is_gradio_space_id = mode == "gradio" and "/" in url and " " not in url
if not is_http_url and not is_gradio_space_id:
return None
api_name = str(raw_endpoint.get("api_name", default.get("api_name", ""))).strip()
if api_name and not api_name.startswith("/"):
api_name = f"/{api_name}"
if mode == "gradio" and is_http_url and api_name and url.rstrip("/").endswith(api_name):
url = url.rstrip("/")[: -len(api_name)]
timeout = float(raw_endpoint.get("timeout", 120.0))
warmup_timeout = float(raw_endpoint.get("warmup_timeout", max(timeout, ENDPOINT_WARMUP_TIMEOUT_SECONDS)))
return {
"name": str(raw_endpoint.get("name", role)),
"url": url,
"mode": mode,
"api_name": api_name,
"timeout": timeout,
"warmup_timeout": warmup_timeout,
"cooldown_seconds": float(raw_endpoint.get("cooldown_seconds", ENDPOINT_FAILURE_COOLDOWN_SECONDS)),
}
def _extract_service_endpoints(mapper: dict[str, Any], service: str) -> list[EndpointConfig]:
"""Extract primary and fallback endpoints from mapper JSON."""
service_config = mapper.get(service, {})
endpoints: list[EndpointConfig] = []
if isinstance(service_config, list):
raw_entries = service_config
elif isinstance(service_config, dict):
raw_entries = []
if "primary" in service_config:
raw_entries.append(service_config["primary"])
if "fallback" in service_config:
raw_entries.append(service_config["fallback"])
raw_entries.extend(service_config.get("fallbacks", []))
if "url" in service_config:
raw_entries.insert(0, service_config)
else:
raw_entries = [service_config]
seen_urls = set()
for idx, raw_entry in enumerate(raw_entries):
endpoint = _normalize_endpoint(raw_entry, service, "primary" if idx == 0 else f"fallback-{idx}")
if not endpoint:
log_error(f"[Mapper] Ignored invalid {service} endpoint entry: {raw_entry}", flush=True)
continue
if endpoint["url"] in seen_urls:
continue
seen_urls.add(endpoint["url"])
endpoints.append(endpoint)
return endpoints
def _fetch_mapper() -> dict[str, Any]:
"""Fetch the mapper JSON from local disk or the remote dataset."""
if _env_enabled("DOD_USE_LOCAL_DATA"):
local_path = _local_mapper_path()
try:
with local_path.open(mode="r", encoding="utf-8") as mapper_file:
mapper = json.load(mapper_file)
if isinstance(mapper, dict):
log_info(f"[Mapper] Loaded local inference mapper from {local_path}", flush=True)
return mapper
log_error(f"[Mapper] Local mapper at {local_path} is not a JSON object. Using environment defaults.", flush=True)
except FileNotFoundError:
log_error(f"[Mapper] Local mapper not found at {local_path}. Using environment defaults.", flush=True)
except Exception as exc:
log_error(f"[Mapper] Failed loading local inference mapper at {local_path}: {exc}", flush=True)
return {}
try:
_refresh_env()
hf_token = _optional_env_secret("HF_TOKEN_DATASET")
headers = {"Authorization": f"Bearer {hf_token}"} if hf_token else {}
response = requests.get(MAPPER_URL, headers=headers, timeout=3.0)
if response.status_code == 200:
mapper = response.json()
if isinstance(mapper, dict):
log_info(f"[Mapper] Loaded inference mapper from {MAPPER_URL}", flush=True)
return mapper
log_error("[Mapper] Remote mapper is not a JSON object. Using environment defaults.", flush=True)
else:
log_error(f"[Mapper] Remote mapper failed with status {response.status_code}.", flush=True)
except Exception as exc:
log_error(f"[Mapper] Failed fetching inference mapper, using defaults: {exc}", flush=True)
return {}
def get_inference_mapper() -> dict[str, Any]:
"""Return cached mapper JSON, refreshing it after the configured TTL."""
global _cached_mapper, _last_mapper_update
if _env_enabled("DOD_USE_LOCAL_API"):
return {}
now = time.time()
with _mapper_lock:
if _cached_mapper is not None and now - _last_mapper_update < MAPPER_CACHE_TTL_SECONDS:
return _cached_mapper
_cached_mapper = _fetch_mapper()
_last_mapper_update = now
return _cached_mapper
def mark_endpoint_failed(service: str, endpoint: EndpointConfig, reason: str) -> None:
"""Temporarily skip an endpoint after a runtime failure.
Args:
service: Service name, such as llm or tts.
endpoint: Endpoint configuration that failed.
reason: Short failure reason for logs.
"""
url = endpoint.get("url", "")
if not url:
return
cooldown = float(endpoint.get("cooldown_seconds", ENDPOINT_FAILURE_COOLDOWN_SECONDS))
retry_at = time.time() + cooldown
with _mapper_lock:
_endpoint_cooldowns[(service, url)] = retry_at
log_info(f"[Mapper] Disabled {service} endpoint for {cooldown:.0f}s after failure: {url} ({reason})", flush=True)
def mark_endpoint_success(service: str, endpoint: EndpointConfig) -> None:
"""Clear a previously marked endpoint failure after a successful call."""
url = endpoint.get("url", "")
if not url:
return
with _mapper_lock:
_endpoint_cooldowns.pop((service, url), None)
def get_endpoint_chain(service: str) -> list[EndpointConfig]:
"""Return available endpoints for a service."""
if _env_enabled("DOD_USE_LOCAL_API"):
endpoint = _default_endpoint(service)
if endpoint.get("url"):
log_info(f"[Mapper] DOD_USE_LOCAL_API=True. Using local {service} endpoint: {endpoint['url']}", flush=True)
return [endpoint]
return []
mapper = get_inference_mapper()
endpoints = _extract_service_endpoints(mapper, service) if mapper else []
if _service_priority(service) == "fallback" and len(endpoints) > 1:
priority_env = "LLM_URL_PRIORITY" if service == "llm" else "TTS_URL_PRIORITY"
log_info(f"[Mapper] {priority_env}=fallback. Trying mapped fallback before primary for {service}.", flush=True)
endpoints = endpoints[1:] + endpoints[:1]
if not endpoints:
log_error(f"[Mapper] No mapped {service} endpoints found. Set DOD_USE_LOCAL_API=True to use local environment URLs.", flush=True)
return []
now = time.time()
available = [
endpoint
for endpoint in endpoints
if now >= _endpoint_cooldowns.get((service, endpoint["url"]), 0.0)
]
skipped_count = len(endpoints) - len(available)
if skipped_count:
log_info(f"[Mapper] Skipping {skipped_count} cooling-down {service} endpoint(s).", flush=True)
selected = available or endpoints
if selected:
names = ", ".join(f"{endpoint.get('name', 'endpoint')}={endpoint['url']}" for endpoint in selected)
log_info(f"[Mapper] Active {service} endpoint chain: {names}", flush=True)
return selected