"""Provider descriptors, factory, and runtime registry.""" from __future__ import annotations import asyncio from collections import defaultdict from collections.abc import Callable, Iterable, MutableMapping from contextlib import suppress import httpx from loguru import logger from config.provider_catalog import ( PROVIDER_CATALOG, SUPPORTED_PROVIDER_IDS, ProviderDescriptor, ) from config.settings import ConfiguredChatModelRef, Settings from providers.base import BaseProvider, ProviderConfig from providers.exceptions import ( AuthenticationError, ModelListResponseError, ProviderError, ServiceUnavailableError, UnknownProviderTypeError, ) from providers.model_listing import ProviderModelInfo, model_infos_from_ids ProviderFactory = Callable[[ProviderConfig, Settings], BaseProvider] # Backwards-compatible name for the catalog (single source: ``config.provider_catalog``). PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = PROVIDER_CATALOG def _create_nvidia_nim(config: ProviderConfig, settings: Settings) -> BaseProvider: from providers.nvidia_nim import NvidiaNimProvider return NvidiaNimProvider(config, nim_settings=settings.nim, settings=settings) def _create_zen(config: ProviderConfig, settings: Settings) -> BaseProvider: from providers.zen import ZenProvider return ZenProvider(config, settings=settings) def _create_cerebras(config: ProviderConfig, settings: Settings) -> BaseProvider: from providers.cerebras import CerebrasProvider return CerebrasProvider(config, settings=settings) def _create_silicon(config: ProviderConfig, settings: Settings) -> BaseProvider: from providers.silicon import SiliconProvider return SiliconProvider(config, settings=settings) def _create_groq(config: ProviderConfig, settings: Settings) -> BaseProvider: from providers.groq import GroqProvider return GroqProvider(config, settings=settings) PROVIDER_FACTORIES: dict[str, ProviderFactory] = { "nvidia_nim": _create_nvidia_nim, "zen": _create_zen, "cerebras": _create_cerebras, "silicon": _create_silicon, "groq": _create_groq, } if set(PROVIDER_DESCRIPTORS) != set(SUPPORTED_PROVIDER_IDS) or set( PROVIDER_FACTORIES ) != set(SUPPORTED_PROVIDER_IDS): raise AssertionError( "PROVIDER_DESCRIPTORS, PROVIDER_FACTORIES, and SUPPORTED_PROVIDER_IDS are out of sync: " f"descriptors={set(PROVIDER_DESCRIPTORS)!r} factories={set(PROVIDER_FACTORIES)!r} " f"ids={set(SUPPORTED_PROVIDER_IDS)!r}" ) def _string_attr(settings: Settings, attr_name: str | None, default: str = "") -> str: if attr_name is None: return default value = getattr(settings, attr_name, default) return value if isinstance(value, str) else default def _credential_for(descriptor: ProviderDescriptor, settings: Settings) -> str: if descriptor.static_credential is not None: return descriptor.static_credential if descriptor.credential_attr: return _string_attr(settings, descriptor.credential_attr) return "" def _require_credential(descriptor: ProviderDescriptor, credential: str) -> None: if descriptor.credential_env is None: return if credential and credential.strip(): return message = f"{descriptor.credential_env} is not set. Add it to your .env file." if descriptor.credential_url: message = f"{message} Get a key at {descriptor.credential_url}" raise AuthenticationError(message) def build_provider_config( descriptor: ProviderDescriptor, settings: Settings ) -> ProviderConfig: credential = _credential_for(descriptor, settings) _require_credential(descriptor, credential) base_url = _string_attr( settings, descriptor.base_url_attr, descriptor.default_base_url or "" ) proxy = _string_attr(settings, descriptor.proxy_attr) return ProviderConfig( api_key=credential, base_url=base_url or descriptor.default_base_url, rate_limit=settings.provider_rate_limit, rate_window=settings.provider_rate_window, max_concurrency=settings.provider_max_concurrency, http_read_timeout=settings.http_read_timeout, http_write_timeout=settings.http_write_timeout, http_connect_timeout=settings.http_connect_timeout, enable_thinking=settings.enable_model_thinking, proxy=proxy, log_raw_sse_events=settings.log_raw_sse_events, log_api_error_tracebacks=settings.log_api_error_tracebacks, ) def create_provider(provider_id: str, settings: Settings) -> BaseProvider: descriptor = PROVIDER_DESCRIPTORS.get(provider_id) if descriptor is None: supported = "', '".join(PROVIDER_DESCRIPTORS) raise UnknownProviderTypeError( f"Unknown provider_type: '{provider_id}'. Supported: '{supported}'" ) config = build_provider_config(descriptor, settings) factory = PROVIDER_FACTORIES.get(provider_id) if factory is None: raise AssertionError(f"Unhandled provider descriptor: {provider_id}") return factory(config, settings) def _format_provider_query_failures( refs: list[ConfiguredChatModelRef], exc: BaseException, settings: Settings, ) -> list[str]: reason = _provider_query_failure_reason(exc, settings) return [_format_model_validation_failure(ref, reason) for ref in refs] def _format_missing_model_failure(ref: ConfiguredChatModelRef) -> str: return _format_model_validation_failure(ref, "missing model") def _format_model_validation_failure(ref: ConfiguredChatModelRef, problem: str) -> str: return ( f"sources={','.join(ref.sources)} provider={ref.provider_id} " f"model={ref.model_id} problem={problem}" ) def _provider_query_failure_reason( exc: BaseException, settings: Settings, ) -> str: if isinstance(exc, ModelListResponseError): return f"malformed model-list response: {exc.message}" if isinstance(exc, httpx.HTTPStatusError): return f"query failure: HTTP {exc.response.status_code}" if isinstance(exc, AuthenticationError): return f"query failure: {exc.message}" if isinstance(exc, ProviderError) and settings.log_api_error_tracebacks: return f"query failure: {exc.message}" return f"query failure: {type(exc).__name__}" def _referenced_provider_ids(settings: Settings) -> frozenset[str]: return frozenset(ref.provider_id for ref in settings.configured_chat_model_refs()) def _model_list_provider_ids_for_settings(settings: Settings) -> tuple[str, ...]: """Return providers worth discovering for this process configuration.""" referenced_provider_ids = _referenced_provider_ids(settings) provider_ids: list[str] = [] for provider_id, descriptor in PROVIDER_DESCRIPTORS.items(): if descriptor.static_credential is not None: if provider_id in referenced_provider_ids: provider_ids.append(provider_id) continue if ( descriptor.credential_env is not None and _credential_for(descriptor, settings).strip() ): provider_ids.append(provider_id) return tuple(provider_ids) def _log_model_discovery_failure( provider_id: str, exc: BaseException, settings: Settings ) -> None: logger.warning( "Provider model discovery skipped: provider={} reason={}", provider_id, _provider_query_failure_reason(exc, settings), ) class ProviderRegistry: """Cache and clean up provider instances by provider id.""" def __init__(self, providers: MutableMapping[str, BaseProvider] | None = None): self._providers = providers if providers is not None else {} self._model_ids_by_provider: dict[str, frozenset[str]] = {} self._model_infos_by_provider: dict[str, dict[str, ProviderModelInfo]] = {} self._model_list_refresh_task: asyncio.Task[None] | None = None def is_cached(self, provider_id: str) -> bool: """Return whether a provider for this id is already in the cache.""" return provider_id in self._providers def get(self, provider_id: str, settings: Settings) -> BaseProvider: from loguru import logger if provider_id not in self._providers: # Log what credentials are being used from config.provider_catalog import PROVIDER_CATALOG desc = PROVIDER_CATALOG.get(provider_id) if desc and desc.credential_attr: cred_value = getattr(settings, desc.credential_attr, "") logger.info( "Creating provider '{}' with credential '{}' = '{}'", provider_id, desc.credential_attr, cred_value[:10] + "..." if cred_value else "EMPTY", ) self._providers[provider_id] = create_provider(provider_id, settings) return self._providers[provider_id] def cache_model_ids(self, provider_id: str, model_ids: Iterable[str]) -> None: """Store a provider model-list result for later instant API responses.""" self.cache_model_infos(provider_id, model_infos_from_ids(model_ids)) def cache_model_infos( self, provider_id: str, model_infos: Iterable[ProviderModelInfo] ) -> None: """Store provider model metadata for later instant API responses.""" clean_infos = { info.model_id: info for info in model_infos if info.model_id.strip() } self._model_infos_by_provider[provider_id] = clean_infos self._model_ids_by_provider[provider_id] = frozenset(clean_infos) def cached_model_ids(self) -> dict[str, frozenset[str]]: """Return a copy of cached raw provider model ids.""" return dict(self._model_ids_by_provider) def cached_model_supports_thinking( self, provider_id: str, model_id: str ) -> bool | None: """Return cached thinking support when a provider exposes it.""" info = self._model_infos_by_provider.get(provider_id, {}).get(model_id) if info is None: return None return info.supports_thinking def cached_prefixed_model_refs(self) -> tuple[str, ...]: """Return cached provider models in user-selectable ``provider/model`` form.""" return tuple(info.model_id for info in self.cached_prefixed_model_infos()) def cached_prefixed_model_infos(self) -> tuple[ProviderModelInfo, ...]: """Return cached provider models with user-selectable prefixed ids.""" infos: list[ProviderModelInfo] = [] for provider_id in SUPPORTED_PROVIDER_IDS: provider_infos = self._model_infos_by_provider.get(provider_id, {}) infos.extend( ProviderModelInfo( model_id=f"{provider_id}/{info.model_id}", supports_thinking=info.supports_thinking, ) for info in sorted( provider_infos.values(), key=lambda item: item.model_id ) ) return tuple(infos) async def refresh_model_list_cache( self, settings: Settings, *, only_missing: bool = False ) -> None: """Best-effort refresh of model lists for providers usable in this process.""" provider_ids = _model_list_provider_ids_for_settings(settings) if only_missing: provider_ids = tuple( provider_id for provider_id in provider_ids if provider_id not in self._model_ids_by_provider ) await self._refresh_model_ids(settings, provider_ids) def start_model_list_refresh(self, settings: Settings) -> None: """Start a non-blocking cache warmup for missing eligible provider lists.""" if ( self._model_list_refresh_task is not None and not self._model_list_refresh_task.done() ): return provider_ids = tuple( provider_id for provider_id in _model_list_provider_ids_for_settings(settings) if provider_id not in self._model_ids_by_provider ) if not provider_ids: logger.info( "Provider model discovery cache already warm: providers={}", len(self._model_ids_by_provider), ) return self._model_list_refresh_task = asyncio.create_task( self._run_model_list_refresh(settings, provider_ids) ) async def _run_model_list_refresh( self, settings: Settings, provider_ids: tuple[str, ...] ) -> None: try: await self._refresh_model_ids(settings, provider_ids) except asyncio.CancelledError: raise except Exception as exc: logger.warning( "Provider model discovery task failed: exc_type={}", type(exc).__name__, ) async def _refresh_model_ids( self, settings: Settings, provider_ids: tuple[str, ...] ) -> None: tasks: dict[str, asyncio.Task[frozenset[ProviderModelInfo]]] = {} for provider_id in provider_ids: try: provider = self.get(provider_id, settings) except Exception as exc: _log_model_discovery_failure(provider_id, exc, settings) continue tasks[provider_id] = asyncio.create_task(provider.list_model_infos()) if not tasks: return logger.info( "Starting model discovery for providers: {}", ", ".join(tasks.keys()) ) results = await asyncio.gather(*tasks.values(), return_exceptions=True) logger.info("Model discovery finished for all providers.") for (provider_id, _task), result in zip(tasks.items(), results, strict=True): if isinstance(result, BaseException): if isinstance(result, asyncio.CancelledError): raise result _log_model_discovery_failure(provider_id, result, settings) continue self.cache_model_infos(provider_id, result) logger.info( "Provider model discovery cached: provider={} models={}", provider_id, len(result), ) async def validate_configured_models(self, settings: Settings) -> None: """Fail fast unless every configured chat model exists upstream.""" refs = settings.configured_chat_model_refs() refs_by_provider: dict[str, list[ConfiguredChatModelRef]] = defaultdict(list) for ref in refs: refs_by_provider[ref.provider_id].append(ref) failures: list[str] = [] tasks: dict[str, asyncio.Task[frozenset[ProviderModelInfo]]] = {} for provider_id, provider_refs in refs_by_provider.items(): try: provider = self.get(provider_id, settings) except Exception as exc: failures.extend( _format_provider_query_failures(provider_refs, exc, settings) ) continue tasks[provider_id] = asyncio.create_task(provider.list_model_infos()) if tasks: results = await asyncio.gather(*tasks.values(), return_exceptions=True) for (provider_id, _task), result in zip( tasks.items(), results, strict=True ): provider_refs = refs_by_provider[provider_id] if isinstance(result, BaseException): if isinstance(result, asyncio.CancelledError): raise result failures.extend( _format_provider_query_failures(provider_refs, result, settings) ) continue self.cache_model_infos(provider_id, result) model_ids = self._model_ids_by_provider[provider_id] failures.extend( _format_missing_model_failure(ref) for ref in provider_refs if ref.model_id not in model_ids ) if failures: message = "Configured model validation failed:\n" + "\n".join( f"- {failure}" for failure in failures ) raise ServiceUnavailableError(message) logger.info( "Configured provider models validated: models={} providers={}", len(refs), len(refs_by_provider), ) async def cleanup(self) -> None: """Call ``cleanup`` on every cached provider, then clear the cache. Attempts all providers even if one fails. A single failure is re-raised as-is; multiple failures are wrapped in :exc:`ExceptionGroup`. """ if ( self._model_list_refresh_task is not None and not self._model_list_refresh_task.done() ): self._model_list_refresh_task.cancel() with suppress(asyncio.CancelledError): await self._model_list_refresh_task items = list(self._providers.items()) errors: list[Exception] = [] try: for _pid, provider in items: try: await provider.cleanup() except Exception as e: errors.append(e) finally: self._providers.clear() self._model_ids_by_provider.clear() self._model_infos_by_provider.clear() if len(errors) == 1: raise errors[0] if len(errors) > 1: msg = "One or more provider cleanups failed" raise ExceptionGroup(msg, errors)