Spaces:
Sleeping
Sleeping
| """Multi-model client supporting GPT-5, GPT-5.1, Gemini models, and Claude 4.5 Sonnet.""" | |
| import os | |
| from typing import Optional, List, Dict, Any | |
| from openai import OpenAI | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| class MultiModelClient: | |
| """Unified client for multiple AI model providers.""" | |
| MODELS = { | |
| "gpt-5": { | |
| "provider": "openrouter", | |
| "model_id": "openai/gpt-5", | |
| "display_name": "GPT-5" | |
| }, | |
| "gpt-5.1": { | |
| "provider": "openrouter", | |
| "model_id": "openai/gpt-5.1", | |
| "display_name": "GPT-5.1" | |
| }, | |
| "gemini-2.5-pro": { | |
| "provider": "openrouter", | |
| "model_id": "google/gemini-2.5-pro", | |
| "display_name": "Gemini 2.5 Pro" | |
| }, | |
| "gemini-3-pro-preview": { | |
| "provider": "openrouter", | |
| "model_id": "google/gemini-3-pro-preview", | |
| "display_name": "Gemini 3 Pro Preview" | |
| }, | |
| "claude-4.5-sonnet": { | |
| "provider": "openrouter", | |
| "model_id": "anthropic/claude-sonnet-4.5", | |
| "display_name": "Claude 4.5 Sonnet" | |
| }, | |
| "claude-4.5-opus": { | |
| "provider": "openrouter", | |
| "model_id": "anthropic/claude-opus-4.5", | |
| "display_name": "Claude 4.5 Opus" | |
| }, | |
| "gpt-4.1-mini": { | |
| "provider": "openrouter", | |
| "model_id": "openai/gpt-4.1-mini", | |
| "display_name": "GPT-4.1 Mini (make-it-heavy default)" | |
| }, | |
| "gemini-2.0-flash": { | |
| "provider": "openrouter", | |
| "model_id": "google/gemini-2.0-flash-001", | |
| "display_name": "Gemini 2.0 Flash (fast)" | |
| }, | |
| "llama-3.1-70b": { | |
| "provider": "openrouter", | |
| "model_id": "meta-llama/llama-3.1-70b", | |
| "display_name": "Llama 3.1 70B (open source)" | |
| } | |
| } | |
| def __init__( | |
| self, | |
| openrouter_api_key: Optional[str] = None, | |
| google_api_key: Optional[str] = None, | |
| temperature: float = 0.7, | |
| max_tokens: int = 4000 | |
| ): | |
| """Initialize multi-model client. | |
| Args: | |
| openrouter_api_key: OpenRouter API key (for all OpenRouter-hosted models) | |
| google_api_key: Google API key (optional, for direct Gemini API access) | |
| temperature: Default sampling temperature | |
| max_tokens: Default maximum tokens | |
| """ | |
| self.openrouter_api_key = openrouter_api_key or os.getenv("OPENROUTER_API_KEY") | |
| self.google_api_key = google_api_key or os.getenv("GOOGLE_API_KEY") | |
| self.temperature = temperature | |
| self.max_tokens = max_tokens | |
| # Initialize OpenRouter client (handles all OpenRouter-hosted models) | |
| if self.openrouter_api_key: | |
| self.openrouter_client = OpenAI( | |
| base_url="https://openrouter.ai/api/v1", | |
| api_key=self.openrouter_api_key, | |
| ) | |
| else: | |
| self.openrouter_client = None | |
| # Google client is optional; only load the SDK if a key is provided | |
| self._google_available = False | |
| if self.google_api_key: | |
| try: | |
| import google.generativeai as genai # type: ignore | |
| genai.configure(api_key=self.google_api_key) | |
| self._google_available = True | |
| except ImportError: | |
| # Library not installed; Gemini direct access will be unavailable | |
| self._google_available = False | |
| def chat( | |
| self, | |
| messages: List[Dict[str, str]], | |
| model: str = "claude-4.5-sonnet", | |
| temperature: Optional[float] = None, | |
| max_tokens: Optional[int] = None | |
| ) -> str: | |
| """Send a chat completion request to the specified model. | |
| Args: | |
| messages: List of message dicts with 'role' and 'content' | |
| model: Model key (gpt-5, gemini-2.5-pro, claude-4.5-sonnet) | |
| temperature: Override default temperature | |
| max_tokens: Override default max tokens | |
| Returns: | |
| Model response content | |
| """ | |
| if model not in self.MODELS: | |
| raise ValueError(f"Unknown model: {model}. Available: {list(self.MODELS.keys())}") | |
| model_info = self.MODELS[model] | |
| provider = model_info["provider"] | |
| temp = temperature if temperature is not None else self.temperature | |
| max_tok = max_tokens if max_tokens is not None else self.max_tokens | |
| # All models now route through OpenRouter | |
| if provider in ["openai", "openrouter", "google"]: | |
| return self._chat_openrouter(messages, model_info["model_id"], temp, max_tok) | |
| else: | |
| raise ValueError(f"Unknown provider: {provider}") | |
| def _chat_openrouter( | |
| self, | |
| messages: List[Dict[str, str]], | |
| model_id: str, | |
| temperature: float, | |
| max_tokens: int | |
| ) -> str: | |
| """Chat using OpenRouter (GPT-5 or Claude).""" | |
| if not self.openrouter_client: | |
| raise ValueError("OpenRouter API key not configured") | |
| try: | |
| response = self.openrouter_client.chat.completions.create( | |
| model=model_id, | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| raise Exception(f"OpenRouter API error: {str(e)}") | |
| def _chat_google( | |
| self, | |
| messages: List[Dict[str, str]], | |
| model_id: str, | |
| temperature: float, | |
| max_tokens: int | |
| ) -> str: | |
| """Chat using Google Gemini.""" | |
| if not self.google_api_key: | |
| raise ValueError("Google API key not configured") | |
| try: | |
| import google.generativeai as genai # type: ignore | |
| from google.generativeai import types as genai_types # type: ignore | |
| except ImportError: | |
| raise ImportError( | |
| "google-generativeai is required for direct Gemini access. " | |
| "Install it or use OpenRouter-hosted models instead." | |
| ) | |
| try: | |
| genai.configure(api_key=self.google_api_key) | |
| model = genai.GenerativeModel(model_id) | |
| gemini_messages = [] | |
| system_instruction = None | |
| for msg in messages: | |
| if msg["role"] == "system": | |
| system_instruction = msg["content"] | |
| elif msg["role"] == "user": | |
| gemini_messages.append({"role": "user", "parts": [msg["content"]]}) | |
| elif msg["role"] == "assistant": | |
| gemini_messages.append({"role": "model", "parts": [msg["content"]]}) | |
| generation_config = genai_types.GenerationConfig( | |
| temperature=temperature, | |
| max_output_tokens=max_tokens | |
| ) | |
| if system_instruction and gemini_messages and gemini_messages[0]["role"] == "user": | |
| gemini_messages[0]["parts"][0] = f"{system_instruction}\n\n{gemini_messages[0]['parts'][0]}" | |
| if len(gemini_messages) == 1 and gemini_messages[0]["role"] == "user": | |
| response = model.generate_content( | |
| gemini_messages[0]["parts"][0], | |
| generation_config=generation_config | |
| ) | |
| return response.text | |
| chat = model.start_chat(history=gemini_messages[:-1]) | |
| response = chat.send_message( | |
| gemini_messages[-1]["parts"][0], | |
| generation_config=generation_config | |
| ) | |
| return response.text | |
| except Exception as e: | |
| raise Exception(f"Google API error: {str(e)}") | |
| async def async_chat( | |
| self, | |
| messages: List[Dict[str, str]], | |
| model: str = "claude-4.5-sonnet", | |
| temperature: Optional[float] = None, | |
| max_tokens: Optional[int] = None | |
| ) -> str: | |
| """Async chat completion request. | |
| Args: | |
| messages: List of message dicts | |
| model: Model key | |
| temperature: Override default temperature | |
| max_tokens: Override default max tokens | |
| Returns: | |
| Model response content | |
| """ | |
| import asyncio | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor( | |
| None, | |
| lambda: self.chat(messages, model, temperature, max_tokens) | |
| ) | |
| def get_available_models(cls) -> List[Dict[str, str]]: | |
| """Get list of available models with metadata. | |
| Returns: | |
| List of model info dicts | |
| """ | |
| return [ | |
| { | |
| "key": key, | |
| "name": info["display_name"], | |
| "provider": info["provider"] | |
| } | |
| for key, info in cls.MODELS.items() | |
| ] | |