File size: 9,624 Bytes
0157ac7
 
ebba9d6
0157ac7
 
 
ebba9d6
0157ac7
 
ebba9d6
0157ac7
 
 
 
 
 
 
ebba9d6
0157ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa9c0b0
 
0157ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5ea640
aa9c0b0
 
 
0157ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebba9d6
0157ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa9c0b0
 
 
0157ac7
 
 
 
 
 
 
 
 
 
 
aa9c0b0
0157ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa9c0b0
 
 
0157ac7
 
 
 
 
 
 
 
 
 
 
 
ebba9d6
0157ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
aa9c0b0
 
 
0157ac7
 
 
 
 
 
aa9c0b0
 
 
0157ac7
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
"""NVIDIA NIM provider implementation."""

import asyncio
import json
from typing import Any

import httpx
import openai
from loguru import logger
from openai import AsyncOpenAI

from config.nim import NimSettings
from config.settings import Settings
from providers.base import ProviderConfig
from providers.defaults import NVIDIA_NIM_DEFAULT_BASE
from providers.openai_compat import OpenAIChatTransport

from . import metrics as nim_metrics
from .request import (
    build_request_body,
    clone_body_without_chat_template,
    clone_body_without_reasoning_budget,
    clone_body_without_reasoning_content,
)


class NvidiaNimProvider(OpenAIChatTransport):
    """NVIDIA NIM provider using official OpenAI client."""

    def __init__(
        self,
        config: ProviderConfig,
        *,
        nim_settings: NimSettings,
        settings: Settings,
    ):
        super().__init__(
            config,
            provider_name="NIM",
            base_url=config.base_url or NVIDIA_NIM_DEFAULT_BASE,
            api_key=config.api_key,
            nim_rate_limit=settings.nim_rate_limit,
            nim_max_concurrency=settings.nim_max_concurrency,
        )
        self._nim_settings = nim_settings
        self._settings = settings

    def _api_key_for_model(self, model_name: str) -> str:
        return self._settings.nvidia_nim_api_key_for_model(model_name)

    def _client_for_body(self, body: dict[str, Any]) -> AsyncOpenAI:
        model_name = str(body.get("model") or "")
        api_key = self._api_key_for_model(model_name)
        return self._client_for_api_key(api_key)

    def _build_request_body(
        self, request: Any, thinking_enabled: bool | None = None
    ) -> dict:
        """Internal helper for tests and shared building."""
        return build_request_body(
            request,
            self._nim_settings,
            thinking_enabled=self._is_thinking_enabled(request, thinking_enabled),
        )

    def _get_retry_request_body(self, error: Exception, body: dict) -> dict | None:
        """Retry once with a downgraded body when NIM rejects a known field."""
        status_code = getattr(error, "status_code", None)
        if not isinstance(error, openai.BadRequestError) and status_code != 400:
            return None

        error_text = str(error)
        error_body = getattr(error, "body", None)
        if error_body is not None:
            error_text = f"{error_text} {json.dumps(error_body, default=str)}"
        error_text = error_text.lower()

        if "reasoning_budget" in error_text:
            retry_body = clone_body_without_reasoning_budget(body)
            if retry_body is None:
                return None
            logger.warning(
                "NIM_STREAM: retrying without reasoning_budget after 400 error"
            )
            return retry_body

        if "chat_template" in error_text:
            retry_body = clone_body_without_chat_template(body)
            if retry_body is None:
                return None
            logger.warning("NIM_STREAM: retrying without chat_template after 400 error")
            return retry_body

        if "reasoning_content" in error_text:
            retry_body = clone_body_without_reasoning_content(body)
            if retry_body is None:
                return None
            logger.warning(
                "NIM_STREAM: retrying without reasoning_content after 400 error"
            )
            return retry_body

        return None

    async def _create_stream(self, body: dict) -> tuple[Any, dict]:
        """Override to support fallback models on transient failures (429/connection/timeouts).

        Attempts the primary model first; on certain transient errors, will iterate
        configured fallback models from settings `nvidia_nim_fallback_models`.
        """
        from config.settings import get_settings

        # Faster timeouts for quick failover detection
        connect_timeout_s = 8  # Down from 10
        first_chunk_timeout_s = 20  # Down from 30
        fallback_first_chunk_timeout_s = 12  # Down from 20

        try:
            client = self._client_for_body(body)
            stream = await asyncio.wait_for(
                self._global_rate_limiter.execute_with_retry(
                    client.chat.completions.create,
                    **body,
                    stream=True,
                    max_retries=1,
                ),
                timeout=connect_timeout_s,
            )
            used_body = body
            # Probe for initial content; if no chunk arrives in time, treat as transient
            try:
                first = await asyncio.wait_for(
                    stream.__anext__(), timeout=first_chunk_timeout_s
                )
            except TimeoutError:
                # try to close original stream if possible
                try:
                    await getattr(stream, "aclose", lambda: None)()
                except Exception:
                    pass
                raise

            async def _wrapped():
                # yield the already-received first chunk, then the rest
                yield first
                async for c in stream:
                    yield c

            return _wrapped(), used_body
        except Exception as error:  # primary model failed
            # Decide whether to attempt fallbacks
            status_code = getattr(error, "status_code", None)
            text = str(error).lower()
            transient = False
            if status_code == 429:
                transient = True
            if "rate limit" in text or "too many requests" in text:
                transient = True
            if "connection" in text and ("refused" in text or "reset" in text):
                transient = True
            if isinstance(
                error, (httpx.ConnectError, httpx.ReadTimeout, asyncio.TimeoutError)
            ):
                transient = True

            if not transient:
                raise

            settings = get_settings()
            csv = (settings.nvidia_nim_fallback_models or "").strip()
            if not csv:
                raise

            candidates = [c.strip() for c in csv.split(",") if c.strip()]

            # normalize: for entries like 'nvidia_nim/model/name' -> use only model part
            def model_for_candidate(cand: str) -> str:
                if "/" in cand:
                    parts = cand.split("/", 1)
                    # if provider prefix present and not this provider, skip later
                    return parts[1]
                return cand

            last_exc = error
            for cand in candidates:
                # skip self model if identical
                try_model = model_for_candidate(cand)
                if try_model == body.get("model"):
                    continue
                # If candidate specified a different provider, ensure it's for NIM
                if "/" in cand:
                    provider = cand.split("/", 1)[0]
                    if provider != "nvidia_nim":
                        # Not applicable to this provider
                        continue

                retry_body = dict(body)
                retry_body["model"] = try_model
                client = self._client_for_body(retry_body)
                logger.warning(
                    "NIM_STREAM: primary model failed (%s); attempting fallback %s",
                    type(error).__name__,
                    cand,
                )
                try:
                    # record attempt
                    try:
                        nim_metrics.record_attempt(cand)
                    except Exception:
                        logger.debug(
                            "NIM_METRICS: failed to record attempt for %s", cand
                        )

                    stream = await self._global_rate_limiter.execute_with_retry(
                        client.chat.completions.create,
                        **retry_body,
                        stream=True,
                        max_retries=1,
                    )
                    # Probe for initial content on fallback stream as well
                    try:
                        first = await asyncio.wait_for(
                            stream.__anext__(), timeout=fallback_first_chunk_timeout_s
                        )
                    except TimeoutError:
                        try:
                            await getattr(stream, "aclose", lambda: None)()
                        except Exception:
                            pass
                        raise

                    async def _wrapped_fallback():
                        yield first
                        async for c in stream:
                            yield c

                    try:
                        nim_metrics.record_success(cand)
                    except Exception:
                        logger.debug(
                            "NIM_METRICS: failed to record success for %s", cand
                        )
                    return _wrapped_fallback(), retry_body
                except Exception as e2:
                    logger.warning("NIM_STREAM: fallback %s failed: %s", cand, e2)
                    try:
                        nim_metrics.record_failure(cand)
                    except Exception:
                        logger.debug(
                            "NIM_METRICS: failed to record failure for %s", cand
                        )
                    last_exc = e2

            # No fallback succeeded; re-raise last exception
            raise last_exc