koichi12's picture
Add files using upload-large-folder tool
762d748 verified
raw
history blame
9.01 kB
"""Integration with OpenAI's API."""
import copy
import functools
from dataclasses import asdict, dataclass, field, replace
from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
from outlines.base import vectorize
from outlines.caching import cache
__all__ = ["OpenAI", "openai", "azure_openai"]
@dataclass(frozen=True)
class OpenAIConfig:
"""Represents the parameters of the OpenAI API.
The information was last fetched on 2023/11/20. We document below the
properties that are specific to the OpenAI API. Not all these properties are
supported by Outlines.
Properties
----------
model
The name of the model. Available models can be found on OpenAI's website.
frequence_penalty
Number between 2.0 and -2.0. Positive values penalize new tokens based on
their existing frequency in the text,
logit_bias
Modifies the likelihood of specified tokens to appear in the completion.
Number between -100 (forbid) and +100 (only allows).
n
The number of completions to return for each prompt.
presence_penalty
Similar to frequency penalty.
response_format
Specifies the format the model must output. `{"type": "json_object"}`
enables JSON mode.
seed
Two completions with the same `seed` value should return the same
completion. This is however not guaranteed.
stop
Up to 4 words where the API will stop the completion.
temperature
Number between 0 and 2. Higher values make the output more random, while
lower values make it more deterministic.
top_p
Number between 0 and 1. Parameter for nucleus sampling.
user
A unique identifier for the end-user.
"""
model: str = ""
frequency_penalty: float = 0
logit_bias: Dict[int, int] = field(default_factory=dict)
max_tokens: Optional[int] = None
n: int = 1
presence_penalty: float = 0
response_format: Optional[Dict[str, str]] = None
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
temperature: float = 1.0
top_p: int = 1
user: str = field(default_factory=str)
class OpenAI:
"""An object that represents the OpenAI API."""
def __init__(
self,
client,
config,
system_prompt: Optional[str] = None,
):
"""Create an `OpenAI` instance.
This class supports the standard OpenAI API, the Azure OpeanAI API as
well as compatible APIs that rely on the OpenAI client.
Parameters
----------
client
An instance of the API's async client.
config
An instance of `OpenAIConfig`. Can be useful to specify some
parameters that cannot be set by calling this class' methods.
"""
self.client = client
self.config = config
# We count the total number of prompt and generated tokens as returned
# by the OpenAI API, summed over all the requests performed with this
# model instance.
self.prompt_tokens = 0
self.completion_tokens = 0
self.format_sequence = lambda seq: seq
def __call__(
self,
prompt: Union[str, List[str]],
max_tokens: Optional[int] = None,
stop_at: Optional[Union[List[str], str]] = None,
*,
system_prompt: Optional[str] = None,
temperature: Optional[float] = None,
samples: Optional[int] = None,
) -> np.ndarray:
"""Call the OpenAI API to generate text.
Parameters
----------
prompt
A string or list of strings that will be used to prompt the model
max_tokens
The maximum number of tokens to generate
stop_at
A string or array of strings which, such that the generation stops
when they are generated.
system_prompt
The content of the system message that precedes the user's prompt.
temperature
The value of the temperature used to sample tokens
samples
The number of completions to generate for each prompt
stop_at
Up to 4 words where the API will stop the completion.
"""
if max_tokens is None:
max_tokens = self.config.max_tokens
if stop_at is None:
stop_at = self.config.stop
if temperature is None:
temperature = self.config.temperature
if samples is None:
samples = self.config.n
config = replace(self.config, max_tokens=max_tokens, temperature=temperature, n=samples, stop=stop_at) # type: ignore
response, prompt_tokens, completion_tokens = generate_chat(
prompt, system_prompt, self.client, config
)
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
return self.format_sequence(response)
def stream(self, *args, **kwargs):
raise NotImplementedError(
"Streaming is currently not supported for the OpenAI API"
)
def new_with_replacements(self, **kwargs):
new_instance = copy.copy(self)
new_instance.config = replace(new_instance.config, **kwargs)
return new_instance
def __str__(self):
return self.__class__.__name__ + " API"
def __repr__(self):
return str(self.config)
@functools.partial(vectorize, signature="(),(),(),()->(s),(),()")
async def generate_chat(
prompt: str,
system_prompt: Union[str, None],
client,
config: OpenAIConfig,
) -> Tuple[np.ndarray, int, int]:
"""Call OpenAI's Chat Completion API.
Parameters
----------
prompt
The prompt we use to start the generation. Passed to the model
with the "user" role.
system_prompt
The system prompt, passed to the model with the "system" role
before the prompt.
client
The API client
config
An `OpenAIConfig` instance.
Returns
-------
A tuple that contains the model's response(s) and usage statistics.
"""
@error_handler
@cache()
async def call_api(prompt, system_prompt, config):
responses = await client.chat.completions.create(
messages=system_message + user_message,
**asdict(config), # type: ignore
)
return responses.model_dump()
system_message = (
[{"role": "system", "content": system_prompt}] if system_prompt else []
)
user_message = [{"role": "user", "content": prompt}]
responses = await call_api(prompt, system_prompt, config)
results = np.array(
[responses["choices"][i]["message"]["content"] for i in range(config.n)]
)
usage = responses["usage"]
return results, usage["prompt_tokens"], usage["completion_tokens"]
def error_handler(api_call_fn: Callable) -> Callable:
"""Handle OpenAI API errors and missing API key."""
def call(*args, **kwargs):
import openai
try:
return api_call_fn(*args, **kwargs)
except (
openai.APITimeoutError,
openai.InternalServerError,
openai.RateLimitError,
) as e:
raise OSError(f"Could not connect to the OpenAI API: {e}")
except (
openai.AuthenticationError,
openai.BadRequestError,
openai.ConflictError,
openai.PermissionDeniedError,
openai.NotFoundError,
openai.UnprocessableEntityError,
) as e:
raise e
return call
@functools.singledispatch
def openai(model_or_client, *args, **kwargs):
return OpenAI(model_or_client, *args, **kwargs)
@openai.register(str)
def openai_model(
model_name: str,
config: Optional[OpenAIConfig] = None,
**openai_client_params,
):
try:
from openai import AsyncOpenAI
except ImportError:
raise ImportError(
"The `openai` library needs to be installed in order to use Outlines' OpenAI integration."
)
if config is not None:
config = replace(config, model=model_name) # type: ignore
else:
config = OpenAIConfig(model=model_name)
client = AsyncOpenAI(**openai_client_params)
return OpenAI(client, config)
def azure_openai(
deployment_name: str,
model_name: Optional[str] = None,
config: Optional[OpenAIConfig] = None,
**azure_openai_client_params,
):
try:
from openai import AsyncAzureOpenAI
except ImportError:
raise ImportError(
"The `openai` library needs to be installed in order to use Outlines' Azure OpenAI integration."
)
if config is not None:
config = replace(config, model=deployment_name) # type: ignore
if config is None:
config = OpenAIConfig(model=deployment_name)
client = AsyncAzureOpenAI(**azure_openai_client_params)
return OpenAI(client, config)