runner-ai-intelligence / src /llm /litellm_client.py
avfranco's picture
HF Space deploy snapshot (minimal allow-list)
d64fd55
import os
import logging
import json
import time
import re
import textwrap
from typing import Dict, Any, Optional, List, Type, Union
# Global LiteLLM configuration to prevent atexit worker errors
# MUST BE SET BEFORE IMPORTING LITELLM
os.environ["LITELLM_TELEMETRY"] = "False"
import litellm
litellm.telemetry = False
litellm.suppress_worker_errors = True
litellm.set_verbose = False
# Internal flag to disable the background logging worker
if hasattr(litellm, "_disable_logging_worker"):
litellm._disable_logging_worker = True
from pydantic import BaseModel
from observability import logger as obs_logger
from observability import components as obs_components
from .base import LLMClient, LLMCapabilities
from .structured import schema_guard, get_json_instruction, validate_structured_output
logger = logging.getLogger(__name__)
class LiteLLMClient(LLMClient):
"""
LLMClient implementation using LiteLLM for OpenAI-compatible APIs.
"""
@property
def capabilities(self) -> LLMCapabilities:
return LLMCapabilities() # Default capabilities
def __init__(
self,
model_name: str,
provider: Optional[str] = None,
api_base: Optional[str] = None,
api_key: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
drop_params: bool = False,
**kwargs,
):
self.model_name = model_name
self.provider = provider
self.api_base = api_base
self.api_key = api_key
self.temperature = temperature
self.max_tokens = max_tokens
self.drop_params = drop_params
self.extra_params = kwargs
if os.getenv("LITELLM_DEBUG", "false").lower() == "true":
# Using the recommended way to enable logging
os.environ["LITELLM_LOG"] = "DEBUG"
# litellm.set_verbose = True
logger.info("LiteLLM verbose logging enabled via LITELLM_LOG=DEBUG")
# LiteLLM handles key resolution automatically from env vars based on model prefix
# (e.g. OPENAI_API_KEY, ANTHROPIC_API_KEY, HF_TOKEN, etc.)
obs_logger.log_event(
level="info",
message=f"LiteLLM client initialized for {model_name} (provider: {provider or 'auto'})",
event="credentials_resolved",
component=obs_components.LLM,
provider="litellm",
model=model_name,
source="environment-automatic",
)
async def generate(
self,
prompt: str,
*,
instruction: Optional[str] = None,
schema: Optional[Type[BaseModel]] = None,
temperature: Optional[float] = None,
tools: Optional[List[Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
) -> Union[str, Dict[str, Any], BaseModel]:
messages = []
if instruction:
messages.append({"role": "system", "content": instruction})
messages.append({"role": "user", "content": prompt})
return await self.chat(
messages,
instruction=None, # Already added to messages
schema=schema,
temperature=temperature,
tools=tools,
metadata=metadata,
name=name,
)
async def chat(
self,
messages: List[Dict[str, str]],
*,
instruction: Optional[str] = None,
schema: Optional[Type[BaseModel]] = None,
temperature: Optional[float] = None,
tools: Optional[List[Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
) -> Union[str, Dict[str, Any], BaseModel]:
if schema:
# Combine all content for the guard
full_prompt = " ".join([m.get("content", "") for m in messages])
schema_guard(full_prompt, instruction)
instruction = get_json_instruction(schema, instruction)
chat_messages = []
for msg in messages:
content = msg.get("content", "")
if content:
content = textwrap.dedent(content).strip()
chat_messages.append({"role": msg["role"], "content": content})
if instruction:
chat_messages.insert(
0, {"role": "system", "content": textwrap.dedent(instruction).strip()}
)
# Ensure model has a provider prefix if api_base is used,
# so LiteLLM knows which adapter to use for the custom endpoint.
model = self.model_name
# 0. Basic prefixing for known providers if not already prefixed
if self.provider == "gemini" and not model.startswith("gemini/"):
model = f"gemini/{model}"
elif self.provider == "anthropic" and not model.startswith("anthropic/"):
model = f"anthropic/{model}"
# 1. Special case for Hugging Face Router vs Inference API
if self.api_base and ("huggingface.co" in self.api_base or "hf.co" in self.api_base):
is_router = "router.huggingface.co" in self.api_base
if is_router:
# Native HF Inference API (not OpenAI compatible)
if not model.startswith("huggingface/"):
if model.startswith("openai/"):
model = f"huggingface/{model}"
elif "/" in model:
model = f"huggingface/{model}"
else:
model = f"huggingface/openai/{model}"
# 2. General case for other custom OpenAI-compatible endpoints
elif self.api_base and not ("/" in model) and not model.startswith("openai/"):
model = f"openai/{model}"
# Prepare completion arguments
completion_kwargs = {
"model": model,
"messages": chat_messages,
"temperature": temperature if temperature is not None else self.temperature,
"max_tokens": self.max_tokens,
"drop_params": self.drop_params,
**self.extra_params,
}
# Only pass api_base/key if they are explicitly provided and not empty
if self.api_base and "api.openai.com" not in self.api_base:
completion_kwargs["api_base"] = self.api_base
if self.api_key:
completion_kwargs["api_key"] = self.api_key
# If we have a schema, we can try to use JSON mode if supported
if schema:
completion_kwargs["response_format"] = {"type": "json_object"}
# Mask API key for logging
log_kwargs = completion_kwargs.copy()
if "api_key" in log_kwargs and log_kwargs["api_key"]:
key = str(log_kwargs["api_key"])
log_kwargs["api_key"] = f"{key[:6]}...{key[-4:]}" if len(key) > 10 else "***"
logger.info(f"LiteLLM sending request to {model} at {self.api_base or 'default'}")
logger.debug(f"Completion args: {log_kwargs}")
try:
obs_logger.log_event(
"info",
"LLM call started",
event="start",
component=obs_components.LLM,
fields={"provider": "litellm", "model": model},
)
start_time = time.time()
response = await litellm.acompletion(**completion_kwargs)
duration_ms = (time.time() - start_time) * 1000
obs_logger.log_event(
"info",
f"Generating completion with model: {self.model_name}",
component=obs_components.LLM,
fields={"model": self.model_name, "temperature": self.temperature, "duration_ms": duration_ms},
)
response_text = response.choices[0].message.content
# Handle possible None content and check for reasoning_content (for models like o1)
if response_text is None:
response_text = getattr(response.choices[0].message, "reasoning_content", "") or ""
logger.debug(f"LiteLLM raw response: {response_text[:200]}...")
if not response_text:
logger.error(f"LiteLLM returned empty response for model {model}")
if schema:
raise ValueError(f"LLM returned empty response for schema {schema.__name__}")
return ""
if not schema:
return response_text
# Get a prompt ID for error reporting
prompt_id = "unknown"
if chat_messages:
last_user_msg = next(
(m for m in reversed(chat_messages) if m["role"] == "user"), None
)
if last_user_msg:
content = last_user_msg.get("content", "")
prompt_id = (content[:20] + "...") if len(content) > 20 else content
return validate_structured_output(
text=response_text,
schema=schema,
provider="litellm",
model=model,
prompt_id=prompt_id,
)
except Exception as e:
duration_ms = (time.time() - start_time) * 1000
obs_logger.log_event(
"error",
f"Async LLM call failed: {str(e)}",
component=obs_components.LLM,
fields={"provider": "litellm", "model": model, "duration_ms": duration_ms},
)
logger.error(f"LiteLLM completion failed: {e}")
raise
async def close(self):
"""
Close LiteLLM sessions.
"""
try:
import litellm
# litellm manages sessions internally.
# We can try to clean up if needed.
if hasattr(litellm, "cleanup_all_sessions"):
litellm.cleanup_all_sessions()
except:
pass