Yash030's picture
Deploy claude-code-nvidia proxy to Hugging Face Spaces
0157ac7
"""Base provider interface - extend this to implement your own provider."""
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from typing import Any
from pydantic import BaseModel
from config.constants import HTTP_CONNECT_TIMEOUT_DEFAULT
from providers.model_listing import ProviderModelInfo, model_infos_from_ids
class ProviderConfig(BaseModel):
"""Configuration for a provider.
Base fields apply to all providers. Provider-specific parameters
(e.g. NIM temperature, top_p) are passed by the provider constructor.
"""
api_key: str
base_url: str | None = None
rate_limit: int | None = None
rate_window: int = 60
max_concurrency: int = 5
http_read_timeout: float = 300.0
http_write_timeout: float = 10.0
http_connect_timeout: float = HTTP_CONNECT_TIMEOUT_DEFAULT
enable_thinking: bool = True
proxy: str = ""
log_raw_sse_events: bool = False
log_api_error_tracebacks: bool = False
class BaseProvider(ABC):
"""Base class for all providers. Extend this to add your own."""
def __init__(self, config: ProviderConfig):
self._config = config
def _is_thinking_enabled(
self, request: Any, thinking_enabled: bool | None = None
) -> bool:
"""Return whether thinking should be enabled for this request."""
thinking = getattr(request, "thinking", None)
config_enabled = (
self._config.enable_thinking
if thinking_enabled is None
else thinking_enabled
)
request_enabled = True
if thinking is not None:
thinking_type = (
thinking.get("type")
if isinstance(thinking, dict)
else getattr(thinking, "type", None)
)
if thinking_type == "disabled":
request_enabled = False
enabled = (
thinking.get("enabled")
if isinstance(thinking, dict)
else getattr(thinking, "enabled", None)
)
if enabled is not None:
request_enabled = bool(enabled)
return config_enabled and request_enabled
def preflight_stream(
self, request: Any, *, thinking_enabled: bool | None = None
) -> None:
"""Eagerly validate/build the upstream request before opening an SSE stream.
Subclasses with ``_build_request_body`` (OpenAI and native) raise
:class:`providers.exceptions.InvalidRequestError` on conversion failures.
"""
build = getattr(self, "_build_request_body", None)
if build is None:
return
build(request, thinking_enabled=thinking_enabled)
def _log_stream_transport_error(
self, tag: str, req_tag: str, error: Exception
) -> None:
"""Log streaming transport failures (metadata-only unless verbose is enabled)."""
from loguru import logger
if self._config.log_api_error_tracebacks:
logger.error(
"{}_ERROR:{} {}: {}", tag, req_tag, type(error).__name__, error
)
return
response = getattr(error, "response", None)
status_code = (
getattr(response, "status_code", None) if response is not None else None
)
logger.error(
"{}_ERROR:{} exc_type={} http_status={}",
tag,
req_tag,
type(error).__name__,
status_code,
)
@abstractmethod
async def cleanup(self) -> None:
"""Release any resources held by this provider."""
@abstractmethod
async def list_model_ids(self) -> frozenset[str]:
"""Return the model ids currently advertised by this provider."""
async def list_model_infos(self) -> frozenset[ProviderModelInfo]:
"""Return advertised model ids with optional provider capability metadata."""
return model_infos_from_ids(await self.list_model_ids())
@abstractmethod
async def stream_response(
self,
request: Any,
input_tokens: int = 0,
*,
request_id: str | None = None,
thinking_enabled: bool | None = None,
) -> AsyncIterator[str]:
"""Stream response in Anthropic SSE format."""
# Typing: abstract async generators need a yield for AsyncIterator[str]
# inference; this branch is never executed.
if False:
yield ""