| from dataclasses import dataclass, field |
| from enum import Enum |
| import logging |
| import os |
| from typing import ( |
| Any, |
| Awaitable, |
| Callable, |
| List, |
| Optional, |
| Iterator, |
| AsyncIterator, |
| Tuple, |
| TypedDict, |
| ) |
|
|
| from litellm import completion, acompletion, embedding |
| import litellm |
|
|
| from python.helpers import dotenv |
| from python.helpers.dotenv import load_dotenv |
| from python.helpers.providers import get_provider_config |
| from python.helpers.rate_limiter import RateLimiter |
| from python.helpers.tokens import approximate_tokens |
|
|
| from langchain_core.language_models.chat_models import SimpleChatModel |
| from langchain_core.outputs.chat_generation import ChatGenerationChunk |
| from langchain_core.callbacks.manager import ( |
| CallbackManagerForLLMRun, |
| AsyncCallbackManagerForLLMRun, |
| ) |
| from langchain_core.messages import ( |
| BaseMessage, |
| AIMessageChunk, |
| HumanMessage, |
| SystemMessage, |
| ) |
| from langchain.embeddings.base import Embeddings |
| from sentence_transformers import SentenceTransformer |
|
|
|
|
| |
| def turn_off_logging(): |
| os.environ["LITELLM_LOG"] = "ERROR" |
| litellm.suppress_debug_info = True |
| |
| for name in logging.Logger.manager.loggerDict: |
| if name.lower().startswith("litellm"): |
| logging.getLogger(name).setLevel(logging.ERROR) |
|
|
|
|
| |
| load_dotenv() |
| turn_off_logging() |
| print("DEBUG: models.py loaded") |
|
|
|
|
| class ModelType(Enum): |
| CHAT = "Chat" |
| EMBEDDING = "Embedding" |
|
|
|
|
| @dataclass |
| class ModelConfig: |
| type: ModelType |
| provider: str |
| name: str |
| api_base: str = "" |
| ctx_length: int = 0 |
| limit_requests: int = 0 |
| limit_input: int = 0 |
| limit_output: int = 0 |
| vision: bool = False |
| kwargs: dict = field(default_factory=dict) |
|
|
| def build_kwargs(self): |
| kwargs = self.kwargs.copy() or {} |
| if self.api_base and "api_base" not in kwargs: |
| kwargs["api_base"] = self.api_base |
| return kwargs |
|
|
|
|
| class ChatChunk(TypedDict): |
| """Simplified response chunk for chat models.""" |
|
|
| response_delta: str |
| reasoning_delta: str |
|
|
|
|
| rate_limiters: dict[str, RateLimiter] = {} |
| api_keys_round_robin: dict[str, int] = {} |
|
|
| def get_api_key(service: str) -> str: |
| |
| key = ( |
| dotenv.get_dotenv_value(f"API_KEY_{service.upper()}") |
| or dotenv.get_dotenv_value(f"{service.upper()}_API_KEY") |
| or dotenv.get_dotenv_value(f"{service.upper()}_API_TOKEN") |
| or "None" |
| ) |
| |
| if "," in key: |
| api_keys = [k.strip() for k in key.split(",") if k.strip()] |
| api_keys_round_robin[service] = api_keys_round_robin.get(service, -1) + 1 |
| key = api_keys[api_keys_round_robin[service] % len(api_keys)] |
| return key |
|
|
|
|
| def get_rate_limiter( |
| provider: str, name: str, requests: int, input: int, output: int |
| ) -> RateLimiter: |
| key = f"{provider}\\{name}" |
| rate_limiters[key] = limiter = rate_limiters.get(key, RateLimiter(seconds=60)) |
| limiter.limits["requests"] = requests or 0 |
| limiter.limits["input"] = input or 0 |
| limiter.limits["output"] = output or 0 |
| return limiter |
|
|
| async def apply_rate_limiter(model_config: ModelConfig|None, input_text: str, rate_limiter_callback: Callable[[str, str, int, int], Awaitable[bool]] | None = None): |
| if not model_config: |
| return |
| limiter = get_rate_limiter( |
| model_config.provider, |
| model_config.name, |
| model_config.limit_requests, |
| model_config.limit_input, |
| model_config.limit_output, |
| ) |
| limiter.add(input=approximate_tokens(input_text)) |
| limiter.add(requests=1) |
| await limiter.wait(rate_limiter_callback) |
| return limiter |
|
|
| def apply_rate_limiter_sync(model_config: ModelConfig|None, input_text: str, rate_limiter_callback: Callable[[str, str, int, int], Awaitable[bool]] | None = None): |
| if not model_config: |
| return |
| import asyncio, nest_asyncio |
| nest_asyncio.apply() |
| return asyncio.run(apply_rate_limiter(model_config, input_text, rate_limiter_callback)) |
|
|
|
|
| class LiteLLMChatWrapper(SimpleChatModel): |
| model_name: str |
| provider: str |
| kwargs: dict = {} |
| |
| class Config: |
| arbitrary_types_allowed = True |
| extra = "allow" |
| validate_assignment = False |
|
|
| def __init__(self, model: str, provider: str, model_config: Optional[ModelConfig] = None, **kwargs: Any): |
| model_value = f"{provider}/{model}" |
| super().__init__(model_name=model_value, provider=provider, kwargs=kwargs) |
| |
| self.a0_model_conf = model_config |
|
|
| @property |
| def _llm_type(self) -> str: |
| return "litellm-chat" |
| |
| def _convert_messages(self, messages: List[BaseMessage]) -> List[dict]: |
| result = [] |
| |
| role_mapping = { |
| "human": "user", |
| "ai": "assistant", |
| "system": "system", |
| "tool": "tool", |
| } |
| for m in messages: |
| role = role_mapping.get(m.type, m.type) |
| message_dict = {"role": role, "content": m.content} |
|
|
| |
| tool_calls = getattr(m, "tool_calls", None) |
| if tool_calls: |
| |
| new_tool_calls = [] |
| for tool_call in tool_calls: |
| |
| args = tool_call["args"] |
| if isinstance(args, dict): |
| import json |
|
|
| args_str = json.dumps(args) |
| else: |
| args_str = str(args) |
|
|
| new_tool_calls.append( |
| { |
| "id": tool_call.get("id", ""), |
| "type": "function", |
| "function": { |
| "name": tool_call["name"], |
| "arguments": args_str, |
| }, |
| } |
| ) |
| message_dict["tool_calls"] = new_tool_calls |
|
|
| |
| tool_call_id = getattr(m, "tool_call_id", None) |
| if tool_call_id: |
| message_dict["tool_call_id"] = tool_call_id |
|
|
| result.append(message_dict) |
| return result |
|
|
| def _call( |
| self, |
| messages: List[BaseMessage], |
| stop: Optional[List[str]] = None, |
| run_manager: Optional[CallbackManagerForLLMRun] = None, |
| **kwargs: Any, |
| ) -> str: |
| import asyncio |
| |
| msgs = self._convert_messages(messages) |
| |
| |
| apply_rate_limiter_sync(self.a0_model_conf, str(msgs)) |
| |
| |
| resp = completion( |
| model=self.model_name, messages=msgs, stop=stop, **{**self.kwargs, **kwargs} |
| ) |
|
|
| |
| parsed = _parse_chunk(resp) |
| return parsed["response_delta"] |
|
|
| def _stream( |
| self, |
| messages: List[BaseMessage], |
| stop: Optional[List[str]] = None, |
| run_manager: Optional[CallbackManagerForLLMRun] = None, |
| **kwargs: Any, |
| ) -> Iterator[ChatGenerationChunk]: |
| import asyncio |
| |
| msgs = self._convert_messages(messages) |
| |
| |
| apply_rate_limiter_sync(self.a0_model_conf, str(msgs)) |
| |
| for chunk in completion( |
| model=self.model_name, |
| messages=msgs, |
| stream=True, |
| stop=stop, |
| **{**self.kwargs, **kwargs}, |
| ): |
| parsed = _parse_chunk(chunk) |
| |
| if parsed["response_delta"]: |
| yield ChatGenerationChunk( |
| message=AIMessageChunk(content=parsed["response_delta"]) |
| ) |
|
|
| async def _astream( |
| self, |
| messages: List[BaseMessage], |
| stop: Optional[List[str]] = None, |
| run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, |
| **kwargs: Any, |
| ) -> AsyncIterator[ChatGenerationChunk]: |
| msgs = self._convert_messages(messages) |
| |
| |
| await apply_rate_limiter(self.a0_model_conf, str(msgs)) |
| |
| |
| response = await acompletion( |
| model=self.model_name, |
| messages=msgs, |
| stream=True, |
| stop=stop, |
| **{**self.kwargs, **kwargs}, |
| ) |
| async for chunk in response: |
| parsed = _parse_chunk(chunk) |
| |
| if parsed["response_delta"]: |
| yield ChatGenerationChunk( |
| message=AIMessageChunk(content=parsed["response_delta"]) |
| ) |
|
|
| async def unified_call( |
| self, |
| system_message="", |
| user_message="", |
| messages: List[BaseMessage] | None = None, |
| response_callback: Callable[[str, str], Awaitable[None]] | None = None, |
| reasoning_callback: Callable[[str, str], Awaitable[None]] | None = None, |
| tokens_callback: Callable[[str, int], Awaitable[None]] | None = None, |
| rate_limiter_callback: Callable[[str, str, int, int], Awaitable[bool]] | None = None, |
| **kwargs: Any, |
| ) -> Tuple[str, str]: |
|
|
| turn_off_logging() |
|
|
| if not messages: |
| messages = [] |
| |
| if system_message: |
| messages.insert(0, SystemMessage(content=system_message)) |
| if user_message: |
| messages.append(HumanMessage(content=user_message)) |
|
|
| |
| msgs_conv = self._convert_messages(messages) |
|
|
| |
| limiter = await apply_rate_limiter(self.a0_model_conf, str(msgs_conv), rate_limiter_callback) |
|
|
| |
| print(f"DEBUG: calling acompletion with model={self.model_name}") |
| _completion = await acompletion( |
| model=self.model_name, |
| messages=msgs_conv, |
| stream=True, |
| **{**self.kwargs, **kwargs}, |
| ) |
|
|
| |
| reasoning = "" |
| response = "" |
|
|
| |
| async for chunk in _completion: |
| parsed = _parse_chunk(chunk) |
| |
| if parsed["reasoning_delta"]: |
| reasoning += parsed["reasoning_delta"] |
| if reasoning_callback: |
| await reasoning_callback(parsed["reasoning_delta"], reasoning) |
| if tokens_callback: |
| await tokens_callback( |
| parsed["reasoning_delta"], |
| approximate_tokens(parsed["reasoning_delta"]), |
| ) |
| |
| if limiter: |
| limiter.add(output=approximate_tokens(parsed["reasoning_delta"])) |
| |
| if parsed["response_delta"]: |
| response += parsed["response_delta"] |
| if response_callback: |
| await response_callback(parsed["response_delta"], response) |
| if tokens_callback: |
| await tokens_callback( |
| parsed["response_delta"], |
| approximate_tokens(parsed["response_delta"]), |
| ) |
| |
| if limiter: |
| limiter.add(output=approximate_tokens(parsed["response_delta"])) |
|
|
| |
| return response, reasoning |
|
|
|
|
| class BrowserCompatibleChatWrapper(LiteLLMChatWrapper): |
| """ |
| A wrapper for browser agent that can filter/sanitize messages |
| before sending them to the LLM. |
| """ |
|
|
| def __init__(self, *args, **kwargs): |
| turn_off_logging() |
| super().__init__(*args, **kwargs) |
| |
| self.model = self.model_name |
|
|
| def _call( |
| self, |
| messages: List[BaseMessage], |
| stop: Optional[List[str]] = None, |
| run_manager: Optional[CallbackManagerForLLMRun] = None, |
| **kwargs: Any, |
| ) -> str: |
| turn_off_logging() |
| result = super()._call(messages, stop, run_manager, **kwargs) |
| return result |
|
|
| async def _astream( |
| self, |
| messages: List[BaseMessage], |
| stop: Optional[List[str]] = None, |
| run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, |
| **kwargs: Any, |
| ) -> AsyncIterator[ChatGenerationChunk]: |
| turn_off_logging() |
| async for chunk in super()._astream(messages, stop, run_manager, **kwargs): |
| yield chunk |
|
|
|
|
| class LiteLLMEmbeddingWrapper(Embeddings): |
| model_name: str |
| kwargs: dict = {} |
| a0_model_conf: Optional[ModelConfig] = None |
|
|
| def __init__(self, model: str, provider: str, model_config: Optional[ModelConfig] = None, **kwargs: Any): |
| self.model_name = f"{provider}/{model}" if provider != "openai" else model |
| self.kwargs = kwargs |
| self.a0_model_conf = model_config |
| |
| def embed_documents(self, texts: List[str]) -> List[List[float]]: |
| |
| apply_rate_limiter_sync(self.a0_model_conf, " ".join(texts)) |
| |
| resp = embedding(model=self.model_name, input=texts, **self.kwargs) |
| return [ |
| item.get("embedding") if isinstance(item, dict) else item.embedding |
| for item in resp.data |
| ] |
|
|
| def embed_query(self, text: str) -> List[float]: |
| |
| apply_rate_limiter_sync(self.a0_model_conf, text) |
| |
| resp = embedding(model=self.model_name, input=[text], **self.kwargs) |
| item = resp.data[0] |
| return item.get("embedding") if isinstance(item, dict) else item.embedding |
|
|
|
|
| class LocalSentenceTransformerWrapper(Embeddings): |
| """Local wrapper for sentence-transformers models to avoid HuggingFace API calls""" |
|
|
| def __init__(self, provider: str, model: str, model_config: Optional[ModelConfig] = None, **kwargs: Any): |
| |
| model = model.strip().strip('"').strip("'") |
|
|
| |
| if model.startswith("sentence-transformers/"): |
| model = model[len("sentence-transformers/") :] |
|
|
| self.model = SentenceTransformer(model, **kwargs) |
| self.model_name = model |
| self.a0_model_conf = model_config |
| |
| def embed_documents(self, texts: List[str]) -> List[List[float]]: |
| |
| apply_rate_limiter_sync(self.a0_model_conf, " ".join(texts)) |
| |
| embeddings = self.model.encode(texts, convert_to_tensor=False) |
| return embeddings.tolist() if hasattr(embeddings, "tolist") else embeddings |
|
|
| def embed_query(self, text: str) -> List[float]: |
| |
| apply_rate_limiter_sync(self.a0_model_conf, text) |
| |
| embedding = self.model.encode([text], convert_to_tensor=False) |
| result = ( |
| embedding[0].tolist() if hasattr(embedding[0], "tolist") else embedding[0] |
| ) |
| return result |
|
|
|
|
| def _get_litellm_chat( |
| cls: type = LiteLLMChatWrapper, |
| model_name: str = "", |
| provider_name: str = "", |
| model_config: Optional[ModelConfig] = None, |
| **kwargs: Any, |
| ): |
| |
| api_key = kwargs.pop("api_key", None) or get_api_key(provider_name) |
|
|
| |
| if api_key and api_key not in ("None", "NA"): |
| kwargs["api_key"] = api_key |
|
|
| provider_name, model_name, kwargs = _adjust_call_args( |
| provider_name, model_name, kwargs |
| ) |
| print(f"DEBUG: Creating {cls.__name__} with provider={provider_name}, model={model_name}, api_base={kwargs.get('api_base')}") |
| return cls(provider=provider_name, model=model_name, model_config=model_config, **kwargs) |
|
|
|
|
| def _get_litellm_embedding(model_name: str, provider_name: str, model_config: Optional[ModelConfig] = None, **kwargs: Any): |
| |
| if provider_name == "huggingface" and model_name.startswith( |
| "sentence-transformers/" |
| ): |
| |
| provider_name, model_name, kwargs = _adjust_call_args( |
| provider_name, model_name, kwargs |
| ) |
| return LocalSentenceTransformerWrapper( |
| provider=provider_name, model=model_name, model_config=model_config, **kwargs |
| ) |
|
|
| |
| api_key = kwargs.pop("api_key", None) or get_api_key(provider_name) |
|
|
| |
| if api_key and api_key not in ("None", "NA"): |
| kwargs["api_key"] = api_key |
|
|
| provider_name, model_name, kwargs = _adjust_call_args( |
| provider_name, model_name, kwargs |
| ) |
| return LiteLLMEmbeddingWrapper(model=model_name, provider=provider_name, model_config=model_config, **kwargs) |
|
|
|
|
| def _parse_chunk(chunk: Any) -> ChatChunk: |
| delta = chunk["choices"][0].get("delta", {}) |
| message = chunk["choices"][0].get("message", {}) or chunk["choices"][0].get( |
| "model_extra", {} |
| ).get("message", {}) |
| response_delta = ( |
| delta.get("content", "") |
| if isinstance(delta, dict) |
| else getattr(delta, "content", "") |
| ) or ( |
| message.get("content", "") |
| if isinstance(message, dict) |
| else getattr(message, "content", "") |
| ) |
| reasoning_delta = ( |
| delta.get("reasoning_content", "") |
| if isinstance(delta, dict) |
| else getattr(delta, "reasoning_content", "") |
| ) |
| return ChatChunk(reasoning_delta=reasoning_delta, response_delta=response_delta) |
|
|
|
|
| def _adjust_call_args(provider_name: str, model_name: str, kwargs: dict): |
| |
| label_to_id = { |
| "other openai compatible": "other", |
| "openai": "openai", |
| "anthropic": "anthropic", |
| "google": "google", |
| "deepseek": "deepseek", |
| "groq": "groq", |
| "huggingface": "huggingface", |
| "lm studio": "lm_studio", |
| "mistral ai": "mistral", |
| "ollama": "ollama", |
| "openrouter": "openrouter", |
| "sambanova": "sambanova", |
| "venice": "venice" |
| } |
| |
| provider_name_low = str(provider_name).lower() |
| if provider_name_low in label_to_id: |
| provider_name = label_to_id[provider_name_low] |
|
|
| |
| if provider_name == "openrouter": |
| kwargs["extra_headers"] = { |
| "HTTP-Referer": "https://agent-zero.ai", |
| "X-Title": "Agent Zero", |
| } |
|
|
| |
| if provider_name == "other": |
| provider_name = "openai" |
|
|
| return provider_name, model_name, kwargs |
|
|
|
|
| def _merge_provider_defaults( |
| provider_type: str, original_provider: str, kwargs: dict |
| ) -> tuple[str, dict]: |
| provider_name = original_provider |
| |
| |
| label_to_id = { |
| "other openai compatible": "other", |
| "openai": "openai", |
| "anthropic": "anthropic", |
| "google": "google", |
| "deepseek": "deepseek", |
| "groq": "groq", |
| "huggingface": "huggingface", |
| "lm studio": "lm_studio", |
| "mistral ai": "mistral", |
| "ollama": "ollama", |
| "openrouter": "openrouter", |
| "sambanova": "sambanova", |
| "venice": "venice" |
| } |
| orig_low = str(original_provider).lower() |
| if orig_low in label_to_id: |
| original_provider = label_to_id[orig_low] |
| provider_name = original_provider |
|
|
| cfg = get_provider_config(provider_type, original_provider) |
| if cfg: |
| provider_name = cfg.get("litellm_provider", original_provider).lower() |
|
|
| |
| extra_kwargs = cfg.get("kwargs") if isinstance(cfg, dict) else None |
| if isinstance(extra_kwargs, dict): |
| for k, v in extra_kwargs.items(): |
| kwargs.setdefault(k, v) |
|
|
| |
| if "api_key" not in kwargs: |
| key = get_api_key(original_provider) |
| if key and key not in ("None", "NA"): |
| kwargs["api_key"] = key |
|
|
| return provider_name, kwargs |
|
|
|
|
| def get_chat_model(provider: str, name: str, model_config: Optional[ModelConfig] = None, **kwargs: Any) -> LiteLLMChatWrapper: |
| orig = str(provider).lower() |
| provider_name, kwargs = _merge_provider_defaults("chat", orig, kwargs) |
| return _get_litellm_chat(LiteLLMChatWrapper, name, provider_name, model_config, **kwargs) |
|
|
|
|
| def get_browser_model( |
| provider: str, name: str, model_config: Optional[ModelConfig] = None, **kwargs: Any |
| ) -> BrowserCompatibleChatWrapper: |
| orig = str(provider).lower() |
| provider_name, kwargs = _merge_provider_defaults("chat", orig, kwargs) |
| return _get_litellm_chat( |
| BrowserCompatibleChatWrapper, name, provider_name, model_config, **kwargs |
| ) |
|
|
|
|
| def get_embedding_model( |
| provider: str, name: str, model_config: Optional[ModelConfig] = None, **kwargs: Any |
| ) -> LiteLLMEmbeddingWrapper | LocalSentenceTransformerWrapper: |
| orig = str(provider).lower() |
| provider_name, kwargs = _merge_provider_defaults("embedding", orig, kwargs) |
| return _get_litellm_embedding(name, provider_name, model_config, **kwargs) |
|
|