eye-wiki / src /llm /openai_client.py
stanleydukor's picture
Initial deployment
702ea87
"""OpenAI-compatible client for LLM inference (supports Groq, DeepSeek, OpenAI, etc.)."""
import logging
from typing import Generator, List, Optional
from rich.console import Console
from config.settings import settings
from src.llm.llm_client import LLMClient
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class OpenAIClientError(Exception):
"""Raised when an OpenAI-compatible API call fails."""
pass
class OpenAIClient(LLMClient):
"""
Client for interacting with OpenAI-compatible APIs for LLM inference.
Supports:
- OpenAI (https://api.openai.com/v1)
- Groq (https://api.groq.com/openai/v1)
- DeepSeek (https://api.deepseek.com/v1)
- Any OpenAI-compatible endpoint
Features:
- Non-streaming and streaming text generation
- Configurable model, temperature, and max tokens
- Automatic retry via the openai SDK
"""
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
):
"""
Initialize OpenAI-compatible client.
Args:
api_key: API key (default: from settings)
base_url: Base URL for the API (default: from settings, or OpenAI default)
model: Model name (default: from settings)
temperature: Default temperature (default: from settings)
max_tokens: Default max tokens (default: from settings)
"""
try:
from openai import OpenAI
except ImportError:
raise ImportError(
"The 'openai' package is required for OpenAI-compatible providers. "
"Install it with: pip install openai>=1.0.0"
)
self._api_key = api_key or settings.openai_api_key
if not self._api_key:
raise OpenAIClientError(
"API key is required for OpenAI-compatible provider. "
"Set OPENAI_API_KEY environment variable or pass api_key parameter."
)
self._base_url = base_url or settings.openai_base_url
self.llm_model = model or settings.openai_model
self._temperature = temperature if temperature is not None else settings.llm_temperature
self._max_tokens = max_tokens if max_tokens is not None else settings.llm_max_tokens
self.console = Console()
# Initialize OpenAI client
client_kwargs = {"api_key": self._api_key}
if self._base_url:
client_kwargs["base_url"] = self._base_url
self._client = OpenAI(**client_kwargs)
# Log initialization
provider_name = self._base_url or "OpenAI (default)"
self.console.print(f"[green][/green] OpenAI-compatible client initialized")
self.console.print(f" Provider: {provider_name}")
self.console.print(f" Model: {self.llm_model}")
logger.info(f"OpenAI client initialized: provider={provider_name}, model={self.llm_model}")
def generate(
self,
prompt: str,
system_prompt: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
stop: Optional[List[str]] = None,
) -> str:
"""
Generate text using the OpenAI-compatible API (non-streaming).
Args:
prompt: User prompt
system_prompt: Optional system prompt
temperature: Sampling temperature (default: from init/settings)
max_tokens: Maximum tokens to generate (default: from init/settings)
stop: Stop sequences
Returns:
Generated text
Raises:
OpenAIClientError: If generation fails
"""
temperature = temperature if temperature is not None else self._temperature
max_tokens = max_tokens if max_tokens is not None else self._max_tokens
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
try:
response = self._client.chat.completions.create(
model=self.llm_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stop=stop,
)
return response.choices[0].message.content or ""
except Exception as e:
logger.error(f"Failed to generate text via OpenAI-compatible API: {e}")
raise OpenAIClientError(f"Text generation failed: {e}")
def stream_generate(
self,
prompt: str,
system_prompt: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
stop: Optional[List[str]] = None,
) -> Generator[str, None, None]:
"""
Generate text using the OpenAI-compatible API with streaming.
Args:
prompt: User prompt
system_prompt: Optional system prompt
temperature: Sampling temperature (default: from init/settings)
max_tokens: Maximum tokens to generate (default: from init/settings)
stop: Stop sequences
Yields:
Generated text chunks
Raises:
OpenAIClientError: If generation fails
"""
temperature = temperature if temperature is not None else self._temperature
max_tokens = max_tokens if max_tokens is not None else self._max_tokens
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
try:
stream = self._client.chat.completions.create(
model=self.llm_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stop=stop,
stream=True,
)
for chunk in stream:
if chunk.choices and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
except Exception as e:
logger.error(f"Failed to stream generate text via OpenAI-compatible API: {e}")
raise OpenAIClientError(f"Streaming generation failed: {e}")