赵福来
First commit
90a947b
"""
OpenAI sdk module - Interface with OpenAI API
"""
from typing import List, Dict, Any, Generator
from openai import OpenAI
from src.config.settings import settings
from src.backends.utils.logger import app_logger
class ChatClient:
"""OpenAI Client cls"""
def __init__(self, api_key: str, base_url: str = None):
"""
init client
Args:
api_key: OpenAI/SiliconFlow API KEY
base_url: API base URL,default: config in the settings
"""
self.api_key = api_key
self.base_url = base_url or settings.SILICON_FLOW_BASE_URL
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url
)
self.logger = app_logger.get_logger("ChatClient")
def create_chat_completion(self,
model: str,
messages: List[Dict[str, str]],
stream: bool = False,
temperature: float = None,
max_tokens: int = None) -> Any:
"""
Create chat completion
Args:
model: model name
messages: chat history with messages type
stream: if stream
temperature: temperature
max_tokens: max tokens num
Returns:
OpenAI completion obj
"""
try:
params = {
"model": model,
"messages": messages,
"stream": stream,
"temperature": temperature or settings.DEFAULT_TEMPERATURE,
"max_tokens": max_tokens or settings.DEFAULT_MAX_TOKENS
}
self.logger.info(f"Sending request to model: {model}, with: {len(messages)} messages")
response = self.client.chat.completions.create(**params)
if not stream:
self.logger.info(f"Get response from model: {model}")
return response
except Exception as e:
self.logger.error(f"API request failed: {str(e)}")
raise
def get_streaming_response(self,
model: str,
messages: List[Dict[str, str]],
temperature: float = None,
max_tokens: int = None) -> Generator[str, None, None]:
"""
get streaming response
Args:
model: model name
messages: messages
temperature: temperature
max_tokens: max tokens number
Yields:
response chunk
"""
try:
response = self.create_chat_completion(
model=model,
messages=messages,
stream=True,
temperature=temperature,
max_tokens=max_tokens
)
full_response = ""
for chunk in response:
if chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
full_response += content
yield content
self.logger.info(f"Get streaming response from model: {model}, total length count: {len(full_response)}")
except Exception as e:
self.logger.error(f"Get streaming response failed: {str(e)}")
yield f"Error: {str(e)}"
def get_single_response(self,
model: str,
messages: List[Dict[str, str]],
temperature: float = None,
max_tokens: int = None) -> str:
"""
Get single response
Args:
model: model name
messages: messages
temperature: temperature
max_tokens: max tokens numbers
Returns:
Response content string
"""
try:
response = self.create_chat_completion(
model=model,
messages=messages,
stream=False,
temperature=temperature,
max_tokens=max_tokens
)
content = response.choices[0].message.content
self.logger.info(f"Get single response from model: {model},total length count: {len(content)}")
return content
except Exception as e:
self.logger.error(f"Get single response failed: {str(e)}")
return f"Error: {str(e)}"
def test_connection(self) -> tuple[bool, str]:
"""
API connection test
Returns:
(whether success, message)
"""
try:
# test request
response = self.create_chat_completion(
model=settings.TEST_MODEL_NAME,
messages=[{"role": "user",
"content": "Hello"}],
max_tokens=256
)
if response.choices[0].message.content:
self.logger.info("API connection success!")
return True, "API connection success!"
else:
return False, "API connection failed!"
except Exception as e:
self.logger.error(f"connection failed: {str(e)}")
return False, f"connection failed!: {str(e)}"