| """OpenAI-compatible client for LLM inference (supports Groq, DeepSeek, OpenAI, etc.).""" |
|
|
| import logging |
| from typing import Generator, List, Optional |
|
|
| from rich.console import Console |
|
|
| from config.settings import settings |
| from src.llm.llm_client import LLMClient |
|
|
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class OpenAIClientError(Exception): |
| """Raised when an OpenAI-compatible API call fails.""" |
|
|
| pass |
|
|
|
|
| class OpenAIClient(LLMClient): |
| """ |
| Client for interacting with OpenAI-compatible APIs for LLM inference. |
| |
| Supports: |
| - OpenAI (https://api.openai.com/v1) |
| - Groq (https://api.groq.com/openai/v1) |
| - DeepSeek (https://api.deepseek.com/v1) |
| - Any OpenAI-compatible endpoint |
| |
| Features: |
| - Non-streaming and streaming text generation |
| - Configurable model, temperature, and max tokens |
| - Automatic retry via the openai SDK |
| """ |
|
|
| def __init__( |
| self, |
| api_key: Optional[str] = None, |
| base_url: Optional[str] = None, |
| model: Optional[str] = None, |
| temperature: Optional[float] = None, |
| max_tokens: Optional[int] = None, |
| ): |
| """ |
| Initialize OpenAI-compatible client. |
| |
| Args: |
| api_key: API key (default: from settings) |
| base_url: Base URL for the API (default: from settings, or OpenAI default) |
| model: Model name (default: from settings) |
| temperature: Default temperature (default: from settings) |
| max_tokens: Default max tokens (default: from settings) |
| """ |
| try: |
| from openai import OpenAI |
| except ImportError: |
| raise ImportError( |
| "The 'openai' package is required for OpenAI-compatible providers. " |
| "Install it with: pip install openai>=1.0.0" |
| ) |
|
|
| self._api_key = api_key or settings.openai_api_key |
| if not self._api_key: |
| raise OpenAIClientError( |
| "API key is required for OpenAI-compatible provider. " |
| "Set OPENAI_API_KEY environment variable or pass api_key parameter." |
| ) |
|
|
| self._base_url = base_url or settings.openai_base_url |
| self.llm_model = model or settings.openai_model |
| self._temperature = temperature if temperature is not None else settings.llm_temperature |
| self._max_tokens = max_tokens if max_tokens is not None else settings.llm_max_tokens |
|
|
| self.console = Console() |
|
|
| |
| client_kwargs = {"api_key": self._api_key} |
| if self._base_url: |
| client_kwargs["base_url"] = self._base_url |
|
|
| self._client = OpenAI(**client_kwargs) |
|
|
| |
| provider_name = self._base_url or "OpenAI (default)" |
| self.console.print(f"[green][/green] OpenAI-compatible client initialized") |
| self.console.print(f" Provider: {provider_name}") |
| self.console.print(f" Model: {self.llm_model}") |
| logger.info(f"OpenAI client initialized: provider={provider_name}, model={self.llm_model}") |
|
|
| def generate( |
| self, |
| prompt: str, |
| system_prompt: Optional[str] = None, |
| temperature: Optional[float] = None, |
| max_tokens: Optional[int] = None, |
| stop: Optional[List[str]] = None, |
| ) -> str: |
| """ |
| Generate text using the OpenAI-compatible API (non-streaming). |
| |
| Args: |
| prompt: User prompt |
| system_prompt: Optional system prompt |
| temperature: Sampling temperature (default: from init/settings) |
| max_tokens: Maximum tokens to generate (default: from init/settings) |
| stop: Stop sequences |
| |
| Returns: |
| Generated text |
| |
| Raises: |
| OpenAIClientError: If generation fails |
| """ |
| temperature = temperature if temperature is not None else self._temperature |
| max_tokens = max_tokens if max_tokens is not None else self._max_tokens |
|
|
| messages = [] |
| if system_prompt: |
| messages.append({"role": "system", "content": system_prompt}) |
| messages.append({"role": "user", "content": prompt}) |
|
|
| try: |
| response = self._client.chat.completions.create( |
| model=self.llm_model, |
| messages=messages, |
| temperature=temperature, |
| max_tokens=max_tokens, |
| stop=stop, |
| ) |
| return response.choices[0].message.content or "" |
|
|
| except Exception as e: |
| logger.error(f"Failed to generate text via OpenAI-compatible API: {e}") |
| raise OpenAIClientError(f"Text generation failed: {e}") |
|
|
| def stream_generate( |
| self, |
| prompt: str, |
| system_prompt: Optional[str] = None, |
| temperature: Optional[float] = None, |
| max_tokens: Optional[int] = None, |
| stop: Optional[List[str]] = None, |
| ) -> Generator[str, None, None]: |
| """ |
| Generate text using the OpenAI-compatible API with streaming. |
| |
| Args: |
| prompt: User prompt |
| system_prompt: Optional system prompt |
| temperature: Sampling temperature (default: from init/settings) |
| max_tokens: Maximum tokens to generate (default: from init/settings) |
| stop: Stop sequences |
| |
| Yields: |
| Generated text chunks |
| |
| Raises: |
| OpenAIClientError: If generation fails |
| """ |
| temperature = temperature if temperature is not None else self._temperature |
| max_tokens = max_tokens if max_tokens is not None else self._max_tokens |
|
|
| messages = [] |
| if system_prompt: |
| messages.append({"role": "system", "content": system_prompt}) |
| messages.append({"role": "user", "content": prompt}) |
|
|
| try: |
| stream = self._client.chat.completions.create( |
| model=self.llm_model, |
| messages=messages, |
| temperature=temperature, |
| max_tokens=max_tokens, |
| stop=stop, |
| stream=True, |
| ) |
|
|
| for chunk in stream: |
| if chunk.choices and chunk.choices[0].delta.content: |
| yield chunk.choices[0].delta.content |
|
|
| except Exception as e: |
| logger.error(f"Failed to stream generate text via OpenAI-compatible API: {e}") |
| raise OpenAIClientError(f"Streaming generation failed: {e}") |
|
|