Spaces:
Sleeping
Sleeping
File size: 5,363 Bytes
90a947b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | """
OpenAI sdk module - Interface with OpenAI API
"""
from typing import List, Dict, Any, Generator
from openai import OpenAI
from src.config.settings import settings
from src.backends.utils.logger import app_logger
class ChatClient:
"""OpenAI Client cls"""
def __init__(self, api_key: str, base_url: str = None):
"""
init client
Args:
api_key: OpenAI/SiliconFlow API KEY
base_url: API base URL,default: config in the settings
"""
self.api_key = api_key
self.base_url = base_url or settings.SILICON_FLOW_BASE_URL
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url
)
self.logger = app_logger.get_logger("ChatClient")
def create_chat_completion(self,
model: str,
messages: List[Dict[str, str]],
stream: bool = False,
temperature: float = None,
max_tokens: int = None) -> Any:
"""
Create chat completion
Args:
model: model name
messages: chat history with messages type
stream: if stream
temperature: temperature
max_tokens: max tokens num
Returns:
OpenAI completion obj
"""
try:
params = {
"model": model,
"messages": messages,
"stream": stream,
"temperature": temperature or settings.DEFAULT_TEMPERATURE,
"max_tokens": max_tokens or settings.DEFAULT_MAX_TOKENS
}
self.logger.info(f"Sending request to model: {model}, with: {len(messages)} messages")
response = self.client.chat.completions.create(**params)
if not stream:
self.logger.info(f"Get response from model: {model}")
return response
except Exception as e:
self.logger.error(f"API request failed: {str(e)}")
raise
def get_streaming_response(self,
model: str,
messages: List[Dict[str, str]],
temperature: float = None,
max_tokens: int = None) -> Generator[str, None, None]:
"""
get streaming response
Args:
model: model name
messages: messages
temperature: temperature
max_tokens: max tokens number
Yields:
response chunk
"""
try:
response = self.create_chat_completion(
model=model,
messages=messages,
stream=True,
temperature=temperature,
max_tokens=max_tokens
)
full_response = ""
for chunk in response:
if chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
full_response += content
yield content
self.logger.info(f"Get streaming response from model: {model}, total length count: {len(full_response)}")
except Exception as e:
self.logger.error(f"Get streaming response failed: {str(e)}")
yield f"Error: {str(e)}"
def get_single_response(self,
model: str,
messages: List[Dict[str, str]],
temperature: float = None,
max_tokens: int = None) -> str:
"""
Get single response
Args:
model: model name
messages: messages
temperature: temperature
max_tokens: max tokens numbers
Returns:
Response content string
"""
try:
response = self.create_chat_completion(
model=model,
messages=messages,
stream=False,
temperature=temperature,
max_tokens=max_tokens
)
content = response.choices[0].message.content
self.logger.info(f"Get single response from model: {model},total length count: {len(content)}")
return content
except Exception as e:
self.logger.error(f"Get single response failed: {str(e)}")
return f"Error: {str(e)}"
def test_connection(self) -> tuple[bool, str]:
"""
API connection test
Returns:
(whether success, message)
"""
try:
# test request
response = self.create_chat_completion(
model=settings.TEST_MODEL_NAME,
messages=[{"role": "user",
"content": "Hello"}],
max_tokens=256
)
if response.choices[0].message.content:
self.logger.info("API connection success!")
return True, "API connection success!"
else:
return False, "API connection failed!"
except Exception as e:
self.logger.error(f"connection failed: {str(e)}")
return False, f"connection failed!: {str(e)}" |