Spaces:
Sleeping
Sleeping
| """ | |
| OpenAI sdk module - Interface with OpenAI API | |
| """ | |
| from typing import List, Dict, Any, Generator | |
| from openai import OpenAI | |
| from src.config.settings import settings | |
| from src.backends.utils.logger import app_logger | |
| class ChatClient: | |
| """OpenAI Client cls""" | |
| def __init__(self, api_key: str, base_url: str = None): | |
| """ | |
| init client | |
| Args: | |
| api_key: OpenAI/SiliconFlow API KEY | |
| base_url: API base URL,default: config in the settings | |
| """ | |
| self.api_key = api_key | |
| self.base_url = base_url or settings.SILICON_FLOW_BASE_URL | |
| self.client = OpenAI( | |
| api_key=self.api_key, | |
| base_url=self.base_url | |
| ) | |
| self.logger = app_logger.get_logger("ChatClient") | |
| def create_chat_completion(self, | |
| model: str, | |
| messages: List[Dict[str, str]], | |
| stream: bool = False, | |
| temperature: float = None, | |
| max_tokens: int = None) -> Any: | |
| """ | |
| Create chat completion | |
| Args: | |
| model: model name | |
| messages: chat history with messages type | |
| stream: if stream | |
| temperature: temperature | |
| max_tokens: max tokens num | |
| Returns: | |
| OpenAI completion obj | |
| """ | |
| try: | |
| params = { | |
| "model": model, | |
| "messages": messages, | |
| "stream": stream, | |
| "temperature": temperature or settings.DEFAULT_TEMPERATURE, | |
| "max_tokens": max_tokens or settings.DEFAULT_MAX_TOKENS | |
| } | |
| self.logger.info(f"Sending request to model: {model}, with: {len(messages)} messages") | |
| response = self.client.chat.completions.create(**params) | |
| if not stream: | |
| self.logger.info(f"Get response from model: {model}") | |
| return response | |
| except Exception as e: | |
| self.logger.error(f"API request failed: {str(e)}") | |
| raise | |
| def get_streaming_response(self, | |
| model: str, | |
| messages: List[Dict[str, str]], | |
| temperature: float = None, | |
| max_tokens: int = None) -> Generator[str, None, None]: | |
| """ | |
| get streaming response | |
| Args: | |
| model: model name | |
| messages: messages | |
| temperature: temperature | |
| max_tokens: max tokens number | |
| Yields: | |
| response chunk | |
| """ | |
| try: | |
| response = self.create_chat_completion( | |
| model=model, | |
| messages=messages, | |
| stream=True, | |
| temperature=temperature, | |
| max_tokens=max_tokens | |
| ) | |
| full_response = "" | |
| for chunk in response: | |
| if chunk.choices[0].delta.content is not None: | |
| content = chunk.choices[0].delta.content | |
| full_response += content | |
| yield content | |
| self.logger.info(f"Get streaming response from model: {model}, total length count: {len(full_response)}") | |
| except Exception as e: | |
| self.logger.error(f"Get streaming response failed: {str(e)}") | |
| yield f"Error: {str(e)}" | |
| def get_single_response(self, | |
| model: str, | |
| messages: List[Dict[str, str]], | |
| temperature: float = None, | |
| max_tokens: int = None) -> str: | |
| """ | |
| Get single response | |
| Args: | |
| model: model name | |
| messages: messages | |
| temperature: temperature | |
| max_tokens: max tokens numbers | |
| Returns: | |
| Response content string | |
| """ | |
| try: | |
| response = self.create_chat_completion( | |
| model=model, | |
| messages=messages, | |
| stream=False, | |
| temperature=temperature, | |
| max_tokens=max_tokens | |
| ) | |
| content = response.choices[0].message.content | |
| self.logger.info(f"Get single response from model: {model},total length count: {len(content)}") | |
| return content | |
| except Exception as e: | |
| self.logger.error(f"Get single response failed: {str(e)}") | |
| return f"Error: {str(e)}" | |
| def test_connection(self) -> tuple[bool, str]: | |
| """ | |
| API connection test | |
| Returns: | |
| (whether success, message) | |
| """ | |
| try: | |
| # test request | |
| response = self.create_chat_completion( | |
| model=settings.TEST_MODEL_NAME, | |
| messages=[{"role": "user", | |
| "content": "Hello"}], | |
| max_tokens=256 | |
| ) | |
| if response.choices[0].message.content: | |
| self.logger.info("API connection success!") | |
| return True, "API connection success!" | |
| else: | |
| return False, "API connection failed!" | |
| except Exception as e: | |
| self.logger.error(f"connection failed: {str(e)}") | |
| return False, f"connection failed!: {str(e)}" |