Spaces:
Running
Running
| """NVIDIA NIM provider implementation.""" | |
| import asyncio | |
| import json | |
| from typing import Any | |
| import httpx | |
| import openai | |
| from loguru import logger | |
| from openai import AsyncOpenAI | |
| from config.nim import NimSettings | |
| from config.settings import Settings | |
| from providers.base import ProviderConfig | |
| from providers.defaults import NVIDIA_NIM_DEFAULT_BASE | |
| from providers.openai_compat import OpenAIChatTransport | |
| from . import metrics as nim_metrics | |
| from .request import ( | |
| build_request_body, | |
| clone_body_without_chat_template, | |
| clone_body_without_reasoning_budget, | |
| clone_body_without_reasoning_content, | |
| ) | |
| class NvidiaNimProvider(OpenAIChatTransport): | |
| """NVIDIA NIM provider using official OpenAI client.""" | |
| def __init__( | |
| self, | |
| config: ProviderConfig, | |
| *, | |
| nim_settings: NimSettings, | |
| settings: Settings, | |
| ): | |
| super().__init__( | |
| config, | |
| provider_name="NIM", | |
| base_url=config.base_url or NVIDIA_NIM_DEFAULT_BASE, | |
| api_key=config.api_key, | |
| nim_rate_limit=settings.nim_rate_limit, | |
| nim_max_concurrency=settings.nim_max_concurrency, | |
| ) | |
| self._nim_settings = nim_settings | |
| self._settings = settings | |
| def _api_key_for_model(self, model_name: str) -> str: | |
| return self._settings.nvidia_nim_api_key_for_model(model_name) | |
| def _client_for_body(self, body: dict[str, Any]) -> AsyncOpenAI: | |
| model_name = str(body.get("model") or "") | |
| api_key = self._api_key_for_model(model_name) | |
| return self._client_for_api_key(api_key) | |
| def _build_request_body( | |
| self, request: Any, thinking_enabled: bool | None = None | |
| ) -> dict: | |
| """Internal helper for tests and shared building.""" | |
| return build_request_body( | |
| request, | |
| self._nim_settings, | |
| thinking_enabled=self._is_thinking_enabled(request, thinking_enabled), | |
| ) | |
| def _get_retry_request_body(self, error: Exception, body: dict) -> dict | None: | |
| """Retry once with a downgraded body when NIM rejects a known field.""" | |
| status_code = getattr(error, "status_code", None) | |
| if not isinstance(error, openai.BadRequestError) and status_code != 400: | |
| return None | |
| error_text = str(error) | |
| error_body = getattr(error, "body", None) | |
| if error_body is not None: | |
| error_text = f"{error_text} {json.dumps(error_body, default=str)}" | |
| error_text = error_text.lower() | |
| if "reasoning_budget" in error_text: | |
| retry_body = clone_body_without_reasoning_budget(body) | |
| if retry_body is None: | |
| return None | |
| logger.warning( | |
| "NIM_STREAM: retrying without reasoning_budget after 400 error" | |
| ) | |
| return retry_body | |
| if "chat_template" in error_text: | |
| retry_body = clone_body_without_chat_template(body) | |
| if retry_body is None: | |
| return None | |
| logger.warning("NIM_STREAM: retrying without chat_template after 400 error") | |
| return retry_body | |
| if "reasoning_content" in error_text: | |
| retry_body = clone_body_without_reasoning_content(body) | |
| if retry_body is None: | |
| return None | |
| logger.warning( | |
| "NIM_STREAM: retrying without reasoning_content after 400 error" | |
| ) | |
| return retry_body | |
| return None | |
| async def _create_stream(self, body: dict) -> tuple[Any, dict]: | |
| """Override to support fallback models on transient failures (429/connection/timeouts). | |
| Attempts the primary model first; on certain transient errors, will iterate | |
| configured fallback models from settings `nvidia_nim_fallback_models`. | |
| """ | |
| from config.settings import get_settings | |
| # Faster timeouts for quick failover detection | |
| connect_timeout_s = 8 # Down from 10 | |
| first_chunk_timeout_s = 20 # Down from 30 | |
| fallback_first_chunk_timeout_s = 12 # Down from 20 | |
| try: | |
| client = self._client_for_body(body) | |
| stream = await asyncio.wait_for( | |
| self._global_rate_limiter.execute_with_retry( | |
| client.chat.completions.create, | |
| **body, | |
| stream=True, | |
| max_retries=1, | |
| ), | |
| timeout=connect_timeout_s, | |
| ) | |
| used_body = body | |
| # Probe for initial content; if no chunk arrives in time, treat as transient | |
| try: | |
| first = await asyncio.wait_for( | |
| stream.__anext__(), timeout=first_chunk_timeout_s | |
| ) | |
| except TimeoutError: | |
| # try to close original stream if possible | |
| try: | |
| await getattr(stream, "aclose", lambda: None)() | |
| except Exception: | |
| pass | |
| raise | |
| async def _wrapped(): | |
| # yield the already-received first chunk, then the rest | |
| yield first | |
| async for c in stream: | |
| yield c | |
| return _wrapped(), used_body | |
| except Exception as error: # primary model failed | |
| # Decide whether to attempt fallbacks | |
| status_code = getattr(error, "status_code", None) | |
| text = str(error).lower() | |
| transient = False | |
| if status_code == 429: | |
| transient = True | |
| if "rate limit" in text or "too many requests" in text: | |
| transient = True | |
| if "connection" in text and ("refused" in text or "reset" in text): | |
| transient = True | |
| if isinstance( | |
| error, (httpx.ConnectError, httpx.ReadTimeout, asyncio.TimeoutError) | |
| ): | |
| transient = True | |
| if not transient: | |
| raise | |
| settings = get_settings() | |
| csv = (settings.nvidia_nim_fallback_models or "").strip() | |
| if not csv: | |
| raise | |
| candidates = [c.strip() for c in csv.split(",") if c.strip()] | |
| # normalize: for entries like 'nvidia_nim/model/name' -> use only model part | |
| def model_for_candidate(cand: str) -> str: | |
| if "/" in cand: | |
| parts = cand.split("/", 1) | |
| # if provider prefix present and not this provider, skip later | |
| return parts[1] | |
| return cand | |
| last_exc = error | |
| for cand in candidates: | |
| # skip self model if identical | |
| try_model = model_for_candidate(cand) | |
| if try_model == body.get("model"): | |
| continue | |
| # If candidate specified a different provider, ensure it's for NIM | |
| if "/" in cand: | |
| provider = cand.split("/", 1)[0] | |
| if provider != "nvidia_nim": | |
| # Not applicable to this provider | |
| continue | |
| retry_body = dict(body) | |
| retry_body["model"] = try_model | |
| client = self._client_for_body(retry_body) | |
| logger.warning( | |
| "NIM_STREAM: primary model failed (%s); attempting fallback %s", | |
| type(error).__name__, | |
| cand, | |
| ) | |
| try: | |
| # record attempt | |
| try: | |
| nim_metrics.record_attempt(cand) | |
| except Exception: | |
| logger.debug( | |
| "NIM_METRICS: failed to record attempt for %s", cand | |
| ) | |
| stream = await self._global_rate_limiter.execute_with_retry( | |
| client.chat.completions.create, | |
| **retry_body, | |
| stream=True, | |
| max_retries=1, | |
| ) | |
| # Probe for initial content on fallback stream as well | |
| try: | |
| first = await asyncio.wait_for( | |
| stream.__anext__(), timeout=fallback_first_chunk_timeout_s | |
| ) | |
| except TimeoutError: | |
| try: | |
| await getattr(stream, "aclose", lambda: None)() | |
| except Exception: | |
| pass | |
| raise | |
| async def _wrapped_fallback(): | |
| yield first | |
| async for c in stream: | |
| yield c | |
| try: | |
| nim_metrics.record_success(cand) | |
| except Exception: | |
| logger.debug( | |
| "NIM_METRICS: failed to record success for %s", cand | |
| ) | |
| return _wrapped_fallback(), retry_body | |
| except Exception as e2: | |
| logger.warning("NIM_STREAM: fallback %s failed: %s", cand, e2) | |
| try: | |
| nim_metrics.record_failure(cand) | |
| except Exception: | |
| logger.debug( | |
| "NIM_METRICS: failed to record failure for %s", cand | |
| ) | |
| last_exc = e2 | |
| # No fallback succeeded; re-raise last exception | |
| raise last_exc | |