import os import logging import json import time import re import textwrap from typing import Dict, Any, Optional, List, Type, Union # Global LiteLLM configuration to prevent atexit worker errors # MUST BE SET BEFORE IMPORTING LITELLM os.environ["LITELLM_TELEMETRY"] = "False" import litellm litellm.telemetry = False litellm.suppress_worker_errors = True litellm.set_verbose = False # Internal flag to disable the background logging worker if hasattr(litellm, "_disable_logging_worker"): litellm._disable_logging_worker = True from pydantic import BaseModel from observability import logger as obs_logger from observability import components as obs_components from .base import LLMClient, LLMCapabilities from .structured import schema_guard, get_json_instruction, validate_structured_output logger = logging.getLogger(__name__) class LiteLLMClient(LLMClient): """ LLMClient implementation using LiteLLM for OpenAI-compatible APIs. """ @property def capabilities(self) -> LLMCapabilities: return LLMCapabilities() # Default capabilities def __init__( self, model_name: str, provider: Optional[str] = None, api_base: Optional[str] = None, api_key: Optional[str] = None, temperature: Optional[float] = None, max_tokens: Optional[int] = None, drop_params: bool = False, **kwargs, ): self.model_name = model_name self.provider = provider self.api_base = api_base self.api_key = api_key self.temperature = temperature self.max_tokens = max_tokens self.drop_params = drop_params self.extra_params = kwargs if os.getenv("LITELLM_DEBUG", "false").lower() == "true": # Using the recommended way to enable logging os.environ["LITELLM_LOG"] = "DEBUG" # litellm.set_verbose = True logger.info("LiteLLM verbose logging enabled via LITELLM_LOG=DEBUG") # LiteLLM handles key resolution automatically from env vars based on model prefix # (e.g. OPENAI_API_KEY, ANTHROPIC_API_KEY, HF_TOKEN, etc.) obs_logger.log_event( level="info", message=f"LiteLLM client initialized for {model_name} (provider: {provider or 'auto'})", event="credentials_resolved", component=obs_components.LLM, provider="litellm", model=model_name, source="environment-automatic", ) async def generate( self, prompt: str, *, instruction: Optional[str] = None, schema: Optional[Type[BaseModel]] = None, temperature: Optional[float] = None, tools: Optional[List[Any]] = None, metadata: Optional[Dict[str, Any]] = None, name: Optional[str] = None, ) -> Union[str, Dict[str, Any], BaseModel]: messages = [] if instruction: messages.append({"role": "system", "content": instruction}) messages.append({"role": "user", "content": prompt}) return await self.chat( messages, instruction=None, # Already added to messages schema=schema, temperature=temperature, tools=tools, metadata=metadata, name=name, ) async def chat( self, messages: List[Dict[str, str]], *, instruction: Optional[str] = None, schema: Optional[Type[BaseModel]] = None, temperature: Optional[float] = None, tools: Optional[List[Any]] = None, metadata: Optional[Dict[str, Any]] = None, name: Optional[str] = None, ) -> Union[str, Dict[str, Any], BaseModel]: if schema: # Combine all content for the guard full_prompt = " ".join([m.get("content", "") for m in messages]) schema_guard(full_prompt, instruction) instruction = get_json_instruction(schema, instruction) chat_messages = [] for msg in messages: content = msg.get("content", "") if content: content = textwrap.dedent(content).strip() chat_messages.append({"role": msg["role"], "content": content}) if instruction: chat_messages.insert( 0, {"role": "system", "content": textwrap.dedent(instruction).strip()} ) # Ensure model has a provider prefix if api_base is used, # so LiteLLM knows which adapter to use for the custom endpoint. model = self.model_name # 0. Basic prefixing for known providers if not already prefixed if self.provider == "gemini" and not model.startswith("gemini/"): model = f"gemini/{model}" elif self.provider == "anthropic" and not model.startswith("anthropic/"): model = f"anthropic/{model}" # 1. Special case for Hugging Face Router vs Inference API if self.api_base and ("huggingface.co" in self.api_base or "hf.co" in self.api_base): is_router = "router.huggingface.co" in self.api_base if is_router: # Native HF Inference API (not OpenAI compatible) if not model.startswith("huggingface/"): if model.startswith("openai/"): model = f"huggingface/{model}" elif "/" in model: model = f"huggingface/{model}" else: model = f"huggingface/openai/{model}" # 2. General case for other custom OpenAI-compatible endpoints elif self.api_base and not ("/" in model) and not model.startswith("openai/"): model = f"openai/{model}" # Prepare completion arguments completion_kwargs = { "model": model, "messages": chat_messages, "temperature": temperature if temperature is not None else self.temperature, "max_tokens": self.max_tokens, "drop_params": self.drop_params, **self.extra_params, } # Only pass api_base/key if they are explicitly provided and not empty if self.api_base and "api.openai.com" not in self.api_base: completion_kwargs["api_base"] = self.api_base if self.api_key: completion_kwargs["api_key"] = self.api_key # If we have a schema, we can try to use JSON mode if supported if schema: completion_kwargs["response_format"] = {"type": "json_object"} # Mask API key for logging log_kwargs = completion_kwargs.copy() if "api_key" in log_kwargs and log_kwargs["api_key"]: key = str(log_kwargs["api_key"]) log_kwargs["api_key"] = f"{key[:6]}...{key[-4:]}" if len(key) > 10 else "***" logger.info(f"LiteLLM sending request to {model} at {self.api_base or 'default'}") logger.debug(f"Completion args: {log_kwargs}") try: obs_logger.log_event( "info", "LLM call started", event="start", component=obs_components.LLM, fields={"provider": "litellm", "model": model}, ) start_time = time.time() response = await litellm.acompletion(**completion_kwargs) duration_ms = (time.time() - start_time) * 1000 obs_logger.log_event( "info", f"Generating completion with model: {self.model_name}", component=obs_components.LLM, fields={"model": self.model_name, "temperature": self.temperature, "duration_ms": duration_ms}, ) response_text = response.choices[0].message.content # Handle possible None content and check for reasoning_content (for models like o1) if response_text is None: response_text = getattr(response.choices[0].message, "reasoning_content", "") or "" logger.debug(f"LiteLLM raw response: {response_text[:200]}...") if not response_text: logger.error(f"LiteLLM returned empty response for model {model}") if schema: raise ValueError(f"LLM returned empty response for schema {schema.__name__}") return "" if not schema: return response_text # Get a prompt ID for error reporting prompt_id = "unknown" if chat_messages: last_user_msg = next( (m for m in reversed(chat_messages) if m["role"] == "user"), None ) if last_user_msg: content = last_user_msg.get("content", "") prompt_id = (content[:20] + "...") if len(content) > 20 else content return validate_structured_output( text=response_text, schema=schema, provider="litellm", model=model, prompt_id=prompt_id, ) except Exception as e: duration_ms = (time.time() - start_time) * 1000 obs_logger.log_event( "error", f"Async LLM call failed: {str(e)}", component=obs_components.LLM, fields={"provider": "litellm", "model": model, "duration_ms": duration_ms}, ) logger.error(f"LiteLLM completion failed: {e}") raise async def close(self): """ Close LiteLLM sessions. """ try: import litellm # litellm manages sessions internally. # We can try to clean up if needed. if hasattr(litellm, "cleanup_all_sessions"): litellm.cleanup_all_sessions() except: pass