NeerajCodz commited on
Commit
f5ba363
·
1 Parent(s): f080be2

feat: add NVIDIA provider and update Gemini/Groq with latest models

Browse files
backend/app/models/providers/google.py CHANGED
@@ -26,9 +26,10 @@ class GoogleProvider(BaseProvider):
26
 
27
  # Model definitions with pricing (per 1K tokens)
28
  MODELS = {
29
- "gemini-1.5-pro": ModelInfo(
30
- id="gemini-1.5-pro",
31
- name="Gemini 1.5 Pro",
 
32
  provider="google",
33
  context_window=2097152,
34
  max_output_tokens=8192,
@@ -38,9 +39,9 @@ class GoogleProvider(BaseProvider):
38
  cost_per_1k_input=0.00125,
39
  cost_per_1k_output=0.005,
40
  ),
41
- "gemini-1.5-flash": ModelInfo(
42
- id="gemini-1.5-flash",
43
- name="Gemini 1.5 Flash",
44
  provider="google",
45
  context_window=1048576,
46
  max_output_tokens=8192,
@@ -50,9 +51,35 @@ class GoogleProvider(BaseProvider):
50
  cost_per_1k_input=0.000075,
51
  cost_per_1k_output=0.0003,
52
  ),
53
- "gemini-2.0-flash-exp": ModelInfo(
54
- id="gemini-2.0-flash-exp",
55
- name="Gemini 2.0 Flash (Experimental)",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  provider="google",
57
  context_window=1048576,
58
  max_output_tokens=8192,
@@ -62,6 +89,43 @@ class GoogleProvider(BaseProvider):
62
  cost_per_1k_input=0.0,
63
  cost_per_1k_output=0.0,
64
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  "gemini-pro": ModelInfo(
66
  id="gemini-pro",
67
  name="Gemini Pro",
@@ -78,7 +142,8 @@ class GoogleProvider(BaseProvider):
78
 
79
  # Aliases
80
  MODEL_ALIASES = {
81
- "gemini-flash": "gemini-1.5-flash",
 
82
  "gemini-1.5": "gemini-1.5-pro",
83
  }
84
 
 
26
 
27
  # Model definitions with pricing (per 1K tokens)
28
  MODELS = {
29
+ # Gemini 2.5 Series
30
+ "gemini-2.5-pro": ModelInfo(
31
+ id="gemini-2.5-pro",
32
+ name="Gemini 2.5 Pro",
33
  provider="google",
34
  context_window=2097152,
35
  max_output_tokens=8192,
 
39
  cost_per_1k_input=0.00125,
40
  cost_per_1k_output=0.005,
41
  ),
42
+ "gemini-2.5-flash": ModelInfo(
43
+ id="gemini-2.5-flash",
44
+ name="Gemini 2.5 Flash",
45
  provider="google",
46
  context_window=1048576,
47
  max_output_tokens=8192,
 
51
  cost_per_1k_input=0.000075,
52
  cost_per_1k_output=0.0003,
53
  ),
54
+ # Gemini 2.0 Series
55
+ "gemini-2.0-flash": ModelInfo(
56
+ id="gemini-2.0-flash",
57
+ name="Gemini 2.0 Flash",
58
+ provider="google",
59
+ context_window=1048576,
60
+ max_output_tokens=8192,
61
+ supports_functions=True,
62
+ supports_vision=True,
63
+ supports_streaming=True,
64
+ cost_per_1k_input=0.0,
65
+ cost_per_1k_output=0.0,
66
+ ),
67
+ "gemini-2.0-flash-lite": ModelInfo(
68
+ id="gemini-2.0-flash-lite",
69
+ name="Gemini 2.0 Flash Lite",
70
+ provider="google",
71
+ context_window=524288,
72
+ max_output_tokens=8192,
73
+ supports_functions=True,
74
+ supports_vision=True,
75
+ supports_streaming=True,
76
+ cost_per_1k_input=0.0,
77
+ cost_per_1k_output=0.0,
78
+ ),
79
+ # Gemini 3.0 Series (Preview)
80
+ "gemini-3-flash-preview": ModelInfo(
81
+ id="gemini-3-flash-preview",
82
+ name="Gemini 3 Flash Preview",
83
  provider="google",
84
  context_window=1048576,
85
  max_output_tokens=8192,
 
89
  cost_per_1k_input=0.0,
90
  cost_per_1k_output=0.0,
91
  ),
92
+ "gemini-3.1-flash-lite-preview": ModelInfo(
93
+ id="gemini-3.1-flash-lite-preview",
94
+ name="Gemini 3.1 Flash Lite Preview",
95
+ provider="google",
96
+ context_window=524288,
97
+ max_output_tokens=8192,
98
+ supports_functions=True,
99
+ supports_vision=True,
100
+ supports_streaming=True,
101
+ cost_per_1k_input=0.0,
102
+ cost_per_1k_output=0.0,
103
+ ),
104
+ # Gemini 1.5 Series (Stable)
105
+ "gemini-1.5-pro": ModelInfo(
106
+ id="gemini-1.5-pro",
107
+ name="Gemini 1.5 Pro",
108
+ provider="google",
109
+ context_window=2097152,
110
+ max_output_tokens=8192,
111
+ supports_functions=True,
112
+ supports_vision=True,
113
+ supports_streaming=True,
114
+ cost_per_1k_input=0.00125,
115
+ cost_per_1k_output=0.005,
116
+ ),
117
+ "gemini-1.5-flash": ModelInfo(
118
+ id="gemini-1.5-flash",
119
+ name="Gemini 1.5 Flash",
120
+ provider="google",
121
+ context_window=1048576,
122
+ max_output_tokens=8192,
123
+ supports_functions=True,
124
+ supports_vision=True,
125
+ supports_streaming=True,
126
+ cost_per_1k_input=0.000075,
127
+ cost_per_1k_output=0.0003,
128
+ ),
129
  "gemini-pro": ModelInfo(
130
  id="gemini-pro",
131
  name="Gemini Pro",
 
142
 
143
  # Aliases
144
  MODEL_ALIASES = {
145
+ "gemini-flash": "gemini-2.5-flash",
146
+ "gemini-pro-latest": "gemini-2.5-pro",
147
  "gemini-1.5": "gemini-1.5-pro",
148
  }
149
 
backend/app/models/providers/groq.py CHANGED
@@ -38,6 +38,18 @@ class GroqProvider(BaseProvider):
38
  cost_per_1k_input=0.00059,
39
  cost_per_1k_output=0.00079,
40
  ),
 
 
 
 
 
 
 
 
 
 
 
 
41
  "llama-3.1-70b-versatile": ModelInfo(
42
  id="llama-3.1-70b-versatile",
43
  name="Llama 3.1 70B Versatile",
@@ -98,6 +110,18 @@ class GroqProvider(BaseProvider):
98
  cost_per_1k_input=0.00024,
99
  cost_per_1k_output=0.00024,
100
  ),
 
 
 
 
 
 
 
 
 
 
 
 
101
  "gemma2-9b-it": ModelInfo(
102
  id="gemma2-9b-it",
103
  name="Gemma 2 9B IT",
 
38
  cost_per_1k_input=0.00059,
39
  cost_per_1k_output=0.00079,
40
  ),
41
+ "llama-3.2-90b-vision-preview": ModelInfo(
42
+ id="llama-3.2-90b-vision-preview",
43
+ name="Llama 3.2 90B Vision",
44
+ provider="groq",
45
+ context_window=128000,
46
+ max_output_tokens=8192,
47
+ supports_functions=True,
48
+ supports_vision=True,
49
+ supports_streaming=True,
50
+ cost_per_1k_input=0.0009,
51
+ cost_per_1k_output=0.0009,
52
+ ),
53
  "llama-3.1-70b-versatile": ModelInfo(
54
  id="llama-3.1-70b-versatile",
55
  name="Llama 3.1 70B Versatile",
 
110
  cost_per_1k_input=0.00024,
111
  cost_per_1k_output=0.00024,
112
  ),
113
+ "gemma2-9b-it": ModelInfo(
114
+ id="gemma2-9b-it",
115
+ name="Gemma 2 9B",
116
+ provider="groq",
117
+ context_window=8192,
118
+ max_output_tokens=8192,
119
+ supports_functions=True,
120
+ supports_vision=False,
121
+ supports_streaming=True,
122
+ cost_per_1k_input=0.0002,
123
+ cost_per_1k_output=0.0002,
124
+ ),
125
  "gemma2-9b-it": ModelInfo(
126
  id="gemma2-9b-it",
127
  name="Gemma 2 9B IT",
backend/app/models/providers/nvidia.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """NVIDIA AI provider implementation via OpenAI-compatible API."""
2
+
3
+ import json
4
+ import time
5
+ from typing import Any, AsyncIterator
6
+
7
+ import httpx
8
+
9
+ from app.models.providers.base import (
10
+ AuthenticationError,
11
+ BaseProvider,
12
+ CompletionResponse,
13
+ ModelInfo,
14
+ ModelNotFoundError,
15
+ ProviderError,
16
+ RateLimitError,
17
+ TokenUsage,
18
+ )
19
+
20
+
21
+ class NVIDIAProvider(BaseProvider):
22
+ """NVIDIA AI API provider supporting reasoning and code models."""
23
+
24
+ PROVIDER_NAME = "nvidia"
25
+ DEFAULT_BASE_URL = "https://integrate.api.nvidia.com/v1"
26
+
27
+ # Model definitions with configurations
28
+ MODELS = {
29
+ # Reasoning models
30
+ "step-3.5-flash": ModelInfo(
31
+ id="stepfun-ai/step-3.5-flash",
32
+ name="Step 3.5 Flash (Reasoning)",
33
+ provider="nvidia",
34
+ context_window=16384,
35
+ max_output_tokens=16384,
36
+ supports_functions=False,
37
+ supports_vision=False,
38
+ supports_streaming=True,
39
+ cost_per_1k_input=0.0, # Free tier
40
+ cost_per_1k_output=0.0,
41
+ ),
42
+ "glm4.7": ModelInfo(
43
+ id="z-ai/glm4.7",
44
+ name="GLM 4.7 (Reasoning)",
45
+ provider="nvidia",
46
+ context_window=16384,
47
+ max_output_tokens=16384,
48
+ supports_functions=False,
49
+ supports_vision=False,
50
+ supports_streaming=True,
51
+ cost_per_1k_input=0.0,
52
+ cost_per_1k_output=0.0,
53
+ ),
54
+ "deepseek-v3.2": ModelInfo(
55
+ id="deepseek-ai/deepseek-v3.2",
56
+ name="DeepSeek V3.2 (Reasoning)",
57
+ provider="nvidia",
58
+ context_window=8192,
59
+ max_output_tokens=8192,
60
+ supports_functions=False,
61
+ supports_vision=False,
62
+ supports_streaming=True,
63
+ cost_per_1k_input=0.0,
64
+ cost_per_1k_output=0.0,
65
+ ),
66
+ "deepseek-r1": ModelInfo(
67
+ id="deepseek-ai/deepseek-r1",
68
+ name="DeepSeek R1 (Reasoning)",
69
+ provider="nvidia",
70
+ context_window=16384,
71
+ max_output_tokens=16384,
72
+ supports_functions=False,
73
+ supports_vision=False,
74
+ supports_streaming=True,
75
+ cost_per_1k_input=0.0,
76
+ cost_per_1k_output=0.0,
77
+ ),
78
+ # Code models
79
+ "devstral-2-123b": ModelInfo(
80
+ id="mistralai/devstral-2-123b-instruct-2512",
81
+ name="Devstral 2 123B (Code)",
82
+ provider="nvidia",
83
+ context_window=8192,
84
+ max_output_tokens=8192,
85
+ supports_functions=False,
86
+ supports_vision=False,
87
+ supports_streaming=True,
88
+ cost_per_1k_input=0.0,
89
+ cost_per_1k_output=0.0,
90
+ ),
91
+ # General models
92
+ "llama-3.3-70b": ModelInfo(
93
+ id="meta/llama-3.3-70b-instruct",
94
+ name="Llama 3.3 70B",
95
+ provider="nvidia",
96
+ context_window=8192,
97
+ max_output_tokens=8192,
98
+ supports_functions=False,
99
+ supports_vision=False,
100
+ supports_streaming=True,
101
+ cost_per_1k_input=0.0,
102
+ cost_per_1k_output=0.0,
103
+ ),
104
+ "nemotron-70b": ModelInfo(
105
+ id="nvidia/llama-3.1-nemotron-70b-instruct",
106
+ name="Nemotron 70B",
107
+ provider="nvidia",
108
+ context_window=4096,
109
+ max_output_tokens=4096,
110
+ supports_functions=False,
111
+ supports_vision=False,
112
+ supports_streaming=True,
113
+ cost_per_1k_input=0.0,
114
+ cost_per_1k_output=0.0,
115
+ ),
116
+ }
117
+
118
+ # Reasoning model configs
119
+ REASONING_CONFIGS = {
120
+ "step-3.5-flash": {
121
+ "temperature": 1.0,
122
+ "top_p": 0.9,
123
+ },
124
+ "glm4.7": {
125
+ "temperature": 1.0,
126
+ "top_p": 1.0,
127
+ "extra_body": {"chat_template_kwargs": {"enable_thinking": True, "clear_thinking": False}},
128
+ },
129
+ "deepseek-v3.2": {
130
+ "temperature": 1.0,
131
+ "top_p": 0.95,
132
+ "extra_body": {"chat_template_kwargs": {"thinking": True}},
133
+ },
134
+ "deepseek-r1": {
135
+ "temperature": 0.6,
136
+ "top_p": 0.95,
137
+ },
138
+ }
139
+
140
+ def __init__(
141
+ self,
142
+ api_key: str | None = None,
143
+ base_url: str | None = None,
144
+ timeout: float = 60.0,
145
+ max_retries: int = 2,
146
+ ):
147
+ """
148
+ Initialize NVIDIA provider.
149
+
150
+ Args:
151
+ api_key: NVIDIA API key
152
+ base_url: Base URL for NVIDIA API (defaults to integrate.api.nvidia.com)
153
+ timeout: Request timeout in seconds
154
+ max_retries: Maximum number of retries for failed requests
155
+ """
156
+ super().__init__(api_key, base_url or self.DEFAULT_BASE_URL, timeout, max_retries)
157
+ self._last_request_time = 0.0
158
+
159
+ def _get_headers(self) -> dict[str, str]:
160
+ """Get headers for NVIDIA API requests."""
161
+ return {
162
+ "Authorization": f"Bearer {self.api_key}",
163
+ "Content-Type": "application/json",
164
+ }
165
+
166
+ async def _rate_limit(self) -> None:
167
+ """Apply rate limiting between requests."""
168
+ elapsed = time.time() - self._last_request_time
169
+ min_interval = 0.3 # 300ms between requests
170
+ if elapsed < min_interval:
171
+ import asyncio
172
+ await asyncio.sleep(min_interval - elapsed)
173
+ self._last_request_time = time.time()
174
+
175
+ async def complete(
176
+ self,
177
+ messages: list[dict[str, str]],
178
+ model: str = "devstral-2-123b",
179
+ temperature: float = 0.7,
180
+ max_tokens: int | None = None,
181
+ **kwargs: Any,
182
+ ) -> CompletionResponse:
183
+ """
184
+ Create a chat completion using NVIDIA models.
185
+
186
+ Args:
187
+ messages: List of message dictionaries with 'role' and 'content'
188
+ model: Model key (e.g., 'devstral-2-123b', 'llama-3.3-70b')
189
+ temperature: Sampling temperature
190
+ max_tokens: Maximum tokens to generate
191
+ **kwargs: Additional model-specific parameters
192
+
193
+ Returns:
194
+ CompletionResponse with generated text and metadata
195
+
196
+ Raises:
197
+ ModelNotFoundError: If model is not supported
198
+ AuthenticationError: If API key is invalid
199
+ RateLimitError: If rate limit is exceeded
200
+ ProviderError: For other API errors
201
+ """
202
+ # Validate model
203
+ if model not in self.MODELS:
204
+ raise ModelNotFoundError(f"Model {model} not found. Available: {list(self.MODELS.keys())}")
205
+
206
+ model_info = self.MODELS[model]
207
+ model_id = model_info.id
208
+
209
+ # Apply rate limiting
210
+ await self._rate_limit()
211
+
212
+ # Build request payload
213
+ payload: dict[str, Any] = {
214
+ "model": model_id,
215
+ "messages": messages,
216
+ "temperature": temperature,
217
+ "max_tokens": max_tokens or model_info.max_output_tokens,
218
+ }
219
+
220
+ # Add reasoning model configs if applicable
221
+ if model in self.REASONING_CONFIGS:
222
+ config = self.REASONING_CONFIGS[model]
223
+ if "extra_body" in config:
224
+ payload["extra_body"] = config["extra_body"]
225
+ if "top_p" in config:
226
+ payload["top_p"] = config["top_p"]
227
+
228
+ # Add any additional kwargs
229
+ payload.update(kwargs)
230
+
231
+ try:
232
+ async with httpx.AsyncClient(timeout=self.timeout) as client:
233
+ response = await client.post(
234
+ f"{self.base_url}/chat/completions",
235
+ headers=self._get_headers(),
236
+ json=payload,
237
+ )
238
+
239
+ if response.status_code == 401:
240
+ raise AuthenticationError("Invalid NVIDIA API key")
241
+ elif response.status_code == 429:
242
+ raise RateLimitError("NVIDIA API rate limit exceeded")
243
+ elif response.status_code >= 400:
244
+ error_detail = response.text
245
+ raise ProviderError(f"NVIDIA API error ({response.status_code}): {error_detail}")
246
+
247
+ data = response.json()
248
+
249
+ # Extract response
250
+ choice = data["choices"][0]
251
+ content = choice["message"]["content"]
252
+
253
+ # Extract usage
254
+ usage_data = data.get("usage", {})
255
+ usage = TokenUsage(
256
+ prompt_tokens=usage_data.get("prompt_tokens", 0),
257
+ completion_tokens=usage_data.get("completion_tokens", 0),
258
+ total_tokens=usage_data.get("total_tokens", 0),
259
+ )
260
+
261
+ return CompletionResponse(
262
+ content=content,
263
+ model=model,
264
+ provider=self.PROVIDER_NAME,
265
+ usage=usage,
266
+ finish_reason=choice.get("finish_reason", "stop"),
267
+ raw_response=data,
268
+ )
269
+
270
+ except (AuthenticationError, RateLimitError, ProviderError, ModelNotFoundError):
271
+ raise
272
+ except Exception as e:
273
+ raise ProviderError(f"NVIDIA request failed: {str(e)}") from e
274
+
275
+ async def complete_stream(
276
+ self,
277
+ messages: list[dict[str, str]],
278
+ model: str = "devstral-2-123b",
279
+ temperature: float = 0.7,
280
+ max_tokens: int | None = None,
281
+ **kwargs: Any,
282
+ ) -> AsyncIterator[str]:
283
+ """
284
+ Create a streaming chat completion.
285
+
286
+ Args:
287
+ messages: List of message dictionaries
288
+ model: Model key
289
+ temperature: Sampling temperature
290
+ max_tokens: Maximum tokens to generate
291
+ **kwargs: Additional parameters
292
+
293
+ Yields:
294
+ Content chunks as they arrive
295
+
296
+ Raises:
297
+ Same as complete()
298
+ """
299
+ if model not in self.MODELS:
300
+ raise ModelNotFoundError(f"Model {model} not found")
301
+
302
+ model_info = self.MODELS[model]
303
+ model_id = model_info.id
304
+
305
+ await self._rate_limit()
306
+
307
+ payload: dict[str, Any] = {
308
+ "model": model_id,
309
+ "messages": messages,
310
+ "temperature": temperature,
311
+ "max_tokens": max_tokens or model_info.max_output_tokens,
312
+ "stream": True,
313
+ }
314
+
315
+ if model in self.REASONING_CONFIGS:
316
+ config = self.REASONING_CONFIGS[model]
317
+ if "extra_body" in config:
318
+ payload["extra_body"] = config["extra_body"]
319
+ if "top_p" in config:
320
+ payload["top_p"] = config["top_p"]
321
+
322
+ payload.update(kwargs)
323
+
324
+ try:
325
+ async with httpx.AsyncClient(timeout=self.timeout) as client:
326
+ async with client.stream(
327
+ "POST",
328
+ f"{self.base_url}/chat/completions",
329
+ headers=self._get_headers(),
330
+ json=payload,
331
+ ) as response:
332
+ if response.status_code == 401:
333
+ raise AuthenticationError("Invalid NVIDIA API key")
334
+ elif response.status_code == 429:
335
+ raise RateLimitError("NVIDIA API rate limit exceeded")
336
+ elif response.status_code >= 400:
337
+ error_detail = await response.aread()
338
+ raise ProviderError(f"NVIDIA API error: {error_detail.decode()}")
339
+
340
+ async for line in response.aiter_lines():
341
+ if not line.strip() or not line.startswith("data: "):
342
+ continue
343
+
344
+ data_str = line[6:] # Remove 'data: ' prefix
345
+ if data_str == "[DONE]":
346
+ break
347
+
348
+ try:
349
+ data = json.loads(data_str)
350
+ if "choices" in data and data["choices"]:
351
+ delta = data["choices"][0].get("delta", {})
352
+ content = delta.get("content")
353
+ if content:
354
+ yield content
355
+ except json.JSONDecodeError:
356
+ continue
357
+
358
+ except (AuthenticationError, RateLimitError, ProviderError, ModelNotFoundError):
359
+ raise
360
+ except Exception as e:
361
+ raise ProviderError(f"NVIDIA streaming failed: {str(e)}") from e
362
+
363
+ def list_models(self) -> list[ModelInfo]:
364
+ """List all available NVIDIA models."""
365
+ return list(self.MODELS.values())
366
+
367
+ def get_model_info(self, model: str) -> ModelInfo:
368
+ """Get information about a specific model."""
369
+ if model not in self.MODELS:
370
+ raise ModelNotFoundError(f"Model {model} not found")
371
+ return self.MODELS[model]