Spaces:
Running
Running
| """Base provider interface - extend this to implement your own provider.""" | |
| from abc import ABC, abstractmethod | |
| from collections.abc import AsyncIterator | |
| from typing import Any | |
| from pydantic import BaseModel | |
| from config.constants import HTTP_CONNECT_TIMEOUT_DEFAULT | |
| from providers.model_listing import ProviderModelInfo, model_infos_from_ids | |
| class ProviderConfig(BaseModel): | |
| """Configuration for a provider. | |
| Base fields apply to all providers. Provider-specific parameters | |
| (e.g. NIM temperature, top_p) are passed by the provider constructor. | |
| """ | |
| api_key: str | |
| base_url: str | None = None | |
| rate_limit: int | None = None | |
| rate_window: int = 60 | |
| max_concurrency: int = 5 | |
| http_read_timeout: float = 300.0 | |
| http_write_timeout: float = 10.0 | |
| http_connect_timeout: float = HTTP_CONNECT_TIMEOUT_DEFAULT | |
| enable_thinking: bool = True | |
| proxy: str = "" | |
| log_raw_sse_events: bool = False | |
| log_api_error_tracebacks: bool = False | |
| class BaseProvider(ABC): | |
| """Base class for all providers. Extend this to add your own.""" | |
| def __init__(self, config: ProviderConfig): | |
| self._config = config | |
| def _is_thinking_enabled( | |
| self, request: Any, thinking_enabled: bool | None = None | |
| ) -> bool: | |
| """Return whether thinking should be enabled for this request.""" | |
| thinking = getattr(request, "thinking", None) | |
| config_enabled = ( | |
| self._config.enable_thinking | |
| if thinking_enabled is None | |
| else thinking_enabled | |
| ) | |
| request_enabled = True | |
| if thinking is not None: | |
| thinking_type = ( | |
| thinking.get("type") | |
| if isinstance(thinking, dict) | |
| else getattr(thinking, "type", None) | |
| ) | |
| if thinking_type == "disabled": | |
| request_enabled = False | |
| enabled = ( | |
| thinking.get("enabled") | |
| if isinstance(thinking, dict) | |
| else getattr(thinking, "enabled", None) | |
| ) | |
| if enabled is not None: | |
| request_enabled = bool(enabled) | |
| return config_enabled and request_enabled | |
| def preflight_stream( | |
| self, request: Any, *, thinking_enabled: bool | None = None | |
| ) -> None: | |
| """Eagerly validate/build the upstream request before opening an SSE stream. | |
| Subclasses with ``_build_request_body`` (OpenAI and native) raise | |
| :class:`providers.exceptions.InvalidRequestError` on conversion failures. | |
| """ | |
| build = getattr(self, "_build_request_body", None) | |
| if build is None: | |
| return | |
| build(request, thinking_enabled=thinking_enabled) | |
| def _log_stream_transport_error( | |
| self, tag: str, req_tag: str, error: Exception | |
| ) -> None: | |
| """Log streaming transport failures (metadata-only unless verbose is enabled).""" | |
| from loguru import logger | |
| if self._config.log_api_error_tracebacks: | |
| logger.error( | |
| "{}_ERROR:{} {}: {}", tag, req_tag, type(error).__name__, error | |
| ) | |
| return | |
| response = getattr(error, "response", None) | |
| status_code = ( | |
| getattr(response, "status_code", None) if response is not None else None | |
| ) | |
| logger.error( | |
| "{}_ERROR:{} exc_type={} http_status={}", | |
| tag, | |
| req_tag, | |
| type(error).__name__, | |
| status_code, | |
| ) | |
| async def cleanup(self) -> None: | |
| """Release any resources held by this provider.""" | |
| async def list_model_ids(self) -> frozenset[str]: | |
| """Return the model ids currently advertised by this provider.""" | |
| async def list_model_infos(self) -> frozenset[ProviderModelInfo]: | |
| """Return advertised model ids with optional provider capability metadata.""" | |
| return model_infos_from_ids(await self.list_model_ids()) | |
| async def stream_response( | |
| self, | |
| request: Any, | |
| input_tokens: int = 0, | |
| *, | |
| request_id: str | None = None, | |
| thinking_enabled: bool | None = None, | |
| ) -> AsyncIterator[str]: | |
| """Stream response in Anthropic SSE format.""" | |
| # Typing: abstract async generators need a yield for AsyncIterator[str] | |
| # inference; this branch is never executed. | |
| if False: | |
| yield "" | |