import logging import os import openai logging.basicConfig(level=logging.INFO) def get_openai_client(client_type: str): """ Refer to [this page](https://platform.openai.com/docs/models) for authentication using OpenAI. Refer to [this page](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/switching-endpoints) for authentication using Azure OpenAI. """ assert client_type in ["azure_openai", "openai"] if not os.environ.get('OPENAI_API_VERSION'): os.environ['OPENAI_API_VERSION'] = "2023-05-15" if client_type == "openai": # 构建客户端参数 client_params = { 'api_key': os.environ['OPENAI_API_KEY'] } # 如果 OPENAI_BASEURL 环境变量存在,则添加 base_url 参数 if os.environ.get('OPENAI_BASEURL'): client_params['base_url'] = os.environ['OPENAI_BASEURL'] client = openai.OpenAI(**client_params) elif client_type == "azure_openai": endpoint: str = os.environ['AZURE_ENDPOINT'] if not endpoint.startswith("https://"): endpoint = f"https://{endpoint}.openai.azure.com" os.environ['AZURE_ENDPOINT'] = endpoint client = openai.AzureOpenAI( api_key=os.environ['AZURE_OPENAI_KEY'], azure_endpoint=os.environ['AZURE_ENDPOINT'], # f"https://YOUR_END_POINT.openai.azure.com" azure_deployment=os.environ['AZURE_DEPLOYMENT'] ) else: raise NotImplementedError return client