Spaces:
Running
Running
| """OpenAI-style chat base for :class:`OpenAIChatTransport` (NIM, etc.). | |
| ``AnthropicMessagesTransport``-based providers (OpenRouter, LM Studio, DeepSeek, …) live | |
| in separate modules; do not list them as subclasses of this class. | |
| """ | |
| import asyncio | |
| import json | |
| import uuid | |
| from abc import abstractmethod | |
| from collections.abc import AsyncIterator, Iterator | |
| from typing import Any | |
| import httpx | |
| from loguru import logger | |
| from openai import AsyncOpenAI | |
| from core.anthropic import ( | |
| ContentType, | |
| HeuristicToolParser, | |
| SSEBuilder, | |
| ThinkTagParser, | |
| append_request_id, | |
| map_stop_reason, | |
| ) | |
| from providers.base import BaseProvider, ProviderConfig | |
| from providers.error_mapping import ( | |
| map_error, | |
| user_visible_message_for_mapped_provider_error, | |
| ) | |
| from providers.model_listing import ( | |
| ProviderModelInfo, | |
| extract_openai_model_ids, | |
| model_infos_from_ids, | |
| ) | |
| from providers.rate_limit import GlobalRateLimiter | |
| def _iter_heuristic_tool_use_sse( | |
| sse: SSEBuilder, tool_use: dict[str, Any] | |
| ) -> Iterator[str]: | |
| """Emit SSE for one heuristic tool_use block (closes open text/thinking first).""" | |
| if tool_use.get("name") == "Task" and isinstance(tool_use.get("input"), dict): | |
| task_input = tool_use["input"] | |
| if task_input.get("run_in_background") is not False: | |
| task_input["run_in_background"] = False | |
| yield from sse.close_content_blocks() | |
| block_idx = sse.blocks.allocate_index() | |
| yield sse.content_block_start( | |
| block_idx, | |
| "tool_use", | |
| id=tool_use["id"], | |
| name=tool_use["name"], | |
| ) | |
| yield sse.content_block_delta( | |
| block_idx, | |
| "input_json_delta", | |
| json.dumps(tool_use["input"]), | |
| ) | |
| yield sse.content_block_stop(block_idx) | |
| class OpenAIChatTransport(BaseProvider): | |
| """Base for OpenAI-compatible ``/chat/completions`` adapters (NIM, …).""" | |
| def __init__( | |
| self, | |
| config: ProviderConfig, | |
| *, | |
| provider_name: str, | |
| base_url: str, | |
| api_key: str, | |
| nim_rate_limit: int = 100, | |
| nim_max_concurrency: int = 40, | |
| ): | |
| super().__init__(config) | |
| self._provider_name = provider_name | |
| self._api_key = api_key | |
| self._base_url = base_url.rstrip("/") | |
| self._http_client = None | |
| self._client_cache: dict[str, AsyncOpenAI] = {} | |
| # NIM gets adaptive rate starting at 100 req/min (leaves headroom) | |
| # Zen is effectively unlimited (9999) | |
| if provider_name.lower() == "zen": | |
| effective_rate_limit = 9999 | |
| effective_max_concurrency = config.max_concurrency * 4 | |
| use_adaptive = None | |
| else: | |
| effective_rate_limit = nim_rate_limit | |
| effective_max_concurrency = max( | |
| nim_max_concurrency, config.max_concurrency * 4 | |
| ) | |
| use_adaptive = nim_rate_limit | |
| self._global_rate_limiter = GlobalRateLimiter.get_scoped_instance( | |
| provider_name.lower(), | |
| rate_limit=effective_rate_limit, | |
| rate_window=config.rate_window, | |
| max_concurrency=effective_max_concurrency, | |
| adaptive_rate=use_adaptive, | |
| adaptive_min_rate=10, | |
| ) | |
| # Connection pool tuned for maximum throughput. | |
| # Increased keepalive and connections for high concurrency. | |
| http_client_args = { | |
| "timeout": httpx.Timeout( | |
| config.http_read_timeout, | |
| connect=config.http_connect_timeout, | |
| read=config.http_read_timeout, | |
| write=config.http_write_timeout, | |
| ), | |
| "trust_env": False, | |
| "http2": True, | |
| "limits": httpx.Limits( | |
| max_keepalive_connections=100, | |
| max_connections=500, | |
| keepalive_expiry=5.0, | |
| ), | |
| } | |
| if config.proxy: | |
| http_client_args["proxy"] = config.proxy | |
| self._http_client = httpx.AsyncClient(**http_client_args) | |
| self._client = AsyncOpenAI( | |
| api_key=self._api_key, | |
| base_url=self._base_url, | |
| max_retries=0, | |
| timeout=httpx.Timeout( | |
| config.http_read_timeout, | |
| connect=config.http_connect_timeout, | |
| read=config.http_read_timeout, | |
| write=config.http_write_timeout, | |
| ), | |
| http_client=self._http_client, | |
| ) | |
| self._client_cache[self._api_key] = self._client | |
| if not self._api_key: | |
| logger.warning( | |
| "OpenAIChatTransport initialized with EMPTY API key: provider={} base_url={}", | |
| provider_name, | |
| base_url, | |
| ) | |
| else: | |
| logger.info( | |
| "OpenAIChatTransport initialized: provider={} base_url={} api_key_set={} key_prefix={}", | |
| provider_name, | |
| base_url, | |
| bool(self._api_key), | |
| self._api_key[:10] if self._api_key else "EMPTY", | |
| ) | |
| def _client_for_api_key(self, api_key: str) -> AsyncOpenAI: | |
| """Return a cached OpenAI client for the given API key.""" | |
| if api_key == self._api_key: | |
| return self._client | |
| client = self._client_cache.get(api_key) | |
| if client is not None: | |
| return client | |
| client = AsyncOpenAI( | |
| api_key=api_key, | |
| base_url=self._base_url, | |
| max_retries=0, | |
| timeout=httpx.Timeout( | |
| self._config.http_read_timeout, | |
| connect=self._config.http_connect_timeout, | |
| read=self._config.http_read_timeout, | |
| write=self._config.http_write_timeout, | |
| ), | |
| http_client=self._http_client, | |
| ) | |
| self._client_cache[api_key] = client | |
| return client | |
| async def cleanup(self) -> None: | |
| """Release HTTP client resources.""" | |
| seen: set[int] = set() | |
| for client in list(self._client_cache.values()): | |
| client_id = id(client) | |
| if client_id in seen: | |
| continue | |
| seen.add(client_id) | |
| await client.aclose() | |
| async def list_model_infos(self) -> frozenset[ProviderModelInfo]: | |
| """Return model metadata from the provider's OpenAI-compatible models endpoint.""" | |
| model_ids = await self.list_model_ids() | |
| # Default all models to supports_thinking=None (unknown) unless provider overrides | |
| return model_infos_from_ids(model_ids, supports_thinking=None) | |
| async def list_model_ids(self) -> frozenset[str]: | |
| """Return model ids from the provider's OpenAI-compatible models endpoint.""" | |
| payload = await self._client.models.list() | |
| return extract_openai_model_ids(payload, provider_name=self._provider_name) | |
| def _build_request_body( | |
| self, request: Any, thinking_enabled: bool | None = None | |
| ) -> dict: | |
| """Build request body. Must be implemented by subclasses.""" | |
| def _handle_extra_reasoning( | |
| self, delta: Any, sse: SSEBuilder, *, thinking_enabled: bool | |
| ) -> Iterator[str]: | |
| """Hook for provider-specific reasoning (e.g. OpenRouter reasoning_details).""" | |
| return iter(()) | |
| def _get_retry_request_body(self, error: Exception, body: dict) -> dict | None: | |
| """Return a modified request body for one retry, or None.""" | |
| return None | |
| async def _create_stream(self, body: dict) -> tuple[Any, dict]: | |
| """Create a streaming chat completion, optionally retrying once.""" | |
| from loguru import logger | |
| try: | |
| logger.info( | |
| "{}_CREATE_STREAM: calling API with model={}", | |
| self._provider_name, | |
| body.get("model"), | |
| ) | |
| stream = await self._global_rate_limiter.execute_with_retry( | |
| self._client.chat.completions.create, | |
| **body, | |
| stream=True, | |
| max_retries=1, | |
| ) | |
| return stream, body | |
| except Exception as error: | |
| logger.error( | |
| "{}_CREATE_STREAM_ERROR: {} - {} - response={}", | |
| self._provider_name, | |
| type(error).__name__, | |
| str(error), | |
| getattr(error, "response", None), | |
| ) | |
| retry_body = self._get_retry_request_body(error, body) | |
| if retry_body is None: | |
| raise | |
| stream = await self._global_rate_limiter.execute_with_retry( | |
| self._client.chat.completions.create, | |
| **retry_body, | |
| stream=True, | |
| max_retries=1, | |
| ) | |
| return stream, retry_body | |
| def _emit_tool_arg_delta( | |
| self, sse: SSEBuilder, tc_index: int, args: str | |
| ) -> Iterator[str]: | |
| """Emit one argument fragment for a started tool block (Task buffer or raw JSON).""" | |
| if not args: | |
| return | |
| state = sse.blocks.tool_states.get(tc_index) | |
| if state is None: | |
| return | |
| if state.name == "Task": | |
| parsed = sse.blocks.buffer_task_args(tc_index, args) | |
| if parsed is not None: | |
| yield sse.emit_tool_delta(tc_index, json.dumps(parsed)) | |
| return | |
| yield sse.emit_tool_delta(tc_index, args) | |
| def _process_tool_call(self, tc: dict, sse: SSEBuilder) -> Iterator[str]: | |
| """Process a single tool call delta and yield SSE events.""" | |
| tc_index = tc.get("index", 0) | |
| if tc_index < 0: | |
| tc_index = len(sse.blocks.tool_states) | |
| fn_delta = tc.get("function", {}) | |
| incoming_name = fn_delta.get("name") | |
| arguments = fn_delta.get("arguments", "") or "" | |
| if tc.get("id") is not None: | |
| sse.blocks.set_stream_tool_id(tc_index, tc.get("id")) | |
| if incoming_name is not None: | |
| sse.blocks.register_tool_name(tc_index, incoming_name) | |
| state = sse.blocks.tool_states.get(tc_index) | |
| resolved_id = (state.tool_id if state and state.tool_id else None) or tc.get( | |
| "id" | |
| ) | |
| resolved_name = (state.name if state else "") or "" | |
| if not state or not state.started: | |
| name_ok = bool((resolved_name or "").strip()) | |
| if name_ok: | |
| tool_id = str(resolved_id) if resolved_id else f"tool_{uuid.uuid4()}" | |
| display_name = (resolved_name or "").strip() or "tool_call" | |
| yield sse.start_tool_block(tc_index, tool_id, display_name) | |
| state = sse.blocks.tool_states[tc_index] | |
| if state.pre_start_args: | |
| pre = state.pre_start_args | |
| state.pre_start_args = "" | |
| yield from self._emit_tool_arg_delta(sse, tc_index, pre) | |
| state = sse.blocks.tool_states.get(tc_index) | |
| if not arguments: | |
| return | |
| if state is None or not state.started: | |
| state = sse.blocks.ensure_tool_state(tc_index) | |
| if not (resolved_name or "").strip(): | |
| state.pre_start_args += arguments | |
| return | |
| yield from self._emit_tool_arg_delta(sse, tc_index, arguments) | |
| def _flush_task_arg_buffers(self, sse: SSEBuilder) -> Iterator[str]: | |
| """Emit buffered Task args as a single JSON delta (best-effort).""" | |
| for tool_index, out in sse.blocks.flush_task_arg_buffers(): | |
| yield sse.emit_tool_delta(tool_index, out) | |
| async def stream_response( | |
| self, | |
| request: Any, | |
| input_tokens: int = 0, | |
| *, | |
| request_id: str | None = None, | |
| thinking_enabled: bool | None = None, | |
| ) -> AsyncIterator[str]: | |
| """Stream response in Anthropic SSE format.""" | |
| with logger.contextualize(request_id=request_id): | |
| async for event in self._stream_response_impl( | |
| request, input_tokens, request_id, thinking_enabled=thinking_enabled | |
| ): | |
| yield event | |
| async def _stream_response_impl( | |
| self, | |
| request: Any, | |
| input_tokens: int, | |
| request_id: str | None, | |
| *, | |
| thinking_enabled: bool | None, | |
| ) -> AsyncIterator[str]: | |
| """Shared streaming implementation.""" | |
| tag = self._provider_name | |
| message_id = f"msg_{uuid.uuid4()}" | |
| sse = SSEBuilder( | |
| message_id, | |
| request.model, | |
| input_tokens, | |
| log_raw_events=self._config.log_raw_sse_events, | |
| ) | |
| body = self._build_request_body(request, thinking_enabled=thinking_enabled) | |
| thinking_enabled = self._is_thinking_enabled(request, thinking_enabled) | |
| req_tag = f" request_id={request_id}" if request_id else "" | |
| # Log the actual model being sent to the provider for debugging | |
| actual_model = body.get("model", "") | |
| logger.info( | |
| "{}_STREAM:{} model={} msgs={} tools={} actual_model_to_provider={}", | |
| tag, | |
| req_tag, | |
| body.get("model"), | |
| len(body.get("messages", [])), | |
| len(body.get("tools", [])), | |
| actual_model, | |
| ) | |
| think_parser = ThinkTagParser() | |
| heuristic_parser = HeuristicToolParser() | |
| finish_reason = None | |
| usage_info = None | |
| async with self._global_rate_limiter.concurrency_slot(): | |
| try: | |
| yield sse.message_start() | |
| stream, body = await self._create_stream(body) | |
| async for chunk in stream: | |
| if getattr(chunk, "usage", None): | |
| usage_info = chunk.usage | |
| if not chunk.choices: | |
| continue | |
| choice = chunk.choices[0] | |
| delta = choice.delta | |
| if delta is None: | |
| continue | |
| if choice.finish_reason: | |
| finish_reason = choice.finish_reason | |
| logger.debug("{} finish_reason: {}", tag, finish_reason) | |
| # Handle reasoning_content (OpenAI extended format) | |
| reasoning = getattr(delta, "reasoning_content", None) | |
| if thinking_enabled and reasoning: | |
| for event in sse.ensure_thinking_block(): | |
| yield event | |
| yield sse.emit_thinking_delta(reasoning) | |
| # Provider-specific extra reasoning (e.g. OpenRouter reasoning_details) | |
| for event in self._handle_extra_reasoning( | |
| delta, | |
| sse, | |
| thinking_enabled=thinking_enabled, | |
| ): | |
| yield event | |
| # Handle text content | |
| if delta.content: | |
| for part in think_parser.feed(delta.content): | |
| if part.type == ContentType.THINKING: | |
| if not thinking_enabled: | |
| continue | |
| for event in sse.ensure_thinking_block(): | |
| yield event | |
| yield sse.emit_thinking_delta(part.content) | |
| else: | |
| filtered_text, detected_tools = heuristic_parser.feed( | |
| part.content | |
| ) | |
| if filtered_text: | |
| for event in sse.ensure_text_block(): | |
| yield event | |
| yield sse.emit_text_delta(filtered_text) | |
| for tool_use in detected_tools: | |
| for event in _iter_heuristic_tool_use_sse( | |
| sse, tool_use | |
| ): | |
| yield event | |
| # Handle native tool calls | |
| if delta.tool_calls: | |
| for event in sse.close_content_blocks(): | |
| yield event | |
| for tc in delta.tool_calls: | |
| tc_info = { | |
| "index": tc.index, | |
| "id": tc.id, | |
| "function": { | |
| "name": tc.function.name, | |
| "arguments": tc.function.arguments, | |
| }, | |
| } | |
| for event in self._process_tool_call(tc_info, sse): | |
| yield event | |
| except asyncio.CancelledError: | |
| raise | |
| except Exception as e: | |
| self._log_stream_transport_error(tag, req_tag, e) | |
| mapped_e = map_error(e, rate_limiter=self._global_rate_limiter) | |
| has_started_tool = any( | |
| s.started for s in sse.blocks.tool_states.values() | |
| ) | |
| has_content_blocks = ( | |
| sse.blocks.text_index != -1 | |
| or sse.blocks.thinking_index != -1 | |
| or has_started_tool | |
| or len(sse._accumulated_text_parts) > 0 | |
| or len(sse._accumulated_reasoning_parts) > 0 | |
| ) | |
| if has_content_blocks and isinstance( | |
| e, | |
| ( | |
| httpx.RemoteProtocolError, | |
| httpx.ReadTimeout, | |
| asyncio.TimeoutError, | |
| httpx.ConnectError, | |
| ), | |
| ): | |
| logger.warning( | |
| "{}_STREAM: Transient error mid-stream. Faking max_tokens to resume. {}", | |
| tag, | |
| e, | |
| ) | |
| for event in sse.close_all_blocks(): | |
| yield event | |
| yield sse.message_delta("max_tokens", sse.estimate_output_tokens()) | |
| yield sse.message_stop() | |
| return | |
| base_message = user_visible_message_for_mapped_provider_error( | |
| mapped_e, | |
| provider_name=tag, | |
| read_timeout_s=self._config.http_read_timeout, | |
| ) | |
| error_message = append_request_id(base_message, request_id) | |
| logger.info( | |
| "{}_STREAM: Emitting SSE error event for {}{}", | |
| tag, | |
| type(e).__name__, | |
| req_tag, | |
| ) | |
| for event in sse.close_all_blocks(): | |
| yield event | |
| if sse.blocks.has_emitted_tool_block(): | |
| # Avoid a second assistant text block after an emitted tool_use, which | |
| # breaks OpenAI history replay (issue #206) when Claude Code stores it. | |
| yield sse.emit_top_level_error(error_message) | |
| else: | |
| for event in sse.emit_error(error_message): | |
| yield event | |
| yield sse.message_delta("end_turn", 1) | |
| yield sse.message_stop() | |
| return | |
| # Flush remaining content | |
| remaining = think_parser.flush() | |
| if remaining: | |
| if remaining.type == ContentType.THINKING: | |
| if not thinking_enabled: | |
| remaining = None | |
| else: | |
| for event in sse.ensure_thinking_block(): | |
| yield event | |
| yield sse.emit_thinking_delta(remaining.content) | |
| if remaining and remaining.type == ContentType.TEXT: | |
| for event in sse.ensure_text_block(): | |
| yield event | |
| yield sse.emit_text_delta(remaining.content) | |
| for tool_use in heuristic_parser.flush(): | |
| for event in _iter_heuristic_tool_use_sse(sse, tool_use): | |
| yield event | |
| has_started_tool = any(s.started for s in sse.blocks.tool_states.values()) | |
| has_content_blocks = ( | |
| sse.blocks.text_index != -1 | |
| or sse.blocks.thinking_index != -1 | |
| or has_started_tool | |
| ) | |
| if not has_content_blocks: | |
| for event in sse.ensure_text_block(): | |
| yield event | |
| yield sse.emit_text_delta(" ") | |
| elif ( | |
| not has_started_tool | |
| and not sse.accumulated_text.strip() | |
| and sse.accumulated_reasoning.strip() | |
| ): | |
| # Some OpenAI-compatible models (e.g. NIM reasoning templates) stream only | |
| # ``reasoning_content`` with no ``content``; emit a minimal text block so | |
| # clients and smoke ``text_content()`` see a completed assistant message. | |
| for event in sse.ensure_text_block(): | |
| yield event | |
| yield sse.emit_text_delta(" ") | |
| for event in self._flush_task_arg_buffers(sse): | |
| yield event | |
| for event in sse.close_all_blocks(): | |
| yield event | |
| completion = ( | |
| getattr(usage_info, "completion_tokens", None) | |
| if usage_info is not None | |
| else None | |
| ) | |
| if isinstance(completion, int): | |
| output_tokens = completion | |
| else: | |
| output_tokens = sse.estimate_output_tokens() | |
| if usage_info and hasattr(usage_info, "prompt_tokens"): | |
| provider_input = usage_info.prompt_tokens | |
| if isinstance(provider_input, int): | |
| logger.debug( | |
| "TOKEN_ESTIMATE: our={} provider={} diff={:+d}", | |
| input_tokens, | |
| provider_input, | |
| provider_input - input_tokens, | |
| ) | |
| yield sse.message_delta(map_stop_reason(finish_reason), output_tokens) | |
| yield sse.message_stop() | |