NeerajCodz commited on
Commit
101ad87
·
1 Parent(s): f5ba363

feat: update model router with NVIDIA support and latest model scores

Browse files
backend/app/main.py CHANGED
@@ -62,7 +62,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
62
  await _memory_manager.initialize()
63
 
64
  logger.info("Initializing model router...")
65
- _model_router = SmartModelRouter(settings)
 
 
 
 
 
 
66
  await _model_router.initialize()
67
 
68
  logger.info("Initializing tool registry...")
 
62
  await _memory_manager.initialize()
63
 
64
  logger.info("Initializing model router...")
65
+ _model_router = SmartModelRouter(
66
+ openai_api_key=settings.openai_api_key,
67
+ anthropic_api_key=settings.anthropic_api_key,
68
+ google_api_key=settings.google_api_key,
69
+ groq_api_key=settings.groq_api_key,
70
+ nvidia_api_key=settings.nvidia_api_key,
71
+ )
72
  await _model_router.initialize()
73
 
74
  logger.info("Initializing tool registry...")
backend/app/models/providers/__init__.py CHANGED
@@ -13,6 +13,7 @@ from app.models.providers.openai import OpenAIProvider
13
  from app.models.providers.anthropic import AnthropicProvider
14
  from app.models.providers.google import GoogleProvider
15
  from app.models.providers.groq import GroqProvider
 
16
 
17
  __all__ = [
18
  # Base
@@ -28,4 +29,5 @@ __all__ = [
28
  "AnthropicProvider",
29
  "GoogleProvider",
30
  "GroqProvider",
 
31
  ]
 
13
  from app.models.providers.anthropic import AnthropicProvider
14
  from app.models.providers.google import GoogleProvider
15
  from app.models.providers.groq import GroqProvider
16
+ from app.models.providers.nvidia import NVIDIAProvider
17
 
18
  __all__ = [
19
  # Base
 
29
  "AnthropicProvider",
30
  "GoogleProvider",
31
  "GroqProvider",
32
+ "NVIDIAProvider",
33
  ]
backend/app/models/router.py CHANGED
@@ -22,6 +22,7 @@ from app.models.providers.openai import OpenAIProvider
22
  from app.models.providers.anthropic import AnthropicProvider
23
  from app.models.providers.google import GoogleProvider
24
  from app.models.providers.groq import GroqProvider
 
25
 
26
  logger = logging.getLogger(__name__)
27
 
@@ -60,14 +61,14 @@ class RoutingConfig:
60
 
61
  # Task-specific model preferences
62
  task_preferences: dict[TaskType, list[str]] = field(default_factory=lambda: {
63
- TaskType.GENERAL: ["gpt-4o", "claude-3-5-sonnet-20241022", "gemini-1.5-pro"],
64
- TaskType.CODE: ["claude-3-5-sonnet-20241022", "gpt-4o", "gemini-1.5-pro"],
65
- TaskType.REASONING: ["claude-3-opus-20240229", "gpt-4o", "gemini-1.5-pro"],
66
- TaskType.EXTRACTION: ["gpt-4o-mini", "claude-3-haiku-20240307", "gemini-1.5-flash"],
67
- TaskType.SUMMARIZATION: ["gpt-4o-mini", "claude-3-5-haiku-20241022", "gemini-1.5-flash"],
68
  TaskType.CLASSIFICATION: ["gpt-4o-mini", "claude-3-haiku-20240307", "llama-3.1-8b-instant"],
69
- TaskType.CREATIVE: ["claude-3-5-sonnet-20241022", "gpt-4o", "gemini-1.5-pro"],
70
- TaskType.FAST: ["llama-3.1-8b-instant", "gemini-1.5-flash", "gpt-4o-mini"],
71
  })
72
 
73
 
@@ -143,19 +144,35 @@ class SmartModelRouter:
143
  "claude-3-sonnet-20240229": 0.88,
144
  "claude-3-5-haiku-20241022": 0.82,
145
  "claude-3-haiku-20240307": 0.75,
146
- # Google
 
 
 
 
 
 
 
 
147
  "gemini-1.5-pro": 0.91,
148
- "gemini-2.0-flash-exp": 0.88,
149
  "gemini-1.5-flash": 0.78,
150
  "gemini-pro": 0.75,
151
  # Groq
152
  "llama-3.3-70b-versatile": 0.85,
 
153
  "llama-3.1-70b-versatile": 0.84,
154
  "llama3-70b-8192": 0.82,
155
  "mixtral-8x7b-32768": 0.78,
156
  "llama-3.1-8b-instant": 0.65,
157
  "llama3-8b-8192": 0.60,
158
  "gemma2-9b-it": 0.62,
 
 
 
 
 
 
 
 
159
  }
160
 
161
  # Model speed rankings (relative, based on typical latency)
@@ -168,9 +185,22 @@ class SmartModelRouter:
168
  "llama3-70b-8192": 0.92,
169
  "llama-3.1-70b-versatile": 0.91,
170
  "llama-3.3-70b-versatile": 0.90,
171
- # Google Flash is fast
 
 
 
 
172
  "gemini-1.5-flash": 0.88,
173
- "gemini-2.0-flash-exp": 0.87,
 
 
 
 
 
 
 
 
 
174
  # Mini models
175
  "gpt-4o-mini": 0.85,
176
  "claude-3-haiku-20240307": 0.84,
@@ -178,6 +208,7 @@ class SmartModelRouter:
178
  "gpt-3.5-turbo": 0.82,
179
  # Pro models
180
  "gemini-pro": 0.75,
 
181
  "gemini-1.5-pro": 0.70,
182
  "gpt-4o": 0.68,
183
  "claude-3-5-sonnet-20241022": 0.65,
@@ -193,6 +224,7 @@ class SmartModelRouter:
193
  anthropic_api_key: str | SecretStr | None = None,
194
  google_api_key: str | SecretStr | None = None,
195
  groq_api_key: str | SecretStr | None = None,
 
196
  config: RoutingConfig | None = None,
197
  ):
198
  self.config = config or RoutingConfig()
@@ -207,6 +239,7 @@ class SmartModelRouter:
207
  "anthropic": self._get_key_value(anthropic_api_key),
208
  "google": self._get_key_value(google_api_key),
209
  "groq": self._get_key_value(groq_api_key),
 
210
  }
211
 
212
  @staticmethod
@@ -248,6 +281,12 @@ class SmartModelRouter:
248
  self.providers["groq"] = provider
249
  logger.info("Initialized Groq provider")
250
 
 
 
 
 
 
 
251
  if not self.providers:
252
  logger.warning("No LLM providers configured")
253
 
 
22
  from app.models.providers.anthropic import AnthropicProvider
23
  from app.models.providers.google import GoogleProvider
24
  from app.models.providers.groq import GroqProvider
25
+ from app.models.providers.nvidia import NVIDIAProvider
26
 
27
  logger = logging.getLogger(__name__)
28
 
 
61
 
62
  # Task-specific model preferences
63
  task_preferences: dict[TaskType, list[str]] = field(default_factory=lambda: {
64
+ TaskType.GENERAL: ["gpt-4o", "claude-3-5-sonnet-20241022", "gemini-2.5-pro", "deepseek-r1"],
65
+ TaskType.CODE: ["claude-3-5-sonnet-20241022", "gpt-4o", "devstral-2-123b", "gemini-2.5-pro"],
66
+ TaskType.REASONING: ["claude-3-opus-20240229", "deepseek-r1", "gpt-4o", "step-3.5-flash"],
67
+ TaskType.EXTRACTION: ["gpt-4o-mini", "claude-3-haiku-20240307", "gemini-2.5-flash"],
68
+ TaskType.SUMMARIZATION: ["gpt-4o-mini", "claude-3-5-haiku-20241022", "gemini-2.5-flash"],
69
  TaskType.CLASSIFICATION: ["gpt-4o-mini", "claude-3-haiku-20240307", "llama-3.1-8b-instant"],
70
+ TaskType.CREATIVE: ["claude-3-5-sonnet-20241022", "gpt-4o", "gemini-2.5-pro"],
71
+ TaskType.FAST: ["llama-3.1-8b-instant", "gemini-2.5-flash", "gpt-4o-mini"],
72
  })
73
 
74
 
 
144
  "claude-3-sonnet-20240229": 0.88,
145
  "claude-3-5-haiku-20241022": 0.82,
146
  "claude-3-haiku-20240307": 0.75,
147
+ # Google Gemini 2.5 & 3.0
148
+ "gemini-2.5-pro": 0.93,
149
+ "gemini-2.5-flash": 0.85,
150
+ "gemini-3-flash-preview": 0.87,
151
+ "gemini-3.1-flash-lite-preview": 0.82,
152
+ # Google Gemini 2.0
153
+ "gemini-2.0-flash": 0.88,
154
+ "gemini-2.0-flash-lite": 0.80,
155
+ # Google Gemini 1.5
156
  "gemini-1.5-pro": 0.91,
 
157
  "gemini-1.5-flash": 0.78,
158
  "gemini-pro": 0.75,
159
  # Groq
160
  "llama-3.3-70b-versatile": 0.85,
161
+ "llama-3.2-90b-vision-preview": 0.84,
162
  "llama-3.1-70b-versatile": 0.84,
163
  "llama3-70b-8192": 0.82,
164
  "mixtral-8x7b-32768": 0.78,
165
  "llama-3.1-8b-instant": 0.65,
166
  "llama3-8b-8192": 0.60,
167
  "gemma2-9b-it": 0.62,
168
+ # NVIDIA
169
+ "deepseek-r1": 0.92,
170
+ "deepseek-v3.2": 0.90,
171
+ "step-3.5-flash": 0.88,
172
+ "glm4.7": 0.87,
173
+ "devstral-2-123b": 0.86,
174
+ "llama-3.3-70b": 0.85,
175
+ "nemotron-70b": 0.83,
176
  }
177
 
178
  # Model speed rankings (relative, based on typical latency)
 
185
  "llama3-70b-8192": 0.92,
186
  "llama-3.1-70b-versatile": 0.91,
187
  "llama-3.3-70b-versatile": 0.90,
188
+ "llama-3.2-90b-vision-preview": 0.89,
189
+ # Google Flash models
190
+ "gemini-2.5-flash": 0.90,
191
+ "gemini-3-flash-preview": 0.89,
192
+ "gemini-2.0-flash": 0.88,
193
  "gemini-1.5-flash": 0.88,
194
+ "gemini-2.0-flash-lite": 0.87,
195
+ "gemini-3.1-flash-lite-preview": 0.86,
196
+ # NVIDIA models
197
+ "step-3.5-flash": 0.85,
198
+ "devstral-2-123b": 0.84,
199
+ "llama-3.3-70b": 0.83,
200
+ "nemotron-70b": 0.82,
201
+ "glm4.7": 0.81,
202
+ "deepseek-v3.2": 0.80,
203
+ "deepseek-r1": 0.79,
204
  # Mini models
205
  "gpt-4o-mini": 0.85,
206
  "claude-3-haiku-20240307": 0.84,
 
208
  "gpt-3.5-turbo": 0.82,
209
  # Pro models
210
  "gemini-pro": 0.75,
211
+ "gemini-2.5-pro": 0.72,
212
  "gemini-1.5-pro": 0.70,
213
  "gpt-4o": 0.68,
214
  "claude-3-5-sonnet-20241022": 0.65,
 
224
  anthropic_api_key: str | SecretStr | None = None,
225
  google_api_key: str | SecretStr | None = None,
226
  groq_api_key: str | SecretStr | None = None,
227
+ nvidia_api_key: str | SecretStr | None = None,
228
  config: RoutingConfig | None = None,
229
  ):
230
  self.config = config or RoutingConfig()
 
239
  "anthropic": self._get_key_value(anthropic_api_key),
240
  "google": self._get_key_value(google_api_key),
241
  "groq": self._get_key_value(groq_api_key),
242
+ "nvidia": self._get_key_value(nvidia_api_key),
243
  }
244
 
245
  @staticmethod
 
281
  self.providers["groq"] = provider
282
  logger.info("Initialized Groq provider")
283
 
284
+ if self._api_keys["nvidia"]:
285
+ provider = NVIDIAProvider(api_key=self._api_keys["nvidia"])
286
+ await provider.initialize()
287
+ self.providers["nvidia"] = provider
288
+ logger.info("Initialized NVIDIA provider")
289
+
290
  if not self.providers:
291
  logger.warning("No LLM providers configured")
292