Spaces:
Running
Running
| from __future__ import annotations | |
| import json | |
| import os | |
| import threading | |
| import time | |
| from pathlib import Path | |
| from typing import Any | |
| import requests | |
| from dotenv import load_dotenv | |
| from dod_logging import log_error, log_info | |
| load_dotenv(override=True) | |
| EndpointConfig = dict[str, Any] | |
| PLACEHOLDER_SECRET_VALUES = { | |
| "your_token", | |
| "your_huggingface_token", | |
| "your_hf_token", | |
| "hf_token", | |
| "token", | |
| } | |
| MAPPER_DATASET_REPO_ID = os.getenv("DOD_INFERENCE_MAPPER_DATASET_REPO_ID", "elismasilva/dod-inference-mapper") | |
| MAPPER_DATASET_REVISION = os.getenv("DOD_INFERENCE_MAPPER_DATASET_REVISION", "main") | |
| MAPPER_DATASET_PATH = "inference_map.json" | |
| MAPPER_URL = os.getenv( | |
| "DOD_INFERENCE_MAPPER_URL", | |
| f"https://huggingface.co/datasets/{MAPPER_DATASET_REPO_ID}/raw/{MAPPER_DATASET_REVISION}/{MAPPER_DATASET_PATH}", | |
| ) | |
| MAPPER_CACHE_TTL_SECONDS = float(os.getenv("DOD_INFERENCE_MAPPER_TTL_SECONDS", "60")) | |
| ENDPOINT_FAILURE_COOLDOWN_SECONDS = float(os.getenv("DOD_ENDPOINT_FAILURE_COOLDOWN_SECONDS", "180")) | |
| ENDPOINT_WARMUP_TIMEOUT_SECONDS = float(os.getenv("DOD_ENDPOINT_WARMUP_TIMEOUT_SECONDS", "75")) | |
| _mapper_lock = threading.Lock() | |
| _cached_mapper: dict[str, Any] | None = None | |
| _last_mapper_update = 0.0 | |
| _endpoint_cooldowns: dict[tuple[str, str], float] = {} | |
| def _refresh_env() -> None: | |
| """Reload local .env values so development flags override stale shell values.""" | |
| load_dotenv(override=True) | |
| def _env_enabled(name: str, fallback_name: str | None = None) -> bool: | |
| """Return whether an environment flag is truthy.""" | |
| _refresh_env() | |
| value = os.getenv(name) | |
| if value is None and fallback_name: | |
| value = os.getenv(fallback_name, "") | |
| return str(value or "").lower() in {"1", "true", "yes", "on"} | |
| def _optional_env_secret(name: str) -> str: | |
| """Return an environment secret while ignoring blank or placeholder values.""" | |
| _refresh_env() | |
| value = os.getenv(name, "").strip().strip("\"'") | |
| if not value or value.lower() in PLACEHOLDER_SECRET_VALUES: | |
| return "" | |
| return value | |
| def _local_data_dir() -> Path: | |
| """Return the local data directory used when local data mode is enabled.""" | |
| _refresh_env() | |
| return Path(os.getenv("DOD_LOCAL_DATA_DIR", Path.home() / ".dod")).expanduser() | |
| def _local_mapper_path() -> Path: | |
| """Return the local inference mapper JSON path.""" | |
| _refresh_env() | |
| default_path = _local_data_dir() / Path(MAPPER_DATASET_PATH).name | |
| return Path(os.getenv("DOD_LOCAL_INFERENCE_MAPPER_PATH", default_path)).expanduser() | |
| def _service_priority(service: str) -> str: | |
| """Return the configured endpoint priority for one service.""" | |
| _refresh_env() | |
| env_name = "LLM_URL_PRIORITY" if service == "llm" else "TTS_URL_PRIORITY" | |
| priority = os.getenv(env_name, "primary").strip().lower() | |
| if priority not in {"primary", "fallback"}: | |
| log_error(f"[Mapper] Ignored invalid {env_name}={priority!r}. Using primary.", flush=True) | |
| return "primary" | |
| return priority | |
| def _default_endpoint(service: str) -> EndpointConfig: | |
| """Return the local environment fallback endpoint for a service.""" | |
| _refresh_env() | |
| if service == "llm": | |
| return { | |
| "name": "env-llm", | |
| "url": os.getenv("LLM_URL", "https://elismasilva-voxcpm2-nanovllm-service.hf.space"), | |
| "mode": "gradio", | |
| "api_name": "/generate_inference", | |
| } | |
| if service == "tts": | |
| tts_url = os.getenv("TTS_API_URL", "http://127.0.0.1:8000/generate_api") | |
| tts_mode = os.getenv("TTS_API_MODE", "rest") | |
| if tts_mode == "gradio" and tts_url.rstrip("/").endswith("/generate_api"): | |
| tts_url = tts_url.rstrip("/")[: -len("/generate_api")] | |
| return { | |
| "name": "env-tts", | |
| "url": tts_url, | |
| "mode": tts_mode, | |
| "api_name": "/generate_api", | |
| } | |
| return {"name": f"env-{service}", "url": "", "mode": "rest"} | |
| def _normalize_endpoint(raw_endpoint: Any, service: str, role: str) -> EndpointConfig | None: | |
| """Normalize one mapper entry into a consistent endpoint dictionary.""" | |
| if isinstance(raw_endpoint, str): | |
| raw_endpoint = {"url": raw_endpoint} | |
| if not isinstance(raw_endpoint, dict): | |
| return None | |
| default = _default_endpoint(service) | |
| mode = str(raw_endpoint.get("mode", default.get("mode", "rest"))).strip().lower() | |
| url = str(raw_endpoint.get("url") or raw_endpoint.get("space") or raw_endpoint.get("src") or "").strip() | |
| is_http_url = url.startswith("http") | |
| is_gradio_space_id = mode == "gradio" and "/" in url and " " not in url | |
| if not is_http_url and not is_gradio_space_id: | |
| return None | |
| api_name = str(raw_endpoint.get("api_name", default.get("api_name", ""))).strip() | |
| if api_name and not api_name.startswith("/"): | |
| api_name = f"/{api_name}" | |
| if mode == "gradio" and is_http_url and api_name and url.rstrip("/").endswith(api_name): | |
| url = url.rstrip("/")[: -len(api_name)] | |
| timeout = float(raw_endpoint.get("timeout", 120.0)) | |
| warmup_timeout = float(raw_endpoint.get("warmup_timeout", max(timeout, ENDPOINT_WARMUP_TIMEOUT_SECONDS))) | |
| return { | |
| "name": str(raw_endpoint.get("name", role)), | |
| "url": url, | |
| "mode": mode, | |
| "api_name": api_name, | |
| "timeout": timeout, | |
| "warmup_timeout": warmup_timeout, | |
| "cooldown_seconds": float(raw_endpoint.get("cooldown_seconds", ENDPOINT_FAILURE_COOLDOWN_SECONDS)), | |
| } | |
| def _extract_service_endpoints(mapper: dict[str, Any], service: str) -> list[EndpointConfig]: | |
| """Extract primary and fallback endpoints from mapper JSON.""" | |
| service_config = mapper.get(service, {}) | |
| endpoints: list[EndpointConfig] = [] | |
| if isinstance(service_config, list): | |
| raw_entries = service_config | |
| elif isinstance(service_config, dict): | |
| raw_entries = [] | |
| if "primary" in service_config: | |
| raw_entries.append(service_config["primary"]) | |
| if "fallback" in service_config: | |
| raw_entries.append(service_config["fallback"]) | |
| raw_entries.extend(service_config.get("fallbacks", [])) | |
| if "url" in service_config: | |
| raw_entries.insert(0, service_config) | |
| else: | |
| raw_entries = [service_config] | |
| seen_urls = set() | |
| for idx, raw_entry in enumerate(raw_entries): | |
| endpoint = _normalize_endpoint(raw_entry, service, "primary" if idx == 0 else f"fallback-{idx}") | |
| if not endpoint: | |
| log_error(f"[Mapper] Ignored invalid {service} endpoint entry: {raw_entry}", flush=True) | |
| continue | |
| if endpoint["url"] in seen_urls: | |
| continue | |
| seen_urls.add(endpoint["url"]) | |
| endpoints.append(endpoint) | |
| return endpoints | |
| def _fetch_mapper() -> dict[str, Any]: | |
| """Fetch the mapper JSON from local disk or the remote dataset.""" | |
| if _env_enabled("DOD_USE_LOCAL_DATA"): | |
| local_path = _local_mapper_path() | |
| try: | |
| with local_path.open(mode="r", encoding="utf-8") as mapper_file: | |
| mapper = json.load(mapper_file) | |
| if isinstance(mapper, dict): | |
| log_info(f"[Mapper] Loaded local inference mapper from {local_path}", flush=True) | |
| return mapper | |
| log_error(f"[Mapper] Local mapper at {local_path} is not a JSON object. Using environment defaults.", flush=True) | |
| except FileNotFoundError: | |
| log_error(f"[Mapper] Local mapper not found at {local_path}. Using environment defaults.", flush=True) | |
| except Exception as exc: | |
| log_error(f"[Mapper] Failed loading local inference mapper at {local_path}: {exc}", flush=True) | |
| return {} | |
| try: | |
| _refresh_env() | |
| hf_token = _optional_env_secret("HF_TOKEN_DATASET") | |
| headers = {"Authorization": f"Bearer {hf_token}"} if hf_token else {} | |
| response = requests.get(MAPPER_URL, headers=headers, timeout=3.0) | |
| if response.status_code == 200: | |
| mapper = response.json() | |
| if isinstance(mapper, dict): | |
| log_info(f"[Mapper] Loaded inference mapper from {MAPPER_URL}", flush=True) | |
| return mapper | |
| log_error("[Mapper] Remote mapper is not a JSON object. Using environment defaults.", flush=True) | |
| else: | |
| log_error(f"[Mapper] Remote mapper failed with status {response.status_code}.", flush=True) | |
| except Exception as exc: | |
| log_error(f"[Mapper] Failed fetching inference mapper, using defaults: {exc}", flush=True) | |
| return {} | |
| def get_inference_mapper() -> dict[str, Any]: | |
| """Return cached mapper JSON, refreshing it after the configured TTL.""" | |
| global _cached_mapper, _last_mapper_update | |
| if _env_enabled("DOD_USE_LOCAL_API"): | |
| return {} | |
| now = time.time() | |
| with _mapper_lock: | |
| if _cached_mapper is not None and now - _last_mapper_update < MAPPER_CACHE_TTL_SECONDS: | |
| return _cached_mapper | |
| _cached_mapper = _fetch_mapper() | |
| _last_mapper_update = now | |
| return _cached_mapper | |
| def mark_endpoint_failed(service: str, endpoint: EndpointConfig, reason: str) -> None: | |
| """Temporarily skip an endpoint after a runtime failure. | |
| Args: | |
| service: Service name, such as llm or tts. | |
| endpoint: Endpoint configuration that failed. | |
| reason: Short failure reason for logs. | |
| """ | |
| url = endpoint.get("url", "") | |
| if not url: | |
| return | |
| cooldown = float(endpoint.get("cooldown_seconds", ENDPOINT_FAILURE_COOLDOWN_SECONDS)) | |
| retry_at = time.time() + cooldown | |
| with _mapper_lock: | |
| _endpoint_cooldowns[(service, url)] = retry_at | |
| log_info(f"[Mapper] Disabled {service} endpoint for {cooldown:.0f}s after failure: {url} ({reason})", flush=True) | |
| def mark_endpoint_success(service: str, endpoint: EndpointConfig) -> None: | |
| """Clear a previously marked endpoint failure after a successful call.""" | |
| url = endpoint.get("url", "") | |
| if not url: | |
| return | |
| with _mapper_lock: | |
| _endpoint_cooldowns.pop((service, url), None) | |
| def get_endpoint_chain(service: str) -> list[EndpointConfig]: | |
| """Return available endpoints for a service.""" | |
| if _env_enabled("DOD_USE_LOCAL_API"): | |
| endpoint = _default_endpoint(service) | |
| if endpoint.get("url"): | |
| log_info(f"[Mapper] DOD_USE_LOCAL_API=True. Using local {service} endpoint: {endpoint['url']}", flush=True) | |
| return [endpoint] | |
| return [] | |
| mapper = get_inference_mapper() | |
| endpoints = _extract_service_endpoints(mapper, service) if mapper else [] | |
| if _service_priority(service) == "fallback" and len(endpoints) > 1: | |
| priority_env = "LLM_URL_PRIORITY" if service == "llm" else "TTS_URL_PRIORITY" | |
| log_info(f"[Mapper] {priority_env}=fallback. Trying mapped fallback before primary for {service}.", flush=True) | |
| endpoints = endpoints[1:] + endpoints[:1] | |
| if not endpoints: | |
| log_error(f"[Mapper] No mapped {service} endpoints found. Set DOD_USE_LOCAL_API=True to use local environment URLs.", flush=True) | |
| return [] | |
| now = time.time() | |
| available = [ | |
| endpoint | |
| for endpoint in endpoints | |
| if now >= _endpoint_cooldowns.get((service, endpoint["url"]), 0.0) | |
| ] | |
| skipped_count = len(endpoints) - len(available) | |
| if skipped_count: | |
| log_info(f"[Mapper] Skipping {skipped_count} cooling-down {service} endpoint(s).", flush=True) | |
| selected = available or endpoints | |
| if selected: | |
| names = ", ".join(f"{endpoint.get('name', 'endpoint')}={endpoint['url']}" for endpoint in selected) | |
| log_info(f"[Mapper] Active {service} endpoint chain: {names}", flush=True) | |
| return selected | |