File size: 8,941 Bytes
ea81a05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
"""Multi-model client supporting GPT-5, GPT-5.1, Gemini models, and Claude 4.5 Sonnet."""

import os
from typing import Optional, List, Dict, Any
from openai import OpenAI
from dotenv import load_dotenv

load_dotenv()


class MultiModelClient:
    """Unified client for multiple AI model providers."""

    MODELS = {
        "gpt-5": {
            "provider": "openrouter",
            "model_id": "openai/gpt-5",
            "display_name": "GPT-5"
        },
        "gpt-5.1": {
            "provider": "openrouter",
            "model_id": "openai/gpt-5.1",
            "display_name": "GPT-5.1"
        },
        "gemini-2.5-pro": {
            "provider": "openrouter",
            "model_id": "google/gemini-2.5-pro",
            "display_name": "Gemini 2.5 Pro"
        },
        "gemini-3-pro-preview": {
            "provider": "openrouter",
            "model_id": "google/gemini-3-pro-preview",
            "display_name": "Gemini 3 Pro Preview"
        },
        "claude-4.5-sonnet": {
            "provider": "openrouter",
            "model_id": "anthropic/claude-sonnet-4.5",
            "display_name": "Claude 4.5 Sonnet"
        },
        "claude-4.5-opus": {
            "provider": "openrouter",
            "model_id": "anthropic/claude-opus-4.5",
            "display_name": "Claude 4.5 Opus"
        },
        "gpt-4.1-mini": {
            "provider": "openrouter",
            "model_id": "openai/gpt-4.1-mini",
            "display_name": "GPT-4.1 Mini (make-it-heavy default)"
        },
        "gemini-2.0-flash": {
            "provider": "openrouter",
            "model_id": "google/gemini-2.0-flash-001",
            "display_name": "Gemini 2.0 Flash (fast)"
        },
        "llama-3.1-70b": {
            "provider": "openrouter",
            "model_id": "meta-llama/llama-3.1-70b",
            "display_name": "Llama 3.1 70B (open source)"
        }
    }

    def __init__(
        self,
        openrouter_api_key: Optional[str] = None,
        google_api_key: Optional[str] = None,
        temperature: float = 0.7,
        max_tokens: int = 4000
    ):
        """Initialize multi-model client.

        Args:
            openrouter_api_key: OpenRouter API key (for all OpenRouter-hosted models)
            google_api_key: Google API key (optional, for direct Gemini API access)
            temperature: Default sampling temperature
            max_tokens: Default maximum tokens
        """
        self.openrouter_api_key = openrouter_api_key or os.getenv("OPENROUTER_API_KEY")
        self.google_api_key = google_api_key or os.getenv("GOOGLE_API_KEY")

        self.temperature = temperature
        self.max_tokens = max_tokens

        # Initialize OpenRouter client (handles all OpenRouter-hosted models)
        if self.openrouter_api_key:
            self.openrouter_client = OpenAI(
                base_url="https://openrouter.ai/api/v1",
                api_key=self.openrouter_api_key,
            )
        else:
            self.openrouter_client = None

        # Google client is optional; only load the SDK if a key is provided
        self._google_available = False
        if self.google_api_key:
            try:
                import google.generativeai as genai  # type: ignore

                genai.configure(api_key=self.google_api_key)
                self._google_available = True
            except ImportError:
                # Library not installed; Gemini direct access will be unavailable
                self._google_available = False

    def chat(
        self,
        messages: List[Dict[str, str]],
        model: str = "claude-4.5-sonnet",
        temperature: Optional[float] = None,
        max_tokens: Optional[int] = None
    ) -> str:
        """Send a chat completion request to the specified model.

        Args:
            messages: List of message dicts with 'role' and 'content'
            model: Model key (gpt-5, gemini-2.5-pro, claude-4.5-sonnet)
            temperature: Override default temperature
            max_tokens: Override default max tokens

        Returns:
            Model response content
        """
        if model not in self.MODELS:
            raise ValueError(f"Unknown model: {model}. Available: {list(self.MODELS.keys())}")

        model_info = self.MODELS[model]
        provider = model_info["provider"]

        temp = temperature if temperature is not None else self.temperature
        max_tok = max_tokens if max_tokens is not None else self.max_tokens

        # All models now route through OpenRouter
        if provider in ["openai", "openrouter", "google"]:
            return self._chat_openrouter(messages, model_info["model_id"], temp, max_tok)
        else:
            raise ValueError(f"Unknown provider: {provider}")

    def _chat_openrouter(
        self,
        messages: List[Dict[str, str]],
        model_id: str,
        temperature: float,
        max_tokens: int
    ) -> str:
        """Chat using OpenRouter (GPT-5 or Claude)."""
        if not self.openrouter_client:
            raise ValueError("OpenRouter API key not configured")

        try:
            response = self.openrouter_client.chat.completions.create(
                model=model_id,
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens
            )
            return response.choices[0].message.content
        except Exception as e:
            raise Exception(f"OpenRouter API error: {str(e)}")

    def _chat_google(
        self,
        messages: List[Dict[str, str]],
        model_id: str,
        temperature: float,
        max_tokens: int
    ) -> str:
        """Chat using Google Gemini."""
        if not self.google_api_key:
            raise ValueError("Google API key not configured")

        try:
            import google.generativeai as genai  # type: ignore
            from google.generativeai import types as genai_types  # type: ignore
        except ImportError:
            raise ImportError(
                "google-generativeai is required for direct Gemini access. "
                "Install it or use OpenRouter-hosted models instead."
            )

        try:
            genai.configure(api_key=self.google_api_key)
            model = genai.GenerativeModel(model_id)

            gemini_messages = []
            system_instruction = None

            for msg in messages:
                if msg["role"] == "system":
                    system_instruction = msg["content"]
                elif msg["role"] == "user":
                    gemini_messages.append({"role": "user", "parts": [msg["content"]]})
                elif msg["role"] == "assistant":
                    gemini_messages.append({"role": "model", "parts": [msg["content"]]})

            generation_config = genai_types.GenerationConfig(
                temperature=temperature,
                max_output_tokens=max_tokens
            )

            if system_instruction and gemini_messages and gemini_messages[0]["role"] == "user":
                gemini_messages[0]["parts"][0] = f"{system_instruction}\n\n{gemini_messages[0]['parts'][0]}"

            if len(gemini_messages) == 1 and gemini_messages[0]["role"] == "user":
                response = model.generate_content(
                    gemini_messages[0]["parts"][0],
                    generation_config=generation_config
                )
                return response.text

            chat = model.start_chat(history=gemini_messages[:-1])
            response = chat.send_message(
                gemini_messages[-1]["parts"][0],
                generation_config=generation_config
            )
            return response.text

        except Exception as e:
            raise Exception(f"Google API error: {str(e)}")

    async def async_chat(
        self,
        messages: List[Dict[str, str]],
        model: str = "claude-4.5-sonnet",
        temperature: Optional[float] = None,
        max_tokens: Optional[int] = None
    ) -> str:
        """Async chat completion request.

        Args:
            messages: List of message dicts
            model: Model key
            temperature: Override default temperature
            max_tokens: Override default max tokens

        Returns:
            Model response content
        """
        import asyncio
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(
            None,
            lambda: self.chat(messages, model, temperature, max_tokens)
        )

    @classmethod
    def get_available_models(cls) -> List[Dict[str, str]]:
        """Get list of available models with metadata.

        Returns:
            List of model info dicts
        """
        return [
            {
                "key": key,
                "name": info["display_name"],
                "provider": info["provider"]
            }
            for key, info in cls.MODELS.items()
        ]