Yash030's picture
Add credential logging at provider creation.
a1a14b2
"""Provider descriptors, factory, and runtime registry."""
from __future__ import annotations
import asyncio
from collections import defaultdict
from collections.abc import Callable, Iterable, MutableMapping
from contextlib import suppress
import httpx
from loguru import logger
from config.provider_catalog import (
PROVIDER_CATALOG,
SUPPORTED_PROVIDER_IDS,
ProviderDescriptor,
)
from config.settings import ConfiguredChatModelRef, Settings
from providers.base import BaseProvider, ProviderConfig
from providers.exceptions import (
AuthenticationError,
ModelListResponseError,
ProviderError,
ServiceUnavailableError,
UnknownProviderTypeError,
)
from providers.model_listing import ProviderModelInfo, model_infos_from_ids
ProviderFactory = Callable[[ProviderConfig, Settings], BaseProvider]
# Backwards-compatible name for the catalog (single source: ``config.provider_catalog``).
PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = PROVIDER_CATALOG
def _create_nvidia_nim(config: ProviderConfig, settings: Settings) -> BaseProvider:
from providers.nvidia_nim import NvidiaNimProvider
return NvidiaNimProvider(config, nim_settings=settings.nim, settings=settings)
def _create_zen(config: ProviderConfig, settings: Settings) -> BaseProvider:
from providers.zen import ZenProvider
return ZenProvider(config, settings=settings)
def _create_cerebras(config: ProviderConfig, settings: Settings) -> BaseProvider:
from providers.cerebras import CerebrasProvider
return CerebrasProvider(config, settings=settings)
def _create_silicon(config: ProviderConfig, settings: Settings) -> BaseProvider:
from providers.silicon import SiliconProvider
return SiliconProvider(config, settings=settings)
def _create_groq(config: ProviderConfig, settings: Settings) -> BaseProvider:
from providers.groq import GroqProvider
return GroqProvider(config, settings=settings)
PROVIDER_FACTORIES: dict[str, ProviderFactory] = {
"nvidia_nim": _create_nvidia_nim,
"zen": _create_zen,
"cerebras": _create_cerebras,
"silicon": _create_silicon,
"groq": _create_groq,
}
if set(PROVIDER_DESCRIPTORS) != set(SUPPORTED_PROVIDER_IDS) or set(
PROVIDER_FACTORIES
) != set(SUPPORTED_PROVIDER_IDS):
raise AssertionError(
"PROVIDER_DESCRIPTORS, PROVIDER_FACTORIES, and SUPPORTED_PROVIDER_IDS are out of sync: "
f"descriptors={set(PROVIDER_DESCRIPTORS)!r} factories={set(PROVIDER_FACTORIES)!r} "
f"ids={set(SUPPORTED_PROVIDER_IDS)!r}"
)
def _string_attr(settings: Settings, attr_name: str | None, default: str = "") -> str:
if attr_name is None:
return default
value = getattr(settings, attr_name, default)
return value if isinstance(value, str) else default
def _credential_for(descriptor: ProviderDescriptor, settings: Settings) -> str:
if descriptor.static_credential is not None:
return descriptor.static_credential
if descriptor.credential_attr:
return _string_attr(settings, descriptor.credential_attr)
return ""
def _require_credential(descriptor: ProviderDescriptor, credential: str) -> None:
if descriptor.credential_env is None:
return
if credential and credential.strip():
return
message = f"{descriptor.credential_env} is not set. Add it to your .env file."
if descriptor.credential_url:
message = f"{message} Get a key at {descriptor.credential_url}"
raise AuthenticationError(message)
def build_provider_config(
descriptor: ProviderDescriptor, settings: Settings
) -> ProviderConfig:
credential = _credential_for(descriptor, settings)
_require_credential(descriptor, credential)
base_url = _string_attr(
settings, descriptor.base_url_attr, descriptor.default_base_url or ""
)
proxy = _string_attr(settings, descriptor.proxy_attr)
return ProviderConfig(
api_key=credential,
base_url=base_url or descriptor.default_base_url,
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
enable_thinking=settings.enable_model_thinking,
proxy=proxy,
log_raw_sse_events=settings.log_raw_sse_events,
log_api_error_tracebacks=settings.log_api_error_tracebacks,
)
def create_provider(provider_id: str, settings: Settings) -> BaseProvider:
descriptor = PROVIDER_DESCRIPTORS.get(provider_id)
if descriptor is None:
supported = "', '".join(PROVIDER_DESCRIPTORS)
raise UnknownProviderTypeError(
f"Unknown provider_type: '{provider_id}'. Supported: '{supported}'"
)
config = build_provider_config(descriptor, settings)
factory = PROVIDER_FACTORIES.get(provider_id)
if factory is None:
raise AssertionError(f"Unhandled provider descriptor: {provider_id}")
return factory(config, settings)
def _format_provider_query_failures(
refs: list[ConfiguredChatModelRef],
exc: BaseException,
settings: Settings,
) -> list[str]:
reason = _provider_query_failure_reason(exc, settings)
return [_format_model_validation_failure(ref, reason) for ref in refs]
def _format_missing_model_failure(ref: ConfiguredChatModelRef) -> str:
return _format_model_validation_failure(ref, "missing model")
def _format_model_validation_failure(ref: ConfiguredChatModelRef, problem: str) -> str:
return (
f"sources={','.join(ref.sources)} provider={ref.provider_id} "
f"model={ref.model_id} problem={problem}"
)
def _provider_query_failure_reason(
exc: BaseException,
settings: Settings,
) -> str:
if isinstance(exc, ModelListResponseError):
return f"malformed model-list response: {exc.message}"
if isinstance(exc, httpx.HTTPStatusError):
return f"query failure: HTTP {exc.response.status_code}"
if isinstance(exc, AuthenticationError):
return f"query failure: {exc.message}"
if isinstance(exc, ProviderError) and settings.log_api_error_tracebacks:
return f"query failure: {exc.message}"
return f"query failure: {type(exc).__name__}"
def _referenced_provider_ids(settings: Settings) -> frozenset[str]:
return frozenset(ref.provider_id for ref in settings.configured_chat_model_refs())
def _model_list_provider_ids_for_settings(settings: Settings) -> tuple[str, ...]:
"""Return providers worth discovering for this process configuration."""
referenced_provider_ids = _referenced_provider_ids(settings)
provider_ids: list[str] = []
for provider_id, descriptor in PROVIDER_DESCRIPTORS.items():
if descriptor.static_credential is not None:
if provider_id in referenced_provider_ids:
provider_ids.append(provider_id)
continue
if (
descriptor.credential_env is not None
and _credential_for(descriptor, settings).strip()
):
provider_ids.append(provider_id)
return tuple(provider_ids)
def _log_model_discovery_failure(
provider_id: str, exc: BaseException, settings: Settings
) -> None:
logger.warning(
"Provider model discovery skipped: provider={} reason={}",
provider_id,
_provider_query_failure_reason(exc, settings),
)
class ProviderRegistry:
"""Cache and clean up provider instances by provider id."""
def __init__(self, providers: MutableMapping[str, BaseProvider] | None = None):
self._providers = providers if providers is not None else {}
self._model_ids_by_provider: dict[str, frozenset[str]] = {}
self._model_infos_by_provider: dict[str, dict[str, ProviderModelInfo]] = {}
self._model_list_refresh_task: asyncio.Task[None] | None = None
def is_cached(self, provider_id: str) -> bool:
"""Return whether a provider for this id is already in the cache."""
return provider_id in self._providers
def get(self, provider_id: str, settings: Settings) -> BaseProvider:
from loguru import logger
if provider_id not in self._providers:
# Log what credentials are being used
from config.provider_catalog import PROVIDER_CATALOG
desc = PROVIDER_CATALOG.get(provider_id)
if desc and desc.credential_attr:
cred_value = getattr(settings, desc.credential_attr, "")
logger.info(
"Creating provider '{}' with credential '{}' = '{}'",
provider_id,
desc.credential_attr,
cred_value[:10] + "..." if cred_value else "EMPTY",
)
self._providers[provider_id] = create_provider(provider_id, settings)
return self._providers[provider_id]
def cache_model_ids(self, provider_id: str, model_ids: Iterable[str]) -> None:
"""Store a provider model-list result for later instant API responses."""
self.cache_model_infos(provider_id, model_infos_from_ids(model_ids))
def cache_model_infos(
self, provider_id: str, model_infos: Iterable[ProviderModelInfo]
) -> None:
"""Store provider model metadata for later instant API responses."""
clean_infos = {
info.model_id: info for info in model_infos if info.model_id.strip()
}
self._model_infos_by_provider[provider_id] = clean_infos
self._model_ids_by_provider[provider_id] = frozenset(clean_infos)
def cached_model_ids(self) -> dict[str, frozenset[str]]:
"""Return a copy of cached raw provider model ids."""
return dict(self._model_ids_by_provider)
def cached_model_supports_thinking(
self, provider_id: str, model_id: str
) -> bool | None:
"""Return cached thinking support when a provider exposes it."""
info = self._model_infos_by_provider.get(provider_id, {}).get(model_id)
if info is None:
return None
return info.supports_thinking
def cached_prefixed_model_refs(self) -> tuple[str, ...]:
"""Return cached provider models in user-selectable ``provider/model`` form."""
return tuple(info.model_id for info in self.cached_prefixed_model_infos())
def cached_prefixed_model_infos(self) -> tuple[ProviderModelInfo, ...]:
"""Return cached provider models with user-selectable prefixed ids."""
infos: list[ProviderModelInfo] = []
for provider_id in SUPPORTED_PROVIDER_IDS:
provider_infos = self._model_infos_by_provider.get(provider_id, {})
infos.extend(
ProviderModelInfo(
model_id=f"{provider_id}/{info.model_id}",
supports_thinking=info.supports_thinking,
)
for info in sorted(
provider_infos.values(), key=lambda item: item.model_id
)
)
return tuple(infos)
async def refresh_model_list_cache(
self, settings: Settings, *, only_missing: bool = False
) -> None:
"""Best-effort refresh of model lists for providers usable in this process."""
provider_ids = _model_list_provider_ids_for_settings(settings)
if only_missing:
provider_ids = tuple(
provider_id
for provider_id in provider_ids
if provider_id not in self._model_ids_by_provider
)
await self._refresh_model_ids(settings, provider_ids)
def start_model_list_refresh(self, settings: Settings) -> None:
"""Start a non-blocking cache warmup for missing eligible provider lists."""
if (
self._model_list_refresh_task is not None
and not self._model_list_refresh_task.done()
):
return
provider_ids = tuple(
provider_id
for provider_id in _model_list_provider_ids_for_settings(settings)
if provider_id not in self._model_ids_by_provider
)
if not provider_ids:
logger.info(
"Provider model discovery cache already warm: providers={}",
len(self._model_ids_by_provider),
)
return
self._model_list_refresh_task = asyncio.create_task(
self._run_model_list_refresh(settings, provider_ids)
)
async def _run_model_list_refresh(
self, settings: Settings, provider_ids: tuple[str, ...]
) -> None:
try:
await self._refresh_model_ids(settings, provider_ids)
except asyncio.CancelledError:
raise
except Exception as exc:
logger.warning(
"Provider model discovery task failed: exc_type={}",
type(exc).__name__,
)
async def _refresh_model_ids(
self, settings: Settings, provider_ids: tuple[str, ...]
) -> None:
tasks: dict[str, asyncio.Task[frozenset[ProviderModelInfo]]] = {}
for provider_id in provider_ids:
try:
provider = self.get(provider_id, settings)
except Exception as exc:
_log_model_discovery_failure(provider_id, exc, settings)
continue
tasks[provider_id] = asyncio.create_task(provider.list_model_infos())
if not tasks:
return
logger.info(
"Starting model discovery for providers: {}", ", ".join(tasks.keys())
)
results = await asyncio.gather(*tasks.values(), return_exceptions=True)
logger.info("Model discovery finished for all providers.")
for (provider_id, _task), result in zip(tasks.items(), results, strict=True):
if isinstance(result, BaseException):
if isinstance(result, asyncio.CancelledError):
raise result
_log_model_discovery_failure(provider_id, result, settings)
continue
self.cache_model_infos(provider_id, result)
logger.info(
"Provider model discovery cached: provider={} models={}",
provider_id,
len(result),
)
async def validate_configured_models(self, settings: Settings) -> None:
"""Fail fast unless every configured chat model exists upstream."""
refs = settings.configured_chat_model_refs()
refs_by_provider: dict[str, list[ConfiguredChatModelRef]] = defaultdict(list)
for ref in refs:
refs_by_provider[ref.provider_id].append(ref)
failures: list[str] = []
tasks: dict[str, asyncio.Task[frozenset[ProviderModelInfo]]] = {}
for provider_id, provider_refs in refs_by_provider.items():
try:
provider = self.get(provider_id, settings)
except Exception as exc:
failures.extend(
_format_provider_query_failures(provider_refs, exc, settings)
)
continue
tasks[provider_id] = asyncio.create_task(provider.list_model_infos())
if tasks:
results = await asyncio.gather(*tasks.values(), return_exceptions=True)
for (provider_id, _task), result in zip(
tasks.items(), results, strict=True
):
provider_refs = refs_by_provider[provider_id]
if isinstance(result, BaseException):
if isinstance(result, asyncio.CancelledError):
raise result
failures.extend(
_format_provider_query_failures(provider_refs, result, settings)
)
continue
self.cache_model_infos(provider_id, result)
model_ids = self._model_ids_by_provider[provider_id]
failures.extend(
_format_missing_model_failure(ref)
for ref in provider_refs
if ref.model_id not in model_ids
)
if failures:
message = "Configured model validation failed:\n" + "\n".join(
f"- {failure}" for failure in failures
)
raise ServiceUnavailableError(message)
logger.info(
"Configured provider models validated: models={} providers={}",
len(refs),
len(refs_by_provider),
)
async def cleanup(self) -> None:
"""Call ``cleanup`` on every cached provider, then clear the cache.
Attempts all providers even if one fails. A single failure is re-raised
as-is; multiple failures are wrapped in :exc:`ExceptionGroup`.
"""
if (
self._model_list_refresh_task is not None
and not self._model_list_refresh_task.done()
):
self._model_list_refresh_task.cancel()
with suppress(asyncio.CancelledError):
await self._model_list_refresh_task
items = list(self._providers.items())
errors: list[Exception] = []
try:
for _pid, provider in items:
try:
await provider.cleanup()
except Exception as e:
errors.append(e)
finally:
self._providers.clear()
self._model_ids_by_provider.clear()
self._model_infos_by_provider.clear()
if len(errors) == 1:
raise errors[0]
if len(errors) > 1:
msg = "One or more provider cleanups failed"
raise ExceptionGroup(msg, errors)