| """Integration with OpenAI's API.""" |
| import copy |
| import functools |
| from dataclasses import asdict, dataclass, field, replace |
| from typing import Callable, Dict, List, Optional, Tuple, Union |
|
|
| import numpy as np |
|
|
| from outlines.base import vectorize |
| from outlines.caching import cache |
|
|
| __all__ = ["OpenAI", "openai", "azure_openai"] |
|
|
|
|
| @dataclass(frozen=True) |
| class OpenAIConfig: |
| """Represents the parameters of the OpenAI API. |
| |
| The information was last fetched on 2023/11/20. We document below the |
| properties that are specific to the OpenAI API. Not all these properties are |
| supported by Outlines. |
| |
| Properties |
| ---------- |
| model |
| The name of the model. Available models can be found on OpenAI's website. |
| frequence_penalty |
| Number between 2.0 and -2.0. Positive values penalize new tokens based on |
| their existing frequency in the text, |
| logit_bias |
| Modifies the likelihood of specified tokens to appear in the completion. |
| Number between -100 (forbid) and +100 (only allows). |
| n |
| The number of completions to return for each prompt. |
| presence_penalty |
| Similar to frequency penalty. |
| response_format |
| Specifies the format the model must output. `{"type": "json_object"}` |
| enables JSON mode. |
| seed |
| Two completions with the same `seed` value should return the same |
| completion. This is however not guaranteed. |
| stop |
| Up to 4 words where the API will stop the completion. |
| temperature |
| Number between 0 and 2. Higher values make the output more random, while |
| lower values make it more deterministic. |
| top_p |
| Number between 0 and 1. Parameter for nucleus sampling. |
| user |
| A unique identifier for the end-user. |
| |
| """ |
|
|
| model: str = "" |
| frequency_penalty: float = 0 |
| logit_bias: Dict[int, int] = field(default_factory=dict) |
| max_tokens: Optional[int] = None |
| n: int = 1 |
| presence_penalty: float = 0 |
| response_format: Optional[Dict[str, str]] = None |
| seed: Optional[int] = None |
| stop: Optional[Union[str, List[str]]] = None |
| temperature: float = 1.0 |
| top_p: int = 1 |
| user: str = field(default_factory=str) |
|
|
|
|
| class OpenAI: |
| """An object that represents the OpenAI API.""" |
|
|
| def __init__( |
| self, |
| client, |
| config, |
| system_prompt: Optional[str] = None, |
| ): |
| """Create an `OpenAI` instance. |
| |
| This class supports the standard OpenAI API, the Azure OpeanAI API as |
| well as compatible APIs that rely on the OpenAI client. |
| |
| Parameters |
| ---------- |
| client |
| An instance of the API's async client. |
| config |
| An instance of `OpenAIConfig`. Can be useful to specify some |
| parameters that cannot be set by calling this class' methods. |
| """ |
|
|
| self.client = client |
| self.config = config |
|
|
| |
| |
| |
| self.prompt_tokens = 0 |
| self.completion_tokens = 0 |
|
|
| self.format_sequence = lambda seq: seq |
|
|
| def __call__( |
| self, |
| prompt: Union[str, List[str]], |
| max_tokens: Optional[int] = None, |
| stop_at: Optional[Union[List[str], str]] = None, |
| *, |
| system_prompt: Optional[str] = None, |
| temperature: Optional[float] = None, |
| samples: Optional[int] = None, |
| ) -> np.ndarray: |
| """Call the OpenAI API to generate text. |
| |
| Parameters |
| ---------- |
| prompt |
| A string or list of strings that will be used to prompt the model |
| max_tokens |
| The maximum number of tokens to generate |
| stop_at |
| A string or array of strings which, such that the generation stops |
| when they are generated. |
| system_prompt |
| The content of the system message that precedes the user's prompt. |
| temperature |
| The value of the temperature used to sample tokens |
| samples |
| The number of completions to generate for each prompt |
| stop_at |
| Up to 4 words where the API will stop the completion. |
| |
| """ |
| if max_tokens is None: |
| max_tokens = self.config.max_tokens |
| if stop_at is None: |
| stop_at = self.config.stop |
| if temperature is None: |
| temperature = self.config.temperature |
| if samples is None: |
| samples = self.config.n |
|
|
| config = replace(self.config, max_tokens=max_tokens, temperature=temperature, n=samples, stop=stop_at) |
|
|
| response, prompt_tokens, completion_tokens = generate_chat( |
| prompt, system_prompt, self.client, config |
| ) |
| self.prompt_tokens += prompt_tokens |
| self.completion_tokens += completion_tokens |
|
|
| return self.format_sequence(response) |
|
|
| def stream(self, *args, **kwargs): |
| raise NotImplementedError( |
| "Streaming is currently not supported for the OpenAI API" |
| ) |
|
|
| def new_with_replacements(self, **kwargs): |
| new_instance = copy.copy(self) |
| new_instance.config = replace(new_instance.config, **kwargs) |
| return new_instance |
|
|
| def __str__(self): |
| return self.__class__.__name__ + " API" |
|
|
| def __repr__(self): |
| return str(self.config) |
|
|
|
|
| @functools.partial(vectorize, signature="(),(),(),()->(s),(),()") |
| async def generate_chat( |
| prompt: str, |
| system_prompt: Union[str, None], |
| client, |
| config: OpenAIConfig, |
| ) -> Tuple[np.ndarray, int, int]: |
| """Call OpenAI's Chat Completion API. |
| |
| Parameters |
| ---------- |
| prompt |
| The prompt we use to start the generation. Passed to the model |
| with the "user" role. |
| system_prompt |
| The system prompt, passed to the model with the "system" role |
| before the prompt. |
| client |
| The API client |
| config |
| An `OpenAIConfig` instance. |
| |
| Returns |
| ------- |
| A tuple that contains the model's response(s) and usage statistics. |
| |
| """ |
|
|
| @error_handler |
| @cache() |
| async def call_api(prompt, system_prompt, config): |
| responses = await client.chat.completions.create( |
| messages=system_message + user_message, |
| **asdict(config), |
| ) |
| return responses.model_dump() |
|
|
| system_message = ( |
| [{"role": "system", "content": system_prompt}] if system_prompt else [] |
| ) |
| user_message = [{"role": "user", "content": prompt}] |
|
|
| responses = await call_api(prompt, system_prompt, config) |
|
|
| results = np.array( |
| [responses["choices"][i]["message"]["content"] for i in range(config.n)] |
| ) |
| usage = responses["usage"] |
|
|
| return results, usage["prompt_tokens"], usage["completion_tokens"] |
|
|
|
|
| def error_handler(api_call_fn: Callable) -> Callable: |
| """Handle OpenAI API errors and missing API key.""" |
|
|
| def call(*args, **kwargs): |
| import openai |
|
|
| try: |
| return api_call_fn(*args, **kwargs) |
| except ( |
| openai.APITimeoutError, |
| openai.InternalServerError, |
| openai.RateLimitError, |
| ) as e: |
| raise OSError(f"Could not connect to the OpenAI API: {e}") |
| except ( |
| openai.AuthenticationError, |
| openai.BadRequestError, |
| openai.ConflictError, |
| openai.PermissionDeniedError, |
| openai.NotFoundError, |
| openai.UnprocessableEntityError, |
| ) as e: |
| raise e |
|
|
| return call |
|
|
|
|
| @functools.singledispatch |
| def openai(model_or_client, *args, **kwargs): |
| return OpenAI(model_or_client, *args, **kwargs) |
|
|
|
|
| @openai.register(str) |
| def openai_model( |
| model_name: str, |
| config: Optional[OpenAIConfig] = None, |
| **openai_client_params, |
| ): |
| try: |
| from openai import AsyncOpenAI |
| except ImportError: |
| raise ImportError( |
| "The `openai` library needs to be installed in order to use Outlines' OpenAI integration." |
| ) |
|
|
| if config is not None: |
| config = replace(config, model=model_name) |
| else: |
| config = OpenAIConfig(model=model_name) |
|
|
| client = AsyncOpenAI(**openai_client_params) |
|
|
| return OpenAI(client, config) |
|
|
|
|
| def azure_openai( |
| deployment_name: str, |
| model_name: Optional[str] = None, |
| config: Optional[OpenAIConfig] = None, |
| **azure_openai_client_params, |
| ): |
| try: |
| from openai import AsyncAzureOpenAI |
| except ImportError: |
| raise ImportError( |
| "The `openai` library needs to be installed in order to use Outlines' Azure OpenAI integration." |
| ) |
|
|
| if config is not None: |
| config = replace(config, model=deployment_name) |
| if config is None: |
| config = OpenAIConfig(model=deployment_name) |
|
|
| client = AsyncAzureOpenAI(**azure_openai_client_params) |
|
|
| return OpenAI(client, config) |
|
|