Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| from typing import Callable, List, Dict, Any, Optional | |
| from dotenv import load_dotenv | |
| import litellm | |
| load_dotenv() | |
| _PROVIDER_MAP = { | |
| "openai": { | |
| "default_model": "gpt-4o", | |
| "model_prefix": "openai/", | |
| "api_key": os.getenv("OPENAI_API_KEY"), | |
| }, | |
| "mistral": { | |
| "default_model": "mistral-small-2503", | |
| "model_prefix": "mistral/", | |
| "api_key": os.getenv("MISTRAL_API_KEY"), | |
| }, | |
| "gemini": { | |
| "default_model": "gemini-2.0-flash", | |
| "model_prefix": "gemini/", | |
| "api_key": os.getenv("GOOGLE_API_KEY"), | |
| }, | |
| "custom": { | |
| "default_model": "gpt-3.5-turbo", | |
| "model_prefix": "", | |
| "api_key": os.getenv("CUSTOM_API_KEY"), | |
| "api_base": os.getenv("CUSTOM_API_BASE"), | |
| }, | |
| } | |
| def get_default_model(provider: str) -> str: | |
| """Get the default model name for a provider.""" | |
| return _PROVIDER_MAP.get(provider, {}).get("default_model", "gpt-3.5-turbo") | |
| def get_completion_fn(provider: str, model_name: str = None, api_key: str = None) -> Callable[[str], str]: | |
| """Get completion function with optional custom model and API key.""" | |
| cfg = _PROVIDER_MAP.get(provider, _PROVIDER_MAP["custom"]) | |
| # Use provided model name or default | |
| if not model_name or model_name.strip() == "": | |
| model_name = cfg["default_model"] | |
| # Use provided API key or default from .env | |
| if not api_key or api_key.strip() == "": | |
| api_key = cfg["api_key"] | |
| # Construct full model name with prefix | |
| full_model = f"{cfg['model_prefix']}{model_name}" | |
| def _call( | |
| prompt: str, | |
| tools: Optional[List[Dict[str, Any]]] = None, | |
| tool_choice: Optional[str] = None | |
| ) -> str: | |
| messages = [{"role": "user", "content": prompt}] | |
| # Add tool-related parameters if provided | |
| extra_params = {} | |
| if tools: | |
| extra_params["tools"] = tools | |
| if tool_choice: | |
| extra_params["tool_choice"] = tool_choice | |
| resp = litellm.completion( | |
| model=full_model, | |
| messages=messages, | |
| api_key=api_key, | |
| api_base=cfg.get("api_base"), | |
| **extra_params | |
| ) | |
| # Handle tool calls | |
| if resp.choices[0].message.tool_calls: | |
| tool_calls = resp.choices[0].message.tool_calls | |
| return tool_calls[0].json() | |
| return resp["choices"][0]["message"]["content"].strip() | |
| return _call | |