Spaces:
Sleeping
Sleeping
| """NVIDIA AI provider implementation via OpenAI-compatible API.""" | |
| import json | |
| import time | |
| from typing import Any, AsyncIterator | |
| import httpx | |
| from app.models.providers.base import ( | |
| AuthenticationError, | |
| BaseProvider, | |
| CompletionResponse, | |
| ModelInfo, | |
| ModelNotFoundError, | |
| ProviderError, | |
| RateLimitError, | |
| TokenUsage, | |
| ) | |
| class NVIDIAProvider(BaseProvider): | |
| """NVIDIA AI API provider supporting reasoning and code models.""" | |
| PROVIDER_NAME = "nvidia" | |
| DEFAULT_BASE_URL = "https://integrate.api.nvidia.com/v1" | |
| # Model definitions with configurations | |
| MODELS = { | |
| # Reasoning models | |
| "step-3.5-flash": ModelInfo( | |
| id="stepfun-ai/step-3.5-flash", | |
| name="Step 3.5 Flash (Reasoning)", | |
| provider="nvidia", | |
| context_window=16384, | |
| max_output_tokens=16384, | |
| supports_functions=False, | |
| supports_vision=False, | |
| supports_streaming=True, | |
| cost_per_1k_input=0.0, # Free tier | |
| cost_per_1k_output=0.0, | |
| ), | |
| "glm4.7": ModelInfo( | |
| id="z-ai/glm4.7", | |
| name="GLM 4.7 (Reasoning)", | |
| provider="nvidia", | |
| context_window=16384, | |
| max_output_tokens=16384, | |
| supports_functions=False, | |
| supports_vision=False, | |
| supports_streaming=True, | |
| cost_per_1k_input=0.0, | |
| cost_per_1k_output=0.0, | |
| ), | |
| "deepseek-v3.2": ModelInfo( | |
| id="deepseek-ai/deepseek-v3.2", | |
| name="DeepSeek V3.2 (Reasoning)", | |
| provider="nvidia", | |
| context_window=8192, | |
| max_output_tokens=8192, | |
| supports_functions=False, | |
| supports_vision=False, | |
| supports_streaming=True, | |
| cost_per_1k_input=0.0, | |
| cost_per_1k_output=0.0, | |
| ), | |
| "deepseek-r1": ModelInfo( | |
| id="deepseek-ai/deepseek-r1", | |
| name="DeepSeek R1 (Reasoning)", | |
| provider="nvidia", | |
| context_window=16384, | |
| max_output_tokens=16384, | |
| supports_functions=False, | |
| supports_vision=False, | |
| supports_streaming=True, | |
| cost_per_1k_input=0.0, | |
| cost_per_1k_output=0.0, | |
| ), | |
| # Code models | |
| "devstral-2-123b": ModelInfo( | |
| id="mistralai/devstral-2-123b-instruct-2512", | |
| name="Devstral 2 123B (Code)", | |
| provider="nvidia", | |
| context_window=8192, | |
| max_output_tokens=8192, | |
| supports_functions=False, | |
| supports_vision=False, | |
| supports_streaming=True, | |
| cost_per_1k_input=0.0, | |
| cost_per_1k_output=0.0, | |
| ), | |
| # General models | |
| "llama-3.3-70b": ModelInfo( | |
| id="meta/llama-3.3-70b-instruct", | |
| name="Llama 3.3 70B", | |
| provider="nvidia", | |
| context_window=8192, | |
| max_output_tokens=8192, | |
| supports_functions=False, | |
| supports_vision=False, | |
| supports_streaming=True, | |
| cost_per_1k_input=0.0, | |
| cost_per_1k_output=0.0, | |
| ), | |
| "nemotron-70b": ModelInfo( | |
| id="nvidia/llama-3.1-nemotron-70b-instruct", | |
| name="Nemotron 70B", | |
| provider="nvidia", | |
| context_window=4096, | |
| max_output_tokens=4096, | |
| supports_functions=False, | |
| supports_vision=False, | |
| supports_streaming=True, | |
| cost_per_1k_input=0.0, | |
| cost_per_1k_output=0.0, | |
| ), | |
| } | |
| # Reasoning model configs | |
| REASONING_CONFIGS = { | |
| "step-3.5-flash": { | |
| "temperature": 1.0, | |
| "top_p": 0.9, | |
| }, | |
| "glm4.7": { | |
| "temperature": 1.0, | |
| "top_p": 1.0, | |
| "extra_body": {"chat_template_kwargs": {"enable_thinking": True, "clear_thinking": False}}, | |
| }, | |
| "deepseek-v3.2": { | |
| "temperature": 1.0, | |
| "top_p": 0.95, | |
| "extra_body": {"chat_template_kwargs": {"thinking": True}}, | |
| }, | |
| "deepseek-r1": { | |
| "temperature": 0.6, | |
| "top_p": 0.95, | |
| }, | |
| } | |
| def __init__( | |
| self, | |
| api_key: str | None = None, | |
| base_url: str | None = None, | |
| timeout: float = 60.0, | |
| max_retries: int = 2, | |
| ): | |
| """ | |
| Initialize NVIDIA provider. | |
| Args: | |
| api_key: NVIDIA API key | |
| base_url: Base URL for NVIDIA API (defaults to integrate.api.nvidia.com) | |
| timeout: Request timeout in seconds | |
| max_retries: Maximum number of retries for failed requests | |
| """ | |
| super().__init__(api_key, base_url or self.DEFAULT_BASE_URL, timeout, max_retries) | |
| self._last_request_time = 0.0 | |
| def _get_headers(self) -> dict[str, str]: | |
| """Get headers for NVIDIA API requests.""" | |
| return { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json", | |
| } | |
| async def _apply_rate_limit(self) -> None: | |
| """Apply rate limiting between requests.""" | |
| elapsed = time.time() - self._last_request_time | |
| min_interval = 0.3 # 300ms between requests | |
| if elapsed < min_interval: | |
| import asyncio | |
| await asyncio.sleep(min_interval - elapsed) | |
| self._last_request_time = time.time() | |
| async def complete( | |
| self, | |
| messages: list[dict[str, str]], | |
| model: str = "devstral-2-123b", | |
| temperature: float = 0.7, | |
| max_tokens: int | None = None, | |
| **kwargs: Any, | |
| ) -> CompletionResponse: | |
| """ | |
| Create a chat completion using NVIDIA models. | |
| Args: | |
| messages: List of message dictionaries with 'role' and 'content' | |
| model: Model key (e.g., 'devstral-2-123b', 'llama-3.3-70b') | |
| temperature: Sampling temperature | |
| max_tokens: Maximum tokens to generate | |
| **kwargs: Additional model-specific parameters | |
| Returns: | |
| CompletionResponse with generated text and metadata | |
| Raises: | |
| ModelNotFoundError: If model is not supported | |
| AuthenticationError: If API key is invalid | |
| RateLimitError: If rate limit is exceeded | |
| ProviderError: For other API errors | |
| """ | |
| # Validate model | |
| if model not in self.MODELS: | |
| raise ModelNotFoundError(self.PROVIDER_NAME, model) | |
| model_info = self.MODELS[model] | |
| model_id = model_info.id | |
| # Apply rate limiting | |
| await self._apply_rate_limit() | |
| # Build request payload | |
| payload: dict[str, Any] = { | |
| "model": model_id, | |
| "messages": messages, | |
| "temperature": temperature, | |
| "max_tokens": max_tokens or model_info.max_output_tokens, | |
| } | |
| # Add reasoning model configs if applicable | |
| if model in self.REASONING_CONFIGS: | |
| config = self.REASONING_CONFIGS[model] | |
| if "extra_body" in config: | |
| payload["extra_body"] = config["extra_body"] | |
| if "top_p" in config: | |
| payload["top_p"] = config["top_p"] | |
| # Add any additional kwargs | |
| payload.update(kwargs) | |
| try: | |
| async with httpx.AsyncClient(timeout=self.timeout) as client: | |
| response = await client.post( | |
| f"{self.base_url}/chat/completions", | |
| headers=self._get_headers(), | |
| json=payload, | |
| ) | |
| if response.status_code == 401: | |
| raise AuthenticationError(self.PROVIDER_NAME, "Invalid NVIDIA API key") | |
| elif response.status_code == 429: | |
| raise RateLimitError(self.PROVIDER_NAME) | |
| elif response.status_code >= 400: | |
| error_detail = response.text | |
| raise ProviderError(f"NVIDIA API error ({response.status_code}): {error_detail}", self.PROVIDER_NAME) | |
| data = response.json() | |
| # Extract response | |
| choice = data["choices"][0] | |
| content = choice["message"]["content"] | |
| # Extract usage | |
| usage_data = data.get("usage", {}) | |
| usage = TokenUsage( | |
| prompt_tokens=usage_data.get("prompt_tokens", 0), | |
| completion_tokens=usage_data.get("completion_tokens", 0), | |
| total_tokens=usage_data.get("total_tokens", 0), | |
| ) | |
| return CompletionResponse( | |
| content=content, | |
| model=model, | |
| provider=self.PROVIDER_NAME, | |
| usage=usage, | |
| finish_reason=choice.get("finish_reason", "stop"), | |
| raw_response=data, | |
| ) | |
| except (AuthenticationError, RateLimitError, ProviderError, ModelNotFoundError): | |
| raise | |
| except Exception as e: | |
| raise ProviderError(f"NVIDIA request failed: {str(e)}", self.PROVIDER_NAME) from e | |
| async def complete_stream( | |
| self, | |
| messages: list[dict[str, str]], | |
| model: str = "devstral-2-123b", | |
| temperature: float = 0.7, | |
| max_tokens: int | None = None, | |
| **kwargs: Any, | |
| ) -> AsyncIterator[str]: | |
| """ | |
| Create a streaming chat completion. | |
| Args: | |
| messages: List of message dictionaries | |
| model: Model key | |
| temperature: Sampling temperature | |
| max_tokens: Maximum tokens to generate | |
| **kwargs: Additional parameters | |
| Yields: | |
| Content chunks as they arrive | |
| Raises: | |
| Same as complete() | |
| """ | |
| if model not in self.MODELS: | |
| raise ModelNotFoundError(self.PROVIDER_NAME, model) | |
| model_info = self.MODELS[model] | |
| model_id = model_info.id | |
| await self._apply_rate_limit() | |
| payload: dict[str, Any] = { | |
| "model": model_id, | |
| "messages": messages, | |
| "temperature": temperature, | |
| "max_tokens": max_tokens or model_info.max_output_tokens, | |
| "stream": True, | |
| } | |
| if model in self.REASONING_CONFIGS: | |
| config = self.REASONING_CONFIGS[model] | |
| if "extra_body" in config: | |
| payload["extra_body"] = config["extra_body"] | |
| if "top_p" in config: | |
| payload["top_p"] = config["top_p"] | |
| payload.update(kwargs) | |
| try: | |
| async with httpx.AsyncClient(timeout=self.timeout) as client: | |
| async with client.stream( | |
| "POST", | |
| f"{self.base_url}/chat/completions", | |
| headers=self._get_headers(), | |
| json=payload, | |
| ) as response: | |
| if response.status_code == 401: | |
| raise AuthenticationError(self.PROVIDER_NAME, "Invalid NVIDIA API key") | |
| elif response.status_code == 429: | |
| raise RateLimitError(self.PROVIDER_NAME) | |
| elif response.status_code >= 400: | |
| error_detail = await response.aread() | |
| raise ProviderError(f"NVIDIA API error: {error_detail.decode()}", self.PROVIDER_NAME) | |
| async for line in response.aiter_lines(): | |
| if not line.strip() or not line.startswith("data: "): | |
| continue | |
| data_str = line[6:] # Remove 'data: ' prefix | |
| if data_str == "[DONE]": | |
| break | |
| try: | |
| data = json.loads(data_str) | |
| if "choices" in data and data["choices"]: | |
| delta = data["choices"][0].get("delta", {}) | |
| content = delta.get("content") | |
| if content: | |
| yield content | |
| except json.JSONDecodeError: | |
| continue | |
| except (AuthenticationError, RateLimitError, ProviderError, ModelNotFoundError): | |
| raise | |
| except Exception as e: | |
| raise ProviderError(f"NVIDIA streaming failed: {str(e)}", self.PROVIDER_NAME) from e | |
| def list_models(self) -> list[ModelInfo]: | |
| """List all available NVIDIA models.""" | |
| return list(self.MODELS.values()) | |
| def get_models(self) -> list[ModelInfo]: | |
| """Get list of available models (required by abstract base).""" | |
| return self.list_models() | |
| async def stream( | |
| self, | |
| messages: list[dict[str, Any]], | |
| model: str, | |
| temperature: float = 0.7, | |
| max_tokens: int | None = None, | |
| **kwargs: Any, | |
| ) -> AsyncIterator[str]: | |
| """Stream a completion (delegates to complete_stream).""" | |
| async for chunk in self.complete_stream(messages, model, temperature, max_tokens, **kwargs): | |
| yield chunk | |
| def get_model_info(self, model: str) -> ModelInfo: | |
| """Get information about a specific model.""" | |
| if model not in self.MODELS: | |
| raise ModelNotFoundError(self.PROVIDER_NAME, model) | |
| return self.MODELS[model] | |