Spaces:
Paused
Paused
| from typing import Any, Coroutine, Dict, Iterable, Literal, Optional, Union | |
| import httpx | |
| from openai import AsyncAzureOpenAI, AzureOpenAI | |
| from typing_extensions import overload | |
| from ...types.llms.openai import ( | |
| Assistant, | |
| AssistantEventHandler, | |
| AssistantStreamManager, | |
| AssistantToolParam, | |
| AsyncAssistantEventHandler, | |
| AsyncAssistantStreamManager, | |
| AsyncCursorPage, | |
| OpenAICreateThreadParamsMessage, | |
| OpenAIMessage, | |
| Run, | |
| SyncCursorPage, | |
| Thread, | |
| ) | |
| from .common_utils import BaseAzureLLM | |
| class AzureAssistantsAPI(BaseAzureLLM): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| def get_azure_client( | |
| self, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| client: Optional[AzureOpenAI] = None, | |
| litellm_params: Optional[dict] = None, | |
| ) -> AzureOpenAI: | |
| if client is None: | |
| azure_client_params = self.initialize_azure_sdk_client( | |
| litellm_params=litellm_params or {}, | |
| api_key=api_key, | |
| api_base=api_base, | |
| model_name="", | |
| api_version=api_version, | |
| is_async=False, | |
| ) | |
| azure_openai_client = AzureOpenAI(**azure_client_params) # type: ignore | |
| else: | |
| azure_openai_client = client | |
| return azure_openai_client | |
| def async_get_azure_client( | |
| self, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| client: Optional[AsyncAzureOpenAI] = None, | |
| litellm_params: Optional[dict] = None, | |
| ) -> AsyncAzureOpenAI: | |
| if client is None: | |
| azure_client_params = self.initialize_azure_sdk_client( | |
| litellm_params=litellm_params or {}, | |
| api_key=api_key, | |
| api_base=api_base, | |
| model_name="", | |
| api_version=api_version, | |
| is_async=True, | |
| ) | |
| azure_openai_client = AsyncAzureOpenAI(**azure_client_params) | |
| # azure_openai_client = AsyncAzureOpenAI(**data) # type: ignore | |
| else: | |
| azure_openai_client = client | |
| return azure_openai_client | |
| ### ASSISTANTS ### | |
| async def async_get_assistants( | |
| self, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| client: Optional[AsyncAzureOpenAI], | |
| litellm_params: Optional[dict] = None, | |
| ) -> AsyncCursorPage[Assistant]: | |
| azure_openai_client = self.async_get_azure_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| azure_ad_token=azure_ad_token, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| litellm_params=litellm_params, | |
| ) | |
| response = await azure_openai_client.beta.assistants.list() | |
| return response | |
| # fmt: off | |
| def get_assistants( | |
| self, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| client: Optional[AsyncAzureOpenAI], | |
| aget_assistants: Literal[True], | |
| ) -> Coroutine[None, None, AsyncCursorPage[Assistant]]: | |
| ... | |
| def get_assistants( | |
| self, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| client: Optional[AzureOpenAI], | |
| aget_assistants: Optional[Literal[False]], | |
| ) -> SyncCursorPage[Assistant]: | |
| ... | |
| # fmt: on | |
| def get_assistants( | |
| self, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| client=None, | |
| aget_assistants=None, | |
| litellm_params: Optional[dict] = None, | |
| ): | |
| if aget_assistants is not None and aget_assistants is True: | |
| return self.async_get_assistants( | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| azure_ad_token=azure_ad_token, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| litellm_params=litellm_params, | |
| ) | |
| azure_openai_client = self.get_azure_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| azure_ad_token=azure_ad_token, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| api_version=api_version, | |
| litellm_params=litellm_params, | |
| ) | |
| response = azure_openai_client.beta.assistants.list() | |
| return response | |
| ### MESSAGES ### | |
| async def a_add_message( | |
| self, | |
| thread_id: str, | |
| message_data: dict, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| client: Optional[AsyncAzureOpenAI] = None, | |
| litellm_params: Optional[dict] = None, | |
| ) -> OpenAIMessage: | |
| openai_client = self.async_get_azure_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| azure_ad_token=azure_ad_token, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| litellm_params=litellm_params, | |
| ) | |
| thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore | |
| thread_id, **message_data # type: ignore | |
| ) | |
| response_obj: Optional[OpenAIMessage] = None | |
| if getattr(thread_message, "status", None) is None: | |
| thread_message.status = "completed" | |
| response_obj = OpenAIMessage(**thread_message.dict()) | |
| else: | |
| response_obj = OpenAIMessage(**thread_message.dict()) | |
| return response_obj | |
| # fmt: off | |
| def add_message( | |
| self, | |
| thread_id: str, | |
| message_data: dict, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| client: Optional[AsyncAzureOpenAI], | |
| a_add_message: Literal[True], | |
| litellm_params: Optional[dict] = None, | |
| ) -> Coroutine[None, None, OpenAIMessage]: | |
| ... | |
| def add_message( | |
| self, | |
| thread_id: str, | |
| message_data: dict, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| client: Optional[AzureOpenAI], | |
| a_add_message: Optional[Literal[False]], | |
| litellm_params: Optional[dict] = None, | |
| ) -> OpenAIMessage: | |
| ... | |
| # fmt: on | |
| def add_message( | |
| self, | |
| thread_id: str, | |
| message_data: dict, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| client=None, | |
| a_add_message: Optional[bool] = None, | |
| litellm_params: Optional[dict] = None, | |
| ): | |
| if a_add_message is not None and a_add_message is True: | |
| return self.a_add_message( | |
| thread_id=thread_id, | |
| message_data=message_data, | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| azure_ad_token=azure_ad_token, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| litellm_params=litellm_params, | |
| ) | |
| openai_client = self.get_azure_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| azure_ad_token=azure_ad_token, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| litellm_params=litellm_params, | |
| ) | |
| thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( # type: ignore | |
| thread_id, **message_data # type: ignore | |
| ) | |
| response_obj: Optional[OpenAIMessage] = None | |
| if getattr(thread_message, "status", None) is None: | |
| thread_message.status = "completed" | |
| response_obj = OpenAIMessage(**thread_message.dict()) | |
| else: | |
| response_obj = OpenAIMessage(**thread_message.dict()) | |
| return response_obj | |
| async def async_get_messages( | |
| self, | |
| thread_id: str, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| client: Optional[AsyncAzureOpenAI] = None, | |
| litellm_params: Optional[dict] = None, | |
| ) -> AsyncCursorPage[OpenAIMessage]: | |
| openai_client = self.async_get_azure_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| azure_ad_token=azure_ad_token, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| litellm_params=litellm_params, | |
| ) | |
| response = await openai_client.beta.threads.messages.list(thread_id=thread_id) | |
| return response | |
| # fmt: off | |
| def get_messages( | |
| self, | |
| thread_id: str, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| client: Optional[AsyncAzureOpenAI], | |
| aget_messages: Literal[True], | |
| litellm_params: Optional[dict] = None, | |
| ) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]: | |
| ... | |
| def get_messages( | |
| self, | |
| thread_id: str, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| client: Optional[AzureOpenAI], | |
| aget_messages: Optional[Literal[False]], | |
| litellm_params: Optional[dict] = None, | |
| ) -> SyncCursorPage[OpenAIMessage]: | |
| ... | |
| # fmt: on | |
| def get_messages( | |
| self, | |
| thread_id: str, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| client=None, | |
| aget_messages=None, | |
| litellm_params: Optional[dict] = None, | |
| ): | |
| if aget_messages is not None and aget_messages is True: | |
| return self.async_get_messages( | |
| thread_id=thread_id, | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| azure_ad_token=azure_ad_token, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| litellm_params=litellm_params, | |
| ) | |
| openai_client = self.get_azure_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| azure_ad_token=azure_ad_token, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| litellm_params=litellm_params, | |
| ) | |
| response = openai_client.beta.threads.messages.list(thread_id=thread_id) | |
| return response | |
| ### THREADS ### | |
| async def async_create_thread( | |
| self, | |
| metadata: Optional[dict], | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| client: Optional[AsyncAzureOpenAI], | |
| messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], | |
| litellm_params: Optional[dict] = None, | |
| ) -> Thread: | |
| openai_client = self.async_get_azure_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| azure_ad_token=azure_ad_token, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| litellm_params=litellm_params, | |
| ) | |
| data = {} | |
| if messages is not None: | |
| data["messages"] = messages # type: ignore | |
| if metadata is not None: | |
| data["metadata"] = metadata # type: ignore | |
| message_thread = await openai_client.beta.threads.create(**data) # type: ignore | |
| return Thread(**message_thread.dict()) | |
| # fmt: off | |
| def create_thread( | |
| self, | |
| metadata: Optional[dict], | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], | |
| client: Optional[AsyncAzureOpenAI], | |
| acreate_thread: Literal[True], | |
| litellm_params: Optional[dict] = None, | |
| ) -> Coroutine[None, None, Thread]: | |
| ... | |
| def create_thread( | |
| self, | |
| metadata: Optional[dict], | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], | |
| client: Optional[AzureOpenAI], | |
| acreate_thread: Optional[Literal[False]], | |
| litellm_params: Optional[dict] = None, | |
| ) -> Thread: | |
| ... | |
| # fmt: on | |
| def create_thread( | |
| self, | |
| metadata: Optional[dict], | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], | |
| client=None, | |
| acreate_thread=None, | |
| litellm_params: Optional[dict] = None, | |
| ): | |
| """ | |
| Here's an example: | |
| ``` | |
| from litellm.llms.openai.openai import OpenAIAssistantsAPI, MessageData | |
| # create thread | |
| message: MessageData = {"role": "user", "content": "Hey, how's it going?"} | |
| openai_api.create_thread(messages=[message]) | |
| ``` | |
| """ | |
| if acreate_thread is not None and acreate_thread is True: | |
| return self.async_create_thread( | |
| metadata=metadata, | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| azure_ad_token=azure_ad_token, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| messages=messages, | |
| litellm_params=litellm_params, | |
| ) | |
| azure_openai_client = self.get_azure_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| azure_ad_token=azure_ad_token, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| litellm_params=litellm_params, | |
| ) | |
| data = {} | |
| if messages is not None: | |
| data["messages"] = messages # type: ignore | |
| if metadata is not None: | |
| data["metadata"] = metadata # type: ignore | |
| message_thread = azure_openai_client.beta.threads.create(**data) # type: ignore | |
| return Thread(**message_thread.dict()) | |
| async def async_get_thread( | |
| self, | |
| thread_id: str, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| client: Optional[AsyncAzureOpenAI], | |
| litellm_params: Optional[dict] = None, | |
| ) -> Thread: | |
| openai_client = self.async_get_azure_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| azure_ad_token=azure_ad_token, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| litellm_params=litellm_params, | |
| ) | |
| response = await openai_client.beta.threads.retrieve(thread_id=thread_id) | |
| return Thread(**response.dict()) | |
| # fmt: off | |
| def get_thread( | |
| self, | |
| thread_id: str, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| client: Optional[AsyncAzureOpenAI], | |
| aget_thread: Literal[True], | |
| litellm_params: Optional[dict] = None, | |
| ) -> Coroutine[None, None, Thread]: | |
| ... | |
| def get_thread( | |
| self, | |
| thread_id: str, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| client: Optional[AzureOpenAI], | |
| aget_thread: Optional[Literal[False]], | |
| litellm_params: Optional[dict] = None, | |
| ) -> Thread: | |
| ... | |
| # fmt: on | |
| def get_thread( | |
| self, | |
| thread_id: str, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| client=None, | |
| aget_thread=None, | |
| litellm_params: Optional[dict] = None, | |
| ): | |
| if aget_thread is not None and aget_thread is True: | |
| return self.async_get_thread( | |
| thread_id=thread_id, | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| azure_ad_token=azure_ad_token, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| litellm_params=litellm_params, | |
| ) | |
| openai_client = self.get_azure_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| azure_ad_token=azure_ad_token, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| litellm_params=litellm_params, | |
| ) | |
| response = openai_client.beta.threads.retrieve(thread_id=thread_id) | |
| return Thread(**response.dict()) | |
| # def delete_thread(self): | |
| # pass | |
| ### RUNS ### | |
| async def arun_thread( | |
| self, | |
| thread_id: str, | |
| assistant_id: str, | |
| additional_instructions: Optional[str], | |
| instructions: Optional[str], | |
| metadata: Optional[Dict], | |
| model: Optional[str], | |
| stream: Optional[bool], | |
| tools: Optional[Iterable[AssistantToolParam]], | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| client: Optional[AsyncAzureOpenAI], | |
| litellm_params: Optional[dict] = None, | |
| ) -> Run: | |
| openai_client = self.async_get_azure_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| api_version=api_version, | |
| azure_ad_token=azure_ad_token, | |
| client=client, | |
| litellm_params=litellm_params, | |
| ) | |
| response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore | |
| thread_id=thread_id, | |
| assistant_id=assistant_id, | |
| additional_instructions=additional_instructions, | |
| instructions=instructions, | |
| metadata=metadata, # type: ignore | |
| model=model, | |
| tools=tools, | |
| ) | |
| return response | |
| def async_run_thread_stream( | |
| self, | |
| client: AsyncAzureOpenAI, | |
| thread_id: str, | |
| assistant_id: str, | |
| additional_instructions: Optional[str], | |
| instructions: Optional[str], | |
| metadata: Optional[Dict], | |
| model: Optional[str], | |
| tools: Optional[Iterable[AssistantToolParam]], | |
| event_handler: Optional[AssistantEventHandler], | |
| litellm_params: Optional[dict] = None, | |
| ) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]: | |
| data: Dict[str, Any] = { | |
| "thread_id": thread_id, | |
| "assistant_id": assistant_id, | |
| "additional_instructions": additional_instructions, | |
| "instructions": instructions, | |
| "metadata": metadata, | |
| "model": model, | |
| "tools": tools, | |
| } | |
| if event_handler is not None: | |
| data["event_handler"] = event_handler | |
| return client.beta.threads.runs.stream(**data) # type: ignore | |
| def run_thread_stream( | |
| self, | |
| client: AzureOpenAI, | |
| thread_id: str, | |
| assistant_id: str, | |
| additional_instructions: Optional[str], | |
| instructions: Optional[str], | |
| metadata: Optional[Dict], | |
| model: Optional[str], | |
| tools: Optional[Iterable[AssistantToolParam]], | |
| event_handler: Optional[AssistantEventHandler], | |
| litellm_params: Optional[dict] = None, | |
| ) -> AssistantStreamManager[AssistantEventHandler]: | |
| data: Dict[str, Any] = { | |
| "thread_id": thread_id, | |
| "assistant_id": assistant_id, | |
| "additional_instructions": additional_instructions, | |
| "instructions": instructions, | |
| "metadata": metadata, | |
| "model": model, | |
| "tools": tools, | |
| } | |
| if event_handler is not None: | |
| data["event_handler"] = event_handler | |
| return client.beta.threads.runs.stream(**data) # type: ignore | |
| # fmt: off | |
| def run_thread( | |
| self, | |
| thread_id: str, | |
| assistant_id: str, | |
| additional_instructions: Optional[str], | |
| instructions: Optional[str], | |
| metadata: Optional[Dict], | |
| model: Optional[str], | |
| stream: Optional[bool], | |
| tools: Optional[Iterable[AssistantToolParam]], | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| client: Optional[AsyncAzureOpenAI], | |
| arun_thread: Literal[True], | |
| ) -> Coroutine[None, None, Run]: | |
| ... | |
| def run_thread( | |
| self, | |
| thread_id: str, | |
| assistant_id: str, | |
| additional_instructions: Optional[str], | |
| instructions: Optional[str], | |
| metadata: Optional[Dict], | |
| model: Optional[str], | |
| stream: Optional[bool], | |
| tools: Optional[Iterable[AssistantToolParam]], | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| client: Optional[AzureOpenAI], | |
| arun_thread: Optional[Literal[False]], | |
| ) -> Run: | |
| ... | |
| # fmt: on | |
| def run_thread( | |
| self, | |
| thread_id: str, | |
| assistant_id: str, | |
| additional_instructions: Optional[str], | |
| instructions: Optional[str], | |
| metadata: Optional[Dict], | |
| model: Optional[str], | |
| stream: Optional[bool], | |
| tools: Optional[Iterable[AssistantToolParam]], | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| client=None, | |
| arun_thread=None, | |
| event_handler: Optional[AssistantEventHandler] = None, | |
| litellm_params: Optional[dict] = None, | |
| ): | |
| if arun_thread is not None and arun_thread is True: | |
| if stream is not None and stream is True: | |
| azure_client = self.async_get_azure_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| azure_ad_token=azure_ad_token, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| litellm_params=litellm_params, | |
| ) | |
| return self.async_run_thread_stream( | |
| client=azure_client, | |
| thread_id=thread_id, | |
| assistant_id=assistant_id, | |
| additional_instructions=additional_instructions, | |
| instructions=instructions, | |
| metadata=metadata, | |
| model=model, | |
| tools=tools, | |
| event_handler=event_handler, | |
| litellm_params=litellm_params, | |
| ) | |
| return self.arun_thread( | |
| thread_id=thread_id, | |
| assistant_id=assistant_id, | |
| additional_instructions=additional_instructions, | |
| instructions=instructions, | |
| metadata=metadata, # type: ignore | |
| model=model, | |
| stream=stream, | |
| tools=tools, | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| azure_ad_token=azure_ad_token, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| litellm_params=litellm_params, | |
| ) | |
| openai_client = self.get_azure_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| azure_ad_token=azure_ad_token, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| litellm_params=litellm_params, | |
| ) | |
| if stream is not None and stream is True: | |
| return self.run_thread_stream( | |
| client=openai_client, | |
| thread_id=thread_id, | |
| assistant_id=assistant_id, | |
| additional_instructions=additional_instructions, | |
| instructions=instructions, | |
| metadata=metadata, | |
| model=model, | |
| tools=tools, | |
| event_handler=event_handler, | |
| litellm_params=litellm_params, | |
| ) | |
| response = openai_client.beta.threads.runs.create_and_poll( # type: ignore | |
| thread_id=thread_id, | |
| assistant_id=assistant_id, | |
| additional_instructions=additional_instructions, | |
| instructions=instructions, | |
| metadata=metadata, # type: ignore | |
| model=model, | |
| tools=tools, | |
| ) | |
| return response | |
| # Create Assistant | |
| async def async_create_assistants( | |
| self, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| client: Optional[AsyncAzureOpenAI], | |
| create_assistant_data: dict, | |
| litellm_params: Optional[dict] = None, | |
| ) -> Assistant: | |
| azure_openai_client = self.async_get_azure_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| azure_ad_token=azure_ad_token, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| litellm_params=litellm_params, | |
| ) | |
| response = await azure_openai_client.beta.assistants.create( | |
| **create_assistant_data | |
| ) | |
| return response | |
| def create_assistants( | |
| self, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| create_assistant_data: dict, | |
| client=None, | |
| async_create_assistants=None, | |
| litellm_params: Optional[dict] = None, | |
| ): | |
| if async_create_assistants is not None and async_create_assistants is True: | |
| return self.async_create_assistants( | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| azure_ad_token=azure_ad_token, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| create_assistant_data=create_assistant_data, | |
| litellm_params=litellm_params, | |
| ) | |
| azure_openai_client = self.get_azure_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| azure_ad_token=azure_ad_token, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| litellm_params=litellm_params, | |
| ) | |
| response = azure_openai_client.beta.assistants.create(**create_assistant_data) | |
| return response | |
| # Delete Assistant | |
| async def async_delete_assistant( | |
| self, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| client: Optional[AsyncAzureOpenAI], | |
| assistant_id: str, | |
| litellm_params: Optional[dict] = None, | |
| ): | |
| azure_openai_client = self.async_get_azure_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| azure_ad_token=azure_ad_token, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| litellm_params=litellm_params, | |
| ) | |
| response = await azure_openai_client.beta.assistants.delete( | |
| assistant_id=assistant_id | |
| ) | |
| return response | |
| def delete_assistant( | |
| self, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| azure_ad_token: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| assistant_id: str, | |
| async_delete_assistants: Optional[bool] = None, | |
| client=None, | |
| litellm_params: Optional[dict] = None, | |
| ): | |
| if async_delete_assistants is not None and async_delete_assistants is True: | |
| return self.async_delete_assistant( | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| azure_ad_token=azure_ad_token, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| assistant_id=assistant_id, | |
| litellm_params=litellm_params, | |
| ) | |
| azure_openai_client = self.get_azure_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| api_version=api_version, | |
| azure_ad_token=azure_ad_token, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| client=client, | |
| litellm_params=litellm_params, | |
| ) | |
| response = azure_openai_client.beta.assistants.delete(assistant_id=assistant_id) | |
| return response | |