Spaces:
Sleeping
Sleeping
Commit ·
101ad87
1
Parent(s): f5ba363
feat: update model router with NVIDIA support and latest model scores
Browse files- backend/app/main.py +7 -1
- backend/app/models/providers/__init__.py +2 -0
- backend/app/models/router.py +50 -11
backend/app/main.py
CHANGED
|
@@ -62,7 +62,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|
| 62 |
await _memory_manager.initialize()
|
| 63 |
|
| 64 |
logger.info("Initializing model router...")
|
| 65 |
-
_model_router = SmartModelRouter(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
await _model_router.initialize()
|
| 67 |
|
| 68 |
logger.info("Initializing tool registry...")
|
|
|
|
| 62 |
await _memory_manager.initialize()
|
| 63 |
|
| 64 |
logger.info("Initializing model router...")
|
| 65 |
+
_model_router = SmartModelRouter(
|
| 66 |
+
openai_api_key=settings.openai_api_key,
|
| 67 |
+
anthropic_api_key=settings.anthropic_api_key,
|
| 68 |
+
google_api_key=settings.google_api_key,
|
| 69 |
+
groq_api_key=settings.groq_api_key,
|
| 70 |
+
nvidia_api_key=settings.nvidia_api_key,
|
| 71 |
+
)
|
| 72 |
await _model_router.initialize()
|
| 73 |
|
| 74 |
logger.info("Initializing tool registry...")
|
backend/app/models/providers/__init__.py
CHANGED
|
@@ -13,6 +13,7 @@ from app.models.providers.openai import OpenAIProvider
|
|
| 13 |
from app.models.providers.anthropic import AnthropicProvider
|
| 14 |
from app.models.providers.google import GoogleProvider
|
| 15 |
from app.models.providers.groq import GroqProvider
|
|
|
|
| 16 |
|
| 17 |
__all__ = [
|
| 18 |
# Base
|
|
@@ -28,4 +29,5 @@ __all__ = [
|
|
| 28 |
"AnthropicProvider",
|
| 29 |
"GoogleProvider",
|
| 30 |
"GroqProvider",
|
|
|
|
| 31 |
]
|
|
|
|
| 13 |
from app.models.providers.anthropic import AnthropicProvider
|
| 14 |
from app.models.providers.google import GoogleProvider
|
| 15 |
from app.models.providers.groq import GroqProvider
|
| 16 |
+
from app.models.providers.nvidia import NVIDIAProvider
|
| 17 |
|
| 18 |
__all__ = [
|
| 19 |
# Base
|
|
|
|
| 29 |
"AnthropicProvider",
|
| 30 |
"GoogleProvider",
|
| 31 |
"GroqProvider",
|
| 32 |
+
"NVIDIAProvider",
|
| 33 |
]
|
backend/app/models/router.py
CHANGED
|
@@ -22,6 +22,7 @@ from app.models.providers.openai import OpenAIProvider
|
|
| 22 |
from app.models.providers.anthropic import AnthropicProvider
|
| 23 |
from app.models.providers.google import GoogleProvider
|
| 24 |
from app.models.providers.groq import GroqProvider
|
|
|
|
| 25 |
|
| 26 |
logger = logging.getLogger(__name__)
|
| 27 |
|
|
@@ -60,14 +61,14 @@ class RoutingConfig:
|
|
| 60 |
|
| 61 |
# Task-specific model preferences
|
| 62 |
task_preferences: dict[TaskType, list[str]] = field(default_factory=lambda: {
|
| 63 |
-
TaskType.GENERAL: ["gpt-4o", "claude-3-5-sonnet-20241022", "gemini-
|
| 64 |
-
TaskType.CODE: ["claude-3-5-sonnet-20241022", "gpt-4o", "gemini-
|
| 65 |
-
TaskType.REASONING: ["claude-3-opus-20240229", "gpt-4o", "
|
| 66 |
-
TaskType.EXTRACTION: ["gpt-4o-mini", "claude-3-haiku-20240307", "gemini-
|
| 67 |
-
TaskType.SUMMARIZATION: ["gpt-4o-mini", "claude-3-5-haiku-20241022", "gemini-
|
| 68 |
TaskType.CLASSIFICATION: ["gpt-4o-mini", "claude-3-haiku-20240307", "llama-3.1-8b-instant"],
|
| 69 |
-
TaskType.CREATIVE: ["claude-3-5-sonnet-20241022", "gpt-4o", "gemini-
|
| 70 |
-
TaskType.FAST: ["llama-3.1-8b-instant", "gemini-
|
| 71 |
})
|
| 72 |
|
| 73 |
|
|
@@ -143,19 +144,35 @@ class SmartModelRouter:
|
|
| 143 |
"claude-3-sonnet-20240229": 0.88,
|
| 144 |
"claude-3-5-haiku-20241022": 0.82,
|
| 145 |
"claude-3-haiku-20240307": 0.75,
|
| 146 |
-
# Google
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
"gemini-1.5-pro": 0.91,
|
| 148 |
-
"gemini-2.0-flash-exp": 0.88,
|
| 149 |
"gemini-1.5-flash": 0.78,
|
| 150 |
"gemini-pro": 0.75,
|
| 151 |
# Groq
|
| 152 |
"llama-3.3-70b-versatile": 0.85,
|
|
|
|
| 153 |
"llama-3.1-70b-versatile": 0.84,
|
| 154 |
"llama3-70b-8192": 0.82,
|
| 155 |
"mixtral-8x7b-32768": 0.78,
|
| 156 |
"llama-3.1-8b-instant": 0.65,
|
| 157 |
"llama3-8b-8192": 0.60,
|
| 158 |
"gemma2-9b-it": 0.62,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
}
|
| 160 |
|
| 161 |
# Model speed rankings (relative, based on typical latency)
|
|
@@ -168,9 +185,22 @@ class SmartModelRouter:
|
|
| 168 |
"llama3-70b-8192": 0.92,
|
| 169 |
"llama-3.1-70b-versatile": 0.91,
|
| 170 |
"llama-3.3-70b-versatile": 0.90,
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
"gemini-1.5-flash": 0.88,
|
| 173 |
-
"gemini-2.0-flash-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
# Mini models
|
| 175 |
"gpt-4o-mini": 0.85,
|
| 176 |
"claude-3-haiku-20240307": 0.84,
|
|
@@ -178,6 +208,7 @@ class SmartModelRouter:
|
|
| 178 |
"gpt-3.5-turbo": 0.82,
|
| 179 |
# Pro models
|
| 180 |
"gemini-pro": 0.75,
|
|
|
|
| 181 |
"gemini-1.5-pro": 0.70,
|
| 182 |
"gpt-4o": 0.68,
|
| 183 |
"claude-3-5-sonnet-20241022": 0.65,
|
|
@@ -193,6 +224,7 @@ class SmartModelRouter:
|
|
| 193 |
anthropic_api_key: str | SecretStr | None = None,
|
| 194 |
google_api_key: str | SecretStr | None = None,
|
| 195 |
groq_api_key: str | SecretStr | None = None,
|
|
|
|
| 196 |
config: RoutingConfig | None = None,
|
| 197 |
):
|
| 198 |
self.config = config or RoutingConfig()
|
|
@@ -207,6 +239,7 @@ class SmartModelRouter:
|
|
| 207 |
"anthropic": self._get_key_value(anthropic_api_key),
|
| 208 |
"google": self._get_key_value(google_api_key),
|
| 209 |
"groq": self._get_key_value(groq_api_key),
|
|
|
|
| 210 |
}
|
| 211 |
|
| 212 |
@staticmethod
|
|
@@ -248,6 +281,12 @@ class SmartModelRouter:
|
|
| 248 |
self.providers["groq"] = provider
|
| 249 |
logger.info("Initialized Groq provider")
|
| 250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
if not self.providers:
|
| 252 |
logger.warning("No LLM providers configured")
|
| 253 |
|
|
|
|
| 22 |
from app.models.providers.anthropic import AnthropicProvider
|
| 23 |
from app.models.providers.google import GoogleProvider
|
| 24 |
from app.models.providers.groq import GroqProvider
|
| 25 |
+
from app.models.providers.nvidia import NVIDIAProvider
|
| 26 |
|
| 27 |
logger = logging.getLogger(__name__)
|
| 28 |
|
|
|
|
| 61 |
|
| 62 |
# Task-specific model preferences
|
| 63 |
task_preferences: dict[TaskType, list[str]] = field(default_factory=lambda: {
|
| 64 |
+
TaskType.GENERAL: ["gpt-4o", "claude-3-5-sonnet-20241022", "gemini-2.5-pro", "deepseek-r1"],
|
| 65 |
+
TaskType.CODE: ["claude-3-5-sonnet-20241022", "gpt-4o", "devstral-2-123b", "gemini-2.5-pro"],
|
| 66 |
+
TaskType.REASONING: ["claude-3-opus-20240229", "deepseek-r1", "gpt-4o", "step-3.5-flash"],
|
| 67 |
+
TaskType.EXTRACTION: ["gpt-4o-mini", "claude-3-haiku-20240307", "gemini-2.5-flash"],
|
| 68 |
+
TaskType.SUMMARIZATION: ["gpt-4o-mini", "claude-3-5-haiku-20241022", "gemini-2.5-flash"],
|
| 69 |
TaskType.CLASSIFICATION: ["gpt-4o-mini", "claude-3-haiku-20240307", "llama-3.1-8b-instant"],
|
| 70 |
+
TaskType.CREATIVE: ["claude-3-5-sonnet-20241022", "gpt-4o", "gemini-2.5-pro"],
|
| 71 |
+
TaskType.FAST: ["llama-3.1-8b-instant", "gemini-2.5-flash", "gpt-4o-mini"],
|
| 72 |
})
|
| 73 |
|
| 74 |
|
|
|
|
| 144 |
"claude-3-sonnet-20240229": 0.88,
|
| 145 |
"claude-3-5-haiku-20241022": 0.82,
|
| 146 |
"claude-3-haiku-20240307": 0.75,
|
| 147 |
+
# Google Gemini 2.5 & 3.0
|
| 148 |
+
"gemini-2.5-pro": 0.93,
|
| 149 |
+
"gemini-2.5-flash": 0.85,
|
| 150 |
+
"gemini-3-flash-preview": 0.87,
|
| 151 |
+
"gemini-3.1-flash-lite-preview": 0.82,
|
| 152 |
+
# Google Gemini 2.0
|
| 153 |
+
"gemini-2.0-flash": 0.88,
|
| 154 |
+
"gemini-2.0-flash-lite": 0.80,
|
| 155 |
+
# Google Gemini 1.5
|
| 156 |
"gemini-1.5-pro": 0.91,
|
|
|
|
| 157 |
"gemini-1.5-flash": 0.78,
|
| 158 |
"gemini-pro": 0.75,
|
| 159 |
# Groq
|
| 160 |
"llama-3.3-70b-versatile": 0.85,
|
| 161 |
+
"llama-3.2-90b-vision-preview": 0.84,
|
| 162 |
"llama-3.1-70b-versatile": 0.84,
|
| 163 |
"llama3-70b-8192": 0.82,
|
| 164 |
"mixtral-8x7b-32768": 0.78,
|
| 165 |
"llama-3.1-8b-instant": 0.65,
|
| 166 |
"llama3-8b-8192": 0.60,
|
| 167 |
"gemma2-9b-it": 0.62,
|
| 168 |
+
# NVIDIA
|
| 169 |
+
"deepseek-r1": 0.92,
|
| 170 |
+
"deepseek-v3.2": 0.90,
|
| 171 |
+
"step-3.5-flash": 0.88,
|
| 172 |
+
"glm4.7": 0.87,
|
| 173 |
+
"devstral-2-123b": 0.86,
|
| 174 |
+
"llama-3.3-70b": 0.85,
|
| 175 |
+
"nemotron-70b": 0.83,
|
| 176 |
}
|
| 177 |
|
| 178 |
# Model speed rankings (relative, based on typical latency)
|
|
|
|
| 185 |
"llama3-70b-8192": 0.92,
|
| 186 |
"llama-3.1-70b-versatile": 0.91,
|
| 187 |
"llama-3.3-70b-versatile": 0.90,
|
| 188 |
+
"llama-3.2-90b-vision-preview": 0.89,
|
| 189 |
+
# Google Flash models
|
| 190 |
+
"gemini-2.5-flash": 0.90,
|
| 191 |
+
"gemini-3-flash-preview": 0.89,
|
| 192 |
+
"gemini-2.0-flash": 0.88,
|
| 193 |
"gemini-1.5-flash": 0.88,
|
| 194 |
+
"gemini-2.0-flash-lite": 0.87,
|
| 195 |
+
"gemini-3.1-flash-lite-preview": 0.86,
|
| 196 |
+
# NVIDIA models
|
| 197 |
+
"step-3.5-flash": 0.85,
|
| 198 |
+
"devstral-2-123b": 0.84,
|
| 199 |
+
"llama-3.3-70b": 0.83,
|
| 200 |
+
"nemotron-70b": 0.82,
|
| 201 |
+
"glm4.7": 0.81,
|
| 202 |
+
"deepseek-v3.2": 0.80,
|
| 203 |
+
"deepseek-r1": 0.79,
|
| 204 |
# Mini models
|
| 205 |
"gpt-4o-mini": 0.85,
|
| 206 |
"claude-3-haiku-20240307": 0.84,
|
|
|
|
| 208 |
"gpt-3.5-turbo": 0.82,
|
| 209 |
# Pro models
|
| 210 |
"gemini-pro": 0.75,
|
| 211 |
+
"gemini-2.5-pro": 0.72,
|
| 212 |
"gemini-1.5-pro": 0.70,
|
| 213 |
"gpt-4o": 0.68,
|
| 214 |
"claude-3-5-sonnet-20241022": 0.65,
|
|
|
|
| 224 |
anthropic_api_key: str | SecretStr | None = None,
|
| 225 |
google_api_key: str | SecretStr | None = None,
|
| 226 |
groq_api_key: str | SecretStr | None = None,
|
| 227 |
+
nvidia_api_key: str | SecretStr | None = None,
|
| 228 |
config: RoutingConfig | None = None,
|
| 229 |
):
|
| 230 |
self.config = config or RoutingConfig()
|
|
|
|
| 239 |
"anthropic": self._get_key_value(anthropic_api_key),
|
| 240 |
"google": self._get_key_value(google_api_key),
|
| 241 |
"groq": self._get_key_value(groq_api_key),
|
| 242 |
+
"nvidia": self._get_key_value(nvidia_api_key),
|
| 243 |
}
|
| 244 |
|
| 245 |
@staticmethod
|
|
|
|
| 281 |
self.providers["groq"] = provider
|
| 282 |
logger.info("Initialized Groq provider")
|
| 283 |
|
| 284 |
+
if self._api_keys["nvidia"]:
|
| 285 |
+
provider = NVIDIAProvider(api_key=self._api_keys["nvidia"])
|
| 286 |
+
await provider.initialize()
|
| 287 |
+
self.providers["nvidia"] = provider
|
| 288 |
+
logger.info("Initialized NVIDIA provider")
|
| 289 |
+
|
| 290 |
if not self.providers:
|
| 291 |
logger.warning("No LLM providers configured")
|
| 292 |
|