Smart-Auto-Complete / src /api_client.py
Sandipan Haldar
modify UI
3789a0e
"""
API Client for Smart Auto-Complete
Handles communication with OpenAI and Anthropic APIs
"""
import logging
import time
from typing import Dict, List, Optional, Union
import anthropic
import openai
from .utils import validate_api_key
logger = logging.getLogger(__name__)
class APIClient:
"""
Unified API client for multiple AI providers
Supports OpenAI GPT and Anthropic Claude models
"""
def __init__(self, settings=None):
"""
Initialize the API client with settings
Args:
settings: Application settings object
"""
self.settings = settings
self.openai_client = None
self.anthropic_client = None
self.current_provider = None
self.request_count = 0
self.last_request_time = 0
self._initialize_clients()
def _get_token_param_name(self, model: str) -> str:
"""
Get the correct token parameter name based on the model
Args:
model: The model name
Returns:
The correct parameter name ('max_tokens' or 'max_completion_tokens')
"""
# o3 models and newer reasoning models use max_completion_tokens
if model.startswith(("o3", "o1")):
return "max_completion_tokens"
# All other models use max_tokens
return "max_tokens"
def _initialize_clients(self):
"""Initialize API clients based on available keys"""
try:
# Initialize OpenAI client
if (
self.settings
and hasattr(self.settings, "OPENAI_API_KEY")
and self.settings.OPENAI_API_KEY
and validate_api_key(self.settings.OPENAI_API_KEY, "openai")
):
self.openai_client = openai.OpenAI(api_key=self.settings.OPENAI_API_KEY)
logger.info("OpenAI client initialized successfully")
# Initialize Anthropic client
if (
self.settings
and hasattr(self.settings, "ANTHROPIC_API_KEY")
and self.settings.ANTHROPIC_API_KEY
and validate_api_key(self.settings.ANTHROPIC_API_KEY, "anthropic")
):
self.anthropic_client = anthropic.Anthropic(
api_key=self.settings.ANTHROPIC_API_KEY
)
logger.info("Anthropic client initialized successfully")
# Set default provider
if hasattr(self.settings, "DEFAULT_PROVIDER"):
self.current_provider = self.settings.DEFAULT_PROVIDER
elif self.openai_client:
self.current_provider = "openai"
elif self.anthropic_client:
self.current_provider = "anthropic"
else:
logger.warning("No valid API clients initialized")
except Exception as e:
logger.error(f"Error initializing API clients: {str(e)}")
def get_completion(
self,
messages: List[Dict[str, str]],
temperature: float = 0.7,
max_tokens: int = 150,
provider: Optional[str] = None,
) -> Optional[str]:
"""
Get a completion from the specified provider
Args:
messages: List of message dictionaries with 'role' and 'content'
temperature: Sampling temperature (0.0 to 1.0)
max_tokens: Maximum tokens in response
provider: Specific provider to use ('openai' or 'anthropic')
Returns:
Generated completion text or None if failed
"""
try:
# Rate limiting check
if not self._check_rate_limit():
logger.warning("Rate limit exceeded, skipping request")
return None
# Determine which provider to use
use_provider = provider or self.current_provider
if use_provider == "openai" and self.openai_client:
return self._get_openai_completion(messages, temperature, max_tokens)
elif use_provider == "anthropic" and self.anthropic_client:
return self._get_anthropic_completion(messages, temperature, max_tokens)
else:
# Fallback to any available provider
if self.openai_client:
return self._get_openai_completion(
messages, temperature, max_tokens
)
elif self.anthropic_client:
return self._get_anthropic_completion(
messages, temperature, max_tokens
)
else:
logger.error("No API clients available")
return None
except Exception as e:
logger.error(f"Error getting completion: {str(e)}")
return None
def _get_openai_completion(
self, messages: List[Dict[str, str]], temperature: float, max_tokens: int
) -> Optional[str]:
"""Get completion from OpenAI API"""
try:
# Get model from settings
model = (
self.settings.get_model_for_provider("openai")
if self.settings
else "gpt-4o-mini"
)
logger.debug(f"Using OpenAI model: {model}")
# Get the correct token parameter name for this model
token_param = self._get_token_param_name(model)
logger.debug(f"Using token parameter: {token_param} = {max_tokens}")
# Build the request parameters
request_params = {
"model": model,
"messages": messages,
token_param: max_tokens, # Use the correct parameter name
"n": 1,
"stop": None,
}
# Only add temperature for non-reasoning models
# o3 and o1 models use default temperature (1.0) and don't accept custom values
if not model.startswith(("o3", "o1")):
request_params["temperature"] = temperature
logger.debug(f"Using custom temperature: {temperature}")
else:
logger.debug(f"Using default temperature for reasoning model {model}")
# Only add presence_penalty and frequency_penalty for non-reasoning models
# o3 and o1 models don't support these parameters
if not model.startswith(("o3", "o1")):
request_params["presence_penalty"] = 0.1
request_params["frequency_penalty"] = 0.1
response = self.openai_client.chat.completions.create(**request_params)
self._update_request_stats()
if response.choices and len(response.choices) > 0:
return response.choices[0].message.content.strip()
else:
logger.warning("No choices returned from OpenAI API")
return None
except openai.RateLimitError:
logger.warning("OpenAI rate limit exceeded")
return None
except openai.APIError as e:
logger.error(f"OpenAI API error: {str(e)}")
return None
except Exception as e:
logger.error(f"Unexpected error with OpenAI: {str(e)}")
return None
def _get_anthropic_completion(
self, messages: List[Dict[str, str]], temperature: float, max_tokens: int
) -> Optional[str]:
"""Get completion from Anthropic API"""
try:
# Convert messages format for Anthropic
system_message = ""
user_messages = []
for msg in messages:
if msg["role"] == "system":
system_message = msg["content"]
else:
user_messages.append(msg)
# Get model from settings
model = (
self.settings.get_model_for_provider("anthropic")
if self.settings
else "claude-3-haiku-20240307"
)
logger.debug(f"Using Anthropic model: {model}")
# Create the completion request
response = self.anthropic_client.messages.create(
model=model,
max_tokens=max_tokens,
temperature=temperature,
system=system_message,
messages=user_messages,
)
self._update_request_stats()
if response.content and len(response.content) > 0:
return response.content[0].text.strip()
else:
logger.warning("No content returned from Anthropic API")
return None
except anthropic.RateLimitError:
logger.warning("Anthropic rate limit exceeded")
return None
except anthropic.APIError as e:
logger.error(f"Anthropic API error: {str(e)}")
return None
except Exception as e:
logger.error(f"Unexpected error with Anthropic: {str(e)}")
return None
def _check_rate_limit(self) -> bool:
"""
Check if we're within rate limits
Simple implementation - can be enhanced with more sophisticated logic
"""
current_time = time.time()
# Allow max 60 requests per minute (1 per second)
if current_time - self.last_request_time < 1.0:
return False
return True
def _update_request_stats(self):
"""Update request statistics"""
self.request_count += 1
self.last_request_time = time.time()
def get_available_providers(self) -> List[str]:
"""Get list of available providers"""
providers = []
if self.openai_client:
providers.append("openai")
if self.anthropic_client:
providers.append("anthropic")
return providers
def switch_provider(self, provider: str) -> bool:
"""
Switch to a different provider
Args:
provider: Provider name ('openai' or 'anthropic')
Returns:
True if switch was successful, False otherwise
"""
if provider == "openai" and self.openai_client:
self.current_provider = "openai"
logger.info("Switched to OpenAI provider")
return True
elif provider == "anthropic" and self.anthropic_client:
self.current_provider = "anthropic"
logger.info("Switched to Anthropic provider")
return True
else:
logger.warning(f"Cannot switch to provider: {provider}")
return False
def get_stats(self) -> Dict[str, Union[int, float, str]]:
"""Get API usage statistics"""
return {
"request_count": self.request_count,
"current_provider": self.current_provider,
"available_providers": self.get_available_providers(),
"last_request_time": self.last_request_time,
}
def test_connection(self, provider: Optional[str] = None) -> bool:
"""
Test connection to the API provider
Args:
provider: Specific provider to test, or None for current provider
Returns:
True if connection is successful, False otherwise
"""
try:
test_messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Say 'Hello' in one word."},
]
result = self.get_completion(
messages=test_messages,
temperature=0.1,
max_tokens=10,
provider=provider,
)
return result is not None and len(result.strip()) > 0
except Exception as e:
logger.error(f"Connection test failed: {str(e)}")
return False