Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/openai/_utils/__init__.py +60 -0
- .venv/lib/python3.11/site-packages/openai/_utils/_proxy.py +62 -0
- .venv/lib/python3.11/site-packages/openai/cli/_api/__init__.py +1 -0
- .venv/lib/python3.11/site-packages/openai/cli/_api/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/openai/cli/_api/__pycache__/_main.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/openai/cli/_api/__pycache__/audio.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/openai/cli/_api/__pycache__/completions.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/openai/cli/_api/__pycache__/files.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/openai/cli/_api/__pycache__/image.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/openai/cli/_api/__pycache__/models.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/openai/cli/_api/chat/__init__.py +13 -0
- .venv/lib/python3.11/site-packages/openai/cli/_api/chat/completions.py +160 -0
- .venv/lib/python3.11/site-packages/openai/lib/.keep +4 -0
- .venv/lib/python3.11/site-packages/openai/lib/__init__.py +2 -0
- .venv/lib/python3.11/site-packages/openai/lib/__pycache__/azure.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/openai/lib/_old_api.py +72 -0
- .venv/lib/python3.11/site-packages/openai/lib/_parsing/__init__.py +12 -0
- .venv/lib/python3.11/site-packages/openai/lib/_parsing/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/openai/lib/_parsing/__pycache__/_completions.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/openai/lib/_parsing/_completions.py +264 -0
- .venv/lib/python3.11/site-packages/openai/lib/_pydantic.py +155 -0
- .venv/lib/python3.11/site-packages/openai/lib/_tools.py +54 -0
- .venv/lib/python3.11/site-packages/openai/lib/_validators.py +809 -0
- .venv/lib/python3.11/site-packages/openai/lib/azure.py +587 -0
- .venv/lib/python3.11/site-packages/openai/lib/streaming/__init__.py +8 -0
- .venv/lib/python3.11/site-packages/openai/lib/streaming/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/openai/lib/streaming/__pycache__/_assistants.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/openai/lib/streaming/__pycache__/_deltas.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/openai/lib/streaming/_assistants.py +1038 -0
- .venv/lib/python3.11/site-packages/openai/lib/streaming/_deltas.py +64 -0
- .venv/lib/python3.11/site-packages/openai/lib/streaming/chat/__init__.py +27 -0
- .venv/lib/python3.11/site-packages/openai/lib/streaming/chat/__pycache__/_completions.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/openai/lib/streaming/chat/__pycache__/_events.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/openai/lib/streaming/chat/_completions.py +755 -0
- .venv/lib/python3.11/site-packages/openai/lib/streaming/chat/_events.py +123 -0
- .venv/lib/python3.11/site-packages/openai/lib/streaming/chat/_types.py +20 -0
- .venv/lib/python3.11/site-packages/openai/resources/__init__.py +173 -0
- .venv/lib/python3.11/site-packages/openai/resources/audio/__init__.py +61 -0
- .venv/lib/python3.11/site-packages/openai/resources/audio/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/openai/resources/audio/__pycache__/audio.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/openai/resources/audio/__pycache__/speech.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/openai/resources/audio/__pycache__/transcriptions.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/openai/resources/audio/__pycache__/translations.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/openai/resources/audio/audio.py +166 -0
- .venv/lib/python3.11/site-packages/openai/resources/audio/speech.py +234 -0
- .venv/lib/python3.11/site-packages/openai/resources/audio/transcriptions.py +415 -0
- .venv/lib/python3.11/site-packages/openai/resources/audio/translations.py +373 -0
- .venv/lib/python3.11/site-packages/openai/resources/batches.py +517 -0
- .venv/lib/python3.11/site-packages/openai/resources/beta/__init__.py +61 -0
- .venv/lib/python3.11/site-packages/openai/resources/beta/__pycache__/__init__.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/openai/_utils/__init__.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._logs import SensitiveHeadersFilter as SensitiveHeadersFilter
|
| 2 |
+
from ._sync import asyncify as asyncify
|
| 3 |
+
from ._proxy import LazyProxy as LazyProxy
|
| 4 |
+
from ._utils import (
|
| 5 |
+
flatten as flatten,
|
| 6 |
+
is_dict as is_dict,
|
| 7 |
+
is_list as is_list,
|
| 8 |
+
is_given as is_given,
|
| 9 |
+
is_tuple as is_tuple,
|
| 10 |
+
json_safe as json_safe,
|
| 11 |
+
lru_cache as lru_cache,
|
| 12 |
+
is_mapping as is_mapping,
|
| 13 |
+
is_tuple_t as is_tuple_t,
|
| 14 |
+
parse_date as parse_date,
|
| 15 |
+
is_iterable as is_iterable,
|
| 16 |
+
is_sequence as is_sequence,
|
| 17 |
+
coerce_float as coerce_float,
|
| 18 |
+
is_mapping_t as is_mapping_t,
|
| 19 |
+
removeprefix as removeprefix,
|
| 20 |
+
removesuffix as removesuffix,
|
| 21 |
+
extract_files as extract_files,
|
| 22 |
+
is_sequence_t as is_sequence_t,
|
| 23 |
+
required_args as required_args,
|
| 24 |
+
coerce_boolean as coerce_boolean,
|
| 25 |
+
coerce_integer as coerce_integer,
|
| 26 |
+
file_from_path as file_from_path,
|
| 27 |
+
parse_datetime as parse_datetime,
|
| 28 |
+
is_azure_client as is_azure_client,
|
| 29 |
+
strip_not_given as strip_not_given,
|
| 30 |
+
deepcopy_minimal as deepcopy_minimal,
|
| 31 |
+
get_async_library as get_async_library,
|
| 32 |
+
maybe_coerce_float as maybe_coerce_float,
|
| 33 |
+
get_required_header as get_required_header,
|
| 34 |
+
maybe_coerce_boolean as maybe_coerce_boolean,
|
| 35 |
+
maybe_coerce_integer as maybe_coerce_integer,
|
| 36 |
+
is_async_azure_client as is_async_azure_client,
|
| 37 |
+
)
|
| 38 |
+
from ._typing import (
|
| 39 |
+
is_list_type as is_list_type,
|
| 40 |
+
is_union_type as is_union_type,
|
| 41 |
+
extract_type_arg as extract_type_arg,
|
| 42 |
+
is_iterable_type as is_iterable_type,
|
| 43 |
+
is_required_type as is_required_type,
|
| 44 |
+
is_annotated_type as is_annotated_type,
|
| 45 |
+
is_type_alias_type as is_type_alias_type,
|
| 46 |
+
strip_annotated_type as strip_annotated_type,
|
| 47 |
+
extract_type_var_from_base as extract_type_var_from_base,
|
| 48 |
+
)
|
| 49 |
+
from ._streams import consume_sync_iterator as consume_sync_iterator, consume_async_iterator as consume_async_iterator
|
| 50 |
+
from ._transform import (
|
| 51 |
+
PropertyInfo as PropertyInfo,
|
| 52 |
+
transform as transform,
|
| 53 |
+
async_transform as async_transform,
|
| 54 |
+
maybe_transform as maybe_transform,
|
| 55 |
+
async_maybe_transform as async_maybe_transform,
|
| 56 |
+
)
|
| 57 |
+
from ._reflection import (
|
| 58 |
+
function_has_argument as function_has_argument,
|
| 59 |
+
assert_signatures_in_sync as assert_signatures_in_sync,
|
| 60 |
+
)
|
.venv/lib/python3.11/site-packages/openai/_utils/_proxy.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from typing import Generic, TypeVar, Iterable, cast
|
| 5 |
+
from typing_extensions import override
|
| 6 |
+
|
| 7 |
+
T = TypeVar("T")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class LazyProxy(Generic[T], ABC):
|
| 11 |
+
"""Implements data methods to pretend that an instance is another instance.
|
| 12 |
+
|
| 13 |
+
This includes forwarding attribute access and other methods.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
# Note: we have to special case proxies that themselves return proxies
|
| 17 |
+
# to support using a proxy as a catch-all for any random access, e.g. `proxy.foo.bar.baz`
|
| 18 |
+
|
| 19 |
+
def __getattr__(self, attr: str) -> object:
|
| 20 |
+
proxied = self.__get_proxied__()
|
| 21 |
+
if isinstance(proxied, LazyProxy):
|
| 22 |
+
return proxied # pyright: ignore
|
| 23 |
+
return getattr(proxied, attr)
|
| 24 |
+
|
| 25 |
+
@override
|
| 26 |
+
def __repr__(self) -> str:
|
| 27 |
+
proxied = self.__get_proxied__()
|
| 28 |
+
if isinstance(proxied, LazyProxy):
|
| 29 |
+
return proxied.__class__.__name__
|
| 30 |
+
return repr(self.__get_proxied__())
|
| 31 |
+
|
| 32 |
+
@override
|
| 33 |
+
def __str__(self) -> str:
|
| 34 |
+
proxied = self.__get_proxied__()
|
| 35 |
+
if isinstance(proxied, LazyProxy):
|
| 36 |
+
return proxied.__class__.__name__
|
| 37 |
+
return str(proxied)
|
| 38 |
+
|
| 39 |
+
@override
|
| 40 |
+
def __dir__(self) -> Iterable[str]:
|
| 41 |
+
proxied = self.__get_proxied__()
|
| 42 |
+
if isinstance(proxied, LazyProxy):
|
| 43 |
+
return []
|
| 44 |
+
return proxied.__dir__()
|
| 45 |
+
|
| 46 |
+
@property # type: ignore
|
| 47 |
+
@override
|
| 48 |
+
def __class__(self) -> type: # pyright: ignore
|
| 49 |
+
proxied = self.__get_proxied__()
|
| 50 |
+
if issubclass(type(proxied), LazyProxy):
|
| 51 |
+
return type(proxied)
|
| 52 |
+
return proxied.__class__
|
| 53 |
+
|
| 54 |
+
def __get_proxied__(self) -> T:
|
| 55 |
+
return self.__load__()
|
| 56 |
+
|
| 57 |
+
def __as_proxied__(self) -> T:
|
| 58 |
+
"""Helper method that returns the current proxy, typed as the loaded object"""
|
| 59 |
+
return cast(T, self)
|
| 60 |
+
|
| 61 |
+
@abstractmethod
|
| 62 |
+
def __load__(self) -> T: ...
|
.venv/lib/python3.11/site-packages/openai/cli/_api/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from ._main import register_commands as register_commands
|
.venv/lib/python3.11/site-packages/openai/cli/_api/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (249 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/openai/cli/_api/__pycache__/_main.cpython-311.pyc
ADDED
|
Binary file (1.16 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/openai/cli/_api/__pycache__/audio.cpython-311.pyc
ADDED
|
Binary file (6.34 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/openai/cli/_api/__pycache__/completions.cpython-311.pyc
ADDED
|
Binary file (9.58 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/openai/cli/_api/__pycache__/files.cpython-311.pyc
ADDED
|
Binary file (5.03 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/openai/cli/_api/__pycache__/image.cpython-311.pyc
ADDED
|
Binary file (8.43 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/openai/cli/_api/__pycache__/models.cpython-311.pyc
ADDED
|
Binary file (3.26 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/openai/cli/_api/chat/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import TYPE_CHECKING
|
| 4 |
+
from argparse import ArgumentParser
|
| 5 |
+
|
| 6 |
+
from . import completions
|
| 7 |
+
|
| 8 |
+
if TYPE_CHECKING:
|
| 9 |
+
from argparse import _SubParsersAction
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
|
| 13 |
+
completions.register(subparser)
|
.venv/lib/python3.11/site-packages/openai/cli/_api/chat/completions.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
from typing import TYPE_CHECKING, List, Optional, cast
|
| 5 |
+
from argparse import ArgumentParser
|
| 6 |
+
from typing_extensions import Literal, NamedTuple
|
| 7 |
+
|
| 8 |
+
from ..._utils import get_client
|
| 9 |
+
from ..._models import BaseModel
|
| 10 |
+
from ...._streaming import Stream
|
| 11 |
+
from ....types.chat import (
|
| 12 |
+
ChatCompletionRole,
|
| 13 |
+
ChatCompletionChunk,
|
| 14 |
+
CompletionCreateParams,
|
| 15 |
+
)
|
| 16 |
+
from ....types.chat.completion_create_params import (
|
| 17 |
+
CompletionCreateParamsStreaming,
|
| 18 |
+
CompletionCreateParamsNonStreaming,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
if TYPE_CHECKING:
|
| 22 |
+
from argparse import _SubParsersAction
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
|
| 26 |
+
sub = subparser.add_parser("chat.completions.create")
|
| 27 |
+
|
| 28 |
+
sub._action_groups.pop()
|
| 29 |
+
req = sub.add_argument_group("required arguments")
|
| 30 |
+
opt = sub.add_argument_group("optional arguments")
|
| 31 |
+
|
| 32 |
+
req.add_argument(
|
| 33 |
+
"-g",
|
| 34 |
+
"--message",
|
| 35 |
+
action="append",
|
| 36 |
+
nargs=2,
|
| 37 |
+
metavar=("ROLE", "CONTENT"),
|
| 38 |
+
help="A message in `{role} {content}` format. Use this argument multiple times to add multiple messages.",
|
| 39 |
+
required=True,
|
| 40 |
+
)
|
| 41 |
+
req.add_argument(
|
| 42 |
+
"-m",
|
| 43 |
+
"--model",
|
| 44 |
+
help="The model to use.",
|
| 45 |
+
required=True,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
opt.add_argument(
|
| 49 |
+
"-n",
|
| 50 |
+
"--n",
|
| 51 |
+
help="How many completions to generate for the conversation.",
|
| 52 |
+
type=int,
|
| 53 |
+
)
|
| 54 |
+
opt.add_argument("-M", "--max-tokens", help="The maximum number of tokens to generate.", type=int)
|
| 55 |
+
opt.add_argument(
|
| 56 |
+
"-t",
|
| 57 |
+
"--temperature",
|
| 58 |
+
help="""What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
|
| 59 |
+
|
| 60 |
+
Mutually exclusive with `top_p`.""",
|
| 61 |
+
type=float,
|
| 62 |
+
)
|
| 63 |
+
opt.add_argument(
|
| 64 |
+
"-P",
|
| 65 |
+
"--top_p",
|
| 66 |
+
help="""An alternative to sampling with temperature, called nucleus sampling, where the considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%% probability mass are considered.
|
| 67 |
+
|
| 68 |
+
Mutually exclusive with `temperature`.""",
|
| 69 |
+
type=float,
|
| 70 |
+
)
|
| 71 |
+
opt.add_argument(
|
| 72 |
+
"--stop",
|
| 73 |
+
help="A stop sequence at which to stop generating tokens for the message.",
|
| 74 |
+
)
|
| 75 |
+
opt.add_argument("--stream", help="Stream messages as they're ready.", action="store_true")
|
| 76 |
+
sub.set_defaults(func=CLIChatCompletion.create, args_model=CLIChatCompletionCreateArgs)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class CLIMessage(NamedTuple):
|
| 80 |
+
role: ChatCompletionRole
|
| 81 |
+
content: str
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class CLIChatCompletionCreateArgs(BaseModel):
|
| 85 |
+
message: List[CLIMessage]
|
| 86 |
+
model: str
|
| 87 |
+
n: Optional[int] = None
|
| 88 |
+
max_tokens: Optional[int] = None
|
| 89 |
+
temperature: Optional[float] = None
|
| 90 |
+
top_p: Optional[float] = None
|
| 91 |
+
stop: Optional[str] = None
|
| 92 |
+
stream: bool = False
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class CLIChatCompletion:
|
| 96 |
+
@staticmethod
|
| 97 |
+
def create(args: CLIChatCompletionCreateArgs) -> None:
|
| 98 |
+
params: CompletionCreateParams = {
|
| 99 |
+
"model": args.model,
|
| 100 |
+
"messages": [
|
| 101 |
+
{"role": cast(Literal["user"], message.role), "content": message.content} for message in args.message
|
| 102 |
+
],
|
| 103 |
+
# type checkers are not good at inferring union types so we have to set stream afterwards
|
| 104 |
+
"stream": False,
|
| 105 |
+
}
|
| 106 |
+
if args.temperature is not None:
|
| 107 |
+
params['temperature'] = args.temperature
|
| 108 |
+
if args.stop is not None:
|
| 109 |
+
params['stop'] = args.stop
|
| 110 |
+
if args.top_p is not None:
|
| 111 |
+
params['top_p'] = args.top_p
|
| 112 |
+
if args.n is not None:
|
| 113 |
+
params['n'] = args.n
|
| 114 |
+
if args.stream:
|
| 115 |
+
params["stream"] = args.stream # type: ignore
|
| 116 |
+
if args.max_tokens is not None:
|
| 117 |
+
params["max_tokens"] = args.max_tokens
|
| 118 |
+
|
| 119 |
+
if args.stream:
|
| 120 |
+
return CLIChatCompletion._stream_create(cast(CompletionCreateParamsStreaming, params))
|
| 121 |
+
|
| 122 |
+
return CLIChatCompletion._create(cast(CompletionCreateParamsNonStreaming, params))
|
| 123 |
+
|
| 124 |
+
@staticmethod
|
| 125 |
+
def _create(params: CompletionCreateParamsNonStreaming) -> None:
|
| 126 |
+
completion = get_client().chat.completions.create(**params)
|
| 127 |
+
should_print_header = len(completion.choices) > 1
|
| 128 |
+
for choice in completion.choices:
|
| 129 |
+
if should_print_header:
|
| 130 |
+
sys.stdout.write("===== Chat Completion {} =====\n".format(choice.index))
|
| 131 |
+
|
| 132 |
+
content = choice.message.content if choice.message.content is not None else "None"
|
| 133 |
+
sys.stdout.write(content)
|
| 134 |
+
|
| 135 |
+
if should_print_header or not content.endswith("\n"):
|
| 136 |
+
sys.stdout.write("\n")
|
| 137 |
+
|
| 138 |
+
sys.stdout.flush()
|
| 139 |
+
|
| 140 |
+
@staticmethod
|
| 141 |
+
def _stream_create(params: CompletionCreateParamsStreaming) -> None:
|
| 142 |
+
# cast is required for mypy
|
| 143 |
+
stream = cast( # pyright: ignore[reportUnnecessaryCast]
|
| 144 |
+
Stream[ChatCompletionChunk], get_client().chat.completions.create(**params)
|
| 145 |
+
)
|
| 146 |
+
for chunk in stream:
|
| 147 |
+
should_print_header = len(chunk.choices) > 1
|
| 148 |
+
for choice in chunk.choices:
|
| 149 |
+
if should_print_header:
|
| 150 |
+
sys.stdout.write("===== Chat Completion {} =====\n".format(choice.index))
|
| 151 |
+
|
| 152 |
+
content = choice.delta.content or ""
|
| 153 |
+
sys.stdout.write(content)
|
| 154 |
+
|
| 155 |
+
if should_print_header:
|
| 156 |
+
sys.stdout.write("\n")
|
| 157 |
+
|
| 158 |
+
sys.stdout.flush()
|
| 159 |
+
|
| 160 |
+
sys.stdout.write("\n")
|
.venv/lib/python3.11/site-packages/openai/lib/.keep
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
File generated from our OpenAPI spec by Stainless.
|
| 2 |
+
|
| 3 |
+
This directory can be used to store custom files to expand the SDK.
|
| 4 |
+
It is ignored by Stainless code generation and its content (other than this keep file) won't be touched.
|
.venv/lib/python3.11/site-packages/openai/lib/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._tools import pydantic_function_tool as pydantic_function_tool
|
| 2 |
+
from ._parsing import ResponseFormatT as ResponseFormatT
|
.venv/lib/python3.11/site-packages/openai/lib/__pycache__/azure.cpython-311.pyc
ADDED
|
Binary file (20.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/openai/lib/_old_api.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import TYPE_CHECKING, Any
|
| 4 |
+
from typing_extensions import override
|
| 5 |
+
|
| 6 |
+
from .._utils import LazyProxy
|
| 7 |
+
from .._exceptions import OpenAIError
|
| 8 |
+
|
| 9 |
+
INSTRUCTIONS = """
|
| 10 |
+
|
| 11 |
+
You tried to access openai.{symbol}, but this is no longer supported in openai>=1.0.0 - see the README at https://github.com/openai/openai-python for the API.
|
| 12 |
+
|
| 13 |
+
You can run `openai migrate` to automatically upgrade your codebase to use the 1.0.0 interface.
|
| 14 |
+
|
| 15 |
+
Alternatively, you can pin your installation to the old version, e.g. `pip install openai==0.28`
|
| 16 |
+
|
| 17 |
+
A detailed migration guide is available here: https://github.com/openai/openai-python/discussions/742
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class APIRemovedInV1(OpenAIError):
|
| 22 |
+
def __init__(self, *, symbol: str) -> None:
|
| 23 |
+
super().__init__(INSTRUCTIONS.format(symbol=symbol))
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class APIRemovedInV1Proxy(LazyProxy[Any]):
|
| 27 |
+
def __init__(self, *, symbol: str) -> None:
|
| 28 |
+
super().__init__()
|
| 29 |
+
self._symbol = symbol
|
| 30 |
+
|
| 31 |
+
@override
|
| 32 |
+
def __load__(self) -> Any:
|
| 33 |
+
# return the proxy until it is eventually called so that
|
| 34 |
+
# we don't break people that are just checking the attributes
|
| 35 |
+
# of a module
|
| 36 |
+
return self
|
| 37 |
+
|
| 38 |
+
def __call__(self, *_args: Any, **_kwargs: Any) -> Any:
|
| 39 |
+
raise APIRemovedInV1(symbol=self._symbol)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
SYMBOLS = [
|
| 43 |
+
"Edit",
|
| 44 |
+
"File",
|
| 45 |
+
"Audio",
|
| 46 |
+
"Image",
|
| 47 |
+
"Model",
|
| 48 |
+
"Engine",
|
| 49 |
+
"Customer",
|
| 50 |
+
"FineTune",
|
| 51 |
+
"Embedding",
|
| 52 |
+
"Completion",
|
| 53 |
+
"Deployment",
|
| 54 |
+
"Moderation",
|
| 55 |
+
"ErrorObject",
|
| 56 |
+
"FineTuningJob",
|
| 57 |
+
"ChatCompletion",
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
# we explicitly tell type checkers that nothing is exported
|
| 61 |
+
# from this file so that when we re-export the old symbols
|
| 62 |
+
# in `openai/__init__.py` they aren't added to the auto-complete
|
| 63 |
+
# suggestions given by editors
|
| 64 |
+
if TYPE_CHECKING:
|
| 65 |
+
__all__: list[str] = []
|
| 66 |
+
else:
|
| 67 |
+
__all__ = SYMBOLS
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
__locals = locals()
|
| 71 |
+
for symbol in SYMBOLS:
|
| 72 |
+
__locals[symbol] = APIRemovedInV1Proxy(symbol=symbol)
|
.venv/lib/python3.11/site-packages/openai/lib/_parsing/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._completions import (
|
| 2 |
+
ResponseFormatT as ResponseFormatT,
|
| 3 |
+
has_parseable_input,
|
| 4 |
+
has_parseable_input as has_parseable_input,
|
| 5 |
+
maybe_parse_content as maybe_parse_content,
|
| 6 |
+
validate_input_tools as validate_input_tools,
|
| 7 |
+
parse_chat_completion as parse_chat_completion,
|
| 8 |
+
get_input_tool_by_name as get_input_tool_by_name,
|
| 9 |
+
solve_response_format_t as solve_response_format_t,
|
| 10 |
+
parse_function_tool_arguments as parse_function_tool_arguments,
|
| 11 |
+
type_to_response_format_param as type_to_response_format_param,
|
| 12 |
+
)
|
.venv/lib/python3.11/site-packages/openai/lib/_parsing/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (643 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/openai/lib/_parsing/__pycache__/_completions.cpython-311.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/openai/lib/_parsing/_completions.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from typing import TYPE_CHECKING, Any, Iterable, cast
|
| 5 |
+
from typing_extensions import TypeVar, TypeGuard, assert_never
|
| 6 |
+
|
| 7 |
+
import pydantic
|
| 8 |
+
|
| 9 |
+
from .._tools import PydanticFunctionTool
|
| 10 |
+
from ..._types import NOT_GIVEN, NotGiven
|
| 11 |
+
from ..._utils import is_dict, is_given
|
| 12 |
+
from ..._compat import PYDANTIC_V2, model_parse_json
|
| 13 |
+
from ..._models import construct_type_unchecked
|
| 14 |
+
from .._pydantic import is_basemodel_type, to_strict_json_schema, is_dataclass_like_type
|
| 15 |
+
from ...types.chat import (
|
| 16 |
+
ParsedChoice,
|
| 17 |
+
ChatCompletion,
|
| 18 |
+
ParsedFunction,
|
| 19 |
+
ParsedChatCompletion,
|
| 20 |
+
ChatCompletionMessage,
|
| 21 |
+
ParsedFunctionToolCall,
|
| 22 |
+
ChatCompletionToolParam,
|
| 23 |
+
ParsedChatCompletionMessage,
|
| 24 |
+
completion_create_params,
|
| 25 |
+
)
|
| 26 |
+
from ..._exceptions import LengthFinishReasonError, ContentFilterFinishReasonError
|
| 27 |
+
from ...types.shared_params import FunctionDefinition
|
| 28 |
+
from ...types.chat.completion_create_params import ResponseFormat as ResponseFormatParam
|
| 29 |
+
from ...types.chat.chat_completion_message_tool_call import Function
|
| 30 |
+
|
| 31 |
+
ResponseFormatT = TypeVar(
|
| 32 |
+
"ResponseFormatT",
|
| 33 |
+
# if it isn't given then we don't do any parsing
|
| 34 |
+
default=None,
|
| 35 |
+
)
|
| 36 |
+
_default_response_format: None = None
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def validate_input_tools(
|
| 40 |
+
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
|
| 41 |
+
) -> None:
|
| 42 |
+
if not is_given(tools):
|
| 43 |
+
return
|
| 44 |
+
|
| 45 |
+
for tool in tools:
|
| 46 |
+
if tool["type"] != "function":
|
| 47 |
+
raise ValueError(
|
| 48 |
+
f'Currently only `function` tool types support auto-parsing; Received `{tool["type"]}`',
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
strict = tool["function"].get("strict")
|
| 52 |
+
if strict is not True:
|
| 53 |
+
raise ValueError(
|
| 54 |
+
f'`{tool["function"]["name"]}` is not strict. Only `strict` function tools can be auto-parsed'
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def parse_chat_completion(
|
| 59 |
+
*,
|
| 60 |
+
response_format: type[ResponseFormatT] | completion_create_params.ResponseFormat | NotGiven,
|
| 61 |
+
input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
|
| 62 |
+
chat_completion: ChatCompletion | ParsedChatCompletion[object],
|
| 63 |
+
) -> ParsedChatCompletion[ResponseFormatT]:
|
| 64 |
+
if is_given(input_tools):
|
| 65 |
+
input_tools = [t for t in input_tools]
|
| 66 |
+
else:
|
| 67 |
+
input_tools = []
|
| 68 |
+
|
| 69 |
+
choices: list[ParsedChoice[ResponseFormatT]] = []
|
| 70 |
+
for choice in chat_completion.choices:
|
| 71 |
+
if choice.finish_reason == "length":
|
| 72 |
+
raise LengthFinishReasonError(completion=chat_completion)
|
| 73 |
+
|
| 74 |
+
if choice.finish_reason == "content_filter":
|
| 75 |
+
raise ContentFilterFinishReasonError()
|
| 76 |
+
|
| 77 |
+
message = choice.message
|
| 78 |
+
|
| 79 |
+
tool_calls: list[ParsedFunctionToolCall] = []
|
| 80 |
+
if message.tool_calls:
|
| 81 |
+
for tool_call in message.tool_calls:
|
| 82 |
+
if tool_call.type == "function":
|
| 83 |
+
tool_call_dict = tool_call.to_dict()
|
| 84 |
+
tool_calls.append(
|
| 85 |
+
construct_type_unchecked(
|
| 86 |
+
value={
|
| 87 |
+
**tool_call_dict,
|
| 88 |
+
"function": {
|
| 89 |
+
**cast(Any, tool_call_dict["function"]),
|
| 90 |
+
"parsed_arguments": parse_function_tool_arguments(
|
| 91 |
+
input_tools=input_tools, function=tool_call.function
|
| 92 |
+
),
|
| 93 |
+
},
|
| 94 |
+
},
|
| 95 |
+
type_=ParsedFunctionToolCall,
|
| 96 |
+
)
|
| 97 |
+
)
|
| 98 |
+
elif TYPE_CHECKING: # type: ignore[unreachable]
|
| 99 |
+
assert_never(tool_call)
|
| 100 |
+
else:
|
| 101 |
+
tool_calls.append(tool_call)
|
| 102 |
+
|
| 103 |
+
choices.append(
|
| 104 |
+
construct_type_unchecked(
|
| 105 |
+
type_=cast(Any, ParsedChoice)[solve_response_format_t(response_format)],
|
| 106 |
+
value={
|
| 107 |
+
**choice.to_dict(),
|
| 108 |
+
"message": {
|
| 109 |
+
**message.to_dict(),
|
| 110 |
+
"parsed": maybe_parse_content(
|
| 111 |
+
response_format=response_format,
|
| 112 |
+
message=message,
|
| 113 |
+
),
|
| 114 |
+
"tool_calls": tool_calls,
|
| 115 |
+
},
|
| 116 |
+
},
|
| 117 |
+
)
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
return cast(
|
| 121 |
+
ParsedChatCompletion[ResponseFormatT],
|
| 122 |
+
construct_type_unchecked(
|
| 123 |
+
type_=cast(Any, ParsedChatCompletion)[solve_response_format_t(response_format)],
|
| 124 |
+
value={
|
| 125 |
+
**chat_completion.to_dict(),
|
| 126 |
+
"choices": choices,
|
| 127 |
+
},
|
| 128 |
+
),
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def get_input_tool_by_name(*, input_tools: list[ChatCompletionToolParam], name: str) -> ChatCompletionToolParam | None:
|
| 133 |
+
return next((t for t in input_tools if t.get("function", {}).get("name") == name), None)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def parse_function_tool_arguments(
|
| 137 |
+
*, input_tools: list[ChatCompletionToolParam], function: Function | ParsedFunction
|
| 138 |
+
) -> object:
|
| 139 |
+
input_tool = get_input_tool_by_name(input_tools=input_tools, name=function.name)
|
| 140 |
+
if not input_tool:
|
| 141 |
+
return None
|
| 142 |
+
|
| 143 |
+
input_fn = cast(object, input_tool.get("function"))
|
| 144 |
+
if isinstance(input_fn, PydanticFunctionTool):
|
| 145 |
+
return model_parse_json(input_fn.model, function.arguments)
|
| 146 |
+
|
| 147 |
+
input_fn = cast(FunctionDefinition, input_fn)
|
| 148 |
+
|
| 149 |
+
if not input_fn.get("strict"):
|
| 150 |
+
return None
|
| 151 |
+
|
| 152 |
+
return json.loads(function.arguments)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def maybe_parse_content(
|
| 156 |
+
*,
|
| 157 |
+
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
|
| 158 |
+
message: ChatCompletionMessage | ParsedChatCompletionMessage[object],
|
| 159 |
+
) -> ResponseFormatT | None:
|
| 160 |
+
if has_rich_response_format(response_format) and message.content and not message.refusal:
|
| 161 |
+
return _parse_content(response_format, message.content)
|
| 162 |
+
|
| 163 |
+
return None
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def solve_response_format_t(
|
| 167 |
+
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
|
| 168 |
+
) -> type[ResponseFormatT]:
|
| 169 |
+
"""Return the runtime type for the given response format.
|
| 170 |
+
|
| 171 |
+
If no response format is given, or if we won't auto-parse the response format
|
| 172 |
+
then we default to `None`.
|
| 173 |
+
"""
|
| 174 |
+
if has_rich_response_format(response_format):
|
| 175 |
+
return response_format
|
| 176 |
+
|
| 177 |
+
return cast("type[ResponseFormatT]", _default_response_format)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def has_parseable_input(
|
| 181 |
+
*,
|
| 182 |
+
response_format: type | ResponseFormatParam | NotGiven,
|
| 183 |
+
input_tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
|
| 184 |
+
) -> bool:
|
| 185 |
+
if has_rich_response_format(response_format):
|
| 186 |
+
return True
|
| 187 |
+
|
| 188 |
+
for input_tool in input_tools or []:
|
| 189 |
+
if is_parseable_tool(input_tool):
|
| 190 |
+
return True
|
| 191 |
+
|
| 192 |
+
return False
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def has_rich_response_format(
|
| 196 |
+
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
|
| 197 |
+
) -> TypeGuard[type[ResponseFormatT]]:
|
| 198 |
+
if not is_given(response_format):
|
| 199 |
+
return False
|
| 200 |
+
|
| 201 |
+
if is_response_format_param(response_format):
|
| 202 |
+
return False
|
| 203 |
+
|
| 204 |
+
return True
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def is_response_format_param(response_format: object) -> TypeGuard[ResponseFormatParam]:
|
| 208 |
+
return is_dict(response_format)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def is_parseable_tool(input_tool: ChatCompletionToolParam) -> bool:
|
| 212 |
+
input_fn = cast(object, input_tool.get("function"))
|
| 213 |
+
if isinstance(input_fn, PydanticFunctionTool):
|
| 214 |
+
return True
|
| 215 |
+
|
| 216 |
+
return cast(FunctionDefinition, input_fn).get("strict") or False
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def _parse_content(response_format: type[ResponseFormatT], content: str) -> ResponseFormatT:
|
| 220 |
+
if is_basemodel_type(response_format):
|
| 221 |
+
return cast(ResponseFormatT, model_parse_json(response_format, content))
|
| 222 |
+
|
| 223 |
+
if is_dataclass_like_type(response_format):
|
| 224 |
+
if not PYDANTIC_V2:
|
| 225 |
+
raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {response_format}")
|
| 226 |
+
|
| 227 |
+
return pydantic.TypeAdapter(response_format).validate_json(content)
|
| 228 |
+
|
| 229 |
+
raise TypeError(f"Unable to automatically parse response format type {response_format}")
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def type_to_response_format_param(
|
| 233 |
+
response_format: type | completion_create_params.ResponseFormat | NotGiven,
|
| 234 |
+
) -> ResponseFormatParam | NotGiven:
|
| 235 |
+
if not is_given(response_format):
|
| 236 |
+
return NOT_GIVEN
|
| 237 |
+
|
| 238 |
+
if is_response_format_param(response_format):
|
| 239 |
+
return response_format
|
| 240 |
+
|
| 241 |
+
# type checkers don't narrow the negation of a `TypeGuard` as it isn't
|
| 242 |
+
# a safe default behaviour but we know that at this point the `response_format`
|
| 243 |
+
# can only be a `type`
|
| 244 |
+
response_format = cast(type, response_format)
|
| 245 |
+
|
| 246 |
+
json_schema_type: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any] | None = None
|
| 247 |
+
|
| 248 |
+
if is_basemodel_type(response_format):
|
| 249 |
+
name = response_format.__name__
|
| 250 |
+
json_schema_type = response_format
|
| 251 |
+
elif is_dataclass_like_type(response_format):
|
| 252 |
+
name = response_format.__name__
|
| 253 |
+
json_schema_type = pydantic.TypeAdapter(response_format)
|
| 254 |
+
else:
|
| 255 |
+
raise TypeError(f"Unsupported response_format type - {response_format}")
|
| 256 |
+
|
| 257 |
+
return {
|
| 258 |
+
"type": "json_schema",
|
| 259 |
+
"json_schema": {
|
| 260 |
+
"schema": to_strict_json_schema(json_schema_type),
|
| 261 |
+
"name": name,
|
| 262 |
+
"strict": True,
|
| 263 |
+
},
|
| 264 |
+
}
|
.venv/lib/python3.11/site-packages/openai/lib/_pydantic.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import inspect
|
| 4 |
+
from typing import Any, TypeVar
|
| 5 |
+
from typing_extensions import TypeGuard
|
| 6 |
+
|
| 7 |
+
import pydantic
|
| 8 |
+
|
| 9 |
+
from .._types import NOT_GIVEN
|
| 10 |
+
from .._utils import is_dict as _is_dict, is_list
|
| 11 |
+
from .._compat import PYDANTIC_V2, model_json_schema
|
| 12 |
+
|
| 13 |
+
_T = TypeVar("_T")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def to_strict_json_schema(model: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any]) -> dict[str, Any]:
|
| 17 |
+
if inspect.isclass(model) and is_basemodel_type(model):
|
| 18 |
+
schema = model_json_schema(model)
|
| 19 |
+
elif PYDANTIC_V2 and isinstance(model, pydantic.TypeAdapter):
|
| 20 |
+
schema = model.json_schema()
|
| 21 |
+
else:
|
| 22 |
+
raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {model}")
|
| 23 |
+
|
| 24 |
+
return _ensure_strict_json_schema(schema, path=(), root=schema)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _ensure_strict_json_schema(
|
| 28 |
+
json_schema: object,
|
| 29 |
+
*,
|
| 30 |
+
path: tuple[str, ...],
|
| 31 |
+
root: dict[str, object],
|
| 32 |
+
) -> dict[str, Any]:
|
| 33 |
+
"""Mutates the given JSON schema to ensure it conforms to the `strict` standard
|
| 34 |
+
that the API expects.
|
| 35 |
+
"""
|
| 36 |
+
if not is_dict(json_schema):
|
| 37 |
+
raise TypeError(f"Expected {json_schema} to be a dictionary; path={path}")
|
| 38 |
+
|
| 39 |
+
defs = json_schema.get("$defs")
|
| 40 |
+
if is_dict(defs):
|
| 41 |
+
for def_name, def_schema in defs.items():
|
| 42 |
+
_ensure_strict_json_schema(def_schema, path=(*path, "$defs", def_name), root=root)
|
| 43 |
+
|
| 44 |
+
definitions = json_schema.get("definitions")
|
| 45 |
+
if is_dict(definitions):
|
| 46 |
+
for definition_name, definition_schema in definitions.items():
|
| 47 |
+
_ensure_strict_json_schema(definition_schema, path=(*path, "definitions", definition_name), root=root)
|
| 48 |
+
|
| 49 |
+
typ = json_schema.get("type")
|
| 50 |
+
if typ == "object" and "additionalProperties" not in json_schema:
|
| 51 |
+
json_schema["additionalProperties"] = False
|
| 52 |
+
|
| 53 |
+
# object types
|
| 54 |
+
# { 'type': 'object', 'properties': { 'a': {...} } }
|
| 55 |
+
properties = json_schema.get("properties")
|
| 56 |
+
if is_dict(properties):
|
| 57 |
+
json_schema["required"] = [prop for prop in properties.keys()]
|
| 58 |
+
json_schema["properties"] = {
|
| 59 |
+
key: _ensure_strict_json_schema(prop_schema, path=(*path, "properties", key), root=root)
|
| 60 |
+
for key, prop_schema in properties.items()
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
# arrays
|
| 64 |
+
# { 'type': 'array', 'items': {...} }
|
| 65 |
+
items = json_schema.get("items")
|
| 66 |
+
if is_dict(items):
|
| 67 |
+
json_schema["items"] = _ensure_strict_json_schema(items, path=(*path, "items"), root=root)
|
| 68 |
+
|
| 69 |
+
# unions
|
| 70 |
+
any_of = json_schema.get("anyOf")
|
| 71 |
+
if is_list(any_of):
|
| 72 |
+
json_schema["anyOf"] = [
|
| 73 |
+
_ensure_strict_json_schema(variant, path=(*path, "anyOf", str(i)), root=root)
|
| 74 |
+
for i, variant in enumerate(any_of)
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
# intersections
|
| 78 |
+
all_of = json_schema.get("allOf")
|
| 79 |
+
if is_list(all_of):
|
| 80 |
+
if len(all_of) == 1:
|
| 81 |
+
json_schema.update(_ensure_strict_json_schema(all_of[0], path=(*path, "allOf", "0"), root=root))
|
| 82 |
+
json_schema.pop("allOf")
|
| 83 |
+
else:
|
| 84 |
+
json_schema["allOf"] = [
|
| 85 |
+
_ensure_strict_json_schema(entry, path=(*path, "allOf", str(i)), root=root)
|
| 86 |
+
for i, entry in enumerate(all_of)
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
# strip `None` defaults as there's no meaningful distinction here
|
| 90 |
+
# the schema will still be `nullable` and the model will default
|
| 91 |
+
# to using `None` anyway
|
| 92 |
+
if json_schema.get("default", NOT_GIVEN) is None:
|
| 93 |
+
json_schema.pop("default")
|
| 94 |
+
|
| 95 |
+
# we can't use `$ref`s if there are also other properties defined, e.g.
|
| 96 |
+
# `{"$ref": "...", "description": "my description"}`
|
| 97 |
+
#
|
| 98 |
+
# so we unravel the ref
|
| 99 |
+
# `{"type": "string", "description": "my description"}`
|
| 100 |
+
ref = json_schema.get("$ref")
|
| 101 |
+
if ref and has_more_than_n_keys(json_schema, 1):
|
| 102 |
+
assert isinstance(ref, str), f"Received non-string $ref - {ref}"
|
| 103 |
+
|
| 104 |
+
resolved = resolve_ref(root=root, ref=ref)
|
| 105 |
+
if not is_dict(resolved):
|
| 106 |
+
raise ValueError(f"Expected `$ref: {ref}` to resolved to a dictionary but got {resolved}")
|
| 107 |
+
|
| 108 |
+
# properties from the json schema take priority over the ones on the `$ref`
|
| 109 |
+
json_schema.update({**resolved, **json_schema})
|
| 110 |
+
json_schema.pop("$ref")
|
| 111 |
+
# Since the schema expanded from `$ref` might not have `additionalProperties: false` applied,
|
| 112 |
+
# we call `_ensure_strict_json_schema` again to fix the inlined schema and ensure it's valid.
|
| 113 |
+
return _ensure_strict_json_schema(json_schema, path=path, root=root)
|
| 114 |
+
|
| 115 |
+
return json_schema
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def resolve_ref(*, root: dict[str, object], ref: str) -> object:
|
| 119 |
+
if not ref.startswith("#/"):
|
| 120 |
+
raise ValueError(f"Unexpected $ref format {ref!r}; Does not start with #/")
|
| 121 |
+
|
| 122 |
+
path = ref[2:].split("/")
|
| 123 |
+
resolved = root
|
| 124 |
+
for key in path:
|
| 125 |
+
value = resolved[key]
|
| 126 |
+
assert is_dict(value), f"encountered non-dictionary entry while resolving {ref} - {resolved}"
|
| 127 |
+
resolved = value
|
| 128 |
+
|
| 129 |
+
return resolved
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def is_basemodel_type(typ: type) -> TypeGuard[type[pydantic.BaseModel]]:
|
| 133 |
+
if not inspect.isclass(typ):
|
| 134 |
+
return False
|
| 135 |
+
return issubclass(typ, pydantic.BaseModel)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def is_dataclass_like_type(typ: type) -> bool:
|
| 139 |
+
"""Returns True if the given type likely used `@pydantic.dataclass`"""
|
| 140 |
+
return hasattr(typ, "__pydantic_config__")
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def is_dict(obj: object) -> TypeGuard[dict[str, object]]:
|
| 144 |
+
# just pretend that we know there are only `str` keys
|
| 145 |
+
# as that check is not worth the performance cost
|
| 146 |
+
return _is_dict(obj)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def has_more_than_n_keys(obj: dict[str, object], n: int) -> bool:
|
| 150 |
+
i = 0
|
| 151 |
+
for _ in obj.keys():
|
| 152 |
+
i += 1
|
| 153 |
+
if i > n:
|
| 154 |
+
return True
|
| 155 |
+
return False
|
.venv/lib/python3.11/site-packages/openai/lib/_tools.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, cast
|
| 4 |
+
|
| 5 |
+
import pydantic
|
| 6 |
+
|
| 7 |
+
from ._pydantic import to_strict_json_schema
|
| 8 |
+
from ..types.chat import ChatCompletionToolParam
|
| 9 |
+
from ..types.shared_params import FunctionDefinition
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PydanticFunctionTool(Dict[str, Any]):
|
| 13 |
+
"""Dictionary wrapper so we can pass the given base model
|
| 14 |
+
throughout the entire request stack without having to special
|
| 15 |
+
case it.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
model: type[pydantic.BaseModel]
|
| 19 |
+
|
| 20 |
+
def __init__(self, defn: FunctionDefinition, model: type[pydantic.BaseModel]) -> None:
|
| 21 |
+
super().__init__(defn)
|
| 22 |
+
self.model = model
|
| 23 |
+
|
| 24 |
+
def cast(self) -> FunctionDefinition:
|
| 25 |
+
return cast(FunctionDefinition, self)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def pydantic_function_tool(
|
| 29 |
+
model: type[pydantic.BaseModel],
|
| 30 |
+
*,
|
| 31 |
+
name: str | None = None, # inferred from class name by default
|
| 32 |
+
description: str | None = None, # inferred from class docstring by default
|
| 33 |
+
) -> ChatCompletionToolParam:
|
| 34 |
+
if description is None:
|
| 35 |
+
# note: we intentionally don't use `.getdoc()` to avoid
|
| 36 |
+
# including pydantic's docstrings
|
| 37 |
+
description = model.__doc__
|
| 38 |
+
|
| 39 |
+
function = PydanticFunctionTool(
|
| 40 |
+
{
|
| 41 |
+
"name": name or model.__name__,
|
| 42 |
+
"strict": True,
|
| 43 |
+
"parameters": to_strict_json_schema(model),
|
| 44 |
+
},
|
| 45 |
+
model,
|
| 46 |
+
).cast()
|
| 47 |
+
|
| 48 |
+
if description is not None:
|
| 49 |
+
function["description"] = description
|
| 50 |
+
|
| 51 |
+
return {
|
| 52 |
+
"type": "function",
|
| 53 |
+
"function": function,
|
| 54 |
+
}
|
.venv/lib/python3.11/site-packages/openai/lib/_validators.py
ADDED
|
@@ -0,0 +1,809 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pyright: basic
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
from typing import Any, TypeVar, Callable, Optional, NamedTuple
|
| 7 |
+
from typing_extensions import TypeAlias
|
| 8 |
+
|
| 9 |
+
from .._extras import pandas as pd
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Remediation(NamedTuple):
|
| 13 |
+
name: str
|
| 14 |
+
immediate_msg: Optional[str] = None
|
| 15 |
+
necessary_msg: Optional[str] = None
|
| 16 |
+
necessary_fn: Optional[Callable[[Any], Any]] = None
|
| 17 |
+
optional_msg: Optional[str] = None
|
| 18 |
+
optional_fn: Optional[Callable[[Any], Any]] = None
|
| 19 |
+
error_msg: Optional[str] = None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
OptionalDataFrameT = TypeVar("OptionalDataFrameT", bound="Optional[pd.DataFrame]")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def num_examples_validator(df: pd.DataFrame) -> Remediation:
|
| 26 |
+
"""
|
| 27 |
+
This validator will only print out the number of examples and recommend to the user to increase the number of examples if less than 100.
|
| 28 |
+
"""
|
| 29 |
+
MIN_EXAMPLES = 100
|
| 30 |
+
optional_suggestion = (
|
| 31 |
+
""
|
| 32 |
+
if len(df) >= MIN_EXAMPLES
|
| 33 |
+
else ". In general, we recommend having at least a few hundred examples. We've found that performance tends to linearly increase for every doubling of the number of examples"
|
| 34 |
+
)
|
| 35 |
+
immediate_msg = f"\n- Your file contains {len(df)} prompt-completion pairs{optional_suggestion}"
|
| 36 |
+
return Remediation(name="num_examples", immediate_msg=immediate_msg)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def necessary_column_validator(df: pd.DataFrame, necessary_column: str) -> Remediation:
|
| 40 |
+
"""
|
| 41 |
+
This validator will ensure that the necessary column is present in the dataframe.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def lower_case_column(df: pd.DataFrame, column: Any) -> pd.DataFrame:
|
| 45 |
+
cols = [c for c in df.columns if str(c).lower() == column]
|
| 46 |
+
df.rename(columns={cols[0]: column.lower()}, inplace=True)
|
| 47 |
+
return df
|
| 48 |
+
|
| 49 |
+
immediate_msg = None
|
| 50 |
+
necessary_fn = None
|
| 51 |
+
necessary_msg = None
|
| 52 |
+
error_msg = None
|
| 53 |
+
|
| 54 |
+
if necessary_column not in df.columns:
|
| 55 |
+
if necessary_column in [str(c).lower() for c in df.columns]:
|
| 56 |
+
|
| 57 |
+
def lower_case_column_creator(df: pd.DataFrame) -> pd.DataFrame:
|
| 58 |
+
return lower_case_column(df, necessary_column)
|
| 59 |
+
|
| 60 |
+
necessary_fn = lower_case_column_creator
|
| 61 |
+
immediate_msg = f"\n- The `{necessary_column}` column/key should be lowercase"
|
| 62 |
+
necessary_msg = f"Lower case column name to `{necessary_column}`"
|
| 63 |
+
else:
|
| 64 |
+
error_msg = f"`{necessary_column}` column/key is missing. Please make sure you name your columns/keys appropriately, then retry"
|
| 65 |
+
|
| 66 |
+
return Remediation(
|
| 67 |
+
name="necessary_column",
|
| 68 |
+
immediate_msg=immediate_msg,
|
| 69 |
+
necessary_msg=necessary_msg,
|
| 70 |
+
necessary_fn=necessary_fn,
|
| 71 |
+
error_msg=error_msg,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def additional_column_validator(df: pd.DataFrame, fields: list[str] = ["prompt", "completion"]) -> Remediation:
|
| 76 |
+
"""
|
| 77 |
+
This validator will remove additional columns from the dataframe.
|
| 78 |
+
"""
|
| 79 |
+
additional_columns = []
|
| 80 |
+
necessary_msg = None
|
| 81 |
+
immediate_msg = None
|
| 82 |
+
necessary_fn = None # type: ignore
|
| 83 |
+
|
| 84 |
+
if len(df.columns) > 2:
|
| 85 |
+
additional_columns = [c for c in df.columns if c not in fields]
|
| 86 |
+
warn_message = ""
|
| 87 |
+
for ac in additional_columns:
|
| 88 |
+
dups = [c for c in additional_columns if ac in c]
|
| 89 |
+
if len(dups) > 0:
|
| 90 |
+
warn_message += f"\n WARNING: Some of the additional columns/keys contain `{ac}` in their name. These will be ignored, and the column/key `{ac}` will be used instead. This could also result from a duplicate column/key in the provided file."
|
| 91 |
+
immediate_msg = f"\n- The input file should contain exactly two columns/keys per row. Additional columns/keys present are: {additional_columns}{warn_message}"
|
| 92 |
+
necessary_msg = f"Remove additional columns/keys: {additional_columns}"
|
| 93 |
+
|
| 94 |
+
def necessary_fn(x: Any) -> Any:
|
| 95 |
+
return x[fields]
|
| 96 |
+
|
| 97 |
+
return Remediation(
|
| 98 |
+
name="additional_column",
|
| 99 |
+
immediate_msg=immediate_msg,
|
| 100 |
+
necessary_msg=necessary_msg,
|
| 101 |
+
necessary_fn=necessary_fn,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def non_empty_field_validator(df: pd.DataFrame, field: str = "completion") -> Remediation:
|
| 106 |
+
"""
|
| 107 |
+
This validator will ensure that no completion is empty.
|
| 108 |
+
"""
|
| 109 |
+
necessary_msg = None
|
| 110 |
+
necessary_fn = None # type: ignore
|
| 111 |
+
immediate_msg = None
|
| 112 |
+
|
| 113 |
+
if df[field].apply(lambda x: x == "").any() or df[field].isnull().any():
|
| 114 |
+
empty_rows = (df[field] == "") | (df[field].isnull())
|
| 115 |
+
empty_indexes = df.reset_index().index[empty_rows].tolist()
|
| 116 |
+
immediate_msg = f"\n- `{field}` column/key should not contain empty strings. These are rows: {empty_indexes}"
|
| 117 |
+
|
| 118 |
+
def necessary_fn(x: Any) -> Any:
|
| 119 |
+
return x[x[field] != ""].dropna(subset=[field])
|
| 120 |
+
|
| 121 |
+
necessary_msg = f"Remove {len(empty_indexes)} rows with empty {field}s"
|
| 122 |
+
|
| 123 |
+
return Remediation(
|
| 124 |
+
name=f"empty_{field}",
|
| 125 |
+
immediate_msg=immediate_msg,
|
| 126 |
+
necessary_msg=necessary_msg,
|
| 127 |
+
necessary_fn=necessary_fn,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def duplicated_rows_validator(df: pd.DataFrame, fields: list[str] = ["prompt", "completion"]) -> Remediation:
|
| 132 |
+
"""
|
| 133 |
+
This validator will suggest to the user to remove duplicate rows if they exist.
|
| 134 |
+
"""
|
| 135 |
+
duplicated_rows = df.duplicated(subset=fields)
|
| 136 |
+
duplicated_indexes = df.reset_index().index[duplicated_rows].tolist()
|
| 137 |
+
immediate_msg = None
|
| 138 |
+
optional_msg = None
|
| 139 |
+
optional_fn = None # type: ignore
|
| 140 |
+
|
| 141 |
+
if len(duplicated_indexes) > 0:
|
| 142 |
+
immediate_msg = f"\n- There are {len(duplicated_indexes)} duplicated {'-'.join(fields)} sets. These are rows: {duplicated_indexes}"
|
| 143 |
+
optional_msg = f"Remove {len(duplicated_indexes)} duplicate rows"
|
| 144 |
+
|
| 145 |
+
def optional_fn(x: Any) -> Any:
|
| 146 |
+
return x.drop_duplicates(subset=fields)
|
| 147 |
+
|
| 148 |
+
return Remediation(
|
| 149 |
+
name="duplicated_rows",
|
| 150 |
+
immediate_msg=immediate_msg,
|
| 151 |
+
optional_msg=optional_msg,
|
| 152 |
+
optional_fn=optional_fn,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def long_examples_validator(df: pd.DataFrame) -> Remediation:
|
| 157 |
+
"""
|
| 158 |
+
This validator will suggest to the user to remove examples that are too long.
|
| 159 |
+
"""
|
| 160 |
+
immediate_msg = None
|
| 161 |
+
optional_msg = None
|
| 162 |
+
optional_fn = None # type: ignore
|
| 163 |
+
|
| 164 |
+
ft_type = infer_task_type(df)
|
| 165 |
+
if ft_type != "open-ended generation":
|
| 166 |
+
|
| 167 |
+
def get_long_indexes(d: pd.DataFrame) -> Any:
|
| 168 |
+
long_examples = d.apply(lambda x: len(x.prompt) + len(x.completion) > 10000, axis=1)
|
| 169 |
+
return d.reset_index().index[long_examples].tolist()
|
| 170 |
+
|
| 171 |
+
long_indexes = get_long_indexes(df)
|
| 172 |
+
|
| 173 |
+
if len(long_indexes) > 0:
|
| 174 |
+
immediate_msg = f"\n- There are {len(long_indexes)} examples that are very long. These are rows: {long_indexes}\nFor conditional generation, and for classification the examples shouldn't be longer than 2048 tokens."
|
| 175 |
+
optional_msg = f"Remove {len(long_indexes)} long examples"
|
| 176 |
+
|
| 177 |
+
def optional_fn(x: Any) -> Any:
|
| 178 |
+
long_indexes_to_drop = get_long_indexes(x)
|
| 179 |
+
if long_indexes != long_indexes_to_drop:
|
| 180 |
+
sys.stdout.write(
|
| 181 |
+
f"The indices of the long examples has changed as a result of a previously applied recommendation.\nThe {len(long_indexes_to_drop)} long examples to be dropped are now at the following indices: {long_indexes_to_drop}\n"
|
| 182 |
+
)
|
| 183 |
+
return x.drop(long_indexes_to_drop)
|
| 184 |
+
|
| 185 |
+
return Remediation(
|
| 186 |
+
name="long_examples",
|
| 187 |
+
immediate_msg=immediate_msg,
|
| 188 |
+
optional_msg=optional_msg,
|
| 189 |
+
optional_fn=optional_fn,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def common_prompt_suffix_validator(df: pd.DataFrame) -> Remediation:
|
| 194 |
+
"""
|
| 195 |
+
This validator will suggest to add a common suffix to the prompt if one doesn't already exist in case of classification or conditional generation.
|
| 196 |
+
"""
|
| 197 |
+
error_msg = None
|
| 198 |
+
immediate_msg = None
|
| 199 |
+
optional_msg = None
|
| 200 |
+
optional_fn = None # type: ignore
|
| 201 |
+
|
| 202 |
+
# Find a suffix which is not contained within the prompt otherwise
|
| 203 |
+
suggested_suffix = "\n\n### =>\n\n"
|
| 204 |
+
suffix_options = [
|
| 205 |
+
" ->",
|
| 206 |
+
"\n\n###\n\n",
|
| 207 |
+
"\n\n===\n\n",
|
| 208 |
+
"\n\n---\n\n",
|
| 209 |
+
"\n\n===>\n\n",
|
| 210 |
+
"\n\n--->\n\n",
|
| 211 |
+
]
|
| 212 |
+
for suffix_option in suffix_options:
|
| 213 |
+
if suffix_option == " ->":
|
| 214 |
+
if df.prompt.str.contains("\n").any():
|
| 215 |
+
continue
|
| 216 |
+
if df.prompt.str.contains(suffix_option, regex=False).any():
|
| 217 |
+
continue
|
| 218 |
+
suggested_suffix = suffix_option
|
| 219 |
+
break
|
| 220 |
+
display_suggested_suffix = suggested_suffix.replace("\n", "\\n")
|
| 221 |
+
|
| 222 |
+
ft_type = infer_task_type(df)
|
| 223 |
+
if ft_type == "open-ended generation":
|
| 224 |
+
return Remediation(name="common_suffix")
|
| 225 |
+
|
| 226 |
+
def add_suffix(x: Any, suffix: Any) -> Any:
|
| 227 |
+
x["prompt"] += suffix
|
| 228 |
+
return x
|
| 229 |
+
|
| 230 |
+
common_suffix = get_common_xfix(df.prompt, xfix="suffix")
|
| 231 |
+
if (df.prompt == common_suffix).all():
|
| 232 |
+
error_msg = f"All prompts are identical: `{common_suffix}`\nConsider leaving the prompts blank if you want to do open-ended generation, otherwise ensure prompts are different"
|
| 233 |
+
return Remediation(name="common_suffix", error_msg=error_msg)
|
| 234 |
+
|
| 235 |
+
if common_suffix != "":
|
| 236 |
+
common_suffix_new_line_handled = common_suffix.replace("\n", "\\n")
|
| 237 |
+
immediate_msg = f"\n- All prompts end with suffix `{common_suffix_new_line_handled}`"
|
| 238 |
+
if len(common_suffix) > 10:
|
| 239 |
+
immediate_msg += f". This suffix seems very long. Consider replacing with a shorter suffix, such as `{display_suggested_suffix}`"
|
| 240 |
+
if df.prompt.str[: -len(common_suffix)].str.contains(common_suffix, regex=False).any():
|
| 241 |
+
immediate_msg += f"\n WARNING: Some of your prompts contain the suffix `{common_suffix}` more than once. We strongly suggest that you review your prompts and add a unique suffix"
|
| 242 |
+
|
| 243 |
+
else:
|
| 244 |
+
immediate_msg = "\n- Your data does not contain a common separator at the end of your prompts. Having a separator string appended to the end of the prompt makes it clearer to the fine-tuned model where the completion should begin. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples. If you intend to do open-ended generation, then you should leave the prompts empty"
|
| 245 |
+
|
| 246 |
+
if common_suffix == "":
|
| 247 |
+
optional_msg = f"Add a suffix separator `{display_suggested_suffix}` to all prompts"
|
| 248 |
+
|
| 249 |
+
def optional_fn(x: Any) -> Any:
|
| 250 |
+
return add_suffix(x, suggested_suffix)
|
| 251 |
+
|
| 252 |
+
return Remediation(
|
| 253 |
+
name="common_completion_suffix",
|
| 254 |
+
immediate_msg=immediate_msg,
|
| 255 |
+
optional_msg=optional_msg,
|
| 256 |
+
optional_fn=optional_fn,
|
| 257 |
+
error_msg=error_msg,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def common_prompt_prefix_validator(df: pd.DataFrame) -> Remediation:
|
| 262 |
+
"""
|
| 263 |
+
This validator will suggest to remove a common prefix from the prompt if a long one exist.
|
| 264 |
+
"""
|
| 265 |
+
MAX_PREFIX_LEN = 12
|
| 266 |
+
|
| 267 |
+
immediate_msg = None
|
| 268 |
+
optional_msg = None
|
| 269 |
+
optional_fn = None # type: ignore
|
| 270 |
+
|
| 271 |
+
common_prefix = get_common_xfix(df.prompt, xfix="prefix")
|
| 272 |
+
if common_prefix == "":
|
| 273 |
+
return Remediation(name="common_prefix")
|
| 274 |
+
|
| 275 |
+
def remove_common_prefix(x: Any, prefix: Any) -> Any:
|
| 276 |
+
x["prompt"] = x["prompt"].str[len(prefix) :]
|
| 277 |
+
return x
|
| 278 |
+
|
| 279 |
+
if (df.prompt == common_prefix).all():
|
| 280 |
+
# already handled by common_suffix_validator
|
| 281 |
+
return Remediation(name="common_prefix")
|
| 282 |
+
|
| 283 |
+
if common_prefix != "":
|
| 284 |
+
immediate_msg = f"\n- All prompts start with prefix `{common_prefix}`"
|
| 285 |
+
if MAX_PREFIX_LEN < len(common_prefix):
|
| 286 |
+
immediate_msg += ". Fine-tuning doesn't require the instruction specifying the task, or a few-shot example scenario. Most of the time you should only add the input data into the prompt, and the desired output into the completion"
|
| 287 |
+
optional_msg = f"Remove prefix `{common_prefix}` from all prompts"
|
| 288 |
+
|
| 289 |
+
def optional_fn(x: Any) -> Any:
|
| 290 |
+
return remove_common_prefix(x, common_prefix)
|
| 291 |
+
|
| 292 |
+
return Remediation(
|
| 293 |
+
name="common_prompt_prefix",
|
| 294 |
+
immediate_msg=immediate_msg,
|
| 295 |
+
optional_msg=optional_msg,
|
| 296 |
+
optional_fn=optional_fn,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def common_completion_prefix_validator(df: pd.DataFrame) -> Remediation:
|
| 301 |
+
"""
|
| 302 |
+
This validator will suggest to remove a common prefix from the completion if a long one exist.
|
| 303 |
+
"""
|
| 304 |
+
MAX_PREFIX_LEN = 5
|
| 305 |
+
|
| 306 |
+
common_prefix = get_common_xfix(df.completion, xfix="prefix")
|
| 307 |
+
ws_prefix = len(common_prefix) > 0 and common_prefix[0] == " "
|
| 308 |
+
if len(common_prefix) < MAX_PREFIX_LEN:
|
| 309 |
+
return Remediation(name="common_prefix")
|
| 310 |
+
|
| 311 |
+
def remove_common_prefix(x: Any, prefix: Any, ws_prefix: Any) -> Any:
|
| 312 |
+
x["completion"] = x["completion"].str[len(prefix) :]
|
| 313 |
+
if ws_prefix:
|
| 314 |
+
# keep the single whitespace as prefix
|
| 315 |
+
x["completion"] = f" {x['completion']}"
|
| 316 |
+
return x
|
| 317 |
+
|
| 318 |
+
if (df.completion == common_prefix).all():
|
| 319 |
+
# already handled by common_suffix_validator
|
| 320 |
+
return Remediation(name="common_prefix")
|
| 321 |
+
|
| 322 |
+
immediate_msg = f"\n- All completions start with prefix `{common_prefix}`. Most of the time you should only add the output data into the completion, without any prefix"
|
| 323 |
+
optional_msg = f"Remove prefix `{common_prefix}` from all completions"
|
| 324 |
+
|
| 325 |
+
def optional_fn(x: Any) -> Any:
|
| 326 |
+
return remove_common_prefix(x, common_prefix, ws_prefix)
|
| 327 |
+
|
| 328 |
+
return Remediation(
|
| 329 |
+
name="common_completion_prefix",
|
| 330 |
+
immediate_msg=immediate_msg,
|
| 331 |
+
optional_msg=optional_msg,
|
| 332 |
+
optional_fn=optional_fn,
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def common_completion_suffix_validator(df: pd.DataFrame) -> Remediation:
|
| 337 |
+
"""
|
| 338 |
+
This validator will suggest to add a common suffix to the completion if one doesn't already exist in case of classification or conditional generation.
|
| 339 |
+
"""
|
| 340 |
+
error_msg = None
|
| 341 |
+
immediate_msg = None
|
| 342 |
+
optional_msg = None
|
| 343 |
+
optional_fn = None # type: ignore
|
| 344 |
+
|
| 345 |
+
ft_type = infer_task_type(df)
|
| 346 |
+
if ft_type == "open-ended generation" or ft_type == "classification":
|
| 347 |
+
return Remediation(name="common_suffix")
|
| 348 |
+
|
| 349 |
+
common_suffix = get_common_xfix(df.completion, xfix="suffix")
|
| 350 |
+
if (df.completion == common_suffix).all():
|
| 351 |
+
error_msg = f"All completions are identical: `{common_suffix}`\nEnsure completions are different, otherwise the model will just repeat `{common_suffix}`"
|
| 352 |
+
return Remediation(name="common_suffix", error_msg=error_msg)
|
| 353 |
+
|
| 354 |
+
# Find a suffix which is not contained within the completion otherwise
|
| 355 |
+
suggested_suffix = " [END]"
|
| 356 |
+
suffix_options = [
|
| 357 |
+
"\n",
|
| 358 |
+
".",
|
| 359 |
+
" END",
|
| 360 |
+
"***",
|
| 361 |
+
"+++",
|
| 362 |
+
"&&&",
|
| 363 |
+
"$$$",
|
| 364 |
+
"@@@",
|
| 365 |
+
"%%%",
|
| 366 |
+
]
|
| 367 |
+
for suffix_option in suffix_options:
|
| 368 |
+
if df.completion.str.contains(suffix_option, regex=False).any():
|
| 369 |
+
continue
|
| 370 |
+
suggested_suffix = suffix_option
|
| 371 |
+
break
|
| 372 |
+
display_suggested_suffix = suggested_suffix.replace("\n", "\\n")
|
| 373 |
+
|
| 374 |
+
def add_suffix(x: Any, suffix: Any) -> Any:
|
| 375 |
+
x["completion"] += suffix
|
| 376 |
+
return x
|
| 377 |
+
|
| 378 |
+
if common_suffix != "":
|
| 379 |
+
common_suffix_new_line_handled = common_suffix.replace("\n", "\\n")
|
| 380 |
+
immediate_msg = f"\n- All completions end with suffix `{common_suffix_new_line_handled}`"
|
| 381 |
+
if len(common_suffix) > 10:
|
| 382 |
+
immediate_msg += f". This suffix seems very long. Consider replacing with a shorter suffix, such as `{display_suggested_suffix}`"
|
| 383 |
+
if df.completion.str[: -len(common_suffix)].str.contains(common_suffix, regex=False).any():
|
| 384 |
+
immediate_msg += f"\n WARNING: Some of your completions contain the suffix `{common_suffix}` more than once. We suggest that you review your completions and add a unique ending"
|
| 385 |
+
|
| 386 |
+
else:
|
| 387 |
+
immediate_msg = "\n- Your data does not contain a common ending at the end of your completions. Having a common ending string appended to the end of the completion makes it clearer to the fine-tuned model where the completion should end. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples."
|
| 388 |
+
|
| 389 |
+
if common_suffix == "":
|
| 390 |
+
optional_msg = f"Add a suffix ending `{display_suggested_suffix}` to all completions"
|
| 391 |
+
|
| 392 |
+
def optional_fn(x: Any) -> Any:
|
| 393 |
+
return add_suffix(x, suggested_suffix)
|
| 394 |
+
|
| 395 |
+
return Remediation(
|
| 396 |
+
name="common_completion_suffix",
|
| 397 |
+
immediate_msg=immediate_msg,
|
| 398 |
+
optional_msg=optional_msg,
|
| 399 |
+
optional_fn=optional_fn,
|
| 400 |
+
error_msg=error_msg,
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def completions_space_start_validator(df: pd.DataFrame) -> Remediation:
|
| 405 |
+
"""
|
| 406 |
+
This validator will suggest to add a space at the start of the completion if it doesn't already exist. This helps with tokenization.
|
| 407 |
+
"""
|
| 408 |
+
|
| 409 |
+
def add_space_start(x: Any) -> Any:
|
| 410 |
+
x["completion"] = x["completion"].apply(lambda s: ("" if s.startswith(" ") else " ") + s)
|
| 411 |
+
return x
|
| 412 |
+
|
| 413 |
+
optional_msg = None
|
| 414 |
+
optional_fn = None
|
| 415 |
+
immediate_msg = None
|
| 416 |
+
|
| 417 |
+
if df.completion.str[:1].nunique() != 1 or df.completion.values[0][0] != " ":
|
| 418 |
+
immediate_msg = "\n- The completion should start with a whitespace character (` `). This tends to produce better results due to the tokenization we use. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details"
|
| 419 |
+
optional_msg = "Add a whitespace character to the beginning of the completion"
|
| 420 |
+
optional_fn = add_space_start
|
| 421 |
+
return Remediation(
|
| 422 |
+
name="completion_space_start",
|
| 423 |
+
immediate_msg=immediate_msg,
|
| 424 |
+
optional_msg=optional_msg,
|
| 425 |
+
optional_fn=optional_fn,
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def lower_case_validator(df: pd.DataFrame, column: Any) -> Remediation | None:
|
| 430 |
+
"""
|
| 431 |
+
This validator will suggest to lowercase the column values, if more than a third of letters are uppercase.
|
| 432 |
+
"""
|
| 433 |
+
|
| 434 |
+
def lower_case(x: Any) -> Any:
|
| 435 |
+
x[column] = x[column].str.lower()
|
| 436 |
+
return x
|
| 437 |
+
|
| 438 |
+
count_upper = df[column].apply(lambda x: sum(1 for c in x if c.isalpha() and c.isupper())).sum()
|
| 439 |
+
count_lower = df[column].apply(lambda x: sum(1 for c in x if c.isalpha() and c.islower())).sum()
|
| 440 |
+
|
| 441 |
+
if count_upper * 2 > count_lower:
|
| 442 |
+
return Remediation(
|
| 443 |
+
name="lower_case",
|
| 444 |
+
immediate_msg=f"\n- More than a third of your `{column}` column/key is uppercase. Uppercase {column}s tends to perform worse than a mixture of case encountered in normal language. We recommend to lower case the data if that makes sense in your domain. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details",
|
| 445 |
+
optional_msg=f"Lowercase all your data in column/key `{column}`",
|
| 446 |
+
optional_fn=lower_case,
|
| 447 |
+
)
|
| 448 |
+
return None
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def read_any_format(
|
| 452 |
+
fname: str, fields: list[str] = ["prompt", "completion"]
|
| 453 |
+
) -> tuple[pd.DataFrame | None, Remediation]:
|
| 454 |
+
"""
|
| 455 |
+
This function will read a file saved in .csv, .json, .txt, .xlsx or .tsv format using pandas.
|
| 456 |
+
- for .xlsx it will read the first sheet
|
| 457 |
+
- for .txt it will assume completions and split on newline
|
| 458 |
+
"""
|
| 459 |
+
remediation = None
|
| 460 |
+
necessary_msg = None
|
| 461 |
+
immediate_msg = None
|
| 462 |
+
error_msg = None
|
| 463 |
+
df = None
|
| 464 |
+
|
| 465 |
+
if os.path.isfile(fname):
|
| 466 |
+
try:
|
| 467 |
+
if fname.lower().endswith(".csv") or fname.lower().endswith(".tsv"):
|
| 468 |
+
file_extension_str, separator = ("CSV", ",") if fname.lower().endswith(".csv") else ("TSV", "\t")
|
| 469 |
+
immediate_msg = (
|
| 470 |
+
f"\n- Based on your file extension, your file is formatted as a {file_extension_str} file"
|
| 471 |
+
)
|
| 472 |
+
necessary_msg = f"Your format `{file_extension_str}` will be converted to `JSONL`"
|
| 473 |
+
df = pd.read_csv(fname, sep=separator, dtype=str).fillna("")
|
| 474 |
+
elif fname.lower().endswith(".xlsx"):
|
| 475 |
+
immediate_msg = "\n- Based on your file extension, your file is formatted as an Excel file"
|
| 476 |
+
necessary_msg = "Your format `XLSX` will be converted to `JSONL`"
|
| 477 |
+
xls = pd.ExcelFile(fname)
|
| 478 |
+
sheets = xls.sheet_names
|
| 479 |
+
if len(sheets) > 1:
|
| 480 |
+
immediate_msg += "\n- Your Excel file contains more than one sheet. Please either save as csv or ensure all data is present in the first sheet. WARNING: Reading only the first sheet..."
|
| 481 |
+
df = pd.read_excel(fname, dtype=str).fillna("")
|
| 482 |
+
elif fname.lower().endswith(".txt"):
|
| 483 |
+
immediate_msg = "\n- Based on your file extension, you provided a text file"
|
| 484 |
+
necessary_msg = "Your format `TXT` will be converted to `JSONL`"
|
| 485 |
+
with open(fname, "r") as f:
|
| 486 |
+
content = f.read()
|
| 487 |
+
df = pd.DataFrame(
|
| 488 |
+
[["", line] for line in content.split("\n")],
|
| 489 |
+
columns=fields,
|
| 490 |
+
dtype=str,
|
| 491 |
+
).fillna("")
|
| 492 |
+
elif fname.lower().endswith(".jsonl"):
|
| 493 |
+
df = pd.read_json(fname, lines=True, dtype=str).fillna("") # type: ignore
|
| 494 |
+
if len(df) == 1: # type: ignore
|
| 495 |
+
# this is NOT what we expect for a .jsonl file
|
| 496 |
+
immediate_msg = "\n- Your JSONL file appears to be in a JSON format. Your file will be converted to JSONL format"
|
| 497 |
+
necessary_msg = "Your format `JSON` will be converted to `JSONL`"
|
| 498 |
+
df = pd.read_json(fname, dtype=str).fillna("") # type: ignore
|
| 499 |
+
else:
|
| 500 |
+
pass # this is what we expect for a .jsonl file
|
| 501 |
+
elif fname.lower().endswith(".json"):
|
| 502 |
+
try:
|
| 503 |
+
# to handle case where .json file is actually a .jsonl file
|
| 504 |
+
df = pd.read_json(fname, lines=True, dtype=str).fillna("") # type: ignore
|
| 505 |
+
if len(df) == 1: # type: ignore
|
| 506 |
+
# this code path corresponds to a .json file that has one line
|
| 507 |
+
df = pd.read_json(fname, dtype=str).fillna("") # type: ignore
|
| 508 |
+
else:
|
| 509 |
+
# this is NOT what we expect for a .json file
|
| 510 |
+
immediate_msg = "\n- Your JSON file appears to be in a JSONL format. Your file will be converted to JSONL format"
|
| 511 |
+
necessary_msg = "Your format `JSON` will be converted to `JSONL`"
|
| 512 |
+
except ValueError:
|
| 513 |
+
# this code path corresponds to a .json file that has multiple lines (i.e. it is indented)
|
| 514 |
+
df = pd.read_json(fname, dtype=str).fillna("") # type: ignore
|
| 515 |
+
else:
|
| 516 |
+
error_msg = (
|
| 517 |
+
"Your file must have one of the following extensions: .CSV, .TSV, .XLSX, .TXT, .JSON or .JSONL"
|
| 518 |
+
)
|
| 519 |
+
if "." in fname:
|
| 520 |
+
error_msg += f" Your file `{fname}` ends with the extension `.{fname.split('.')[-1]}` which is not supported."
|
| 521 |
+
else:
|
| 522 |
+
error_msg += f" Your file `{fname}` is missing a file extension."
|
| 523 |
+
|
| 524 |
+
except (ValueError, TypeError):
|
| 525 |
+
file_extension_str = fname.split(".")[-1].upper()
|
| 526 |
+
error_msg = f"Your file `{fname}` does not appear to be in valid {file_extension_str} format. Please ensure your file is formatted as a valid {file_extension_str} file."
|
| 527 |
+
|
| 528 |
+
else:
|
| 529 |
+
error_msg = f"File {fname} does not exist."
|
| 530 |
+
|
| 531 |
+
remediation = Remediation(
|
| 532 |
+
name="read_any_format",
|
| 533 |
+
necessary_msg=necessary_msg,
|
| 534 |
+
immediate_msg=immediate_msg,
|
| 535 |
+
error_msg=error_msg,
|
| 536 |
+
)
|
| 537 |
+
return df, remediation
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
def format_inferrer_validator(df: pd.DataFrame) -> Remediation:
|
| 541 |
+
"""
|
| 542 |
+
This validator will infer the likely fine-tuning format of the data, and display it to the user if it is classification.
|
| 543 |
+
It will also suggest to use ada and explain train/validation split benefits.
|
| 544 |
+
"""
|
| 545 |
+
ft_type = infer_task_type(df)
|
| 546 |
+
immediate_msg = None
|
| 547 |
+
if ft_type == "classification":
|
| 548 |
+
immediate_msg = f"\n- Based on your data it seems like you're trying to fine-tune a model for {ft_type}\n- For classification, we recommend you try one of the faster and cheaper models, such as `ada`\n- For classification, you can estimate the expected model performance by keeping a held out dataset, which is not used for training"
|
| 549 |
+
return Remediation(name="num_examples", immediate_msg=immediate_msg)
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
def apply_necessary_remediation(df: OptionalDataFrameT, remediation: Remediation) -> OptionalDataFrameT:
|
| 553 |
+
"""
|
| 554 |
+
This function will apply a necessary remediation to a dataframe, or print an error message if one exists.
|
| 555 |
+
"""
|
| 556 |
+
if remediation.error_msg is not None:
|
| 557 |
+
sys.stderr.write(f"\n\nERROR in {remediation.name} validator: {remediation.error_msg}\n\nAborting...")
|
| 558 |
+
sys.exit(1)
|
| 559 |
+
if remediation.immediate_msg is not None:
|
| 560 |
+
sys.stdout.write(remediation.immediate_msg)
|
| 561 |
+
if remediation.necessary_fn is not None:
|
| 562 |
+
df = remediation.necessary_fn(df)
|
| 563 |
+
return df
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def accept_suggestion(input_text: str, auto_accept: bool) -> bool:
|
| 567 |
+
sys.stdout.write(input_text)
|
| 568 |
+
if auto_accept:
|
| 569 |
+
sys.stdout.write("Y\n")
|
| 570 |
+
return True
|
| 571 |
+
return input().lower() != "n"
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
def apply_optional_remediation(
|
| 575 |
+
df: pd.DataFrame, remediation: Remediation, auto_accept: bool
|
| 576 |
+
) -> tuple[pd.DataFrame, bool]:
|
| 577 |
+
"""
|
| 578 |
+
This function will apply an optional remediation to a dataframe, based on the user input.
|
| 579 |
+
"""
|
| 580 |
+
optional_applied = False
|
| 581 |
+
input_text = f"- [Recommended] {remediation.optional_msg} [Y/n]: "
|
| 582 |
+
if remediation.optional_msg is not None:
|
| 583 |
+
if accept_suggestion(input_text, auto_accept):
|
| 584 |
+
assert remediation.optional_fn is not None
|
| 585 |
+
df = remediation.optional_fn(df)
|
| 586 |
+
optional_applied = True
|
| 587 |
+
if remediation.necessary_msg is not None:
|
| 588 |
+
sys.stdout.write(f"- [Necessary] {remediation.necessary_msg}\n")
|
| 589 |
+
return df, optional_applied
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
def estimate_fine_tuning_time(df: pd.DataFrame) -> None:
|
| 593 |
+
"""
|
| 594 |
+
Estimate the time it'll take to fine-tune the dataset
|
| 595 |
+
"""
|
| 596 |
+
ft_format = infer_task_type(df)
|
| 597 |
+
expected_time = 1.0
|
| 598 |
+
if ft_format == "classification":
|
| 599 |
+
num_examples = len(df)
|
| 600 |
+
expected_time = num_examples * 1.44
|
| 601 |
+
else:
|
| 602 |
+
size = df.memory_usage(index=True).sum()
|
| 603 |
+
expected_time = size * 0.0515
|
| 604 |
+
|
| 605 |
+
def format_time(time: float) -> str:
|
| 606 |
+
if time < 60:
|
| 607 |
+
return f"{round(time, 2)} seconds"
|
| 608 |
+
elif time < 3600:
|
| 609 |
+
return f"{round(time / 60, 2)} minutes"
|
| 610 |
+
elif time < 86400:
|
| 611 |
+
return f"{round(time / 3600, 2)} hours"
|
| 612 |
+
else:
|
| 613 |
+
return f"{round(time / 86400, 2)} days"
|
| 614 |
+
|
| 615 |
+
time_string = format_time(expected_time + 140)
|
| 616 |
+
sys.stdout.write(
|
| 617 |
+
f"Once your model starts training, it'll approximately take {time_string} to train a `curie` model, and less for `ada` and `babbage`. Queue will approximately take half an hour per job ahead of you.\n"
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
def get_outfnames(fname: str, split: bool) -> list[str]:
|
| 622 |
+
suffixes = ["_train", "_valid"] if split else [""]
|
| 623 |
+
i = 0
|
| 624 |
+
while True:
|
| 625 |
+
index_suffix = f" ({i})" if i > 0 else ""
|
| 626 |
+
candidate_fnames = [f"{os.path.splitext(fname)[0]}_prepared{suffix}{index_suffix}.jsonl" for suffix in suffixes]
|
| 627 |
+
if not any(os.path.isfile(f) for f in candidate_fnames):
|
| 628 |
+
return candidate_fnames
|
| 629 |
+
i += 1
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
def get_classification_hyperparams(df: pd.DataFrame) -> tuple[int, object]:
|
| 633 |
+
n_classes = df.completion.nunique()
|
| 634 |
+
pos_class = None
|
| 635 |
+
if n_classes == 2:
|
| 636 |
+
pos_class = df.completion.value_counts().index[0]
|
| 637 |
+
return n_classes, pos_class
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
def write_out_file(df: pd.DataFrame, fname: str, any_remediations: bool, auto_accept: bool) -> None:
|
| 641 |
+
"""
|
| 642 |
+
This function will write out a dataframe to a file, if the user would like to proceed, and also offer a fine-tuning command with the newly created file.
|
| 643 |
+
For classification it will optionally ask the user if they would like to split the data into train/valid files, and modify the suggested command to include the valid set.
|
| 644 |
+
"""
|
| 645 |
+
ft_format = infer_task_type(df)
|
| 646 |
+
common_prompt_suffix = get_common_xfix(df.prompt, xfix="suffix")
|
| 647 |
+
common_completion_suffix = get_common_xfix(df.completion, xfix="suffix")
|
| 648 |
+
|
| 649 |
+
split = False
|
| 650 |
+
input_text = "- [Recommended] Would you like to split into training and validation set? [Y/n]: "
|
| 651 |
+
if ft_format == "classification":
|
| 652 |
+
if accept_suggestion(input_text, auto_accept):
|
| 653 |
+
split = True
|
| 654 |
+
|
| 655 |
+
additional_params = ""
|
| 656 |
+
common_prompt_suffix_new_line_handled = common_prompt_suffix.replace("\n", "\\n")
|
| 657 |
+
common_completion_suffix_new_line_handled = common_completion_suffix.replace("\n", "\\n")
|
| 658 |
+
optional_ending_string = (
|
| 659 |
+
f' Make sure to include `stop=["{common_completion_suffix_new_line_handled}"]` so that the generated texts ends at the expected place.'
|
| 660 |
+
if len(common_completion_suffix_new_line_handled) > 0
|
| 661 |
+
else ""
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
input_text = "\n\nYour data will be written to a new JSONL file. Proceed [Y/n]: "
|
| 665 |
+
|
| 666 |
+
if not any_remediations and not split:
|
| 667 |
+
sys.stdout.write(
|
| 668 |
+
f'\nYou can use your file for fine-tuning:\n> openai api fine_tunes.create -t "{fname}"{additional_params}\n\nAfter you’ve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt.{optional_ending_string}\n'
|
| 669 |
+
)
|
| 670 |
+
estimate_fine_tuning_time(df)
|
| 671 |
+
|
| 672 |
+
elif accept_suggestion(input_text, auto_accept):
|
| 673 |
+
fnames = get_outfnames(fname, split)
|
| 674 |
+
if split:
|
| 675 |
+
assert len(fnames) == 2 and "train" in fnames[0] and "valid" in fnames[1]
|
| 676 |
+
MAX_VALID_EXAMPLES = 1000
|
| 677 |
+
n_train = max(len(df) - MAX_VALID_EXAMPLES, int(len(df) * 0.8))
|
| 678 |
+
df_train = df.sample(n=n_train, random_state=42)
|
| 679 |
+
df_valid = df.drop(df_train.index)
|
| 680 |
+
df_train[["prompt", "completion"]].to_json( # type: ignore
|
| 681 |
+
fnames[0], lines=True, orient="records", force_ascii=False, indent=None
|
| 682 |
+
)
|
| 683 |
+
df_valid[["prompt", "completion"]].to_json(
|
| 684 |
+
fnames[1], lines=True, orient="records", force_ascii=False, indent=None
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
n_classes, pos_class = get_classification_hyperparams(df)
|
| 688 |
+
additional_params += " --compute_classification_metrics"
|
| 689 |
+
if n_classes == 2:
|
| 690 |
+
additional_params += f' --classification_positive_class "{pos_class}"'
|
| 691 |
+
else:
|
| 692 |
+
additional_params += f" --classification_n_classes {n_classes}"
|
| 693 |
+
else:
|
| 694 |
+
assert len(fnames) == 1
|
| 695 |
+
df[["prompt", "completion"]].to_json(
|
| 696 |
+
fnames[0], lines=True, orient="records", force_ascii=False, indent=None
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
# Add -v VALID_FILE if we split the file into train / valid
|
| 700 |
+
files_string = ("s" if split else "") + " to `" + ("` and `".join(fnames))
|
| 701 |
+
valid_string = f' -v "{fnames[1]}"' if split else ""
|
| 702 |
+
separator_reminder = (
|
| 703 |
+
""
|
| 704 |
+
if len(common_prompt_suffix_new_line_handled) == 0
|
| 705 |
+
else f"After you’ve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt."
|
| 706 |
+
)
|
| 707 |
+
sys.stdout.write(
|
| 708 |
+
f'\nWrote modified file{files_string}`\nFeel free to take a look!\n\nNow use that file when fine-tuning:\n> openai api fine_tunes.create -t "{fnames[0]}"{valid_string}{additional_params}\n\n{separator_reminder}{optional_ending_string}\n'
|
| 709 |
+
)
|
| 710 |
+
estimate_fine_tuning_time(df)
|
| 711 |
+
else:
|
| 712 |
+
sys.stdout.write("Aborting... did not write the file\n")
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
def infer_task_type(df: pd.DataFrame) -> str:
|
| 716 |
+
"""
|
| 717 |
+
Infer the likely fine-tuning task type from the data
|
| 718 |
+
"""
|
| 719 |
+
CLASSIFICATION_THRESHOLD = 3 # min_average instances of each class
|
| 720 |
+
if sum(df.prompt.str.len()) == 0:
|
| 721 |
+
return "open-ended generation"
|
| 722 |
+
|
| 723 |
+
if len(df.completion.unique()) < len(df) / CLASSIFICATION_THRESHOLD:
|
| 724 |
+
return "classification"
|
| 725 |
+
|
| 726 |
+
return "conditional generation"
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
def get_common_xfix(series: Any, xfix: str = "suffix") -> str:
|
| 730 |
+
"""
|
| 731 |
+
Finds the longest common suffix or prefix of all the values in a series
|
| 732 |
+
"""
|
| 733 |
+
common_xfix = ""
|
| 734 |
+
while True:
|
| 735 |
+
common_xfixes = (
|
| 736 |
+
series.str[-(len(common_xfix) + 1) :] if xfix == "suffix" else series.str[: len(common_xfix) + 1]
|
| 737 |
+
) # first few or last few characters
|
| 738 |
+
if common_xfixes.nunique() != 1: # we found the character at which we don't have a unique xfix anymore
|
| 739 |
+
break
|
| 740 |
+
elif common_xfix == common_xfixes.values[0]: # the entire first row is a prefix of every other row
|
| 741 |
+
break
|
| 742 |
+
else: # the first or last few characters are still common across all rows - let's try to add one more
|
| 743 |
+
common_xfix = common_xfixes.values[0]
|
| 744 |
+
return common_xfix
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
Validator: TypeAlias = "Callable[[pd.DataFrame], Remediation | None]"
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
def get_validators() -> list[Validator]:
|
| 751 |
+
return [
|
| 752 |
+
num_examples_validator,
|
| 753 |
+
lambda x: necessary_column_validator(x, "prompt"),
|
| 754 |
+
lambda x: necessary_column_validator(x, "completion"),
|
| 755 |
+
additional_column_validator,
|
| 756 |
+
non_empty_field_validator,
|
| 757 |
+
format_inferrer_validator,
|
| 758 |
+
duplicated_rows_validator,
|
| 759 |
+
long_examples_validator,
|
| 760 |
+
lambda x: lower_case_validator(x, "prompt"),
|
| 761 |
+
lambda x: lower_case_validator(x, "completion"),
|
| 762 |
+
common_prompt_suffix_validator,
|
| 763 |
+
common_prompt_prefix_validator,
|
| 764 |
+
common_completion_prefix_validator,
|
| 765 |
+
common_completion_suffix_validator,
|
| 766 |
+
completions_space_start_validator,
|
| 767 |
+
]
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
def apply_validators(
|
| 771 |
+
df: pd.DataFrame,
|
| 772 |
+
fname: str,
|
| 773 |
+
remediation: Remediation | None,
|
| 774 |
+
validators: list[Validator],
|
| 775 |
+
auto_accept: bool,
|
| 776 |
+
write_out_file_func: Callable[..., Any],
|
| 777 |
+
) -> None:
|
| 778 |
+
optional_remediations: list[Remediation] = []
|
| 779 |
+
if remediation is not None:
|
| 780 |
+
optional_remediations.append(remediation)
|
| 781 |
+
for validator in validators:
|
| 782 |
+
remediation = validator(df)
|
| 783 |
+
if remediation is not None:
|
| 784 |
+
optional_remediations.append(remediation)
|
| 785 |
+
df = apply_necessary_remediation(df, remediation)
|
| 786 |
+
|
| 787 |
+
any_optional_or_necessary_remediations = any(
|
| 788 |
+
[
|
| 789 |
+
remediation
|
| 790 |
+
for remediation in optional_remediations
|
| 791 |
+
if remediation.optional_msg is not None or remediation.necessary_msg is not None
|
| 792 |
+
]
|
| 793 |
+
)
|
| 794 |
+
any_necessary_applied = any(
|
| 795 |
+
[remediation for remediation in optional_remediations if remediation.necessary_msg is not None]
|
| 796 |
+
)
|
| 797 |
+
any_optional_applied = False
|
| 798 |
+
|
| 799 |
+
if any_optional_or_necessary_remediations:
|
| 800 |
+
sys.stdout.write("\n\nBased on the analysis we will perform the following actions:\n")
|
| 801 |
+
for remediation in optional_remediations:
|
| 802 |
+
df, optional_applied = apply_optional_remediation(df, remediation, auto_accept)
|
| 803 |
+
any_optional_applied = any_optional_applied or optional_applied
|
| 804 |
+
else:
|
| 805 |
+
sys.stdout.write("\n\nNo remediations found.\n")
|
| 806 |
+
|
| 807 |
+
any_optional_or_necessary_applied = any_optional_applied or any_necessary_applied
|
| 808 |
+
|
| 809 |
+
write_out_file_func(df, fname, any_optional_or_necessary_applied, auto_accept)
|
.venv/lib/python3.11/site-packages/openai/lib/azure.py
ADDED
|
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import inspect
|
| 5 |
+
from typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, cast, overload
|
| 6 |
+
from typing_extensions import Self, override
|
| 7 |
+
|
| 8 |
+
import httpx
|
| 9 |
+
|
| 10 |
+
from .._types import NOT_GIVEN, Omit, Query, Timeout, NotGiven
|
| 11 |
+
from .._utils import is_given, is_mapping
|
| 12 |
+
from .._client import OpenAI, AsyncOpenAI
|
| 13 |
+
from .._compat import model_copy
|
| 14 |
+
from .._models import FinalRequestOptions
|
| 15 |
+
from .._streaming import Stream, AsyncStream
|
| 16 |
+
from .._exceptions import OpenAIError
|
| 17 |
+
from .._base_client import DEFAULT_MAX_RETRIES, BaseClient
|
| 18 |
+
|
| 19 |
+
_deployments_endpoints = set(
|
| 20 |
+
[
|
| 21 |
+
"/completions",
|
| 22 |
+
"/chat/completions",
|
| 23 |
+
"/embeddings",
|
| 24 |
+
"/audio/transcriptions",
|
| 25 |
+
"/audio/translations",
|
| 26 |
+
"/audio/speech",
|
| 27 |
+
"/images/generations",
|
| 28 |
+
]
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
AzureADTokenProvider = Callable[[], str]
|
| 33 |
+
AsyncAzureADTokenProvider = Callable[[], "str | Awaitable[str]"]
|
| 34 |
+
_HttpxClientT = TypeVar("_HttpxClientT", bound=Union[httpx.Client, httpx.AsyncClient])
|
| 35 |
+
_DefaultStreamT = TypeVar("_DefaultStreamT", bound=Union[Stream[Any], AsyncStream[Any]])
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# we need to use a sentinel API key value for Azure AD
|
| 39 |
+
# as we don't want to make the `api_key` in the main client Optional
|
| 40 |
+
# and Azure AD tokens may be retrieved on a per-request basis
|
| 41 |
+
API_KEY_SENTINEL = "".join(["<", "missing API key", ">"])
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class MutuallyExclusiveAuthError(OpenAIError):
|
| 45 |
+
def __init__(self) -> None:
|
| 46 |
+
super().__init__(
|
| 47 |
+
"The `api_key`, `azure_ad_token` and `azure_ad_token_provider` arguments are mutually exclusive; Only one can be passed at a time"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class BaseAzureClient(BaseClient[_HttpxClientT, _DefaultStreamT]):
|
| 52 |
+
@override
|
| 53 |
+
def _build_request(
|
| 54 |
+
self,
|
| 55 |
+
options: FinalRequestOptions,
|
| 56 |
+
*,
|
| 57 |
+
retries_taken: int = 0,
|
| 58 |
+
) -> httpx.Request:
|
| 59 |
+
if options.url in _deployments_endpoints and is_mapping(options.json_data):
|
| 60 |
+
model = options.json_data.get("model")
|
| 61 |
+
if model is not None and not "/deployments" in str(self.base_url):
|
| 62 |
+
options.url = f"/deployments/{model}{options.url}"
|
| 63 |
+
|
| 64 |
+
return super()._build_request(options, retries_taken=retries_taken)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
|
| 68 |
+
@overload
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
*,
|
| 72 |
+
azure_endpoint: str,
|
| 73 |
+
azure_deployment: str | None = None,
|
| 74 |
+
api_version: str | None = None,
|
| 75 |
+
api_key: str | None = None,
|
| 76 |
+
azure_ad_token: str | None = None,
|
| 77 |
+
azure_ad_token_provider: AzureADTokenProvider | None = None,
|
| 78 |
+
organization: str | None = None,
|
| 79 |
+
websocket_base_url: str | httpx.URL | None = None,
|
| 80 |
+
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
| 81 |
+
max_retries: int = DEFAULT_MAX_RETRIES,
|
| 82 |
+
default_headers: Mapping[str, str] | None = None,
|
| 83 |
+
default_query: Mapping[str, object] | None = None,
|
| 84 |
+
http_client: httpx.Client | None = None,
|
| 85 |
+
_strict_response_validation: bool = False,
|
| 86 |
+
) -> None: ...
|
| 87 |
+
|
| 88 |
+
@overload
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
*,
|
| 92 |
+
azure_deployment: str | None = None,
|
| 93 |
+
api_version: str | None = None,
|
| 94 |
+
api_key: str | None = None,
|
| 95 |
+
azure_ad_token: str | None = None,
|
| 96 |
+
azure_ad_token_provider: AzureADTokenProvider | None = None,
|
| 97 |
+
organization: str | None = None,
|
| 98 |
+
websocket_base_url: str | httpx.URL | None = None,
|
| 99 |
+
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
| 100 |
+
max_retries: int = DEFAULT_MAX_RETRIES,
|
| 101 |
+
default_headers: Mapping[str, str] | None = None,
|
| 102 |
+
default_query: Mapping[str, object] | None = None,
|
| 103 |
+
http_client: httpx.Client | None = None,
|
| 104 |
+
_strict_response_validation: bool = False,
|
| 105 |
+
) -> None: ...
|
| 106 |
+
|
| 107 |
+
@overload
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
*,
|
| 111 |
+
base_url: str,
|
| 112 |
+
api_version: str | None = None,
|
| 113 |
+
api_key: str | None = None,
|
| 114 |
+
azure_ad_token: str | None = None,
|
| 115 |
+
azure_ad_token_provider: AzureADTokenProvider | None = None,
|
| 116 |
+
organization: str | None = None,
|
| 117 |
+
websocket_base_url: str | httpx.URL | None = None,
|
| 118 |
+
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
| 119 |
+
max_retries: int = DEFAULT_MAX_RETRIES,
|
| 120 |
+
default_headers: Mapping[str, str] | None = None,
|
| 121 |
+
default_query: Mapping[str, object] | None = None,
|
| 122 |
+
http_client: httpx.Client | None = None,
|
| 123 |
+
_strict_response_validation: bool = False,
|
| 124 |
+
) -> None: ...
|
| 125 |
+
|
| 126 |
+
def __init__(
|
| 127 |
+
self,
|
| 128 |
+
*,
|
| 129 |
+
api_version: str | None = None,
|
| 130 |
+
azure_endpoint: str | None = None,
|
| 131 |
+
azure_deployment: str | None = None,
|
| 132 |
+
api_key: str | None = None,
|
| 133 |
+
azure_ad_token: str | None = None,
|
| 134 |
+
azure_ad_token_provider: AzureADTokenProvider | None = None,
|
| 135 |
+
organization: str | None = None,
|
| 136 |
+
project: str | None = None,
|
| 137 |
+
websocket_base_url: str | httpx.URL | None = None,
|
| 138 |
+
base_url: str | None = None,
|
| 139 |
+
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
| 140 |
+
max_retries: int = DEFAULT_MAX_RETRIES,
|
| 141 |
+
default_headers: Mapping[str, str] | None = None,
|
| 142 |
+
default_query: Mapping[str, object] | None = None,
|
| 143 |
+
http_client: httpx.Client | None = None,
|
| 144 |
+
_strict_response_validation: bool = False,
|
| 145 |
+
) -> None:
|
| 146 |
+
"""Construct a new synchronous azure openai client instance.
|
| 147 |
+
|
| 148 |
+
This automatically infers the following arguments from their corresponding environment variables if they are not provided:
|
| 149 |
+
- `api_key` from `AZURE_OPENAI_API_KEY`
|
| 150 |
+
- `organization` from `OPENAI_ORG_ID`
|
| 151 |
+
- `project` from `OPENAI_PROJECT_ID`
|
| 152 |
+
- `azure_ad_token` from `AZURE_OPENAI_AD_TOKEN`
|
| 153 |
+
- `api_version` from `OPENAI_API_VERSION`
|
| 154 |
+
- `azure_endpoint` from `AZURE_OPENAI_ENDPOINT`
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
azure_endpoint: Your Azure endpoint, including the resource, e.g. `https://example-resource.azure.openai.com/`
|
| 158 |
+
|
| 159 |
+
azure_ad_token: Your Azure Active Directory token, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
|
| 160 |
+
|
| 161 |
+
azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.
|
| 162 |
+
|
| 163 |
+
azure_deployment: A model deployment, if given sets the base client URL to include `/deployments/{azure_deployment}`.
|
| 164 |
+
Note: this means you won't be able to use non-deployment endpoints. Not supported with Assistants APIs.
|
| 165 |
+
"""
|
| 166 |
+
if api_key is None:
|
| 167 |
+
api_key = os.environ.get("AZURE_OPENAI_API_KEY")
|
| 168 |
+
|
| 169 |
+
if azure_ad_token is None:
|
| 170 |
+
azure_ad_token = os.environ.get("AZURE_OPENAI_AD_TOKEN")
|
| 171 |
+
|
| 172 |
+
if api_key is None and azure_ad_token is None and azure_ad_token_provider is None:
|
| 173 |
+
raise OpenAIError(
|
| 174 |
+
"Missing credentials. Please pass one of `api_key`, `azure_ad_token`, `azure_ad_token_provider`, or the `AZURE_OPENAI_API_KEY` or `AZURE_OPENAI_AD_TOKEN` environment variables."
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
if api_version is None:
|
| 178 |
+
api_version = os.environ.get("OPENAI_API_VERSION")
|
| 179 |
+
|
| 180 |
+
if api_version is None:
|
| 181 |
+
raise ValueError(
|
| 182 |
+
"Must provide either the `api_version` argument or the `OPENAI_API_VERSION` environment variable"
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
if default_query is None:
|
| 186 |
+
default_query = {"api-version": api_version}
|
| 187 |
+
else:
|
| 188 |
+
default_query = {**default_query, "api-version": api_version}
|
| 189 |
+
|
| 190 |
+
if base_url is None:
|
| 191 |
+
if azure_endpoint is None:
|
| 192 |
+
azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
|
| 193 |
+
|
| 194 |
+
if azure_endpoint is None:
|
| 195 |
+
raise ValueError(
|
| 196 |
+
"Must provide one of the `base_url` or `azure_endpoint` arguments, or the `AZURE_OPENAI_ENDPOINT` environment variable"
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
if azure_deployment is not None:
|
| 200 |
+
base_url = f"{azure_endpoint.rstrip('/')}/openai/deployments/{azure_deployment}"
|
| 201 |
+
else:
|
| 202 |
+
base_url = f"{azure_endpoint.rstrip('/')}/openai"
|
| 203 |
+
else:
|
| 204 |
+
if azure_endpoint is not None:
|
| 205 |
+
raise ValueError("base_url and azure_endpoint are mutually exclusive")
|
| 206 |
+
|
| 207 |
+
if api_key is None:
|
| 208 |
+
# define a sentinel value to avoid any typing issues
|
| 209 |
+
api_key = API_KEY_SENTINEL
|
| 210 |
+
|
| 211 |
+
super().__init__(
|
| 212 |
+
api_key=api_key,
|
| 213 |
+
organization=organization,
|
| 214 |
+
project=project,
|
| 215 |
+
base_url=base_url,
|
| 216 |
+
timeout=timeout,
|
| 217 |
+
max_retries=max_retries,
|
| 218 |
+
default_headers=default_headers,
|
| 219 |
+
default_query=default_query,
|
| 220 |
+
http_client=http_client,
|
| 221 |
+
websocket_base_url=websocket_base_url,
|
| 222 |
+
_strict_response_validation=_strict_response_validation,
|
| 223 |
+
)
|
| 224 |
+
self._api_version = api_version
|
| 225 |
+
self._azure_ad_token = azure_ad_token
|
| 226 |
+
self._azure_ad_token_provider = azure_ad_token_provider
|
| 227 |
+
|
| 228 |
+
@override
|
| 229 |
+
def copy(
|
| 230 |
+
self,
|
| 231 |
+
*,
|
| 232 |
+
api_key: str | None = None,
|
| 233 |
+
organization: str | None = None,
|
| 234 |
+
project: str | None = None,
|
| 235 |
+
websocket_base_url: str | httpx.URL | None = None,
|
| 236 |
+
api_version: str | None = None,
|
| 237 |
+
azure_ad_token: str | None = None,
|
| 238 |
+
azure_ad_token_provider: AzureADTokenProvider | None = None,
|
| 239 |
+
base_url: str | httpx.URL | None = None,
|
| 240 |
+
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
| 241 |
+
http_client: httpx.Client | None = None,
|
| 242 |
+
max_retries: int | NotGiven = NOT_GIVEN,
|
| 243 |
+
default_headers: Mapping[str, str] | None = None,
|
| 244 |
+
set_default_headers: Mapping[str, str] | None = None,
|
| 245 |
+
default_query: Mapping[str, object] | None = None,
|
| 246 |
+
set_default_query: Mapping[str, object] | None = None,
|
| 247 |
+
_extra_kwargs: Mapping[str, Any] = {},
|
| 248 |
+
) -> Self:
|
| 249 |
+
"""
|
| 250 |
+
Create a new client instance re-using the same options given to the current client with optional overriding.
|
| 251 |
+
"""
|
| 252 |
+
return super().copy(
|
| 253 |
+
api_key=api_key,
|
| 254 |
+
organization=organization,
|
| 255 |
+
project=project,
|
| 256 |
+
websocket_base_url=websocket_base_url,
|
| 257 |
+
base_url=base_url,
|
| 258 |
+
timeout=timeout,
|
| 259 |
+
http_client=http_client,
|
| 260 |
+
max_retries=max_retries,
|
| 261 |
+
default_headers=default_headers,
|
| 262 |
+
set_default_headers=set_default_headers,
|
| 263 |
+
default_query=default_query,
|
| 264 |
+
set_default_query=set_default_query,
|
| 265 |
+
_extra_kwargs={
|
| 266 |
+
"api_version": api_version or self._api_version,
|
| 267 |
+
"azure_ad_token": azure_ad_token or self._azure_ad_token,
|
| 268 |
+
"azure_ad_token_provider": azure_ad_token_provider or self._azure_ad_token_provider,
|
| 269 |
+
**_extra_kwargs,
|
| 270 |
+
},
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
with_options = copy
|
| 274 |
+
|
| 275 |
+
def _get_azure_ad_token(self) -> str | None:
|
| 276 |
+
if self._azure_ad_token is not None:
|
| 277 |
+
return self._azure_ad_token
|
| 278 |
+
|
| 279 |
+
provider = self._azure_ad_token_provider
|
| 280 |
+
if provider is not None:
|
| 281 |
+
token = provider()
|
| 282 |
+
if not token or not isinstance(token, str): # pyright: ignore[reportUnnecessaryIsInstance]
|
| 283 |
+
raise ValueError(
|
| 284 |
+
f"Expected `azure_ad_token_provider` argument to return a string but it returned {token}",
|
| 285 |
+
)
|
| 286 |
+
return token
|
| 287 |
+
|
| 288 |
+
return None
|
| 289 |
+
|
| 290 |
+
@override
|
| 291 |
+
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
|
| 292 |
+
headers: dict[str, str | Omit] = {**options.headers} if is_given(options.headers) else {}
|
| 293 |
+
|
| 294 |
+
options = model_copy(options)
|
| 295 |
+
options.headers = headers
|
| 296 |
+
|
| 297 |
+
azure_ad_token = self._get_azure_ad_token()
|
| 298 |
+
if azure_ad_token is not None:
|
| 299 |
+
if headers.get("Authorization") is None:
|
| 300 |
+
headers["Authorization"] = f"Bearer {azure_ad_token}"
|
| 301 |
+
elif self.api_key is not API_KEY_SENTINEL:
|
| 302 |
+
if headers.get("api-key") is None:
|
| 303 |
+
headers["api-key"] = self.api_key
|
| 304 |
+
else:
|
| 305 |
+
# should never be hit
|
| 306 |
+
raise ValueError("Unable to handle auth")
|
| 307 |
+
|
| 308 |
+
return options
|
| 309 |
+
|
| 310 |
+
def _configure_realtime(self, model: str, extra_query: Query) -> tuple[Query, dict[str, str]]:
|
| 311 |
+
auth_headers = {}
|
| 312 |
+
query = {
|
| 313 |
+
**extra_query,
|
| 314 |
+
"api-version": self._api_version,
|
| 315 |
+
"deployment": model,
|
| 316 |
+
}
|
| 317 |
+
if self.api_key != "<missing API key>":
|
| 318 |
+
auth_headers = {"api-key": self.api_key}
|
| 319 |
+
else:
|
| 320 |
+
token = self._get_azure_ad_token()
|
| 321 |
+
if token:
|
| 322 |
+
auth_headers = {"Authorization": f"Bearer {token}"}
|
| 323 |
+
return query, auth_headers
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], AsyncOpenAI):
|
| 327 |
+
@overload
|
| 328 |
+
def __init__(
|
| 329 |
+
self,
|
| 330 |
+
*,
|
| 331 |
+
azure_endpoint: str,
|
| 332 |
+
azure_deployment: str | None = None,
|
| 333 |
+
api_version: str | None = None,
|
| 334 |
+
api_key: str | None = None,
|
| 335 |
+
azure_ad_token: str | None = None,
|
| 336 |
+
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
|
| 337 |
+
organization: str | None = None,
|
| 338 |
+
project: str | None = None,
|
| 339 |
+
websocket_base_url: str | httpx.URL | None = None,
|
| 340 |
+
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
| 341 |
+
max_retries: int = DEFAULT_MAX_RETRIES,
|
| 342 |
+
default_headers: Mapping[str, str] | None = None,
|
| 343 |
+
default_query: Mapping[str, object] | None = None,
|
| 344 |
+
http_client: httpx.AsyncClient | None = None,
|
| 345 |
+
_strict_response_validation: bool = False,
|
| 346 |
+
) -> None: ...
|
| 347 |
+
|
| 348 |
+
@overload
|
| 349 |
+
def __init__(
|
| 350 |
+
self,
|
| 351 |
+
*,
|
| 352 |
+
azure_deployment: str | None = None,
|
| 353 |
+
api_version: str | None = None,
|
| 354 |
+
api_key: str | None = None,
|
| 355 |
+
azure_ad_token: str | None = None,
|
| 356 |
+
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
|
| 357 |
+
organization: str | None = None,
|
| 358 |
+
project: str | None = None,
|
| 359 |
+
websocket_base_url: str | httpx.URL | None = None,
|
| 360 |
+
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
| 361 |
+
max_retries: int = DEFAULT_MAX_RETRIES,
|
| 362 |
+
default_headers: Mapping[str, str] | None = None,
|
| 363 |
+
default_query: Mapping[str, object] | None = None,
|
| 364 |
+
http_client: httpx.AsyncClient | None = None,
|
| 365 |
+
_strict_response_validation: bool = False,
|
| 366 |
+
) -> None: ...
|
| 367 |
+
|
| 368 |
+
@overload
|
| 369 |
+
def __init__(
|
| 370 |
+
self,
|
| 371 |
+
*,
|
| 372 |
+
base_url: str,
|
| 373 |
+
api_version: str | None = None,
|
| 374 |
+
api_key: str | None = None,
|
| 375 |
+
azure_ad_token: str | None = None,
|
| 376 |
+
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
|
| 377 |
+
organization: str | None = None,
|
| 378 |
+
project: str | None = None,
|
| 379 |
+
websocket_base_url: str | httpx.URL | None = None,
|
| 380 |
+
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
| 381 |
+
max_retries: int = DEFAULT_MAX_RETRIES,
|
| 382 |
+
default_headers: Mapping[str, str] | None = None,
|
| 383 |
+
default_query: Mapping[str, object] | None = None,
|
| 384 |
+
http_client: httpx.AsyncClient | None = None,
|
| 385 |
+
_strict_response_validation: bool = False,
|
| 386 |
+
) -> None: ...
|
| 387 |
+
|
| 388 |
+
def __init__(
|
| 389 |
+
self,
|
| 390 |
+
*,
|
| 391 |
+
azure_endpoint: str | None = None,
|
| 392 |
+
azure_deployment: str | None = None,
|
| 393 |
+
api_version: str | None = None,
|
| 394 |
+
api_key: str | None = None,
|
| 395 |
+
azure_ad_token: str | None = None,
|
| 396 |
+
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
|
| 397 |
+
organization: str | None = None,
|
| 398 |
+
project: str | None = None,
|
| 399 |
+
base_url: str | None = None,
|
| 400 |
+
websocket_base_url: str | httpx.URL | None = None,
|
| 401 |
+
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
| 402 |
+
max_retries: int = DEFAULT_MAX_RETRIES,
|
| 403 |
+
default_headers: Mapping[str, str] | None = None,
|
| 404 |
+
default_query: Mapping[str, object] | None = None,
|
| 405 |
+
http_client: httpx.AsyncClient | None = None,
|
| 406 |
+
_strict_response_validation: bool = False,
|
| 407 |
+
) -> None:
|
| 408 |
+
"""Construct a new asynchronous azure openai client instance.
|
| 409 |
+
|
| 410 |
+
This automatically infers the following arguments from their corresponding environment variables if they are not provided:
|
| 411 |
+
- `api_key` from `AZURE_OPENAI_API_KEY`
|
| 412 |
+
- `organization` from `OPENAI_ORG_ID`
|
| 413 |
+
- `project` from `OPENAI_PROJECT_ID`
|
| 414 |
+
- `azure_ad_token` from `AZURE_OPENAI_AD_TOKEN`
|
| 415 |
+
- `api_version` from `OPENAI_API_VERSION`
|
| 416 |
+
- `azure_endpoint` from `AZURE_OPENAI_ENDPOINT`
|
| 417 |
+
|
| 418 |
+
Args:
|
| 419 |
+
azure_endpoint: Your Azure endpoint, including the resource, e.g. `https://example-resource.azure.openai.com/`
|
| 420 |
+
|
| 421 |
+
azure_ad_token: Your Azure Active Directory token, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
|
| 422 |
+
|
| 423 |
+
azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.
|
| 424 |
+
|
| 425 |
+
azure_deployment: A model deployment, if given sets the base client URL to include `/deployments/{azure_deployment}`.
|
| 426 |
+
Note: this means you won't be able to use non-deployment endpoints. Not supported with Assistants APIs.
|
| 427 |
+
"""
|
| 428 |
+
if api_key is None:
|
| 429 |
+
api_key = os.environ.get("AZURE_OPENAI_API_KEY")
|
| 430 |
+
|
| 431 |
+
if azure_ad_token is None:
|
| 432 |
+
azure_ad_token = os.environ.get("AZURE_OPENAI_AD_TOKEN")
|
| 433 |
+
|
| 434 |
+
if api_key is None and azure_ad_token is None and azure_ad_token_provider is None:
|
| 435 |
+
raise OpenAIError(
|
| 436 |
+
"Missing credentials. Please pass one of `api_key`, `azure_ad_token`, `azure_ad_token_provider`, or the `AZURE_OPENAI_API_KEY` or `AZURE_OPENAI_AD_TOKEN` environment variables."
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
if api_version is None:
|
| 440 |
+
api_version = os.environ.get("OPENAI_API_VERSION")
|
| 441 |
+
|
| 442 |
+
if api_version is None:
|
| 443 |
+
raise ValueError(
|
| 444 |
+
"Must provide either the `api_version` argument or the `OPENAI_API_VERSION` environment variable"
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
if default_query is None:
|
| 448 |
+
default_query = {"api-version": api_version}
|
| 449 |
+
else:
|
| 450 |
+
default_query = {**default_query, "api-version": api_version}
|
| 451 |
+
|
| 452 |
+
if base_url is None:
|
| 453 |
+
if azure_endpoint is None:
|
| 454 |
+
azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
|
| 455 |
+
|
| 456 |
+
if azure_endpoint is None:
|
| 457 |
+
raise ValueError(
|
| 458 |
+
"Must provide one of the `base_url` or `azure_endpoint` arguments, or the `AZURE_OPENAI_ENDPOINT` environment variable"
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
if azure_deployment is not None:
|
| 462 |
+
base_url = f"{azure_endpoint.rstrip('/')}/openai/deployments/{azure_deployment}"
|
| 463 |
+
else:
|
| 464 |
+
base_url = f"{azure_endpoint.rstrip('/')}/openai"
|
| 465 |
+
else:
|
| 466 |
+
if azure_endpoint is not None:
|
| 467 |
+
raise ValueError("base_url and azure_endpoint are mutually exclusive")
|
| 468 |
+
|
| 469 |
+
if api_key is None:
|
| 470 |
+
# define a sentinel value to avoid any typing issues
|
| 471 |
+
api_key = API_KEY_SENTINEL
|
| 472 |
+
|
| 473 |
+
super().__init__(
|
| 474 |
+
api_key=api_key,
|
| 475 |
+
organization=organization,
|
| 476 |
+
project=project,
|
| 477 |
+
base_url=base_url,
|
| 478 |
+
timeout=timeout,
|
| 479 |
+
max_retries=max_retries,
|
| 480 |
+
default_headers=default_headers,
|
| 481 |
+
default_query=default_query,
|
| 482 |
+
http_client=http_client,
|
| 483 |
+
websocket_base_url=websocket_base_url,
|
| 484 |
+
_strict_response_validation=_strict_response_validation,
|
| 485 |
+
)
|
| 486 |
+
self._api_version = api_version
|
| 487 |
+
self._azure_ad_token = azure_ad_token
|
| 488 |
+
self._azure_ad_token_provider = azure_ad_token_provider
|
| 489 |
+
|
| 490 |
+
@override
|
| 491 |
+
def copy(
|
| 492 |
+
self,
|
| 493 |
+
*,
|
| 494 |
+
api_key: str | None = None,
|
| 495 |
+
organization: str | None = None,
|
| 496 |
+
project: str | None = None,
|
| 497 |
+
websocket_base_url: str | httpx.URL | None = None,
|
| 498 |
+
api_version: str | None = None,
|
| 499 |
+
azure_ad_token: str | None = None,
|
| 500 |
+
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
|
| 501 |
+
base_url: str | httpx.URL | None = None,
|
| 502 |
+
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
| 503 |
+
http_client: httpx.AsyncClient | None = None,
|
| 504 |
+
max_retries: int | NotGiven = NOT_GIVEN,
|
| 505 |
+
default_headers: Mapping[str, str] | None = None,
|
| 506 |
+
set_default_headers: Mapping[str, str] | None = None,
|
| 507 |
+
default_query: Mapping[str, object] | None = None,
|
| 508 |
+
set_default_query: Mapping[str, object] | None = None,
|
| 509 |
+
_extra_kwargs: Mapping[str, Any] = {},
|
| 510 |
+
) -> Self:
|
| 511 |
+
"""
|
| 512 |
+
Create a new client instance re-using the same options given to the current client with optional overriding.
|
| 513 |
+
"""
|
| 514 |
+
return super().copy(
|
| 515 |
+
api_key=api_key,
|
| 516 |
+
organization=organization,
|
| 517 |
+
project=project,
|
| 518 |
+
websocket_base_url=websocket_base_url,
|
| 519 |
+
base_url=base_url,
|
| 520 |
+
timeout=timeout,
|
| 521 |
+
http_client=http_client,
|
| 522 |
+
max_retries=max_retries,
|
| 523 |
+
default_headers=default_headers,
|
| 524 |
+
set_default_headers=set_default_headers,
|
| 525 |
+
default_query=default_query,
|
| 526 |
+
set_default_query=set_default_query,
|
| 527 |
+
_extra_kwargs={
|
| 528 |
+
"api_version": api_version or self._api_version,
|
| 529 |
+
"azure_ad_token": azure_ad_token or self._azure_ad_token,
|
| 530 |
+
"azure_ad_token_provider": azure_ad_token_provider or self._azure_ad_token_provider,
|
| 531 |
+
**_extra_kwargs,
|
| 532 |
+
},
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
with_options = copy
|
| 536 |
+
|
| 537 |
+
async def _get_azure_ad_token(self) -> str | None:
|
| 538 |
+
if self._azure_ad_token is not None:
|
| 539 |
+
return self._azure_ad_token
|
| 540 |
+
|
| 541 |
+
provider = self._azure_ad_token_provider
|
| 542 |
+
if provider is not None:
|
| 543 |
+
token = provider()
|
| 544 |
+
if inspect.isawaitable(token):
|
| 545 |
+
token = await token
|
| 546 |
+
if not token or not isinstance(cast(Any, token), str):
|
| 547 |
+
raise ValueError(
|
| 548 |
+
f"Expected `azure_ad_token_provider` argument to return a string but it returned {token}",
|
| 549 |
+
)
|
| 550 |
+
return str(token)
|
| 551 |
+
|
| 552 |
+
return None
|
| 553 |
+
|
| 554 |
+
@override
|
| 555 |
+
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
|
| 556 |
+
headers: dict[str, str | Omit] = {**options.headers} if is_given(options.headers) else {}
|
| 557 |
+
|
| 558 |
+
options = model_copy(options)
|
| 559 |
+
options.headers = headers
|
| 560 |
+
|
| 561 |
+
azure_ad_token = await self._get_azure_ad_token()
|
| 562 |
+
if azure_ad_token is not None:
|
| 563 |
+
if headers.get("Authorization") is None:
|
| 564 |
+
headers["Authorization"] = f"Bearer {azure_ad_token}"
|
| 565 |
+
elif self.api_key is not API_KEY_SENTINEL:
|
| 566 |
+
if headers.get("api-key") is None:
|
| 567 |
+
headers["api-key"] = self.api_key
|
| 568 |
+
else:
|
| 569 |
+
# should never be hit
|
| 570 |
+
raise ValueError("Unable to handle auth")
|
| 571 |
+
|
| 572 |
+
return options
|
| 573 |
+
|
| 574 |
+
async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[Query, dict[str, str]]:
|
| 575 |
+
auth_headers = {}
|
| 576 |
+
query = {
|
| 577 |
+
**extra_query,
|
| 578 |
+
"api-version": self._api_version,
|
| 579 |
+
"deployment": model,
|
| 580 |
+
}
|
| 581 |
+
if self.api_key != "<missing API key>":
|
| 582 |
+
auth_headers = {"api-key": self.api_key}
|
| 583 |
+
else:
|
| 584 |
+
token = await self._get_azure_ad_token()
|
| 585 |
+
if token:
|
| 586 |
+
auth_headers = {"Authorization": f"Bearer {token}"}
|
| 587 |
+
return query, auth_headers
|
.venv/lib/python3.11/site-packages/openai/lib/streaming/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._assistants import (
|
| 2 |
+
AssistantEventHandler as AssistantEventHandler,
|
| 3 |
+
AssistantEventHandlerT as AssistantEventHandlerT,
|
| 4 |
+
AssistantStreamManager as AssistantStreamManager,
|
| 5 |
+
AsyncAssistantEventHandler as AsyncAssistantEventHandler,
|
| 6 |
+
AsyncAssistantEventHandlerT as AsyncAssistantEventHandlerT,
|
| 7 |
+
AsyncAssistantStreamManager as AsyncAssistantStreamManager,
|
| 8 |
+
)
|
.venv/lib/python3.11/site-packages/openai/lib/streaming/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (509 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/openai/lib/streaming/__pycache__/_assistants.cpython-311.pyc
ADDED
|
Binary file (46.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/openai/lib/streaming/__pycache__/_deltas.cpython-311.pyc
ADDED
|
Binary file (3.01 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/openai/lib/streaming/_assistants.py
ADDED
|
@@ -0,0 +1,1038 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
from types import TracebackType
|
| 5 |
+
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Callable, Iterable, Iterator, cast
|
| 6 |
+
from typing_extensions import Awaitable, AsyncIterable, AsyncIterator, assert_never
|
| 7 |
+
|
| 8 |
+
import httpx
|
| 9 |
+
|
| 10 |
+
from ..._utils import is_dict, is_list, consume_sync_iterator, consume_async_iterator
|
| 11 |
+
from ..._compat import model_dump
|
| 12 |
+
from ..._models import construct_type
|
| 13 |
+
from ..._streaming import Stream, AsyncStream
|
| 14 |
+
from ...types.beta import AssistantStreamEvent
|
| 15 |
+
from ...types.beta.threads import (
|
| 16 |
+
Run,
|
| 17 |
+
Text,
|
| 18 |
+
Message,
|
| 19 |
+
ImageFile,
|
| 20 |
+
TextDelta,
|
| 21 |
+
MessageDelta,
|
| 22 |
+
MessageContent,
|
| 23 |
+
MessageContentDelta,
|
| 24 |
+
)
|
| 25 |
+
from ...types.beta.threads.runs import RunStep, ToolCall, RunStepDelta, ToolCallDelta
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class AssistantEventHandler:
|
| 29 |
+
text_deltas: Iterable[str]
|
| 30 |
+
"""Iterator over just the text deltas in the stream.
|
| 31 |
+
|
| 32 |
+
This corresponds to the `thread.message.delta` event
|
| 33 |
+
in the API.
|
| 34 |
+
|
| 35 |
+
```py
|
| 36 |
+
for text in stream.text_deltas:
|
| 37 |
+
print(text, end="", flush=True)
|
| 38 |
+
print()
|
| 39 |
+
```
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self) -> None:
|
| 43 |
+
self._current_event: AssistantStreamEvent | None = None
|
| 44 |
+
self._current_message_content_index: int | None = None
|
| 45 |
+
self._current_message_content: MessageContent | None = None
|
| 46 |
+
self._current_tool_call_index: int | None = None
|
| 47 |
+
self._current_tool_call: ToolCall | None = None
|
| 48 |
+
self.__current_run_step_id: str | None = None
|
| 49 |
+
self.__current_run: Run | None = None
|
| 50 |
+
self.__run_step_snapshots: dict[str, RunStep] = {}
|
| 51 |
+
self.__message_snapshots: dict[str, Message] = {}
|
| 52 |
+
self.__current_message_snapshot: Message | None = None
|
| 53 |
+
|
| 54 |
+
self.text_deltas = self.__text_deltas__()
|
| 55 |
+
self._iterator = self.__stream__()
|
| 56 |
+
self.__stream: Stream[AssistantStreamEvent] | None = None
|
| 57 |
+
|
| 58 |
+
def _init(self, stream: Stream[AssistantStreamEvent]) -> None:
|
| 59 |
+
if self.__stream:
|
| 60 |
+
raise RuntimeError(
|
| 61 |
+
"A single event handler cannot be shared between multiple streams; You will need to construct a new event handler instance"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
self.__stream = stream
|
| 65 |
+
|
| 66 |
+
def __next__(self) -> AssistantStreamEvent:
|
| 67 |
+
return self._iterator.__next__()
|
| 68 |
+
|
| 69 |
+
def __iter__(self) -> Iterator[AssistantStreamEvent]:
|
| 70 |
+
for item in self._iterator:
|
| 71 |
+
yield item
|
| 72 |
+
|
| 73 |
+
@property
|
| 74 |
+
def current_event(self) -> AssistantStreamEvent | None:
|
| 75 |
+
return self._current_event
|
| 76 |
+
|
| 77 |
+
@property
|
| 78 |
+
def current_run(self) -> Run | None:
|
| 79 |
+
return self.__current_run
|
| 80 |
+
|
| 81 |
+
@property
|
| 82 |
+
def current_run_step_snapshot(self) -> RunStep | None:
|
| 83 |
+
if not self.__current_run_step_id:
|
| 84 |
+
return None
|
| 85 |
+
|
| 86 |
+
return self.__run_step_snapshots[self.__current_run_step_id]
|
| 87 |
+
|
| 88 |
+
@property
|
| 89 |
+
def current_message_snapshot(self) -> Message | None:
|
| 90 |
+
return self.__current_message_snapshot
|
| 91 |
+
|
| 92 |
+
def close(self) -> None:
|
| 93 |
+
"""
|
| 94 |
+
Close the response and release the connection.
|
| 95 |
+
|
| 96 |
+
Automatically called when the context manager exits.
|
| 97 |
+
"""
|
| 98 |
+
if self.__stream:
|
| 99 |
+
self.__stream.close()
|
| 100 |
+
|
| 101 |
+
def until_done(self) -> None:
|
| 102 |
+
"""Waits until the stream has been consumed"""
|
| 103 |
+
consume_sync_iterator(self)
|
| 104 |
+
|
| 105 |
+
def get_final_run(self) -> Run:
|
| 106 |
+
"""Wait for the stream to finish and returns the completed Run object"""
|
| 107 |
+
self.until_done()
|
| 108 |
+
|
| 109 |
+
if not self.__current_run:
|
| 110 |
+
raise RuntimeError("No final run object found")
|
| 111 |
+
|
| 112 |
+
return self.__current_run
|
| 113 |
+
|
| 114 |
+
def get_final_run_steps(self) -> list[RunStep]:
|
| 115 |
+
"""Wait for the stream to finish and returns the steps taken in this run"""
|
| 116 |
+
self.until_done()
|
| 117 |
+
|
| 118 |
+
if not self.__run_step_snapshots:
|
| 119 |
+
raise RuntimeError("No run steps found")
|
| 120 |
+
|
| 121 |
+
return [step for step in self.__run_step_snapshots.values()]
|
| 122 |
+
|
| 123 |
+
def get_final_messages(self) -> list[Message]:
|
| 124 |
+
"""Wait for the stream to finish and returns the messages emitted in this run"""
|
| 125 |
+
self.until_done()
|
| 126 |
+
|
| 127 |
+
if not self.__message_snapshots:
|
| 128 |
+
raise RuntimeError("No messages found")
|
| 129 |
+
|
| 130 |
+
return [message for message in self.__message_snapshots.values()]
|
| 131 |
+
|
| 132 |
+
def __text_deltas__(self) -> Iterator[str]:
|
| 133 |
+
for event in self:
|
| 134 |
+
if event.event != "thread.message.delta":
|
| 135 |
+
continue
|
| 136 |
+
|
| 137 |
+
for content_delta in event.data.delta.content or []:
|
| 138 |
+
if content_delta.type == "text" and content_delta.text and content_delta.text.value:
|
| 139 |
+
yield content_delta.text.value
|
| 140 |
+
|
| 141 |
+
# event handlers
|
| 142 |
+
|
| 143 |
+
def on_end(self) -> None:
|
| 144 |
+
"""Fires when the stream has finished.
|
| 145 |
+
|
| 146 |
+
This happens if the stream is read to completion
|
| 147 |
+
or if an exception occurs during iteration.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
def on_event(self, event: AssistantStreamEvent) -> None:
|
| 151 |
+
"""Callback that is fired for every Server-Sent-Event"""
|
| 152 |
+
|
| 153 |
+
def on_run_step_created(self, run_step: RunStep) -> None:
|
| 154 |
+
"""Callback that is fired when a run step is created"""
|
| 155 |
+
|
| 156 |
+
def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None:
|
| 157 |
+
"""Callback that is fired whenever a run step delta is returned from the API
|
| 158 |
+
|
| 159 |
+
The first argument is just the delta as sent by the API and the second argument
|
| 160 |
+
is the accumulated snapshot of the run step. For example, a tool calls event may
|
| 161 |
+
look like this:
|
| 162 |
+
|
| 163 |
+
# delta
|
| 164 |
+
tool_calls=[
|
| 165 |
+
RunStepDeltaToolCallsCodeInterpreter(
|
| 166 |
+
index=0,
|
| 167 |
+
type='code_interpreter',
|
| 168 |
+
id=None,
|
| 169 |
+
code_interpreter=CodeInterpreter(input=' sympy', outputs=None)
|
| 170 |
+
)
|
| 171 |
+
]
|
| 172 |
+
# snapshot
|
| 173 |
+
tool_calls=[
|
| 174 |
+
CodeToolCall(
|
| 175 |
+
id='call_wKayJlcYV12NiadiZuJXxcfx',
|
| 176 |
+
code_interpreter=CodeInterpreter(input='from sympy', outputs=[]),
|
| 177 |
+
type='code_interpreter',
|
| 178 |
+
index=0
|
| 179 |
+
)
|
| 180 |
+
],
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
def on_run_step_done(self, run_step: RunStep) -> None:
|
| 184 |
+
"""Callback that is fired when a run step is completed"""
|
| 185 |
+
|
| 186 |
+
def on_tool_call_created(self, tool_call: ToolCall) -> None:
|
| 187 |
+
"""Callback that is fired when a tool call is created"""
|
| 188 |
+
|
| 189 |
+
def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None:
|
| 190 |
+
"""Callback that is fired when a tool call delta is encountered"""
|
| 191 |
+
|
| 192 |
+
def on_tool_call_done(self, tool_call: ToolCall) -> None:
|
| 193 |
+
"""Callback that is fired when a tool call delta is encountered"""
|
| 194 |
+
|
| 195 |
+
def on_exception(self, exception: Exception) -> None:
|
| 196 |
+
"""Fired whenever an exception happens during streaming"""
|
| 197 |
+
|
| 198 |
+
def on_timeout(self) -> None:
|
| 199 |
+
"""Fires if the request times out"""
|
| 200 |
+
|
| 201 |
+
def on_message_created(self, message: Message) -> None:
|
| 202 |
+
"""Callback that is fired when a message is created"""
|
| 203 |
+
|
| 204 |
+
def on_message_delta(self, delta: MessageDelta, snapshot: Message) -> None:
|
| 205 |
+
"""Callback that is fired whenever a message delta is returned from the API
|
| 206 |
+
|
| 207 |
+
The first argument is just the delta as sent by the API and the second argument
|
| 208 |
+
is the accumulated snapshot of the message. For example, a text content event may
|
| 209 |
+
look like this:
|
| 210 |
+
|
| 211 |
+
# delta
|
| 212 |
+
MessageDeltaText(
|
| 213 |
+
index=0,
|
| 214 |
+
type='text',
|
| 215 |
+
text=Text(
|
| 216 |
+
value=' Jane'
|
| 217 |
+
),
|
| 218 |
+
)
|
| 219 |
+
# snapshot
|
| 220 |
+
MessageContentText(
|
| 221 |
+
index=0,
|
| 222 |
+
type='text',
|
| 223 |
+
text=Text(
|
| 224 |
+
value='Certainly, Jane'
|
| 225 |
+
),
|
| 226 |
+
)
|
| 227 |
+
"""
|
| 228 |
+
|
| 229 |
+
def on_message_done(self, message: Message) -> None:
|
| 230 |
+
"""Callback that is fired when a message is completed"""
|
| 231 |
+
|
| 232 |
+
def on_text_created(self, text: Text) -> None:
|
| 233 |
+
"""Callback that is fired when a text content block is created"""
|
| 234 |
+
|
| 235 |
+
def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None:
|
| 236 |
+
"""Callback that is fired whenever a text content delta is returned
|
| 237 |
+
by the API.
|
| 238 |
+
|
| 239 |
+
The first argument is just the delta as sent by the API and the second argument
|
| 240 |
+
is the accumulated snapshot of the text. For example:
|
| 241 |
+
|
| 242 |
+
on_text_delta(TextDelta(value="The"), Text(value="The")),
|
| 243 |
+
on_text_delta(TextDelta(value=" solution"), Text(value="The solution")),
|
| 244 |
+
on_text_delta(TextDelta(value=" to"), Text(value="The solution to")),
|
| 245 |
+
on_text_delta(TextDelta(value=" the"), Text(value="The solution to the")),
|
| 246 |
+
on_text_delta(TextDelta(value=" equation"), Text(value="The solution to the equation")),
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
def on_text_done(self, text: Text) -> None:
|
| 250 |
+
"""Callback that is fired when a text content block is finished"""
|
| 251 |
+
|
| 252 |
+
def on_image_file_done(self, image_file: ImageFile) -> None:
|
| 253 |
+
"""Callback that is fired when an image file block is finished"""
|
| 254 |
+
|
| 255 |
+
def _emit_sse_event(self, event: AssistantStreamEvent) -> None:
|
| 256 |
+
self._current_event = event
|
| 257 |
+
self.on_event(event)
|
| 258 |
+
|
| 259 |
+
self.__current_message_snapshot, new_content = accumulate_event(
|
| 260 |
+
event=event,
|
| 261 |
+
current_message_snapshot=self.__current_message_snapshot,
|
| 262 |
+
)
|
| 263 |
+
if self.__current_message_snapshot is not None:
|
| 264 |
+
self.__message_snapshots[self.__current_message_snapshot.id] = self.__current_message_snapshot
|
| 265 |
+
|
| 266 |
+
accumulate_run_step(
|
| 267 |
+
event=event,
|
| 268 |
+
run_step_snapshots=self.__run_step_snapshots,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
for content_delta in new_content:
|
| 272 |
+
assert self.__current_message_snapshot is not None
|
| 273 |
+
|
| 274 |
+
block = self.__current_message_snapshot.content[content_delta.index]
|
| 275 |
+
if block.type == "text":
|
| 276 |
+
self.on_text_created(block.text)
|
| 277 |
+
|
| 278 |
+
if (
|
| 279 |
+
event.event == "thread.run.completed"
|
| 280 |
+
or event.event == "thread.run.cancelled"
|
| 281 |
+
or event.event == "thread.run.expired"
|
| 282 |
+
or event.event == "thread.run.failed"
|
| 283 |
+
or event.event == "thread.run.requires_action"
|
| 284 |
+
or event.event == "thread.run.incomplete"
|
| 285 |
+
):
|
| 286 |
+
self.__current_run = event.data
|
| 287 |
+
if self._current_tool_call:
|
| 288 |
+
self.on_tool_call_done(self._current_tool_call)
|
| 289 |
+
elif (
|
| 290 |
+
event.event == "thread.run.created"
|
| 291 |
+
or event.event == "thread.run.in_progress"
|
| 292 |
+
or event.event == "thread.run.cancelling"
|
| 293 |
+
or event.event == "thread.run.queued"
|
| 294 |
+
):
|
| 295 |
+
self.__current_run = event.data
|
| 296 |
+
elif event.event == "thread.message.created":
|
| 297 |
+
self.on_message_created(event.data)
|
| 298 |
+
elif event.event == "thread.message.delta":
|
| 299 |
+
snapshot = self.__current_message_snapshot
|
| 300 |
+
assert snapshot is not None
|
| 301 |
+
|
| 302 |
+
message_delta = event.data.delta
|
| 303 |
+
if message_delta.content is not None:
|
| 304 |
+
for content_delta in message_delta.content:
|
| 305 |
+
if content_delta.type == "text" and content_delta.text:
|
| 306 |
+
snapshot_content = snapshot.content[content_delta.index]
|
| 307 |
+
assert snapshot_content.type == "text"
|
| 308 |
+
self.on_text_delta(content_delta.text, snapshot_content.text)
|
| 309 |
+
|
| 310 |
+
# If the delta is for a new message content:
|
| 311 |
+
# - emit on_text_done/on_image_file_done for the previous message content
|
| 312 |
+
# - emit on_text_created/on_image_created for the new message content
|
| 313 |
+
if content_delta.index != self._current_message_content_index:
|
| 314 |
+
if self._current_message_content is not None:
|
| 315 |
+
if self._current_message_content.type == "text":
|
| 316 |
+
self.on_text_done(self._current_message_content.text)
|
| 317 |
+
elif self._current_message_content.type == "image_file":
|
| 318 |
+
self.on_image_file_done(self._current_message_content.image_file)
|
| 319 |
+
|
| 320 |
+
self._current_message_content_index = content_delta.index
|
| 321 |
+
self._current_message_content = snapshot.content[content_delta.index]
|
| 322 |
+
|
| 323 |
+
# Update the current_message_content (delta event is correctly emitted already)
|
| 324 |
+
self._current_message_content = snapshot.content[content_delta.index]
|
| 325 |
+
|
| 326 |
+
self.on_message_delta(event.data.delta, snapshot)
|
| 327 |
+
elif event.event == "thread.message.completed" or event.event == "thread.message.incomplete":
|
| 328 |
+
self.__current_message_snapshot = event.data
|
| 329 |
+
self.__message_snapshots[event.data.id] = event.data
|
| 330 |
+
|
| 331 |
+
if self._current_message_content_index is not None:
|
| 332 |
+
content = event.data.content[self._current_message_content_index]
|
| 333 |
+
if content.type == "text":
|
| 334 |
+
self.on_text_done(content.text)
|
| 335 |
+
elif content.type == "image_file":
|
| 336 |
+
self.on_image_file_done(content.image_file)
|
| 337 |
+
|
| 338 |
+
self.on_message_done(event.data)
|
| 339 |
+
elif event.event == "thread.run.step.created":
|
| 340 |
+
self.__current_run_step_id = event.data.id
|
| 341 |
+
self.on_run_step_created(event.data)
|
| 342 |
+
elif event.event == "thread.run.step.in_progress":
|
| 343 |
+
self.__current_run_step_id = event.data.id
|
| 344 |
+
elif event.event == "thread.run.step.delta":
|
| 345 |
+
step_snapshot = self.__run_step_snapshots[event.data.id]
|
| 346 |
+
|
| 347 |
+
run_step_delta = event.data.delta
|
| 348 |
+
if (
|
| 349 |
+
run_step_delta.step_details
|
| 350 |
+
and run_step_delta.step_details.type == "tool_calls"
|
| 351 |
+
and run_step_delta.step_details.tool_calls is not None
|
| 352 |
+
):
|
| 353 |
+
assert step_snapshot.step_details.type == "tool_calls"
|
| 354 |
+
for tool_call_delta in run_step_delta.step_details.tool_calls:
|
| 355 |
+
if tool_call_delta.index == self._current_tool_call_index:
|
| 356 |
+
self.on_tool_call_delta(
|
| 357 |
+
tool_call_delta,
|
| 358 |
+
step_snapshot.step_details.tool_calls[tool_call_delta.index],
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# If the delta is for a new tool call:
|
| 362 |
+
# - emit on_tool_call_done for the previous tool_call
|
| 363 |
+
# - emit on_tool_call_created for the new tool_call
|
| 364 |
+
if tool_call_delta.index != self._current_tool_call_index:
|
| 365 |
+
if self._current_tool_call is not None:
|
| 366 |
+
self.on_tool_call_done(self._current_tool_call)
|
| 367 |
+
|
| 368 |
+
self._current_tool_call_index = tool_call_delta.index
|
| 369 |
+
self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index]
|
| 370 |
+
self.on_tool_call_created(self._current_tool_call)
|
| 371 |
+
|
| 372 |
+
# Update the current_tool_call (delta event is correctly emitted already)
|
| 373 |
+
self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index]
|
| 374 |
+
|
| 375 |
+
self.on_run_step_delta(
|
| 376 |
+
event.data.delta,
|
| 377 |
+
step_snapshot,
|
| 378 |
+
)
|
| 379 |
+
elif (
|
| 380 |
+
event.event == "thread.run.step.completed"
|
| 381 |
+
or event.event == "thread.run.step.cancelled"
|
| 382 |
+
or event.event == "thread.run.step.expired"
|
| 383 |
+
or event.event == "thread.run.step.failed"
|
| 384 |
+
):
|
| 385 |
+
if self._current_tool_call:
|
| 386 |
+
self.on_tool_call_done(self._current_tool_call)
|
| 387 |
+
|
| 388 |
+
self.on_run_step_done(event.data)
|
| 389 |
+
self.__current_run_step_id = None
|
| 390 |
+
elif event.event == "thread.created" or event.event == "thread.message.in_progress" or event.event == "error":
|
| 391 |
+
# currently no special handling
|
| 392 |
+
...
|
| 393 |
+
else:
|
| 394 |
+
# we only want to error at build-time
|
| 395 |
+
if TYPE_CHECKING: # type: ignore[unreachable]
|
| 396 |
+
assert_never(event)
|
| 397 |
+
|
| 398 |
+
self._current_event = None
|
| 399 |
+
|
| 400 |
+
def __stream__(self) -> Iterator[AssistantStreamEvent]:
|
| 401 |
+
stream = self.__stream
|
| 402 |
+
if not stream:
|
| 403 |
+
raise RuntimeError("Stream has not been started yet")
|
| 404 |
+
|
| 405 |
+
try:
|
| 406 |
+
for event in stream:
|
| 407 |
+
self._emit_sse_event(event)
|
| 408 |
+
|
| 409 |
+
yield event
|
| 410 |
+
except (httpx.TimeoutException, asyncio.TimeoutError) as exc:
|
| 411 |
+
self.on_timeout()
|
| 412 |
+
self.on_exception(exc)
|
| 413 |
+
raise
|
| 414 |
+
except Exception as exc:
|
| 415 |
+
self.on_exception(exc)
|
| 416 |
+
raise
|
| 417 |
+
finally:
|
| 418 |
+
self.on_end()
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
AssistantEventHandlerT = TypeVar("AssistantEventHandlerT", bound=AssistantEventHandler)
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
class AssistantStreamManager(Generic[AssistantEventHandlerT]):
|
| 425 |
+
"""Wrapper over AssistantStreamEventHandler that is returned by `.stream()`
|
| 426 |
+
so that a context manager can be used.
|
| 427 |
+
|
| 428 |
+
```py
|
| 429 |
+
with client.threads.create_and_run_stream(...) as stream:
|
| 430 |
+
for event in stream:
|
| 431 |
+
...
|
| 432 |
+
```
|
| 433 |
+
"""
|
| 434 |
+
|
| 435 |
+
def __init__(
|
| 436 |
+
self,
|
| 437 |
+
api_request: Callable[[], Stream[AssistantStreamEvent]],
|
| 438 |
+
*,
|
| 439 |
+
event_handler: AssistantEventHandlerT,
|
| 440 |
+
) -> None:
|
| 441 |
+
self.__stream: Stream[AssistantStreamEvent] | None = None
|
| 442 |
+
self.__event_handler = event_handler
|
| 443 |
+
self.__api_request = api_request
|
| 444 |
+
|
| 445 |
+
def __enter__(self) -> AssistantEventHandlerT:
|
| 446 |
+
self.__stream = self.__api_request()
|
| 447 |
+
self.__event_handler._init(self.__stream)
|
| 448 |
+
return self.__event_handler
|
| 449 |
+
|
| 450 |
+
def __exit__(
|
| 451 |
+
self,
|
| 452 |
+
exc_type: type[BaseException] | None,
|
| 453 |
+
exc: BaseException | None,
|
| 454 |
+
exc_tb: TracebackType | None,
|
| 455 |
+
) -> None:
|
| 456 |
+
if self.__stream is not None:
|
| 457 |
+
self.__stream.close()
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
class AsyncAssistantEventHandler:
|
| 461 |
+
text_deltas: AsyncIterable[str]
|
| 462 |
+
"""Iterator over just the text deltas in the stream.
|
| 463 |
+
|
| 464 |
+
This corresponds to the `thread.message.delta` event
|
| 465 |
+
in the API.
|
| 466 |
+
|
| 467 |
+
```py
|
| 468 |
+
async for text in stream.text_deltas:
|
| 469 |
+
print(text, end="", flush=True)
|
| 470 |
+
print()
|
| 471 |
+
```
|
| 472 |
+
"""
|
| 473 |
+
|
| 474 |
+
def __init__(self) -> None:
|
| 475 |
+
self._current_event: AssistantStreamEvent | None = None
|
| 476 |
+
self._current_message_content_index: int | None = None
|
| 477 |
+
self._current_message_content: MessageContent | None = None
|
| 478 |
+
self._current_tool_call_index: int | None = None
|
| 479 |
+
self._current_tool_call: ToolCall | None = None
|
| 480 |
+
self.__current_run_step_id: str | None = None
|
| 481 |
+
self.__current_run: Run | None = None
|
| 482 |
+
self.__run_step_snapshots: dict[str, RunStep] = {}
|
| 483 |
+
self.__message_snapshots: dict[str, Message] = {}
|
| 484 |
+
self.__current_message_snapshot: Message | None = None
|
| 485 |
+
|
| 486 |
+
self.text_deltas = self.__text_deltas__()
|
| 487 |
+
self._iterator = self.__stream__()
|
| 488 |
+
self.__stream: AsyncStream[AssistantStreamEvent] | None = None
|
| 489 |
+
|
| 490 |
+
def _init(self, stream: AsyncStream[AssistantStreamEvent]) -> None:
|
| 491 |
+
if self.__stream:
|
| 492 |
+
raise RuntimeError(
|
| 493 |
+
"A single event handler cannot be shared between multiple streams; You will need to construct a new event handler instance"
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
self.__stream = stream
|
| 497 |
+
|
| 498 |
+
async def __anext__(self) -> AssistantStreamEvent:
|
| 499 |
+
return await self._iterator.__anext__()
|
| 500 |
+
|
| 501 |
+
async def __aiter__(self) -> AsyncIterator[AssistantStreamEvent]:
|
| 502 |
+
async for item in self._iterator:
|
| 503 |
+
yield item
|
| 504 |
+
|
| 505 |
+
async def close(self) -> None:
|
| 506 |
+
"""
|
| 507 |
+
Close the response and release the connection.
|
| 508 |
+
|
| 509 |
+
Automatically called when the context manager exits.
|
| 510 |
+
"""
|
| 511 |
+
if self.__stream:
|
| 512 |
+
await self.__stream.close()
|
| 513 |
+
|
| 514 |
+
@property
|
| 515 |
+
def current_event(self) -> AssistantStreamEvent | None:
|
| 516 |
+
return self._current_event
|
| 517 |
+
|
| 518 |
+
@property
|
| 519 |
+
def current_run(self) -> Run | None:
|
| 520 |
+
return self.__current_run
|
| 521 |
+
|
| 522 |
+
@property
|
| 523 |
+
def current_run_step_snapshot(self) -> RunStep | None:
|
| 524 |
+
if not self.__current_run_step_id:
|
| 525 |
+
return None
|
| 526 |
+
|
| 527 |
+
return self.__run_step_snapshots[self.__current_run_step_id]
|
| 528 |
+
|
| 529 |
+
@property
|
| 530 |
+
def current_message_snapshot(self) -> Message | None:
|
| 531 |
+
return self.__current_message_snapshot
|
| 532 |
+
|
| 533 |
+
async def until_done(self) -> None:
|
| 534 |
+
"""Waits until the stream has been consumed"""
|
| 535 |
+
await consume_async_iterator(self)
|
| 536 |
+
|
| 537 |
+
async def get_final_run(self) -> Run:
|
| 538 |
+
"""Wait for the stream to finish and returns the completed Run object"""
|
| 539 |
+
await self.until_done()
|
| 540 |
+
|
| 541 |
+
if not self.__current_run:
|
| 542 |
+
raise RuntimeError("No final run object found")
|
| 543 |
+
|
| 544 |
+
return self.__current_run
|
| 545 |
+
|
| 546 |
+
async def get_final_run_steps(self) -> list[RunStep]:
|
| 547 |
+
"""Wait for the stream to finish and returns the steps taken in this run"""
|
| 548 |
+
await self.until_done()
|
| 549 |
+
|
| 550 |
+
if not self.__run_step_snapshots:
|
| 551 |
+
raise RuntimeError("No run steps found")
|
| 552 |
+
|
| 553 |
+
return [step for step in self.__run_step_snapshots.values()]
|
| 554 |
+
|
| 555 |
+
async def get_final_messages(self) -> list[Message]:
|
| 556 |
+
"""Wait for the stream to finish and returns the messages emitted in this run"""
|
| 557 |
+
await self.until_done()
|
| 558 |
+
|
| 559 |
+
if not self.__message_snapshots:
|
| 560 |
+
raise RuntimeError("No messages found")
|
| 561 |
+
|
| 562 |
+
return [message for message in self.__message_snapshots.values()]
|
| 563 |
+
|
| 564 |
+
async def __text_deltas__(self) -> AsyncIterator[str]:
|
| 565 |
+
async for event in self:
|
| 566 |
+
if event.event != "thread.message.delta":
|
| 567 |
+
continue
|
| 568 |
+
|
| 569 |
+
for content_delta in event.data.delta.content or []:
|
| 570 |
+
if content_delta.type == "text" and content_delta.text and content_delta.text.value:
|
| 571 |
+
yield content_delta.text.value
|
| 572 |
+
|
| 573 |
+
# event handlers
|
| 574 |
+
|
| 575 |
+
async def on_end(self) -> None:
|
| 576 |
+
"""Fires when the stream has finished.
|
| 577 |
+
|
| 578 |
+
This happens if the stream is read to completion
|
| 579 |
+
or if an exception occurs during iteration.
|
| 580 |
+
"""
|
| 581 |
+
|
| 582 |
+
async def on_event(self, event: AssistantStreamEvent) -> None:
|
| 583 |
+
"""Callback that is fired for every Server-Sent-Event"""
|
| 584 |
+
|
| 585 |
+
async def on_run_step_created(self, run_step: RunStep) -> None:
|
| 586 |
+
"""Callback that is fired when a run step is created"""
|
| 587 |
+
|
| 588 |
+
async def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None:
|
| 589 |
+
"""Callback that is fired whenever a run step delta is returned from the API
|
| 590 |
+
|
| 591 |
+
The first argument is just the delta as sent by the API and the second argument
|
| 592 |
+
is the accumulated snapshot of the run step. For example, a tool calls event may
|
| 593 |
+
look like this:
|
| 594 |
+
|
| 595 |
+
# delta
|
| 596 |
+
tool_calls=[
|
| 597 |
+
RunStepDeltaToolCallsCodeInterpreter(
|
| 598 |
+
index=0,
|
| 599 |
+
type='code_interpreter',
|
| 600 |
+
id=None,
|
| 601 |
+
code_interpreter=CodeInterpreter(input=' sympy', outputs=None)
|
| 602 |
+
)
|
| 603 |
+
]
|
| 604 |
+
# snapshot
|
| 605 |
+
tool_calls=[
|
| 606 |
+
CodeToolCall(
|
| 607 |
+
id='call_wKayJlcYV12NiadiZuJXxcfx',
|
| 608 |
+
code_interpreter=CodeInterpreter(input='from sympy', outputs=[]),
|
| 609 |
+
type='code_interpreter',
|
| 610 |
+
index=0
|
| 611 |
+
)
|
| 612 |
+
],
|
| 613 |
+
"""
|
| 614 |
+
|
| 615 |
+
async def on_run_step_done(self, run_step: RunStep) -> None:
|
| 616 |
+
"""Callback that is fired when a run step is completed"""
|
| 617 |
+
|
| 618 |
+
async def on_tool_call_created(self, tool_call: ToolCall) -> None:
|
| 619 |
+
"""Callback that is fired when a tool call is created"""
|
| 620 |
+
|
| 621 |
+
async def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None:
|
| 622 |
+
"""Callback that is fired when a tool call delta is encountered"""
|
| 623 |
+
|
| 624 |
+
async def on_tool_call_done(self, tool_call: ToolCall) -> None:
|
| 625 |
+
"""Callback that is fired when a tool call delta is encountered"""
|
| 626 |
+
|
| 627 |
+
async def on_exception(self, exception: Exception) -> None:
|
| 628 |
+
"""Fired whenever an exception happens during streaming"""
|
| 629 |
+
|
| 630 |
+
async def on_timeout(self) -> None:
|
| 631 |
+
"""Fires if the request times out"""
|
| 632 |
+
|
| 633 |
+
async def on_message_created(self, message: Message) -> None:
|
| 634 |
+
"""Callback that is fired when a message is created"""
|
| 635 |
+
|
| 636 |
+
async def on_message_delta(self, delta: MessageDelta, snapshot: Message) -> None:
|
| 637 |
+
"""Callback that is fired whenever a message delta is returned from the API
|
| 638 |
+
|
| 639 |
+
The first argument is just the delta as sent by the API and the second argument
|
| 640 |
+
is the accumulated snapshot of the message. For example, a text content event may
|
| 641 |
+
look like this:
|
| 642 |
+
|
| 643 |
+
# delta
|
| 644 |
+
MessageDeltaText(
|
| 645 |
+
index=0,
|
| 646 |
+
type='text',
|
| 647 |
+
text=Text(
|
| 648 |
+
value=' Jane'
|
| 649 |
+
),
|
| 650 |
+
)
|
| 651 |
+
# snapshot
|
| 652 |
+
MessageContentText(
|
| 653 |
+
index=0,
|
| 654 |
+
type='text',
|
| 655 |
+
text=Text(
|
| 656 |
+
value='Certainly, Jane'
|
| 657 |
+
),
|
| 658 |
+
)
|
| 659 |
+
"""
|
| 660 |
+
|
| 661 |
+
async def on_message_done(self, message: Message) -> None:
|
| 662 |
+
"""Callback that is fired when a message is completed"""
|
| 663 |
+
|
| 664 |
+
async def on_text_created(self, text: Text) -> None:
|
| 665 |
+
"""Callback that is fired when a text content block is created"""
|
| 666 |
+
|
| 667 |
+
async def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None:
|
| 668 |
+
"""Callback that is fired whenever a text content delta is returned
|
| 669 |
+
by the API.
|
| 670 |
+
|
| 671 |
+
The first argument is just the delta as sent by the API and the second argument
|
| 672 |
+
is the accumulated snapshot of the text. For example:
|
| 673 |
+
|
| 674 |
+
on_text_delta(TextDelta(value="The"), Text(value="The")),
|
| 675 |
+
on_text_delta(TextDelta(value=" solution"), Text(value="The solution")),
|
| 676 |
+
on_text_delta(TextDelta(value=" to"), Text(value="The solution to")),
|
| 677 |
+
on_text_delta(TextDelta(value=" the"), Text(value="The solution to the")),
|
| 678 |
+
on_text_delta(TextDelta(value=" equation"), Text(value="The solution to the equivalent")),
|
| 679 |
+
"""
|
| 680 |
+
|
| 681 |
+
async def on_text_done(self, text: Text) -> None:
|
| 682 |
+
"""Callback that is fired when a text content block is finished"""
|
| 683 |
+
|
| 684 |
+
async def on_image_file_done(self, image_file: ImageFile) -> None:
|
| 685 |
+
"""Callback that is fired when an image file block is finished"""
|
| 686 |
+
|
| 687 |
+
async def _emit_sse_event(self, event: AssistantStreamEvent) -> None:
|
| 688 |
+
self._current_event = event
|
| 689 |
+
await self.on_event(event)
|
| 690 |
+
|
| 691 |
+
self.__current_message_snapshot, new_content = accumulate_event(
|
| 692 |
+
event=event,
|
| 693 |
+
current_message_snapshot=self.__current_message_snapshot,
|
| 694 |
+
)
|
| 695 |
+
if self.__current_message_snapshot is not None:
|
| 696 |
+
self.__message_snapshots[self.__current_message_snapshot.id] = self.__current_message_snapshot
|
| 697 |
+
|
| 698 |
+
accumulate_run_step(
|
| 699 |
+
event=event,
|
| 700 |
+
run_step_snapshots=self.__run_step_snapshots,
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
for content_delta in new_content:
|
| 704 |
+
assert self.__current_message_snapshot is not None
|
| 705 |
+
|
| 706 |
+
block = self.__current_message_snapshot.content[content_delta.index]
|
| 707 |
+
if block.type == "text":
|
| 708 |
+
await self.on_text_created(block.text)
|
| 709 |
+
|
| 710 |
+
if (
|
| 711 |
+
event.event == "thread.run.completed"
|
| 712 |
+
or event.event == "thread.run.cancelled"
|
| 713 |
+
or event.event == "thread.run.expired"
|
| 714 |
+
or event.event == "thread.run.failed"
|
| 715 |
+
or event.event == "thread.run.requires_action"
|
| 716 |
+
or event.event == "thread.run.incomplete"
|
| 717 |
+
):
|
| 718 |
+
self.__current_run = event.data
|
| 719 |
+
if self._current_tool_call:
|
| 720 |
+
await self.on_tool_call_done(self._current_tool_call)
|
| 721 |
+
elif (
|
| 722 |
+
event.event == "thread.run.created"
|
| 723 |
+
or event.event == "thread.run.in_progress"
|
| 724 |
+
or event.event == "thread.run.cancelling"
|
| 725 |
+
or event.event == "thread.run.queued"
|
| 726 |
+
):
|
| 727 |
+
self.__current_run = event.data
|
| 728 |
+
elif event.event == "thread.message.created":
|
| 729 |
+
await self.on_message_created(event.data)
|
| 730 |
+
elif event.event == "thread.message.delta":
|
| 731 |
+
snapshot = self.__current_message_snapshot
|
| 732 |
+
assert snapshot is not None
|
| 733 |
+
|
| 734 |
+
message_delta = event.data.delta
|
| 735 |
+
if message_delta.content is not None:
|
| 736 |
+
for content_delta in message_delta.content:
|
| 737 |
+
if content_delta.type == "text" and content_delta.text:
|
| 738 |
+
snapshot_content = snapshot.content[content_delta.index]
|
| 739 |
+
assert snapshot_content.type == "text"
|
| 740 |
+
await self.on_text_delta(content_delta.text, snapshot_content.text)
|
| 741 |
+
|
| 742 |
+
# If the delta is for a new message content:
|
| 743 |
+
# - emit on_text_done/on_image_file_done for the previous message content
|
| 744 |
+
# - emit on_text_created/on_image_created for the new message content
|
| 745 |
+
if content_delta.index != self._current_message_content_index:
|
| 746 |
+
if self._current_message_content is not None:
|
| 747 |
+
if self._current_message_content.type == "text":
|
| 748 |
+
await self.on_text_done(self._current_message_content.text)
|
| 749 |
+
elif self._current_message_content.type == "image_file":
|
| 750 |
+
await self.on_image_file_done(self._current_message_content.image_file)
|
| 751 |
+
|
| 752 |
+
self._current_message_content_index = content_delta.index
|
| 753 |
+
self._current_message_content = snapshot.content[content_delta.index]
|
| 754 |
+
|
| 755 |
+
# Update the current_message_content (delta event is correctly emitted already)
|
| 756 |
+
self._current_message_content = snapshot.content[content_delta.index]
|
| 757 |
+
|
| 758 |
+
await self.on_message_delta(event.data.delta, snapshot)
|
| 759 |
+
elif event.event == "thread.message.completed" or event.event == "thread.message.incomplete":
|
| 760 |
+
self.__current_message_snapshot = event.data
|
| 761 |
+
self.__message_snapshots[event.data.id] = event.data
|
| 762 |
+
|
| 763 |
+
if self._current_message_content_index is not None:
|
| 764 |
+
content = event.data.content[self._current_message_content_index]
|
| 765 |
+
if content.type == "text":
|
| 766 |
+
await self.on_text_done(content.text)
|
| 767 |
+
elif content.type == "image_file":
|
| 768 |
+
await self.on_image_file_done(content.image_file)
|
| 769 |
+
|
| 770 |
+
await self.on_message_done(event.data)
|
| 771 |
+
elif event.event == "thread.run.step.created":
|
| 772 |
+
self.__current_run_step_id = event.data.id
|
| 773 |
+
await self.on_run_step_created(event.data)
|
| 774 |
+
elif event.event == "thread.run.step.in_progress":
|
| 775 |
+
self.__current_run_step_id = event.data.id
|
| 776 |
+
elif event.event == "thread.run.step.delta":
|
| 777 |
+
step_snapshot = self.__run_step_snapshots[event.data.id]
|
| 778 |
+
|
| 779 |
+
run_step_delta = event.data.delta
|
| 780 |
+
if (
|
| 781 |
+
run_step_delta.step_details
|
| 782 |
+
and run_step_delta.step_details.type == "tool_calls"
|
| 783 |
+
and run_step_delta.step_details.tool_calls is not None
|
| 784 |
+
):
|
| 785 |
+
assert step_snapshot.step_details.type == "tool_calls"
|
| 786 |
+
for tool_call_delta in run_step_delta.step_details.tool_calls:
|
| 787 |
+
if tool_call_delta.index == self._current_tool_call_index:
|
| 788 |
+
await self.on_tool_call_delta(
|
| 789 |
+
tool_call_delta,
|
| 790 |
+
step_snapshot.step_details.tool_calls[tool_call_delta.index],
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
# If the delta is for a new tool call:
|
| 794 |
+
# - emit on_tool_call_done for the previous tool_call
|
| 795 |
+
# - emit on_tool_call_created for the new tool_call
|
| 796 |
+
if tool_call_delta.index != self._current_tool_call_index:
|
| 797 |
+
if self._current_tool_call is not None:
|
| 798 |
+
await self.on_tool_call_done(self._current_tool_call)
|
| 799 |
+
|
| 800 |
+
self._current_tool_call_index = tool_call_delta.index
|
| 801 |
+
self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index]
|
| 802 |
+
await self.on_tool_call_created(self._current_tool_call)
|
| 803 |
+
|
| 804 |
+
# Update the current_tool_call (delta event is correctly emitted already)
|
| 805 |
+
self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index]
|
| 806 |
+
|
| 807 |
+
await self.on_run_step_delta(
|
| 808 |
+
event.data.delta,
|
| 809 |
+
step_snapshot,
|
| 810 |
+
)
|
| 811 |
+
elif (
|
| 812 |
+
event.event == "thread.run.step.completed"
|
| 813 |
+
or event.event == "thread.run.step.cancelled"
|
| 814 |
+
or event.event == "thread.run.step.expired"
|
| 815 |
+
or event.event == "thread.run.step.failed"
|
| 816 |
+
):
|
| 817 |
+
if self._current_tool_call:
|
| 818 |
+
await self.on_tool_call_done(self._current_tool_call)
|
| 819 |
+
|
| 820 |
+
await self.on_run_step_done(event.data)
|
| 821 |
+
self.__current_run_step_id = None
|
| 822 |
+
elif event.event == "thread.created" or event.event == "thread.message.in_progress" or event.event == "error":
|
| 823 |
+
# currently no special handling
|
| 824 |
+
...
|
| 825 |
+
else:
|
| 826 |
+
# we only want to error at build-time
|
| 827 |
+
if TYPE_CHECKING: # type: ignore[unreachable]
|
| 828 |
+
assert_never(event)
|
| 829 |
+
|
| 830 |
+
self._current_event = None
|
| 831 |
+
|
| 832 |
+
async def __stream__(self) -> AsyncIterator[AssistantStreamEvent]:
|
| 833 |
+
stream = self.__stream
|
| 834 |
+
if not stream:
|
| 835 |
+
raise RuntimeError("Stream has not been started yet")
|
| 836 |
+
|
| 837 |
+
try:
|
| 838 |
+
async for event in stream:
|
| 839 |
+
await self._emit_sse_event(event)
|
| 840 |
+
|
| 841 |
+
yield event
|
| 842 |
+
except (httpx.TimeoutException, asyncio.TimeoutError) as exc:
|
| 843 |
+
await self.on_timeout()
|
| 844 |
+
await self.on_exception(exc)
|
| 845 |
+
raise
|
| 846 |
+
except Exception as exc:
|
| 847 |
+
await self.on_exception(exc)
|
| 848 |
+
raise
|
| 849 |
+
finally:
|
| 850 |
+
await self.on_end()
|
| 851 |
+
|
| 852 |
+
|
| 853 |
+
AsyncAssistantEventHandlerT = TypeVar("AsyncAssistantEventHandlerT", bound=AsyncAssistantEventHandler)
|
| 854 |
+
|
| 855 |
+
|
| 856 |
+
class AsyncAssistantStreamManager(Generic[AsyncAssistantEventHandlerT]):
|
| 857 |
+
"""Wrapper over AsyncAssistantStreamEventHandler that is returned by `.stream()`
|
| 858 |
+
so that an async context manager can be used without `await`ing the
|
| 859 |
+
original client call.
|
| 860 |
+
|
| 861 |
+
```py
|
| 862 |
+
async with client.threads.create_and_run_stream(...) as stream:
|
| 863 |
+
async for event in stream:
|
| 864 |
+
...
|
| 865 |
+
```
|
| 866 |
+
"""
|
| 867 |
+
|
| 868 |
+
def __init__(
|
| 869 |
+
self,
|
| 870 |
+
api_request: Awaitable[AsyncStream[AssistantStreamEvent]],
|
| 871 |
+
*,
|
| 872 |
+
event_handler: AsyncAssistantEventHandlerT,
|
| 873 |
+
) -> None:
|
| 874 |
+
self.__stream: AsyncStream[AssistantStreamEvent] | None = None
|
| 875 |
+
self.__event_handler = event_handler
|
| 876 |
+
self.__api_request = api_request
|
| 877 |
+
|
| 878 |
+
async def __aenter__(self) -> AsyncAssistantEventHandlerT:
|
| 879 |
+
self.__stream = await self.__api_request
|
| 880 |
+
self.__event_handler._init(self.__stream)
|
| 881 |
+
return self.__event_handler
|
| 882 |
+
|
| 883 |
+
async def __aexit__(
|
| 884 |
+
self,
|
| 885 |
+
exc_type: type[BaseException] | None,
|
| 886 |
+
exc: BaseException | None,
|
| 887 |
+
exc_tb: TracebackType | None,
|
| 888 |
+
) -> None:
|
| 889 |
+
if self.__stream is not None:
|
| 890 |
+
await self.__stream.close()
|
| 891 |
+
|
| 892 |
+
|
| 893 |
+
def accumulate_run_step(
|
| 894 |
+
*,
|
| 895 |
+
event: AssistantStreamEvent,
|
| 896 |
+
run_step_snapshots: dict[str, RunStep],
|
| 897 |
+
) -> None:
|
| 898 |
+
if event.event == "thread.run.step.created":
|
| 899 |
+
run_step_snapshots[event.data.id] = event.data
|
| 900 |
+
return
|
| 901 |
+
|
| 902 |
+
if event.event == "thread.run.step.delta":
|
| 903 |
+
data = event.data
|
| 904 |
+
snapshot = run_step_snapshots[data.id]
|
| 905 |
+
|
| 906 |
+
if data.delta:
|
| 907 |
+
merged = accumulate_delta(
|
| 908 |
+
cast(
|
| 909 |
+
"dict[object, object]",
|
| 910 |
+
model_dump(snapshot, exclude_unset=True, warnings=False),
|
| 911 |
+
),
|
| 912 |
+
cast(
|
| 913 |
+
"dict[object, object]",
|
| 914 |
+
model_dump(data.delta, exclude_unset=True, warnings=False),
|
| 915 |
+
),
|
| 916 |
+
)
|
| 917 |
+
run_step_snapshots[snapshot.id] = cast(RunStep, construct_type(type_=RunStep, value=merged))
|
| 918 |
+
|
| 919 |
+
return None
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
def accumulate_event(
|
| 923 |
+
*,
|
| 924 |
+
event: AssistantStreamEvent,
|
| 925 |
+
current_message_snapshot: Message | None,
|
| 926 |
+
) -> tuple[Message | None, list[MessageContentDelta]]:
|
| 927 |
+
"""Returns a tuple of message snapshot and newly created text message deltas"""
|
| 928 |
+
if event.event == "thread.message.created":
|
| 929 |
+
return event.data, []
|
| 930 |
+
|
| 931 |
+
new_content: list[MessageContentDelta] = []
|
| 932 |
+
|
| 933 |
+
if event.event != "thread.message.delta":
|
| 934 |
+
return current_message_snapshot, []
|
| 935 |
+
|
| 936 |
+
if not current_message_snapshot:
|
| 937 |
+
raise RuntimeError("Encountered a message delta with no previous snapshot")
|
| 938 |
+
|
| 939 |
+
data = event.data
|
| 940 |
+
if data.delta.content:
|
| 941 |
+
for content_delta in data.delta.content:
|
| 942 |
+
try:
|
| 943 |
+
block = current_message_snapshot.content[content_delta.index]
|
| 944 |
+
except IndexError:
|
| 945 |
+
current_message_snapshot.content.insert(
|
| 946 |
+
content_delta.index,
|
| 947 |
+
cast(
|
| 948 |
+
MessageContent,
|
| 949 |
+
construct_type(
|
| 950 |
+
# mypy doesn't allow Content for some reason
|
| 951 |
+
type_=cast(Any, MessageContent),
|
| 952 |
+
value=model_dump(content_delta, exclude_unset=True, warnings=False),
|
| 953 |
+
),
|
| 954 |
+
),
|
| 955 |
+
)
|
| 956 |
+
new_content.append(content_delta)
|
| 957 |
+
else:
|
| 958 |
+
merged = accumulate_delta(
|
| 959 |
+
cast(
|
| 960 |
+
"dict[object, object]",
|
| 961 |
+
model_dump(block, exclude_unset=True, warnings=False),
|
| 962 |
+
),
|
| 963 |
+
cast(
|
| 964 |
+
"dict[object, object]",
|
| 965 |
+
model_dump(content_delta, exclude_unset=True, warnings=False),
|
| 966 |
+
),
|
| 967 |
+
)
|
| 968 |
+
current_message_snapshot.content[content_delta.index] = cast(
|
| 969 |
+
MessageContent,
|
| 970 |
+
construct_type(
|
| 971 |
+
# mypy doesn't allow Content for some reason
|
| 972 |
+
type_=cast(Any, MessageContent),
|
| 973 |
+
value=merged,
|
| 974 |
+
),
|
| 975 |
+
)
|
| 976 |
+
|
| 977 |
+
return current_message_snapshot, new_content
|
| 978 |
+
|
| 979 |
+
|
| 980 |
+
def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> dict[object, object]:
|
| 981 |
+
for key, delta_value in delta.items():
|
| 982 |
+
if key not in acc:
|
| 983 |
+
acc[key] = delta_value
|
| 984 |
+
continue
|
| 985 |
+
|
| 986 |
+
acc_value = acc[key]
|
| 987 |
+
if acc_value is None:
|
| 988 |
+
acc[key] = delta_value
|
| 989 |
+
continue
|
| 990 |
+
|
| 991 |
+
# the `index` property is used in arrays of objects so it should
|
| 992 |
+
# not be accumulated like other values e.g.
|
| 993 |
+
# [{'foo': 'bar', 'index': 0}]
|
| 994 |
+
#
|
| 995 |
+
# the same applies to `type` properties as they're used for
|
| 996 |
+
# discriminated unions
|
| 997 |
+
if key == "index" or key == "type":
|
| 998 |
+
acc[key] = delta_value
|
| 999 |
+
continue
|
| 1000 |
+
|
| 1001 |
+
if isinstance(acc_value, str) and isinstance(delta_value, str):
|
| 1002 |
+
acc_value += delta_value
|
| 1003 |
+
elif isinstance(acc_value, (int, float)) and isinstance(delta_value, (int, float)):
|
| 1004 |
+
acc_value += delta_value
|
| 1005 |
+
elif is_dict(acc_value) and is_dict(delta_value):
|
| 1006 |
+
acc_value = accumulate_delta(acc_value, delta_value)
|
| 1007 |
+
elif is_list(acc_value) and is_list(delta_value):
|
| 1008 |
+
# for lists of non-dictionary items we'll only ever get new entries
|
| 1009 |
+
# in the array, existing entries will never be changed
|
| 1010 |
+
if all(isinstance(x, (str, int, float)) for x in acc_value):
|
| 1011 |
+
acc_value.extend(delta_value)
|
| 1012 |
+
continue
|
| 1013 |
+
|
| 1014 |
+
for delta_entry in delta_value:
|
| 1015 |
+
if not is_dict(delta_entry):
|
| 1016 |
+
raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}")
|
| 1017 |
+
|
| 1018 |
+
try:
|
| 1019 |
+
index = delta_entry["index"]
|
| 1020 |
+
except KeyError as exc:
|
| 1021 |
+
raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc
|
| 1022 |
+
|
| 1023 |
+
if not isinstance(index, int):
|
| 1024 |
+
raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}")
|
| 1025 |
+
|
| 1026 |
+
try:
|
| 1027 |
+
acc_entry = acc_value[index]
|
| 1028 |
+
except IndexError:
|
| 1029 |
+
acc_value.insert(index, delta_entry)
|
| 1030 |
+
else:
|
| 1031 |
+
if not is_dict(acc_entry):
|
| 1032 |
+
raise TypeError("not handled yet")
|
| 1033 |
+
|
| 1034 |
+
acc_value[index] = accumulate_delta(acc_entry, delta_entry)
|
| 1035 |
+
|
| 1036 |
+
acc[key] = acc_value
|
| 1037 |
+
|
| 1038 |
+
return acc
|
.venv/lib/python3.11/site-packages/openai/lib/streaming/_deltas.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from ..._utils import is_dict, is_list
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> dict[object, object]:
|
| 7 |
+
for key, delta_value in delta.items():
|
| 8 |
+
if key not in acc:
|
| 9 |
+
acc[key] = delta_value
|
| 10 |
+
continue
|
| 11 |
+
|
| 12 |
+
acc_value = acc[key]
|
| 13 |
+
if acc_value is None:
|
| 14 |
+
acc[key] = delta_value
|
| 15 |
+
continue
|
| 16 |
+
|
| 17 |
+
# the `index` property is used in arrays of objects so it should
|
| 18 |
+
# not be accumulated like other values e.g.
|
| 19 |
+
# [{'foo': 'bar', 'index': 0}]
|
| 20 |
+
#
|
| 21 |
+
# the same applies to `type` properties as they're used for
|
| 22 |
+
# discriminated unions
|
| 23 |
+
if key == "index" or key == "type":
|
| 24 |
+
acc[key] = delta_value
|
| 25 |
+
continue
|
| 26 |
+
|
| 27 |
+
if isinstance(acc_value, str) and isinstance(delta_value, str):
|
| 28 |
+
acc_value += delta_value
|
| 29 |
+
elif isinstance(acc_value, (int, float)) and isinstance(delta_value, (int, float)):
|
| 30 |
+
acc_value += delta_value
|
| 31 |
+
elif is_dict(acc_value) and is_dict(delta_value):
|
| 32 |
+
acc_value = accumulate_delta(acc_value, delta_value)
|
| 33 |
+
elif is_list(acc_value) and is_list(delta_value):
|
| 34 |
+
# for lists of non-dictionary items we'll only ever get new entries
|
| 35 |
+
# in the array, existing entries will never be changed
|
| 36 |
+
if all(isinstance(x, (str, int, float)) for x in acc_value):
|
| 37 |
+
acc_value.extend(delta_value)
|
| 38 |
+
continue
|
| 39 |
+
|
| 40 |
+
for delta_entry in delta_value:
|
| 41 |
+
if not is_dict(delta_entry):
|
| 42 |
+
raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}")
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
index = delta_entry["index"]
|
| 46 |
+
except KeyError as exc:
|
| 47 |
+
raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc
|
| 48 |
+
|
| 49 |
+
if not isinstance(index, int):
|
| 50 |
+
raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}")
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
acc_entry = acc_value[index]
|
| 54 |
+
except IndexError:
|
| 55 |
+
acc_value.insert(index, delta_entry)
|
| 56 |
+
else:
|
| 57 |
+
if not is_dict(acc_entry):
|
| 58 |
+
raise TypeError("not handled yet")
|
| 59 |
+
|
| 60 |
+
acc_value[index] = accumulate_delta(acc_entry, delta_entry)
|
| 61 |
+
|
| 62 |
+
acc[key] = acc_value
|
| 63 |
+
|
| 64 |
+
return acc
|
.venv/lib/python3.11/site-packages/openai/lib/streaming/chat/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._types import (
|
| 2 |
+
ParsedChoiceSnapshot as ParsedChoiceSnapshot,
|
| 3 |
+
ParsedChatCompletionSnapshot as ParsedChatCompletionSnapshot,
|
| 4 |
+
ParsedChatCompletionMessageSnapshot as ParsedChatCompletionMessageSnapshot,
|
| 5 |
+
)
|
| 6 |
+
from ._events import (
|
| 7 |
+
ChunkEvent as ChunkEvent,
|
| 8 |
+
ContentDoneEvent as ContentDoneEvent,
|
| 9 |
+
RefusalDoneEvent as RefusalDoneEvent,
|
| 10 |
+
ContentDeltaEvent as ContentDeltaEvent,
|
| 11 |
+
RefusalDeltaEvent as RefusalDeltaEvent,
|
| 12 |
+
LogprobsContentDoneEvent as LogprobsContentDoneEvent,
|
| 13 |
+
LogprobsRefusalDoneEvent as LogprobsRefusalDoneEvent,
|
| 14 |
+
ChatCompletionStreamEvent as ChatCompletionStreamEvent,
|
| 15 |
+
LogprobsContentDeltaEvent as LogprobsContentDeltaEvent,
|
| 16 |
+
LogprobsRefusalDeltaEvent as LogprobsRefusalDeltaEvent,
|
| 17 |
+
ParsedChatCompletionSnapshot as ParsedChatCompletionSnapshot,
|
| 18 |
+
FunctionToolCallArgumentsDoneEvent as FunctionToolCallArgumentsDoneEvent,
|
| 19 |
+
FunctionToolCallArgumentsDeltaEvent as FunctionToolCallArgumentsDeltaEvent,
|
| 20 |
+
)
|
| 21 |
+
from ._completions import (
|
| 22 |
+
ChatCompletionStream as ChatCompletionStream,
|
| 23 |
+
AsyncChatCompletionStream as AsyncChatCompletionStream,
|
| 24 |
+
ChatCompletionStreamState as ChatCompletionStreamState,
|
| 25 |
+
ChatCompletionStreamManager as ChatCompletionStreamManager,
|
| 26 |
+
AsyncChatCompletionStreamManager as AsyncChatCompletionStreamManager,
|
| 27 |
+
)
|
.venv/lib/python3.11/site-packages/openai/lib/streaming/chat/__pycache__/_completions.cpython-311.pyc
ADDED
|
Binary file (31.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/openai/lib/streaming/chat/__pycache__/_events.cpython-311.pyc
ADDED
|
Binary file (5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/openai/lib/streaming/chat/_completions.py
ADDED
|
@@ -0,0 +1,755 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import inspect
|
| 4 |
+
from types import TracebackType
|
| 5 |
+
from typing import TYPE_CHECKING, Any, Generic, Callable, Iterable, Awaitable, AsyncIterator, cast
|
| 6 |
+
from typing_extensions import Self, Iterator, assert_never
|
| 7 |
+
|
| 8 |
+
from jiter import from_json
|
| 9 |
+
|
| 10 |
+
from ._types import ParsedChoiceSnapshot, ParsedChatCompletionSnapshot, ParsedChatCompletionMessageSnapshot
|
| 11 |
+
from ._events import (
|
| 12 |
+
ChunkEvent,
|
| 13 |
+
ContentDoneEvent,
|
| 14 |
+
RefusalDoneEvent,
|
| 15 |
+
ContentDeltaEvent,
|
| 16 |
+
RefusalDeltaEvent,
|
| 17 |
+
LogprobsContentDoneEvent,
|
| 18 |
+
LogprobsRefusalDoneEvent,
|
| 19 |
+
ChatCompletionStreamEvent,
|
| 20 |
+
LogprobsContentDeltaEvent,
|
| 21 |
+
LogprobsRefusalDeltaEvent,
|
| 22 |
+
FunctionToolCallArgumentsDoneEvent,
|
| 23 |
+
FunctionToolCallArgumentsDeltaEvent,
|
| 24 |
+
)
|
| 25 |
+
from .._deltas import accumulate_delta
|
| 26 |
+
from ...._types import NOT_GIVEN, IncEx, NotGiven
|
| 27 |
+
from ...._utils import is_given, consume_sync_iterator, consume_async_iterator
|
| 28 |
+
from ...._compat import model_dump
|
| 29 |
+
from ...._models import build, construct_type
|
| 30 |
+
from ..._parsing import (
|
| 31 |
+
ResponseFormatT,
|
| 32 |
+
has_parseable_input,
|
| 33 |
+
maybe_parse_content,
|
| 34 |
+
parse_chat_completion,
|
| 35 |
+
get_input_tool_by_name,
|
| 36 |
+
solve_response_format_t,
|
| 37 |
+
parse_function_tool_arguments,
|
| 38 |
+
)
|
| 39 |
+
from ...._streaming import Stream, AsyncStream
|
| 40 |
+
from ....types.chat import ChatCompletionChunk, ParsedChatCompletion, ChatCompletionToolParam
|
| 41 |
+
from ...._exceptions import LengthFinishReasonError, ContentFilterFinishReasonError
|
| 42 |
+
from ....types.chat.chat_completion import ChoiceLogprobs
|
| 43 |
+
from ....types.chat.chat_completion_chunk import Choice as ChoiceChunk
|
| 44 |
+
from ....types.chat.completion_create_params import ResponseFormat as ResponseFormatParam
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class ChatCompletionStream(Generic[ResponseFormatT]):
|
| 48 |
+
"""Wrapper over the Chat Completions streaming API that adds helpful
|
| 49 |
+
events such as `content.done`, supports automatically parsing
|
| 50 |
+
responses & tool calls and accumulates a `ChatCompletion` object
|
| 51 |
+
from each individual chunk.
|
| 52 |
+
|
| 53 |
+
https://platform.openai.com/docs/api-reference/streaming
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
*,
|
| 59 |
+
raw_stream: Stream[ChatCompletionChunk],
|
| 60 |
+
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
|
| 61 |
+
input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
|
| 62 |
+
) -> None:
|
| 63 |
+
self._raw_stream = raw_stream
|
| 64 |
+
self._response = raw_stream.response
|
| 65 |
+
self._iterator = self.__stream__()
|
| 66 |
+
self._state = ChatCompletionStreamState(response_format=response_format, input_tools=input_tools)
|
| 67 |
+
|
| 68 |
+
def __next__(self) -> ChatCompletionStreamEvent[ResponseFormatT]:
|
| 69 |
+
return self._iterator.__next__()
|
| 70 |
+
|
| 71 |
+
def __iter__(self) -> Iterator[ChatCompletionStreamEvent[ResponseFormatT]]:
|
| 72 |
+
for item in self._iterator:
|
| 73 |
+
yield item
|
| 74 |
+
|
| 75 |
+
def __enter__(self) -> Self:
|
| 76 |
+
return self
|
| 77 |
+
|
| 78 |
+
def __exit__(
|
| 79 |
+
self,
|
| 80 |
+
exc_type: type[BaseException] | None,
|
| 81 |
+
exc: BaseException | None,
|
| 82 |
+
exc_tb: TracebackType | None,
|
| 83 |
+
) -> None:
|
| 84 |
+
self.close()
|
| 85 |
+
|
| 86 |
+
def close(self) -> None:
|
| 87 |
+
"""
|
| 88 |
+
Close the response and release the connection.
|
| 89 |
+
|
| 90 |
+
Automatically called if the response body is read to completion.
|
| 91 |
+
"""
|
| 92 |
+
self._response.close()
|
| 93 |
+
|
| 94 |
+
def get_final_completion(self) -> ParsedChatCompletion[ResponseFormatT]:
|
| 95 |
+
"""Waits until the stream has been read to completion and returns
|
| 96 |
+
the accumulated `ParsedChatCompletion` object.
|
| 97 |
+
|
| 98 |
+
If you passed a class type to `.stream()`, the `completion.choices[0].message.parsed`
|
| 99 |
+
property will be the content deserialised into that class, if there was any content returned
|
| 100 |
+
by the API.
|
| 101 |
+
"""
|
| 102 |
+
self.until_done()
|
| 103 |
+
return self._state.get_final_completion()
|
| 104 |
+
|
| 105 |
+
def until_done(self) -> Self:
|
| 106 |
+
"""Blocks until the stream has been consumed."""
|
| 107 |
+
consume_sync_iterator(self)
|
| 108 |
+
return self
|
| 109 |
+
|
| 110 |
+
@property
|
| 111 |
+
def current_completion_snapshot(self) -> ParsedChatCompletionSnapshot:
|
| 112 |
+
return self._state.current_completion_snapshot
|
| 113 |
+
|
| 114 |
+
def __stream__(self) -> Iterator[ChatCompletionStreamEvent[ResponseFormatT]]:
|
| 115 |
+
for sse_event in self._raw_stream:
|
| 116 |
+
events_to_fire = self._state.handle_chunk(sse_event)
|
| 117 |
+
for event in events_to_fire:
|
| 118 |
+
yield event
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class ChatCompletionStreamManager(Generic[ResponseFormatT]):
|
| 122 |
+
"""Context manager over a `ChatCompletionStream` that is returned by `.stream()`.
|
| 123 |
+
|
| 124 |
+
This context manager ensures the response cannot be leaked if you don't read
|
| 125 |
+
the stream to completion.
|
| 126 |
+
|
| 127 |
+
Usage:
|
| 128 |
+
```py
|
| 129 |
+
with client.beta.chat.completions.stream(...) as stream:
|
| 130 |
+
for event in stream:
|
| 131 |
+
...
|
| 132 |
+
```
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
def __init__(
|
| 136 |
+
self,
|
| 137 |
+
api_request: Callable[[], Stream[ChatCompletionChunk]],
|
| 138 |
+
*,
|
| 139 |
+
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
|
| 140 |
+
input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
|
| 141 |
+
) -> None:
|
| 142 |
+
self.__stream: ChatCompletionStream[ResponseFormatT] | None = None
|
| 143 |
+
self.__api_request = api_request
|
| 144 |
+
self.__response_format = response_format
|
| 145 |
+
self.__input_tools = input_tools
|
| 146 |
+
|
| 147 |
+
def __enter__(self) -> ChatCompletionStream[ResponseFormatT]:
|
| 148 |
+
raw_stream = self.__api_request()
|
| 149 |
+
|
| 150 |
+
self.__stream = ChatCompletionStream(
|
| 151 |
+
raw_stream=raw_stream,
|
| 152 |
+
response_format=self.__response_format,
|
| 153 |
+
input_tools=self.__input_tools,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
return self.__stream
|
| 157 |
+
|
| 158 |
+
def __exit__(
|
| 159 |
+
self,
|
| 160 |
+
exc_type: type[BaseException] | None,
|
| 161 |
+
exc: BaseException | None,
|
| 162 |
+
exc_tb: TracebackType | None,
|
| 163 |
+
) -> None:
|
| 164 |
+
if self.__stream is not None:
|
| 165 |
+
self.__stream.close()
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class AsyncChatCompletionStream(Generic[ResponseFormatT]):
|
| 169 |
+
"""Wrapper over the Chat Completions streaming API that adds helpful
|
| 170 |
+
events such as `content.done`, supports automatically parsing
|
| 171 |
+
responses & tool calls and accumulates a `ChatCompletion` object
|
| 172 |
+
from each individual chunk.
|
| 173 |
+
|
| 174 |
+
https://platform.openai.com/docs/api-reference/streaming
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
def __init__(
|
| 178 |
+
self,
|
| 179 |
+
*,
|
| 180 |
+
raw_stream: AsyncStream[ChatCompletionChunk],
|
| 181 |
+
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
|
| 182 |
+
input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
|
| 183 |
+
) -> None:
|
| 184 |
+
self._raw_stream = raw_stream
|
| 185 |
+
self._response = raw_stream.response
|
| 186 |
+
self._iterator = self.__stream__()
|
| 187 |
+
self._state = ChatCompletionStreamState(response_format=response_format, input_tools=input_tools)
|
| 188 |
+
|
| 189 |
+
async def __anext__(self) -> ChatCompletionStreamEvent[ResponseFormatT]:
|
| 190 |
+
return await self._iterator.__anext__()
|
| 191 |
+
|
| 192 |
+
async def __aiter__(self) -> AsyncIterator[ChatCompletionStreamEvent[ResponseFormatT]]:
|
| 193 |
+
async for item in self._iterator:
|
| 194 |
+
yield item
|
| 195 |
+
|
| 196 |
+
async def __aenter__(self) -> Self:
|
| 197 |
+
return self
|
| 198 |
+
|
| 199 |
+
async def __aexit__(
|
| 200 |
+
self,
|
| 201 |
+
exc_type: type[BaseException] | None,
|
| 202 |
+
exc: BaseException | None,
|
| 203 |
+
exc_tb: TracebackType | None,
|
| 204 |
+
) -> None:
|
| 205 |
+
await self.close()
|
| 206 |
+
|
| 207 |
+
async def close(self) -> None:
|
| 208 |
+
"""
|
| 209 |
+
Close the response and release the connection.
|
| 210 |
+
|
| 211 |
+
Automatically called if the response body is read to completion.
|
| 212 |
+
"""
|
| 213 |
+
await self._response.aclose()
|
| 214 |
+
|
| 215 |
+
async def get_final_completion(self) -> ParsedChatCompletion[ResponseFormatT]:
|
| 216 |
+
"""Waits until the stream has been read to completion and returns
|
| 217 |
+
the accumulated `ParsedChatCompletion` object.
|
| 218 |
+
|
| 219 |
+
If you passed a class type to `.stream()`, the `completion.choices[0].message.parsed`
|
| 220 |
+
property will be the content deserialised into that class, if there was any content returned
|
| 221 |
+
by the API.
|
| 222 |
+
"""
|
| 223 |
+
await self.until_done()
|
| 224 |
+
return self._state.get_final_completion()
|
| 225 |
+
|
| 226 |
+
async def until_done(self) -> Self:
|
| 227 |
+
"""Blocks until the stream has been consumed."""
|
| 228 |
+
await consume_async_iterator(self)
|
| 229 |
+
return self
|
| 230 |
+
|
| 231 |
+
@property
|
| 232 |
+
def current_completion_snapshot(self) -> ParsedChatCompletionSnapshot:
|
| 233 |
+
return self._state.current_completion_snapshot
|
| 234 |
+
|
| 235 |
+
async def __stream__(self) -> AsyncIterator[ChatCompletionStreamEvent[ResponseFormatT]]:
|
| 236 |
+
async for sse_event in self._raw_stream:
|
| 237 |
+
events_to_fire = self._state.handle_chunk(sse_event)
|
| 238 |
+
for event in events_to_fire:
|
| 239 |
+
yield event
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class AsyncChatCompletionStreamManager(Generic[ResponseFormatT]):
|
| 243 |
+
"""Context manager over a `AsyncChatCompletionStream` that is returned by `.stream()`.
|
| 244 |
+
|
| 245 |
+
This context manager ensures the response cannot be leaked if you don't read
|
| 246 |
+
the stream to completion.
|
| 247 |
+
|
| 248 |
+
Usage:
|
| 249 |
+
```py
|
| 250 |
+
async with client.beta.chat.completions.stream(...) as stream:
|
| 251 |
+
for event in stream:
|
| 252 |
+
...
|
| 253 |
+
```
|
| 254 |
+
"""
|
| 255 |
+
|
| 256 |
+
def __init__(
|
| 257 |
+
self,
|
| 258 |
+
api_request: Awaitable[AsyncStream[ChatCompletionChunk]],
|
| 259 |
+
*,
|
| 260 |
+
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
|
| 261 |
+
input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
|
| 262 |
+
) -> None:
|
| 263 |
+
self.__stream: AsyncChatCompletionStream[ResponseFormatT] | None = None
|
| 264 |
+
self.__api_request = api_request
|
| 265 |
+
self.__response_format = response_format
|
| 266 |
+
self.__input_tools = input_tools
|
| 267 |
+
|
| 268 |
+
async def __aenter__(self) -> AsyncChatCompletionStream[ResponseFormatT]:
|
| 269 |
+
raw_stream = await self.__api_request
|
| 270 |
+
|
| 271 |
+
self.__stream = AsyncChatCompletionStream(
|
| 272 |
+
raw_stream=raw_stream,
|
| 273 |
+
response_format=self.__response_format,
|
| 274 |
+
input_tools=self.__input_tools,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
return self.__stream
|
| 278 |
+
|
| 279 |
+
async def __aexit__(
|
| 280 |
+
self,
|
| 281 |
+
exc_type: type[BaseException] | None,
|
| 282 |
+
exc: BaseException | None,
|
| 283 |
+
exc_tb: TracebackType | None,
|
| 284 |
+
) -> None:
|
| 285 |
+
if self.__stream is not None:
|
| 286 |
+
await self.__stream.close()
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class ChatCompletionStreamState(Generic[ResponseFormatT]):
|
| 290 |
+
"""Helper class for manually accumulating `ChatCompletionChunk`s into a final `ChatCompletion` object.
|
| 291 |
+
|
| 292 |
+
This is useful in cases where you can't always use the `.stream()` method, e.g.
|
| 293 |
+
|
| 294 |
+
```py
|
| 295 |
+
from openai.lib.streaming.chat import ChatCompletionStreamState
|
| 296 |
+
|
| 297 |
+
state = ChatCompletionStreamState()
|
| 298 |
+
|
| 299 |
+
stream = client.chat.completions.create(..., stream=True)
|
| 300 |
+
for chunk in response:
|
| 301 |
+
state.handle_chunk(chunk)
|
| 302 |
+
|
| 303 |
+
# can also access the accumulated `ChatCompletion` mid-stream
|
| 304 |
+
state.current_completion_snapshot
|
| 305 |
+
|
| 306 |
+
print(state.get_final_completion())
|
| 307 |
+
```
|
| 308 |
+
"""
|
| 309 |
+
|
| 310 |
+
def __init__(
|
| 311 |
+
self,
|
| 312 |
+
*,
|
| 313 |
+
input_tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
|
| 314 |
+
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven = NOT_GIVEN,
|
| 315 |
+
) -> None:
|
| 316 |
+
self.__current_completion_snapshot: ParsedChatCompletionSnapshot | None = None
|
| 317 |
+
self.__choice_event_states: list[ChoiceEventState] = []
|
| 318 |
+
|
| 319 |
+
self._input_tools = [tool for tool in input_tools] if is_given(input_tools) else []
|
| 320 |
+
self._response_format = response_format
|
| 321 |
+
self._rich_response_format: type | NotGiven = response_format if inspect.isclass(response_format) else NOT_GIVEN
|
| 322 |
+
|
| 323 |
+
def get_final_completion(self) -> ParsedChatCompletion[ResponseFormatT]:
|
| 324 |
+
"""Parse the final completion object.
|
| 325 |
+
|
| 326 |
+
Note this does not provide any guarantees that the stream has actually finished, you must
|
| 327 |
+
only call this method when the stream is finished.
|
| 328 |
+
"""
|
| 329 |
+
return parse_chat_completion(
|
| 330 |
+
chat_completion=self.current_completion_snapshot,
|
| 331 |
+
response_format=self._rich_response_format,
|
| 332 |
+
input_tools=self._input_tools,
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
@property
|
| 336 |
+
def current_completion_snapshot(self) -> ParsedChatCompletionSnapshot:
|
| 337 |
+
assert self.__current_completion_snapshot is not None
|
| 338 |
+
return self.__current_completion_snapshot
|
| 339 |
+
|
| 340 |
+
def handle_chunk(self, chunk: ChatCompletionChunk) -> Iterable[ChatCompletionStreamEvent[ResponseFormatT]]:
|
| 341 |
+
"""Accumulate a new chunk into the snapshot and returns an iterable of events to yield."""
|
| 342 |
+
self.__current_completion_snapshot = self._accumulate_chunk(chunk)
|
| 343 |
+
|
| 344 |
+
return self._build_events(
|
| 345 |
+
chunk=chunk,
|
| 346 |
+
completion_snapshot=self.__current_completion_snapshot,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
def _get_choice_state(self, choice: ChoiceChunk) -> ChoiceEventState:
|
| 350 |
+
try:
|
| 351 |
+
return self.__choice_event_states[choice.index]
|
| 352 |
+
except IndexError:
|
| 353 |
+
choice_state = ChoiceEventState(input_tools=self._input_tools)
|
| 354 |
+
self.__choice_event_states.append(choice_state)
|
| 355 |
+
return choice_state
|
| 356 |
+
|
| 357 |
+
def _accumulate_chunk(self, chunk: ChatCompletionChunk) -> ParsedChatCompletionSnapshot:
|
| 358 |
+
completion_snapshot = self.__current_completion_snapshot
|
| 359 |
+
|
| 360 |
+
if completion_snapshot is None:
|
| 361 |
+
return _convert_initial_chunk_into_snapshot(chunk)
|
| 362 |
+
|
| 363 |
+
for choice in chunk.choices:
|
| 364 |
+
try:
|
| 365 |
+
choice_snapshot = completion_snapshot.choices[choice.index]
|
| 366 |
+
previous_tool_calls = choice_snapshot.message.tool_calls or []
|
| 367 |
+
|
| 368 |
+
choice_snapshot.message = cast(
|
| 369 |
+
ParsedChatCompletionMessageSnapshot,
|
| 370 |
+
construct_type(
|
| 371 |
+
type_=ParsedChatCompletionMessageSnapshot,
|
| 372 |
+
value=accumulate_delta(
|
| 373 |
+
cast(
|
| 374 |
+
"dict[object, object]",
|
| 375 |
+
model_dump(
|
| 376 |
+
choice_snapshot.message,
|
| 377 |
+
# we don't want to serialise / deserialise our custom properties
|
| 378 |
+
# as they won't appear in the delta and we don't want to have to
|
| 379 |
+
# continuosly reparse the content
|
| 380 |
+
exclude=cast(
|
| 381 |
+
# cast required as mypy isn't smart enough to infer `True` here to `Literal[True]`
|
| 382 |
+
IncEx,
|
| 383 |
+
{
|
| 384 |
+
"parsed": True,
|
| 385 |
+
"tool_calls": {
|
| 386 |
+
idx: {"function": {"parsed_arguments": True}}
|
| 387 |
+
for idx, _ in enumerate(choice_snapshot.message.tool_calls or [])
|
| 388 |
+
},
|
| 389 |
+
},
|
| 390 |
+
),
|
| 391 |
+
),
|
| 392 |
+
),
|
| 393 |
+
cast("dict[object, object]", choice.delta.to_dict()),
|
| 394 |
+
),
|
| 395 |
+
),
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
# ensure tools that have already been parsed are added back into the newly
|
| 399 |
+
# constructed message snapshot
|
| 400 |
+
for tool_index, prev_tool in enumerate(previous_tool_calls):
|
| 401 |
+
new_tool = (choice_snapshot.message.tool_calls or [])[tool_index]
|
| 402 |
+
|
| 403 |
+
if prev_tool.type == "function":
|
| 404 |
+
assert new_tool.type == "function"
|
| 405 |
+
new_tool.function.parsed_arguments = prev_tool.function.parsed_arguments
|
| 406 |
+
elif TYPE_CHECKING: # type: ignore[unreachable]
|
| 407 |
+
assert_never(prev_tool)
|
| 408 |
+
except IndexError:
|
| 409 |
+
choice_snapshot = cast(
|
| 410 |
+
ParsedChoiceSnapshot,
|
| 411 |
+
construct_type(
|
| 412 |
+
type_=ParsedChoiceSnapshot,
|
| 413 |
+
value={
|
| 414 |
+
**choice.model_dump(exclude_unset=True, exclude={"delta"}),
|
| 415 |
+
"message": choice.delta.to_dict(),
|
| 416 |
+
},
|
| 417 |
+
),
|
| 418 |
+
)
|
| 419 |
+
completion_snapshot.choices.append(choice_snapshot)
|
| 420 |
+
|
| 421 |
+
if choice.finish_reason:
|
| 422 |
+
choice_snapshot.finish_reason = choice.finish_reason
|
| 423 |
+
|
| 424 |
+
if has_parseable_input(response_format=self._response_format, input_tools=self._input_tools):
|
| 425 |
+
if choice.finish_reason == "length":
|
| 426 |
+
# at the time of writing, `.usage` will always be `None` but
|
| 427 |
+
# we include it here in case that is changed in the future
|
| 428 |
+
raise LengthFinishReasonError(completion=completion_snapshot)
|
| 429 |
+
|
| 430 |
+
if choice.finish_reason == "content_filter":
|
| 431 |
+
raise ContentFilterFinishReasonError()
|
| 432 |
+
|
| 433 |
+
if (
|
| 434 |
+
choice_snapshot.message.content
|
| 435 |
+
and not choice_snapshot.message.refusal
|
| 436 |
+
and is_given(self._rich_response_format)
|
| 437 |
+
):
|
| 438 |
+
choice_snapshot.message.parsed = from_json(
|
| 439 |
+
bytes(choice_snapshot.message.content, "utf-8"),
|
| 440 |
+
partial_mode=True,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
for tool_call_chunk in choice.delta.tool_calls or []:
|
| 444 |
+
tool_call_snapshot = (choice_snapshot.message.tool_calls or [])[tool_call_chunk.index]
|
| 445 |
+
|
| 446 |
+
if tool_call_snapshot.type == "function":
|
| 447 |
+
input_tool = get_input_tool_by_name(
|
| 448 |
+
input_tools=self._input_tools, name=tool_call_snapshot.function.name
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
if (
|
| 452 |
+
input_tool
|
| 453 |
+
and input_tool.get("function", {}).get("strict")
|
| 454 |
+
and tool_call_snapshot.function.arguments
|
| 455 |
+
):
|
| 456 |
+
tool_call_snapshot.function.parsed_arguments = from_json(
|
| 457 |
+
bytes(tool_call_snapshot.function.arguments, "utf-8"),
|
| 458 |
+
partial_mode=True,
|
| 459 |
+
)
|
| 460 |
+
elif TYPE_CHECKING: # type: ignore[unreachable]
|
| 461 |
+
assert_never(tool_call_snapshot)
|
| 462 |
+
|
| 463 |
+
if choice.logprobs is not None:
|
| 464 |
+
if choice_snapshot.logprobs is None:
|
| 465 |
+
choice_snapshot.logprobs = build(
|
| 466 |
+
ChoiceLogprobs,
|
| 467 |
+
content=choice.logprobs.content,
|
| 468 |
+
refusal=choice.logprobs.refusal,
|
| 469 |
+
)
|
| 470 |
+
else:
|
| 471 |
+
if choice.logprobs.content:
|
| 472 |
+
if choice_snapshot.logprobs.content is None:
|
| 473 |
+
choice_snapshot.logprobs.content = []
|
| 474 |
+
|
| 475 |
+
choice_snapshot.logprobs.content.extend(choice.logprobs.content)
|
| 476 |
+
|
| 477 |
+
if choice.logprobs.refusal:
|
| 478 |
+
if choice_snapshot.logprobs.refusal is None:
|
| 479 |
+
choice_snapshot.logprobs.refusal = []
|
| 480 |
+
|
| 481 |
+
choice_snapshot.logprobs.refusal.extend(choice.logprobs.refusal)
|
| 482 |
+
|
| 483 |
+
completion_snapshot.usage = chunk.usage
|
| 484 |
+
completion_snapshot.system_fingerprint = chunk.system_fingerprint
|
| 485 |
+
|
| 486 |
+
return completion_snapshot
|
| 487 |
+
|
| 488 |
+
def _build_events(
|
| 489 |
+
self,
|
| 490 |
+
*,
|
| 491 |
+
chunk: ChatCompletionChunk,
|
| 492 |
+
completion_snapshot: ParsedChatCompletionSnapshot,
|
| 493 |
+
) -> list[ChatCompletionStreamEvent[ResponseFormatT]]:
|
| 494 |
+
events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]] = []
|
| 495 |
+
|
| 496 |
+
events_to_fire.append(
|
| 497 |
+
build(ChunkEvent, type="chunk", chunk=chunk, snapshot=completion_snapshot),
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
for choice in chunk.choices:
|
| 501 |
+
choice_state = self._get_choice_state(choice)
|
| 502 |
+
choice_snapshot = completion_snapshot.choices[choice.index]
|
| 503 |
+
|
| 504 |
+
if choice.delta.content is not None and choice_snapshot.message.content is not None:
|
| 505 |
+
events_to_fire.append(
|
| 506 |
+
build(
|
| 507 |
+
ContentDeltaEvent,
|
| 508 |
+
type="content.delta",
|
| 509 |
+
delta=choice.delta.content,
|
| 510 |
+
snapshot=choice_snapshot.message.content,
|
| 511 |
+
parsed=choice_snapshot.message.parsed,
|
| 512 |
+
)
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
if choice.delta.refusal is not None and choice_snapshot.message.refusal is not None:
|
| 516 |
+
events_to_fire.append(
|
| 517 |
+
build(
|
| 518 |
+
RefusalDeltaEvent,
|
| 519 |
+
type="refusal.delta",
|
| 520 |
+
delta=choice.delta.refusal,
|
| 521 |
+
snapshot=choice_snapshot.message.refusal,
|
| 522 |
+
)
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
if choice.delta.tool_calls:
|
| 526 |
+
tool_calls = choice_snapshot.message.tool_calls
|
| 527 |
+
assert tool_calls is not None
|
| 528 |
+
|
| 529 |
+
for tool_call_delta in choice.delta.tool_calls:
|
| 530 |
+
tool_call = tool_calls[tool_call_delta.index]
|
| 531 |
+
|
| 532 |
+
if tool_call.type == "function":
|
| 533 |
+
assert tool_call_delta.function is not None
|
| 534 |
+
events_to_fire.append(
|
| 535 |
+
build(
|
| 536 |
+
FunctionToolCallArgumentsDeltaEvent,
|
| 537 |
+
type="tool_calls.function.arguments.delta",
|
| 538 |
+
name=tool_call.function.name,
|
| 539 |
+
index=tool_call_delta.index,
|
| 540 |
+
arguments=tool_call.function.arguments,
|
| 541 |
+
parsed_arguments=tool_call.function.parsed_arguments,
|
| 542 |
+
arguments_delta=tool_call_delta.function.arguments or "",
|
| 543 |
+
)
|
| 544 |
+
)
|
| 545 |
+
elif TYPE_CHECKING: # type: ignore[unreachable]
|
| 546 |
+
assert_never(tool_call)
|
| 547 |
+
|
| 548 |
+
if choice.logprobs is not None and choice_snapshot.logprobs is not None:
|
| 549 |
+
if choice.logprobs.content and choice_snapshot.logprobs.content:
|
| 550 |
+
events_to_fire.append(
|
| 551 |
+
build(
|
| 552 |
+
LogprobsContentDeltaEvent,
|
| 553 |
+
type="logprobs.content.delta",
|
| 554 |
+
content=choice.logprobs.content,
|
| 555 |
+
snapshot=choice_snapshot.logprobs.content,
|
| 556 |
+
),
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
if choice.logprobs.refusal and choice_snapshot.logprobs.refusal:
|
| 560 |
+
events_to_fire.append(
|
| 561 |
+
build(
|
| 562 |
+
LogprobsRefusalDeltaEvent,
|
| 563 |
+
type="logprobs.refusal.delta",
|
| 564 |
+
refusal=choice.logprobs.refusal,
|
| 565 |
+
snapshot=choice_snapshot.logprobs.refusal,
|
| 566 |
+
),
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
events_to_fire.extend(
|
| 570 |
+
choice_state.get_done_events(
|
| 571 |
+
choice_chunk=choice,
|
| 572 |
+
choice_snapshot=choice_snapshot,
|
| 573 |
+
response_format=self._response_format,
|
| 574 |
+
)
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
return events_to_fire
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
class ChoiceEventState:
|
| 581 |
+
def __init__(self, *, input_tools: list[ChatCompletionToolParam]) -> None:
|
| 582 |
+
self._input_tools = input_tools
|
| 583 |
+
|
| 584 |
+
self._content_done = False
|
| 585 |
+
self._refusal_done = False
|
| 586 |
+
self._logprobs_content_done = False
|
| 587 |
+
self._logprobs_refusal_done = False
|
| 588 |
+
self._done_tool_calls: set[int] = set()
|
| 589 |
+
self.__current_tool_call_index: int | None = None
|
| 590 |
+
|
| 591 |
+
def get_done_events(
|
| 592 |
+
self,
|
| 593 |
+
*,
|
| 594 |
+
choice_chunk: ChoiceChunk,
|
| 595 |
+
choice_snapshot: ParsedChoiceSnapshot,
|
| 596 |
+
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
|
| 597 |
+
) -> list[ChatCompletionStreamEvent[ResponseFormatT]]:
|
| 598 |
+
events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]] = []
|
| 599 |
+
|
| 600 |
+
if choice_snapshot.finish_reason:
|
| 601 |
+
events_to_fire.extend(
|
| 602 |
+
self._content_done_events(choice_snapshot=choice_snapshot, response_format=response_format)
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
if (
|
| 606 |
+
self.__current_tool_call_index is not None
|
| 607 |
+
and self.__current_tool_call_index not in self._done_tool_calls
|
| 608 |
+
):
|
| 609 |
+
self._add_tool_done_event(
|
| 610 |
+
events_to_fire=events_to_fire,
|
| 611 |
+
choice_snapshot=choice_snapshot,
|
| 612 |
+
tool_index=self.__current_tool_call_index,
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
for tool_call in choice_chunk.delta.tool_calls or []:
|
| 616 |
+
if self.__current_tool_call_index != tool_call.index:
|
| 617 |
+
events_to_fire.extend(
|
| 618 |
+
self._content_done_events(choice_snapshot=choice_snapshot, response_format=response_format)
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
if self.__current_tool_call_index is not None:
|
| 622 |
+
self._add_tool_done_event(
|
| 623 |
+
events_to_fire=events_to_fire,
|
| 624 |
+
choice_snapshot=choice_snapshot,
|
| 625 |
+
tool_index=self.__current_tool_call_index,
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
self.__current_tool_call_index = tool_call.index
|
| 629 |
+
|
| 630 |
+
return events_to_fire
|
| 631 |
+
|
| 632 |
+
def _content_done_events(
|
| 633 |
+
self,
|
| 634 |
+
*,
|
| 635 |
+
choice_snapshot: ParsedChoiceSnapshot,
|
| 636 |
+
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
|
| 637 |
+
) -> list[ChatCompletionStreamEvent[ResponseFormatT]]:
|
| 638 |
+
events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]] = []
|
| 639 |
+
|
| 640 |
+
if choice_snapshot.message.content and not self._content_done:
|
| 641 |
+
self._content_done = True
|
| 642 |
+
|
| 643 |
+
parsed = maybe_parse_content(
|
| 644 |
+
response_format=response_format,
|
| 645 |
+
message=choice_snapshot.message,
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
# update the parsed content to now use the richer `response_format`
|
| 649 |
+
# as opposed to the raw JSON-parsed object as the content is now
|
| 650 |
+
# complete and can be fully validated.
|
| 651 |
+
choice_snapshot.message.parsed = parsed
|
| 652 |
+
|
| 653 |
+
events_to_fire.append(
|
| 654 |
+
build(
|
| 655 |
+
# we do this dance so that when the `ContentDoneEvent` instance
|
| 656 |
+
# is printed at runtime the class name will include the solved
|
| 657 |
+
# type variable, e.g. `ContentDoneEvent[MyModelType]`
|
| 658 |
+
cast( # pyright: ignore[reportUnnecessaryCast]
|
| 659 |
+
"type[ContentDoneEvent[ResponseFormatT]]",
|
| 660 |
+
cast(Any, ContentDoneEvent)[solve_response_format_t(response_format)],
|
| 661 |
+
),
|
| 662 |
+
type="content.done",
|
| 663 |
+
content=choice_snapshot.message.content,
|
| 664 |
+
parsed=parsed,
|
| 665 |
+
),
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
if choice_snapshot.message.refusal is not None and not self._refusal_done:
|
| 669 |
+
self._refusal_done = True
|
| 670 |
+
events_to_fire.append(
|
| 671 |
+
build(RefusalDoneEvent, type="refusal.done", refusal=choice_snapshot.message.refusal),
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
if (
|
| 675 |
+
choice_snapshot.logprobs is not None
|
| 676 |
+
and choice_snapshot.logprobs.content is not None
|
| 677 |
+
and not self._logprobs_content_done
|
| 678 |
+
):
|
| 679 |
+
self._logprobs_content_done = True
|
| 680 |
+
events_to_fire.append(
|
| 681 |
+
build(LogprobsContentDoneEvent, type="logprobs.content.done", content=choice_snapshot.logprobs.content),
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
if (
|
| 685 |
+
choice_snapshot.logprobs is not None
|
| 686 |
+
and choice_snapshot.logprobs.refusal is not None
|
| 687 |
+
and not self._logprobs_refusal_done
|
| 688 |
+
):
|
| 689 |
+
self._logprobs_refusal_done = True
|
| 690 |
+
events_to_fire.append(
|
| 691 |
+
build(LogprobsRefusalDoneEvent, type="logprobs.refusal.done", refusal=choice_snapshot.logprobs.refusal),
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
return events_to_fire
|
| 695 |
+
|
| 696 |
+
def _add_tool_done_event(
|
| 697 |
+
self,
|
| 698 |
+
*,
|
| 699 |
+
events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]],
|
| 700 |
+
choice_snapshot: ParsedChoiceSnapshot,
|
| 701 |
+
tool_index: int,
|
| 702 |
+
) -> None:
|
| 703 |
+
if tool_index in self._done_tool_calls:
|
| 704 |
+
return
|
| 705 |
+
|
| 706 |
+
self._done_tool_calls.add(tool_index)
|
| 707 |
+
|
| 708 |
+
assert choice_snapshot.message.tool_calls is not None
|
| 709 |
+
tool_call_snapshot = choice_snapshot.message.tool_calls[tool_index]
|
| 710 |
+
|
| 711 |
+
if tool_call_snapshot.type == "function":
|
| 712 |
+
parsed_arguments = parse_function_tool_arguments(
|
| 713 |
+
input_tools=self._input_tools, function=tool_call_snapshot.function
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
# update the parsed content to potentially use a richer type
|
| 717 |
+
# as opposed to the raw JSON-parsed object as the content is now
|
| 718 |
+
# complete and can be fully validated.
|
| 719 |
+
tool_call_snapshot.function.parsed_arguments = parsed_arguments
|
| 720 |
+
|
| 721 |
+
events_to_fire.append(
|
| 722 |
+
build(
|
| 723 |
+
FunctionToolCallArgumentsDoneEvent,
|
| 724 |
+
type="tool_calls.function.arguments.done",
|
| 725 |
+
index=tool_index,
|
| 726 |
+
name=tool_call_snapshot.function.name,
|
| 727 |
+
arguments=tool_call_snapshot.function.arguments,
|
| 728 |
+
parsed_arguments=parsed_arguments,
|
| 729 |
+
)
|
| 730 |
+
)
|
| 731 |
+
elif TYPE_CHECKING: # type: ignore[unreachable]
|
| 732 |
+
assert_never(tool_call_snapshot)
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
def _convert_initial_chunk_into_snapshot(chunk: ChatCompletionChunk) -> ParsedChatCompletionSnapshot:
|
| 736 |
+
data = chunk.to_dict()
|
| 737 |
+
choices = cast("list[object]", data["choices"])
|
| 738 |
+
|
| 739 |
+
for choice in chunk.choices:
|
| 740 |
+
choices[choice.index] = {
|
| 741 |
+
**choice.model_dump(exclude_unset=True, exclude={"delta"}),
|
| 742 |
+
"message": choice.delta.to_dict(),
|
| 743 |
+
}
|
| 744 |
+
|
| 745 |
+
return cast(
|
| 746 |
+
ParsedChatCompletionSnapshot,
|
| 747 |
+
construct_type(
|
| 748 |
+
type_=ParsedChatCompletionSnapshot,
|
| 749 |
+
value={
|
| 750 |
+
"system_fingerprint": None,
|
| 751 |
+
**data,
|
| 752 |
+
"object": "chat.completion",
|
| 753 |
+
},
|
| 754 |
+
),
|
| 755 |
+
)
|
.venv/lib/python3.11/site-packages/openai/lib/streaming/chat/_events.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Union, Generic, Optional
|
| 2 |
+
from typing_extensions import Literal
|
| 3 |
+
|
| 4 |
+
from ._types import ParsedChatCompletionSnapshot
|
| 5 |
+
from ...._models import BaseModel, GenericModel
|
| 6 |
+
from ..._parsing import ResponseFormatT
|
| 7 |
+
from ....types.chat import ChatCompletionChunk, ChatCompletionTokenLogprob
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ChunkEvent(BaseModel):
|
| 11 |
+
type: Literal["chunk"]
|
| 12 |
+
|
| 13 |
+
chunk: ChatCompletionChunk
|
| 14 |
+
|
| 15 |
+
snapshot: ParsedChatCompletionSnapshot
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ContentDeltaEvent(BaseModel):
|
| 19 |
+
"""This event is yielded for every chunk with `choice.delta.content` data."""
|
| 20 |
+
|
| 21 |
+
type: Literal["content.delta"]
|
| 22 |
+
|
| 23 |
+
delta: str
|
| 24 |
+
|
| 25 |
+
snapshot: str
|
| 26 |
+
|
| 27 |
+
parsed: Optional[object] = None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ContentDoneEvent(GenericModel, Generic[ResponseFormatT]):
|
| 31 |
+
type: Literal["content.done"]
|
| 32 |
+
|
| 33 |
+
content: str
|
| 34 |
+
|
| 35 |
+
parsed: Optional[ResponseFormatT] = None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class RefusalDeltaEvent(BaseModel):
|
| 39 |
+
type: Literal["refusal.delta"]
|
| 40 |
+
|
| 41 |
+
delta: str
|
| 42 |
+
|
| 43 |
+
snapshot: str
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class RefusalDoneEvent(BaseModel):
|
| 47 |
+
type: Literal["refusal.done"]
|
| 48 |
+
|
| 49 |
+
refusal: str
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class FunctionToolCallArgumentsDeltaEvent(BaseModel):
|
| 53 |
+
type: Literal["tool_calls.function.arguments.delta"]
|
| 54 |
+
|
| 55 |
+
name: str
|
| 56 |
+
|
| 57 |
+
index: int
|
| 58 |
+
|
| 59 |
+
arguments: str
|
| 60 |
+
"""Accumulated raw JSON string"""
|
| 61 |
+
|
| 62 |
+
parsed_arguments: object
|
| 63 |
+
"""The parsed arguments so far"""
|
| 64 |
+
|
| 65 |
+
arguments_delta: str
|
| 66 |
+
"""The JSON string delta"""
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class FunctionToolCallArgumentsDoneEvent(BaseModel):
|
| 70 |
+
type: Literal["tool_calls.function.arguments.done"]
|
| 71 |
+
|
| 72 |
+
name: str
|
| 73 |
+
|
| 74 |
+
index: int
|
| 75 |
+
|
| 76 |
+
arguments: str
|
| 77 |
+
"""Accumulated raw JSON string"""
|
| 78 |
+
|
| 79 |
+
parsed_arguments: object
|
| 80 |
+
"""The parsed arguments"""
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class LogprobsContentDeltaEvent(BaseModel):
|
| 84 |
+
type: Literal["logprobs.content.delta"]
|
| 85 |
+
|
| 86 |
+
content: List[ChatCompletionTokenLogprob]
|
| 87 |
+
|
| 88 |
+
snapshot: List[ChatCompletionTokenLogprob]
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class LogprobsContentDoneEvent(BaseModel):
|
| 92 |
+
type: Literal["logprobs.content.done"]
|
| 93 |
+
|
| 94 |
+
content: List[ChatCompletionTokenLogprob]
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class LogprobsRefusalDeltaEvent(BaseModel):
|
| 98 |
+
type: Literal["logprobs.refusal.delta"]
|
| 99 |
+
|
| 100 |
+
refusal: List[ChatCompletionTokenLogprob]
|
| 101 |
+
|
| 102 |
+
snapshot: List[ChatCompletionTokenLogprob]
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class LogprobsRefusalDoneEvent(BaseModel):
|
| 106 |
+
type: Literal["logprobs.refusal.done"]
|
| 107 |
+
|
| 108 |
+
refusal: List[ChatCompletionTokenLogprob]
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
ChatCompletionStreamEvent = Union[
|
| 112 |
+
ChunkEvent,
|
| 113 |
+
ContentDeltaEvent,
|
| 114 |
+
ContentDoneEvent[ResponseFormatT],
|
| 115 |
+
RefusalDeltaEvent,
|
| 116 |
+
RefusalDoneEvent,
|
| 117 |
+
FunctionToolCallArgumentsDeltaEvent,
|
| 118 |
+
FunctionToolCallArgumentsDoneEvent,
|
| 119 |
+
LogprobsContentDeltaEvent,
|
| 120 |
+
LogprobsContentDoneEvent,
|
| 121 |
+
LogprobsRefusalDeltaEvent,
|
| 122 |
+
LogprobsRefusalDoneEvent,
|
| 123 |
+
]
|
.venv/lib/python3.11/site-packages/openai/lib/streaming/chat/_types.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing_extensions import TypeAlias
|
| 4 |
+
|
| 5 |
+
from ....types.chat import ParsedChoice, ParsedChatCompletion, ParsedChatCompletionMessage
|
| 6 |
+
|
| 7 |
+
ParsedChatCompletionSnapshot: TypeAlias = ParsedChatCompletion[object]
|
| 8 |
+
"""Snapshot type representing an in-progress accumulation of
|
| 9 |
+
a `ParsedChatCompletion` object.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
ParsedChatCompletionMessageSnapshot: TypeAlias = ParsedChatCompletionMessage[object]
|
| 13 |
+
"""Snapshot type representing an in-progress accumulation of
|
| 14 |
+
a `ParsedChatCompletionMessage` object.
|
| 15 |
+
|
| 16 |
+
If the content has been fully accumulated, the `.parsed` content will be
|
| 17 |
+
the `response_format` instance, otherwise it'll be the raw JSON parsed version.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
ParsedChoiceSnapshot: TypeAlias = ParsedChoice[object]
|
.venv/lib/python3.11/site-packages/openai/resources/__init__.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
| 2 |
+
|
| 3 |
+
from .beta import (
|
| 4 |
+
Beta,
|
| 5 |
+
AsyncBeta,
|
| 6 |
+
BetaWithRawResponse,
|
| 7 |
+
AsyncBetaWithRawResponse,
|
| 8 |
+
BetaWithStreamingResponse,
|
| 9 |
+
AsyncBetaWithStreamingResponse,
|
| 10 |
+
)
|
| 11 |
+
from .chat import (
|
| 12 |
+
Chat,
|
| 13 |
+
AsyncChat,
|
| 14 |
+
ChatWithRawResponse,
|
| 15 |
+
AsyncChatWithRawResponse,
|
| 16 |
+
ChatWithStreamingResponse,
|
| 17 |
+
AsyncChatWithStreamingResponse,
|
| 18 |
+
)
|
| 19 |
+
from .audio import (
|
| 20 |
+
Audio,
|
| 21 |
+
AsyncAudio,
|
| 22 |
+
AudioWithRawResponse,
|
| 23 |
+
AsyncAudioWithRawResponse,
|
| 24 |
+
AudioWithStreamingResponse,
|
| 25 |
+
AsyncAudioWithStreamingResponse,
|
| 26 |
+
)
|
| 27 |
+
from .files import (
|
| 28 |
+
Files,
|
| 29 |
+
AsyncFiles,
|
| 30 |
+
FilesWithRawResponse,
|
| 31 |
+
AsyncFilesWithRawResponse,
|
| 32 |
+
FilesWithStreamingResponse,
|
| 33 |
+
AsyncFilesWithStreamingResponse,
|
| 34 |
+
)
|
| 35 |
+
from .images import (
|
| 36 |
+
Images,
|
| 37 |
+
AsyncImages,
|
| 38 |
+
ImagesWithRawResponse,
|
| 39 |
+
AsyncImagesWithRawResponse,
|
| 40 |
+
ImagesWithStreamingResponse,
|
| 41 |
+
AsyncImagesWithStreamingResponse,
|
| 42 |
+
)
|
| 43 |
+
from .models import (
|
| 44 |
+
Models,
|
| 45 |
+
AsyncModels,
|
| 46 |
+
ModelsWithRawResponse,
|
| 47 |
+
AsyncModelsWithRawResponse,
|
| 48 |
+
ModelsWithStreamingResponse,
|
| 49 |
+
AsyncModelsWithStreamingResponse,
|
| 50 |
+
)
|
| 51 |
+
from .batches import (
|
| 52 |
+
Batches,
|
| 53 |
+
AsyncBatches,
|
| 54 |
+
BatchesWithRawResponse,
|
| 55 |
+
AsyncBatchesWithRawResponse,
|
| 56 |
+
BatchesWithStreamingResponse,
|
| 57 |
+
AsyncBatchesWithStreamingResponse,
|
| 58 |
+
)
|
| 59 |
+
from .uploads import (
|
| 60 |
+
Uploads,
|
| 61 |
+
AsyncUploads,
|
| 62 |
+
UploadsWithRawResponse,
|
| 63 |
+
AsyncUploadsWithRawResponse,
|
| 64 |
+
UploadsWithStreamingResponse,
|
| 65 |
+
AsyncUploadsWithStreamingResponse,
|
| 66 |
+
)
|
| 67 |
+
from .embeddings import (
|
| 68 |
+
Embeddings,
|
| 69 |
+
AsyncEmbeddings,
|
| 70 |
+
EmbeddingsWithRawResponse,
|
| 71 |
+
AsyncEmbeddingsWithRawResponse,
|
| 72 |
+
EmbeddingsWithStreamingResponse,
|
| 73 |
+
AsyncEmbeddingsWithStreamingResponse,
|
| 74 |
+
)
|
| 75 |
+
from .completions import (
|
| 76 |
+
Completions,
|
| 77 |
+
AsyncCompletions,
|
| 78 |
+
CompletionsWithRawResponse,
|
| 79 |
+
AsyncCompletionsWithRawResponse,
|
| 80 |
+
CompletionsWithStreamingResponse,
|
| 81 |
+
AsyncCompletionsWithStreamingResponse,
|
| 82 |
+
)
|
| 83 |
+
from .fine_tuning import (
|
| 84 |
+
FineTuning,
|
| 85 |
+
AsyncFineTuning,
|
| 86 |
+
FineTuningWithRawResponse,
|
| 87 |
+
AsyncFineTuningWithRawResponse,
|
| 88 |
+
FineTuningWithStreamingResponse,
|
| 89 |
+
AsyncFineTuningWithStreamingResponse,
|
| 90 |
+
)
|
| 91 |
+
from .moderations import (
|
| 92 |
+
Moderations,
|
| 93 |
+
AsyncModerations,
|
| 94 |
+
ModerationsWithRawResponse,
|
| 95 |
+
AsyncModerationsWithRawResponse,
|
| 96 |
+
ModerationsWithStreamingResponse,
|
| 97 |
+
AsyncModerationsWithStreamingResponse,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
__all__ = [
|
| 101 |
+
"Completions",
|
| 102 |
+
"AsyncCompletions",
|
| 103 |
+
"CompletionsWithRawResponse",
|
| 104 |
+
"AsyncCompletionsWithRawResponse",
|
| 105 |
+
"CompletionsWithStreamingResponse",
|
| 106 |
+
"AsyncCompletionsWithStreamingResponse",
|
| 107 |
+
"Chat",
|
| 108 |
+
"AsyncChat",
|
| 109 |
+
"ChatWithRawResponse",
|
| 110 |
+
"AsyncChatWithRawResponse",
|
| 111 |
+
"ChatWithStreamingResponse",
|
| 112 |
+
"AsyncChatWithStreamingResponse",
|
| 113 |
+
"Embeddings",
|
| 114 |
+
"AsyncEmbeddings",
|
| 115 |
+
"EmbeddingsWithRawResponse",
|
| 116 |
+
"AsyncEmbeddingsWithRawResponse",
|
| 117 |
+
"EmbeddingsWithStreamingResponse",
|
| 118 |
+
"AsyncEmbeddingsWithStreamingResponse",
|
| 119 |
+
"Files",
|
| 120 |
+
"AsyncFiles",
|
| 121 |
+
"FilesWithRawResponse",
|
| 122 |
+
"AsyncFilesWithRawResponse",
|
| 123 |
+
"FilesWithStreamingResponse",
|
| 124 |
+
"AsyncFilesWithStreamingResponse",
|
| 125 |
+
"Images",
|
| 126 |
+
"AsyncImages",
|
| 127 |
+
"ImagesWithRawResponse",
|
| 128 |
+
"AsyncImagesWithRawResponse",
|
| 129 |
+
"ImagesWithStreamingResponse",
|
| 130 |
+
"AsyncImagesWithStreamingResponse",
|
| 131 |
+
"Audio",
|
| 132 |
+
"AsyncAudio",
|
| 133 |
+
"AudioWithRawResponse",
|
| 134 |
+
"AsyncAudioWithRawResponse",
|
| 135 |
+
"AudioWithStreamingResponse",
|
| 136 |
+
"AsyncAudioWithStreamingResponse",
|
| 137 |
+
"Moderations",
|
| 138 |
+
"AsyncModerations",
|
| 139 |
+
"ModerationsWithRawResponse",
|
| 140 |
+
"AsyncModerationsWithRawResponse",
|
| 141 |
+
"ModerationsWithStreamingResponse",
|
| 142 |
+
"AsyncModerationsWithStreamingResponse",
|
| 143 |
+
"Models",
|
| 144 |
+
"AsyncModels",
|
| 145 |
+
"ModelsWithRawResponse",
|
| 146 |
+
"AsyncModelsWithRawResponse",
|
| 147 |
+
"ModelsWithStreamingResponse",
|
| 148 |
+
"AsyncModelsWithStreamingResponse",
|
| 149 |
+
"FineTuning",
|
| 150 |
+
"AsyncFineTuning",
|
| 151 |
+
"FineTuningWithRawResponse",
|
| 152 |
+
"AsyncFineTuningWithRawResponse",
|
| 153 |
+
"FineTuningWithStreamingResponse",
|
| 154 |
+
"AsyncFineTuningWithStreamingResponse",
|
| 155 |
+
"Beta",
|
| 156 |
+
"AsyncBeta",
|
| 157 |
+
"BetaWithRawResponse",
|
| 158 |
+
"AsyncBetaWithRawResponse",
|
| 159 |
+
"BetaWithStreamingResponse",
|
| 160 |
+
"AsyncBetaWithStreamingResponse",
|
| 161 |
+
"Batches",
|
| 162 |
+
"AsyncBatches",
|
| 163 |
+
"BatchesWithRawResponse",
|
| 164 |
+
"AsyncBatchesWithRawResponse",
|
| 165 |
+
"BatchesWithStreamingResponse",
|
| 166 |
+
"AsyncBatchesWithStreamingResponse",
|
| 167 |
+
"Uploads",
|
| 168 |
+
"AsyncUploads",
|
| 169 |
+
"UploadsWithRawResponse",
|
| 170 |
+
"AsyncUploadsWithRawResponse",
|
| 171 |
+
"UploadsWithStreamingResponse",
|
| 172 |
+
"AsyncUploadsWithStreamingResponse",
|
| 173 |
+
]
|
.venv/lib/python3.11/site-packages/openai/resources/audio/__init__.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
| 2 |
+
|
| 3 |
+
from .audio import (
|
| 4 |
+
Audio,
|
| 5 |
+
AsyncAudio,
|
| 6 |
+
AudioWithRawResponse,
|
| 7 |
+
AsyncAudioWithRawResponse,
|
| 8 |
+
AudioWithStreamingResponse,
|
| 9 |
+
AsyncAudioWithStreamingResponse,
|
| 10 |
+
)
|
| 11 |
+
from .speech import (
|
| 12 |
+
Speech,
|
| 13 |
+
AsyncSpeech,
|
| 14 |
+
SpeechWithRawResponse,
|
| 15 |
+
AsyncSpeechWithRawResponse,
|
| 16 |
+
SpeechWithStreamingResponse,
|
| 17 |
+
AsyncSpeechWithStreamingResponse,
|
| 18 |
+
)
|
| 19 |
+
from .translations import (
|
| 20 |
+
Translations,
|
| 21 |
+
AsyncTranslations,
|
| 22 |
+
TranslationsWithRawResponse,
|
| 23 |
+
AsyncTranslationsWithRawResponse,
|
| 24 |
+
TranslationsWithStreamingResponse,
|
| 25 |
+
AsyncTranslationsWithStreamingResponse,
|
| 26 |
+
)
|
| 27 |
+
from .transcriptions import (
|
| 28 |
+
Transcriptions,
|
| 29 |
+
AsyncTranscriptions,
|
| 30 |
+
TranscriptionsWithRawResponse,
|
| 31 |
+
AsyncTranscriptionsWithRawResponse,
|
| 32 |
+
TranscriptionsWithStreamingResponse,
|
| 33 |
+
AsyncTranscriptionsWithStreamingResponse,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
__all__ = [
|
| 37 |
+
"Transcriptions",
|
| 38 |
+
"AsyncTranscriptions",
|
| 39 |
+
"TranscriptionsWithRawResponse",
|
| 40 |
+
"AsyncTranscriptionsWithRawResponse",
|
| 41 |
+
"TranscriptionsWithStreamingResponse",
|
| 42 |
+
"AsyncTranscriptionsWithStreamingResponse",
|
| 43 |
+
"Translations",
|
| 44 |
+
"AsyncTranslations",
|
| 45 |
+
"TranslationsWithRawResponse",
|
| 46 |
+
"AsyncTranslationsWithRawResponse",
|
| 47 |
+
"TranslationsWithStreamingResponse",
|
| 48 |
+
"AsyncTranslationsWithStreamingResponse",
|
| 49 |
+
"Speech",
|
| 50 |
+
"AsyncSpeech",
|
| 51 |
+
"SpeechWithRawResponse",
|
| 52 |
+
"AsyncSpeechWithRawResponse",
|
| 53 |
+
"SpeechWithStreamingResponse",
|
| 54 |
+
"AsyncSpeechWithStreamingResponse",
|
| 55 |
+
"Audio",
|
| 56 |
+
"AsyncAudio",
|
| 57 |
+
"AudioWithRawResponse",
|
| 58 |
+
"AsyncAudioWithRawResponse",
|
| 59 |
+
"AudioWithStreamingResponse",
|
| 60 |
+
"AsyncAudioWithStreamingResponse",
|
| 61 |
+
]
|
.venv/lib/python3.11/site-packages/openai/resources/audio/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.59 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/openai/resources/audio/__pycache__/audio.cpython-311.pyc
ADDED
|
Binary file (9.64 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/openai/resources/audio/__pycache__/speech.cpython-311.pyc
ADDED
|
Binary file (8.62 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/openai/resources/audio/__pycache__/transcriptions.cpython-311.pyc
ADDED
|
Binary file (13.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/openai/resources/audio/__pycache__/translations.cpython-311.pyc
ADDED
|
Binary file (12.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/openai/resources/audio/audio.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from .speech import (
|
| 6 |
+
Speech,
|
| 7 |
+
AsyncSpeech,
|
| 8 |
+
SpeechWithRawResponse,
|
| 9 |
+
AsyncSpeechWithRawResponse,
|
| 10 |
+
SpeechWithStreamingResponse,
|
| 11 |
+
AsyncSpeechWithStreamingResponse,
|
| 12 |
+
)
|
| 13 |
+
from ..._compat import cached_property
|
| 14 |
+
from ..._resource import SyncAPIResource, AsyncAPIResource
|
| 15 |
+
from .translations import (
|
| 16 |
+
Translations,
|
| 17 |
+
AsyncTranslations,
|
| 18 |
+
TranslationsWithRawResponse,
|
| 19 |
+
AsyncTranslationsWithRawResponse,
|
| 20 |
+
TranslationsWithStreamingResponse,
|
| 21 |
+
AsyncTranslationsWithStreamingResponse,
|
| 22 |
+
)
|
| 23 |
+
from .transcriptions import (
|
| 24 |
+
Transcriptions,
|
| 25 |
+
AsyncTranscriptions,
|
| 26 |
+
TranscriptionsWithRawResponse,
|
| 27 |
+
AsyncTranscriptionsWithRawResponse,
|
| 28 |
+
TranscriptionsWithStreamingResponse,
|
| 29 |
+
AsyncTranscriptionsWithStreamingResponse,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
__all__ = ["Audio", "AsyncAudio"]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class Audio(SyncAPIResource):
|
| 36 |
+
@cached_property
|
| 37 |
+
def transcriptions(self) -> Transcriptions:
|
| 38 |
+
return Transcriptions(self._client)
|
| 39 |
+
|
| 40 |
+
@cached_property
|
| 41 |
+
def translations(self) -> Translations:
|
| 42 |
+
return Translations(self._client)
|
| 43 |
+
|
| 44 |
+
@cached_property
|
| 45 |
+
def speech(self) -> Speech:
|
| 46 |
+
return Speech(self._client)
|
| 47 |
+
|
| 48 |
+
@cached_property
|
| 49 |
+
def with_raw_response(self) -> AudioWithRawResponse:
|
| 50 |
+
"""
|
| 51 |
+
This property can be used as a prefix for any HTTP method call to return
|
| 52 |
+
the raw response object instead of the parsed content.
|
| 53 |
+
|
| 54 |
+
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
|
| 55 |
+
"""
|
| 56 |
+
return AudioWithRawResponse(self)
|
| 57 |
+
|
| 58 |
+
@cached_property
|
| 59 |
+
def with_streaming_response(self) -> AudioWithStreamingResponse:
|
| 60 |
+
"""
|
| 61 |
+
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
|
| 62 |
+
|
| 63 |
+
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
|
| 64 |
+
"""
|
| 65 |
+
return AudioWithStreamingResponse(self)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class AsyncAudio(AsyncAPIResource):
|
| 69 |
+
@cached_property
|
| 70 |
+
def transcriptions(self) -> AsyncTranscriptions:
|
| 71 |
+
return AsyncTranscriptions(self._client)
|
| 72 |
+
|
| 73 |
+
@cached_property
|
| 74 |
+
def translations(self) -> AsyncTranslations:
|
| 75 |
+
return AsyncTranslations(self._client)
|
| 76 |
+
|
| 77 |
+
@cached_property
|
| 78 |
+
def speech(self) -> AsyncSpeech:
|
| 79 |
+
return AsyncSpeech(self._client)
|
| 80 |
+
|
| 81 |
+
@cached_property
|
| 82 |
+
def with_raw_response(self) -> AsyncAudioWithRawResponse:
|
| 83 |
+
"""
|
| 84 |
+
This property can be used as a prefix for any HTTP method call to return
|
| 85 |
+
the raw response object instead of the parsed content.
|
| 86 |
+
|
| 87 |
+
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
|
| 88 |
+
"""
|
| 89 |
+
return AsyncAudioWithRawResponse(self)
|
| 90 |
+
|
| 91 |
+
@cached_property
|
| 92 |
+
def with_streaming_response(self) -> AsyncAudioWithStreamingResponse:
|
| 93 |
+
"""
|
| 94 |
+
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
|
| 95 |
+
|
| 96 |
+
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
|
| 97 |
+
"""
|
| 98 |
+
return AsyncAudioWithStreamingResponse(self)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class AudioWithRawResponse:
|
| 102 |
+
def __init__(self, audio: Audio) -> None:
|
| 103 |
+
self._audio = audio
|
| 104 |
+
|
| 105 |
+
@cached_property
|
| 106 |
+
def transcriptions(self) -> TranscriptionsWithRawResponse:
|
| 107 |
+
return TranscriptionsWithRawResponse(self._audio.transcriptions)
|
| 108 |
+
|
| 109 |
+
@cached_property
|
| 110 |
+
def translations(self) -> TranslationsWithRawResponse:
|
| 111 |
+
return TranslationsWithRawResponse(self._audio.translations)
|
| 112 |
+
|
| 113 |
+
@cached_property
|
| 114 |
+
def speech(self) -> SpeechWithRawResponse:
|
| 115 |
+
return SpeechWithRawResponse(self._audio.speech)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class AsyncAudioWithRawResponse:
|
| 119 |
+
def __init__(self, audio: AsyncAudio) -> None:
|
| 120 |
+
self._audio = audio
|
| 121 |
+
|
| 122 |
+
@cached_property
|
| 123 |
+
def transcriptions(self) -> AsyncTranscriptionsWithRawResponse:
|
| 124 |
+
return AsyncTranscriptionsWithRawResponse(self._audio.transcriptions)
|
| 125 |
+
|
| 126 |
+
@cached_property
|
| 127 |
+
def translations(self) -> AsyncTranslationsWithRawResponse:
|
| 128 |
+
return AsyncTranslationsWithRawResponse(self._audio.translations)
|
| 129 |
+
|
| 130 |
+
@cached_property
|
| 131 |
+
def speech(self) -> AsyncSpeechWithRawResponse:
|
| 132 |
+
return AsyncSpeechWithRawResponse(self._audio.speech)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class AudioWithStreamingResponse:
|
| 136 |
+
def __init__(self, audio: Audio) -> None:
|
| 137 |
+
self._audio = audio
|
| 138 |
+
|
| 139 |
+
@cached_property
|
| 140 |
+
def transcriptions(self) -> TranscriptionsWithStreamingResponse:
|
| 141 |
+
return TranscriptionsWithStreamingResponse(self._audio.transcriptions)
|
| 142 |
+
|
| 143 |
+
@cached_property
|
| 144 |
+
def translations(self) -> TranslationsWithStreamingResponse:
|
| 145 |
+
return TranslationsWithStreamingResponse(self._audio.translations)
|
| 146 |
+
|
| 147 |
+
@cached_property
|
| 148 |
+
def speech(self) -> SpeechWithStreamingResponse:
|
| 149 |
+
return SpeechWithStreamingResponse(self._audio.speech)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class AsyncAudioWithStreamingResponse:
|
| 153 |
+
def __init__(self, audio: AsyncAudio) -> None:
|
| 154 |
+
self._audio = audio
|
| 155 |
+
|
| 156 |
+
@cached_property
|
| 157 |
+
def transcriptions(self) -> AsyncTranscriptionsWithStreamingResponse:
|
| 158 |
+
return AsyncTranscriptionsWithStreamingResponse(self._audio.transcriptions)
|
| 159 |
+
|
| 160 |
+
@cached_property
|
| 161 |
+
def translations(self) -> AsyncTranslationsWithStreamingResponse:
|
| 162 |
+
return AsyncTranslationsWithStreamingResponse(self._audio.translations)
|
| 163 |
+
|
| 164 |
+
@cached_property
|
| 165 |
+
def speech(self) -> AsyncSpeechWithStreamingResponse:
|
| 166 |
+
return AsyncSpeechWithStreamingResponse(self._audio.speech)
|
.venv/lib/python3.11/site-packages/openai/resources/audio/speech.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Union
|
| 6 |
+
from typing_extensions import Literal
|
| 7 |
+
|
| 8 |
+
import httpx
|
| 9 |
+
|
| 10 |
+
from ... import _legacy_response
|
| 11 |
+
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
|
| 12 |
+
from ..._utils import (
|
| 13 |
+
maybe_transform,
|
| 14 |
+
async_maybe_transform,
|
| 15 |
+
)
|
| 16 |
+
from ..._compat import cached_property
|
| 17 |
+
from ..._resource import SyncAPIResource, AsyncAPIResource
|
| 18 |
+
from ..._response import (
|
| 19 |
+
StreamedBinaryAPIResponse,
|
| 20 |
+
AsyncStreamedBinaryAPIResponse,
|
| 21 |
+
to_custom_streamed_response_wrapper,
|
| 22 |
+
async_to_custom_streamed_response_wrapper,
|
| 23 |
+
)
|
| 24 |
+
from ...types.audio import speech_create_params
|
| 25 |
+
from ..._base_client import make_request_options
|
| 26 |
+
from ...types.audio.speech_model import SpeechModel
|
| 27 |
+
|
| 28 |
+
__all__ = ["Speech", "AsyncSpeech"]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Speech(SyncAPIResource):
|
| 32 |
+
@cached_property
|
| 33 |
+
def with_raw_response(self) -> SpeechWithRawResponse:
|
| 34 |
+
"""
|
| 35 |
+
This property can be used as a prefix for any HTTP method call to return
|
| 36 |
+
the raw response object instead of the parsed content.
|
| 37 |
+
|
| 38 |
+
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
|
| 39 |
+
"""
|
| 40 |
+
return SpeechWithRawResponse(self)
|
| 41 |
+
|
| 42 |
+
@cached_property
|
| 43 |
+
def with_streaming_response(self) -> SpeechWithStreamingResponse:
|
| 44 |
+
"""
|
| 45 |
+
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
|
| 46 |
+
|
| 47 |
+
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
|
| 48 |
+
"""
|
| 49 |
+
return SpeechWithStreamingResponse(self)
|
| 50 |
+
|
| 51 |
+
def create(
|
| 52 |
+
self,
|
| 53 |
+
*,
|
| 54 |
+
input: str,
|
| 55 |
+
model: Union[str, SpeechModel],
|
| 56 |
+
voice: Literal["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"],
|
| 57 |
+
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] | NotGiven = NOT_GIVEN,
|
| 58 |
+
speed: float | NotGiven = NOT_GIVEN,
|
| 59 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 60 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 61 |
+
extra_headers: Headers | None = None,
|
| 62 |
+
extra_query: Query | None = None,
|
| 63 |
+
extra_body: Body | None = None,
|
| 64 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 65 |
+
) -> _legacy_response.HttpxBinaryResponseContent:
|
| 66 |
+
"""
|
| 67 |
+
Generates audio from the input text.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
input: The text to generate audio for. The maximum length is 4096 characters.
|
| 71 |
+
|
| 72 |
+
model:
|
| 73 |
+
One of the available [TTS models](https://platform.openai.com/docs/models#tts):
|
| 74 |
+
`tts-1` or `tts-1-hd`
|
| 75 |
+
|
| 76 |
+
voice: The voice to use when generating the audio. Supported voices are `alloy`, `ash`,
|
| 77 |
+
`coral`, `echo`, `fable`, `onyx`, `nova`, `sage` and `shimmer`. Previews of the
|
| 78 |
+
voices are available in the
|
| 79 |
+
[Text to speech guide](https://platform.openai.com/docs/guides/text-to-speech#voice-options).
|
| 80 |
+
|
| 81 |
+
response_format: The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`,
|
| 82 |
+
`wav`, and `pcm`.
|
| 83 |
+
|
| 84 |
+
speed: The speed of the generated audio. Select a value from `0.25` to `4.0`. `1.0` is
|
| 85 |
+
the default.
|
| 86 |
+
|
| 87 |
+
extra_headers: Send extra headers
|
| 88 |
+
|
| 89 |
+
extra_query: Add additional query parameters to the request
|
| 90 |
+
|
| 91 |
+
extra_body: Add additional JSON properties to the request
|
| 92 |
+
|
| 93 |
+
timeout: Override the client-level default timeout for this request, in seconds
|
| 94 |
+
"""
|
| 95 |
+
extra_headers = {"Accept": "application/octet-stream", **(extra_headers or {})}
|
| 96 |
+
return self._post(
|
| 97 |
+
"/audio/speech",
|
| 98 |
+
body=maybe_transform(
|
| 99 |
+
{
|
| 100 |
+
"input": input,
|
| 101 |
+
"model": model,
|
| 102 |
+
"voice": voice,
|
| 103 |
+
"response_format": response_format,
|
| 104 |
+
"speed": speed,
|
| 105 |
+
},
|
| 106 |
+
speech_create_params.SpeechCreateParams,
|
| 107 |
+
),
|
| 108 |
+
options=make_request_options(
|
| 109 |
+
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
|
| 110 |
+
),
|
| 111 |
+
cast_to=_legacy_response.HttpxBinaryResponseContent,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class AsyncSpeech(AsyncAPIResource):
|
| 116 |
+
@cached_property
|
| 117 |
+
def with_raw_response(self) -> AsyncSpeechWithRawResponse:
|
| 118 |
+
"""
|
| 119 |
+
This property can be used as a prefix for any HTTP method call to return
|
| 120 |
+
the raw response object instead of the parsed content.
|
| 121 |
+
|
| 122 |
+
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
|
| 123 |
+
"""
|
| 124 |
+
return AsyncSpeechWithRawResponse(self)
|
| 125 |
+
|
| 126 |
+
@cached_property
|
| 127 |
+
def with_streaming_response(self) -> AsyncSpeechWithStreamingResponse:
|
| 128 |
+
"""
|
| 129 |
+
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
|
| 130 |
+
|
| 131 |
+
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
|
| 132 |
+
"""
|
| 133 |
+
return AsyncSpeechWithStreamingResponse(self)
|
| 134 |
+
|
| 135 |
+
async def create(
|
| 136 |
+
self,
|
| 137 |
+
*,
|
| 138 |
+
input: str,
|
| 139 |
+
model: Union[str, SpeechModel],
|
| 140 |
+
voice: Literal["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"],
|
| 141 |
+
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] | NotGiven = NOT_GIVEN,
|
| 142 |
+
speed: float | NotGiven = NOT_GIVEN,
|
| 143 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 144 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 145 |
+
extra_headers: Headers | None = None,
|
| 146 |
+
extra_query: Query | None = None,
|
| 147 |
+
extra_body: Body | None = None,
|
| 148 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 149 |
+
) -> _legacy_response.HttpxBinaryResponseContent:
|
| 150 |
+
"""
|
| 151 |
+
Generates audio from the input text.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
input: The text to generate audio for. The maximum length is 4096 characters.
|
| 155 |
+
|
| 156 |
+
model:
|
| 157 |
+
One of the available [TTS models](https://platform.openai.com/docs/models#tts):
|
| 158 |
+
`tts-1` or `tts-1-hd`
|
| 159 |
+
|
| 160 |
+
voice: The voice to use when generating the audio. Supported voices are `alloy`, `ash`,
|
| 161 |
+
`coral`, `echo`, `fable`, `onyx`, `nova`, `sage` and `shimmer`. Previews of the
|
| 162 |
+
voices are available in the
|
| 163 |
+
[Text to speech guide](https://platform.openai.com/docs/guides/text-to-speech#voice-options).
|
| 164 |
+
|
| 165 |
+
response_format: The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`,
|
| 166 |
+
`wav`, and `pcm`.
|
| 167 |
+
|
| 168 |
+
speed: The speed of the generated audio. Select a value from `0.25` to `4.0`. `1.0` is
|
| 169 |
+
the default.
|
| 170 |
+
|
| 171 |
+
extra_headers: Send extra headers
|
| 172 |
+
|
| 173 |
+
extra_query: Add additional query parameters to the request
|
| 174 |
+
|
| 175 |
+
extra_body: Add additional JSON properties to the request
|
| 176 |
+
|
| 177 |
+
timeout: Override the client-level default timeout for this request, in seconds
|
| 178 |
+
"""
|
| 179 |
+
extra_headers = {"Accept": "application/octet-stream", **(extra_headers or {})}
|
| 180 |
+
return await self._post(
|
| 181 |
+
"/audio/speech",
|
| 182 |
+
body=await async_maybe_transform(
|
| 183 |
+
{
|
| 184 |
+
"input": input,
|
| 185 |
+
"model": model,
|
| 186 |
+
"voice": voice,
|
| 187 |
+
"response_format": response_format,
|
| 188 |
+
"speed": speed,
|
| 189 |
+
},
|
| 190 |
+
speech_create_params.SpeechCreateParams,
|
| 191 |
+
),
|
| 192 |
+
options=make_request_options(
|
| 193 |
+
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
|
| 194 |
+
),
|
| 195 |
+
cast_to=_legacy_response.HttpxBinaryResponseContent,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class SpeechWithRawResponse:
|
| 200 |
+
def __init__(self, speech: Speech) -> None:
|
| 201 |
+
self._speech = speech
|
| 202 |
+
|
| 203 |
+
self.create = _legacy_response.to_raw_response_wrapper(
|
| 204 |
+
speech.create,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class AsyncSpeechWithRawResponse:
|
| 209 |
+
def __init__(self, speech: AsyncSpeech) -> None:
|
| 210 |
+
self._speech = speech
|
| 211 |
+
|
| 212 |
+
self.create = _legacy_response.async_to_raw_response_wrapper(
|
| 213 |
+
speech.create,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class SpeechWithStreamingResponse:
|
| 218 |
+
def __init__(self, speech: Speech) -> None:
|
| 219 |
+
self._speech = speech
|
| 220 |
+
|
| 221 |
+
self.create = to_custom_streamed_response_wrapper(
|
| 222 |
+
speech.create,
|
| 223 |
+
StreamedBinaryAPIResponse,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class AsyncSpeechWithStreamingResponse:
|
| 228 |
+
def __init__(self, speech: AsyncSpeech) -> None:
|
| 229 |
+
self._speech = speech
|
| 230 |
+
|
| 231 |
+
self.create = async_to_custom_streamed_response_wrapper(
|
| 232 |
+
speech.create,
|
| 233 |
+
AsyncStreamedBinaryAPIResponse,
|
| 234 |
+
)
|
.venv/lib/python3.11/site-packages/openai/resources/audio/transcriptions.py
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from typing import TYPE_CHECKING, List, Union, Mapping, cast
|
| 7 |
+
from typing_extensions import Literal, overload, assert_never
|
| 8 |
+
|
| 9 |
+
import httpx
|
| 10 |
+
|
| 11 |
+
from ... import _legacy_response
|
| 12 |
+
from ...types import AudioResponseFormat
|
| 13 |
+
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
|
| 14 |
+
from ..._utils import (
|
| 15 |
+
extract_files,
|
| 16 |
+
maybe_transform,
|
| 17 |
+
deepcopy_minimal,
|
| 18 |
+
async_maybe_transform,
|
| 19 |
+
)
|
| 20 |
+
from ..._compat import cached_property
|
| 21 |
+
from ..._resource import SyncAPIResource, AsyncAPIResource
|
| 22 |
+
from ..._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
|
| 23 |
+
from ...types.audio import transcription_create_params
|
| 24 |
+
from ..._base_client import make_request_options
|
| 25 |
+
from ...types.audio_model import AudioModel
|
| 26 |
+
from ...types.audio.transcription import Transcription
|
| 27 |
+
from ...types.audio_response_format import AudioResponseFormat
|
| 28 |
+
from ...types.audio.transcription_verbose import TranscriptionVerbose
|
| 29 |
+
|
| 30 |
+
__all__ = ["Transcriptions", "AsyncTranscriptions"]
|
| 31 |
+
|
| 32 |
+
log: logging.Logger = logging.getLogger("openai.audio.transcriptions")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class Transcriptions(SyncAPIResource):
|
| 36 |
+
@cached_property
|
| 37 |
+
def with_raw_response(self) -> TranscriptionsWithRawResponse:
|
| 38 |
+
"""
|
| 39 |
+
This property can be used as a prefix for any HTTP method call to return
|
| 40 |
+
the raw response object instead of the parsed content.
|
| 41 |
+
|
| 42 |
+
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
|
| 43 |
+
"""
|
| 44 |
+
return TranscriptionsWithRawResponse(self)
|
| 45 |
+
|
| 46 |
+
@cached_property
|
| 47 |
+
def with_streaming_response(self) -> TranscriptionsWithStreamingResponse:
|
| 48 |
+
"""
|
| 49 |
+
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
|
| 50 |
+
|
| 51 |
+
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
|
| 52 |
+
"""
|
| 53 |
+
return TranscriptionsWithStreamingResponse(self)
|
| 54 |
+
|
| 55 |
+
@overload
|
| 56 |
+
def create(
|
| 57 |
+
self,
|
| 58 |
+
*,
|
| 59 |
+
file: FileTypes,
|
| 60 |
+
model: Union[str, AudioModel],
|
| 61 |
+
response_format: Union[Literal["json"], NotGiven] = NOT_GIVEN,
|
| 62 |
+
language: str | NotGiven = NOT_GIVEN,
|
| 63 |
+
prompt: str | NotGiven = NOT_GIVEN,
|
| 64 |
+
temperature: float | NotGiven = NOT_GIVEN,
|
| 65 |
+
timestamp_granularities: List[Literal["word", "segment"]] | NotGiven = NOT_GIVEN,
|
| 66 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 67 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 68 |
+
extra_headers: Headers | None = None,
|
| 69 |
+
extra_query: Query | None = None,
|
| 70 |
+
extra_body: Body | None = None,
|
| 71 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 72 |
+
) -> Transcription: ...
|
| 73 |
+
|
| 74 |
+
@overload
|
| 75 |
+
def create(
|
| 76 |
+
self,
|
| 77 |
+
*,
|
| 78 |
+
file: FileTypes,
|
| 79 |
+
model: Union[str, AudioModel],
|
| 80 |
+
response_format: Literal["verbose_json"],
|
| 81 |
+
language: str | NotGiven = NOT_GIVEN,
|
| 82 |
+
prompt: str | NotGiven = NOT_GIVEN,
|
| 83 |
+
temperature: float | NotGiven = NOT_GIVEN,
|
| 84 |
+
timestamp_granularities: List[Literal["word", "segment"]] | NotGiven = NOT_GIVEN,
|
| 85 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 86 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 87 |
+
extra_headers: Headers | None = None,
|
| 88 |
+
extra_query: Query | None = None,
|
| 89 |
+
extra_body: Body | None = None,
|
| 90 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 91 |
+
) -> TranscriptionVerbose: ...
|
| 92 |
+
|
| 93 |
+
@overload
|
| 94 |
+
def create(
|
| 95 |
+
self,
|
| 96 |
+
*,
|
| 97 |
+
file: FileTypes,
|
| 98 |
+
model: Union[str, AudioModel],
|
| 99 |
+
response_format: Literal["text", "srt", "vtt"],
|
| 100 |
+
language: str | NotGiven = NOT_GIVEN,
|
| 101 |
+
prompt: str | NotGiven = NOT_GIVEN,
|
| 102 |
+
temperature: float | NotGiven = NOT_GIVEN,
|
| 103 |
+
timestamp_granularities: List[Literal["word", "segment"]] | NotGiven = NOT_GIVEN,
|
| 104 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 105 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 106 |
+
extra_headers: Headers | None = None,
|
| 107 |
+
extra_query: Query | None = None,
|
| 108 |
+
extra_body: Body | None = None,
|
| 109 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 110 |
+
) -> str: ...
|
| 111 |
+
|
| 112 |
+
def create(
|
| 113 |
+
self,
|
| 114 |
+
*,
|
| 115 |
+
file: FileTypes,
|
| 116 |
+
model: Union[str, AudioModel],
|
| 117 |
+
language: str | NotGiven = NOT_GIVEN,
|
| 118 |
+
prompt: str | NotGiven = NOT_GIVEN,
|
| 119 |
+
response_format: Union[AudioResponseFormat, NotGiven] = NOT_GIVEN,
|
| 120 |
+
temperature: float | NotGiven = NOT_GIVEN,
|
| 121 |
+
timestamp_granularities: List[Literal["word", "segment"]] | NotGiven = NOT_GIVEN,
|
| 122 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 123 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 124 |
+
extra_headers: Headers | None = None,
|
| 125 |
+
extra_query: Query | None = None,
|
| 126 |
+
extra_body: Body | None = None,
|
| 127 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 128 |
+
) -> Transcription | TranscriptionVerbose | str:
|
| 129 |
+
"""
|
| 130 |
+
Transcribes audio into the input language.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
file:
|
| 134 |
+
The audio file object (not file name) to transcribe, in one of these formats:
|
| 135 |
+
flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
|
| 136 |
+
|
| 137 |
+
model: ID of the model to use. Only `whisper-1` (which is powered by our open source
|
| 138 |
+
Whisper V2 model) is currently available.
|
| 139 |
+
|
| 140 |
+
language: The language of the input audio. Supplying the input language in
|
| 141 |
+
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) (e.g. `en`)
|
| 142 |
+
format will improve accuracy and latency.
|
| 143 |
+
|
| 144 |
+
prompt: An optional text to guide the model's style or continue a previous audio
|
| 145 |
+
segment. The
|
| 146 |
+
[prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
|
| 147 |
+
should match the audio language.
|
| 148 |
+
|
| 149 |
+
response_format: The format of the output, in one of these options: `json`, `text`, `srt`,
|
| 150 |
+
`verbose_json`, or `vtt`.
|
| 151 |
+
|
| 152 |
+
temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the
|
| 153 |
+
output more random, while lower values like 0.2 will make it more focused and
|
| 154 |
+
deterministic. If set to 0, the model will use
|
| 155 |
+
[log probability](https://en.wikipedia.org/wiki/Log_probability) to
|
| 156 |
+
automatically increase the temperature until certain thresholds are hit.
|
| 157 |
+
|
| 158 |
+
timestamp_granularities: The timestamp granularities to populate for this transcription.
|
| 159 |
+
`response_format` must be set `verbose_json` to use timestamp granularities.
|
| 160 |
+
Either or both of these options are supported: `word`, or `segment`. Note: There
|
| 161 |
+
is no additional latency for segment timestamps, but generating word timestamps
|
| 162 |
+
incurs additional latency.
|
| 163 |
+
|
| 164 |
+
extra_headers: Send extra headers
|
| 165 |
+
|
| 166 |
+
extra_query: Add additional query parameters to the request
|
| 167 |
+
|
| 168 |
+
extra_body: Add additional JSON properties to the request
|
| 169 |
+
|
| 170 |
+
timeout: Override the client-level default timeout for this request, in seconds
|
| 171 |
+
"""
|
| 172 |
+
body = deepcopy_minimal(
|
| 173 |
+
{
|
| 174 |
+
"file": file,
|
| 175 |
+
"model": model,
|
| 176 |
+
"language": language,
|
| 177 |
+
"prompt": prompt,
|
| 178 |
+
"response_format": response_format,
|
| 179 |
+
"temperature": temperature,
|
| 180 |
+
"timestamp_granularities": timestamp_granularities,
|
| 181 |
+
}
|
| 182 |
+
)
|
| 183 |
+
files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
|
| 184 |
+
# It should be noted that the actual Content-Type header that will be
|
| 185 |
+
# sent to the server will contain a `boundary` parameter, e.g.
|
| 186 |
+
# multipart/form-data; boundary=---abc--
|
| 187 |
+
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
|
| 188 |
+
return self._post( # type: ignore[return-value]
|
| 189 |
+
"/audio/transcriptions",
|
| 190 |
+
body=maybe_transform(body, transcription_create_params.TranscriptionCreateParams),
|
| 191 |
+
files=files,
|
| 192 |
+
options=make_request_options(
|
| 193 |
+
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
|
| 194 |
+
),
|
| 195 |
+
cast_to=_get_response_format_type(response_format),
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class AsyncTranscriptions(AsyncAPIResource):
|
| 200 |
+
@cached_property
|
| 201 |
+
def with_raw_response(self) -> AsyncTranscriptionsWithRawResponse:
|
| 202 |
+
"""
|
| 203 |
+
This property can be used as a prefix for any HTTP method call to return
|
| 204 |
+
the raw response object instead of the parsed content.
|
| 205 |
+
|
| 206 |
+
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
|
| 207 |
+
"""
|
| 208 |
+
return AsyncTranscriptionsWithRawResponse(self)
|
| 209 |
+
|
| 210 |
+
@cached_property
|
| 211 |
+
def with_streaming_response(self) -> AsyncTranscriptionsWithStreamingResponse:
|
| 212 |
+
"""
|
| 213 |
+
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
|
| 214 |
+
|
| 215 |
+
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
|
| 216 |
+
"""
|
| 217 |
+
return AsyncTranscriptionsWithStreamingResponse(self)
|
| 218 |
+
|
| 219 |
+
@overload
|
| 220 |
+
async def create(
|
| 221 |
+
self,
|
| 222 |
+
*,
|
| 223 |
+
file: FileTypes,
|
| 224 |
+
model: Union[str, AudioModel],
|
| 225 |
+
response_format: Union[Literal["json"], NotGiven] = NOT_GIVEN,
|
| 226 |
+
language: str | NotGiven = NOT_GIVEN,
|
| 227 |
+
prompt: str | NotGiven = NOT_GIVEN,
|
| 228 |
+
temperature: float | NotGiven = NOT_GIVEN,
|
| 229 |
+
timestamp_granularities: List[Literal["word", "segment"]] | NotGiven = NOT_GIVEN,
|
| 230 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 231 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 232 |
+
extra_headers: Headers | None = None,
|
| 233 |
+
extra_query: Query | None = None,
|
| 234 |
+
extra_body: Body | None = None,
|
| 235 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 236 |
+
) -> Transcription: ...
|
| 237 |
+
|
| 238 |
+
@overload
|
| 239 |
+
async def create(
|
| 240 |
+
self,
|
| 241 |
+
*,
|
| 242 |
+
file: FileTypes,
|
| 243 |
+
model: Union[str, AudioModel],
|
| 244 |
+
response_format: Literal["verbose_json"],
|
| 245 |
+
language: str | NotGiven = NOT_GIVEN,
|
| 246 |
+
prompt: str | NotGiven = NOT_GIVEN,
|
| 247 |
+
temperature: float | NotGiven = NOT_GIVEN,
|
| 248 |
+
timestamp_granularities: List[Literal["word", "segment"]] | NotGiven = NOT_GIVEN,
|
| 249 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 250 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 251 |
+
extra_headers: Headers | None = None,
|
| 252 |
+
extra_query: Query | None = None,
|
| 253 |
+
extra_body: Body | None = None,
|
| 254 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 255 |
+
) -> TranscriptionVerbose: ...
|
| 256 |
+
|
| 257 |
+
@overload
|
| 258 |
+
async def create(
|
| 259 |
+
self,
|
| 260 |
+
*,
|
| 261 |
+
file: FileTypes,
|
| 262 |
+
model: Union[str, AudioModel],
|
| 263 |
+
response_format: Literal["text", "srt", "vtt"],
|
| 264 |
+
language: str | NotGiven = NOT_GIVEN,
|
| 265 |
+
prompt: str | NotGiven = NOT_GIVEN,
|
| 266 |
+
temperature: float | NotGiven = NOT_GIVEN,
|
| 267 |
+
timestamp_granularities: List[Literal["word", "segment"]] | NotGiven = NOT_GIVEN,
|
| 268 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 269 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 270 |
+
extra_headers: Headers | None = None,
|
| 271 |
+
extra_query: Query | None = None,
|
| 272 |
+
extra_body: Body | None = None,
|
| 273 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 274 |
+
) -> str: ...
|
| 275 |
+
|
| 276 |
+
async def create(
|
| 277 |
+
self,
|
| 278 |
+
*,
|
| 279 |
+
file: FileTypes,
|
| 280 |
+
model: Union[str, AudioModel],
|
| 281 |
+
language: str | NotGiven = NOT_GIVEN,
|
| 282 |
+
prompt: str | NotGiven = NOT_GIVEN,
|
| 283 |
+
response_format: Union[AudioResponseFormat, NotGiven] = NOT_GIVEN,
|
| 284 |
+
temperature: float | NotGiven = NOT_GIVEN,
|
| 285 |
+
timestamp_granularities: List[Literal["word", "segment"]] | NotGiven = NOT_GIVEN,
|
| 286 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 287 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 288 |
+
extra_headers: Headers | None = None,
|
| 289 |
+
extra_query: Query | None = None,
|
| 290 |
+
extra_body: Body | None = None,
|
| 291 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 292 |
+
) -> Transcription | TranscriptionVerbose | str:
|
| 293 |
+
"""
|
| 294 |
+
Transcribes audio into the input language.
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
file:
|
| 298 |
+
The audio file object (not file name) to transcribe, in one of these formats:
|
| 299 |
+
flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
|
| 300 |
+
|
| 301 |
+
model: ID of the model to use. Only `whisper-1` (which is powered by our open source
|
| 302 |
+
Whisper V2 model) is currently available.
|
| 303 |
+
|
| 304 |
+
language: The language of the input audio. Supplying the input language in
|
| 305 |
+
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) (e.g. `en`)
|
| 306 |
+
format will improve accuracy and latency.
|
| 307 |
+
|
| 308 |
+
prompt: An optional text to guide the model's style or continue a previous audio
|
| 309 |
+
segment. The
|
| 310 |
+
[prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
|
| 311 |
+
should match the audio language.
|
| 312 |
+
|
| 313 |
+
response_format: The format of the output, in one of these options: `json`, `text`, `srt`,
|
| 314 |
+
`verbose_json`, or `vtt`.
|
| 315 |
+
|
| 316 |
+
temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the
|
| 317 |
+
output more random, while lower values like 0.2 will make it more focused and
|
| 318 |
+
deterministic. If set to 0, the model will use
|
| 319 |
+
[log probability](https://en.wikipedia.org/wiki/Log_probability) to
|
| 320 |
+
automatically increase the temperature until certain thresholds are hit.
|
| 321 |
+
|
| 322 |
+
timestamp_granularities: The timestamp granularities to populate for this transcription.
|
| 323 |
+
`response_format` must be set `verbose_json` to use timestamp granularities.
|
| 324 |
+
Either or both of these options are supported: `word`, or `segment`. Note: There
|
| 325 |
+
is no additional latency for segment timestamps, but generating word timestamps
|
| 326 |
+
incurs additional latency.
|
| 327 |
+
|
| 328 |
+
extra_headers: Send extra headers
|
| 329 |
+
|
| 330 |
+
extra_query: Add additional query parameters to the request
|
| 331 |
+
|
| 332 |
+
extra_body: Add additional JSON properties to the request
|
| 333 |
+
|
| 334 |
+
timeout: Override the client-level default timeout for this request, in seconds
|
| 335 |
+
"""
|
| 336 |
+
body = deepcopy_minimal(
|
| 337 |
+
{
|
| 338 |
+
"file": file,
|
| 339 |
+
"model": model,
|
| 340 |
+
"language": language,
|
| 341 |
+
"prompt": prompt,
|
| 342 |
+
"response_format": response_format,
|
| 343 |
+
"temperature": temperature,
|
| 344 |
+
"timestamp_granularities": timestamp_granularities,
|
| 345 |
+
}
|
| 346 |
+
)
|
| 347 |
+
files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
|
| 348 |
+
# It should be noted that the actual Content-Type header that will be
|
| 349 |
+
# sent to the server will contain a `boundary` parameter, e.g.
|
| 350 |
+
# multipart/form-data; boundary=---abc--
|
| 351 |
+
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
|
| 352 |
+
return await self._post(
|
| 353 |
+
"/audio/transcriptions",
|
| 354 |
+
body=await async_maybe_transform(body, transcription_create_params.TranscriptionCreateParams),
|
| 355 |
+
files=files,
|
| 356 |
+
options=make_request_options(
|
| 357 |
+
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
|
| 358 |
+
),
|
| 359 |
+
cast_to=_get_response_format_type(response_format),
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
class TranscriptionsWithRawResponse:
|
| 364 |
+
def __init__(self, transcriptions: Transcriptions) -> None:
|
| 365 |
+
self._transcriptions = transcriptions
|
| 366 |
+
|
| 367 |
+
self.create = _legacy_response.to_raw_response_wrapper(
|
| 368 |
+
transcriptions.create,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class AsyncTranscriptionsWithRawResponse:
|
| 373 |
+
def __init__(self, transcriptions: AsyncTranscriptions) -> None:
|
| 374 |
+
self._transcriptions = transcriptions
|
| 375 |
+
|
| 376 |
+
self.create = _legacy_response.async_to_raw_response_wrapper(
|
| 377 |
+
transcriptions.create,
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
class TranscriptionsWithStreamingResponse:
|
| 382 |
+
def __init__(self, transcriptions: Transcriptions) -> None:
|
| 383 |
+
self._transcriptions = transcriptions
|
| 384 |
+
|
| 385 |
+
self.create = to_streamed_response_wrapper(
|
| 386 |
+
transcriptions.create,
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
class AsyncTranscriptionsWithStreamingResponse:
|
| 391 |
+
def __init__(self, transcriptions: AsyncTranscriptions) -> None:
|
| 392 |
+
self._transcriptions = transcriptions
|
| 393 |
+
|
| 394 |
+
self.create = async_to_streamed_response_wrapper(
|
| 395 |
+
transcriptions.create,
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def _get_response_format_type(
|
| 400 |
+
response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] | NotGiven,
|
| 401 |
+
) -> type[Transcription | TranscriptionVerbose | str]:
|
| 402 |
+
if isinstance(response_format, NotGiven) or response_format is None: # pyright: ignore[reportUnnecessaryComparison]
|
| 403 |
+
return Transcription
|
| 404 |
+
|
| 405 |
+
if response_format == "json":
|
| 406 |
+
return Transcription
|
| 407 |
+
elif response_format == "verbose_json":
|
| 408 |
+
return TranscriptionVerbose
|
| 409 |
+
elif response_format == "srt" or response_format == "text" or response_format == "vtt":
|
| 410 |
+
return str
|
| 411 |
+
elif TYPE_CHECKING: # type: ignore[unreachable]
|
| 412 |
+
assert_never(response_format)
|
| 413 |
+
else:
|
| 414 |
+
log.warn("Unexpected audio response format: %s", response_format)
|
| 415 |
+
return Transcription
|
.venv/lib/python3.11/site-packages/openai/resources/audio/translations.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from typing import TYPE_CHECKING, Union, Mapping, cast
|
| 7 |
+
from typing_extensions import Literal, overload, assert_never
|
| 8 |
+
|
| 9 |
+
import httpx
|
| 10 |
+
|
| 11 |
+
from ... import _legacy_response
|
| 12 |
+
from ...types import AudioResponseFormat
|
| 13 |
+
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
|
| 14 |
+
from ..._utils import (
|
| 15 |
+
extract_files,
|
| 16 |
+
maybe_transform,
|
| 17 |
+
deepcopy_minimal,
|
| 18 |
+
async_maybe_transform,
|
| 19 |
+
)
|
| 20 |
+
from ..._compat import cached_property
|
| 21 |
+
from ..._resource import SyncAPIResource, AsyncAPIResource
|
| 22 |
+
from ..._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
|
| 23 |
+
from ...types.audio import translation_create_params
|
| 24 |
+
from ..._base_client import make_request_options
|
| 25 |
+
from ...types.audio_model import AudioModel
|
| 26 |
+
from ...types.audio.translation import Translation
|
| 27 |
+
from ...types.audio_response_format import AudioResponseFormat
|
| 28 |
+
from ...types.audio.translation_verbose import TranslationVerbose
|
| 29 |
+
|
| 30 |
+
__all__ = ["Translations", "AsyncTranslations"]
|
| 31 |
+
|
| 32 |
+
log: logging.Logger = logging.getLogger("openai.audio.transcriptions")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class Translations(SyncAPIResource):
|
| 36 |
+
@cached_property
|
| 37 |
+
def with_raw_response(self) -> TranslationsWithRawResponse:
|
| 38 |
+
"""
|
| 39 |
+
This property can be used as a prefix for any HTTP method call to return
|
| 40 |
+
the raw response object instead of the parsed content.
|
| 41 |
+
|
| 42 |
+
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
|
| 43 |
+
"""
|
| 44 |
+
return TranslationsWithRawResponse(self)
|
| 45 |
+
|
| 46 |
+
@cached_property
|
| 47 |
+
def with_streaming_response(self) -> TranslationsWithStreamingResponse:
|
| 48 |
+
"""
|
| 49 |
+
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
|
| 50 |
+
|
| 51 |
+
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
|
| 52 |
+
"""
|
| 53 |
+
return TranslationsWithStreamingResponse(self)
|
| 54 |
+
|
| 55 |
+
@overload
|
| 56 |
+
def create(
|
| 57 |
+
self,
|
| 58 |
+
*,
|
| 59 |
+
file: FileTypes,
|
| 60 |
+
model: Union[str, AudioModel],
|
| 61 |
+
response_format: Union[Literal["json"], NotGiven] = NOT_GIVEN,
|
| 62 |
+
prompt: str | NotGiven = NOT_GIVEN,
|
| 63 |
+
temperature: float | NotGiven = NOT_GIVEN,
|
| 64 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 65 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 66 |
+
extra_headers: Headers | None = None,
|
| 67 |
+
extra_query: Query | None = None,
|
| 68 |
+
extra_body: Body | None = None,
|
| 69 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 70 |
+
) -> Translation: ...
|
| 71 |
+
|
| 72 |
+
@overload
|
| 73 |
+
def create(
|
| 74 |
+
self,
|
| 75 |
+
*,
|
| 76 |
+
file: FileTypes,
|
| 77 |
+
model: Union[str, AudioModel],
|
| 78 |
+
response_format: Literal["verbose_json"],
|
| 79 |
+
prompt: str | NotGiven = NOT_GIVEN,
|
| 80 |
+
temperature: float | NotGiven = NOT_GIVEN,
|
| 81 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 82 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 83 |
+
extra_headers: Headers | None = None,
|
| 84 |
+
extra_query: Query | None = None,
|
| 85 |
+
extra_body: Body | None = None,
|
| 86 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 87 |
+
) -> TranslationVerbose: ...
|
| 88 |
+
|
| 89 |
+
@overload
|
| 90 |
+
def create(
|
| 91 |
+
self,
|
| 92 |
+
*,
|
| 93 |
+
file: FileTypes,
|
| 94 |
+
model: Union[str, AudioModel],
|
| 95 |
+
response_format: Literal["text", "srt", "vtt"],
|
| 96 |
+
prompt: str | NotGiven = NOT_GIVEN,
|
| 97 |
+
temperature: float | NotGiven = NOT_GIVEN,
|
| 98 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 99 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 100 |
+
extra_headers: Headers | None = None,
|
| 101 |
+
extra_query: Query | None = None,
|
| 102 |
+
extra_body: Body | None = None,
|
| 103 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 104 |
+
) -> str: ...
|
| 105 |
+
|
| 106 |
+
def create(
|
| 107 |
+
self,
|
| 108 |
+
*,
|
| 109 |
+
file: FileTypes,
|
| 110 |
+
model: Union[str, AudioModel],
|
| 111 |
+
prompt: str | NotGiven = NOT_GIVEN,
|
| 112 |
+
response_format: Union[AudioResponseFormat, NotGiven] = NOT_GIVEN,
|
| 113 |
+
temperature: float | NotGiven = NOT_GIVEN,
|
| 114 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 115 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 116 |
+
extra_headers: Headers | None = None,
|
| 117 |
+
extra_query: Query | None = None,
|
| 118 |
+
extra_body: Body | None = None,
|
| 119 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 120 |
+
) -> Translation | TranslationVerbose | str:
|
| 121 |
+
"""
|
| 122 |
+
Translates audio into English.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
file: The audio file object (not file name) translate, in one of these formats: flac,
|
| 126 |
+
mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
|
| 127 |
+
|
| 128 |
+
model: ID of the model to use. Only `whisper-1` (which is powered by our open source
|
| 129 |
+
Whisper V2 model) is currently available.
|
| 130 |
+
|
| 131 |
+
prompt: An optional text to guide the model's style or continue a previous audio
|
| 132 |
+
segment. The
|
| 133 |
+
[prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
|
| 134 |
+
should be in English.
|
| 135 |
+
|
| 136 |
+
response_format: The format of the output, in one of these options: `json`, `text`, `srt`,
|
| 137 |
+
`verbose_json`, or `vtt`.
|
| 138 |
+
|
| 139 |
+
temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the
|
| 140 |
+
output more random, while lower values like 0.2 will make it more focused and
|
| 141 |
+
deterministic. If set to 0, the model will use
|
| 142 |
+
[log probability](https://en.wikipedia.org/wiki/Log_probability) to
|
| 143 |
+
automatically increase the temperature until certain thresholds are hit.
|
| 144 |
+
|
| 145 |
+
extra_headers: Send extra headers
|
| 146 |
+
|
| 147 |
+
extra_query: Add additional query parameters to the request
|
| 148 |
+
|
| 149 |
+
extra_body: Add additional JSON properties to the request
|
| 150 |
+
|
| 151 |
+
timeout: Override the client-level default timeout for this request, in seconds
|
| 152 |
+
"""
|
| 153 |
+
body = deepcopy_minimal(
|
| 154 |
+
{
|
| 155 |
+
"file": file,
|
| 156 |
+
"model": model,
|
| 157 |
+
"prompt": prompt,
|
| 158 |
+
"response_format": response_format,
|
| 159 |
+
"temperature": temperature,
|
| 160 |
+
}
|
| 161 |
+
)
|
| 162 |
+
files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
|
| 163 |
+
# It should be noted that the actual Content-Type header that will be
|
| 164 |
+
# sent to the server will contain a `boundary` parameter, e.g.
|
| 165 |
+
# multipart/form-data; boundary=---abc--
|
| 166 |
+
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
|
| 167 |
+
return self._post( # type: ignore[return-value]
|
| 168 |
+
"/audio/translations",
|
| 169 |
+
body=maybe_transform(body, translation_create_params.TranslationCreateParams),
|
| 170 |
+
files=files,
|
| 171 |
+
options=make_request_options(
|
| 172 |
+
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
|
| 173 |
+
),
|
| 174 |
+
cast_to=_get_response_format_type(response_format),
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class AsyncTranslations(AsyncAPIResource):
|
| 179 |
+
@cached_property
|
| 180 |
+
def with_raw_response(self) -> AsyncTranslationsWithRawResponse:
|
| 181 |
+
"""
|
| 182 |
+
This property can be used as a prefix for any HTTP method call to return
|
| 183 |
+
the raw response object instead of the parsed content.
|
| 184 |
+
|
| 185 |
+
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
|
| 186 |
+
"""
|
| 187 |
+
return AsyncTranslationsWithRawResponse(self)
|
| 188 |
+
|
| 189 |
+
@cached_property
|
| 190 |
+
def with_streaming_response(self) -> AsyncTranslationsWithStreamingResponse:
|
| 191 |
+
"""
|
| 192 |
+
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
|
| 193 |
+
|
| 194 |
+
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
|
| 195 |
+
"""
|
| 196 |
+
return AsyncTranslationsWithStreamingResponse(self)
|
| 197 |
+
|
| 198 |
+
@overload
|
| 199 |
+
async def create(
|
| 200 |
+
self,
|
| 201 |
+
*,
|
| 202 |
+
file: FileTypes,
|
| 203 |
+
model: Union[str, AudioModel],
|
| 204 |
+
response_format: Union[Literal["json"], NotGiven] = NOT_GIVEN,
|
| 205 |
+
prompt: str | NotGiven = NOT_GIVEN,
|
| 206 |
+
temperature: float | NotGiven = NOT_GIVEN,
|
| 207 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 208 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 209 |
+
extra_headers: Headers | None = None,
|
| 210 |
+
extra_query: Query | None = None,
|
| 211 |
+
extra_body: Body | None = None,
|
| 212 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 213 |
+
) -> Translation: ...
|
| 214 |
+
|
| 215 |
+
@overload
|
| 216 |
+
async def create(
|
| 217 |
+
self,
|
| 218 |
+
*,
|
| 219 |
+
file: FileTypes,
|
| 220 |
+
model: Union[str, AudioModel],
|
| 221 |
+
response_format: Literal["verbose_json"],
|
| 222 |
+
prompt: str | NotGiven = NOT_GIVEN,
|
| 223 |
+
temperature: float | NotGiven = NOT_GIVEN,
|
| 224 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 225 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 226 |
+
extra_headers: Headers | None = None,
|
| 227 |
+
extra_query: Query | None = None,
|
| 228 |
+
extra_body: Body | None = None,
|
| 229 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 230 |
+
) -> TranslationVerbose: ...
|
| 231 |
+
|
| 232 |
+
@overload
|
| 233 |
+
async def create(
|
| 234 |
+
self,
|
| 235 |
+
*,
|
| 236 |
+
file: FileTypes,
|
| 237 |
+
model: Union[str, AudioModel],
|
| 238 |
+
response_format: Literal["text", "srt", "vtt"],
|
| 239 |
+
prompt: str | NotGiven = NOT_GIVEN,
|
| 240 |
+
temperature: float | NotGiven = NOT_GIVEN,
|
| 241 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 242 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 243 |
+
extra_headers: Headers | None = None,
|
| 244 |
+
extra_query: Query | None = None,
|
| 245 |
+
extra_body: Body | None = None,
|
| 246 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 247 |
+
) -> str: ...
|
| 248 |
+
|
| 249 |
+
async def create(
|
| 250 |
+
self,
|
| 251 |
+
*,
|
| 252 |
+
file: FileTypes,
|
| 253 |
+
model: Union[str, AudioModel],
|
| 254 |
+
prompt: str | NotGiven = NOT_GIVEN,
|
| 255 |
+
response_format: Union[AudioResponseFormat, NotGiven] = NOT_GIVEN,
|
| 256 |
+
temperature: float | NotGiven = NOT_GIVEN,
|
| 257 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 258 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 259 |
+
extra_headers: Headers | None = None,
|
| 260 |
+
extra_query: Query | None = None,
|
| 261 |
+
extra_body: Body | None = None,
|
| 262 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 263 |
+
) -> Translation | TranslationVerbose | str:
|
| 264 |
+
"""
|
| 265 |
+
Translates audio into English.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
file: The audio file object (not file name) translate, in one of these formats: flac,
|
| 269 |
+
mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
|
| 270 |
+
|
| 271 |
+
model: ID of the model to use. Only `whisper-1` (which is powered by our open source
|
| 272 |
+
Whisper V2 model) is currently available.
|
| 273 |
+
|
| 274 |
+
prompt: An optional text to guide the model's style or continue a previous audio
|
| 275 |
+
segment. The
|
| 276 |
+
[prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
|
| 277 |
+
should be in English.
|
| 278 |
+
|
| 279 |
+
response_format: The format of the output, in one of these options: `json`, `text`, `srt`,
|
| 280 |
+
`verbose_json`, or `vtt`.
|
| 281 |
+
|
| 282 |
+
temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the
|
| 283 |
+
output more random, while lower values like 0.2 will make it more focused and
|
| 284 |
+
deterministic. If set to 0, the model will use
|
| 285 |
+
[log probability](https://en.wikipedia.org/wiki/Log_probability) to
|
| 286 |
+
automatically increase the temperature until certain thresholds are hit.
|
| 287 |
+
|
| 288 |
+
extra_headers: Send extra headers
|
| 289 |
+
|
| 290 |
+
extra_query: Add additional query parameters to the request
|
| 291 |
+
|
| 292 |
+
extra_body: Add additional JSON properties to the request
|
| 293 |
+
|
| 294 |
+
timeout: Override the client-level default timeout for this request, in seconds
|
| 295 |
+
"""
|
| 296 |
+
body = deepcopy_minimal(
|
| 297 |
+
{
|
| 298 |
+
"file": file,
|
| 299 |
+
"model": model,
|
| 300 |
+
"prompt": prompt,
|
| 301 |
+
"response_format": response_format,
|
| 302 |
+
"temperature": temperature,
|
| 303 |
+
}
|
| 304 |
+
)
|
| 305 |
+
files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
|
| 306 |
+
# It should be noted that the actual Content-Type header that will be
|
| 307 |
+
# sent to the server will contain a `boundary` parameter, e.g.
|
| 308 |
+
# multipart/form-data; boundary=---abc--
|
| 309 |
+
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
|
| 310 |
+
return await self._post(
|
| 311 |
+
"/audio/translations",
|
| 312 |
+
body=await async_maybe_transform(body, translation_create_params.TranslationCreateParams),
|
| 313 |
+
files=files,
|
| 314 |
+
options=make_request_options(
|
| 315 |
+
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
|
| 316 |
+
),
|
| 317 |
+
cast_to=_get_response_format_type(response_format),
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class TranslationsWithRawResponse:
|
| 322 |
+
def __init__(self, translations: Translations) -> None:
|
| 323 |
+
self._translations = translations
|
| 324 |
+
|
| 325 |
+
self.create = _legacy_response.to_raw_response_wrapper(
|
| 326 |
+
translations.create,
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class AsyncTranslationsWithRawResponse:
|
| 331 |
+
def __init__(self, translations: AsyncTranslations) -> None:
|
| 332 |
+
self._translations = translations
|
| 333 |
+
|
| 334 |
+
self.create = _legacy_response.async_to_raw_response_wrapper(
|
| 335 |
+
translations.create,
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class TranslationsWithStreamingResponse:
|
| 340 |
+
def __init__(self, translations: Translations) -> None:
|
| 341 |
+
self._translations = translations
|
| 342 |
+
|
| 343 |
+
self.create = to_streamed_response_wrapper(
|
| 344 |
+
translations.create,
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
class AsyncTranslationsWithStreamingResponse:
|
| 349 |
+
def __init__(self, translations: AsyncTranslations) -> None:
|
| 350 |
+
self._translations = translations
|
| 351 |
+
|
| 352 |
+
self.create = async_to_streamed_response_wrapper(
|
| 353 |
+
translations.create,
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def _get_response_format_type(
|
| 358 |
+
response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] | NotGiven,
|
| 359 |
+
) -> type[Translation | TranslationVerbose | str]:
|
| 360 |
+
if isinstance(response_format, NotGiven) or response_format is None: # pyright: ignore[reportUnnecessaryComparison]
|
| 361 |
+
return Translation
|
| 362 |
+
|
| 363 |
+
if response_format == "json":
|
| 364 |
+
return Translation
|
| 365 |
+
elif response_format == "verbose_json":
|
| 366 |
+
return TranslationVerbose
|
| 367 |
+
elif response_format == "srt" or response_format == "text" or response_format == "vtt":
|
| 368 |
+
return str
|
| 369 |
+
elif TYPE_CHECKING: # type: ignore[unreachable]
|
| 370 |
+
assert_never(response_format)
|
| 371 |
+
else:
|
| 372 |
+
log.warn("Unexpected audio response format: %s", response_format)
|
| 373 |
+
return Transcription
|
.venv/lib/python3.11/site-packages/openai/resources/batches.py
ADDED
|
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from typing_extensions import Literal
|
| 7 |
+
|
| 8 |
+
import httpx
|
| 9 |
+
|
| 10 |
+
from .. import _legacy_response
|
| 11 |
+
from ..types import batch_list_params, batch_create_params
|
| 12 |
+
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
|
| 13 |
+
from .._utils import (
|
| 14 |
+
maybe_transform,
|
| 15 |
+
async_maybe_transform,
|
| 16 |
+
)
|
| 17 |
+
from .._compat import cached_property
|
| 18 |
+
from .._resource import SyncAPIResource, AsyncAPIResource
|
| 19 |
+
from .._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
|
| 20 |
+
from ..pagination import SyncCursorPage, AsyncCursorPage
|
| 21 |
+
from ..types.batch import Batch
|
| 22 |
+
from .._base_client import AsyncPaginator, make_request_options
|
| 23 |
+
from ..types.shared_params.metadata import Metadata
|
| 24 |
+
|
| 25 |
+
__all__ = ["Batches", "AsyncBatches"]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Batches(SyncAPIResource):
|
| 29 |
+
@cached_property
|
| 30 |
+
def with_raw_response(self) -> BatchesWithRawResponse:
|
| 31 |
+
"""
|
| 32 |
+
This property can be used as a prefix for any HTTP method call to return
|
| 33 |
+
the raw response object instead of the parsed content.
|
| 34 |
+
|
| 35 |
+
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
|
| 36 |
+
"""
|
| 37 |
+
return BatchesWithRawResponse(self)
|
| 38 |
+
|
| 39 |
+
@cached_property
|
| 40 |
+
def with_streaming_response(self) -> BatchesWithStreamingResponse:
|
| 41 |
+
"""
|
| 42 |
+
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
|
| 43 |
+
|
| 44 |
+
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
|
| 45 |
+
"""
|
| 46 |
+
return BatchesWithStreamingResponse(self)
|
| 47 |
+
|
| 48 |
+
def create(
|
| 49 |
+
self,
|
| 50 |
+
*,
|
| 51 |
+
completion_window: Literal["24h"],
|
| 52 |
+
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
|
| 53 |
+
input_file_id: str,
|
| 54 |
+
metadata: Optional[Metadata] | NotGiven = NOT_GIVEN,
|
| 55 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 56 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 57 |
+
extra_headers: Headers | None = None,
|
| 58 |
+
extra_query: Query | None = None,
|
| 59 |
+
extra_body: Body | None = None,
|
| 60 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 61 |
+
) -> Batch:
|
| 62 |
+
"""
|
| 63 |
+
Creates and executes a batch from an uploaded file of requests
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
completion_window: The time frame within which the batch should be processed. Currently only `24h`
|
| 67 |
+
is supported.
|
| 68 |
+
|
| 69 |
+
endpoint: The endpoint to be used for all requests in the batch. Currently
|
| 70 |
+
`/v1/chat/completions`, `/v1/embeddings`, and `/v1/completions` are supported.
|
| 71 |
+
Note that `/v1/embeddings` batches are also restricted to a maximum of 50,000
|
| 72 |
+
embedding inputs across all requests in the batch.
|
| 73 |
+
|
| 74 |
+
input_file_id: The ID of an uploaded file that contains requests for the new batch.
|
| 75 |
+
|
| 76 |
+
See [upload file](https://platform.openai.com/docs/api-reference/files/create)
|
| 77 |
+
for how to upload a file.
|
| 78 |
+
|
| 79 |
+
Your input file must be formatted as a
|
| 80 |
+
[JSONL file](https://platform.openai.com/docs/api-reference/batch/request-input),
|
| 81 |
+
and must be uploaded with the purpose `batch`. The file can contain up to 50,000
|
| 82 |
+
requests, and can be up to 200 MB in size.
|
| 83 |
+
|
| 84 |
+
metadata: Set of 16 key-value pairs that can be attached to an object. This can be useful
|
| 85 |
+
for storing additional information about the object in a structured format, and
|
| 86 |
+
querying for objects via API or the dashboard.
|
| 87 |
+
|
| 88 |
+
Keys are strings with a maximum length of 64 characters. Values are strings with
|
| 89 |
+
a maximum length of 512 characters.
|
| 90 |
+
|
| 91 |
+
extra_headers: Send extra headers
|
| 92 |
+
|
| 93 |
+
extra_query: Add additional query parameters to the request
|
| 94 |
+
|
| 95 |
+
extra_body: Add additional JSON properties to the request
|
| 96 |
+
|
| 97 |
+
timeout: Override the client-level default timeout for this request, in seconds
|
| 98 |
+
"""
|
| 99 |
+
return self._post(
|
| 100 |
+
"/batches",
|
| 101 |
+
body=maybe_transform(
|
| 102 |
+
{
|
| 103 |
+
"completion_window": completion_window,
|
| 104 |
+
"endpoint": endpoint,
|
| 105 |
+
"input_file_id": input_file_id,
|
| 106 |
+
"metadata": metadata,
|
| 107 |
+
},
|
| 108 |
+
batch_create_params.BatchCreateParams,
|
| 109 |
+
),
|
| 110 |
+
options=make_request_options(
|
| 111 |
+
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
|
| 112 |
+
),
|
| 113 |
+
cast_to=Batch,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def retrieve(
|
| 117 |
+
self,
|
| 118 |
+
batch_id: str,
|
| 119 |
+
*,
|
| 120 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 121 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 122 |
+
extra_headers: Headers | None = None,
|
| 123 |
+
extra_query: Query | None = None,
|
| 124 |
+
extra_body: Body | None = None,
|
| 125 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 126 |
+
) -> Batch:
|
| 127 |
+
"""
|
| 128 |
+
Retrieves a batch.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
extra_headers: Send extra headers
|
| 132 |
+
|
| 133 |
+
extra_query: Add additional query parameters to the request
|
| 134 |
+
|
| 135 |
+
extra_body: Add additional JSON properties to the request
|
| 136 |
+
|
| 137 |
+
timeout: Override the client-level default timeout for this request, in seconds
|
| 138 |
+
"""
|
| 139 |
+
if not batch_id:
|
| 140 |
+
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
|
| 141 |
+
return self._get(
|
| 142 |
+
f"/batches/{batch_id}",
|
| 143 |
+
options=make_request_options(
|
| 144 |
+
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
|
| 145 |
+
),
|
| 146 |
+
cast_to=Batch,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def list(
|
| 150 |
+
self,
|
| 151 |
+
*,
|
| 152 |
+
after: str | NotGiven = NOT_GIVEN,
|
| 153 |
+
limit: int | NotGiven = NOT_GIVEN,
|
| 154 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 155 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 156 |
+
extra_headers: Headers | None = None,
|
| 157 |
+
extra_query: Query | None = None,
|
| 158 |
+
extra_body: Body | None = None,
|
| 159 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 160 |
+
) -> SyncCursorPage[Batch]:
|
| 161 |
+
"""List your organization's batches.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
after: A cursor for use in pagination.
|
| 165 |
+
|
| 166 |
+
`after` is an object ID that defines your place
|
| 167 |
+
in the list. For instance, if you make a list request and receive 100 objects,
|
| 168 |
+
ending with obj_foo, your subsequent call can include after=obj_foo in order to
|
| 169 |
+
fetch the next page of the list.
|
| 170 |
+
|
| 171 |
+
limit: A limit on the number of objects to be returned. Limit can range between 1 and
|
| 172 |
+
100, and the default is 20.
|
| 173 |
+
|
| 174 |
+
extra_headers: Send extra headers
|
| 175 |
+
|
| 176 |
+
extra_query: Add additional query parameters to the request
|
| 177 |
+
|
| 178 |
+
extra_body: Add additional JSON properties to the request
|
| 179 |
+
|
| 180 |
+
timeout: Override the client-level default timeout for this request, in seconds
|
| 181 |
+
"""
|
| 182 |
+
return self._get_api_list(
|
| 183 |
+
"/batches",
|
| 184 |
+
page=SyncCursorPage[Batch],
|
| 185 |
+
options=make_request_options(
|
| 186 |
+
extra_headers=extra_headers,
|
| 187 |
+
extra_query=extra_query,
|
| 188 |
+
extra_body=extra_body,
|
| 189 |
+
timeout=timeout,
|
| 190 |
+
query=maybe_transform(
|
| 191 |
+
{
|
| 192 |
+
"after": after,
|
| 193 |
+
"limit": limit,
|
| 194 |
+
},
|
| 195 |
+
batch_list_params.BatchListParams,
|
| 196 |
+
),
|
| 197 |
+
),
|
| 198 |
+
model=Batch,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
def cancel(
|
| 202 |
+
self,
|
| 203 |
+
batch_id: str,
|
| 204 |
+
*,
|
| 205 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 206 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 207 |
+
extra_headers: Headers | None = None,
|
| 208 |
+
extra_query: Query | None = None,
|
| 209 |
+
extra_body: Body | None = None,
|
| 210 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 211 |
+
) -> Batch:
|
| 212 |
+
"""Cancels an in-progress batch.
|
| 213 |
+
|
| 214 |
+
The batch will be in status `cancelling` for up to
|
| 215 |
+
10 minutes, before changing to `cancelled`, where it will have partial results
|
| 216 |
+
(if any) available in the output file.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
extra_headers: Send extra headers
|
| 220 |
+
|
| 221 |
+
extra_query: Add additional query parameters to the request
|
| 222 |
+
|
| 223 |
+
extra_body: Add additional JSON properties to the request
|
| 224 |
+
|
| 225 |
+
timeout: Override the client-level default timeout for this request, in seconds
|
| 226 |
+
"""
|
| 227 |
+
if not batch_id:
|
| 228 |
+
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
|
| 229 |
+
return self._post(
|
| 230 |
+
f"/batches/{batch_id}/cancel",
|
| 231 |
+
options=make_request_options(
|
| 232 |
+
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
|
| 233 |
+
),
|
| 234 |
+
cast_to=Batch,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class AsyncBatches(AsyncAPIResource):
|
| 239 |
+
@cached_property
|
| 240 |
+
def with_raw_response(self) -> AsyncBatchesWithRawResponse:
|
| 241 |
+
"""
|
| 242 |
+
This property can be used as a prefix for any HTTP method call to return
|
| 243 |
+
the raw response object instead of the parsed content.
|
| 244 |
+
|
| 245 |
+
For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
|
| 246 |
+
"""
|
| 247 |
+
return AsyncBatchesWithRawResponse(self)
|
| 248 |
+
|
| 249 |
+
@cached_property
|
| 250 |
+
def with_streaming_response(self) -> AsyncBatchesWithStreamingResponse:
|
| 251 |
+
"""
|
| 252 |
+
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
|
| 253 |
+
|
| 254 |
+
For more information, see https://www.github.com/openai/openai-python#with_streaming_response
|
| 255 |
+
"""
|
| 256 |
+
return AsyncBatchesWithStreamingResponse(self)
|
| 257 |
+
|
| 258 |
+
async def create(
|
| 259 |
+
self,
|
| 260 |
+
*,
|
| 261 |
+
completion_window: Literal["24h"],
|
| 262 |
+
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
|
| 263 |
+
input_file_id: str,
|
| 264 |
+
metadata: Optional[Metadata] | NotGiven = NOT_GIVEN,
|
| 265 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 266 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 267 |
+
extra_headers: Headers | None = None,
|
| 268 |
+
extra_query: Query | None = None,
|
| 269 |
+
extra_body: Body | None = None,
|
| 270 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 271 |
+
) -> Batch:
|
| 272 |
+
"""
|
| 273 |
+
Creates and executes a batch from an uploaded file of requests
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
completion_window: The time frame within which the batch should be processed. Currently only `24h`
|
| 277 |
+
is supported.
|
| 278 |
+
|
| 279 |
+
endpoint: The endpoint to be used for all requests in the batch. Currently
|
| 280 |
+
`/v1/chat/completions`, `/v1/embeddings`, and `/v1/completions` are supported.
|
| 281 |
+
Note that `/v1/embeddings` batches are also restricted to a maximum of 50,000
|
| 282 |
+
embedding inputs across all requests in the batch.
|
| 283 |
+
|
| 284 |
+
input_file_id: The ID of an uploaded file that contains requests for the new batch.
|
| 285 |
+
|
| 286 |
+
See [upload file](https://platform.openai.com/docs/api-reference/files/create)
|
| 287 |
+
for how to upload a file.
|
| 288 |
+
|
| 289 |
+
Your input file must be formatted as a
|
| 290 |
+
[JSONL file](https://platform.openai.com/docs/api-reference/batch/request-input),
|
| 291 |
+
and must be uploaded with the purpose `batch`. The file can contain up to 50,000
|
| 292 |
+
requests, and can be up to 200 MB in size.
|
| 293 |
+
|
| 294 |
+
metadata: Set of 16 key-value pairs that can be attached to an object. This can be useful
|
| 295 |
+
for storing additional information about the object in a structured format, and
|
| 296 |
+
querying for objects via API or the dashboard.
|
| 297 |
+
|
| 298 |
+
Keys are strings with a maximum length of 64 characters. Values are strings with
|
| 299 |
+
a maximum length of 512 characters.
|
| 300 |
+
|
| 301 |
+
extra_headers: Send extra headers
|
| 302 |
+
|
| 303 |
+
extra_query: Add additional query parameters to the request
|
| 304 |
+
|
| 305 |
+
extra_body: Add additional JSON properties to the request
|
| 306 |
+
|
| 307 |
+
timeout: Override the client-level default timeout for this request, in seconds
|
| 308 |
+
"""
|
| 309 |
+
return await self._post(
|
| 310 |
+
"/batches",
|
| 311 |
+
body=await async_maybe_transform(
|
| 312 |
+
{
|
| 313 |
+
"completion_window": completion_window,
|
| 314 |
+
"endpoint": endpoint,
|
| 315 |
+
"input_file_id": input_file_id,
|
| 316 |
+
"metadata": metadata,
|
| 317 |
+
},
|
| 318 |
+
batch_create_params.BatchCreateParams,
|
| 319 |
+
),
|
| 320 |
+
options=make_request_options(
|
| 321 |
+
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
|
| 322 |
+
),
|
| 323 |
+
cast_to=Batch,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
async def retrieve(
|
| 327 |
+
self,
|
| 328 |
+
batch_id: str,
|
| 329 |
+
*,
|
| 330 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 331 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 332 |
+
extra_headers: Headers | None = None,
|
| 333 |
+
extra_query: Query | None = None,
|
| 334 |
+
extra_body: Body | None = None,
|
| 335 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 336 |
+
) -> Batch:
|
| 337 |
+
"""
|
| 338 |
+
Retrieves a batch.
|
| 339 |
+
|
| 340 |
+
Args:
|
| 341 |
+
extra_headers: Send extra headers
|
| 342 |
+
|
| 343 |
+
extra_query: Add additional query parameters to the request
|
| 344 |
+
|
| 345 |
+
extra_body: Add additional JSON properties to the request
|
| 346 |
+
|
| 347 |
+
timeout: Override the client-level default timeout for this request, in seconds
|
| 348 |
+
"""
|
| 349 |
+
if not batch_id:
|
| 350 |
+
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
|
| 351 |
+
return await self._get(
|
| 352 |
+
f"/batches/{batch_id}",
|
| 353 |
+
options=make_request_options(
|
| 354 |
+
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
|
| 355 |
+
),
|
| 356 |
+
cast_to=Batch,
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
def list(
|
| 360 |
+
self,
|
| 361 |
+
*,
|
| 362 |
+
after: str | NotGiven = NOT_GIVEN,
|
| 363 |
+
limit: int | NotGiven = NOT_GIVEN,
|
| 364 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 365 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 366 |
+
extra_headers: Headers | None = None,
|
| 367 |
+
extra_query: Query | None = None,
|
| 368 |
+
extra_body: Body | None = None,
|
| 369 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 370 |
+
) -> AsyncPaginator[Batch, AsyncCursorPage[Batch]]:
|
| 371 |
+
"""List your organization's batches.
|
| 372 |
+
|
| 373 |
+
Args:
|
| 374 |
+
after: A cursor for use in pagination.
|
| 375 |
+
|
| 376 |
+
`after` is an object ID that defines your place
|
| 377 |
+
in the list. For instance, if you make a list request and receive 100 objects,
|
| 378 |
+
ending with obj_foo, your subsequent call can include after=obj_foo in order to
|
| 379 |
+
fetch the next page of the list.
|
| 380 |
+
|
| 381 |
+
limit: A limit on the number of objects to be returned. Limit can range between 1 and
|
| 382 |
+
100, and the default is 20.
|
| 383 |
+
|
| 384 |
+
extra_headers: Send extra headers
|
| 385 |
+
|
| 386 |
+
extra_query: Add additional query parameters to the request
|
| 387 |
+
|
| 388 |
+
extra_body: Add additional JSON properties to the request
|
| 389 |
+
|
| 390 |
+
timeout: Override the client-level default timeout for this request, in seconds
|
| 391 |
+
"""
|
| 392 |
+
return self._get_api_list(
|
| 393 |
+
"/batches",
|
| 394 |
+
page=AsyncCursorPage[Batch],
|
| 395 |
+
options=make_request_options(
|
| 396 |
+
extra_headers=extra_headers,
|
| 397 |
+
extra_query=extra_query,
|
| 398 |
+
extra_body=extra_body,
|
| 399 |
+
timeout=timeout,
|
| 400 |
+
query=maybe_transform(
|
| 401 |
+
{
|
| 402 |
+
"after": after,
|
| 403 |
+
"limit": limit,
|
| 404 |
+
},
|
| 405 |
+
batch_list_params.BatchListParams,
|
| 406 |
+
),
|
| 407 |
+
),
|
| 408 |
+
model=Batch,
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
async def cancel(
|
| 412 |
+
self,
|
| 413 |
+
batch_id: str,
|
| 414 |
+
*,
|
| 415 |
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
| 416 |
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
| 417 |
+
extra_headers: Headers | None = None,
|
| 418 |
+
extra_query: Query | None = None,
|
| 419 |
+
extra_body: Body | None = None,
|
| 420 |
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
| 421 |
+
) -> Batch:
|
| 422 |
+
"""Cancels an in-progress batch.
|
| 423 |
+
|
| 424 |
+
The batch will be in status `cancelling` for up to
|
| 425 |
+
10 minutes, before changing to `cancelled`, where it will have partial results
|
| 426 |
+
(if any) available in the output file.
|
| 427 |
+
|
| 428 |
+
Args:
|
| 429 |
+
extra_headers: Send extra headers
|
| 430 |
+
|
| 431 |
+
extra_query: Add additional query parameters to the request
|
| 432 |
+
|
| 433 |
+
extra_body: Add additional JSON properties to the request
|
| 434 |
+
|
| 435 |
+
timeout: Override the client-level default timeout for this request, in seconds
|
| 436 |
+
"""
|
| 437 |
+
if not batch_id:
|
| 438 |
+
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
|
| 439 |
+
return await self._post(
|
| 440 |
+
f"/batches/{batch_id}/cancel",
|
| 441 |
+
options=make_request_options(
|
| 442 |
+
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
|
| 443 |
+
),
|
| 444 |
+
cast_to=Batch,
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
class BatchesWithRawResponse:
|
| 449 |
+
def __init__(self, batches: Batches) -> None:
|
| 450 |
+
self._batches = batches
|
| 451 |
+
|
| 452 |
+
self.create = _legacy_response.to_raw_response_wrapper(
|
| 453 |
+
batches.create,
|
| 454 |
+
)
|
| 455 |
+
self.retrieve = _legacy_response.to_raw_response_wrapper(
|
| 456 |
+
batches.retrieve,
|
| 457 |
+
)
|
| 458 |
+
self.list = _legacy_response.to_raw_response_wrapper(
|
| 459 |
+
batches.list,
|
| 460 |
+
)
|
| 461 |
+
self.cancel = _legacy_response.to_raw_response_wrapper(
|
| 462 |
+
batches.cancel,
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
class AsyncBatchesWithRawResponse:
|
| 467 |
+
def __init__(self, batches: AsyncBatches) -> None:
|
| 468 |
+
self._batches = batches
|
| 469 |
+
|
| 470 |
+
self.create = _legacy_response.async_to_raw_response_wrapper(
|
| 471 |
+
batches.create,
|
| 472 |
+
)
|
| 473 |
+
self.retrieve = _legacy_response.async_to_raw_response_wrapper(
|
| 474 |
+
batches.retrieve,
|
| 475 |
+
)
|
| 476 |
+
self.list = _legacy_response.async_to_raw_response_wrapper(
|
| 477 |
+
batches.list,
|
| 478 |
+
)
|
| 479 |
+
self.cancel = _legacy_response.async_to_raw_response_wrapper(
|
| 480 |
+
batches.cancel,
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
class BatchesWithStreamingResponse:
|
| 485 |
+
def __init__(self, batches: Batches) -> None:
|
| 486 |
+
self._batches = batches
|
| 487 |
+
|
| 488 |
+
self.create = to_streamed_response_wrapper(
|
| 489 |
+
batches.create,
|
| 490 |
+
)
|
| 491 |
+
self.retrieve = to_streamed_response_wrapper(
|
| 492 |
+
batches.retrieve,
|
| 493 |
+
)
|
| 494 |
+
self.list = to_streamed_response_wrapper(
|
| 495 |
+
batches.list,
|
| 496 |
+
)
|
| 497 |
+
self.cancel = to_streamed_response_wrapper(
|
| 498 |
+
batches.cancel,
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
class AsyncBatchesWithStreamingResponse:
|
| 503 |
+
def __init__(self, batches: AsyncBatches) -> None:
|
| 504 |
+
self._batches = batches
|
| 505 |
+
|
| 506 |
+
self.create = async_to_streamed_response_wrapper(
|
| 507 |
+
batches.create,
|
| 508 |
+
)
|
| 509 |
+
self.retrieve = async_to_streamed_response_wrapper(
|
| 510 |
+
batches.retrieve,
|
| 511 |
+
)
|
| 512 |
+
self.list = async_to_streamed_response_wrapper(
|
| 513 |
+
batches.list,
|
| 514 |
+
)
|
| 515 |
+
self.cancel = async_to_streamed_response_wrapper(
|
| 516 |
+
batches.cancel,
|
| 517 |
+
)
|
.venv/lib/python3.11/site-packages/openai/resources/beta/__init__.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
| 2 |
+
|
| 3 |
+
from .beta import (
|
| 4 |
+
Beta,
|
| 5 |
+
AsyncBeta,
|
| 6 |
+
BetaWithRawResponse,
|
| 7 |
+
AsyncBetaWithRawResponse,
|
| 8 |
+
BetaWithStreamingResponse,
|
| 9 |
+
AsyncBetaWithStreamingResponse,
|
| 10 |
+
)
|
| 11 |
+
from .threads import (
|
| 12 |
+
Threads,
|
| 13 |
+
AsyncThreads,
|
| 14 |
+
ThreadsWithRawResponse,
|
| 15 |
+
AsyncThreadsWithRawResponse,
|
| 16 |
+
ThreadsWithStreamingResponse,
|
| 17 |
+
AsyncThreadsWithStreamingResponse,
|
| 18 |
+
)
|
| 19 |
+
from .assistants import (
|
| 20 |
+
Assistants,
|
| 21 |
+
AsyncAssistants,
|
| 22 |
+
AssistantsWithRawResponse,
|
| 23 |
+
AsyncAssistantsWithRawResponse,
|
| 24 |
+
AssistantsWithStreamingResponse,
|
| 25 |
+
AsyncAssistantsWithStreamingResponse,
|
| 26 |
+
)
|
| 27 |
+
from .vector_stores import (
|
| 28 |
+
VectorStores,
|
| 29 |
+
AsyncVectorStores,
|
| 30 |
+
VectorStoresWithRawResponse,
|
| 31 |
+
AsyncVectorStoresWithRawResponse,
|
| 32 |
+
VectorStoresWithStreamingResponse,
|
| 33 |
+
AsyncVectorStoresWithStreamingResponse,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
__all__ = [
|
| 37 |
+
"VectorStores",
|
| 38 |
+
"AsyncVectorStores",
|
| 39 |
+
"VectorStoresWithRawResponse",
|
| 40 |
+
"AsyncVectorStoresWithRawResponse",
|
| 41 |
+
"VectorStoresWithStreamingResponse",
|
| 42 |
+
"AsyncVectorStoresWithStreamingResponse",
|
| 43 |
+
"Assistants",
|
| 44 |
+
"AsyncAssistants",
|
| 45 |
+
"AssistantsWithRawResponse",
|
| 46 |
+
"AsyncAssistantsWithRawResponse",
|
| 47 |
+
"AssistantsWithStreamingResponse",
|
| 48 |
+
"AsyncAssistantsWithStreamingResponse",
|
| 49 |
+
"Threads",
|
| 50 |
+
"AsyncThreads",
|
| 51 |
+
"ThreadsWithRawResponse",
|
| 52 |
+
"AsyncThreadsWithRawResponse",
|
| 53 |
+
"ThreadsWithStreamingResponse",
|
| 54 |
+
"AsyncThreadsWithStreamingResponse",
|
| 55 |
+
"Beta",
|
| 56 |
+
"AsyncBeta",
|
| 57 |
+
"BetaWithRawResponse",
|
| 58 |
+
"AsyncBetaWithRawResponse",
|
| 59 |
+
"BetaWithStreamingResponse",
|
| 60 |
+
"AsyncBetaWithStreamingResponse",
|
| 61 |
+
]
|
.venv/lib/python3.11/site-packages/openai/resources/beta/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.56 kB). View file
|
|
|