Spaces:
Running
Running
File size: 4,447 Bytes
0157ac7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | """Base provider interface - extend this to implement your own provider."""
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from typing import Any
from pydantic import BaseModel
from config.constants import HTTP_CONNECT_TIMEOUT_DEFAULT
from providers.model_listing import ProviderModelInfo, model_infos_from_ids
class ProviderConfig(BaseModel):
"""Configuration for a provider.
Base fields apply to all providers. Provider-specific parameters
(e.g. NIM temperature, top_p) are passed by the provider constructor.
"""
api_key: str
base_url: str | None = None
rate_limit: int | None = None
rate_window: int = 60
max_concurrency: int = 5
http_read_timeout: float = 300.0
http_write_timeout: float = 10.0
http_connect_timeout: float = HTTP_CONNECT_TIMEOUT_DEFAULT
enable_thinking: bool = True
proxy: str = ""
log_raw_sse_events: bool = False
log_api_error_tracebacks: bool = False
class BaseProvider(ABC):
"""Base class for all providers. Extend this to add your own."""
def __init__(self, config: ProviderConfig):
self._config = config
def _is_thinking_enabled(
self, request: Any, thinking_enabled: bool | None = None
) -> bool:
"""Return whether thinking should be enabled for this request."""
thinking = getattr(request, "thinking", None)
config_enabled = (
self._config.enable_thinking
if thinking_enabled is None
else thinking_enabled
)
request_enabled = True
if thinking is not None:
thinking_type = (
thinking.get("type")
if isinstance(thinking, dict)
else getattr(thinking, "type", None)
)
if thinking_type == "disabled":
request_enabled = False
enabled = (
thinking.get("enabled")
if isinstance(thinking, dict)
else getattr(thinking, "enabled", None)
)
if enabled is not None:
request_enabled = bool(enabled)
return config_enabled and request_enabled
def preflight_stream(
self, request: Any, *, thinking_enabled: bool | None = None
) -> None:
"""Eagerly validate/build the upstream request before opening an SSE stream.
Subclasses with ``_build_request_body`` (OpenAI and native) raise
:class:`providers.exceptions.InvalidRequestError` on conversion failures.
"""
build = getattr(self, "_build_request_body", None)
if build is None:
return
build(request, thinking_enabled=thinking_enabled)
def _log_stream_transport_error(
self, tag: str, req_tag: str, error: Exception
) -> None:
"""Log streaming transport failures (metadata-only unless verbose is enabled)."""
from loguru import logger
if self._config.log_api_error_tracebacks:
logger.error(
"{}_ERROR:{} {}: {}", tag, req_tag, type(error).__name__, error
)
return
response = getattr(error, "response", None)
status_code = (
getattr(response, "status_code", None) if response is not None else None
)
logger.error(
"{}_ERROR:{} exc_type={} http_status={}",
tag,
req_tag,
type(error).__name__,
status_code,
)
@abstractmethod
async def cleanup(self) -> None:
"""Release any resources held by this provider."""
@abstractmethod
async def list_model_ids(self) -> frozenset[str]:
"""Return the model ids currently advertised by this provider."""
async def list_model_infos(self) -> frozenset[ProviderModelInfo]:
"""Return advertised model ids with optional provider capability metadata."""
return model_infos_from_ids(await self.list_model_ids())
@abstractmethod
async def stream_response(
self,
request: Any,
input_tokens: int = 0,
*,
request_id: str | None = None,
thinking_enabled: bool | None = None,
) -> AsyncIterator[str]:
"""Stream response in Anthropic SSE format."""
# Typing: abstract async generators need a yield for AsyncIterator[str]
# inference; this branch is never executed.
if False:
yield ""
|