|
|
""" |
|
|
API Client for Smart Auto-Complete |
|
|
Handles communication with OpenAI and Anthropic APIs |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import time |
|
|
from typing import Dict, List, Optional, Union |
|
|
|
|
|
import anthropic |
|
|
import openai |
|
|
|
|
|
from .utils import validate_api_key |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class APIClient: |
|
|
""" |
|
|
Unified API client for multiple AI providers |
|
|
Supports OpenAI GPT and Anthropic Claude models |
|
|
""" |
|
|
|
|
|
def __init__(self, settings=None): |
|
|
""" |
|
|
Initialize the API client with settings |
|
|
|
|
|
Args: |
|
|
settings: Application settings object |
|
|
""" |
|
|
self.settings = settings |
|
|
self.openai_client = None |
|
|
self.anthropic_client = None |
|
|
self.current_provider = None |
|
|
self.request_count = 0 |
|
|
self.last_request_time = 0 |
|
|
|
|
|
self._initialize_clients() |
|
|
|
|
|
def _get_token_param_name(self, model: str) -> str: |
|
|
""" |
|
|
Get the correct token parameter name based on the model |
|
|
|
|
|
Args: |
|
|
model: The model name |
|
|
|
|
|
Returns: |
|
|
The correct parameter name ('max_tokens' or 'max_completion_tokens') |
|
|
""" |
|
|
|
|
|
if model.startswith(("o3", "o1")): |
|
|
return "max_completion_tokens" |
|
|
|
|
|
return "max_tokens" |
|
|
|
|
|
def _initialize_clients(self): |
|
|
"""Initialize API clients based on available keys""" |
|
|
try: |
|
|
|
|
|
if ( |
|
|
self.settings |
|
|
and hasattr(self.settings, "OPENAI_API_KEY") |
|
|
and self.settings.OPENAI_API_KEY |
|
|
and validate_api_key(self.settings.OPENAI_API_KEY, "openai") |
|
|
): |
|
|
self.openai_client = openai.OpenAI(api_key=self.settings.OPENAI_API_KEY) |
|
|
logger.info("OpenAI client initialized successfully") |
|
|
|
|
|
|
|
|
if ( |
|
|
self.settings |
|
|
and hasattr(self.settings, "ANTHROPIC_API_KEY") |
|
|
and self.settings.ANTHROPIC_API_KEY |
|
|
and validate_api_key(self.settings.ANTHROPIC_API_KEY, "anthropic") |
|
|
): |
|
|
self.anthropic_client = anthropic.Anthropic( |
|
|
api_key=self.settings.ANTHROPIC_API_KEY |
|
|
) |
|
|
logger.info("Anthropic client initialized successfully") |
|
|
|
|
|
|
|
|
if hasattr(self.settings, "DEFAULT_PROVIDER"): |
|
|
self.current_provider = self.settings.DEFAULT_PROVIDER |
|
|
elif self.openai_client: |
|
|
self.current_provider = "openai" |
|
|
elif self.anthropic_client: |
|
|
self.current_provider = "anthropic" |
|
|
else: |
|
|
logger.warning("No valid API clients initialized") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error initializing API clients: {str(e)}") |
|
|
|
|
|
def get_completion( |
|
|
self, |
|
|
messages: List[Dict[str, str]], |
|
|
temperature: float = 0.7, |
|
|
max_tokens: int = 150, |
|
|
provider: Optional[str] = None, |
|
|
) -> Optional[str]: |
|
|
""" |
|
|
Get a completion from the specified provider |
|
|
|
|
|
Args: |
|
|
messages: List of message dictionaries with 'role' and 'content' |
|
|
temperature: Sampling temperature (0.0 to 1.0) |
|
|
max_tokens: Maximum tokens in response |
|
|
provider: Specific provider to use ('openai' or 'anthropic') |
|
|
|
|
|
Returns: |
|
|
Generated completion text or None if failed |
|
|
""" |
|
|
try: |
|
|
|
|
|
if not self._check_rate_limit(): |
|
|
logger.warning("Rate limit exceeded, skipping request") |
|
|
return None |
|
|
|
|
|
|
|
|
use_provider = provider or self.current_provider |
|
|
|
|
|
if use_provider == "openai" and self.openai_client: |
|
|
return self._get_openai_completion(messages, temperature, max_tokens) |
|
|
elif use_provider == "anthropic" and self.anthropic_client: |
|
|
return self._get_anthropic_completion(messages, temperature, max_tokens) |
|
|
else: |
|
|
|
|
|
if self.openai_client: |
|
|
return self._get_openai_completion( |
|
|
messages, temperature, max_tokens |
|
|
) |
|
|
elif self.anthropic_client: |
|
|
return self._get_anthropic_completion( |
|
|
messages, temperature, max_tokens |
|
|
) |
|
|
else: |
|
|
logger.error("No API clients available") |
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error getting completion: {str(e)}") |
|
|
return None |
|
|
|
|
|
def _get_openai_completion( |
|
|
self, messages: List[Dict[str, str]], temperature: float, max_tokens: int |
|
|
) -> Optional[str]: |
|
|
"""Get completion from OpenAI API""" |
|
|
try: |
|
|
|
|
|
model = ( |
|
|
self.settings.get_model_for_provider("openai") |
|
|
if self.settings |
|
|
else "gpt-4o-mini" |
|
|
) |
|
|
|
|
|
logger.debug(f"Using OpenAI model: {model}") |
|
|
|
|
|
|
|
|
token_param = self._get_token_param_name(model) |
|
|
logger.debug(f"Using token parameter: {token_param} = {max_tokens}") |
|
|
|
|
|
|
|
|
request_params = { |
|
|
"model": model, |
|
|
"messages": messages, |
|
|
token_param: max_tokens, |
|
|
"n": 1, |
|
|
"stop": None, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if not model.startswith(("o3", "o1")): |
|
|
request_params["temperature"] = temperature |
|
|
logger.debug(f"Using custom temperature: {temperature}") |
|
|
else: |
|
|
logger.debug(f"Using default temperature for reasoning model {model}") |
|
|
|
|
|
|
|
|
|
|
|
if not model.startswith(("o3", "o1")): |
|
|
request_params["presence_penalty"] = 0.1 |
|
|
request_params["frequency_penalty"] = 0.1 |
|
|
|
|
|
response = self.openai_client.chat.completions.create(**request_params) |
|
|
|
|
|
self._update_request_stats() |
|
|
|
|
|
if response.choices and len(response.choices) > 0: |
|
|
return response.choices[0].message.content.strip() |
|
|
else: |
|
|
logger.warning("No choices returned from OpenAI API") |
|
|
return None |
|
|
|
|
|
except openai.RateLimitError: |
|
|
logger.warning("OpenAI rate limit exceeded") |
|
|
return None |
|
|
except openai.APIError as e: |
|
|
logger.error(f"OpenAI API error: {str(e)}") |
|
|
return None |
|
|
except Exception as e: |
|
|
logger.error(f"Unexpected error with OpenAI: {str(e)}") |
|
|
return None |
|
|
|
|
|
def _get_anthropic_completion( |
|
|
self, messages: List[Dict[str, str]], temperature: float, max_tokens: int |
|
|
) -> Optional[str]: |
|
|
"""Get completion from Anthropic API""" |
|
|
try: |
|
|
|
|
|
system_message = "" |
|
|
user_messages = [] |
|
|
|
|
|
for msg in messages: |
|
|
if msg["role"] == "system": |
|
|
system_message = msg["content"] |
|
|
else: |
|
|
user_messages.append(msg) |
|
|
|
|
|
|
|
|
model = ( |
|
|
self.settings.get_model_for_provider("anthropic") |
|
|
if self.settings |
|
|
else "claude-3-haiku-20240307" |
|
|
) |
|
|
|
|
|
logger.debug(f"Using Anthropic model: {model}") |
|
|
|
|
|
|
|
|
response = self.anthropic_client.messages.create( |
|
|
model=model, |
|
|
max_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
system=system_message, |
|
|
messages=user_messages, |
|
|
) |
|
|
|
|
|
self._update_request_stats() |
|
|
|
|
|
if response.content and len(response.content) > 0: |
|
|
return response.content[0].text.strip() |
|
|
else: |
|
|
logger.warning("No content returned from Anthropic API") |
|
|
return None |
|
|
|
|
|
except anthropic.RateLimitError: |
|
|
logger.warning("Anthropic rate limit exceeded") |
|
|
return None |
|
|
except anthropic.APIError as e: |
|
|
logger.error(f"Anthropic API error: {str(e)}") |
|
|
return None |
|
|
except Exception as e: |
|
|
logger.error(f"Unexpected error with Anthropic: {str(e)}") |
|
|
return None |
|
|
|
|
|
def _check_rate_limit(self) -> bool: |
|
|
""" |
|
|
Check if we're within rate limits |
|
|
Simple implementation - can be enhanced with more sophisticated logic |
|
|
""" |
|
|
current_time = time.time() |
|
|
|
|
|
|
|
|
if current_time - self.last_request_time < 1.0: |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
def _update_request_stats(self): |
|
|
"""Update request statistics""" |
|
|
self.request_count += 1 |
|
|
self.last_request_time = time.time() |
|
|
|
|
|
def get_available_providers(self) -> List[str]: |
|
|
"""Get list of available providers""" |
|
|
providers = [] |
|
|
if self.openai_client: |
|
|
providers.append("openai") |
|
|
if self.anthropic_client: |
|
|
providers.append("anthropic") |
|
|
return providers |
|
|
|
|
|
def switch_provider(self, provider: str) -> bool: |
|
|
""" |
|
|
Switch to a different provider |
|
|
|
|
|
Args: |
|
|
provider: Provider name ('openai' or 'anthropic') |
|
|
|
|
|
Returns: |
|
|
True if switch was successful, False otherwise |
|
|
""" |
|
|
if provider == "openai" and self.openai_client: |
|
|
self.current_provider = "openai" |
|
|
logger.info("Switched to OpenAI provider") |
|
|
return True |
|
|
elif provider == "anthropic" and self.anthropic_client: |
|
|
self.current_provider = "anthropic" |
|
|
logger.info("Switched to Anthropic provider") |
|
|
return True |
|
|
else: |
|
|
logger.warning(f"Cannot switch to provider: {provider}") |
|
|
return False |
|
|
|
|
|
def get_stats(self) -> Dict[str, Union[int, float, str]]: |
|
|
"""Get API usage statistics""" |
|
|
return { |
|
|
"request_count": self.request_count, |
|
|
"current_provider": self.current_provider, |
|
|
"available_providers": self.get_available_providers(), |
|
|
"last_request_time": self.last_request_time, |
|
|
} |
|
|
|
|
|
def test_connection(self, provider: Optional[str] = None) -> bool: |
|
|
""" |
|
|
Test connection to the API provider |
|
|
|
|
|
Args: |
|
|
provider: Specific provider to test, or None for current provider |
|
|
|
|
|
Returns: |
|
|
True if connection is successful, False otherwise |
|
|
""" |
|
|
try: |
|
|
test_messages = [ |
|
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
|
{"role": "user", "content": "Say 'Hello' in one word."}, |
|
|
] |
|
|
|
|
|
result = self.get_completion( |
|
|
messages=test_messages, |
|
|
temperature=0.1, |
|
|
max_tokens=10, |
|
|
provider=provider, |
|
|
) |
|
|
|
|
|
return result is not None and len(result.strip()) > 0 |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Connection test failed: {str(e)}") |
|
|
return False |
|
|
|