Heavy / src /multi_client.py
justinhew
Deploy to HF Spaces
ea81a05
"""Multi-model client supporting GPT-5, GPT-5.1, Gemini models, and Claude 4.5 Sonnet."""
import os
from typing import Optional, List, Dict, Any
from openai import OpenAI
from dotenv import load_dotenv
load_dotenv()
class MultiModelClient:
"""Unified client for multiple AI model providers."""
MODELS = {
"gpt-5": {
"provider": "openrouter",
"model_id": "openai/gpt-5",
"display_name": "GPT-5"
},
"gpt-5.1": {
"provider": "openrouter",
"model_id": "openai/gpt-5.1",
"display_name": "GPT-5.1"
},
"gemini-2.5-pro": {
"provider": "openrouter",
"model_id": "google/gemini-2.5-pro",
"display_name": "Gemini 2.5 Pro"
},
"gemini-3-pro-preview": {
"provider": "openrouter",
"model_id": "google/gemini-3-pro-preview",
"display_name": "Gemini 3 Pro Preview"
},
"claude-4.5-sonnet": {
"provider": "openrouter",
"model_id": "anthropic/claude-sonnet-4.5",
"display_name": "Claude 4.5 Sonnet"
},
"claude-4.5-opus": {
"provider": "openrouter",
"model_id": "anthropic/claude-opus-4.5",
"display_name": "Claude 4.5 Opus"
},
"gpt-4.1-mini": {
"provider": "openrouter",
"model_id": "openai/gpt-4.1-mini",
"display_name": "GPT-4.1 Mini (make-it-heavy default)"
},
"gemini-2.0-flash": {
"provider": "openrouter",
"model_id": "google/gemini-2.0-flash-001",
"display_name": "Gemini 2.0 Flash (fast)"
},
"llama-3.1-70b": {
"provider": "openrouter",
"model_id": "meta-llama/llama-3.1-70b",
"display_name": "Llama 3.1 70B (open source)"
}
}
def __init__(
self,
openrouter_api_key: Optional[str] = None,
google_api_key: Optional[str] = None,
temperature: float = 0.7,
max_tokens: int = 4000
):
"""Initialize multi-model client.
Args:
openrouter_api_key: OpenRouter API key (for all OpenRouter-hosted models)
google_api_key: Google API key (optional, for direct Gemini API access)
temperature: Default sampling temperature
max_tokens: Default maximum tokens
"""
self.openrouter_api_key = openrouter_api_key or os.getenv("OPENROUTER_API_KEY")
self.google_api_key = google_api_key or os.getenv("GOOGLE_API_KEY")
self.temperature = temperature
self.max_tokens = max_tokens
# Initialize OpenRouter client (handles all OpenRouter-hosted models)
if self.openrouter_api_key:
self.openrouter_client = OpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=self.openrouter_api_key,
)
else:
self.openrouter_client = None
# Google client is optional; only load the SDK if a key is provided
self._google_available = False
if self.google_api_key:
try:
import google.generativeai as genai # type: ignore
genai.configure(api_key=self.google_api_key)
self._google_available = True
except ImportError:
# Library not installed; Gemini direct access will be unavailable
self._google_available = False
def chat(
self,
messages: List[Dict[str, str]],
model: str = "claude-4.5-sonnet",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None
) -> str:
"""Send a chat completion request to the specified model.
Args:
messages: List of message dicts with 'role' and 'content'
model: Model key (gpt-5, gemini-2.5-pro, claude-4.5-sonnet)
temperature: Override default temperature
max_tokens: Override default max tokens
Returns:
Model response content
"""
if model not in self.MODELS:
raise ValueError(f"Unknown model: {model}. Available: {list(self.MODELS.keys())}")
model_info = self.MODELS[model]
provider = model_info["provider"]
temp = temperature if temperature is not None else self.temperature
max_tok = max_tokens if max_tokens is not None else self.max_tokens
# All models now route through OpenRouter
if provider in ["openai", "openrouter", "google"]:
return self._chat_openrouter(messages, model_info["model_id"], temp, max_tok)
else:
raise ValueError(f"Unknown provider: {provider}")
def _chat_openrouter(
self,
messages: List[Dict[str, str]],
model_id: str,
temperature: float,
max_tokens: int
) -> str:
"""Chat using OpenRouter (GPT-5 or Claude)."""
if not self.openrouter_client:
raise ValueError("OpenRouter API key not configured")
try:
response = self.openrouter_client.chat.completions.create(
model=model_id,
messages=messages,
temperature=temperature,
max_tokens=max_tokens
)
return response.choices[0].message.content
except Exception as e:
raise Exception(f"OpenRouter API error: {str(e)}")
def _chat_google(
self,
messages: List[Dict[str, str]],
model_id: str,
temperature: float,
max_tokens: int
) -> str:
"""Chat using Google Gemini."""
if not self.google_api_key:
raise ValueError("Google API key not configured")
try:
import google.generativeai as genai # type: ignore
from google.generativeai import types as genai_types # type: ignore
except ImportError:
raise ImportError(
"google-generativeai is required for direct Gemini access. "
"Install it or use OpenRouter-hosted models instead."
)
try:
genai.configure(api_key=self.google_api_key)
model = genai.GenerativeModel(model_id)
gemini_messages = []
system_instruction = None
for msg in messages:
if msg["role"] == "system":
system_instruction = msg["content"]
elif msg["role"] == "user":
gemini_messages.append({"role": "user", "parts": [msg["content"]]})
elif msg["role"] == "assistant":
gemini_messages.append({"role": "model", "parts": [msg["content"]]})
generation_config = genai_types.GenerationConfig(
temperature=temperature,
max_output_tokens=max_tokens
)
if system_instruction and gemini_messages and gemini_messages[0]["role"] == "user":
gemini_messages[0]["parts"][0] = f"{system_instruction}\n\n{gemini_messages[0]['parts'][0]}"
if len(gemini_messages) == 1 and gemini_messages[0]["role"] == "user":
response = model.generate_content(
gemini_messages[0]["parts"][0],
generation_config=generation_config
)
return response.text
chat = model.start_chat(history=gemini_messages[:-1])
response = chat.send_message(
gemini_messages[-1]["parts"][0],
generation_config=generation_config
)
return response.text
except Exception as e:
raise Exception(f"Google API error: {str(e)}")
async def async_chat(
self,
messages: List[Dict[str, str]],
model: str = "claude-4.5-sonnet",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None
) -> str:
"""Async chat completion request.
Args:
messages: List of message dicts
model: Model key
temperature: Override default temperature
max_tokens: Override default max tokens
Returns:
Model response content
"""
import asyncio
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
lambda: self.chat(messages, model, temperature, max_tokens)
)
@classmethod
def get_available_models(cls) -> List[Dict[str, str]]:
"""Get list of available models with metadata.
Returns:
List of model info dicts
"""
return [
{
"key": key,
"name": info["display_name"],
"provider": info["provider"]
}
for key, info in cls.MODELS.items()
]