Spaces:
Running
Running
File size: 17,802 Bytes
0157ac7 43ea069 98fdd46 0157ac7 43ea069 98fdd46 0157ac7 a1a14b2 0157ac7 a1a14b2 0157ac7 574e4e7 0157ac7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 | """Provider descriptors, factory, and runtime registry."""
from __future__ import annotations
import asyncio
from collections import defaultdict
from collections.abc import Callable, Iterable, MutableMapping
from contextlib import suppress
import httpx
from loguru import logger
from config.provider_catalog import (
PROVIDER_CATALOG,
SUPPORTED_PROVIDER_IDS,
ProviderDescriptor,
)
from config.settings import ConfiguredChatModelRef, Settings
from providers.base import BaseProvider, ProviderConfig
from providers.exceptions import (
AuthenticationError,
ModelListResponseError,
ProviderError,
ServiceUnavailableError,
UnknownProviderTypeError,
)
from providers.model_listing import ProviderModelInfo, model_infos_from_ids
ProviderFactory = Callable[[ProviderConfig, Settings], BaseProvider]
# Backwards-compatible name for the catalog (single source: ``config.provider_catalog``).
PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = PROVIDER_CATALOG
def _create_nvidia_nim(config: ProviderConfig, settings: Settings) -> BaseProvider:
from providers.nvidia_nim import NvidiaNimProvider
return NvidiaNimProvider(config, nim_settings=settings.nim, settings=settings)
def _create_zen(config: ProviderConfig, settings: Settings) -> BaseProvider:
from providers.zen import ZenProvider
return ZenProvider(config, settings=settings)
def _create_cerebras(config: ProviderConfig, settings: Settings) -> BaseProvider:
from providers.cerebras import CerebrasProvider
return CerebrasProvider(config, settings=settings)
def _create_silicon(config: ProviderConfig, settings: Settings) -> BaseProvider:
from providers.silicon import SiliconProvider
return SiliconProvider(config, settings=settings)
def _create_groq(config: ProviderConfig, settings: Settings) -> BaseProvider:
from providers.groq import GroqProvider
return GroqProvider(config, settings=settings)
PROVIDER_FACTORIES: dict[str, ProviderFactory] = {
"nvidia_nim": _create_nvidia_nim,
"zen": _create_zen,
"cerebras": _create_cerebras,
"silicon": _create_silicon,
"groq": _create_groq,
}
if set(PROVIDER_DESCRIPTORS) != set(SUPPORTED_PROVIDER_IDS) or set(
PROVIDER_FACTORIES
) != set(SUPPORTED_PROVIDER_IDS):
raise AssertionError(
"PROVIDER_DESCRIPTORS, PROVIDER_FACTORIES, and SUPPORTED_PROVIDER_IDS are out of sync: "
f"descriptors={set(PROVIDER_DESCRIPTORS)!r} factories={set(PROVIDER_FACTORIES)!r} "
f"ids={set(SUPPORTED_PROVIDER_IDS)!r}"
)
def _string_attr(settings: Settings, attr_name: str | None, default: str = "") -> str:
if attr_name is None:
return default
value = getattr(settings, attr_name, default)
return value if isinstance(value, str) else default
def _credential_for(descriptor: ProviderDescriptor, settings: Settings) -> str:
if descriptor.static_credential is not None:
return descriptor.static_credential
if descriptor.credential_attr:
return _string_attr(settings, descriptor.credential_attr)
return ""
def _require_credential(descriptor: ProviderDescriptor, credential: str) -> None:
if descriptor.credential_env is None:
return
if credential and credential.strip():
return
message = f"{descriptor.credential_env} is not set. Add it to your .env file."
if descriptor.credential_url:
message = f"{message} Get a key at {descriptor.credential_url}"
raise AuthenticationError(message)
def build_provider_config(
descriptor: ProviderDescriptor, settings: Settings
) -> ProviderConfig:
credential = _credential_for(descriptor, settings)
_require_credential(descriptor, credential)
base_url = _string_attr(
settings, descriptor.base_url_attr, descriptor.default_base_url or ""
)
proxy = _string_attr(settings, descriptor.proxy_attr)
return ProviderConfig(
api_key=credential,
base_url=base_url or descriptor.default_base_url,
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
enable_thinking=settings.enable_model_thinking,
proxy=proxy,
log_raw_sse_events=settings.log_raw_sse_events,
log_api_error_tracebacks=settings.log_api_error_tracebacks,
)
def create_provider(provider_id: str, settings: Settings) -> BaseProvider:
descriptor = PROVIDER_DESCRIPTORS.get(provider_id)
if descriptor is None:
supported = "', '".join(PROVIDER_DESCRIPTORS)
raise UnknownProviderTypeError(
f"Unknown provider_type: '{provider_id}'. Supported: '{supported}'"
)
config = build_provider_config(descriptor, settings)
factory = PROVIDER_FACTORIES.get(provider_id)
if factory is None:
raise AssertionError(f"Unhandled provider descriptor: {provider_id}")
return factory(config, settings)
def _format_provider_query_failures(
refs: list[ConfiguredChatModelRef],
exc: BaseException,
settings: Settings,
) -> list[str]:
reason = _provider_query_failure_reason(exc, settings)
return [_format_model_validation_failure(ref, reason) for ref in refs]
def _format_missing_model_failure(ref: ConfiguredChatModelRef) -> str:
return _format_model_validation_failure(ref, "missing model")
def _format_model_validation_failure(ref: ConfiguredChatModelRef, problem: str) -> str:
return (
f"sources={','.join(ref.sources)} provider={ref.provider_id} "
f"model={ref.model_id} problem={problem}"
)
def _provider_query_failure_reason(
exc: BaseException,
settings: Settings,
) -> str:
if isinstance(exc, ModelListResponseError):
return f"malformed model-list response: {exc.message}"
if isinstance(exc, httpx.HTTPStatusError):
return f"query failure: HTTP {exc.response.status_code}"
if isinstance(exc, AuthenticationError):
return f"query failure: {exc.message}"
if isinstance(exc, ProviderError) and settings.log_api_error_tracebacks:
return f"query failure: {exc.message}"
return f"query failure: {type(exc).__name__}"
def _referenced_provider_ids(settings: Settings) -> frozenset[str]:
return frozenset(ref.provider_id for ref in settings.configured_chat_model_refs())
def _model_list_provider_ids_for_settings(settings: Settings) -> tuple[str, ...]:
"""Return providers worth discovering for this process configuration."""
referenced_provider_ids = _referenced_provider_ids(settings)
provider_ids: list[str] = []
for provider_id, descriptor in PROVIDER_DESCRIPTORS.items():
if descriptor.static_credential is not None:
if provider_id in referenced_provider_ids:
provider_ids.append(provider_id)
continue
if (
descriptor.credential_env is not None
and _credential_for(descriptor, settings).strip()
):
provider_ids.append(provider_id)
return tuple(provider_ids)
def _log_model_discovery_failure(
provider_id: str, exc: BaseException, settings: Settings
) -> None:
logger.warning(
"Provider model discovery skipped: provider={} reason={}",
provider_id,
_provider_query_failure_reason(exc, settings),
)
class ProviderRegistry:
"""Cache and clean up provider instances by provider id."""
def __init__(self, providers: MutableMapping[str, BaseProvider] | None = None):
self._providers = providers if providers is not None else {}
self._model_ids_by_provider: dict[str, frozenset[str]] = {}
self._model_infos_by_provider: dict[str, dict[str, ProviderModelInfo]] = {}
self._model_list_refresh_task: asyncio.Task[None] | None = None
def is_cached(self, provider_id: str) -> bool:
"""Return whether a provider for this id is already in the cache."""
return provider_id in self._providers
def get(self, provider_id: str, settings: Settings) -> BaseProvider:
from loguru import logger
if provider_id not in self._providers:
# Log what credentials are being used
from config.provider_catalog import PROVIDER_CATALOG
desc = PROVIDER_CATALOG.get(provider_id)
if desc and desc.credential_attr:
cred_value = getattr(settings, desc.credential_attr, "")
logger.info(
"Creating provider '{}' with credential '{}' = '{}'",
provider_id,
desc.credential_attr,
cred_value[:10] + "..." if cred_value else "EMPTY",
)
self._providers[provider_id] = create_provider(provider_id, settings)
return self._providers[provider_id]
def cache_model_ids(self, provider_id: str, model_ids: Iterable[str]) -> None:
"""Store a provider model-list result for later instant API responses."""
self.cache_model_infos(provider_id, model_infos_from_ids(model_ids))
def cache_model_infos(
self, provider_id: str, model_infos: Iterable[ProviderModelInfo]
) -> None:
"""Store provider model metadata for later instant API responses."""
clean_infos = {
info.model_id: info for info in model_infos if info.model_id.strip()
}
self._model_infos_by_provider[provider_id] = clean_infos
self._model_ids_by_provider[provider_id] = frozenset(clean_infos)
def cached_model_ids(self) -> dict[str, frozenset[str]]:
"""Return a copy of cached raw provider model ids."""
return dict(self._model_ids_by_provider)
def cached_model_supports_thinking(
self, provider_id: str, model_id: str
) -> bool | None:
"""Return cached thinking support when a provider exposes it."""
info = self._model_infos_by_provider.get(provider_id, {}).get(model_id)
if info is None:
return None
return info.supports_thinking
def cached_prefixed_model_refs(self) -> tuple[str, ...]:
"""Return cached provider models in user-selectable ``provider/model`` form."""
return tuple(info.model_id for info in self.cached_prefixed_model_infos())
def cached_prefixed_model_infos(self) -> tuple[ProviderModelInfo, ...]:
"""Return cached provider models with user-selectable prefixed ids."""
infos: list[ProviderModelInfo] = []
for provider_id in SUPPORTED_PROVIDER_IDS:
provider_infos = self._model_infos_by_provider.get(provider_id, {})
infos.extend(
ProviderModelInfo(
model_id=f"{provider_id}/{info.model_id}",
supports_thinking=info.supports_thinking,
)
for info in sorted(
provider_infos.values(), key=lambda item: item.model_id
)
)
return tuple(infos)
async def refresh_model_list_cache(
self, settings: Settings, *, only_missing: bool = False
) -> None:
"""Best-effort refresh of model lists for providers usable in this process."""
provider_ids = _model_list_provider_ids_for_settings(settings)
if only_missing:
provider_ids = tuple(
provider_id
for provider_id in provider_ids
if provider_id not in self._model_ids_by_provider
)
await self._refresh_model_ids(settings, provider_ids)
def start_model_list_refresh(self, settings: Settings) -> None:
"""Start a non-blocking cache warmup for missing eligible provider lists."""
if (
self._model_list_refresh_task is not None
and not self._model_list_refresh_task.done()
):
return
provider_ids = tuple(
provider_id
for provider_id in _model_list_provider_ids_for_settings(settings)
if provider_id not in self._model_ids_by_provider
)
if not provider_ids:
logger.info(
"Provider model discovery cache already warm: providers={}",
len(self._model_ids_by_provider),
)
return
self._model_list_refresh_task = asyncio.create_task(
self._run_model_list_refresh(settings, provider_ids)
)
async def _run_model_list_refresh(
self, settings: Settings, provider_ids: tuple[str, ...]
) -> None:
try:
await self._refresh_model_ids(settings, provider_ids)
except asyncio.CancelledError:
raise
except Exception as exc:
logger.warning(
"Provider model discovery task failed: exc_type={}",
type(exc).__name__,
)
async def _refresh_model_ids(
self, settings: Settings, provider_ids: tuple[str, ...]
) -> None:
tasks: dict[str, asyncio.Task[frozenset[ProviderModelInfo]]] = {}
for provider_id in provider_ids:
try:
provider = self.get(provider_id, settings)
except Exception as exc:
_log_model_discovery_failure(provider_id, exc, settings)
continue
tasks[provider_id] = asyncio.create_task(provider.list_model_infos())
if not tasks:
return
logger.info(
"Starting model discovery for providers: {}", ", ".join(tasks.keys())
)
results = await asyncio.gather(*tasks.values(), return_exceptions=True)
logger.info("Model discovery finished for all providers.")
for (provider_id, _task), result in zip(tasks.items(), results, strict=True):
if isinstance(result, BaseException):
if isinstance(result, asyncio.CancelledError):
raise result
_log_model_discovery_failure(provider_id, result, settings)
continue
self.cache_model_infos(provider_id, result)
logger.info(
"Provider model discovery cached: provider={} models={}",
provider_id,
len(result),
)
async def validate_configured_models(self, settings: Settings) -> None:
"""Fail fast unless every configured chat model exists upstream."""
refs = settings.configured_chat_model_refs()
refs_by_provider: dict[str, list[ConfiguredChatModelRef]] = defaultdict(list)
for ref in refs:
refs_by_provider[ref.provider_id].append(ref)
failures: list[str] = []
tasks: dict[str, asyncio.Task[frozenset[ProviderModelInfo]]] = {}
for provider_id, provider_refs in refs_by_provider.items():
try:
provider = self.get(provider_id, settings)
except Exception as exc:
failures.extend(
_format_provider_query_failures(provider_refs, exc, settings)
)
continue
tasks[provider_id] = asyncio.create_task(provider.list_model_infos())
if tasks:
results = await asyncio.gather(*tasks.values(), return_exceptions=True)
for (provider_id, _task), result in zip(
tasks.items(), results, strict=True
):
provider_refs = refs_by_provider[provider_id]
if isinstance(result, BaseException):
if isinstance(result, asyncio.CancelledError):
raise result
failures.extend(
_format_provider_query_failures(provider_refs, result, settings)
)
continue
self.cache_model_infos(provider_id, result)
model_ids = self._model_ids_by_provider[provider_id]
failures.extend(
_format_missing_model_failure(ref)
for ref in provider_refs
if ref.model_id not in model_ids
)
if failures:
message = "Configured model validation failed:\n" + "\n".join(
f"- {failure}" for failure in failures
)
raise ServiceUnavailableError(message)
logger.info(
"Configured provider models validated: models={} providers={}",
len(refs),
len(refs_by_provider),
)
async def cleanup(self) -> None:
"""Call ``cleanup`` on every cached provider, then clear the cache.
Attempts all providers even if one fails. A single failure is re-raised
as-is; multiple failures are wrapped in :exc:`ExceptionGroup`.
"""
if (
self._model_list_refresh_task is not None
and not self._model_list_refresh_task.done()
):
self._model_list_refresh_task.cancel()
with suppress(asyncio.CancelledError):
await self._model_list_refresh_task
items = list(self._providers.items())
errors: list[Exception] = []
try:
for _pid, provider in items:
try:
await provider.cleanup()
except Exception as e:
errors.append(e)
finally:
self._providers.clear()
self._model_ids_by_provider.clear()
self._model_infos_by_provider.clear()
if len(errors) == 1:
raise errors[0]
if len(errors) > 1:
msg = "One or more provider cleanups failed"
raise ExceptionGroup(msg, errors)
|