File size: 4,447 Bytes
0157ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""Base provider interface - extend this to implement your own provider."""

from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from typing import Any

from pydantic import BaseModel

from config.constants import HTTP_CONNECT_TIMEOUT_DEFAULT
from providers.model_listing import ProviderModelInfo, model_infos_from_ids


class ProviderConfig(BaseModel):
    """Configuration for a provider.

    Base fields apply to all providers. Provider-specific parameters
    (e.g. NIM temperature, top_p) are passed by the provider constructor.
    """

    api_key: str
    base_url: str | None = None
    rate_limit: int | None = None
    rate_window: int = 60
    max_concurrency: int = 5
    http_read_timeout: float = 300.0
    http_write_timeout: float = 10.0
    http_connect_timeout: float = HTTP_CONNECT_TIMEOUT_DEFAULT
    enable_thinking: bool = True
    proxy: str = ""
    log_raw_sse_events: bool = False
    log_api_error_tracebacks: bool = False


class BaseProvider(ABC):
    """Base class for all providers. Extend this to add your own."""

    def __init__(self, config: ProviderConfig):
        self._config = config

    def _is_thinking_enabled(
        self, request: Any, thinking_enabled: bool | None = None
    ) -> bool:
        """Return whether thinking should be enabled for this request."""
        thinking = getattr(request, "thinking", None)
        config_enabled = (
            self._config.enable_thinking
            if thinking_enabled is None
            else thinking_enabled
        )
        request_enabled = True
        if thinking is not None:
            thinking_type = (
                thinking.get("type")
                if isinstance(thinking, dict)
                else getattr(thinking, "type", None)
            )
            if thinking_type == "disabled":
                request_enabled = False

            enabled = (
                thinking.get("enabled")
                if isinstance(thinking, dict)
                else getattr(thinking, "enabled", None)
            )
            if enabled is not None:
                request_enabled = bool(enabled)
        return config_enabled and request_enabled

    def preflight_stream(
        self, request: Any, *, thinking_enabled: bool | None = None
    ) -> None:
        """Eagerly validate/build the upstream request before opening an SSE stream.

        Subclasses with ``_build_request_body`` (OpenAI and native) raise
        :class:`providers.exceptions.InvalidRequestError` on conversion failures.
        """
        build = getattr(self, "_build_request_body", None)
        if build is None:
            return
        build(request, thinking_enabled=thinking_enabled)

    def _log_stream_transport_error(
        self, tag: str, req_tag: str, error: Exception
    ) -> None:
        """Log streaming transport failures (metadata-only unless verbose is enabled)."""
        from loguru import logger

        if self._config.log_api_error_tracebacks:
            logger.error(
                "{}_ERROR:{} {}: {}", tag, req_tag, type(error).__name__, error
            )
            return
        response = getattr(error, "response", None)
        status_code = (
            getattr(response, "status_code", None) if response is not None else None
        )
        logger.error(
            "{}_ERROR:{} exc_type={} http_status={}",
            tag,
            req_tag,
            type(error).__name__,
            status_code,
        )

    @abstractmethod
    async def cleanup(self) -> None:
        """Release any resources held by this provider."""

    @abstractmethod
    async def list_model_ids(self) -> frozenset[str]:
        """Return the model ids currently advertised by this provider."""

    async def list_model_infos(self) -> frozenset[ProviderModelInfo]:
        """Return advertised model ids with optional provider capability metadata."""
        return model_infos_from_ids(await self.list_model_ids())

    @abstractmethod
    async def stream_response(
        self,
        request: Any,
        input_tokens: int = 0,
        *,
        request_id: str | None = None,
        thinking_enabled: bool | None = None,
    ) -> AsyncIterator[str]:
        """Stream response in Anthropic SSE format."""
        # Typing: abstract async generators need a yield for AsyncIterator[str]
        # inference; this branch is never executed.
        if False:
            yield ""