Spaces:
Paused
Paused
| from typing import Any, Coroutine, Optional, Union | |
| import httpx | |
| from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI | |
| from openai.types.fine_tuning import FineTuningJob | |
| from litellm._logging import verbose_logger | |
| class OpenAIFineTuningAPI: | |
| """ | |
| OpenAI methods to support for batches | |
| """ | |
| def __init__(self) -> None: | |
| super().__init__() | |
| def get_openai_client( | |
| self, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[ | |
| Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] | |
| ] = None, | |
| _is_async: bool = False, | |
| api_version: Optional[str] = None, | |
| litellm_params: Optional[dict] = None, | |
| ) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI,]]: | |
| received_args = locals() | |
| openai_client: Optional[ | |
| Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] | |
| ] = None | |
| if client is None: | |
| data = {} | |
| for k, v in received_args.items(): | |
| if k == "self" or k == "client" or k == "_is_async": | |
| pass | |
| elif k == "api_base" and v is not None: | |
| data["base_url"] = v | |
| elif v is not None: | |
| data[k] = v | |
| if _is_async is True: | |
| openai_client = AsyncOpenAI(**data) | |
| else: | |
| openai_client = OpenAI(**data) # type: ignore | |
| else: | |
| openai_client = client | |
| return openai_client | |
| async def acreate_fine_tuning_job( | |
| self, | |
| create_fine_tuning_job_data: dict, | |
| openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI], | |
| ) -> FineTuningJob: | |
| response = await openai_client.fine_tuning.jobs.create( | |
| **create_fine_tuning_job_data | |
| ) | |
| return response | |
| def create_fine_tuning_job( | |
| self, | |
| _is_async: bool, | |
| create_fine_tuning_job_data: dict, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[ | |
| Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] | |
| ] = None, | |
| ) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]: | |
| openai_client: Optional[ | |
| Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] | |
| ] = self.get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| _is_async=_is_async, | |
| api_version=api_version, | |
| ) | |
| if openai_client is None: | |
| raise ValueError( | |
| "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." | |
| ) | |
| if _is_async is True: | |
| if not isinstance(openai_client, (AsyncOpenAI, AsyncAzureOpenAI)): | |
| raise ValueError( | |
| "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." | |
| ) | |
| return self.acreate_fine_tuning_job( # type: ignore | |
| create_fine_tuning_job_data=create_fine_tuning_job_data, | |
| openai_client=openai_client, | |
| ) | |
| verbose_logger.debug( | |
| "creating fine tuning job, args= %s", create_fine_tuning_job_data | |
| ) | |
| response = openai_client.fine_tuning.jobs.create(**create_fine_tuning_job_data) | |
| return response | |
| async def acancel_fine_tuning_job( | |
| self, | |
| fine_tuning_job_id: str, | |
| openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI], | |
| ) -> FineTuningJob: | |
| response = await openai_client.fine_tuning.jobs.cancel( | |
| fine_tuning_job_id=fine_tuning_job_id | |
| ) | |
| return response | |
| def cancel_fine_tuning_job( | |
| self, | |
| _is_async: bool, | |
| fine_tuning_job_id: str, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[ | |
| Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] | |
| ] = None, | |
| ): | |
| openai_client: Optional[ | |
| Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] | |
| ] = self.get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| _is_async=_is_async, | |
| api_version=api_version, | |
| ) | |
| if openai_client is None: | |
| raise ValueError( | |
| "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." | |
| ) | |
| if _is_async is True: | |
| if not isinstance(openai_client, (AsyncOpenAI, AsyncAzureOpenAI)): | |
| raise ValueError( | |
| "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." | |
| ) | |
| return self.acancel_fine_tuning_job( # type: ignore | |
| fine_tuning_job_id=fine_tuning_job_id, | |
| openai_client=openai_client, | |
| ) | |
| verbose_logger.debug("canceling fine tuning job, args= %s", fine_tuning_job_id) | |
| response = openai_client.fine_tuning.jobs.cancel( | |
| fine_tuning_job_id=fine_tuning_job_id | |
| ) | |
| return response | |
| async def alist_fine_tuning_jobs( | |
| self, | |
| openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI], | |
| after: Optional[str] = None, | |
| limit: Optional[int] = None, | |
| ): | |
| response = await openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore | |
| return response | |
| def list_fine_tuning_jobs( | |
| self, | |
| _is_async: bool, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[ | |
| Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] | |
| ] = None, | |
| after: Optional[str] = None, | |
| limit: Optional[int] = None, | |
| ): | |
| openai_client: Optional[ | |
| Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] | |
| ] = self.get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| _is_async=_is_async, | |
| api_version=api_version, | |
| ) | |
| if openai_client is None: | |
| raise ValueError( | |
| "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." | |
| ) | |
| if _is_async is True: | |
| if not isinstance(openai_client, (AsyncOpenAI, AsyncAzureOpenAI)): | |
| raise ValueError( | |
| "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." | |
| ) | |
| return self.alist_fine_tuning_jobs( # type: ignore | |
| after=after, | |
| limit=limit, | |
| openai_client=openai_client, | |
| ) | |
| verbose_logger.debug("list fine tuning job, after= %s, limit= %s", after, limit) | |
| response = openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore | |
| return response | |
| async def aretrieve_fine_tuning_job( | |
| self, | |
| fine_tuning_job_id: str, | |
| openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI], | |
| ) -> FineTuningJob: | |
| response = await openai_client.fine_tuning.jobs.retrieve( | |
| fine_tuning_job_id=fine_tuning_job_id | |
| ) | |
| return response | |
| def retrieve_fine_tuning_job( | |
| self, | |
| _is_async: bool, | |
| fine_tuning_job_id: str, | |
| api_key: Optional[str], | |
| api_base: Optional[str], | |
| api_version: Optional[str], | |
| timeout: Union[float, httpx.Timeout], | |
| max_retries: Optional[int], | |
| organization: Optional[str], | |
| client: Optional[ | |
| Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] | |
| ] = None, | |
| ): | |
| openai_client: Optional[ | |
| Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] | |
| ] = self.get_openai_client( | |
| api_key=api_key, | |
| api_base=api_base, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| organization=organization, | |
| client=client, | |
| _is_async=_is_async, | |
| api_version=api_version, | |
| ) | |
| if openai_client is None: | |
| raise ValueError( | |
| "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." | |
| ) | |
| if _is_async is True: | |
| if not isinstance(openai_client, AsyncOpenAI): | |
| raise ValueError( | |
| "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." | |
| ) | |
| return self.aretrieve_fine_tuning_job( # type: ignore | |
| fine_tuning_job_id=fine_tuning_job_id, | |
| openai_client=openai_client, | |
| ) | |
| verbose_logger.debug("retrieving fine tuning job, id= %s", fine_tuning_job_id) | |
| response = openai_client.fine_tuning.jobs.retrieve( | |
| fine_tuning_job_id=fine_tuning_job_id | |
| ) | |
| return response | |