Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import sys | |
| import logging | |
| import argparse | |
| from typing import Any, List, Type, Optional | |
| from typing_extensions import ClassVar | |
| import httpx | |
| import pydantic | |
| import openai | |
| from . import _tools | |
| from .. import _ApiType, __version__ | |
| from ._api import register_commands | |
| from ._utils import can_use_http2 | |
| from ._errors import CLIError, display_error | |
| from .._compat import PYDANTIC_V1, ConfigDict, model_parse | |
| from .._models import BaseModel | |
| from .._exceptions import APIError | |
| logger = logging.getLogger() | |
| formatter = logging.Formatter("[%(asctime)s] %(message)s") | |
| handler = logging.StreamHandler(sys.stderr) | |
| handler.setFormatter(formatter) | |
| logger.addHandler(handler) | |
| class Arguments(BaseModel): | |
| if PYDANTIC_V1: | |
| class Config(pydantic.BaseConfig): # type: ignore | |
| extra: Any = pydantic.Extra.ignore # type: ignore | |
| else: | |
| model_config: ClassVar[ConfigDict] = ConfigDict( | |
| extra="ignore", | |
| ) | |
| verbosity: int | |
| version: Optional[str] = None | |
| api_key: Optional[str] | |
| api_base: Optional[str] | |
| organization: Optional[str] | |
| proxy: Optional[List[str]] | |
| api_type: Optional[_ApiType] = None | |
| api_version: Optional[str] = None | |
| # azure | |
| azure_endpoint: Optional[str] = None | |
| azure_ad_token: Optional[str] = None | |
| # internal, set by subparsers to parse their specific args | |
| args_model: Optional[Type[BaseModel]] = None | |
| # internal, used so that subparsers can forward unknown arguments | |
| unknown_args: List[str] = [] | |
| allow_unknown_args: bool = False | |
| def _build_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser(description=None, prog="openai") | |
| parser.add_argument( | |
| "-v", | |
| "--verbose", | |
| action="count", | |
| dest="verbosity", | |
| default=0, | |
| help="Set verbosity.", | |
| ) | |
| parser.add_argument("-b", "--api-base", help="What API base url to use.") | |
| parser.add_argument("-k", "--api-key", help="What API key to use.") | |
| parser.add_argument("-p", "--proxy", nargs="+", help="What proxy to use.") | |
| parser.add_argument( | |
| "-o", | |
| "--organization", | |
| help="Which organization to run as (will use your default organization if not specified)", | |
| ) | |
| parser.add_argument( | |
| "-t", | |
| "--api-type", | |
| type=str, | |
| choices=("openai", "azure"), | |
| help="The backend API to call, must be `openai` or `azure`", | |
| ) | |
| parser.add_argument( | |
| "--api-version", | |
| help="The Azure API version, e.g. 'https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning'", | |
| ) | |
| # azure | |
| parser.add_argument( | |
| "--azure-endpoint", | |
| help="The Azure endpoint, e.g. 'https://endpoint.openai.azure.com'", | |
| ) | |
| parser.add_argument( | |
| "--azure-ad-token", | |
| help="A token from Azure Active Directory, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id", | |
| ) | |
| # prints the package version | |
| parser.add_argument( | |
| "-V", | |
| "--version", | |
| action="version", | |
| version="%(prog)s " + __version__, | |
| ) | |
| def help() -> None: | |
| parser.print_help() | |
| parser.set_defaults(func=help) | |
| subparsers = parser.add_subparsers() | |
| sub_api = subparsers.add_parser("api", help="Direct API calls") | |
| register_commands(sub_api) | |
| sub_tools = subparsers.add_parser("tools", help="Client side tools for convenience") | |
| _tools.register_commands(sub_tools, subparsers) | |
| return parser | |
| def main() -> int: | |
| try: | |
| _main() | |
| except (APIError, CLIError, pydantic.ValidationError) as err: | |
| display_error(err) | |
| return 1 | |
| except KeyboardInterrupt: | |
| sys.stderr.write("\n") | |
| return 1 | |
| return 0 | |
| def _parse_args(parser: argparse.ArgumentParser) -> tuple[argparse.Namespace, Arguments, list[str]]: | |
| # argparse by default will strip out the `--` but we want to keep it for unknown arguments | |
| if "--" in sys.argv: | |
| idx = sys.argv.index("--") | |
| known_args = sys.argv[1:idx] | |
| unknown_args = sys.argv[idx:] | |
| else: | |
| known_args = sys.argv[1:] | |
| unknown_args = [] | |
| parsed, remaining_unknown = parser.parse_known_args(known_args) | |
| # append any remaining unknown arguments from the initial parsing | |
| remaining_unknown.extend(unknown_args) | |
| args = model_parse(Arguments, vars(parsed)) | |
| if not args.allow_unknown_args: | |
| # we have to parse twice to ensure any unknown arguments | |
| # result in an error if that behaviour is desired | |
| parser.parse_args() | |
| return parsed, args, remaining_unknown | |
| def _main() -> None: | |
| parser = _build_parser() | |
| parsed, args, unknown = _parse_args(parser) | |
| if args.verbosity != 0: | |
| sys.stderr.write("Warning: --verbosity isn't supported yet\n") | |
| proxies: dict[str, httpx.BaseTransport] = {} | |
| if args.proxy is not None: | |
| for proxy in args.proxy: | |
| key = "https://" if proxy.startswith("https") else "http://" | |
| if key in proxies: | |
| raise CLIError(f"Multiple {key} proxies given - only the last one would be used") | |
| proxies[key] = httpx.HTTPTransport(proxy=httpx.Proxy(httpx.URL(proxy))) | |
| http_client = httpx.Client( | |
| mounts=proxies or None, | |
| http2=can_use_http2(), | |
| ) | |
| openai.http_client = http_client | |
| if args.organization: | |
| openai.organization = args.organization | |
| if args.api_key: | |
| openai.api_key = args.api_key | |
| if args.api_base: | |
| openai.base_url = args.api_base | |
| # azure | |
| if args.api_type is not None: | |
| openai.api_type = args.api_type | |
| if args.azure_endpoint is not None: | |
| openai.azure_endpoint = args.azure_endpoint | |
| if args.api_version is not None: | |
| openai.api_version = args.api_version | |
| if args.azure_ad_token is not None: | |
| openai.azure_ad_token = args.azure_ad_token | |
| try: | |
| if args.args_model: | |
| parsed.func( | |
| model_parse( | |
| args.args_model, | |
| { | |
| **{ | |
| # we omit None values so that they can be defaulted to `NotGiven` | |
| # and we'll strip it from the API request | |
| key: value | |
| for key, value in vars(parsed).items() | |
| if value is not None | |
| }, | |
| "unknown_args": unknown, | |
| }, | |
| ) | |
| ) | |
| else: | |
| parsed.func() | |
| finally: | |
| try: | |
| http_client.close() | |
| except Exception: | |
| pass | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |