Spaces:
Running
Running
| import os | |
| import logging | |
| import json | |
| import time | |
| import re | |
| import textwrap | |
| from typing import Dict, Any, Optional, List, Type, Union | |
| # Global LiteLLM configuration to prevent atexit worker errors | |
| # MUST BE SET BEFORE IMPORTING LITELLM | |
| os.environ["LITELLM_TELEMETRY"] = "False" | |
| import litellm | |
| litellm.telemetry = False | |
| litellm.suppress_worker_errors = True | |
| litellm.set_verbose = False | |
| # Internal flag to disable the background logging worker | |
| if hasattr(litellm, "_disable_logging_worker"): | |
| litellm._disable_logging_worker = True | |
| from pydantic import BaseModel | |
| from observability import logger as obs_logger | |
| from observability import components as obs_components | |
| from .base import LLMClient, LLMCapabilities | |
| from .structured import schema_guard, get_json_instruction, validate_structured_output | |
| logger = logging.getLogger(__name__) | |
| class LiteLLMClient(LLMClient): | |
| """ | |
| LLMClient implementation using LiteLLM for OpenAI-compatible APIs. | |
| """ | |
| def capabilities(self) -> LLMCapabilities: | |
| return LLMCapabilities() # Default capabilities | |
| def __init__( | |
| self, | |
| model_name: str, | |
| provider: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| api_key: Optional[str] = None, | |
| temperature: Optional[float] = None, | |
| max_tokens: Optional[int] = None, | |
| drop_params: bool = False, | |
| **kwargs, | |
| ): | |
| self.model_name = model_name | |
| self.provider = provider | |
| self.api_base = api_base | |
| self.api_key = api_key | |
| self.temperature = temperature | |
| self.max_tokens = max_tokens | |
| self.drop_params = drop_params | |
| self.extra_params = kwargs | |
| if os.getenv("LITELLM_DEBUG", "false").lower() == "true": | |
| # Using the recommended way to enable logging | |
| os.environ["LITELLM_LOG"] = "DEBUG" | |
| # litellm.set_verbose = True | |
| logger.info("LiteLLM verbose logging enabled via LITELLM_LOG=DEBUG") | |
| # LiteLLM handles key resolution automatically from env vars based on model prefix | |
| # (e.g. OPENAI_API_KEY, ANTHROPIC_API_KEY, HF_TOKEN, etc.) | |
| obs_logger.log_event( | |
| level="info", | |
| message=f"LiteLLM client initialized for {model_name} (provider: {provider or 'auto'})", | |
| event="credentials_resolved", | |
| component=obs_components.LLM, | |
| provider="litellm", | |
| model=model_name, | |
| source="environment-automatic", | |
| ) | |
| async def generate( | |
| self, | |
| prompt: str, | |
| *, | |
| instruction: Optional[str] = None, | |
| schema: Optional[Type[BaseModel]] = None, | |
| temperature: Optional[float] = None, | |
| tools: Optional[List[Any]] = None, | |
| metadata: Optional[Dict[str, Any]] = None, | |
| name: Optional[str] = None, | |
| ) -> Union[str, Dict[str, Any], BaseModel]: | |
| messages = [] | |
| if instruction: | |
| messages.append({"role": "system", "content": instruction}) | |
| messages.append({"role": "user", "content": prompt}) | |
| return await self.chat( | |
| messages, | |
| instruction=None, # Already added to messages | |
| schema=schema, | |
| temperature=temperature, | |
| tools=tools, | |
| metadata=metadata, | |
| name=name, | |
| ) | |
| async def chat( | |
| self, | |
| messages: List[Dict[str, str]], | |
| *, | |
| instruction: Optional[str] = None, | |
| schema: Optional[Type[BaseModel]] = None, | |
| temperature: Optional[float] = None, | |
| tools: Optional[List[Any]] = None, | |
| metadata: Optional[Dict[str, Any]] = None, | |
| name: Optional[str] = None, | |
| ) -> Union[str, Dict[str, Any], BaseModel]: | |
| if schema: | |
| # Combine all content for the guard | |
| full_prompt = " ".join([m.get("content", "") for m in messages]) | |
| schema_guard(full_prompt, instruction) | |
| instruction = get_json_instruction(schema, instruction) | |
| chat_messages = [] | |
| for msg in messages: | |
| content = msg.get("content", "") | |
| if content: | |
| content = textwrap.dedent(content).strip() | |
| chat_messages.append({"role": msg["role"], "content": content}) | |
| if instruction: | |
| chat_messages.insert( | |
| 0, {"role": "system", "content": textwrap.dedent(instruction).strip()} | |
| ) | |
| # Ensure model has a provider prefix if api_base is used, | |
| # so LiteLLM knows which adapter to use for the custom endpoint. | |
| model = self.model_name | |
| # 0. Basic prefixing for known providers if not already prefixed | |
| if self.provider == "gemini" and not model.startswith("gemini/"): | |
| model = f"gemini/{model}" | |
| elif self.provider == "anthropic" and not model.startswith("anthropic/"): | |
| model = f"anthropic/{model}" | |
| # 1. Special case for Hugging Face Router vs Inference API | |
| if self.api_base and ("huggingface.co" in self.api_base or "hf.co" in self.api_base): | |
| is_router = "router.huggingface.co" in self.api_base | |
| if is_router: | |
| # Native HF Inference API (not OpenAI compatible) | |
| if not model.startswith("huggingface/"): | |
| if model.startswith("openai/"): | |
| model = f"huggingface/{model}" | |
| elif "/" in model: | |
| model = f"huggingface/{model}" | |
| else: | |
| model = f"huggingface/openai/{model}" | |
| # 2. General case for other custom OpenAI-compatible endpoints | |
| elif self.api_base and not ("/" in model) and not model.startswith("openai/"): | |
| model = f"openai/{model}" | |
| # Prepare completion arguments | |
| completion_kwargs = { | |
| "model": model, | |
| "messages": chat_messages, | |
| "temperature": temperature if temperature is not None else self.temperature, | |
| "max_tokens": self.max_tokens, | |
| "drop_params": self.drop_params, | |
| **self.extra_params, | |
| } | |
| # Only pass api_base/key if they are explicitly provided and not empty | |
| if self.api_base and "api.openai.com" not in self.api_base: | |
| completion_kwargs["api_base"] = self.api_base | |
| if self.api_key: | |
| completion_kwargs["api_key"] = self.api_key | |
| # If we have a schema, we can try to use JSON mode if supported | |
| if schema: | |
| completion_kwargs["response_format"] = {"type": "json_object"} | |
| # Mask API key for logging | |
| log_kwargs = completion_kwargs.copy() | |
| if "api_key" in log_kwargs and log_kwargs["api_key"]: | |
| key = str(log_kwargs["api_key"]) | |
| log_kwargs["api_key"] = f"{key[:6]}...{key[-4:]}" if len(key) > 10 else "***" | |
| logger.info(f"LiteLLM sending request to {model} at {self.api_base or 'default'}") | |
| logger.debug(f"Completion args: {log_kwargs}") | |
| try: | |
| obs_logger.log_event( | |
| "info", | |
| "LLM call started", | |
| event="start", | |
| component=obs_components.LLM, | |
| fields={"provider": "litellm", "model": model}, | |
| ) | |
| start_time = time.time() | |
| response = await litellm.acompletion(**completion_kwargs) | |
| duration_ms = (time.time() - start_time) * 1000 | |
| obs_logger.log_event( | |
| "info", | |
| f"Generating completion with model: {self.model_name}", | |
| component=obs_components.LLM, | |
| fields={"model": self.model_name, "temperature": self.temperature, "duration_ms": duration_ms}, | |
| ) | |
| response_text = response.choices[0].message.content | |
| # Handle possible None content and check for reasoning_content (for models like o1) | |
| if response_text is None: | |
| response_text = getattr(response.choices[0].message, "reasoning_content", "") or "" | |
| logger.debug(f"LiteLLM raw response: {response_text[:200]}...") | |
| if not response_text: | |
| logger.error(f"LiteLLM returned empty response for model {model}") | |
| if schema: | |
| raise ValueError(f"LLM returned empty response for schema {schema.__name__}") | |
| return "" | |
| if not schema: | |
| return response_text | |
| # Get a prompt ID for error reporting | |
| prompt_id = "unknown" | |
| if chat_messages: | |
| last_user_msg = next( | |
| (m for m in reversed(chat_messages) if m["role"] == "user"), None | |
| ) | |
| if last_user_msg: | |
| content = last_user_msg.get("content", "") | |
| prompt_id = (content[:20] + "...") if len(content) > 20 else content | |
| return validate_structured_output( | |
| text=response_text, | |
| schema=schema, | |
| provider="litellm", | |
| model=model, | |
| prompt_id=prompt_id, | |
| ) | |
| except Exception as e: | |
| duration_ms = (time.time() - start_time) * 1000 | |
| obs_logger.log_event( | |
| "error", | |
| f"Async LLM call failed: {str(e)}", | |
| component=obs_components.LLM, | |
| fields={"provider": "litellm", "model": model, "duration_ms": duration_ms}, | |
| ) | |
| logger.error(f"LiteLLM completion failed: {e}") | |
| raise | |
| async def close(self): | |
| """ | |
| Close LiteLLM sessions. | |
| """ | |
| try: | |
| import litellm | |
| # litellm manages sessions internally. | |
| # We can try to clean up if needed. | |
| if hasattr(litellm, "cleanup_all_sessions"): | |
| litellm.cleanup_all_sessions() | |
| except: | |
| pass | |